From f8bef48cf85f129b87d4be7380d9474afce6ff71 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 29 Sep 2025 18:13:58 -0700 Subject: [PATCH 01/40] nav/pointcloud spec proposal --- dimos/map/spec.py | 31 +++++++++++++++++++++++++++++++ dimos/navigation/spec.py | 32 ++++++++++++++++++++++++++++++++ dimos/perception/spec.py | 22 ++++++++++++++++++++++ 3 files changed, 85 insertions(+) create mode 100644 dimos/map/spec.py create mode 100644 dimos/navigation/spec.py create mode 100644 dimos/perception/spec.py diff --git a/dimos/map/spec.py b/dimos/map/spec.py new file mode 100644 index 0000000000..0733b8ce33 --- /dev/null +++ b/dimos/map/spec.py @@ -0,0 +1,31 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC + +from dimos.core import Out +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 + + +class GlobalPointcloudSpec(ABC): + global_pointcloud: Out[PointCloud2] = None + + +class GlobalMapSpec(ABC): + global_map: Out[OccupancyGrid] = None + + +class GlobalCostmapSpec(ABC): + global_costmap: Out[OccupancyGrid] = None diff --git a/dimos/navigation/spec.py b/dimos/navigation/spec.py new file mode 100644 index 0000000000..14370ab53d --- /dev/null +++ b/dimos/navigation/spec.py @@ -0,0 +1,32 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC + +from dimos.core import In, Out +from dimos.msgs.geometry_msgs import Path, PoseStamped, Twist + + +class NavSpec(ABC): + goal_req: In[PoseStamped] = None # type: ignore + goal_active: Out[PoseStamped] = None # type: ignore + path_active: Out[Path] = None # type: ignore + ctrl: Out[Twist] = None # type: ignore + + # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" + def goto(self, target: PoseStamped) -> None: + pass + + def stop(self) -> None: + pass diff --git a/dimos/perception/spec.py b/dimos/perception/spec.py new file mode 100644 index 0000000000..0b73750d53 --- /dev/null +++ b/dimos/perception/spec.py @@ -0,0 +1,22 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC + +from dimos.core import Out +from dimos.msgs.sensor_msgs import PointCloud2 + + +class PointcloudPerception(ABC): + pointcloud: Out[PointCloud2] = None # type: ignore From 6a669b69ad0c338fc53bfa8c6238be0c6a946e2e Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 29 Sep 2025 18:19:12 -0700 Subject: [PATCH 02/40] rosnav spec --- dimos/navigation/rosnav.py | 40 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 dimos/navigation/rosnav.py diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py new file mode 100644 index 0000000000..86b1558895 --- /dev/null +++ b/dimos/navigation/rosnav.py @@ -0,0 +1,40 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core import In, Module, Out +from dimos.msgs.geometry_msgs import Path, PoseStamped, Twist +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.navigation.spec import NavSpec +from dimos.perception.pointcloud.spec import GlobalPointcloudPerception, PointcloudPerception + + +class Config: + global_frame_id: str = "world" + + +class RosNav(Module, PointcloudPerception, GlobalPointcloudPerception, NavSpec): + goal_req: In[PoseStamped] = None # type: ignore + goal_active: Out[PoseStamped] = None # type: ignore + path_active: Out[Path] = None # type: ignore + + ctrl: Out[Twist] = None # type: ignore + + pointcloud: Out[PointCloud2] = None # type: ignore + global_pointcloud: Out[PointCloud2] = None # type: ignore + + config: Config + + def __init__(self, *args, **kwargs): + self.config = Config(**kwargs) + super().__init__() From e5571276bfa27308f986458e5602fb5a3b44ef50 Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 29 Sep 2025 18:22:08 -0700 Subject: [PATCH 03/40] cleaner nav spec --- dimos/{map => mapping}/spec.py | 2 +- dimos/navigation/rosnav.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) rename dimos/{map => mapping}/spec.py (96%) diff --git a/dimos/map/spec.py b/dimos/mapping/spec.py similarity index 96% rename from dimos/map/spec.py rename to dimos/mapping/spec.py index 0733b8ce33..c8675df3f9 100644 --- a/dimos/map/spec.py +++ b/dimos/mapping/spec.py @@ -19,7 +19,7 @@ from dimos.msgs.sensor_msgs import PointCloud2 -class GlobalPointcloudSpec(ABC): +class Global3DMapSpec(ABC): global_pointcloud: Out[PointCloud2] = None diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 86b1558895..09be7c5096 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -13,17 +13,18 @@ # limitations under the License. from dimos.core import In, Module, Out +from dimos.mapping.spec import Global3DMapSpec from dimos.msgs.geometry_msgs import Path, PoseStamped, Twist from dimos.msgs.sensor_msgs import PointCloud2 from dimos.navigation.spec import NavSpec -from dimos.perception.pointcloud.spec import GlobalPointcloudPerception, PointcloudPerception +from dimos.perception.pointcloud.spec import PointcloudPerception class Config: global_frame_id: str = "world" -class RosNav(Module, PointcloudPerception, GlobalPointcloudPerception, NavSpec): +class RosNav(Module, PointcloudPerception, Global3DMapSpec, NavSpec): goal_req: In[PoseStamped] = None # type: ignore goal_active: Out[PoseStamped] = None # type: ignore path_active: Out[Path] = None # type: ignore From 37c84e56394db81e32dc923a6c08d4a792dae12f Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 30 Sep 2025 14:41:50 -0700 Subject: [PATCH 04/40] PR comments --- dimos/navigation/rosnav.py | 10 ---------- dimos/navigation/spec.py | 4 ++-- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 09be7c5096..b3104eee5b 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -20,10 +20,6 @@ from dimos.perception.pointcloud.spec import PointcloudPerception -class Config: - global_frame_id: str = "world" - - class RosNav(Module, PointcloudPerception, Global3DMapSpec, NavSpec): goal_req: In[PoseStamped] = None # type: ignore goal_active: Out[PoseStamped] = None # type: ignore @@ -33,9 +29,3 @@ class RosNav(Module, PointcloudPerception, Global3DMapSpec, NavSpec): pointcloud: Out[PointCloud2] = None # type: ignore global_pointcloud: Out[PointCloud2] = None # type: ignore - - config: Config - - def __init__(self, *args, **kwargs): - self.config = Config(**kwargs) - super().__init__() diff --git a/dimos/navigation/spec.py b/dimos/navigation/spec.py index 14370ab53d..8c752c8af1 100644 --- a/dimos/navigation/spec.py +++ b/dimos/navigation/spec.py @@ -25,8 +25,8 @@ class NavSpec(ABC): ctrl: Out[Twist] = None # type: ignore # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" - def goto(self, target: PoseStamped) -> None: + def navigate_to_target(self, target: PoseStamped) -> None: pass - def stop(self) -> None: + def stop_navigating(self) -> None: pass From 0e4740f39e8d32843f04132cc630aa5c98cf9d1b Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 30 Sep 2025 14:45:03 -0700 Subject: [PATCH 05/40] type check --- dimos/navigation/rosnav.py | 3 ++- dimos/navigation/spec.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index b3104eee5b..bea3a7d542 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -14,7 +14,8 @@ from dimos.core import In, Module, Out from dimos.mapping.spec import Global3DMapSpec -from dimos.msgs.geometry_msgs import Path, PoseStamped, Twist +from dimos.msgs.geometry_msgs import PoseStamped, Twist +from dimos.msgs.nav_msgs import Path from dimos.msgs.sensor_msgs import PointCloud2 from dimos.navigation.spec import NavSpec from dimos.perception.pointcloud.spec import PointcloudPerception diff --git a/dimos/navigation/spec.py b/dimos/navigation/spec.py index 8c752c8af1..69aa7b2409 100644 --- a/dimos/navigation/spec.py +++ b/dimos/navigation/spec.py @@ -15,7 +15,8 @@ from abc import ABC from dimos.core import In, Out -from dimos.msgs.geometry_msgs import Path, PoseStamped, Twist +from dimos.msgs.geometry_msgs import PoseStamped, Twist +from dimos.msgs.nav_msgs import Path class NavSpec(ABC): From e9f5057911029725b6f7f1ebfa58e1a5cbd0afdd Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 30 Sep 2025 15:11:24 -0700 Subject: [PATCH 06/40] switched the spec to protocols --- dimos/mapping/spec.py | 14 +++++++------- dimos/navigation/rosnav.py | 17 ++++++++++++----- dimos/navigation/spec.py | 18 ++++++++---------- dimos/perception/spec.py | 6 +++--- 4 files changed, 30 insertions(+), 25 deletions(-) diff --git a/dimos/mapping/spec.py b/dimos/mapping/spec.py index c8675df3f9..3d82cea0cc 100644 --- a/dimos/mapping/spec.py +++ b/dimos/mapping/spec.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC +from typing import Protocol from dimos.core import Out from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import PointCloud2 -class Global3DMapSpec(ABC): - global_pointcloud: Out[PointCloud2] = None +class Global3DMapSpec(Protocol): + global_pointcloud: Out[PointCloud2] -class GlobalMapSpec(ABC): - global_map: Out[OccupancyGrid] = None +class GlobalMapSpec(Protocol): + global_map: Out[OccupancyGrid] -class GlobalCostmapSpec(ABC): - global_costmap: Out[OccupancyGrid] = None +class GlobalCostmapSpec(Protocol): + global_costmap: Out[OccupancyGrid] diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index bea3a7d542..9bdee3fe3a 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -13,20 +13,27 @@ # limitations under the License. from dimos.core import In, Module, Out -from dimos.mapping.spec import Global3DMapSpec from dimos.msgs.geometry_msgs import PoseStamped, Twist from dimos.msgs.nav_msgs import Path from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.navigation.spec import NavSpec -from dimos.perception.pointcloud.spec import PointcloudPerception -class RosNav(Module, PointcloudPerception, Global3DMapSpec, NavSpec): +class ROSNav(Module): goal_req: In[PoseStamped] = None # type: ignore goal_active: Out[PoseStamped] = None # type: ignore path_active: Out[Path] = None # type: ignore - ctrl: Out[Twist] = None # type: ignore + # PointcloudPerception attributes pointcloud: Out[PointCloud2] = None # type: ignore + + # Global3DMapSpec attributes global_pointcloud: Out[PointCloud2] = None # type: ignore + + def navigate_to_target(self, target: PoseStamped) -> None: + # TODO: Implement navigation logic + pass + + def stop_navigating(self) -> None: + # TODO: Implement stop logic + pass diff --git a/dimos/navigation/spec.py b/dimos/navigation/spec.py index 69aa7b2409..69bfdac262 100644 --- a/dimos/navigation/spec.py +++ b/dimos/navigation/spec.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC +from typing import Protocol from dimos.core import In, Out from dimos.msgs.geometry_msgs import PoseStamped, Twist from dimos.msgs.nav_msgs import Path -class NavSpec(ABC): - goal_req: In[PoseStamped] = None # type: ignore - goal_active: Out[PoseStamped] = None # type: ignore - path_active: Out[Path] = None # type: ignore - ctrl: Out[Twist] = None # type: ignore +class NavSpec(Protocol): + goal_req: In[PoseStamped] + goal_active: Out[PoseStamped] + path_active: Out[Path] + ctrl: Out[Twist] # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" - def navigate_to_target(self, target: PoseStamped) -> None: - pass + def navigate_to_target(self, target: PoseStamped) -> None: ... - def stop_navigating(self) -> None: - pass + def stop_navigating(self) -> None: ... diff --git a/dimos/perception/spec.py b/dimos/perception/spec.py index 0b73750d53..de53ce9bd7 100644 --- a/dimos/perception/spec.py +++ b/dimos/perception/spec.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC +from typing import Protocol from dimos.core import Out from dimos.msgs.sensor_msgs import PointCloud2 -class PointcloudPerception(ABC): - pointcloud: Out[PointCloud2] = None # type: ignore +class PointcloudPerception(Protocol): + pointcloud: Out[PointCloud2] From f913f57a09a2d869dbc22b2afbae14e142cfca5b Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 30 Sep 2025 15:13:19 -0700 Subject: [PATCH 07/40] rosnav type check --- dimos/navigation/test_rosnav.py | 37 +++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 dimos/navigation/test_rosnav.py diff --git a/dimos/navigation/test_rosnav.py b/dimos/navigation/test_rosnav.py new file mode 100644 index 0000000000..5de1c0e6ab --- /dev/null +++ b/dimos/navigation/test_rosnav.py @@ -0,0 +1,37 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.mapping.spec import Global3DMapSpec +from dimos.navigation.rosnav import ROSNav +from dimos.navigation.spec import NavSpec +from dimos.perception.spec import PointcloudPerception + + +class RosNavSpec(NavSpec, PointcloudPerception, Global3DMapSpec, Protocol): + """Combined protocol for navigation components.""" + + pass + + +def accepts_combined_protocol(nav: RosNavSpec) -> None: + """Function that accepts all navigation protocols at once.""" + pass + + +def test_typing_prototypes(): + """Test that ROSNav correctly implements all required protocols.""" + rosnav = ROSNav() + accepts_combined_protocol(rosnav) From ccc7e733c78ef067dafa7b5ec5b5f90fbfa9f238 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 16:30:16 -0700 Subject: [PATCH 08/40] rewrote ros nav --- dimos/core/__init__.py | 62 ++- dimos/hardware/camera/module.py | 2 +- dimos/navigation/rosnav.py | 318 ++++++++++++- dimos/navigation/rosnav/__init__.py | 2 - dimos/navigation/rosnav/nav_bot.py | 423 ------------------ dimos/navigation/rosnav/rosnav.py | 47 -- .../unitree_webrtc/connection/__init__.py | 1 + .../{ => connection}/connection.py | 17 +- dimos/robot/unitree_webrtc/connection/g1.py | 69 +++ .../robot/unitree_webrtc/modular/__init__.py | 4 +- dimos/robot/unitree_webrtc/modular/ivan_g1.py | 91 ++++ dimos/robot/unitree_webrtc/modular/misc.py | 33 ++ dimos/utils/logging_config.py | 11 +- 13 files changed, 579 insertions(+), 501 deletions(-) delete mode 100644 dimos/navigation/rosnav/__init__.py delete mode 100644 dimos/navigation/rosnav/nav_bot.py delete mode 100644 dimos/navigation/rosnav/rosnav.py create mode 100644 dimos/robot/unitree_webrtc/connection/__init__.py rename dimos/robot/unitree_webrtc/{ => connection}/connection.py (97%) create mode 100644 dimos/robot/unitree_webrtc/connection/g1.py create mode 100644 dimos/robot/unitree_webrtc/modular/ivan_g1.py create mode 100644 dimos/robot/unitree_webrtc/modular/misc.py diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 0bd3603126..9bc954f3b0 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -1,7 +1,8 @@ from __future__ import annotations import multiprocessing as mp -from typing import Optional +import time +from typing import Any, Optional, Protocol from dask.distributed import Client, LocalCluster from rich.console import Console @@ -10,7 +11,6 @@ from dimos.core.core import rpc from dimos.core.module import Module, ModuleBase, ModuleConfig from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport -from dimos.utils.actor_registry import ActorRegistry from dimos.core.transport import ( LCMTransport, SHMTransport, @@ -21,6 +21,7 @@ from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.protocol.rpc.spec import RPCSpec from dimos.protocol.tf import LCMTF, TF, PubSubTF, TFConfig, TFSpec +from dimos.utils.actor_registry import ActorRegistry __all__ = [ "DimosCluster", @@ -154,7 +155,42 @@ def rpc_call(*args, **kwargs): return self.actor_instance.__getattr__(name) -DimosCluster = Client +class DimosCluster(Protocol): + """Extended Dask Client with DimOS-specific methods. + + This protocol defines the interface of a Dask Client that has been + patched with additional methods via patchdask(). + """ + + def deploy( + self, + actor_class: type, + *args: Any, + **kwargs: Any, + ) -> RPCClient: + """Deploy an actor to the cluster and return an RPC client. + + Args: + actor_class: The actor class to deploy + *args: Positional arguments to pass to the actor constructor + **kwargs: Keyword arguments to pass to the actor constructor + + Returns: + RPCClient: A client for making RPC calls to the deployed actor + """ + ... + + def check_worker_memory(self) -> None: + """Check and display memory usage of all workers.""" + ... + + def stop(self) -> None: + """Stop the client (alias for close).""" + ... + + def close_all(self) -> None: + """Close all resources including cluster, client, and shared memory transports.""" + ... def patchdask(dask_client: Client, local_cluster: LocalCluster) -> DimosCluster: @@ -244,9 +280,10 @@ def close_all(): # Stop all SharedMemory transports before closing Dask # This prevents the "leaked shared_memory objects" warning and hangs try: - from dimos.protocol.pubsub import shmpubsub import gc + from dimos.protocol.pubsub import shmpubsub + for obj in gc.get_objects(): if isinstance(obj, (shmpubsub.SharedMemory, shmpubsub.PickleSharedMemory)): try: @@ -299,18 +336,21 @@ def close_all(): dask_client.check_worker_memory = check_worker_memory dask_client.stop = lambda: dask_client.close() dask_client.close_all = close_all - return dask_client + return dask_client # type: ignore[return-value] -def start(n: Optional[int] = None, memory_limit: str = "auto") -> Client: +def start(n: Optional[int] = None, memory_limit: str = "auto") -> DimosCluster: """Start a Dask LocalCluster with specified workers and memory limits. Args: n: Number of workers (defaults to CPU count) memory_limit: Memory limit per worker (e.g., '4GB', '2GiB', or 'auto' for Dask's default) + + Returns: + DimosCluster: A patched Dask client with deploy(), check_worker_memory(), stop(), and close_all() methods """ - import signal import atexit + import signal console = Console() if not n: @@ -358,3 +398,11 @@ def signal_handler(sig, frame): signal.signal(signal.SIGTERM, signal_handler) return patched_client + + +def wait_exit(): + while True: + try: + time.sleep(1) + except KeyboardInterrupt: + print("exiting...") diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index 2b2880b80a..875b6f66a7 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -107,7 +107,7 @@ def video_stream(self) -> Image: for image in iter(_queue.get, None): yield image - def camera_info_stream(self, frequency: float = 5.0) -> Observable[CameraInfo]: + def camera_info_stream(self, frequency: float = 1.0) -> Observable[CameraInfo]: def camera_info(_) -> CameraInfo: self.hardware.camera_info.ts = time.time() return self.hardware.camera_info diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 9bdee3fe3a..1036739f25 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -1,3 +1,4 @@ +#!/usr/bin/env python3 # Copyright 2025 Dimensional Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -12,28 +13,313 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.core import In, Module, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist +""" +NavBot class for navigation-related functionality. +Encapsulates ROS bridge and topic remapping for Unitree robots. +""" + +import logging +import threading +import time + +import rclpy +from geometry_msgs.msg import PointStamped as ROSPointStamped +from geometry_msgs.msg import PoseStamped as ROSPoseStamped + +# ROS2 message imports +from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from nav_msgs.msg import Path as ROSPath +from rclpy.node import Node +from sensor_msgs.msg import Joy as ROSJoy +from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 +from std_msgs.msg import Bool as ROSBool +from std_msgs.msg import Int8 as ROSInt8 +from tf2_msgs.msg import TFMessage as ROSTFMessage + +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + TwistStamped, + Vector3, +) from dimos.msgs.nav_msgs import Path from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.std_msgs import Bool +from dimos.msgs.tf2_msgs.TFMessage import TFMessage +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion + +logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) class ROSNav(Module): - goal_req: In[PoseStamped] = None # type: ignore - goal_active: Out[PoseStamped] = None # type: ignore - path_active: Out[Path] = None # type: ignore - ctrl: Out[Twist] = None # type: ignore + goal_req: In[PoseStamped] = None + + pointcloud: Out[PointCloud2] = None + global_pointcloud: Out[PointCloud2] = None + + goal_active: Out[PoseStamped] = None + path_active: Out[Path] = None + cmd_vel: Out[TwistStamped] = None + + def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): + super().__init__(*args, **kwargs) + + if not rclpy.ok(): + rclpy.init() + self._node = Node("navigation_module") + + self.goal_reach = None + self.sensor_to_base_link_transform = sensor_to_base_link_transform or [ + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + self.spin_thread = None + + # ROS2 Publishers + self.goal_pose_pub = self._node.create_publisher(ROSPoseStamped, "/goal_pose", 10) + self.soft_stop_pub = self._node.create_publisher(ROSInt8, "/soft_stop", 10) + self.joy_pub = self._node.create_publisher(ROSJoy, "/joy", 10) + + # ROS2 Subscribers + self.goal_reached_sub = self._node.create_subscription( + ROSBool, "/goal_reached", self._on_ros_goal_reached, 10 + ) + self.cmd_vel_sub = self._node.create_subscription( + ROSTwistStamped, "/cmd_vel", self._on_ros_cmd_vel, 10 + ) + self.goal_waypoint_sub = self._node.create_subscription( + ROSPointStamped, "/way_point", self._on_ros_goal_waypoint, 10 + ) + self.registered_scan_sub = self._node.create_subscription( + ROSPointCloud2, "/registered_scan", self._on_ros_registered_scan, 10 + ) + + self.global_pointcloud_sub = self._node.create_subscription( + ROSPointCloud2, "/terrain_map_ext", self._on_ros_global_pointcloud, 10 + ) + + self.path_sub = self._node.create_subscription(ROSPath, "/path", self._on_ros_path, 10) + self.tf_sub = self._node.create_subscription(ROSTFMessage, "/tf", self._on_ros_tf, 10) + + logger.info("NavigationModule initialized with ROS2 node") + + @rpc + def start(self): + self._running = True + self.spin_thread = threading.Thread(target=self._spin_node, daemon=True) + self.spin_thread.start() + + def broadcast_lidar(): + while self._running: + if not hasattr(self, "_local_pointcloud"): + return + self.pointcloud.publish(PointCloud2.from_ros_msg(self._local_pointcloud)) + time.sleep(0.5) + + def broadcast_map(): + while self._running: + if not hasattr(self, "_global_pointcloud"): + return + self.global_pointcloud.publish(PointCloud2.from_ros_msg(self.global_pointcloud)) + time.sleep(1.0) + + self.map_broadcast_thread = threading.Thread(target=broadcast_map, daemon=True) + self.lidar_broadcast_thread = threading.Thread(target=broadcast_lidar, daemon=True) + + self.goal_req.subscribe(self._on_goal_pose) + + logger.info("NavigationModule started with ROS2 spinning") + + def _spin_node(self): + while self._running and rclpy.ok(): + try: + rclpy.spin_once(self._node, timeout_sec=0.1) + except Exception as e: + if self._running: + logger.error(f"ROS2 spin error: {e}") + + def _on_ros_goal_reached(self, msg: ROSBool): + self.goal_reach = msg.data + + def _on_ros_goal_waypoint(self, msg: ROSPointStamped): + dimos_pose = PoseStamped( + ts=time.time(), + frame_id=msg.header.frame_id, + position=Vector3(msg.point.x, msg.point.y, msg.point.z), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), + ) + self.goal_active.publish(dimos_pose) + + def _on_ros_cmd_vel(self, msg: ROSTwistStamped): + self.cmd_vel.publish(TwistStamped.from_ros_msg(msg)) + + def _on_ros_registered_scan(self, msg: ROSPointCloud2): + self._local_pointcloud = msg + + def _on_ros_global_pointcloud(self, msg: ROSPointCloud2): + self._global_pointcloud = msg + + def _on_ros_path(self, msg: ROSPath): + dimos_path = Path.from_ros_msg(msg) + dimos_path.frame_id = "base_link" + self.path_active.publish(dimos_path) + + def _on_ros_tf(self, msg: ROSTFMessage): + ros_tf = TFMessage.from_ros_msg(msg) + + translation = Vector3( + self.sensor_to_base_link_transform[0], + self.sensor_to_base_link_transform[1], + self.sensor_to_base_link_transform[2], + ) + euler_angles = Vector3( + self.sensor_to_base_link_transform[3], + self.sensor_to_base_link_transform[4], + self.sensor_to_base_link_transform[5], + ) + rotation = euler_to_quaternion(euler_angles) + + sensor_to_base_link_tf = Transform( + translation=translation, + rotation=rotation, + frame_id="sensor", + child_frame_id="base_link", + ts=time.time(), + ) + + map_to_world_tf = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=euler_to_quaternion(Vector3(0.0, 0.0, 0.0)), + frame_id="map", + child_frame_id="world", + ts=time.time(), + ) + + self.tf.publish(sensor_to_base_link_tf, map_to_world_tf, *ros_tf.transforms) + + def _on_goal_pose(self, msg: PoseStamped): + self.navigate_to(msg) + + def _on_cancel_goal(self, msg: Bool): + if msg.data: + self.stop() + + def _set_autonomy_mode(self): + joy_msg = ROSJoy() + joy_msg.axes = [ + 0.0, # axis 0 + 0.0, # axis 1 + -1.0, # axis 2 + 0.0, # axis 3 + 1.0, # axis 4 + 1.0, # axis 5 + 0.0, # axis 6 + 0.0, # axis 7 + ] + joy_msg.buttons = [ + 0, # button 0 + 0, # button 1 + 0, # button 2 + 0, # button 3 + 0, # button 4 + 0, # button 5 + 0, # button 6 + 1, # button 7 - controls autonomy mode + 0, # button 8 + 0, # button 9 + 0, # button 10 + ] + self.joy_pub.publish(joy_msg) + logger.info("Setting autonomy mode via Joy message") + + @rpc + def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: + """ + Navigate to a target pose by publishing to ROS topics. + + Args: + pose: Target pose to navigate to + timeout: Maximum time to wait for goal (seconds) + + Returns: + True if navigation was successful + """ + logger.info( + f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + ) + + self.goal_reach = None + self._set_autonomy_mode() + + # Enable soft stop (0 = enable) + soft_stop_msg = ROSInt8() + soft_stop_msg.data = 0 + self.soft_stop_pub.publish(soft_stop_msg) + + ros_pose = pose.to_ros_msg() + self.goal_pose_pub.publish(ros_pose) + + # Wait for goal to be reached + start_time = time.time() + while time.time() - start_time < timeout: + if self.goal_reach is not None: + soft_stop_msg.data = 2 + self.soft_stop_pub.publish(soft_stop_msg) + return self.goal_reach + time.sleep(0.1) + + self.stop_navigation() + logger.warning(f"Navigation timed out after {timeout} seconds") + return False + + @rpc + def stop_navigation(self) -> bool: + """ + Stop current navigation by publishing to ROS topics. + + Returns: + True if stop command was sent successfully + """ + logger.info("Stopping navigation") + + cancel_msg = ROSBool() + cancel_msg.data = True + + soft_stop_msg = ROSInt8() + soft_stop_msg.data = 2 + self.soft_stop_pub.publish(soft_stop_msg) + + return True - # PointcloudPerception attributes - pointcloud: Out[PointCloud2] = None # type: ignore + @rpc + def stop(self): + try: + self._running = False + if self.spin_thread: + self.spin_thread.join(timeout=1) + self._node.destroy_node() + except Exception as e: + logger.error(f"Error during shutdown: {e}") - # Global3DMapSpec attributes - global_pointcloud: Out[PointCloud2] = None # type: ignore - def navigate_to_target(self, target: PoseStamped) -> None: - # TODO: Implement navigation logic - pass +def deploy(dimos: DimosCluster): + nav = dimos.deploy(ROSNav) + # nav.pointcloud.transport = pSHMTransport("/lidar") + # nav.global_pointcloud.transport = pSHMTransport("/map") + nav.pointcloud.transport = LCMTransport("/lidar", PointCloud2) + nav.global_pointcloud.transport = LCMTransport("/map", PointCloud2) - def stop_navigating(self) -> None: - # TODO: Implement stop logic - pass + nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) + nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) + nav.goal_active.transport = LCMTransport("/goal_active", PoseStamped) + nav.path_active.transport = LCMTransport("/path_active", Path) + nav.cmd_vel.transport = LCMTransport("/cmd_vel", TwistStamped) + nav.start() + return nav diff --git a/dimos/navigation/rosnav/__init__.py b/dimos/navigation/rosnav/__init__.py deleted file mode 100644 index a88bffeb43..0000000000 --- a/dimos/navigation/rosnav/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from dimos.navigation.rosnav.rosnav import ROSNav -from dimos.navigation.rosnav.nav_bot import ROSNavigationModule, NavBot diff --git a/dimos/navigation/rosnav/nav_bot.py b/dimos/navigation/rosnav/nav_bot.py deleted file mode 100644 index 4a5ca0c45a..0000000000 --- a/dimos/navigation/rosnav/nav_bot.py +++ /dev/null @@ -1,423 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -NavBot class for navigation-related functionality. -Encapsulates ROS bridge and topic remapping for Unitree robots. -""" - -import logging -import time -import threading - -import rclpy -from rclpy.node import Node -from rclpy.executors import SingleThreadedExecutor - -from dimos import core -from dimos.protocol import pubsub -from dimos.core import In, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Twist, Transform, Vector3, Quaternion -from dimos.msgs.nav_msgs import Odometry, Path -from dimos.msgs.sensor_msgs import PointCloud2, Joy -from dimos.msgs.std_msgs import Bool -from dimos.msgs.tf2_msgs.TFMessage import TFMessage -from dimos.utils.transform_utils import euler_to_quaternion -from dimos.utils.logging_config import setup_logger -from dimos.navigation.rosnav import ROSNav - -# ROS2 message imports -from geometry_msgs.msg import TwistStamped as ROSTwistStamped -from geometry_msgs.msg import PoseStamped as ROSPoseStamped -from geometry_msgs.msg import PointStamped as ROSPointStamped -from nav_msgs.msg import Odometry as ROSOdometry -from nav_msgs.msg import Path as ROSPath -from sensor_msgs.msg import PointCloud2 as ROSPointCloud2, Joy as ROSJoy -from std_msgs.msg import Bool as ROSBool, Int8 as ROSInt8 -from tf2_msgs.msg import TFMessage as ROSTFMessage - -logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) - - -class ROSNavigationModule(ROSNav): - """ - Handles navigation control and odometry remapping. - """ - - goal_req: In[PoseStamped] = None - cancel_goal: In[Bool] = None - - pointcloud: Out[PointCloud2] = None - global_pointcloud: Out[PointCloud2] = None - - goal_active: Out[PoseStamped] = None - path_active: Out[Path] = None - goal_reached: Out[Bool] = None - odom: Out[Odometry] = None - cmd_vel: Out[Twist] = None - odom_pose: Out[PoseStamped] = None - - def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): - super().__init__(*args, **kwargs) - if not rclpy.ok(): - rclpy.init() - self._node = Node("navigation_module") - - self.goal_reach = None - self.sensor_to_base_link_transform = sensor_to_base_link_transform or [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ] - self.spin_thread = None - - # ROS2 Publishers - self.goal_pose_pub = self._node.create_publisher(ROSPoseStamped, "/goal_pose", 10) - self.cancel_goal_pub = self._node.create_publisher(ROSBool, "/cancel_goal", 10) - self.soft_stop_pub = self._node.create_publisher(ROSInt8, "/soft_stop", 10) - self.joy_pub = self._node.create_publisher(ROSJoy, "/joy", 10) - - # ROS2 Subscribers - self.goal_reached_sub = self._node.create_subscription( - ROSBool, "/goal_reached", self._on_ros_goal_reached, 10 - ) - self.odom_sub = self._node.create_subscription( - ROSOdometry, "/state_estimation", self._on_ros_odom, 10 - ) - self.cmd_vel_sub = self._node.create_subscription( - ROSTwistStamped, "/cmd_vel", self._on_ros_cmd_vel, 10 - ) - self.goal_waypoint_sub = self._node.create_subscription( - ROSPointStamped, "/way_point", self._on_ros_goal_waypoint, 10 - ) - self.registered_scan_sub = self._node.create_subscription( - ROSPointCloud2, "/registered_scan", self._on_ros_registered_scan, 10 - ) - self.global_pointcloud_sub = self._node.create_subscription( - ROSPointCloud2, "/terrain_map_ext", self._on_ros_global_pointcloud, 10 - ) - self.path_sub = self._node.create_subscription(ROSPath, "/path", self._on_ros_path, 10) - self.tf_sub = self._node.create_subscription(ROSTFMessage, "/tf", self._on_ros_tf, 10) - - logger.info("NavigationModule initialized with ROS2 node") - - @rpc - def start(self): - self._running = True - self.spin_thread = threading.Thread(target=self._spin_node, daemon=True) - self.spin_thread.start() - - self.goal_req.subscribe(self._on_goal_pose) - self.cancel_goal.subscribe(self._on_cancel_goal) - - logger.info("NavigationModule started with ROS2 spinning") - - def _spin_node(self): - while self._running and rclpy.ok(): - try: - rclpy.spin_once(self._node, timeout_sec=0.1) - except Exception as e: - if self._running: - logger.error(f"ROS2 spin error: {e}") - - def _on_ros_goal_reached(self, msg: ROSBool): - self.goal_reach = msg.data - dimos_bool = Bool(data=msg.data) - self.goal_reached.publish(dimos_bool) - - def _on_ros_goal_waypoint(self, msg: ROSPointStamped): - dimos_pose = PoseStamped( - ts=time.time(), - frame_id=msg.header.frame_id, - position=Vector3(msg.point.x, msg.point.y, msg.point.z), - orientation=Quaternion(0.0, 0.0, 0.0, 1.0), - ) - self.goal_active.publish(dimos_pose) - - def _on_ros_cmd_vel(self, msg: ROSTwistStamped): - # Extract the twist from the stamped message - dimos_twist = Twist( - linear=Vector3(msg.twist.linear.x, msg.twist.linear.y, msg.twist.linear.z), - angular=Vector3(msg.twist.angular.x, msg.twist.angular.y, msg.twist.angular.z), - ) - self.cmd_vel.publish(dimos_twist) - - def _on_ros_odom(self, msg: ROSOdometry): - dimos_odom = Odometry.from_ros_msg(msg) - self.odom.publish(dimos_odom) - - dimos_pose = PoseStamped( - ts=dimos_odom.ts, - frame_id=dimos_odom.frame_id, - position=dimos_odom.pose.pose.position, - orientation=dimos_odom.pose.pose.orientation, - ) - self.odom_pose.publish(dimos_pose) - - def _on_ros_registered_scan(self, msg: ROSPointCloud2): - dimos_pointcloud = PointCloud2.from_ros_msg(msg) - self.pointcloud.publish(dimos_pointcloud) - - def _on_ros_global_pointcloud(self, msg: ROSPointCloud2): - dimos_pointcloud = PointCloud2.from_ros_msg(msg) - self.global_pointcloud.publish(dimos_pointcloud) - - def _on_ros_path(self, msg: ROSPath): - dimos_path = Path.from_ros_msg(msg) - self.path_active.publish(dimos_path) - - def _on_ros_tf(self, msg: ROSTFMessage): - ros_tf = TFMessage.from_ros_msg(msg) - - translation = Vector3( - self.sensor_to_base_link_transform[0], - self.sensor_to_base_link_transform[1], - self.sensor_to_base_link_transform[2], - ) - euler_angles = Vector3( - self.sensor_to_base_link_transform[3], - self.sensor_to_base_link_transform[4], - self.sensor_to_base_link_transform[5], - ) - rotation = euler_to_quaternion(euler_angles) - - sensor_to_base_link_tf = Transform( - translation=translation, - rotation=rotation, - frame_id="sensor", - child_frame_id="base_link", - ts=time.time(), - ) - - map_to_world_tf = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=euler_to_quaternion(Vector3(0.0, 0.0, 0.0)), - frame_id="map", - child_frame_id="world", - ts=time.time(), - ) - - self.tf.publish(sensor_to_base_link_tf, map_to_world_tf, *ros_tf.transforms) - - def _on_goal_pose(self, msg: PoseStamped): - self.navigate_to(msg) - - def _on_cancel_goal(self, msg: Bool): - if msg.data: - self.stop() - - def _set_autonomy_mode(self): - joy_msg = ROSJoy() - joy_msg.axes = [ - 0.0, # axis 0 - 0.0, # axis 1 - -1.0, # axis 2 - 0.0, # axis 3 - 1.0, # axis 4 - 1.0, # axis 5 - 0.0, # axis 6 - 0.0, # axis 7 - ] - joy_msg.buttons = [ - 0, # button 0 - 0, # button 1 - 0, # button 2 - 0, # button 3 - 0, # button 4 - 0, # button 5 - 0, # button 6 - 1, # button 7 - controls autonomy mode - 0, # button 8 - 0, # button 9 - 0, # button 10 - ] - self.joy_pub.publish(joy_msg) - logger.info("Setting autonomy mode via Joy message") - - @rpc - def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: - """ - Navigate to a target pose by publishing to ROS topics. - - Args: - pose: Target pose to navigate to - timeout: Maximum time to wait for goal (seconds) - - Returns: - True if navigation was successful - """ - logger.info( - f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" - ) - - self.goal_reach = None - self._set_autonomy_mode() - - # Enable soft stop (0 = enable) - soft_stop_msg = ROSInt8() - soft_stop_msg.data = 0 - self.soft_stop_pub.publish(soft_stop_msg) - - ros_pose = pose.to_ros_msg() - self.goal_pose_pub.publish(ros_pose) - - # Wait for goal to be reached - start_time = time.time() - while time.time() - start_time < timeout: - if self.goal_reach is not None: - soft_stop_msg.data = 2 - self.soft_stop_pub.publish(soft_stop_msg) - return self.goal_reach - time.sleep(0.1) - - self.stop_navigation() - logger.warning(f"Navigation timed out after {timeout} seconds") - return False - - @rpc - def stop_navigation(self) -> bool: - """ - Stop current navigation by publishing to ROS topics. - - Returns: - True if stop command was sent successfully - """ - logger.info("Stopping navigation") - - cancel_msg = ROSBool() - cancel_msg.data = True - self.cancel_goal_pub.publish(cancel_msg) - - soft_stop_msg = ROSInt8() - soft_stop_msg.data = 2 - self.soft_stop_pub.publish(soft_stop_msg) - - return True - - @rpc - def stop(self): - try: - self._running = False - if self.spin_thread: - self.spin_thread.join(timeout=1) - self._node.destroy_node() - except Exception as e: - logger.error(f"Error during shutdown: {e}") - - -class NavBot: - """ - NavBot wrapper that deploys NavigationModule with proper DIMOS/ROS2 integration. - """ - - def __init__(self, dimos=None, sensor_to_base_link_transform=None): - """ - Initialize NavBot. - - Args: - dimos: DIMOS instance (creates new one if None) - sensor_to_base_link_transform: Optional [x, y, z, roll, pitch, yaw] transform - """ - if dimos is None: - self.dimos = core.start(2) - else: - self.dimos = dimos - - self.sensor_to_base_link_transform = sensor_to_base_link_transform or [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ] - self.navigation_module = None - - def start(self): - logger.info("Deploying navigation module...") - self.navigation_module = self.dimos.deploy( - ROSNavigationModule, sensor_to_base_link_transform=self.sensor_to_base_link_transform - ) - - self.navigation_module.goal_req.transport = core.LCMTransport("/goal", PoseStamped) - self.navigation_module.cancel_goal.transport = core.LCMTransport("/cancel_goal", Bool) - - self.navigation_module.pointcloud.transport = core.LCMTransport( - "/pointcloud_map", PointCloud2 - ) - self.navigation_module.global_pointcloud.transport = core.LCMTransport( - "/global_pointcloud", PointCloud2 - ) - self.navigation_module.goal_active.transport = core.LCMTransport( - "/goal_active", PoseStamped - ) - self.navigation_module.path_active.transport = core.LCMTransport("/path_active", Path) - self.navigation_module.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) - self.navigation_module.odom.transport = core.LCMTransport("/odom", Odometry) - self.navigation_module.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) - self.navigation_module.odom_pose.transport = core.LCMTransport("/odom_pose", PoseStamped) - - self.navigation_module.start() - - def shutdown(self): - logger.info("Shutting down NavBot...") - - if self.navigation_module: - self.navigation_module.stop() - - if rclpy.ok(): - rclpy.shutdown() - - logger.info("NavBot shutdown complete") - - -def main(): - pubsub.lcm.autoconf() - nav_bot = NavBot() - nav_bot.start() - - logger.info("\nTesting navigation in 2 seconds...") - time.sleep(2) - - test_pose = PoseStamped( - ts=time.time(), - frame_id="map", - position=Vector3(1.0, 1.0, 0.0), - orientation=Quaternion(0.0, 0.0, 0.0, 0.0), - ) - - logger.info(f"Sending navigation goal to: (1.0, 1.0, 0.0)") - - if nav_bot.navigation_module: - success = nav_bot.navigation_module.navigate_to(test_pose, timeout=30.0) - if success: - logger.info("✓ Navigation goal reached!") - else: - logger.warning("✗ Navigation failed or timed out") - - try: - logger.info("\nNavBot running. Press Ctrl+C to stop.") - while True: - time.sleep(1) - except KeyboardInterrupt: - logger.info("\nShutting down...") - nav_bot.shutdown() - - -if __name__ == "__main__": - main() diff --git a/dimos/navigation/rosnav/rosnav.py b/dimos/navigation/rosnav/rosnav.py deleted file mode 100644 index 440a0f4269..0000000000 --- a/dimos/navigation/rosnav/rosnav.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.core import In, Module, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.nav_msgs import Path -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.msgs.std_msgs import Bool - - -class ROSNav(Module): - goal_req: In[PoseStamped] = None # type: ignore - goal_active: Out[PoseStamped] = None # type: ignore - path_active: Out[Path] = None # type: ignore - cancel_goal: In[Bool] = None # type: ignore - cmd_vel: Out[Twist] = None # type: ignore - - # PointcloudPerception attributes - pointcloud: Out[PointCloud2] = None # type: ignore - - # Global3DMapSpec attributes - global_pointcloud: Out[PointCloud2] = None # type: ignore - - def start(self) -> None: - pass - - def stop(self) -> None: - pass - - def navigate_to(self, target: PoseStamped) -> None: - # TODO: Implement navigation logic - pass - - def stop_navigation(self) -> None: - # TODO: Implement stop logic - pass diff --git a/dimos/robot/unitree_webrtc/connection/__init__.py b/dimos/robot/unitree_webrtc/connection/__init__.py new file mode 100644 index 0000000000..cd93ef78ac --- /dev/null +++ b/dimos/robot/unitree_webrtc/connection/__init__.py @@ -0,0 +1 @@ +import dimos.robot.unitree_webrtc.connection.g1 as g1 diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection/connection.py similarity index 97% rename from dimos/robot/unitree_webrtc/connection.py rename to dimos/robot/unitree_webrtc/connection/connection.py index 8ddc77ac63..abfba92fa9 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection/connection.py @@ -30,7 +30,7 @@ from reactivex.observable import Observable from reactivex.subject import Subject -from dimos.core import In, Module, Out, rpc +from dimos.core import DimosCluster, In, Module, Out, rpc from dimos.core.resource import Resource from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 from dimos.msgs.sensor_msgs import Image @@ -402,3 +402,18 @@ async def async_disconnect(): if hasattr(self, "thread") and self.thread.is_alive(): self.thread.join(timeout=2.0) + + +def deploy(dimos: DimosCluster, ip: str) -> None: + from dimos.robot.foxglove_bridge import FoxgloveBridge + + connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) + + bridge = FoxgloveBridge( + shm_channels=[ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + ] + ) + bridge.start() + connection.start() diff --git a/dimos/robot/unitree_webrtc/connection/g1.py b/dimos/robot/unitree_webrtc/connection/g1.py new file mode 100644 index 0000000000..af3f07cf69 --- /dev/null +++ b/dimos/robot/unitree_webrtc/connection/g1.py @@ -0,0 +1,69 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from reactivex.disposable import Disposable + +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, rpc +from dimos.msgs.geometry_msgs import ( + Twist, + TwistStamped, +) +from dimos.robot.unitree_webrtc.connection.connection import UnitreeWebRTCConnection + + +class G1Connection(Module): + cmd_vel: In[TwistStamped] = None + ip: str + + def __init__(self, ip: str = None, **kwargs): + super().__init__(**kwargs) + self.ip = ip + + @rpc + def start(self): + super().start() + self.connection = UnitreeWebRTCConnection(self.ip) + self.connection.start() + + unsub = self.cmd_vel.subscribe(self.move) + self._disposables.add(Disposable(unsub)) + + @rpc + def stop(self) -> None: + self.connection.stop() + super().stop() + + @rpc + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + """Send movement command to robot.""" + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) + self.connection.move(twist, duration) + + @rpc + def publish_request(self, topic: str, data: dict): + """Forward WebRTC publish requests to connection.""" + return self.connection.publish_request(topic, data) + + +class LocalPlanner(Protocol): + cmd_vel: Out[TwistStamped] + + +def deploy(dimos: DimosCluster, ip: str, local_planner: LocalPlanner) -> G1Connection: + connection = dimos.deploy(G1Connection, ip) + connection.cmd_vel.connect(local_planner.cmd_vel) + connection.start() + return connection diff --git a/dimos/robot/unitree_webrtc/modular/__init__.py b/dimos/robot/unitree_webrtc/modular/__init__.py index d823cd796e..21d37d2dbd 100644 --- a/dimos/robot/unitree_webrtc/modular/__init__.py +++ b/dimos/robot/unitree_webrtc/modular/__init__.py @@ -1,2 +1,2 @@ -from dimos.robot.unitree_webrtc.modular.connection_module import deploy_connection -from dimos.robot.unitree_webrtc.modular.navigation import deploy_navigation +# from dimos.robot.unitree_webrtc.modular.connection_module import deploy_connection +# from dimos.robot.unitree_webrtc.modular.navigation import deploy_navigation diff --git a/dimos/robot/unitree_webrtc/modular/ivan_g1.py b/dimos/robot/unitree_webrtc/modular/ivan_g1.py new file mode 100644 index 0000000000..cd83dd1468 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/ivan_g1.py @@ -0,0 +1,91 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core import DimosCluster, LCMTransport, pSHMTransport, start, wait_exit +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import CameraModule +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Vector3, +) +from dimos.msgs.sensor_msgs import CameraInfo +from dimos.navigation import rosnav +from dimos.robot.unitree_webrtc.connection import g1 +from dimos.robot.unitree_webrtc.modular.misc import deploy_foxglove +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__name__) + + +def deploy_monozed(dimos) -> CameraModule: + camera = dimos.deploy( + CameraModule, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=5, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ) + camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) + camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + camera.start() + return camera + + +def ivan_g1(dimos: DimosCluster, ip: str) -> None: + nav = rosnav.deploy(dimos) + connection = g1.deploy(dimos, ip, nav) + zed = deploy_monozed(dimos) + fg = deploy_foxglove(dimos) + + time.sleep(5) + + test_pose = PoseStamped( + ts=time.time(), + frame_id="map", + position=Vector3(0.0, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + ) + + nav.navigate_to(test_pose) + wait_exit() + dimos.close_all() + + +if __name__ == "__main__": + import argparse + import os + + from dotenv import load_dotenv + + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") + parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") + + args = parser.parse_args() + ivan_g1(start(8), args.ip) diff --git a/dimos/robot/unitree_webrtc/modular/misc.py b/dimos/robot/unitree_webrtc/modular/misc.py new file mode 100644 index 0000000000..7880426a6f --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/misc.py @@ -0,0 +1,33 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging + +from dimos.core import DimosCluster +from dimos.robot.foxglove_bridge import FoxgloveBridge + +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) + + +def deploy_foxglove(dimos: DimosCluster) -> FoxgloveBridge: + foxglove_bridge = dimos.deploy( + FoxgloveBridge, + shm_channels=[ + "/image#sensor_msgs.Image", + # "/lidar#sensor_msgs.PointCloud2", + # "/map#sensor_msgs.PointCloud2", + ], + ) + foxglove_bridge.start() + return foxglove_bridge diff --git a/dimos/utils/logging_config.py b/dimos/utils/logging_config.py index a1e1a25ca4..a0a6a5fc4a 100644 --- a/dimos/utils/logging_config.py +++ b/dimos/utils/logging_config.py @@ -17,13 +17,20 @@ This module sets up a logger with color output for different log levels. """ -import os import logging -import colorlog +import os from typing import Optional +import colorlog + logging.basicConfig(format="%(name)s - %(levelname)s - %(message)s") +logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("websockets.server").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) +logging.getLogger("asyncio").setLevel(logging.ERROR) + def setup_logger( name: str, level: Optional[int] = None, log_format: Optional[str] = None From 13456126e55efa87cbc93b31cc70c7c6bed2b51a Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 16:42:56 -0700 Subject: [PATCH 09/40] rosnav pointcloud frequency --- dimos/navigation/rosnav.py | 34 +++++++++++++--------- dimos/robot/unitree_webrtc/modular/misc.py | 4 +-- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 1036739f25..5316cec0dc 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -21,6 +21,7 @@ import logging import threading import time +from typing import Optional import rclpy from geometry_msgs.msg import PointStamped as ROSPointStamped @@ -64,6 +65,9 @@ class ROSNav(Module): path_active: Out[Path] = None cmd_vel: Out[TwistStamped] = None + _local_pointcloud: Optional[ROSPointCloud2] = None + _global_pointcloud: Optional[ROSPointCloud2] = None + def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): super().__init__(*args, **kwargs) @@ -113,28 +117,32 @@ def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): @rpc def start(self): self._running = True - self.spin_thread = threading.Thread(target=self._spin_node, daemon=True) - self.spin_thread.start() + # TODO these should be rxpy streams, rxpy has a way to convert callbacks to streams def broadcast_lidar(): while self._running: - if not hasattr(self, "_local_pointcloud"): - return - self.pointcloud.publish(PointCloud2.from_ros_msg(self._local_pointcloud)) + if self._local_pointcloud: + self.pointcloud.publish( + PointCloud2.from_ros_msg(self._local_pointcloud), + ) time.sleep(0.5) def broadcast_map(): while self._running: - if not hasattr(self, "_global_pointcloud"): - return - self.global_pointcloud.publish(PointCloud2.from_ros_msg(self.global_pointcloud)) + if self._global_pointcloud: + self.global_pointcloud.publish( + PointCloud2.from_ros_msg(self._global_pointcloud) + ) time.sleep(1.0) self.map_broadcast_thread = threading.Thread(target=broadcast_map, daemon=True) self.lidar_broadcast_thread = threading.Thread(target=broadcast_lidar, daemon=True) + self.map_broadcast_thread.start() + self.lidar_broadcast_thread.start() + self.spin_thread = threading.Thread(target=self._spin_node, daemon=True) + self.spin_thread.start() self.goal_req.subscribe(self._on_goal_pose) - logger.info("NavigationModule started with ROS2 spinning") def _spin_node(self): @@ -311,10 +319,10 @@ def stop(self): def deploy(dimos: DimosCluster): nav = dimos.deploy(ROSNav) - # nav.pointcloud.transport = pSHMTransport("/lidar") - # nav.global_pointcloud.transport = pSHMTransport("/map") - nav.pointcloud.transport = LCMTransport("/lidar", PointCloud2) - nav.global_pointcloud.transport = LCMTransport("/map", PointCloud2) + nav.pointcloud.transport = pSHMTransport("/lidar") + nav.global_pointcloud.transport = pSHMTransport("/map") + # nav.pointcloud.transport = LCMTransport("/lidar", PointCloud2) + # nav.global_pointcloud.transport = LCMTransport("/map", PointCloud2) nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) diff --git a/dimos/robot/unitree_webrtc/modular/misc.py b/dimos/robot/unitree_webrtc/modular/misc.py index 7880426a6f..7df99237c7 100644 --- a/dimos/robot/unitree_webrtc/modular/misc.py +++ b/dimos/robot/unitree_webrtc/modular/misc.py @@ -25,8 +25,8 @@ def deploy_foxglove(dimos: DimosCluster) -> FoxgloveBridge: FoxgloveBridge, shm_channels=[ "/image#sensor_msgs.Image", - # "/lidar#sensor_msgs.PointCloud2", - # "/map#sensor_msgs.PointCloud2", + "/lidar#sensor_msgs.PointCloud2", + "/map#sensor_msgs.PointCloud2", ], ) foxglove_bridge.start() From f3d604fec33b72ec58aa4153fb66bd7cf54b8ece Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 16:48:07 -0700 Subject: [PATCH 10/40] camera frequency adjustment --- dimos/hardware/camera/module.py | 8 ++------ dimos/robot/unitree_webrtc/modular/ivan_g1.py | 1 + 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index 875b6f66a7..e75f06a92d 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -47,6 +47,7 @@ class CameraModuleConfig(ModuleConfig): frame_id: str = "camera_link" transform: Optional[Transform] = field(default_factory=default_transform) hardware: Callable[[], CameraHardware] | CameraHardware = Webcam + frequency: float = 5.0 class CameraModule(Module): @@ -60,9 +61,6 @@ class CameraModule(Module): default_config = CameraModuleConfig - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - @rpc def start(self): if callable(self.config.hardware): @@ -73,9 +71,7 @@ def start(self): if self._module_subscription: return "already started" - stream = self.hardware.image_stream().pipe(sharpness_barrier(5)) - - # camera_info_stream = self.camera_info_stream(frequency=5.0) + stream = self.hardware.image_stream().pipe(sharpness_barrier(self.config.frequency)) def publish_info(camera_info: CameraInfo): self.camera_info.publish(camera_info) diff --git a/dimos/robot/unitree_webrtc/modular/ivan_g1.py b/dimos/robot/unitree_webrtc/modular/ivan_g1.py index cd83dd1468..505c2adca7 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_g1.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_g1.py @@ -37,6 +37,7 @@ def deploy_monozed(dimos) -> CameraModule: camera = dimos.deploy( CameraModule, + frequency=4.0, transform=Transform( translation=Vector3(0.05, 0.0, 0.0), rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), From 13cd96398fd6c095597124d0f3053674203881dc Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 18:02:29 -0700 Subject: [PATCH 11/40] detection module deployment --- dimos/perception/detection/module2D.py | 59 ++++++++-------- dimos/perception/detection/module3D.py | 67 +++++++++++++++++-- dimos/perception/detection/moduleDB.py | 42 ++++++++++-- dimos/spec/__init__.py | 1 + .../spec.py => spec/perception.py} | 11 ++- 5 files changed, 137 insertions(+), 43 deletions(-) create mode 100644 dimos/spec/__init__.py rename dimos/{perception/spec.py => spec/perception.py} (82%) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index c4b0ba5a43..a52745452c 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -22,7 +22,8 @@ from reactivex.observable import Observable from reactivex.subject import Subject -from dimos.core import In, Module, Out, rpc +from dimos import spec +from dimos.core import DimosCluster, In, Module, Out, rpc from dimos.core.module import ModuleConfig from dimos.msgs.geometry_msgs import Transform, Vector3 from dimos.msgs.sensor_msgs import Image @@ -42,7 +43,7 @@ class Config(ModuleConfig): max_freq: float = 10 detector: Optional[Callable[[Any], Detector]] = YoloPersonDetector - camera_info: CameraInfo = CameraInfo() + publish_detection_images: bool = True class Detection2DModule(Module): @@ -83,33 +84,6 @@ def sharp_image_stream(self) -> Observable[Image]: def detection_stream_2d(self) -> Observable[ImageDetections2D]: return backpressure(self.image.observable().pipe(ops.map(self.process_image_frame))) - def pixel_to_3d( - self, - pixel: Tuple[int, int], - camera_info: CameraInfo, - assumed_depth: float = 1.0, - ) -> Vector3: - """Unproject 2D pixel coordinates to 3D position in camera optical frame. - - Args: - camera_info: Camera calibration information - assumed_depth: Assumed depth in meters (default 1.0m from camera) - - Returns: - Vector3 position in camera optical frame coordinates - """ - # Extract camera intrinsics - fx, fy = camera_info.K[0], camera_info.K[4] - cx, cy = camera_info.K[2], camera_info.K[5] - - # Unproject pixel to normalized camera coordinates - x_norm = (pixel[0] - cx) / fx - y_norm = (pixel[1] - cy) / fy - - # Create 3D point at assumed depth in camera optical frame - # Camera optical frame: X right, Y down, Z forward - return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) - def track(self, detections: ImageDetections2D): sensor_frame = self.tf.get("sensor", "camera_optical", detections.image.ts, 5.0) @@ -166,7 +140,32 @@ def publish_cropped_images(detections: ImageDetections2D): image_topic = getattr(self, "detected_image_" + str(index)) image_topic.publish(detection.cropped_image()) - self.detection_stream_2d().subscribe(publish_cropped_images) + if self.config.publish_detection_images: + self.detection_stream_2d().subscribe(publish_cropped_images) @rpc def stop(self): ... + + +def deploy( + dimos: DimosCluster, + camera_info: CameraInfo, + camera: spec.Camera, + prefix: str = "/detector2d", + **kwargs, +) -> Detection2DModule: + from dimos.core import LCMTransport + + detector = Detection2DModule(camera_info=camera.config.camera_info, **kwargs) + + detector.image.connect(camera.image) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.start() + return detector diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index b8fe42da9a..548f294e17 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -13,17 +13,18 @@ # limitations under the License. -from typing import Optional +from typing import Optional, Tuple from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from lcm_msgs.foxglove_msgs import SceneUpdate from reactivex import operators as ops from reactivex.observable import Observable +from dimos import spec from dimos.agents2 import skill -from dimos.core import In, Out, rpc -from dimos.msgs.geometry_msgs import Transform -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.core import DimosCluster, In, Out, rpc +from dimos.msgs.geometry_msgs import Transform, Vector3 +from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.module2D import Config as Module2DConfig from dimos.perception.detection.module2D import Detection2DModule @@ -82,6 +83,32 @@ def process_frame( return ImageDetections3DPC(detections.image, detection3d_list) + def pixel_to_3d( + self, + pixel: Tuple[int, int], + assumed_depth: float = 1.0, + ) -> Vector3: + """Unproject 2D pixel coordinates to 3D position in camera optical frame. + + Args: + camera_info: Camera calibration information + assumed_depth: Assumed depth in meters (default 1.0m from camera) + + Returns: + Vector3 position in camera optical frame coordinates + """ + # Extract camera intrinsics + fx, fy = self.camera_info.K[0], self.camera_info.K[4] + cx, cy = self.camera_info.K[2], self.camera_info.K[5] + + # Unproject pixel to normalized camera coordinates + x_norm = (pixel[0] - cx) / fx + y_norm = (pixel[1] - cy) / fy + + # Create 3D point at assumed depth in camera optical frame + # Camera optical frame: X right, Y down, Z forward + return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) + @skill # type: ignore[arg-type] def ask_vlm(self, question: str) -> str | ImageDetections3DPC: """ @@ -134,3 +161,35 @@ def _publish_detections(self, detections: ImageDetections3DPC): for index, detection in enumerate(detections[:3]): pointcloud_topic = getattr(self, "detected_pointcloud_" + str(index)) pointcloud_topic.publish(detection.pointcloud) + + +def deploy( + dimos: DimosCluster, + camera_info: CameraInfo, + lidar: spec.Pointcloud, + camera: spec.Camera, + prefix: str = "/detector3d", +) -> Detection3DModule: + from dimos.core import LCMTransport + + detector = Detection3DModule( + camera_info=camera.config.camera_info, + ) + + detector.image.connect(camera.image) + detector.pointcloud.connect(lidar.pointcloud) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + detector.scene_update.transport = LCMTransport(f"{prefix}/scene_update", SceneUpdate) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.detected_pointcloud_0.transport = LCMTransport(f"{prefix}/pointcloud/0", PointCloud2) + detector.detected_pointcloud_1.transport = LCMTransport(f"{prefix}/pointcloud/1", PointCloud2) + detector.detected_pointcloud_2.transport = LCMTransport(f"{prefix}/pointcloud/2", PointCloud2) + + detector.start() + return detector diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index ccc14d96f5..6cdde0335a 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -20,17 +20,14 @@ from lcm_msgs.foxglove_msgs import SceneUpdate from reactivex.observable import Observable -from dimos.agents2 import Agent, Output, Reducer, Stream, skill -from dimos.core import In, Out, rpc +from dimos import spec +from dimos.core import DimosCluster, In, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule -from dimos.perception.detection.type import Detection3D, ImageDetections3DPC, TableStr +from dimos.perception.detection.type import ImageDetections3DPC, TableStr from dimos.perception.detection.type.detection3d import Detection3DPC -from dimos.protocol.skill.skill import skill -from dimos.protocol.skill.type import Output, Reducer, Stream -from dimos.types.timestamped import to_datetime # Represents an object in space, as collection of 3d detections over time @@ -309,3 +306,34 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": def __len__(self): return len(self.objects.values()) + + +def deploy( + dimos: DimosCluster, + camera_info: CameraInfo, + lidar: spec.Pointcloud, + camera: spec.Camera, + prefix: str = "/objectdb", + **kwargs, +) -> ObjectDBModule: + from dimos.core import LCMTransport + + detector = ObjectDBModule(camera_info=camera.config.camera_info, **kwargs) + + detector.image.connect(camera.image) + detector.pointcloud.connect(lidar.pointcloud) + + detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport(f"{prefix}/detections", Detection2DArray) + detector.scene_update.transport = LCMTransport(f"{prefix}/scene_update", SceneUpdate) + + detector.detected_image_0.transport = LCMTransport(f"{prefix}/image/0", Image) + detector.detected_image_1.transport = LCMTransport(f"{prefix}/image/1", Image) + detector.detected_image_2.transport = LCMTransport(f"{prefix}/image/2", Image) + + detector.detected_pointcloud_0.transport = LCMTransport(f"{prefix}/pointcloud/0", PointCloud2) + detector.detected_pointcloud_1.transport = LCMTransport(f"{prefix}/pointcloud/1", PointCloud2) + detector.detected_pointcloud_2.transport = LCMTransport(f"{prefix}/pointcloud/2", PointCloud2) + + detector.start() + return detector diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py new file mode 100644 index 0000000000..556faa5561 --- /dev/null +++ b/dimos/spec/__init__.py @@ -0,0 +1 @@ +from dimos.spec.perception import Camera, Image, PointCloud diff --git a/dimos/perception/spec.py b/dimos/spec/perception.py similarity index 82% rename from dimos/perception/spec.py rename to dimos/spec/perception.py index de53ce9bd7..3a1ef05686 100644 --- a/dimos/perception/spec.py +++ b/dimos/spec/perception.py @@ -15,8 +15,15 @@ from typing import Protocol from dimos.core import Out -from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.sensor_msgs import Image, PointCloud2 -class PointcloudPerception(Protocol): +class Image(Protocol): + image: Out[Image] + + +Camera = Image + + +class Pointcloud(Protocol): pointcloud: Out[PointCloud2] From 401ac434f4cc72a8a31cbe522fde5632d6da2f85 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 19:13:56 -0700 Subject: [PATCH 12/40] fixing run files --- dimos/agents2/__init__.py | 2 +- dimos/agents2/agent.py | 30 +- dimos/perception/detection/module2D.py | 19 +- dimos/perception/detection/module3D.py | 5 +- dimos/perception/detection/moduleDB.py | 7 +- dimos/robot/unitree_webrtc/connection/g1.py | 7 +- dimos/robot/unitree_webrtc/connection/go2.py | 299 ++++++++++++++++++ dimos/robot/unitree_webrtc/modular/ivan_g1.py | 19 +- .../robot/unitree_webrtc/modular/ivan_go2.py | 59 ++++ .../unitree_webrtc/modular/ivan_unitree.py | 139 -------- dimos/spec/__init__.py | 3 +- dimos/spec/perception.py | 4 +- 12 files changed, 422 insertions(+), 171 deletions(-) create mode 100644 dimos/robot/unitree_webrtc/connection/go2.py create mode 100644 dimos/robot/unitree_webrtc/modular/ivan_go2.py delete mode 100644 dimos/robot/unitree_webrtc/modular/ivan_unitree.py diff --git a/dimos/agents2/__init__.py b/dimos/agents2/__init__.py index 28a48430b6..c817bb3aee 100644 --- a/dimos/agents2/__init__.py +++ b/dimos/agents2/__init__.py @@ -7,7 +7,7 @@ ToolMessage, ) -from dimos.agents2.agent import Agent +from dimos.agents2.agent import Agent, deploy from dimos.agents2.spec import AgentSpec from dimos.protocol.skill.skill import skill from dimos.protocol.skill.type import Output, Reducer, Stream diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 94f418acc2..51952d4b4d 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import json import datetime +import json import os import uuid from operator import itemgetter @@ -28,9 +28,8 @@ ToolMessage, ) -from dimos.agents2.spec import AgentSpec -from dimos.core import rpc -from dimos.msgs.sensor_msgs import Image +from dimos.agents2.spec import AgentSpec, Model, Provider +from dimos.core import DimosCluster, rpc from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateDict from dimos.protocol.skill.type import Output from dimos.utils.logging_config import setup_logger @@ -346,3 +345,26 @@ def _write_debug_history_file(self): with open(file_path, "w") as f: json.dump(history, f, default=lambda x: repr(x), indent=2) + + +def deploy( + dimos: DimosCluster, + system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot.", + model: Model = Model.GPT_4O, + provider: Provider = Provider.OPENAI, +) -> Agent: + from dimos.agents2.cli.human import HumanInput + + agent = dimos.deploy( + Agent, + system_prompt=system_prompt, + model=model, + provider=provider, + ) + + human_input = dimos.deploy(HumanInput) + agent.register_skills(human_input) + + agent.start() + + return agent diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index a52745452c..913e84bd7a 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -12,12 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Optional, Tuple +from typing import Any, Callable, Optional from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, ) -from dimos_lcm.sensor_msgs import CameraInfo from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject @@ -26,7 +25,7 @@ from dimos.core import DimosCluster, In, Module, Out, rpc from dimos.core.module import ModuleConfig from dimos.msgs.geometry_msgs import Transform, Vector3 -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs import CameraInfo, Image from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.detectors import Detector @@ -42,8 +41,9 @@ @dataclass class Config(ModuleConfig): max_freq: float = 10 - detector: Optional[Callable[[Any], Detector]] = YoloPersonDetector + detector: Optional[Callable[[Any], Detector]] = Yolo2DDetector publish_detection_images: bool = True + camera_info: CameraInfo = None # type: ignore class Detection2DModule(Module): @@ -64,7 +64,6 @@ class Detection2DModule(Module): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.config: Config = Config(**kwargs) self.detector = self.config.detector() self.vlm_detections_subject = Subject() self.previous_detection_count = 0 @@ -82,7 +81,7 @@ def sharp_image_stream(self) -> Observable[Image]: @simple_mcache def detection_stream_2d(self) -> Observable[ImageDetections2D]: - return backpressure(self.image.observable().pipe(ops.map(self.process_image_frame))) + return backpressure(self.sharp_image_stream().pipe(ops.map(self.process_image_frame))) def track(self, detections: ImageDetections2D): sensor_frame = self.tf.get("sensor", "camera_optical", detections.image.ts, 5.0) @@ -125,7 +124,7 @@ def track(self, detections: ImageDetections2D): @rpc def start(self): - self.detection_stream_2d().subscribe(self.track) + # self.detection_stream_2d().subscribe(self.track) self.detection_stream_2d().subscribe( lambda det: self.detections.publish(det.to_ros_detection2d_array()) @@ -144,19 +143,19 @@ def publish_cropped_images(detections: ImageDetections2D): self.detection_stream_2d().subscribe(publish_cropped_images) @rpc - def stop(self): ... + def stop(self): + return super().stop() def deploy( dimos: DimosCluster, - camera_info: CameraInfo, camera: spec.Camera, prefix: str = "/detector2d", **kwargs, ) -> Detection2DModule: from dimos.core import LCMTransport - detector = Detection2DModule(camera_info=camera.config.camera_info, **kwargs) + detector = Detection2DModule(**kwargs) detector.image.connect(camera.image) diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 548f294e17..e27cd2a930 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -169,12 +169,11 @@ def deploy( lidar: spec.Pointcloud, camera: spec.Camera, prefix: str = "/detector3d", + **kwargs, ) -> Detection3DModule: from dimos.core import LCMTransport - detector = Detection3DModule( - camera_info=camera.config.camera_info, - ) + detector = dimos.deploy(Detection3DModule, camera_info=camera_info, **kwargs) detector.image.connect(camera.image) detector.pointcloud.connect(lidar.pointcloud) diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 6cdde0335a..4f38bfffec 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -267,6 +267,10 @@ def scene_thread(): self.detection_stream_3d.subscribe(update_objects) + @rpc + def stop(self): + return super().stop() + def goto_object(self, object_id: str) -> Optional[Object3D]: """Go to object by id.""" return self.objects.get(object_id, None) @@ -310,7 +314,6 @@ def __len__(self): def deploy( dimos: DimosCluster, - camera_info: CameraInfo, lidar: spec.Pointcloud, camera: spec.Camera, prefix: str = "/objectdb", @@ -318,7 +321,7 @@ def deploy( ) -> ObjectDBModule: from dimos.core import LCMTransport - detector = ObjectDBModule(camera_info=camera.config.camera_info, **kwargs) + detector = ObjectDBModule(camera_info=camera.camera_info, **kwargs) detector.image.connect(camera.image) detector.pointcloud.connect(lidar.pointcloud) diff --git a/dimos/robot/unitree_webrtc/connection/g1.py b/dimos/robot/unitree_webrtc/connection/g1.py index af3f07cf69..b1b82a2dff 100644 --- a/dimos/robot/unitree_webrtc/connection/g1.py +++ b/dimos/robot/unitree_webrtc/connection/g1.py @@ -16,6 +16,7 @@ from reactivex.disposable import Disposable +from dimos import spec from dimos.core import DimosCluster, In, LCMTransport, Module, Out, rpc from dimos.msgs.geometry_msgs import ( Twist, @@ -58,11 +59,7 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) -class LocalPlanner(Protocol): - cmd_vel: Out[TwistStamped] - - -def deploy(dimos: DimosCluster, ip: str, local_planner: LocalPlanner) -> G1Connection: +def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: connection = dimos.deploy(G1Connection, ip) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() diff --git a/dimos/robot/unitree_webrtc/connection/go2.py b/dimos/robot/unitree_webrtc/connection/go2.py new file mode 100644 index 0000000000..04eabc9884 --- /dev/null +++ b/dimos/robot/unitree_webrtc/connection/go2.py @@ -0,0 +1,299 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import List, Optional, Protocol + +from dimos_lcm.sensor_msgs import CameraInfo +from reactivex.disposable import Disposable + +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + TwistStamped, + Vector3, +) +from dimos.msgs.nav_msgs import OccupancyGrid, Path +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.robot.unitree_webrtc.connection.connection import UnitreeWebRTCConnection +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.data import get_data +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +logger = setup_logger(__file__, level=logging.INFO) + + +def _camera_info() -> CameraInfo: + fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) + width, height = (1280, 720) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo(**base_msg, header=Header("camera_optical")) + + +camera_info = _camera_info() + + +class FakeRTC(UnitreeWebRTCConnection): + dir_name = "unitree_go2_office_walk2" + + # we don't want UnitreeWebRTCConnection to init + def __init__( + self, + **kwargs, + ): + get_data(self.dir_name) + self.replay_config = { + "loop": kwargs.get("loop"), + "seek": kwargs.get("seek"), + "duration": kwargs.get("duration"), + } + + def connect(self): + pass + + def start(self): + pass + + def standup(self): + print("standup suppressed") + + def liedown(self): + print("liedown suppressed") + + @simple_mcache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") + return lidar_store.stream(**self.replay_config) + + @simple_mcache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay(f"{self.dir_name}/odom") + return odom_store.stream(**self.replay_config) + + # we don't have raw video stream in the data set + @simple_mcache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay(f"{self.dir_name}/video") + + return video_store.stream(**self.replay_config) + + def move(self, vector: Twist, duration: float = 0.0): + pass + + def publish_request(self, topic: str, data: dict): + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} + + +class GO2Connection(Module): + cmd_vel: In[Twist] = None + pointcloud: Out[LidarMessage] = None + image: Out[Image] = None + camera_info: Out[CameraInfo] = None + connection_type: str = "webrtc" + + ip: str + + def __init__( + self, + ip: str = None, + connection_type: str = "webrtc", + rectify_image: bool = True, + *args, + **kwargs, + ): + self.ip = ip + self.connection = None + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self) -> None: + """Start the connection and subscribe to sensor streams.""" + super().start() + + match self.ip: + case None | "fake" | "": + self.connection = FakeRTC() + case "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection() + case _: + self.connection = UnitreeWebRTCConnection(self.ip) + + self.connection.start() + + self._disposables.add( + self.connection.lidar_stream().subscribe(self.pointcloud.publish), + ) + + self._disposables.add( + self.connection.odom_stream().subscribe(self._publish_tf), + ) + + self._disposables.add( + self.connection.video_stream().subscribe(self.image.publish), + ) + + self._disposables.add( + self.cmd_vel.subscribe(self.move), + ) + + # Start publishing camera info at 1 Hz + from threading import Thread + + self._camera_info_thread = Thread( + target=self.publish_camera_info, + daemon=True, + ) + self._camera_info_thread.start() + + @rpc + def stop(self) -> None: + if self.connection: + self.connection.stop() + if hasattr(self, "_camera_info_thread"): + self._camera_info_thread.join(timeout=1.0) + super().stop() + + @classmethod + def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=odom.ts, + ) + + sensor = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="sensor", + ts=odom.ts, + ) + + return [ + Transform.from_pose("base_link", odom), + camera_link, + camera_optical, + sensor, + ] + + def _publish_tf(self, msg): + self.tf.publish(*self._odom_to_tf(msg)) + + def publish_camera_info(self): + while True: + self.camera_info.publish(camera_info) + time.sleep(1.0) + + @rpc + def get_odom(self) -> Optional[PoseStamped]: + """Get the robot's odometry. + + Returns: + The robot's odometry + """ + return self._odom + + @rpc + def move(self, twist: Twist, duration: float = 0.0): + """Send movement command to robot.""" + self.connection.move(twist, duration) + + @rpc + def standup(self): + """Make the robot stand up.""" + return self.connection.standup() + + @rpc + def liedown(self): + """Make the robot lie down.""" + return self.connection.liedown() + + @rpc + def publish_request(self, topic: str, data: dict): + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) + + +def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: + from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE + + connection = dimos.deploy(GO2Connection, ip) + + connection.pointcloud.transport = pSHMTransport( + f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + connection.image.transport = pSHMTransport( + f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) + connection.start() + return connection diff --git a/dimos/robot/unitree_webrtc/modular/ivan_g1.py b/dimos/robot/unitree_webrtc/modular/ivan_g1.py index 505c2adca7..0ff922388a 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_g1.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_g1.py @@ -57,11 +57,12 @@ def deploy_monozed(dimos) -> CameraModule: return camera -def ivan_g1(dimos: DimosCluster, ip: str) -> None: +def deploy(dimos: DimosCluster, ip: str) -> None: nav = rosnav.deploy(dimos) connection = g1.deploy(dimos, ip, nav) zed = deploy_monozed(dimos) - fg = deploy_foxglove(dimos) + + deploy_foxglove(dimos) time.sleep(5) @@ -73,8 +74,12 @@ def ivan_g1(dimos: DimosCluster, ip: str) -> None: ) nav.navigate_to(test_pose) - wait_exit() - dimos.close_all() + + return { + "nav": nav, + "connection": connection, + "zed": zed, + } if __name__ == "__main__": @@ -89,4 +94,8 @@ def ivan_g1(dimos: DimosCluster, ip: str) -> None: parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") args = parser.parse_args() - ivan_g1(start(8), args.ip) + + dimos = start(8) + deploy(dimos, args.ip) + wait_exit() + dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/ivan_go2.py b/dimos/robot/unitree_webrtc/modular/ivan_go2.py new file mode 100644 index 0000000000..f6f0f83adc --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/ivan_go2.py @@ -0,0 +1,59 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from dimos import agents2 +from dimos.core import DimosCluster, start, wait_exit +from dimos.perception.detection import module3D as module3D +from dimos.robot.unitree_webrtc.connection import go2 +from dimos.robot.unitree_webrtc.modular.misc import deploy_foxglove +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) + + +def deploy(dimos: DimosCluster, ip: str): + connection = go2.deploy(dimos, ip) + deploy_foxglove(dimos) + + detector = module3D.deploy( + dimos, + go2.camera_info, + camera=connection, + lidar=connection, + ) + + agent = agents2.deploy(dimos) + agent.register_skills(detector) + + +if __name__ == "__main__": + import argparse + import os + + from dotenv import load_dotenv + + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") + parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") + + args = parser.parse_args() + + dimos = start(8) + deploy(dimos, args.ip) + wait_exit() + dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py deleted file mode 100644 index 948dccaa16..0000000000 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ /dev/null @@ -1,139 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time - -from dimos_lcm.foxglove_msgs import SceneUpdate - -from dimos.agents2.spec import Model, Provider -from dimos.core import LCMTransport, start - -# from dimos.msgs.detection2d import Detection2DArray -from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.module2D import Detection2DModule -from dimos.perception.detection.module3D import Detection3DModule -from dimos.perception.detection.person_tracker import PersonTracker -from dimos.perception.detection.reid import ReidModule -from dimos.protocol.pubsub import lcm -from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation -from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) - - -def detection_unitree(): - dimos = start(8) - connection = deploy_connection(dimos) - - def goto(pose): - print("NAVIGATION REQUESTED:", pose) - return True - - detector = dimos.deploy( - Detection2DModule, - # goto=goto, - camera_info=ConnectionModule._camera_info(), - ) - - detector.image.connect(connection.video) - # detector.pointcloud.connect(mapper.global_map) - # detector.pointcloud.connect(connection.lidar) - - detector.annotations.transport = LCMTransport("/annotations", ImageAnnotations) - detector.detections.transport = LCMTransport("/detections", Detection2DArray) - - # detector.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) - # detector.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) - # detector.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) - - detector.detected_image_0.transport = LCMTransport("/detected/image/0", Image) - detector.detected_image_1.transport = LCMTransport("/detected/image/1", Image) - detector.detected_image_2.transport = LCMTransport("/detected/image/2", Image) - # detector.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) - - # reidModule = dimos.deploy(ReidModule) - - # reidModule.image.connect(connection.video) - # reidModule.detections.connect(detector.detections) - # reidModule.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) - - # nav = deploy_navigation(dimos, connection) - - # person_tracker = dimos.deploy(PersonTracker, cameraInfo=ConnectionModule._camera_info()) - # person_tracker.image.connect(connection.video) - # person_tracker.detections.connect(detector.detections) - # person_tracker.target.transport = LCMTransport("/goal_request", PoseStamped) - - reid = dimos.deploy(ReidModule) - - reid.image.connect(connection.video) - reid.detections.connect(detector.detections) - reid.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) - - detector.start() - # person_tracker.start() - connection.start() - reid.start() - - from dimos.agents2 import Agent, Output, Reducer, Stream, skill - from dimos.agents2.cli.human import HumanInput - - agent = Agent( - system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot.", - model=Model.GPT_4O, # Could add CLAUDE models to enum - provider=Provider.OPENAI, # Would need ANTHROPIC provider - ) - - human_input = dimos.deploy(HumanInput) - agent.register_skills(human_input) - # agent.register_skills(connection) - agent.register_skills(detector) - - bridge = FoxgloveBridge( - shm_channels=[ - "/image#sensor_msgs.Image", - "/lidar#sensor_msgs.PointCloud2", - ] - ) - # bridge = FoxgloveBridge() - time.sleep(1) - bridge.start() - - # agent.run_implicit_skill("video_stream_tool") - # agent.run_implicit_skill("human") - - # agent.start() - # agent.loop_thread() - - try: - while True: - time.sleep(1) - except KeyboardInterrupt: - connection.stop() - logger.info("Shutting down...") - - -def main(): - lcm.autoconf() - detection_unitree() - - -if __name__ == "__main__": - main() diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py index 556faa5561..d7a18b190c 100644 --- a/dimos/spec/__init__.py +++ b/dimos/spec/__init__.py @@ -1 +1,2 @@ -from dimos.spec.perception import Camera, Image, PointCloud +from dimos.spec.control import LocalPlanner +from dimos.spec.perception import Camera, Image, Pointcloud diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 3a1ef05686..dba9feb67c 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -14,6 +14,8 @@ from typing import Protocol +from dimos_lcm.sensor_msgs import CameraInfo + from dimos.core import Out from dimos.msgs.sensor_msgs import Image, PointCloud2 @@ -22,7 +24,7 @@ class Image(Protocol): image: Out[Image] -Camera = Image +class Camera(Image): ... class Pointcloud(Protocol): From f89bd3d1c99d776181140077ab8318bdf7bfe894 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 19:19:17 -0700 Subject: [PATCH 13/40] module3d scene update --- dimos/perception/detection/module3D.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index e27cd2a930..45a9baa956 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -162,6 +162,8 @@ def _publish_detections(self, detections: ImageDetections3DPC): pointcloud_topic = getattr(self, "detected_pointcloud_" + str(index)) pointcloud_topic.publish(detection.pointcloud) + self.scene_update.publish(detections.to_foxglove_scene_update()) + def deploy( dimos: DimosCluster, From 3b81baed3e680662fa5052e4fed7c2052dc1d2d9 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 19:23:33 -0700 Subject: [PATCH 14/40] moduledb deploy --- dimos/perception/detection/moduleDB.py | 11 ++++++----- dimos/robot/unitree_webrtc/modular/ivan_go2.py | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 4f38bfffec..6e79cf87ec 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -158,9 +158,9 @@ class ObjectDBModule(Detection3DModule, TableStr): remembered_locations: Dict[str, PoseStamped] - def __init__(self, goto: Callable[[PoseStamped], Any], *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.goto = goto + self.goto = None self.objects = {} self.remembered_locations = {} @@ -314,14 +314,15 @@ def __len__(self): def deploy( dimos: DimosCluster, + camera_info: CameraInfo, lidar: spec.Pointcloud, camera: spec.Camera, - prefix: str = "/objectdb", + prefix: str = "/detectorDB", **kwargs, -) -> ObjectDBModule: +) -> Detection3DModule: from dimos.core import LCMTransport - detector = ObjectDBModule(camera_info=camera.camera_info, **kwargs) + detector = dimos.deploy(ObjectDBModule, camera_info=camera_info, **kwargs) detector.image.connect(camera.image) detector.pointcloud.connect(lidar.pointcloud) diff --git a/dimos/robot/unitree_webrtc/modular/ivan_go2.py b/dimos/robot/unitree_webrtc/modular/ivan_go2.py index f6f0f83adc..81238d4268 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_go2.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_go2.py @@ -17,7 +17,7 @@ from dimos import agents2 from dimos.core import DimosCluster, start, wait_exit -from dimos.perception.detection import module3D as module3D +from dimos.perception.detection import module3D, moduleDB from dimos.robot.unitree_webrtc.connection import go2 from dimos.robot.unitree_webrtc.modular.misc import deploy_foxglove from dimos.utils.logging_config import setup_logger @@ -29,7 +29,7 @@ def deploy(dimos: DimosCluster, ip: str): connection = go2.deploy(dimos, ip) deploy_foxglove(dimos) - detector = module3D.deploy( + detector = moduleDB.deploy( dimos, go2.camera_info, camera=connection, From 6e1c38b919b473c2e9ef940c89d0b79a1f1c5b63 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 17 Oct 2025 23:50:21 -0700 Subject: [PATCH 15/40] spatial mem, nav, skills --- dimos/agents2/agent.py | 16 +- dimos/agents2/cli/human.py | 3 +- dimos/agents2/skills/navigation.py | 127 ++++++---------- dimos/navigation/rosnav.py | 58 +++++++- dimos/perception/detection/module3D.py | 28 +++- dimos/perception/detection/moduleDB.py | 64 ++++---- dimos/perception/spatial_perception.py | 80 +++++----- dimos/robot/foxglove_bridge.py | 23 ++- dimos/robot/unitree_webrtc/modular/ivan_g1.py | 55 +++++-- .../robot/unitree_webrtc/modular/ivan_go2.py | 4 +- .../unitree_webrtc/modular/ivan_unitree.py | 139 ++++++++++++++++++ dimos/robot/unitree_webrtc/modular/misc.py | 33 ----- 12 files changed, 413 insertions(+), 217 deletions(-) create mode 100644 dimos/robot/unitree_webrtc/modular/ivan_unitree.py delete mode 100644 dimos/robot/unitree_webrtc/modular/misc.py diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 51952d4b4d..430873c396 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -30,7 +30,12 @@ from dimos.agents2.spec import AgentSpec, Model, Provider from dimos.core import DimosCluster, rpc -from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateDict +from dimos.protocol.skill.coordinator import ( + SkillContainer, + SkillCoordinator, + SkillState, + SkillStateDict, +) from dimos.protocol.skill.type import Output from dimos.utils.logging_config import setup_logger @@ -352,6 +357,7 @@ def deploy( system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot.", model: Model = Model.GPT_4O, provider: Provider = Provider.OPENAI, + skill_containers: Optional[List[SkillContainer]] = [], ) -> Agent: from dimos.agents2.cli.human import HumanInput @@ -363,8 +369,16 @@ def deploy( ) human_input = dimos.deploy(HumanInput) + human_input.start() + agent.register_skills(human_input) + for skill_container in skill_containers: + print("Registering skill container:", skill_container) + agent.register_skills(skill_container) + + agent.run_implicit_skill("human") agent.start() + agent.loop_thread() return agent diff --git a/dimos/agents2/cli/human.py b/dimos/agents2/cli/human.py index 5a20abb388..9da594c085 100644 --- a/dimos/agents2/cli/human.py +++ b/dimos/agents2/cli/human.py @@ -14,9 +14,10 @@ import queue +from reactivex.disposable import Disposable + from dimos.agents2 import Output, Reducer, Stream, skill from dimos.core import Module, pLCMTransport, rpc -from reactivex.disposable import Disposable class HumanInput(Module): diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 18558515e6..ae57995b18 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -18,54 +18,50 @@ from reactivex import Observable from reactivex.disposable import CompositeDisposable, Disposable +from dimos.core import Module from dimos.core.resource import Resource from dimos.models.qwen.video_query import BBox from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.sensor_msgs import Image +from dimos.navigation.bt_navigator.navigator import NavigatorState from dimos.navigation.visual.query import get_object_bbox_from_image from dimos.protocol.skill.skill import SkillContainer, skill from dimos.robot.robot import UnitreeRobot from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler -from dimos.navigation.bt_navigator.navigator import NavigatorState logger = setup_logger(__file__) -class NavigationSkillContainer(SkillContainer, Resource): - _robot: UnitreeRobot +class NavigationSkillContainer(Module): _disposables: CompositeDisposable _latest_image: Optional[Image] _video_stream: Observable[Image] - _started: bool + _started: bool = False + + def __init__(self, spatial_memory, nav, detection_module): + self.nav = nav + self.spatial_memory = spatial_memory + self.detection_module = detection_module - def __init__(self, robot: UnitreeRobot, video_stream: Observable[Image]): super().__init__() - self._robot = robot - self._disposables = CompositeDisposable() - self._latest_image = None - self._video_stream = video_stream self._similarity_threshold = 0.23 - self._started = False self._vl_model = QwenVlModel() def start(self) -> None: - unsub = self._video_stream.subscribe(self._on_video) - self._disposables.add(Disposable(unsub) if callable(unsub) else unsub) + # unsub = self._video_stream.subscribe(self._on_video) + # self._disposables.add(Disposable(unsub) if callable(unsub) else unsub) self._started = True def stop(self) -> None: self._disposables.dispose() super().stop() - def _on_video(self, image: Image) -> None: - self._latest_image = image - @skill() - def tag_location_in_spatial_memory(self, location_name: str) -> str: + def tag_location(self, location_name: str) -> str: """Tag this location in the spatial memory with a name. This associates the current location with the given name in the spatial memory, allowing you to navigate back to it. @@ -79,10 +75,12 @@ def tag_location_in_spatial_memory(self, location_name: str) -> str: if not self._started: raise ValueError(f"{self} has not been started.") + tf = self.tf.get("map", "base_link", time_tolerance=2.0) + if not tf: + return "Could not get the robot's current transform." - pose_data = self._robot.get_odom() - position = pose_data.position - rotation = quaternion_to_euler(pose_data.orientation) + position = tf.translation + rotation = tf.rotation.to_euler() location = RobotLocation( name=location_name, @@ -90,11 +88,18 @@ def tag_location_in_spatial_memory(self, location_name: str) -> str: rotation=(rotation.x, rotation.y, rotation.z), ) - if not self._robot.spatial_memory.tag_location(location): + if not self.spatial_memory.tag_location(location): return f"Failed to store '{location_name}' in the spatial memory" logger.info(f"Tagged {location}") - return f"The current location has been tagged as '{location_name}'." + return f"Tagged '{location_name}': ({position.x},{position.y})." + + def _navigate_to_object(self, query: str) -> Optional[str]: + position = self.detection_module.nav_vlm(query) + if not position: + return None + self.nav.navigate_to(position) + return f"Arrived to object matching '{query}' in view." @skill() def navigate_with_text(self, query: str) -> str: @@ -111,7 +116,6 @@ def navigate_with_text(self, query: str) -> str: if not self._started: raise ValueError(f"{self} has not been started.") - success_msg = self._navigate_by_tagged_location(query) if success_msg: return success_msg @@ -131,72 +135,25 @@ def navigate_with_text(self, query: str) -> str: return f"No tagged location called '{query}'. No object in view matching '{query}'. No matching location found in semantic map for '{query}'." def _navigate_by_tagged_location(self, query: str) -> Optional[str]: - robot_location = self._robot.spatial_memory.query_tagged_location(query) + robot_location = self.spatial_memory.query_tagged_location(query) if not robot_location: return None + print("Found tagged location:", robot_location) goal_pose = PoseStamped( position=make_vector3(*robot_location.position), orientation=euler_to_quaternion(make_vector3(*robot_location.rotation)), - frame_id="world", - ) - - result = self._robot.navigate_to(goal_pose, blocking=True) - if not result: - return None - - return ( - f"Successfuly arrived at location tagged '{robot_location.name}' from query '{query}'." + frame_id="map", ) - def _navigate_to_object(self, query: str) -> Optional[str]: - try: - bbox = self._get_bbox_for_current_frame(query) - except Exception: - logger.error(f"Failed to get bbox for {query}", exc_info=True) - return None + print("Goal pose for tagged location nav:", goal_pose) - if bbox is None: + result = self.nav.navigate_to(goal_pose) + if not result: return None - logger.info(f"Found {query} at {bbox}") - - # Start tracking - BBoxNavigationModule automatically generates goals - self._robot.object_tracker.track(bbox) - - start_time = time.time() - timeout = 30.0 - goal_set = False - - while time.time() - start_time < timeout: - # Check if navigator finished - if self._robot.navigator.get_state() == NavigatorState.IDLE and goal_set: - logger.info("Waiting for goal result") - time.sleep(1.0) - if not self._robot.navigator.is_goal_reached(): - logger.info(f"Goal cancelled, tracking '{query}' failed") - self._robot.object_tracker.stop_track() - return None - else: - logger.info(f"Reached '{query}'") - self._robot.object_tracker.stop_track() - return f"Successfully arrived at '{query}'" - - # If goal set and tracking lost, just continue (tracker will resume or timeout) - if goal_set and not self._robot.object_tracker.is_tracking(): - continue - - # BBoxNavigationModule automatically sends goals when tracker publishes - # Just check if we have any detections to mark goal_set - if self._robot.object_tracker.is_tracking(): - goal_set = True - - time.sleep(0.25) - - logger.warning(f"Navigation to '{query}' timed out after {timeout}s") - self._robot.object_tracker.stop_track() - return None + return f"Arrived to '{robot_location.name}' from query '{query}'." def _get_bbox_for_current_frame(self, query: str) -> Optional[BBox]: if self._latest_image is None: @@ -205,7 +162,7 @@ def _get_bbox_for_current_frame(self, query: str) -> Optional[BBox]: return get_object_bbox_from_image(self._vl_model, self._latest_image, query) def _navigate_using_semantic_map(self, query: str) -> str: - results = self._robot.spatial_memory.query_by_text(query) + results = self.spatial_memory.query_by_text(query) if not results: return f"No matching location found in semantic map for '{query}'" @@ -214,33 +171,34 @@ def _navigate_using_semantic_map(self, query: str) -> str: goal_pose = self._get_goal_pose_from_result(best_match) + print("Goal pose for semantic nav:", goal_pose) if not goal_pose: return f"Found a result for '{query}' but it didn't have a valid position." - result = self._robot.navigate_to(goal_pose, blocking=True) + result = self.nav.navigate_to(goal_pose) if not result: return f"Failed to navigate for '{query}'" return f"Successfuly arrived at '{query}'" - @skill() + # @skill() def follow_human(self, person: str) -> str: """Follow a specific person""" return "Not implemented yet." - @skill() + # @skill() def stop_movement(self) -> str: """Immediatly stop moving.""" if not self._started: raise ValueError(f"{self} has not been started.") - self._robot.stop_exploration() + # self._robot.stop_exploration() return "Stopped" - @skill() + # @skill() def start_exploration(self, timeout: float = 240.0) -> str: """A skill that performs autonomous frontier exploration. @@ -286,8 +244,9 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> Optional[PoseSta metadata = result.get("metadata") if not metadata: return None - + print(metadata) first = metadata[0] + print(first) pos_x = first.get("pos_x", 0) pos_y = first.get("pos_y", 0) theta = first.get("rot_z", 0) @@ -295,5 +254,5 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> Optional[PoseSta return PoseStamped( position=make_vector3(pos_x, pos_y, 0), orientation=euler_to_quaternion(make_vector3(0, 0, theta)), - frame_id="world", + frame_id="map", ) diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 5316cec0dc..487ceff89f 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -37,6 +37,7 @@ from std_msgs.msg import Int8 as ROSInt8 from tf2_msgs.msg import TFMessage as ROSTFMessage +from dimos.agents2 import Output, Reducer, Stream, skill from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc from dimos.msgs.geometry_msgs import ( PoseStamped, @@ -68,6 +69,8 @@ class ROSNav(Module): _local_pointcloud: Optional[ROSPointCloud2] = None _global_pointcloud: Optional[ROSPointCloud2] = None + _current_position_running: bool = False + def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): super().__init__(*args, **kwargs) @@ -247,6 +250,59 @@ def _set_autonomy_mode(self): self.joy_pub.publish(joy_msg) logger.info("Setting autonomy mode via Joy message") + @skill(stream=Stream.passive, reducer=Reducer.latest) + def current_position(self): + """passively stream the current position of the robot every second""" + if self._current_position_running: + return "already running" + while True: + self._current_position_running = True + time.sleep(1.0) + tf = self.tf.get("map", "base_link") + if not tf: + continue + yield f"current position {tf.translation.x}, {tf.translation.y}" + + @skill(stream=Stream.call_agent, reducer=Reducer.string) + def goto(self, x: float, y: float): + """ + move the robot in relative coordinates + x is forward, y is left + + goto(1, 0) will move the robot forward by 1 meter + """ + pose_to = PoseStamped( + position=Vector3(x, y, 0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + frame_id="base_link", + ts=time.time(), + ) + + yield "moving, please wait..." + self.navigate_to(pose_to) + yield "arrived" + + @skill(stream=Stream.call_agent, reducer=Reducer.string) + def goto_global(self, x: float, y: float) -> bool: + """ + go to coordinates x,y in the map frame + 0,0 is your starting position + """ + target = PoseStamped( + ts=time.time(), + frame_id="map", + position=Vector3(x, y, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + ) + + pos = self.tf.get("base_link", "map").translation + + yield f"moving from {pos.x:.2f}, {pos.y:.2f} to {x:.2f}, {y:.2f}, please wait..." + + self.navigate_to(target) + + yield "arrived to {x:.2f}, {y:.2f}" + @rpc def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: """ @@ -260,7 +316,7 @@ def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: True if navigation was successful """ logger.info( - f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f})" + f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f} @ {pose.frame_id})" ) self.goal_reach = None diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 45a9baa956..56ca66f940 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -23,7 +23,7 @@ from dimos import spec from dimos.agents2 import skill from dimos.core import DimosCluster, In, Out, rpc -from dimos.msgs.geometry_msgs import Transform, Vector3 +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.module2D import Config as Module2DConfig @@ -109,8 +109,9 @@ def pixel_to_3d( # Camera optical frame: X right, Y down, Z forward return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) - @skill # type: ignore[arg-type] - def ask_vlm(self, question: str) -> str | ImageDetections3DPC: + # @skill # type: ignore[arg-type] + @rpc + def nav_vlm(self, question: str) -> str: """ query visual model about the view in front of the camera you can ask to mark objects like: @@ -128,9 +129,28 @@ def ask_vlm(self, question: str) -> str | ImageDetections3DPC: return "No detections" detections: ImageDetections2D = result + + print(detections) + if not len(detections): + print("No 2d detections") + return None + pc = self.pointcloud.get_next() transform = self.tf.get("camera_optical", pc.frame_id, detections.image.ts, 5.0) - return self.process_frame(detections, pc, transform) + + detections3d = self.process_frame(detections, pc, transform) + + if len(detections3d): + return detections3d[0].pose + print("No 3d detections, projecting 2d") + + center = detections[0].get_bbox_center() + return PoseStamped( + ts=detections.image.ts, + frame_id="world", + position=self.pixel_to_3d(center, assumed_depth=1.5), + orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + ) @rpc def start(self): diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 6e79cf87ec..0428b79275 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -214,35 +214,36 @@ def agent_encode(self) -> str: return "No objects detected yet." return "\n".join(ret) - def vlm_query(self, description: str) -> Optional[Object3D]: # type: ignore[override] - imageDetections2D = super().ask_vlm(description) - print("VLM query found", imageDetections2D, "detections") - time.sleep(3) - - if not imageDetections2D.detections: - return None - - ret = [] - for obj in self.objects.values(): - if obj.ts != imageDetections2D.ts: - print( - "Skipping", - obj.track_id, - "ts", - obj.ts, - "!=", - imageDetections2D.ts, - ) - continue - if obj.class_id != -100: - continue - if obj.name != imageDetections2D.detections[0].name: - print("Skipping", obj.name, "!=", imageDetections2D.detections[0].name) - continue - ret.append(obj) - ret.sort(key=lambda x: x.ts) - - return ret[0] if ret else None + # @rpc + # def vlm_query(self, description: str) -> Optional[Object3D]: # type: ignore[override] + # imageDetections2D = super().ask_vlm(description) + # print("VLM query found", imageDetections2D, "detections") + # time.sleep(3) + + # if not imageDetections2D.detections: + # return None + + # ret = [] + # for obj in self.objects.values(): + # if obj.ts != imageDetections2D.ts: + # print( + # "Skipping", + # obj.track_id, + # "ts", + # obj.ts, + # "!=", + # imageDetections2D.ts, + # ) + # continue + # if obj.class_id != -100: + # continue + # if obj.name != imageDetections2D.detections[0].name: + # print("Skipping", obj.name, "!=", imageDetections2D.detections[0].name) + # continue + # ret.append(obj) + # ret.sort(key=lambda x: x.ts) + + # return ret[0] if ret else None def lookup(self, label: str) -> List[Detection3DPC]: """Look up a detection by label.""" @@ -254,8 +255,9 @@ def start(self): def update_objects(imageDetections: ImageDetections3DPC): for detection in imageDetections.detections: - # print(detection) - return self.add_detection(detection) + if detection.name == "person": + continue + self.add_detection(detection) def scene_thread(): while True: diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 7d93e2e174..5d374a6088 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -16,27 +16,28 @@ Spatial Memory module for creating a semantic map of the environment. """ -import uuid -import time import os -from typing import Dict, List, Optional, Any +import time +import uuid +from datetime import datetime +from typing import Any, Dict, List, Optional -import numpy as np import cv2 +import numpy as np from reactivex import Observable, disposable, interval from reactivex import operators as ops -from datetime import datetime from reactivex.disposable import Disposable -from dimos.core import In, Module, rpc -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.geometry_msgs import Vector3, Pose, PoseStamped -from dimos.utils.logging_config import setup_logger -from dimos.agents.memory.spatial_vector_db import SpatialVectorDB +from dimos import spec from dimos.agents.memory.image_embedding import ImageEmbeddingProvider +from dimos.agents.memory.spatial_vector_db import SpatialVectorDB from dimos.agents.memory.visual_memory import VisualMemory -from dimos.types.vector import Vector +from dimos.core import DimosCluster, In, Module, rpc +from dimos.msgs.geometry_msgs import Pose, PoseStamped, Vector3 +from dimos.msgs.sensor_msgs import Image from dimos.types.robot_location import RobotLocation +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger logger = setup_logger(__file__) @@ -52,8 +53,7 @@ class SpatialMemory(Module): """ # LCM inputs - color_image: In[Image] = None - odom: In[PoseStamped] = None + image: In[Image] = None def __init__( self, @@ -120,8 +120,8 @@ def __init__( except Exception as e: logger.error(f"Error clearing ChromaDB directory: {e}") - from chromadb.config import Settings import chromadb + from chromadb.config import Settings self._chroma_client = chromadb.PersistentClient( path=db_path, settings=Settings(anonymized_telemetry=False) @@ -169,7 +169,6 @@ def __init__( # Track latest data for processing self._latest_video_frame: Optional[np.ndarray] = None - self._latest_odom: Optional[PoseStamped] = None self._process_interval = 1 logger.info(f"SpatialMemory initialized with model {embedding_model}") @@ -188,13 +187,7 @@ def set_video(image_msg: Image): else: logger.warning("Received image message without data attribute") - def set_odom(odom_msg: PoseStamped): - self._latest_odom = odom_msg - - unsub = self.color_image.subscribe(set_video) - self._disposables.add(Disposable(unsub)) - - unsub = self.odom.subscribe(set_odom) + unsub = self.image.subscribe(set_video) self._disposables.add(Disposable(unsub)) # Start periodic processing using interval @@ -215,17 +208,13 @@ def stop(self): def _process_frame(self): """Process the latest frame with pose data if available.""" - if self._latest_video_frame is None or self._latest_odom is None: + tf = self.tf.get("map", "base_link") + if self._latest_video_frame is None or tf is None: return - # Extract position and rotation from odometry - position = self._latest_odom.position - orientation = self._latest_odom.orientation - + # print("Processing frame for spatial memory...", tf) # Create Pose object with position and orientation - current_pose = Pose( - position=Vector3(position.x, position.y, position.z), orientation=orientation - ) + current_pose = tf.to_pose() # Process the frame directly try: @@ -261,9 +250,10 @@ def _process_frame(self): frame_embedding = self.embedding_provider.get_embedding(self._latest_video_frame) frame_id = f"frame_{datetime.now().strftime('%Y%m%d_%H%M%S')}_{uuid.uuid4().hex[:8]}" - # Get euler angles from quaternion orientation for metadata - euler = orientation.to_euler() + euler = tf.rotation.to_euler() + + print(f"Storing frame {frame_id} at position {current_pose}...") # Create metadata dictionary with primitive types only metadata = { @@ -572,24 +562,16 @@ def add_named_location( Returns: True if successfully added, False otherwise """ - # Use current position/rotation if not provided - if position is None and self._latest_odom is not None: - pos = self._latest_odom.position - position = [pos.x, pos.y, pos.z] - - if rotation is None and self._latest_odom is not None: - euler = self._latest_odom.orientation.to_euler() - rotation = [euler.x, euler.y, euler.z] - - if position is None: + tf = self.tf.get("map", "base_link") + if not tf: logger.error("No position available for robot location") return False # Create RobotLocation object location = RobotLocation( name=name, - position=Vector(position), - rotation=Vector(rotation) if rotation else Vector([0, 0, 0]), + position=tf.translation, + rotation=tf.rotation.to_euler(), description=description or f"Location: {name}", timestamp=time.time(), ) @@ -649,3 +631,13 @@ def query_tagged_location(self, query: str) -> Optional[RobotLocation]: if semantic_distance < 0.3: return location return None + + +def deploy( + dimos: DimosCluster, + camera: spec.Camera, +): + spatial_memory = dimos.deploy(SpatialMemory, db_path="/tmp/spatial_memory_db") + spatial_memory.image.connect(camera.image) + spatial_memory.start() + return spatial_memory diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index 18211f65c2..91102a1ae3 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -13,12 +13,17 @@ # limitations under the License. import asyncio +import logging import threading +from typing import List, Optional # this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm from dimos_lcm.foxglove_bridge import FoxgloveBridge as LCMFoxgloveBridge -from dimos.core import Module, rpc +from dimos.core import DimosCluster, Module, rpc + +logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) +logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) class FoxgloveBridge(Module): @@ -58,3 +63,19 @@ def stop(self): self._thread.join(timeout=2) super().stop() + + +def deploy( + dimos: DimosCluster, + shm_channels: Optional[List[str]] = [ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + "/map#sensor_msgs.PointCloud2", + ], +) -> FoxgloveBridge: + foxglove_bridge = dimos.deploy( + FoxgloveBridge, + shm_channels=shm_channels, + ) + foxglove_bridge.start() + return foxglove_bridge diff --git a/dimos/robot/unitree_webrtc/modular/ivan_g1.py b/dimos/robot/unitree_webrtc/modular/ivan_g1.py index 0ff922388a..274e5c34a0 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_g1.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_g1.py @@ -12,23 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time - +from dimos import agents2 +from dimos.agents2.skills.navigation import NavigationSkillContainer from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE -from dimos.core import DimosCluster, LCMTransport, pSHMTransport, start, wait_exit +from dimos.core import DimosCluster, LCMTransport, Module, pSHMTransport, start, wait_exit from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam from dimos.msgs.geometry_msgs import ( - PoseStamped, Quaternion, Transform, Vector3, ) from dimos.msgs.sensor_msgs import CameraInfo from dimos.navigation import rosnav +from dimos.perception import spatial_perception +from dimos.perception.detection import module3D, moduleDB +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.protocol.skill.skill import SkillContainer, skill +from dimos.robot import foxglove_bridge from dimos.robot.unitree_webrtc.connection import g1 -from dimos.robot.unitree_webrtc.modular.misc import deploy_foxglove from dimos.utils.logging_config import setup_logger logger = setup_logger(__name__) @@ -60,25 +63,47 @@ def deploy_monozed(dimos) -> CameraModule: def deploy(dimos: DimosCluster, ip: str) -> None: nav = rosnav.deploy(dimos) connection = g1.deploy(dimos, ip, nav) - zed = deploy_monozed(dimos) + zedcam = deploy_monozed(dimos) + spatialmem = spatial_perception.deploy(dimos, zedcam) + + person_detector = module3D.deploy( + dimos, + zed.CameraInfo.SingleWebcam, + camera=zedcam, + lidar=nav, + detector=YoloPersonDetector, + ) - deploy_foxglove(dimos) + detector = moduleDB.deploy( + dimos, + zed.CameraInfo.SingleWebcam, + camera=zedcam, + lidar=nav, + ) - time.sleep(5) + foxglove_bridge.deploy(dimos) - test_pose = PoseStamped( - ts=time.time(), - frame_id="map", - position=Vector3(0.0, 0.0, 0.0), - orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + navskills = dimos.deploy( + NavigationSkillContainer, + spatialmem, + nav, + detector, ) + navskills.start() - nav.navigate_to(test_pose) + agent = agents2.deploy( + dimos, + "You are controling a humanoid robot", + skill_containers=[connection, nav, zedcam, spatialmem, navskills], + ) + # asofkasfkaslfks + agent.run_implicit_skill("current_position") + # agent.run_implicit_skill("video_stream") return { "nav": nav, "connection": connection, - "zed": zed, + "zed": zedcam, } diff --git a/dimos/robot/unitree_webrtc/modular/ivan_go2.py b/dimos/robot/unitree_webrtc/modular/ivan_go2.py index 81238d4268..16caf06e60 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_go2.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_go2.py @@ -18,8 +18,8 @@ from dimos import agents2 from dimos.core import DimosCluster, start, wait_exit from dimos.perception.detection import module3D, moduleDB +from dimos.robot import foxglove_bridge from dimos.robot.unitree_webrtc.connection import go2 -from dimos.robot.unitree_webrtc.modular.misc import deploy_foxglove from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) @@ -27,7 +27,7 @@ def deploy(dimos: DimosCluster, ip: str): connection = go2.deploy(dimos, ip) - deploy_foxglove(dimos) + foxglove_bridge.deploy(dimos) detector = moduleDB.deploy( dimos, diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py new file mode 100644 index 0000000000..948dccaa16 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -0,0 +1,139 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +from dimos_lcm.foxglove_msgs import SceneUpdate + +from dimos.agents2.spec import Model, Provider +from dimos.core import LCMTransport, start + +# from dimos.msgs.detection2d import Detection2DArray +from dimos.msgs.foxglove_msgs import ImageAnnotations +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.perception.detection.module2D import Detection2DModule +from dimos.perception.detection.module3D import Detection3DModule +from dimos.perception.detection.person_tracker import PersonTracker +from dimos.perception.detection.reid import ReidModule +from dimos.protocol.pubsub import lcm +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation +from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) + + +def detection_unitree(): + dimos = start(8) + connection = deploy_connection(dimos) + + def goto(pose): + print("NAVIGATION REQUESTED:", pose) + return True + + detector = dimos.deploy( + Detection2DModule, + # goto=goto, + camera_info=ConnectionModule._camera_info(), + ) + + detector.image.connect(connection.video) + # detector.pointcloud.connect(mapper.global_map) + # detector.pointcloud.connect(connection.lidar) + + detector.annotations.transport = LCMTransport("/annotations", ImageAnnotations) + detector.detections.transport = LCMTransport("/detections", Detection2DArray) + + # detector.detected_pointcloud_0.transport = LCMTransport("/detected/pointcloud/0", PointCloud2) + # detector.detected_pointcloud_1.transport = LCMTransport("/detected/pointcloud/1", PointCloud2) + # detector.detected_pointcloud_2.transport = LCMTransport("/detected/pointcloud/2", PointCloud2) + + detector.detected_image_0.transport = LCMTransport("/detected/image/0", Image) + detector.detected_image_1.transport = LCMTransport("/detected/image/1", Image) + detector.detected_image_2.transport = LCMTransport("/detected/image/2", Image) + # detector.scene_update.transport = LCMTransport("/scene_update", SceneUpdate) + + # reidModule = dimos.deploy(ReidModule) + + # reidModule.image.connect(connection.video) + # reidModule.detections.connect(detector.detections) + # reidModule.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) + + # nav = deploy_navigation(dimos, connection) + + # person_tracker = dimos.deploy(PersonTracker, cameraInfo=ConnectionModule._camera_info()) + # person_tracker.image.connect(connection.video) + # person_tracker.detections.connect(detector.detections) + # person_tracker.target.transport = LCMTransport("/goal_request", PoseStamped) + + reid = dimos.deploy(ReidModule) + + reid.image.connect(connection.video) + reid.detections.connect(detector.detections) + reid.annotations.transport = LCMTransport("/reid/annotations", ImageAnnotations) + + detector.start() + # person_tracker.start() + connection.start() + reid.start() + + from dimos.agents2 import Agent, Output, Reducer, Stream, skill + from dimos.agents2.cli.human import HumanInput + + agent = Agent( + system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot.", + model=Model.GPT_4O, # Could add CLAUDE models to enum + provider=Provider.OPENAI, # Would need ANTHROPIC provider + ) + + human_input = dimos.deploy(HumanInput) + agent.register_skills(human_input) + # agent.register_skills(connection) + agent.register_skills(detector) + + bridge = FoxgloveBridge( + shm_channels=[ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + ] + ) + # bridge = FoxgloveBridge() + time.sleep(1) + bridge.start() + + # agent.run_implicit_skill("video_stream_tool") + # agent.run_implicit_skill("human") + + # agent.start() + # agent.loop_thread() + + try: + while True: + time.sleep(1) + except KeyboardInterrupt: + connection.stop() + logger.info("Shutting down...") + + +def main(): + lcm.autoconf() + detection_unitree() + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree_webrtc/modular/misc.py b/dimos/robot/unitree_webrtc/modular/misc.py deleted file mode 100644 index 7df99237c7..0000000000 --- a/dimos/robot/unitree_webrtc/modular/misc.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from dimos.core import DimosCluster -from dimos.robot.foxglove_bridge import FoxgloveBridge - -logging.getLogger("FoxgloveServer").setLevel(logging.ERROR) - - -def deploy_foxglove(dimos: DimosCluster) -> FoxgloveBridge: - foxglove_bridge = dimos.deploy( - FoxgloveBridge, - shm_channels=[ - "/image#sensor_msgs.Image", - "/lidar#sensor_msgs.PointCloud2", - "/map#sensor_msgs.PointCloud2", - ], - ) - foxglove_bridge.start() - return foxglove_bridge From 225c50432b690527ad406b1521af9b7baa05c4cc Mon Sep 17 00:00:00 2001 From: lesh Date: Sat, 18 Oct 2025 01:56:45 -0700 Subject: [PATCH 16/40] wrap --- dimos/agents2/agent.py | 4 ++-- dimos/agents2/skills/navigation.py | 1 + dimos/perception/detection/module3D.py | 7 +++++-- dimos/perception/detection/moduleDB.py | 6 +++--- dimos/robot/unitree_webrtc/modular/ivan_g1.py | 15 +++++++-------- 5 files changed, 18 insertions(+), 15 deletions(-) diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 430873c396..6b448567aa 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -285,8 +285,8 @@ def _get_state() -> str: if msg.tool_calls: self.execute_tool_calls(msg.tool_calls) - print(self) - print(self.coordinator) + # print(self) + # print(self.coordinator) self._write_debug_history_file() diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index ae57995b18..938a8b2684 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -96,6 +96,7 @@ def tag_location(self, location_name: str) -> str: def _navigate_to_object(self, query: str) -> Optional[str]: position = self.detection_module.nav_vlm(query) + print("Object position from VLM:", position) if not position: return None self.nav.navigate_to(position) diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 56ca66f940..18d99396c0 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -123,10 +123,13 @@ def nav_vlm(self, question: str) -> str: from dimos.models.vl.qwen import QwenVlModel model = QwenVlModel() - result = model.query(self.image.get_next(), question) + image = self.image.get_next() + result = model.query_detections(image, question) + + print("VLM result:", result, "for", image, "and question", question) if isinstance(result, str) or not result or not len(result): - return "No detections" + return None detections: ImageDetections2D = result diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 0428b79275..ff3fecd279 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -255,7 +255,7 @@ def start(self): def update_objects(imageDetections: ImageDetections3DPC): for detection in imageDetections.detections: - if detection.name == "person": + if detection.class_id == 1: continue self.add_detection(detection) @@ -293,8 +293,8 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": for obj in copy(self.objects).values(): # we need at least 3 detectieons to consider it a valid object # for this to be serious we need a ratio of detections within the window of observations - # if obj.class_id != -100 and obj.detections < 2: - # continue + if obj.class_id != -100 and obj.detections < 4: + continue # print( # f"Object {obj.track_id}: {len(obj.detections)} detections, confidence {obj.confidence}" diff --git a/dimos/robot/unitree_webrtc/modular/ivan_g1.py b/dimos/robot/unitree_webrtc/modular/ivan_g1.py index 274e5c34a0..7e3fe98e05 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_g1.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_g1.py @@ -66,13 +66,13 @@ def deploy(dimos: DimosCluster, ip: str) -> None: zedcam = deploy_monozed(dimos) spatialmem = spatial_perception.deploy(dimos, zedcam) - person_detector = module3D.deploy( - dimos, - zed.CameraInfo.SingleWebcam, - camera=zedcam, - lidar=nav, - detector=YoloPersonDetector, - ) + # person_detector = module3D.deploy( + # dimos, + # zed.CameraInfo.SingleWebcam, + # camera=zedcam, + # lidar=nav, + # detector=YoloPersonDetector, + # ) detector = moduleDB.deploy( dimos, @@ -96,7 +96,6 @@ def deploy(dimos: DimosCluster, ip: str) -> None: "You are controling a humanoid robot", skill_containers=[connection, nav, zedcam, spatialmem, navskills], ) - # asofkasfkaslfks agent.run_implicit_skill("current_position") # agent.run_implicit_skill("video_stream") From 812318b05379004d6c280adab84e570c030f1391 Mon Sep 17 00:00:00 2001 From: lesh Date: Sat, 18 Oct 2025 13:44:24 -0700 Subject: [PATCH 17/40] bugfix --- dimos/models/vl/test_models.py | 1 - dimos/robot/unitree_webrtc/connection/g1.py | 11 ++++------- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/dimos/models/vl/test_models.py b/dimos/models/vl/test_models.py index 3871626ae1..adc49798e9 100644 --- a/dimos/models/vl/test_models.py +++ b/dimos/models/vl/test_models.py @@ -8,7 +8,6 @@ from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.sensor_msgs import Image -from dimos.perception.detection.detectors.yolo import Yolo2DDetector from dimos.perception.detection.type import ImageDetections2D from dimos.utils.data import get_data diff --git a/dimos/robot/unitree_webrtc/connection/g1.py b/dimos/robot/unitree_webrtc/connection/g1.py index b1b82a2dff..9b4e9a87fa 100644 --- a/dimos/robot/unitree_webrtc/connection/g1.py +++ b/dimos/robot/unitree_webrtc/connection/g1.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Protocol - -from reactivex.disposable import Disposable - from dimos import spec -from dimos.core import DimosCluster, In, LCMTransport, Module, Out, rpc +from dimos.core import DimosCluster, In, Module, rpc from dimos.msgs.geometry_msgs import ( Twist, TwistStamped, @@ -39,8 +35,9 @@ def start(self): self.connection = UnitreeWebRTCConnection(self.ip) self.connection.start() - unsub = self.cmd_vel.subscribe(self.move) - self._disposables.add(Disposable(unsub)) + self._disposables.add( + self.cmd_vel.subscribe(self.move), + ) @rpc def stop(self) -> None: From 28ca16cce4046bdd13ec19c0a9368b67f1bab037 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 19 Oct 2025 20:21:41 -0700 Subject: [PATCH 18/40] modular g1 run files --- dimos/perception/detection/module3D.py | 17 ++++- dimos/perception/detection/moduleDB.py | 6 +- dimos/robot/unitree_webrtc/modular/g1agent.py | 67 +++++++++++++++++++ .../unitree_webrtc/modular/g1detector.py | 62 +++++++++++++++++ .../modular/{ivan_g1.py => g1zed.py} | 43 +----------- .../unitree_webrtc/modular/ivan_unitree.py | 4 +- 6 files changed, 150 insertions(+), 49 deletions(-) create mode 100644 dimos/robot/unitree_webrtc/modular/g1agent.py create mode 100644 dimos/robot/unitree_webrtc/modular/g1detector.py rename dimos/robot/unitree_webrtc/modular/{ivan_g1.py => g1zed.py} (67%) diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 18d99396c0..bb5c828332 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -98,8 +98,8 @@ def pixel_to_3d( Vector3 position in camera optical frame coordinates """ # Extract camera intrinsics - fx, fy = self.camera_info.K[0], self.camera_info.K[4] - cx, cy = self.camera_info.K[2], self.camera_info.K[5] + fx, fy = self.config.camera_info.K[0], self.config.camera_info.K[4] + cx, cy = self.config.camera_info.K[2], self.config.camera_info.K[5] # Unproject pixel to normalized camera coordinates x_norm = (pixel[0] - cx) / fx @@ -109,6 +109,17 @@ def pixel_to_3d( # Camera optical frame: X right, Y down, Z forward return Vector3(x_norm * assumed_depth, y_norm * assumed_depth, assumed_depth) + @skill() + def ask_vlm(self, question: str) -> str: + """asks a visual model about the view of the robot, for example + is the bannana in the trunk? + """ + from dimos.models.vl.qwen import QwenVlModel + + model = QwenVlModel() + image = self.image.get_next() + return model.query(image, question) + # @skill # type: ignore[arg-type] @rpc def nav_vlm(self, question: str) -> str: @@ -152,7 +163,7 @@ def nav_vlm(self, question: str) -> str: ts=detections.image.ts, frame_id="world", position=self.pixel_to_3d(center, assumed_depth=1.5), - orientation=Quaternion(0.0, 0.0, 0.0, 0.0), + orientation=Quaternion(0.0, 0.0, 0.0, 1.0), ) @rpc diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index ff3fecd279..8485ff8416 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -207,8 +207,8 @@ def agent_encode(self) -> str: for obj in copy(self.objects).values(): # we need at least 3 detectieons to consider it a valid object # for this to be serious we need a ratio of detections within the window of observations - # if len(obj.detections) < 3: - # continue + if len(obj.detections) < 4: + continue ret.append(str(obj.agent_encode())) if not ret: return "No objects detected yet." @@ -255,8 +255,6 @@ def start(self): def update_objects(imageDetections: ImageDetections3DPC): for detection in imageDetections.detections: - if detection.class_id == 1: - continue self.add_detection(detection) def scene_thread(): diff --git a/dimos/robot/unitree_webrtc/modular/g1agent.py b/dimos/robot/unitree_webrtc/modular/g1agent.py new file mode 100644 index 0000000000..2f280d6b98 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/g1agent.py @@ -0,0 +1,67 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos import agents2 +from dimos.agents2.skills.navigation import NavigationSkillContainer +from dimos.core import DimosCluster, start, wait_exit +from dimos.perception import spatial_perception +from dimos.robot.unitree_webrtc.modular import g1detector + + +def deploy(dimos: DimosCluster, ip: str) -> None: + g1 = g1detector.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + detector3d = g1.get("detector3d") + connection = g1.get("connection") + + spatialmem = spatial_perception.deploy(dimos, camera) + + navskills = dimos.deploy( + NavigationSkillContainer, + spatialmem, + nav, + detector3d, + ) + navskills.start() + + agent = agents2.deploy( + dimos, + "You are controling a humanoid robot", + skill_containers=[connection, nav, camera, spatialmem, navskills], + ) + agent.run_implicit_skill("current_position") + agent.run_implicit_skill("video_stream") + + return {"agent": agent, "spatialmem": spatialmem} + g1 + + +if __name__ == "__main__": + import argparse + import os + + from dotenv import load_dotenv + + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") + parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") + + args = parser.parse_args() + + dimos = start(8) + deploy(dimos, args.ip) + wait_exit() + dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/g1detector.py b/dimos/robot/unitree_webrtc/modular/g1detector.py new file mode 100644 index 0000000000..272b2a4c4d --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/g1detector.py @@ -0,0 +1,62 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core import DimosCluster, start, wait_exit +from dimos.perception.detection import module3D, moduleDB +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.robot.unitree_webrtc.modular import g1zed + + +def deploy(dimos: DimosCluster, ip: str) -> None: + g1 = g1zed.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + camerainfo = g1.get("camerainfo") + + person_detector = module3D.deploy( + dimos, + camerainfo, + camera=camera, + lidar=nav, + detector=YoloPersonDetector, + ) + + detector3d = moduleDB.deploy( + dimos, + camerainfo, + camera=camera, + lidar=nav, + ) + + return {"person_detector": person_detector, "detector3d": detector3d} + g1 + + +if __name__ == "__main__": + import argparse + import os + + from dotenv import load_dotenv + + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") + parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") + + args = parser.parse_args() + + dimos = start(8) + deploy(dimos, args.ip) + wait_exit() + dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/ivan_g1.py b/dimos/robot/unitree_webrtc/modular/g1zed.py similarity index 67% rename from dimos/robot/unitree_webrtc/modular/ivan_g1.py rename to dimos/robot/unitree_webrtc/modular/g1zed.py index 7e3fe98e05..ba4caa3d38 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_g1.py +++ b/dimos/robot/unitree_webrtc/modular/g1zed.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos import agents2 -from dimos.agents2.skills.navigation import NavigationSkillContainer from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE -from dimos.core import DimosCluster, LCMTransport, Module, pSHMTransport, start, wait_exit +from dimos.core import DimosCluster, LCMTransport, pSHMTransport, start, wait_exit from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam @@ -26,10 +24,6 @@ ) from dimos.msgs.sensor_msgs import CameraInfo from dimos.navigation import rosnav -from dimos.perception import spatial_perception -from dimos.perception.detection import module3D, moduleDB -from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector -from dimos.protocol.skill.skill import SkillContainer, skill from dimos.robot import foxglove_bridge from dimos.robot.unitree_webrtc.connection import g1 from dimos.utils.logging_config import setup_logger @@ -64,45 +58,14 @@ def deploy(dimos: DimosCluster, ip: str) -> None: nav = rosnav.deploy(dimos) connection = g1.deploy(dimos, ip, nav) zedcam = deploy_monozed(dimos) - spatialmem = spatial_perception.deploy(dimos, zedcam) - - # person_detector = module3D.deploy( - # dimos, - # zed.CameraInfo.SingleWebcam, - # camera=zedcam, - # lidar=nav, - # detector=YoloPersonDetector, - # ) - - detector = moduleDB.deploy( - dimos, - zed.CameraInfo.SingleWebcam, - camera=zedcam, - lidar=nav, - ) foxglove_bridge.deploy(dimos) - navskills = dimos.deploy( - NavigationSkillContainer, - spatialmem, - nav, - detector, - ) - navskills.start() - - agent = agents2.deploy( - dimos, - "You are controling a humanoid robot", - skill_containers=[connection, nav, zedcam, spatialmem, navskills], - ) - agent.run_implicit_skill("current_position") - # agent.run_implicit_skill("video_stream") - return { "nav": nav, "connection": connection, - "zed": zedcam, + "camera": zedcam, + "camerainfo": zed.CameraInfo.SingleWebcam, } diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py index 948dccaa16..a9aa986e2e 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -31,7 +31,7 @@ from dimos.perception.detection.reid import ReidModule from dimos.protocol.pubsub import lcm from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation +from dimos.robot.unitree_webrtc.connection import go2 from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule from dimos.utils.logging_config import setup_logger @@ -40,7 +40,7 @@ def detection_unitree(): dimos = start(8) - connection = deploy_connection(dimos) + connection = go2.deploy(dimos) def goto(pose): print("NAVIGATION REQUESTED:", pose) From c1d102ad9177b77834a27edf996804a2a2a534df Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 19 Oct 2025 20:24:46 -0700 Subject: [PATCH 19/40] good run files --- dimos/robot/unitree_webrtc/modular/g1agent.py | 19 ---- .../unitree_webrtc/modular/g1detector.py | 19 ---- dimos/robot/unitree_webrtc/modular/g1zed.py | 19 ---- dimos/robot/unitree_webrtc/modular/run.py | 92 +++++++++++++++++++ 4 files changed, 92 insertions(+), 57 deletions(-) create mode 100644 dimos/robot/unitree_webrtc/modular/run.py diff --git a/dimos/robot/unitree_webrtc/modular/g1agent.py b/dimos/robot/unitree_webrtc/modular/g1agent.py index 2f280d6b98..284b58e80f 100644 --- a/dimos/robot/unitree_webrtc/modular/g1agent.py +++ b/dimos/robot/unitree_webrtc/modular/g1agent.py @@ -46,22 +46,3 @@ def deploy(dimos: DimosCluster, ip: str) -> None: agent.run_implicit_skill("video_stream") return {"agent": agent, "spatialmem": spatialmem} + g1 - - -if __name__ == "__main__": - import argparse - import os - - from dotenv import load_dotenv - - load_dotenv() - - parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") - parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") - - args = parser.parse_args() - - dimos = start(8) - deploy(dimos, args.ip) - wait_exit() - dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/g1detector.py b/dimos/robot/unitree_webrtc/modular/g1detector.py index 272b2a4c4d..ec1ef73a95 100644 --- a/dimos/robot/unitree_webrtc/modular/g1detector.py +++ b/dimos/robot/unitree_webrtc/modular/g1detector.py @@ -41,22 +41,3 @@ def deploy(dimos: DimosCluster, ip: str) -> None: ) return {"person_detector": person_detector, "detector3d": detector3d} + g1 - - -if __name__ == "__main__": - import argparse - import os - - from dotenv import load_dotenv - - load_dotenv() - - parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") - parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") - - args = parser.parse_args() - - dimos = start(8) - deploy(dimos, args.ip) - wait_exit() - dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/g1zed.py b/dimos/robot/unitree_webrtc/modular/g1zed.py index ba4caa3d38..3fd41d633b 100644 --- a/dimos/robot/unitree_webrtc/modular/g1zed.py +++ b/dimos/robot/unitree_webrtc/modular/g1zed.py @@ -67,22 +67,3 @@ def deploy(dimos: DimosCluster, ip: str) -> None: "camera": zedcam, "camerainfo": zed.CameraInfo.SingleWebcam, } - - -if __name__ == "__main__": - import argparse - import os - - from dotenv import load_dotenv - - load_dotenv() - - parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") - parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") - - args = parser.parse_args() - - dimos = start(8) - deploy(dimos, args.ip) - wait_exit() - dimos.close_all() diff --git a/dimos/robot/unitree_webrtc/modular/run.py b/dimos/robot/unitree_webrtc/modular/run.py new file mode 100644 index 0000000000..0fa8511d24 --- /dev/null +++ b/dimos/robot/unitree_webrtc/modular/run.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Centralized runner for modular G1 deployment scripts. + +Usage: + python run.py g1agent --ip 192.168.1.100 + python run.py g1zed + python run.py g1detector --ip $ROBOT_IP +""" + +import argparse +import importlib +import os +import sys + +from dotenv import load_dotenv + +from dimos.core import start, wait_exit + + +def main(): + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree G1 Modular Deployment Runner") + parser.add_argument( + "module", + help="Module name to run (e.g., g1agent, g1zed, g1detector)", + ) + parser.add_argument( + "--ip", + default=os.getenv("ROBOT_IP"), + help="Robot IP address (default: ROBOT_IP from .env)", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="Number of worker threads for DimosCluster (default: 8)", + ) + + args = parser.parse_args() + + # Validate IP address + if not args.ip: + print("ERROR: Robot IP address not provided") + print("Please provide --ip or set ROBOT_IP in .env") + sys.exit(1) + + # Import the module + try: + # Try importing from current package first + module = importlib.import_module( + f".{args.module}", package="dimos.robot.unitree_webrtc.modular" + ) + except ImportError as e: + print(f"ERROR: Could not import module '{args.module}'") + print(f"Make sure the module exists in dimos/robot/unitree_webrtc/modular/") + print(f"Import error: {e}") + sys.exit(1) + + # Verify deploy function exists + if not hasattr(module, "deploy"): + print(f"ERROR: Module '{args.module}' does not have a 'deploy' function") + sys.exit(1) + + print(f"Running {args.module}.deploy() with IP {args.ip}") + + # Run the standard deployment pattern + dimos = start(args.workers) + try: + module.deploy(dimos, args.ip) + wait_exit() + finally: + dimos.close_all() + + +if __name__ == "__main__": + main() From a05520cccf40bb970ebbc64be3f1b7ba1edd33dd Mon Sep 17 00:00:00 2001 From: lesh Date: Mon, 20 Oct 2025 13:17:30 -0700 Subject: [PATCH 20/40] cleanup --- dimos/perception/detection/module2D.py | 1 - dimos/robot/unitree_webrtc/modular/g1agent.py | 2 +- dimos/robot/unitree_webrtc/modular/g1detector.py | 5 ++++- dimos/robot/unitree_webrtc/modular/g1zed.py | 1 + 4 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 913e84bd7a..26f33f7e95 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -156,7 +156,6 @@ def deploy( from dimos.core import LCMTransport detector = Detection2DModule(**kwargs) - detector.image.connect(camera.image) detector.annotations.transport = LCMTransport(f"{prefix}/annotations", ImageAnnotations) diff --git a/dimos/robot/unitree_webrtc/modular/g1agent.py b/dimos/robot/unitree_webrtc/modular/g1agent.py index 284b58e80f..895b3fe1b6 100644 --- a/dimos/robot/unitree_webrtc/modular/g1agent.py +++ b/dimos/robot/unitree_webrtc/modular/g1agent.py @@ -14,7 +14,7 @@ from dimos import agents2 from dimos.agents2.skills.navigation import NavigationSkillContainer -from dimos.core import DimosCluster, start, wait_exit +from dimos.core import DimosCluster from dimos.perception import spatial_perception from dimos.robot.unitree_webrtc.modular import g1detector diff --git a/dimos/robot/unitree_webrtc/modular/g1detector.py b/dimos/robot/unitree_webrtc/modular/g1detector.py index ec1ef73a95..9cbfbbf897 100644 --- a/dimos/robot/unitree_webrtc/modular/g1detector.py +++ b/dimos/robot/unitree_webrtc/modular/g1detector.py @@ -40,4 +40,7 @@ def deploy(dimos: DimosCluster, ip: str) -> None: lidar=nav, ) - return {"person_detector": person_detector, "detector3d": detector3d} + g1 + return { + "person_detector": person_detector, + "detector3d": detector3d, + } + g1 diff --git a/dimos/robot/unitree_webrtc/modular/g1zed.py b/dimos/robot/unitree_webrtc/modular/g1zed.py index 3fd41d633b..6aee2276e8 100644 --- a/dimos/robot/unitree_webrtc/modular/g1zed.py +++ b/dimos/robot/unitree_webrtc/modular/g1zed.py @@ -48,6 +48,7 @@ def deploy_monozed(dimos) -> CameraModule: camera_info=zed.CameraInfo.SingleWebcam, ), ) + camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) camera.start() From 24c192bf79aa80654b13b7969089694333f44a28 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 21 Oct 2025 14:08:35 -0700 Subject: [PATCH 21/40] small changes --- dimos/robot/unitree_webrtc/modular/g1detector.py | 2 +- dimos/robot/unitree_webrtc/modular/g1zed.py | 4 ++-- dimos/robot/unitree_webrtc/modular/run.py | 7 ++++++- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/dimos/robot/unitree_webrtc/modular/g1detector.py b/dimos/robot/unitree_webrtc/modular/g1detector.py index 9cbfbbf897..90a8e3f44b 100644 --- a/dimos/robot/unitree_webrtc/modular/g1detector.py +++ b/dimos/robot/unitree_webrtc/modular/g1detector.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.core import DimosCluster, start, wait_exit +from dimos.core import DimosCluster from dimos.perception.detection import module3D, moduleDB from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.robot.unitree_webrtc.modular import g1zed diff --git a/dimos/robot/unitree_webrtc/modular/g1zed.py b/dimos/robot/unitree_webrtc/modular/g1zed.py index 6aee2276e8..3af5b53118 100644 --- a/dimos/robot/unitree_webrtc/modular/g1zed.py +++ b/dimos/robot/unitree_webrtc/modular/g1zed.py @@ -31,7 +31,7 @@ logger = setup_logger(__name__) -def deploy_monozed(dimos) -> CameraModule: +def deploy_g1_monozed(dimos) -> CameraModule: camera = dimos.deploy( CameraModule, frequency=4.0, @@ -58,7 +58,7 @@ def deploy_monozed(dimos) -> CameraModule: def deploy(dimos: DimosCluster, ip: str) -> None: nav = rosnav.deploy(dimos) connection = g1.deploy(dimos, ip, nav) - zedcam = deploy_monozed(dimos) + zedcam = deploy_g1_monozed(dimos) foxglove_bridge.deploy(dimos) diff --git a/dimos/robot/unitree_webrtc/modular/run.py b/dimos/robot/unitree_webrtc/modular/run.py index 0fa8511d24..aa6ca2af14 100644 --- a/dimos/robot/unitree_webrtc/modular/run.py +++ b/dimos/robot/unitree_webrtc/modular/run.py @@ -67,9 +67,14 @@ def main(): f".{args.module}", package="dimos.robot.unitree_webrtc.modular" ) except ImportError as e: - print(f"ERROR: Could not import module '{args.module}'") + import traceback + + traceback.print_exc() + + print(f"\nERROR: Could not import module '{args.module}'") print(f"Make sure the module exists in dimos/robot/unitree_webrtc/modular/") print(f"Import error: {e}") + sys.exit(1) # Verify deploy function exists From f90ff5f22841cfc018d5ddfa375d9a8210462ca9 Mon Sep 17 00:00:00 2001 From: lesh Date: Tue, 21 Oct 2025 16:47:34 -0700 Subject: [PATCH 22/40] detection module fixes, g1 run files work --- dimos/perception/detection/module2D.py | 16 ++++++-- dimos/perception/detection/module3D.py | 4 -- dimos/perception/detection/moduleDB.py | 41 ++++++++++++------- dimos/perception/detection/type/__init__.py | 2 + .../detection/type/detection2d/__init__.py | 2 +- .../detection/type/detection2d/base.py | 5 ++- .../detection/type/imageDetections.py | 18 +++++++- dimos/robot/unitree_webrtc/modular/g1agent.py | 2 +- .../unitree_webrtc/modular/g1detector.py | 8 ++-- dimos/robot/unitree_webrtc/modular/g1zed.py | 2 +- dimos/spec/perception.py | 2 - 11 files changed, 68 insertions(+), 34 deletions(-) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index 26f33f7e95..aec2850e3e 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -31,9 +31,7 @@ from dimos.perception.detection.detectors import Detector from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.detectors.yolo import Yolo2DDetector -from dimos.perception.detection.type import ( - ImageDetections2D, -) +from dimos.perception.detection.type import Detection2D, Filter2D, ImageDetections2D from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure @@ -44,6 +42,13 @@ class Config(ModuleConfig): detector: Optional[Callable[[Any], Detector]] = Yolo2DDetector publish_detection_images: bool = True camera_info: CameraInfo = None # type: ignore + filter: list[Filter2D] | Filter2D | None = None + + def __post_init__(self): + if self.filter is None: + self.filter = [] + elif not isinstance(self.filter, list): + self.filter = [self.filter] class Detection2DModule(Module): @@ -69,7 +74,10 @@ def __init__(self, *args, **kwargs): self.previous_detection_count = 0 def process_image_frame(self, image: Image) -> ImageDetections2D: - return self.detector.process_image(image) + imageDetections = self.detector.process_image(image) + if not self.config.filter: + return imageDetections + return imageDetections.filter(*self.config.filter) @simple_mcache def sharp_image_stream(self) -> Observable[Image]: diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index bb5c828332..792acb1969 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -26,7 +26,6 @@ from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray -from dimos.perception.detection.module2D import Config as Module2DConfig from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.type import ( ImageDetections2D, @@ -37,9 +36,6 @@ from dimos.utils.reactive import backpressure -class Config(Module2DConfig): ... - - class Detection3DModule(Detection2DModule): image: In[Image] = None # type: ignore pointcloud: In[PointCloud2] = None # type: ignore diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 8485ff8416..959e3a6138 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -158,6 +158,25 @@ class ObjectDBModule(Detection3DModule, TableStr): remembered_locations: Dict[str, PoseStamped] + @rpc + def start(self): + Detection3DModule.start(self) + + def update_objects(imageDetections: ImageDetections3DPC): + for detection in imageDetections.detections: + self.add_detection(detection) + + def scene_thread(): + while True: + print(self) + scene_update = self.to_foxglove_scene_update() + self.scene_update.publish(scene_update) + time.sleep(1.0) + + threading.Thread(target=scene_thread, daemon=True).start() + + self.detection_stream_3d.subscribe(update_objects) + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.goto = None @@ -259,6 +278,7 @@ def update_objects(imageDetections: ImageDetections3DPC): def scene_thread(): while True: + print(self) scene_update = self.to_foxglove_scene_update() self.scene_update.publish(scene_update) time.sleep(1.0) @@ -288,22 +308,13 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": scene_update.deletions = [] scene_update.entities = [] - for obj in copy(self.objects).values(): - # we need at least 3 detectieons to consider it a valid object - # for this to be serious we need a ratio of detections within the window of observations - if obj.class_id != -100 and obj.detections < 4: - continue - - # print( - # f"Object {obj.track_id}: {len(obj.detections)} detections, confidence {obj.confidence}" - # ) - # print(obj.to_pose()) - - scene_update.entities.append( - obj.to_foxglove_scene_entity( - entity_id=f"object_{obj.name}_{obj.track_id}_{obj.detections}" + for obj in self.objects: + try: + scene_update.entities.append( + obj.to_foxglove_scene_entity(entity_id=f"{obj.name}_{obj.track_id}") ) - ) + except Exception as e: + pass scene_update.entities_length = len(scene_update.entities) return scene_update diff --git a/dimos/perception/detection/type/__init__.py b/dimos/perception/detection/type/__init__.py index d8f36d79dc..bc44d984fd 100644 --- a/dimos/perception/detection/type/__init__.py +++ b/dimos/perception/detection/type/__init__.py @@ -2,6 +2,7 @@ Detection2D, Detection2DBBox, Detection2DPerson, + Filter2D, ImageDetections2D, ) from dimos.perception.detection.type.detection3d import ( @@ -21,6 +22,7 @@ __all__ = [ # 2D Detection types "Detection2D", + "Filter2D", "Detection2DBBox", "Detection2DPerson", "ImageDetections2D", diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py index 1096abda9c..197c7a55e2 100644 --- a/dimos/perception/detection/type/detection2d/__init__.py +++ b/dimos/perception/detection/type/detection2d/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.perception.detection.type.detection2d.base import Detection2D +from dimos.perception.detection.type.detection2d.base import Detection2D, Filter2D from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox from dimos.perception.detection.type.detection2d.imageDetections2D import ImageDetections2D from dimos.perception.detection.type.detection2d.person import Detection2DPerson diff --git a/dimos/perception/detection/type/detection2d/base.py b/dimos/perception/detection/type/detection2d/base.py index e89bf65409..ea57acb911 100644 --- a/dimos/perception/detection/type/detection2d/base.py +++ b/dimos/perception/detection/type/detection2d/base.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from typing import List +from typing import Callable, List from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation, TextAnnotation from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D @@ -50,3 +50,6 @@ def to_points_annotation(self) -> List[PointsAnnotation]: def to_ros_detection2d(self) -> ROSDetection2D: """Convert detection to ROS Detection2D message.""" ... + + +Filter2D = Callable[[Detection2D], bool] diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 994c939e4d..0a1ce8cf56 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Callable, Generic, List, Optional, TypeVar from dimos_lcm.vision_msgs import Detection2DArray @@ -57,6 +57,22 @@ def __iter__(self): def __getitem__(self, index): return self.detections[index] + def filter(self, *predicates: Callable[[T], bool]) -> ImageDetections[T]: + """Filter detections using one or more predicate functions. + + Multiple predicates are applied in cascade (all must return True). + + Args: + *predicates: Functions that take a detection and return True to keep it + + Returns: + A new ImageDetections instance with filtered detections + """ + filtered_detections = self.detections + for predicate in predicates: + filtered_detections = [det for det in filtered_detections if predicate(det)] + return ImageDetections(self.image, filtered_detections) + def to_ros_detection2d_array(self) -> Detection2DArray: return Detection2DArray( detections_length=len(self.detections), diff --git a/dimos/robot/unitree_webrtc/modular/g1agent.py b/dimos/robot/unitree_webrtc/modular/g1agent.py index 895b3fe1b6..06da0ec950 100644 --- a/dimos/robot/unitree_webrtc/modular/g1agent.py +++ b/dimos/robot/unitree_webrtc/modular/g1agent.py @@ -45,4 +45,4 @@ def deploy(dimos: DimosCluster, ip: str) -> None: agent.run_implicit_skill("current_position") agent.run_implicit_skill("video_stream") - return {"agent": agent, "spatialmem": spatialmem} + g1 + return {"agent": agent, "spatialmem": spatialmem, **g1} diff --git a/dimos/robot/unitree_webrtc/modular/g1detector.py b/dimos/robot/unitree_webrtc/modular/g1detector.py index 90a8e3f44b..d058c64825 100644 --- a/dimos/robot/unitree_webrtc/modular/g1detector.py +++ b/dimos/robot/unitree_webrtc/modular/g1detector.py @@ -38,9 +38,9 @@ def deploy(dimos: DimosCluster, ip: str) -> None: camerainfo, camera=camera, lidar=nav, + filter=lambda det: det.class_id != 0, ) - return { - "person_detector": person_detector, - "detector3d": detector3d, - } + g1 + # return {"detector3d": detector3d, **g1} + + return {"person_detector": person_detector, "detector3d": detector3d, **g1} diff --git a/dimos/robot/unitree_webrtc/modular/g1zed.py b/dimos/robot/unitree_webrtc/modular/g1zed.py index 3af5b53118..c33d71e2ad 100644 --- a/dimos/robot/unitree_webrtc/modular/g1zed.py +++ b/dimos/robot/unitree_webrtc/modular/g1zed.py @@ -37,7 +37,7 @@ def deploy_g1_monozed(dimos) -> CameraModule: frequency=4.0, transform=Transform( translation=Vector3(0.05, 0.0, 0.0), - rotation=Quaternion.from_euler(Vector3(0.0, 0.2, 0.0)), + rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), frame_id="sensor", child_frame_id="camera_link", ), diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index dba9feb67c..09a0d18524 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -14,8 +14,6 @@ from typing import Protocol -from dimos_lcm.sensor_msgs import CameraInfo - from dimos.core import Out from dimos.msgs.sensor_msgs import Image, PointCloud2 From c7cc70c9313800a0607372e8071dc40c357df91e Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 03:28:20 +0200 Subject: [PATCH 23/40] go2 clean --- .../robot/unitree_webrtc/modular/ivan_go2.py | 23 ++----------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/dimos/robot/unitree_webrtc/modular/ivan_go2.py b/dimos/robot/unitree_webrtc/modular/ivan_go2.py index 16caf06e60..d4d7a89704 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_go2.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_go2.py @@ -16,8 +16,8 @@ import time from dimos import agents2 -from dimos.core import DimosCluster, start, wait_exit -from dimos.perception.detection import module3D, moduleDB +from dimos.core import DimosCluster +from dimos.perception.detection import moduleDB from dimos.robot import foxglove_bridge from dimos.robot.unitree_webrtc.connection import go2 from dimos.utils.logging_config import setup_logger @@ -38,22 +38,3 @@ def deploy(dimos: DimosCluster, ip: str): agent = agents2.deploy(dimos) agent.register_skills(detector) - - -if __name__ == "__main__": - import argparse - import os - - from dotenv import load_dotenv - - load_dotenv() - - parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") - parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") - - args = parser.parse_args() - - dimos = start(8) - deploy(dimos, args.ip) - wait_exit() - dimos.close_all() From bfe4689c7b36a2336d1dd7a5f4ee9e12fbefed04 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 09:46:08 +0300 Subject: [PATCH 24/40] camera cleanup --- dimos/hardware/camera/module.py | 14 +++--- dimos/robot/unitree_webrtc/unitree_go2.py | 55 +++++++---------------- 2 files changed, 21 insertions(+), 48 deletions(-) diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index e75f06a92d..18aff8d91b 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -15,7 +15,7 @@ import queue import time from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Literal, Optional, Protocol, TypeVar +from typing import Callable, Optional import reactivex as rx from dimos_lcm.sensor_msgs import CameraInfo @@ -24,12 +24,9 @@ from reactivex.observable import Observable from dimos.agents2 import Output, Reducer, Stream, skill -from dimos.core import Module, Out, rpc -from dimos.core.module import Module, ModuleConfig -from dimos.hardware.camera.spec import ( - CameraHardware, -) -from dimos.hardware.camera.webcam import Webcam, WebcamConfig +from dimos.core import Module, ModuleConfig, Out, rpc +from dimos.hardware.camera.spec import CameraHardware +from dimos.hardware.camera.webcam import Webcam from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier @@ -54,7 +51,7 @@ class CameraModule(Module): image: Out[Image] = None camera_info: Out[CameraInfo] = None - hardware: CameraHardware = None + hardware: Callable[[], CameraHardware] | CameraHardware = None _module_subscription: Optional[Disposable] = None _camera_info_subscription: Optional[Disposable] = None _skill_stream: Optional[Observable[Image]] = None @@ -117,6 +114,7 @@ def stop(self): if self._camera_info_subscription: self._camera_info_subscription.dispose() self._camera_info_subscription = None + # Also stop the hardware if it has a stop method if self.hardware and hasattr(self.hardware, "stop"): self.hardware.stop() diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index a3109e24f3..fbe9117c4a 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -22,6 +22,8 @@ import warnings from typing import Optional +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.std_msgs import Bool, String from reactivex import Observable from reactivex.disposable import CompositeDisposable @@ -31,44 +33,40 @@ from dimos.core.dimos import Dimos from dimos.core.resource import Resource from dimos.mapping.types import LatLon -from dimos.msgs.std_msgs import Header -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header from dimos.msgs.vision_msgs import Detection2DArray -from dimos_lcm.std_msgs import String -from dimos_lcm.sensor_msgs import CameraInfo -from dimos.perception.spatial_perception import SpatialMemory +from dimos.navigation.bbox_navigation import BBoxNavigationModule +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner from dimos.perception.common.utils import ( load_camera_info, load_camera_info_opencv, rectify_image, ) +from dimos.perception.object_tracker_2d import ObjectTracker2D +from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM, Topic from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.utils.monitoring import UtilizationModule -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.robot.robot import UnitreeRobot from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills from dimos.skills.skills import AbstractRobotSkill, SkillLibrary +from dimos.types.robot_capabilities import RobotCapability from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger +from dimos.utils.monitoring import UtilizationModule from dimos.utils.testing import TimedSensorReplay -from dimos.perception.object_tracker_2d import ObjectTracker2D -from dimos.navigation.bbox_navigation import BBoxNavigationModule -from dimos_lcm.std_msgs import Bool -from dimos.robot.robot import UnitreeRobot -from dimos.types.robot_capabilities import RobotCapability - +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule logger = setup_logger(__file__, level=logging.INFO) @@ -678,26 +676,3 @@ def get_odom(self) -> PoseStamped: The robot's odometry """ return self.connection.get_odom() - - -def main(): - """Main entry point.""" - ip = os.getenv("ROBOT_IP") - connection_type = os.getenv("CONNECTION_TYPE", "webrtc") - - pubsub.lcm.autoconf() - - robot = UnitreeGo2(ip=ip, websocket_port=7779, connection_type=connection_type) - robot.start() - - try: - while True: - time.sleep(0.1) - except KeyboardInterrupt: - pass - finally: - robot.stop() - - -if __name__ == "__main__": - main() From 5e9e2c3869e87b34fb88760fd22658a041376fcb Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 17:32:55 +0300 Subject: [PATCH 25/40] moved new run stuff to unitree/ --- dimos/conftest.py | 11 + dimos/navigation/test_rosnav.py | 12 +- dimos/perception/detection/conftest.py | 6 +- dimos/perception/detection/test_moduleDB.py | 6 +- dimos/robot/unitree/README.md | 25 -- dimos/robot/unitree/__init__.py | 0 dimos/robot/unitree/connection/__init__.py | 4 + dimos/robot/unitree/connection/connection.py | 418 ++++++++++++++++++ dimos/robot/unitree/connection/g1.py | 66 +++ dimos/robot/unitree/connection/go2.py | 288 ++++++++++++ dimos/robot/unitree/g1/g1agent.py | 48 ++ dimos/robot/unitree/g1/g1detector.py | 46 ++ dimos/robot/unitree/g1/g1zed.py | 70 +++ .../ivan_go2.py => unitree/go2/go2.py} | 5 +- dimos/robot/unitree/run.py | 95 ++++ dimos/robot/unitree/unitree_go2.py | 208 --------- dimos/robot/unitree/unitree_ros_control.py | 157 ------- dimos/robot/unitree/unitree_skills.py | 314 ------------- dimos/robot/unitree_webrtc/connection.py | 404 +++++++++++++++++ .../unitree_webrtc/connection/__init__.py | 3 + .../robot/unitree_webrtc/modular/__init__.py | 4 +- .../unitree_webrtc/modular/ivan_unitree.py | 4 +- dimos/robot/unitree_webrtc/unitree_go2.py | 55 ++- 23 files changed, 1511 insertions(+), 738 deletions(-) delete mode 100644 dimos/robot/unitree/README.md delete mode 100644 dimos/robot/unitree/__init__.py create mode 100644 dimos/robot/unitree/connection/__init__.py create mode 100644 dimos/robot/unitree/connection/connection.py create mode 100644 dimos/robot/unitree/connection/g1.py create mode 100644 dimos/robot/unitree/connection/go2.py create mode 100644 dimos/robot/unitree/g1/g1agent.py create mode 100644 dimos/robot/unitree/g1/g1detector.py create mode 100644 dimos/robot/unitree/g1/g1zed.py rename dimos/robot/{unitree_webrtc/modular/ivan_go2.py => unitree/go2/go2.py} (88%) create mode 100644 dimos/robot/unitree/run.py delete mode 100644 dimos/robot/unitree/unitree_go2.py delete mode 100644 dimos/robot/unitree/unitree_ros_control.py delete mode 100644 dimos/robot/unitree/unitree_skills.py create mode 100644 dimos/robot/unitree_webrtc/connection.py diff --git a/dimos/conftest.py b/dimos/conftest.py index 495afa8a24..cbfbcbcbf6 100644 --- a/dimos/conftest.py +++ b/dimos/conftest.py @@ -33,6 +33,17 @@ def event_loop(): _skip_for = ["lcm", "heavy", "ros"] +@pytest.fixture(scope="module") +def dimos_cluster(): + from dimos.core import start + + dimos = start(4) + try: + yield dimos + finally: + dimos.stop() + + @pytest.hookimpl() def pytest_sessionfinish(session): """Track threads that exist at session start - these are not leaks.""" diff --git a/dimos/navigation/test_rosnav.py b/dimos/navigation/test_rosnav.py index 5de1c0e6ab..bb803b783c 100644 --- a/dimos/navigation/test_rosnav.py +++ b/dimos/navigation/test_rosnav.py @@ -14,24 +14,26 @@ from typing import Protocol +import pytest + from dimos.mapping.spec import Global3DMapSpec -from dimos.navigation.rosnav import ROSNav from dimos.navigation.spec import NavSpec from dimos.perception.spec import PointcloudPerception class RosNavSpec(NavSpec, PointcloudPerception, Global3DMapSpec, Protocol): - """Combined protocol for navigation components.""" - pass def accepts_combined_protocol(nav: RosNavSpec) -> None: - """Function that accepts all navigation protocols at once.""" pass +# this is just a typing test; no runtime behavior is tested +@pytest.mark.skip def test_typing_prototypes(): - """Test that ROSNav correctly implements all required protocols.""" + from dimos.navigation.rosnav import ROSNav + rosnav = ROSNav() accepts_combined_protocol(rosnav) + rosnav.stop() diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index cdd15c1f92..e7812558ab 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -35,7 +35,7 @@ ImageDetections3DPC, ) from dimos.protocol.tf import TF -from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree_webrtc.connection import go2 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.data import get_data @@ -101,10 +101,10 @@ def moment_provider(**kwargs) -> Moment: if odom_frame is None: raise ValueError("No odom frame found") - transforms = ConnectionModule._odom_to_tf(odom_frame) + transforms = go2._odom_to_tf(odom_frame) tf.receive_transform(*transforms) - camera_info_out = ConnectionModule._camera_info() + camera_info_out = go2._camera_info() # ConnectionModule._camera_info() returns Out[CameraInfo], extract the value from typing import cast diff --git a/dimos/perception/detection/test_moduleDB.py b/dimos/perception/detection/test_moduleDB.py index 1ede53f172..4eec932dce 100644 --- a/dimos/perception/detection/test_moduleDB.py +++ b/dimos/perception/detection/test_moduleDB.py @@ -22,14 +22,13 @@ from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.protocol.service import lcmservice as lcm -from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation +from dimos.robot.unitree_webrtc.connection import go2 from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule @pytest.mark.module def test_moduleDB(dimos_cluster): - connection = deploy_connection(dimos_cluster) + connection = go2.deploy(dimos_cluster, "fake") moduleDB = dimos_cluster.deploy( ObjectDBModule, @@ -57,6 +56,5 @@ def test_moduleDB(dimos_cluster): moduleDB.start() time.sleep(4) - print("STARTING QUERY!!") print("VLM RES", moduleDB.navigate_to_object_in_view("white floor")) time.sleep(30) diff --git a/dimos/robot/unitree/README.md b/dimos/robot/unitree/README.md deleted file mode 100644 index 5ee389cb31..0000000000 --- a/dimos/robot/unitree/README.md +++ /dev/null @@ -1,25 +0,0 @@ -## Unitree Go2 ROS Control Setup - -Install unitree ros2 workspace as per instructions in https://github.com/dimensionalOS/go2_ros2_sdk/blob/master/README.md - -Run the following command to source the workspace and add dimos to the python path: - -``` -source /home/ros/unitree_ros2_ws/install/setup.bash - -export PYTHONPATH=/home/stash/dimensional/dimos:$PYTHONPATH -``` - -Run the following command to start the ROS control node: - -``` -ros2 launch go2_robot_sdk robot.launch.py -``` - -Run the following command to start the agent: - -``` -python3 dimos/robot/unitree/run_go2_ros.py -``` - - diff --git a/dimos/robot/unitree/__init__.py b/dimos/robot/unitree/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/dimos/robot/unitree/connection/__init__.py b/dimos/robot/unitree/connection/__init__.py new file mode 100644 index 0000000000..5c1dff1922 --- /dev/null +++ b/dimos/robot/unitree/connection/__init__.py @@ -0,0 +1,4 @@ +import dimos.robot.unitree.connection.g1 as g1 +import dimos.robot.unitree.connection.go2 as go2 + +__all__ = ["g1", "go2"] diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py new file mode 100644 index 0000000000..6fc2657318 --- /dev/null +++ b/dimos/robot/unitree/connection/connection.py @@ -0,0 +1,418 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import functools +import threading +import time +from dataclasses import dataclass +from typing import Literal, Optional, TypeAlias + +import numpy as np +from aiortc import MediaStreamTrack +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import DimosCluster, In, Module, Out, rpc +from dimos.core.resource import Resource +from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.reactive import backpressure, callback_to_observable + +VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] + + +@dataclass +class SerializableVideoFrame: + """Pickleable wrapper for av.VideoFrame with all metadata""" + + data: np.ndarray + pts: Optional[int] = None + time: Optional[float] = None + dts: Optional[int] = None + width: Optional[int] = None + height: Optional[int] = None + format: Optional[str] = None + + @classmethod + def from_av_frame(cls, frame): + return cls( + data=frame.to_ndarray(format="rgb24"), + pts=frame.pts, + time=frame.time, + dts=frame.dts, + width=frame.width, + height=frame.height, + format=frame.format.name if hasattr(frame, "format") and frame.format else None, + ) + + def to_ndarray(self, format=None): + return self.data + + +class UnitreeWebRTCConnection(Resource): + def __init__(self, ip: str, mode: str = "ai"): + self.ip = ip + self.mode = mode + self.stop_timer = None + self.cmd_vel_timeout = 0.2 + self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) + self.connect() + + def connect(self): + self.loop = asyncio.new_event_loop() + self.task = None + self.connected_event = asyncio.Event() + self.connection_ready = threading.Event() + + async def async_connect(): + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + + self.conn.datachannel.set_decoder(decoder_type="native") + + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} + ) + + self.connected_event.set() + self.connection_ready.set() + + while True: + await asyncio.sleep(1) + + def start_background_loop(): + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=start_background_loop, daemon=True) + self.thread.start() + self.connection_ready.wait() + + def start(self) -> None: + pass + + def stop(self) -> None: + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if self.task: + self.task.cancel() + + async def async_disconnect() -> None: + try: + await self.conn.disconnect() + except Exception: + pass + + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + self.loop.call_soon_threadsafe(self.loop.stop) + + if self.thread.is_alive(): + self.thread.join(timeout=2.0) + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send movement command to the robot using Twist commands. + + Args: + twist: Twist message with linear and angular velocities + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z + + # WebRTC coordinate mapping: + # x - Positive right, negative left + # y - positive forward, negative backwards + # yaw - Positive rotate right, negative rotate left + async def async_move(): + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, + ) + + async def async_move_duration(): + """Send movement commands continuously for the specified duration.""" + start_time = time.time() + sleep_time = 0.01 + + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(sleep_time) + + # Cancel existing timer and start a new one + if self.stop_timer: + self.stop_timer.cancel() + + # Auto-stop after 0.5 seconds if no new commands + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer.daemon = True + self.stop_timer.start() + + try: + if duration > 0: + # Send continuous move commands for the duration + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + # Stop after duration + self.stop() + else: + # Single command for continuous movement + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + print(f"Failed to send movement command: {e}") + return False + + # Generic conversion of unitree subscription to Subject (used for all subs) + def unitree_sub_stream(self, topic_name: str): + def subscribe_in_thread(cb): + # Run the subscription in the background thread that has the event loop + def run_subscription(): + self.conn.datachannel.pub_sub.subscribe(topic_name, cb) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_subscription) + + def unsubscribe_in_thread(cb): + # Run the unsubscription in the background thread that has the event loop + def run_unsubscription(): + self.conn.datachannel.pub_sub.unsubscribe(topic_name) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_unsubscription) + + return callback_to_observable( + start=subscribe_in_thread, + stop=unsubscribe_in_thread, + ) + + # Generic sync API call (we jump into the client thread) + def publish_request(self, topic: str, data: dict): + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + @simple_mcache + def raw_lidar_stream(self) -> Subject[LidarMessage]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @simple_mcache + def raw_odom_stream(self) -> Subject[Pose]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @simple_mcache + def lidar_stream(self) -> Subject[LidarMessage]: + return backpressure( + self.raw_lidar_stream().pipe( + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) + ) + ) + + @simple_mcache + def tf_stream(self) -> Subject[Transform]: + base_link = functools.partial(Transform.from_pose, "base_link") + return backpressure(self.odom_stream().pipe(ops.map(base_link))) + + @simple_mcache + def odom_stream(self) -> Subject[Pose]: + return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + + @simple_mcache + def video_stream(self) -> Observable[Image]: + return backpressure( + self.raw_video_stream().pipe( + ops.filter(lambda frame: frame is not None), + ops.map( + lambda frame: Image.from_numpy( + # np.ascontiguousarray(frame.to_ndarray("rgb24")), + frame.to_ndarray(format="rgb24"), + frame_id="camera_optical", + ) + ), + ) + ) + + @simple_mcache + def lowstate_stream(self) -> Subject[LowStateMsg]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) + + def standup_ai(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + + def standup_normal(self): + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + time.sleep(0.5) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) + return True + + @rpc + def standup(self): + if self.mode == "ai": + return self.standup_ai() + else: + return self.standup_normal() + + @rpc + def liedown(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + + async def handstand(self): + return self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, + ) + + @rpc + def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: + return self.publish_request( + RTC_TOPIC["VUI"], + { + "api_id": 1001, + "parameter": { + "color": color, + "time": colortime, + }, + }, + ) + + @simple_mcache + def raw_video_stream(self) -> Observable[VideoMessage]: + subject: Subject[VideoMessage] = Subject() + stop_event = threading.Event() + + async def accept_track(track: MediaStreamTrack) -> VideoMessage: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + serializable_frame = SerializableVideoFrame.from_av_frame(frame) + subject.on_next(serializable_frame) + + self.conn.video.add_track_callback(accept_track) + + # Run the video channel switching in the background thread + def switch_video_channel(): + self.conn.video.switchVideoChannel(True) + + self.loop.call_soon_threadsafe(switch_video_channel) + + def stop(): + stop_event.set() # Signal the loop to stop + self.conn.video.track_callbacks.remove(accept_track) + + # Run the video channel switching off in the background thread + def switch_video_channel_off(): + self.conn.video.switchVideoChannel(False) + + self.loop.call_soon_threadsafe(switch_video_channel_off) + + return subject.pipe(ops.finally_action(stop)) + + def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: + """Get the video stream from the robot's camera. + + Implements the AbstractRobot interface method. + + Args: + fps: Frames per second. This parameter is included for API compatibility, + but doesn't affect the actual frame rate which is determined by the camera. + + Returns: + Observable: An observable stream of video frames or None if video is not available. + """ + try: + print("Starting WebRTC video stream...") + stream = self.video_stream() + if stream is None: + print("Warning: Video stream is not available") + return stream + + except Exception as e: + print(f"Error getting video stream: {e}") + return None + + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + return self.move(Twist()) + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if hasattr(self, "task") and self.task: + self.task.cancel() + if hasattr(self, "conn"): + + async def async_disconnect(): + try: + await self.conn.disconnect() + except: + pass + + if hasattr(self, "loop") and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + if hasattr(self, "loop") and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if hasattr(self, "thread") and self.thread.is_alive(): + self.thread.join(timeout=2.0) + + +def deploy(dimos: DimosCluster, ip: str) -> None: + from dimos.robot.foxglove_bridge import FoxgloveBridge + + connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) + + bridge = FoxgloveBridge( + shm_channels=[ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + ] + ) + bridge.start() + connection.start() diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py new file mode 100644 index 0000000000..299631179a --- /dev/null +++ b/dimos/robot/unitree/connection/g1.py @@ -0,0 +1,66 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from dimos import spec +from dimos.core import DimosCluster, In, Module, rpc +from dimos.msgs.geometry_msgs import ( + Twist, + TwistStamped, +) +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection + + +class G1Connection(Module): + cmd_vel: In[TwistStamped] = None # type: ignore + ip: str + + def __init__(self, ip: Optional[str] = None, **kwargs): + super().__init__(**kwargs) + self.ip = ip + self.connection: Optional[UnitreeWebRTCConnection] = None + + @rpc + def start(self): + super().start() + self.connection = UnitreeWebRTCConnection(self.ip) + self.connection.start() + + self._disposables.add( + self.cmd_vel.subscribe(self.move), + ) + + @rpc + def stop(self) -> None: + self.connection.stop() + super().stop() + + @rpc + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + """Send movement command to robot.""" + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) + self.connection.move(twist, duration) + + @rpc + def publish_request(self, topic: str, data: dict): + """Forward WebRTC publish requests to connection.""" + return self.connection.publish_request(topic, data) + + +def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: + connection = dimos.deploy(G1Connection, ip) + connection.cmd_vel.connect(local_planner.cmd_vel) + connection.start() + return connection diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py new file mode 100644 index 0000000000..dade12cf0e --- /dev/null +++ b/dimos/robot/unitree/connection/go2.py @@ -0,0 +1,288 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time +from typing import List, Optional + +from dimos_lcm.sensor_msgs import CameraInfo + +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.msgs.geometry_msgs import ( + PoseStamped, + Quaternion, + Transform, + Twist, + TwistStamped, + Vector3, +) +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header +from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.data import get_data +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay + +logger = setup_logger(__file__, level=logging.INFO) + + +def _camera_info() -> CameraInfo: + fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) + width, height = (1280, 720) + + # Camera matrix K (3x3) + K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] + + # No distortion coefficients for now + D = [0.0, 0.0, 0.0, 0.0, 0.0] + + # Identity rotation matrix + R = [1, 0, 0, 0, 1, 0, 0, 0, 1] + + # Projection matrix P (3x4) + P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] + + base_msg = { + "D_length": len(D), + "height": height, + "width": width, + "distortion_model": "plumb_bob", + "D": D, + "K": K, + "R": R, + "P": P, + "binning_x": 0, + "binning_y": 0, + } + + return CameraInfo(**base_msg, header=Header("camera_optical")) + + +camera_info = _camera_info() + + +class FakeRTC(UnitreeWebRTCConnection): + dir_name = "unitree_go2_office_walk2" + + # we don't want UnitreeWebRTCConnection to init + def __init__( + self, + **kwargs, + ): + get_data(self.dir_name) + self.replay_config = { + "loop": kwargs.get("loop"), + "seek": kwargs.get("seek"), + "duration": kwargs.get("duration"), + } + + def connect(self): + pass + + def start(self): + pass + + def standup(self): + print("standup suppressed") + + def liedown(self): + print("liedown suppressed") + + @simple_mcache + def lidar_stream(self): + print("lidar stream start") + lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") + return lidar_store.stream(**self.replay_config) + + @simple_mcache + def odom_stream(self): + print("odom stream start") + odom_store = TimedSensorReplay(f"{self.dir_name}/odom") + return odom_store.stream(**self.replay_config) + + # we don't have raw video stream in the data set + @simple_mcache + def video_stream(self): + print("video stream start") + video_store = TimedSensorReplay(f"{self.dir_name}/video") + + return video_store.stream(**self.replay_config) + + def move(self, vector: Twist, duration: float = 0.0): + pass + + def publish_request(self, topic: str, data: dict): + """Fake publish request for testing.""" + return {"status": "ok", "message": "Fake publish"} + + +class GO2Connection(Module): + cmd_vel: In[Twist] = None + pointcloud: Out[LidarMessage] = None + image: Out[Image] = None + camera_info: Out[CameraInfo] = None + connection_type: str = "webrtc" + + ip: str + + def __init__( + self, + ip: Optional[str] = None, + connection_type: str = "webrtc", + rectify_image: bool = True, + *args, + **kwargs, + ): + self.ip = ip + self.connection: Optional[UnitreeWebRTCConnection] = None + Module.__init__(self, *args, **kwargs) + + @rpc + def start(self) -> None: + """Start the connection and subscribe to sensor streams.""" + super().start() + + match self.ip: + case None | "fake" | "": + self.connection = FakeRTC() + case "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection() + case _: + self.connection = UnitreeWebRTCConnection(self.ip) + + self.connection.start() + + self._disposables.add( + self.connection.lidar_stream().subscribe(self.pointcloud.publish), + ) + + self._disposables.add( + self.connection.odom_stream().subscribe(self._publish_tf), + ) + + self._disposables.add( + self.connection.video_stream().subscribe(self.image.publish), + ) + + self._disposables.add( + self.cmd_vel.subscribe(self.move), + ) + + # Start publishing camera info at 1 Hz + from threading import Thread + + self._camera_info_thread = Thread( + target=self.publish_camera_info, + daemon=True, + ) + self._camera_info_thread.start() + + self.standup() + + @rpc + def stop(self) -> None: + self.liedown() + if self.connection: + self.connection.stop() + if hasattr(self, "_camera_info_thread"): + self._camera_info_thread.join(timeout=1.0) + super().stop() + + @classmethod + def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: + camera_link = Transform( + translation=Vector3(0.3, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ts=odom.ts, + ) + + camera_optical = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), + frame_id="camera_link", + child_frame_id="camera_optical", + ts=odom.ts, + ) + + sensor = Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="world", + child_frame_id="sensor", + ts=odom.ts, + ) + + return [ + Transform.from_pose("base_link", odom), + camera_link, + camera_optical, + sensor, + ] + + def _publish_tf(self, msg): + self.tf.publish(*self._odom_to_tf(msg)) + + def publish_camera_info(self): + while True: + self.camera_info.publish(camera_info) + time.sleep(1.0) + + @rpc + def move(self, twist: Twist, duration: float = 0.0): + """Send movement command to robot.""" + self.connection.move(twist, duration) + + @rpc + def standup(self): + """Make the robot stand up.""" + return self.connection.standup() + + @rpc + def liedown(self): + """Make the robot lie down.""" + return self.connection.liedown() + + @rpc + def publish_request(self, topic: str, data: dict): + """Publish a request to the WebRTC connection. + Args: + topic: The RTC topic to publish to + data: The data dictionary to publish + Returns: + The result of the publish request + """ + return self.connection.publish_request(topic, data) + + +def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: + from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE + + connection = dimos.deploy(GO2Connection, ip) + + connection.pointcloud.transport = pSHMTransport( + f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + connection.image.transport = pSHMTransport( + f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE + ) + connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) + connection.start() + return connection diff --git a/dimos/robot/unitree/g1/g1agent.py b/dimos/robot/unitree/g1/g1agent.py new file mode 100644 index 0000000000..d537d41f65 --- /dev/null +++ b/dimos/robot/unitree/g1/g1agent.py @@ -0,0 +1,48 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos import agents2 +from dimos.agents2.skills.navigation import NavigationSkillContainer +from dimos.core import DimosCluster +from dimos.perception import spatial_perception +from dimos.robot.unitree.g1 import g1detector + + +def deploy(dimos: DimosCluster, ip: str) -> None: + g1 = g1detector.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + detector3d = g1.get("detector3d") + connection = g1.get("connection") + + spatialmem = spatial_perception.deploy(dimos, camera) + + navskills = dimos.deploy( + NavigationSkillContainer, + spatialmem, + nav, + detector3d, + ) + navskills.start() + + agent = agents2.deploy( + dimos, + "You are controling a humanoid robot", + skill_containers=[connection, nav, camera, spatialmem, navskills], + ) + agent.run_implicit_skill("current_position") + agent.run_implicit_skill("video_stream") + + return {"agent": agent, "spatialmem": spatialmem, **g1} diff --git a/dimos/robot/unitree/g1/g1detector.py b/dimos/robot/unitree/g1/g1detector.py new file mode 100644 index 0000000000..f7324f691b --- /dev/null +++ b/dimos/robot/unitree/g1/g1detector.py @@ -0,0 +1,46 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.core import DimosCluster +from dimos.perception.detection import module3D, moduleDB +from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector +from dimos.robot.unitree.g1 import g1zed + + +def deploy(dimos: DimosCluster, ip: str) -> None: + g1 = g1zed.deploy(dimos, ip) + + nav = g1.get("nav") + camera = g1.get("camera") + camerainfo = g1.get("camerainfo") + + person_detector = module3D.deploy( + dimos, + camerainfo, + camera=camera, + lidar=nav, + detector=YoloPersonDetector, + ) + + detector3d = moduleDB.deploy( + dimos, + camerainfo, + camera=camera, + lidar=nav, + filter=lambda det: det.class_id != 0, + ) + + # return {"detector3d": detector3d, **g1} + + return {"person_detector": person_detector, "detector3d": detector3d, **g1} diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py new file mode 100644 index 0000000000..1919eb3c49 --- /dev/null +++ b/dimos/robot/unitree/g1/g1zed.py @@ -0,0 +1,70 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE +from dimos.core import DimosCluster, LCMTransport, pSHMTransport +from dimos.hardware.camera import zed +from dimos.hardware.camera.module import CameraModule +from dimos.hardware.camera.webcam import Webcam +from dimos.msgs.geometry_msgs import ( + Quaternion, + Transform, + Vector3, +) +from dimos.msgs.sensor_msgs import CameraInfo +from dimos.navigation import rosnav +from dimos.robot import foxglove_bridge +from dimos.robot.unitree.connection import g1 +from dimos.utils.logging_config import setup_logger + +logger = setup_logger(__name__) + + +def deploy_g1_monozed(dimos) -> CameraModule: + camera = dimos.deploy( + CameraModule, + frequency=4.0, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=5, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), + ) + + camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) + camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + camera.start() + return camera + + +def deploy(dimos: DimosCluster, ip: str) -> None: + nav = rosnav.deploy(dimos) + connection = g1.deploy(dimos, ip, nav) + zedcam = deploy_g1_monozed(dimos) + + foxglove_bridge.deploy(dimos) + + return { + "nav": nav, + "connection": connection, + "camera": zedcam, + "camerainfo": zed.CameraInfo.SingleWebcam, + } diff --git a/dimos/robot/unitree_webrtc/modular/ivan_go2.py b/dimos/robot/unitree/go2/go2.py similarity index 88% rename from dimos/robot/unitree_webrtc/modular/ivan_go2.py rename to dimos/robot/unitree/go2/go2.py index d4d7a89704..251afdb5b3 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_go2.py +++ b/dimos/robot/unitree/go2/go2.py @@ -13,16 +13,15 @@ # limitations under the License. import logging -import time from dimos import agents2 from dimos.core import DimosCluster from dimos.perception.detection import moduleDB from dimos.robot import foxglove_bridge -from dimos.robot.unitree_webrtc.connection import go2 +from dimos.robot.unitree.connection import go2 from dimos.utils.logging_config import setup_logger -logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) +logger = setup_logger(__name__, level=logging.INFO) def deploy(dimos: DimosCluster, ip: str): diff --git a/dimos/robot/unitree/run.py b/dimos/robot/unitree/run.py new file mode 100644 index 0000000000..af822232f5 --- /dev/null +++ b/dimos/robot/unitree/run.py @@ -0,0 +1,95 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Centralized runner for modular G1 deployment scripts. + +Usage: + python run.py g1agent --ip 192.168.1.100 + python run.py g1zed + python run.py g1detector --ip $ROBOT_IP +""" + +import argparse +import importlib +import os +import sys + +from dotenv import load_dotenv + +from dimos.core import start, wait_exit + + +def main(): + load_dotenv() + + parser = argparse.ArgumentParser(description="Unitree G1 Modular Deployment Runner") + parser.add_argument( + "module", + help="Module name to run (e.g., g1agent, g1zed, g1detector)", + ) + parser.add_argument( + "--ip", + default=os.getenv("ROBOT_IP"), + help="Robot IP address (default: ROBOT_IP from .env)", + ) + parser.add_argument( + "--workers", + type=int, + default=8, + help="Number of worker threads for DimosCluster (default: 8)", + ) + + args = parser.parse_args() + + # Validate IP address + if not args.ip: + print("ERROR: Robot IP address not provided") + print("Please provide --ip or set ROBOT_IP in .env") + sys.exit(1) + + # Import the module + try: + # Try importing from current package first + module = importlib.import_module(f".{args.module}", package="dimos.robot.unitree.g1") + except ImportError as e: + import traceback + + traceback.print_exc() + + print(f"\nERROR: Could not import module '{args.module}'") + print(f"Make sure the module exists in dimos/robot/unitree/g1/") + print(f"Import error: {e}") + + sys.exit(1) + + # Verify deploy function exists + if not hasattr(module, "deploy"): + print(f"ERROR: Module '{args.module}' does not have a 'deploy' function") + sys.exit(1) + + print(f"Running {args.module}.deploy() with IP {args.ip}") + + # Run the standard deployment pattern + dimos = start(args.workers) + try: + module.deploy(dimos, args.ip) + wait_exit() + finally: + dimos.close_all() + + +if __name__ == "__main__": + main() diff --git a/dimos/robot/unitree/unitree_go2.py b/dimos/robot/unitree/unitree_go2.py deleted file mode 100644 index ca878e7134..0000000000 --- a/dimos/robot/unitree/unitree_go2.py +++ /dev/null @@ -1,208 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import multiprocessing -from typing import Optional, Union, List -import numpy as np -from dimos.robot.robot import Robot -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from reactivex.disposable import CompositeDisposable -import logging -import os -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from reactivex.scheduler import ThreadPoolScheduler -from dimos.utils.logging_config import setup_logger -from dimos.perception.person_tracker import PersonTrackingStream -from dimos.perception.object_tracker import ObjectTrackingStream -from dimos.robot.local_planner.local_planner import navigate_path_local -from dimos.robot.local_planner.vfh_local_planner import VFHPurePursuitPlanner -from dimos.robot.global_planner.planner import AstarPlanner -from dimos.types.costmap import Costmap -from dimos.types.robot_capabilities import RobotCapability -from dimos.types.vector import Vector - -# Set up logging -logger = setup_logger("dimos.robot.unitree.unitree_go2", level=logging.DEBUG) - -# UnitreeGo2 Print Colors (Magenta) -UNITREE_GO2_PRINT_COLOR = "\033[35m" -UNITREE_GO2_RESET_COLOR = "\033[0m" - - -class UnitreeGo2(Robot): - """Unitree Go2 robot implementation using ROS2 control interface. - - This class extends the base Robot class to provide specific functionality - for the Unitree Go2 quadruped robot using ROS2 for communication and control. - """ - - def __init__( - self, - video_provider=None, - output_dir: str = os.path.join(os.getcwd(), "assets", "output"), - skill_library: SkillLibrary = None, - robot_capabilities: List[RobotCapability] = None, - spatial_memory_collection: str = "spatial_memory", - new_memory: bool = False, - disable_video_stream: bool = False, - mock_connection: bool = False, - enable_perception: bool = True, - ): - """Initialize UnitreeGo2 robot with ROS control interface. - - Args: - video_provider: Provider for video streams - output_dir: Directory for output files - skill_library: Library of robot skills - robot_capabilities: List of robot capabilities - spatial_memory_collection: Collection name for spatial memory - new_memory: Whether to create new memory collection - disable_video_stream: Whether to disable video streaming - mock_connection: Whether to use mock connection for testing - enable_perception: Whether to enable perception streams and spatial memory - """ - # Create ROS control interface - ros_control = UnitreeROSControl( - node_name="unitree_go2", - video_provider=video_provider, - disable_video_stream=disable_video_stream, - mock_connection=mock_connection, - ) - - # Initialize skill library if not provided - if skill_library is None: - skill_library = MyUnitreeSkills() - - # Initialize base robot with connection interface - super().__init__( - connection_interface=ros_control, - output_dir=output_dir, - skill_library=skill_library, - capabilities=robot_capabilities - or [ - RobotCapability.LOCOMOTION, - RobotCapability.VISION, - RobotCapability.AUDIO, - ], - spatial_memory_collection=spatial_memory_collection, - new_memory=new_memory, - enable_perception=enable_perception, - ) - - if self.skill_library is not None: - for skill in self.skill_library: - if isinstance(skill, AbstractRobotSkill): - self.skill_library.create_instance(skill.__name__, robot=self) - if isinstance(self.skill_library, MyUnitreeSkills): - self.skill_library._robot = self - self.skill_library.init() - self.skill_library.initialize_skills() - - # Camera stuff - self.camera_intrinsics = [819.553492, 820.646595, 625.284099, 336.808987] - self.camera_pitch = np.deg2rad(0) # negative for downward pitch - self.camera_height = 0.44 # meters - - # Initialize UnitreeGo2-specific attributes - self.disposables = CompositeDisposable() - self.main_stream_obs = None - - # Initialize thread pool scheduler - self.optimal_thread_count = multiprocessing.cpu_count() - self.thread_pool_scheduler = ThreadPoolScheduler(self.optimal_thread_count // 2) - - # Initialize visual servoing if enabled - if not disable_video_stream: - self.video_stream_ros = self.get_video_stream(fps=8) - if enable_perception: - self.person_tracker = PersonTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - self.object_tracker = ObjectTrackingStream( - camera_intrinsics=self.camera_intrinsics, - camera_pitch=self.camera_pitch, - camera_height=self.camera_height, - ) - person_tracking_stream = self.person_tracker.create_stream(self.video_stream_ros) - object_tracking_stream = self.object_tracker.create_stream(self.video_stream_ros) - - self.person_tracking_stream = person_tracking_stream - self.object_tracking_stream = object_tracking_stream - else: - # Video stream is available but perception tracking is disabled - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - else: - # Video stream is disabled - self.video_stream_ros = None - self.person_tracker = None - self.object_tracker = None - self.person_tracking_stream = None - self.object_tracking_stream = None - - # Initialize the local planner and create BEV visualization stream - # Note: These features require ROS-specific methods that may not be available on all connection interfaces - if hasattr(self.connection_interface, "topic_latest") and hasattr( - self.connection_interface, "transform_euler" - ): - self.local_planner = VFHPurePursuitPlanner( - get_costmap=self.connection_interface.topic_latest( - "/local_costmap/costmap", Costmap - ), - transform=self.connection_interface, - move_vel_control=self.connection_interface.move_vel_control, - robot_width=0.36, # Unitree Go2 width in meters - robot_length=0.6, # Unitree Go2 length in meters - max_linear_vel=0.5, - lookahead_distance=2.0, - visualization_size=500, # 500x500 pixel visualization - ) - - self.global_planner = AstarPlanner( - conservativism=20, # how close to obstacles robot is allowed to path plan - set_local_nav=lambda path, stop_event=None, goal_theta=None: navigate_path_local( - self, path, timeout=120.0, goal_theta=goal_theta, stop_event=stop_event - ), - get_costmap=self.connection_interface.topic_latest("map", Costmap), - get_robot_pos=lambda: self.connection_interface.transform_euler_pos("base_link"), - ) - - # Create the visualization stream at 5Hz - self.local_planner_viz_stream = self.local_planner.create_stream(frequency_hz=5.0) - else: - self.local_planner = None - self.global_planner = None - self.local_planner_viz_stream = None - - def get_skills(self) -> Optional[SkillLibrary]: - return self.skill_library - - def get_pose(self) -> dict: - """ - Get the current pose (position and rotation) of the robot in the map frame. - - Returns: - Dictionary containing: - - position: Vector (x, y, z) - - rotation: Vector (roll, pitch, yaw) in radians - """ - position_tuple, orientation_tuple = self.connection_interface.get_pose_odom_transform() - position = Vector(position_tuple[0], position_tuple[1], position_tuple[2]) - rotation = Vector(orientation_tuple[0], orientation_tuple[1], orientation_tuple[2]) - return {"position": position, "rotation": rotation} diff --git a/dimos/robot/unitree/unitree_ros_control.py b/dimos/robot/unitree/unitree_ros_control.py deleted file mode 100644 index 56e83cb30f..0000000000 --- a/dimos/robot/unitree/unitree_ros_control.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from go2_interfaces.msg import Go2State, IMU -from unitree_go.msg import WebRtcReq -from typing import Type -from sensor_msgs.msg import Image, CompressedImage, CameraInfo -from dimos.robot.ros_control import ROSControl, RobotMode -from dimos.utils.logging_config import setup_logger - -logger = setup_logger("dimos.robot.unitree.unitree_ros_control") - - -class UnitreeROSControl(ROSControl): - """Hardware interface for Unitree Go2 robot using ROS2""" - - # ROS Camera Topics - CAMERA_TOPICS = { - "raw": {"topic": "camera/image_raw", "type": Image}, - "compressed": {"topic": "camera/compressed", "type": CompressedImage}, - "info": {"topic": "camera/camera_info", "type": CameraInfo}, - } - # Hard coded ROS Message types and Topic names for Unitree Go2 - DEFAULT_STATE_MSG_TYPE = Go2State - DEFAULT_IMU_MSG_TYPE = IMU - DEFAULT_WEBRTC_MSG_TYPE = WebRtcReq - DEFAULT_STATE_TOPIC = "go2_states" - DEFAULT_IMU_TOPIC = "imu" - DEFAULT_WEBRTC_TOPIC = "webrtc_req" - DEFAULT_CMD_VEL_TOPIC = "cmd_vel_out" - DEFAULT_POSE_TOPIC = "pose_cmd" - DEFAULT_ODOM_TOPIC = "odom" - DEFAULT_COSTMAP_TOPIC = "local_costmap/costmap" - DEFAULT_MAX_LINEAR_VELOCITY = 1.0 - DEFAULT_MAX_ANGULAR_VELOCITY = 2.0 - - # Hard coded WebRTC API parameters for Unitree Go2 - DEFAULT_WEBRTC_API_TOPIC = "rt/api/sport/request" - - def __init__( - self, - node_name: str = "unitree_hardware_interface", - state_topic: str = None, - imu_topic: str = None, - webrtc_topic: str = None, - webrtc_api_topic: str = None, - move_vel_topic: str = None, - pose_topic: str = None, - odom_topic: str = None, - costmap_topic: str = None, - state_msg_type: Type = None, - imu_msg_type: Type = None, - webrtc_msg_type: Type = None, - max_linear_velocity: float = None, - max_angular_velocity: float = None, - use_raw: bool = False, - debug: bool = False, - disable_video_stream: bool = False, - mock_connection: bool = False, - ): - """ - Initialize Unitree ROS control interface with default values for Unitree Go2 - - Args: - node_name: Name for the ROS node - state_topic: ROS Topic name for robot state (defaults to DEFAULT_STATE_TOPIC) - imu_topic: ROS Topic name for IMU data (defaults to DEFAULT_IMU_TOPIC) - webrtc_topic: ROS Topic for WebRTC commands (defaults to DEFAULT_WEBRTC_TOPIC) - cmd_vel_topic: ROS Topic for direct movement velocity commands (defaults to DEFAULT_CMD_VEL_TOPIC) - pose_topic: ROS Topic for pose commands (defaults to DEFAULT_POSE_TOPIC) - odom_topic: ROS Topic for odometry data (defaults to DEFAULT_ODOM_TOPIC) - costmap_topic: ROS Topic for local costmap data (defaults to DEFAULT_COSTMAP_TOPIC) - state_msg_type: ROS Message type for state data (defaults to DEFAULT_STATE_MSG_TYPE) - imu_msg_type: ROS message type for IMU data (defaults to DEFAULT_IMU_MSG_TYPE) - webrtc_msg_type: ROS message type for webrtc data (defaults to DEFAULT_WEBRTC_MSG_TYPE) - max_linear_velocity: Maximum linear velocity in m/s (defaults to DEFAULT_MAX_LINEAR_VELOCITY) - max_angular_velocity: Maximum angular velocity in rad/s (defaults to DEFAULT_MAX_ANGULAR_VELOCITY) - use_raw: Whether to use raw camera topics (defaults to False) - debug: Whether to enable debug logging - disable_video_stream: Whether to run without video stream for testing. - mock_connection: Whether to run without active ActionClient servers for testing. - """ - - logger.info("Initializing Unitree ROS control interface") - # Select which camera topics to use - active_camera_topics = None - if not disable_video_stream: - active_camera_topics = {"main": self.CAMERA_TOPICS["raw" if use_raw else "compressed"]} - - # Use default values if not provided - state_topic = state_topic or self.DEFAULT_STATE_TOPIC - imu_topic = imu_topic or self.DEFAULT_IMU_TOPIC - webrtc_topic = webrtc_topic or self.DEFAULT_WEBRTC_TOPIC - move_vel_topic = move_vel_topic or self.DEFAULT_CMD_VEL_TOPIC - pose_topic = pose_topic or self.DEFAULT_POSE_TOPIC - odom_topic = odom_topic or self.DEFAULT_ODOM_TOPIC - costmap_topic = costmap_topic or self.DEFAULT_COSTMAP_TOPIC - webrtc_api_topic = webrtc_api_topic or self.DEFAULT_WEBRTC_API_TOPIC - state_msg_type = state_msg_type or self.DEFAULT_STATE_MSG_TYPE - imu_msg_type = imu_msg_type or self.DEFAULT_IMU_MSG_TYPE - webrtc_msg_type = webrtc_msg_type or self.DEFAULT_WEBRTC_MSG_TYPE - max_linear_velocity = max_linear_velocity or self.DEFAULT_MAX_LINEAR_VELOCITY - max_angular_velocity = max_angular_velocity or self.DEFAULT_MAX_ANGULAR_VELOCITY - - super().__init__( - node_name=node_name, - camera_topics=active_camera_topics, - mock_connection=mock_connection, - state_topic=state_topic, - imu_topic=imu_topic, - state_msg_type=state_msg_type, - imu_msg_type=imu_msg_type, - webrtc_msg_type=webrtc_msg_type, - webrtc_topic=webrtc_topic, - webrtc_api_topic=webrtc_api_topic, - move_vel_topic=move_vel_topic, - pose_topic=pose_topic, - odom_topic=odom_topic, - costmap_topic=costmap_topic, - max_linear_velocity=max_linear_velocity, - max_angular_velocity=max_angular_velocity, - debug=debug, - ) - - # Unitree-specific RobotMode State update conditons - def _update_mode(self, msg: Go2State): - """ - Implementation of abstract method to update robot mode - - Logic: - - If progress is 0 and mode is 1, then state is IDLE - - If progress is 1 OR mode is NOT equal to 1, then state is MOVING - """ - # Direct access to protected instance variables from the parent class - mode = msg.mode - progress = msg.progress - - if progress == 0 and mode == 1: - self._mode = RobotMode.IDLE - logger.debug("Robot mode set to IDLE (progress=0, mode=1)") - elif progress == 1 or mode != 1: - self._mode = RobotMode.MOVING - logger.debug(f"Robot mode set to MOVING (progress={progress}, mode={mode})") - else: - self._mode = RobotMode.UNKNOWN - logger.debug(f"Robot mode set to UNKNOWN (progress={progress}, mode={mode})") diff --git a/dimos/robot/unitree/unitree_skills.py b/dimos/robot/unitree/unitree_skills.py deleted file mode 100644 index 5029123ed1..0000000000 --- a/dimos/robot/unitree/unitree_skills.py +++ /dev/null @@ -1,314 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -from typing import TYPE_CHECKING, List, Optional, Tuple, Union -import time -from pydantic import Field - -if TYPE_CHECKING: - from dimos.robot.robot import Robot, MockRobot -else: - Robot = "Robot" - MockRobot = "MockRobot" - -from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary -from dimos.types.constants import Colors -from dimos.types.vector import Vector - -# Module-level constant for Unitree ROS control definitions -UNITREE_ROS_CONTROLS: List[Tuple[str, int, str]] = [ - ("Damp", 1001, "Lowers the robot to the ground fully."), - ( - "BalanceStand", - 1002, - "Activates a mode that maintains the robot in a balanced standing position.", - ), - ( - "StandUp", - 1004, - "Commands the robot to transition from a sitting or prone position to a standing posture.", - ), - ( - "StandDown", - 1005, - "Instructs the robot to move from a standing position to a sitting or prone posture.", - ), - ( - "RecoveryStand", - 1006, - "Recovers the robot to a state from which it can take more commands. Useful to run after multiple dynamic commands like front flips.", - ), - # ( - # "Euler", - # 1007, - # "Adjusts the robot's orientation using Euler angles, providing precise control over its rotation.", - # ), - # ("Move", 1008, "Move the robot using velocity commands."), # Intentionally omitted - ("Sit", 1009, "Commands the robot to sit down from a standing or moving stance."), - # ( - # "RiseSit", - # 1010, - # "Commands the robot to rise back to a standing position from a sitting posture.", - # ), - # ( - # "SwitchGait", - # 1011, - # "Switches the robot's walking pattern or style dynamically, suitable for different terrains or speeds.", - # ), - # ("Trigger", 1012, "Triggers a specific action or custom routine programmed into the robot."), - # ( - # "BodyHeight", - # 1013, - # "Adjusts the height of the robot's body from the ground, useful for navigating various obstacles.", - # ), - # ( - # "FootRaiseHeight", - # 1014, - # "Controls how high the robot lifts its feet during movement, which can be adjusted for different surfaces.", - # ), - ( - "SpeedLevel", - 1015, - "Sets or adjusts the speed at which the robot moves, with various levels available for different operational needs.", - ), - ( - "ShakeHand", - 1016, - "Performs a greeting action, which could involve a wave or other friendly gesture.", - ), - ("Stretch", 1017, "Engages the robot in a stretching routine."), - # ( - # "TrajectoryFollow", - # 1018, - # "Directs the robot to follow a predefined trajectory, which could involve complex paths or maneuvers.", - # ), - # ( - # "ContinuousGait", - # 1019, - # "Enables a mode for continuous walking or running, ideal for long-distance travel.", - # ), - ("Content", 1020, "To display or trigger when the robot is happy."), - ("Wallow", 1021, "The robot falls onto its back and rolls around."), - ( - "Dance1", - 1022, - "Performs a predefined dance routine 1, programmed for entertainment or demonstration.", - ), - ("Dance2", 1023, "Performs another variant of a predefined dance routine 2."), - # ("GetBodyHeight", 1024, "Retrieves the current height of the robot's body from the ground."), - # ( - # "GetFootRaiseHeight", - # 1025, - # "Retrieves the current height at which the robot's feet are being raised during movement.", - # ), - # ("GetSpeedLevel", 1026, "Returns the current speed level at which the robot is operating."), - # ( - # "SwitchJoystick", - # 1027, - # "Toggles the control mode to joystick input, allowing for manual direction of the robot's movements.", - # ), - ( - "Pose", - 1028, - "Directs the robot to take a specific pose or stance, which could be used for tasks or performances.", - ), - ( - "Scrape", - 1029, - "Robot falls to its hind legs and makes scraping motions with its front legs.", - ), - ("FrontFlip", 1030, "Executes a front flip, a complex and dynamic maneuver."), - ("FrontJump", 1031, "Commands the robot to perform a forward jump."), - ( - "FrontPounce", - 1032, - "Initiates a pouncing movement forward, mimicking animal-like pouncing behavior.", - ), - # ("WiggleHips", 1033, "Causes the robot to wiggle its hips."), - # ( - # "GetState", - # 1034, - # "Retrieves the current operational state of the robot, including status reports or diagnostic information.", - # ), - # ( - # "EconomicGait", - # 1035, - # "Engages a more energy-efficient walking or running mode to conserve battery life.", - # ), - # ("FingerHeart", 1036, "Performs a finger heart gesture while on its hind legs."), - # ( - # "Handstand", - # 1301, - # "Commands the robot to perform a handstand, demonstrating balance and control.", - # ), - # ( - # "CrossStep", - # 1302, - # "Engages the robot in a cross-stepping routine, useful for complex locomotion or dance moves.", - # ), - # ( - # "OnesidedStep", - # 1303, - # "Commands the robot to perform a stepping motion that predominantly uses one side.", - # ), - # ( - # "Bound", - # 1304, - # "Initiates a bounding motion, similar to a light, repetitive hopping or leaping.", - # ), - # ( - # "LeadFollow", - # 1045, - # "Engages follow-the-leader behavior, where the robot follows a designated leader or follows a signal.", - # ), - # ("LeftFlip", 1042, "Executes a flip towards the left side."), - # ("RightFlip", 1043, "Performs a flip towards the right side."), - # ("Backflip", 1044, "Executes a backflip, a complex and dynamic maneuver."), -] - -# region MyUnitreeSkills - - -class MyUnitreeSkills(SkillLibrary): - """My Unitree Skills.""" - - _robot: Optional[Robot] = None - - @classmethod - def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): - """Add multiple skill classes as class attributes. - - Args: - skill_classes: List of skill classes to add - """ - if isinstance(skill_classes, list): - for skill_class in skill_classes: - setattr(cls, skill_class.__name__, skill_class) - else: - setattr(cls, skill_classes.__name__, skill_classes) - - def __init__(self, robot: Optional[Robot] = None): - super().__init__() - self._robot: Robot = None - - # Add dynamic skills to this class - self.register_skills(self.create_skills_live()) - - if robot is not None: - self._robot = robot - self.initialize_skills() - - def initialize_skills(self): - # Create the skills and add them to the list of skills - self.register_skills(self.create_skills_live()) - - # Provide the robot instance to each skill - for skill_class in self: - print( - f"{Colors.GREEN_PRINT_COLOR}Creating instance for skill: {skill_class}{Colors.RESET_COLOR}" - ) - self.create_instance(skill_class.__name__, robot=self._robot) - - # Refresh the class skills - self.refresh_class_skills() - - def create_skills_live(self) -> List[AbstractRobotSkill]: - # ================================================ - # Procedurally created skills - # ================================================ - class BaseUnitreeSkill(AbstractRobotSkill): - """Base skill for dynamic skill creation.""" - - def __call__(self): - string = f"{Colors.GREEN_PRINT_COLOR}This is a base skill, created for the specific skill: {self._app_id}{Colors.RESET_COLOR}" - print(string) - super().__call__() - if self._app_id is None: - raise RuntimeError( - f"{Colors.RED_PRINT_COLOR}" - f"No App ID provided to {self.__class__.__name__} Skill" - f"{Colors.RESET_COLOR}" - ) - else: - self._robot.webrtc_req(api_id=self._app_id) - string = f"{Colors.GREEN_PRINT_COLOR}{self.__class__.__name__} was successful: id={self._app_id}{Colors.RESET_COLOR}" - print(string) - return string - - skills_classes = [] - for name, app_id, description in UNITREE_ROS_CONTROLS: - skill_class = type( - name, # Name of the class - (BaseUnitreeSkill,), # Base classes - {"__doc__": description, "_app_id": app_id}, - ) - skills_classes.append(skill_class) - - return skills_classes - - # region Class-based Skills - - class Move(AbstractRobotSkill): - """Move the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Forward velocity (m/s).") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - super().__call__() - return self._robot.move(Vector(self.x, self.y, self.yaw), duration=self.duration) - - class Reverse(AbstractRobotSkill): - """Reverse the robot using direct velocity commands. Determine duration required based on user distance instructions.""" - - x: float = Field(..., description="Backward velocity (m/s). Positive values move backward.") - y: float = Field(default=0.0, description="Left/right velocity (m/s)") - yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") - duration: float = Field(default=0.0, description="How long to move (seconds).") - - def __call__(self): - super().__call__() - # Use move with negative x for backward movement - return self._robot.move(Vector(-self.x, self.y, self.yaw), duration=self.duration) - - class SpinLeft(AbstractRobotSkill): - """Spin the robot left using degree commands.""" - - degrees: float = Field(..., description="Distance to spin left in degrees") - - def __call__(self): - super().__call__() - return self._robot.spin(degrees=self.degrees) # Spinning left is positive degrees - - class SpinRight(AbstractRobotSkill): - """Spin the robot right using degree commands.""" - - degrees: float = Field(..., description="Distance to spin right in degrees") - - def __call__(self): - super().__call__() - return self._robot.spin(degrees=-self.degrees) # Spinning right is negative degrees - - class Wait(AbstractSkill): - """Wait for a specified amount of time.""" - - seconds: float = Field(..., description="Seconds to wait") - - def __call__(self): - time.sleep(self.seconds) - return f"Wait completed with length={self.seconds}s" diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py new file mode 100644 index 0000000000..8ddc77ac63 --- /dev/null +++ b/dimos/robot/unitree_webrtc/connection.py @@ -0,0 +1,404 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import functools +import threading +import time +from dataclasses import dataclass +from typing import Literal, Optional, TypeAlias + +import numpy as np +from aiortc import MediaStreamTrack +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject + +from dimos.core import In, Module, Out, rpc +from dimos.core.resource import Resource +from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.robot.connection_interface import ConnectionInterface +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.utils.decorators.decorators import simple_mcache +from dimos.utils.reactive import backpressure, callback_to_observable + +VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] + + +@dataclass +class SerializableVideoFrame: + """Pickleable wrapper for av.VideoFrame with all metadata""" + + data: np.ndarray + pts: Optional[int] = None + time: Optional[float] = None + dts: Optional[int] = None + width: Optional[int] = None + height: Optional[int] = None + format: Optional[str] = None + + @classmethod + def from_av_frame(cls, frame): + return cls( + data=frame.to_ndarray(format="rgb24"), + pts=frame.pts, + time=frame.time, + dts=frame.dts, + width=frame.width, + height=frame.height, + format=frame.format.name if hasattr(frame, "format") and frame.format else None, + ) + + def to_ndarray(self, format=None): + return self.data + + +class UnitreeWebRTCConnection(Resource): + def __init__(self, ip: str, mode: str = "ai"): + self.ip = ip + self.mode = mode + self.stop_timer = None + self.cmd_vel_timeout = 0.2 + self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) + self.connect() + + def connect(self): + self.loop = asyncio.new_event_loop() + self.task = None + self.connected_event = asyncio.Event() + self.connection_ready = threading.Event() + + async def async_connect(): + await self.conn.connect() + await self.conn.datachannel.disableTrafficSaving(True) + + self.conn.datachannel.set_decoder(decoder_type="native") + + await self.conn.datachannel.pub_sub.publish_request_new( + RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} + ) + + self.connected_event.set() + self.connection_ready.set() + + while True: + await asyncio.sleep(1) + + def start_background_loop(): + asyncio.set_event_loop(self.loop) + self.task = self.loop.create_task(async_connect()) + self.loop.run_forever() + + self.loop = asyncio.new_event_loop() + self.thread = threading.Thread(target=start_background_loop, daemon=True) + self.thread.start() + self.connection_ready.wait() + + def start(self) -> None: + pass + + def stop(self) -> None: + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if self.task: + self.task.cancel() + + async def async_disconnect() -> None: + try: + await self.conn.disconnect() + except Exception: + pass + + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + self.loop.call_soon_threadsafe(self.loop.stop) + + if self.thread.is_alive(): + self.thread.join(timeout=2.0) + + def move(self, twist: Twist, duration: float = 0.0) -> bool: + """Send movement command to the robot using Twist commands. + + Args: + twist: Twist message with linear and angular velocities + duration: How long to move (seconds). If 0, command is continuous + + Returns: + bool: True if command was sent successfully + """ + x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z + + # WebRTC coordinate mapping: + # x - Positive right, negative left + # y - positive forward, negative backwards + # yaw - Positive rotate right, negative rotate left + async def async_move(): + self.conn.datachannel.pub_sub.publish_without_callback( + RTC_TOPIC["WIRELESS_CONTROLLER"], + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, + ) + + async def async_move_duration(): + """Send movement commands continuously for the specified duration.""" + start_time = time.time() + sleep_time = 0.01 + + while time.time() - start_time < duration: + await async_move() + await asyncio.sleep(sleep_time) + + # Cancel existing timer and start a new one + if self.stop_timer: + self.stop_timer.cancel() + + # Auto-stop after 0.5 seconds if no new commands + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer.daemon = True + self.stop_timer.start() + + try: + if duration > 0: + # Send continuous move commands for the duration + future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) + future.result() + # Stop after duration + self.stop() + else: + # Single command for continuous movement + future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) + future.result() + return True + except Exception as e: + print(f"Failed to send movement command: {e}") + return False + + # Generic conversion of unitree subscription to Subject (used for all subs) + def unitree_sub_stream(self, topic_name: str): + def subscribe_in_thread(cb): + # Run the subscription in the background thread that has the event loop + def run_subscription(): + self.conn.datachannel.pub_sub.subscribe(topic_name, cb) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_subscription) + + def unsubscribe_in_thread(cb): + # Run the unsubscription in the background thread that has the event loop + def run_unsubscription(): + self.conn.datachannel.pub_sub.unsubscribe(topic_name) + + # Use call_soon_threadsafe to run in the background thread + self.loop.call_soon_threadsafe(run_unsubscription) + + return callback_to_observable( + start=subscribe_in_thread, + stop=unsubscribe_in_thread, + ) + + # Generic sync API call (we jump into the client thread) + def publish_request(self, topic: str, data: dict): + future = asyncio.run_coroutine_threadsafe( + self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop + ) + return future.result() + + @simple_mcache + def raw_lidar_stream(self) -> Subject[LidarMessage]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) + + @simple_mcache + def raw_odom_stream(self) -> Subject[Pose]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) + + @simple_mcache + def lidar_stream(self) -> Subject[LidarMessage]: + return backpressure( + self.raw_lidar_stream().pipe( + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) + ) + ) + + @simple_mcache + def tf_stream(self) -> Subject[Transform]: + base_link = functools.partial(Transform.from_pose, "base_link") + return backpressure(self.odom_stream().pipe(ops.map(base_link))) + + @simple_mcache + def odom_stream(self) -> Subject[Pose]: + return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) + + @simple_mcache + def video_stream(self) -> Observable[Image]: + return backpressure( + self.raw_video_stream().pipe( + ops.filter(lambda frame: frame is not None), + ops.map( + lambda frame: Image.from_numpy( + # np.ascontiguousarray(frame.to_ndarray("rgb24")), + frame.to_ndarray(format="rgb24"), + frame_id="camera_optical", + ) + ), + ) + ) + + @simple_mcache + def lowstate_stream(self) -> Subject[LowStateMsg]: + return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) + + def standup_ai(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) + + def standup_normal(self): + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) + time.sleep(0.5) + self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) + return True + + @rpc + def standup(self): + if self.mode == "ai": + return self.standup_ai() + else: + return self.standup_normal() + + @rpc + def liedown(self): + return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) + + async def handstand(self): + return self.publish_request( + RTC_TOPIC["SPORT_MOD"], + {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, + ) + + @rpc + def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: + return self.publish_request( + RTC_TOPIC["VUI"], + { + "api_id": 1001, + "parameter": { + "color": color, + "time": colortime, + }, + }, + ) + + @simple_mcache + def raw_video_stream(self) -> Observable[VideoMessage]: + subject: Subject[VideoMessage] = Subject() + stop_event = threading.Event() + + async def accept_track(track: MediaStreamTrack) -> VideoMessage: + while True: + if stop_event.is_set(): + return + frame = await track.recv() + serializable_frame = SerializableVideoFrame.from_av_frame(frame) + subject.on_next(serializable_frame) + + self.conn.video.add_track_callback(accept_track) + + # Run the video channel switching in the background thread + def switch_video_channel(): + self.conn.video.switchVideoChannel(True) + + self.loop.call_soon_threadsafe(switch_video_channel) + + def stop(): + stop_event.set() # Signal the loop to stop + self.conn.video.track_callbacks.remove(accept_track) + + # Run the video channel switching off in the background thread + def switch_video_channel_off(): + self.conn.video.switchVideoChannel(False) + + self.loop.call_soon_threadsafe(switch_video_channel_off) + + return subject.pipe(ops.finally_action(stop)) + + def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: + """Get the video stream from the robot's camera. + + Implements the AbstractRobot interface method. + + Args: + fps: Frames per second. This parameter is included for API compatibility, + but doesn't affect the actual frame rate which is determined by the camera. + + Returns: + Observable: An observable stream of video frames or None if video is not available. + """ + try: + print("Starting WebRTC video stream...") + stream = self.video_stream() + if stream is None: + print("Warning: Video stream is not available") + return stream + + except Exception as e: + print(f"Error getting video stream: {e}") + return None + + def stop(self) -> bool: + """Stop the robot's movement. + + Returns: + bool: True if stop command was sent successfully + """ + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + return self.move(Twist()) + + def disconnect(self) -> None: + """Disconnect from the robot and clean up resources.""" + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + + if hasattr(self, "task") and self.task: + self.task.cancel() + if hasattr(self, "conn"): + + async def async_disconnect(): + try: + await self.conn.disconnect() + except: + pass + + if hasattr(self, "loop") and self.loop.is_running(): + asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) + + if hasattr(self, "loop") and self.loop.is_running(): + self.loop.call_soon_threadsafe(self.loop.stop) + + if hasattr(self, "thread") and self.thread.is_alive(): + self.thread.join(timeout=2.0) diff --git a/dimos/robot/unitree_webrtc/connection/__init__.py b/dimos/robot/unitree_webrtc/connection/__init__.py index cd93ef78ac..2a6a983761 100644 --- a/dimos/robot/unitree_webrtc/connection/__init__.py +++ b/dimos/robot/unitree_webrtc/connection/__init__.py @@ -1 +1,4 @@ import dimos.robot.unitree_webrtc.connection.g1 as g1 +import dimos.robot.unitree_webrtc.connection.go2 as go2 + +__all__ = ["g1", "go2"] diff --git a/dimos/robot/unitree_webrtc/modular/__init__.py b/dimos/robot/unitree_webrtc/modular/__init__.py index 21d37d2dbd..d823cd796e 100644 --- a/dimos/robot/unitree_webrtc/modular/__init__.py +++ b/dimos/robot/unitree_webrtc/modular/__init__.py @@ -1,2 +1,2 @@ -# from dimos.robot.unitree_webrtc.modular.connection_module import deploy_connection -# from dimos.robot.unitree_webrtc.modular.navigation import deploy_navigation +from dimos.robot.unitree_webrtc.modular.connection_module import deploy_connection +from dimos.robot.unitree_webrtc.modular.navigation import deploy_navigation diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py index a9aa986e2e..948dccaa16 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -31,7 +31,7 @@ from dimos.perception.detection.reid import ReidModule from dimos.protocol.pubsub import lcm from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.unitree_webrtc.connection import go2 +from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule from dimos.utils.logging_config import setup_logger @@ -40,7 +40,7 @@ def detection_unitree(): dimos = start(8) - connection = go2.deploy(dimos) + connection = deploy_connection(dimos) def goto(pose): print("NAVIGATION REQUESTED:", pose) diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index fbe9117c4a..a3109e24f3 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -22,8 +22,6 @@ import warnings from typing import Optional -from dimos_lcm.sensor_msgs import CameraInfo -from dimos_lcm.std_msgs import Bool, String from reactivex import Observable from reactivex.disposable import CompositeDisposable @@ -33,40 +31,44 @@ from dimos.core.dimos import Dimos from dimos.core.resource import Resource from dimos.mapping.types import LatLon -from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 +from dimos.msgs.std_msgs import Header +from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header from dimos.msgs.vision_msgs import Detection2DArray -from dimos.navigation.bbox_navigation import BBoxNavigationModule -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos_lcm.std_msgs import String +from dimos_lcm.sensor_msgs import CameraInfo +from dimos.perception.spatial_perception import SpatialMemory from dimos.perception.common.utils import ( load_camera_info, load_camera_info_opencv, rectify_image, ) -from dimos.perception.object_tracker_2d import ObjectTracker2D -from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM, Topic from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.robot import UnitreeRobot +from dimos.utils.monitoring import UtilizationModule +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills from dimos.skills.skills import AbstractRobotSkill, SkillLibrary -from dimos.types.robot_capabilities import RobotCapability from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger -from dimos.utils.monitoring import UtilizationModule from dimos.utils.testing import TimedSensorReplay -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule +from dimos.perception.object_tracker_2d import ObjectTracker2D +from dimos.navigation.bbox_navigation import BBoxNavigationModule +from dimos_lcm.std_msgs import Bool +from dimos.robot.robot import UnitreeRobot +from dimos.types.robot_capabilities import RobotCapability + logger = setup_logger(__file__, level=logging.INFO) @@ -676,3 +678,26 @@ def get_odom(self) -> PoseStamped: The robot's odometry """ return self.connection.get_odom() + + +def main(): + """Main entry point.""" + ip = os.getenv("ROBOT_IP") + connection_type = os.getenv("CONNECTION_TYPE", "webrtc") + + pubsub.lcm.autoconf() + + robot = UnitreeGo2(ip=ip, websocket_port=7779, connection_type=connection_type) + robot.start() + + try: + while True: + time.sleep(0.1) + except KeyboardInterrupt: + pass + finally: + robot.stop() + + +if __name__ == "__main__": + main() From 510ea78de92e8faef873cefa736fd06af8dfcef7 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 18:06:30 +0300 Subject: [PATCH 26/40] typing fixes in progress --- dimos/robot/unitree/connection/connection.py | 42 ++++++------- dimos/robot/unitree/connection/g1.py | 18 +++--- dimos/robot/unitree/connection/go2.py | 66 +++++++++++--------- dimos/robot/unitree/g1/g1zed.py | 44 ++++++++----- dimos/robot/unitree_webrtc/type/lidar.py | 4 +- 5 files changed, 97 insertions(+), 77 deletions(-) diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py index 6fc2657318..c31bbf6d73 100644 --- a/dimos/robot/unitree/connection/connection.py +++ b/dimos/robot/unitree/connection/connection.py @@ -17,9 +17,10 @@ import threading import time from dataclasses import dataclass -from typing import Literal, Optional, TypeAlias +from typing import Literal, Optional, Type, TypeAlias import numpy as np +from numpy.typing import NDArray from aiortc import MediaStreamTrack from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] @@ -40,7 +41,7 @@ from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure, callback_to_observable -VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] +VideoMessage: TypeAlias = NDArray[np.uint8] # Shape: (height, width, 3) @dataclass @@ -75,7 +76,7 @@ class UnitreeWebRTCConnection(Resource): def __init__(self, ip: str, mode: str = "ai"): self.ip = ip self.mode = mode - self.stop_timer = None + self.stop_timer: Optional[threading.Timer] = None self.cmd_vel_timeout = 0.2 self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) self.connect() @@ -126,6 +127,7 @@ def stop(self) -> None: async def async_disconnect() -> None: try: + self.move(Twist()) await self.conn.disconnect() except Exception: pass @@ -225,28 +227,28 @@ def publish_request(self, topic: str, data: dict): return future.result() @simple_mcache - def raw_lidar_stream(self) -> Subject[LidarMessage]: + def raw_lidar_stream(self) -> Observable[LidarMessage]: return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) @simple_mcache - def raw_odom_stream(self) -> Subject[Pose]: + def raw_odom_stream(self) -> Observable[Pose]: return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) @simple_mcache - def lidar_stream(self) -> Subject[LidarMessage]: + def lidar_stream(self) -> Observable[LidarMessage]: return backpressure( self.raw_lidar_stream().pipe( - ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) + ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) # type: ignore[arg-type] ) ) @simple_mcache - def tf_stream(self) -> Subject[Transform]: + def tf_stream(self) -> Observable[Transform]: base_link = functools.partial(Transform.from_pose, "base_link") return backpressure(self.odom_stream().pipe(ops.map(base_link))) @simple_mcache - def odom_stream(self) -> Subject[Pose]: + def odom_stream(self) -> Observable[Pose]: return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) @simple_mcache @@ -257,7 +259,7 @@ def video_stream(self) -> Observable[Image]: ops.map( lambda frame: Image.from_numpy( # np.ascontiguousarray(frame.to_ndarray("rgb24")), - frame.to_ndarray(format="rgb24"), + frame.to_ndarray(format="rgb24"), # type: ignore[attr-defined] frame_id="camera_optical", ) ), @@ -265,7 +267,7 @@ def video_stream(self) -> Observable[Image]: ) @simple_mcache - def lowstate_stream(self) -> Subject[LowStateMsg]: + def lowstate_stream(self) -> Observable[LowStateMsg]: return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) def standup_ai(self): @@ -312,7 +314,7 @@ def raw_video_stream(self) -> Observable[VideoMessage]: subject: Subject[VideoMessage] = Subject() stop_event = threading.Event() - async def accept_track(track: MediaStreamTrack) -> VideoMessage: + async def accept_track(track: MediaStreamTrack) -> None: while True: if stop_event.is_set(): return @@ -352,16 +354,9 @@ def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: Returns: Observable: An observable stream of video frames or None if video is not available. """ - try: - print("Starting WebRTC video stream...") - stream = self.video_stream() - if stream is None: - print("Warning: Video stream is not available") - return stream - - except Exception as e: - print(f"Error getting video stream: {e}") - return None + print("Starting WebRTC video stream...") + stream = self.video_stream() + return stream def stop(self) -> bool: """Stop the robot's movement. @@ -373,8 +368,7 @@ def stop(self) -> bool: if self.stop_timer: self.stop_timer.cancel() self.stop_timer = None - - return self.move(Twist()) + return True def disconnect(self) -> None: """Disconnect from the robot and clean up resources.""" diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py index 299631179a..8437404852 100644 --- a/dimos/robot/unitree/connection/g1.py +++ b/dimos/robot/unitree/connection/g1.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Optional, cast from dimos import spec -from dimos.core import DimosCluster, In, Module, rpc +from dimos.core import DimosCluster, In, Module, RPCClient, rpc from dimos.msgs.geometry_msgs import ( Twist, TwistStamped, @@ -25,19 +25,21 @@ class G1Connection(Module): cmd_vel: In[TwistStamped] = None # type: ignore - ip: str + ip: Optional[str] + + connection: UnitreeWebRTCConnection def __init__(self, ip: Optional[str] = None, **kwargs): super().__init__(**kwargs) - self.ip = ip - self.connection: Optional[UnitreeWebRTCConnection] = None + + if ip is None: + raise ValueError("IP address must be provided for G1") + self.connection = UnitreeWebRTCConnection(ip) @rpc def start(self): super().start() - self.connection = UnitreeWebRTCConnection(self.ip) self.connection.start() - self._disposables.add( self.cmd_vel.subscribe(self.move), ) @@ -60,7 +62,7 @@ def publish_request(self, topic: str, data: dict): def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: - connection = dimos.deploy(G1Connection, ip) + connection = cast(G1Connection, dimos.deploy(G1Connection, ip)) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index dade12cf0e..8ea903df92 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -14,9 +14,11 @@ import logging import time -from typing import List, Optional +from threading import Thread +from typing import List, Optional, Protocol from dimos_lcm.sensor_msgs import CameraInfo +from reactivex.observable import Observable from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc from dimos.msgs.geometry_msgs import ( @@ -39,6 +41,20 @@ logger = setup_logger(__file__, level=logging.INFO) +class Go2ConnectionProtocol(Protocol): + """Protocol defining the interface for Go2 robot connections.""" + + def start(self) -> None: ... + def stop(self) -> None: ... + def lidar_stream(self) -> Observable: ... + def odom_stream(self) -> Observable: ... + def video_stream(self) -> Observable: ... + def move(self, twist: Twist, duration: float = 0.0) -> bool: ... + def standup(self) -> None: ... + def liedown(self) -> None: ... + def publish_request(self, topic: str, data: dict) -> dict: ... + + def _camera_info() -> CameraInfo: fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) width, height = (1280, 720) @@ -74,7 +90,7 @@ def _camera_info() -> CameraInfo: camera_info = _camera_info() -class FakeRTC(UnitreeWebRTCConnection): +class ReplayConnection(UnitreeWebRTCConnection): dir_name = "unitree_go2_office_walk2" # we don't want UnitreeWebRTCConnection to init @@ -130,24 +146,32 @@ def publish_request(self, topic: str, data: dict): class GO2Connection(Module): - cmd_vel: In[Twist] = None - pointcloud: Out[LidarMessage] = None - image: Out[Image] = None - camera_info: Out[CameraInfo] = None + cmd_vel: In[Twist] = None # type: ignore + pointcloud: Out[LidarMessage] = None # type: ignore + image: Out[Image] = None # type: ignore + camera_info: Out[CameraInfo] = None # type: ignore connection_type: str = "webrtc" - ip: str + connection: Go2ConnectionProtocol + + ip: Optional[str] def __init__( self, ip: Optional[str] = None, - connection_type: str = "webrtc", - rectify_image: bool = True, *args, **kwargs, ): - self.ip = ip - self.connection: Optional[UnitreeWebRTCConnection] = None + match ip: + case None | "fake" | "mock" | "replay": + self.connection = ReplayConnection() + case "mujoco": + from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection + + self.connection = MujocoConnection() + case _: + self.connection = UnitreeWebRTCConnection(ip) + Module.__init__(self, *args, **kwargs) @rpc @@ -155,16 +179,6 @@ def start(self) -> None: """Start the connection and subscribe to sensor streams.""" super().start() - match self.ip: - case None | "fake" | "": - self.connection = FakeRTC() - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection() - case _: - self.connection = UnitreeWebRTCConnection(self.ip) - self.connection.start() self._disposables.add( @@ -179,12 +193,7 @@ def start(self) -> None: self.connection.video_stream().subscribe(self.image.publish), ) - self._disposables.add( - self.cmd_vel.subscribe(self.move), - ) - - # Start publishing camera info at 1 Hz - from threading import Thread + self.cmd_vel.subscribe(self.move) self._camera_info_thread = Thread( target=self.publish_camera_info, @@ -285,4 +294,5 @@ def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) connection.start() - return connection + + return connection # type: ignore diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py index 1919eb3c49..5641691c20 100644 --- a/dimos/robot/unitree/g1/g1zed.py +++ b/dimos/robot/unitree/g1/g1zed.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any, Dict, TypedDict, cast + from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE -from dimos.core import DimosCluster, LCMTransport, pSHMTransport +from dimos.core import DimosCluster, LCMTransport, RPCClient, pSHMTransport from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam @@ -24,28 +26,40 @@ ) from dimos.msgs.sensor_msgs import CameraInfo from dimos.navigation import rosnav +from dimos.navigation.rosnav import ROSNav from dimos.robot import foxglove_bridge from dimos.robot.unitree.connection import g1 +from dimos.robot.unitree.connection.g1 import G1Connection from dimos.utils.logging_config import setup_logger logger = setup_logger(__name__) -def deploy_g1_monozed(dimos) -> CameraModule: - camera = dimos.deploy( +class G1ZedDeployResult(TypedDict): + nav: ROSNav + connection: G1Connection + camera: CameraModule + camerainfo: CameraInfo + + +def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: + camera = cast( CameraModule, - frequency=4.0, - transform=Transform( - translation=Vector3(0.05, 0.0, 0.0), - rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), - frame_id="sensor", - child_frame_id="camera_link", - ), - hardware=lambda: Webcam( - camera_index=0, - frequency=5, - stereo_slice="left", - camera_info=zed.CameraInfo.SingleWebcam, + dimos.deploy( + CameraModule, + frequency=4.0, + transform=Transform( + translation=Vector3(0.05, 0.0, 0.0), + rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), + frame_id="sensor", + child_frame_id="camera_link", + ), + hardware=lambda: Webcam( + camera_index=0, + frequency=5, + stereo_slice="left", + camera_info=zed.CameraInfo.SingleWebcam, + ), ), ) diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index aefd9654e1..30fe3c587e 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -14,7 +14,7 @@ import time from copy import copy -from typing import List, Optional, TypedDict +from typing import List, Optional, Type, TypedDict import numpy as np import open3d as o3d @@ -65,7 +65,7 @@ def __init__(self, **kwargs): self.resolution = kwargs.get("resolution", 0.05) @classmethod - def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": + def from_msg(cls: Type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": data = raw_message["data"] points = data["data"]["points"] pointcloud = o3d.geometry.PointCloud() From cf8b9134abd9d4994ca3fcde5efb3e11344e44d5 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 18:12:02 +0300 Subject: [PATCH 27/40] run.py fix --- dimos/robot/unitree/run.py | 53 ++++++++++++++++++++++++++------------ 1 file changed, 37 insertions(+), 16 deletions(-) diff --git a/dimos/robot/unitree/run.py b/dimos/robot/unitree/run.py index af822232f5..17f1226fd8 100644 --- a/dimos/robot/unitree/run.py +++ b/dimos/robot/unitree/run.py @@ -14,18 +14,20 @@ # limitations under the License. """ -Centralized runner for modular G1 deployment scripts. +Centralized runner for modular Unitree robot deployment scripts. Usage: python run.py g1agent --ip 192.168.1.100 - python run.py g1zed - python run.py g1detector --ip $ROBOT_IP + python run.py g1/g1zed --ip $ROBOT_IP + python run.py go2/go2.py --ip $ROBOT_IP + python run.py connection/g1.py --ip $ROBOT_IP """ import argparse import importlib import os import sys +from pathlib import Path from dotenv import load_dotenv @@ -35,10 +37,10 @@ def main(): load_dotenv() - parser = argparse.ArgumentParser(description="Unitree G1 Modular Deployment Runner") + parser = argparse.ArgumentParser(description="Unitree Robot Modular Deployment Runner") parser.add_argument( "module", - help="Module name to run (e.g., g1agent, g1zed, g1detector)", + help="Module name/path to run (e.g., g1agent, g1/g1zed, go2/go2.py)", ) parser.add_argument( "--ip", @@ -60,20 +62,39 @@ def main(): print("Please provide --ip or set ROBOT_IP in .env") sys.exit(1) - # Import the module - try: - # Try importing from current package first - module = importlib.import_module(f".{args.module}", package="dimos.robot.unitree.g1") - except ImportError as e: - import traceback + # Parse the module path + module_path = args.module - traceback.print_exc() + # Remove .py extension if present + if module_path.endswith(".py"): + module_path = module_path[:-3] - print(f"\nERROR: Could not import module '{args.module}'") - print(f"Make sure the module exists in dimos/robot/unitree/g1/") - print(f"Import error: {e}") + # Convert path separators to dots for import + module_path = module_path.replace("/", ".") - sys.exit(1) + # Import the module + try: + # Build the full import path + full_module_path = f"dimos.robot.unitree.{module_path}" + print(f"Importing module: {full_module_path}") + module = importlib.import_module(full_module_path) + except ImportError as e: + # Try as a relative import from the unitree package + try: + module = importlib.import_module(f".{module_path}", package="dimos.robot.unitree") + except ImportError as e2: + import traceback + + traceback.print_exc() + + print(f"\nERROR: Could not import module '{args.module}'") + print(f"Tried importing as:") + print(f" 1. {full_module_path}") + print(f" 2. Relative import from dimos.robot.unitree") + print(f"Make sure the module exists in dimos/robot/unitree/") + print(f"Import error: {e2}") + + sys.exit(1) # Verify deploy function exists if not hasattr(module, "deploy"): From 547a56d2e147483cb0b9394005de26a6d70c3fa5 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 18:48:54 +0300 Subject: [PATCH 28/40] type fixes finished --- dimos/core/__init__.py | 10 ++++--- dimos/core/dimos.py | 2 +- dimos/core/stream.py | 14 ++++++++++ dimos/hardware/camera/module.py | 9 +++++-- dimos/perception/detection/module3D.py | 3 +-- dimos/perception/detection/moduleDB.py | 3 +-- dimos/robot/unitree/connection/connection.py | 28 ++++++++++---------- dimos/robot/unitree/connection/g1.py | 4 +-- dimos/robot/unitree/connection/go2.py | 17 +++++++----- dimos/robot/unitree/g1/g1agent.py | 2 +- dimos/robot/unitree/g1/g1detector.py | 7 +---- dimos/robot/unitree/g1/g1zed.py | 5 ++-- dimos/robot/unitree/go2/go2.py | 1 - dimos/spec/perception.py | 9 ++++--- 14 files changed, 68 insertions(+), 46 deletions(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 9bc954f3b0..0bab0b8b84 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -2,13 +2,14 @@ import multiprocessing as mp import time -from typing import Any, Optional, Protocol +from typing import Any, Optional, Protocol, Type, TypeVar from dask.distributed import Client, LocalCluster from rich.console import Console import dimos.core.colors as colors from dimos.core.core import rpc +from dimos.core.dimos import Dimos from dimos.core.module import Module, ModuleBase, ModuleConfig from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.core.transport import ( @@ -155,6 +156,9 @@ def rpc_call(*args, **kwargs): return self.actor_instance.__getattr__(name) +T = TypeVar("T", bound="Module") + + class DimosCluster(Protocol): """Extended Dask Client with DimOS-specific methods. @@ -164,10 +168,10 @@ class DimosCluster(Protocol): def deploy( self, - actor_class: type, + actor_class: Type[T], *args: Any, **kwargs: Any, - ) -> RPCClient: + ) -> T: """Deploy an actor to the cluster and return an RPC client. Args: diff --git a/dimos/core/dimos.py b/dimos/core/dimos.py index d286284fec..be3ad11daa 100644 --- a/dimos/core/dimos.py +++ b/dimos/core/dimos.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Type, TypeVar +from typing import Cast, Optional, Type, TypeVar from dimos import core from dimos.core import DimosCluster, Module diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 0a7f5fb17c..672ea4316e 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -138,6 +138,11 @@ def __init__(self, *argv, **kwargs): def transport(self) -> Transport[T]: return self._transport + @transport.setter + def transport(self, value: Transport[T]) -> None: + # just for type checking + ... + @property def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY @@ -210,6 +215,15 @@ def transport(self) -> Transport[T]: self._transport = self.connection.transport return self._transport + @transport.setter + def transport(self, value: Transport[T]) -> None: + # just for type checking + ... + + def connect(self, value: Out[T]) -> None: + # just for type checking + ... + @property def state(self) -> State: # noqa: D401 return State.UNBOUND if self.owner is None else State.READY diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index 18aff8d91b..0dda51804d 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -23,6 +23,7 @@ from reactivex.disposable import Disposable from reactivex.observable import Observable +from dimos import spec from dimos.agents2 import Output, Reducer, Stream, skill from dimos.core import Module, ModuleConfig, Out, rpc from dimos.hardware.camera.spec import CameraHardware @@ -47,9 +48,9 @@ class CameraModuleConfig(ModuleConfig): frequency: float = 5.0 -class CameraModule(Module): +class CameraModule(Module, spec.Camera): image: Out[Image] = None - camera_info: Out[CameraInfo] = None + camera_info_stream: Out[CameraInfo] = None hardware: Callable[[], CameraHardware] | CameraHardware = None _module_subscription: Optional[Disposable] = None @@ -58,6 +59,10 @@ class CameraModule(Module): default_config = CameraModuleConfig + @property + def camera_info(self) -> CameraInfo: + return self.hardware.camera_info + @rpc def start(self): if callable(self.config.hardware): diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index 792acb1969..e46569c07a 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -197,7 +197,6 @@ def _publish_detections(self, detections: ImageDetections3DPC): def deploy( dimos: DimosCluster, - camera_info: CameraInfo, lidar: spec.Pointcloud, camera: spec.Camera, prefix: str = "/detector3d", @@ -205,7 +204,7 @@ def deploy( ) -> Detection3DModule: from dimos.core import LCMTransport - detector = dimos.deploy(Detection3DModule, camera_info=camera_info, **kwargs) + detector = dimos.deploy(Detection3DModule, camera_info=camera.camera_info, **kwargs) detector.image.connect(camera.image) detector.pointcloud.connect(lidar.pointcloud) diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 959e3a6138..df0ecda0d7 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -325,7 +325,6 @@ def __len__(self): def deploy( dimos: DimosCluster, - camera_info: CameraInfo, lidar: spec.Pointcloud, camera: spec.Camera, prefix: str = "/detectorDB", @@ -333,7 +332,7 @@ def deploy( ) -> Detection3DModule: from dimos.core import LCMTransport - detector = dimos.deploy(ObjectDBModule, camera_info=camera_info, **kwargs) + detector = dimos.deploy(ObjectDBModule, camera_info=camera.camera_info, **kwargs) detector.image.connect(camera.image) detector.pointcloud.connect(lidar.pointcloud) diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py index c31bbf6d73..327e4b8410 100644 --- a/dimos/robot/unitree/connection/connection.py +++ b/dimos/robot/unitree/connection/connection.py @@ -20,20 +20,20 @@ from typing import Literal, Optional, Type, TypeAlias import numpy as np -from numpy.typing import NDArray from aiortc import MediaStreamTrack from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] Go2WebRTCConnection, WebRTCConnectionMethod, ) +from numpy.typing import NDArray from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject -from dimos.core import DimosCluster, In, Module, Out, rpc +from dimos.core import rpc from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 +from dimos.msgs.geometry_msgs import Pose, Transform, Twist from dimos.msgs.sensor_msgs import Image from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg @@ -397,16 +397,16 @@ async def async_disconnect(): self.thread.join(timeout=2.0) -def deploy(dimos: DimosCluster, ip: str) -> None: - from dimos.robot.foxglove_bridge import FoxgloveBridge +# def deploy(dimos: DimosCluster, ip: str) -> None: +# from dimos.robot.foxglove_bridge import FoxgloveBridge - connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) +# connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) - bridge = FoxgloveBridge( - shm_channels=[ - "/image#sensor_msgs.Image", - "/lidar#sensor_msgs.PointCloud2", - ] - ) - bridge.start() - connection.start() +# bridge = FoxgloveBridge( +# shm_channels=[ +# "/image#sensor_msgs.Image", +# "/lidar#sensor_msgs.PointCloud2", +# ] +# ) +# bridge.start() +# connection.start() diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py index 8437404852..88386a59ed 100644 --- a/dimos/robot/unitree/connection/g1.py +++ b/dimos/robot/unitree/connection/g1.py @@ -15,7 +15,7 @@ from typing import Optional, cast from dimos import spec -from dimos.core import DimosCluster, In, Module, RPCClient, rpc +from dimos.core import DimosCluster, In, Module, rpc from dimos.msgs.geometry_msgs import ( Twist, TwistStamped, @@ -62,7 +62,7 @@ def publish_request(self, topic: str, data: dict): def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: - connection = cast(G1Connection, dimos.deploy(G1Connection, ip)) + connection = dimos.deploy(G1Connection, ip) connection.cmd_vel.connect(local_planner.cmd_vel) connection.start() return connection diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index 8ea903df92..7baf647b08 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -20,6 +20,7 @@ from dimos_lcm.sensor_msgs import CameraInfo from reactivex.observable import Observable +from dimos import spec from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc from dimos.msgs.geometry_msgs import ( PoseStamped, @@ -29,7 +30,7 @@ TwistStamped, Vector3, ) -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.std_msgs import Header from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -145,17 +146,19 @@ def publish_request(self, topic: str, data: dict): return {"status": "ok", "message": "Fake publish"} -class GO2Connection(Module): +class GO2Connection(Module, spec.Camera, spec.Pointcloud): cmd_vel: In[Twist] = None # type: ignore - pointcloud: Out[LidarMessage] = None # type: ignore + pointcloud: Out[PointCloud2] = None # type: ignore image: Out[Image] = None # type: ignore - camera_info: Out[CameraInfo] = None # type: ignore + camera_info_stream: Out[CameraInfo] = None # type: ignore connection_type: str = "webrtc" connection: Go2ConnectionProtocol ip: Optional[str] + camera_info: CameraInfo = camera_info + def __init__( self, ip: Optional[str] = None, @@ -291,8 +294,10 @@ def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: connection.image.transport = pSHMTransport( f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) - connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + + # connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) connection.start() - return connection # type: ignore + return connection diff --git a/dimos/robot/unitree/g1/g1agent.py b/dimos/robot/unitree/g1/g1agent.py index d537d41f65..826a3c4ad8 100644 --- a/dimos/robot/unitree/g1/g1agent.py +++ b/dimos/robot/unitree/g1/g1agent.py @@ -19,7 +19,7 @@ from dimos.robot.unitree.g1 import g1detector -def deploy(dimos: DimosCluster, ip: str) -> None: +def deploy(dimos: DimosCluster, ip: str): g1 = g1detector.deploy(dimos, ip) nav = g1.get("nav") diff --git a/dimos/robot/unitree/g1/g1detector.py b/dimos/robot/unitree/g1/g1detector.py index f7324f691b..b743aaac6e 100644 --- a/dimos/robot/unitree/g1/g1detector.py +++ b/dimos/robot/unitree/g1/g1detector.py @@ -18,16 +18,14 @@ from dimos.robot.unitree.g1 import g1zed -def deploy(dimos: DimosCluster, ip: str) -> None: +def deploy(dimos: DimosCluster, ip: str): g1 = g1zed.deploy(dimos, ip) nav = g1.get("nav") camera = g1.get("camera") - camerainfo = g1.get("camerainfo") person_detector = module3D.deploy( dimos, - camerainfo, camera=camera, lidar=nav, detector=YoloPersonDetector, @@ -35,12 +33,9 @@ def deploy(dimos: DimosCluster, ip: str) -> None: detector3d = moduleDB.deploy( dimos, - camerainfo, camera=camera, lidar=nav, filter=lambda det: det.class_id != 0, ) - # return {"detector3d": detector3d, **g1} - return {"person_detector": person_detector, "detector3d": detector3d, **g1} diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py index 5641691c20..6818ddd83e 100644 --- a/dimos/robot/unitree/g1/g1zed.py +++ b/dimos/robot/unitree/g1/g1zed.py @@ -64,12 +64,12 @@ def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: ) camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) - camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) + camera.camera_info_stream.transport = LCMTransport("/camera_info", CameraInfo) camera.start() return camera -def deploy(dimos: DimosCluster, ip: str) -> None: +def deploy(dimos: DimosCluster, ip: str): nav = rosnav.deploy(dimos) connection = g1.deploy(dimos, ip, nav) zedcam = deploy_g1_monozed(dimos) @@ -80,5 +80,4 @@ def deploy(dimos: DimosCluster, ip: str) -> None: "nav": nav, "connection": connection, "camera": zedcam, - "camerainfo": zed.CameraInfo.SingleWebcam, } diff --git a/dimos/robot/unitree/go2/go2.py b/dimos/robot/unitree/go2/go2.py index 251afdb5b3..0712a933df 100644 --- a/dimos/robot/unitree/go2/go2.py +++ b/dimos/robot/unitree/go2/go2.py @@ -30,7 +30,6 @@ def deploy(dimos: DimosCluster, ip: str): detector = moduleDB.deploy( dimos, - go2.camera_info, camera=connection, lidar=connection, ) diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 09a0d18524..774492106b 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -15,14 +15,17 @@ from typing import Protocol from dimos.core import Out -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.sensor_msgs import CameraInfo, PointCloud2 +from dimos.msgs.sensor_msgs import Image as ImageMsg class Image(Protocol): - image: Out[Image] + image: Out[ImageMsg] -class Camera(Image): ... +class Camera(Image): + camera_info: Out[CameraInfo] + _camera_info: CameraInfo class Pointcloud(Protocol): From 4f987eaf4606c4e70033f781f7723eeb692ebde8 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 19:00:49 +0300 Subject: [PATCH 29/40] import issue cleanup --- dimos/core/dimos.py | 10 +++--- dimos/mapping/spec.py | 31 ------------------- dimos/navigation/spec.py | 31 ------------------- dimos/navigation/test_rosnav.py | 6 ++-- .../unitree_webrtc/connection/__init__.py | 3 +- dimos/spec/__init__.py | 13 ++++++++ 6 files changed, 23 insertions(+), 71 deletions(-) delete mode 100644 dimos/mapping/spec.py delete mode 100644 dimos/navigation/spec.py diff --git a/dimos/core/dimos.py b/dimos/core/dimos.py index be3ad11daa..e9fd683d66 100644 --- a/dimos/core/dimos.py +++ b/dimos/core/dimos.py @@ -12,20 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Cast, Optional, Type, TypeVar +from typing import TYPE_CHECKING, cast, Optional, Type, TypeVar from dimos import core -from dimos.core import DimosCluster, Module from dimos.core.resource import Resource +if TYPE_CHECKING: + from dimos.core import DimosCluster, Module + T = TypeVar("T", bound="Module") class Dimos(Resource): - _client: Optional[DimosCluster] = None + _client: Optional["DimosCluster"] = None _n: Optional[int] = None _memory_limit: str = "auto" - _deployed_modules: dict[Type[Module], Module] = {} + _deployed_modules: dict[Type["Module"], "Module"] = {} def __init__(self, n: Optional[int] = None, memory_limit: str = "auto"): self._n = n diff --git a/dimos/mapping/spec.py b/dimos/mapping/spec.py deleted file mode 100644 index 3d82cea0cc..0000000000 --- a/dimos/mapping/spec.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Protocol - -from dimos.core import Out -from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import PointCloud2 - - -class Global3DMapSpec(Protocol): - global_pointcloud: Out[PointCloud2] - - -class GlobalMapSpec(Protocol): - global_map: Out[OccupancyGrid] - - -class GlobalCostmapSpec(Protocol): - global_costmap: Out[OccupancyGrid] diff --git a/dimos/navigation/spec.py b/dimos/navigation/spec.py deleted file mode 100644 index 69bfdac262..0000000000 --- a/dimos/navigation/spec.py +++ /dev/null @@ -1,31 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Protocol - -from dimos.core import In, Out -from dimos.msgs.geometry_msgs import PoseStamped, Twist -from dimos.msgs.nav_msgs import Path - - -class NavSpec(Protocol): - goal_req: In[PoseStamped] - goal_active: Out[PoseStamped] - path_active: Out[Path] - ctrl: Out[Twist] - - # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" - def navigate_to_target(self, target: PoseStamped) -> None: ... - - def stop_navigating(self) -> None: ... diff --git a/dimos/navigation/test_rosnav.py b/dimos/navigation/test_rosnav.py index bb803b783c..9b4b37c96e 100644 --- a/dimos/navigation/test_rosnav.py +++ b/dimos/navigation/test_rosnav.py @@ -16,12 +16,10 @@ import pytest -from dimos.mapping.spec import Global3DMapSpec -from dimos.navigation.spec import NavSpec -from dimos.perception.spec import PointcloudPerception +from dimos.spec import Global3DMap, Nav, Pointcloud -class RosNavSpec(NavSpec, PointcloudPerception, Global3DMapSpec, Protocol): +class RosNavSpec(Nav, Pointcloud, Global3DMap, Protocol): pass diff --git a/dimos/robot/unitree_webrtc/connection/__init__.py b/dimos/robot/unitree_webrtc/connection/__init__.py index 2a6a983761..603901a9ef 100644 --- a/dimos/robot/unitree_webrtc/connection/__init__.py +++ b/dimos/robot/unitree_webrtc/connection/__init__.py @@ -1,4 +1,5 @@ import dimos.robot.unitree_webrtc.connection.g1 as g1 import dimos.robot.unitree_webrtc.connection.go2 as go2 +from dimos.robot.unitree_webrtc.connection.connection import UnitreeWebRTCConnection -__all__ = ["g1", "go2"] +__all__ = ["g1", "go2", "UnitreeWebRTCConnection"] diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py index d7a18b190c..06b9b2243a 100644 --- a/dimos/spec/__init__.py +++ b/dimos/spec/__init__.py @@ -1,2 +1,15 @@ from dimos.spec.control import LocalPlanner +from dimos.spec.map import Global3DMap, GlobalCostmap, GlobalMap +from dimos.spec.nav import Nav from dimos.spec.perception import Camera, Image, Pointcloud + +__all__ = [ + "Image", + "Camera", + "Pointcloud", + "Global3DMap", + "GlobalMap", + "GlobalCostmap", + "LocalPlanner", + "Nav", +] From 2e98745d3d1fcc8fb2e7cb97ea0e6d245ae36f0f Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 19:01:09 +0300 Subject: [PATCH 30/40] spec files --- dimos/spec/control.py | 22 ++++++++++++++++++++++ dimos/spec/map.py | 31 +++++++++++++++++++++++++++++++ dimos/spec/nav.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 84 insertions(+) create mode 100644 dimos/spec/control.py create mode 100644 dimos/spec/map.py create mode 100644 dimos/spec/nav.py diff --git a/dimos/spec/control.py b/dimos/spec/control.py new file mode 100644 index 0000000000..405c10880d --- /dev/null +++ b/dimos/spec/control.py @@ -0,0 +1,22 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.geometry_msgs import Twist + + +class LocalPlanner(Protocol): + cmd_vel: Out[Twist] diff --git a/dimos/spec/map.py b/dimos/spec/map.py new file mode 100644 index 0000000000..c087d5f3fc --- /dev/null +++ b/dimos/spec/map.py @@ -0,0 +1,31 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import Out +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.msgs.sensor_msgs import PointCloud2 + + +class Global3DMap(Protocol): + global_pointcloud: Out[PointCloud2] + + +class GlobalMap(Protocol): + global_map: Out[OccupancyGrid] + + +class GlobalCostmap(Protocol): + global_costmap: Out[OccupancyGrid] diff --git a/dimos/spec/nav.py b/dimos/spec/nav.py new file mode 100644 index 0000000000..feb98aebf4 --- /dev/null +++ b/dimos/spec/nav.py @@ -0,0 +1,31 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from dimos.core import In, Out +from dimos.msgs.geometry_msgs import PoseStamped, Twist +from dimos.msgs.nav_msgs import Path + + +class Nav(Protocol): + goal_req: In[PoseStamped] + goal_active: Out[PoseStamped] + path_active: Out[Path] + ctrl: Out[Twist] + + # identity quaternion (Quaternion(0,0,0,1)) represents "no rotation requested" + def navigate_to_target(self, target: PoseStamped) -> None: ... + + def stop_navigating(self) -> None: ... From 3226db31aa6a60607c447952023dc120f1e6fbce Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 20:08:17 +0300 Subject: [PATCH 31/40] tests cleanuo, removed unitree_webrtc extra files --- dimos/agents2/skills/conftest.py | 16 +- dimos/agents2/skills/navigation.py | 5 +- dimos/agents2/skills/test_navigation.py | 1 - dimos/perception/detection/conftest.py | 17 +- dimos/perception/detection/module3D.py | 2 +- dimos/perception/detection/test_moduleDB.py | 5 +- .../unitree_webrtc/connection/__init__.py | 5 - .../unitree_webrtc/connection/connection.py | 419 ------------------ dimos/robot/unitree_webrtc/connection/g1.py | 63 --- dimos/robot/unitree_webrtc/connection/go2.py | 299 ------------- dimos/robot/unitree_webrtc/modular/g1agent.py | 48 -- .../unitree_webrtc/modular/g1detector.py | 46 -- dimos/robot/unitree_webrtc/modular/g1zed.py | 70 --- dimos/robot/unitree_webrtc/modular/run.py | 97 ---- dimos/robot/unitree_webrtc/type/lidar.py | 4 +- 15 files changed, 24 insertions(+), 1073 deletions(-) delete mode 100644 dimos/robot/unitree_webrtc/connection/__init__.py delete mode 100644 dimos/robot/unitree_webrtc/connection/connection.py delete mode 100644 dimos/robot/unitree_webrtc/connection/g1.py delete mode 100644 dimos/robot/unitree_webrtc/connection/go2.py delete mode 100644 dimos/robot/unitree_webrtc/modular/g1agent.py delete mode 100644 dimos/robot/unitree_webrtc/modular/g1detector.py delete mode 100644 dimos/robot/unitree_webrtc/modular/g1zed.py delete mode 100644 dimos/robot/unitree_webrtc/modular/run.py diff --git a/dimos/agents2/skills/conftest.py b/dimos/agents2/skills/conftest.py index 7ea89e320a..78524419ae 100644 --- a/dimos/agents2/skills/conftest.py +++ b/dimos/agents2/skills/conftest.py @@ -12,19 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial + import pytest import reactivex as rx -from functools import partial from reactivex.scheduler import ThreadPoolScheduler +from dimos.agents2.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.agents2.skills.gps_nav_skill import GpsNavSkillContainer from dimos.agents2.skills.navigation import NavigationSkillContainer -from dimos.agents2.skills.google_maps_skill_container import GoogleMapsSkillContainer from dimos.mapping.types import LatLon +from dimos.msgs.sensor_msgs import Image from dimos.robot.robot import GpsRobot from dimos.robot.unitree_webrtc.run_agents2 import SYSTEM_PROMPT from dimos.utils.data import get_data -from dimos.msgs.sensor_msgs import Image @pytest.fixture(autouse=True) @@ -65,8 +66,13 @@ def fake_gps_position_stream(): @pytest.fixture -def navigation_skill_container(fake_robot, fake_video_stream): - container = NavigationSkillContainer(fake_robot, fake_video_stream) +def fake_detection_module(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def navigation_skill_container(fake_robot, fake_video_stream, fake_detection_module): + container = NavigationSkillContainer(fake_robot, fake_video_stream, fake_detection_module) container.start() yield container container.stop() diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 938a8b2684..eebfcf6fac 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -22,7 +22,7 @@ from dimos.core.resource import Resource from dimos.models.qwen.video_query import BBox from dimos.models.vl.qwen import QwenVlModel -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.msgs.sensor_msgs import Image from dimos.navigation.bt_navigator.navigator import NavigatorState @@ -31,7 +31,6 @@ from dimos.robot.robot import UnitreeRobot from dimos.types.robot_location import RobotLocation from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler logger = setup_logger(__file__) @@ -144,7 +143,7 @@ def _navigate_by_tagged_location(self, query: str) -> Optional[str]: print("Found tagged location:", robot_location) goal_pose = PoseStamped( position=make_vector3(*robot_location.position), - orientation=euler_to_quaternion(make_vector3(*robot_location.rotation)), + orientation=Quaternion.from_euler(Vector3(*robot_location.rotation)), frame_id="map", ) diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index f90f8a2d19..ae456e3bcf 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -20,7 +20,6 @@ def test_stop_movement(fake_robot, create_navigation_agent): agent = create_navigation_agent(fixture="test_stop_movement.json") agent.query("stop") - fake_robot.stop_exploration.assert_called_once_with() diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index e7812558ab..69481c2fb0 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -13,7 +13,7 @@ # limitations under the License. import functools -from typing import Callable, Generator, Optional, TypedDict, Union +from typing import Callable, Generator, Optional, TypedDict import pytest from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations @@ -29,13 +29,12 @@ from dimos.perception.detection.moduleDB import ObjectDBModule from dimos.perception.detection.type import ( Detection2D, - Detection3D, Detection3DPC, ImageDetections2D, ImageDetections3DPC, ) from dimos.protocol.tf import TF -from dimos.robot.unitree_webrtc.connection import go2 +from dimos.robot.unitree.connection import go2 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.data import get_data @@ -101,11 +100,10 @@ def moment_provider(**kwargs) -> Moment: if odom_frame is None: raise ValueError("No odom frame found") - transforms = go2._odom_to_tf(odom_frame) + transforms = go2.GO2Connection._odom_to_tf(odom_frame) tf.receive_transform(*transforms) - camera_info_out = go2._camera_info() - # ConnectionModule._camera_info() returns Out[CameraInfo], extract the value + camera_info_out = go2.camera_info from typing import cast camera_info = cast(CameraInfo, camera_info_out) @@ -265,11 +263,8 @@ def object_db_module(get_moment): from dimos.perception.detection.detectors import Yolo2DDetector module2d = Detection2DModule(detector=lambda: Yolo2DDetector(device="cpu")) - module3d = Detection3DModule(camera_info=ConnectionModule._camera_info()) - moduleDB = ObjectDBModule( - camera_info=ConnectionModule._camera_info(), - goto=lambda obj_id: None, # No-op for testing - ) + module3d = Detection3DModule(camera_info=go2.camera_info) + moduleDB = ObjectDBModule(camera_info=go2.camera_info) # Process 5 frames to build up object history for i in range(5): diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index e46569c07a..c218704600 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -24,7 +24,7 @@ from dimos.agents2 import skill from dimos.core import DimosCluster, In, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.module2D import Detection2DModule from dimos.perception.detection.type import ( diff --git a/dimos/perception/detection/test_moduleDB.py b/dimos/perception/detection/test_moduleDB.py index 4eec932dce..97598b6ee2 100644 --- a/dimos/perception/detection/test_moduleDB.py +++ b/dimos/perception/detection/test_moduleDB.py @@ -22,8 +22,7 @@ from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.moduleDB import ObjectDBModule -from dimos.robot.unitree_webrtc.connection import go2 -from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule +from dimos.robot.unitree.connection import go2 @pytest.mark.module @@ -32,7 +31,7 @@ def test_moduleDB(dimos_cluster): moduleDB = dimos_cluster.deploy( ObjectDBModule, - camera_info=ConnectionModule._camera_info(), + camera_info=go2.camera_info, goto=lambda obj_id: print(f"Going to {obj_id}"), ) moduleDB.image.connect(connection.video) diff --git a/dimos/robot/unitree_webrtc/connection/__init__.py b/dimos/robot/unitree_webrtc/connection/__init__.py deleted file mode 100644 index 603901a9ef..0000000000 --- a/dimos/robot/unitree_webrtc/connection/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -import dimos.robot.unitree_webrtc.connection.g1 as g1 -import dimos.robot.unitree_webrtc.connection.go2 as go2 -from dimos.robot.unitree_webrtc.connection.connection import UnitreeWebRTCConnection - -__all__ = ["g1", "go2", "UnitreeWebRTCConnection"] diff --git a/dimos/robot/unitree_webrtc/connection/connection.py b/dimos/robot/unitree_webrtc/connection/connection.py deleted file mode 100644 index abfba92fa9..0000000000 --- a/dimos/robot/unitree_webrtc/connection/connection.py +++ /dev/null @@ -1,419 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import asyncio -import functools -import threading -import time -from dataclasses import dataclass -from typing import Literal, Optional, TypeAlias - -import numpy as np -from aiortc import MediaStreamTrack -from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR -from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] - Go2WebRTCConnection, - WebRTCConnectionMethod, -) -from reactivex import operators as ops -from reactivex.observable import Observable -from reactivex.subject import Subject - -from dimos.core import DimosCluster, In, Module, Out, rpc -from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs import Image -from dimos.robot.connection_interface import ConnectionInterface -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.decorators.decorators import simple_mcache -from dimos.utils.reactive import backpressure, callback_to_observable - -VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] - - -@dataclass -class SerializableVideoFrame: - """Pickleable wrapper for av.VideoFrame with all metadata""" - - data: np.ndarray - pts: Optional[int] = None - time: Optional[float] = None - dts: Optional[int] = None - width: Optional[int] = None - height: Optional[int] = None - format: Optional[str] = None - - @classmethod - def from_av_frame(cls, frame): - return cls( - data=frame.to_ndarray(format="rgb24"), - pts=frame.pts, - time=frame.time, - dts=frame.dts, - width=frame.width, - height=frame.height, - format=frame.format.name if hasattr(frame, "format") and frame.format else None, - ) - - def to_ndarray(self, format=None): - return self.data - - -class UnitreeWebRTCConnection(Resource): - def __init__(self, ip: str, mode: str = "ai"): - self.ip = ip - self.mode = mode - self.stop_timer = None - self.cmd_vel_timeout = 0.2 - self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) - self.connect() - - def connect(self): - self.loop = asyncio.new_event_loop() - self.task = None - self.connected_event = asyncio.Event() - self.connection_ready = threading.Event() - - async def async_connect(): - await self.conn.connect() - await self.conn.datachannel.disableTrafficSaving(True) - - self.conn.datachannel.set_decoder(decoder_type="native") - - await self.conn.datachannel.pub_sub.publish_request_new( - RTC_TOPIC["MOTION_SWITCHER"], {"api_id": 1002, "parameter": {"name": self.mode}} - ) - - self.connected_event.set() - self.connection_ready.set() - - while True: - await asyncio.sleep(1) - - def start_background_loop(): - asyncio.set_event_loop(self.loop) - self.task = self.loop.create_task(async_connect()) - self.loop.run_forever() - - self.loop = asyncio.new_event_loop() - self.thread = threading.Thread(target=start_background_loop, daemon=True) - self.thread.start() - self.connection_ready.wait() - - def start(self) -> None: - pass - - def stop(self) -> None: - # Cancel timer - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - if self.task: - self.task.cancel() - - async def async_disconnect() -> None: - try: - await self.conn.disconnect() - except Exception: - pass - - if self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - self.loop.call_soon_threadsafe(self.loop.stop) - - if self.thread.is_alive(): - self.thread.join(timeout=2.0) - - def move(self, twist: Twist, duration: float = 0.0) -> bool: - """Send movement command to the robot using Twist commands. - - Args: - twist: Twist message with linear and angular velocities - duration: How long to move (seconds). If 0, command is continuous - - Returns: - bool: True if command was sent successfully - """ - x, y, yaw = twist.linear.x, twist.linear.y, twist.angular.z - - # WebRTC coordinate mapping: - # x - Positive right, negative left - # y - positive forward, negative backwards - # yaw - Positive rotate right, negative rotate left - async def async_move(): - self.conn.datachannel.pub_sub.publish_without_callback( - RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, - ) - - async def async_move_duration(): - """Send movement commands continuously for the specified duration.""" - start_time = time.time() - sleep_time = 0.01 - - while time.time() - start_time < duration: - await async_move() - await asyncio.sleep(sleep_time) - - # Cancel existing timer and start a new one - if self.stop_timer: - self.stop_timer.cancel() - - # Auto-stop after 0.5 seconds if no new commands - self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) - self.stop_timer.daemon = True - self.stop_timer.start() - - try: - if duration > 0: - # Send continuous move commands for the duration - future = asyncio.run_coroutine_threadsafe(async_move_duration(), self.loop) - future.result() - # Stop after duration - self.stop() - else: - # Single command for continuous movement - future = asyncio.run_coroutine_threadsafe(async_move(), self.loop) - future.result() - return True - except Exception as e: - print(f"Failed to send movement command: {e}") - return False - - # Generic conversion of unitree subscription to Subject (used for all subs) - def unitree_sub_stream(self, topic_name: str): - def subscribe_in_thread(cb): - # Run the subscription in the background thread that has the event loop - def run_subscription(): - self.conn.datachannel.pub_sub.subscribe(topic_name, cb) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_subscription) - - def unsubscribe_in_thread(cb): - # Run the unsubscription in the background thread that has the event loop - def run_unsubscription(): - self.conn.datachannel.pub_sub.unsubscribe(topic_name) - - # Use call_soon_threadsafe to run in the background thread - self.loop.call_soon_threadsafe(run_unsubscription) - - return callback_to_observable( - start=subscribe_in_thread, - stop=unsubscribe_in_thread, - ) - - # Generic sync API call (we jump into the client thread) - def publish_request(self, topic: str, data: dict): - future = asyncio.run_coroutine_threadsafe( - self.conn.datachannel.pub_sub.publish_request_new(topic, data), self.loop - ) - return future.result() - - @simple_mcache - def raw_lidar_stream(self) -> Subject[LidarMessage]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ULIDAR_ARRAY"])) - - @simple_mcache - def raw_odom_stream(self) -> Subject[Pose]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["ROBOTODOM"])) - - @simple_mcache - def lidar_stream(self) -> Subject[LidarMessage]: - return backpressure( - self.raw_lidar_stream().pipe( - ops.map(lambda raw_frame: LidarMessage.from_msg(raw_frame, ts=time.time())) - ) - ) - - @simple_mcache - def tf_stream(self) -> Subject[Transform]: - base_link = functools.partial(Transform.from_pose, "base_link") - return backpressure(self.odom_stream().pipe(ops.map(base_link))) - - @simple_mcache - def odom_stream(self) -> Subject[Pose]: - return backpressure(self.raw_odom_stream().pipe(ops.map(Odometry.from_msg))) - - @simple_mcache - def video_stream(self) -> Observable[Image]: - return backpressure( - self.raw_video_stream().pipe( - ops.filter(lambda frame: frame is not None), - ops.map( - lambda frame: Image.from_numpy( - # np.ascontiguousarray(frame.to_ndarray("rgb24")), - frame.to_ndarray(format="rgb24"), - frame_id="camera_optical", - ) - ), - ) - ) - - @simple_mcache - def lowstate_stream(self) -> Subject[LowStateMsg]: - return backpressure(self.unitree_sub_stream(RTC_TOPIC["LOW_STATE"])) - - def standup_ai(self): - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) - - def standup_normal(self): - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) - time.sleep(0.5) - self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) - return True - - @rpc - def standup(self): - if self.mode == "ai": - return self.standup_ai() - else: - return self.standup_normal() - - @rpc - def liedown(self): - return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandDown"]}) - - async def handstand(self): - return self.publish_request( - RTC_TOPIC["SPORT_MOD"], - {"api_id": SPORT_CMD["Standup"], "parameter": {"data": True}}, - ) - - @rpc - def color(self, color: VUI_COLOR = VUI_COLOR.RED, colortime: int = 60) -> bool: - return self.publish_request( - RTC_TOPIC["VUI"], - { - "api_id": 1001, - "parameter": { - "color": color, - "time": colortime, - }, - }, - ) - - @simple_mcache - def raw_video_stream(self) -> Observable[VideoMessage]: - subject: Subject[VideoMessage] = Subject() - stop_event = threading.Event() - - async def accept_track(track: MediaStreamTrack) -> VideoMessage: - while True: - if stop_event.is_set(): - return - frame = await track.recv() - serializable_frame = SerializableVideoFrame.from_av_frame(frame) - subject.on_next(serializable_frame) - - self.conn.video.add_track_callback(accept_track) - - # Run the video channel switching in the background thread - def switch_video_channel(): - self.conn.video.switchVideoChannel(True) - - self.loop.call_soon_threadsafe(switch_video_channel) - - def stop(): - stop_event.set() # Signal the loop to stop - self.conn.video.track_callbacks.remove(accept_track) - - # Run the video channel switching off in the background thread - def switch_video_channel_off(): - self.conn.video.switchVideoChannel(False) - - self.loop.call_soon_threadsafe(switch_video_channel_off) - - return subject.pipe(ops.finally_action(stop)) - - def get_video_stream(self, fps: int = 30) -> Observable[VideoMessage]: - """Get the video stream from the robot's camera. - - Implements the AbstractRobot interface method. - - Args: - fps: Frames per second. This parameter is included for API compatibility, - but doesn't affect the actual frame rate which is determined by the camera. - - Returns: - Observable: An observable stream of video frames or None if video is not available. - """ - try: - print("Starting WebRTC video stream...") - stream = self.video_stream() - if stream is None: - print("Warning: Video stream is not available") - return stream - - except Exception as e: - print(f"Error getting video stream: {e}") - return None - - def stop(self) -> bool: - """Stop the robot's movement. - - Returns: - bool: True if stop command was sent successfully - """ - # Cancel timer since we're explicitly stopping - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - return self.move(Twist()) - - def disconnect(self) -> None: - """Disconnect from the robot and clean up resources.""" - # Cancel timer - if self.stop_timer: - self.stop_timer.cancel() - self.stop_timer = None - - if hasattr(self, "task") and self.task: - self.task.cancel() - if hasattr(self, "conn"): - - async def async_disconnect(): - try: - await self.conn.disconnect() - except: - pass - - if hasattr(self, "loop") and self.loop.is_running(): - asyncio.run_coroutine_threadsafe(async_disconnect(), self.loop) - - if hasattr(self, "loop") and self.loop.is_running(): - self.loop.call_soon_threadsafe(self.loop.stop) - - if hasattr(self, "thread") and self.thread.is_alive(): - self.thread.join(timeout=2.0) - - -def deploy(dimos: DimosCluster, ip: str) -> None: - from dimos.robot.foxglove_bridge import FoxgloveBridge - - connection = dimos.deploy(UnitreeWebRTCConnection, ip=ip) - - bridge = FoxgloveBridge( - shm_channels=[ - "/image#sensor_msgs.Image", - "/lidar#sensor_msgs.PointCloud2", - ] - ) - bridge.start() - connection.start() diff --git a/dimos/robot/unitree_webrtc/connection/g1.py b/dimos/robot/unitree_webrtc/connection/g1.py deleted file mode 100644 index 9b4e9a87fa..0000000000 --- a/dimos/robot/unitree_webrtc/connection/g1.py +++ /dev/null @@ -1,63 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos import spec -from dimos.core import DimosCluster, In, Module, rpc -from dimos.msgs.geometry_msgs import ( - Twist, - TwistStamped, -) -from dimos.robot.unitree_webrtc.connection.connection import UnitreeWebRTCConnection - - -class G1Connection(Module): - cmd_vel: In[TwistStamped] = None - ip: str - - def __init__(self, ip: str = None, **kwargs): - super().__init__(**kwargs) - self.ip = ip - - @rpc - def start(self): - super().start() - self.connection = UnitreeWebRTCConnection(self.ip) - self.connection.start() - - self._disposables.add( - self.cmd_vel.subscribe(self.move), - ) - - @rpc - def stop(self) -> None: - self.connection.stop() - super().stop() - - @rpc - def move(self, twist_stamped: TwistStamped, duration: float = 0.0): - """Send movement command to robot.""" - twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) - self.connection.move(twist, duration) - - @rpc - def publish_request(self, topic: str, data: dict): - """Forward WebRTC publish requests to connection.""" - return self.connection.publish_request(topic, data) - - -def deploy(dimos: DimosCluster, ip: str, local_planner: spec.LocalPlanner) -> G1Connection: - connection = dimos.deploy(G1Connection, ip) - connection.cmd_vel.connect(local_planner.cmd_vel) - connection.start() - return connection diff --git a/dimos/robot/unitree_webrtc/connection/go2.py b/dimos/robot/unitree_webrtc/connection/go2.py deleted file mode 100644 index 04eabc9884..0000000000 --- a/dimos/robot/unitree_webrtc/connection/go2.py +++ /dev/null @@ -1,299 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from typing import List, Optional, Protocol - -from dimos_lcm.sensor_msgs import CameraInfo -from reactivex.disposable import Disposable - -from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc -from dimos.msgs.geometry_msgs import ( - PoseStamped, - Quaternion, - Transform, - Twist, - TwistStamped, - Vector3, -) -from dimos.msgs.nav_msgs import OccupancyGrid, Path -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.robot.unitree_webrtc.connection.connection import UnitreeWebRTCConnection -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.map import Map -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from dimos.utils.data import get_data -from dimos.utils.decorators.decorators import simple_mcache -from dimos.utils.logging_config import setup_logger -from dimos.utils.testing import TimedSensorReplay - -logger = setup_logger(__file__, level=logging.INFO) - - -def _camera_info() -> CameraInfo: - fx, fy, cx, cy = (819.553492, 820.646595, 625.284099, 336.808987) - width, height = (1280, 720) - - # Camera matrix K (3x3) - K = [fx, 0, cx, 0, fy, cy, 0, 0, 1] - - # No distortion coefficients for now - D = [0.0, 0.0, 0.0, 0.0, 0.0] - - # Identity rotation matrix - R = [1, 0, 0, 0, 1, 0, 0, 0, 1] - - # Projection matrix P (3x4) - P = [fx, 0, cx, 0, 0, fy, cy, 0, 0, 0, 1, 0] - - base_msg = { - "D_length": len(D), - "height": height, - "width": width, - "distortion_model": "plumb_bob", - "D": D, - "K": K, - "R": R, - "P": P, - "binning_x": 0, - "binning_y": 0, - } - - return CameraInfo(**base_msg, header=Header("camera_optical")) - - -camera_info = _camera_info() - - -class FakeRTC(UnitreeWebRTCConnection): - dir_name = "unitree_go2_office_walk2" - - # we don't want UnitreeWebRTCConnection to init - def __init__( - self, - **kwargs, - ): - get_data(self.dir_name) - self.replay_config = { - "loop": kwargs.get("loop"), - "seek": kwargs.get("seek"), - "duration": kwargs.get("duration"), - } - - def connect(self): - pass - - def start(self): - pass - - def standup(self): - print("standup suppressed") - - def liedown(self): - print("liedown suppressed") - - @simple_mcache - def lidar_stream(self): - print("lidar stream start") - lidar_store = TimedSensorReplay(f"{self.dir_name}/lidar") - return lidar_store.stream(**self.replay_config) - - @simple_mcache - def odom_stream(self): - print("odom stream start") - odom_store = TimedSensorReplay(f"{self.dir_name}/odom") - return odom_store.stream(**self.replay_config) - - # we don't have raw video stream in the data set - @simple_mcache - def video_stream(self): - print("video stream start") - video_store = TimedSensorReplay(f"{self.dir_name}/video") - - return video_store.stream(**self.replay_config) - - def move(self, vector: Twist, duration: float = 0.0): - pass - - def publish_request(self, topic: str, data: dict): - """Fake publish request for testing.""" - return {"status": "ok", "message": "Fake publish"} - - -class GO2Connection(Module): - cmd_vel: In[Twist] = None - pointcloud: Out[LidarMessage] = None - image: Out[Image] = None - camera_info: Out[CameraInfo] = None - connection_type: str = "webrtc" - - ip: str - - def __init__( - self, - ip: str = None, - connection_type: str = "webrtc", - rectify_image: bool = True, - *args, - **kwargs, - ): - self.ip = ip - self.connection = None - Module.__init__(self, *args, **kwargs) - - @rpc - def start(self) -> None: - """Start the connection and subscribe to sensor streams.""" - super().start() - - match self.ip: - case None | "fake" | "": - self.connection = FakeRTC() - case "mujoco": - from dimos.robot.unitree_webrtc.mujoco_connection import MujocoConnection - - self.connection = MujocoConnection() - case _: - self.connection = UnitreeWebRTCConnection(self.ip) - - self.connection.start() - - self._disposables.add( - self.connection.lidar_stream().subscribe(self.pointcloud.publish), - ) - - self._disposables.add( - self.connection.odom_stream().subscribe(self._publish_tf), - ) - - self._disposables.add( - self.connection.video_stream().subscribe(self.image.publish), - ) - - self._disposables.add( - self.cmd_vel.subscribe(self.move), - ) - - # Start publishing camera info at 1 Hz - from threading import Thread - - self._camera_info_thread = Thread( - target=self.publish_camera_info, - daemon=True, - ) - self._camera_info_thread.start() - - @rpc - def stop(self) -> None: - if self.connection: - self.connection.stop() - if hasattr(self, "_camera_info_thread"): - self._camera_info_thread.join(timeout=1.0) - super().stop() - - @classmethod - def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: - camera_link = Transform( - translation=Vector3(0.3, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", - ts=odom.ts, - ) - - camera_optical = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(-0.5, 0.5, -0.5, 0.5), - frame_id="camera_link", - child_frame_id="camera_optical", - ts=odom.ts, - ) - - sensor = Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="world", - child_frame_id="sensor", - ts=odom.ts, - ) - - return [ - Transform.from_pose("base_link", odom), - camera_link, - camera_optical, - sensor, - ] - - def _publish_tf(self, msg): - self.tf.publish(*self._odom_to_tf(msg)) - - def publish_camera_info(self): - while True: - self.camera_info.publish(camera_info) - time.sleep(1.0) - - @rpc - def get_odom(self) -> Optional[PoseStamped]: - """Get the robot's odometry. - - Returns: - The robot's odometry - """ - return self._odom - - @rpc - def move(self, twist: Twist, duration: float = 0.0): - """Send movement command to robot.""" - self.connection.move(twist, duration) - - @rpc - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() - - @rpc - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() - - @rpc - def publish_request(self, topic: str, data: dict): - """Publish a request to the WebRTC connection. - Args: - topic: The RTC topic to publish to - data: The data dictionary to publish - Returns: - The result of the publish request - """ - return self.connection.publish_request(topic, data) - - -def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: - from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE - - connection = dimos.deploy(GO2Connection, ip) - - connection.pointcloud.transport = pSHMTransport( - f"{prefix}/lidar", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - connection.image.transport = pSHMTransport( - f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE - ) - connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) - connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) - connection.start() - return connection diff --git a/dimos/robot/unitree_webrtc/modular/g1agent.py b/dimos/robot/unitree_webrtc/modular/g1agent.py deleted file mode 100644 index 06da0ec950..0000000000 --- a/dimos/robot/unitree_webrtc/modular/g1agent.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos import agents2 -from dimos.agents2.skills.navigation import NavigationSkillContainer -from dimos.core import DimosCluster -from dimos.perception import spatial_perception -from dimos.robot.unitree_webrtc.modular import g1detector - - -def deploy(dimos: DimosCluster, ip: str) -> None: - g1 = g1detector.deploy(dimos, ip) - - nav = g1.get("nav") - camera = g1.get("camera") - detector3d = g1.get("detector3d") - connection = g1.get("connection") - - spatialmem = spatial_perception.deploy(dimos, camera) - - navskills = dimos.deploy( - NavigationSkillContainer, - spatialmem, - nav, - detector3d, - ) - navskills.start() - - agent = agents2.deploy( - dimos, - "You are controling a humanoid robot", - skill_containers=[connection, nav, camera, spatialmem, navskills], - ) - agent.run_implicit_skill("current_position") - agent.run_implicit_skill("video_stream") - - return {"agent": agent, "spatialmem": spatialmem, **g1} diff --git a/dimos/robot/unitree_webrtc/modular/g1detector.py b/dimos/robot/unitree_webrtc/modular/g1detector.py deleted file mode 100644 index d058c64825..0000000000 --- a/dimos/robot/unitree_webrtc/modular/g1detector.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.core import DimosCluster -from dimos.perception.detection import module3D, moduleDB -from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector -from dimos.robot.unitree_webrtc.modular import g1zed - - -def deploy(dimos: DimosCluster, ip: str) -> None: - g1 = g1zed.deploy(dimos, ip) - - nav = g1.get("nav") - camera = g1.get("camera") - camerainfo = g1.get("camerainfo") - - person_detector = module3D.deploy( - dimos, - camerainfo, - camera=camera, - lidar=nav, - detector=YoloPersonDetector, - ) - - detector3d = moduleDB.deploy( - dimos, - camerainfo, - camera=camera, - lidar=nav, - filter=lambda det: det.class_id != 0, - ) - - # return {"detector3d": detector3d, **g1} - - return {"person_detector": person_detector, "detector3d": detector3d, **g1} diff --git a/dimos/robot/unitree_webrtc/modular/g1zed.py b/dimos/robot/unitree_webrtc/modular/g1zed.py deleted file mode 100644 index c33d71e2ad..0000000000 --- a/dimos/robot/unitree_webrtc/modular/g1zed.py +++ /dev/null @@ -1,70 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE -from dimos.core import DimosCluster, LCMTransport, pSHMTransport, start, wait_exit -from dimos.hardware.camera import zed -from dimos.hardware.camera.module import CameraModule -from dimos.hardware.camera.webcam import Webcam -from dimos.msgs.geometry_msgs import ( - Quaternion, - Transform, - Vector3, -) -from dimos.msgs.sensor_msgs import CameraInfo -from dimos.navigation import rosnav -from dimos.robot import foxglove_bridge -from dimos.robot.unitree_webrtc.connection import g1 -from dimos.utils.logging_config import setup_logger - -logger = setup_logger(__name__) - - -def deploy_g1_monozed(dimos) -> CameraModule: - camera = dimos.deploy( - CameraModule, - frequency=4.0, - transform=Transform( - translation=Vector3(0.05, 0.0, 0.0), - rotation=Quaternion.from_euler(Vector3(0.0, 0.0, 0.0)), - frame_id="sensor", - child_frame_id="camera_link", - ), - hardware=lambda: Webcam( - camera_index=0, - frequency=5, - stereo_slice="left", - camera_info=zed.CameraInfo.SingleWebcam, - ), - ) - - camera.image.transport = pSHMTransport("/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE) - camera.camera_info.transport = LCMTransport("/camera_info", CameraInfo) - camera.start() - return camera - - -def deploy(dimos: DimosCluster, ip: str) -> None: - nav = rosnav.deploy(dimos) - connection = g1.deploy(dimos, ip, nav) - zedcam = deploy_g1_monozed(dimos) - - foxglove_bridge.deploy(dimos) - - return { - "nav": nav, - "connection": connection, - "camera": zedcam, - "camerainfo": zed.CameraInfo.SingleWebcam, - } diff --git a/dimos/robot/unitree_webrtc/modular/run.py b/dimos/robot/unitree_webrtc/modular/run.py deleted file mode 100644 index aa6ca2af14..0000000000 --- a/dimos/robot/unitree_webrtc/modular/run.py +++ /dev/null @@ -1,97 +0,0 @@ -#!/usr/bin/env python3 -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Centralized runner for modular G1 deployment scripts. - -Usage: - python run.py g1agent --ip 192.168.1.100 - python run.py g1zed - python run.py g1detector --ip $ROBOT_IP -""" - -import argparse -import importlib -import os -import sys - -from dotenv import load_dotenv - -from dimos.core import start, wait_exit - - -def main(): - load_dotenv() - - parser = argparse.ArgumentParser(description="Unitree G1 Modular Deployment Runner") - parser.add_argument( - "module", - help="Module name to run (e.g., g1agent, g1zed, g1detector)", - ) - parser.add_argument( - "--ip", - default=os.getenv("ROBOT_IP"), - help="Robot IP address (default: ROBOT_IP from .env)", - ) - parser.add_argument( - "--workers", - type=int, - default=8, - help="Number of worker threads for DimosCluster (default: 8)", - ) - - args = parser.parse_args() - - # Validate IP address - if not args.ip: - print("ERROR: Robot IP address not provided") - print("Please provide --ip or set ROBOT_IP in .env") - sys.exit(1) - - # Import the module - try: - # Try importing from current package first - module = importlib.import_module( - f".{args.module}", package="dimos.robot.unitree_webrtc.modular" - ) - except ImportError as e: - import traceback - - traceback.print_exc() - - print(f"\nERROR: Could not import module '{args.module}'") - print(f"Make sure the module exists in dimos/robot/unitree_webrtc/modular/") - print(f"Import error: {e}") - - sys.exit(1) - - # Verify deploy function exists - if not hasattr(module, "deploy"): - print(f"ERROR: Module '{args.module}' does not have a 'deploy' function") - sys.exit(1) - - print(f"Running {args.module}.deploy() with IP {args.ip}") - - # Run the standard deployment pattern - dimos = start(args.workers) - try: - module.deploy(dimos, args.ip) - wait_exit() - finally: - dimos.close_all() - - -if __name__ == "__main__": - main() diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index 30fe3c587e..aefd9654e1 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -14,7 +14,7 @@ import time from copy import copy -from typing import List, Optional, Type, TypedDict +from typing import List, Optional, TypedDict import numpy as np import open3d as o3d @@ -65,7 +65,7 @@ def __init__(self, **kwargs): self.resolution = kwargs.get("resolution", 0.05) @classmethod - def from_msg(cls: Type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": + def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": data = raw_message["data"] points = data["data"]["points"] pointcloud = o3d.geometry.PointCloud() From 0a6ac3b49bb5c923ab20079a7f7bbd6740db72f6 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 20:22:53 +0300 Subject: [PATCH 32/40] import fixes --- dimos/core/__init__.py | 1 - dimos/robot/foxglove_bridge.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 2af905c4e0..8d767779c4 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -10,7 +10,6 @@ import dimos.core.colors as colors from dimos.core.core import rpc -from dimos.core.dimos import Dimos from dimos.core.module import Module, ModuleBase, ModuleConfig from dimos.core.rpc_client import RPCClient from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index a847077cf8..fa87653624 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -15,6 +15,7 @@ import asyncio import logging import threading +from typing import List, Optional # this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm from dimos_lcm.foxglove_bridge import FoxgloveBridge as LCMFoxgloveBridge From 5ae2f97a77af3dc6f409737b5a2359e4c669c4f5 Mon Sep 17 00:00:00 2001 From: lesh Date: Fri, 24 Oct 2025 20:27:51 +0300 Subject: [PATCH 33/40] test fix --- dimos/agents2/skills/navigation.py | 6 +++--- dimos/agents2/skills/test_navigation.py | 3 +++ 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index c367a00d9f..09c6c074ba 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -352,12 +352,12 @@ def _navigate_using_semantic_map(self, query: str) -> str: return f"Successfuly arrived at '{query}'" - # @skill() + @skill() def follow_human(self, person: str) -> str: """Follow a specific person""" return "Not implemented yet." - # @skill() + @skill() def stop_movement(self) -> str: """Immediatly stop moving.""" @@ -438,7 +438,7 @@ def _get_goal_pose_from_result(self, result: dict[str, Any]) -> Optional[PoseSta return PoseStamped( position=make_vector3(pos_x, pos_y, 0), - orientation=euler_to_quaternion(make_vector3(0, 0, theta)), + orientation=Quaternion.from_euler(make_vector3(0, 0, theta)), frame_id="map", ) diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index 5d70fa2bc5..f612095f29 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -13,10 +13,13 @@ # limitations under the License. +import pytest + from dimos.msgs.geometry_msgs import PoseStamped, Vector3 from dimos.utils.transform_utils import euler_to_quaternion +# @pytest.mark.skip def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker): navigation_skill_container._cancel_goal = mocker.Mock() navigation_skill_container._stop_exploration = mocker.Mock() From 08bea73ccbccb0e0d92956c965fdf10e4ec83403 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 26 Oct 2025 15:53:03 +0200 Subject: [PATCH 34/40] tests pass --- dimos/navigation/rosnav.py | 19 ++++++------ dimos/navigation/test_rosnav.py | 37 ------------------------ dimos/robot/unitree/connection/go2.py | 1 - dimos/robot/unitree_webrtc/type/lidar.py | 2 +- 4 files changed, 11 insertions(+), 48 deletions(-) delete mode 100644 dimos/navigation/test_rosnav.py diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 487ceff89f..4121bcc345 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -21,7 +21,7 @@ import logging import threading import time -from typing import Optional +from typing import Generator, Optional import rclpy from geometry_msgs.msg import PointStamped as ROSPointStamped @@ -37,6 +37,7 @@ from std_msgs.msg import Int8 as ROSInt8 from tf2_msgs.msg import TFMessage as ROSTFMessage +from dimos import spec from dimos.agents2 import Output, Reducer, Stream, skill from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc from dimos.msgs.geometry_msgs import ( @@ -56,15 +57,15 @@ logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) -class ROSNav(Module): - goal_req: In[PoseStamped] = None +class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlanner): + goal_req: In[PoseStamped] = None # type: ignore - pointcloud: Out[PointCloud2] = None - global_pointcloud: Out[PointCloud2] = None + pointcloud: Out[PointCloud2] = None # type: ignore + global_pointcloud: Out[PointCloud2] = None # type: ignore - goal_active: Out[PoseStamped] = None - path_active: Out[Path] = None - cmd_vel: Out[TwistStamped] = None + goal_active: Out[PoseStamped] = None # type: ignore + path_active: Out[Path] = None # type: ignore + cmd_vel: Out[TwistStamped] = None # type: ignore _local_pointcloud: Optional[ROSPointCloud2] = None _global_pointcloud: Optional[ROSPointCloud2] = None @@ -283,7 +284,7 @@ def goto(self, x: float, y: float): yield "arrived" @skill(stream=Stream.call_agent, reducer=Reducer.string) - def goto_global(self, x: float, y: float) -> bool: + def goto_global(self, x: float, y: float) -> Generator[str, None, None]: """ go to coordinates x,y in the map frame 0,0 is your starting position diff --git a/dimos/navigation/test_rosnav.py b/dimos/navigation/test_rosnav.py deleted file mode 100644 index 9b4b37c96e..0000000000 --- a/dimos/navigation/test_rosnav.py +++ /dev/null @@ -1,37 +0,0 @@ -# Copyright 2025 Dimensional Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Protocol - -import pytest - -from dimos.spec import Global3DMap, Nav, Pointcloud - - -class RosNavSpec(Nav, Pointcloud, Global3DMap, Protocol): - pass - - -def accepts_combined_protocol(nav: RosNavSpec) -> None: - pass - - -# this is just a typing test; no runtime behavior is tested -@pytest.mark.skip -def test_typing_prototypes(): - from dimos.navigation.rosnav import ROSNav - - rosnav = ROSNav() - accepts_combined_protocol(rosnav) - rosnav.stop() diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index 7baf647b08..a0a2bd7a85 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -33,7 +33,6 @@ from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.std_msgs import Header from dimos.robot.unitree.connection.connection import UnitreeWebRTCConnection -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.data import get_data from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index aefd9654e1..e21c7ddd00 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -65,7 +65,7 @@ def __init__(self, **kwargs): self.resolution = kwargs.get("resolution", 0.05) @classmethod - def from_msg(cls: "LidarMessage", raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": + def from_msg(cls: type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> "LidarMessage": data = raw_message["data"] points = data["data"]["points"] pointcloud = o3d.geometry.PointCloud() From bb485179eae173632a6bbf072dbbe81558a39535 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 26 Oct 2025 16:37:37 +0200 Subject: [PATCH 35/40] standard configuration for rosnav --- dimos/msgs/geometry_msgs/Transform.py | 14 ++- dimos/navigation/rosnav.py | 151 +++++++++++++------------ dimos/perception/detection/module2D.py | 3 +- 3 files changed, 93 insertions(+), 75 deletions(-) diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 4db4c929a7..88ee8627ae 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -21,10 +21,10 @@ from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped try: - from geometry_msgs.msg import TransformStamped as ROSTransformStamped + from geometry_msgs.msg import Quaternion as ROSQuaternion from geometry_msgs.msg import Transform as ROSTransform + from geometry_msgs.msg import TransformStamped as ROSTransformStamped from geometry_msgs.msg import Vector3 as ROSVector3 - from geometry_msgs.msg import Quaternion as ROSQuaternion except ImportError: ROSTransformStamped = None ROSTransform = None @@ -60,6 +60,16 @@ def __init__( self.translation = translation if translation is not None else Vector3() self.rotation = rotation if rotation is not None else Quaternion() + def now(self) -> "Transform": + """Return a copy of this Transform with the current timestamp.""" + return Transform( + translation=self.translation, + rotation=self.rotation, + frame_id=self.frame_id, + child_frame_id=self.child_frame_id, + ts=time.time(), + ) + def __repr__(self) -> str: return f"Transform(translation={self.translation!r}, rotation={self.rotation!r})" diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index 4121bcc345..d74da612d8 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -21,6 +21,7 @@ import logging import threading import time +from dataclasses import dataclass from typing import Generator, Optional import rclpy @@ -31,6 +32,8 @@ from geometry_msgs.msg import TwistStamped as ROSTwistStamped from nav_msgs.msg import Path as ROSPath from rclpy.node import Node +from reactivex import operators as ops +from reactivex.subject import Subject from sensor_msgs.msg import Joy as ROSJoy from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 from std_msgs.msg import Bool as ROSBool @@ -38,8 +41,9 @@ from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos import spec -from dimos.agents2 import Output, Reducer, Stream, skill +from dimos.agents2 import Reducer, Stream, skill from dimos.core import DimosCluster, In, LCMTransport, Module, Out, pSHMTransport, rpc +from dimos.core.module import ModuleConfig from dimos.msgs.geometry_msgs import ( PoseStamped, Quaternion, @@ -57,7 +61,19 @@ logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) +@dataclass +class Config(ModuleConfig): + local_pointcloud_freq: float = 2.0 + global_pointcloud_freq: float = 1.0 + sensor_to_base_link_transform: Transform = Transform( + frame_id="sensor", child_frame_id="base_link" + ) + + class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlanner): + config: Config + default_config = Config + goal_req: In[PoseStamped] = None # type: ignore pointcloud: Out[PointCloud2] = None # type: ignore @@ -67,28 +83,25 @@ class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlan path_active: Out[Path] = None # type: ignore cmd_vel: Out[TwistStamped] = None # type: ignore - _local_pointcloud: Optional[ROSPointCloud2] = None - _global_pointcloud: Optional[ROSPointCloud2] = None + # Using RxPY Subjects for reactive data flow instead of storing state + _local_pointcloud_subject: Subject + _global_pointcloud_subject: Subject _current_position_running: bool = False + _spin_thread: Optional[threading.Thread] = None + _goal_reach: Optional[bool] = None - def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) + # Initialize RxPY Subjects for streaming data + self._local_pointcloud_subject = Subject() + self._global_pointcloud_subject = Subject() + if not rclpy.ok(): rclpy.init() - self._node = Node("navigation_module") - self.goal_reach = None - self.sensor_to_base_link_transform = sensor_to_base_link_transform or [ - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - ] - self.spin_thread = None + self._node = Node("navigation_module") # ROS2 Publishers self.goal_pose_pub = self._node.create_publisher(ROSPoseStamped, "/goal_pose", 10) @@ -122,32 +135,34 @@ def __init__(self, sensor_to_base_link_transform=None, *args, **kwargs): def start(self): self._running = True - # TODO these should be rxpy streams, rxpy has a way to convert callbacks to streams - def broadcast_lidar(): - while self._running: - if self._local_pointcloud: - self.pointcloud.publish( - PointCloud2.from_ros_msg(self._local_pointcloud), - ) - time.sleep(0.5) - - def broadcast_map(): - while self._running: - if self._global_pointcloud: - self.global_pointcloud.publish( - PointCloud2.from_ros_msg(self._global_pointcloud) - ) - time.sleep(1.0) - - self.map_broadcast_thread = threading.Thread(target=broadcast_map, daemon=True) - self.lidar_broadcast_thread = threading.Thread(target=broadcast_lidar, daemon=True) - self.map_broadcast_thread.start() - self.lidar_broadcast_thread.start() - self.spin_thread = threading.Thread(target=self._spin_node, daemon=True) - self.spin_thread.start() + self._disposables.add( + self._local_pointcloud_subject.pipe( + ops.sample(1.0 / self.config.local_pointcloud_freq), # Sample at desired frequency + ops.map(lambda msg: PointCloud2.from_ros_msg(msg)), + ).subscribe( + on_next=self.pointcloud.publish, + on_error=lambda e: logger.error(f"Lidar stream error: {e}"), + ) + ) + + self._disposables.add( + self._global_pointcloud_subject.pipe( + ops.sample(1.0 / self.config.global_pointcloud_freq), # Sample at desired frequency + ops.map(lambda msg: PointCloud2.from_ros_msg(msg)), + ).subscribe( + on_next=self.global_pointcloud.publish, + on_error=lambda e: logger.error(f"Map stream error: {e}"), + ) + ) + + # Create and start the spin thread for ROS2 node spinning + self._spin_thread = threading.Thread( + target=self._spin_node, daemon=True, name="ROS2SpinThread" + ) + self._spin_thread.start() self.goal_req.subscribe(self._on_goal_pose) - logger.info("NavigationModule started with ROS2 spinning") + logger.info("NavigationModule started with ROS2 spinning and RxPY streams") def _spin_node(self): while self._running and rclpy.ok(): @@ -158,7 +173,7 @@ def _spin_node(self): logger.error(f"ROS2 spin error: {e}") def _on_ros_goal_reached(self, msg: ROSBool): - self.goal_reach = msg.data + self._goal_reach = msg.data def _on_ros_goal_waypoint(self, msg: ROSPointStamped): dimos_pose = PoseStamped( @@ -173,10 +188,10 @@ def _on_ros_cmd_vel(self, msg: ROSTwistStamped): self.cmd_vel.publish(TwistStamped.from_ros_msg(msg)) def _on_ros_registered_scan(self, msg: ROSPointCloud2): - self._local_pointcloud = msg + self._local_pointcloud_subject.on_next(msg) def _on_ros_global_pointcloud(self, msg: ROSPointCloud2): - self._global_pointcloud = msg + self._global_pointcloud_subject.on_next(msg) def _on_ros_path(self, msg: ROSPath): dimos_path = Path.from_ros_msg(msg) @@ -186,26 +201,6 @@ def _on_ros_path(self, msg: ROSPath): def _on_ros_tf(self, msg: ROSTFMessage): ros_tf = TFMessage.from_ros_msg(msg) - translation = Vector3( - self.sensor_to_base_link_transform[0], - self.sensor_to_base_link_transform[1], - self.sensor_to_base_link_transform[2], - ) - euler_angles = Vector3( - self.sensor_to_base_link_transform[3], - self.sensor_to_base_link_transform[4], - self.sensor_to_base_link_transform[5], - ) - rotation = euler_to_quaternion(euler_angles) - - sensor_to_base_link_tf = Transform( - translation=translation, - rotation=rotation, - frame_id="sensor", - child_frame_id="base_link", - ts=time.time(), - ) - map_to_world_tf = Transform( translation=Vector3(0.0, 0.0, 0.0), rotation=euler_to_quaternion(Vector3(0.0, 0.0, 0.0)), @@ -214,7 +209,11 @@ def _on_ros_tf(self, msg: ROSTFMessage): ts=time.time(), ) - self.tf.publish(sensor_to_base_link_tf, map_to_world_tf, *ros_tf.transforms) + self.tf.publish( + self.config.sensor_to_base_link_transform.now(), + map_to_world_tf, + *ros_tf.transforms, + ) def _on_goal_pose(self, msg: PoseStamped): self.navigate_to(msg) @@ -320,7 +319,7 @@ def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: f"Navigating to goal: ({pose.position.x:.2f}, {pose.position.y:.2f}, {pose.position.z:.2f} @ {pose.frame_id})" ) - self.goal_reach = None + self._goal_reach = None self._set_autonomy_mode() # Enable soft stop (0 = enable) @@ -334,10 +333,10 @@ def navigate_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: # Wait for goal to be reached start_time = time.time() while time.time() - start_time < timeout: - if self.goal_reach is not None: + if self._goal_reach is not None: soft_stop_msg.data = 2 self.soft_stop_pub.publish(soft_stop_msg) - return self.goal_reach + return self._goal_reach time.sleep(0.1) self.stop_navigation() @@ -365,21 +364,31 @@ def stop_navigation(self) -> bool: @rpc def stop(self): + """Stop the navigation module and clean up resources.""" + self.stop_navigation() try: self._running = False - if self.spin_thread: - self.spin_thread.join(timeout=1) - self._node.destroy_node() + + self._local_pointcloud_subject.on_completed() + self._global_pointcloud_subject.on_completed() + + if self._spin_thread and self._spin_thread.is_alive(): + self._spin_thread.join(timeout=1.0) + + if hasattr(self, "_node") and self._node: + self._node.destroy_node() + except Exception as e: logger.error(f"Error during shutdown: {e}") + finally: + super().stop() def deploy(dimos: DimosCluster): nav = dimos.deploy(ROSNav) + nav.pointcloud.transport = pSHMTransport("/lidar") nav.global_pointcloud.transport = pSHMTransport("/map") - # nav.pointcloud.transport = LCMTransport("/lidar", PointCloud2) - # nav.global_pointcloud.transport = LCMTransport("/map", PointCloud2) nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) nav.goal_req.transport = LCMTransport("/goal_req", PoseStamped) diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index aec2850e3e..cc2790e0df 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -29,9 +29,8 @@ from dimos.msgs.sensor_msgs.Image import sharpness_barrier from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.detectors import Detector -from dimos.perception.detection.detectors.person.yolo import YoloPersonDetector from dimos.perception.detection.detectors.yolo import Yolo2DDetector -from dimos.perception.detection.type import Detection2D, Filter2D, ImageDetections2D +from dimos.perception.detection.type import Filter2D, ImageDetections2D from dimos.utils.decorators.decorators import simple_mcache from dimos.utils.reactive import backpressure From 1781f64d70d048f6e1f2a78b2c90978c95142bf9 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 26 Oct 2025 16:39:04 +0200 Subject: [PATCH 36/40] sensor transform for G1 head --- dimos/robot/unitree/g1/g1zed.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py index 6818ddd83e..c4fb9d90c1 100644 --- a/dimos/robot/unitree/g1/g1zed.py +++ b/dimos/robot/unitree/g1/g1zed.py @@ -70,7 +70,12 @@ def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: def deploy(dimos: DimosCluster, ip: str): - nav = rosnav.deploy(dimos) + nav = rosnav.deploy( + dimos, + sensor_to_base_link_transform=Transform( + frame_id="sensor", child_frame_id="base_link", translation=Vector3(0.0, 0.0, 1.5) + ), + ) connection = g1.deploy(dimos, ip, nav) zedcam = deploy_g1_monozed(dimos) From 0c3aa3023d290aad4c5471546c5298cd32de78e2 Mon Sep 17 00:00:00 2001 From: lesh Date: Sun, 26 Oct 2025 18:04:35 +0200 Subject: [PATCH 37/40] fixing mujoco, twiststamped --- dimos/perception/detection/moduleDB.py | 20 ----------- dimos/robot/unitree/connection/connection.py | 8 ++--- dimos/robot/unitree/connection/go2.py | 15 ++++---- dimos/robot/unitree/go2/go2.py | 15 ++++---- dimos/robot/unitree_webrtc/type/map.py | 13 ++++++- .../web/websocket_vis/websocket_vis_module.py | 36 ++++++++++++------- 6 files changed, 54 insertions(+), 53 deletions(-) diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index df0ecda0d7..620d15cec3 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -168,7 +168,6 @@ def update_objects(imageDetections: ImageDetections3DPC): def scene_thread(): while True: - print(self) scene_update = self.to_foxglove_scene_update() self.scene_update.publish(scene_update) time.sleep(1.0) @@ -268,25 +267,6 @@ def lookup(self, label: str) -> List[Detection3DPC]: """Look up a detection by label.""" return [] - @rpc - def start(self): - Detection3DModule.start(self) - - def update_objects(imageDetections: ImageDetections3DPC): - for detection in imageDetections.detections: - self.add_detection(detection) - - def scene_thread(): - while True: - print(self) - scene_update = self.to_foxglove_scene_update() - self.scene_update.publish(scene_update) - time.sleep(1.0) - - threading.Thread(target=scene_thread, daemon=True).start() - - self.detection_stream_3d.subscribe(update_objects) - @rpc def stop(self): return super().stop() diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py index 327e4b8410..fc9714c8ba 100644 --- a/dimos/robot/unitree/connection/connection.py +++ b/dimos/robot/unitree/connection/connection.py @@ -17,7 +17,7 @@ import threading import time from dataclasses import dataclass -from typing import Literal, Optional, Type, TypeAlias +from typing import Optional, TypeAlias import numpy as np from aiortc import MediaStreamTrack @@ -33,7 +33,7 @@ from dimos.core import rpc from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist +from dimos.msgs.geometry_msgs import Pose, Transform, TwistStamped from dimos.msgs.sensor_msgs import Image from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg @@ -127,7 +127,7 @@ def stop(self) -> None: async def async_disconnect() -> None: try: - self.move(Twist()) + self.move(TwistStamped()) await self.conn.disconnect() except Exception: pass @@ -140,7 +140,7 @@ async def async_disconnect() -> None: if self.thread.is_alive(): self.thread.join(timeout=2.0) - def move(self, twist: Twist, duration: float = 0.0) -> bool: + def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: """Send movement command to the robot using Twist commands. Args: diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index a0a2bd7a85..a55d8a8bdd 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -26,7 +26,6 @@ PoseStamped, Quaternion, Transform, - Twist, TwistStamped, Vector3, ) @@ -49,7 +48,7 @@ def stop(self) -> None: ... def lidar_stream(self) -> Observable: ... def odom_stream(self) -> Observable: ... def video_stream(self) -> Observable: ... - def move(self, twist: Twist, duration: float = 0.0) -> bool: ... + def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: ... def standup(self) -> None: ... def liedown(self) -> None: ... def publish_request(self, topic: str, data: dict) -> dict: ... @@ -137,7 +136,7 @@ def video_stream(self): return video_store.stream(**self.replay_config) - def move(self, vector: Twist, duration: float = 0.0): + def move(self, twist: TwistStamped, duration: float = 0.0): pass def publish_request(self, topic: str, data: dict): @@ -146,7 +145,7 @@ def publish_request(self, topic: str, data: dict): class GO2Connection(Module, spec.Camera, spec.Pointcloud): - cmd_vel: In[Twist] = None # type: ignore + cmd_vel: In[TwistStamped] = None # type: ignore pointcloud: Out[PointCloud2] = None # type: ignore image: Out[Image] = None # type: ignore camera_info_stream: Out[CameraInfo] = None # type: ignore @@ -252,11 +251,11 @@ def _publish_tf(self, msg): def publish_camera_info(self): while True: - self.camera_info.publish(camera_info) + self.camera_info_stream.publish(camera_info) time.sleep(1.0) @rpc - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: TwistStamped, duration: float = 0.0): """Send movement command to robot.""" self.connection.move(twist, duration) @@ -294,9 +293,9 @@ def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: f"{prefix}/image", default_capacity=DEFAULT_CAPACITY_COLOR_IMAGE ) - # connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) + connection.cmd_vel.transport = LCMTransport(f"{prefix}/cmd_vel", TwistStamped) - connection.camera_info.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) + connection.camera_info_stream.transport = LCMTransport(f"{prefix}/camera_info", CameraInfo) connection.start() return connection diff --git a/dimos/robot/unitree/go2/go2.py b/dimos/robot/unitree/go2/go2.py index 0712a933df..05c05e7a8e 100644 --- a/dimos/robot/unitree/go2/go2.py +++ b/dimos/robot/unitree/go2/go2.py @@ -28,11 +28,12 @@ def deploy(dimos: DimosCluster, ip: str): connection = go2.deploy(dimos, ip) foxglove_bridge.deploy(dimos) - detector = moduleDB.deploy( - dimos, - camera=connection, - lidar=connection, - ) + # detector = moduleDB.deploy( + # dimos, + # camera=connection, + # lidar=connection, + # ) - agent = agents2.deploy(dimos) - agent.register_skills(detector) + # agent = agents2.deploy(dimos) + # agent.register_skills(detector) + return connection diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 068048fb8b..61eaa83d0f 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -20,10 +20,11 @@ from reactivex import interval from reactivex.disposable import Disposable -from dimos.core import In, Module, Out, rpc +from dimos.core import DimosCluster, In, LCMTransport, Module, Out, rpc from dimos.core.global_config import GlobalConfig from dimos.msgs.nav_msgs import OccupancyGrid from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.robot.unitree.connection.go2 import Go2ConnectionProtocol from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -171,4 +172,14 @@ def splice_cylinder( mapper = Map.blueprint +def deploy(dimos: DimosCluster, connection: Go2ConnectionProtocol): + mapper = dimos.deploy(Map, global_publish_interval=1.0) + mapper.global_map.transport = LCMTransport("/global_map", LidarMessage) + mapper.global_costmap.transport = LCMTransport("/global_costmap", OccupancyGrid) + mapper.local_costmap.transport = LCMTransport("/local_costmap", OccupancyGrid) + mapper.lidar.connect(connection.pointcloud) + mapper.start() + return mapper + + __all__ = ["Map", "mapper"] diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index b33b874ecc..af1cb3bdd5 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -19,25 +19,26 @@ """ import asyncio +import base64 import threading import time from typing import Any, Dict, Optional -import base64 -import numpy as np +import numpy as np import socketio import uvicorn +from dimos_lcm.std_msgs import Bool +from reactivex.disposable import Disposable from starlette.applications import Starlette from starlette.responses import HTMLResponse from starlette.routing import Route -from dimos.core import Module, In, Out, rpc -from dimos_lcm.std_msgs import Bool +from dimos.core import In, Module, Out, rpc from dimos.mapping.types import LatLon from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger -from reactivex.disposable import Disposable + from .optimized_costmap import OptimizedCostmapEncoder logger = setup_logger("dimos.web.websocket_vis") @@ -124,14 +125,23 @@ def start(self): self._uvicorn_server_thread = threading.Thread(target=self._run_uvicorn_server, daemon=True) self._uvicorn_server_thread.start() - unsub = self.odom.subscribe(self._on_robot_pose) - self._disposables.add(Disposable(unsub)) - - unsub = self.gps_location.subscribe(self._on_gps_location) - self._disposables.add(Disposable(unsub)) - - unsub = self.path.subscribe(self._on_path) - self._disposables.add(Disposable(unsub)) + try: + unsub = self.odom.subscribe(self._on_robot_pose) + self._disposables.add(Disposable(unsub)) + except Exception as e: + ... + + try: + unsub = self.gps_location.subscribe(self._on_gps_location) + self._disposables.add(Disposable(unsub)) + except Exception as e: + ... + + try: + unsub = self.path.subscribe(self._on_path) + self._disposables.add(Disposable(unsub)) + except Exception as e: + ... unsub = self.global_costmap.subscribe(self._on_global_costmap) self._disposables.add(Disposable(unsub)) From e22c0db89f03ac89efdb2619b0a3d7cccc87ed9a Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Tue, 28 Oct 2025 23:49:56 +0200 Subject: [PATCH 38/40] redo ruff/mypy chages --- .pre-commit-config.yaml | 7 +- bin/filter-errors-after-date | 5 +- bin/filter-errors-for-user | 3 +- dimos/agents/agent.py | 98 ++++++------ dimos/agents/agent_config.py | 10 +- dimos/agents/agent_ctransformers_gguf.py | 41 ++--- dimos/agents/agent_huggingface_local.py | 41 ++--- dimos/agents/agent_huggingface_remote.py | 39 ++--- dimos/agents/agent_message.py | 15 +- dimos/agents/agent_types.py | 42 +++-- dimos/agents/cerebras_agent.py | 77 +++++----- dimos/agents/claude_agent.py | 67 ++++---- dimos/agents/memory/base.py | 7 +- dimos/agents/memory/chroma_impl.py | 33 ++-- dimos/agents/memory/image_embedding.py | 7 +- dimos/agents/memory/spatial_vector_db.py | 23 +-- dimos/agents/memory/test_image_embedding.py | 15 +- dimos/agents/memory/visual_memory.py | 20 +-- dimos/agents/modules/agent_pool.py | 32 ++-- dimos/agents/modules/base.py | 40 ++--- dimos/agents/modules/base_agent.py | 36 ++--- dimos/agents/modules/gateway/client.py | 45 ++++-- .../modules/gateway/tensorzero_embedded.py | 31 ++-- .../modules/gateway/tensorzero_simple.py | 4 +- dimos/agents/modules/gateway/utils.py | 11 +- dimos/agents/modules/simple_vision_agent.py | 27 ++-- dimos/agents/planning_agent.py | 29 ++-- dimos/agents/prompt_builder/impl.py | 21 +-- dimos/agents/test_agent_image_message.py | 16 +- dimos/agents/test_agent_message_streams.py | 43 +++--- dimos/agents/test_agent_pool.py | 33 ++-- dimos/agents/test_agent_tools.py | 49 +++--- dimos/agents/test_agent_with_modules.py | 22 ++- dimos/agents/test_base_agent_text.py | 62 ++++---- dimos/agents/test_conversation_history.py | 41 ++--- dimos/agents/test_gateway.py | 12 +- dimos/agents/test_simple_agent_module.py | 33 ++-- dimos/agents/tokenizer/base.py | 6 +- .../agents/tokenizer/huggingface_tokenizer.py | 13 +- dimos/agents/tokenizer/openai_tokenizer.py | 13 +- dimos/agents2/agent.py | 42 ++--- dimos/agents2/cli/human.py | 3 +- dimos/agents2/conftest.py | 5 +- dimos/agents2/constants.py | 1 - .../skills/google_maps_skill_container.py | 12 +- dimos/agents2/skills/gps_nav_skill.py | 12 +- dimos/agents2/skills/navigation.py | 19 ++- dimos/agents2/skills/osm.py | 15 +- dimos/agents2/skills/ros_navigation.py | 11 +- .../test_google_maps_skill_container.py | 7 +- dimos/agents2/skills/test_gps_nav_skills.py | 4 +- dimos/agents2/skills/test_navigation.py | 10 +- dimos/agents2/spec.py | 18 +-- dimos/agents2/system_prompt.py | 2 +- dimos/agents2/temp/run_unitree_agents2.py | 10 +- dimos/agents2/temp/run_unitree_async.py | 15 +- .../agents2/temp/test_unitree_agent_query.py | 11 +- .../temp/test_unitree_skill_container.py | 4 +- dimos/agents2/temp/webcam_agent.py | 13 +- dimos/agents2/test_agent.py | 2 +- dimos/agents2/test_agent_direct.py | 2 +- dimos/agents2/test_agent_fake.py | 6 +- dimos/agents2/test_mock_agent.py | 8 +- dimos/agents2/test_stash_agent.py | 3 +- dimos/agents2/testing.py | 47 +++--- dimos/core/__init__.py | 26 ++-- dimos/core/blueprints.py | 16 +- dimos/core/core.py | 8 +- dimos/core/global_config.py | 1 + dimos/core/module.py | 46 +++--- dimos/core/module_coordinator.py | 18 +-- dimos/core/o3dpickle.py | 2 +- dimos/core/rpc_client.py | 15 +- dimos/core/skill_module.py | 6 +- dimos/core/stream.py | 42 ++--- dimos/core/test_blueprints.py | 12 +- dimos/core/test_core.py | 14 +- dimos/core/test_modules.py | 43 +++--- dimos/core/test_rpcstress.py | 14 +- dimos/core/test_stream.py | 44 +++--- dimos/core/testing.py | 12 +- dimos/core/transport.py | 45 +++--- dimos/environment/agent_environment.py | 28 ++-- dimos/environment/colmap_environment.py | 16 +- dimos/environment/environment.py | 16 +- dimos/exceptions/agent_memory_exceptions.py | 18 ++- dimos/hardware/camera/module.py | 37 ++--- dimos/hardware/camera/spec.py | 4 +- dimos/hardware/camera/test_webcam.py | 4 +- dimos/hardware/camera/webcam.py | 22 +-- dimos/hardware/camera/zed/__init__.py | 9 +- dimos/hardware/camera/zed/camera.py | 60 ++++---- dimos/hardware/camera/zed/test_zed.py | 2 +- dimos/hardware/end_effector.py | 2 +- dimos/hardware/fake_zed_module.py | 16 +- dimos/hardware/gstreamer_camera.py | 20 +-- .../hardware/gstreamer_camera_test_script.py | 8 +- dimos/hardware/gstreamer_sender.py | 12 +- dimos/hardware/piper_arm.py | 102 ++++++------- dimos/hardware/sensor.py | 2 +- dimos/hardware/ufactory.py | 4 +- dimos/manipulation/manip_aio_pipeline.py | 70 +++++---- dimos/manipulation/manip_aio_processer.py | 46 +++--- dimos/manipulation/manipulation_history.py | 33 ++-- dimos/manipulation/manipulation_interface.py | 60 ++++---- .../manipulation/test_manipulation_history.py | 47 +++--- .../visual_servoing/detection3d.py | 56 +++---- .../visual_servoing/manipulation_module.py | 103 ++++++------- dimos/manipulation/visual_servoing/pbvs.py | 37 ++--- dimos/manipulation/visual_servoing/utils.py | 55 +++---- dimos/mapping/google_maps/conftest.py | 2 +- dimos/mapping/google_maps/google_maps.py | 29 ++-- dimos/mapping/google_maps/test_google_maps.py | 8 +- dimos/mapping/google_maps/types.py | 26 ++-- dimos/mapping/osm/current_location_map.py | 12 +- dimos/mapping/osm/demo_osm.py | 6 +- dimos/mapping/osm/osm.py | 14 +- dimos/mapping/osm/query.py | 6 +- dimos/mapping/osm/test_osm.py | 10 +- dimos/mapping/types.py | 4 +- .../models/Detic/configs/BoxSup_ViLD_200e.py | 23 ++- dimos/models/Detic/configs/Detic_ViLD_200e.py | 33 ++-- dimos/models/Detic/demo.py | 21 ++- dimos/models/Detic/detic/__init__.py | 15 +- dimos/models/Detic/detic/config.py | 2 +- dimos/models/Detic/detic/custom_solver.py | 11 +- .../detic/data/custom_build_augmentation.py | 4 +- .../detic/data/custom_dataset_dataloader.py | 92 +++++------ .../Detic/detic/data/custom_dataset_mapper.py | 49 +++--- dimos/models/Detic/detic/data/datasets/cc.py | 1 + .../detic/data/datasets/coco_zeroshot.py | 3 +- .../Detic/detic/data/datasets/imagenet.py | 3 +- .../data/datasets/lvis_22k_categories.py | 2 +- .../Detic/detic/data/datasets/lvis_v1.py | 20 ++- .../Detic/detic/data/datasets/objects365.py | 3 +- dimos/models/Detic/detic/data/datasets/oid.py | 5 +- .../Detic/detic/data/datasets/register_oid.py | 17 +-- dimos/models/Detic/detic/data/tar_dataset.py | 27 ++-- .../transforms/custom_augmentation_impl.py | 5 +- .../detic/data/transforms/custom_transform.py | 11 +- .../detic/evaluation/custom_coco_eval.py | 30 ++-- .../models/Detic/detic/evaluation/oideval.py | 91 ++++++----- .../modeling/backbone/swintransformer.py | 143 ++++++++--------- .../Detic/detic/modeling/backbone/timm.py | 43 +++--- dimos/models/Detic/detic/modeling/debug.py | 64 ++++---- .../detic/modeling/meta_arch/custom_rcnn.py | 51 ++++--- .../modeling/meta_arch/d2_deformable_detr.py | 30 ++-- .../modeling/roi_heads/detic_fast_rcnn.py | 102 ++++++------- .../modeling/roi_heads/detic_roi_heads.py | 35 ++--- .../modeling/roi_heads/res5_roi_heads.py | 22 +-- .../roi_heads/zero_shot_classifier.py | 6 +- .../Detic/detic/modeling/text/text_encoder.py | 33 ++-- dimos/models/Detic/detic/modeling/utils.py | 11 +- dimos/models/Detic/detic/predictor.py | 26 ++-- dimos/models/Detic/lazy_train_net.py | 6 +- dimos/models/Detic/predict.py | 17 ++- .../CenterNet2/centernet/__init__.py | 20 ++- .../CenterNet2/centernet/config.py | 2 +- .../data/custom_build_augmentation.py | 3 +- .../data/custom_dataset_dataloader.py | 61 ++++---- .../centernet/data/datasets/coco.py | 8 +- .../centernet/data/datasets/nuimages.py | 3 +- .../centernet/data/datasets/objects365.py | 3 +- .../transforms/custom_augmentation_impl.py | 5 +- .../data/transforms/custom_transform.py | 11 +- .../centernet/modeling/backbone/bifpn.py | 144 +++++++++--------- .../centernet/modeling/backbone/bifpn_fcos.py | 53 ++++--- .../centernet/modeling/backbone/dla.py | 99 ++++++------ .../centernet/modeling/backbone/dlafpn.py | 107 +++++++------ .../centernet/modeling/backbone/fpn_p5.py | 12 +- .../centernet/modeling/backbone/res2net.py | 76 ++++----- .../CenterNet2/centernet/modeling/debug.py | 44 +++--- .../modeling/dense_heads/centernet.py | 118 +++++++------- .../modeling/dense_heads/centernet_head.py | 35 ++--- .../centernet/modeling/dense_heads/utils.py | 6 +- .../centernet/modeling/layers/deform_conv.py | 27 ++-- .../modeling/layers/heatmap_focal_loss.py | 4 +- .../centernet/modeling/layers/iou_loss.py | 6 +- .../centernet/modeling/layers/ml_nms.py | 2 +- .../modeling/meta_arch/centernet_detector.py | 14 +- .../modeling/roi_heads/custom_fast_rcnn.py | 18 ++- .../modeling/roi_heads/custom_roi_heads.py | 59 +++---- .../centernet/modeling/roi_heads/fed_loss.py | 9 +- .../Detic/third_party/CenterNet2/demo.py | 9 +- .../Detic/third_party/CenterNet2/predictor.py | 24 +-- .../CenterNet2/tools/analyze_model.py | 23 ++- .../third_party/CenterNet2/tools/benchmark.py | 30 ++-- .../tools/convert-torchvision-to-d2.py | 5 +- .../CenterNet2/tools/deploy/export_model.py | 13 +- .../CenterNet2/tools/lazyconfig_train_net.py | 6 +- .../CenterNet2/tools/lightning_train_net.py | 35 ++--- .../CenterNet2/tools/plain_train_net.py | 24 +-- .../third_party/CenterNet2/tools/train_net.py | 12 +- .../CenterNet2/tools/visualize_data.py | 18 ++- .../tools/visualize_json_results.py | 12 +- .../Detic/third_party/CenterNet2/train_net.py | 46 +++--- .../third_party/Deformable-DETR/benchmark.py | 9 +- .../Deformable-DETR/datasets/__init__.py | 2 +- .../Deformable-DETR/datasets/coco.py | 27 ++-- .../Deformable-DETR/datasets/coco_eval.py | 29 ++-- .../Deformable-DETR/datasets/coco_panoptic.py | 17 +-- .../datasets/data_prefetcher.py | 6 +- .../Deformable-DETR/datasets/panoptic_eval.py | 8 +- .../Deformable-DETR/datasets/samplers.py | 28 ++-- .../datasets/torchvision_datasets/coco.py | 23 +-- .../Deformable-DETR/datasets/transforms.py | 64 ++++---- .../third_party/Deformable-DETR/engine.py | 14 +- .../Detic/third_party/Deformable-DETR/main.py | 26 ++-- .../Deformable-DETR/models/backbone.py | 31 ++-- .../Deformable-DETR/models/deformable_detr.py | 57 +++---- .../models/deformable_transformer.py | 75 +++++---- .../Deformable-DETR/models/matcher.py | 5 +- .../ops/functions/ms_deform_attn_func.py | 8 +- .../models/ops/modules/ms_deform_attn.py | 23 ++- .../Deformable-DETR/models/ops/setup.py | 11 +- .../Deformable-DETR/models/ops/test.py | 15 +- .../models/position_encoding.py | 8 +- .../Deformable-DETR/models/segmentation.py | 34 ++--- .../Deformable-DETR/tools/launch.py | 6 +- .../third_party/Deformable-DETR/util/misc.py | 98 ++++++------ .../Deformable-DETR/util/plot_utils.py | 22 +-- ...nvert-thirdparty-pretrained-model-to-d2.py | 1 + .../Detic/tools/create_imagenetlvis_json.py | 9 +- dimos/models/Detic/tools/create_lvis_21k.py | 10 +- dimos/models/Detic/tools/download_cc.py | 15 +- .../models/Detic/tools/dump_clip_features.py | 19 +-- dimos/models/Detic/tools/fix_o365_names.py | 6 +- dimos/models/Detic/tools/fix_o365_path.py | 5 +- dimos/models/Detic/tools/get_cc_tags.py | 5 +- .../Detic/tools/get_coco_zeroshot_oriorder.py | 4 +- .../tools/get_imagenet_21k_full_tar_json.py | 7 +- dimos/models/Detic/tools/get_lvis_cat_info.py | 2 +- dimos/models/Detic/tools/merge_lvis_coco.py | 6 +- .../Detic/tools/preprocess_imagenet22k.py | 21 +-- dimos/models/Detic/tools/remove_lvis_rare.py | 2 +- .../models/Detic/tools/unzip_imagenet_lvis.py | 2 +- dimos/models/Detic/train_net.py | 52 +++---- dimos/models/depth/metric3d.py | 16 +- dimos/models/embedding/base.py | 16 +- dimos/models/embedding/clip.py | 7 +- .../embedding_models_disabled_tests.py | 32 ++-- dimos/models/embedding/mobileclip.py | 4 +- dimos/models/embedding/treid.py | 2 +- dimos/models/labels/llava-34b.py | 11 +- .../contact_graspnet_pytorch/inference.py | 74 ++++----- .../test_contact_graspnet.py | 33 ++-- dimos/models/pointcloud/pointcloud_utils.py | 9 +- dimos/models/qwen/video_query.py | 16 +- dimos/models/segmentation/clipseg.py | 2 +- dimos/models/segmentation/sam.py | 4 +- dimos/models/segmentation/segment_utils.py | 8 +- dimos/models/vl/base.py | 2 +- dimos/models/vl/moondream.py | 9 +- dimos/models/vl/qwen.py | 7 +- dimos/models/vl/test_base.py | 4 +- dimos/models/vl/test_models.py | 9 +- dimos/msgs/foxglove_msgs/Color.py | 1 + dimos/msgs/geometry_msgs/Pose.py | 21 +-- dimos/msgs/geometry_msgs/PoseStamped.py | 8 +- .../msgs/geometry_msgs/PoseWithCovariance.py | 18 ++- .../PoseWithCovarianceStamped.py | 4 +- dimos/msgs/geometry_msgs/Quaternion.py | 10 +- dimos/msgs/geometry_msgs/Transform.py | 34 +++-- dimos/msgs/geometry_msgs/Twist.py | 17 +-- dimos/msgs/geometry_msgs/TwistStamped.py | 8 +- .../msgs/geometry_msgs/TwistWithCovariance.py | 10 +- .../TwistWithCovarianceStamped.py | 4 +- dimos/msgs/geometry_msgs/Vector3.py | 14 +- dimos/msgs/geometry_msgs/test_Pose.py | 92 ++++++----- dimos/msgs/geometry_msgs/test_PoseStamped.py | 10 +- .../geometry_msgs/test_PoseWithCovariance.py | 58 +++---- .../test_PoseWithCovarianceStamped.py | 47 +++--- dimos/msgs/geometry_msgs/test_Quaternion.py | 38 ++--- dimos/msgs/geometry_msgs/test_Transform.py | 34 ++--- dimos/msgs/geometry_msgs/test_Twist.py | 29 ++-- dimos/msgs/geometry_msgs/test_TwistStamped.py | 12 +- .../geometry_msgs/test_TwistWithCovariance.py | 56 +++---- .../test_TwistWithCovarianceStamped.py | 47 +++--- dimos/msgs/geometry_msgs/test_Vector3.py | 88 +++++------ dimos/msgs/geometry_msgs/test_publish.py | 8 +- dimos/msgs/nav_msgs/OccupancyGrid.py | 37 +++-- dimos/msgs/nav_msgs/Odometry.py | 12 +- dimos/msgs/nav_msgs/Path.py | 46 +++--- dimos/msgs/nav_msgs/__init__.py | 4 +- dimos/msgs/nav_msgs/test_OccupancyGrid.py | 40 ++--- dimos/msgs/nav_msgs/test_Odometry.py | 58 +++---- dimos/msgs/nav_msgs/test_Path.py | 50 +++--- dimos/msgs/sensor_msgs/CameraInfo.py | 31 ++-- dimos/msgs/sensor_msgs/Image.py | 68 +++++---- dimos/msgs/sensor_msgs/Joy.py | 18 +-- dimos/msgs/sensor_msgs/PointCloud2.py | 32 ++-- dimos/msgs/sensor_msgs/__init__.py | 4 +- .../sensor_msgs/image_impls/AbstractImage.py | 14 +- .../msgs/sensor_msgs/image_impls/CudaImage.py | 62 ++++---- .../sensor_msgs/image_impls/NumpyImage.py | 27 ++-- .../image_impls/test_image_backend_utils.py | 38 +++-- .../image_impls/test_image_backends.py | 43 +++--- dimos/msgs/sensor_msgs/test_CameraInfo.py | 17 +-- dimos/msgs/sensor_msgs/test_Joy.py | 16 +- dimos/msgs/sensor_msgs/test_PointCloud2.py | 42 +++-- dimos/msgs/sensor_msgs/test_image.py | 18 +-- dimos/msgs/std_msgs/Bool.py | 4 +- dimos/msgs/std_msgs/Header.py | 5 +- dimos/msgs/std_msgs/Int32.py | 3 +- dimos/msgs/std_msgs/Int8.py | 3 +- dimos/msgs/std_msgs/__init__.py | 4 +- dimos/msgs/std_msgs/test_header.py | 10 +- dimos/msgs/tf2_msgs/TFMessage.py | 17 +-- dimos/msgs/tf2_msgs/test_TFMessage.py | 18 +-- dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py | 5 +- dimos/navigation/bbox_navigation.py | 18 ++- .../navigation/bt_navigator/goal_validator.py | 15 +- dimos/navigation/bt_navigator/navigator.py | 40 ++--- .../bt_navigator/recovery_server.py | 4 +- .../test_wavefront_frontier_goal_selector.py | 26 ++-- .../navigation/frontier_exploration/utils.py | 13 +- .../wavefront_frontier_goal_selector.py | 63 ++++---- dimos/navigation/global_planner/__init__.py | 4 +- dimos/navigation/global_planner/algo.py | 6 +- dimos/navigation/global_planner/planner.py | 21 +-- dimos/navigation/local_planner/__init__.py | 2 +- .../local_planner/holonomic_local_planner.py | 10 +- .../navigation/local_planner/local_planner.py | 40 ++--- .../local_planner/test_base_local_planner.py | 26 ++-- dimos/navigation/rosnav.py | 53 ++++--- dimos/navigation/visual/query.py | 3 +- dimos/perception/common/__init__.py | 2 +- .../perception/common/detection2d_tracker.py | 49 +++--- dimos/perception/common/export_tensorrt.py | 3 +- dimos/perception/common/ibvs.py | 10 +- dimos/perception/common/utils.py | 86 ++++++----- dimos/perception/detection/conftest.py | 19 +-- dimos/perception/detection/detectors/detic.py | 26 ++-- .../detectors/person/test_person_detectors.py | 16 +- .../detection/detectors/person/yolo.py | 9 +- .../detectors/test_bbox_detectors.py | 18 +-- dimos/perception/detection/detectors/yolo.py | 9 +- dimos/perception/detection/module2D.py | 15 +- dimos/perception/detection/module3D.py | 10 +- dimos/perception/detection/moduleDB.py | 49 +++--- dimos/perception/detection/person_tracker.py | 11 +- dimos/perception/detection/reid/__init__.py | 6 +- .../detection/reid/embedding_id_system.py | 17 ++- dimos/perception/detection/reid/module.py | 8 +- .../reid/test_embedding_id_system.py | 24 +-- .../perception/detection/reid/test_module.py | 2 +- dimos/perception/detection/test_moduleDB.py | 4 +- dimos/perception/detection/type/__init__.py | 10 +- .../detection/type/detection2d/__init__.py | 2 +- .../detection/type/detection2d/base.py | 6 +- .../detection/type/detection2d/bbox.py | 36 +++-- .../type/detection2d/imageDetections2D.py | 22 +-- .../detection/type/detection2d/person.py | 36 ++--- .../detection/type/detection2d/test_bbox.py | 2 +- .../detection2d/test_imageDetections2D.py | 6 +- .../detection/type/detection2d/test_person.py | 4 +- .../detection/type/detection3d/__init__.py | 2 +- .../detection/type/detection3d/base.py | 12 +- .../detection/type/detection3d/bbox.py | 20 +-- .../type/detection3d/imageDetections3DPC.py | 2 +- .../detection/type/detection3d/pointcloud.py | 28 ++-- .../type/detection3d/pointcloud_filters.py | 22 +-- .../detection3d/test_imageDetections3DPC.py | 6 +- .../type/detection3d/test_pointcloud.py | 2 +- .../detection/type/imageDetections.py | 14 +- .../detection/type/test_detection3d.py | 8 +- .../detection/type/test_object3d.py | 13 +- dimos/perception/detection/type/utils.py | 2 +- dimos/perception/detection2d/utils.py | 23 +-- .../grasp_generation/grasp_generation.py | 23 +-- dimos/perception/grasp_generation/utils.py | 32 ++-- dimos/perception/object_detection_stream.py | 34 +++-- dimos/perception/object_tracker.py | 72 ++++----- dimos/perception/object_tracker_2d.py | 48 +++--- dimos/perception/object_tracker_3d.py | 41 +++-- dimos/perception/person_tracker.py | 37 +++-- dimos/perception/pointcloud/__init__.py | 2 +- dimos/perception/pointcloud/cuboid_fit.py | 24 +-- .../pointcloud/pointcloud_filtering.py | 41 +++-- .../pointcloud/test_pointcloud_filtering.py | 28 ++-- dimos/perception/pointcloud/utils.py | 54 +++---- dimos/perception/segmentation/__init__.py | 2 +- .../perception/segmentation/image_analyzer.py | 17 ++- dimos/perception/segmentation/sam_2d_seg.py | 40 ++--- .../segmentation/test_sam_2d_seg.py | 16 +- dimos/perception/segmentation/utils.py | 44 ++++-- dimos/perception/spatial_perception.py | 60 ++++---- dimos/perception/test_spatial_memory.py | 16 +- .../perception/test_spatial_memory_module.py | 25 ++- dimos/protocol/encode/__init__.py | 4 +- dimos/protocol/pubsub/lcmpubsub.py | 27 ++-- dimos/protocol/pubsub/memory.py | 7 +- dimos/protocol/pubsub/redispubsub.py | 25 +-- dimos/protocol/pubsub/shm/ipc_factory.py | 33 ++-- dimos/protocol/pubsub/shmpubsub.py | 43 +++--- dimos/protocol/pubsub/spec.py | 19 +-- dimos/protocol/pubsub/test_encoder.py | 24 +-- dimos/protocol/pubsub/test_lcmpubsub.py | 19 ++- dimos/protocol/pubsub/test_spec.py | 32 ++-- dimos/protocol/rpc/off_test_pubsubrpc.py | 28 ++-- dimos/protocol/rpc/pubsubrpc.py | 17 ++- dimos/protocol/rpc/spec.py | 22 ++- dimos/protocol/rpc/test_lcmrpc.py | 4 +- dimos/protocol/rpc/test_lcmrpc_timeout.py | 10 +- dimos/protocol/service/lcmservice.py | 24 +-- dimos/protocol/service/spec.py | 4 +- dimos/protocol/service/test_lcmservice.py | 45 +++--- dimos/protocol/service/test_spec.py | 12 +- dimos/protocol/skill/comms.py | 16 +- dimos/protocol/skill/coordinator.py | 48 +++--- dimos/protocol/skill/schema.py | 8 +- dimos/protocol/skill/skill.py | 12 +- dimos/protocol/skill/test_coordinator.py | 20 +-- dimos/protocol/skill/test_utils.py | 24 +-- dimos/protocol/skill/type.py | 34 ++--- dimos/protocol/tf/__init__.py | 4 +- dimos/protocol/tf/test_tf.py | 40 ++--- dimos/protocol/tf/tf.py | 51 +++---- dimos/protocol/tf/tflcmcpp.py | 18 +-- dimos/robot/agilex/piper_arm.py | 22 ++- dimos/robot/agilex/run.py | 13 +- dimos/robot/all_blueprints.py | 1 - dimos/robot/cli/dimos_robot.py | 11 +- dimos/robot/connection_interface.py | 5 +- dimos/robot/foxglove_bridge.py | 23 +-- dimos/robot/position_stream.py | 21 ++- dimos/robot/recorder.py | 17 ++- dimos/robot/robot.py | 9 +- dimos/robot/ros_bridge.py | 24 +-- dimos/robot/ros_command_queue.py | 36 ++--- dimos/robot/ros_control.py | 122 ++++++++------- dimos/robot/ros_observable_topic.py | 32 ++-- dimos/robot/ros_transform.py | 17 ++- dimos/robot/test_ros_bridge.py | 46 +++--- dimos/robot/test_ros_observable_topic.py | 30 ++-- dimos/robot/unitree/connection/connection.py | 50 +++--- dimos/robot/unitree/connection/g1.py | 9 +- dimos/robot/unitree/connection/go2.py | 32 ++-- dimos/robot/unitree/g1/g1zed.py | 6 +- dimos/robot/unitree/go2/go2.py | 2 - dimos/robot/unitree/run.py | 11 +- dimos/robot/unitree_webrtc/connection.py | 53 ++++--- dimos/robot/unitree_webrtc/depth_module.py | 27 ++-- .../unitree_webrtc/g1_joystick_module.py | 8 +- dimos/robot/unitree_webrtc/g1_run.py | 12 +- .../modular/connection_module.py | 38 +++-- dimos/robot/unitree_webrtc/modular/detect.py | 4 +- .../unitree_webrtc/modular/ivan_unitree.py | 17 +-- .../unitree_webrtc/modular/navigation.py | 2 +- .../robot/unitree_webrtc/mujoco_connection.py | 32 ++-- dimos/robot/unitree_webrtc/rosnav.py | 30 +--- .../test_unitree_go2_integration.py | 19 +-- dimos/robot/unitree_webrtc/testing/helpers.py | 8 +- dimos/robot/unitree_webrtc/testing/mock.py | 25 +-- .../robot/unitree_webrtc/testing/multimock.py | 22 ++- .../unitree_webrtc/testing/test_actors.py | 14 +- .../robot/unitree_webrtc/testing/test_mock.py | 10 +- .../unitree_webrtc/testing/test_tooling.py | 6 +- dimos/robot/unitree_webrtc/type/lidar.py | 13 +- dimos/robot/unitree_webrtc/type/lowstate.py | 14 +- dimos/robot/unitree_webrtc/type/map.py | 9 +- dimos/robot/unitree_webrtc/type/odometry.py | 3 - dimos/robot/unitree_webrtc/type/test_lidar.py | 7 +- dimos/robot/unitree_webrtc/type/test_map.py | 12 +- .../unitree_webrtc/type/test_odometry.py | 7 +- .../unitree_webrtc/type/test_timeseries.py | 12 +- dimos/robot/unitree_webrtc/type/timeseries.py | 15 +- dimos/robot/unitree_webrtc/type/vector.py | 24 +-- .../unitree_webrtc/unitree_b1/b1_command.py | 4 +- .../unitree_webrtc/unitree_b1/connection.py | 35 ++--- .../unitree_b1/joystick_module.py | 9 +- .../unitree_b1/test_connection.py | 26 ++-- .../unitree_webrtc/unitree_b1/unitree_b1.py | 24 ++- dimos/robot/unitree_webrtc/unitree_g1.py | 53 ++++--- .../unitree_g1_skill_container.py | 21 ++- dimos/robot/unitree_webrtc/unitree_go2.py | 101 ++++++------ .../unitree_webrtc/unitree_go2_blueprints.py | 32 ++-- .../unitree_webrtc/unitree_skill_container.py | 19 ++- dimos/robot/unitree_webrtc/unitree_skills.py | 34 +++-- dimos/robot/utils/robot_debugger.py | 2 +- dimos/simulation/__init__.py | 2 +- dimos/simulation/base/simulator_base.py | 7 +- dimos/simulation/base/stream_base.py | 10 +- dimos/simulation/genesis/simulator.py | 19 +-- dimos/simulation/genesis/stream.py | 23 +-- dimos/simulation/isaac/simulator.py | 11 +- dimos/simulation/isaac/stream.py | 21 +-- dimos/simulation/mujoco/depth_camera.py | 1 + dimos/simulation/mujoco/model.py | 7 +- dimos/simulation/mujoco/mujoco.py | 21 ++- dimos/simulation/mujoco/policy.py | 2 +- dimos/skills/kill_skill.py | 3 +- .../abstract_manipulation_skill.py | 10 +- .../manipulation/force_constraint_skill.py | 5 +- dimos/skills/manipulation/manipulate_skill.py | 19 +-- dimos/skills/manipulation/pick_and_place.py | 29 ++-- .../manipulation/rotation_constraint_skill.py | 13 +- .../translation_constraint_skill.py | 12 +- dimos/skills/rest/rest.py | 6 +- dimos/skills/skills.py | 40 ++--- dimos/skills/speak.py | 28 ++-- dimos/skills/unitree/unitree_speak.py | 30 ++-- dimos/skills/visual_navigation_skills.py | 16 +- dimos/spec/__init__.py | 6 +- dimos/spec/perception.py | 3 +- dimos/stream/audio/base.py | 7 +- dimos/stream/audio/node_key_recorder.py | 33 ++-- dimos/stream/audio/node_microphone.py | 22 +-- dimos/stream/audio/node_normalizer.py | 20 +-- dimos/stream/audio/node_output.py | 21 +-- dimos/stream/audio/node_simulated.py | 21 +-- dimos/stream/audio/node_volume_monitor.py | 13 +- dimos/stream/audio/pipelines.py | 6 +- dimos/stream/audio/stt/node_whisper.py | 19 ++- dimos/stream/audio/text/base.py | 1 + dimos/stream/audio/text/node_stdout.py | 4 +- dimos/stream/audio/tts/node_openai.py | 21 +-- dimos/stream/audio/tts/node_pytts.py | 5 +- dimos/stream/audio/utils.py | 2 +- dimos/stream/audio/volume.py | 5 +- dimos/stream/data_provider.py | 25 ++- dimos/stream/frame_processor.py | 21 +-- dimos/stream/ros_video_provider.py | 11 +- dimos/stream/rtsp_video_provider.py | 9 +- dimos/stream/stream_merger.py | 8 +- dimos/stream/video_operators.py | 76 ++++----- dimos/stream/video_provider.py | 11 +- dimos/stream/video_providers/unitree.py | 27 ++-- dimos/stream/videostream.py | 8 +- dimos/types/label.py | 6 +- dimos/types/manipulation.py | 56 +++---- dimos/types/robot_location.py | 18 +-- dimos/types/ros_polyfill.py | 6 +- dimos/types/sample.py | 50 +++--- dimos/types/segmentation.py | 7 +- dimos/types/test_timestamped.py | 40 ++--- dimos/types/test_vector.py | 86 +++++------ dimos/types/test_weaklist.py | 16 +- dimos/types/timestamped.py | 39 ++--- dimos/types/vector.py | 23 +-- dimos/types/weaklist.py | 7 +- dimos/utils/actor_registry.py | 11 +- dimos/utils/cli/agentspy/agentspy.py | 38 ++--- dimos/utils/cli/agentspy/demo_agentspy.py | 6 +- dimos/utils/cli/boxglove/boxglove.py | 36 ++--- dimos/utils/cli/boxglove/connection.py | 7 +- .../foxglove_bridge/run_foxglove_bridge.py | 6 +- dimos/utils/cli/human/humancli.py | 21 +-- dimos/utils/cli/lcmspy/lcmspy.py | 32 ++-- dimos/utils/cli/lcmspy/run_lcmspy.py | 26 +--- dimos/utils/cli/lcmspy/test_lcmspy.py | 20 ++- dimos/utils/cli/skillspy/demo_skillspy.py | 9 +- dimos/utils/cli/skillspy/skillspy.py | 45 +++--- dimos/utils/cli/theme.py | 2 +- dimos/utils/data.py | 15 +- dimos/utils/decorators/accumulators.py | 16 +- dimos/utils/decorators/decorators.py | 12 +- dimos/utils/decorators/test_decorators.py | 46 +++--- dimos/utils/deprecation.py | 2 +- dimos/utils/extract_frames.py | 7 +- dimos/utils/generic.py | 10 +- dimos/utils/generic_subscriber.py | 29 ++-- dimos/utils/gpu_utils.py | 1 - dimos/utils/llm_utils.py | 3 +- dimos/utils/logging_config.py | 3 +- dimos/utils/monitoring.py | 43 +++--- dimos/utils/reactive.py | 19 +-- dimos/utils/s3_utils.py | 15 +- dimos/utils/simple_controller.py | 16 +- dimos/utils/test_data.py | 6 +- dimos/utils/test_foxglove_bridge.py | 13 +- dimos/utils/test_generic.py | 1 + dimos/utils/test_llm_utils.py | 16 +- dimos/utils/test_reactive.py | 31 ++-- dimos/utils/test_testing.py | 30 ++-- dimos/utils/test_transform_utils.py | 138 ++++++++--------- dimos/utils/testing.py | 67 ++++---- dimos/utils/transform_utils.py | 5 +- dimos/web/dimos_interface/api/server.py | 53 +++---- dimos/web/edge_io.py | 4 +- dimos/web/fastapi_server.py | 39 ++--- dimos/web/flask_server.py | 17 ++- dimos/web/robot_web_interface.py | 2 +- dimos/web/websocket_vis/costmap_viz.py | 6 +- dimos/web/websocket_vis/optimized_costmap.py | 19 +-- dimos/web/websocket_vis/path_history.py | 9 +- .../web/websocket_vis/websocket_vis_module.py | 54 ++++--- pyproject.toml | 10 ++ setup.py | 2 +- tests/agent_manip_flow_fastapi_test.py | 14 +- tests/agent_manip_flow_flask_test.py | 15 +- tests/agent_memory_test.py | 4 - tests/genesissim/stream_camera.py | 2 +- tests/isaacsim/stream_camera.py | 4 +- tests/run.py | 49 +++--- tests/run_go2_ros.py | 4 +- tests/run_navigation_only.py | 25 +-- tests/simple_agent_test.py | 7 +- tests/test_agent.py | 3 - tests/test_agent_alibaba.py | 10 +- tests/test_agent_ctransformers_gguf.py | 2 - tests/test_agent_huggingface_local.py | 10 +- tests/test_agent_huggingface_local_jetson.py | 10 +- tests/test_agent_huggingface_remote.py | 9 +- tests/test_audio_agent.py | 4 +- tests/test_audio_robot_agent.py | 9 +- tests/test_cerebras_unitree_ros.py | 22 +-- tests/test_claude_agent_query.py | 3 +- tests/test_claude_agent_skills_query.py | 23 ++- tests/test_command_pose_unitree.py | 8 +- tests/test_header.py | 4 +- tests/test_huggingface_llm_agent.py | 2 - tests/test_manipulation_agent.py | 48 ++---- .../test_manipulation_perception_pipeline.py | 16 +- ...est_manipulation_perception_pipeline.py.py | 16 +- ...test_manipulation_pipeline_single_frame.py | 17 ++- ..._manipulation_pipeline_single_frame_lcm.py | 38 ++--- tests/test_move_vel_unitree.py | 7 +- tests/test_navigate_to_object_robot.py | 13 +- tests/test_navigation_skills.py | 10 +- ...bject_detection_agent_data_query_stream.py | 20 +-- tests/test_object_detection_stream.py | 19 ++- tests/test_object_tracking_module.py | 13 +- tests/test_object_tracking_webcam.py | 9 +- tests/test_object_tracking_with_qwen.py | 15 +- tests/test_person_following_robot.py | 7 +- tests/test_person_following_webcam.py | 9 +- tests/test_pick_and_place_module.py | 12 +- tests/test_pick_and_place_skill.py | 2 +- tests/test_planning_agent_web_interface.py | 6 +- tests/test_planning_robot_agent.py | 7 +- tests/test_pointcloud_filtering.py | 16 +- tests/test_qwen_image_query.py | 2 + tests/test_robot.py | 9 +- tests/test_rtsp_video_provider.py | 12 +- tests/test_semantic_seg_robot.py | 18 +-- tests/test_semantic_seg_robot_agent.py | 16 +- tests/test_semantic_seg_webcam.py | 9 +- tests/test_skills.py | 7 +- tests/test_skills_rest.py | 15 +- tests/test_spatial_memory.py | 15 +- tests/test_spatial_memory_query.py | 14 +- tests/test_standalone_chromadb.py | 5 +- tests/test_standalone_fastapi.py | 8 +- tests/test_standalone_hugging_face.py | 28 +--- tests/test_standalone_openai_json.py | 4 +- tests/test_standalone_openai_json_struct.py | 7 +- ...test_standalone_openai_json_struct_func.py | 9 +- ...lone_openai_json_struct_func_playground.py | 27 +--- tests/test_standalone_project_out.py | 8 +- tests/test_standalone_rxpy_01.py | 39 +++-- tests/test_unitree_agent.py | 3 +- tests/test_unitree_agent_queries_fastapi.py | 3 +- tests/test_unitree_ros_v0.0.4.py | 27 ++-- tests/test_webrtc_queue.py | 5 +- tests/test_websocketvis.py | 17 ++- tests/test_zed_module.py | 17 +-- tests/test_zed_setup.py | 7 +- tests/visualization_script.py | 93 ++++------- tests/zed_neural_depth_demo.py | 16 +- 660 files changed, 7082 insertions(+), 7275 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a807e203b..67544f7f29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,12 +19,13 @@ repos: - --use-current-year - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.11 + rev: v0.14.1 hooks: - #- id: ruff-check - # args: [--fix] - id: ruff-format stages: [pre-commit] + - id: ruff-check + args: [--fix, --unsafe-fixes] + - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.6.0 hooks: diff --git a/bin/filter-errors-after-date b/bin/filter-errors-after-date index 5a0c46408e..03c7de0ca7 100755 --- a/bin/filter-errors-after-date +++ b/bin/filter-errors-after-date @@ -3,11 +3,10 @@ # Used to filter errors to only show lines committed on or after a specific date # Can be chained with filter-errors-for-user -import sys +from datetime import datetime import re import subprocess -from datetime import datetime - +import sys _blame = {} diff --git a/bin/filter-errors-for-user b/bin/filter-errors-for-user index 78247a9bb2..045b30b293 100755 --- a/bin/filter-errors-for-user +++ b/bin/filter-errors-for-user @@ -2,10 +2,9 @@ # Used when running `./bin/mypy-strict --for-me` -import sys import re import subprocess - +import sys _blame = {} diff --git a/dimos/agents/agent.py b/dimos/agents/agent.py index 1ce2216fe7..62765ef706 100644 --- a/dimos/agents/agent.py +++ b/dimos/agents/agent.py @@ -30,29 +30,32 @@ import json import os import threading -from typing import Any, Tuple, Optional, Union +from typing import TYPE_CHECKING, Any # Third-party imports from dotenv import load_dotenv from openai import NOT_GIVEN, OpenAI from pydantic import BaseModel -from reactivex import Observer, create, Observable, empty, operators as RxOps, just +from reactivex import Observable, Observer, create, empty, just, operators as RxOps from reactivex.disposable import CompositeDisposable, Disposable -from reactivex.scheduler import ThreadPoolScheduler from reactivex.subject import Subject # Local imports -from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.memory.chroma_impl import OpenAISemanticMemory from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.stream.frame_processor import FrameProcessor from dimos.stream.stream_merger import create_stream_merger from dimos.stream.video_operators import Operators as MyOps, VideoOperators as MyVidOps -from dimos.utils.threadpool import get_scheduler from dimos.utils.logging_config import setup_logger +from dimos.utils.threadpool import get_scheduler + +if TYPE_CHECKING: + from reactivex.scheduler import ThreadPoolScheduler + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.tokenizer.base import AbstractTokenizer # Initialize environment variables load_dotenv() @@ -75,9 +78,9 @@ def __init__( self, dev_name: str = "NA", agent_type: str = "Base", - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - pool_scheduler: Optional[ThreadPoolScheduler] = None, - ): + agent_memory: AbstractAgentSemanticMemory | None = None, + pool_scheduler: ThreadPoolScheduler | None = None, + ) -> None: """ Initializes a new instance of the Agent. @@ -94,7 +97,7 @@ def __init__( self.disposables = CompositeDisposable() self.pool_scheduler = pool_scheduler if pool_scheduler else get_scheduler() - def dispose_all(self): + def dispose_all(self) -> None: """Disposes of all active subscriptions managed by this agent.""" if self.disposables: self.disposables.dispose() @@ -145,16 +148,16 @@ def __init__( self, dev_name: str = "NA", agent_type: str = "LLM", - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - pool_scheduler: Optional[ThreadPoolScheduler] = None, + agent_memory: AbstractAgentSemanticMemory | None = None, + pool_scheduler: ThreadPoolScheduler | None = None, process_all_inputs: bool = False, - system_query: Optional[str] = None, + system_query: str | None = None, max_output_tokens_per_request: int = 16384, max_input_tokens_per_request: int = 128000, - input_query_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - ): + input_query_stream: Observable | None = None, + input_data_stream: Observable | None = None, + input_video_stream: Observable | None = None, + ) -> None: """ Initializes a new instance of the LLMAgent. @@ -169,9 +172,9 @@ def __init__( """ super().__init__(dev_name, agent_type, agent_memory, pool_scheduler) # These attributes can be configured by a subclass if needed. - self.query: Optional[str] = None - self.prompt_builder: Optional[PromptBuilder] = None - self.system_query: Optional[str] = system_query + self.query: str | None = None + self.prompt_builder: PromptBuilder | None = None + self.system_query: str | None = system_query self.image_detail: str = "low" self.max_input_tokens_per_request: int = max_input_tokens_per_request self.max_output_tokens_per_request: int = max_output_tokens_per_request @@ -180,7 +183,7 @@ def __init__( ) self.rag_query_n: int = 4 self.rag_similarity_threshold: float = 0.45 - self.frame_processor: Optional[FrameProcessor] = None + self.frame_processor: FrameProcessor | None = None self.output_dir: str = os.path.join(os.getcwd(), "assets", "agent") self.process_all_inputs: bool = process_all_inputs os.makedirs(self.output_dir, exist_ok=True) @@ -225,8 +228,11 @@ def __init__( ) logger.info("Subscribing to merged input stream...") + # Define a query extractor for the merged stream - query_extractor = lambda emission: (emission[0], emission[1][0]) + def query_extractor(emission): + return (emission[0], emission[1][0]) + self.disposables.add( self.subscribe_to_image_processing( self.merged_stream, query_extractor=query_extractor @@ -241,7 +247,7 @@ def __init__( logger.info("Subscribing to input query stream...") self.disposables.add(self.subscribe_to_query_processing(self.input_query_stream)) - def _update_query(self, incoming_query: Optional[str]) -> None: + def _update_query(self, incoming_query: str | None) -> None: """Updates the query if an incoming query is provided. Args: @@ -250,7 +256,7 @@ def _update_query(self, incoming_query: Optional[str]) -> None: if incoming_query is not None: self.query = incoming_query - def _get_rag_context(self) -> Tuple[str, str]: + def _get_rag_context(self) -> tuple[str, str]: """Queries the agent memory to retrieve RAG context. Returns: @@ -273,8 +279,8 @@ def _get_rag_context(self) -> Tuple[str, str]: def _build_prompt( self, - base64_image: Optional[str], - dimensions: Optional[Tuple[int, int]], + base64_image: str | None, + dimensions: tuple[int, int] | None, override_token_limit: bool, condensed_results: str, ) -> list: @@ -370,10 +376,10 @@ def _tooling_callback(message, messages, response_message, skill_library: SkillL def _observable_query( self, observer: Observer, - base64_image: Optional[str] = None, - dimensions: Optional[Tuple[int, int]] = None, + base64_image: str | None = None, + dimensions: tuple[int, int] | None = None, override_token_limit: bool = False, - incoming_query: Optional[str] = None, + incoming_query: str | None = None, ): """Prepares and sends a query to the LLM, emitting the response to the observer. @@ -449,7 +455,7 @@ def _send_query(self, messages: list) -> Any: """ raise NotImplementedError("Subclasses must implement _send_query method.") - def _log_response_to_file(self, response, output_dir: str = None): + def _log_response_to_file(self, response, output_dir: str | None = None) -> None: """Logs the LLM response to a file. Args: @@ -670,7 +676,7 @@ def run_observable_query(self, query_text: str, **kwargs) -> Observable: ) ) - def dispose_all(self): + def dispose_all(self) -> None: """Disposes of all active subscriptions managed by this agent.""" super().dispose_all() self.response_subject.on_completed() @@ -695,27 +701,27 @@ def __init__( dev_name: str, agent_type: str = "Vision", query: str = "What do you see?", - input_query_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, + input_data_stream: Observable | None = None, + input_video_stream: Observable | None = None, output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, max_input_tokens_per_request: int = 128000, max_output_tokens_per_request: int = 16384, model_name: str = "gpt-4o", - prompt_builder: Optional[PromptBuilder] = None, - tokenizer: Optional[AbstractTokenizer] = None, + prompt_builder: PromptBuilder | None = None, + tokenizer: AbstractTokenizer | None = None, rag_query_n: int = 4, rag_similarity_threshold: float = 0.45, - skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, - response_model: Optional[BaseModel] = None, - frame_processor: Optional[FrameProcessor] = None, + skills: AbstractSkill | list[AbstractSkill] | SkillLibrary | None = None, + response_model: BaseModel | None = None, + frame_processor: FrameProcessor | None = None, image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - openai_client: Optional[OpenAI] = None, - ): + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + openai_client: OpenAI | None = None, + ) -> None: """ Initializes a new instance of the OpenAIAgent. @@ -803,7 +809,7 @@ def __init__( logger.info("OpenAI Agent Initialized.") - def _add_context_to_memory(self): + def _add_context_to_memory(self) -> None: """Adds initial context to the agent's memory.""" context_data = [ ( diff --git a/dimos/agents/agent_config.py b/dimos/agents/agent_config.py index 0ffbcd2983..5b9027b072 100644 --- a/dimos/agents/agent_config.py +++ b/dimos/agents/agent_config.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List + from dimos.agents.agent import Agent class AgentConfig: - def __init__(self, agents: List[Agent] = None): + def __init__(self, agents: list[Agent] | None = None) -> None: """ Initialize an AgentConfig with a list of agents. @@ -26,7 +26,7 @@ def __init__(self, agents: List[Agent] = None): """ self.agents = agents if agents is not None else [] - def add_agent(self, agent: Agent): + def add_agent(self, agent: Agent) -> None: """ Add an agent to the configuration. @@ -35,7 +35,7 @@ def add_agent(self, agent: Agent): """ self.agents.append(agent) - def remove_agent(self, agent: Agent): + def remove_agent(self, agent: Agent) -> None: """ Remove an agent from the configuration. @@ -45,7 +45,7 @@ def remove_agent(self, agent: Agent): if agent in self.agents: self.agents.remove(agent) - def get_agents(self) -> List[Agent]: + def get_agents(self) -> list[Agent]: """ Get the list of configured agents. diff --git a/dimos/agents/agent_ctransformers_gguf.py b/dimos/agents/agent_ctransformers_gguf.py index 32d6fc59ca..17d233437d 100644 --- a/dimos/agents/agent_ctransformers_gguf.py +++ b/dimos/agents/agent_ctransformers_gguf.py @@ -17,18 +17,15 @@ # Standard library imports import logging import os -from typing import Any, Optional +from typing import TYPE_CHECKING, Any # Third-party imports from dotenv import load_dotenv from reactivex import Observable, create -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject import torch # Local imports from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.prompt_builder.impl import PromptBuilder from dimos.utils.logging_config import setup_logger @@ -40,30 +37,38 @@ from ctransformers import AutoModelForCausalLM as CTransformersModel +if TYPE_CHECKING: + from reactivex.scheduler import ThreadPoolScheduler + from reactivex.subject import Subject + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + class CTransformersTokenizerAdapter: - def __init__(self, model): + def __init__(self, model) -> None: self.model = model - def encode(self, text, **kwargs): + def encode(self, text: str, **kwargs): return self.model.tokenize(text) def decode(self, token_ids, **kwargs): return self.model.detokenize(token_ids) - def token_count(self, text): + def token_count(self, text: str): return len(self.tokenize_text(text)) if text else 0 - def tokenize_text(self, text): + def tokenize_text(self, text: str): return self.model.tokenize(text) def detokenize_text(self, tokenized_text): try: return self.model.detokenize(tokenized_text) except Exception as e: - raise ValueError(f"Failed to detokenize text. Error: {str(e)}") + raise ValueError(f"Failed to detokenize text. Error: {e!s}") - def apply_chat_template(self, conversation, tokenize=False, add_generation_prompt=True): + def apply_chat_template( + self, conversation, tokenize: bool = False, add_generation_prompt: bool = True + ): prompt = "" for message in conversation: role = message["role"] @@ -91,17 +96,17 @@ def __init__( gpu_layers: int = 50, device: str = "auto", query: str = "How many r's are in the word 'strawberry'?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, + input_video_stream: Observable | None = None, output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = "You are a helpful assistant.", + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = "You are a helpful assistant.", max_output_tokens_per_request: int = 10, max_input_tokens_per_request: int = 250, - prompt_builder: Optional[PromptBuilder] = None, - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - ): + prompt_builder: PromptBuilder | None = None, + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + ) -> None: # Determine appropriate default for process_all_inputs if not provided if process_all_inputs is None: # Default to True for text queries, False for video streams diff --git a/dimos/agents/agent_huggingface_local.py b/dimos/agents/agent_huggingface_local.py index 14f970c3bc..69d02bb1d2 100644 --- a/dimos/agents/agent_huggingface_local.py +++ b/dimos/agents/agent_huggingface_local.py @@ -17,25 +17,28 @@ # Standard library imports import logging import os -from typing import Any, Optional +from typing import TYPE_CHECKING, Any # Third-party imports from dotenv import load_dotenv from reactivex import Observable, create -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject import torch from transformers import AutoModelForCausalLM # Local imports from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.memory.chroma_impl import LocalSemanticMemory from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from reactivex.scheduler import ThreadPoolScheduler + from reactivex.subject import Subject + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.tokenizer.base import AbstractTokenizer + # Initialize environment variables load_dotenv() @@ -52,19 +55,19 @@ def __init__( model_name: str = "Qwen/Qwen2.5-3B", device: str = "auto", query: str = "How many r's are in the word 'strawberry'?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, + input_video_stream: Observable | None = None, output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, - max_output_tokens_per_request: int = None, - max_input_tokens_per_request: int = None, - prompt_builder: Optional[PromptBuilder] = None, - tokenizer: Optional[AbstractTokenizer] = None, + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, + max_output_tokens_per_request: int | None = None, + max_input_tokens_per_request: int | None = None, + prompt_builder: PromptBuilder | None = None, + tokenizer: AbstractTokenizer | None = None, image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - ): + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + ) -> None: # Determine appropriate default for process_all_inputs if not provided if process_all_inputs is None: # Default to True for text queries, False for video streams @@ -134,7 +137,7 @@ def _send_query(self, messages: list) -> Any: try: # Log the incoming messages - print(f"{_BLUE_PRINT_COLOR}Messages: {str(messages)}{_RESET_COLOR}") + print(f"{_BLUE_PRINT_COLOR}Messages: {messages!s}{_RESET_COLOR}") # Process with chat template try: @@ -163,7 +166,9 @@ def _send_query(self, messages: list) -> Any: print("Processing generated output...") generated_ids = [ output_ids[len(input_ids) :] - for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) + for input_ids, output_ids in zip( + model_inputs.input_ids, generated_ids, strict=False + ) ] # Convert tokens back to text diff --git a/dimos/agents/agent_huggingface_remote.py b/dimos/agents/agent_huggingface_remote.py index d98b277706..5bb5b293d3 100644 --- a/dimos/agents/agent_huggingface_remote.py +++ b/dimos/agents/agent_huggingface_remote.py @@ -17,23 +17,26 @@ # Standard library imports import logging import os -from typing import Any, Optional +from typing import TYPE_CHECKING, Any # Third-party imports from dotenv import load_dotenv from huggingface_hub import InferenceClient -from reactivex import create, Observable -from reactivex.scheduler import ThreadPoolScheduler -from reactivex.subject import Subject +from reactivex import Observable, create # Local imports from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from reactivex.scheduler import ThreadPoolScheduler + from reactivex.subject import Subject + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.tokenizer.base import AbstractTokenizer + # Initialize environment variables load_dotenv() @@ -49,21 +52,21 @@ def __init__( agent_type: str = "HF-LLM", model_name: str = "Qwen/QwQ-32B", query: str = "How many r's are in the word 'strawberry'?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, + input_video_stream: Observable | None = None, output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, max_output_tokens_per_request: int = 16384, - prompt_builder: Optional[PromptBuilder] = None, - tokenizer: Optional[AbstractTokenizer] = None, + prompt_builder: PromptBuilder | None = None, + tokenizer: AbstractTokenizer | None = None, image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - api_key: Optional[str] = None, - hf_provider: Optional[str] = None, - hf_base_url: Optional[str] = None, - ): + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + api_key: str | None = None, + hf_provider: str | None = None, + hf_base_url: str | None = None, + ) -> None: # Determine appropriate default for process_all_inputs if not provided if process_all_inputs is None: # Default to True for text queries, False for video streams diff --git a/dimos/agents/agent_message.py b/dimos/agents/agent_message.py index 5baa3c11f0..cecd8092c1 100644 --- a/dimos/agents/agent_message.py +++ b/dimos/agents/agent_message.py @@ -15,11 +15,10 @@ """AgentMessage type for multimodal agent communication.""" from dataclasses import dataclass, field -from typing import List, Optional, Union import time -from dimos.msgs.sensor_msgs.Image import Image from dimos.agents.agent_types import AgentImage +from dimos.msgs.sensor_msgs.Image import Image @dataclass @@ -33,9 +32,9 @@ class AgentMessage: into a single message when sent to the LLM. """ - messages: List[str] = field(default_factory=list) - images: List[AgentImage] = field(default_factory=list) - sender_id: Optional[str] = None + messages: list[str] = field(default_factory=list) + images: list[AgentImage] = field(default_factory=list) + sender_id: str | None = None timestamp: float = field(default_factory=time.time) def add_text(self, text: str) -> None: @@ -43,7 +42,7 @@ def add_text(self, text: str) -> None: if text: # Only add non-empty text self.messages.append(text) - def add_image(self, image: Union[Image, AgentImage]) -> None: + def add_image(self, image: Image | AgentImage) -> None: """Add an image. Converts Image to AgentImage if needed.""" if isinstance(image, Image): # Convert to AgentImage @@ -72,11 +71,11 @@ def is_multimodal(self) -> bool: """Check if message contains both text and images.""" return self.has_text() and self.has_images() - def get_primary_text(self) -> Optional[str]: + def get_primary_text(self) -> str | None: """Get the first text message, if any.""" return self.messages[0] if self.messages else None - def get_primary_image(self) -> Optional[AgentImage]: + def get_primary_image(self) -> AgentImage | None: """Get the first image, if any.""" return self.images[0] if self.images else None diff --git a/dimos/agents/agent_types.py b/dimos/agents/agent_types.py index e57f4dec84..db41acbafb 100644 --- a/dimos/agents/agent_types.py +++ b/dimos/agents/agent_types.py @@ -15,10 +15,10 @@ """Agent-specific types for message passing.""" from dataclasses import dataclass, field -from typing import List, Optional, Dict, Any, Union +import json import threading import time -import json +from typing import Any @dataclass @@ -30,9 +30,9 @@ class AgentImage: """ base64_jpeg: str - width: Optional[int] = None - height: Optional[int] = None - metadata: Dict[str, Any] = field(default_factory=dict) + width: int | None = None + height: int | None = None + metadata: dict[str, Any] = field(default_factory=dict) def __repr__(self) -> str: return f"AgentImage(size={self.width}x{self.height}, metadata={list(self.metadata.keys())})" @@ -44,7 +44,7 @@ class ToolCall: id: str name: str - arguments: Dict[str, Any] + arguments: dict[str, Any] status: str = "pending" # pending, executing, completed, failed def __repr__(self) -> str: @@ -60,9 +60,9 @@ class AgentResponse: content: str role: str = "assistant" - tool_calls: Optional[List[ToolCall]] = None + tool_calls: list[ToolCall] | None = None requires_follow_up: bool = False # Indicates if tool execution is needed - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) timestamp: float = field(default_factory=time.time) def __repr__(self) -> str: @@ -80,13 +80,13 @@ class ConversationMessage: """ role: str # "system", "user", "assistant", "tool" - content: Union[str, List[Dict[str, Any]]] # Text or content blocks - tool_calls: Optional[List[ToolCall]] = None - tool_call_id: Optional[str] = None # For tool responses - name: Optional[str] = None # For tool messages (function name) + content: str | list[dict[str, Any]] # Text or content blocks + tool_calls: list[ToolCall] | None = None + tool_call_id: str | None = None # For tool responses + name: str | None = None # For tool messages (function name) timestamp: float = field(default_factory=time.time) - def to_openai_format(self) -> Dict[str, Any]: + def to_openai_format(self) -> dict[str, Any]: """Convert to OpenAI API format.""" msg = {"role": self.role} @@ -136,17 +136,17 @@ class ConversationHistory: LLM providers and automatic trimming. """ - def __init__(self, max_size: int = 20): + def __init__(self, max_size: int = 20) -> None: """Initialize conversation history. Args: max_size: Maximum number of messages to keep """ - self._messages: List[ConversationMessage] = [] + self._messages: list[ConversationMessage] = [] self._lock = threading.Lock() self.max_size = max_size - def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None: + def add_user_message(self, content: str | list[dict[str, Any]]) -> None: """Add user message to history. Args: @@ -156,9 +156,7 @@ def add_user_message(self, content: Union[str, List[Dict[str, Any]]]) -> None: self._messages.append(ConversationMessage(role="user", content=content)) self._trim() - def add_assistant_message( - self, content: str, tool_calls: Optional[List[ToolCall]] = None - ) -> None: + def add_assistant_message(self, content: str, tool_calls: list[ToolCall] | None = None) -> None: """Add assistant response to history. Args: @@ -171,7 +169,7 @@ def add_assistant_message( ) self._trim() - def add_tool_result(self, tool_call_id: str, content: str, name: Optional[str] = None) -> None: + def add_tool_result(self, tool_call_id: str, content: str, name: str | None = None) -> None: """Add tool execution result to history. Args: @@ -187,7 +185,7 @@ def add_tool_result(self, tool_call_id: str, content: str, name: Optional[str] = ) self._trim() - def add_raw_message(self, message: Dict[str, Any]) -> None: + def add_raw_message(self, message: dict[str, Any]) -> None: """Add a raw message dict to history. Args: @@ -223,7 +221,7 @@ def add_raw_message(self, message: Dict[str, Any]) -> None: ) self._trim() - def to_openai_format(self) -> List[Dict[str, Any]]: + def to_openai_format(self) -> list[dict[str, Any]]: """Export history in OpenAI format. Returns: diff --git a/dimos/agents/cerebras_agent.py b/dimos/agents/cerebras_agent.py index 854beb848d..e58de812d0 100644 --- a/dimos/agents/cerebras_agent.py +++ b/dimos/agents/cerebras_agent.py @@ -20,31 +20,32 @@ from __future__ import annotations -import os -import threading import copy -from typing import Any, Dict, List, Optional, Union, Tuple -import logging import json -import re +import os +import threading import time +from typing import TYPE_CHECKING from cerebras.cloud.sdk import Cerebras from dotenv import load_dotenv -from pydantic import BaseModel -from reactivex import Observable -from reactivex.observer import Observer -from reactivex.scheduler import ThreadPoolScheduler # Local imports from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.prompt_builder.impl import PromptBuilder -from dimos.agents.tokenizer.base import AbstractTokenizer +from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos.stream.frame_processor import FrameProcessor from dimos.utils.logging_config import setup_logger -from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer + +if TYPE_CHECKING: + from pydantic import BaseModel + from reactivex import Observable + from reactivex.observer import Observer + from reactivex.scheduler import ThreadPoolScheduler + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.tokenizer.base import AbstractTokenizer + from dimos.stream.frame_processor import FrameProcessor # Initialize environment variables load_dotenv() @@ -57,9 +58,9 @@ class CerebrasResponseMessage(dict): def __init__( self, - content="", + content: str = "", tool_calls=None, - ): + ) -> None: self.content = content self.tool_calls = tool_calls or [] self.parsed = None @@ -67,7 +68,7 @@ def __init__( # Initialize as dict with the proper structure super().__init__(self.to_dict()) - def __str__(self): + def __str__(self) -> str: # Return a string representation for logging if self.content: return self.content @@ -115,24 +116,24 @@ def __init__( dev_name: str, agent_type: str = "Vision", query: str = "What do you see?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, + input_video_stream: Observable | None = None, + input_data_stream: Observable | None = None, output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, max_input_tokens_per_request: int = 128000, max_output_tokens_per_request: int = 16384, model_name: str = "llama-4-scout-17b-16e-instruct", - skills: Optional[Union[AbstractSkill, list[AbstractSkill], SkillLibrary]] = None, - response_model: Optional[BaseModel] = None, - frame_processor: Optional[FrameProcessor] = None, + skills: AbstractSkill | list[AbstractSkill] | SkillLibrary | None = None, + response_model: BaseModel | None = None, + frame_processor: FrameProcessor | None = None, image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - tokenizer: Optional[AbstractTokenizer] = None, - prompt_builder: Optional[PromptBuilder] = None, - ): + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + tokenizer: AbstractTokenizer | None = None, + prompt_builder: PromptBuilder | None = None, + ) -> None: """ Initializes a new instance of the CerebrasAgent. @@ -229,7 +230,7 @@ def __init__( logger.info("Cerebras Agent Initialized.") - def _add_context_to_memory(self): + def _add_context_to_memory(self) -> None: """Adds initial context to the agent's memory.""" context_data = [ ( @@ -256,8 +257,8 @@ def _add_context_to_memory(self): def _build_prompt( self, messages: list, - base64_image: Optional[Union[str, List[str]]] = None, - dimensions: Optional[Tuple[int, int]] = None, + base64_image: str | list[str] | None = None, + dimensions: tuple[int, int] | None = None, override_token_limit: bool = False, condensed_results: str = "", ) -> list: @@ -405,7 +406,11 @@ def clean_cerebras_schema(self, schema: dict) -> dict: return cleaned def create_tool_call( - self, name: str = None, arguments: dict = None, call_id: str = None, content: str = None + self, + name: str | None = None, + arguments: dict | None = None, + call_id: str | None = None, + content: str | None = None, ): """Create a tool call object from either direct parameters or JSON content.""" # If content is provided, parse it as JSON @@ -520,10 +525,10 @@ def _send_query(self, messages: list) -> CerebrasResponseMessage: def _observable_query( self, observer: Observer, - base64_image: Optional[str] = None, - dimensions: Optional[Tuple[int, int]] = None, + base64_image: str | None = None, + dimensions: tuple[int, int] | None = None, override_token_limit: bool = False, - incoming_query: Optional[str] = None, + incoming_query: str | None = None, reset_conversation: bool = False, ): """Main query handler that manages conversation history and Cerebras interactions. diff --git a/dimos/agents/claude_agent.py b/dimos/agents/claude_agent.py index e87b1f47b4..c8163de162 100644 --- a/dimos/agents/claude_agent.py +++ b/dimos/agents/claude_agent.py @@ -23,22 +23,25 @@ import json import os -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any import anthropic from dotenv import load_dotenv -from pydantic import BaseModel -from reactivex import Observable -from reactivex.scheduler import ThreadPoolScheduler # Local imports from dimos.agents.agent import LLMAgent -from dimos.agents.memory.base import AbstractAgentSemanticMemory -from dimos.agents.prompt_builder.impl import PromptBuilder from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.stream.frame_processor import FrameProcessor from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from pydantic import BaseModel + from reactivex import Observable + from reactivex.scheduler import ThreadPoolScheduler + + from dimos.agents.memory.base import AbstractAgentSemanticMemory + from dimos.agents.prompt_builder.impl import PromptBuilder + # Initialize environment variables load_dotenv() @@ -48,13 +51,13 @@ # Response object compatible with LLMAgent class ResponseMessage: - def __init__(self, content="", tool_calls=None, thinking_blocks=None): + def __init__(self, content: str = "", tool_calls=None, thinking_blocks=None) -> None: self.content = content self.tool_calls = tool_calls or [] self.thinking_blocks = thinking_blocks or [] self.parsed = None - def __str__(self): + def __str__(self) -> str: # Return a string representation for logging parts = [] @@ -82,26 +85,26 @@ def __init__( dev_name: str, agent_type: str = "Vision", query: str = "What do you see?", - input_query_stream: Optional[Observable] = None, - input_video_stream: Optional[Observable] = None, - input_data_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, + input_video_stream: Observable | None = None, + input_data_stream: Observable | None = None, output_dir: str = os.path.join(os.getcwd(), "assets", "agent"), - agent_memory: Optional[AbstractAgentSemanticMemory] = None, - system_query: Optional[str] = None, + agent_memory: AbstractAgentSemanticMemory | None = None, + system_query: str | None = None, max_input_tokens_per_request: int = 128000, max_output_tokens_per_request: int = 16384, model_name: str = "claude-3-7-sonnet-20250219", - prompt_builder: Optional[PromptBuilder] = None, + prompt_builder: PromptBuilder | None = None, rag_query_n: int = 4, rag_similarity_threshold: float = 0.45, - skills: Optional[AbstractSkill] = None, - response_model: Optional[BaseModel] = None, - frame_processor: Optional[FrameProcessor] = None, + skills: AbstractSkill | None = None, + response_model: BaseModel | None = None, + frame_processor: FrameProcessor | None = None, image_detail: str = "low", - pool_scheduler: Optional[ThreadPoolScheduler] = None, - process_all_inputs: Optional[bool] = None, - thinking_budget_tokens: Optional[int] = 2000, - ): + pool_scheduler: ThreadPoolScheduler | None = None, + process_all_inputs: bool | None = None, + thinking_budget_tokens: int | None = 2000, + ) -> None: """ Initializes a new instance of the ClaudeAgent. @@ -192,7 +195,7 @@ def __init__( logger.info("Claude Agent Initialized.") - def _add_context_to_memory(self): + def _add_context_to_memory(self) -> None: """Adds initial context to the agent's memory.""" context_data = [ ( @@ -216,7 +219,7 @@ def _add_context_to_memory(self): for doc_id, text in context_data: self.agent_memory.add_vector(doc_id, text) - def _convert_tools_to_claude_format(self, tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + def _convert_tools_to_claude_format(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """ Converts DIMOS tools to Claude format. @@ -258,11 +261,11 @@ def _convert_tools_to_claude_format(self, tools: List[Dict[str, Any]]) -> List[D def _build_prompt( self, messages: list, - base64_image: Optional[Union[str, List[str]]] = None, - dimensions: Optional[Tuple[int, int]] = None, + base64_image: str | list[str] | None = None, + dimensions: tuple[int, int] | None = None, override_token_limit: bool = False, rag_results: str = "", - thinking_budget_tokens: int = None, + thinking_budget_tokens: int | None = None, ) -> list: """Builds a prompt message specifically for Claude API, using local messages copy.""" """Builds a prompt message specifically for Claude API. @@ -535,13 +538,13 @@ def _send_query(self, messages: list, claude_params: dict) -> Any: def _observable_query( self, observer: Observer, - base64_image: Optional[str] = None, - dimensions: Optional[Tuple[int, int]] = None, + base64_image: str | None = None, + dimensions: tuple[int, int] | None = None, override_token_limit: bool = False, - incoming_query: Optional[str] = None, + incoming_query: str | None = None, reset_conversation: bool = False, - thinking_budget_tokens: int = None, - ): + thinking_budget_tokens: int | None = None, + ) -> None: """Main query handler that manages conversation history and Claude interactions. This is the primary method for handling all queries, whether they come through @@ -695,7 +698,7 @@ def _handle_tooling(self, response_message, messages): } ) - def _tooling_callback(self, response_message): + def _tooling_callback(self, response_message) -> None: """Runs the observable query for each tool call in the current response_message""" if not hasattr(response_message, "tool_calls") or not response_message.tool_calls: return diff --git a/dimos/agents/memory/base.py b/dimos/agents/memory/base.py index af8cbf689f..eb48dcca44 100644 --- a/dimos/agents/memory/base.py +++ b/dimos/agents/memory/base.py @@ -13,9 +13,10 @@ # limitations under the License. from abc import abstractmethod + from dimos.exceptions.agent_memory_exceptions import ( - UnknownConnectionTypeError, AgentMemoryConnectionError, + UnknownConnectionTypeError, ) from dimos.utils.logging_config import setup_logger @@ -27,7 +28,7 @@ class AbstractAgentSemanticMemory: # AbstractAgentMemory): - def __init__(self, connection_type="local", **kwargs): + def __init__(self, connection_type: str = "local", **kwargs) -> None: """ Initialize with dynamic connection parameters. Args: @@ -87,7 +88,7 @@ def get_vector(self, vector_id): """ @abstractmethod - def query(self, query_texts, n_results=4, similarity_threshold=None): + def query(self, query_texts, n_results: int = 4, similarity_threshold=None): """Performs a semantic search in the vector database. Args: diff --git a/dimos/agents/memory/chroma_impl.py b/dimos/agents/memory/chroma_impl.py index 06f6989355..b238b616d8 100644 --- a/dimos/agents/memory/chroma_impl.py +++ b/dimos/agents/memory/chroma_impl.py @@ -12,18 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.agents.memory.base import AbstractAgentSemanticMemory +from collections.abc import Sequence +import os -from langchain_openai import OpenAIEmbeddings from langchain_chroma import Chroma -import os +from langchain_openai import OpenAIEmbeddings import torch +from dimos.agents.memory.base import AbstractAgentSemanticMemory + class ChromaAgentSemanticMemory(AbstractAgentSemanticMemory): """Base class for Chroma-based semantic memory implementations.""" - def __init__(self, collection_name="my_collection"): + def __init__(self, collection_name: str = "my_collection") -> None: """Initialize the connection to the local Chroma DB.""" self.collection_name = collection_name self.db_connection = None @@ -54,7 +56,7 @@ def get_vector(self, vector_id): result = self.db_connection.get(include=["embeddings"], ids=[vector_id]) return result - def query(self, query_texts, n_results=4, similarity_threshold=None): + def query(self, query_texts, n_results: int = 4, similarity_threshold=None): """Query the collection with a specific text and return up to n results.""" if not self.db_connection: raise Exception("Collection not initialized. Call connect() first.") @@ -84,8 +86,11 @@ class OpenAISemanticMemory(ChromaAgentSemanticMemory): """Semantic memory implementation using OpenAI's embedding API.""" def __init__( - self, collection_name="my_collection", model="text-embedding-3-large", dimensions=1024 - ): + self, + collection_name: str = "my_collection", + model: str = "text-embedding-3-large", + dimensions: int = 1024, + ) -> None: """Initialize OpenAI-based semantic memory. Args: @@ -123,8 +128,10 @@ class LocalSemanticMemory(ChromaAgentSemanticMemory): """Semantic memory implementation using local models.""" def __init__( - self, collection_name="my_collection", model_name="sentence-transformers/all-MiniLM-L6-v2" - ): + self, + collection_name: str = "my_collection", + model_name: str = "sentence-transformers/all-MiniLM-L6-v2", + ) -> None: """Initialize the local semantic memory using SentenceTransformer. Args: @@ -135,7 +142,7 @@ def __init__( self.model_name = model_name super().__init__(collection_name=collection_name) - def create(self): + def create(self) -> None: """Create local embedding model and initialize the ChromaDB client.""" # Load the sentence transformer model # Use CUDA if available, otherwise fall back to CPU @@ -145,14 +152,14 @@ def create(self): # Create a custom embedding class that implements the embed_query method class SentenceTransformerEmbeddings: - def __init__(self, model): + def __init__(self, model) -> None: self.model = model - def embed_query(self, text): + def embed_query(self, text: str): """Embed a single query text.""" return self.model.encode(text, normalize_embeddings=True).tolist() - def embed_documents(self, texts): + def embed_documents(self, texts: Sequence[str]): """Embed multiple documents/texts.""" return self.model.encode(texts, normalize_embeddings=True).tolist() diff --git a/dimos/agents/memory/image_embedding.py b/dimos/agents/memory/image_embedding.py index 142839abd9..7b6dd88515 100644 --- a/dimos/agents/memory/image_embedding.py +++ b/dimos/agents/memory/image_embedding.py @@ -22,7 +22,6 @@ import base64 import io import os -from typing import Union import cv2 import numpy as np @@ -42,7 +41,7 @@ class ImageEmbeddingProvider: that can be stored in a vector database and used for similarity search. """ - def __init__(self, model_name: str = "clip", dimensions: int = 512): + def __init__(self, model_name: str = "clip", dimensions: int = 512) -> None: """ Initialize the image embedding provider. @@ -94,7 +93,7 @@ def _initialize_model(self): self.processor = None raise - def get_embedding(self, image: Union[np.ndarray, str, bytes]) -> np.ndarray: + def get_embedding(self, image: np.ndarray | str | bytes) -> np.ndarray: """ Generate an embedding vector for the provided image. @@ -234,7 +233,7 @@ def get_text_embedding(self, text: str) -> np.ndarray: logger.error(f"Error generating text embedding: {e}") return np.random.randn(self.dimensions).astype(np.float32) - def _prepare_image(self, image: Union[np.ndarray, str, bytes]) -> Image.Image: + def _prepare_image(self, image: np.ndarray | str | bytes) -> Image.Image: """ Convert the input image to PIL format required by the models. diff --git a/dimos/agents/memory/spatial_vector_db.py b/dimos/agents/memory/spatial_vector_db.py index a4eefb792b..ac5dcc026a 100644 --- a/dimos/agents/memory/spatial_vector_db.py +++ b/dimos/agents/memory/spatial_vector_db.py @@ -19,9 +19,10 @@ their XY locations and querying by location or image similarity. """ -import numpy as np -from typing import List, Dict, Optional, Tuple, Any +from typing import Any + import chromadb +import numpy as np from dimos.agents.memory.visual_memory import VisualMemory from dimos.types.robot_location import RobotLocation @@ -44,7 +45,7 @@ def __init__( chroma_client=None, visual_memory=None, embedding_provider=None, - ): + ) -> None: """ Initialize the spatial vector database. @@ -104,11 +105,11 @@ def __init__( logger.info(f"Created NEW {client_type} collection '{collection_name}'") except Exception as e: logger.info( - f"Initialized {client_type} collection '{collection_name}' (count error: {str(e)})" + f"Initialized {client_type} collection '{collection_name}' (count error: {e!s})" ) def add_image_vector( - self, vector_id: str, image: np.ndarray, embedding: np.ndarray, metadata: Dict[str, Any] + self, vector_id: str, image: np.ndarray, embedding: np.ndarray, metadata: dict[str, Any] ) -> None: """ Add an image with its embedding and metadata to the vector database. @@ -129,7 +130,7 @@ def add_image_vector( logger.info(f"Added image vector {vector_id} with metadata: {metadata}") - def query_by_embedding(self, embedding: np.ndarray, limit: int = 5) -> List[Dict]: + def query_by_embedding(self, embedding: np.ndarray, limit: int = 5) -> list[dict]: """ Query the vector database for images similar to the provided embedding. @@ -149,7 +150,7 @@ def query_by_embedding(self, embedding: np.ndarray, limit: int = 5) -> List[Dict # TODO: implement efficient nearest neighbor search def query_by_location( self, x: float, y: float, radius: float = 2.0, limit: int = 5 - ) -> List[Dict]: + ) -> list[dict]: """ Query the vector database for images near the specified location. @@ -192,7 +193,7 @@ def query_by_location( return self._process_query_results(filtered_results) - def _process_query_results(self, results) -> List[Dict]: + def _process_query_results(self, results) -> list[dict]: """Process query results to include decoded images.""" if not results or not results["ids"]: return [] @@ -227,7 +228,7 @@ def _process_query_results(self, results) -> List[Dict]: return processed_results - def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: + def query_by_text(self, text: str, limit: int = 5) -> list[dict]: """ Query the vector database for images matching the provided text description. @@ -259,7 +260,7 @@ def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: ) return self._process_query_results(results) - def get_all_locations(self) -> List[Tuple[float, float, float]]: + def get_all_locations(self) -> list[tuple[float, float, float]]: """Get all locations stored in the database.""" # Get all items from the collection without embeddings results = self.image_collection.get(include=["metadatas"]) @@ -301,7 +302,7 @@ def tag_location(self, location: RobotLocation) -> None: ids=[location_id], documents=[location.name], metadatas=[metadata] ) - def query_tagged_location(self, query: str) -> Tuple[Optional[RobotLocation], float]: + def query_tagged_location(self, query: str) -> tuple[RobotLocation | None, float]: """ Query for a tagged location using semantic text search. diff --git a/dimos/agents/memory/test_image_embedding.py b/dimos/agents/memory/test_image_embedding.py index 0a28ac11b7..b1e7cabf09 100644 --- a/dimos/agents/memory/test_image_embedding.py +++ b/dimos/agents/memory/test_image_embedding.py @@ -21,7 +21,6 @@ import numpy as np import pytest -import reactivex as rx from reactivex import operators as ops from dimos.agents.memory.image_embedding import ImageEmbeddingProvider @@ -33,7 +32,7 @@ class TestImageEmbedding: """Test class for CLIP image embedding functionality.""" @pytest.mark.tofix - def test_clip_embedding_initialization(self): + def test_clip_embedding_initialization(self) -> None: """Test CLIP embedding provider initializes correctly.""" try: # Initialize the embedding provider with CLIP model @@ -46,7 +45,7 @@ def test_clip_embedding_initialization(self): pytest.skip(f"Skipping test due to model initialization error: {e}") @pytest.mark.tofix - def test_clip_embedding_process_video(self): + def test_clip_embedding_process_video(self) -> None: """Test CLIP embedding provider can process video frames and return embeddings.""" try: from dimos.utils.data import get_data @@ -80,7 +79,7 @@ def process_frame(frame): frames_processed = 0 target_frames = 10 - def on_next(result): + def on_next(result) -> None: nonlocal frames_processed, results if not result: # Skip None results return @@ -92,10 +91,10 @@ def on_next(result): if frames_processed >= target_frames: subscription.dispose() - def on_error(error): + def on_error(error) -> None: pytest.fail(f"Error in embedding stream: {error}") - def on_completed(): + def on_completed() -> None: pass # Subscribe and wait for results @@ -143,7 +142,7 @@ def on_completed(): "embedding1": results[0]["embedding"], "embedding2": results[1]["embedding"] if len(results) > 1 else None, } - print(f"Saved embeddings for similarity testing") + print("Saved embeddings for similarity testing") print("CLIP embedding test passed successfully!") @@ -151,7 +150,7 @@ def on_completed(): pytest.fail(f"Test failed with error: {e}") @pytest.mark.tofix - def test_clip_embedding_similarity(self): + def test_clip_embedding_similarity(self) -> None: """Test CLIP embedding similarity search and text-to-image queries.""" try: # Skip if previous test didn't generate embeddings diff --git a/dimos/agents/memory/visual_memory.py b/dimos/agents/memory/visual_memory.py index 0087a4fe9b..90f1272fef 100644 --- a/dimos/agents/memory/visual_memory.py +++ b/dimos/agents/memory/visual_memory.py @@ -16,13 +16,13 @@ Visual memory storage for managing image data persistence and retrieval """ +import base64 import os import pickle -import base64 -import numpy as np + import cv2 +import numpy as np -from typing import Optional from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.agents.memory.visual_memory") @@ -37,7 +37,7 @@ class VisualMemory: load the image data from disk. """ - def __init__(self, output_dir: str = None): + def __init__(self, output_dir: str | None = None) -> None: """ Initialize the visual memory system. @@ -74,7 +74,7 @@ def add(self, image_id: str, image: np.ndarray) -> None: self.images[image_id] = b64_encoded logger.debug(f"Added image {image_id} to visual memory") - def get(self, image_id: str) -> Optional[np.ndarray]: + def get(self, image_id: str) -> np.ndarray | None: """ Retrieve an image from visual memory. @@ -97,7 +97,7 @@ def get(self, image_id: str) -> Optional[np.ndarray]: image = cv2.imdecode(image_array, cv2.IMREAD_COLOR) return image except Exception as e: - logger.warning(f"Failed to decode image for ID {image_id}: {str(e)}") + logger.warning(f"Failed to decode image for ID {image_id}: {e!s}") return None def contains(self, image_id: str) -> bool: @@ -121,7 +121,7 @@ def count(self) -> int: """ return len(self.images) - def save(self, filename: Optional[str] = None) -> str: + def save(self, filename: str | None = None) -> str: """ Save the visual memory to disk. @@ -146,11 +146,11 @@ def save(self, filename: Optional[str] = None) -> str: logger.info(f"Saved {len(self.images)} images to {output_path}") return output_path except Exception as e: - logger.error(f"Failed to save visual memory: {str(e)}") + logger.error(f"Failed to save visual memory: {e!s}") return "" @classmethod - def load(cls, path: str, output_dir: Optional[str] = None) -> "VisualMemory": + def load(cls, path: str, output_dir: str | None = None) -> "VisualMemory": """ Load visual memory from disk. @@ -173,7 +173,7 @@ def load(cls, path: str, output_dir: Optional[str] = None) -> "VisualMemory": logger.info(f"Loaded {len(instance.images)} images from {path}") return instance except Exception as e: - logger.error(f"Failed to load visual memory: {str(e)}") + logger.error(f"Failed to load visual memory: {e!s}") return instance def clear(self) -> None: diff --git a/dimos/agents/modules/agent_pool.py b/dimos/agents/modules/agent_pool.py index c5b466159f..08ef943765 100644 --- a/dimos/agents/modules/agent_pool.py +++ b/dimos/agents/modules/agent_pool.py @@ -14,14 +14,14 @@ """Agent pool module for managing multiple agents.""" -from typing import Any, Dict, List, Union +from typing import Any from reactivex import operators as ops from reactivex.subject import Subject -from dimos.core import Module, In, Out, rpc from dimos.agents.modules.base_agent import BaseAgentModule from dimos.agents.modules.unified_agent import UnifiedAgentModule +from dimos.core import In, Module, Out, rpc from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.agents.modules.agent_pool") @@ -38,10 +38,12 @@ class AgentPoolModule(Module): """ # Module I/O - query_in: In[Dict[str, Any]] = None # {agent_id: str, query: str, ...} - response_out: Out[Dict[str, Any]] = None # {agent_id: str, response: str, ...} + query_in: In[dict[str, Any]] = None # {agent_id: str, query: str, ...} + response_out: Out[dict[str, Any]] = None # {agent_id: str, response: str, ...} - def __init__(self, agents_config: Dict[str, Dict[str, Any]], default_agent: str = None): + def __init__( + self, agents_config: dict[str, dict[str, Any]], default_agent: str | None = None + ) -> None: """Initialize agent pool. Args: @@ -66,7 +68,7 @@ def __init__(self, agents_config: Dict[str, Dict[str, Any]], default_agent: str self._response_subject = Subject() @rpc - def start(self): + def start(self) -> None: """Deploy and start all agents.""" super().start() logger.info(f"Starting agent pool with {len(self._config)} agents") @@ -103,7 +105,7 @@ def start(self): logger.info("Agent pool started") @rpc - def stop(self): + def stop(self) -> None: """Stop all agents.""" logger.info("Stopping agent pool") @@ -119,7 +121,7 @@ def stop(self): super().stop() @rpc - def add_agent(self, agent_id: str, config: Dict[str, Any]): + def add_agent(self, agent_id: str, config: dict[str, Any]) -> None: """Add a new agent to the pool.""" if agent_id in self._agents: logger.warning(f"Agent {agent_id} already exists") @@ -142,7 +144,7 @@ def add_agent(self, agent_id: str, config: Dict[str, Any]): logger.info(f"Added agent: {agent_id}") @rpc - def remove_agent(self, agent_id: str): + def remove_agent(self, agent_id: str) -> None: """Remove an agent from the pool.""" if agent_id not in self._agents: logger.warning(f"Agent {agent_id} not found") @@ -156,7 +158,7 @@ def remove_agent(self, agent_id: str): logger.info(f"Removed agent: {agent_id}") @rpc - def list_agents(self) -> List[Dict[str, Any]]: + def list_agents(self) -> list[dict[str, Any]]: """List all agents and their configurations.""" return [ {"id": agent_id, "type": info["type"], "model": info["config"].get("model", "unknown")} @@ -164,7 +166,7 @@ def list_agents(self) -> List[Dict[str, Any]]: ] @rpc - def broadcast_query(self, query: str, exclude: List[str] = None): + def broadcast_query(self, query: str, exclude: list[str] | None = None) -> None: """Send query to all agents (except excluded ones).""" exclude = exclude or [] @@ -175,12 +177,12 @@ def broadcast_query(self, query: str, exclude: List[str] = None): logger.info(f"Broadcasted query to {len(self._agents) - len(exclude)} agents") def _setup_agent_routing( - self, agent_id: str, agent: Union[BaseAgentModule, UnifiedAgentModule] - ): + self, agent_id: str, agent: BaseAgentModule | UnifiedAgentModule + ) -> None: """Setup response routing for an agent.""" # Subscribe to agent responses and tag with agent_id - def tag_response(response: str) -> Dict[str, Any]: + def tag_response(response: str) -> dict[str, Any]: return { "agent_id": agent_id, "response": response, @@ -193,7 +195,7 @@ def tag_response(response: str) -> Dict[str, Any]: .subscribe(self._response_subject.on_next) ) - def _route_query(self, msg: Dict[str, Any]): + def _route_query(self, msg: dict[str, Any]) -> None: """Route incoming query to appropriate agent(s).""" # Extract routing info agent_id = msg.get("agent_id", self._default_agent) diff --git a/dimos/agents/modules/base.py b/dimos/agents/modules/base.py index ef778e2da4..9caaac49cc 100644 --- a/dimos/agents/modules/base.py +++ b/dimos/agents/modules/base.py @@ -15,18 +15,18 @@ """Base agent class with all features (non-module).""" import asyncio -import json from concurrent.futures import ThreadPoolExecutor -from typing import Any, Dict, List, Optional, Union +import json +from typing import Any from reactivex.subject import Subject +from dimos.agents.agent_message import AgentMessage +from dimos.agents.agent_types import AgentResponse, ConversationHistory, ToolCall from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.memory.chroma_impl import OpenAISemanticMemory from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.utils.logging_config import setup_logger -from dimos.agents.agent_message import AgentMessage -from dimos.agents.agent_types import AgentResponse, ToolCall, ConversationHistory try: from .gateway import UnifiedGatewayClient @@ -66,21 +66,21 @@ class BaseAgent: def __init__( self, model: str = "openai::gpt-4o-mini", - system_prompt: Optional[str] = None, - skills: Optional[Union[SkillLibrary, List[AbstractSkill], AbstractSkill]] = None, - memory: Optional[AbstractAgentSemanticMemory] = None, + system_prompt: str | None = None, + skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None, + memory: AbstractAgentSemanticMemory | None = None, temperature: float = 0.0, max_tokens: int = 4096, max_input_tokens: int = 128000, max_history: int = 20, rag_n: int = 4, rag_threshold: float = 0.45, - seed: Optional[int] = None, + seed: int | None = None, # Legacy compatibility dev_name: str = "BaseAgent", agent_type: str = "LLM", **kwargs, - ): + ) -> None: """Initialize the base agent with all features. Args: @@ -155,7 +155,7 @@ def max_history(self) -> int: return self._max_history @max_history.setter - def max_history(self, value: int): + def max_history(self, value: int) -> None: """Set max history size and update conversation.""" self._max_history = value self.conversation.max_size = value @@ -164,7 +164,7 @@ def _check_vision_support(self) -> bool: """Check if the model supports vision.""" return self.model in VISION_MODELS - def _initialize_memory(self): + def _initialize_memory(self) -> None: """Initialize memory with default context.""" try: contexts = [ @@ -252,7 +252,7 @@ async def _process_query_async(self, agent_msg: AgentMessage) -> AgentResponse: # Check for tool calls tool_calls = None - if "tool_calls" in message and message["tool_calls"]: + if message.get("tool_calls"): tool_calls = [ ToolCall( id=tc["id"], @@ -319,7 +319,7 @@ def _get_rag_context(self, query: str) -> str: def _build_messages( self, agent_msg: AgentMessage, rag_context: str = "" - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Build messages list from AgentMessage.""" messages = [] @@ -376,9 +376,9 @@ def _build_messages( async def _handle_tool_calls( self, - tool_calls: List[ToolCall], - messages: List[Dict[str, Any]], - user_message: Dict[str, Any], + tool_calls: list[ToolCall], + messages: list[dict[str, Any]], + user_message: dict[str, Any], ) -> str: """Handle tool calls from LLM (blocking mode by default).""" try: @@ -424,7 +424,7 @@ async def _handle_tool_calls( tool_result = { "role": "tool", "tool_call_id": tool_call.id, - "content": f"Error: {str(e)}", + "content": f"Error: {e!s}", "name": tool_call.name, } tool_results.append(tool_result) @@ -472,9 +472,9 @@ async def _handle_tool_calls( except Exception as e: logger.error(f"Error handling tool calls: {e}") - return f"Error executing tools: {str(e)}" + return f"Error executing tools: {e!s}" - def query(self, message: Union[str, AgentMessage]) -> AgentResponse: + def query(self, message: str | AgentMessage) -> AgentResponse: """Synchronous query method for direct usage. Args: @@ -498,7 +498,7 @@ def query(self, message: Union[str, AgentMessage]) -> AgentResponse: finally: loop.close() - async def aquery(self, message: Union[str, AgentMessage]) -> AgentResponse: + async def aquery(self, message: str | AgentMessage) -> AgentResponse: """Asynchronous query method. Args: diff --git a/dimos/agents/modules/base_agent.py b/dimos/agents/modules/base_agent.py index 3c83214f6c..0bceb1112e 100644 --- a/dimos/agents/modules/base_agent.py +++ b/dimos/agents/modules/base_agent.py @@ -15,12 +15,12 @@ """Base agent module that wraps BaseAgent for DimOS module usage.""" import threading -from typing import Any, Dict, List, Optional, Union +from typing import Any -from dimos.core import Module, In, Out, rpc -from dimos.agents.memory.base import AbstractAgentSemanticMemory from dimos.agents.agent_message import AgentMessage from dimos.agents.agent_types import AgentResponse +from dimos.agents.memory.base import AbstractAgentSemanticMemory +from dimos.core import In, Module, Out, rpc from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.utils.logging_config import setup_logger @@ -46,9 +46,9 @@ class BaseAgentModule(BaseAgent, Module): def __init__( self, model: str = "openai::gpt-4o-mini", - system_prompt: Optional[str] = None, - skills: Optional[Union[SkillLibrary, List[AbstractSkill], AbstractSkill]] = None, - memory: Optional[AbstractAgentSemanticMemory] = None, + system_prompt: str | None = None, + skills: SkillLibrary | list[AbstractSkill] | AbstractSkill | None = None, + memory: AbstractAgentSemanticMemory | None = None, temperature: float = 0.0, max_tokens: int = 4096, max_input_tokens: int = 128000, @@ -57,7 +57,7 @@ def __init__( rag_threshold: float = 0.45, process_all_inputs: bool = False, **kwargs, - ): + ) -> None: """Initialize the agent module. Args: @@ -107,7 +107,7 @@ def __init__( self._data_lock = threading.Lock() @rpc - def start(self): + def start(self) -> None: """Start the agent module and connect streams.""" super().start() logger.info(f"Starting agent module with model: {self.model}") @@ -132,7 +132,7 @@ def start(self): logger.info("Agent module started") @rpc - def stop(self): + def stop(self) -> None: """Stop the agent module.""" logger.info("Stopping agent module") @@ -148,31 +148,31 @@ def stop(self): super().stop() @rpc - def clear_history(self): + def clear_history(self) -> None: """Clear conversation history.""" with self._history_lock: self.history = [] logger.info("Conversation history cleared") @rpc - def add_skill(self, skill: AbstractSkill): + def add_skill(self, skill: AbstractSkill) -> None: """Add a skill to the agent.""" self.skills.add(skill) logger.info(f"Added skill: {skill.__class__.__name__}") @rpc - def set_system_prompt(self, prompt: str): + def set_system_prompt(self, prompt: str) -> None: """Update system prompt.""" self.system_prompt = prompt logger.info("System prompt updated") @rpc - def get_conversation_history(self) -> List[Dict[str, Any]]: + def get_conversation_history(self) -> list[dict[str, Any]]: """Get current conversation history.""" with self._history_lock: return self.history.copy() - def _handle_agent_message(self, message: AgentMessage): + def _handle_agent_message(self, message: AgentMessage) -> None: """Handle AgentMessage from module input.""" # Process through BaseAgent query method try: @@ -183,7 +183,7 @@ def _handle_agent_message(self, message: AgentMessage): logger.error(f"Agent message processing error: {e}") self.response_subject.on_error(e) - def _handle_module_query(self, query: str): + def _handle_module_query(self, query: str) -> None: """Handle legacy query from module input.""" # For simple text queries, just convert to AgentMessage agent_msg = AgentMessage() @@ -192,17 +192,17 @@ def _handle_module_query(self, query: str): # Process through unified handler self._handle_agent_message(agent_msg) - def _update_latest_data(self, data: Dict[str, Any]): + def _update_latest_data(self, data: dict[str, Any]) -> None: """Update latest data context.""" with self._data_lock: self._latest_data = data - def _update_latest_image(self, img: Any): + def _update_latest_image(self, img: Any) -> None: """Update latest image.""" with self._image_lock: self._latest_image = img - def _format_data_context(self, data: Dict[str, Any]) -> str: + def _format_data_context(self, data: dict[str, Any]) -> str: """Format data dictionary as context string.""" # Simple formatting - can be customized parts = [] diff --git a/dimos/agents/modules/gateway/client.py b/dimos/agents/modules/gateway/client.py index f873f0ec64..6d8abf5e14 100644 --- a/dimos/agents/modules/gateway/client.py +++ b/dimos/agents/modules/gateway/client.py @@ -15,9 +15,12 @@ """Unified gateway client for LLM access.""" import asyncio +from collections.abc import AsyncIterator, Iterator import logging import os -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union +from types import TracebackType +from typing import Any + import httpx from tenacity import retry, stop_after_attempt, wait_exponential @@ -34,8 +37,8 @@ class UnifiedGatewayClient: """ def __init__( - self, gateway_url: Optional[str] = None, timeout: float = 60.0, use_simple: bool = False - ): + self, gateway_url: str | None = None, timeout: float = 60.0, use_simple: bool = False + ) -> None: """Initialize the gateway client. Args: @@ -82,13 +85,13 @@ def _get_async_client(self) -> httpx.AsyncClient: def inference( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, temperature: float = 0.0, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, stream: bool = False, **kwargs, - ) -> Union[Dict[str, Any], Iterator[Dict[str, Any]]]: + ) -> dict[str, Any] | Iterator[dict[str, Any]]: """Synchronous inference call. Args: @@ -117,13 +120,13 @@ def inference( async def ainference( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, temperature: float = 0.0, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, stream: bool = False, **kwargs, - ) -> Union[Dict[str, Any], AsyncIterator[Dict[str, Any]]]: + ) -> dict[str, Any] | AsyncIterator[dict[str, Any]]: """Asynchronous inference call. Args: @@ -148,7 +151,7 @@ async def ainference( **kwargs, ) - def close(self): + def close(self) -> None: """Close the HTTP clients.""" if self._client: self._client.close() @@ -159,14 +162,14 @@ def close(self): pass self._tensorzero_client.close() - async def aclose(self): + async def aclose(self) -> None: """Async close method.""" if self._async_client: await self._async_client.aclose() self._async_client = None await self._tensorzero_client.aclose() - def __del__(self): + def __del__(self) -> None: """Cleanup on deletion.""" self.close() if self._async_client: @@ -185,7 +188,12 @@ def __enter__(self): """Context manager entry.""" return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Context manager exit.""" self.close() @@ -193,6 +201,11 @@ async def __aenter__(self): """Async context manager entry.""" return self - async def __aexit__(self, exc_type, exc_val, exc_tb): + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Async context manager exit.""" await self.aclose() diff --git a/dimos/agents/modules/gateway/tensorzero_embedded.py b/dimos/agents/modules/gateway/tensorzero_embedded.py index af04ec099b..90d30fe82d 100644 --- a/dimos/agents/modules/gateway/tensorzero_embedded.py +++ b/dimos/agents/modules/gateway/tensorzero_embedded.py @@ -14,11 +14,10 @@ """TensorZero embedded gateway client with correct config format.""" -import os -import json +from collections.abc import AsyncIterator, Iterator import logging -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) @@ -26,14 +25,14 @@ class TensorZeroEmbeddedGateway: """TensorZero embedded gateway using patch_openai_client.""" - def __init__(self): + def __init__(self) -> None: """Initialize TensorZero embedded gateway.""" self._client = None self._config_path = None self._setup_config() self._initialize_client() - def _setup_config(self): + def _setup_config(self) -> None: """Create TensorZero configuration with correct format.""" config_dir = Path("/tmp/tensorzero_embedded") config_dir.mkdir(exist_ok=True) @@ -81,7 +80,7 @@ def _setup_config(self): # Cerebras Models - disabled for CI (no API key) # [models.llama_3_3_70b] # routing = ["cerebras"] -# +# # [models.llama_3_3_70b.providers.cerebras] # type = "openai" # model_name = "llama-3.3-70b" @@ -180,13 +179,13 @@ def _map_model_to_tensorzero(self, model: str) -> str: def inference( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, temperature: float = 0.0, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, stream: bool = False, **kwargs, - ) -> Union[Dict[str, Any], Iterator[Dict[str, Any]]]: + ) -> dict[str, Any] | Iterator[dict[str, Any]]: """Synchronous inference call through TensorZero.""" # Map model to TensorZero function @@ -233,13 +232,13 @@ def stream_generator(): async def ainference( self, model: str, - messages: List[Dict[str, Any]], - tools: Optional[List[Dict[str, Any]]] = None, + messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, temperature: float = 0.0, - max_tokens: Optional[int] = None, + max_tokens: int | None = None, stream: bool = False, **kwargs, - ) -> Union[Dict[str, Any], AsyncIterator[Dict[str, Any]]]: + ) -> dict[str, Any] | AsyncIterator[dict[str, Any]]: """Async inference with streaming support.""" import asyncio @@ -270,12 +269,12 @@ async def stream_generator(): ) return result - def close(self): + def close(self) -> None: """Close the client.""" # TensorZero embedded doesn't need explicit cleanup pass - async def aclose(self): + async def aclose(self) -> None: """Async close.""" # TensorZero embedded doesn't need explicit cleanup pass diff --git a/dimos/agents/modules/gateway/tensorzero_simple.py b/dimos/agents/modules/gateway/tensorzero_simple.py index 21809bdef5..a2cc57e2fb 100644 --- a/dimos/agents/modules/gateway/tensorzero_simple.py +++ b/dimos/agents/modules/gateway/tensorzero_simple.py @@ -15,11 +15,11 @@ """Minimal TensorZero test to get it working.""" -import os from pathlib import Path + +from dotenv import load_dotenv from openai import OpenAI from tensorzero import patch_openai_client -from dotenv import load_dotenv load_dotenv() diff --git a/dimos/agents/modules/gateway/utils.py b/dimos/agents/modules/gateway/utils.py index e95a4dad04..ac9dc3e364 100644 --- a/dimos/agents/modules/gateway/utils.py +++ b/dimos/agents/modules/gateway/utils.py @@ -14,14 +14,13 @@ """Utility functions for gateway operations.""" -from typing import Any, Dict, List, Optional, Union -import json import logging +from typing import Any logger = logging.getLogger(__name__) -def convert_tools_to_standard_format(tools: List[Dict[str, Any]]) -> List[Dict[str, Any]]: +def convert_tools_to_standard_format(tools: list[dict[str, Any]]) -> list[dict[str, Any]]: """Convert DimOS tool format to standard format accepted by gateways. DimOS tools come from pydantic_function_tool and have this format: @@ -47,7 +46,7 @@ def convert_tools_to_standard_format(tools: List[Dict[str, Any]]) -> List[Dict[s return tools -def parse_streaming_response(chunk: Dict[str, Any]) -> Dict[str, Any]: +def parse_streaming_response(chunk: dict[str, Any]) -> dict[str, Any]: """Parse a streaming response chunk into a standard format. Args: @@ -103,7 +102,7 @@ def parse_streaming_response(chunk: Dict[str, Any]) -> Dict[str, Any]: return {"type": "unknown", "content": chunk, "metadata": {}} -def create_tool_response(tool_id: str, result: Any, is_error: bool = False) -> Dict[str, Any]: +def create_tool_response(tool_id: str, result: Any, is_error: bool = False) -> dict[str, Any]: """Create a properly formatted tool response. Args: @@ -124,7 +123,7 @@ def create_tool_response(tool_id: str, result: Any, is_error: bool = False) -> D } -def extract_image_from_message(message: Dict[str, Any]) -> Optional[Dict[str, Any]]: +def extract_image_from_message(message: dict[str, Any]) -> dict[str, Any] | None: """Extract image data from a message if present. Args: diff --git a/dimos/agents/modules/simple_vision_agent.py b/dimos/agents/modules/simple_vision_agent.py index 9bb6fb9894..b4888fd073 100644 --- a/dimos/agents/modules/simple_vision_agent.py +++ b/dimos/agents/modules/simple_vision_agent.py @@ -18,16 +18,15 @@ import base64 import io import threading -from typing import Optional import numpy as np from PIL import Image as PILImage +from reactivex.disposable import Disposable -from dimos.core import Module, In, Out, rpc +from dimos.agents.modules.gateway import UnifiedGatewayClient +from dimos.core import In, Module, Out, rpc from dimos.msgs.sensor_msgs import Image from dimos.utils.logging_config import setup_logger -from dimos.agents.modules.gateway import UnifiedGatewayClient -from reactivex.disposable import Disposable logger = setup_logger(__file__) @@ -46,10 +45,10 @@ class SimpleVisionAgentModule(Module): def __init__( self, model: str = "openai::gpt-4o-mini", - system_prompt: str = None, + system_prompt: str | None = None, temperature: float = 0.0, max_tokens: int = 4096, - ): + ) -> None: """Initialize the vision agent. Args: @@ -72,7 +71,7 @@ def __init__( self._lock = threading.Lock() @rpc - def start(self): + def start(self) -> None: """Initialize and start the agent.""" super().start() @@ -93,21 +92,21 @@ def start(self): logger.info("Simple vision agent started") @rpc - def stop(self): + def stop(self) -> None: logger.info("Stopping simple vision agent") if self.gateway: self.gateway.close() super().stop() - def _handle_image(self, image: Image): + def _handle_image(self, image: Image) -> None: """Handle incoming image.""" logger.info( f"Received new image: {image.data.shape if hasattr(image, 'data') else 'unknown shape'}" ) self._latest_image = image - def _handle_query(self, query: str): + def _handle_query(self, query: str) -> None: """Handle text query.""" with self._lock: if self._processing: @@ -120,11 +119,11 @@ def _handle_query(self, query: str): thread.daemon = True thread.start() - def _run_async_query(self, query: str): + def _run_async_query(self, query: str) -> None: """Run async query in new event loop.""" asyncio.run(self._process_query(query)) - async def _process_query(self, query: str): + async def _process_query(self, query: str) -> None: """Process the query.""" try: logger.info(f"Processing query: {query}") @@ -206,12 +205,12 @@ async def _process_query(self, query: str): traceback.print_exc() if self.response_out: - self.response_out.publish(f"Error: {str(e)}") + self.response_out.publish(f"Error: {e!s}") finally: with self._lock: self._processing = False - def _encode_image(self, image: Image) -> Optional[str]: + def _encode_image(self, image: Image) -> str | None: """Encode image to base64.""" try: # Convert to numpy array if needed diff --git a/dimos/agents/planning_agent.py b/dimos/agents/planning_agent.py index 52971e770a..6dbdbf5866 100644 --- a/dimos/agents/planning_agent.py +++ b/dimos/agents/planning_agent.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from textwrap import dedent import threading -from typing import List, Optional, Literal -from reactivex import Observable -from reactivex import operators as ops import time -from dimos.skills.skills import AbstractSkill +from typing import Literal + +from pydantic import BaseModel +from reactivex import Observable, operators as ops + from dimos.agents.agent import OpenAIAgent +from dimos.skills.skills import AbstractSkill from dimos.utils.logging_config import setup_logger -from textwrap import dedent -from pydantic import BaseModel logger = setup_logger("dimos.agents.planning_agent") @@ -29,7 +30,7 @@ # For response validation class PlanningAgentResponse(BaseModel): type: Literal["dialogue", "plan"] - content: List[str] + content: list[str] needs_confirmation: bool @@ -50,10 +51,10 @@ def __init__( self, dev_name: str = "PlanningAgent", model_name: str = "gpt-4", - input_query_stream: Optional[Observable] = None, + input_query_stream: Observable | None = None, use_terminal: bool = False, - skills: Optional[AbstractSkill] = None, - ): + skills: AbstractSkill | None = None, + ) -> None: """Initialize the planning agent. Args: @@ -192,9 +193,9 @@ def _send_query(self, messages: list) -> PlanningAgentResponse: try: return super()._send_query(messages) except Exception as e: - logger.error(f"Caught exception in _send_query: {str(e)}") + logger.error(f"Caught exception in _send_query: {e!s}") return PlanningAgentResponse( - type="dialogue", content=f"Error: {str(e)}", needs_confirmation=False + type="dialogue", content=f"Error: {e!s}", needs_confirmation=False ) def process_user_input(self, user_input: str) -> None: @@ -244,7 +245,7 @@ def process_user_input(self, user_input: str) -> None: response = self._send_query(messages) self._handle_response(response) - def start_terminal_interface(self): + def start_terminal_interface(self) -> None: """Start the terminal interface for input/output.""" time.sleep(5) # buffer time for clean terminal interface printing @@ -298,7 +299,7 @@ def get_response_observable(self) -> Observable: Observable: An observable that emits plan steps from the agent. """ - def extract_content(response) -> List[str]: + def extract_content(response) -> list[str]: if isinstance(response, PlanningAgentResponse): if response.type == "plan": return response.content # List of steps to be emitted individually diff --git a/dimos/agents/prompt_builder/impl.py b/dimos/agents/prompt_builder/impl.py index 0e66191837..9cd532fea9 100644 --- a/dimos/agents/prompt_builder/impl.py +++ b/dimos/agents/prompt_builder/impl.py @@ -14,7 +14,7 @@ from textwrap import dedent -from typing import Optional + from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.agents.tokenizer.openai_tokenizer import OpenAITokenizer @@ -24,9 +24,9 @@ class PromptBuilder: DEFAULT_SYSTEM_PROMPT = dedent(""" - You are an AI assistant capable of understanding and analyzing both visual and textual information. - Your task is to provide accurate and insightful responses based on the data provided to you. - Use the following information to assist the user with their query. Do not rely on any internal + You are an AI assistant capable of understanding and analyzing both visual and textual information. + Your task is to provide accurate and insightful responses based on the data provided to you. + Use the following information to assist the user with their query. Do not rely on any internal knowledge or make assumptions beyond the provided data. Visual Context: You may have been given an image to analyze. Use the visual details to enhance your response. @@ -39,8 +39,11 @@ class PromptBuilder: """) def __init__( - self, model_name="gpt-4o", max_tokens=128000, tokenizer: Optional[AbstractTokenizer] = None - ): + self, + model_name: str = "gpt-4o", + max_tokens: int = 128000, + tokenizer: AbstractTokenizer | None = None, + ) -> None: """ Initialize the prompt builder. Args: @@ -52,7 +55,7 @@ def __init__( self.max_tokens = max_tokens self.tokenizer: AbstractTokenizer = tokenizer or OpenAITokenizer(model_name=self.model_name) - def truncate_tokens(self, text, max_tokens, strategy): + def truncate_tokens(self, text: str, max_tokens, strategy): """ Truncate text to fit within max_tokens using a specified strategy. Args: @@ -88,11 +91,11 @@ def build( base64_image=None, image_width=None, image_height=None, - image_detail="low", + image_detail: str = "low", rag_context=None, budgets=None, policies=None, - override_token_limit=False, + override_token_limit: bool = False, ): """ Builds a dynamic prompt tailored to token limits, respecting budgets and policies. diff --git a/dimos/agents/test_agent_image_message.py b/dimos/agents/test_agent_image_message.py index 5f30dcf9cd..c7f84bcefe 100644 --- a/dimos/agents/test_agent_image_message.py +++ b/dimos/agents/test_agent_image_message.py @@ -18,9 +18,9 @@ import logging import os +from dotenv import load_dotenv import numpy as np import pytest -from dotenv import load_dotenv from dimos.agents.agent_message import AgentMessage from dimos.agents.modules.base import BaseAgent @@ -34,7 +34,7 @@ @pytest.mark.tofix -def test_agent_single_image(): +def test_agent_single_image() -> None: """Test agent with single image in AgentMessage.""" load_dotenv() @@ -95,7 +95,7 @@ def test_agent_single_image(): @pytest.mark.tofix -def test_agent_multiple_images(): +def test_agent_multiple_images() -> None: """Test agent with multiple images in AgentMessage.""" load_dotenv() @@ -163,7 +163,7 @@ def test_agent_multiple_images(): @pytest.mark.tofix -def test_agent_image_with_context(): +def test_agent_image_with_context() -> None: """Test agent maintaining context with image queries.""" load_dotenv() @@ -212,7 +212,7 @@ def test_agent_image_with_context(): @pytest.mark.tofix -def test_agent_mixed_content(): +def test_agent_mixed_content() -> None: """Test agent with mixed text-only and image queries.""" load_dotenv() @@ -290,7 +290,7 @@ def test_agent_mixed_content(): @pytest.mark.tofix -def test_agent_empty_image_message(): +def test_agent_empty_image_message() -> None: """Test edge case with empty parts of AgentMessage.""" load_dotenv() @@ -338,7 +338,7 @@ def test_agent_empty_image_message(): @pytest.mark.tofix -def test_agent_non_vision_model_with_images(): +def test_agent_non_vision_model_with_images() -> None: """Test that non-vision models handle image input gracefully.""" load_dotenv() @@ -375,7 +375,7 @@ def test_agent_non_vision_model_with_images(): @pytest.mark.tofix -def test_mock_agent_with_images(): +def test_mock_agent_with_images() -> None: """Test mock agent with images for CI.""" # This test doesn't need API keys diff --git a/dimos/agents/test_agent_message_streams.py b/dimos/agents/test_agent_message_streams.py index a84a0ed48e..22d33b46de 100644 --- a/dimos/agents/test_agent_message_streams.py +++ b/dimos/agents/test_agent_message_streams.py @@ -17,23 +17,22 @@ import asyncio import os -import time -from dotenv import load_dotenv -import pytest import pickle +from dotenv import load_dotenv +import pytest from reactivex import operators as ops from dimos import core -from dimos.core import Module, In, Out, rpc -from dimos.agents.modules.base_agent import BaseAgentModule from dimos.agents.agent_message import AgentMessage from dimos.agents.agent_types import AgentResponse +from dimos.agents.modules.base_agent import BaseAgentModule +from dimos.core import In, Module, Out, rpc from dimos.msgs.sensor_msgs import Image from dimos.protocol import pubsub from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay logger = setup_logger("test_agent_message_streams") @@ -43,14 +42,14 @@ class VideoMessageSender(Module): message_out: Out[AgentMessage] = None - def __init__(self, video_path: str): + def __init__(self, video_path: str) -> None: super().__init__() self.video_path = video_path self._subscription = None self._frame_count = 0 @rpc - def start(self): + def start(self) -> None: """Start sending video messages.""" # Use TimedSensorReplay to replay video frames video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) @@ -83,12 +82,12 @@ def _create_message(self, frame: Image) -> AgentMessage: logger.info(f"Created message with frame {self._frame_count}") return msg - def _send_message(self, msg: AgentMessage): + def _send_message(self, msg: AgentMessage) -> None: """Send the message and test pickling.""" # Test that message can be pickled (for module communication) try: pickled = pickle.dumps(msg) - unpickled = pickle.loads(pickled) + pickle.loads(pickled) logger.info(f"Message pickling test passed - size: {len(pickled)} bytes") except Exception as e: logger.error(f"Message pickling failed: {e}") @@ -96,7 +95,7 @@ def _send_message(self, msg: AgentMessage): self.message_out.publish(msg) @rpc - def stop(self): + def stop(self) -> None: """Stop streaming.""" if self._subscription: self._subscription.dispose() @@ -108,13 +107,13 @@ class MultiImageMessageSender(Module): message_out: Out[AgentMessage] = None - def __init__(self, video_path: str): + def __init__(self, video_path: str) -> None: super().__init__() self.video_path = video_path self.frames = [] @rpc - def start(self): + def start(self) -> None: """Collect some frames.""" video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) @@ -124,13 +123,13 @@ def start(self): on_completed=self._send_multi_image_query, ) - def _send_multi_image_query(self): + def _send_multi_image_query(self) -> None: """Send query with multiple images.""" if len(self.frames) >= 2: msg = AgentMessage() msg.add_text("Compare these images and describe what changed between them.") - for i, frame in enumerate(self.frames[:2]): + for _i, frame in enumerate(self.frames[:2]): msg.add_image(frame) logger.info(f"Sending multi-image message with {len(msg.images)} images") @@ -150,15 +149,15 @@ class ResponseCollector(Module): response_in: In[AgentResponse] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.responses = [] @rpc - def start(self): + def start(self) -> None: self.response_in.subscribe(self._on_response) - def _on_response(self, resp: AgentResponse): + def _on_response(self, resp: AgentResponse) -> None: logger.info(f"Collected response: {resp.content[:100] if resp.content else 'None'}...") self.responses.append(resp) @@ -170,7 +169,7 @@ def get_responses(self): @pytest.mark.tofix @pytest.mark.module @pytest.mark.asyncio -async def test_agent_message_video_stream(): +async def test_agent_message_video_stream() -> None: """Test BaseAgentModule with AgentMessage containing video frames.""" load_dotenv() @@ -254,7 +253,7 @@ async def test_agent_message_video_stream(): @pytest.mark.tofix @pytest.mark.module @pytest.mark.asyncio -async def test_agent_message_multi_image(): +async def test_agent_message_multi_image() -> None: """Test BaseAgentModule with AgentMessage containing multiple images.""" load_dotenv() @@ -330,7 +329,7 @@ async def test_agent_message_multi_image(): @pytest.mark.tofix -def test_agent_message_text_only(): +def test_agent_message_text_only() -> None: """Test BaseAgent with text-only AgentMessage.""" load_dotenv() @@ -354,7 +353,7 @@ def test_agent_message_text_only(): msg.add_text("of France?") response = agent.query(msg) - assert "Paris" in response.content, f"Expected 'Paris' in response" + assert "Paris" in response.content, "Expected 'Paris' in response" # Test pickling of AgentMessage pickled = pickle.dumps(msg) diff --git a/dimos/agents/test_agent_pool.py b/dimos/agents/test_agent_pool.py index 9c0b530b68..b3576b80e2 100644 --- a/dimos/agents/test_agent_pool.py +++ b/dimos/agents/test_agent_pool.py @@ -16,12 +16,13 @@ import asyncio import os -import pytest + from dotenv import load_dotenv +import pytest from dimos import core -from dimos.core import Module, Out, In, rpc from dimos.agents.modules.base_agent import BaseAgentModule +from dimos.core import In, Module, Out, rpc from dimos.protocol import pubsub @@ -34,10 +35,10 @@ class PoolRouter(Module): agent3_out: Out[str] = None @rpc - def start(self): + def start(self) -> None: self.query_in.subscribe(self._route) - def _route(self, msg: dict): + def _route(self, msg: dict) -> None: agent_id = msg.get("agent_id", "agent1") query = msg.get("query", "") @@ -66,7 +67,7 @@ class PoolAggregator(Module): response_out: Out[dict] = None @rpc - def start(self): + def start(self) -> None: if self.agent1_in: self.agent1_in.subscribe(lambda r: self._handle_response("agent1", r)) if self.agent2_in: @@ -74,7 +75,7 @@ def start(self): if self.agent3_in: self.agent3_in.subscribe(lambda r: self._handle_response("agent3", r)) - def _handle_response(self, agent_id: str, response: str): + def _handle_response(self, agent_id: str, response: str) -> None: if self.response_out: self.response_out.publish({"agent_id": agent_id, "response": response}) @@ -85,11 +86,11 @@ class PoolController(Module): query_out: Out[dict] = None @rpc - def send_to_agent(self, agent_id: str, query: str): + def send_to_agent(self, agent_id: str, query: str) -> None: self.query_out.publish({"agent_id": agent_id, "query": query}) @rpc - def broadcast(self, query: str): + def broadcast(self, query: str) -> None: self.query_out.publish({"agent_id": "all", "query": query}) @@ -98,12 +99,12 @@ class PoolCollector(Module): response_in: In[dict] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.responses = [] @rpc - def start(self): + def start(self) -> None: self.response_in.subscribe(lambda r: self.responses.append(r)) @rpc @@ -118,7 +119,7 @@ def get_by_agent(self, agent_id: str) -> list: @pytest.mark.skip("Skipping pool tests for now") @pytest.mark.module @pytest.mark.asyncio -async def test_agent_pool(): +async def test_agent_pool() -> None: """Test agent pool with multiple agents.""" load_dotenv() pubsub.lcm.autoconf() @@ -211,7 +212,7 @@ async def test_agent_pool(): await asyncio.sleep(3) # Test direct routing - for i, model_id in enumerate(models[:2]): # Test first 2 agents + for _i, model_id in enumerate(models[:2]): # Test first 2 agents controller.send_to_agent(model_id, f"Say hello from {model_id}") await asyncio.sleep(0.5) @@ -252,7 +253,7 @@ async def test_agent_pool(): @pytest.mark.skip("Skipping pool tests for now") @pytest.mark.module @pytest.mark.asyncio -async def test_mock_agent_pool(): +async def test_mock_agent_pool() -> None: """Test agent pool with mock agents.""" pubsub.lcm.autoconf() @@ -262,15 +263,15 @@ class MockPoolAgent(Module): query_in: In[str] = None response_out: Out[str] = None - def __init__(self, agent_id: str): + def __init__(self, agent_id: str) -> None: super().__init__() self.agent_id = agent_id @rpc - def start(self): + def start(self) -> None: self.query_in.subscribe(self._handle_query) - def _handle_query(self, query: str): + def _handle_query(self, query: str) -> None: if "1+1" in query: self.response_out.publish(f"{self.agent_id}: The answer is 2") else: diff --git a/dimos/agents/test_agent_tools.py b/dimos/agents/test_agent_tools.py index 5e3c021772..fd485ac015 100644 --- a/dimos/agents/test_agent_tools.py +++ b/dimos/agents/test_agent_tools.py @@ -14,20 +14,21 @@ """Production test for BaseAgent tool handling functionality.""" -import pytest import asyncio import os + from dotenv import load_dotenv from pydantic import Field +import pytest -from dimos.agents.modules.base import BaseAgent -from dimos.agents.modules.base_agent import BaseAgentModule +from dimos import core from dimos.agents.agent_message import AgentMessage from dimos.agents.agent_types import AgentResponse -from dimos.skills.skills import AbstractSkill, SkillLibrary -from dimos import core -from dimos.core import Module, Out, In, rpc +from dimos.agents.modules.base import BaseAgent +from dimos.agents.modules.base_agent import BaseAgentModule +from dimos.core import In, Module, Out, rpc from dimos.protocol import pubsub +from dimos.skills.skills import AbstractSkill, SkillLibrary from dimos.utils.logging_config import setup_logger logger = setup_logger("test_agent_tools") @@ -45,7 +46,7 @@ def __call__(self) -> str: result = eval(self.expression) return f"The result is {result}" except Exception as e: - return f"Error calculating: {str(e)}" + return f"Error calculating: {e!s}" class WeatherSkill(AbstractSkill): @@ -80,7 +81,7 @@ class ToolTestController(Module): message_out: Out[AgentMessage] = None @rpc - def send_query(self, query: str): + def send_query(self, query: str) -> None: msg = AgentMessage() msg.add_text(query) self.message_out.publish(msg) @@ -91,17 +92,17 @@ class ResponseCollector(Module): response_in: In[AgentResponse] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.responses = [] @rpc - def start(self): + def start(self) -> None: logger.info("ResponseCollector starting subscription") self.response_in.subscribe(self._on_response) logger.info("ResponseCollector subscription active") - def _on_response(self, response): + def _on_response(self, response) -> None: logger.info(f"ResponseCollector received response #{len(self.responses) + 1}: {response}") self.responses.append(response) @@ -113,7 +114,7 @@ def get_responses(self): @pytest.mark.tofix @pytest.mark.module @pytest.mark.asyncio -async def test_agent_module_with_tools(): +async def test_agent_module_with_tools() -> None: """Test BaseAgentModule with tool execution.""" load_dotenv() @@ -188,9 +189,9 @@ async def test_agent_module_with_tools(): # Verify weather details assert isinstance(response, AgentResponse), "Expected AgentResponse object" - assert "new york" in response.content.lower(), f"Expected 'New York' in response" - assert "72" in response.content, f"Expected temperature '72' in response" - assert "sunny" in response.content.lower(), f"Expected 'sunny' in response" + assert "new york" in response.content.lower(), "Expected 'New York' in response" + assert "72" in response.content, "Expected temperature '72' in response" + assert "sunny" in response.content.lower(), "Expected 'sunny' in response" # Test 3: Navigation (potentially long-running) logger.info("\n=== Test 3: Navigation Tool ===") @@ -240,7 +241,7 @@ async def test_agent_module_with_tools(): @pytest.mark.tofix -def test_base_agent_direct_tools(): +def test_base_agent_direct_tools() -> None: """Test BaseAgent direct usage with tools.""" load_dotenv() @@ -295,9 +296,9 @@ def test_base_agent_direct_tools(): logger.info(f"Tool calls: {response2.tool_calls}") assert response2.content is not None - assert "london" in response2.content.lower(), f"Expected 'London' in response" - assert "72" in response2.content, f"Expected temperature '72' in response" - assert "sunny" in response2.content.lower(), f"Expected 'sunny' in response" + assert "london" in response2.content.lower(), "Expected 'London' in response" + assert "72" in response2.content, "Expected temperature '72' in response" + assert "sunny" in response2.content.lower(), "Expected 'sunny' in response" # Verify tool was called if response2.tool_calls is not None: @@ -316,7 +317,7 @@ def test_base_agent_direct_tools(): class MockToolAgent(BaseAgent): """Mock agent for CI testing without API calls.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: # Skip gateway initialization self.model = kwargs.get("model", "mock::test") self.system_prompt = kwargs.get("system_prompt", "Mock agent") @@ -330,8 +331,8 @@ def __init__(self, **kwargs): async def _process_query_async(self, agent_msg, base64_image=None, base64_images=None): """Mock tool execution.""" - from dimos.agents.agent_types import AgentResponse, ToolCall from dimos.agents.agent_message import AgentMessage + from dimos.agents.agent_types import AgentResponse, ToolCall # Get text from AgentMessage if isinstance(agent_msg, AgentMessage): @@ -362,12 +363,12 @@ async def _process_query_async(self, agent_msg, base64_image=None, base64_images # Default response return AgentResponse(content=f"Mock response to: {query}") - def dispose(self): + def dispose(self) -> None: pass @pytest.mark.tofix -def test_mock_agent_tools(): +def test_mock_agent_tools() -> None: """Test mock agent with tools for CI.""" # Create skill library skill_library = SkillLibrary() @@ -384,7 +385,7 @@ def test_mock_agent_tools(): logger.info(f"Mock tool calls: {response.tool_calls}") assert response.content is not None - assert "42" in response.content, f"Expected '42' in response" + assert "42" in response.content, "Expected '42' in response" assert response.tool_calls is not None, "Expected tool calls" assert len(response.tool_calls) == 1, "Expected exactly one tool call" assert response.tool_calls[0].name == "CalculateSkill", "Expected CalculateSkill" diff --git a/dimos/agents/test_agent_with_modules.py b/dimos/agents/test_agent_with_modules.py index 5eefd92efe..1a4ac70f65 100644 --- a/dimos/agents/test_agent_with_modules.py +++ b/dimos/agents/test_agent_with_modules.py @@ -15,17 +15,15 @@ """Test agent module with proper module connections.""" import asyncio -import os -import pytest -import threading -import time + from dotenv import load_dotenv +import pytest from dimos import core -from dimos.core import Module, Out, In, rpc -from dimos.agents.modules.base_agent import BaseAgentModule from dimos.agents.agent_message import AgentMessage from dimos.agents.agent_types import AgentResponse +from dimos.agents.modules.base_agent import BaseAgentModule +from dimos.core import In, Module, Out, rpc from dimos.protocol import pubsub @@ -35,11 +33,11 @@ class QuerySender(Module): message_out: Out[AgentMessage] = None - def __init__(self): + def __init__(self) -> None: super().__init__() @rpc - def send_query(self, query: str): + def send_query(self, query: str) -> None: """Send a query.""" print(f"Sending query: {query}") msg = AgentMessage() @@ -53,16 +51,16 @@ class ResponseCollector(Module): response_in: In[AgentResponse] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.responses = [] @rpc - def start(self): + def start(self) -> None: """Start collecting.""" self.response_in.subscribe(self._on_response) - def _on_response(self, msg: AgentResponse): + def _on_response(self, msg: AgentResponse) -> None: print(f"Received response: {msg.content if msg.content else msg}") self.responses.append(msg) @@ -75,7 +73,7 @@ def get_responses(self): @pytest.mark.tofix @pytest.mark.module @pytest.mark.asyncio -async def test_agent_module_connections(): +async def test_agent_module_connections() -> None: """Test agent module with proper connections.""" load_dotenv() pubsub.lcm.autoconf() diff --git a/dimos/agents/test_base_agent_text.py b/dimos/agents/test_base_agent_text.py index af0dd6ae4b..022bea9cd2 100644 --- a/dimos/agents/test_base_agent_text.py +++ b/dimos/agents/test_base_agent_text.py @@ -14,17 +14,18 @@ """Test BaseAgent text functionality.""" -import pytest import asyncio import os + from dotenv import load_dotenv +import pytest -from dimos.agents.modules.base import BaseAgent -from dimos.agents.modules.base_agent import BaseAgentModule +from dimos import core from dimos.agents.agent_message import AgentMessage from dimos.agents.agent_types import AgentResponse -from dimos import core -from dimos.core import Module, Out, In, rpc +from dimos.agents.modules.base import BaseAgent +from dimos.agents.modules.base_agent import BaseAgentModule +from dimos.core import In, Module, Out, rpc from dimos.protocol import pubsub @@ -34,14 +35,14 @@ class QuerySender(Module): message_out: Out[AgentMessage] = None # New AgentMessage output @rpc - def send_query(self, query: str): + def send_query(self, query: str) -> None: """Send a query as AgentMessage.""" msg = AgentMessage() msg.add_text(query) self.message_out.publish(msg) @rpc - def send_message(self, message: AgentMessage): + def send_message(self, message: AgentMessage) -> None: """Send an AgentMessage.""" self.message_out.publish(message) @@ -51,16 +52,16 @@ class ResponseCollector(Module): response_in: In[AgentResponse] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.responses = [] @rpc - def start(self): + def start(self) -> None: """Start collecting.""" self.response_in.subscribe(self._on_response) - def _on_response(self, msg): + def _on_response(self, msg) -> None: self.responses.append(msg) @rpc @@ -70,7 +71,7 @@ def get_responses(self): @pytest.mark.tofix -def test_base_agent_direct_text(): +def test_base_agent_direct_text() -> None: """Test BaseAgent direct text usage.""" load_dotenv() @@ -100,7 +101,7 @@ def test_base_agent_direct_text(): print(f"[Test] Query: 'What is 3+3?' -> Response: '{response.content}'") assert response.content is not None assert "6" in response.content or "six" in response.content.lower(), ( - f"Expected '6' or 'six' in response" + "Expected '6' or 'six' in response" ) # Test conversation history @@ -119,7 +120,7 @@ def test_base_agent_direct_text(): @pytest.mark.tofix @pytest.mark.asyncio -async def test_base_agent_async_text(): +async def test_base_agent_async_text() -> None: """Test BaseAgent async text usage.""" load_dotenv() @@ -137,14 +138,14 @@ async def test_base_agent_async_text(): # Test async query with string response = await agent.aquery("What is the capital of France?") assert response.content is not None - assert "Paris" in response.content, f"Expected 'Paris' in response" + assert "Paris" in response.content, "Expected 'Paris' in response" # Test async query with AgentMessage msg = AgentMessage() msg.add_text("What is the capital of Germany?") response = await agent.aquery(msg) assert response.content is not None - assert "Berlin" in response.content, f"Expected 'Berlin' in response" + assert "Berlin" in response.content, "Expected 'Berlin' in response" # Clean up agent.dispose() @@ -153,7 +154,7 @@ async def test_base_agent_async_text(): @pytest.mark.tofix @pytest.mark.module @pytest.mark.asyncio -async def test_base_agent_module_text(): +async def test_base_agent_module_text() -> None: """Test BaseAgentModule with text via DimOS.""" load_dotenv() @@ -208,7 +209,7 @@ async def test_base_agent_module_text(): assert len(responses) >= 2, "Should have at least two responses" resp = responses[1] assert isinstance(resp, AgentResponse), "Expected AgentResponse object" - assert "blue" in resp.content.lower(), f"Expected 'blue' in response" + assert "blue" in resp.content.lower(), "Expected 'blue' in response" # Test conversation history sender.send_query("What was my first question?") @@ -218,7 +219,7 @@ async def test_base_agent_module_text(): assert len(responses) >= 3, "Should have at least three responses" resp = responses[2] assert isinstance(resp, AgentResponse), "Expected AgentResponse object" - assert "2+2" in resp.content or "2" in resp.content, f"Expected reference to first question" + assert "2+2" in resp.content or "2" in resp.content, "Expected reference to first question" # Stop modules agent.stop() @@ -237,7 +238,7 @@ async def test_base_agent_module_text(): ], ) @pytest.mark.tofix -def test_base_agent_providers(model, provider): +def test_base_agent_providers(model, provider) -> None: """Test BaseAgent with different providers.""" load_dotenv() @@ -271,7 +272,7 @@ def test_base_agent_providers(model, provider): @pytest.mark.tofix -def test_base_agent_memory(): +def test_base_agent_memory() -> None: """Test BaseAgent with memory/RAG.""" load_dotenv() @@ -299,7 +300,7 @@ def test_base_agent_memory(): response = agent.query(msg) assert response.content is not None assert "framework" in response.content.lower() or "robotic" in response.content.lower(), ( - f"Expected context about DimOS in response" + "Expected context about DimOS in response" ) # Clean up @@ -309,7 +310,7 @@ def test_base_agent_memory(): class MockAgent(BaseAgent): """Mock agent for testing without API calls.""" - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: # Don't call super().__init__ to avoid gateway initialization from dimos.agents.agent_types import ConversationHistory @@ -319,7 +320,7 @@ def __init__(self, **kwargs): self._supports_vision = False self.response_subject = None # Simplified - async def _process_query_async(self, query: str, base64_image=None): + async def _process_query_async(self, query: str, base64_image=None) -> str: """Mock response.""" if "2+2" in query: return "The answer is 4" @@ -365,13 +366,13 @@ async def aquery(self, message) -> AgentResponse: self.conversation.add_assistant_message(response) return AgentResponse(content=response) - def dispose(self): + def dispose(self) -> None: """Mock dispose.""" pass @pytest.mark.tofix -def test_mock_agent(): +def test_mock_agent() -> None: """Test mock agent for CI without API keys.""" # Create mock agent agent = MockAgent(model="mock::test", system_prompt="Mock assistant") @@ -400,7 +401,7 @@ def test_mock_agent(): @pytest.mark.tofix -def test_base_agent_conversation_history(): +def test_base_agent_conversation_history() -> None: """Test that conversation history is properly maintained.""" load_dotenv() @@ -428,7 +429,7 @@ def test_base_agent_conversation_history(): # Test 2: Reference previous context response2 = agent.query("What is my name?") - assert "Alice" in response2.content, f"Agent should remember the name" + assert "Alice" in response2.content, "Agent should remember the name" # Conversation history should now have 4 messages assert agent.conversation.size() == 4 @@ -450,7 +451,7 @@ def test_base_agent_conversation_history(): # Test 4: History trimming (set low limit) agent.max_history = 4 - response4 = agent.query("What was my first message?") + agent.query("What was my first message?") # Conversation history should be trimmed to 4 messages assert agent.conversation.size() == 4 @@ -463,16 +464,17 @@ def test_base_agent_conversation_history(): @pytest.mark.tofix -def test_base_agent_history_with_tools(): +def test_base_agent_history_with_tools() -> None: """Test conversation history with tool calls.""" load_dotenv() if not os.getenv("OPENAI_API_KEY"): pytest.skip("No OPENAI_API_KEY found") - from dimos.skills.skills import AbstractSkill, SkillLibrary from pydantic import Field + from dimos.skills.skills import AbstractSkill, SkillLibrary + class CalculatorSkill(AbstractSkill): """Perform calculations.""" diff --git a/dimos/agents/test_conversation_history.py b/dimos/agents/test_conversation_history.py index b80892f304..95b28fbc0b 100644 --- a/dimos/agents/test_conversation_history.py +++ b/dimos/agents/test_conversation_history.py @@ -15,25 +15,26 @@ """Comprehensive conversation history tests for agents.""" -import os import asyncio -import pytest -import numpy as np +import logging +import os + from dotenv import load_dotenv +import numpy as np +from pydantic import Field +import pytest -from dimos.agents.modules.base import BaseAgent from dimos.agents.agent_message import AgentMessage -from dimos.agents.agent_types import AgentResponse, ConversationHistory +from dimos.agents.agent_types import AgentResponse +from dimos.agents.modules.base import BaseAgent from dimos.msgs.sensor_msgs import Image from dimos.skills.skills import AbstractSkill, SkillLibrary -from pydantic import Field -import logging logger = logging.getLogger(__name__) @pytest.mark.tofix -def test_conversation_history_basic(): +def test_conversation_history_basic() -> None: """Test basic conversation history functionality.""" load_dotenv() @@ -90,7 +91,7 @@ def test_conversation_history_basic(): @pytest.mark.tofix -def test_conversation_history_with_images(): +def test_conversation_history_with_images() -> None: """Test conversation history with multimodal content.""" load_dotenv() @@ -106,7 +107,7 @@ def test_conversation_history_with_images(): try: # Send text message - response1 = agent.query("I'm going to show you some colors") + agent.query("I'm going to show you some colors") assert agent.conversation.size() == 2 # Send image with text @@ -115,7 +116,7 @@ def test_conversation_history_with_images(): red_img = Image(data=np.full((100, 100, 3), [255, 0, 0], dtype=np.uint8)) msg.add_image(red_img) - response2 = agent.query(msg) + agent.query(msg) assert agent.conversation.size() == 4 # Ask about the image @@ -132,7 +133,7 @@ def test_conversation_history_with_images(): blue_img = Image(data=np.full((100, 100, 3), [0, 0, 255], dtype=np.uint8)) msg2.add_image(blue_img) - response4 = agent.query(msg2) + agent.query(msg2) assert agent.conversation.size() == 8 # Ask about all images @@ -151,7 +152,7 @@ def test_conversation_history_with_images(): @pytest.mark.tofix -def test_conversation_history_trimming(): +def test_conversation_history_trimming() -> None: """Test that conversation history is trimmed to max size.""" load_dotenv() @@ -194,7 +195,7 @@ def test_conversation_history_trimming(): assert size == 3, f"After Message 5, size should still be 3, got {size}" # Early messages should be trimmed - response = agent.query("What was the first fruit I mentioned?") + agent.query("What was the first fruit I mentioned?") size = agent.conversation.size() assert size == 3, f"After question, size should still be 3, got {size}" @@ -210,7 +211,7 @@ def test_conversation_history_trimming(): @pytest.mark.tofix -def test_conversation_history_with_tools(): +def test_conversation_history_with_tools() -> None: """Test conversation history with tool calls.""" load_dotenv() @@ -244,7 +245,7 @@ class TestSkillLibrary(SkillLibrary): try: # Initial query - response1 = agent.query("Hello, I need help with math") + agent.query("Hello, I need help with math") assert agent.conversation.size() == 2 # Force tool use explicitly @@ -269,7 +270,7 @@ class TestSkillLibrary(SkillLibrary): @pytest.mark.tofix -def test_conversation_thread_safety(): +def test_conversation_thread_safety() -> None: """Test that conversation history is thread-safe.""" load_dotenv() @@ -280,7 +281,7 @@ def test_conversation_thread_safety(): try: - async def query_async(text): + async def query_async(text: str): """Async wrapper for query.""" return await agent.aquery(text) @@ -303,7 +304,7 @@ async def run_concurrent(): @pytest.mark.tofix -def test_conversation_history_formats(): +def test_conversation_history_formats() -> None: """Test ConversationHistory formatting methods.""" load_dotenv() @@ -363,7 +364,7 @@ def test_conversation_history_formats(): @pytest.mark.tofix @pytest.mark.timeout(30) # Add timeout to prevent hanging -def test_conversation_edge_cases(): +def test_conversation_edge_cases() -> None: """Test edge cases in conversation history.""" load_dotenv() diff --git a/dimos/agents/test_gateway.py b/dimos/agents/test_gateway.py index d962ec46ad..2c54d5d1ac 100644 --- a/dimos/agents/test_gateway.py +++ b/dimos/agents/test_gateway.py @@ -17,15 +17,15 @@ import asyncio import os -import pytest from dotenv import load_dotenv +import pytest from dimos.agents.modules.gateway import UnifiedGatewayClient @pytest.mark.tofix @pytest.mark.asyncio -async def test_gateway_basic(): +async def test_gateway_basic() -> None: """Test basic gateway functionality.""" load_dotenv() @@ -72,7 +72,7 @@ async def test_gateway_basic(): @pytest.mark.tofix @pytest.mark.asyncio -async def test_gateway_streaming(): +async def test_gateway_streaming() -> None: """Test gateway streaming functionality.""" load_dotenv() @@ -99,7 +99,7 @@ async def test_gateway_streaming(): # Reconstruct content content = "" for chunk in chunks: - if "choices" in chunk and chunk["choices"]: + if chunk.get("choices"): delta = chunk["choices"][0].get("delta", {}) chunk_content = delta.get("content") if chunk_content is not None: @@ -113,7 +113,7 @@ async def test_gateway_streaming(): @pytest.mark.tofix @pytest.mark.asyncio -async def test_gateway_tools(): +async def test_gateway_tools() -> None: """Test gateway can pass tool definitions to LLM and get responses.""" load_dotenv() @@ -158,7 +158,7 @@ async def test_gateway_tools(): @pytest.mark.tofix @pytest.mark.asyncio -async def test_gateway_providers(): +async def test_gateway_providers() -> None: """Test gateway with different providers.""" load_dotenv() diff --git a/dimos/agents/test_simple_agent_module.py b/dimos/agents/test_simple_agent_module.py index 2da67540d6..bd374877dd 100644 --- a/dimos/agents/test_simple_agent_module.py +++ b/dimos/agents/test_simple_agent_module.py @@ -16,14 +16,15 @@ import asyncio import os -import pytest + from dotenv import load_dotenv +import pytest from dimos import core -from dimos.core import Module, Out, In, rpc -from dimos.agents.modules.base_agent import BaseAgentModule from dimos.agents.agent_message import AgentMessage from dimos.agents.agent_types import AgentResponse +from dimos.agents.modules.base_agent import BaseAgentModule +from dimos.core import In, Module, Out, rpc from dimos.protocol import pubsub @@ -33,7 +34,7 @@ class QuerySender(Module): message_out: Out[AgentMessage] = None @rpc - def send_query(self, query: str): + def send_query(self, query: str) -> None: """Send a query.""" msg = AgentMessage() msg.add_text(query) @@ -45,16 +46,16 @@ class ResponseCollector(Module): response_in: In[AgentResponse] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.responses = [] @rpc - def start(self): + def start(self) -> None: """Start collecting.""" self.response_in.subscribe(self._on_response) - def _on_response(self, response: AgentResponse): + def _on_response(self, response: AgentResponse) -> None: """Handle response.""" self.responses.append(response) @@ -64,7 +65,7 @@ def get_responses(self) -> list: return self.responses @rpc - def clear(self): + def clear(self) -> None: """Clear responses.""" self.responses = [] @@ -81,19 +82,19 @@ def clear(self): ("qwen::qwen-turbo", "Qwen"), ], ) -async def test_simple_agent_module(model, provider): +async def test_simple_agent_module(model, provider) -> None: """Test simple agent module with different providers.""" load_dotenv() # Skip if no API key if provider == "OpenAI" and not os.getenv("OPENAI_API_KEY"): - pytest.skip(f"No OpenAI API key found") + pytest.skip("No OpenAI API key found") elif provider == "Claude" and not os.getenv("ANTHROPIC_API_KEY"): - pytest.skip(f"No Anthropic API key found") + pytest.skip("No Anthropic API key found") elif provider == "Cerebras" and not os.getenv("CEREBRAS_API_KEY"): - pytest.skip(f"No Cerebras API key found") + pytest.skip("No Cerebras API key found") elif provider == "Qwen" and not os.getenv("ALIBABA_API_KEY"): - pytest.skip(f"No Qwen API key found") + pytest.skip("No Qwen API key found") pubsub.lcm.autoconf() @@ -154,7 +155,7 @@ async def test_simple_agent_module(model, provider): @pytest.mark.tofix @pytest.mark.module @pytest.mark.asyncio -async def test_mock_agent_module(): +async def test_mock_agent_module() -> None: """Test agent module with mock responses (no API needed).""" pubsub.lcm.autoconf() @@ -165,10 +166,10 @@ class MockAgentModule(Module): response_out: Out[AgentResponse] = None @rpc - def start(self): + def start(self) -> None: self.message_in.subscribe(self._handle_message) - def _handle_message(self, msg: AgentMessage): + def _handle_message(self, msg: AgentMessage) -> None: query = msg.get_combined_text() if "2+2" in query: self.response_out.publish(AgentResponse(content="4")) diff --git a/dimos/agents/tokenizer/base.py b/dimos/agents/tokenizer/base.py index b7e96de71f..7957c896fa 100644 --- a/dimos/agents/tokenizer/base.py +++ b/dimos/agents/tokenizer/base.py @@ -21,7 +21,7 @@ class AbstractTokenizer(ABC): @abstractmethod - def tokenize_text(self, text): + def tokenize_text(self, text: str): pass @abstractmethod @@ -29,9 +29,9 @@ def detokenize_text(self, tokenized_text): pass @abstractmethod - def token_count(self, text): + def token_count(self, text: str): pass @abstractmethod - def image_token_count(self, image_width, image_height, image_detail="low"): + def image_token_count(self, image_width, image_height, image_detail: str = "low"): pass diff --git a/dimos/agents/tokenizer/huggingface_tokenizer.py b/dimos/agents/tokenizer/huggingface_tokenizer.py index 2a7b0d2283..34ace64fb0 100644 --- a/dimos/agents/tokenizer/huggingface_tokenizer.py +++ b/dimos/agents/tokenizer/huggingface_tokenizer.py @@ -13,12 +13,13 @@ # limitations under the License. from transformers import AutoTokenizer + from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.utils.logging_config import setup_logger class HuggingFaceTokenizer(AbstractTokenizer): - def __init__(self, model_name: str = "Qwen/Qwen2.5-0.5B", **kwargs): + def __init__(self, model_name: str = "Qwen/Qwen2.5-0.5B", **kwargs) -> None: super().__init__(**kwargs) # Initilize the tokenizer for the huggingface models @@ -27,10 +28,10 @@ def __init__(self, model_name: str = "Qwen/Qwen2.5-0.5B", **kwargs): self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) except Exception as e: raise ValueError( - f"Failed to initialize tokenizer for model {self.model_name}. Error: {str(e)}" + f"Failed to initialize tokenizer for model {self.model_name}. Error: {e!s}" ) - def tokenize_text(self, text): + def tokenize_text(self, text: str): """ Tokenize a text string using the openai tokenizer. """ @@ -43,16 +44,16 @@ def detokenize_text(self, tokenized_text): try: return self.tokenizer.decode(tokenized_text, errors="ignore") except Exception as e: - raise ValueError(f"Failed to detokenize text. Error: {str(e)}") + raise ValueError(f"Failed to detokenize text. Error: {e!s}") - def token_count(self, text): + def token_count(self, text: str): """ Gets the token count of a text string using the openai tokenizer. """ return len(self.tokenize_text(text)) if text else 0 @staticmethod - def image_token_count(image_width, image_height, image_detail="high"): + def image_token_count(image_width, image_height, image_detail: str = "high"): """ Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. """ diff --git a/dimos/agents/tokenizer/openai_tokenizer.py b/dimos/agents/tokenizer/openai_tokenizer.py index 7517ae5e72..7fe5017241 100644 --- a/dimos/agents/tokenizer/openai_tokenizer.py +++ b/dimos/agents/tokenizer/openai_tokenizer.py @@ -13,12 +13,13 @@ # limitations under the License. import tiktoken + from dimos.agents.tokenizer.base import AbstractTokenizer from dimos.utils.logging_config import setup_logger class OpenAITokenizer(AbstractTokenizer): - def __init__(self, model_name: str = "gpt-4o", **kwargs): + def __init__(self, model_name: str = "gpt-4o", **kwargs) -> None: super().__init__(**kwargs) # Initilize the tokenizer for the openai set of models @@ -27,10 +28,10 @@ def __init__(self, model_name: str = "gpt-4o", **kwargs): self.tokenizer = tiktoken.encoding_for_model(self.model_name) except Exception as e: raise ValueError( - f"Failed to initialize tokenizer for model {self.model_name}. Error: {str(e)}" + f"Failed to initialize tokenizer for model {self.model_name}. Error: {e!s}" ) - def tokenize_text(self, text): + def tokenize_text(self, text: str): """ Tokenize a text string using the openai tokenizer. """ @@ -43,16 +44,16 @@ def detokenize_text(self, tokenized_text): try: return self.tokenizer.decode(tokenized_text, errors="ignore") except Exception as e: - raise ValueError(f"Failed to detokenize text. Error: {str(e)}") + raise ValueError(f"Failed to detokenize text. Error: {e!s}") - def token_count(self, text): + def token_count(self, text: str): """ Gets the token count of a text string using the openai tokenizer. """ return len(self.tokenize_text(text)) if text else 0 @staticmethod - def image_token_count(image_width, image_height, image_detail="high"): + def image_token_count(image_width, image_height, image_detail: str = "high"): """ Calculate the number of tokens in an image. Low detail is 85 tokens, high detail is 170 tokens per 512x512 square. """ diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index d03b848dd2..3869983d70 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -14,10 +14,10 @@ import asyncio import datetime import json +from operator import itemgetter import os +from typing import Any, TypedDict import uuid -from operator import itemgetter -from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union from langchain.chat_models import init_chat_model from langchain_core.messages import ( @@ -96,8 +96,8 @@ def _custom_json_serializers(obj): # and builds messages to be sent to an agent def snapshot_to_messages( state: SkillStateDict, - tool_calls: List[ToolCall], -) -> Tuple[List[ToolMessage], Optional[AIMessage]]: + tool_calls: list[ToolCall], +) -> tuple[list[ToolMessage], AIMessage | None]: # builds a set of tool call ids from a previous agent request tool_call_ids = set( map(itemgetter("id"), tool_calls), @@ -107,15 +107,15 @@ def snapshot_to_messages( tool_msgs: list[ToolMessage] = [] # build a general skill state overview (for longer running skills) - state_overview: list[Dict[str, SkillStateSummary]] = [] + state_overview: list[dict[str, SkillStateSummary]] = [] # for special skills that want to return a separate message # (images for example, requires to be a HumanMessage) - special_msgs: List[HumanMessage] = [] + special_msgs: list[HumanMessage] = [] # for special skills that want to return a separate message that should # stay in history, like actual human messages, critical events - history_msgs: List[HumanMessage] = [] + history_msgs: list[HumanMessage] = [] # Initialize state_msg state_msg = None @@ -162,13 +162,13 @@ def snapshot_to_messages( # Agent class job is to glue skill coordinator state to an agent, builds langchain messages class Agent(AgentSpec): system_message: SystemMessage - state_messages: List[Union[AIMessage, HumanMessage]] + state_messages: list[AIMessage | HumanMessage] def __init__( self, *args, **kwargs, - ): + ) -> None: AgentSpec.__init__(self, *args, **kwargs) self.state_messages = [] @@ -201,30 +201,30 @@ def get_agent_id(self) -> str: return self._agent_id @rpc - def start(self): + def start(self) -> None: super().start() self.coordinator.start() @rpc - def stop(self): + def stop(self) -> None: self.coordinator.stop() self._agent_stopped = True super().stop() - def clear_history(self): + def clear_history(self) -> None: self._history.clear() - def append_history(self, *msgs: List[Union[AIMessage, HumanMessage]]): + def append_history(self, *msgs: list[AIMessage | HumanMessage]) -> None: for msg in msgs: self.publish(msg) self._history.extend(msgs) def history(self): - return [self.system_message] + self._history + self.state_messages + return [self.system_message, *self._history, *self.state_messages] # Used by agent to execute tool calls - def execute_tool_calls(self, tool_calls: List[ToolCall]) -> None: + def execute_tool_calls(self, tool_calls: list[ToolCall]) -> None: """Execute a list of tool calls from the agent.""" if self._agent_stopped: logger.warning("Agent is stopped, cannot execute tool calls.") @@ -325,7 +325,7 @@ def _get_state() -> str: traceback.print_exc() @rpc - def loop_thread(self): + def loop_thread(self) -> bool: asyncio.run_coroutine_threadsafe(self.agent_loop(), self._loop) return True @@ -351,7 +351,7 @@ def register_skills(self, container, run_implicit_name: str | None = None): def get_tools(self): return self.coordinator.get_tools() - def _write_debug_history_file(self): + def _write_debug_history_file(self) -> None: file_path = os.getenv("DEBUG_AGENT_HISTORY_FILE") if not file_path: return @@ -378,13 +378,15 @@ def stop(self) -> None: def deploy( dimos: DimosCluster, - system_prompt="You are a helpful assistant for controlling a Unitree Go2 robot.", + system_prompt: str = "You are a helpful assistant for controlling a Unitree Go2 robot.", model: Model = Model.GPT_4O, provider: Provider = Provider.OPENAI, - skill_containers: Optional[List[SkillContainer]] = [], + skill_containers: list[SkillContainer] | None = None, ) -> Agent: from dimos.agents2.cli.human import HumanInput + if skill_containers is None: + skill_containers = [] agent = dimos.deploy( Agent, system_prompt=system_prompt, @@ -408,4 +410,4 @@ def deploy( return agent -__all__ = ["Agent", "llm_agent", "deploy"] +__all__ = ["Agent", "deploy", "llm_agent"] diff --git a/dimos/agents2/cli/human.py b/dimos/agents2/cli/human.py index 8256520db3..15727d87b8 100644 --- a/dimos/agents2/cli/human.py +++ b/dimos/agents2/cli/human.py @@ -36,8 +36,7 @@ def human(self): msg_queue = queue.Queue() unsub = transport.subscribe(msg_queue.put) self._disposables.add(Disposable(unsub)) - for message in iter(msg_queue.get, None): - yield message + yield from iter(msg_queue.get, None) @rpc def start(self) -> None: diff --git a/dimos/agents2/conftest.py b/dimos/agents2/conftest.py index de805afdcf..769523f8c5 100644 --- a/dimos/agents2/conftest.py +++ b/dimos/agents2/conftest.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from pathlib import Path +import pytest + from dimos.agents2.agent import Agent from dimos.agents2.testing import MockModel from dimos.protocol.skill.test_coordinator import SkillContainerTest @@ -26,7 +27,7 @@ def fixture_dir(): @pytest.fixture -def potato_system_prompt(): +def potato_system_prompt() -> str: return "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" diff --git a/dimos/agents2/constants.py b/dimos/agents2/constants.py index 1608a635f8..0d7d4832a0 100644 --- a/dimos/agents2/constants.py +++ b/dimos/agents2/constants.py @@ -14,5 +14,4 @@ from dimos.constants import DIMOS_PROJECT_ROOT - AGENT_SYSTEM_PROMPT_PATH = DIMOS_PROJECT_ROOT / "assets/agent/prompt_agents2.txt" diff --git a/dimos/agents2/skills/google_maps_skill_container.py b/dimos/agents2/skills/google_maps_skill_container.py index ddf64cbef0..433914a5e3 100644 --- a/dimos/agents2/skills/google_maps_skill_container.py +++ b/dimos/agents2/skills/google_maps_skill_container.py @@ -13,8 +13,10 @@ # limitations under the License. import json -from typing import Any, Optional, Union +from typing import Any + from reactivex import Observable +from reactivex.disposable import CompositeDisposable from dimos.core.resource import Resource from dimos.mapping.google_maps.google_maps import GoogleMaps @@ -24,20 +26,18 @@ from dimos.robot.robot import Robot from dimos.utils.logging_config import setup_logger -from reactivex.disposable import CompositeDisposable - logger = setup_logger(__file__) class GoogleMapsSkillContainer(SkillContainer, Resource): _robot: Robot _disposables: CompositeDisposable - _latest_location: Optional[LatLon] + _latest_location: LatLon | None _position_stream: Observable[LatLon] _current_location_map: CurrentLocationMap _started: bool - def __init__(self, robot: Robot, position_stream: Observable[LatLon]): + def __init__(self, robot: Robot, position_stream: Observable[LatLon]) -> None: super().__init__() self._robot = robot self._disposables = CompositeDisposable() @@ -110,7 +110,7 @@ def get_gps_position_for_queries(self, *queries: str) -> str: location = self._get_latest_location() - results: list[Union[dict[str, Any], str]] = [] + results: list[dict[str, Any] | str] = [] for query in queries: try: diff --git a/dimos/agents2/skills/gps_nav_skill.py b/dimos/agents2/skills/gps_nav_skill.py index dedda933ca..80e346790a 100644 --- a/dimos/agents2/skills/gps_nav_skill.py +++ b/dimos/agents2/skills/gps_nav_skill.py @@ -13,8 +13,9 @@ # limitations under the License. import json -from typing import Optional + from reactivex import Observable +from reactivex.disposable import CompositeDisposable from dimos.core.resource import Resource from dimos.mapping.google_maps.google_maps import GoogleMaps @@ -25,22 +26,19 @@ from dimos.robot.robot import Robot from dimos.utils.logging_config import setup_logger -from reactivex.disposable import CompositeDisposable - - logger = setup_logger(__file__) class GpsNavSkillContainer(SkillContainer, Resource): _robot: Robot _disposables: CompositeDisposable - _latest_location: Optional[LatLon] + _latest_location: LatLon | None _position_stream: Observable[LatLon] _current_location_map: CurrentLocationMap _started: bool _max_valid_distance: int - def __init__(self, robot: Robot, position_stream: Observable[LatLon]): + def __init__(self, robot: Robot, position_stream: Observable[LatLon]) -> None: super().__init__() self._robot = robot self._disposables = CompositeDisposable() @@ -92,7 +90,7 @@ def set_gps_travel_points(self, *points: dict[str, float]) -> str: return "I've successfully set the travel points." - def _convert_point(self, point: dict[str, float]) -> Optional[LatLon]: + def _convert_point(self, point: dict[str, float]) -> LatLon | None: if not isinstance(point, dict): return None lat = point.get("lat") diff --git a/dimos/agents2/skills/navigation.py b/dimos/agents2/skills/navigation.py index 09c6c074ba..9a7b91d68a 100644 --- a/dimos/agents2/skills/navigation.py +++ b/dimos/agents2/skills/navigation.py @@ -13,8 +13,7 @@ # limitations under the License. import time -from functools import partial -from typing import Any, Optional +from typing import Any from dimos.core.core import rpc from dimos.core.rpc_client import RpcCall @@ -35,8 +34,8 @@ class NavigationSkillContainer(SkillModule): - _latest_image: Optional[Image] = None - _latest_odom: Optional[PoseStamped] = None + _latest_image: Image | None = None + _latest_odom: PoseStamped | None = None _skill_started: bool = False _similarity_threshold: float = 0.23 @@ -57,7 +56,7 @@ class NavigationSkillContainer(SkillModule): color_image: In[Image] = None odom: In[PoseStamped] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self._skill_started = False self._vl_model = QwenVlModel() @@ -178,7 +177,7 @@ def tag_location(self, location_name: str) -> str: logger.info(f"Tagged {location}") return f"Tagged '{location_name}': ({position.x},{position.y})." - def _navigate_to_object(self, query: str) -> Optional[str]: + def _navigate_to_object(self, query: str) -> str | None: position = self.detection_module.nav_vlm(query) print("Object position from VLM:", position) if not position: @@ -219,7 +218,7 @@ def navigate_with_text(self, query: str) -> str: return f"No tagged location called '{query}'. No object in view matching '{query}'. No matching location found in semantic map for '{query}'." - def _navigate_by_tagged_location(self, query: str) -> Optional[str]: + def _navigate_by_tagged_location(self, query: str) -> str | None: if not self._query_tagged_location: logger.warning("SpatialMemory module not connected, cannot query tagged locations") return None @@ -266,7 +265,7 @@ def _navigate_to(self, pose: PoseStamped) -> bool: logger.info("Navigation goal reached") return True - def _navigate_to_object(self, query: str) -> Optional[str]: + def _navigate_to_object(self, query: str) -> str | None: try: bbox = self._get_bbox_for_current_frame(query) except Exception: @@ -322,7 +321,7 @@ def _navigate_to_object(self, query: str) -> Optional[str]: self._stop_track() return None - def _get_bbox_for_current_frame(self, query: str) -> Optional[BBox]: + def _get_bbox_for_current_frame(self, query: str) -> BBox | None: if self._latest_image is None: return None @@ -418,7 +417,7 @@ def _start_exploration(self, timeout: float) -> str: return "Exploration completed successfuly" - def _get_goal_pose_from_result(self, result: dict[str, Any]) -> Optional[PoseStamped]: + def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | None: similarity = 1.0 - (result.get("distance") or 1) if similarity < self._similarity_threshold: logger.warning( diff --git a/dimos/agents2/skills/osm.py b/dimos/agents2/skills/osm.py index eaaef41858..ae721bea81 100644 --- a/dimos/agents2/skills/osm.py +++ b/dimos/agents2/skills/osm.py @@ -12,43 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional -from dimos.core.core import rpc -from dimos.core.module import Module -from dimos.core.rpc_client import RPCClient, RpcCall from dimos.core.skill_module import SkillModule from dimos.core.stream import In from dimos.mapping.osm.current_location_map import CurrentLocationMap -from dimos.mapping.utils.distance import distance_in_meters from dimos.mapping.types import LatLon +from dimos.mapping.utils.distance import distance_in_meters from dimos.models.vl.qwen import QwenVlModel from dimos.protocol.skill.skill import skill from dimos.utils.logging_config import setup_logger - logger = setup_logger(__file__) class OsmSkill(SkillModule): - _latest_location: Optional[LatLon] + _latest_location: LatLon | None _current_location_map: CurrentLocationMap _skill_started: bool gps_location: In[LatLon] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self._latest_location = None self._current_location_map = CurrentLocationMap(QwenVlModel()) self._skill_started = False - def start(self): + def start(self) -> None: super().start() self._skill_started = True self._disposables.add(self.gps_location.subscribe(self._on_gps_location)) - def stop(self): + def stop(self) -> None: super().stop() def _on_gps_location(self, location: LatLon) -> None: diff --git a/dimos/agents2/skills/ros_navigation.py b/dimos/agents2/skills/ros_navigation.py index e751a5d7aa..973cdcc10f 100644 --- a/dimos/agents2/skills/ros_navigation.py +++ b/dimos/agents2/skills/ros_navigation.py @@ -13,15 +13,14 @@ # limitations under the License. import time -from typing import Any, Optional +from typing import TYPE_CHECKING, Any + +from dimos.core.resource import Resource from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.geometry_msgs.Vector3 import make_vector3 from dimos.protocol.skill.skill import SkillContainer, skill from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion -from dimos.core.resource import Resource - -from typing import TYPE_CHECKING if TYPE_CHECKING: from dimos.robot.unitree_webrtc.unitree_g1 import UnitreeG1 @@ -33,7 +32,7 @@ class RosNavigation(SkillContainer, Resource): _robot: "UnitreeG1" _started: bool - def __init__(self, robot: "UnitreeG1"): + def __init__(self, robot: "UnitreeG1") -> None: self._robot = robot self._similarity_threshold = 0.23 self._started = False @@ -97,7 +96,7 @@ def stop_movement(self) -> str: return "Stopped" - def _get_goal_pose_from_result(self, result: dict[str, Any]) -> Optional[PoseStamped]: + def _get_goal_pose_from_result(self, result: dict[str, Any]) -> PoseStamped | None: similarity = 1.0 - (result.get("distance") or 1) if similarity < self._similarity_threshold: logger.warning( diff --git a/dimos/agents2/skills/test_google_maps_skill_container.py b/dimos/agents2/skills/test_google_maps_skill_container.py index ff7a396a84..27a9dadb8f 100644 --- a/dimos/agents2/skills/test_google_maps_skill_container.py +++ b/dimos/agents2/skills/test_google_maps_skill_container.py @@ -13,10 +13,11 @@ # limitations under the License. import re + from dimos.mapping.google_maps.types import Coordinates, LocationContext, Position -def test_where_am_i(create_google_maps_agent, google_maps_skill_container): +def test_where_am_i(create_google_maps_agent, google_maps_skill_container) -> None: google_maps_skill_container._client.get_location_context.return_value = LocationContext( street="Bourbon Street", coordinates=Coordinates(lat=37.782654, lon=-122.413273) ) @@ -27,7 +28,9 @@ def test_where_am_i(create_google_maps_agent, google_maps_skill_container): assert "bourbon" in response.lower() -def test_get_gps_position_for_queries(create_google_maps_agent, google_maps_skill_container): +def test_get_gps_position_for_queries( + create_google_maps_agent, google_maps_skill_container +) -> None: google_maps_skill_container._client.get_position.side_effect = [ Position(lat=37.782601, lon=-122.413201, description="address 1"), Position(lat=37.782602, lon=-122.413202, description="address 2"), diff --git a/dimos/agents2/skills/test_gps_nav_skills.py b/dimos/agents2/skills/test_gps_nav_skills.py index 5f5593609f..9e8090b169 100644 --- a/dimos/agents2/skills/test_gps_nav_skills.py +++ b/dimos/agents2/skills/test_gps_nav_skills.py @@ -16,7 +16,7 @@ from dimos.mapping.types import LatLon -def test_set_gps_travel_points(fake_gps_robot, create_gps_nav_agent): +def test_set_gps_travel_points(fake_gps_robot, create_gps_nav_agent) -> None: agent = create_gps_nav_agent(fixture="test_set_gps_travel_points.json") agent.query("go to lat: 37.782654, lon: -122.413273") @@ -26,7 +26,7 @@ def test_set_gps_travel_points(fake_gps_robot, create_gps_nav_agent): ) -def test_set_gps_travel_points_multiple(fake_gps_robot, create_gps_nav_agent): +def test_set_gps_travel_points_multiple(fake_gps_robot, create_gps_nav_agent) -> None: agent = create_gps_nav_agent(fixture="test_set_gps_travel_points_multiple.json") agent.query( diff --git a/dimos/agents2/skills/test_navigation.py b/dimos/agents2/skills/test_navigation.py index f612095f29..9d4f3b7eff 100644 --- a/dimos/agents2/skills/test_navigation.py +++ b/dimos/agents2/skills/test_navigation.py @@ -13,14 +13,12 @@ # limitations under the License. -import pytest - from dimos.msgs.geometry_msgs import PoseStamped, Vector3 from dimos.utils.transform_utils import euler_to_quaternion # @pytest.mark.skip -def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker): +def test_stop_movement(create_navigation_agent, navigation_skill_container, mocker) -> None: navigation_skill_container._cancel_goal = mocker.Mock() navigation_skill_container._stop_exploration = mocker.Mock() agent = create_navigation_agent(fixture="test_stop_movement.json") @@ -31,7 +29,7 @@ def test_stop_movement(create_navigation_agent, navigation_skill_container, mock navigation_skill_container._stop_exploration.assert_called_once_with() -def test_take_a_look_around(create_navigation_agent, navigation_skill_container, mocker): +def test_take_a_look_around(create_navigation_agent, navigation_skill_container, mocker) -> None: navigation_skill_container._explore = mocker.Mock() navigation_skill_container._is_exploration_active = mocker.Mock() mocker.patch("dimos.agents2.skills.navigation.time.sleep") @@ -42,7 +40,9 @@ def test_take_a_look_around(create_navigation_agent, navigation_skill_container, navigation_skill_container._explore.assert_called_once_with() -def test_go_to_semantic_location(create_navigation_agent, navigation_skill_container, mocker): +def test_go_to_semantic_location( + create_navigation_agent, navigation_skill_container, mocker +) -> None: mocker.patch( "dimos.agents2.skills.navigation.NavigationSkillContainer._navigate_by_tagged_location", return_value=None, diff --git a/dimos/agents2/spec.py b/dimos/agents2/spec.py index 889092bad3..9973b05356 100644 --- a/dimos/agents2/spec.py +++ b/dimos/agents2/spec.py @@ -17,16 +17,14 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from enum import Enum -from typing import Any, List, Optional, Tuple, Union +from typing import Any, Union from langchain.chat_models.base import _SUPPORTED_PROVIDERS from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, HumanMessage, - MessageLikeRepresentation, SystemMessage, - ToolCall, ToolMessage, ) from rich.console import Console @@ -131,13 +129,13 @@ class Model(str, Enum): @dataclass class AgentConfig(ModuleConfig): - system_prompt: Optional[str | SystemMessage] = None - skills: Optional[SkillContainer | list[SkillContainer]] = None + system_prompt: str | SystemMessage | None = None + skills: SkillContainer | list[SkillContainer] | None = None # we can provide model/provvider enums or instantiated model_instance model: Model = Model.GPT_4O provider: Provider = Provider.OPENAI - model_instance: Optional[BaseChatModel] = None + model_instance: BaseChatModel | None = None agent_transport: type[PubSub] = lcm.PickleLCM agent_topic: Any = field(default_factory=lambda: lcm.Topic("/agent")) @@ -149,14 +147,14 @@ class AgentConfig(ModuleConfig): class AgentSpec(Service[AgentConfig], Module, ABC): default_config: type[AgentConfig] = AgentConfig - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: Service.__init__(self, *args, **kwargs) Module.__init__(self, *args, **kwargs) if self.config.agent_transport: self.transport = self.config.agent_transport() - def publish(self, msg: AnyMessage): + def publish(self, msg: AnyMessage) -> None: if self.transport: self.transport.publish(self.config.agent_topic, msg) @@ -171,10 +169,10 @@ def stop(self) -> None: def clear_history(self): ... @abstractmethod - def append_history(self, *msgs: List[Union[AIMessage, HumanMessage]]): ... + def append_history(self, *msgs: list[AIMessage | HumanMessage]): ... @abstractmethod - def history(self) -> List[AnyMessage]: ... + def history(self) -> list[AnyMessage]: ... @rpc @abstractmethod diff --git a/dimos/agents2/system_prompt.py b/dimos/agents2/system_prompt.py index 5168ed96d0..6b14f3e193 100644 --- a/dimos/agents2/system_prompt.py +++ b/dimos/agents2/system_prompt.py @@ -20,6 +20,6 @@ def get_system_prompt() -> str: global _SYSTEM_PROMPT if _SYSTEM_PROMPT is None: - with open(AGENT_SYSTEM_PROMPT_PATH, "r") as f: + with open(AGENT_SYSTEM_PROMPT_PATH) as f: _SYSTEM_PROMPT = f.read() return _SYSTEM_PROMPT diff --git a/dimos/agents2/temp/run_unitree_agents2.py b/dimos/agents2/temp/run_unitree_agents2.py index 29b9d4c978..aacfd1b5f4 100644 --- a/dimos/agents2/temp/run_unitree_agents2.py +++ b/dimos/agents2/temp/run_unitree_agents2.py @@ -19,9 +19,9 @@ """ import os +from pathlib import Path import sys import time -from pathlib import Path from dotenv import load_dotenv @@ -52,7 +52,7 @@ class UnitreeAgentRunner: """Manages the Unitree robot with the new agents2 framework.""" - def __init__(self): + def __init__(self) -> None: self.robot = None self.agent = None self.agent_thread = None @@ -99,7 +99,7 @@ def setup_agent(self, skillcontainers, system_prompt: str) -> Agent: agent.loop_thread() return agent - def run(self): + def run(self) -> None: """Main run loop.""" print("\n" + "=" * 60) print("Unitree Go2 Robot with agents2 Framework") @@ -157,7 +157,7 @@ def run(self): # finally: # self.shutdown() - def shutdown(self): + def shutdown(self) -> None: logger.info("Shutting down...") self.running = False @@ -178,7 +178,7 @@ def shutdown(self): logger.info("Shutdown complete") -def main(): +def main() -> None: runner = UnitreeAgentRunner() runner.run() diff --git a/dimos/agents2/temp/run_unitree_async.py b/dimos/agents2/temp/run_unitree_async.py index cb870096da..29213c1c90 100644 --- a/dimos/agents2/temp/run_unitree_async.py +++ b/dimos/agents2/temp/run_unitree_async.py @@ -20,17 +20,18 @@ import asyncio import os -import sys from pathlib import Path +import sys + from dotenv import load_dotenv # Add parent directories to path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer from dimos.agents2 import Agent from dimos.agents2.spec import Model, Provider +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer from dimos.utils.logging_config import setup_logger logger = setup_logger("run_unitree_async") @@ -70,10 +71,10 @@ async def handle_query(agent, query_text): return "Query timeout" except Exception as e: logger.error(f"Error processing query: {e}") - return f"Error: {str(e)}" + return f"Error: {e!s}" -async def interactive_loop(agent): +async def interactive_loop(agent) -> None: """Run an interactive query loop.""" print("\n" + "=" * 60) print("Interactive Agent Mode") @@ -101,7 +102,7 @@ async def interactive_loop(agent): logger.error(f"Error in interactive loop: {e}") -async def main(): +async def main() -> None: """Main async function.""" print("\n" + "=" * 60) print("Unitree Go2 Robot with agents2 Framework (Async)") @@ -115,7 +116,7 @@ async def main(): # Load system prompt try: - with open(SYSTEM_PROMPT_PATH, "r") as f: + with open(SYSTEM_PROMPT_PATH) as f: system_prompt = f.read() except FileNotFoundError: system_prompt = """You are a helpful robot assistant controlling a Unitree Go2 robot. diff --git a/dimos/agents2/temp/test_unitree_agent_query.py b/dimos/agents2/temp/test_unitree_agent_query.py index bd2843ac19..4990940e6c 100644 --- a/dimos/agents2/temp/test_unitree_agent_query.py +++ b/dimos/agents2/temp/test_unitree_agent_query.py @@ -20,17 +20,18 @@ import asyncio import os +from pathlib import Path import sys import time -from pathlib import Path + from dotenv import load_dotenv # Add parent directories to path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) -from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer from dimos.agents2 import Agent from dimos.agents2.spec import Model, Provider +from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer from dimos.utils.logging_config import setup_logger logger = setup_logger("test_agent_query") @@ -80,7 +81,7 @@ async def test_async_query(): return future -def test_sync_query_with_thread(): +def test_sync_query_with_thread() -> None: """Test agent query using threading for the event loop.""" print("\n=== Testing Sync Query with Thread ===\n") @@ -111,7 +112,7 @@ def test_sync_query_with_thread(): logger.warning("Agent's event loop is NOT running - this is the problem!") # Try to run the loop in a thread - def run_loop(): + def run_loop() -> None: asyncio.set_event_loop(agent._loop) agent._loop.run_forever() @@ -189,7 +190,7 @@ def run_loop(): # dimos.stop() -def main(): +def main() -> None: """Run tests based on available API key.""" if not os.getenv("OPENAI_API_KEY"): diff --git a/dimos/agents2/temp/test_unitree_skill_container.py b/dimos/agents2/temp/test_unitree_skill_container.py index 3b127e2ca0..16502004ff 100644 --- a/dimos/agents2/temp/test_unitree_skill_container.py +++ b/dimos/agents2/temp/test_unitree_skill_container.py @@ -18,9 +18,9 @@ Tests skill registration and basic functionality. """ +from pathlib import Path import sys import time -from pathlib import Path # Add parent directories to path sys.path.insert(0, str(Path(__file__).parent.parent.parent.parent)) @@ -96,7 +96,7 @@ def test_agent_with_skills(): time.sleep(0.1) -def test_skill_schemas(): +def test_skill_schemas() -> None: """Test that skill schemas are properly generated for LangChain.""" print("\n=== Testing Skill Schemas ===") diff --git a/dimos/agents2/temp/webcam_agent.py b/dimos/agents2/temp/webcam_agent.py index 17a68a55ad..485684d9e0 100644 --- a/dimos/agents2/temp/webcam_agent.py +++ b/dimos/agents2/temp/webcam_agent.py @@ -18,8 +18,8 @@ This is the migrated version using the new LangChain-based agent system. """ -import time from threading import Thread +import time import reactivex as rx import reactivex.operators as ops @@ -27,12 +27,11 @@ from dimos.agents2 import Agent, Output, Reducer, Stream, skill from dimos.agents2.cli.human import HumanInput from dimos.agents2.spec import Model, Provider -from dimos.core import LCMTransport, Module, start, rpc +from dimos.core import LCMTransport, Module, rpc, start from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 - from dimos.msgs.sensor_msgs import CameraInfo, Image from dimos.protocol.skill.test_coordinator import SkillContainerTest from dimos.web.robot_web_interface import RobotWebInterface @@ -47,13 +46,13 @@ class WebModule(Module): _human_messages_running = False - def __init__(self): + def __init__(self) -> None: super().__init__() self.agent_response = rx.subject.Subject() self.human_query = rx.subject.Subject() @rpc - def start(self): + def start(self) -> None: super().start() text_streams = { @@ -73,7 +72,7 @@ def start(self): self.thread.start() @rpc - def stop(self): + def stop(self) -> None: if self.web_interface: self.web_interface.stop() if self.thread: @@ -96,7 +95,7 @@ def human_messages(self): yield message -def main(): +def main() -> None: dimos = start(4) # Create agent agent = Agent( diff --git a/dimos/agents2/test_agent.py b/dimos/agents2/test_agent.py index e1cd9adbcd..447d02e6e3 100644 --- a/dimos/agents2/test_agent.py +++ b/dimos/agents2/test_agent.py @@ -149,7 +149,7 @@ async def agent_context(request): # @pytest.mark.timeout(40) @pytest.mark.tool @pytest.mark.asyncio -async def test_agent_init(agent_context): +async def test_agent_init(agent_context) -> None: """Test agent initialization and basic functionality across different configurations""" agent, testcontainer = agent_context diff --git a/dimos/agents2/test_agent_direct.py b/dimos/agents2/test_agent_direct.py index 8466eb4070..ee3f9aa091 100644 --- a/dimos/agents2/test_agent_direct.py +++ b/dimos/agents2/test_agent_direct.py @@ -79,7 +79,7 @@ def full(): testcontainer.stop() -def check_agent(agent_context): +def check_agent(agent_context) -> None: """Test agent initialization and basic functionality across different configurations""" with agent_context() as [agent, testcontainer]: agent.register_skills(testcontainer) diff --git a/dimos/agents2/test_agent_fake.py b/dimos/agents2/test_agent_fake.py index a282ed3794..14e28cd89c 100644 --- a/dimos/agents2/test_agent_fake.py +++ b/dimos/agents2/test_agent_fake.py @@ -13,13 +13,13 @@ # limitations under the License. -def test_what_is_your_name(create_potato_agent): +def test_what_is_your_name(create_potato_agent) -> None: agent = create_potato_agent(fixture="test_what_is_your_name.json") response = agent.query("hi there, please tell me what's your name?") assert "Mr. Potato" in response -def test_how_much_is_124181112_plus_124124(create_potato_agent): +def test_how_much_is_124181112_plus_124124(create_potato_agent) -> None: agent = create_potato_agent(fixture="test_how_much_is_124181112_plus_124124.json") response = agent.query("how much is 124181112 + 124124?") @@ -29,7 +29,7 @@ def test_how_much_is_124181112_plus_124124(create_potato_agent): assert "999000000" in response.replace(",", "") -def test_what_do_you_see_in_this_picture(create_potato_agent): +def test_what_do_you_see_in_this_picture(create_potato_agent) -> None: agent = create_potato_agent(fixture="test_what_do_you_see_in_this_picture.json") response = agent.query("take a photo and tell me what do you see") diff --git a/dimos/agents2/test_mock_agent.py b/dimos/agents2/test_mock_agent.py index 5ade99f9ab..4b113b45a0 100644 --- a/dimos/agents2/test_mock_agent.py +++ b/dimos/agents2/test_mock_agent.py @@ -16,9 +16,9 @@ import time -import pytest from dimos_lcm.sensor_msgs import CameraInfo from langchain_core.messages import AIMessage, HumanMessage +import pytest from dimos.agents2.agent import Agent from dimos.agents2.testing import MockModel @@ -30,7 +30,7 @@ from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -def test_tool_call(): +def test_tool_call() -> None: """Test agent initialization and tool call execution.""" # Create a fake model that will respond with tool calls fake_model = MockModel( @@ -74,7 +74,7 @@ def test_tool_call(): agent.stop() -def test_image_tool_call(): +def test_image_tool_call() -> None: """Test agent with image tool call execution.""" dimos = start(2) # Create a fake model that will respond with image tool calls @@ -131,7 +131,7 @@ def test_image_tool_call(): @pytest.mark.tool -def test_tool_call_implicit_detections(): +def test_tool_call_implicit_detections() -> None: """Test agent with image tool call execution.""" dimos = start(2) # Create a fake model that will respond with image tool calls diff --git a/dimos/agents2/test_stash_agent.py b/dimos/agents2/test_stash_agent.py index 715e24b513..8e2972568a 100644 --- a/dimos/agents2/test_stash_agent.py +++ b/dimos/agents2/test_stash_agent.py @@ -15,13 +15,12 @@ import pytest from dimos.agents2.agent import Agent -from dimos.core import start from dimos.protocol.skill.test_coordinator import SkillContainerTest @pytest.mark.tool @pytest.mark.asyncio -async def test_agent_init(): +async def test_agent_init() -> None: system_prompt = ( "Your name is Mr. Potato, potatoes are bad at math. Use a tools if asked to calculate" ) diff --git a/dimos/agents2/testing.py b/dimos/agents2/testing.py index 8b173ecfd3..b729c13d50 100644 --- a/dimos/agents2/testing.py +++ b/dimos/agents2/testing.py @@ -14,10 +14,11 @@ """Testing utilities for agents.""" +from collections.abc import Iterator, Sequence import json import os from pathlib import Path -from typing import Any, Dict, Iterator, List, Optional, Sequence, Union +from typing import Any from langchain.chat_models import init_chat_model from langchain_core.callbacks.manager import CallbackManagerForLLMRun @@ -39,14 +40,14 @@ class MockModel(SimpleChatModel): 2. Record mode: Uses a real LLM and saves responses to a JSON file """ - responses: List[Union[str, AIMessage]] = [] + responses: list[str | AIMessage] = [] i: int = 0 - json_path: Optional[Path] = None + json_path: Path | None = None record: bool = False - real_model: Optional[Any] = None - recorded_messages: List[Dict[str, Any]] = [] + real_model: Any | None = None + recorded_messages: list[dict[str, Any]] = [] - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: # Extract custom parameters before calling super().__init__ responses = kwargs.pop("responses", []) json_path = kwargs.pop("json_path", None) @@ -58,7 +59,7 @@ def __init__(self, **kwargs): self.json_path = Path(json_path) if json_path else None self.record = bool(os.getenv("RECORD")) self.i = 0 - self._bound_tools: Optional[Sequence[Any]] = None + self._bound_tools: Sequence[Any] | None = None self.recorded_messages = [] if self.record: @@ -76,8 +77,8 @@ def __init__(self, **kwargs): def _llm_type(self) -> str: return "tool-call-fake-chat-model" - def _load_responses_from_json(self) -> List[AIMessage]: - with open(self.json_path, "r") as f: + def _load_responses_from_json(self) -> list[AIMessage]: + with open(self.json_path) as f: data = json.load(f) responses = [] @@ -92,7 +93,7 @@ def _load_responses_from_json(self) -> List[AIMessage]: responses.append(msg) return responses - def _save_responses_to_json(self): + def _save_responses_to_json(self) -> None: if not self.json_path: return @@ -112,9 +113,9 @@ def _save_responses_to_json(self): def _call( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> str: """Not used in _generate.""" @@ -122,9 +123,9 @@ def _call( def _generate( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> ChatResult: if self.record: @@ -146,7 +147,7 @@ def _generate( else: # Playback mode - use predefined responses if not self.responses: - raise ValueError(f"No responses available for playback. ") + raise ValueError("No responses available for playback. ") if self.i >= len(self.responses): # Don't wrap around - stay at last response @@ -165,9 +166,9 @@ def _generate( def _stream( self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: """Stream not implemented for testing.""" @@ -178,9 +179,9 @@ def _stream( def bind_tools( self, - tools: Sequence[Union[dict[str, Any], type, Any]], + tools: Sequence[dict[str, Any] | type | Any], *, - tool_choice: Optional[str] = None, + tool_choice: str | None = None, **kwargs: Any, ) -> Runnable: """Store tools and return self.""" @@ -191,6 +192,6 @@ def bind_tools( return self @property - def tools(self) -> Optional[Sequence[Any]]: + def tools(self) -> Sequence[Any] | None: """Get bound tools for inspection.""" return self._bound_tools diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py index 8d767779c4..514d0cf3a6 100644 --- a/dimos/core/__init__.py +++ b/dimos/core/__init__.py @@ -26,10 +26,11 @@ from dimos.utils.actor_registry import ActorRegistry __all__ = [ - "DimosCluster", - "In", "LCMRPC", "LCMTF", + "TF", + "DimosCluster", + "In", "LCMTransport", "Module", "ModuleBase", @@ -40,7 +41,6 @@ "RemoteIn", "RemoteOut", "SHMTransport", - "TF", "TFConfig", "TFSpec", "Transport", @@ -55,11 +55,11 @@ class CudaCleanupPlugin: """Dask worker plugin to cleanup CUDA resources on shutdown.""" - def setup(self, worker): + def setup(self, worker) -> None: """Called when worker starts.""" pass - def teardown(self, worker): + def teardown(self, worker) -> None: """Clean up CUDA resources when worker shuts down.""" try: import sys @@ -79,7 +79,7 @@ def teardown(self, worker): pass -def patch_actor(actor, cls): ... +def patch_actor(actor, cls) -> None: ... DimosCluster = Client @@ -101,14 +101,14 @@ def deploy( ).result() worker = actor.set_ref(actor).result() - print((f"deployed: {colors.blue(actor)} @ {colors.orange('worker ' + str(worker))}")) + print(f"deployed: {colors.blue(actor)} @ {colors.orange('worker ' + str(worker))}") # Register actor deployment in shared memory ActorRegistry.update(str(actor), str(worker)) return RPCClient(actor, actor_class) - def check_worker_memory(): + def check_worker_memory() -> None: """Check memory usage of all workers.""" info = dask_client.scheduler_info() console = Console() @@ -130,7 +130,7 @@ def check_worker_memory(): memory_used_gb = memory_used / 1e9 memory_limit_gb = memory_limit / 1e9 managed_gb = managed_bytes / 1e9 - spilled_gb = spilled / 1e9 + spilled / 1e9 total_memory_used += memory_used total_memory_limit += memory_limit @@ -161,7 +161,7 @@ def check_worker_memory(): f"[bold]Total: {total_used_gb:.2f}/{total_limit_gb:.2f}GB ({total_percentage:.1f}%) across {total_workers} workers[/bold]" ) - def close_all(): + def close_all() -> None: # Prevents multiple calls to close_all if hasattr(dask_client, "_closed") and dask_client._closed: return @@ -227,7 +227,7 @@ def close_all(): return dask_client # type: ignore[return-value] -def start(n: Optional[int] = None, memory_limit: str = "auto") -> DimosCluster: +def start(n: int | None = None, memory_limit: str = "auto") -> DimosCluster: """Start a Dask LocalCluster with specified workers and memory limits. Args: @@ -260,7 +260,7 @@ def start(n: Optional[int] = None, memory_limit: str = "auto") -> DimosCluster: patched_client._shutting_down = False # Signal handler with proper exit handling - def signal_handler(sig, frame): + def signal_handler(sig, frame) -> None: # If already shutting down, force exit if patched_client._shutting_down: import os @@ -286,7 +286,7 @@ def signal_handler(sig, frame): return patched_client -def wait_exit(): +def wait_exit() -> None: while True: try: time.sleep(1) diff --git a/dimos/core/blueprints.py b/dimos/core/blueprints.py index 53f20a0bfb..793402088b 100644 --- a/dimos/core/blueprints.py +++ b/dimos/core/blueprints.py @@ -12,16 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass, field from collections import defaultdict -from functools import cached_property +from collections.abc import Mapping +from dataclasses import dataclass, field +from functools import cached_property, reduce import inspect +import operator from types import MappingProxyType -from typing import Any, Literal, Mapping, get_origin, get_args +from typing import Any, Literal, get_args, get_origin -from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.global_config import GlobalConfig from dimos.core.module import Module +from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.stream import In, Out from dimos.core.transport import LCMTransport, pLCMTransport from dimos.utils.generic import short_id @@ -164,9 +166,11 @@ def create_module_blueprint(module: type[Module], *args: Any, **kwargs: Any) -> def autoconnect(*blueprints: ModuleBlueprintSet) -> ModuleBlueprintSet: all_blueprints = tuple(_eliminate_duplicates([bp for bs in blueprints for bp in bs.blueprints])) - all_transports = dict(sum([list(x.transports.items()) for x in blueprints], [])) + all_transports = dict( + reduce(operator.iadd, [list(x.transports.items()) for x in blueprints], []) + ) all_config_overrides = dict( - sum([list(x.global_config_overrides.items()) for x in blueprints], []) + reduce(operator.iadd, [list(x.global_config_overrides.items()) for x in blueprints], []) ) return ModuleBlueprintSet( diff --git a/dimos/core/core.py b/dimos/core/core.py index 6a30f18d9e..57e49e555d 100644 --- a/dimos/core/core.py +++ b/dimos/core/core.py @@ -15,17 +15,17 @@ from __future__ import annotations -import traceback from typing import ( + TYPE_CHECKING, Any, - Callable, - List, TypeVar, ) -import dimos.core.colors as colors from dimos.core.o3dpickle import register_picklers +if TYPE_CHECKING: + from collections.abc import Callable + # injects pickling system into o3d register_picklers() T = TypeVar("T") diff --git a/dimos/core/global_config.py b/dimos/core/global_config.py index e25184c351..9c155ecfc5 100644 --- a/dimos/core/global_config.py +++ b/dimos/core/global_config.py @@ -13,6 +13,7 @@ # limitations under the License. from functools import cached_property + from pydantic_settings import BaseSettings, SettingsConfigDict diff --git a/dimos/core/module.py b/dimos/core/module.py index aa65c1479f..4d8cb6ef5f 100644 --- a/dimos/core/module.py +++ b/dimos/core/module.py @@ -12,25 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +from collections.abc import Callable +from dataclasses import dataclass from functools import partial import inspect import threading -from dataclasses import dataclass from typing import ( Any, - Callable, - Optional, get_args, get_origin, get_type_hints, ) -from reactivex.disposable import CompositeDisposable from dask.distributed import Actor, get_worker +from reactivex.disposable import CompositeDisposable from dimos.core import colors from dimos.core.core import T, rpc -from dimos.core.global_config import GlobalConfig from dimos.core.resource import Resource from dimos.core.stream import In, Out, RemoteIn, RemoteOut, Transport from dimos.protocol.rpc import LCMRPC, RPCSpec @@ -40,7 +38,7 @@ from dimos.utils.generic import classproperty -def get_loop() -> tuple[asyncio.AbstractEventLoop, Optional[threading.Thread]]: +def get_loop() -> tuple[asyncio.AbstractEventLoop, threading.Thread | None]: # we are actually instantiating a new loop here # to not interfere with an existing dask loop @@ -75,15 +73,15 @@ class ModuleConfig: class ModuleBase(Configurable[ModuleConfig], SkillContainer, Resource): - _rpc: Optional[RPCSpec] = None - _tf: Optional[TFSpec] = None - _loop: Optional[asyncio.AbstractEventLoop] = None - _loop_thread: Optional[threading.Thread] + _rpc: RPCSpec | None = None + _tf: TFSpec | None = None + _loop: asyncio.AbstractEventLoop | None = None + _loop_thread: threading.Thread | None _disposables: CompositeDisposable default_config = ModuleConfig - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._loop, self._loop_thread = get_loop() self._disposables = CompositeDisposable() @@ -107,7 +105,7 @@ def stop(self) -> None: self._close_module() super().stop() - def _close_module(self): + def _close_module(self) -> None: self._close_rpc() if hasattr(self, "_loop") and self._loop_thread: if self._loop_thread.is_alive(): @@ -121,7 +119,7 @@ def _close_module(self): if hasattr(self, "_disposables"): self._disposables.dispose() - def _close_rpc(self): + def _close_rpc(self) -> None: # Using hasattr is needed because SkillCoordinator skips ModuleBase.__init__ and self.rpc is never set. if hasattr(self, "rpc") and self.rpc: self.rpc.stop() @@ -138,7 +136,7 @@ def __getstate__(self): state.pop("_tf", None) return state - def __setstate__(self, state): + def __setstate__(self, state) -> None: """Restore object from pickled state.""" self.__dict__.update(state) # Reinitialize runtime attributes @@ -156,7 +154,7 @@ def tf(self): return self._tf @tf.setter - def tf(self, value): + def tf(self, value) -> None: import warnings warnings.warn( @@ -197,9 +195,9 @@ def rpcs(cls) -> dict[str, Callable]: def io(self) -> str: def _box(name: str) -> str: return [ - f"┌┴" + "─" * (len(name) + 1) + "┐", + "┌┴" + "─" * (len(name) + 1) + "┐", f"│ {name} │", - f"└┬" + "─" * (len(name) + 1) + "┘", + "└┬" + "─" * (len(name) + 1) + "┘", ] # can't modify __str__ on a function like we are doing for I/O @@ -241,18 +239,18 @@ def repr_rpc(fn: Callable) -> str: return "\n".join(ret) @classproperty - def blueprint(cls): + def blueprint(self): # Here to prevent circular imports. from dimos.core.blueprints import create_module_blueprint - return partial(create_module_blueprint, cls) + return partial(create_module_blueprint, self) class DaskModule(ModuleBase): ref: Actor worker: int - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: self.ref = None for name, ann in get_type_hints(self, include_extras=True).items(): @@ -273,11 +271,11 @@ def set_ref(self, ref) -> int: self.worker = worker.name return worker.name - def __str__(self): + def __str__(self) -> str: return f"{self.__class__.__name__}" # called from remote - def set_transport(self, stream_name: str, transport: Transport): + def set_transport(self, stream_name: str, transport: Transport) -> bool: stream = getattr(self, stream_name, None) if not stream: raise ValueError(f"{stream_name} not found in {self.__class__.__name__}") @@ -297,10 +295,10 @@ def connect_stream(self, input_name: str, remote_stream: RemoteOut[T]): raise TypeError(f"Input {input_name} is not a valid stream") input_stream.connection = remote_stream - def dask_receive_msg(self, input_name: str, msg: Any): + def dask_receive_msg(self, input_name: str, msg: Any) -> None: getattr(self, input_name).transport.dask_receive_msg(msg) - def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]): + def dask_register_subscriber(self, output_name: str, subscriber: RemoteIn[T]) -> None: getattr(self, output_name).transport.dask_register_subscriber(subscriber) diff --git a/dimos/core/module_coordinator.py b/dimos/core/module_coordinator.py index 47081a0d71..477ba4b651 100644 --- a/dimos/core/module_coordinator.py +++ b/dimos/core/module_coordinator.py @@ -13,7 +13,7 @@ # limitations under the License. import time -from typing import TYPE_CHECKING, Optional, Type, TypeVar +from typing import TYPE_CHECKING, TypeVar from dimos import core from dimos.core import DimosCluster, Module @@ -21,23 +21,23 @@ from dimos.core.resource import Resource if TYPE_CHECKING: - from dimos.core import DimosCluster, Module + from dimos.core import Module T = TypeVar("T", bound="Module") class ModuleCoordinator(Resource): - _client: Optional[DimosCluster] = None - _n: Optional[int] = None + _client: DimosCluster | None = None + _n: int | None = None _memory_limit: str = "auto" - _deployed_modules: dict[Type["Module"], "Module"] = {} + _deployed_modules: dict[type["Module"], "Module"] = {} def __init__( self, - n: Optional[int] = None, + n: int | None = None, memory_limit: str = "auto", global_config: GlobalConfig | None = None, - ): + ) -> None: cfg = global_config or GlobalConfig() self._n = n if n is not None else cfg.n_dask_workers self._memory_limit = memory_limit @@ -51,7 +51,7 @@ def stop(self) -> None: self._client.close_all() - def deploy(self, module_class: Type[T], *args, **kwargs) -> T: + def deploy(self, module_class: type[T], *args, **kwargs) -> T: if not self._client: raise ValueError("Not started") @@ -63,7 +63,7 @@ def start_all_modules(self) -> None: for module in self._deployed_modules.values(): module.start() - def get_instance(self, module: Type[T]) -> T | None: + def get_instance(self, module: type[T]) -> T | None: return self._deployed_modules.get(module) def wait_until_shutdown(self) -> None: diff --git a/dimos/core/o3dpickle.py b/dimos/core/o3dpickle.py index a18916a06c..8e0f13dbf0 100644 --- a/dimos/core/o3dpickle.py +++ b/dimos/core/o3dpickle.py @@ -31,7 +31,7 @@ def reconstruct_pointcloud(points_array): return pc -def register_picklers(): +def register_picklers() -> None: # Register for the actual PointCloud class that gets instantiated # We need to create a dummy PointCloud to get its actual class _dummy_pc = o3d.geometry.PointCloud() diff --git a/dimos/core/rpc_client.py b/dimos/core/rpc_client.py index dce1d704af..bfcec5bb71 100644 --- a/dimos/core/rpc_client.py +++ b/dimos/core/rpc_client.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable - +from collections.abc import Callable +from typing import Any from dimos.protocol.rpc.lcmrpc import LCMRPC from dimos.utils.logging_config import setup_logger - logger = setup_logger(__file__) @@ -38,7 +37,7 @@ def __init__( remote_name: str, unsub_fns: list, stop_client: Callable[[], None] | None = None, - ): + ) -> None: self._original_method = original_method self._rpc = rpc self._name = name @@ -51,7 +50,7 @@ def __init__( self.__name__ = original_method.__name__ self.__qualname__ = f"{self.__class__.__name__}.{original_method.__name__}" - def set_rpc(self, rpc: LCMRPC): + def set_rpc(self, rpc: LCMRPC) -> None: self._rpc = rpc def __call__(self, *args, **kwargs): @@ -74,7 +73,7 @@ def __call__(self, *args, **kwargs): def __getstate__(self): return (self._original_method, self._name, self._remote_name) - def __setstate__(self, state): + def __setstate__(self, state) -> None: self._original_method, self._name, self._remote_name = state self._unsub_fns = [] self._rpc = None @@ -82,7 +81,7 @@ def __setstate__(self, state): class RPCClient: - def __init__(self, actor_instance, actor_class): + def __init__(self, actor_instance, actor_class) -> None: self.rpc = LCMRPC() self.actor_class = actor_class self.remote_name = actor_class.__name__ @@ -91,7 +90,7 @@ def __init__(self, actor_instance, actor_class): self.rpc.start() self._unsub_fns = [] - def stop_rpc_client(self): + def stop_rpc_client(self) -> None: for unsub in self._unsub_fns: try: unsub() diff --git a/dimos/core/skill_module.py b/dimos/core/skill_module.py index f432b48861..4c6a42fa5b 100644 --- a/dimos/core/skill_module.py +++ b/dimos/core/skill_module.py @@ -13,7 +13,7 @@ # limitations under the License. from dimos.core.module import Module -from dimos.core.rpc_client import RPCClient, RpcCall +from dimos.core.rpc_client import RpcCall, RPCClient from dimos.protocol.skill.skill import rpc @@ -25,8 +25,8 @@ def set_LlmAgent_register_skills(self, callable: RpcCall) -> None: callable.set_rpc(self.rpc) callable(RPCClient(self, self.__class__)) - def __getstate__(self): + def __getstate__(self) -> None: pass - def __setstate__(self, _state): + def __setstate__(self, _state) -> None: pass diff --git a/dimos/core/stream.py b/dimos/core/stream.py index 672ea4316e..a8843b0989 100644 --- a/dimos/core/stream.py +++ b/dimos/core/stream.py @@ -16,15 +16,14 @@ import enum from typing import ( + TYPE_CHECKING, Any, - Callable, Generic, - Optional, TypeVar, ) -import reactivex as rx from dask.distributed import Actor +import reactivex as rx from reactivex import operators as ops from reactivex.disposable import Disposable @@ -32,13 +31,16 @@ import dimos.utils.reactive as reactive from dimos.utils.reactive import backpressure +if TYPE_CHECKING: + from collections.abc import Callable + T = TypeVar("T") class ObservableMixin(Generic[T]): # subscribes and returns the first value it receives # might be nicer to write without rxpy but had this snippet ready - def get_next(self, timeout=10.0) -> T: + def get_next(self, timeout: float = 10.0) -> T: try: return ( self.observable() @@ -73,9 +75,9 @@ class State(enum.Enum): class Transport(ObservableMixin[T]): # used by local Output - def broadcast(self, selfstream: Out[T], value: T): ... + def broadcast(self, selfstream: Out[T], value: T) -> None: ... - def publish(self, msg: T): + def publish(self, msg: T) -> None: self.broadcast(None, msg) # used by local Input @@ -83,15 +85,15 @@ def subscribe(self, selfstream: In[T], callback: Callable[[T], any]) -> None: .. class Stream(Generic[T]): - _transport: Optional[Transport] + _transport: Transport | None def __init__( self, type: type[T], name: str, - owner: Optional[Any] = None, - transport: Optional[Transport] = None, - ): + owner: Any | None = None, + transport: Transport | None = None, + ) -> None: self.name = name self.owner = owner self.type = type @@ -113,7 +115,7 @@ def _color_fn(self) -> Callable[[str], str]: return colors.green return lambda s: s - def __str__(self) -> str: # noqa: D401 + def __str__(self) -> str: return ( self.__class__.__name__ + " " @@ -131,7 +133,7 @@ def __str__(self) -> str: # noqa: D401 class Out(Stream[T]): _transport: Transport - def __init__(self, *argv, **kwargs): + def __init__(self, *argv, **kwargs) -> None: super().__init__(*argv, **kwargs) @property @@ -144,10 +146,10 @@ def transport(self, value: Transport[T]) -> None: ... @property - def state(self) -> State: # noqa: D401 + def state(self) -> State: return State.UNBOUND if self.owner is None else State.READY - def __reduce__(self): # noqa: D401 + def __reduce__(self): if self.owner is None or not hasattr(self.owner, "ref"): raise ValueError("Cannot serialise Out without an owner ref") return ( @@ -168,7 +170,7 @@ def publish(self, msg): class RemoteStream(Stream[T]): @property - def state(self) -> State: # noqa: D401 + def state(self) -> State: return State.UNBOUND if self.owner is None else State.READY # this won't work but nvm @@ -193,10 +195,10 @@ def subscribe(self, cb) -> Callable[[], None]: # representation of Input # as views from inside of the module class In(Stream[T], ObservableMixin[T]): - connection: Optional[RemoteOut[T]] = None + connection: RemoteOut[T] | None = None _transport: Transport - def __str__(self): + def __str__(self) -> str: mystr = super().__str__() if not self.connection: @@ -204,7 +206,7 @@ def __str__(self): return (mystr + " ◀─").ljust(60, "─") + f" {self.connection}" - def __reduce__(self): # noqa: D401 + def __reduce__(self): if self.owner is None or not hasattr(self.owner, "ref"): raise ValueError("Cannot serialise Out without an owner ref") return (RemoteIn, (self.type, self.name, self.owner.ref, self._transport)) @@ -225,7 +227,7 @@ def connect(self, value: Out[T]) -> None: ... @property - def state(self) -> State: # noqa: D401 + def state(self) -> State: return State.UNBOUND if self.owner is None else State.READY # returns unsubscribe function @@ -244,7 +246,7 @@ def connect(self, other: RemoteOut[T]) -> None: def transport(self) -> Transport[T]: return self._transport - def publish(self, msg): + def publish(self, msg) -> None: self.transport.broadcast(self, msg) @transport.setter diff --git a/dimos/core/test_blueprints.py b/dimos/core/test_blueprints.py index edce54f2e1..da39ef467d 100644 --- a/dimos/core/test_blueprints.py +++ b/dimos/core/test_blueprints.py @@ -17,8 +17,8 @@ ModuleBlueprintSet, ModuleConnection, _make_module_blueprint, + autoconnect, ) -from dimos.core.blueprints import autoconnect from dimos.core.core import rpc from dimos.core.global_config import GlobalConfig from dimos.core.module import Module @@ -91,7 +91,7 @@ class ModuleC(Module): module_c = ModuleC.blueprint -def test_get_connection_set(): +def test_get_connection_set() -> None: assert _make_module_blueprint(CatModule, args=("arg1"), kwargs={"k": "v"}) == ModuleBlueprint( module=CatModule, connections=( @@ -103,7 +103,7 @@ def test_get_connection_set(): ) -def test_autoconnect(): +def test_autoconnect() -> None: blueprint_set = autoconnect(module_a(), module_b()) assert blueprint_set == ModuleBlueprintSet( @@ -131,7 +131,7 @@ def test_autoconnect(): ) -def test_with_transports(): +def test_with_transports() -> None: custom_transport = LCMTransport("/custom_topic", Data1) blueprint_set = autoconnect(module_a(), module_b()).with_transports( {("data1", Data1): custom_transport} @@ -141,7 +141,7 @@ def test_with_transports(): assert blueprint_set.transports[("data1", Data1)] == custom_transport -def test_with_global_config(): +def test_with_global_config() -> None: blueprint_set = autoconnect(module_a(), module_b()).with_global_config(option1=True, option2=42) assert "option1" in blueprint_set.global_config_overrides @@ -150,7 +150,7 @@ def test_with_global_config(): assert blueprint_set.global_config_overrides["option2"] == 42 -def test_build_happy_path(): +def test_build_happy_path() -> None: pubsub.lcm.autoconf() blueprint_set = autoconnect(module_a(), module_b(), module_c()) diff --git a/dimos/core/test_core.py b/dimos/core/test_core.py index 1acf87f078..97f09a4182 100644 --- a/dimos/core/test_core.py +++ b/dimos/core/test_core.py @@ -15,6 +15,7 @@ import time import pytest +from reactivex.disposable import Disposable from dimos.core import ( In, @@ -29,7 +30,6 @@ from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry -from reactivex.disposable import Disposable assert dimos @@ -46,12 +46,12 @@ class Navigation(Module): @rpc def navigate_to(self, target: Vector3) -> bool: ... - def __init__(self): + def __init__(self) -> None: super().__init__() @rpc - def start(self): - def _odom(msg): + def start(self) -> None: + def _odom(msg) -> None: self.odom_msg_count += 1 print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) self.mov.publish(msg.position) @@ -59,7 +59,7 @@ def _odom(msg): unsub = self.odometry.subscribe(_odom) self._disposables.add(Disposable(unsub)) - def _lidar(msg): + def _lidar(msg) -> None: self.lidar_msg_count += 1 if hasattr(msg, "pubtime"): print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) @@ -70,7 +70,7 @@ def _lidar(msg): self._disposables.add(Disposable(unsub)) -def test_classmethods(): +def test_classmethods() -> None: # Test class property access class_rpcs = Navigation.rpcs print("Class rpcs:", class_rpcs) @@ -103,7 +103,7 @@ def test_classmethods(): @pytest.mark.module -def test_basic_deployment(dimos): +def test_basic_deployment(dimos) -> None: robot = dimos.deploy(MockRobotClient) print("\n") diff --git a/dimos/core/test_modules.py b/dimos/core/test_modules.py index 42112f2415..d1f925aff2 100644 --- a/dimos/core/test_modules.py +++ b/dimos/core/test_modules.py @@ -17,7 +17,6 @@ import ast import inspect from pathlib import Path -from typing import Dict, List, Set, Tuple import pytest @@ -27,13 +26,13 @@ class ModuleVisitor(ast.NodeVisitor): """AST visitor to find classes and their base classes.""" - def __init__(self, filepath: str): + def __init__(self, filepath: str) -> None: self.filepath = filepath - self.classes: List[ - Tuple[str, List[str], Set[str]] + self.classes: list[ + tuple[str, list[str], set[str]] ] = [] # (class_name, base_classes, methods) - def visit_ClassDef(self, node: ast.ClassDef): + def visit_ClassDef(self, node: ast.ClassDef) -> None: """Visit a class definition.""" # Get base class names base_classes = [] @@ -61,7 +60,7 @@ def visit_ClassDef(self, node: ast.ClassDef): self.generic_visit(node) -def get_import_aliases(tree: ast.AST) -> Dict[str, str]: +def get_import_aliases(tree: ast.AST) -> dict[str, str]: """Extract import aliases from the AST.""" aliases = {} @@ -81,10 +80,10 @@ def get_import_aliases(tree: ast.AST) -> Dict[str, str]: def is_module_subclass( - base_classes: List[str], - aliases: Dict[str, str], - class_hierarchy: Dict[str, List[str]] = None, - current_module_path: str = None, + base_classes: list[str], + aliases: dict[str, str], + class_hierarchy: dict[str, list[str]] | None = None, + current_module_path: str | None = None, ) -> bool: """Check if any base class is or resolves to dimos.core.Module or its variants (recursively).""" target_classes = { @@ -99,7 +98,7 @@ def is_module_subclass( "dimos.core.module.DaskModule", } - def find_qualified_name(base: str, context_module: str = None) -> str: + def find_qualified_name(base: str, context_module: str | None = None) -> str: """Find the qualified name for a base class, using import context if available.""" if not class_hierarchy: return base @@ -126,7 +125,9 @@ def find_qualified_name(base: str, context_module: str = None) -> str: # Otherwise return the base as-is return base - def check_base(base: str, visited: Set[str] = None, context_module: str = None) -> bool: + def check_base( + base: str, visited: set[str] | None = None, context_module: str | None = None + ) -> bool: if visited is None: visited = set() @@ -168,8 +169,10 @@ def check_base(base: str, visited: Set[str] = None, context_module: str = None) def scan_file( - filepath: Path, class_hierarchy: Dict[str, List[str]] = None, root_path: Path = None -) -> List[Tuple[str, str, bool, bool, Set[str]]]: + filepath: Path, + class_hierarchy: dict[str, list[str]] | None = None, + root_path: Path | None = None, +) -> list[tuple[str, str, bool, bool, set[str]]]: """ Scan a Python file for Module subclasses. @@ -179,7 +182,7 @@ def scan_file( forbidden_method_names = {"acquire", "release", "open", "close", "shutdown", "clean", "cleanup"} try: - with open(filepath, "r", encoding="utf-8") as f: + with open(filepath, encoding="utf-8") as f: content = f.read() tree = ast.parse(content, filename=str(filepath)) @@ -215,7 +218,7 @@ def scan_file( return [] -def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]: +def build_class_hierarchy(root_path: Path) -> dict[str, list[str]]: """Build a complete class hierarchy by scanning all Python files.""" hierarchy = {} @@ -225,7 +228,7 @@ def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]: continue try: - with open(filepath, "r", encoding="utf-8") as f: + with open(filepath, encoding="utf-8") as f: content = f.read() tree = ast.parse(content, filename=str(filepath)) @@ -257,7 +260,7 @@ def build_class_hierarchy(root_path: Path) -> Dict[str, List[str]]: return hierarchy -def scan_directory(root_path: Path) -> List[Tuple[str, str, bool, bool, Set[str]]]: +def scan_directory(root_path: Path) -> list[tuple[str, str, bool, bool, set[str]]]: """Scan all Python files in the directory tree.""" # First, build the complete class hierarchy class_hierarchy = build_class_hierarchy(root_path) @@ -305,7 +308,9 @@ def get_all_module_subclasses(): get_all_module_subclasses(), ids=lambda val: val[0] if isinstance(val, str) else str(val), ) -def test_module_has_start_and_stop(class_name, filepath, has_start, has_stop, forbidden_methods): +def test_module_has_start_and_stop( + class_name: str, filepath, has_start, has_stop, forbidden_methods +) -> None: """Test that Module subclasses implement start and stop methods and don't use forbidden methods.""" # Get relative path for better error messages try: diff --git a/dimos/core/test_rpcstress.py b/dimos/core/test_rpcstress.py index 8f7a0dac40..fc00a95854 100644 --- a/dimos/core/test_rpcstress.py +++ b/dimos/core/test_rpcstress.py @@ -23,7 +23,7 @@ class Counter(Module): count_stream: Out[int] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.current_count = 0 @@ -40,7 +40,7 @@ class CounterValidator(Module): count_in: In[int] = None - def __init__(self, increment_func): + def __init__(self, increment_func) -> None: super().__init__() self.increment_func = increment_func self.last_seen = 0 @@ -53,7 +53,7 @@ def __init__(self, increment_func): self.waiting_for_response = False @rpc - def start(self): + def start(self) -> None: """Start the validator.""" self.count_in.subscribe(self._on_count_received) self.running = True @@ -61,13 +61,13 @@ def start(self): self.call_thread.start() @rpc - def stop(self): + def stop(self) -> None: """Stop the validator.""" self.running = False if self.call_thread: self.call_thread.join() - def _on_count_received(self, count: int): + def _on_count_received(self, count: int) -> None: """Check if we received all numbers in sequence and trigger next call.""" # Calculate round trip time if self.call_start_time: @@ -83,7 +83,7 @@ def _on_count_received(self, count: int): # Signal that we can make the next call self.waiting_for_response = False - def _call_loop(self): + def _call_loop(self) -> None: """Call increment only after receiving response from previous call.""" while self.running: if not self.waiting_for_response: @@ -159,7 +159,7 @@ def get_stats(self): # Get stats before stopping stats = validator.get_stats() - print(f"\n[MAIN] Final statistics:") + print("\n[MAIN] Final statistics:") print(f" - Total calls made: {stats['call_count']}") print(f" - Last number seen: {stats['last_seen']}") print(f" - Missing numbers: {stats['missing_count']}") diff --git a/dimos/core/test_stream.py b/dimos/core/test_stream.py index 59fa806716..91091e42af 100644 --- a/dimos/core/test_stream.py +++ b/dimos/core/test_stream.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable import time -from typing import Callable, Optional import pytest @@ -34,16 +34,16 @@ class SubscriberBase(Module): sub1_msgs: list[Odometry] = None sub2_msgs: list[Odometry] = None - def __init__(self): + def __init__(self) -> None: self.sub1_msgs = [] self.sub2_msgs = [] super().__init__() @rpc - def sub1(self): ... + def sub1(self) -> None: ... @rpc - def sub2(self): ... + def sub2(self) -> None: ... @rpc def active_subscribers(self): @@ -60,19 +60,19 @@ def sub2_msgs_len(self) -> int: class ClassicSubscriber(SubscriberBase): odom: In[Odometry] = None - unsub: Optional[Callable[[], None]] = None - unsub2: Optional[Callable[[], None]] = None + unsub: Callable[[], None] | None = None + unsub2: Callable[[], None] | None = None @rpc - def sub1(self): + def sub1(self) -> None: self.unsub = self.odom.subscribe(self.sub1_msgs.append) @rpc - def sub2(self): + def sub2(self) -> None: self.unsub2 = self.odom.subscribe(self.sub2_msgs.append) @rpc - def stop(self): + def stop(self) -> None: if self.unsub: self.unsub() self.unsub = None @@ -83,21 +83,21 @@ def stop(self): class RXPYSubscriber(SubscriberBase): odom: In[Odometry] = None - unsub: Optional[Callable[[], None]] = None - unsub2: Optional[Callable[[], None]] = None + unsub: Callable[[], None] | None = None + unsub2: Callable[[], None] | None = None - hot: Optional[Callable[[], None]] = None + hot: Callable[[], None] | None = None @rpc - def sub1(self): + def sub1(self) -> None: self.unsub = self.odom.observable().subscribe(self.sub1_msgs.append) @rpc - def sub2(self): + def sub2(self) -> None: self.unsub2 = self.odom.observable().subscribe(self.sub2_msgs.append) @rpc - def stop(self): + def stop(self) -> None: if self.unsub: self.unsub.dispose() self.unsub = None @@ -110,11 +110,11 @@ def get_next(self): return self.odom.get_next() @rpc - def start_hot_getter(self): + def start_hot_getter(self) -> None: self.hot = self.odom.hot_latest() @rpc - def stop_hot_getter(self): + def stop_hot_getter(self) -> None: self.hot.dispose() @rpc @@ -128,7 +128,7 @@ class SpyLCMTransport(LCMTransport): def __reduce__(self): return (SpyLCMTransport, (self.topic.topic, self.topic.lcm_type)) - def __init__(self, topic: str, type: type, **kwargs): + def __init__(self, topic: str, type: type, **kwargs) -> None: super().__init__(topic, type, **kwargs) self._subscriber_map = {} # Maps unsubscribe functions to track active subs @@ -139,7 +139,7 @@ def subscribe(self, selfstream: In, callback: Callable) -> Callable[[], None]: # Increment counter self.active_subscribers += 1 - def wrapped_unsubscribe(): + def wrapped_unsubscribe() -> None: # Create wrapper that decrements counter when called if wrapped_unsubscribe in self._subscriber_map: self.active_subscribers -= 1 @@ -154,7 +154,7 @@ def wrapped_unsubscribe(): @pytest.mark.parametrize("subscriber_class", [ClassicSubscriber, RXPYSubscriber]) @pytest.mark.module -def test_subscription(dimos, subscriber_class): +def test_subscription(dimos, subscriber_class) -> None: robot = dimos.deploy(MockRobotClient) robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) @@ -192,7 +192,7 @@ def test_subscription(dimos, subscriber_class): @pytest.mark.module -def test_get_next(dimos): +def test_get_next(dimos) -> None: robot = dimos.deploy(MockRobotClient) robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) @@ -221,7 +221,7 @@ def test_get_next(dimos): @pytest.mark.module -def test_hot_getter(dimos): +def test_hot_getter(dimos) -> None: robot = dimos.deploy(MockRobotClient) robot.lidar.transport = SpyLCMTransport("/lidar", LidarMessage) diff --git a/dimos/core/testing.py b/dimos/core/testing.py index e17b25f41e..92f6d6b497 100644 --- a/dimos/core/testing.py +++ b/dimos/core/testing.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time from threading import Event, Thread +import time import pytest -from dimos.core import In, Module, Out, start, rpc +from dimos.core import In, Module, Out, rpc, start from dimos.msgs.geometry_msgs import Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -39,16 +39,16 @@ class MockRobotClient(Module): mov_msg_count = 0 - def mov_callback(self, msg): + def mov_callback(self, msg) -> None: self.mov_msg_count += 1 - def __init__(self): + def __init__(self) -> None: super().__init__() self._stop_event = Event() self._thread = None @rpc - def start(self): + def start(self) -> None: super().start() self._thread = Thread(target=self.odomloop) @@ -63,7 +63,7 @@ def stop(self) -> None: super().stop() - def odomloop(self): + def odomloop(self) -> None: odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) diff --git a/dimos/core/transport.py b/dimos/core/transport.py index 77f471bafe..1e770515b8 100644 --- a/dimos/core/transport.py +++ b/dimos/core/transport.py @@ -15,32 +15,23 @@ from __future__ import annotations import traceback -from typing import Any, Callable, Generic, List, Optional, Protocol, TypeVar +from typing import TypeVar import dimos.core.colors as colors T = TypeVar("T") -import traceback from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Optional, - Protocol, + TYPE_CHECKING, TypeVar, - get_args, - get_origin, - get_type_hints, ) -import dimos.core.colors as colors from dimos.core.stream import In, RemoteIn, Transport -from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM -from dimos.protocol.pubsub.lcmpubsub import Topic as LCMTopic -from dimos.protocol.pubsub.shmpubsub import SharedMemory, PickleSharedMemory +from dimos.protocol.pubsub.lcmpubsub import LCM, PickleLCM, Topic as LCMTopic +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory, SharedMemory + +if TYPE_CHECKING: + from collections.abc import Callable T = TypeVar("T") @@ -48,7 +39,7 @@ class PubSubTransport(Transport[T]): topic: any - def __init__(self, topic: any): + def __init__(self, topic: any) -> None: self.topic = topic def __str__(self) -> str: @@ -62,14 +53,14 @@ def __str__(self) -> str: class pLCMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, **kwargs): + def __init__(self, topic: str, **kwargs) -> None: super().__init__(topic) self.lcm = PickleLCM(**kwargs) def __reduce__(self): return (pLCMTransport, (self.topic,)) - def broadcast(self, _, msg): + def broadcast(self, _, msg) -> None: if not self._started: self.lcm.start() self._started = True @@ -86,14 +77,14 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> class LCMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, type: type, **kwargs): + def __init__(self, topic: str, type: type, **kwargs) -> None: super().__init__(LCMTopic(topic, type)) self.lcm = LCM(**kwargs) def __reduce__(self): return (LCMTransport, (self.topic.topic, self.topic.lcm_type)) - def broadcast(self, _, msg): + def broadcast(self, _, msg) -> None: if not self._started: self.lcm.start() self._started = True @@ -110,14 +101,14 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> class pSHMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, **kwargs): + def __init__(self, topic: str, **kwargs) -> None: super().__init__(topic) self.shm = PickleSharedMemory(**kwargs) def __reduce__(self): return (pSHMTransport, (self.topic,)) - def broadcast(self, _, msg): + def broadcast(self, _, msg) -> None: if not self._started: self.shm.start() self._started = True @@ -134,14 +125,14 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> class SHMTransport(PubSubTransport[T]): _started: bool = False - def __init__(self, topic: str, **kwargs): + def __init__(self, topic: str, **kwargs) -> None: super().__init__(topic) self.shm = SharedMemory(**kwargs) def __reduce__(self): return (SHMTransport, (self.topic,)) - def broadcast(self, _, msg): + def broadcast(self, _, msg) -> None: if not self._started: self.shm.start() self._started = True @@ -156,10 +147,10 @@ def subscribe(self, callback: Callable[[T], None], selfstream: In[T] = None) -> class DaskTransport(Transport[T]): - subscribers: List[Callable[[T], None]] + subscribers: list[Callable[[T], None]] _started: bool = False - def __init__(self): + def __init__(self) -> None: self.subscribers = [] def __str__(self) -> str: diff --git a/dimos/environment/agent_environment.py b/dimos/environment/agent_environment.py index 861a1f429b..a5dab0e272 100644 --- a/dimos/environment/agent_environment.py +++ b/dimos/environment/agent_environment.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path + import cv2 import numpy as np -from pathlib import Path -from typing import List, Union + from .environment import Environment class AgentEnvironment(Environment): - def __init__(self): + def __init__(self) -> None: super().__init__() self.environment_type = "agent" self.frames = [] @@ -29,7 +30,7 @@ def __init__(self): self._segmentations = [] self._point_clouds = [] - def initialize_from_images(self, images: Union[List[str], List[np.ndarray]]) -> bool: + def initialize_from_images(self, images: list[str] | list[np.ndarray]) -> bool: """Initialize environment from a list of image paths or numpy arrays. Args: @@ -88,37 +89,42 @@ def initialize_from_directory(self, directory_path: str) -> bool: # TODO: Implement directory initialization raise NotImplementedError("Directory initialization not yet implemented") - def label_objects(self) -> List[str]: + def label_objects(self) -> list[str]: """Implementation of abstract method to label objects.""" # TODO: Implement object labeling using a detection model raise NotImplementedError("Object labeling not yet implemented") def generate_segmentations( - self, model: str = None, objects: List[str] = None, *args, **kwargs - ) -> List[np.ndarray]: + self, model: str | None = None, objects: list[str] | None = None, *args, **kwargs + ) -> list[np.ndarray]: """Generate segmentations for the current frame.""" # TODO: Implement segmentation generation using specified model raise NotImplementedError("Segmentation generation not yet implemented") - def get_segmentations(self) -> List[np.ndarray]: + def get_segmentations(self) -> list[np.ndarray]: """Return pre-computed segmentations for the current frame.""" if self._segmentations: return self._segmentations[self.current_frame_idx] return [] - def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: + def generate_point_cloud(self, object: str | None = None, *args, **kwargs) -> np.ndarray: """Generate point cloud from the current frame.""" # TODO: Implement point cloud generation raise NotImplementedError("Point cloud generation not yet implemented") - def get_point_cloud(self, object: str = None) -> np.ndarray: + def get_point_cloud(self, object: str | None = None) -> np.ndarray: """Return pre-computed point cloud.""" if self._point_clouds: return self._point_clouds[self.current_frame_idx] return np.array([]) def generate_depth_map( - self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs + self, + stereo: bool | None = None, + monocular: bool | None = None, + model: str | None = None, + *args, + **kwargs, ) -> np.ndarray: """Generate depth map for the current frame.""" # TODO: Implement depth map generation using specified method diff --git a/dimos/environment/colmap_environment.py b/dimos/environment/colmap_environment.py index 9981e50098..f1b0986c77 100644 --- a/dimos/environment/colmap_environment.py +++ b/dimos/environment/colmap_environment.py @@ -14,9 +14,11 @@ # UNDER DEVELOPMENT 🚧🚧🚧 +from pathlib import Path + import cv2 import pycolmap -from pathlib import Path + from dimos.environment.environment import Environment @@ -58,7 +60,7 @@ def initialize_from_video(self, video_path, frame_output_dir): # Initialize from the extracted frames return self.initialize_from_images(frame_output_dir) - def _extract_frames_from_video(self, video_path, frame_output_dir): + def _extract_frames_from_video(self, video_path, frame_output_dir) -> None: """Extract frames from a video and save them to a directory.""" cap = cv2.VideoCapture(str(video_path)) frame_count = 0 @@ -73,17 +75,17 @@ def _extract_frames_from_video(self, video_path, frame_output_dir): cap.release() - def label_objects(self): + def label_objects(self) -> None: pass - def get_visualization(self, format_type): + def get_visualization(self, format_type) -> None: pass - def get_segmentations(self): + def get_segmentations(self) -> None: pass - def get_point_cloud(self, object_id=None): + def get_point_cloud(self, object_id=None) -> None: pass - def get_depth_map(self): + def get_depth_map(self) -> None: pass diff --git a/dimos/environment/environment.py b/dimos/environment/environment.py index 0770b0f2ce..8b0068cbae 100644 --- a/dimos/environment/environment.py +++ b/dimos/environment/environment.py @@ -13,11 +13,12 @@ # limitations under the License. from abc import ABC, abstractmethod + import numpy as np class Environment(ABC): - def __init__(self): + def __init__(self) -> None: self.environment_type = None self.graph = None @@ -38,7 +39,7 @@ def get_visualization(self, format_type): @abstractmethod def generate_segmentations( - self, model: str = None, objects: list[str] = None, *args, **kwargs + self, model: str | None = None, objects: list[str] | None = None, *args, **kwargs ) -> list[np.ndarray]: """ Generate object segmentations of objects[] using neural methods. @@ -70,7 +71,7 @@ def get_segmentations(self) -> list[np.ndarray]: pass @abstractmethod - def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarray: + def generate_point_cloud(self, object: str | None = None, *args, **kwargs) -> np.ndarray: """ Generate a point cloud for the entire environment or a specific object. @@ -90,7 +91,7 @@ def generate_point_cloud(self, object: str = None, *args, **kwargs) -> np.ndarra pass @abstractmethod - def get_point_cloud(self, object: str = None) -> np.ndarray: + def get_point_cloud(self, object: str | None = None) -> np.ndarray: """ Return point clouds of the entire environment or a specific object. @@ -105,7 +106,12 @@ def get_point_cloud(self, object: str = None) -> np.ndarray: @abstractmethod def generate_depth_map( - self, stereo: bool = None, monocular: bool = None, model: str = None, *args, **kwargs + self, + stereo: bool | None = None, + monocular: bool | None = None, + model: str | None = None, + *args, + **kwargs, ) -> np.ndarray: """ Generate a depth map using monocular or stereo camera methods. diff --git a/dimos/exceptions/agent_memory_exceptions.py b/dimos/exceptions/agent_memory_exceptions.py index cbf3460754..073e56c643 100644 --- a/dimos/exceptions/agent_memory_exceptions.py +++ b/dimos/exceptions/agent_memory_exceptions.py @@ -24,7 +24,7 @@ class AgentMemoryError(Exception): message (str): Human-readable message describing the error. """ - def __init__(self, message="Error in AgentMemory operation"): + def __init__(self, message: str = "Error in AgentMemory operation") -> None: super().__init__(message) @@ -38,14 +38,14 @@ class AgentMemoryConnectionError(AgentMemoryError): cause (Exception, optional): Original exception, if any, that led to this error. """ - def __init__(self, message="Failed to connect to the database", cause=None): + def __init__(self, message: str = "Failed to connect to the database", cause=None) -> None: super().__init__(message) if cause: self.cause = cause self.traceback = traceback.format_exc() if cause else None - def __str__(self): - return f"{self.message}\nCaused by: {repr(self.cause)}" if self.cause else self.message + def __str__(self) -> str: + return f"{self.message}\nCaused by: {self.cause!r}" if self.cause else self.message class UnknownConnectionTypeError(AgentMemoryConnectionError): @@ -56,7 +56,9 @@ class UnknownConnectionTypeError(AgentMemoryConnectionError): message (str): Human-readable message explaining that an unknown connection type was used. """ - def __init__(self, message="Unknown connection type used in AgentMemory connection"): + def __init__( + self, message: str = "Unknown connection type used in AgentMemory connection" + ) -> None: super().__init__(message) @@ -69,7 +71,9 @@ class DataRetrievalError(AgentMemoryError): message (str): Human-readable message describing the data retrieval error. """ - def __init__(self, message="Error in retrieving data during AgentMemory operation"): + def __init__( + self, message: str = "Error in retrieving data during AgentMemory operation" + ) -> None: super().__init__(message) @@ -83,7 +87,7 @@ class DataNotFoundError(DataRetrievalError): message (str, optional): Human-readable message providing more detail. If not provided, a default message is generated. """ - def __init__(self, vector_id, message=None): + def __init__(self, vector_id, message=None) -> None: message = message or f"Requested data for vector ID {vector_id} was not found." super().__init__(message) self.vector_id = vector_id diff --git a/dimos/hardware/camera/module.py b/dimos/hardware/camera/module.py index 0dda51804d..0ac241959e 100644 --- a/dimos/hardware/camera/module.py +++ b/dimos/hardware/camera/module.py @@ -12,13 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable +from dataclasses import dataclass, field import queue import time -from dataclasses import dataclass, field -from typing import Callable, Optional -import reactivex as rx from dimos_lcm.sensor_msgs import CameraInfo +import reactivex as rx from reactivex import operators as ops from reactivex.disposable import Disposable from reactivex.observable import Observable @@ -32,18 +32,20 @@ from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import Image, sharpness_barrier -default_transform = lambda: Transform( - translation=Vector3(0.0, 0.0, 0.0), - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="camera_link", -) + +def default_transform(): + return Transform( + translation=Vector3(0.0, 0.0, 0.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="base_link", + child_frame_id="camera_link", + ) @dataclass class CameraModuleConfig(ModuleConfig): frame_id: str = "camera_link" - transform: Optional[Transform] = field(default_factory=default_transform) + transform: Transform | None = field(default_factory=default_transform) hardware: Callable[[], CameraHardware] | CameraHardware = Webcam frequency: float = 5.0 @@ -53,9 +55,9 @@ class CameraModule(Module, spec.Camera): camera_info_stream: Out[CameraInfo] = None hardware: Callable[[], CameraHardware] | CameraHardware = None - _module_subscription: Optional[Disposable] = None - _camera_info_subscription: Optional[Disposable] = None - _skill_stream: Optional[Observable[Image]] = None + _module_subscription: Disposable | None = None + _camera_info_subscription: Disposable | None = None + _skill_stream: Observable[Image] | None = None default_config = CameraModuleConfig @@ -64,7 +66,7 @@ def camera_info(self) -> CameraInfo: return self.hardware.camera_info @rpc - def start(self): + def start(self) -> str: if callable(self.config.hardware): self.hardware = self.config.hardware() else: @@ -75,7 +77,7 @@ def start(self): stream = self.hardware.image_stream().pipe(sharpness_barrier(self.config.frequency)) - def publish_info(camera_info: CameraInfo): + def publish_info(camera_info: CameraInfo) -> None: self.camera_info.publish(camera_info) if self.config.transform is None: @@ -102,8 +104,7 @@ def video_stream(self) -> Image: _queue = queue.Queue(maxsize=1) self.hardware.image_stream().subscribe(_queue.put) - for image in iter(_queue.get, None): - yield image + yield from iter(_queue.get, None) def camera_info_stream(self, frequency: float = 1.0) -> Observable[CameraInfo]: def camera_info(_) -> CameraInfo: @@ -112,7 +113,7 @@ def camera_info(_) -> CameraInfo: return rx.interval(1.0 / frequency).pipe(ops.map(camera_info)) - def stop(self): + def stop(self) -> None: if self._module_subscription: self._module_subscription.dispose() self._module_subscription = None diff --git a/dimos/hardware/camera/spec.py b/dimos/hardware/camera/spec.py index cc69db5d1c..b9722d6cd2 100644 --- a/dimos/hardware/camera/spec.py +++ b/dimos/hardware/camera/spec.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod, abstractproperty -from typing import Generic, Optional, Protocol, TypeVar +from typing import Generic, Protocol, TypeVar from dimos_lcm.sensor_msgs import CameraInfo from reactivex.observable import Observable @@ -23,7 +23,7 @@ class CameraConfig(Protocol): - frame_id_prefix: Optional[str] + frame_id_prefix: str | None CameraConfigT = TypeVar("CameraConfigT", bound=CameraConfig) diff --git a/dimos/hardware/camera/test_webcam.py b/dimos/hardware/camera/test_webcam.py index 0f6a509084..e2f99e85dd 100644 --- a/dimos/hardware/camera/test_webcam.py +++ b/dimos/hardware/camera/test_webcam.py @@ -25,7 +25,7 @@ @pytest.mark.tool -def test_streaming_single(): +def test_streaming_single() -> None: dimos = core.start(1) camera = dimos.deploy( @@ -57,7 +57,7 @@ def test_streaming_single(): @pytest.mark.tool -def test_streaming_double(): +def test_streaming_double() -> None: dimos = core.start(2) camera1 = dimos.deploy( diff --git a/dimos/hardware/camera/webcam.py b/dimos/hardware/camera/webcam.py index 7f9c9940a7..0f68989002 100644 --- a/dimos/hardware/camera/webcam.py +++ b/dimos/hardware/camera/webcam.py @@ -12,20 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading -import time from dataclasses import dataclass, field from functools import cache -from typing import Literal, Optional +import threading +import time +from typing import Literal import cv2 from dimos_lcm.sensor_msgs import CameraInfo from reactivex import create from reactivex.observable import Observable +from dimos.hardware.camera.spec import CameraConfig, CameraHardware from dimos.msgs.sensor_msgs import Image from dimos.msgs.sensor_msgs.Image import ImageFormat -from dimos.hardware.camera.spec import CameraConfig, CameraHardware from dimos.utils.reactive import backpressure @@ -36,14 +36,14 @@ class WebcamConfig(CameraConfig): frame_height: int = 480 frequency: int = 15 camera_info: CameraInfo = field(default_factory=CameraInfo) - frame_id_prefix: Optional[str] = None - stereo_slice: Optional[Literal["left", "right"]] = None # For stereo cameras + frame_id_prefix: str | None = None + stereo_slice: Literal["left", "right"] | None = None # For stereo cameras class Webcam(CameraHardware[WebcamConfig]): default_config = WebcamConfig - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._capture = None self._capture_thread = None @@ -66,7 +66,7 @@ def subscribe(observer, scheduler=None): return # Return a dispose function to stop camera when unsubscribed - def dispose(): + def dispose() -> None: self._observer = None self.stop() @@ -92,7 +92,7 @@ def start(self): self._capture_thread = threading.Thread(target=self._capture_loop, daemon=True) self._capture_thread.start() - def stop(self): + def stop(self) -> None: """Stop capturing frames""" # Signal thread to stop self._stop_event.set() @@ -140,7 +140,7 @@ def capture_frame(self) -> Image: return image - def _capture_loop(self): + def _capture_loop(self) -> None: """Capture frames at the configured frequency""" frame_interval = 1.0 / self.config.frequency next_frame_time = time.time() @@ -167,4 +167,4 @@ def _capture_loop(self): def camera_info(self) -> CameraInfo: return self.config.camera_info - def emit(self, image: Image): ... + def emit(self, image: Image) -> None: ... diff --git a/dimos/hardware/camera/zed/__init__.py b/dimos/hardware/camera/zed/__init__.py index 3c39045606..d7b70a1319 100644 --- a/dimos/hardware/camera/zed/__init__.py +++ b/dimos/hardware/camera/zed/__init__.py @@ -15,6 +15,7 @@ """ZED camera hardware interfaces.""" from pathlib import Path + from dimos.msgs.sensor_msgs.CameraInfo import CalibrationProvider # Check if ZED SDK is available @@ -31,13 +32,13 @@ else: # Provide stub classes when SDK is not available class ZEDCamera: - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: raise ImportError( "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." ) class ZEDModule: - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: raise ImportError( "ZED SDK not installed. Please install pyzed package to use ZED camera functionality." ) @@ -48,8 +49,8 @@ def __init__(self, *args, **kwargs): CameraInfo = CalibrationProvider(CALIBRATION_DIR) __all__ = [ - "ZEDCamera", - "ZEDModule", "HAS_ZED_SDK", "CameraInfo", + "ZEDCamera", + "ZEDModule", ] diff --git a/dimos/hardware/camera/zed/camera.py b/dimos/hardware/camera/zed/camera.py index e9f029c845..fdcd93f731 100644 --- a/dimos/hardware/camera/zed/camera.py +++ b/dimos/hardware/camera/zed/camera.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional, Tuple +from types import TracebackType +from typing import Any import cv2 +from dimos_lcm.sensor_msgs import CameraInfo import numpy as np import open3d as o3d import pyzed.sl as sl -from dimos_lcm.sensor_msgs import CameraInfo from reactivex import interval from dimos.core import Module, Out, rpc @@ -43,7 +44,7 @@ def __init__( depth_mode: sl.DEPTH_MODE = sl.DEPTH_MODE.NEURAL, fps: int = 30, **kwargs, - ): + ) -> None: """ Initialize ZED Camera. @@ -126,7 +127,7 @@ def enable_positional_tracking( enable_pose_smoothing: bool = True, enable_imu_fusion: bool = True, set_floor_as_origin: bool = False, - initial_world_transform: Optional[sl.Transform] = None, + initial_world_transform: sl.Transform | None = None, ) -> bool: """ Enable positional tracking on the ZED camera. @@ -169,7 +170,7 @@ def enable_positional_tracking( logger.error(f"Error enabling positional tracking: {e}") return False - def disable_positional_tracking(self): + def disable_positional_tracking(self) -> None: """Disable positional tracking.""" if self.tracking_enabled: self.zed.disable_positional_tracking() @@ -178,7 +179,7 @@ def disable_positional_tracking(self): def get_pose( self, reference_frame: sl.REFERENCE_FRAME = sl.REFERENCE_FRAME.WORLD - ) -> Optional[Dict[str, Any]]: + ) -> dict[str, Any] | None: """ Get the current camera pose. @@ -229,7 +230,7 @@ def get_pose( logger.error(f"Error getting pose: {e}") return None - def get_imu_data(self) -> Optional[Dict[str, Any]]: + def get_imu_data(self) -> dict[str, Any] | None: """ Get IMU sensor data if available. @@ -277,7 +278,7 @@ def get_imu_data(self) -> Optional[Dict[str, Any]]: def capture_frame( self, - ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: + ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None]: """ Capture a frame from ZED camera. @@ -312,7 +313,7 @@ def capture_frame( logger.error(f"Error capturing frame: {e}") return None, None, None - def capture_pointcloud(self) -> Optional[o3d.geometry.PointCloud]: + def capture_pointcloud(self) -> o3d.geometry.PointCloud | None: """ Capture point cloud from ZED camera. @@ -330,7 +331,7 @@ def capture_pointcloud(self) -> Optional[o3d.geometry.PointCloud]: point_cloud_data = self.point_cloud.get_data() # Convert to numpy array format - height, width = point_cloud_data.shape[:2] + _height, _width = point_cloud_data.shape[:2] points = point_cloud_data.reshape(-1, 4) # Extract XYZ coordinates @@ -372,9 +373,7 @@ def capture_pointcloud(self) -> Optional[o3d.geometry.PointCloud]: def capture_frame_with_pose( self, - ) -> Tuple[ - Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray], Optional[Dict[str, Any]] - ]: + ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None, dict[str, Any] | None]: """ Capture a frame with synchronized pose data. @@ -405,7 +404,7 @@ def capture_frame_with_pose( logger.error(f"Error capturing frame with pose: {e}") return None, None, None, None - def close(self): + def close(self) -> None: """Close the ZED camera.""" if self.is_opened: # Disable tracking if enabled @@ -416,7 +415,7 @@ def close(self): self.is_opened = False logger.info("ZED camera closed") - def get_camera_info(self) -> Dict[str, Any]: + def get_camera_info(self) -> dict[str, Any]: """Get ZED camera information and calibration parameters.""" if not self.is_opened: return {} @@ -434,8 +433,6 @@ def get_camera_info(self) -> Dict[str, Any]: else: # Method 2: Calculate from left and right camera positions # The baseline is the distance between left and right cameras - left_cam = calibration.left_cam - right_cam = calibration.right_cam # Try different ways to get baseline in SDK 4.0+ if hasattr(info.camera_configuration, "calibration_parameters_raw"): @@ -513,7 +510,12 @@ def __enter__(self): raise RuntimeError("Failed to open ZED camera") return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: """Context manager exit.""" self.close() @@ -546,9 +548,9 @@ def __init__( set_floor_as_origin: bool = True, publish_rate: float = 30.0, frame_id: str = "zed_camera", - recording_path: str = None, + recording_path: str | None = None, **kwargs, - ): + ) -> None: """ Initialize ZED Module. @@ -604,7 +606,7 @@ def __init__( logger.info(f"ZEDModule initialized for camera {camera_id}") @rpc - def start(self): + def start(self) -> None: """Start the ZED module and begin publishing data.""" if self._running: logger.warning("ZED module already running") @@ -656,7 +658,7 @@ def start(self): self._running = False @rpc - def stop(self): + def stop(self) -> None: """Stop the ZED module.""" if not self._running: return @@ -675,7 +677,7 @@ def stop(self): super().stop() - def _capture_and_publish(self): + def _capture_and_publish(self) -> None: """Capture frame and publish all data.""" if not self._running or not self.zed_camera: return @@ -719,7 +721,7 @@ def _capture_and_publish(self): except Exception as e: logger.error(f"Error in capture and publish: {e}") - def _publish_color_image(self, image: np.ndarray, header: Header): + def _publish_color_image(self, image: np.ndarray, header: Header) -> None: """Publish color image as LCM message.""" try: # Convert BGR to RGB if needed @@ -741,7 +743,7 @@ def _publish_color_image(self, image: np.ndarray, header: Header): except Exception as e: logger.error(f"Error publishing color image: {e}") - def _publish_depth_image(self, depth: np.ndarray, header: Header): + def _publish_depth_image(self, depth: np.ndarray, header: Header) -> None: """Publish depth image as LCM message.""" try: # Depth is float32 in meters @@ -756,7 +758,7 @@ def _publish_depth_image(self, depth: np.ndarray, header: Header): except Exception as e: logger.error(f"Error publishing depth image: {e}") - def _publish_camera_info(self): + def _publish_camera_info(self) -> None: """Publish camera calibration information.""" try: info = self.zed_camera.get_camera_info() @@ -834,7 +836,7 @@ def _publish_camera_info(self): except Exception as e: logger.error(f"Error publishing camera info: {e}") - def _publish_pose(self, pose_data: Dict[str, Any], header: Header): + def _publish_pose(self, pose_data: dict[str, Any], header: Header) -> None: """Publish camera pose as PoseStamped message and TF transform.""" try: position = pose_data.get("position", [0, 0, 0]) @@ -858,14 +860,14 @@ def _publish_pose(self, pose_data: Dict[str, Any], header: Header): logger.error(f"Error publishing pose: {e}") @rpc - def get_camera_info(self) -> Dict[str, Any]: + def get_camera_info(self) -> dict[str, Any]: """Get camera information and calibration parameters.""" if self.zed_camera: return self.zed_camera.get_camera_info() return {} @rpc - def get_pose(self) -> Optional[Dict[str, Any]]: + def get_pose(self) -> dict[str, Any] | None: """Get current camera pose if tracking is enabled.""" if self.zed_camera and self.enable_tracking: return self.zed_camera.get_pose() diff --git a/dimos/hardware/camera/zed/test_zed.py b/dimos/hardware/camera/zed/test_zed.py index ce1bef0b54..33810d3c2a 100644 --- a/dimos/hardware/camera/zed/test_zed.py +++ b/dimos/hardware/camera/zed/test_zed.py @@ -16,7 +16,7 @@ from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo -def test_zed_import_and_calibration_access(): +def test_zed_import_and_calibration_access() -> None: """Test that zed module can be imported and calibrations accessed.""" # Import zed module from camera from dimos.hardware.camera import zed diff --git a/dimos/hardware/end_effector.py b/dimos/hardware/end_effector.py index 373408003d..1c5eb08281 100644 --- a/dimos/hardware/end_effector.py +++ b/dimos/hardware/end_effector.py @@ -14,7 +14,7 @@ class EndEffector: - def __init__(self, effector_type=None): + def __init__(self, effector_type=None) -> None: self.effector_type = effector_type def get_effector_type(self): diff --git a/dimos/hardware/fake_zed_module.py b/dimos/hardware/fake_zed_module.py index b0a246ef12..c4c46c33b3 100644 --- a/dimos/hardware/fake_zed_module.py +++ b/dimos/hardware/fake_zed_module.py @@ -19,16 +19,17 @@ import functools import logging + +from dimos_lcm.sensor_msgs import CameraInfo import numpy as np from dimos.core import Module, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos_lcm.sensor_msgs import CameraInfo from dimos.msgs.std_msgs import Header -from dimos.utils.testing import TimedSensorReplay -from dimos.utils.logging_config import setup_logger from dimos.protocol.tf import TF +from dimos.utils.logging_config import setup_logger +from dimos.utils.testing import TimedSensorReplay logger = setup_logger(__name__, level=logging.INFO) @@ -44,7 +45,7 @@ class FakeZEDModule(Module): camera_info: Out[CameraInfo] = None pose: Out[PoseStamped] = None - def __init__(self, recording_path: str, frame_id: str = "zed_camera", **kwargs): + def __init__(self, recording_path: str, frame_id: str = "zed_camera", **kwargs) -> None: """ Initialize FakeZEDModule with recording path. @@ -197,7 +198,7 @@ def camera_info_autocast(x): return info_replay.stream() @rpc - def start(self): + def start(self) -> None: """Start replaying recorded data.""" super().start() @@ -261,15 +262,16 @@ def stop(self) -> None: super().stop() - def _publish_pose(self, msg): + def _publish_pose(self, msg) -> None: """Publish pose and TF transform.""" if msg: self.pose.publish(msg) # Publish TF transform from world to camera - from dimos.msgs.geometry_msgs import Transform, Vector3, Quaternion import time + from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 + transform = Transform( translation=Vector3(*msg.position), rotation=Quaternion(*msg.orientation), diff --git a/dimos/hardware/gstreamer_camera.py b/dimos/hardware/gstreamer_camera.py index 32c2e8304b..38ede23ee1 100644 --- a/dimos/hardware/gstreamer_camera.py +++ b/dimos/hardware/gstreamer_camera.py @@ -33,7 +33,7 @@ gi.require_version("Gst", "1.0") gi.require_version("GstApp", "1.0") -from gi.repository import Gst, GLib +from gi.repository import GLib, Gst logger = setup_logger("dimos.hardware.gstreamer_camera", level=logging.INFO) @@ -54,7 +54,7 @@ def __init__( reconnect_interval: float = 5.0, *args, **kwargs, - ): + ) -> None: """Initialize the GStreamer TCP camera module. Args: @@ -83,7 +83,7 @@ def __init__( Module.__init__(self, *args, **kwargs) @rpc - def start(self): + def start(self) -> None: if self.running: logger.warning("GStreamer camera module is already running") return @@ -128,12 +128,12 @@ def _connect(self) -> None: logger.error(f"Failed to connect to {self.host}:{self.port}: {e}") self._schedule_reconnect() - def _cleanup_reconnect_timer(self): + def _cleanup_reconnect_timer(self) -> None: if self.reconnect_timer_id: GLib.source_remove(self.reconnect_timer_id) self.reconnect_timer_id = None - def _schedule_reconnect(self): + def _schedule_reconnect(self) -> None: if not self.should_reconnect: return @@ -143,14 +143,14 @@ def _schedule_reconnect(self): int(self.reconnect_interval), self._reconnect_timeout ) - def _reconnect_timeout(self): + def _reconnect_timeout(self) -> bool: self.reconnect_timer_id = None if self.should_reconnect: logger.info("Attempting to reconnect...") self._connect() return False # Don't repeat the timeout - def _handle_disconnect(self): + def _handle_disconnect(self) -> None: if not self.should_reconnect: return @@ -205,13 +205,13 @@ def _start_pipeline(self): bus.add_signal_watch() bus.connect("message", self._on_bus_message) - def _run_main_loop(self): + def _run_main_loop(self) -> None: try: self.main_loop.run() except Exception as e: logger.error(f"Main loop error: {e}") - def _on_bus_message(self, bus, message): + def _on_bus_message(self, bus, message) -> None: t = message.type if t == Gst.MessageType.EOS: @@ -226,7 +226,7 @@ def _on_bus_message(self, bus, message): logger.warning(f"GStreamer warning: {warn}, {debug}") elif t == Gst.MessageType.STATE_CHANGED: if message.src == self.pipeline: - old_state, new_state, pending_state = message.parse_state_changed() + _old_state, new_state, _pending_state = message.parse_state_changed() if new_state == Gst.State.PLAYING: logger.info("Pipeline is now playing - connected to TCP server") diff --git a/dimos/hardware/gstreamer_camera_test_script.py b/dimos/hardware/gstreamer_camera_test_script.py index fd0e154904..f815579c0d 100755 --- a/dimos/hardware/gstreamer_camera_test_script.py +++ b/dimos/hardware/gstreamer_camera_test_script.py @@ -18,16 +18,16 @@ import logging import time -from dimos.hardware.gstreamer_camera import GstreamerCameraModule from dimos import core -from dimos.protocol import pubsub +from dimos.hardware.gstreamer_camera import GstreamerCameraModule from dimos.msgs.sensor_msgs import Image +from dimos.protocol import pubsub logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) -def main(): +def main() -> None: parser = argparse.ArgumentParser(description="Test script for GStreamer TCP camera module") # Network options @@ -82,7 +82,7 @@ def main(): last_log_time = [time.time()] first_timestamp = [None] - def on_frame(msg): + def on_frame(msg) -> None: frame_count[0] += 1 current_time = time.time() diff --git a/dimos/hardware/gstreamer_sender.py b/dimos/hardware/gstreamer_sender.py index 5b526609e1..ce7c1d6145 100755 --- a/dimos/hardware/gstreamer_sender.py +++ b/dimos/hardware/gstreamer_sender.py @@ -52,7 +52,7 @@ def __init__( host: str = "0.0.0.0", port: int = 5000, single_camera: bool = False, - ): + ) -> None: """Initialize the GStreamer TCP sender. Args: @@ -200,7 +200,7 @@ def _inject_absolute_timestamp(self, pad, info, user_data): self.frame_count += 1 return Gst.PadProbeReturn.OK - def _on_bus_message(self, bus, message): + def _on_bus_message(self, bus, message) -> None: t = message.type if t == Gst.MessageType.EOS: @@ -215,7 +215,7 @@ def _on_bus_message(self, bus, message): logger.warning(f"Pipeline warning: {warn}, {debug}") elif t == Gst.MessageType.STATE_CHANGED: if message.src == self.pipeline: - old_state, new_state, pending_state = message.parse_state_changed() + old_state, new_state, _pending_state = message.parse_state_changed() logger.debug( f"Pipeline state changed: {old_state.value_nick} -> {new_state.value_nick}" ) @@ -261,7 +261,7 @@ def start(self): finally: self.stop() - def stop(self): + def stop(self) -> None: if not self.running: return @@ -282,7 +282,7 @@ def stop(self): logger.info("TCP video sender stopped") -def main(): +def main() -> None: parser = argparse.ArgumentParser( description="GStreamer TCP video sender with absolute timestamps" ) @@ -340,7 +340,7 @@ def main(): ) # Handle signals gracefully - def signal_handler(sig, frame): + def signal_handler(sig, frame) -> None: logger.info(f"Received signal {sig}, shutting down...") sender.stop() sys.exit(0) diff --git a/dimos/hardware/piper_arm.py b/dimos/hardware/piper_arm.py index 71ce4bf04f..d27d1df394 100644 --- a/dimos/hardware/piper_arm.py +++ b/dimos/hardware/piper_arm.py @@ -14,34 +14,32 @@ # dimos/hardware/piper_arm.py -from reactivex.disposable import Disposable -from typing import Tuple -from piper_sdk import * # from the official Piper SDK -import numpy as np -import time -import kinpy as kp +import select import sys import termios -import tty -import select -from scipy.spatial.transform import Rotation as R -from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler -from dimos.utils.logging_config import setup_logger - import threading +import time +import tty +from dimos_lcm.geometry_msgs import Pose, Twist, Vector3 +import kinpy as kp +import numpy as np +from piper_sdk import * # from the official Piper SDK import pytest +from reactivex.disposable import Disposable +from scipy.spatial.transform import Rotation as R import dimos.core as core -import dimos.protocol.service.lcmservice as lcmservice from dimos.core import In, Module, rpc -from dimos_lcm.geometry_msgs import Pose, Vector3, Twist +import dimos.protocol.service.lcmservice as lcmservice +from dimos.utils.logging_config import setup_logger +from dimos.utils.transform_utils import euler_to_quaternion, quaternion_to_euler logger = setup_logger(__file__) class PiperArm: - def __init__(self, arm_name: str = "arm"): + def __init__(self, arm_name: str = "arm") -> None: self.arm = C_PiperInterface_V2() self.arm.ConnectPort() self.resetArm() @@ -54,7 +52,7 @@ def __init__(self, arm_name: str = "arm"): time.sleep(1) self.init_vel_controller() - def enable(self): + def enable(self) -> None: while not self.arm.EnablePiper(): pass time.sleep(0.01) @@ -67,7 +65,7 @@ def enable(self): # ) self.arm.MotionCtrl_2(0x01, 0x01, 80, 0xAD) - def gotoZero(self): + def gotoZero(self) -> None: factor = 1000 position = [57.0, 0.0, 215.0, 0, 90.0, 0, 0] X = round(position[0] * factor) @@ -76,13 +74,13 @@ def gotoZero(self): RX = round(position[3] * factor) RY = round(position[4] * factor) RZ = round(position[5] * factor) - joint_6 = round(position[6] * factor) + round(position[6] * factor) logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) self.arm.GripperCtrl(0, 1000, 0x01, 0) - def gotoObserve(self): + def gotoObserve(self) -> None: factor = 1000 position = [57.0, 0.0, 280.0, 0, 120.0, 0, 0] X = round(position[0] * factor) @@ -91,12 +89,12 @@ def gotoObserve(self): RX = round(position[3] * factor) RY = round(position[4] * factor) RZ = round(position[5] * factor) - joint_6 = round(position[6] * factor) + round(position[6] * factor) logger.debug(f"Going to zero position: X={X}, Y={Y}, Z={Z}, RX={RX}, RY={RY}, RZ={RZ}") self.arm.MotionCtrl_2(0x01, 0x00, 100, 0x00) self.arm.EndPoseCtrl(X, Y, Z, RX, RY, RZ) - def softStop(self): + def softStop(self) -> None: self.gotoZero() time.sleep(1) self.arm.MotionCtrl_2( @@ -107,7 +105,7 @@ def softStop(self): self.arm.MotionCtrl_1(0x01, 0, 0) time.sleep(3) - def cmd_ee_pose_values(self, x, y, z, r, p, y_, line_mode=False): + def cmd_ee_pose_values(self, x, y, z, r, p, y_, line_mode: bool = False) -> None: """Command end-effector to target pose in space (position + Euler angles)""" factor = 1000 pose = [ @@ -123,7 +121,7 @@ def cmd_ee_pose_values(self, x, y, z, r, p, y_, line_mode=False): int(pose[0]), int(pose[1]), int(pose[2]), int(pose[3]), int(pose[4]), int(pose[5]) ) - def cmd_ee_pose(self, pose: Pose, line_mode=False): + def cmd_ee_pose(self, pose: Pose, line_mode: bool = False) -> None: """Command end-effector to target pose using Pose message""" # Convert quaternion to euler angles euler = quaternion_to_euler(pose.orientation, degrees=True) @@ -160,7 +158,7 @@ def get_ee_pose(self): return Pose(position, orientation) - def cmd_gripper_ctrl(self, position, effort=0.25): + def cmd_gripper_ctrl(self, position, effort: float = 0.25) -> None: """Command end-effector gripper""" factor = 1000 position = position * factor * factor # meters @@ -169,7 +167,7 @@ def cmd_gripper_ctrl(self, position, effort=0.25): self.arm.GripperCtrl(abs(round(position)), abs(round(effort)), 0x01, 0) logger.debug(f"Commanding gripper position: {position}mm") - def enable_gripper(self): + def enable_gripper(self) -> None: """Enable the gripper using the initialization sequence""" logger.info("Enabling gripper...") while not self.arm.EnablePiper(): @@ -178,12 +176,12 @@ def enable_gripper(self): self.arm.GripperCtrl(0, 1000, 0x01, 0) logger.info("Gripper enabled") - def release_gripper(self): + def release_gripper(self) -> None: """Release gripper by opening to 100mm (10cm)""" logger.info("Releasing gripper (opening to 100mm)") self.cmd_gripper_ctrl(0.1) # 0.1m = 100mm = 10cm - def get_gripper_feedback(self) -> Tuple[float, float]: + def get_gripper_feedback(self) -> tuple[float, float]: """ Get current gripper feedback. @@ -221,7 +219,7 @@ def gripper_object_detected(self, commanded_effort: float = 0.25) -> bool: True if object is detected in gripper, False otherwise """ # Get gripper feedback - angle_degrees, actual_effort = self.get_gripper_feedback() + _angle_degrees, actual_effort = self.get_gripper_feedback() # Check if object is grasped (effort > 80% of commanded effort) effort_threshold = 0.8 * commanded_effort @@ -234,12 +232,12 @@ def gripper_object_detected(self, commanded_effort: float = 0.25) -> bool: return object_present - def resetArm(self): + def resetArm(self) -> None: self.arm.MotionCtrl_1(0x02, 0, 0) self.arm.MotionCtrl_2(0, 0, 0, 0x00) logger.info("Resetting arm") - def init_vel_controller(self): + def init_vel_controller(self) -> None: self.chain = kp.build_serial_chain_from_urdf( open("dimos/hardware/piper_description.urdf"), "gripper_base" ) @@ -247,7 +245,7 @@ def init_vel_controller(self): self.J_pinv = np.linalg.pinv(self.J) self.dt = 0.01 - def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): + def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot) -> None: joint_state = self.arm.GetArmJointMsgs().joint_state # print(f"[PiperArm] Current Joints (direct): {joint_state}", type(joint_state)) joint_angles = np.array( @@ -283,17 +281,17 @@ def cmd_vel(self, x_dot, y_dot, z_dot, R_dot, P_dot, Y_dot): self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) self.arm.JointCtrl( - int(round(newq[0])), - int(round(newq[1])), - int(round(newq[2])), - int(round(newq[3])), - int(round(newq[4])), - int(round(newq[5])), + round(newq[0]), + round(newq[1]), + round(newq[2]), + round(newq[3]), + round(newq[4]), + round(newq[5]), ) time.sleep(self.dt) # print(f"[PiperArm] Moving to Joints to : {newq}") - def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot): + def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot) -> None: factor = 1000 x_dot = x_dot * factor y_dot = y_dot * factor @@ -341,7 +339,7 @@ def cmd_vel_ee(self, x_dot, y_dot, z_dot, RX_dot, PY_dot, YZ_dot): ) time.sleep(self.dt) - def disable(self): + def disable(self) -> None: self.softStop() while self.arm.DisablePiper(): @@ -353,7 +351,7 @@ def disable(self): class VelocityController(Module): cmd_vel: In[Twist] = None - def __init__(self, arm, period=0.01, *args, **kwargs): + def __init__(self, arm, period: float = 0.01, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.arm = arm self.period = period @@ -362,13 +360,13 @@ def __init__(self, arm, period=0.01, *args, **kwargs): self._thread = None @rpc - def start(self): + def start(self) -> None: super().start() unsub = self.cmd_vel.subscribe(self.handle_cmd_vel) self._disposables.add(Disposable(unsub)) - def control_loop(): + def control_loop() -> None: while True: # Check for timeout (1 second) if self.last_cmd_time and (time.time() - self.last_cmd_time) > 1.0: @@ -426,12 +424,12 @@ def control_loop(): self.arm.MotionCtrl_2(0x01, 0x01, 100, 0xAD) self.arm.JointCtrl( - int(round(newq[0])), - int(round(newq[1])), - int(round(newq[2])), - int(round(newq[3])), - int(round(newq[4])), - int(round(newq[5])), + round(newq[0]), + round(newq[1]), + round(newq[2]), + round(newq[3]), + round(newq[4]), + round(newq[5]), ) time.sleep(self.period) @@ -445,13 +443,13 @@ def stop(self) -> None: self._thread.join(2) super().stop() - def handle_cmd_vel(self, cmd_vel: Twist): + def handle_cmd_vel(self, cmd_vel: Twist) -> None: self.latest_cmd = cmd_vel self.last_cmd_time = time.time() @pytest.mark.tool -def run_velocity_controller(): +def run_velocity_controller() -> None: lcmservice.autoconf() dimos = core.start(2) @@ -470,7 +468,7 @@ def run_velocity_controller(): if __name__ == "__main__": arm = PiperArm() - def get_key(timeout=0.1): + def get_key(timeout: float = 0.1): """Non-blocking key reader for arrow keys.""" fd = sys.stdin.fileno() old_settings = termios.tcgetattr(fd) @@ -490,7 +488,7 @@ def get_key(timeout=0.1): finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - def teleop_linear_vel(arm): + def teleop_linear_vel(arm) -> None: print("Use arrow keys to control linear velocity (x/y/z). Press 'q' to quit.") print("Up/Down: +x/-x, Left/Right: +y/-y, 'w'/'s': +z/-z") x_dot, y_dot, z_dot = 0.0, 0.0, 0.0 diff --git a/dimos/hardware/sensor.py b/dimos/hardware/sensor.py index 3dc7b3850e..aa39f25ec6 100644 --- a/dimos/hardware/sensor.py +++ b/dimos/hardware/sensor.py @@ -16,7 +16,7 @@ class AbstractSensor(ABC): - def __init__(self, sensor_type=None): + def __init__(self, sensor_type=None) -> None: self.sensor_type = sensor_type @abstractmethod diff --git a/dimos/hardware/ufactory.py b/dimos/hardware/ufactory.py index cf4e139ccb..57caf2e3bd 100644 --- a/dimos/hardware/ufactory.py +++ b/dimos/hardware/ufactory.py @@ -16,7 +16,7 @@ class UFactoryEndEffector(EndEffector): - def __init__(self, model=None, **kwargs): + def __init__(self, model=None, **kwargs) -> None: super().__init__(**kwargs) self.model = model @@ -25,7 +25,7 @@ def get_model(self): class UFactory7DOFArm: - def __init__(self, arm_length=None): + def __init__(self, arm_length=None) -> None: self.arm_length = arm_length def get_arm_length(self): diff --git a/dimos/manipulation/manip_aio_pipeline.py b/dimos/manipulation/manip_aio_pipeline.py index 7c69e562cf..14c5d62afe 100644 --- a/dimos/manipulation/manip_aio_pipeline.py +++ b/dimos/manipulation/manip_aio_pipeline.py @@ -18,24 +18,22 @@ import asyncio import json -import logging import threading import time -import traceback -import websockets -from typing import Dict, List, Optional, Any + +import cv2 import numpy as np import reactivex as rx import reactivex.operators as ops -from dimos.utils.logging_config import setup_logger +import websockets + +from dimos.perception.common.utils import colorize_depth from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering -from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.perception.grasp_generation.utils import draw_grasps_on_image +from dimos.perception.object_detection_stream import ObjectDetectionStream +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.pointcloud.utils import create_point_cloud_overlay_visualization -from dimos.perception.common.utils import colorize_depth from dimos.utils.logging_config import setup_logger -import cv2 logger = setup_logger("dimos.perception.manip_aio_pipeline") @@ -51,13 +49,13 @@ class ManipulationPipeline: def __init__( self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] + camera_intrinsics: list[float], # [fx, fy, cx, cy] min_confidence: float = 0.6, max_objects: int = 10, - vocabulary: Optional[str] = None, - grasp_server_url: Optional[str] = None, + vocabulary: str | None = None, + grasp_server_url: str | None = None, enable_grasp_generation: bool = False, - ): + ) -> None: """ Initialize the manipulation pipeline. @@ -81,14 +79,14 @@ def __init__( self.grasp_loop_thread = None # Storage for grasp results and filtered objects - self.latest_grasps: List[dict] = [] # Simplified: just a list of grasps + self.latest_grasps: list[dict] = [] # Simplified: just a list of grasps self.grasps_consumed = False self.latest_filtered_objects = [] self.latest_rgb_for_grasps = None # Store RGB image for grasp overlay self.grasp_lock = threading.Lock() # Track pending requests - simplified to single task - self.grasp_task: Optional[asyncio.Task] = None + self.grasp_task: asyncio.Task | None = None # Reactive subjects for streaming filtered objects and grasps self.filtered_objects_subject = rx.subject.Subject() @@ -111,7 +109,7 @@ def __init__( logger.info(f"Initialized ManipulationPipeline with confidence={min_confidence}") - def create_streams(self, zed_stream: rx.Observable) -> Dict[str, rx.Observable]: + def create_streams(self, zed_stream: rx.Observable) -> dict[str, rx.Observable]: """ Create streams using exact old main logic. """ @@ -140,7 +138,7 @@ def create_streams(self, zed_stream: rx.Observable) -> Dict[str, rx.Observable]: frame_lock = threading.Lock() # Subscribe to combined ZED frames (from old main) - def on_zed_frame(zed_data): + def on_zed_frame(zed_data) -> None: nonlocal latest_rgb, latest_depth if zed_data is not None: with frame_lock: @@ -167,9 +165,9 @@ def get_depth_or_overlay(zed_data): ) # Process object detection results with point cloud filtering (from old main) - def on_detection_next(result): + def on_detection_next(result) -> None: nonlocal latest_point_cloud_overlay - if "objects" in result and result["objects"]: + if result.get("objects"): # Get latest RGB and depth frames with frame_lock: rgb = latest_rgb @@ -210,12 +208,12 @@ def on_detection_next(result): task = self.request_scene_grasps(filtered_objects) if task: # Check for results after a delay - def check_grasps_later(): + def check_grasps_later() -> None: time.sleep(2.0) # Wait for grasp processing # Wait for task to complete if hasattr(self, "grasp_task") and self.grasp_task: try: - result = self.grasp_task.result( + self.grasp_task.result( timeout=3.0 ) # Get result with timeout except Exception as e: @@ -258,13 +256,13 @@ def check_grasps_later(): with frame_lock: latest_point_cloud_overlay = None - def on_error(error): + def on_error(error) -> None: logger.error(f"Error in stream: {error}") - def on_completed(): + def on_completed() -> None: logger.info("Stream completed") - def start_subscriptions(): + def start_subscriptions() -> None: """Start subscriptions in background thread (from old main)""" # Subscribe to combined ZED frames zed_frame_stream.subscribe(on_next=on_zed_frame) @@ -303,10 +301,10 @@ def start_subscriptions(): "grasp_overlay": grasp_overlay_stream, } - def _start_grasp_loop(self): + def _start_grasp_loop(self) -> None: """Start asyncio event loop in a background thread for WebSocket communication.""" - def run_loop(): + def run_loop() -> None: self.grasp_loop = asyncio.new_event_loop() asyncio.set_event_loop(self.grasp_loop) self.grasp_loop.run_forever() @@ -319,8 +317,8 @@ def run_loop(): time.sleep(0.01) async def _send_grasp_request( - self, points: np.ndarray, colors: Optional[np.ndarray] - ) -> Optional[List[dict]]: + self, points: np.ndarray, colors: np.ndarray | None + ) -> list[dict] | None: """Send grasp request to Dimensional Grasp server.""" try: # Comprehensive client-side validation to prevent server errors @@ -419,7 +417,7 @@ async def _send_grasp_request( return None - def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: + def request_scene_grasps(self, objects: list[dict]) -> asyncio.Task | None: """Request grasps for entire scene by combining all object point clouds.""" if not self.grasp_loop or not objects: return None @@ -428,7 +426,7 @@ def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: all_colors = [] valid_objects = 0 - for i, obj in enumerate(objects): + for _i, obj in enumerate(objects): # Validate point cloud data if "point_cloud_numpy" not in obj or obj["point_cloud_numpy"] is None: continue @@ -494,11 +492,11 @@ def request_scene_grasps(self, objects: List[dict]) -> Optional[asyncio.Task]: self.grasp_task = task return task - except Exception as e: + except Exception: logger.warning("Failed to create grasp task") return None - def get_latest_grasps(self, timeout: float = 5.0) -> Optional[List[dict]]: + def get_latest_grasps(self, timeout: float = 5.0) -> list[dict] | None: """Get latest grasp results, waiting for new ones if current ones have been consumed.""" # Mark current grasps as consumed and get a reference with self.grasp_lock: @@ -525,7 +523,7 @@ def clear_grasps(self) -> None: with self.grasp_lock: self.latest_grasps = [] - def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: + def _prepare_colors(self, colors: np.ndarray | None) -> np.ndarray | None: """Prepare colors array, converting from various formats if needed.""" if colors is None: return None @@ -535,7 +533,7 @@ def _prepare_colors(self, colors: Optional[np.ndarray]) -> Optional[np.ndarray]: return colors - def _convert_grasp_format(self, grasps: List[dict]) -> List[dict]: + def _convert_grasp_format(self, grasps: list[dict]) -> list[dict]: """Convert Grasp format to our visualization format.""" converted = [] @@ -559,7 +557,7 @@ def _convert_grasp_format(self, grasps: List[dict]) -> List[dict]: return converted - def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> dict[str, float]: """Convert rotation matrix to Euler angles (in radians).""" sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) @@ -576,7 +574,7 @@ def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, fl return {"roll": x, "pitch": y, "yaw": z} - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" if hasattr(self.detector, "cleanup"): self.detector.cleanup() diff --git a/dimos/manipulation/manip_aio_processer.py b/dimos/manipulation/manip_aio_processer.py index aa439d2814..e0bfc73256 100644 --- a/dimos/manipulation/manip_aio_processer.py +++ b/dimos/manipulation/manip_aio_processer.py @@ -16,28 +16,28 @@ Sequential manipulation processor for single-frame processing without reactive streams. """ -import logging import time -from typing import Dict, List, Optional, Any, Tuple -import numpy as np +from typing import Any + import cv2 +import numpy as np -from dimos.utils.logging_config import setup_logger +from dimos.perception.common.utils import ( + colorize_depth, + combine_object_data, + detection_results_to_object_data, +) from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering -from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.grasp_generation.grasp_generation import HostedGraspGenerator from dimos.perception.grasp_generation.utils import create_grasp_overlay +from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.pointcloud.utils import ( create_point_cloud_overlay_visualization, extract_and_cluster_misc_points, overlay_point_clouds_on_image, ) -from dimos.perception.common.utils import ( - colorize_depth, - detection_results_to_object_data, - combine_object_data, -) +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.perception.manip_aio_processor") @@ -52,14 +52,14 @@ class ManipulationProcessor: def __init__( self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] + camera_intrinsics: list[float], # [fx, fy, cx, cy] min_confidence: float = 0.6, max_objects: int = 20, - vocabulary: Optional[str] = None, + vocabulary: str | None = None, enable_grasp_generation: bool = False, - grasp_server_url: Optional[str] = None, # Required when enable_grasp_generation=True + grasp_server_url: str | None = None, # Required when enable_grasp_generation=True enable_segmentation: bool = True, - ): + ) -> None: """ Initialize the manipulation processor. @@ -119,8 +119,8 @@ def __init__( ) def process_frame( - self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool = None - ) -> Dict[str, Any]: + self, rgb_image: np.ndarray, depth_image: np.ndarray, generate_grasps: bool | None = None + ) -> dict[str, Any]: """ Process a single RGB-D frame through the complete pipeline. @@ -295,7 +295,7 @@ def process_frame( return results - def run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: + def run_object_detection(self, rgb_image: np.ndarray) -> dict[str, Any]: """Run object detection on RGB image.""" try: # Convert RGB to BGR for Detic detector @@ -329,8 +329,8 @@ def run_object_detection(self, rgb_image: np.ndarray) -> Dict[str, Any]: return {"objects": [], "viz_frame": rgb_image.copy()} def run_pointcloud_filtering( - self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: List[Dict] - ) -> List[Dict]: + self, rgb_image: np.ndarray, depth_image: np.ndarray, objects: list[dict] + ) -> list[dict]: """Run point cloud filtering on detected objects.""" try: filtered_objects = self.pointcloud_filter.process_images( @@ -341,7 +341,7 @@ def run_pointcloud_filtering( logger.error(f"Point cloud filtering failed: {e}") return [] - def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: + def run_segmentation(self, rgb_image: np.ndarray) -> dict[str, Any]: """Run semantic segmentation on RGB image.""" if not self.segmenter: return {"objects": [], "viz_frame": rgb_image.copy()} @@ -380,7 +380,7 @@ def run_segmentation(self, rgb_image: np.ndarray) -> Dict[str, Any]: logger.error(f"Segmentation failed: {e}") return {"objects": [], "viz_frame": rgb_image.copy()} - def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Optional[List[Dict]]: + def run_grasp_generation(self, filtered_objects: list[dict], full_pcd) -> list[dict] | None: """Run grasp generation using the configured generator.""" if not self.grasp_generator: logger.warning("Grasp generation requested but no generator available") @@ -397,7 +397,7 @@ def run_grasp_generation(self, filtered_objects: List[Dict], full_pcd) -> Option logger.error(f"Grasp generation failed: {e}") return None - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" if hasattr(self.detector, "cleanup"): self.detector.cleanup() diff --git a/dimos/manipulation/manipulation_history.py b/dimos/manipulation/manipulation_history.py index 8404b225c1..a77900ba30 100644 --- a/dimos/manipulation/manipulation_history.py +++ b/dimos/manipulation/manipulation_history.py @@ -28,20 +28,16 @@ """Module for manipulation history tracking and search.""" -from typing import Dict, List, Optional, Any, Tuple, Union, Set, Callable from dataclasses import dataclass, field -import time from datetime import datetime -import os import json +import os import pickle -import uuid +import time +from typing import Any from dimos.types.manipulation import ( ManipulationTask, - AbstractConstraint, - ManipulationTaskConstraint, - ManipulationMetadata, ) from dimos.utils.logging_config import setup_logger @@ -61,8 +57,8 @@ class ManipulationHistoryEntry: task: ManipulationTask timestamp: float = field(default_factory=time.time) - result: Dict[str, Any] = field(default_factory=dict) - manipulation_response: Optional[str] = ( + result: dict[str, Any] = field(default_factory=dict) + manipulation_response: str | None = ( None # Any elaborative response from the motion planner / manipulation executor ) @@ -78,14 +74,14 @@ class ManipulationHistory: focusing on quick lookups and flexible search capabilities. """ - def __init__(self, output_dir: str = None, new_memory: bool = False): + def __init__(self, output_dir: str | None = None, new_memory: bool = False) -> None: """Initialize a new manipulation history. Args: output_dir: Directory to save history to new_memory: If True, creates a new memory instead of loading existing one """ - self._history: List[ManipulationHistoryEntry] = [] + self._history: list[ManipulationHistoryEntry] = [] self._output_dir = output_dir if output_dir and not new_memory: @@ -192,7 +188,7 @@ def load_from_dir(self, directory: str) -> None: except Exception as e: logger.error(f"Failed to load history: {e}") - def get_all_entries(self) -> List[ManipulationHistoryEntry]: + def get_all_entries(self) -> list[ManipulationHistoryEntry]: """Get all entries in chronological order. Returns: @@ -200,7 +196,7 @@ def get_all_entries(self) -> List[ManipulationHistoryEntry]: """ return self._history.copy() - def get_entry_by_index(self, index: int) -> Optional[ManipulationHistoryEntry]: + def get_entry_by_index(self, index: int) -> ManipulationHistoryEntry | None: """Get an entry by its index. Args: @@ -215,7 +211,7 @@ def get_entry_by_index(self, index: int) -> Optional[ManipulationHistoryEntry]: def get_entries_by_timerange( self, start_time: float, end_time: float - ) -> List[ManipulationHistoryEntry]: + ) -> list[ManipulationHistoryEntry]: """Get entries within a specific time range. Args: @@ -227,7 +223,7 @@ def get_entries_by_timerange( """ return [entry for entry in self._history if start_time <= entry.timestamp <= end_time] - def get_entries_by_object(self, object_name: str) -> List[ManipulationHistoryEntry]: + def get_entries_by_object(self, object_name: str) -> list[ManipulationHistoryEntry]: """Get entries related to a specific object. Args: @@ -239,7 +235,10 @@ def get_entries_by_object(self, object_name: str) -> List[ManipulationHistoryEnt return [entry for entry in self._history if entry.task.target_object == object_name] def create_task_entry( - self, task: ManipulationTask, result: Dict[str, Any] = None, agent_response: str = None + self, + task: ManipulationTask, + result: dict[str, Any] | None = None, + agent_response: str | None = None, ) -> ManipulationHistoryEntry: """Create a new manipulation history entry. @@ -257,7 +256,7 @@ def create_task_entry( self.add_entry(entry) return entry - def search(self, **kwargs) -> List[ManipulationHistoryEntry]: + def search(self, **kwargs) -> list[ManipulationHistoryEntry]: """Flexible search method that can search by any field in ManipulationHistoryEntry using dot notation. This method supports dot notation to access nested fields. String values automatically use diff --git a/dimos/manipulation/manipulation_interface.py b/dimos/manipulation/manipulation_interface.py index 68d3924a99..ae63eb79ed 100644 --- a/dimos/manipulation/manipulation_interface.py +++ b/dimos/manipulation/manipulation_interface.py @@ -20,29 +20,23 @@ metadata streams. """ -from typing import Dict, List, Optional, Any, Tuple, Union -from dataclasses import dataclass import os -import time -from datetime import datetime -from reactivex.disposable import Disposable +from typing import TYPE_CHECKING, Any + +from dimos.manipulation.manipulation_history import ( + ManipulationHistory, +) from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.types.manipulation import ( AbstractConstraint, - TranslationConstraint, - RotationConstraint, - ForceConstraint, - ManipulationTaskConstraint, ManipulationTask, - ManipulationMetadata, ObjectData, ) -from dimos.manipulation.manipulation_history import ( - ManipulationHistory, - ManipulationHistoryEntry, -) from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from reactivex.disposable import Disposable + logger = setup_logger("dimos.robot.manipulation_interface") @@ -60,7 +54,7 @@ def __init__( output_dir: str, new_memory: bool = False, perception_stream: ObjectDetectionStream = None, - ): + ) -> None: """ Initialize a new ManipulationInterface instance. @@ -81,12 +75,12 @@ def __init__( ) # List of constraints generated by the Agent via constraint generation skills - self.agent_constraints: List[AbstractConstraint] = [] + self.agent_constraints: list[AbstractConstraint] = [] # Initialize object detection stream and related properties self.perception_stream = perception_stream - self.latest_objects: List[ObjectData] = [] - self.stream_subscription: Optional[Disposable] = None + self.latest_objects: list[ObjectData] = [] + self.stream_subscription: Disposable | None = None # Set up subscription to perception stream if available self._setup_perception_subscription() @@ -103,7 +97,7 @@ def add_constraint(self, constraint: AbstractConstraint) -> None: self.agent_constraints.append(constraint) logger.info(f"Added agent constraint: {constraint}") - def get_constraints(self) -> List[AbstractConstraint]: + def get_constraints(self) -> list[AbstractConstraint]: """ Get all constraints generated by the Agent via constraint generation skills. @@ -112,7 +106,7 @@ def get_constraints(self) -> List[AbstractConstraint]: """ return self.agent_constraints - def get_constraint(self, constraint_id: str) -> Optional[AbstractConstraint]: + def get_constraint(self, constraint_id: str) -> AbstractConstraint | None: """ Get a specific constraint by its ID. @@ -131,7 +125,7 @@ def get_constraint(self, constraint_id: str) -> Optional[AbstractConstraint]: return None def add_manipulation_task( - self, task: ManipulationTask, manipulation_response: Optional[str] = None + self, task: ManipulationTask, manipulation_response: str | None = None ) -> None: """ Add a manipulation task to ManipulationHistory. @@ -146,7 +140,7 @@ def add_manipulation_task( task=task, result=None, notes=None, manipulation_response=manipulation_response ) - def get_manipulation_task(self, task_id: str) -> Optional[ManipulationTask]: + def get_manipulation_task(self, task_id: str) -> ManipulationTask | None: """ Get a manipulation task by its ID. @@ -158,7 +152,7 @@ def get_manipulation_task(self, task_id: str) -> Optional[ManipulationTask]: """ return self.history.get_manipulation_task(task_id) - def get_all_manipulation_tasks(self) -> List[ManipulationTask]: + def get_all_manipulation_tasks(self) -> list[ManipulationTask]: """ Get all manipulation tasks. @@ -168,8 +162,8 @@ def get_all_manipulation_tasks(self) -> List[ManipulationTask]: return self.history.get_all_manipulation_tasks() def update_task_status( - self, task_id: str, status: str, result: Optional[Dict[str, Any]] = None - ) -> Optional[ManipulationTask]: + self, task_id: str, status: str, result: dict[str, Any] | None = None + ) -> ManipulationTask | None: """ Update the status and result of a manipulation task. @@ -185,7 +179,7 @@ def update_task_status( # === Perception stream methods === - def _setup_perception_subscription(self): + def _setup_perception_subscription(self) -> None: """ Set up subscription to perception stream if available. """ @@ -197,7 +191,7 @@ def _setup_perception_subscription(self): ) logger.info("Subscribed to perception stream") - def _update_latest_objects(self, data): + def _update_latest_objects(self, data) -> None: """ Update the latest detected objects. @@ -207,7 +201,7 @@ def _update_latest_objects(self, data): if "objects" in data: self.latest_objects = data["objects"] - def get_latest_objects(self) -> List[ObjectData]: + def get_latest_objects(self) -> list[ObjectData]: """ Get the latest detected objects from the stream. @@ -216,7 +210,7 @@ def get_latest_objects(self) -> List[ObjectData]: """ return self.latest_objects - def get_object_by_id(self, object_id: int) -> Optional[ObjectData]: + def get_object_by_id(self, object_id: int) -> ObjectData | None: """ Get a specific object by its tracking ID. @@ -231,7 +225,7 @@ def get_object_by_id(self, object_id: int) -> Optional[ObjectData]: return obj return None - def get_objects_by_label(self, label: str) -> List[ObjectData]: + def get_objects_by_label(self, label: str) -> list[ObjectData]: """ Get all objects with a specific label. @@ -243,7 +237,7 @@ def get_objects_by_label(self, label: str) -> List[ObjectData]: """ return [obj for obj in self.latest_objects if obj["label"] == label] - def set_perception_stream(self, perception_stream): + def set_perception_stream(self, perception_stream) -> None: """ Set or update the perception stream. @@ -257,7 +251,7 @@ def set_perception_stream(self, perception_stream): self.perception_stream = perception_stream self._setup_perception_subscription() - def cleanup_perception_subscription(self): + def cleanup_perception_subscription(self) -> None: """ Clean up the stream subscription. """ @@ -285,7 +279,7 @@ def __str__(self) -> str: has_stream = self.perception_stream is not None return f"ManipulationInterface(history={self.manipulation_history}, agent_constraints={len(self.agent_constraints)}, perception_stream={has_stream}, detected_objects={len(self.latest_objects)})" - def __del__(self): + def __del__(self) -> None: """ Clean up resources on deletion. """ diff --git a/dimos/manipulation/test_manipulation_history.py b/dimos/manipulation/test_manipulation_history.py index 239a04a86f..141c9365aa 100644 --- a/dimos/manipulation/test_manipulation_history.py +++ b/dimos/manipulation/test_manipulation_history.py @@ -27,20 +27,17 @@ # limitations under the License. import os -import time import tempfile +import time + import pytest -from typing import Dict, List, Optional, Any, Tuple from dimos.manipulation.manipulation_history import ManipulationHistory, ManipulationHistoryEntry from dimos.types.manipulation import ( + ForceConstraint, ManipulationTask, - AbstractConstraint, - TranslationConstraint, RotationConstraint, - ForceConstraint, - ManipulationTaskConstraint, - ManipulationMetadata, + TranslationConstraint, ) from dimos.types.vector import Vector @@ -159,7 +156,7 @@ def populated_history(sample_task, sample_task_with_constraints): return history -def test_manipulation_history_init(): +def test_manipulation_history_init() -> None: """Test initialization of ManipulationHistory.""" # Default initialization history = ManipulationHistory() @@ -173,7 +170,7 @@ def test_manipulation_history_init(): assert os.path.exists(temp_dir) -def test_manipulation_history_add_entry(sample_task): +def test_manipulation_history_add_entry(sample_task) -> None: """Test adding entries to ManipulationHistory.""" history = ManipulationHistory() @@ -187,7 +184,7 @@ def test_manipulation_history_add_entry(sample_task): assert history.get_entry_by_index(0) == entry -def test_manipulation_history_create_task_entry(sample_task): +def test_manipulation_history_create_task_entry(sample_task) -> None: """Test creating a task entry directly.""" history = ManipulationHistory() @@ -201,11 +198,11 @@ def test_manipulation_history_create_task_entry(sample_task): assert entry.manipulation_response == "Task completed" -def test_manipulation_history_save_load(temp_output_dir, sample_task): +def test_manipulation_history_save_load(temp_output_dir, sample_task) -> None: """Test saving and loading history from disk.""" # Create history and add entry history = ManipulationHistory(output_dir=temp_output_dir) - entry = history.create_task_entry( + history.create_task_entry( task=sample_task, result={"status": "success"}, agent_response="Task completed" ) @@ -221,7 +218,7 @@ def test_manipulation_history_save_load(temp_output_dir, sample_task): assert loaded_history.get_entry_by_index(0).task.description == sample_task.description -def test_manipulation_history_clear(populated_history): +def test_manipulation_history_clear(populated_history) -> None: """Test clearing the history.""" assert len(populated_history) > 0 @@ -230,7 +227,7 @@ def test_manipulation_history_clear(populated_history): assert str(populated_history) == "ManipulationHistory(empty)" -def test_manipulation_history_get_methods(populated_history): +def test_manipulation_history_get_methods(populated_history) -> None: """Test various getter methods of ManipulationHistory.""" # get_all_entries entries = populated_history.get_all_entries() @@ -259,7 +256,7 @@ def test_manipulation_history_get_methods(populated_history): assert bottle_entries[0].task.task_id == "task2" -def test_manipulation_history_search_basic(populated_history): +def test_manipulation_history_search_basic(populated_history) -> None: """Test basic search functionality.""" # Search by exact match on top-level fields results = populated_history.search(timestamp=populated_history.get_entry_by_index(0).timestamp) @@ -281,7 +278,7 @@ def test_manipulation_history_search_basic(populated_history): assert results[0].task.task_id == "task1" -def test_manipulation_history_search_nested(populated_history): +def test_manipulation_history_search_nested(populated_history) -> None: """Test search with nested field paths.""" # Search by nested metadata fields results = populated_history.search( @@ -304,7 +301,7 @@ def test_manipulation_history_search_nested(populated_history): assert results[0].task.task_id == "task1" -def test_manipulation_history_search_wildcards(populated_history): +def test_manipulation_history_search_wildcards(populated_history) -> None: """Test search with wildcard patterns.""" # Search for any object with label "cup" results = populated_history.search(**{"task.metadata.objects.*.label": "cup"}) @@ -322,7 +319,7 @@ def test_manipulation_history_search_wildcards(populated_history): assert results[0].task.task_id == "task2" -def test_manipulation_history_search_constraints(populated_history): +def test_manipulation_history_search_constraints(populated_history) -> None: """Test search by constraint properties.""" # Find entries with any TranslationConstraint with y-axis results = populated_history.search(**{"task.constraints.*.translation_axis": "y"}) @@ -335,7 +332,7 @@ def test_manipulation_history_search_constraints(populated_history): assert results[0].task.task_id == "task2" -def test_manipulation_history_search_string_contains(populated_history): +def test_manipulation_history_search_string_contains(populated_history) -> None: """Test string contains searching.""" # Basic string contains results = populated_history.search(**{"task.description": "Pick"}) @@ -348,7 +345,7 @@ def test_manipulation_history_search_string_contains(populated_history): assert results[0].task.task_id == "task2" -def test_manipulation_history_search_multiple_criteria(populated_history): +def test_manipulation_history_search_multiple_criteria(populated_history) -> None: """Test search with multiple criteria.""" # Multiple criteria - all must match results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) @@ -367,7 +364,7 @@ def test_manipulation_history_search_multiple_criteria(populated_history): assert results[0].task.task_id == "task2" -def test_manipulation_history_search_nonexistent_fields(populated_history): +def test_manipulation_history_search_nonexistent_fields(populated_history) -> None: """Test search with fields that don't exist.""" # Search by nonexistent field results = populated_history.search(nonexistent_field="value") @@ -382,7 +379,7 @@ def test_manipulation_history_search_nonexistent_fields(populated_history): assert len(results) == 0 -def test_manipulation_history_search_timestamp_ranges(populated_history): +def test_manipulation_history_search_timestamp_ranges(populated_history) -> None: """Test searching by timestamp ranges.""" # Get reference timestamps entry1_time = populated_history.get_entry_by_index(0).task.metadata["timestamp"] @@ -406,7 +403,7 @@ def test_manipulation_history_search_timestamp_ranges(populated_history): assert results[1].task.task_id == "task2" -def test_manipulation_history_search_vector_fields(populated_history): +def test_manipulation_history_search_vector_fields(populated_history) -> None: """Test searching by vector components in constraints.""" # Search by reference point components results = populated_history.search(**{"task.constraints.*.reference_point.x": 2.5}) @@ -424,7 +421,7 @@ def test_manipulation_history_search_vector_fields(populated_history): assert results[0].task.task_id == "task2" -def test_manipulation_history_search_execution_details(populated_history): +def test_manipulation_history_search_execution_details(populated_history) -> None: """Test searching by execution time and error patterns.""" # Search by execution time results = populated_history.search(**{"result.execution_time": 2.5}) @@ -442,7 +439,7 @@ def test_manipulation_history_search_execution_details(populated_history): assert results[0].task.task_id == "task1" -def test_manipulation_history_search_multiple_criteria(populated_history): +def test_manipulation_history_search_multiple_criteria(populated_history) -> None: """Test search with multiple criteria.""" # Multiple criteria - all must match results = populated_history.search(**{"task.target_object": "cup", "result.status": "success"}) diff --git a/dimos/manipulation/visual_servoing/detection3d.py b/dimos/manipulation/visual_servoing/detection3d.py index 0b78f3518c..f7371f531a 100644 --- a/dimos/manipulation/visual_servoing/detection3d.py +++ b/dimos/manipulation/visual_servoing/detection3d.py @@ -16,34 +16,32 @@ Real-time 3D object detection processor that extracts object poses from RGB-D data. """ -from typing import List, Optional, Tuple -import numpy as np import cv2 - -from dimos.utils.logging_config import setup_logger -from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter -from dimos.perception.pointcloud.utils import extract_centroids_from_masks -from dimos.perception.detection2d.utils import calculate_object_size_from_bbox -from dimos.perception.common.utils import bbox2d_to_corners - -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos.msgs.std_msgs import Header -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray from dimos_lcm.vision_msgs import ( - Detection3D, + BoundingBox2D, BoundingBox3D, - ObjectHypothesisWithPose, - ObjectHypothesis, Detection2D, - BoundingBox2D, - Pose2D, + Detection3D, + ObjectHypothesis, + ObjectHypothesisWithPose, Point2D, + Pose2D, ) +import numpy as np + from dimos.manipulation.visual_servoing.utils import ( estimate_object_depth, - visualize_detections_3d, transform_pose, + visualize_detections_3d, ) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.perception.common.utils import bbox2d_to_corners +from dimos.perception.detection2d.utils import calculate_object_size_from_bbox +from dimos.perception.pointcloud.utils import extract_centroids_from_masks +from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.manipulation.visual_servoing.detection3d") @@ -58,12 +56,12 @@ class Detection3DProcessor: def __init__( self, - camera_intrinsics: List[float], # [fx, fy, cx, cy] + camera_intrinsics: list[float], # [fx, fy, cx, cy] min_confidence: float = 0.6, min_points: int = 30, max_depth: float = 1.0, max_object_size: float = 0.15, - ): + ) -> None: """ Initialize the real-time 3D detection processor. @@ -93,8 +91,8 @@ def __init__( ) def process_frame( - self, rgb_image: np.ndarray, depth_image: np.ndarray, transform: Optional[np.ndarray] = None - ) -> Tuple[Detection3DArray, Detection2DArray]: + self, rgb_image: np.ndarray, depth_image: np.ndarray, transform: np.ndarray | None = None + ) -> tuple[Detection3DArray, Detection2DArray]: """ Process a single RGB-D frame to extract 3D object detections. @@ -138,7 +136,9 @@ def process_frame( detections_2d = [] pose_dict = {p["mask_idx"]: p for p in poses if p["centroid"][2] < self.max_depth} - for i, (bbox, name, prob, track_id) in enumerate(zip(bboxes, names, probs, track_ids)): + for i, (bbox, name, prob, track_id) in enumerate( + zip(bboxes, names, probs, track_ids, strict=False) + ): if i not in pose_dict: continue @@ -234,8 +234,8 @@ def process_frame( def visualize_detections( self, rgb_image: np.ndarray, - detections_3d: List[Detection3D], - detections_2d: List[Detection2D], + detections_3d: list[Detection3D], + detections_2d: list[Detection2D], show_coordinates: bool = True, ) -> np.ndarray: """ @@ -261,8 +261,8 @@ def visualize_detections( return visualize_detections_3d(rgb_image, detections_3d, show_coordinates, bboxes_2d) def get_closest_detection( - self, detections: List[Detection3D], class_filter: Optional[str] = None - ) -> Optional[Detection3D]: + self, detections: list[Detection3D], class_filter: str | None = None + ) -> Detection3D | None: """ Get the closest detection with valid 3D data. @@ -292,7 +292,7 @@ def get_z_coord(d): return min(valid_detections, key=get_z_coord) - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" if hasattr(self.detector, "cleanup"): self.detector.cleanup() diff --git a/dimos/manipulation/visual_servoing/manipulation_module.py b/dimos/manipulation/visual_servoing/manipulation_module.py index 9d2d77a0fa..a89d43ed7b 100644 --- a/dimos/manipulation/visual_servoing/manipulation_module.py +++ b/dimos/manipulation/visual_servoing/manipulation_module.py @@ -17,40 +17,39 @@ Handles grasping logic, state machine, and hardware coordination as a Dimos module. """ -import cv2 -import time -import threading -from typing import Optional, Tuple, Any, Dict -from enum import Enum from collections import deque +from enum import Enum +import threading +import time +from typing import Any +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo import numpy as np - from reactivex.disposable import Disposable -from dimos.core import Module, In, Out, rpc -from dimos.msgs.sensor_msgs import Image -from dimos.msgs.geometry_msgs import Vector3, Pose, Quaternion -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray -from dimos_lcm.sensor_msgs import CameraInfo +from dimos.core import In, Module, Out, rpc from dimos.hardware.piper_arm import PiperArm from dimos.manipulation.visual_servoing.detection3d import Detection3DProcessor from dimos.manipulation.visual_servoing.pbvs import PBVS -from dimos.perception.common.utils import find_clicked_detection from dimos.manipulation.visual_servoing.utils import ( create_manipulation_visualization, + is_target_reached, select_points_from_depth, transform_points_3d, update_target_grasp_pose, - is_target_reached, ) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.perception.common.utils import find_clicked_detection +from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( - pose_to_matrix, - matrix_to_pose, - create_transform_from_6dof, compose_transforms, + create_transform_from_6dof, + matrix_to_pose, + pose_to_matrix, ) -from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.manipulation.visual_servoing.manipulation_module") @@ -73,13 +72,13 @@ def __init__( self, grasp_stage: GraspStage, target_tracked: bool, - current_executed_pose: Optional[Pose] = None, - current_ee_pose: Optional[Pose] = None, - current_camera_pose: Optional[Pose] = None, - target_pose: Optional[Pose] = None, + current_executed_pose: Pose | None = None, + current_ee_pose: Pose | None = None, + current_camera_pose: Pose | None = None, + target_pose: Pose | None = None, waiting_for_reach: bool = False, - success: Optional[bool] = None, - ): + success: bool | None = None, + ) -> None: self.grasp_stage = grasp_stage self.target_tracked = target_tracked self.current_executed_pose = current_executed_pose @@ -117,9 +116,9 @@ class ManipulationModule(Module): def __init__( self, - ee_to_camera_6dof: Optional[list] = None, + ee_to_camera_6dof: list | None = None, **kwargs, - ): + ) -> None: """ Initialize manipulation module. @@ -220,7 +219,7 @@ def __init__( self.arm.gotoObserve() @rpc - def start(self): + def start(self) -> None: """Start the manipulation module.""" unsub = self.rgb_image.subscribe(self._on_rgb_image) @@ -235,7 +234,7 @@ def start(self): logger.info("Manipulation module started") @rpc - def stop(self): + def stop(self) -> None: """Stop the manipulation module.""" # Stop any running task self.stop_event.set() @@ -250,21 +249,21 @@ def stop(self): logger.info("Manipulation module stopped") - def _on_rgb_image(self, msg: Image): + def _on_rgb_image(self, msg: Image) -> None: """Handle RGB image messages.""" try: self.latest_rgb = msg.data except Exception as e: logger.error(f"Error processing RGB image: {e}") - def _on_depth_image(self, msg: Image): + def _on_depth_image(self, msg: Image) -> None: """Handle depth image messages.""" try: self.latest_depth = msg.data except Exception as e: logger.error(f"Error processing depth image: {e}") - def _on_camera_info(self, msg: CameraInfo): + def _on_camera_info(self, msg: CameraInfo) -> None: """Handle camera info messages.""" try: self.camera_intrinsics = [msg.K[0], msg.K[4], msg.K[2], msg.K[5]] @@ -279,7 +278,7 @@ def _on_camera_info(self, msg: CameraInfo): logger.error(f"Error processing camera info: {e}") @rpc - def get_single_rgb_frame(self) -> Optional[np.ndarray]: + def get_single_rgb_frame(self) -> np.ndarray | None: """ get the latest rgb frame from the camera """ @@ -323,8 +322,12 @@ def handle_keyboard_command(self, key: str) -> str: @rpc def pick_and_place( - self, target_x: int = None, target_y: int = None, place_x: int = None, place_y: int = None - ) -> Dict[str, Any]: + self, + target_x: int | None = None, + target_y: int | None = None, + place_x: int | None = None, + place_y: int | None = None, + ) -> dict[str, Any]: """ Start a pick and place task. @@ -386,7 +389,7 @@ def pick_and_place( return {"status": "started", "message": "Pick and place task started"} - def _run_pick_and_place(self): + def _run_pick_and_place(self) -> None: """Run the pick and place task loop.""" self.task_running = True logger.info("Starting pick and place task") @@ -421,7 +424,7 @@ def _run_pick_and_place(self): self.task_running = False logger.info("Pick and place task ended") - def set_grasp_stage(self, stage: GraspStage): + def set_grasp_stage(self, stage: GraspStage) -> None: """Set the grasp stage.""" self.grasp_stage = stage logger.info(f"Grasp stage: {stage.value}") @@ -479,7 +482,7 @@ def check_within_workspace(self, target_pose: Pose) -> bool: return True - def _check_reach_timeout(self) -> Tuple[bool, float]: + def _check_reach_timeout(self) -> tuple[bool, float]: """Check if robot has exceeded timeout while reaching pose. Returns: @@ -536,7 +539,7 @@ def check_reach_and_adjust(self) -> bool: target_pose = self.current_executed_pose # Check for timeout - this will fail task and reset if timeout occurred - timed_out, time_elapsed = self._check_reach_timeout() + timed_out, _time_elapsed = self._check_reach_timeout() if timed_out: return False @@ -582,7 +585,7 @@ def check_reach_and_adjust(self) -> bool: return True return False - def _update_tracking(self, detection_3d_array: Optional[Detection3DArray]) -> bool: + def _update_tracking(self, detection_3d_array: Detection3DArray | None) -> bool: """Update tracking with new detections.""" if not detection_3d_array or not self.pbvs: return False @@ -593,7 +596,7 @@ def _update_tracking(self, detection_3d_array: Optional[Detection3DArray]) -> bo self.last_valid_target = self.pbvs.get_current_target() return target_tracked - def reset_to_idle(self): + def reset_to_idle(self) -> None: """Reset the manipulation system to IDLE state.""" if self.pbvs: self.pbvs.clear_target() @@ -616,11 +619,11 @@ def reset_to_idle(self): self.arm.gotoObserve() - def execute_idle(self): + def execute_idle(self) -> None: """Execute idle stage.""" pass - def execute_pre_grasp(self): + def execute_pre_grasp(self) -> None: """Execute pre-grasp stage: visual servoing to pre-grasp position.""" if self.waiting_for_reach: if self.check_reach_and_adjust(): @@ -667,7 +670,7 @@ def execute_pre_grasp(self): self.adjustment_count += 1 time.sleep(0.2) - def execute_grasp(self): + def execute_grasp(self) -> None: """Execute grasp stage: move to final grasp position.""" if self.waiting_for_reach: if self.check_reach_and_adjust() and not self.grasp_reached_time: @@ -712,7 +715,7 @@ def execute_grasp(self): self.waiting_for_reach = True self.waiting_start_time = time.time() - def execute_close_and_retract(self): + def execute_close_and_retract(self) -> None: """Execute the retraction sequence after gripper has been closed.""" if self.waiting_for_reach and self.final_pregrasp_pose: if self.check_reach_and_adjust(): @@ -738,7 +741,7 @@ def execute_close_and_retract(self): self.waiting_for_reach = True self.waiting_start_time = time.time() - def execute_place(self): + def execute_place(self) -> None: """Execute place stage: move to place position and release object.""" if self.waiting_for_reach: # Use the already executed pose instead of recalculating @@ -764,7 +767,7 @@ def execute_place(self): self.task_failed = True self.overall_success = False - def execute_retract(self): + def execute_retract(self) -> None: """Execute retract stage: retract from place position.""" if self.waiting_for_reach and self.retract_pose: if self.check_reach_and_adjust(): @@ -794,9 +797,7 @@ def execute_retract(self): def capture_and_process( self, - ) -> Tuple[ - Optional[np.ndarray], Optional[Detection3DArray], Optional[Detection2DArray], Optional[Pose] - ]: + ) -> tuple[np.ndarray | None, Detection3DArray | None, Detection2DArray | None, Pose | None]: """Capture frame from camera data and process detections.""" if self.latest_rgb is None or self.latest_depth is None or self.detector is None: return None, None, None, None @@ -845,7 +846,7 @@ def pick_target(self, x: int, y: int) -> bool: return True return False - def update(self) -> Optional[Dict[str, Any]]: + def update(self) -> dict[str, Any] | None: """Main update function that handles capture, processing, control, and visualization.""" rgb, detection_3d_array, detection_2d_array, camera_pose = self.capture_and_process() if rgb is None: @@ -898,7 +899,7 @@ def update(self) -> Optional[Dict[str, Any]]: return feedback - def _publish_visualization(self, viz_image: np.ndarray): + def _publish_visualization(self, viz_image: np.ndarray) -> None: """Publish visualization image to LCM.""" try: viz_rgb = cv2.cvtColor(viz_image, cv2.COLOR_BGR2RGB) @@ -918,7 +919,7 @@ def check_target_stabilized(self) -> bool: std_devs = np.std(positions, axis=0) return np.all(std_devs < self.pose_stabilization_threshold) - def get_place_target_pose(self) -> Optional[Pose]: + def get_place_target_pose(self) -> Pose | None: """Get the place target pose with z-offset applied based on object height.""" if self.place_target_position is None: return None diff --git a/dimos/manipulation/visual_servoing/pbvs.py b/dimos/manipulation/visual_servoing/pbvs.py index a8f5ce5621..77bf83396e 100644 --- a/dimos/manipulation/visual_servoing/pbvs.py +++ b/dimos/manipulation/visual_servoing/pbvs.py @@ -17,20 +17,21 @@ Supports both eye-in-hand and eye-to-hand configurations. """ -import numpy as np -from typing import Optional, Tuple, List from collections import deque -from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos.msgs.vision_msgs import Detection3DArray + from dimos_lcm.vision_msgs import Detection3D -from dimos.utils.logging_config import setup_logger +import numpy as np +from scipy.spatial.transform import Rotation as R + from dimos.manipulation.visual_servoing.utils import ( - update_target_grasp_pose, - find_best_object_match, create_pbvs_visualization, + find_best_object_match, is_target_reached, + update_target_grasp_pose, ) +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.vision_msgs import Detection3DArray +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.manipulation.pbvs") @@ -59,7 +60,7 @@ def __init__( max_tracking_distance_threshold: float = 0.12, # Max distance for target tracking (m) min_size_similarity: float = 0.6, # Min size similarity threshold (0.0-1.0) direct_ee_control: bool = True, # If True, output target poses instead of velocities - ): + ) -> None: """ Initialize PBVS system. @@ -127,7 +128,7 @@ def set_target(self, target_object: Detection3D) -> bool: return True return False - def clear_target(self): + def clear_target(self) -> None: """Clear the current target.""" self.current_target = None self.target_grasp_pose = None @@ -138,7 +139,7 @@ def clear_target(self): self.controller.clear_state() logger.info("Target cleared") - def get_current_target(self) -> Optional[Detection3D]: + def get_current_target(self) -> Detection3D | None: """ Get the current target object. @@ -147,7 +148,7 @@ def get_current_target(self) -> Optional[Detection3D]: """ return self.current_target - def update_tracking(self, new_detections: Optional[Detection3DArray] = None) -> bool: + def update_tracking(self, new_detections: Detection3DArray | None = None) -> bool: """ Update target tracking with new detections using a rolling window. If tracking is lost, keeps the old target pose. @@ -214,7 +215,7 @@ def compute_control( ee_pose: Pose, grasp_distance: float = 0.15, grasp_pitch_degrees: float = 45.0, - ) -> Tuple[Optional[Vector3], Optional[Vector3], bool, bool, Optional[Pose]]: + ) -> tuple[Vector3 | None, Vector3 | None, bool, bool, Pose | None]: """ Compute PBVS control with position and orientation servoing. @@ -265,7 +266,7 @@ def compute_control( return None, None, False, True, None else: # Velocity control mode - use controller - velocity_cmd, angular_velocity_cmd, controller_reached = ( + velocity_cmd, angular_velocity_cmd, _controller_reached = ( self.controller.compute_control(ee_pose, self.target_grasp_pose) ) # Return has_target=True since we have a target, regardless of tracking status @@ -314,7 +315,7 @@ def __init__( max_velocity: float = 0.1, # m/s max_angular_velocity: float = 0.5, # rad/s target_tolerance: float = 0.01, # 1cm - ): + ) -> None: """ Initialize PBVS controller. @@ -343,7 +344,7 @@ def __init__( f"target_tolerance={target_tolerance}m" ) - def clear_state(self): + def clear_state(self) -> None: """Clear controller state.""" self.last_position_error = None self.last_rotation_error = None @@ -353,7 +354,7 @@ def clear_state(self): def compute_control( self, ee_pose: Pose, grasp_pose: Pose - ) -> Tuple[Optional[Vector3], Optional[Vector3], bool]: + ) -> tuple[Vector3 | None, Vector3 | None, bool]: """ Compute PBVS control with position and orientation servoing. @@ -466,7 +467,7 @@ def _compute_angular_velocity(self, target_rot: Quaternion, current_pose: Pose) def create_status_overlay( self, image: np.ndarray, - current_target: Optional[Detection3D] = None, + current_target: Detection3D | None = None, ) -> np.ndarray: """ Create PBVS status overlay on image. diff --git a/dimos/manipulation/visual_servoing/utils.py b/dimos/manipulation/visual_servoing/utils.py index df78d85327..06479723f6 100644 --- a/dimos/manipulation/visual_servoing/utils.py +++ b/dimos/manipulation/visual_servoing/utils.py @@ -12,31 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np -from typing import Dict, Any, Optional, List, Tuple, Union from dataclasses import dataclass +from typing import Any -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion -from dimos_lcm.vision_msgs import Detection3D, Detection2D import cv2 -from dimos.perception.detection2d.utils import plot_results +from dimos_lcm.vision_msgs import Detection2D, Detection3D +import numpy as np + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 from dimos.perception.common.utils import project_2d_points_to_3d +from dimos.perception.detection2d.utils import plot_results from dimos.utils.transform_utils import ( - optical_to_robot_frame, - robot_to_optical_frame, - pose_to_matrix, - matrix_to_pose, - euler_to_quaternion, compose_transforms, - yaw_towards_point, + euler_to_quaternion, get_distance, + matrix_to_pose, offset_distance, + optical_to_robot_frame, + pose_to_matrix, + robot_to_optical_frame, + yaw_towards_point, ) def match_detection_by_id( - detection_3d: Detection3D, detections_3d: List[Detection3D], detections_2d: List[Detection2D] -) -> Optional[Detection2D]: + detection_3d: Detection3D, detections_3d: list[Detection3D], detections_2d: list[Detection2D] +) -> Detection2D | None: """ Find the corresponding Detection2D for a given Detection3D. @@ -181,8 +182,8 @@ def transform_points_3d( def select_points_from_depth( depth_image: np.ndarray, - target_point: Tuple[int, int], - camera_intrinsics: Union[List[float], np.ndarray], + target_point: tuple[int, int], + camera_intrinsics: list[float] | np.ndarray, radius: int = 5, ) -> np.ndarray: """ @@ -230,7 +231,7 @@ def select_points_from_depth( def update_target_grasp_pose( target_pose: Pose, ee_pose: Pose, grasp_distance: float = 0.0, grasp_pitch_degrees: float = 45.0 -) -> Optional[Pose]: +) -> Pose | None: """ Update target grasp pose based on current target pose and EE pose. @@ -287,7 +288,7 @@ def is_target_reached(target_pose: Pose, current_pose: Pose, tolerance: float = class ObjectMatchResult: """Result of object matching with confidence metrics.""" - matched_object: Optional[Detection3D] + matched_object: Detection3D | None confidence: float distance: float size_similarity: float @@ -299,7 +300,7 @@ def calculate_object_similarity( candidate_obj: Detection3D, distance_weight: float = 0.6, size_weight: float = 0.4, -) -> Tuple[float, float, float]: +) -> tuple[float, float, float]: """ Calculate comprehensive similarity between two objects. @@ -335,7 +336,7 @@ def calculate_object_similarity( # Calculate similarity for each dimension pair dim_similarities = [] - for target_dim, candidate_dim in zip(target_dims, candidate_dims): + for target_dim, candidate_dim in zip(target_dims, candidate_dims, strict=False): if target_dim == 0.0 and candidate_dim == 0.0: dim_similarities.append(1.0) # Both dimensions are zero elif target_dim == 0.0 or candidate_dim == 0.0: @@ -358,7 +359,7 @@ def calculate_object_similarity( def find_best_object_match( target_obj: Detection3D, - candidates: List[Detection3D], + candidates: list[Detection3D], max_distance: float = 0.1, min_size_similarity: float = 0.4, distance_weight: float = 0.7, @@ -412,7 +413,7 @@ def find_best_object_match( ) -def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: +def parse_zed_pose(zed_pose_data: dict[str, Any]) -> Pose | None: """ Parse ZED pose data dictionary into a Pose object. @@ -439,7 +440,7 @@ def parse_zed_pose(zed_pose_data: Dict[str, Any]) -> Optional[Pose]: def estimate_object_depth( - depth_image: np.ndarray, segmentation_mask: Optional[np.ndarray], bbox: List[float] + depth_image: np.ndarray, segmentation_mask: np.ndarray | None, bbox: list[float] ) -> float: """ Estimate object depth dimension using segmentation mask and depth data. @@ -633,8 +634,8 @@ def create_pbvs_visualization( image: np.ndarray, current_target=None, position_error=None, - target_reached=False, - grasp_stage="idle", + target_reached: bool = False, + grasp_stage: str = "idle", ) -> np.ndarray: """ Create simple PBVS visualization overlay. @@ -720,9 +721,9 @@ def create_pbvs_visualization( def visualize_detections_3d( rgb_image: np.ndarray, - detections: List[Detection3D], + detections: list[Detection3D], show_coordinates: bool = True, - bboxes_2d: Optional[List[List[float]]] = None, + bboxes_2d: list[list[float]] | None = None, ) -> np.ndarray: """ Visualize detections with 3D position overlay next to bounding boxes. @@ -768,7 +769,7 @@ def visualize_detections_3d( pos_xyz = np.array([position.x, position.y, position.z]) # Get bounding box coordinates - x1, y1, x2, y2 = map(int, bbox) + _x1, y1, x2, _y2 = map(int, bbox) # Add position text next to bounding box (top-right corner) pos_text = f"({pos_xyz[0]:.2f}, {pos_xyz[1]:.2f}, {pos_xyz[2]:.2f})" diff --git a/dimos/mapping/google_maps/conftest.py b/dimos/mapping/google_maps/conftest.py index 48ba9ccf30..09a7843261 100644 --- a/dimos/mapping/google_maps/conftest.py +++ b/dimos/mapping/google_maps/conftest.py @@ -14,11 +14,11 @@ import json from pathlib import Path + import pytest from dimos.mapping.google_maps.google_maps import GoogleMaps - _FIXTURE_DIR = Path(__file__).parent / "fixtures" diff --git a/dimos/mapping/google_maps/google_maps.py b/dimos/mapping/google_maps/google_maps.py index 3c822e2131..e75de042f4 100644 --- a/dimos/mapping/google_maps/google_maps.py +++ b/dimos/mapping/google_maps/google_maps.py @@ -13,20 +13,19 @@ # limitations under the License. import os -from typing import List, Optional, Tuple + import googlemaps -from dimos.mapping.utils.distance import distance_in_meters -from dimos.mapping.types import LatLon -from dimos.utils.logging_config import setup_logger from dimos.mapping.google_maps.types import ( - Position, - PlacePosition, + Coordinates, LocationContext, NearbyPlace, - Coordinates, + PlacePosition, + Position, ) - +from dimos.mapping.types import LatLon +from dimos.mapping.utils.distance import distance_in_meters +from dimos.utils.logging_config import setup_logger logger = setup_logger(__file__) @@ -35,16 +34,14 @@ class GoogleMaps: _client: googlemaps.Client _max_nearby_places: int - def __init__(self, api_key: Optional[str] = None) -> None: + def __init__(self, api_key: str | None = None) -> None: api_key = api_key or os.environ.get("GOOGLE_MAPS_API_KEY") if not api_key: raise ValueError("GOOGLE_MAPS_API_KEY environment variable not set") self._client = googlemaps.Client(key=api_key) self._max_nearby_places = 6 - def get_position( - self, query: str, current_location: Optional[LatLon] = None - ) -> Optional[Position]: + def get_position(self, query: str, current_location: LatLon | None = None) -> Position | None: # Use location bias if current location is provided if current_location: geocode_results = self._client.geocode( @@ -77,8 +74,8 @@ def get_position( ) def get_position_with_places( - self, query: str, current_location: Optional[LatLon] = None - ) -> Optional[PlacePosition]: + self, query: str, current_location: LatLon | None = None + ) -> PlacePosition | None: # Use location bias if current location is provided if current_location: places_results = self._client.places( @@ -110,7 +107,7 @@ def get_position_with_places( def get_location_context( self, latlon: LatLon, radius: int = 100, n_nearby_places: int = 6 - ) -> Optional[LocationContext]: + ) -> LocationContext | None: reverse_geocode_results = self._client.reverse_geocode((latlon.lat, latlon.lon)) if not reverse_geocode_results: @@ -157,7 +154,7 @@ def get_location_context( def _get_nearby_places( self, latlon: LatLon, radius: int, n_nearby_places: int - ) -> Tuple[List[NearbyPlace], str]: + ) -> tuple[list[NearbyPlace], str]: nearby_places = [] place_types_count: dict[str, int] = {} diff --git a/dimos/mapping/google_maps/test_google_maps.py b/dimos/mapping/google_maps/test_google_maps.py index b1d6dd4c99..52e1493ec3 100644 --- a/dimos/mapping/google_maps/test_google_maps.py +++ b/dimos/mapping/google_maps/test_google_maps.py @@ -12,13 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -from dimos.mapping.google_maps.google_maps import GoogleMaps from dimos.mapping.types import LatLon -def test_get_position(maps_client, maps_fixture): +def test_get_position(maps_client, maps_fixture) -> None: maps_client._client.geocode.return_value = maps_fixture("get_position.json") res = maps_client.get_position("golden gate bridge") @@ -30,7 +28,7 @@ def test_get_position(maps_client, maps_fixture): } -def test_get_position_with_places(maps_client, maps_fixture): +def test_get_position_with_places(maps_client, maps_fixture) -> None: maps_client._client.places.return_value = maps_fixture("get_position_with_places.json") res = maps_client.get_position_with_places("golden gate bridge") @@ -48,7 +46,7 @@ def test_get_position_with_places(maps_client, maps_fixture): } -def test_get_location_context(maps_client, maps_fixture): +def test_get_location_context(maps_client, maps_fixture) -> None: maps_client._client.reverse_geocode.return_value = maps_fixture( "get_location_context_reverse_geocode.json" ) diff --git a/dimos/mapping/google_maps/types.py b/dimos/mapping/google_maps/types.py index 909b1ad271..67713f55ee 100644 --- a/dimos/mapping/google_maps/types.py +++ b/dimos/mapping/google_maps/types.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional + from pydantic import BaseModel @@ -38,14 +38,14 @@ class PlacePosition(BaseModel): lon: float description: str address: str - types: List[str] + types: list[str] class NearbyPlace(BaseModel): """Information about a nearby place.""" name: str - types: List[str] + types: list[str] distance: float vicinity: str @@ -53,14 +53,14 @@ class NearbyPlace(BaseModel): class LocationContext(BaseModel): """Contextual information about a location.""" - formatted_address: Optional[str] = None - street_number: Optional[str] = None - street: Optional[str] = None - neighborhood: Optional[str] = None - locality: Optional[str] = None - admin_area: Optional[str] = None - country: Optional[str] = None - postal_code: Optional[str] = None - nearby_places: List[NearbyPlace] = [] - place_types_summary: Optional[str] = None + formatted_address: str | None = None + street_number: str | None = None + street: str | None = None + neighborhood: str | None = None + locality: str | None = None + admin_area: str | None = None + country: str | None = None + postal_code: str | None = None + nearby_places: list[NearbyPlace] = [] + place_types_summary: str | None = None coordinates: Coordinates diff --git a/dimos/mapping/osm/current_location_map.py b/dimos/mapping/osm/current_location_map.py index 3ddc5fb69a..88942935af 100644 --- a/dimos/mapping/osm/current_location_map.py +++ b/dimos/mapping/osm/current_location_map.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from dimos.mapping.osm.osm import MapImage, get_osm_map from dimos.mapping.osm.query import query_for_one_position, query_for_one_position_and_context @@ -20,16 +19,15 @@ from dimos.models.vl.base import VlModel from dimos.utils.logging_config import setup_logger - logger = setup_logger(__file__) class CurrentLocationMap: _vl_model: VlModel - _position: Optional[LatLon] - _map_image: Optional[MapImage] + _position: LatLon | None + _map_image: MapImage | None - def __init__(self, vl_model: VlModel): + def __init__(self, vl_model: VlModel) -> None: self._vl_model = vl_model self._position = None self._map_image = None @@ -41,12 +39,12 @@ def __init__(self, vl_model: VlModel): def update_position(self, position: LatLon) -> None: self._position = position - def query_for_one_position(self, query: str) -> Optional[LatLon]: + def query_for_one_position(self, query: str) -> LatLon | None: return query_for_one_position(self._vl_model, self._get_current_map(), query) def query_for_one_position_and_context( self, query: str, robot_position: LatLon - ) -> Optional[tuple[LatLon, str]]: + ) -> tuple[LatLon, str] | None: return query_for_one_position_and_context( self._vl_model, self._get_current_map(), query, robot_position ) diff --git a/dimos/mapping/osm/demo_osm.py b/dimos/mapping/osm/demo_osm.py index 46f6298591..cf907378f3 100644 --- a/dimos/mapping/osm/demo_osm.py +++ b/dimos/mapping/osm/demo_osm.py @@ -31,14 +31,14 @@ class DemoRobot(Module): gps_location: Out[LatLon] = None - def start(self): + def start(self) -> None: super().start() self._disposables.add(interval(1.0).subscribe(lambda _: self._publish_gps_location())) - def stop(self): + def stop(self) -> None: super().stop() - def _publish_gps_location(self): + def _publish_gps_location(self) -> None: self.gps_location.publish(LatLon(lat=37.78092426217621, lon=-122.40682866540769)) diff --git a/dimos/mapping/osm/osm.py b/dimos/mapping/osm/osm.py index 0890c0d17a..9f967046f6 100644 --- a/dimos/mapping/osm/osm.py +++ b/dimos/mapping/osm/osm.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass -import math import io -from typing import Tuple, Optional -from concurrent.futures import ThreadPoolExecutor, as_completed -import requests +import math + import numpy as np from PIL import Image as PILImage +import requests from dimos.mapping.types import ImageCoord, LatLon from dimos.msgs.sensor_msgs import Image, ImageFormat @@ -96,7 +96,7 @@ def latlon_to_pixel(self, position: LatLon) -> ImageCoord: return (pixel_x, pixel_y) -def _lat_lon_to_tile(lat: float, lon: float, zoom: int) -> Tuple[float, float]: +def _lat_lon_to_tile(lat: float, lon: float, zoom: int) -> tuple[float, float]: """Convert latitude/longitude to tile coordinates at given zoom level.""" n = 2**zoom x_tile = (lon + 180.0) / 360.0 * n @@ -106,8 +106,8 @@ def _lat_lon_to_tile(lat: float, lon: float, zoom: int) -> Tuple[float, float]: def _download_tile( - args: Tuple[int, int, int, int, int], -) -> Tuple[int, int, Optional[PILImage.Image]]: + args: tuple[int, int, int, int, int], +) -> tuple[int, int, PILImage.Image | None]: """Download a single tile. Args: diff --git a/dimos/mapping/osm/query.py b/dimos/mapping/osm/query.py index d4e7d97280..4501525880 100644 --- a/dimos/mapping/osm/query.py +++ b/dimos/mapping/osm/query.py @@ -13,7 +13,6 @@ # limitations under the License. import re -from typing import Optional, Tuple from dimos.mapping.osm.osm import MapImage from dimos.mapping.types import LatLon @@ -21,13 +20,12 @@ from dimos.utils.generic import extract_json_from_llm_response from dimos.utils.logging_config import setup_logger - _PROLOGUE = "This is an image of an open street map I'm on." _JSON = "Please only respond with valid JSON." logger = setup_logger(__name__) -def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> Optional[LatLon]: +def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) -> LatLon | None: full_query = f"{_PROLOGUE} {query} {_JSON} If there's a match return the x, y coordinates from the image. Example: `[123, 321]`. If there's no match return `null`." response = vl_model.query(map_image.image.data, full_query) coords = tuple(map(int, re.findall(r"\d+", response))) @@ -38,7 +36,7 @@ def query_for_one_position(vl_model: VlModel, map_image: MapImage, query: str) - def query_for_one_position_and_context( vl_model: VlModel, map_image: MapImage, query: str, robot_position: LatLon -) -> Optional[Tuple[LatLon, str]]: +) -> tuple[LatLon, str] | None: example = '{"coordinates": [123, 321], "description": "A Starbucks on 27th Street"}' x, y = map_image.latlon_to_pixel(robot_position) my_location = f"I'm currently at x={x}, y={y}." diff --git a/dimos/mapping/osm/test_osm.py b/dimos/mapping/osm/test_osm.py index 516d8bcfc1..0e993f3157 100644 --- a/dimos/mapping/osm/test_osm.py +++ b/dimos/mapping/osm/test_osm.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import requests_mock -import pytest +from collections.abc import Generator +from typing import Any + import cv2 import numpy as np -from typing import Any, Generator +import pytest from requests import Request +import requests_mock from dimos.mapping.osm.osm import get_osm_map from dimos.mapping.types import LatLon @@ -26,7 +28,7 @@ _fixture_dir = get_data("osm_map_test") -def _tile_callback(request: Request, context: Any) -> bytes: # noqa: ANN401 +def _tile_callback(request: Request, context: Any) -> bytes: parts = (request.url or "").split("/") zoom, x, y_png = parts[-3], parts[-2], parts[-1] y = y_png.removesuffix(".png") diff --git a/dimos/mapping/types.py b/dimos/mapping/types.py index 3ceb64c56b..9c39522011 100644 --- a/dimos/mapping/types.py +++ b/dimos/mapping/types.py @@ -14,14 +14,14 @@ from dataclasses import dataclass -from typing import Optional, TypeAlias +from typing import TypeAlias @dataclass(frozen=True) class LatLon: lat: float lon: float - alt: Optional[float] = None + alt: float | None = None ImageCoord: TypeAlias = tuple[int, int] diff --git a/dimos/models/Detic/configs/BoxSup_ViLD_200e.py b/dimos/models/Detic/configs/BoxSup_ViLD_200e.py index b0bc16c30b..b189c7b54f 100644 --- a/dimos/models/Detic/configs/BoxSup_ViLD_200e.py +++ b/dimos/models/Detic/configs/BoxSup_ViLD_200e.py @@ -1,24 +1,23 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os -import torch -import detectron2.data.transforms as T from detectron2.config import LazyCall as L -from detectron2.layers import ShapeSpec from detectron2.data.samplers import RepeatFactorTrainingSampler +import detectron2.data.transforms as T from detectron2.evaluation.lvis_evaluation import LVISEvaluator +from detectron2.layers import ShapeSpec from detectron2.layers.batch_norm import NaiveSyncBatchNorm -from detectron2.solver import WarmupParamScheduler -from detectron2.solver.build import get_default_optimizer_params +from detectron2.model_zoo import get_config +from detectron2.modeling.box_regression import Box2BoxTransform from detectron2.modeling.matcher import Matcher from detectron2.modeling.roi_heads import FastRCNNConvFCHead -from detectron2.modeling.box_regression import Box2BoxTransform -from detectron2.model_zoo import get_config -from fvcore.common.param_scheduler import CosineParamScheduler - -from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier -from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads +from detectron2.solver import WarmupParamScheduler +from detectron2.solver.build import get_default_optimizer_params from detic.modeling.roi_heads.detic_fast_rcnn import DeticFastRCNNOutputLayers +from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads +from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier +from fvcore.common.param_scheduler import CosineParamScheduler +import torch default_configs = get_config("new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py") dataloader = default_configs["dataloader"] @@ -106,4 +105,4 @@ ) train.checkpointer.period = 20000 // num_nodes -train.output_dir = "./output/Lazy/{}".format(os.path.basename(__file__)[:-3]) +train.output_dir = f"./output/Lazy/{os.path.basename(__file__)[:-3]}" diff --git a/dimos/models/Detic/configs/Detic_ViLD_200e.py b/dimos/models/Detic/configs/Detic_ViLD_200e.py index c0983e291c..470124a109 100644 --- a/dimos/models/Detic/configs/Detic_ViLD_200e.py +++ b/dimos/models/Detic/configs/Detic_ViLD_200e.py @@ -1,28 +1,29 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os -import torch -import detectron2.data.transforms as T from detectron2.config import LazyCall as L -from detectron2.layers import ShapeSpec +import detectron2.data.transforms as T from detectron2.evaluation.lvis_evaluation import LVISEvaluator +from detectron2.layers import ShapeSpec from detectron2.layers.batch_norm import NaiveSyncBatchNorm -from detectron2.solver import WarmupParamScheduler -from detectron2.solver.build import get_default_optimizer_params +from detectron2.model_zoo import get_config +from detectron2.modeling.box_regression import Box2BoxTransform from detectron2.modeling.matcher import Matcher from detectron2.modeling.roi_heads import FastRCNNConvFCHead -from detectron2.modeling.box_regression import Box2BoxTransform -from detectron2.model_zoo import get_config -from fvcore.common.param_scheduler import CosineParamScheduler - -from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier -from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads -from detic.modeling.roi_heads.detic_fast_rcnn import DeticFastRCNNOutputLayers +from detectron2.solver import WarmupParamScheduler +from detectron2.solver.build import get_default_optimizer_params +from detic.data.custom_dataset_dataloader import ( + MultiDatasetSampler, + build_custom_train_loader, + get_detection_dataset_dicts_with_source, +) from detic.data.custom_dataset_mapper import CustomDatasetMapper from detic.modeling.meta_arch.custom_rcnn import CustomRCNN -from detic.data.custom_dataset_dataloader import build_custom_train_loader -from detic.data.custom_dataset_dataloader import MultiDatasetSampler -from detic.data.custom_dataset_dataloader import get_detection_dataset_dicts_with_source +from detic.modeling.roi_heads.detic_fast_rcnn import DeticFastRCNNOutputLayers +from detic.modeling.roi_heads.detic_roi_heads import DeticCascadeROIHeads +from detic.modeling.roi_heads.zero_shot_classifier import ZeroShotClassifier +from fvcore.common.param_scheduler import CosineParamScheduler +import torch default_configs = get_config("new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ.py") dataloader = default_configs["dataloader"] @@ -153,4 +154,4 @@ ) train.checkpointer.period = 20000 // num_nodes -train.output_dir = "./output/Lazy/{}".format(os.path.basename(__file__)[:-3]) +train.output_dir = f"./output/Lazy/{os.path.basename(__file__)[:-3]}" diff --git a/dimos/models/Detic/demo.py b/dimos/models/Detic/demo.py index 80efc99884..e982f745a5 100755 --- a/dimos/models/Detic/demo.py +++ b/dimos/models/Detic/demo.py @@ -2,30 +2,29 @@ import argparse import glob import multiprocessing as mp -import numpy as np import os +import sys import tempfile import time import warnings -import cv2 -import tqdm -import sys -import mss +import cv2 from detectron2.config import get_cfg from detectron2.data.detection_utils import read_image from detectron2.utils.logger import setup_logger +import mss +import numpy as np +import tqdm sys.path.insert(0, "third_party/CenterNet2/") from centernet.config import add_centernet_config from detic.config import add_detic_config - from detic.predictor import VisualizationDemo # Fake a video capture object OpenCV style - half width, half height of first screen using MSS class ScreenGrab: - def __init__(self): + def __init__(self) -> None: self.sct = mss.mss() m0 = self.sct.monitors[0] self.monitor = {"top": 0, "left": 0, "width": m0["width"] / 2, "height": m0["height"] / 2} @@ -35,10 +34,10 @@ def read(self): nf = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) return (True, nf) - def isOpened(self): + def isOpened(self) -> bool: return True - def release(self): + def release(self) -> bool: return True @@ -112,7 +111,7 @@ def get_parser(): return parser -def test_opencv_video_format(codec, file_ext): +def test_opencv_video_format(codec, file_ext) -> bool: with tempfile.TemporaryDirectory(prefix="video_format_test") as dir: filename = os.path.join(dir, "test_file" + file_ext) writer = cv2.VideoWriter( @@ -196,7 +195,7 @@ def test_opencv_video_format(codec, file_ext): ("x264", ".mkv") if test_opencv_video_format("x264", ".mkv") else ("mp4v", ".mp4") ) if codec == ".mp4v": - warnings.warn("x264 codec not available, switching to mp4v") + warnings.warn("x264 codec not available, switching to mp4v", stacklevel=2) if args.output: if os.path.isdir(args.output): output_fname = os.path.join(args.output, basename) diff --git a/dimos/models/Detic/detic/__init__.py b/dimos/models/Detic/detic/__init__.py index ecf772726e..2f8aa0a44e 100644 --- a/dimos/models/Detic/detic/__init__.py +++ b/dimos/models/Detic/detic/__init__.py @@ -1,17 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from .data.datasets import cc, coco_zeroshot, imagenet, lvis_v1, objects365, oid +from .modeling.backbone import swintransformer, timm from .modeling.meta_arch import custom_rcnn -from .modeling.roi_heads import detic_roi_heads -from .modeling.roi_heads import res5_roi_heads -from .modeling.backbone import swintransformer -from .modeling.backbone import timm - - -from .data.datasets import lvis_v1 -from .data.datasets import imagenet -from .data.datasets import cc -from .data.datasets import objects365 -from .data.datasets import oid -from .data.datasets import coco_zeroshot +from .modeling.roi_heads import detic_roi_heads, res5_roi_heads try: from .modeling.meta_arch import d2_deformable_detr diff --git a/dimos/models/Detic/detic/config.py b/dimos/models/Detic/detic/config.py index eb8882f3b2..c053f0bd06 100644 --- a/dimos/models/Detic/detic/config.py +++ b/dimos/models/Detic/detic/config.py @@ -2,7 +2,7 @@ from detectron2.config import CfgNode as CN -def add_detic_config(cfg): +def add_detic_config(cfg) -> None: _C = cfg _C.WITH_IMAGE_LABELS = False # Turn on co-training with classification data diff --git a/dimos/models/Detic/detic/custom_solver.py b/dimos/models/Detic/detic/custom_solver.py index 99eb08ed86..1b7cdf1491 100644 --- a/dimos/models/Detic/detic/custom_solver.py +++ b/dimos/models/Detic/detic/custom_solver.py @@ -1,11 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import itertools -from typing import Any, Dict, List, Set -import torch +from typing import Any from detectron2.config import CfgNode - from detectron2.solver.build import maybe_add_gradient_clipping +import torch def match_name_keywords(n, name_keywords): @@ -21,8 +20,8 @@ def build_custom_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim. """ Build an optimizer from config. """ - params: List[Dict[str, Any]] = [] - memo: Set[torch.nn.parameter.Parameter] = set() + params: list[dict[str, Any]] = [] + memo: set[torch.nn.parameter.Parameter] = set() custom_multiplier_name = cfg.SOLVER.CUSTOM_MULTIPLIER_NAME optimizer_type = cfg.SOLVER.OPTIMIZER for key, value in model.named_parameters(recurse=True): @@ -54,7 +53,7 @@ def maybe_add_full_model_gradient_clipping(optim): # optim: the optimizer class ) class FullModelGradientClippingOptimizer(optim): - def step(self, closure=None): + def step(self, closure=None) -> None: all_params = itertools.chain(*[x["params"] for x in self.param_groups]) torch.nn.utils.clip_grad_norm_(all_params, clip_norm_val) super().step(closure=closure) diff --git a/dimos/models/Detic/detic/data/custom_build_augmentation.py b/dimos/models/Detic/detic/data/custom_build_augmentation.py index cd2bba42c2..e54093936b 100644 --- a/dimos/models/Detic/detic/data/custom_build_augmentation.py +++ b/dimos/models/Detic/detic/data/custom_build_augmentation.py @@ -1,11 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. + from detectron2.data import transforms as T + from .transforms.custom_augmentation_impl import EfficientDetResizeCrop -def build_custom_augmentation(cfg, is_train, scale=None, size=None, min_size=None, max_size=None): +def build_custom_augmentation(cfg, is_train: bool, scale=None, size: int | None=None, min_size: int | None=None, max_size: int | None=None): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. diff --git a/dimos/models/Detic/detic/data/custom_dataset_dataloader.py b/dimos/models/Detic/detic/data/custom_dataset_dataloader.py index bfbab55733..0116e04aec 100644 --- a/dimos/models/Detic/detic/data/custom_dataset_dataloader.py +++ b/dimos/models/Detic/detic/data/custom_dataset_dataloader.py @@ -1,26 +1,30 @@ # Copyright (c) Facebook, Inc. and its affiliates. # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/multi_dataset_dataloader.py (Apache-2.0 License) +from collections import defaultdict +from collections.abc import Iterator, Sequence +import itertools +import math import operator -import torch -import torch.utils.data -from detectron2.utils.comm import get_world_size from detectron2.config import configurable -from torch.utils.data.sampler import Sampler +from detectron2.data.build import ( + build_batch_data_loader, + check_metadata_consistency, + filter_images_with_few_keypoints, + filter_images_with_only_crowd_annotations, + get_detection_dataset_dicts, + print_instances_class_histogram, + worker_init_reset_seed, +) +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog from detectron2.data.common import DatasetFromList, MapDataset from detectron2.data.dataset_mapper import DatasetMapper -from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader -from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler -from detectron2.data.build import worker_init_reset_seed, print_instances_class_histogram -from detectron2.data.build import filter_images_with_only_crowd_annotations -from detectron2.data.build import filter_images_with_few_keypoints -from detectron2.data.build import check_metadata_consistency -from detectron2.data.catalog import MetadataCatalog, DatasetCatalog +from detectron2.data.samplers import RepeatFactorTrainingSampler, TrainingSampler from detectron2.utils import comm -import itertools -import math -from collections import defaultdict -from typing import Optional +from detectron2.utils.comm import get_world_size +import torch +import torch.utils.data +from torch.utils.data.sampler import Sampler def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): @@ -65,7 +69,7 @@ def _custom_train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler= ) sampler = RepeatFactorTrainingSampler(repeat_factors) else: - raise ValueError("Unknown training sampler: {}".format(sampler_name)) + raise ValueError(f"Unknown training sampler: {sampler_name}") return { "dataset": dataset_dicts, @@ -87,18 +91,20 @@ def build_custom_train_loader( *, mapper, sampler, - total_batch_size=16, - aspect_ratio_grouping=True, - num_workers=0, - num_datasets=1, - multi_dataset_grouping=False, - use_diff_bs_size=False, - dataset_bs=[], + total_batch_size: int=16, + aspect_ratio_grouping: bool=True, + num_workers: int=0, + num_datasets: int=1, + multi_dataset_grouping: bool=False, + use_diff_bs_size: bool=False, + dataset_bs=None, ): """ Modified from detectron2.data.build.build_custom_train_loader, but supports different samplers """ + if dataset_bs is None: + dataset_bs = [] if isinstance(dataset, list): dataset = DatasetFromList(dataset, copy=False) if mapper is not None: @@ -127,14 +133,12 @@ def build_custom_train_loader( def build_multi_dataset_batch_data_loader( - use_diff_bs_size, dataset_bs, dataset, sampler, total_batch_size, num_datasets, num_workers=0 + use_diff_bs_size: int, dataset_bs, dataset, sampler, total_batch_size: int, num_datasets: int, num_workers: int=0 ): """ """ world_size = get_world_size() assert total_batch_size > 0 and total_batch_size % world_size == 0, ( - "Total batch size ({}) must be divisible by the number of gpus ({}).".format( - total_batch_size, world_size - ) + f"Total batch size ({total_batch_size}) must be divisible by the number of gpus ({world_size})." ) batch_size = total_batch_size // world_size @@ -153,15 +157,15 @@ def build_multi_dataset_batch_data_loader( def get_detection_dataset_dicts_with_source( - dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None + dataset_names: Sequence[str], filter_empty: bool=True, min_keypoints: int=0, proposal_files=None ): assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] - for dataset_name, dicts in zip(dataset_names, dataset_dicts): - assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + for dataset_name, dicts in zip(dataset_names, dataset_dicts, strict=False): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" - for source_id, (dataset_name, dicts) in enumerate(zip(dataset_names, dataset_dicts)): - assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + for source_id, (dataset_name, dicts) in enumerate(zip(dataset_names, dataset_dicts, strict=False)): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" for d in dicts: d["dataset_source"] = source_id @@ -193,9 +197,9 @@ def __init__( dataset_ratio, use_rfs, dataset_ann, - repeat_threshold=0.001, - seed: Optional[int] = None, - ): + repeat_threshold: float=0.001, + seed: int | None = None, + ) -> None: """ """ sizes = [0 for _ in range(len(dataset_ratio))] for d in dataset_dicts: @@ -203,9 +207,7 @@ def __init__( print("dataset sizes", sizes) self.sizes = sizes assert len(dataset_ratio) == len(sizes), ( - "length of dataset ratio {} should be equal to number if dataset {}".format( - len(dataset_ratio), len(sizes) - ) + f"length of dataset ratio {len(dataset_ratio)} should be equal to number if dataset {len(sizes)}" ) if seed is None: seed = comm.shared_random_seed() @@ -219,7 +221,7 @@ def __init__( dataset_weight = [ torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) - for i, (r, s) in enumerate(zip(dataset_ratio, sizes)) + for i, (r, s) in enumerate(zip(dataset_ratio, sizes, strict=False)) ] dataset_weight = torch.cat(dataset_weight) @@ -242,7 +244,7 @@ def __init__( self.weights = dataset_weight * rfs_factors self.sample_epoch_size = len(self.weights) - def __iter__(self): + def __iter__(self) -> Iterator: start = self._rank yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) @@ -253,18 +255,18 @@ def _infinite_indices(self): ids = torch.multinomial( self.weights, self.sample_epoch_size, generator=g, replacement=True ) - nums = [(self.dataset_ids[ids] == i).sum().int().item() for i in range(len(self.sizes))] + [(self.dataset_ids[ids] == i).sum().int().item() for i in range(len(self.sizes))] yield from ids class MDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): - def __init__(self, dataset, batch_size, num_datasets): + def __init__(self, dataset, batch_size: int, num_datasets: int) -> None: """ """ self.dataset = dataset self.batch_size = batch_size self._buckets = [[] for _ in range(2 * num_datasets)] - def __iter__(self): + def __iter__(self) -> Iterator: for d in self.dataset: w, h = d["width"], d["height"] aspect_ratio_bucket_id = 0 if w > h else 1 @@ -277,13 +279,13 @@ def __iter__(self): class DIFFMDAspectRatioGroupedDataset(torch.utils.data.IterableDataset): - def __init__(self, dataset, batch_sizes, num_datasets): + def __init__(self, dataset, batch_sizes: Sequence[int], num_datasets: int) -> None: """ """ self.dataset = dataset self.batch_sizes = batch_sizes self._buckets = [[] for _ in range(2 * num_datasets)] - def __iter__(self): + def __iter__(self) -> Iterator: for d in self.dataset: w, h = d["width"], d["height"] aspect_ratio_bucket_id = 0 if w > h else 1 diff --git a/dimos/models/Detic/detic/data/custom_dataset_mapper.py b/dimos/models/Detic/detic/data/custom_dataset_mapper.py index ed8e6ade59..46c86ffd84 100644 --- a/dimos/models/Detic/detic/data/custom_dataset_mapper.py +++ b/dimos/models/Detic/detic/data/custom_dataset_mapper.py @@ -1,14 +1,13 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import copy import logging -import numpy as np -import torch from detectron2.config import configurable - -from detectron2.data import detection_utils as utils -from detectron2.data import transforms as T +from detectron2.data import detection_utils as utils, transforms as T from detectron2.data.dataset_mapper import DatasetMapper +import numpy as np +import torch + from .custom_build_augmentation import build_custom_augmentation from .tar_dataset import DiskTarDataset @@ -20,19 +19,23 @@ class CustomDatasetMapper(DatasetMapper): def __init__( self, is_train: bool, - with_ann_type=False, - dataset_ann=[], - use_diff_bs_size=False, - dataset_augs=[], - is_debug=False, - use_tar_dataset=False, - tarfile_path="", - tar_index_dir="", + with_ann_type: bool=False, + dataset_ann=None, + use_diff_bs_size: bool=False, + dataset_augs=None, + is_debug: bool=False, + use_tar_dataset: bool=False, + tarfile_path: str="", + tar_index_dir: str="", **kwargs, - ): + ) -> None: """ add image labels """ + if dataset_augs is None: + dataset_augs = [] + if dataset_ann is None: + dataset_ann = [] self.with_ann_type = with_ann_type self.dataset_ann = dataset_ann self.use_diff_bs_size = use_diff_bs_size @@ -65,7 +68,7 @@ def from_config(cls, cfg, is_train: bool = True): dataset_sizes = cfg.DATALOADER.DATASET_INPUT_SIZE ret["dataset_augs"] = [ build_custom_augmentation(cfg, True, scale, size) - for scale, size in zip(dataset_scales, dataset_sizes) + for scale, size in zip(dataset_scales, dataset_sizes, strict=False) ] else: assert cfg.INPUT.CUSTOM_AUG == "ResizeShortestEdge" @@ -73,7 +76,7 @@ def from_config(cls, cfg, is_train: bool = True): max_sizes = cfg.DATALOADER.DATASET_MAX_SIZES ret["dataset_augs"] = [ build_custom_augmentation(cfg, True, min_size=mi, max_size=ma) - for mi, ma in zip(min_sizes, max_sizes) + for mi, ma in zip(min_sizes, max_sizes, strict=False) ] else: ret["dataset_augs"] = [] @@ -103,7 +106,7 @@ def __call__(self, dataset_dict): if self.is_debug: dataset_dict["dataset_source"] = 0 - not_full_labeled = ( + ( "dataset_source" in dataset_dict and self.with_ann_type and self.dataset_ann[dataset_dict["dataset_source"]] != "box" @@ -178,7 +181,7 @@ def __call__(self, dataset_dict): # DETR augmentation -def build_transform_gen(cfg, is_train): +def build_transform_gen(cfg, is_train: bool): """ """ if is_train: min_size = cfg.INPUT.MIN_SIZE_TRAIN @@ -189,9 +192,7 @@ def build_transform_gen(cfg, is_train): max_size = cfg.INPUT.MAX_SIZE_TEST sample_style = "choice" if sample_style == "range": - assert len(min_size) == 2, "more than 2 ({}) min_size(s) are provided for ranges".format( - len(min_size) - ) + assert len(min_size) == 2, f"more than 2 ({len(min_size)}) min_size(s) are provided for ranges" logger = logging.getLogger(__name__) tfm_gens = [] @@ -214,7 +215,7 @@ class DetrDatasetMapper: 4. Prepare image and annotation to Tensors """ - def __init__(self, cfg, is_train=True): + def __init__(self, cfg, is_train: bool=True) -> None: if cfg.INPUT.CROP.ENABLED and is_train: self.crop_gen = [ T.ResizeShortestEdge([400, 500, 600], sample_style="choice"), @@ -226,9 +227,7 @@ def __init__(self, cfg, is_train=True): self.mask_on = cfg.MODEL.MASK_ON self.tfm_gens = build_transform_gen(cfg, is_train) logging.getLogger(__name__).info( - "Full TransformGens used in training: {}, crop: {}".format( - str(self.tfm_gens), str(self.crop_gen) - ) + f"Full TransformGens used in training: {self.tfm_gens!s}, crop: {self.crop_gen!s}" ) self.img_format = cfg.INPUT.FORMAT diff --git a/dimos/models/Detic/detic/data/datasets/cc.py b/dimos/models/Detic/detic/data/datasets/cc.py index 706db88415..be9c7f4a8b 100644 --- a/dimos/models/Detic/detic/data/datasets/cc.py +++ b/dimos/models/Detic/detic/data/datasets/cc.py @@ -2,6 +2,7 @@ import os from detectron2.data.datasets.lvis import get_lvis_instances_meta + from .lvis_v1 import custom_register_lvis_instances _CUSTOM_SPLITS = { diff --git a/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py b/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py index caf169adc9..80c360593d 100644 --- a/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py +++ b/dimos/models/Detic/detic/data/datasets/coco_zeroshot.py @@ -1,8 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os -from detectron2.data.datasets.register_coco import register_coco_instances from detectron2.data.datasets.builtin_meta import _get_builtin_metadata +from detectron2.data.datasets.register_coco import register_coco_instances + from .lvis_v1 import custom_register_lvis_instances categories_seen = [ diff --git a/dimos/models/Detic/detic/data/datasets/imagenet.py b/dimos/models/Detic/detic/data/datasets/imagenet.py index 9b893a704e..caa7aa8fe0 100644 --- a/dimos/models/Detic/detic/data/datasets/imagenet.py +++ b/dimos/models/Detic/detic/data/datasets/imagenet.py @@ -3,10 +3,11 @@ from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets.lvis import get_lvis_instances_meta + from .lvis_v1 import custom_load_lvis_json, get_lvis_22k_meta -def custom_register_imagenet_instances(name, metadata, json_file, image_root): +def custom_register_imagenet_instances(name: str, metadata, json_file, image_root) -> None: """ """ DatasetCatalog.register(name, lambda: custom_load_lvis_json(json_file, image_root, name)) MetadataCatalog.get(name).set( diff --git a/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py b/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py index 2e10b5dd23..d1b3cc370a 100644 --- a/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py +++ b/dimos/models/Detic/detic/data/datasets/lvis_22k_categories.py @@ -22380,4 +22380,4 @@ {"id": 22045, "synset": "planking.n.01", "name": "planking"}, {"id": 22046, "synset": "chipboard.n.01", "name": "chipboard"}, {"id": 22047, "synset": "knothole.n.01", "name": "knothole"}, -] # noqa +] diff --git a/dimos/models/Detic/detic/data/datasets/lvis_v1.py b/dimos/models/Detic/detic/data/datasets/lvis_v1.py index 3eb88bb4a1..4cdd65876f 100644 --- a/dimos/models/Detic/detic/data/datasets/lvis_v1.py +++ b/dimos/models/Detic/detic/data/datasets/lvis_v1.py @@ -2,18 +2,18 @@ import logging import os -from fvcore.common.timer import Timer -from detectron2.structures import BoxMode -from fvcore.common.file_io import PathManager from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.data.datasets.lvis import get_lvis_instances_meta +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer logger = logging.getLogger(__name__) __all__ = ["custom_load_lvis_json", "custom_register_lvis_instances"] -def custom_register_lvis_instances(name, metadata, json_file, image_root): +def custom_register_lvis_instances(name: str, metadata, json_file, image_root) -> None: """ """ DatasetCatalog.register(name, lambda: custom_load_lvis_json(json_file, image_root, name)) MetadataCatalog.get(name).set( @@ -21,7 +21,7 @@ def custom_register_lvis_instances(name, metadata, json_file, image_root): ) -def custom_load_lvis_json(json_file, image_root, dataset_name=None): +def custom_load_lvis_json(json_file, image_root, dataset_name: str | None=None): """ Modifications: use `file_name` @@ -35,7 +35,7 @@ def custom_load_lvis_json(json_file, image_root, dataset_name=None): timer = Timer() lvis_api = LVIS(json_file) if timer.seconds() > 1: - logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + logger.info(f"Loading {json_file} takes {timer.seconds():.2f} seconds.") catid2contid = { x["id"]: i @@ -49,12 +49,10 @@ def custom_load_lvis_json(json_file, image_root, dataset_name=None): anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] - assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format( - json_file - ) + assert len(set(ann_ids)) == len(ann_ids), f"Annotation ids in '{json_file}' are not unique" - imgs_anns = list(zip(imgs, anns)) - logger.info("Loaded {} images in the LVIS v1 format from {}".format(len(imgs_anns), json_file)) + imgs_anns = list(zip(imgs, anns, strict=False)) + logger.info(f"Loaded {len(imgs_anns)} images in the LVIS v1 format from {json_file}") dataset_dicts = [] diff --git a/dimos/models/Detic/detic/data/datasets/objects365.py b/dimos/models/Detic/detic/data/datasets/objects365.py index 6e0a45044e..236e609287 100644 --- a/dimos/models/Detic/detic/data/datasets/objects365.py +++ b/dimos/models/Detic/detic/data/datasets/objects365.py @@ -1,7 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. -from detectron2.data.datasets.register_coco import register_coco_instances import os +from detectron2.data.datasets.register_coco import register_coco_instances + # categories_v2 = [ # {'id': 1, 'name': 'Person'}, # {'id': 2, 'name': 'Sneakers'}, diff --git a/dimos/models/Detic/detic/data/datasets/oid.py b/dimos/models/Detic/detic/data/datasets/oid.py index d3a6fd14b2..0308a8da1d 100644 --- a/dimos/models/Detic/detic/data/datasets/oid.py +++ b/dimos/models/Detic/detic/data/datasets/oid.py @@ -1,8 +1,9 @@ # Part of the code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/data/datasets/oid.py # Copyright (c) Facebook, Inc. and its affiliates. -from .register_oid import register_oid_instances import os +from .register_oid import register_oid_instances + categories = [ {"id": 1, "name": "Infant bed", "freebase_id": "/m/061hd_"}, {"id": 2, "name": "Rose", "freebase_id": "/m/06m11"}, @@ -508,7 +509,7 @@ def _get_builtin_metadata(cats): - id_to_name = {x["id"]: x["name"] for x in cats} + {x["id"]: x["name"] for x in cats} thing_dataset_id_to_contiguous_id = {i + 1: i for i in range(len(cats))} thing_classes = [x["name"] for x in sorted(cats, key=lambda x: x["id"])] return { diff --git a/dimos/models/Detic/detic/data/datasets/register_oid.py b/dimos/models/Detic/detic/data/datasets/register_oid.py index 59a4da9ab7..ded0d4ab29 100644 --- a/dimos/models/Detic/detic/data/datasets/register_oid.py +++ b/dimos/models/Detic/detic/data/datasets/register_oid.py @@ -1,15 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. # Modified by Xingyi Zhou from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/coco.py +import contextlib import io import logging -import contextlib import os - -from fvcore.common.timer import Timer -from fvcore.common.file_io import PathManager -from detectron2.structures import BoxMode from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.structures import BoxMode +from fvcore.common.file_io import PathManager +from fvcore.common.timer import Timer logger = logging.getLogger(__name__) @@ -20,7 +19,7 @@ __all__ = ["register_coco_instances", "register_coco_panoptic_separated"] -def register_oid_instances(name, metadata, json_file, image_root): +def register_oid_instances(name: str, metadata, json_file, image_root) -> None: """ """ # 1. register a function which returns dicts DatasetCatalog.register(name, lambda: load_coco_json_mem_efficient(json_file, image_root, name)) @@ -33,7 +32,7 @@ def register_oid_instances(name, metadata, json_file, image_root): def load_coco_json_mem_efficient( - json_file, image_root, dataset_name=None, extra_annotation_keys=None + json_file, image_root, dataset_name: str | None=None, extra_annotation_keys=None ): """ Actually not mem efficient @@ -45,7 +44,7 @@ def load_coco_json_mem_efficient( with contextlib.redirect_stdout(io.StringIO()): coco_api = COCO(json_file) if timer.seconds() > 1: - logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + logger.info(f"Loading {json_file} takes {timer.seconds():.2f} seconds.") id_map = None if dataset_name is not None: @@ -69,7 +68,7 @@ def load_coco_json_mem_efficient( # sort indices for reproducible results img_ids = sorted(coco_api.imgs.keys()) imgs = coco_api.loadImgs(img_ids) - logger.info("Loaded {} images in COCO format from {}".format(len(imgs), json_file)) + logger.info(f"Loaded {len(imgs)} images in COCO format from {json_file}") dataset_dicts = [] diff --git a/dimos/models/Detic/detic/data/tar_dataset.py b/dimos/models/Detic/detic/data/tar_dataset.py index 323ef7dbb1..8c87a056d1 100644 --- a/dimos/models/Detic/detic/data/tar_dataset.py +++ b/dimos/models/Detic/detic/data/tar_dataset.py @@ -1,9 +1,10 @@ #!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. -import os import gzip -import numpy as np import io +import os + +import numpy as np from PIL import Image from torch.utils.data import Dataset @@ -19,11 +20,11 @@ class DiskTarDataset(Dataset): def __init__( self, - tarfile_path="dataset/imagenet/ImageNet-21k/metadata/tar_files.npy", - tar_index_dir="dataset/imagenet/ImageNet-21k/metadata/tarindex_npy", - preload=False, - num_synsets="all", - ): + tarfile_path: str="dataset/imagenet/ImageNet-21k/metadata/tar_files.npy", + tar_index_dir: str="dataset/imagenet/ImageNet-21k/metadata/tarindex_npy", + preload: bool=False, + num_synsets: str="all", + ) -> None: """ - preload (bool): Recommend to set preload to False when using - num_synsets (integer or string "all"): set to small number for debugging @@ -55,7 +56,7 @@ def __init__( sI += self.dataset_lens[k] self.labels = labels - def __len__(self): + def __len__(self) -> int: return self.num_samples def __getitem__(self, index): @@ -87,13 +88,13 @@ def __getitem__(self, index): # label is the dataset (synset) we indexed into return image, d_index, index - def __repr__(self): + def __repr__(self) -> str: st = f"DiskTarDataset(subdatasets={len(self.dataset_lens)},samples={self.num_samples})" return st -class _TarDataset(object): - def __init__(self, filename, npy_index_dir, preload=False): +class _TarDataset: + def __init__(self, filename, npy_index_dir, preload: bool=False) -> None: # translated from # fbcode/experimental/deeplearning/matthijs/comp_descs/tardataset.lua self.filename = filename @@ -109,7 +110,7 @@ def __init__(self, filename, npy_index_dir, preload=False): else: self.data = None - def __len__(self): + def __len__(self) -> int: return self.num_samples def load_index(self): @@ -119,7 +120,7 @@ def load_index(self): offsets = np.load(os.path.join(self.npy_index_dir, f"{basename}_offsets.npy")) return names, offsets - def __getitem__(self, idx): + def __getitem__(self, idx: int): if self.data is None: self.data = np.memmap(self.filename, mode="r", dtype="uint8") _, self.offsets = self.load_index() diff --git a/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py b/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py index 895eebab79..7cabc91e0f 100644 --- a/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py +++ b/dimos/models/Detic/detic/data/transforms/custom_augmentation_impl.py @@ -1,12 +1,11 @@ -# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py # Modified by Xingyi Zhou # The original code is under Apache-2.0 License +from detectron2.data.transforms.augmentation import Augmentation import numpy as np from PIL import Image -from detectron2.data.transforms.augmentation import Augmentation from .custom_transform import EfficientDetResizeCropTransform __all__ = [ @@ -20,7 +19,7 @@ class EfficientDetResizeCrop(Augmentation): If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. """ - def __init__(self, size, scale, interp=Image.BILINEAR): + def __init__(self, size: int, scale, interp=Image.BILINEAR) -> None: """ """ super().__init__() self.target_size = (size, size) diff --git a/dimos/models/Detic/detic/data/transforms/custom_transform.py b/dimos/models/Detic/detic/data/transforms/custom_transform.py index a451c0ee85..2017c27a5f 100644 --- a/dimos/models/Detic/detic/data/transforms/custom_transform.py +++ b/dimos/models/Detic/detic/data/transforms/custom_transform.py @@ -1,18 +1,17 @@ -# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Part of the code is from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/data/transforms.py # Modified by Xingyi Zhou # The original code is under Apache-2.0 License -import numpy as np -import torch -import torch.nn.functional as F from fvcore.transforms.transform import ( Transform, ) +import numpy as np from PIL import Image +import torch +import torch.nn.functional as F try: - import cv2 # noqa + import cv2 except ImportError: # OpenCV is an optional dependency at the moment pass @@ -25,7 +24,7 @@ class EfficientDetResizeCropTransform(Transform): """ """ - def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size, interp=None): + def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size: int, interp=None) -> None: """ Args: h, w (int): original image size diff --git a/dimos/models/Detic/detic/evaluation/custom_coco_eval.py b/dimos/models/Detic/detic/evaluation/custom_coco_eval.py index b4bbc9fc94..ce9319dc67 100644 --- a/dimos/models/Detic/detic/evaluation/custom_coco_eval.py +++ b/dimos/models/Detic/detic/evaluation/custom_coco_eval.py @@ -1,15 +1,17 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from collections.abc import Sequence import itertools -import numpy as np -from tabulate import tabulate from detectron2.evaluation.coco_evaluation import COCOEvaluator from detectron2.utils.logger import create_small_table +import numpy as np +from tabulate import tabulate + from ..data.datasets.coco_zeroshot import categories_seen, categories_unseen class CustomCOCOEvaluator(COCOEvaluator): - def _derive_coco_results(self, coco_eval, iou_type, class_names=None): + def _derive_coco_results(self, coco_eval, iou_type, class_names: Sequence[str] | None=None): """ Additionally plot mAP for 'seen classes' and 'unseen classes' """ @@ -30,7 +32,7 @@ def _derive_coco_results(self, coco_eval, iou_type, class_names=None): for idx, metric in enumerate(metrics) } self._logger.info( - "Evaluation results for {}: \n".format(iou_type) + create_small_table(results) + f"Evaluation results for {iou_type}: \n" + create_small_table(results) ) if not np.isfinite(sum(results.values())): self._logger.info("Some metrics cannot be computed and is shown as NaN.") @@ -38,7 +40,7 @@ def _derive_coco_results(self, coco_eval, iou_type, class_names=None): if class_names is None or len(class_names) <= 1: return results # Compute per-category AP - # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa + # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 precisions = coco_eval.eval["precision"] # precision has dims (iou, recall, cls, area range, max dets) assert len(class_names) == precisions.shape[2] @@ -55,11 +57,11 @@ def _derive_coco_results(self, coco_eval, iou_type, class_names=None): precision = precisions[:, :, idx, 0, -1] precision = precision[precision > -1] ap = np.mean(precision) if precision.size else float("nan") - results_per_category.append(("{}".format(name), float(ap * 100))) + results_per_category.append((f"{name}", float(ap * 100))) precision50 = precisions[0, :, idx, 0, -1] precision50 = precision50[precision50 > -1] ap50 = np.mean(precision50) if precision50.size else float("nan") - results_per_category50.append(("{}".format(name), float(ap50 * 100))) + results_per_category50.append((f"{name}", float(ap50 * 100))) if name in seen_names: results_per_category50_seen.append(float(ap50 * 100)) if name in unseen_names: @@ -76,7 +78,7 @@ def _derive_coco_results(self, coco_eval, iou_type, class_names=None): headers=["category", "AP"] * (N_COLS // 2), numalign="left", ) - self._logger.info("Per-category {} AP: \n".format(iou_type) + table) + self._logger.info(f"Per-category {iou_type} AP: \n" + table) N_COLS = min(6, len(results_per_category50) * 2) results_flatten = list(itertools.chain(*results_per_category50)) @@ -88,18 +90,12 @@ def _derive_coco_results(self, coco_eval, iou_type, class_names=None): headers=["category", "AP50"] * (N_COLS // 2), numalign="left", ) - self._logger.info("Per-category {} AP50: \n".format(iou_type) + table) + self._logger.info(f"Per-category {iou_type} AP50: \n" + table) self._logger.info( - "Seen {} AP50: {}".format( - iou_type, - sum(results_per_category50_seen) / len(results_per_category50_seen), - ) + f"Seen {iou_type} AP50: {sum(results_per_category50_seen) / len(results_per_category50_seen)}" ) self._logger.info( - "Unseen {} AP50: {}".format( - iou_type, - sum(results_per_category50_unseen) / len(results_per_category50_unseen), - ) + f"Unseen {iou_type} AP50: {sum(results_per_category50_unseen) / len(results_per_category50_unseen)}" ) results.update({"AP-" + name: ap for name, ap in results_per_category}) diff --git a/dimos/models/Detic/detic/evaluation/oideval.py b/dimos/models/Detic/detic/evaluation/oideval.py index d52a151371..3ba53ddfde 100644 --- a/dimos/models/Detic/detic/evaluation/oideval.py +++ b/dimos/models/Detic/detic/evaluation/oideval.py @@ -8,29 +8,27 @@ # The code is from https://github.com/xingyizhou/UniDet/blob/master/projects/UniDet/unidet/evaluation/oideval.py # The original code is under Apache-2.0 License # Copyright (c) Facebook, Inc. and its affiliates. -import os +from collections import OrderedDict, defaultdict +from collections.abc import Sequence +import copy import datetime -import logging import itertools -from collections import OrderedDict -from collections import defaultdict -import copy import json -import numpy as np -import torch -from tabulate import tabulate - -from lvis.lvis import LVIS -from lvis.results import LVISResults - -import pycocotools.mask as mask_utils +import logging +import os -from fvcore.common.file_io import PathManager -import detectron2.utils.comm as comm from detectron2.data import MetadataCatalog +from detectron2.evaluation import DatasetEvaluator from detectron2.evaluation.coco_evaluation import instances_to_coco_json +import detectron2.utils.comm as comm from detectron2.utils.logger import create_small_table -from detectron2.evaluation import DatasetEvaluator +from fvcore.common.file_io import PathManager +from lvis.lvis import LVIS +from lvis.results import LVISResults +import numpy as np +import pycocotools.mask as mask_utils +from tabulate import tabulate +import torch def compute_average_precision(precision, recall): @@ -81,10 +79,10 @@ def __init__( self, lvis_gt, lvis_dt, - iou_type="bbox", - expand_pred_label=False, - oid_hierarchy_path="./datasets/oid/annotations/challenge-2019-label500-hierarchy.json", - ): + iou_type: str="bbox", + expand_pred_label: bool=False, + oid_hierarchy_path: str="./datasets/oid/annotations/challenge-2019-label500-hierarchy.json", + ) -> None: """Constructor for OIDEval. Args: lvis_gt (LVIS class instance, or str containing path of annotation file) @@ -95,14 +93,14 @@ def __init__( self.logger = logging.getLogger(__name__) if iou_type not in ["bbox", "segm"]: - raise ValueError("iou_type: {} is not supported.".format(iou_type)) + raise ValueError(f"iou_type: {iou_type} is not supported.") if isinstance(lvis_gt, LVIS): self.lvis_gt = lvis_gt elif isinstance(lvis_gt, str): self.lvis_gt = LVIS(lvis_gt) else: - raise TypeError("Unsupported type {} of lvis_gt.".format(lvis_gt)) + raise TypeError(f"Unsupported type {lvis_gt} of lvis_gt.") if isinstance(lvis_dt, LVISResults): self.lvis_dt = lvis_dt @@ -110,20 +108,19 @@ def __init__( # self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt, max_dets=-1) self.lvis_dt = LVISResults(self.lvis_gt, lvis_dt) else: - raise TypeError("Unsupported type {} of lvis_dt.".format(lvis_dt)) + raise TypeError(f"Unsupported type {lvis_dt} of lvis_dt.") if expand_pred_label: - oid_hierarchy = json.load(open(oid_hierarchy_path, "r")) + oid_hierarchy = json.load(open(oid_hierarchy_path)) cat_info = self.lvis_gt.dataset["categories"] freebase2id = {x["freebase_id"]: x["id"] for x in cat_info} - id2freebase = {x["id"]: x["freebase_id"] for x in cat_info} - id2name = {x["id"]: x["name"] for x in cat_info} + {x["id"]: x["freebase_id"] for x in cat_info} + {x["id"]: x["name"] for x in cat_info} fas = defaultdict(set) def dfs(hierarchy, cur_id): all_childs = set() - all_keyed_child = {} if "Subcategory" in hierarchy: for x in hierarchy["Subcategory"]: childs = dfs(x, freebase2id[x["LabelName"]]) @@ -168,12 +165,12 @@ def dfs(hierarchy, cur_id): self.params.img_ids = sorted(self.lvis_gt.get_img_ids()) self.params.cat_ids = sorted(self.lvis_gt.get_cat_ids()) - def _to_mask(self, anns, lvis): + def _to_mask(self, anns, lvis) -> None: for ann in anns: rle = lvis.ann_to_rle(ann) ann["segmentation"] = rle - def _prepare(self): + def _prepare(self) -> None: """Prepare self._gts and self._dts for evaluation based on params.""" cat_ids = self.params.cat_ids if self.params.cat_ids else None @@ -214,13 +211,13 @@ def _prepare(self): continue self._dts[img_id, cat_id].append(dt) - def evaluate(self): + def evaluate(self) -> None: """ Run per image evaluation on given images and store results (a list of dict) in self.eval_imgs. """ self.logger.info("Running per image evaluation.") - self.logger.info("Evaluate annotation type *{}*".format(self.params.iou_type)) + self.logger.info(f"Evaluate annotation type *{self.params.iou_type}*") self.params.img_ids = list(np.unique(self.params.img_ids)) @@ -322,7 +319,7 @@ def evaluate_img_google(self, img_id, cat_id, area_rng): tp_fp_labels = np.zeros(num_detected_boxes, dtype=bool) is_matched_to_group_of = np.zeros(num_detected_boxes, dtype=bool) - def compute_match_iou(iou): + def compute_match_iou(iou) -> None: max_overlap_gt_ids = np.argmax(iou, axis=1) is_gt_detected = np.zeros(iou.shape[1], dtype=bool) for i in range(num_detected_boxes): @@ -381,7 +378,7 @@ def compute_match_ioa(ioa): "num_gt": len(gt), } - def accumulate(self): + def accumulate(self) -> None: """Accumulate per image evaluation results and store the result in self.eval. """ @@ -444,7 +441,7 @@ def accumulate(self): "fps": fps, } - for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum)): + for iou_thr_idx, (tp, fp) in enumerate(zip(tp_sum, fp_sum, strict=False)): tp = np.array(tp) fp = np.array(fp) num_tp = len(tp) @@ -491,16 +488,15 @@ def summarize(self): if not self.eval: raise RuntimeError("Please run accumulate() first.") - max_dets = self.params.max_dets self.results["AP50"] = self._summarize("ap") - def run(self): + def run(self) -> None: """Wrapper function which calculates the results.""" self.evaluate() self.accumulate() self.summarize() - def print_results(self): + def print_results(self) -> None: template = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} catIds={:>3s}] = {:0.3f}" for key, value in self.results.items(): @@ -514,9 +510,9 @@ def print_results(self): if len(key) > 2 and key[2].isdigit(): iou_thr = float(key[2:]) / 100 - iou = "{:0.2f}".format(iou_thr) + iou = f"{iou_thr:0.2f}" else: - iou = "{:0.2f}:{:0.2f}".format(self.params.iou_thrs[0], self.params.iou_thrs[-1]) + iou = f"{self.params.iou_thrs[0]:0.2f}:{self.params.iou_thrs[-1]:0.2f}" cat_group_name = "all" area_rng = "all" @@ -530,7 +526,7 @@ def get_results(self): class Params: - def __init__(self, iou_type): + def __init__(self, iou_type) -> None: self.img_ids = [] self.cat_ids = [] # np.arange causes trouble. the data point on arange is slightly @@ -552,7 +548,7 @@ def __init__(self, iou_type): class OIDEvaluator(DatasetEvaluator): - def __init__(self, dataset_name, cfg, distributed, output_dir=None): + def __init__(self, dataset_name: str, cfg, distributed, output_dir=None) -> None: self._distributed = distributed self._output_dir = output_dir @@ -567,12 +563,12 @@ def __init__(self, dataset_name, cfg, distributed, output_dir=None): self._do_evaluation = len(self._oid_api.get_ann_ids()) > 0 self._mask_on = cfg.MODEL.MASK_ON - def reset(self): + def reset(self) -> None: self._predictions = [] self._oid_results = [] - def process(self, inputs, outputs): - for input, output in zip(inputs, outputs): + def process(self, inputs, outputs) -> None: + for input, output in zip(inputs, outputs, strict=False): prediction = {"image_id": input["image_id"]} instances = output["instances"].to(self._cpu_device) prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) @@ -600,7 +596,7 @@ def evaluate(self): PathManager.mkdirs(self._output_dir) file_path = os.path.join(self._output_dir, "oid_instances_results.json") - self._logger.info("Saving results to {}".format(file_path)) + self._logger.info(f"Saving results to {file_path}") with PathManager.open(file_path, "w") as f: f.write(json.dumps(self._oid_results)) f.flush() @@ -624,9 +620,8 @@ def evaluate(self): return copy.deepcopy(self._results) -def _evaluate_predictions_on_oid(oid_gt, oid_results_path, eval_seg=False, class_names=None): +def _evaluate_predictions_on_oid(oid_gt, oid_results_path, eval_seg: bool=False, class_names: Sequence[str] | None=None): logger = logging.getLogger(__name__) - metrics = ["AP50", "AP50_expand"] results = {} oid_eval = OIDEval(oid_gt, oid_results_path, "bbox", expand_pred_label=False) @@ -661,7 +656,7 @@ def _evaluate_predictions_on_oid(oid_gt, oid_results_path, eval_seg=False, class ( "{} {}".format( name.replace(" ", "_"), - inst_num if inst_num < 1000 else "{:.1f}k".format(inst_num / 1000), + inst_num if inst_num < 1000 else f"{inst_num / 1000:.1f}k", ), float(ap * 100), ) diff --git a/dimos/models/Detic/detic/modeling/backbone/swintransformer.py b/dimos/models/Detic/detic/modeling/backbone/swintransformer.py index 541d3c99dc..5002c96a45 100644 --- a/dimos/models/Detic/detic/modeling/backbone/swintransformer.py +++ b/dimos/models/Detic/detic/modeling/backbone/swintransformer.py @@ -9,20 +9,21 @@ # Modified by Xingyi Zhou from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.utils.checkpoint as checkpoint -import numpy as np -from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from collections.abc import Sequence +from centernet.modeling.backbone.bifpn import BiFPN +from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5 from detectron2.layers import ShapeSpec from detectron2.modeling.backbone.backbone import Backbone from detectron2.modeling.backbone.build import BACKBONE_REGISTRY from detectron2.modeling.backbone.fpn import FPN +import numpy as np +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint -from centernet.modeling.backbone.fpn_p5 import LastLevelP6P7_P5 -from centernet.modeling.backbone.bifpn import BiFPN # from .checkpoint import load_checkpoint @@ -30,8 +31,8 @@ class Mlp(nn.Module): """Multilayer perceptron.""" def __init__( - self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 - ): + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop: float=0.0 + ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -49,7 +50,7 @@ def forward(self, x): return x -def window_partition(x, window_size): +def window_partition(x, window_size: int): """ Args: x: (B, H, W, C) @@ -63,7 +64,7 @@ def window_partition(x, window_size): return windows -def window_reverse(windows, window_size, H, W): +def window_reverse(windows, window_size: int, H, W): """ Args: windows: (num_windows*B, window_size, window_size, C) @@ -94,14 +95,14 @@ class WindowAttention(nn.Module): def __init__( self, - dim, - window_size, - num_heads, - qkv_bias=True, + dim: int, + window_size: int, + num_heads: int, + qkv_bias: bool=True, qk_scale=None, - attn_drop=0.0, - proj_drop=0.0, - ): + attn_drop: float=0.0, + proj_drop: float=0.0, + ) -> None: super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww @@ -197,19 +198,19 @@ class SwinTransformerBlock(nn.Module): def __init__( self, - dim, - num_heads, - window_size=7, - shift_size=0, - mlp_ratio=4.0, - qkv_bias=True, + dim: int, + num_heads: int, + window_size: int=7, + shift_size: int=0, + mlp_ratio: float=4.0, + qkv_bias: bool=True, qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, + drop: float=0.0, + attn_drop: float=0.0, + drop_path: float=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, - ): + ) -> None: super().__init__() self.dim = dim self.num_heads = num_heads @@ -309,7 +310,7 @@ class PatchMerging(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm """ - def __init__(self, dim, norm_layer=nn.LayerNorm): + def __init__(self, dim: int, norm_layer=nn.LayerNorm) -> None: super().__init__() self.dim = dim self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) @@ -364,20 +365,20 @@ class BasicLayer(nn.Module): def __init__( self, - dim, - depth, - num_heads, - window_size=7, - mlp_ratio=4.0, - qkv_bias=True, + dim: int, + depth: int, + num_heads: int, + window_size: int=7, + mlp_ratio: float=4.0, + qkv_bias: bool=True, qk_scale=None, - drop=0.0, - attn_drop=0.0, - drop_path=0.0, + drop: float=0.0, + attn_drop: float=0.0, + drop_path: float=0.0, norm_layer=nn.LayerNorm, downsample=None, - use_checkpoint=False, - ): + use_checkpoint: bool=False, + ) -> None: super().__init__() self.window_size = window_size self.shift_size = window_size // 2 @@ -442,8 +443,8 @@ def forward(self, x, H, W): ) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) - attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( - attn_mask == 0, float(0.0) + attn_mask = attn_mask.masked_fill(attn_mask != 0, (-100.0)).masked_fill( + attn_mask == 0, 0.0 ) for blk in self.blocks: @@ -469,7 +470,7 @@ class PatchEmbed(nn.Module): norm_layer (nn.Module, optional): Normalization layer. Default: None """ - def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + def __init__(self, patch_size: int=4, in_chans: int=3, embed_dim: int=96, norm_layer=None) -> None: super().__init__() patch_size = to_2tuple(patch_size) self.patch_size = patch_size @@ -532,26 +533,30 @@ class SwinTransformer(Backbone): def __init__( self, - pretrain_img_size=224, - patch_size=4, - in_chans=3, - embed_dim=96, - depths=[2, 2, 6, 2], - num_heads=[3, 6, 12, 24], - window_size=7, - mlp_ratio=4.0, - qkv_bias=True, + pretrain_img_size: int=224, + patch_size: int=4, + in_chans: int=3, + embed_dim: int=96, + depths: Sequence[int] | None=None, + num_heads: int | None=None, + window_size: int=7, + mlp_ratio: float=4.0, + qkv_bias: bool=True, qk_scale=None, - drop_rate=0.0, - attn_drop_rate=0.0, - drop_path_rate=0.2, + drop_rate: float=0.0, + attn_drop_rate: float=0.0, + drop_path_rate: float=0.2, norm_layer=nn.LayerNorm, - ape=False, - patch_norm=True, + ape: bool=False, + patch_norm: bool=True, out_indices=(0, 1, 2, 3), frozen_stages=-1, - use_checkpoint=False, - ): + use_checkpoint: bool=False, + ) -> None: + if num_heads is None: + num_heads = [3, 6, 12, 24] + if depths is None: + depths = [2, 2, 6, 2] super().__init__() self.pretrain_img_size = pretrain_img_size @@ -621,14 +626,14 @@ def __init__( self.add_module(layer_name, layer) self._freeze_stages() - self._out_features = ["swin{}".format(i) for i in self.out_indices] + self._out_features = [f"swin{i}" for i in self.out_indices] self._out_feature_channels = { - "swin{}".format(i): self.embed_dim * 2**i for i in self.out_indices + f"swin{i}": self.embed_dim * 2**i for i in self.out_indices } - self._out_feature_strides = {"swin{}".format(i): 2 ** (i + 2) for i in self.out_indices} + self._out_feature_strides = {f"swin{i}": 2 ** (i + 2) for i in self.out_indices} self._size_devisibility = 32 - def _freeze_stages(self): + def _freeze_stages(self) -> None: if self.frozen_stages >= 0: self.patch_embed.eval() for param in self.patch_embed.parameters(): @@ -645,14 +650,14 @@ def _freeze_stages(self): for param in m.parameters(): param.requires_grad = False - def init_weights(self, pretrained=None): + def init_weights(self, pretrained: bool | None=None): """Initialize the weights in backbone. Args: pretrained (str, optional): Path to pre-trained weights. Defaults to None. """ - def _init_weights(m): + def _init_weights(m) -> None: if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: @@ -696,13 +701,13 @@ def forward(self, x): out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() # outs.append(out) - outs["swin{}".format(i)] = out + outs[f"swin{i}"] = out return outs - def train(self, mode=True): + def train(self, mode: bool=True) -> None: """Convert the model into training mode while keep layers freezed.""" - super(SwinTransformer, self).train(mode) + super().train(mode) self._freeze_stages() diff --git a/dimos/models/Detic/detic/modeling/backbone/timm.py b/dimos/models/Detic/detic/modeling/backbone/timm.py index 8b7dd00006..a15e03f875 100644 --- a/dimos/models/Detic/detic/modeling/backbone/timm.py +++ b/dimos/models/Detic/detic/modeling/backbone/timm.py @@ -1,28 +1,23 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. import copy -import torch -from torch import nn -import torch.nn.functional as F -import fvcore.nn.weight_init as weight_init - -from detectron2.modeling.backbone import FPN -from detectron2.modeling.backbone.build import BACKBONE_REGISTRY from detectron2.layers.batch_norm import FrozenBatchNorm2d -from detectron2.modeling.backbone import Backbone - +from detectron2.modeling.backbone import FPN, Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +import fvcore.nn.weight_init as weight_init from timm import create_model +from timm.models.convnext import ConvNeXt, checkpoint_filter_fn, default_cfgs from timm.models.helpers import build_model_with_cfg from timm.models.registry import register_model -from timm.models.resnet import ResNet, Bottleneck -from timm.models.resnet import default_cfgs as default_cfgs_resnet -from timm.models.convnext import ConvNeXt, default_cfgs, checkpoint_filter_fn +from timm.models.resnet import Bottleneck, ResNet, default_cfgs as default_cfgs_resnet +import torch +from torch import nn +import torch.nn.functional as F @register_model -def convnext_tiny_21k(pretrained=False, **kwargs): +def convnext_tiny_21k(pretrained: bool=False, **kwargs): model_args = dict(depths=(3, 3, 9, 3), dims=(96, 192, 384, 768), **kwargs) cfg = default_cfgs["convnext_tiny"] cfg["url"] = "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth" @@ -39,7 +34,7 @@ def convnext_tiny_21k(pretrained=False, **kwargs): class CustomResNet(ResNet): - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: self.out_indices = kwargs.pop("out_indices") super().__init__(**kwargs) @@ -59,7 +54,7 @@ def forward(self, x): ret.append(x) return [ret[i] for i in self.out_indices] - def load_pretrained(self, cached_file): + def load_pretrained(self, cached_file) -> None: data = torch.load(cached_file, map_location="cpu") if "state_dict" in data: self.load_state_dict(data["state_dict"]) @@ -72,7 +67,7 @@ def load_pretrained(self, cached_file): } -def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs): +def create_timm_resnet(variant, out_indices, pretrained: bool=False, **kwargs): params = model_params[variant] default_cfgs_resnet["resnet50_in21k"] = copy.deepcopy(default_cfgs_resnet["resnet50"]) default_cfgs_resnet["resnet50_in21k"]["url"] = ( @@ -95,7 +90,7 @@ def create_timm_resnet(variant, out_indices, pretrained=False, **kwargs): class LastLevelP6P7_P5(nn.Module): """ """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels) -> None: super().__init__() self.num_levels = 2 self.in_feature = "p5" @@ -119,7 +114,7 @@ def freeze_module(x): class TIMM(Backbone): - def __init__(self, base_name, out_levels, freeze_at=0, norm="FrozenBN", pretrained=False): + def __init__(self, base_name: str, out_levels, freeze_at: int=0, norm: str="FrozenBN", pretrained: bool=False) -> None: super().__init__() out_indices = [x - 1 for x in out_levels] if base_name in model_params: @@ -143,12 +138,12 @@ def __init__(self, base_name, out_levels, freeze_at=0, norm="FrozenBN", pretrain dict(num_chs=f["num_chs"], reduction=f["reduction"]) for i, f in enumerate(self.base.feature_info) ] - self._out_features = ["layer{}".format(x) for x in out_levels] + self._out_features = [f"layer{x}" for x in out_levels] self._out_feature_channels = { - "layer{}".format(l): feature_info[l - 1]["num_chs"] for l in out_levels + f"layer{l}": feature_info[l - 1]["num_chs"] for l in out_levels } self._out_feature_strides = { - "layer{}".format(l): feature_info[l - 1]["reduction"] for l in out_levels + f"layer{l}": feature_info[l - 1]["reduction"] for l in out_levels } self._size_divisibility = max(self._out_feature_strides.values()) if "resnet" in base_name: @@ -156,7 +151,7 @@ def __init__(self, base_name, out_levels, freeze_at=0, norm="FrozenBN", pretrain if norm == "FrozenBN": self = FrozenBatchNorm2d.convert_frozen_batchnorm(self) - def freeze(self, freeze_at=0): + def freeze(self, freeze_at: int=0) -> None: """ """ if freeze_at >= 1: print("Frezing", self.base.conv1) @@ -167,7 +162,7 @@ def freeze(self, freeze_at=0): def forward(self, x): features = self.base(x) - ret = {k: v for k, v in zip(self._out_features, features)} + ret = {k: v for k, v in zip(self._out_features, features, strict=False)} return ret @property diff --git a/dimos/models/Detic/detic/modeling/debug.py b/dimos/models/Detic/detic/modeling/debug.py index 21136de2f0..5f0cc7c9fc 100644 --- a/dimos/models/Detic/detic/modeling/debug.py +++ b/dimos/models/Detic/detic/modeling/debug.py @@ -1,9 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from collections.abc import Sequence +import os + import cv2 import numpy as np import torch import torch.nn.functional as F -import os COLORS = ((np.random.rand(1300, 3) * 0.4 + 0.6) * 255).astype(np.uint8).reshape(1300, 1, 1, 3) @@ -20,13 +22,13 @@ def _get_color_image(heatmap): return color_map -def _blend_image(image, color_map, a=0.7): +def _blend_image(image, color_map, a: float=0.7): color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8) return ret -def _blend_image_heatmaps(image, color_maps, a=0.7): +def _blend_image_heatmaps(image, color_maps, a: float=0.7): merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32) for color_map in color_maps: color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) @@ -80,12 +82,12 @@ def debug_train( gt_instances, flattened_hms, reg_targets, - labels, + labels: Sequence[str], pos_inds, shapes_per_level, locations, - strides, -): + strides: Sequence[int], +) -> None: """ images: N x 3 x H x W flattened_hms: LNHiWi x C @@ -107,7 +109,7 @@ def debug_train( for l in range(len(gt_hms)): color_map = _get_color_image(gt_hms[l][i].detach().cpu().numpy()) color_maps.append(color_map) - cv2.imshow("gthm_{}".format(l), color_map) + cv2.imshow(f"gthm_{l}", color_map) blend = _blend_image_heatmaps(image.copy(), color_maps) if gt_instances is not None: bboxes = gt_instances[i].gt_boxes.tensor @@ -157,22 +159,26 @@ def debug_test( images, logits_pred, reg_pred, - agn_hm_pred=[], - preds=[], - vis_thresh=0.3, - debug_show_name=False, - mult_agn=False, -): + agn_hm_pred=None, + preds=None, + vis_thresh: float=0.3, + debug_show_name: bool=False, + mult_agn: bool=False, +) -> None: """ images: N x 3 x H x W class_target: LNHiWi x C cat_agn_heatmap: LNHiWi shapes_per_level: L x 2 [(H_i, W_i)] """ - N = len(images) + if preds is None: + preds = [] + if agn_hm_pred is None: + agn_hm_pred = [] + len(images) for i in range(len(images)): image = images[i].detach().cpu().numpy().transpose(1, 2, 0) - result = image.copy().astype(np.uint8) + image.copy().astype(np.uint8) pred_image = image.copy().astype(np.uint8) color_maps = [] L = len(logits_pred) @@ -191,7 +197,7 @@ def debug_test( logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i] color_map = _get_color_image(logits_pred[l][i].detach().cpu().numpy()) color_maps.append(color_map) - cv2.imshow("predhm_{}".format(l), color_map) + cv2.imshow(f"predhm_{l}", color_map) if debug_show_name: from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES @@ -242,7 +248,7 @@ def debug_test( if agn_hm_pred[l] is not None: agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy() agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(1, 1, 3)).astype(np.uint8) - cv2.imshow("agn_hm_{}".format(l), agn_hm_) + cv2.imshow(f"agn_hm_{l}", agn_hm_) blend = _blend_image_heatmaps(image.copy(), color_maps) cv2.imshow("blend", blend) cv2.imshow("preds", pred_image) @@ -257,13 +263,15 @@ def debug_second_stage( images, instances, proposals=None, - vis_thresh=0.3, - save_debug=False, - debug_show_name=False, - image_labels=[], - save_debug_path="output/save_debug/", - bgr=False, -): + vis_thresh: float=0.3, + save_debug: bool=False, + debug_show_name: bool=False, + image_labels: Sequence[str] | None=None, + save_debug_path: str="output/save_debug/", + bgr: bool=False, +) -> None: + if image_labels is None: + image_labels = [] images = _imagelist_to_tensor(images) if "COCO" in save_debug_path: from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES @@ -358,7 +366,7 @@ def debug_second_stage( ) if selected[j] >= 0 and debug_show_name: cat = selected[j].item() - txt = "{}".format(cat2name[cat]) + txt = f"{cat2name[cat]}" font = cv2.FONT_HERSHEY_SIMPLEX cat_size = cv2.getTextSize(txt, font, 0.5, 2)[0] cv2.rectangle( @@ -384,13 +392,13 @@ def debug_second_stage( cnt = (cnt + 1) % 5000 if not os.path.exists(save_debug_path): os.mkdir(save_debug_path) - save_name = "{}/{:05d}.jpg".format(save_debug_path, cnt) + save_name = f"{save_debug_path}/{cnt:05d}.jpg" if i < len(image_labels): image_label = image_labels[i] - save_name = "{}/{:05d}".format(save_debug_path, cnt) + save_name = f"{save_debug_path}/{cnt:05d}" for x in image_label: class_name = cat2name[x] - save_name = save_name + "|{}".format(class_name) + save_name = save_name + f"|{class_name}" save_name = save_name + ".jpg" cv2.imwrite(save_name, proposal_image) else: diff --git a/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py b/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py index 5711c87beb..019f4e6f84 100644 --- a/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py +++ b/dimos/models/Detic/detic/modeling/meta_arch/custom_rcnn.py @@ -1,17 +1,16 @@ # Copyright (c) Facebook, Inc. and its affiliates. -from typing import Dict, List, Optional, Tuple -import torch -from detectron2.utils.events import get_event_storage -from detectron2.config import configurable -from detectron2.structures import Instances -import detectron2.utils.comm as comm +from detectron2.config import configurable from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY from detectron2.modeling.meta_arch.rcnn import GeneralizedRCNN - +from detectron2.structures import Instances +import detectron2.utils.comm as comm +from detectron2.utils.events import get_event_storage +import torch from torch.cuda.amp import autocast + from ..text.text_encoder import build_text_encoder -from ..utils import load_class_freq, get_fed_loss_inds +from ..utils import get_fed_loss_inds, load_class_freq @META_ARCH_REGISTRY.register() @@ -23,17 +22,19 @@ class CustomRCNN(GeneralizedRCNN): @configurable def __init__( self, - with_image_labels=False, - dataset_loss_weight=[], - fp16=False, - sync_caption_batch=False, - roi_head_name="", - cap_batch_ratio=4, - with_caption=False, - dynamic_classifier=False, + with_image_labels: bool=False, + dataset_loss_weight=None, + fp16: bool=False, + sync_caption_batch: bool=False, + roi_head_name: str="", + cap_batch_ratio: int=4, + with_caption: bool=False, + dynamic_classifier: bool=False, **kwargs, - ): + ) -> None: """ """ + if dataset_loss_weight is None: + dataset_loss_weight = [] self.with_image_labels = with_image_labels self.dataset_loss_weight = dataset_loss_weight self.fp16 = fp16 @@ -80,8 +81,8 @@ def from_config(cls, cfg): def inference( self, - batched_inputs: Tuple[Dict[str, torch.Tensor]], - detected_instances: Optional[List[Instances]] = None, + batched_inputs: tuple[dict[str, torch.Tensor]], + detected_instances: list[Instances] | None = None, do_postprocess: bool = True, ): assert not self.training @@ -97,7 +98,7 @@ def inference( else: return results - def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + def forward(self, batched_inputs: list[dict[str, torch.Tensor]]): """ Add ann_type Ignore proposal loss when training with image labels @@ -110,7 +111,7 @@ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): ann_type = "box" gt_instances = [x["instances"].to(self.device) for x in batched_inputs] if self.with_image_labels: - for inst, x in zip(gt_instances, batched_inputs): + for inst, x in zip(gt_instances, batched_inputs, strict=False): inst._ann_type = x["ann_type"] inst._pos_category_ids = x["pos_category_ids"] ann_types = [x["ann_type"] for x in batched_inputs] @@ -131,7 +132,7 @@ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): if self.with_caption and "caption" in ann_type: inds = [torch.randint(len(x["captions"]), (1,))[0].item() for x in batched_inputs] - caps = [x["captions"][ind] for ind, x in zip(inds, batched_inputs)] + caps = [x["captions"][ind] for ind, x in zip(inds, batched_inputs, strict=False)] caption_features = self.text_encoder(caps).float() if self.sync_caption_batch: caption_features = self._sync_caption_features( @@ -140,7 +141,7 @@ def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): if self.dynamic_classifier and ann_type != "caption": cls_inds = self._sample_cls_inds(gt_instances, ann_type) # inds, inv_inds - ind_with_bg = cls_inds[0].tolist() + [-1] + ind_with_bg = [*cls_inds[0].tolist(), -1] cls_features = ( self.roi_heads.box_predictor[0] .cls_score.zs_weight[:, ind_with_bg] @@ -204,7 +205,7 @@ def _sync_caption_features(self, caption_features, ann_type, BS): ) # (NB) x (D + 1) return caption_features - def _sample_cls_inds(self, gt_instances, ann_type="box"): + def _sample_cls_inds(self, gt_instances, ann_type: str="box"): if ann_type == "box": gt_classes = torch.cat([x.gt_classes for x in gt_instances]) C = len(self.freq_weight) @@ -218,7 +219,7 @@ def _sample_cls_inds(self, gt_instances, ann_type="box"): ) C = self.num_classes freq_weight = None - assert gt_classes.max() < C, "{} {}".format(gt_classes.max(), C) + assert gt_classes.max() < C, f"{gt_classes.max()} {C}" inds = get_fed_loss_inds(gt_classes, self.num_sample_cats, C, weight=freq_weight) cls_id_map = gt_classes.new_full((self.num_classes + 1,), len(inds)) cls_id_map[inds] = torch.arange(len(inds), device=cls_id_map.device) diff --git a/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py b/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py index 636adb1f44..ad1bce7ed0 100644 --- a/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py +++ b/dimos/models/Detic/detic/modeling/meta_arch/d2_deformable_detr.py @@ -1,35 +1,35 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import torch -import torch.nn.functional as F -from torch import nn +from collections.abc import Sequence from detectron2.modeling import META_ARCH_REGISTRY, build_backbone from detectron2.structures import Boxes, Instances -from ..utils import load_class_freq, get_fed_loss_inds - from models.backbone import Joiner from models.deformable_detr import DeformableDETR, SetCriterion +from models.deformable_transformer import DeformableTransformer from models.matcher import HungarianMatcher from models.position_encoding import PositionEmbeddingSine -from models.deformable_transformer import DeformableTransformer from models.segmentation import sigmoid_focal_loss +import torch +from torch import nn +import torch.nn.functional as F from util.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh from util.misc import NestedTensor, accuracy +from ..utils import get_fed_loss_inds, load_class_freq __all__ = ["DeformableDetr"] class CustomSetCriterion(SetCriterion): def __init__( - self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25, use_fed_loss=False - ): + self, num_classes: int, matcher, weight_dict, losses, focal_alpha: float=0.25, use_fed_loss: bool=False + ) -> None: super().__init__(num_classes, matcher, weight_dict, losses, focal_alpha) self.use_fed_loss = use_fed_loss if self.use_fed_loss: self.register_buffer("fed_loss_weight", load_class_freq(freq_weight=0.5)) - def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + def loss_labels(self, outputs, targets, indices, num_boxes: int, log: bool=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ @@ -37,7 +37,7 @@ def loss_labels(self, outputs, targets, indices, num_boxes, log=True): src_logits = outputs["pred_logits"] idx = self._get_src_permutation_idx(indices) - target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices, strict=False)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) @@ -87,7 +87,7 @@ def loss_labels(self, outputs, targets, indices, num_boxes, log=True): class MaskedBackbone(nn.Module): """This is a thin wrapper around D2's backbone to provide padding masking""" - def __init__(self, cfg): + def __init__(self, cfg) -> None: super().__init__() self.backbone = build_backbone(cfg) backbone_shape = self.backbone.output_shape() @@ -112,7 +112,7 @@ class DeformableDetr(nn.Module): Implement Deformable Detr """ - def __init__(self, cfg): + def __init__(self, cfg) -> None: super().__init__() self.with_image_labels = cfg.WITH_IMAGE_LABELS self.weak_weight = cfg.MODEL.DETR.WEAK_WEIGHT @@ -250,7 +250,7 @@ def prepare_targets(self, targets): new_targets[-1].update({"masks": gt_masks}) return new_targets - def post_process(self, outputs, target_sizes): + def post_process(self, outputs, target_sizes: Sequence[int]): """ """ out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] assert len(out_logits) == len(target_sizes) @@ -272,7 +272,7 @@ def post_process(self, outputs, target_sizes): boxes = boxes * scale_fct[:, None, :] results = [] - for s, l, b, size in zip(scores, labels, boxes, target_sizes): + for s, l, b, size in zip(scores, labels, boxes, target_sizes, strict=False): r = Instances((size[0], size[1])) r.pred_boxes = Boxes(b) r.scores = s @@ -303,7 +303,7 @@ def _weak_loss(self, outputs, batched_inputs): loss = loss / len(batched_inputs) return loss - def _max_size_loss(self, logits, boxes, label): + def _max_size_loss(self, logits, boxes, label: str): """ Inputs: logits: L x N x C diff --git a/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py b/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py index 6d4d2e786e..64893840b6 100644 --- a/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py +++ b/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py @@ -1,19 +1,23 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from collections.abc import Sequence import math -import torch -from fvcore.nn import giou_loss, smooth_l1_loss -from torch import nn -from torch.nn import functional as F -import fvcore.nn.weight_init as weight_init -import detectron2.utils.comm as comm + from detectron2.config import configurable from detectron2.layers import ShapeSpec, cat, nonzero_tuple +from detectron2.modeling.roi_heads.fast_rcnn import ( + FastRCNNOutputLayers, + _log_classification_stats, + fast_rcnn_inference, +) +import detectron2.utils.comm as comm from detectron2.utils.events import get_event_storage -from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers -from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference -from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats +from fvcore.nn import giou_loss, smooth_l1_loss +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F -from ..utils import load_class_freq, get_fed_loss_inds +from ..utils import get_fed_loss_inds, load_class_freq from .zero_shot_classifier import ZeroShotClassifier __all__ = ["DeticFastRCNNOutputLayers"] @@ -25,28 +29,28 @@ def __init__( self, input_shape: ShapeSpec, *, - mult_proposal_score=False, + mult_proposal_score: bool=False, cls_score=None, - sync_caption_batch=False, - use_sigmoid_ce=False, - use_fed_loss=False, - ignore_zero_cats=False, - fed_loss_num_cat=50, - dynamic_classifier=False, - image_label_loss="", - use_zeroshot_cls=False, - image_loss_weight=0.1, - with_softmax_prop=False, - caption_weight=1.0, - neg_cap_weight=1.0, - add_image_box=False, - debug=False, - prior_prob=0.01, - cat_freq_path="", - fed_loss_freq_weight=0.5, - softmax_weak_loss=False, + sync_caption_batch: bool=False, + use_sigmoid_ce: bool=False, + use_fed_loss: bool=False, + ignore_zero_cats: bool=False, + fed_loss_num_cat: int=50, + dynamic_classifier: bool=False, + image_label_loss: str="", + use_zeroshot_cls: bool=False, + image_loss_weight: float=0.1, + with_softmax_prop: bool=False, + caption_weight: float=1.0, + neg_cap_weight: float=1.0, + add_image_box: bool=False, + debug: bool=False, + prior_prob: float=0.01, + cat_freq_path: str="", + fed_loss_freq_weight: float=0.5, + softmax_weak_loss: bool=False, **kwargs, - ): + ) -> None: super().__init__( input_shape=input_shape, **kwargs, @@ -147,7 +151,7 @@ def from_config(cls, cfg, input_shape): return ret def losses( - self, predictions, proposals, use_advanced_loss=True, classifier_info=(None, None, None) + self, predictions, proposals, use_advanced_loss: bool=True, classifier_info=(None, None, None) ): """ enable advanced loss @@ -247,7 +251,7 @@ def softmax_cross_entropy_loss(self, pred_class_logits, gt_classes): loss = F.cross_entropy(pred_class_logits, gt_classes, reduction="mean") return loss - def box_reg_loss(self, proposal_boxes, gt_boxes, pred_deltas, gt_classes, num_classes=-1): + def box_reg_loss(self, proposal_boxes, gt_boxes, pred_deltas, gt_classes, num_classes: int=-1): """ Allow custom background index """ @@ -287,7 +291,7 @@ def inference(self, predictions, proposals): scores = self.predict_probs(predictions, proposals) if self.mult_proposal_score: proposal_scores = [p.get("objectness_logits") for p in proposals] - scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)] + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] image_shapes = [x.image_size for x in proposals] return fast_rcnn_inference( boxes, @@ -315,9 +319,9 @@ def image_label_losses( self, predictions, proposals, - image_labels, + image_labels: Sequence[str], classifier_info=(None, None, None), - ann_type="image", + ann_type: str="image", ): """ Inputs: @@ -341,7 +345,7 @@ def image_label_losses( loss = scores[0].new_zeros([1])[0] caption_loss = scores[0].new_zeros([1])[0] for idx, (score, labels, prop_score, p) in enumerate( - zip(scores, image_labels, prop_scores, proposals) + zip(scores, image_labels, prop_scores, proposals, strict=False) ): if score.shape[0] == 0: loss += score.new_zeros([1])[0] @@ -449,7 +453,7 @@ def forward(self, x, classifier_info=(None, None, None)): else: return scores, proposal_deltas - def _caption_loss(self, score, classifier_info, idx, B): + def _caption_loss(self, score, classifier_info, idx: int, B): assert classifier_info[2] is not None assert self.add_image_box cls_and_cap_num = score.shape[1] @@ -464,13 +468,7 @@ def _caption_loss(self, score, classifier_info, idx, B): # caption_target: 1 x MB rank = comm.get_rank() global_idx = B * rank + idx - assert (classifier_info[2][global_idx, -1] - rank) ** 2 < 1e-8, "{} {} {} {} {}".format( - rank, - global_idx, - classifier_info[2][global_idx, -1], - classifier_info[2].shape, - classifier_info[2][:, -1], - ) + assert (classifier_info[2][global_idx, -1] - rank) ** 2 < 1e-8, f"{rank} {global_idx} {classifier_info[2][global_idx, -1]} {classifier_info[2].shape} {classifier_info[2][:, -1]}" caption_target[:, global_idx] = 1.0 else: assert caption_score.shape[1] == B @@ -480,7 +478,7 @@ def _caption_loss(self, score, classifier_info, idx, B): ) if self.sync_caption_batch: fg_mask = (caption_target > 0.5).float() - assert (fg_mask.sum().item() - 1.0) ** 2 < 1e-8, "{} {}".format(fg_mask.shape, fg_mask) + assert (fg_mask.sum().item() - 1.0) ** 2 < 1e-8, f"{fg_mask.shape} {fg_mask}" pos_loss = (caption_loss_img * fg_mask).sum() neg_loss = (caption_loss_img * (1.0 - fg_mask)).sum() caption_loss_img = pos_loss + self.neg_cap_weight * neg_loss @@ -488,7 +486,7 @@ def _caption_loss(self, score, classifier_info, idx, B): caption_loss_img = caption_loss_img.sum() return score, caption_loss_img - def _wsddn_loss(self, score, prop_score, label): + def _wsddn_loss(self, score, prop_score, label: str): assert prop_score is not None loss = 0 final_score = score.sigmoid() * F.softmax(prop_score, dim=0) # B x (C + 1) @@ -499,7 +497,7 @@ def _wsddn_loss(self, score, prop_score, label): ind = final_score[:, label].argmax() return loss, ind - def _max_score_loss(self, score, label): + def _max_score_loss(self, score, label: str): loss = 0 target = score.new_zeros(score.shape[1]) target[label] = 1.0 @@ -507,7 +505,7 @@ def _max_score_loss(self, score, label): loss += F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") return loss, ind - def _min_loss_loss(self, score, label): + def _min_loss_loss(self, score, label: str): loss = 0 target = score.new_zeros(score.shape) target[:, label] = 1.0 @@ -517,7 +515,7 @@ def _min_loss_loss(self, score, label): loss += F.binary_cross_entropy_with_logits(score[ind], target[0], reduction="sum") return loss, ind - def _first_loss(self, score, label): + def _first_loss(self, score, label: str): loss = 0 target = score.new_zeros(score.shape[1]) target[label] = 1.0 @@ -525,7 +523,7 @@ def _first_loss(self, score, label): loss += F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") return loss, ind - def _image_loss(self, score, label): + def _image_loss(self, score, label: str): assert self.add_image_box target = score.new_zeros(score.shape[1]) target[label] = 1.0 @@ -533,7 +531,7 @@ def _image_loss(self, score, label): loss = F.binary_cross_entropy_with_logits(score[ind], target, reduction="sum") return loss, ind - def _max_size_loss(self, score, label, p): + def _max_size_loss(self, score, label: str, p): loss = 0 target = score.new_zeros(score.shape[1]) target[label] = 1.0 @@ -550,7 +548,7 @@ def _max_size_loss(self, score, label, p): return loss, ind -def put_label_distribution(storage, hist_name, hist_counts, num_classes): +def put_label_distribution(storage, hist_name: str, hist_counts, num_classes: int) -> None: """ """ ht_min, ht_max = 0, num_classes hist_edges = torch.linspace( diff --git a/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py b/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py index 8fa0e3f538..7e319453df 100644 --- a/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py +++ b/dimos/models/Detic/detic/modeling/roi_heads/detic_roi_heads.py @@ -1,14 +1,15 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import torch +from collections.abc import Sequence from detectron2.config import configurable -from detectron2.structures import Boxes, Instances -from detectron2.utils.events import get_event_storage - from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY -from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads, _ScaleGradient +from detectron2.structures import Boxes, Instances +from detectron2.utils.events import get_event_storage +import torch + from .detic_fast_rcnn import DeticFastRCNNOutputLayers @@ -27,7 +28,7 @@ def __init__( mask_weight: float = 1.0, one_class_per_proposal: bool = False, **kwargs, - ): + ) -> None: super().__init__(**kwargs) self.mult_proposal_score = mult_proposal_score self.with_image_labels = with_image_labels @@ -56,12 +57,12 @@ def from_config(cls, cfg, input_shape): return ret @classmethod - def _init_box_head(self, cfg, input_shape): + def _init_box_head(cls, cfg, input_shape): ret = super()._init_box_head(cfg, input_shape) del ret["box_predictors"] cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS box_predictors = [] - for box_head, bbox_reg_weights in zip(ret["box_heads"], cascade_bbox_reg_weights): + for box_head, bbox_reg_weights in zip(ret["box_heads"], cascade_bbox_reg_weights, strict=False): box_predictors.append( DeticFastRCNNOutputLayers( cfg, @@ -73,7 +74,7 @@ def _init_box_head(self, cfg, input_shape): return ret def _forward_box( - self, features, proposals, targets=None, ann_type="box", classifier_info=(None, None, None) + self, features, proposals, targets=None, ann_type: str="box", classifier_info=(None, None, None) ): """ Add mult proposal scores at testing @@ -107,7 +108,7 @@ def _forward_box( losses = {} storage = get_event_storage() for stage, (predictor, predictions, proposals) in enumerate(head_outputs): - with storage.name_scope("stage{}".format(stage)): + with storage.name_scope(f"stage{stage}"): if ann_type != "box": stage_losses = {} if ann_type in ["image", "caption", "captiontag"]: @@ -128,17 +129,17 @@ def _forward_box( ) if self.with_image_labels: stage_losses["image_loss"] = predictions[0].new_zeros([1])[0] - losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()}) + losses.update({k + f"_stage{stage}": v for k, v in stage_losses.items()}) return losses else: # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] scores = [ sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) - for scores_per_image in zip(*scores_per_stage) + for scores_per_image in zip(*scores_per_stage, strict=False) ] if self.mult_proposal_score: - scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)] + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] if self.one_class_per_proposal: scores = [s * (s == s[:, :-1].max(dim=1)[0][:, None]).float() for s in scores] predictor, predictions, proposals = head_outputs[-1] @@ -159,7 +160,7 @@ def forward( features, proposals, targets=None, - ann_type="box", + ann_type: str="box", classifier_info=(None, None, None), ): """ @@ -225,13 +226,13 @@ def _get_empty_mask_loss(self, features, proposals, device): else: return {} - def _create_proposals_from_boxes(self, boxes, image_sizes, logits): + def _create_proposals_from_boxes(self, boxes, image_sizes: Sequence[int], logits): """ Add objectness_logits """ boxes = [Boxes(b.detach()) for b in boxes] proposals = [] - for boxes_per_image, image_size, logit in zip(boxes, image_sizes, logits): + for boxes_per_image, image_size, logit in zip(boxes, image_sizes, logits, strict=False): boxes_per_image.clip(image_size) if self.training: inds = boxes_per_image.nonempty() @@ -253,6 +254,6 @@ def _run_stage(self, features, proposals, stage, classifier_info=(None, None, No box_features = self.box_head[stage](box_features) if self.add_feature_to_prop: feats_per_image = box_features.split([len(p) for p in proposals], dim=0) - for feat, p in zip(feats_per_image, proposals): + for feat, p in zip(feats_per_image, proposals, strict=False): p.feat = feat return self.box_predictor[stage](box_features, classifier_info=classifier_info) diff --git a/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py b/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py index d05a5d0537..642f889b5d 100644 --- a/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py +++ b/dimos/models/Detic/detic/modeling/roi_heads/res5_roi_heads.py @@ -1,20 +1,18 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import torch - from detectron2.config import configurable from detectron2.layers import ShapeSpec -from detectron2.structures import Boxes, Instances - from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, Res5ROIHeads +from detectron2.structures import Boxes, Instances +import torch -from .detic_fast_rcnn import DeticFastRCNNOutputLayers from ..debug import debug_second_stage +from .detic_fast_rcnn import DeticFastRCNNOutputLayers @ROI_HEADS_REGISTRY.register() class CustomRes5ROIHeads(Res5ROIHeads): @configurable - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: cfg = kwargs.pop("cfg") super().__init__(**kwargs) stage_channel_factor = 2**3 @@ -54,7 +52,7 @@ def forward( features, proposals, targets=None, - ann_type="box", + ann_type: str="box", classifier_info=(None, None, None), ): """ @@ -82,7 +80,7 @@ def forward( feats_per_image = box_features.mean(dim=[2, 3]).split( [len(p) for p in proposals], dim=0 ) - for feat, p in zip(feats_per_image, proposals): + for feat, p in zip(feats_per_image, proposals, strict=False): p.feat = feat if self.training: @@ -102,7 +100,8 @@ def forward( assert "image_loss" not in losses losses["image_loss"] = predictions[0].new_zeros([1])[0] if self.save_debug: - denormalizer = lambda x: x * self.pixel_std + self.pixel_mean + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean if ann_type != "box": image_labels = [x._pos_category_ids for x in targets] else: @@ -123,7 +122,8 @@ def forward( pred_instances, _ = self.box_predictor.inference(predictions, proposals) pred_instances = self.forward_with_given_boxes(features, pred_instances) if self.save_debug: - denormalizer = lambda x: x * self.pixel_std + self.pixel_mean + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean debug_second_stage( [denormalizer(x.clone()) for x in images], pred_instances, @@ -146,7 +146,7 @@ def get_top_proposals(self, proposals): proposals[i] = self._add_image_box(p) return proposals - def _add_image_box(self, p, use_score=False): + def _add_image_box(self, p, use_score: bool=False): image_box = Instances(p.image_size) n = 1 h, w = p.image_size diff --git a/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py b/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py index 7dfe0d7097..d436e6be34 100644 --- a/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py +++ b/dimos/models/Detic/detic/modeling/roi_heads/zero_shot_classifier.py @@ -1,10 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from detectron2.config import configurable +from detectron2.layers import ShapeSpec import numpy as np import torch from torch import nn from torch.nn import functional as F -from detectron2.config import configurable -from detectron2.layers import ShapeSpec class ZeroShotClassifier(nn.Module): @@ -19,7 +19,7 @@ def __init__( use_bias: float = 0.0, norm_weight: bool = True, norm_temperature: float = 50.0, - ): + ) -> None: super().__init__() if isinstance(input_shape, int): # some backward compatibility input_shape = ShapeSpec(channels=input_shape) diff --git a/dimos/models/Detic/detic/modeling/text/text_encoder.py b/dimos/models/Detic/detic/modeling/text/text_encoder.py index ff58592bd8..335ca659de 100644 --- a/dimos/models/Detic/detic/modeling/text/text_encoder.py +++ b/dimos/models/Detic/detic/modeling/text/text_encoder.py @@ -2,12 +2,11 @@ # Modified by Xingyi Zhou # The original code is under MIT license # Copyright (c) Facebook, Inc. and its affiliates. -from typing import Union, List from collections import OrderedDict -import torch -from torch import nn from clip.simple_tokenizer import SimpleTokenizer as _Tokenizer +import torch +from torch import nn __all__ = ["tokenize"] @@ -29,7 +28,7 @@ def forward(self, x: torch.Tensor): class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): + def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None) -> None: super().__init__() self.attn = nn.MultiheadAttention(d_model, n_head) @@ -61,7 +60,7 @@ def forward(self, x: torch.Tensor): class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): + def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None) -> None: super().__init__() self.width = width self.layers = layers @@ -76,14 +75,14 @@ def forward(self, x: torch.Tensor): class CLIPTEXT(nn.Module): def __init__( self, - embed_dim=512, + embed_dim: int=512, # text - context_length=77, - vocab_size=49408, - transformer_width=512, - transformer_heads=8, - transformer_layers=12, - ): + context_length: int=77, + vocab_size: int=49408, + transformer_width: int=512, + transformer_heads: int=8, + transformer_layers: int=12, + ) -> None: super().__init__() self._tokenizer = _Tokenizer() @@ -108,7 +107,7 @@ def __init__( self.initialize_parameters() - def initialize_parameters(self): + def initialize_parameters(self) -> None: nn.init.normal_(self.token_embedding.weight, std=0.02) nn.init.normal_(self.positional_embedding, std=0.01) @@ -140,14 +139,14 @@ def device(self): def dtype(self): return self.text_projection.dtype - def tokenize(self, texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: + def tokenize(self, texts: str | list[str], context_length: int = 77) -> torch.LongTensor: """ """ if isinstance(texts, str): texts = [texts] sot_token = self._tokenizer.encoder["<|startoftext|>"] eot_token = self._tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts] + all_tokens = [[sot_token, *self._tokenizer.encode(text), eot_token] for text in texts] result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) for i, tokens in enumerate(all_tokens): @@ -159,7 +158,7 @@ def tokenize(self, texts: Union[str, List[str]], context_length: int = 77) -> to return result - def encode_text(self, text): + def encode_text(self, text: str): x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.type(self.dtype) x = x.permute(1, 0, 2) # NLD -> LND @@ -179,7 +178,7 @@ def forward(self, captions): return features -def build_text_encoder(pretrain=True): +def build_text_encoder(pretrain: bool=True): text_encoder = CLIPTEXT() if pretrain: import clip diff --git a/dimos/models/Detic/detic/modeling/utils.py b/dimos/models/Detic/detic/modeling/utils.py index a028e9246d..f24a0699a1 100644 --- a/dimos/models/Detic/detic/modeling/utils.py +++ b/dimos/models/Detic/detic/modeling/utils.py @@ -1,18 +1,19 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import torch import json + import numpy as np +import torch from torch.nn import functional as F -def load_class_freq(path="datasets/metadata/lvis_v1_train_cat_info.json", freq_weight=1.0): - cat_info = json.load(open(path, "r")) +def load_class_freq(path: str="datasets/metadata/lvis_v1_train_cat_info.json", freq_weight: float=1.0): + cat_info = json.load(open(path)) cat_info = torch.tensor([c["image_count"] for c in sorted(cat_info, key=lambda x: x["id"])]) freq_weight = cat_info.float() ** freq_weight return freq_weight -def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): +def get_fed_loss_inds(gt_classes, num_sample_cats: int, C, weight=None): appeared = torch.unique(gt_classes) # C' prob = appeared.new_ones(C + 1).float() prob[-1] = 0 @@ -25,7 +26,7 @@ def get_fed_loss_inds(gt_classes, num_sample_cats, C, weight=None): return appeared -def reset_cls_test(model, cls_path, num_classes): +def reset_cls_test(model, cls_path, num_classes: int) -> None: model.roi_heads.num_classes = num_classes if type(cls_path) == str: print("Resetting zs_weight", cls_path) diff --git a/dimos/models/Detic/detic/predictor.py b/dimos/models/Detic/detic/predictor.py index 9985c2d854..a85941e25a 100644 --- a/dimos/models/Detic/detic/predictor.py +++ b/dimos/models/Detic/detic/predictor.py @@ -1,20 +1,20 @@ # Copyright (c) Facebook, Inc. and its affiliates. import atexit import bisect -import multiprocessing as mp from collections import deque -import cv2 -import torch +import multiprocessing as mp +import cv2 from detectron2.data import MetadataCatalog from detectron2.engine.defaults import DefaultPredictor from detectron2.utils.video_visualizer import VideoVisualizer from detectron2.utils.visualizer import ColorMode, Visualizer +import torch from .modeling.utils import reset_cls_test -def get_clip_embeddings(vocabulary, prompt="a "): +def get_clip_embeddings(vocabulary, prompt: str="a "): from detic.modeling.text.text_encoder import build_text_encoder text_encoder = build_text_encoder(pretrain=True) @@ -39,8 +39,8 @@ def get_clip_embeddings(vocabulary, prompt="a "): } -class VisualizationDemo(object): - def __init__(self, cfg, args, instance_mode=ColorMode.IMAGE, parallel=False): +class VisualizationDemo: + def __init__(self, cfg, args, instance_mode=ColorMode.IMAGE, parallel: bool=False) -> None: """ Args: cfg (CfgNode): @@ -174,13 +174,13 @@ class _StopToken: pass class _PredictWorker(mp.Process): - def __init__(self, cfg, task_queue, result_queue): + def __init__(self, cfg, task_queue, result_queue) -> None: self.cfg = cfg self.task_queue = task_queue self.result_queue = result_queue super().__init__() - def run(self): + def run(self) -> None: predictor = DefaultPredictor(self.cfg) while True: @@ -191,7 +191,7 @@ def run(self): result = predictor(data) self.result_queue.put((idx, result)) - def __init__(self, cfg, num_gpus: int = 1): + def __init__(self, cfg, num_gpus: int = 1) -> None: """ Args: cfg (CfgNode): @@ -204,7 +204,7 @@ def __init__(self, cfg, num_gpus: int = 1): for gpuid in range(max(num_gpus, 1)): cfg = cfg.clone() cfg.defrost() - cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" + cfg.MODEL.DEVICE = f"cuda:{gpuid}" if num_gpus > 0 else "cpu" self.procs.append( AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) ) @@ -218,7 +218,7 @@ def __init__(self, cfg, num_gpus: int = 1): p.start() atexit.register(self.shutdown) - def put(self, image): + def put(self, image) -> None: self.put_idx += 1 self.task_queue.put((self.put_idx, image)) @@ -238,14 +238,14 @@ def get(self): self.result_rank.insert(insert, idx) self.result_data.insert(insert, res) - def __len__(self): + def __len__(self) -> int: return self.put_idx - self.get_idx def __call__(self, image): self.put(image) return self.get() - def shutdown(self): + def shutdown(self) -> None: for _ in self.procs: self.task_queue.put(AsyncPredictor._StopToken()) diff --git a/dimos/models/Detic/lazy_train_net.py b/dimos/models/Detic/lazy_train_net.py index d6c4e7e841..3525a1f63a 100644 --- a/dimos/models/Detic/lazy_train_net.py +++ b/dimos/models/Detic/lazy_train_net.py @@ -42,7 +42,7 @@ def do_test(cfg, model): return ret -def do_train(args, cfg): +def do_train(args, cfg) -> None: """ Args: cfg: an object with the following attributes: @@ -63,7 +63,7 @@ def do_train(args, cfg): """ model = instantiate(cfg.model) logger = logging.getLogger("detectron2") - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") model.to(cfg.train.device) cfg.optimizer.params.model = model @@ -105,7 +105,7 @@ def do_train(args, cfg): trainer.train(start_iter, cfg.train.max_iter) -def main(args): +def main(args) -> None: cfg = LazyConfig.load(args.config_file) cfg = LazyConfig.apply_overrides(cfg, args.opts) default_setup(cfg, args) diff --git a/dimos/models/Detic/predict.py b/dimos/models/Detic/predict.py index 4091bec3b9..bf71d007a1 100644 --- a/dimos/models/Detic/predict.py +++ b/dimos/models/Detic/predict.py @@ -1,26 +1,27 @@ +from pathlib import Path import sys -import cv2 import tempfile -from pathlib import Path -import cog import time +import cog +import cv2 +from detectron2.config import get_cfg +from detectron2.data import MetadataCatalog + # import some common detectron2 utilities from detectron2.engine import DefaultPredictor -from detectron2.config import get_cfg from detectron2.utils.visualizer import Visualizer -from detectron2.data import MetadataCatalog # Detic libraries sys.path.insert(0, "third_party/CenterNet2/") from centernet.config import add_centernet_config from detic.config import add_detic_config -from detic.modeling.utils import reset_cls_test from detic.modeling.text.text_encoder import build_text_encoder +from detic.modeling.utils import reset_cls_test class Predictor(cog.Predictor): - def setup(self): + def setup(self) -> None: cfg = get_cfg() add_centernet_config(cfg) add_detic_config(cfg) @@ -93,7 +94,7 @@ def predict(self, image, vocabulary, custom_vocabulary): return out_path -def get_clip_embeddings(vocabulary, prompt="a "): +def get_clip_embeddings(vocabulary, prompt: str="a "): text_encoder = build_text_encoder(pretrain=True) text_encoder.eval() texts = [prompt + x for x in vocabulary] diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py b/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py index e17db317d9..5e2e7afac6 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/__init__.py @@ -1,14 +1,12 @@ -from .modeling.meta_arch.centernet_detector import CenterNetDetector -from .modeling.dense_heads.centernet import CenterNet -from .modeling.roi_heads.custom_roi_heads import CustomROIHeads, CustomCascadeROIHeads - -from .modeling.backbone.fpn_p5 import build_p67_resnet_fpn_backbone -from .modeling.backbone.dla import build_dla_backbone -from .modeling.backbone.dlafpn import build_dla_fpn3_backbone +from .data.datasets import nuimages +from .data.datasets.coco import _PREDEFINED_SPLITS_COCO +from .data.datasets.objects365 import categories_v1 from .modeling.backbone.bifpn import build_resnet_bifpn_backbone from .modeling.backbone.bifpn_fcos import build_fcos_resnet_bifpn_backbone +from .modeling.backbone.dla import build_dla_backbone +from .modeling.backbone.dlafpn import build_dla_fpn3_backbone +from .modeling.backbone.fpn_p5 import build_p67_resnet_fpn_backbone from .modeling.backbone.res2net import build_p67_res2net_fpn_backbone - -from .data.datasets.objects365 import categories_v1 -from .data.datasets.coco import _PREDEFINED_SPLITS_COCO -from .data.datasets import nuimages +from .modeling.dense_heads.centernet import CenterNet +from .modeling.meta_arch.centernet_detector import CenterNetDetector +from .modeling.roi_heads.custom_roi_heads import CustomCascadeROIHeads, CustomROIHeads diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/config.py b/dimos/models/Detic/third_party/CenterNet2/centernet/config.py index 3ff5c725c9..255eb36340 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/config.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/config.py @@ -1,7 +1,7 @@ from detectron2.config import CfgNode as CN -def add_centernet_config(cfg): +def add_centernet_config(cfg) -> None: _C = cfg _C.MODEL.CENTERNET = CN() diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py index 72e399fa40..1bcb7cee66 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_build_augmentation.py @@ -1,8 +1,9 @@ from detectron2.data import transforms as T + from .transforms.custom_augmentation_impl import EfficientDetResizeCrop -def build_custom_augmentation(cfg, is_train): +def build_custom_augmentation(cfg, is_train: bool): """ Create a list of default :class:`Augmentation` from config. Now it includes resizing and flipping. diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py index b8776789cf..4e23e565a4 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/custom_dataset_dataloader.py @@ -1,21 +1,24 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +from collections import defaultdict +from collections.abc import Iterator, Sequence +import itertools import logging -import torch -import torch.utils.data -from torch.utils.data.sampler import Sampler +from detectron2.data.build import ( + build_batch_data_loader, + check_metadata_consistency, + filter_images_with_few_keypoints, + filter_images_with_only_crowd_annotations, + get_detection_dataset_dicts, + print_instances_class_histogram, +) +from detectron2.data.catalog import DatasetCatalog, MetadataCatalog from detectron2.data.common import DatasetFromList, MapDataset -from detectron2.data.build import get_detection_dataset_dicts, build_batch_data_loader -from detectron2.data.samplers import TrainingSampler, RepeatFactorTrainingSampler -from detectron2.data.build import print_instances_class_histogram -from detectron2.data.build import filter_images_with_only_crowd_annotations -from detectron2.data.build import filter_images_with_few_keypoints -from detectron2.data.build import check_metadata_consistency -from detectron2.data.catalog import MetadataCatalog, DatasetCatalog +from detectron2.data.samplers import RepeatFactorTrainingSampler, TrainingSampler from detectron2.utils import comm -import itertools -from collections import defaultdict -from typing import Optional +import torch +import torch.utils.data +from torch.utils.data.sampler import Sampler # from .custom_build_augmentation import build_custom_augmentation @@ -57,7 +60,7 @@ def build_custom_train_loader(cfg, mapper=None): sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) - logger.info("Using training sampler {}".format(sampler_name)) + logger.info(f"Using training sampler {sampler_name}") # TODO avoid if-else? if sampler_name == "TrainingSampler": sampler = TrainingSampler(len(dataset)) @@ -72,7 +75,7 @@ def build_custom_train_loader(cfg, mapper=None): elif sampler_name == "ClassAwareSampler": sampler = ClassAwareSampler(dataset_dicts) else: - raise ValueError("Unknown training sampler: {}".format(sampler_name)) + raise ValueError(f"Unknown training sampler: {sampler_name}") return build_batch_data_loader( dataset, @@ -84,7 +87,7 @@ def build_custom_train_loader(cfg, mapper=None): class ClassAwareSampler(Sampler): - def __init__(self, dataset_dicts, seed: Optional[int] = None): + def __init__(self, dataset_dicts, seed: int | None = None) -> None: """ Args: size (int): the total number of data of the underlying dataset to sample from @@ -102,7 +105,7 @@ def __init__(self, dataset_dicts, seed: Optional[int] = None): self._world_size = comm.get_world_size() self.weights = self._get_class_balance_factor(dataset_dicts) - def __iter__(self): + def __iter__(self) -> Iterator: start = self._rank yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) @@ -113,7 +116,7 @@ def _infinite_indices(self): ids = torch.multinomial(self.weights, self._size, generator=g, replacement=True) yield from ids - def _get_class_balance_factor(self, dataset_dicts, l=1.0): + def _get_class_balance_factor(self, dataset_dicts, l: float=1.0): # 1. For each category c, compute the fraction of images that contain it: f(c) ret = [] category_freq = defaultdict(int) @@ -121,22 +124,22 @@ def _get_class_balance_factor(self, dataset_dicts, l=1.0): cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} for cat_id in cat_ids: category_freq[cat_id] += 1 - for i, dataset_dict in enumerate(dataset_dicts): + for _i, dataset_dict in enumerate(dataset_dicts): cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} ret.append(sum([1.0 / (category_freq[cat_id] ** l) for cat_id in cat_ids])) return torch.tensor(ret).float() def get_detection_dataset_dicts_with_source( - dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None + dataset_names: Sequence[str], filter_empty: bool=True, min_keypoints: int=0, proposal_files=None ): assert len(dataset_names) dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] - for dataset_name, dicts in zip(dataset_names, dataset_dicts): - assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + for dataset_name, dicts in zip(dataset_names, dataset_dicts, strict=False): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" - for source_id, (dataset_name, dicts) in enumerate(zip(dataset_names, dataset_dicts)): - assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + for source_id, (dataset_name, dicts) in enumerate(zip(dataset_names, dataset_dicts, strict=False)): + assert len(dicts), f"Dataset '{dataset_name}' is empty!" for d in dicts: d["dataset_source"] = source_id @@ -162,7 +165,7 @@ def get_detection_dataset_dicts_with_source( class MultiDatasetSampler(Sampler): - def __init__(self, cfg, sizes, dataset_dicts, seed: Optional[int] = None): + def __init__(self, cfg, sizes: Sequence[int], dataset_dicts, seed: int | None = None) -> None: """ Args: size (int): the total number of data of the underlying dataset to sample from @@ -174,9 +177,7 @@ def __init__(self, cfg, sizes, dataset_dicts, seed: Optional[int] = None): dataset_ratio = cfg.DATALOADER.DATASET_RATIO self._batch_size = cfg.SOLVER.IMS_PER_BATCH assert len(dataset_ratio) == len(sizes), ( - "length of dataset ratio {} should be equal to number if dataset {}".format( - len(dataset_ratio), len(sizes) - ) + f"length of dataset ratio {len(dataset_ratio)} should be equal to number if dataset {len(sizes)}" ) if seed is None: seed = comm.shared_random_seed() @@ -191,13 +192,13 @@ def __init__(self, cfg, sizes, dataset_dicts, seed: Optional[int] = None): dataset_weight = [ torch.ones(s) * max(sizes) / s * r / sum(dataset_ratio) - for i, (r, s) in enumerate(zip(dataset_ratio, sizes)) + for i, (r, s) in enumerate(zip(dataset_ratio, sizes, strict=False)) ] dataset_weight = torch.cat(dataset_weight) self.weights = dataset_weight self.sample_epoch_size = len(self.weights) - def __iter__(self): + def __iter__(self) -> Iterator: start = self._rank yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py index 93f0a13428..5825c40af0 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/coco.py @@ -1,12 +1,12 @@ import os -from detectron2.data.datasets.register_coco import register_coco_instances -from detectron2.data.datasets.coco import load_coco_json -from detectron2.data.datasets.builtin_meta import _get_builtin_metadata from detectron2.data import DatasetCatalog, MetadataCatalog +from detectron2.data.datasets.builtin_meta import _get_builtin_metadata +from detectron2.data.datasets.coco import load_coco_json +from detectron2.data.datasets.register_coco import register_coco_instances -def register_distill_coco_instances(name, metadata, json_file, image_root): +def register_distill_coco_instances(name: str, metadata, json_file, image_root) -> None: """ add extra_annotation_keys """ diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py index 22b80828c0..fdcd40242f 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/nuimages.py @@ -1,6 +1,7 @@ -from detectron2.data.datasets.register_coco import register_coco_instances import os +from detectron2.data.datasets.register_coco import register_coco_instances + categories = [ {"id": 0, "name": "car"}, {"id": 1, "name": "truck"}, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py index 22a017444f..e3e8383a91 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/datasets/objects365.py @@ -1,6 +1,7 @@ -from detectron2.data.datasets.register_coco import register_coco_instances import os +from detectron2.data.datasets.register_coco import register_coco_instances + categories_v1 = [ {"id": 164, "name": "cutting/chopping board"}, {"id": 49, "name": "tie"}, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py index cc6f2ccc9f..f4ec0ad07f 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_augmentation_impl.py @@ -1,14 +1,13 @@ -# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Modified by Xingyi Zhou """ Implement many useful :class:`Augmentation`. """ +from detectron2.data.transforms.augmentation import Augmentation import numpy as np from PIL import Image -from detectron2.data.transforms.augmentation import Augmentation from .custom_transform import EfficientDetResizeCropTransform __all__ = [ @@ -22,7 +21,7 @@ class EfficientDetResizeCrop(Augmentation): If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. """ - def __init__(self, size, scale, interp=Image.BILINEAR): + def __init__(self, size: int, scale, interp=Image.BILINEAR) -> None: """ Args: """ diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py index bd0ce13dc0..6635a5999b 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/data/transforms/custom_transform.py @@ -1,18 +1,17 @@ -# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Modified by Xingyi Zhou # File: transform.py -import numpy as np -import torch -import torch.nn.functional as F from fvcore.transforms.transform import ( Transform, ) +import numpy as np from PIL import Image +import torch +import torch.nn.functional as F try: - import cv2 # noqa + import cv2 except ImportError: # OpenCV is an optional dependency at the moment pass @@ -25,7 +24,7 @@ class EfficientDetResizeCropTransform(Transform): """ """ - def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size, interp=None): + def __init__(self, scaled_h, scaled_w, offset_y, offset_x, img_scale, target_size: int, interp=None) -> None: """ Args: h, w (int): original image size diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py index dd66c1f0c3..733b502da4 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn.py @@ -1,20 +1,20 @@ # Modified from https://github.com/rwightman/efficientdet-pytorch/blob/master/effdet/efficientdet.py # The original file is under Apache-2.0 License -import math from collections import OrderedDict +import math +from detectron2.layers import Conv2d, ShapeSpec +from detectron2.layers.batch_norm import get_norm +from detectron2.modeling.backbone import Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.resnet import build_resnet_backbone import torch from torch import nn -from detectron2.layers import ShapeSpec, Conv2d -from detectron2.modeling.backbone.resnet import build_resnet_backbone -from detectron2.modeling.backbone.build import BACKBONE_REGISTRY -from detectron2.layers.batch_norm import get_norm -from detectron2.modeling.backbone import Backbone from .dlafpn import dla34 -def get_fpn_config(base_reduction=8): +def get_fpn_config(base_reduction: int=8): """BiFPN config with sum.""" p = { "nodes": [ @@ -38,8 +38,8 @@ def swish(x, inplace: bool = False): class Swish(nn.Module): - def __init__(self, inplace: bool = False): - super(Swish, self).__init__() + def __init__(self, inplace: bool = False) -> None: + super().__init__() self.inplace = inplace def forward(self, x): @@ -47,8 +47,8 @@ def forward(self, x): class SequentialAppend(nn.Sequential): - def __init__(self, *args): - super(SequentialAppend, self).__init__(*args) + def __init__(self, *args) -> None: + super().__init__(*args) def forward(self, x): for module in self: @@ -57,8 +57,8 @@ def forward(self, x): class SequentialAppendLast(nn.Sequential): - def __init__(self, *args): - super(SequentialAppendLast, self).__init__(*args) + def __init__(self, *args) -> None: + super().__init__(*args) # def forward(self, x: List[torch.Tensor]): def forward(self, x): @@ -72,15 +72,15 @@ def __init__( self, in_channels, out_channels, - kernel_size, - stride=1, - dilation=1, - padding="", - bias=False, - norm="", + kernel_size: int, + stride: int=1, + dilation: int=1, + padding: str="", + bias: bool=False, + norm: str="", act_layer=Swish, - ): - super(ConvBnAct2d, self).__init__() + ) -> None: + super().__init__() # self.conv = create_conv2d( # in_channels, out_channels, kernel_size, stride=stride, dilation=dilation, padding=padding, bias=bias) self.conv = Conv2d( @@ -110,17 +110,17 @@ def __init__( self, in_channels, out_channels, - kernel_size=3, - stride=1, - dilation=1, - padding="", - bias=False, - channel_multiplier=1.0, - pw_kernel_size=1, + kernel_size: int=3, + stride: int=1, + dilation: int=1, + padding: str="", + bias: bool=False, + channel_multiplier: float=1.0, + pw_kernel_size: int=1, act_layer=Swish, - norm="", - ): - super(SeparableConv2d, self).__init__() + norm: str="", + ) -> None: + super().__init__() # self.conv_dw = create_conv2d( # in_channels, int(in_channels * channel_multiplier), kernel_size, @@ -166,15 +166,15 @@ def __init__( self, in_channels, out_channels, - reduction_ratio=1.0, - pad_type="", - pooling_type="max", - norm="", - apply_bn=False, - conv_after_downsample=False, - redundant_bias=False, - ): - super(ResampleFeatureMap, self).__init__() + reduction_ratio: float=1.0, + pad_type: str="", + pooling_type: str="max", + norm: str="", + apply_bn: bool=False, + conv_after_downsample: bool=False, + redundant_bias: bool=False, + ) -> None: + super().__init__() pooling_type = pooling_type or "max" self.in_channels = in_channels self.out_channels = out_channels @@ -222,20 +222,20 @@ def __init__( fpn_channels, inputs_offsets, target_reduction, - pad_type="", - pooling_type="max", - norm="", - apply_bn_for_resampling=False, - conv_after_downsample=False, - redundant_bias=False, - weight_method="attn", - ): - super(FpnCombine, self).__init__() + pad_type: str="", + pooling_type: str="max", + norm: str="", + apply_bn_for_resampling: bool=False, + conv_after_downsample: bool=False, + redundant_bias: bool=False, + weight_method: str="attn", + ) -> None: + super().__init__() self.inputs_offsets = inputs_offsets self.weight_method = weight_method self.resample = nn.ModuleDict() - for idx, offset in enumerate(inputs_offsets): + for _idx, offset in enumerate(inputs_offsets): in_channels = fpn_channels if offset < len(feature_info): in_channels = feature_info[offset]["num_chs"] @@ -284,7 +284,7 @@ def forward(self, x): elif self.weight_method == "sum": x = torch.stack(nodes, dim=-1) else: - raise ValueError("unknown weight_method {}".format(self.weight_method)) + raise ValueError(f"unknown weight_method {self.weight_method}") x = torch.sum(x, dim=-1) return x @@ -295,18 +295,18 @@ def __init__( feature_info, fpn_config, fpn_channels, - num_levels=5, - pad_type="", - pooling_type="max", - norm="", + num_levels: int=5, + pad_type: str="", + pooling_type: str="max", + norm: str="", act_layer=Swish, - apply_bn_for_resampling=False, - conv_after_downsample=True, - conv_bn_relu_pattern=False, - separable_conv=True, - redundant_bias=False, - ): - super(BiFpnLayer, self).__init__() + apply_bn_for_resampling: bool=False, + conv_after_downsample: bool=True, + conv_bn_relu_pattern: bool=False, + separable_conv: bool=True, + redundant_bias: bool=False, + ) -> None: + super().__init__() self.fpn_config = fpn_config self.num_levels = num_levels self.conv_bn_relu_pattern = False @@ -375,12 +375,12 @@ def __init__( bottom_up, in_features, out_channels, - norm="", - num_levels=5, - num_bifpn=4, - separable_conv=False, - ): - super(BiFPN, self).__init__() + norm: str="", + num_levels: int=5, + num_bifpn: int=4, + separable_conv: bool=False, + ) -> None: + super().__init__() assert isinstance(bottom_up, Backbone) # Feature map strides and channels from the bottom up network (e.g. ResNet) @@ -394,11 +394,11 @@ def __init__( self.in_features = in_features self._size_divisibility = 128 levels = [int(math.log2(s)) for s in in_strides] - self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in in_strides} + self._out_feature_strides = {f"p{int(math.log2(s))}": s for s in in_strides} if len(in_features) < num_levels: for l in range(num_levels - len(in_features)): s = l + levels[-1] - self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + self._out_feature_strides[f"p{s + 1}"] = 2 ** (s + 1) self._out_features = list(sorted(self._out_feature_strides.keys())) self._out_feature_channels = {k: out_channels for k in self._out_features} @@ -470,10 +470,10 @@ def forward(self, x): x = [bottom_up_features[f] for f in self.in_features] assert len(self.resample) == self.num_levels - len(x) x = self.resample(x) - shapes = [xx.shape for xx in x] + [xx.shape for xx in x] # print('resample shapes', shapes) x = self.cell(x) - out = {f: xx for f, xx in zip(self._out_features, x)} + out = {f: xx for f, xx in zip(self._out_features, x, strict=False)} # import pdb; pdb.set_trace() return out diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py index 67c7b67b9e..bdfba2d05b 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/bifpn_fcos.py @@ -1,13 +1,14 @@ # This file is modified from https://github.com/aim-uofa/AdelaiDet/blob/master/adet/modeling/backbone/bifpn.py # The original file is under 2-clause BSD License for academic use, and *non-commercial use*. -import torch -import torch.nn.functional as F -from torch import nn +from collections.abc import Sequence from detectron2.layers import Conv2d, ShapeSpec, get_norm - -from detectron2.modeling.backbone import Backbone, build_resnet_backbone from detectron2.modeling import BACKBONE_REGISTRY +from detectron2.modeling.backbone import Backbone, build_resnet_backbone +import torch +from torch import nn +import torch.nn.functional as F + from .dlafpn import dla34 __all__ = [] @@ -17,7 +18,7 @@ def swish(x): return x * x.sigmoid() -def split_name(name): +def split_name(name: str): for i, c in enumerate(name): if not c.isalpha(): return name[:i], int(name[i:]) @@ -25,8 +26,8 @@ def split_name(name): class FeatureMapResampler(nn.Module): - def __init__(self, in_channels, out_channels, stride, norm=""): - super(FeatureMapResampler, self).__init__() + def __init__(self, in_channels, out_channels, stride: int, norm: str="") -> None: + super().__init__() if in_channels != out_channels: self.reduction = Conv2d( in_channels, @@ -56,8 +57,8 @@ def forward(self, x): class BackboneWithTopLevels(Backbone): - def __init__(self, backbone, out_channels, num_top_levels, norm=""): - super(BackboneWithTopLevels, self).__init__() + def __init__(self, backbone, out_channels, num_top_levels: int, norm: str="") -> None: + super().__init__() self.backbone = backbone backbone_output_shape = backbone.output_shape() @@ -107,7 +108,7 @@ class SingleBiFPN(Backbone): It creates pyramid features built on top of some input feature maps. """ - def __init__(self, in_channels_list, out_channels, norm=""): + def __init__(self, in_channels_list, out_channels, norm: str="") -> None: """ Args: bottom_up (Backbone): module representing the bottom up subnetwork. @@ -121,7 +122,7 @@ def __init__(self, in_channels_list, out_channels, norm=""): out_channels (int): number of channels in the output feature maps. norm (str): the normalization to use. """ - super(SingleBiFPN, self).__init__() + super().__init__() self.out_channels = out_channels # build 5-levels bifpn @@ -161,12 +162,12 @@ def __init__(self, in_channels_list, out_channels, norm=""): lateral_conv = Conv2d( in_channels, out_channels, kernel_size=1, norm=get_norm(norm, out_channels) ) - self.add_module("lateral_{}_f{}".format(input_offset, feat_level), lateral_conv) + self.add_module(f"lateral_{input_offset}_f{feat_level}", lateral_conv) node_info.append(out_channels) num_output_connections.append(0) # generate attention weights - name = "weights_f{}_{}".format(feat_level, inputs_offsets_str) + name = f"weights_f{feat_level}_{inputs_offsets_str}" self.__setattr__( name, nn.Parameter( @@ -175,7 +176,7 @@ def __init__(self, in_channels_list, out_channels, norm=""): ) # generate convolutions after combination - name = "outputs_f{}_{}".format(feat_level, inputs_offsets_str) + name = f"outputs_f{feat_level}_{inputs_offsets_str}" self.add_module( name, Conv2d( @@ -215,7 +216,7 @@ def forward(self, feats): # reduction if input_node.size(1) != self.out_channels: - name = "lateral_{}_f{}".format(input_offset, feat_level) + name = f"lateral_{input_offset}_f{feat_level}" input_node = self.__getattr__(name)(input_node) # maybe downsample @@ -240,7 +241,7 @@ def forward(self, feats): input_nodes.append(input_node) # attention - name = "weights_f{}_{}".format(feat_level, inputs_offsets_str) + name = f"weights_f{feat_level}_{inputs_offsets_str}" weights = F.relu(self.__getattr__(name)) norm_weights = weights / (weights.sum() + 0.0001) @@ -248,7 +249,7 @@ def forward(self, feats): new_node = (norm_weights * new_node).sum(dim=-1) new_node = swish(new_node) - name = "outputs_f{}_{}".format(feat_level, inputs_offsets_str) + name = f"outputs_f{feat_level}_{inputs_offsets_str}" feats.append(self.__getattr__(name)(new_node)) num_output_connections.append(0) @@ -270,7 +271,7 @@ class BiFPN(Backbone): It creates pyramid features built on top of some input feature maps. """ - def __init__(self, bottom_up, in_features, out_channels, num_top_levels, num_repeats, norm=""): + def __init__(self, bottom_up, in_features, out_channels, num_top_levels: int, num_repeats: int, norm: str="") -> None: """ Args: bottom_up (Backbone): module representing the bottom up subnetwork. @@ -286,7 +287,7 @@ def __init__(self, bottom_up, in_features, out_channels, num_top_levels, num_rep num_repeats (int): the number of repeats of BiFPN. norm (str): the normalization to use. """ - super(BiFPN, self).__init__() + super().__init__() assert isinstance(bottom_up, Backbone) # add extra feature levels (i.e., 6 and 7) @@ -305,10 +306,10 @@ def __init__(self, bottom_up, in_features, out_channels, num_top_levels, num_rep self.in_features = in_features # generate output features - self._out_features = ["p{}".format(split_name(name)[1]) for name in in_features] + self._out_features = [f"p{split_name(name)[1]}" for name in in_features] self._out_feature_strides = { out_name: bottom_up_output_shapes[in_name].stride - for out_name, in_name in zip(self._out_features, in_features) + for out_name, in_name in zip(self._out_features, in_features, strict=False) } self._out_feature_channels = {k: out_channels for k in self._out_features} @@ -343,17 +344,15 @@ def forward(self, x): for bifpn in self.repeated_bifpn: feats = bifpn(feats) - return dict(zip(self._out_features, feats)) + return dict(zip(self._out_features, feats, strict=False)) -def _assert_strides_are_log2_contiguous(strides): +def _assert_strides_are_log2_contiguous(strides: Sequence[int]) -> None: """ Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". """ for i, stride in enumerate(strides[1:], 1): - assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( - stride, strides[i - 1] - ) + assert stride == 2 * strides[i - 1], f"Strides {stride} {strides[i - 1]} are not log2 contiguous" @BACKBONE_REGISTRY.register() diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py index 1cb2fa51e8..8b6464153b 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dla.py @@ -1,13 +1,6 @@ -import numpy as np import math from os.path import join -import fvcore.nn.weight_init as weight_init -import torch -import torch.nn.functional as F -from torch import nn -import torch.utils.model_zoo as model_zoo -from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock from detectron2.layers import ( Conv2d, DeformConv, @@ -15,15 +8,21 @@ ShapeSpec, get_norm, ) - from detectron2.modeling.backbone.backbone import Backbone from detectron2.modeling.backbone.build import BACKBONE_REGISTRY from detectron2.modeling.backbone.fpn import FPN +from detectron2.modeling.backbone.resnet import BasicStem, BottleneckBlock, DeformBottleneckBlock +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F +import torch.utils.model_zoo as model_zoo __all__ = [ + "BasicStem", "BottleneckBlock", "DeformBottleneckBlock", - "BasicStem", ] DCNV1 = False @@ -34,13 +33,13 @@ } -def get_model_url(data, name, hash): - return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash)) +def get_model_url(data, name: str, hash): + return join("http://dl.yf.io/dla/models", data, f"{name}-{hash}.pth") class BasicBlock(nn.Module): - def __init__(self, inplanes, planes, stride=1, dilation=1, norm="BN"): - super(BasicBlock, self).__init__() + def __init__(self, inplanes, planes, stride: int=1, dilation: int=1, norm: str="BN") -> None: + super().__init__() self.conv1 = nn.Conv2d( inplanes, planes, @@ -78,8 +77,8 @@ def forward(self, x, residual=None): class Bottleneck(nn.Module): expansion = 2 - def __init__(self, inplanes, planes, stride=1, dilation=1, norm="BN"): - super(Bottleneck, self).__init__() + def __init__(self, inplanes, planes, stride: int=1, dilation: int=1, norm: str="BN") -> None: + super().__init__() expansion = Bottleneck.expansion bottle_planes = planes // expansion self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) @@ -121,8 +120,8 @@ def forward(self, x, residual=None): class Root(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, residual, norm="BN"): - super(Root, self).__init__() + def __init__(self, in_channels, out_channels, kernel_size: int, residual, norm: str="BN") -> None: + super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, 1, stride=1, bias=False, padding=(kernel_size - 1) // 2 ) @@ -148,15 +147,15 @@ def __init__( block, in_channels, out_channels, - stride=1, - level_root=False, - root_dim=0, - root_kernel_size=1, - dilation=1, - root_residual=False, - norm="BN", - ): - super(Tree, self).__init__() + stride: int=1, + level_root: bool=False, + root_dim: int=0, + root_kernel_size: int=1, + dilation: int=1, + root_residual: bool=False, + norm: str="BN", + ) -> None: + super().__init__() if root_dim == 0: root_dim = 2 * out_channels if level_root: @@ -221,12 +220,12 @@ def forward(self, x, residual=None, children=None): class DLA(nn.Module): def __init__( - self, num_layers, levels, channels, block=BasicBlock, residual_root=False, norm="BN" - ): + self, num_layers: int, levels, channels, block=BasicBlock, residual_root: bool=False, norm: str="BN" + ) -> None: """ Args: """ - super(DLA, self).__init__() + super().__init__() self.norm = norm self.channels = channels self.base_layer = nn.Sequential( @@ -277,10 +276,10 @@ def __init__( norm=norm, ) self.load_pretrained_model( - data="imagenet", name="dla{}".format(num_layers), hash=HASH[num_layers] + data="imagenet", name=f"dla{num_layers}", hash=HASH[num_layers] ) - def load_pretrained_model(self, data, name, hash): + def load_pretrained_model(self, data, name: str, hash) -> None: model_url = get_model_url(data, name, hash) model_weights = model_zoo.load_url(model_url) num_classes = len(model_weights[list(model_weights.keys())[-1]]) @@ -290,7 +289,7 @@ def load_pretrained_model(self, data, name, hash): print("Loading pretrained") self.load_state_dict(model_weights, strict=False) - def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + def _make_conv_level(self, inplanes, planes, convs, stride: int=1, dilation: int=1): modules = [] for i in range(convs): modules.extend( @@ -315,12 +314,12 @@ def forward(self, x): y = [] x = self.base_layer(x) for i in range(6): - x = getattr(self, "level{}".format(i))(x) + x = getattr(self, f"level{i}")(x) y.append(x) return y -def fill_up_weights(up): +def fill_up_weights(up) -> None: w = up.weight.data f = math.ceil(w.size(2) / 2) c = (2 * f - 1 - f % 2) / (2.0 * f) @@ -332,8 +331,8 @@ def fill_up_weights(up): class _DeformConv(nn.Module): - def __init__(self, chi, cho, norm="BN"): - super(_DeformConv, self).__init__() + def __init__(self, chi, cho, norm: str="BN") -> None: + super().__init__() self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True)) if DCNV1: self.offset = Conv2d(chi, 18, kernel_size=3, stride=1, padding=1, dilation=1) @@ -363,8 +362,8 @@ def forward(self, x): class IDAUp(nn.Module): - def __init__(self, o, channels, up_f, norm="BN"): - super(IDAUp, self).__init__() + def __init__(self, o, channels, up_f, norm: str="BN") -> None: + super().__init__() for i in range(1, len(channels)): c = channels[i] f = int(up_f[i]) @@ -380,7 +379,7 @@ def __init__(self, o, channels, up_f, norm="BN"): setattr(self, "up_" + str(i), up) setattr(self, "node_" + str(i), node) - def forward(self, layers, startp, endp): + def forward(self, layers, startp, endp) -> None: for i in range(startp + 1, endp): upsample = getattr(self, "up_" + str(i - startp)) project = getattr(self, "proj_" + str(i - startp)) @@ -390,8 +389,8 @@ def forward(self, layers, startp, endp): class DLAUp(nn.Module): - def __init__(self, startp, channels, scales, in_channels=None, norm="BN"): - super(DLAUp, self).__init__() + def __init__(self, startp, channels, scales, in_channels=None, norm: str="BN") -> None: + super().__init__() self.startp = startp if in_channels is None: in_channels = channels @@ -402,7 +401,7 @@ def __init__(self, startp, channels, scales, in_channels=None, norm="BN"): j = -i - 2 setattr( self, - "ida_{}".format(i), + f"ida_{i}", IDAUp(channels[j], in_channels[j:], scales[j:] // scales[j], norm=norm), ) scales[j + 1 :] = scales[j] @@ -411,7 +410,7 @@ def __init__(self, startp, channels, scales, in_channels=None, norm="BN"): def forward(self, layers): out = [layers[-1]] # start with 32 for i in range(len(layers) - self.startp - 1): - ida = getattr(self, "ida_{}".format(i)) + ida = getattr(self, f"ida_{i}") ida(layers, len(layers) - i - 2, len(layers)) out.insert(0, layers[-1]) return out @@ -424,8 +423,8 @@ def forward(self, layers): class DLASeg(Backbone): - def __init__(self, num_layers, out_features, use_dla_up=True, ms_output=False, norm="BN"): - super(DLASeg, self).__init__() + def __init__(self, num_layers: int, out_features, use_dla_up: bool=True, ms_output: bool=False, norm: str="BN") -> None: + super().__init__() # depth = 34 levels, channels, Block = DLA_CONFIGS[num_layers] self.base = DLA( @@ -449,8 +448,8 @@ def __init__(self, num_layers, out_features, use_dla_up=True, ms_output=False, n norm=norm, ) self._out_features = out_features - self._out_feature_channels = {"dla{}".format(i): channels[i] for i in range(6)} - self._out_feature_strides = {"dla{}".format(i): 2**i for i in range(6)} + self._out_feature_channels = {f"dla{i}": channels[i] for i in range(6)} + self._out_feature_strides = {f"dla{i}": 2**i for i in range(6)} self._size_divisibility = 32 @property @@ -468,14 +467,14 @@ def forward(self, x): self.ida_up(y, 0, len(y)) ret = {} for i in range(self.last_level - self.first_level): - out_feature = "dla{}".format(i) + out_feature = f"dla{i}" if out_feature in self._out_features: ret[out_feature] = y[i] else: ret = {} st = self.first_level if self.use_dla_up else 0 for i in range(self.last_level - st): - out_feature = "dla{}".format(i + st) + out_feature = f"dla{i + st}" if out_feature in self._out_features: ret[out_feature] = x[i] @@ -505,7 +504,7 @@ class LastLevelP6P7(nn.Module): C5 feature. """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels) -> None: super().__init__() self.num_levels = 2 self.in_feature = "dla5" diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py index 8cc478ece9..3e95697171 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/dlafpn.py @@ -1,39 +1,36 @@ #!/usr/bin/env python -# -*- coding: utf-8 -*- # this file is from https://github.com/ucbdrive/dla/blob/master/dla.py. import math from os.path import join -import numpy as np +from detectron2.layers import Conv2d, ModulatedDeformConv, ShapeSpec +from detectron2.layers.batch_norm import get_norm +from detectron2.modeling.backbone import FPN, Backbone +from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +import fvcore.nn.weight_init as weight_init +import numpy as np import torch from torch import nn -import torch.utils.model_zoo as model_zoo import torch.nn.functional as F -import fvcore.nn.weight_init as weight_init - -from detectron2.modeling.backbone import FPN -from detectron2.layers import ShapeSpec, ModulatedDeformConv, Conv2d -from detectron2.modeling.backbone.build import BACKBONE_REGISTRY -from detectron2.layers.batch_norm import get_norm -from detectron2.modeling.backbone import Backbone +import torch.utils.model_zoo as model_zoo WEB_ROOT = "http://dl.yf.io/dla/models" -def get_model_url(data, name, hash): - return join("http://dl.yf.io/dla/models", data, "{}-{}.pth".format(name, hash)) +def get_model_url(data, name: str, hash): + return join("http://dl.yf.io/dla/models", data, f"{name}-{hash}.pth") -def conv3x3(in_planes, out_planes, stride=1): +def conv3x3(in_planes, out_planes, stride: int=1): "3x3 convolution with padding" return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class BasicBlock(nn.Module): - def __init__(self, cfg, inplanes, planes, stride=1, dilation=1): - super(BasicBlock, self).__init__() + def __init__(self, cfg, inplanes, planes, stride: int=1, dilation: int=1) -> None: + super().__init__() self.conv1 = nn.Conv2d( inplanes, planes, @@ -71,8 +68,8 @@ def forward(self, x, residual=None): class Bottleneck(nn.Module): expansion = 2 - def __init__(self, cfg, inplanes, planes, stride=1, dilation=1): - super(Bottleneck, self).__init__() + def __init__(self, cfg, inplanes, planes, stride: int=1, dilation: int=1) -> None: + super().__init__() expansion = Bottleneck.expansion bottle_planes = planes // expansion self.conv1 = nn.Conv2d(inplanes, bottle_planes, kernel_size=1, bias=False) @@ -114,8 +111,8 @@ def forward(self, x, residual=None): class Root(nn.Module): - def __init__(self, cfg, in_channels, out_channels, kernel_size, residual): - super(Root, self).__init__() + def __init__(self, cfg, in_channels, out_channels, kernel_size: int, residual) -> None: + super().__init__() self.conv = nn.Conv2d( in_channels, out_channels, @@ -147,14 +144,14 @@ def __init__( block, in_channels, out_channels, - stride=1, - level_root=False, - root_dim=0, - root_kernel_size=1, - dilation=1, - root_residual=False, - ): - super(Tree, self).__init__() + stride: int=1, + level_root: bool=False, + root_dim: int=0, + root_kernel_size: int=1, + dilation: int=1, + root_residual: bool=False, + ) -> None: + super().__init__() if root_dim == 0: root_dim = 2 * out_channels if level_root: @@ -220,12 +217,12 @@ def forward(self, x, residual=None, children=None): class DLA(Backbone): - def __init__(self, cfg, levels, channels, block=BasicBlock, residual_root=False): - super(DLA, self).__init__() + def __init__(self, cfg, levels, channels, block=BasicBlock, residual_root: bool=False) -> None: + super().__init__() self.cfg = cfg self.channels = channels - self._out_features = ["dla{}".format(i) for i in range(6)] + self._out_features = [f"dla{i}" for i in range(6)] self._out_feature_channels = {k: channels[i] for i, k in enumerate(self._out_features)} self._out_feature_strides = {k: 2**i for i, k in enumerate(self._out_features)} @@ -284,7 +281,7 @@ def __init__(self, cfg, levels, channels, block=BasicBlock, residual_root=False) self.load_pretrained_model(data="imagenet", name="dla34", hash="ba72cf86") - def load_pretrained_model(self, data, name, hash): + def load_pretrained_model(self, data, name: str, hash) -> None: model_url = get_model_url(data, name, hash) model_weights = model_zoo.load_url(model_url) del model_weights["fc.weight"] @@ -292,7 +289,7 @@ def load_pretrained_model(self, data, name, hash): print("Loading pretrained DLA!") self.load_state_dict(model_weights, strict=True) - def _make_conv_level(self, inplanes, planes, convs, stride=1, dilation=1): + def _make_conv_level(self, inplanes, planes, convs, stride: int=1, dilation: int=1): modules = [] for i in range(convs): modules.extend( @@ -317,13 +314,13 @@ def forward(self, x): y = {} x = self.base_layer(x) for i in range(6): - name = "level{}".format(i) + name = f"level{i}" x = getattr(self, name)(x) - y["dla{}".format(i)] = x + y[f"dla{i}"] = x return y -def fill_up_weights(up): +def fill_up_weights(up) -> None: w = up.weight.data f = math.ceil(w.size(2) / 2) c = (2 * f - 1 - f % 2) / (2.0 * f) @@ -335,8 +332,8 @@ def fill_up_weights(up): class Conv(nn.Module): - def __init__(self, chi, cho, norm): - super(Conv, self).__init__() + def __init__(self, chi, cho, norm) -> None: + super().__init__() self.conv = nn.Sequential( nn.Conv2d(chi, cho, kernel_size=1, stride=1, bias=False), get_norm(norm, cho), @@ -348,8 +345,8 @@ def forward(self, x): class DeformConv(nn.Module): - def __init__(self, chi, cho, norm): - super(DeformConv, self).__init__() + def __init__(self, chi, cho, norm) -> None: + super().__init__() self.actf = nn.Sequential(get_norm(norm, cho), nn.ReLU(inplace=True)) self.offset = Conv2d(chi, 27, kernel_size=3, stride=1, padding=1, dilation=1) self.conv = ModulatedDeformConv( @@ -369,8 +366,8 @@ def forward(self, x): class IDAUp(nn.Module): - def __init__(self, o, channels, up_f, norm="FrozenBN", node_type=Conv): - super(IDAUp, self).__init__() + def __init__(self, o, channels, up_f, norm: str="FrozenBN", node_type=Conv) -> None: + super().__init__() for i in range(1, len(channels)): c = channels[i] f = int(up_f[i]) @@ -386,7 +383,7 @@ def __init__(self, o, channels, up_f, norm="FrozenBN", node_type=Conv): setattr(self, "up_" + str(i), up) setattr(self, "node_" + str(i), node) - def forward(self, layers, startp, endp): + def forward(self, layers, startp, endp) -> None: for i in range(startp + 1, endp): upsample = getattr(self, "up_" + str(i - startp)) project = getattr(self, "proj_" + str(i - startp)) @@ -402,8 +399,8 @@ def forward(self, layers, startp, endp): class DLAUP(Backbone): - def __init__(self, bottom_up, in_features, norm, dlaup_node="conv"): - super(DLAUP, self).__init__() + def __init__(self, bottom_up, in_features, norm, dlaup_node: str="conv") -> None: + super().__init__() assert isinstance(bottom_up, Backbone) self.bottom_up = bottom_up input_shapes = bottom_up.output_shape() @@ -411,12 +408,12 @@ def __init__(self, bottom_up, in_features, norm, dlaup_node="conv"): in_channels = [input_shapes[f].channels for f in in_features] in_levels = [int(math.log2(input_shapes[f].stride)) for f in in_features] self.in_features = in_features - out_features = ["dlaup{}".format(l) for l in in_levels] + out_features = [f"dlaup{l}" for l in in_levels] self._out_features = out_features self._out_feature_channels = { - "dlaup{}".format(l): in_channels[i] for i, l in enumerate(in_levels) + f"dlaup{l}": in_channels[i] for i, l in enumerate(in_levels) } - self._out_feature_strides = {"dlaup{}".format(l): 2**l for l in in_levels} + self._out_feature_strides = {f"dlaup{l}": 2**l for l in in_levels} print("self._out_features", self._out_features) print("self._out_feature_channels", self._out_feature_channels) @@ -433,7 +430,7 @@ def __init__(self, bottom_up, in_features, norm, dlaup_node="conv"): j = -i - 2 setattr( self, - "ida_{}".format(i), + f"ida_{i}", IDAUp( channels[j], in_channels[j:], @@ -454,17 +451,17 @@ def forward(self, x): layers = [bottom_up_features[f] for f in self.in_features] out = [layers[-1]] # start with 32 for i in range(len(layers) - 1): - ida = getattr(self, "ida_{}".format(i)) + ida = getattr(self, f"ida_{i}") ida(layers, len(layers) - i - 2, len(layers)) out.insert(0, layers[-1]) ret = {} - for k, v in zip(self._out_features, out): + for k, v in zip(self._out_features, out, strict=False): ret[k] = v # import pdb; pdb.set_trace() return ret -def dla34(cfg, pretrained=None): # DLA-34 +def dla34(cfg, pretrained: bool | None=None): # DLA-34 model = DLA(cfg, [1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512], block=BasicBlock) return model @@ -475,7 +472,7 @@ class LastLevelP6P7(nn.Module): C5 feature. """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels) -> None: super().__init__() self.num_levels = 2 self.in_feature = "dla5" @@ -500,7 +497,7 @@ def build_dla_fpn3_backbone(cfg, input_shape: ShapeSpec): """ depth_to_creator = {"dla34": dla34} - bottom_up = depth_to_creator["dla{}".format(cfg.MODEL.DLA.NUM_LAYERS)](cfg) + bottom_up = depth_to_creator[f"dla{cfg.MODEL.DLA.NUM_LAYERS}"](cfg) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS @@ -526,7 +523,7 @@ def build_dla_fpn5_backbone(cfg, input_shape: ShapeSpec): """ depth_to_creator = {"dla34": dla34} - bottom_up = depth_to_creator["dla{}".format(cfg.MODEL.DLA.NUM_LAYERS)](cfg) + bottom_up = depth_to_creator[f"dla{cfg.MODEL.DLA.NUM_LAYERS}"](cfg) in_features = cfg.MODEL.FPN.IN_FEATURES out_channels = cfg.MODEL.FPN.OUT_CHANNELS in_channels_top = bottom_up.output_shape()["dla5"].channels @@ -553,7 +550,7 @@ def build_dlaup_backbone(cfg, input_shape: ShapeSpec): """ depth_to_creator = {"dla34": dla34} - bottom_up = depth_to_creator["dla{}".format(cfg.MODEL.DLA.NUM_LAYERS)](cfg) + bottom_up = depth_to_creator[f"dla{cfg.MODEL.DLA.NUM_LAYERS}"](cfg) backbone = DLAUP( bottom_up=bottom_up, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py index 228b822bbf..4ce285b6c6 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/fpn_p5.py @@ -1,13 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import fvcore.nn.weight_init as weight_init -import torch.nn.functional as F -from torch import nn - from detectron2.layers import ShapeSpec - -from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.build import BACKBONE_REGISTRY +from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.resnet import build_resnet_backbone +import fvcore.nn.weight_init as weight_init +from torch import nn +import torch.nn.functional as F class LastLevelP6P7_P5(nn.Module): @@ -16,7 +14,7 @@ class LastLevelP6P7_P5(nn.Module): C5 feature. """ - def __init__(self, in_channels, out_channels): + def __init__(self, in_channels, out_channels) -> None: super().__init__() self.num_levels = 2 self.in_feature = "p5" diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py index b35f9b2413..0532b20e02 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/backbone/res2net.py @@ -1,11 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # This file is modified from https://github.com/Res2Net/Res2Net-detectron2/blob/master/detectron2/modeling/backbone/resnet.py # The original file is under Apache-2.0 License -import numpy as np -import fvcore.nn.weight_init as weight_init -import torch -import torch.nn.functional as F -from torch import nn from detectron2.layers import ( CNNBlockBase, @@ -15,22 +10,27 @@ ShapeSpec, get_norm, ) - from detectron2.modeling.backbone import Backbone -from detectron2.modeling.backbone.fpn import FPN from detectron2.modeling.backbone.build import BACKBONE_REGISTRY -from .fpn_p5 import LastLevelP6P7_P5 +from detectron2.modeling.backbone.fpn import FPN +import fvcore.nn.weight_init as weight_init +import numpy as np +import torch +from torch import nn +import torch.nn.functional as F + from .bifpn import BiFPN +from .fpn_p5 import LastLevelP6P7_P5 __all__ = [ - "ResNetBlockBase", "BasicBlock", + "BasicStem", "BottleneckBlock", "DeformBottleneckBlock", - "BasicStem", "ResNet", - "make_stage", + "ResNetBlockBase", "build_res2net_backbone", + "make_stage", ] @@ -46,7 +46,7 @@ class BasicBlock(CNNBlockBase): and a projection shortcut if needed. """ - def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + def __init__(self, in_channels, out_channels, *, stride: int=1, norm: str="BN") -> None: """ Args: in_channels (int): Number of input channels. @@ -119,14 +119,14 @@ def __init__( out_channels, *, bottleneck_channels, - stride=1, - num_groups=1, - norm="BN", - stride_in_1x1=False, - dilation=1, - basewidth=26, - scale=4, - ): + stride: int=1, + num_groups: int=1, + norm: str="BN", + stride_in_1x1: bool=False, + dilation: int=1, + basewidth: int=26, + scale: int=4, + ) -> None: """ Args: bottleneck_channels (int): number of output channels for the 3x3 @@ -180,7 +180,7 @@ def __init__( convs = [] bns = [] - for i in range(self.nums): + for _i in range(self.nums): convs.append( nn.Conv2d( width, @@ -278,16 +278,16 @@ def __init__( out_channels, *, bottleneck_channels, - stride=1, - num_groups=1, - norm="BN", - stride_in_1x1=False, - dilation=1, - deform_modulated=False, - deform_num_groups=1, - basewidth=26, - scale=4, - ): + stride: int=1, + num_groups: int=1, + norm: str="BN", + stride_in_1x1: bool=False, + dilation: int=1, + deform_modulated: bool=False, + deform_num_groups: int=1, + basewidth: int=26, + scale: int=4, + ) -> None: super().__init__(in_channels, out_channels, stride) self.deform_modulated = deform_modulated @@ -367,7 +367,7 @@ def __init__( conv2_offsets = [] convs = [] bns = [] - for i in range(self.nums): + for _i in range(self.nums): conv2_offsets.append( Conv2d( width, @@ -488,7 +488,7 @@ def forward(self, x): return out -def make_stage(block_class, num_blocks, first_stride, *, in_channels, out_channels, **kwargs): +def make_stage(block_class, num_blocks: int, first_stride, *, in_channels, out_channels, **kwargs): """ Create a list of blocks just like those in a ResNet stage. Args: @@ -521,7 +521,7 @@ class BasicStem(CNNBlockBase): The standard ResNet stem (layers before the first residual block). """ - def __init__(self, in_channels=3, out_channels=64, norm="BN"): + def __init__(self, in_channels: int=3, out_channels: int=64, norm: str="BN") -> None: """ Args: norm (str or callable): norm after the first conv layer. @@ -574,7 +574,7 @@ def forward(self, x): class ResNet(Backbone): - def __init__(self, stem, stages, num_classes=None, out_features=None): + def __init__(self, stem, stages, num_classes: int | None=None, out_features=None) -> None: """ Args: stem (nn.Module): a stem module @@ -586,7 +586,7 @@ def __init__(self, stem, stages, num_classes=None, out_features=None): be returned in forward. Can be anything in "stem", "linear", or "res2" ... If None, will return the output of the last layer. """ - super(ResNet, self).__init__() + super().__init__() self.stem = stem self.num_classes = num_classes @@ -654,7 +654,7 @@ def output_shape(self): for name in self._out_features } - def freeze(self, freeze_at=0): + def freeze(self, freeze_at: int=0): """ Freeze the first several stages of the ResNet. Commonly used in fine-tuning. @@ -705,7 +705,7 @@ def build_res2net_backbone(cfg, input_shape): deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS # fmt: on - assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + assert res5_dilation in {1, 2}, f"res5_dilation cannot be {res5_dilation}." num_blocks_per_stage = { 18: [2, 2, 2, 2], diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py index 247653c23a..b64e6236dd 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import cv2 import numpy as np import torch @@ -18,13 +20,13 @@ def _get_color_image(heatmap): return color_map -def _blend_image(image, color_map, a=0.7): +def _blend_image(image, color_map, a: float=0.7): color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) ret = np.clip(image * (1 - a) + color_map * a, 0, 255).astype(np.uint8) return ret -def _blend_image_heatmaps(image, color_maps, a=0.7): +def _blend_image_heatmaps(image, color_maps, a: float=0.7): merges = np.zeros((image.shape[0], image.shape[1], 3), np.float32) for color_map in color_maps: color_map = cv2.resize(color_map, (image.shape[1], image.shape[0])) @@ -78,12 +80,12 @@ def debug_train( gt_instances, flattened_hms, reg_targets, - labels, + labels: Sequence[str], pos_inds, shapes_per_level, locations, - strides, -): + strides: Sequence[int], +) -> None: """ images: N x 3 x H x W flattened_hms: LNHiWi x C @@ -105,7 +107,7 @@ def debug_train( for l in range(len(gt_hms)): color_map = _get_color_image(gt_hms[l][i].detach().cpu().numpy()) color_maps.append(color_map) - cv2.imshow("gthm_{}".format(l), color_map) + cv2.imshow(f"gthm_{l}", color_map) blend = _blend_image_heatmaps(image.copy(), color_maps) if gt_instances is not None: bboxes = gt_instances[i].gt_boxes.tensor @@ -155,22 +157,26 @@ def debug_test( images, logits_pred, reg_pred, - agn_hm_pred=[], - preds=[], - vis_thresh=0.3, - debug_show_name=False, - mult_agn=False, -): + agn_hm_pred=None, + preds=None, + vis_thresh: float=0.3, + debug_show_name: bool=False, + mult_agn: bool=False, +) -> None: """ images: N x 3 x H x W class_target: LNHiWi x C cat_agn_heatmap: LNHiWi shapes_per_level: L x 2 [(H_i, W_i)] """ - N = len(images) + if preds is None: + preds = [] + if agn_hm_pred is None: + agn_hm_pred = [] + len(images) for i in range(len(images)): image = images[i].detach().cpu().numpy().transpose(1, 2, 0) - result = image.copy().astype(np.uint8) + image.copy().astype(np.uint8) pred_image = image.copy().astype(np.uint8) color_maps = [] L = len(logits_pred) @@ -189,7 +195,7 @@ def debug_test( logits_pred[l][i] = logits_pred[l][i] * agn_hm_pred[l][i] color_map = _get_color_image(logits_pred[l][i].detach().cpu().numpy()) color_maps.append(color_map) - cv2.imshow("predhm_{}".format(l), color_map) + cv2.imshow(f"predhm_{l}", color_map) if debug_show_name: from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES @@ -240,7 +246,7 @@ def debug_test( if agn_hm_pred[l] is not None: agn_hm_ = agn_hm_pred[l][i, 0, :, :, None].detach().cpu().numpy() agn_hm_ = (agn_hm_ * np.array([255, 255, 255]).reshape(1, 1, 3)).astype(np.uint8) - cv2.imshow("agn_hm_{}".format(l), agn_hm_) + cv2.imshow(f"agn_hm_{l}", agn_hm_) blend = _blend_image_heatmaps(image.copy(), color_maps) cv2.imshow("blend", blend) cv2.imshow("preds", pred_image) @@ -252,8 +258,8 @@ def debug_test( def debug_second_stage( - images, instances, proposals=None, vis_thresh=0.3, save_debug=False, debug_show_name=False -): + images, instances, proposals=None, vis_thresh: float=0.3, save_debug: bool=False, debug_show_name: bool=False +) -> None: images = _imagelist_to_tensor(images) if debug_show_name: from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES @@ -332,5 +338,5 @@ def debug_second_stage( if save_debug: global cnt cnt += 1 - cv2.imwrite("output/save_debug/{}.jpg".format(cnt), proposal_image) + cv2.imwrite(f"output/save_debug/{cnt}.jpg", proposal_image) cv2.waitKey() diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py index 53b28eb18a..332b14bb17 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet.py @@ -1,19 +1,19 @@ -import torch -from torch import nn +from collections.abc import Sequence -from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY +from detectron2.config import configurable from detectron2.layers import cat -from detectron2.structures import Instances, Boxes +from detectron2.modeling.proposal_generator.build import PROPOSAL_GENERATOR_REGISTRY +from detectron2.structures import Boxes, Instances from detectron2.utils.comm import get_world_size -from detectron2.config import configurable +import torch +from torch import nn -from ..layers.heatmap_focal_loss import heatmap_focal_loss_jit -from ..layers.heatmap_focal_loss import binary_heatmap_focal_loss_jit +from ..debug import debug_test, debug_train +from ..layers.heatmap_focal_loss import binary_heatmap_focal_loss_jit, heatmap_focal_loss_jit from ..layers.iou_loss import IOULoss from ..layers.ml_nms import ml_nms -from ..debug import debug_train, debug_test -from .utils import reduce_sum, _transpose from .centernet_head import CenterNetHead +from .utils import _transpose, reduce_sum __all__ = ["CenterNet"] @@ -26,48 +26,54 @@ class CenterNet(nn.Module): def __init__( self, # input_shape: Dict[str, ShapeSpec], - in_channels=256, + in_channels: int=256, *, - num_classes=80, + num_classes: int=80, in_features=("p3", "p4", "p5", "p6", "p7"), - strides=(8, 16, 32, 64, 128), - score_thresh=0.05, - hm_min_overlap=0.8, - loc_loss_type="giou", - min_radius=4, - hm_focal_alpha=0.25, - hm_focal_beta=4, - loss_gamma=2.0, - reg_weight=2.0, - not_norm_reg=True, - with_agn_hm=False, - only_proposal=False, - as_proposal=False, - not_nms=False, - pos_weight=1.0, - neg_weight=1.0, - sigmoid_clamp=1e-4, + strides: Sequence[int]=(8, 16, 32, 64, 128), + score_thresh: float=0.05, + hm_min_overlap: float=0.8, + loc_loss_type: str="giou", + min_radius: int=4, + hm_focal_alpha: float=0.25, + hm_focal_beta: int=4, + loss_gamma: float=2.0, + reg_weight: float=2.0, + not_norm_reg: bool=True, + with_agn_hm: bool=False, + only_proposal: bool=False, + as_proposal: bool=False, + not_nms: bool=False, + pos_weight: float=1.0, + neg_weight: float=1.0, + sigmoid_clamp: float=1e-4, ignore_high_fp=-1.0, - center_nms=False, - sizes_of_interest=[[0, 80], [64, 160], [128, 320], [256, 640], [512, 10000000]], - more_pos=False, - more_pos_thresh=0.2, - more_pos_topk=9, - pre_nms_topk_train=1000, - pre_nms_topk_test=1000, - post_nms_topk_train=100, - post_nms_topk_test=100, - nms_thresh_train=0.6, - nms_thresh_test=0.6, - no_reduce=False, - not_clamp_box=False, - debug=False, - vis_thresh=0.5, - pixel_mean=[103.530, 116.280, 123.675], - pixel_std=[1.0, 1.0, 1.0], - device="cuda", + center_nms: bool=False, + sizes_of_interest=None, + more_pos: bool=False, + more_pos_thresh: float=0.2, + more_pos_topk: int=9, + pre_nms_topk_train: int=1000, + pre_nms_topk_test: int=1000, + post_nms_topk_train: int=100, + post_nms_topk_test: int=100, + nms_thresh_train: float=0.6, + nms_thresh_test: float=0.6, + no_reduce: bool=False, + not_clamp_box: bool=False, + debug: bool=False, + vis_thresh: float=0.5, + pixel_mean=None, + pixel_std=None, + device: str="cuda", centernet_head=None, - ): + ) -> None: + if pixel_std is None: + pixel_std = [1.0, 1.0, 1.0] + if pixel_mean is None: + pixel_mean = [103.53, 116.28, 123.675] + if sizes_of_interest is None: + sizes_of_interest = [[0, 80], [64, 160], [128, 320], [256, 640], [512, 10000000]] super().__init__() self.num_classes = num_classes self.in_features = in_features @@ -245,7 +251,7 @@ def forward(self, images, features_dict, gt_instances): return proposals, losses def losses( - self, pos_inds, labels, reg_targets, flattened_hms, logits_pred, reg_pred, agn_hm_pred + self, pos_inds, labels: Sequence[str], reg_targets, flattened_hms, logits_pred, reg_pred, agn_hm_pred ): """ Inputs: @@ -556,7 +562,7 @@ def _get_reg_targets(self, reg_targets, dist, mask, area): reg_targets_per_im[min_dist == INF] = -INF return reg_targets_per_im - def _create_heatmaps_from_dist(self, dist, labels, channels): + def _create_heatmaps_from_dist(self, dist, labels: Sequence[str], channels): """ dist: M x N labels: N @@ -601,7 +607,7 @@ def _flatten_outputs(self, clss, reg_pred, agn_hm_pred): ) return clss, reg_pred, agn_hm_pred - def get_center3x3(self, locations, centers, strides): + def get_center3x3(self, locations, centers, strides: Sequence[int]): """ Inputs: locations: M x 2 @@ -658,7 +664,7 @@ def inference(self, images, clss_per_level, reg_pred_per_level, agn_hm_pred_per_ @torch.no_grad() def predict_instances( - self, grids, logits_pred, reg_pred, image_sizes, agn_hm_pred, is_proposal=False + self, grids, logits_pred, reg_pred, image_sizes: Sequence[int], agn_hm_pred, is_proposal: bool=False ): sampled_boxes = [] for l in range(len(grids)): @@ -673,14 +679,14 @@ def predict_instances( is_proposal=is_proposal, ) ) - boxlists = list(zip(*sampled_boxes)) + boxlists = list(zip(*sampled_boxes, strict=False)) boxlists = [Instances.cat(boxlist) for boxlist in boxlists] boxlists = self.nms_and_topK(boxlists, nms=not self.not_nms) return boxlists @torch.no_grad() def predict_single_level( - self, grids, heatmap, reg_pred, image_sizes, agn_hm, level, is_proposal=False + self, grids, heatmap, reg_pred, image_sizes: Sequence[int], agn_hm, level, is_proposal: bool=False ): N, C, H, W = heatmap.shape # put in the same format as grids @@ -746,7 +752,7 @@ def predict_single_level( return results @torch.no_grad() - def nms_and_topK(self, boxlists, nms=True): + def nms_and_topK(self, boxlists, nms: bool=True): num_images = len(boxlists) results = [] for i in range(num_images): @@ -795,7 +801,7 @@ def _add_more_pos(self, reg_pred, gt_instances, shapes_per_level): c33_reg_loss.view(N * L, K)[level_masks.view(N * L), 4] = 0 # real center c33_reg_loss = c33_reg_loss.view(N, L * K) if N == 0: - loss_thresh = c33_reg_loss.new_ones((N)).float() + loss_thresh = c33_reg_loss.new_ones(N).float() else: loss_thresh = torch.kthvalue(c33_reg_loss, self.more_pos_topk, dim=1)[0] # N loss_thresh[loss_thresh > self.more_pos_thresh] = self.more_pos_thresh # N @@ -899,7 +905,7 @@ def _get_c33_inds(self, gt_instances, shapes_per_level): c33_regs = torch.cat(c33_regs, dim=0) c33_masks = torch.cat(c33_masks, dim=0) else: - labels = shapes_per_level.new_zeros((0)).long() + labels = shapes_per_level.new_zeros(0).long() level_masks = shapes_per_level.new_zeros((0, L)).bool() c33_inds = shapes_per_level.new_zeros((0, L, K)).long() c33_regs = shapes_per_level.new_zeros((0, L, K, 4)).float() diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py index 3f939233a1..e2e1852e27 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/centernet_head.py @@ -1,18 +1,19 @@ import math + +from detectron2.config import configurable +from detectron2.layers import get_norm import torch from torch import nn from torch.nn import functional as F -from detectron2.layers import get_norm -from detectron2.config import configurable from ..layers.deform_conv import DFConv2d __all__ = ["CenterNetHead"] class Scale(nn.Module): - def __init__(self, init_value=1.0): - super(Scale, self).__init__() + def __init__(self, init_value: float=1.0) -> None: + super().__init__() self.scale = nn.Parameter(torch.FloatTensor([init_value])) def forward(self, input): @@ -25,18 +26,18 @@ def __init__( self, # input_shape: List[ShapeSpec], in_channels, - num_levels, + num_levels: int, *, - num_classes=80, - with_agn_hm=False, - only_proposal=False, - norm="GN", - num_cls_convs=4, - num_box_convs=4, - num_share_convs=0, - use_deformable=False, - prior_prob=0.01, - ): + num_classes: int=80, + with_agn_hm: bool=False, + only_proposal: bool=False, + norm: str="GN", + num_cls_convs: int=4, + num_box_convs: int=4, + num_share_convs: int=0, + use_deformable: bool=False, + prior_prob: float=0.01, + ) -> None: super().__init__() self.num_classes = num_classes self.with_agn_hm = with_agn_hm @@ -82,7 +83,7 @@ def __init__( elif norm != "": tower.append(get_norm(norm, channel)) tower.append(nn.ReLU()) - self.add_module("{}_tower".format(head), nn.Sequential(*tower)) + self.add_module(f"{head}_tower", nn.Sequential(*tower)) self.bbox_pred = nn.Conv2d( in_channels, 4, kernel_size=self.out_kernel, stride=1, padding=self.out_kernel // 2 @@ -129,7 +130,7 @@ def __init__( def from_config(cls, cfg, input_shape): ret = { # 'input_shape': input_shape, - "in_channels": [s.channels for s in input_shape][0], + "in_channels": next(s.channels for s in input_shape), "num_levels": len(input_shape), "num_classes": cfg.MODEL.CENTERNET.NUM_CLASSES, "with_agn_hm": cfg.MODEL.CENTERNET.WITH_AGN_HM, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py index 527d362d90..ea962943ca 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/dense_heads/utils.py @@ -1,9 +1,9 @@ -import torch from detectron2.utils.comm import get_world_size +import torch # from .data import CenterNetCrop -__all__ = ["reduce_sum", "_transpose"] +__all__ = ["_transpose", "reduce_sum"] INF = 1000000000 @@ -18,7 +18,7 @@ def _transpose(training_targets, num_loc_list): training_targets[im_i] = torch.split(training_targets[im_i], num_loc_list, dim=0) targets_level_first = [] - for targets_per_level in zip(*training_targets): + for targets_per_level in zip(*training_targets, strict=False): targets_level_first.append(torch.cat(targets_per_level, dim=0)) return targets_level_first diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py index 396aa9554a..2f938fe0ae 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/deform_conv.py @@ -1,8 +1,7 @@ +from detectron2.layers import Conv2d import torch from torch import nn -from detectron2.layers import Conv2d - class _NewEmptyTensorOp(torch.autograd.Function): @staticmethod @@ -23,16 +22,16 @@ def __init__( self, in_channels, out_channels, - with_modulated_dcn=True, - kernel_size=3, - stride=1, - groups=1, - dilation=1, - deformable_groups=1, - bias=False, + with_modulated_dcn: bool=True, + kernel_size: int=3, + stride: int=1, + groups: int=1, + dilation: int=1, + deformable_groups: int=1, + bias: bool=False, padding=None, - ): - super(DFConv2d, self).__init__() + ) -> None: + super().__init__() if isinstance(kernel_size, (list, tuple)): assert isinstance(stride, (list, tuple)) assert isinstance(dilation, (list, tuple)) @@ -91,7 +90,7 @@ def __init__( self.dilation = dilation self.offset_split = offset_base_channels * deformable_groups * 2 - def forward(self, x, return_offset=False): + def forward(self, x, return_offset: bool=False): if x.numel() > 0: if not self.with_modulated_dcn: offset_mask = self.offset(x) @@ -108,8 +107,8 @@ def forward(self, x, return_offset=False): output_shape = [ (i + 2 * p - (di * (k - 1) + 1)) // d + 1 for i, p, di, k, d in zip( - x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride, strict=False ) ] - output_shape = [x.shape[0], self.conv.weight.shape[0]] + output_shape + output_shape = [x.shape[0], self.conv.weight.shape[0], *output_shape] return _NewEmptyTensorOp.apply(x, output_shape) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py index 893fd9ffab..f8ec4afb5a 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py @@ -1,3 +1,5 @@ +from collections.abc import Sequence + import torch @@ -6,7 +8,7 @@ def heatmap_focal_loss( inputs, targets, pos_inds, - labels, + labels: Sequence[str], alpha: float = -1, beta: float = 4, gamma: float = 2, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py index 9cfe00765c..55fa2a186d 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/iou_loss.py @@ -3,11 +3,11 @@ class IOULoss(nn.Module): - def __init__(self, loc_loss_type="iou"): - super(IOULoss, self).__init__() + def __init__(self, loc_loss_type: str="iou") -> None: + super().__init__() self.loc_loss_type = loc_loss_type - def forward(self, pred, target, weight=None, reduction="sum"): + def forward(self, pred, target, weight=None, reduction: str="sum"): pred_left = pred[:, 0] pred_top = pred[:, 1] pred_right = pred[:, 2] diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py index 80029fa60b..429c986cfe 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/ml_nms.py @@ -1,7 +1,7 @@ from detectron2.layers import batched_nms -def ml_nms(boxlist, nms_thresh, max_proposals=-1, score_field="scores", label_field="labels"): +def ml_nms(boxlist, nms_thresh, max_proposals=-1, score_field: str="scores", label_field: str="labels"): """ Performs non-maximum suppression on a boxlist, with scores specified in a boxlist field via score_field. diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py index 63a1cb13f9..02cd3da416 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/meta_arch/centernet_detector.py @@ -1,15 +1,13 @@ -import torch -from torch import nn - +from detectron2.modeling import build_backbone, build_proposal_generator, detector_postprocess from detectron2.modeling.meta_arch.build import META_ARCH_REGISTRY -from detectron2.modeling import build_backbone, build_proposal_generator -from detectron2.modeling import detector_postprocess from detectron2.structures import ImageList +import torch +from torch import nn @META_ARCH_REGISTRY.register() class CenterNetDetector(nn.Module): - def __init__(self, cfg): + def __init__(self, cfg) -> None: super().__init__() self.mean, self.std = cfg.MODEL.PIXEL_MEAN, cfg.MODEL.PIXEL_STD self.register_buffer("pixel_mean", torch.Tensor(cfg.MODEL.PIXEL_MEAN).view(-1, 1, 1)) @@ -35,7 +33,7 @@ def device(self): return self.pixel_mean.device @torch.no_grad() - def inference(self, batched_inputs, do_postprocess=True): + def inference(self, batched_inputs, do_postprocess: bool=True): images = self.preprocess_image(batched_inputs) inp = images.tensor features = self.backbone(inp) @@ -43,7 +41,7 @@ def inference(self, batched_inputs, do_postprocess=True): processed_results = [] for results_per_image, input_per_image, image_size in zip( - proposals, batched_inputs, images.image_sizes + proposals, batched_inputs, images.image_sizes, strict=False ): if do_postprocess: height = input_per_image.get("height", image_size[0]) diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py index a0c44fec3d..b48b5447ac 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py @@ -1,21 +1,24 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Part of the code is from https://github.com/tztztztztz/eql.detectron2/blob/master/projects/EQL/eql/fast_rcnn.py import math + +from detectron2.layers import ShapeSpec, cat +from detectron2.modeling.roi_heads.fast_rcnn import ( + FastRCNNOutputLayers, + _log_classification_stats, + fast_rcnn_inference, +) import torch from torch import nn from torch.nn import functional as F -from detectron2.layers import ShapeSpec, cat -from detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers -from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference -from detectron2.modeling.roi_heads.fast_rcnn import _log_classification_stats -from .fed_loss import load_class_freq, get_fed_loss_inds +from .fed_loss import get_fed_loss_inds, load_class_freq __all__ = ["CustomFastRCNNOutputLayers"] class CustomFastRCNNOutputLayers(FastRCNNOutputLayers): - def __init__(self, cfg, input_shape: ShapeSpec, **kwargs): + def __init__(self, cfg, input_shape: ShapeSpec, **kwargs) -> None: super().__init__(cfg, input_shape, **kwargs) self.use_sigmoid_ce = cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE if self.use_sigmoid_ce: @@ -43,7 +46,6 @@ def losses(self, predictions, proposals): gt_classes = ( cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) ) - num_classes = self.num_classes _log_classification_stats(scores, gt_classes) if len(proposals): @@ -125,7 +127,7 @@ def inference(self, predictions, proposals): scores = self.predict_probs(predictions, proposals) if self.cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE: proposal_scores = [p.get("objectness_logits") for p in proposals] - scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)] + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] image_shapes = [x.image_size for x in proposals] return fast_rcnn_inference( boxes, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py index aefd1d164e..d0478de2f3 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/custom_roi_heads.py @@ -1,31 +1,30 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import torch - -from detectron2.utils.events import get_event_storage - from detectron2.modeling.box_regression import Box2BoxTransform +from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference from detectron2.modeling.roi_heads.roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads -from detectron2.modeling.roi_heads.cascade_rcnn import CascadeROIHeads +from detectron2.utils.events import get_event_storage +import torch + from .custom_fast_rcnn import CustomFastRCNNOutputLayers @ROI_HEADS_REGISTRY.register() class CustomROIHeads(StandardROIHeads): @classmethod - def _init_box_head(self, cfg, input_shape): + def _init_box_head(cls, cfg, input_shape): ret = super()._init_box_head(cfg, input_shape) del ret["box_predictor"] ret["box_predictor"] = CustomFastRCNNOutputLayers(cfg, ret["box_head"].output_shape) - self.debug = cfg.DEBUG - if self.debug: - self.debug_show_name = cfg.DEBUG_SHOW_NAME - self.save_debug = cfg.SAVE_DEBUG - self.vis_thresh = cfg.VIS_THRESH - self.pixel_mean = ( + cls.debug = cfg.DEBUG + if cls.debug: + cls.debug_show_name = cfg.DEBUG_SHOW_NAME + cls.save_debug = cfg.SAVE_DEBUG + cls.vis_thresh = cfg.VIS_THRESH + cls.pixel_mean = ( torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) ) - self.pixel_std = ( + cls.pixel_std = ( torch.Tensor(cfg.MODEL.PIXEL_STD).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) ) return ret @@ -52,7 +51,8 @@ def forward(self, images, features, proposals, targets=None): if self.debug: from ..debug import debug_second_stage - denormalizer = lambda x: x * self.pixel_std + self.pixel_mean + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean debug_second_stage( [denormalizer(images[0].clone())], pred_instances, @@ -65,13 +65,13 @@ def forward(self, images, features, proposals, targets=None): @ROI_HEADS_REGISTRY.register() class CustomCascadeROIHeads(CascadeROIHeads): @classmethod - def _init_box_head(self, cfg, input_shape): - self.mult_proposal_score = cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE + def _init_box_head(cls, cfg, input_shape): + cls.mult_proposal_score = cfg.MODEL.ROI_BOX_HEAD.MULT_PROPOSAL_SCORE ret = super()._init_box_head(cfg, input_shape) del ret["box_predictors"] cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS box_predictors = [] - for box_head, bbox_reg_weights in zip(ret["box_heads"], cascade_bbox_reg_weights): + for box_head, bbox_reg_weights in zip(ret["box_heads"], cascade_bbox_reg_weights, strict=False): box_predictors.append( CustomFastRCNNOutputLayers( cfg, @@ -80,15 +80,15 @@ def _init_box_head(self, cfg, input_shape): ) ) ret["box_predictors"] = box_predictors - self.debug = cfg.DEBUG - if self.debug: - self.debug_show_name = cfg.DEBUG_SHOW_NAME - self.save_debug = cfg.SAVE_DEBUG - self.vis_thresh = cfg.VIS_THRESH - self.pixel_mean = ( + cls.debug = cfg.DEBUG + if cls.debug: + cls.debug_show_name = cfg.DEBUG_SHOW_NAME + cls.save_debug = cfg.SAVE_DEBUG + cls.vis_thresh = cfg.VIS_THRESH + cls.pixel_mean = ( torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) ) - self.pixel_std = ( + cls.pixel_std = ( torch.Tensor(cfg.MODEL.PIXEL_STD).to(torch.device(cfg.MODEL.DEVICE)).view(3, 1, 1) ) return ret @@ -120,20 +120,20 @@ def _forward_box(self, features, proposals, targets=None): losses = {} storage = get_event_storage() for stage, (predictor, predictions, proposals) in enumerate(head_outputs): - with storage.name_scope("stage{}".format(stage)): + with storage.name_scope(f"stage{stage}"): stage_losses = predictor.losses(predictions, proposals) - losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()}) + losses.update({k + f"_stage{stage}": v for k, v in stage_losses.items()}) return losses else: # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] scores = [ sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) - for scores_per_image in zip(*scores_per_stage) + for scores_per_image in zip(*scores_per_stage, strict=False) ] if self.mult_proposal_score: - scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores)] + scores = [(s * ps[:, None]) ** 0.5 for s, ps in zip(scores, proposal_scores, strict=False)] predictor, predictions, proposals = head_outputs[-1] boxes = predictor.predict_boxes(predictions, proposals) @@ -169,7 +169,8 @@ def forward(self, images, features, proposals, targets=None): if self.debug: from ..debug import debug_second_stage - denormalizer = lambda x: x * self.pixel_std + self.pixel_mean + def denormalizer(x): + return x * self.pixel_std + self.pixel_mean debug_second_stage( [denormalizer(x.clone()) for x in images], pred_instances, diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py index d10e826786..8a41607ea9 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/roi_heads/fed_loss.py @@ -1,15 +1,16 @@ -import torch import json +import torch + -def load_class_freq(path="datasets/lvis/lvis_v1_train_cat_info.json", freq_weight=0.5): - cat_info = json.load(open(path, "r")) +def load_class_freq(path: str="datasets/lvis/lvis_v1_train_cat_info.json", freq_weight: float=0.5): + cat_info = json.load(open(path)) cat_info = torch.tensor([c["image_count"] for c in sorted(cat_info, key=lambda x: x["id"])]) freq_weight = cat_info.float() ** freq_weight return freq_weight -def get_fed_loss_inds(gt_classes, num_sample_cats=50, C=1203, weight=None, fed_cls_inds=-1): +def get_fed_loss_inds(gt_classes, num_sample_cats: int=50, C: int=1203, weight=None, fed_cls_inds=-1): appeared = torch.unique(gt_classes) # C' prob = appeared.new_ones(C + 1).float() prob[-1] = 0 diff --git a/dimos/models/Detic/third_party/CenterNet2/demo.py b/dimos/models/Detic/third_party/CenterNet2/demo.py index 281063f61b..3177d838ac 100644 --- a/dimos/models/Detic/third_party/CenterNet2/demo.py +++ b/dimos/models/Detic/third_party/CenterNet2/demo.py @@ -4,22 +4,21 @@ import multiprocessing as mp import os import time -import cv2 -import tqdm +from centernet.config import add_centernet_config +import cv2 from detectron2.config import get_cfg from detectron2.data.detection_utils import read_image from detectron2.utils.logger import setup_logger - from predictor import VisualizationDemo -from centernet.config import add_centernet_config +import tqdm # constants WINDOW_NAME = "CenterNet2 detections" +from detectron2.data import MetadataCatalog from detectron2.utils.video_visualizer import VideoVisualizer from detectron2.utils.visualizer import ColorMode -from detectron2.data import MetadataCatalog def setup_cfg(args): diff --git a/dimos/models/Detic/third_party/CenterNet2/predictor.py b/dimos/models/Detic/third_party/CenterNet2/predictor.py index 990040fc03..0bdee56264 100644 --- a/dimos/models/Detic/third_party/CenterNet2/predictor.py +++ b/dimos/models/Detic/third_party/CenterNet2/predictor.py @@ -1,19 +1,19 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import atexit import bisect -import multiprocessing as mp from collections import deque -import cv2 -import torch +import multiprocessing as mp +import cv2 from detectron2.data import MetadataCatalog from detectron2.engine.defaults import DefaultPredictor from detectron2.utils.video_visualizer import VideoVisualizer from detectron2.utils.visualizer import ColorMode, Visualizer +import torch -class VisualizationDemo(object): - def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel=False): +class VisualizationDemo: + def __init__(self, cfg, instance_mode=ColorMode.IMAGE, parallel: bool=False) -> None: """ Args: cfg (CfgNode): @@ -161,13 +161,13 @@ class _StopToken: pass class _PredictWorker(mp.Process): - def __init__(self, cfg, task_queue, result_queue): + def __init__(self, cfg, task_queue, result_queue) -> None: self.cfg = cfg self.task_queue = task_queue self.result_queue = result_queue super().__init__() - def run(self): + def run(self) -> None: predictor = DefaultPredictor(self.cfg) while True: @@ -178,7 +178,7 @@ def run(self): result = predictor(data) self.result_queue.put((idx, result)) - def __init__(self, cfg, num_gpus: int = 1): + def __init__(self, cfg, num_gpus: int = 1) -> None: """ Args: cfg (CfgNode): @@ -191,7 +191,7 @@ def __init__(self, cfg, num_gpus: int = 1): for gpuid in range(max(num_gpus, 1)): cfg = cfg.clone() cfg.defrost() - cfg.MODEL.DEVICE = "cuda:{}".format(gpuid) if num_gpus > 0 else "cpu" + cfg.MODEL.DEVICE = f"cuda:{gpuid}" if num_gpus > 0 else "cpu" self.procs.append( AsyncPredictor._PredictWorker(cfg, self.task_queue, self.result_queue) ) @@ -205,7 +205,7 @@ def __init__(self, cfg, num_gpus: int = 1): p.start() atexit.register(self.shutdown) - def put(self, image): + def put(self, image) -> None: self.put_idx += 1 self.task_queue.put((self.put_idx, image)) @@ -225,14 +225,14 @@ def get(self): self.result_rank.insert(insert, idx) self.result_data.insert(insert, res) - def __len__(self): + def __len__(self) -> int: return self.put_idx - self.get_idx def __call__(self, image): self.put(image) return self.get() - def shutdown(self): + def shutdown(self) -> None: for _ in self.procs: self.task_queue.put(AsyncPredictor._StopToken()) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py b/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py index 75a4a794df..7b7b9e3432 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/analyze_model.py @@ -1,11 +1,7 @@ -# -*- coding: utf-8 -*- # Copyright (c) Facebook, Inc. and its affiliates. -import logging -import numpy as np from collections import Counter -import tqdm -from fvcore.nn import flop_count_table # can also try flop_count_str +import logging from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate @@ -18,6 +14,9 @@ parameter_count_table, ) from detectron2.utils.logger import setup_logger +from fvcore.nn import flop_count_table # can also try flop_count_str +import numpy as np +import tqdm logger = logging.getLogger("detectron2") @@ -37,7 +36,7 @@ def setup(args): return cfg -def do_flop(cfg): +def do_flop(cfg) -> None: if isinstance(cfg, CfgNode): data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) model = build_model(cfg) @@ -64,11 +63,11 @@ def do_flop(cfg): + str([(k, v / (idx + 1) / 1e9) for k, v in counts.items()]) ) logger.info( - "Total GFlops: {:.1f}±{:.1f}".format(np.mean(total_flops) / 1e9, np.std(total_flops) / 1e9) + f"Total GFlops: {np.mean(total_flops) / 1e9:.1f}±{np.std(total_flops) / 1e9:.1f}" ) -def do_activation(cfg): +def do_activation(cfg) -> None: if isinstance(cfg, CfgNode): data_loader = build_detection_test_loader(cfg, cfg.DATASETS.TEST[0]) model = build_model(cfg) @@ -91,13 +90,11 @@ def do_activation(cfg): + str([(k, v / idx) for k, v in counts.items()]) ) logger.info( - "Total (Million) Activations: {}±{}".format( - np.mean(total_activations), np.std(total_activations) - ) + f"Total (Million) Activations: {np.mean(total_activations)}±{np.std(total_activations)}" ) -def do_parameter(cfg): +def do_parameter(cfg) -> None: if isinstance(cfg, CfgNode): model = build_model(cfg) else: @@ -105,7 +102,7 @@ def do_parameter(cfg): logger.info("Parameter Count:\n" + parameter_count_table(model, max_depth=5)) -def do_structure(cfg): +def do_structure(cfg) -> None: if isinstance(cfg, CfgNode): model = build_model(cfg) else: diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py b/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py index c2d673fab1..48f398d83d 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/benchmark.py @@ -8,11 +8,6 @@ import itertools import logging -import psutil -import torch -import tqdm -from fvcore.common.timer import Timer -from torch.nn.parallel import DistributedDataParallel from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import LazyConfig, get_cfg, instantiate @@ -29,6 +24,11 @@ from detectron2.utils.collect_env import collect_env_info from detectron2.utils.events import CommonMetricPrinter from detectron2.utils.logger import setup_logger +from fvcore.common.timer import Timer +import psutil +import torch +from torch.nn.parallel import DistributedDataParallel +import tqdm logger = logging.getLogger("detectron2") @@ -59,14 +59,12 @@ def create_data_benchmark(cfg, args): return instantiate(kwargs) -def RAM_msg(): +def RAM_msg() -> str: vram = psutil.virtual_memory() - return "RAM Usage: {:.2f}/{:.2f} GB".format( - (vram.total - vram.available) / 1024**3, vram.total / 1024**3 - ) + return f"RAM Usage: {(vram.total - vram.available) / 1024**3:.2f}/{vram.total / 1024**3:.2f} GB" -def benchmark_data(args): +def benchmark_data(args) -> None: cfg = setup(args) logger.info("After spawning " + RAM_msg()) @@ -78,7 +76,7 @@ def benchmark_data(args): benchmark.benchmark_distributed(250, 1) -def benchmark_data_advanced(args): +def benchmark_data_advanced(args) -> None: # benchmark dataloader with more details to help analyze performance bottleneck cfg = setup(args) benchmark = create_data_benchmark(cfg, args) @@ -94,10 +92,10 @@ def benchmark_data_advanced(args): benchmark.benchmark_distributed(100) -def benchmark_train(args): +def benchmark_train(args) -> None: cfg = setup(args) model = build_model(cfg) - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") if comm.get_world_size() > 1: model = DistributedDataParallel( model, device_ids=[comm.get_local_rank()], broadcast_buffers=False @@ -131,7 +129,7 @@ def f(): @torch.no_grad() -def benchmark_eval(args): +def benchmark_eval(args) -> None: cfg = setup(args) if args.config_file.endswith(".yaml"): model = build_model(cfg) @@ -149,7 +147,7 @@ def benchmark_eval(args): data_loader = instantiate(cfg.dataloader.test) model.eval() - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") dummy_data = DatasetFromList(list(itertools.islice(data_loader, 100)), copy=False) def f(): @@ -167,7 +165,7 @@ def f(): break model(d) pbar.update() - logger.info("{} iters in {} seconds.".format(max_iter, timer.seconds())) + logger.info(f"{max_iter} iters in {timer.seconds()} seconds.") if __name__ == "__main__": diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py b/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py index 4b827d960c..8bf0565d5e 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/convert-torchvision-to-d2.py @@ -3,6 +3,7 @@ import pickle as pkl import sys + import torch """ @@ -40,9 +41,9 @@ if "layer" not in k: k = "stem." + k for t in [1, 2, 3, 4]: - k = k.replace("layer{}".format(t), "res{}".format(t + 1)) + k = k.replace(f"layer{t}", f"res{t + 1}") for t in [1, 2, 3]: - k = k.replace("bn{}".format(t), "conv{}.norm".format(t)) + k = k.replace(f"bn{t}", f"conv{t}.norm") k = k.replace("downsample.0", "shortcut") k = k.replace("downsample.1", "shortcut.norm") print(old_k, "->", k) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py index 067309f241..4d76a57b76 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/deploy/export_model.py @@ -2,14 +2,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse import os -from typing import Dict, List, Tuple -import torch -from torch import Tensor, nn -import detectron2.data.transforms as T from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import build_detection_test_loader, detection_utils +import detectron2.data.transforms as T from detectron2.evaluation import COCOEvaluator, inference_on_dataset, print_csv_format from detectron2.export import TracingAdapter, dump_torchscript_IR, scripting_with_instances from detectron2.modeling import GeneralizedRCNN, RetinaNet, build_model @@ -19,6 +16,8 @@ from detectron2.utils.env import TORCH_VERSION from detectron2.utils.file_io import PathManager from detectron2.utils.logger import setup_logger +import torch +from torch import Tensor, nn def setup_cfg(args): @@ -72,7 +71,7 @@ def export_scripting(torch_model): class ScriptableAdapterBase(nn.Module): # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 # by not retuning instances but dicts. Otherwise the exported model is not deployable - def __init__(self): + def __init__(self) -> None: super().__init__() self.model = torch_model self.eval() @@ -80,14 +79,14 @@ def __init__(self): if isinstance(torch_model, GeneralizedRCNN): class ScriptableAdapter(ScriptableAdapterBase): - def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: + def forward(self, inputs: tuple[dict[str, torch.Tensor]]) -> list[dict[str, Tensor]]: instances = self.model.inference(inputs, do_postprocess=False) return [i.get_fields() for i in instances] else: class ScriptableAdapter(ScriptableAdapterBase): - def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: + def forward(self, inputs: tuple[dict[str, torch.Tensor]]) -> list[dict[str, Tensor]]: instances = self.model(inputs) return [i.get_fields() for i in instances] diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py index 506e8baff6..8f40a40c39 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/lazyconfig_train_net.py @@ -42,7 +42,7 @@ def do_test(cfg, model): return ret -def do_train(args, cfg): +def do_train(args, cfg) -> None: """ Args: cfg: an object with the following attributes: @@ -63,7 +63,7 @@ def do_train(args, cfg): """ model = instantiate(cfg.model) logger = logging.getLogger("detectron2") - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") model.to(cfg.train.device) cfg.optimizer.params.model = model @@ -105,7 +105,7 @@ def do_train(args, cfg): trainer.train(start_iter, cfg.train.max_iter) -def main(args): +def main(args) -> None: cfg = LazyConfig.load(args.config_file) cfg = LazyConfig.apply_overrides(cfg, args.opts) default_setup(cfg, args) diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py index 037957bac6..2b3ccc80b4 100644 --- a/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/lightning_train_net.py @@ -5,14 +5,13 @@ # Depending on how you launch the trainer, there are issues with processes terminating correctly # This module is still dependent on D2 logging, but could be transferred to use Lightning logging +from collections import OrderedDict import logging import os import time +from typing import Any import weakref -from collections import OrderedDict -from typing import Any, Dict, List -import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import build_detection_test_loader, build_detection_train_loader @@ -28,9 +27,9 @@ from detectron2.evaluation.testing import flatten_results_dict from detectron2.modeling import build_model from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm from detectron2.utils.events import EventStorage from detectron2.utils.logger import setup_logger - import pytorch_lightning as pl # type: ignore from pytorch_lightning import LightningDataModule, LightningModule from train_net import build_evaluator @@ -40,7 +39,7 @@ class TrainingModule(LightningModule): - def __init__(self, cfg): + def __init__(self, cfg) -> None: super().__init__() if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 setup_logger() @@ -51,14 +50,14 @@ def __init__(self, cfg): self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER - def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None: checkpoint["iteration"] = self.storage.iter - def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]) -> None: + def on_load_checkpoint(self, checkpointed_state: dict[str, Any]) -> None: self.start_iter = checkpointed_state["iteration"] self.storage.iter = self.start_iter - def setup(self, stage: str): + def setup(self, stage: str) -> None: if self.cfg.MODEL.WEIGHTS: self.checkpointer = DetectionCheckpointer( # Assume you want to save checkpoints together with logs/statistics @@ -110,7 +109,7 @@ def training_step_end(self, training_step_outpus): self.data_start = time.perf_counter() return training_step_outpus - def training_epoch_end(self, training_step_outputs): + def training_epoch_end(self, training_step_outputs) -> None: self.iteration_timer.after_train() if comm.is_main_process(): self.checkpointer.save("model_final") @@ -127,17 +126,17 @@ def _process_dataset_evaluation_results(self) -> OrderedDict: print_csv_format(results[dataset_name]) if len(results) == 1: - results = list(results.values())[0] + results = next(iter(results.values())) return results - def _reset_dataset_evaluators(self): + def _reset_dataset_evaluators(self) -> None: self._evaluators = [] for dataset_name in self.cfg.DATASETS.TEST: evaluator = build_evaluator(self.cfg, dataset_name) evaluator.reset() self._evaluators.append(evaluator) - def on_validation_epoch_start(self, _outputs): + def on_validation_epoch_start(self, _outputs) -> None: self._reset_dataset_evaluators() def validation_epoch_end(self, _outputs): @@ -149,14 +148,12 @@ def validation_epoch_end(self, _outputs): v = float(v) except Exception as e: raise ValueError( - "[EvalHook] eval_function should return a nested dict of float. Got '{}: {}' instead.".format( - k, v - ) + f"[EvalHook] eval_function should return a nested dict of float. Got '{k}: {v}' instead." ) from e self.storage.put_scalars(**flattened_results, smoothing_hint=False) def validation_step(self, batch, batch_idx: int, dataloader_idx: int = 0) -> None: - if not isinstance(batch, List): + if not isinstance(batch, list): batch = [batch] outputs = self.model(batch) self._evaluators[dataloader_idx].process(batch, outputs) @@ -169,7 +166,7 @@ def configure_optimizers(self): class DataModule(LightningDataModule): - def __init__(self, cfg): + def __init__(self, cfg) -> None: super().__init__() self.cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) @@ -183,12 +180,12 @@ def val_dataloader(self): return dataloaders -def main(args): +def main(args) -> None: cfg = setup(args) train(cfg, args) -def train(cfg, args): +def train(cfg, args) -> None: trainer_params = { # training loop is bounded by max steps, use a large max_epochs to make # sure max_steps is met first diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py index 2ff9080f7f..a06d19aff2 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/plain_train_net.py @@ -19,13 +19,10 @@ It also includes fewer abstraction, therefore is easier to add custom logic. """ +from collections import OrderedDict import logging import os -from collections import OrderedDict -import torch -from torch.nn.parallel import DistributedDataParallel -import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer from detectron2.config import get_cfg from detectron2.data import ( @@ -48,12 +45,15 @@ ) from detectron2.modeling import build_model from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm from detectron2.utils.events import EventStorage +import torch +from torch.nn.parallel import DistributedDataParallel logger = logging.getLogger("detectron2") -def get_evaluator(cfg, dataset_name, output_folder=None): +def get_evaluator(cfg, dataset_name: str, output_folder=None): """ Create evaluator(s) for a given dataset. This uses the special metadata "evaluator_type" associated with each builtin dataset. @@ -92,7 +92,7 @@ def get_evaluator(cfg, dataset_name, output_folder=None): return LVISEvaluator(dataset_name, cfg, True, output_folder) if len(evaluator_list) == 0: raise NotImplementedError( - "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) + f"no Evaluator for the dataset {dataset_name} with the type {evaluator_type}" ) if len(evaluator_list) == 1: return evaluator_list[0] @@ -109,14 +109,14 @@ def do_test(cfg, model): results_i = inference_on_dataset(model, data_loader, evaluator) results[dataset_name] = results_i if comm.is_main_process(): - logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + logger.info(f"Evaluation results for {dataset_name} in csv format:") print_csv_format(results_i) if len(results) == 1: - results = list(results.values())[0] + results = next(iter(results.values())) return results -def do_train(cfg, model, resume=False): +def do_train(cfg, model, resume: bool=False) -> None: model.train() optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) @@ -138,9 +138,9 @@ def do_train(cfg, model, resume=False): # compared to "train_net.py", we do not support accurate timing and # precise BN here, because they are not trivial to implement in a small training loop data_loader = build_detection_train_loader(cfg) - logger.info("Starting training from iteration {}".format(start_iter)) + logger.info(f"Starting training from iteration {start_iter}") with EventStorage(start_iter) as storage: - for data, iteration in zip(data_loader, range(start_iter, max_iter)): + for data, iteration in zip(data_loader, range(start_iter, max_iter), strict=False): storage.iter = iteration loss_dict = model(data) @@ -193,7 +193,7 @@ def main(args): cfg = setup(args) model = build_model(cfg) - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") if args.eval_only: DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py b/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py index 10334aa1d8..deb2ca6db8 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/train_net.py @@ -16,12 +16,10 @@ You may want to write your own script with your datasets and other customizations. """ +from collections import OrderedDict import logging import os -from collections import OrderedDict -import torch -import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg from detectron2.data import MetadataCatalog @@ -38,9 +36,11 @@ verify_results, ) from detectron2.modeling import GeneralizedRCNNWithTTA +import detectron2.utils.comm as comm +import torch -def build_evaluator(cfg, dataset_name, output_folder=None): +def build_evaluator(cfg, dataset_name: str, output_folder=None): """ Create evaluator(s) for a given dataset. This uses the special metadata "evaluator_type" associated with each builtin dataset. @@ -79,7 +79,7 @@ def build_evaluator(cfg, dataset_name, output_folder=None): return LVISEvaluator(dataset_name, output_dir=output_folder) if len(evaluator_list) == 0: raise NotImplementedError( - "no Evaluator for the dataset {} with the type {}".format(dataset_name, evaluator_type) + f"no Evaluator for the dataset {dataset_name} with the type {evaluator_type}" ) elif len(evaluator_list) == 1: return evaluator_list[0] @@ -95,7 +95,7 @@ class Trainer(DefaultTrainer): """ @classmethod - def build_evaluator(cls, cfg, dataset_name, output_folder=None): + def build_evaluator(cls, cfg, dataset_name: str, output_folder=None): return build_evaluator(cfg, dataset_name, output_folder) @classmethod diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py index fd0ba8347b..99abfdff4e 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_data.py @@ -1,17 +1,21 @@ #!/usr/bin/env python # Copyright (c) Facebook, Inc. and its affiliates. import argparse -import os from itertools import chain -import cv2 -import tqdm +import os +import cv2 from detectron2.config import get_cfg -from detectron2.data import DatasetCatalog, MetadataCatalog, build_detection_train_loader -from detectron2.data import detection_utils as utils +from detectron2.data import ( + DatasetCatalog, + MetadataCatalog, + build_detection_train_loader, + detection_utils as utils, +) from detectron2.data.build import filter_images_with_few_keypoints from detectron2.utils.logger import setup_logger from detectron2.utils.visualizer import Visualizer +import tqdm def setup(args): @@ -54,14 +58,14 @@ def parse_args(in_args=None): os.makedirs(dirname, exist_ok=True) metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0]) - def output(vis, fname): + def output(vis, fname) -> None: if args.show: print(fname) cv2.imshow("window", vis.get_image()[:, :, ::-1]) cv2.waitKey() else: filepath = os.path.join(dirname, fname) - print("Saving to {} ...".format(filepath)) + print(f"Saving to {filepath} ...") vis.save(filepath) scale = 1.0 diff --git a/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py index 472190e0b3..04dea72446 100755 --- a/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py +++ b/dimos/models/Detic/third_party/CenterNet2/tools/visualize_json_results.py @@ -2,21 +2,21 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse +from collections import defaultdict import json -import numpy as np import os -from collections import defaultdict -import cv2 -import tqdm +import cv2 from detectron2.data import DatasetCatalog, MetadataCatalog from detectron2.structures import Boxes, BoxMode, Instances from detectron2.utils.file_io import PathManager from detectron2.utils.logger import setup_logger from detectron2.utils.visualizer import Visualizer +import numpy as np +import tqdm -def create_instances(predictions, image_size): +def create_instances(predictions, image_size: int): ret = Instances(image_size) score = np.asarray([x["score"] for x in predictions]) @@ -71,7 +71,7 @@ def dataset_id_map(ds_id): return ds_id - 1 else: - raise ValueError("Unsupported dataset: {}".format(args.dataset)) + raise ValueError(f"Unsupported dataset: {args.dataset}") os.makedirs(args.output, exist_ok=True) diff --git a/dimos/models/Detic/third_party/CenterNet2/train_net.py b/dimos/models/Detic/third_party/CenterNet2/train_net.py index 1ca9f4cdd7..92859d7586 100644 --- a/dimos/models/Detic/third_party/CenterNet2/train_net.py +++ b/dimos/models/Detic/third_party/CenterNet2/train_net.py @@ -1,21 +1,20 @@ +from collections import OrderedDict +import datetime import logging import os -from collections import OrderedDict -import torch -from torch.nn.parallel import DistributedDataParallel import time -import datetime -from fvcore.common.timer import Timer -import detectron2.utils.comm as comm +from centernet.config import add_centernet_config +from centernet.data.custom_build_augmentation import build_custom_augmentation from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer from detectron2.config import get_cfg from detectron2.data import ( MetadataCatalog, build_detection_test_loader, ) +from detectron2.data.build import build_detection_train_loader +from detectron2.data.dataset_mapper import DatasetMapper from detectron2.engine import default_argument_parser, default_setup, launch - from detectron2.evaluation import ( COCOEvaluator, LVISEvaluator, @@ -23,19 +22,18 @@ print_csv_format, ) from detectron2.modeling import build_model +from detectron2.modeling.test_time_augmentation import GeneralizedRCNNWithTTA from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm from detectron2.utils.events import ( CommonMetricPrinter, EventStorage, JSONWriter, TensorboardXWriter, ) -from detectron2.modeling.test_time_augmentation import GeneralizedRCNNWithTTA -from detectron2.data.dataset_mapper import DatasetMapper -from detectron2.data.build import build_detection_train_loader - -from centernet.config import add_centernet_config -from centernet.data.custom_build_augmentation import build_custom_augmentation +from fvcore.common.timer import Timer +import torch +from torch.nn.parallel import DistributedDataParallel logger = logging.getLogger("detectron2") @@ -49,7 +47,7 @@ def do_test(cfg, model): else DatasetMapper(cfg, False, augmentations=build_custom_augmentation(cfg, False)) ) data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper) - output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_{}".format(dataset_name)) + output_folder = os.path.join(cfg.OUTPUT_DIR, f"inference_{dataset_name}") evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type if evaluator_type == "lvis": @@ -61,14 +59,14 @@ def do_test(cfg, model): results[dataset_name] = inference_on_dataset(model, data_loader, evaluator) if comm.is_main_process(): - logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + logger.info(f"Evaluation results for {dataset_name} in csv format:") print_csv_format(results[dataset_name]) if len(results) == 1: - results = list(results.values())[0] + results = next(iter(results.values())) return results -def do_train(cfg, model, resume=False): +def do_train(cfg, model, resume: bool=False) -> None: model.train() optimizer = build_optimizer(cfg, model) scheduler = build_lr_scheduler(cfg, optimizer) @@ -115,12 +113,12 @@ def do_train(cfg, model, resume=False): data_loader = build_custom_train_loader(cfg, mapper=mapper) - logger.info("Starting training from iteration {}".format(start_iter)) + logger.info(f"Starting training from iteration {start_iter}") with EventStorage(start_iter) as storage: step_timer = Timer() data_timer = Timer() start_time = time.perf_counter() - for data, iteration in zip(data_loader, range(start_iter, max_iter)): + for data, iteration in zip(data_loader, range(start_iter, max_iter), strict=False): data_time = data_timer.seconds() storage.put_scalars(data_time=data_time) step_timer.reset() @@ -162,7 +160,7 @@ def do_train(cfg, model, resume=False): total_time = time.perf_counter() - start_time logger.info( - "Total training time: {}".format(str(datetime.timedelta(seconds=int(total_time)))) + f"Total training time: {datetime.timedelta(seconds=int(total_time))!s}" ) @@ -176,8 +174,8 @@ def setup(args): cfg.merge_from_list(args.opts) if "/auto" in cfg.OUTPUT_DIR: file_name = os.path.basename(args.config_file)[:-5] - cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace("/auto", "/{}".format(file_name)) - logger.info("OUTPUT_DIR: {}".format(cfg.OUTPUT_DIR)) + cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace("/auto", f"/{file_name}") + logger.info(f"OUTPUT_DIR: {cfg.OUTPUT_DIR}") cfg.freeze() default_setup(cfg, args) return cfg @@ -187,7 +185,7 @@ def main(args): cfg = setup(args) model = build_model(cfg) - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") if args.eval_only: DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume @@ -217,7 +215,7 @@ def main(args): args = args.parse_args() if args.manual_device != "": os.environ["CUDA_VISIBLE_DEVICES"] = args.manual_device - args.dist_url = "tcp://127.0.0.1:{}".format(torch.randint(11111, 60000, (1,))[0].item()) + args.dist_url = f"tcp://127.0.0.1:{torch.randint(11111, 60000, (1,))[0].item()}" print("Command Line Args:", args) launch( main, diff --git a/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py b/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py index 9830274aa6..3a4fcbd4e6 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/benchmark.py @@ -8,15 +8,14 @@ Benchmark inference speed of Deformable DETR. """ +import argparse import os import time -import argparse - -import torch +from datasets import build_dataset from main import get_args_parser as get_main_args_parser from models import build_model -from datasets import build_dataset +import torch from util.misc import nested_tensor_from_tensor_list @@ -32,7 +31,7 @@ def get_benckmark_arg_parser(): @torch.no_grad() -def measure_average_inference_time(model, inputs, num_iters=100, warm_iters=5): +def measure_average_inference_time(model, inputs, num_iters: int=100, warm_iters: int=5): ts = [] for iter_ in range(num_iters): torch.cuda.synchronize() diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py index d34b127147..870166e145 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/__init__.py @@ -8,9 +8,9 @@ # ------------------------------------------------------------------------ import torch.utils.data -from .torchvision_datasets import CocoDetection from .coco import build as build_coco +from .torchvision_datasets import CocoDetection def get_coco_api_from_dataset(dataset): diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py index 00e3d431ba..aa00ce49e3 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco.py @@ -15,14 +15,15 @@ from pathlib import Path +from pycocotools import mask as coco_mask import torch import torch.utils.data -from pycocotools import mask as coco_mask - -from .torchvision_datasets import CocoDetection as TvCocoDetection from util.misc import get_local_rank, get_local_size + import datasets.transforms as T +from .torchvision_datasets import CocoDetection as TvCocoDetection + class CocoDetection(TvCocoDetection): def __init__( @@ -31,11 +32,11 @@ def __init__( ann_file, transforms, return_masks, - cache_mode=False, - local_rank=0, - local_size=1, - ): - super(CocoDetection, self).__init__( + cache_mode: bool=False, + local_rank: int=0, + local_size: int=1, + ) -> None: + super().__init__( img_folder, ann_file, cache_mode=cache_mode, @@ -45,8 +46,8 @@ def __init__( self._transforms = transforms self.prepare = ConvertCocoPolysToMask(return_masks) - def __getitem__(self, idx): - img, target = super(CocoDetection, self).__getitem__(idx) + def __getitem__(self, idx: int): + img, target = super().__getitem__(idx) image_id = self.ids[idx] target = {"image_id": image_id, "annotations": target} img, target = self.prepare(img, target) @@ -55,7 +56,7 @@ def __getitem__(self, idx): return img, target -def convert_coco_poly_to_mask(segmentations, height, width): +def convert_coco_poly_to_mask(segmentations, height, width: int): masks = [] for polygons in segmentations: rles = coco_mask.frPyObjects(polygons, height, width) @@ -72,8 +73,8 @@ def convert_coco_poly_to_mask(segmentations, height, width): return masks -class ConvertCocoPolysToMask(object): - def __init__(self, return_masks=False): +class ConvertCocoPolysToMask: + def __init__(self, return_masks: bool=False) -> None: self.return_masks = return_masks def __call__(self, image, target): diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py index b0b9a76d39..1714024b24 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_eval.py @@ -15,21 +15,20 @@ in the end of the file, as python3 can suppress prints with contextlib """ -import os import contextlib import copy -import numpy as np -import torch +import os -from pycocotools.cocoeval import COCOeval +import numpy as np from pycocotools.coco import COCO +from pycocotools.cocoeval import COCOeval import pycocotools.mask as mask_util - +import torch from util.misc import all_gather -class CocoEvaluator(object): - def __init__(self, coco_gt, iou_types): +class CocoEvaluator: + def __init__(self, coco_gt, iou_types) -> None: assert isinstance(iou_types, (list, tuple)) coco_gt = copy.deepcopy(coco_gt) self.coco_gt = coco_gt @@ -42,7 +41,7 @@ def __init__(self, coco_gt, iou_types): self.img_ids = [] self.eval_imgs = {k: [] for k in iou_types} - def update(self, predictions): + def update(self, predictions) -> None: img_ids = list(np.unique(list(predictions.keys()))) self.img_ids.extend(img_ids) @@ -61,20 +60,20 @@ def update(self, predictions): self.eval_imgs[iou_type].append(eval_imgs) - def synchronize_between_processes(self): + def synchronize_between_processes(self) -> None: for iou_type in self.iou_types: self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) create_common_coco_eval( self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type] ) - def accumulate(self): + def accumulate(self) -> None: for coco_eval in self.coco_eval.values(): coco_eval.accumulate() - def summarize(self): + def summarize(self) -> None: for iou_type, coco_eval in self.coco_eval.items(): - print("IoU metric: {}".format(iou_type)) + print(f"IoU metric: {iou_type}") coco_eval.summarize() def prepare(self, predictions, iou_type): @@ -85,7 +84,7 @@ def prepare(self, predictions, iou_type): elif iou_type == "keypoints": return self.prepare_for_coco_keypoint(predictions) else: - raise ValueError("Unknown iou type {}".format(iou_type)) + raise ValueError(f"Unknown iou type {iou_type}") def prepare_for_coco_detection(self, predictions): coco_results = [] @@ -200,7 +199,7 @@ def merge(img_ids, eval_imgs): return merged_img_ids, merged_eval_imgs -def create_common_coco_eval(coco_eval, img_ids, eval_imgs): +def create_common_coco_eval(coco_eval, img_ids, eval_imgs) -> None: img_ids, eval_imgs = merge(img_ids, eval_imgs) img_ids = list(img_ids) eval_imgs = list(eval_imgs.flatten()) @@ -227,7 +226,7 @@ def evaluate(self): # add backward compatibility if useSegm is specified in params if p.useSegm is not None: p.iouType = "segm" if p.useSegm == 1 else "bbox" - print("useSegm (deprecated) is not None. Running {} evaluation".format(p.iouType)) + print(f"useSegm (deprecated) is not None. Running {p.iouType} evaluation") # print('Evaluate annotation type *{}*'.format(p.iouType)) p.imgIds = list(np.unique(p.imgIds)) if p.useCats: diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py index f0697b63b2..d1dd9bda59 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/coco_panoptic.py @@ -11,18 +11,17 @@ from pathlib import Path import numpy as np -import torch -from PIL import Image - from panopticapi.utils import rgb2id +from PIL import Image +import torch from util.box_ops import masks_to_boxes from .coco import make_coco_transforms class CocoPanoptic: - def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks=True): - with open(ann_file, "r") as f: + def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_masks: bool=True) -> None: + with open(ann_file) as f: self.coco = json.load(f) # sort 'images' field so that they are aligned with 'annotations' @@ -30,7 +29,7 @@ def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_mas self.coco["images"] = sorted(self.coco["images"], key=lambda x: x["id"]) # sanity check if "annotations" in self.coco: - for img, ann in zip(self.coco["images"], self.coco["annotations"]): + for img, ann in zip(self.coco["images"], self.coco["annotations"], strict=False): assert img["file_name"][:-4] == ann["file_name"][:-4] self.img_folder = img_folder @@ -39,7 +38,7 @@ def __init__(self, img_folder, ann_folder, ann_file, transforms=None, return_mas self.transforms = transforms self.return_masks = return_masks - def __getitem__(self, idx): + def __getitem__(self, idx: int): ann_info = ( self.coco["annotations"][idx] if "annotations" in self.coco @@ -83,10 +82,10 @@ def __getitem__(self, idx): return img, target - def __len__(self): + def __len__(self) -> int: return len(self.coco["images"]) - def get_height_and_width(self, idx): + def get_height_and_width(self, idx: int): img_info = self.coco["images"][idx] height = img_info["height"] width = img_info["width"] diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py index 731ebc19d4..4942500801 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/data_prefetcher.py @@ -14,7 +14,7 @@ def to_cuda(samples, targets, device): class data_prefetcher: - def __init__(self, loader, device, prefetch=True): + def __init__(self, loader, device, prefetch: bool=True) -> None: self.loader = iter(loader) self.prefetch = prefetch self.device = device @@ -22,7 +22,7 @@ def __init__(self, loader, device, prefetch=True): self.stream = torch.cuda.Stream() self.preload() - def preload(self): + def preload(self) -> None: try: self.next_samples, self.next_targets = next(self.loader) except StopIteration: @@ -61,7 +61,7 @@ def next(self): samples.record_stream(torch.cuda.current_stream()) if targets is not None: for t in targets: - for k, v in t.items(): + for _k, v in t.items(): v.record_stream(torch.cuda.current_stream()) self.preload() else: diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py index ad606603a9..1a8ed7a82f 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/panoptic_eval.py @@ -18,8 +18,8 @@ pass -class PanopticEvaluator(object): - def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): +class PanopticEvaluator: + def __init__(self, ann_file, ann_folder, output_dir: str="panoptic_eval") -> None: self.gt_json = ann_file self.gt_folder = ann_folder if utils.is_main_process(): @@ -28,14 +28,14 @@ def __init__(self, ann_file, ann_folder, output_dir="panoptic_eval"): self.output_dir = output_dir self.predictions = [] - def update(self, predictions): + def update(self, predictions) -> None: for p in predictions: with open(os.path.join(self.output_dir, p["file_name"]), "wb") as f: f.write(p.pop("png_string")) self.predictions += predictions - def synchronize_between_processes(self): + def synchronize_between_processes(self) -> None: all_predictions = utils.all_gather(self.predictions) merged_predictions = [] for p in all_predictions: diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py index a8892f7561..b753d4ca3d 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/samplers.py @@ -6,8 +6,10 @@ # Modified from codes in torch.utils.data.distributed # ------------------------------------------------------------------------ -import os +from collections.abc import Iterator import math +import os + import torch import torch.distributed as dist from torch.utils.data.sampler import Sampler @@ -29,8 +31,8 @@ class DistributedSampler(Sampler): """ def __init__( - self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True - ): + self, dataset, num_replicas: int | None=None, rank=None, local_rank=None, local_size: int | None=None, shuffle: bool=True + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -43,11 +45,11 @@ def __init__( self.num_replicas = num_replicas self.rank = rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.num_samples = math.ceil(len(self.dataset) * 1.0 / self.num_replicas) self.total_size = self.num_samples * self.num_replicas self.shuffle = shuffle - def __iter__(self): + def __iter__(self) -> Iterator: if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator() @@ -67,10 +69,10 @@ def __iter__(self): return iter(indices) - def __len__(self): + def __len__(self) -> int: return self.num_samples - def set_epoch(self, epoch): + def set_epoch(self, epoch: int) -> None: self.epoch = epoch @@ -90,8 +92,8 @@ class NodeDistributedSampler(Sampler): """ def __init__( - self, dataset, num_replicas=None, rank=None, local_rank=None, local_size=None, shuffle=True - ): + self, dataset, num_replicas: int | None=None, rank=None, local_rank=None, local_size: int | None=None, shuffle: bool=True + ) -> None: if num_replicas is None: if not dist.is_available(): raise RuntimeError("Requires distributed package to be available") @@ -111,12 +113,12 @@ def __init__( self.rank = rank self.local_rank = local_rank self.epoch = 0 - self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.num_samples = math.ceil(len(self.dataset) * 1.0 / self.num_replicas) self.total_size = self.num_samples * self.num_replicas self.total_size_parts = self.num_samples * self.num_replicas // self.num_parts - def __iter__(self): + def __iter__(self) -> Iterator: if self.shuffle: # deterministically shuffle based on epoch g = torch.Generator() @@ -139,8 +141,8 @@ def __iter__(self): return iter(indices) - def __len__(self): + def __len__(self) -> int: return self.num_samples - def set_epoch(self, epoch): + def set_epoch(self, epoch: int) -> None: self.epoch = epoch diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py index a634e37e47..65eb674294 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/torchvision_datasets/coco.py @@ -10,12 +10,13 @@ Copy-Paste from torchvision, but add utility of caching images on memory """ -from torchvision.datasets.vision import VisionDataset -from PIL import Image +from io import BytesIO import os import os.path + +from PIL import Image +from torchvision.datasets.vision import VisionDataset import tqdm -from io import BytesIO class CocoDetection(VisionDataset): @@ -38,11 +39,11 @@ def __init__( transform=None, target_transform=None, transforms=None, - cache_mode=False, - local_rank=0, - local_size=1, - ): - super(CocoDetection, self).__init__(root, transforms, transform, target_transform) + cache_mode: bool=False, + local_rank: int=0, + local_size: int=1, + ) -> None: + super().__init__(root, transforms, transform, target_transform) from pycocotools.coco import COCO self.coco = COCO(annFile) @@ -54,9 +55,9 @@ def __init__( self.cache = {} self.cache_images() - def cache_images(self): + def cache_images(self) -> None: self.cache = {} - for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids): + for index, img_id in zip(tqdm.trange(len(self.ids)), self.ids, strict=False): if index % self.local_size != self.local_rank: continue path = self.coco.loadImgs(img_id)[0]["file_name"] @@ -91,5 +92,5 @@ def __getitem__(self, index): return img, target - def __len__(self): + def __len__(self) -> int: return len(self.ids) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py b/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py index 08a771d475..b10be480ee 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/datasets/transforms.py @@ -11,13 +11,13 @@ Transforms and data augmentation for both image + bbox. """ +from collections.abc import Sequence import random import PIL import torch import torchvision.transforms as T import torchvision.transforms.functional as F - from util.box_ops import box_xyxy_to_cxcywh from util.misc import interpolate @@ -68,7 +68,7 @@ def crop(image, target, region): def hflip(image, target): flipped_image = F.hflip(image) - w, h = image.size + w, _h = image.size target = target.copy() if "boxes" in target: @@ -84,16 +84,16 @@ def hflip(image, target): return flipped_image, target -def resize(image, target, size, max_size=None): +def resize(image, target, size: int, max_size: int | None=None): # size can be min_size (scalar) or (w, h) tuple - def get_size_with_aspect_ratio(image_size, size, max_size=None): + def get_size_with_aspect_ratio(image_size: int, size: int, max_size: int | None=None): w, h = image_size if max_size is not None: min_original_size = float(min((w, h))) max_original_size = float(max((w, h))) if max_original_size / min_original_size * size > max_size: - size = int(round(max_size * min_original_size / max_original_size)) + size = round(max_size * min_original_size / max_original_size) if (w <= h and w == size) or (h <= w and h == size): return (h, w) @@ -107,7 +107,7 @@ def get_size_with_aspect_ratio(image_size, size, max_size=None): return (oh, ow) - def get_size(image_size, size, max_size=None): + def get_size(image_size: int, size: int, max_size: int | None=None): if isinstance(size, (list, tuple)): return size[::-1] else: @@ -119,7 +119,7 @@ def get_size(image_size, size, max_size=None): if target is None: return rescaled_image, None - ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size)) + ratios = tuple(float(s) / float(s_orig) for s, s_orig in zip(rescaled_image.size, image.size, strict=False)) ratio_width, ratio_height = ratios target = target.copy() @@ -159,8 +159,8 @@ def pad(image, target, padding): return padded_image, target -class RandomCrop(object): - def __init__(self, size): +class RandomCrop: + def __init__(self, size: int) -> None: self.size = size def __call__(self, img, target): @@ -168,8 +168,8 @@ def __call__(self, img, target): return crop(img, target, region) -class RandomSizeCrop(object): - def __init__(self, min_size: int, max_size: int): +class RandomSizeCrop: + def __init__(self, min_size: int, max_size: int) -> None: self.min_size = min_size self.max_size = max_size @@ -180,20 +180,20 @@ def __call__(self, img: PIL.Image.Image, target: dict): return crop(img, target, region) -class CenterCrop(object): - def __init__(self, size): +class CenterCrop: + def __init__(self, size: int) -> None: self.size = size def __call__(self, img, target): image_width, image_height = img.size crop_height, crop_width = self.size - crop_top = int(round((image_height - crop_height) / 2.0)) - crop_left = int(round((image_width - crop_width) / 2.0)) + crop_top = round((image_height - crop_height) / 2.0) + crop_left = round((image_width - crop_width) / 2.0) return crop(img, target, (crop_top, crop_left, crop_height, crop_width)) -class RandomHorizontalFlip(object): - def __init__(self, p=0.5): +class RandomHorizontalFlip: + def __init__(self, p: float=0.5) -> None: self.p = p def __call__(self, img, target): @@ -202,8 +202,8 @@ def __call__(self, img, target): return img, target -class RandomResize(object): - def __init__(self, sizes, max_size=None): +class RandomResize: + def __init__(self, sizes: Sequence[int], max_size: int | None=None) -> None: assert isinstance(sizes, (list, tuple)) self.sizes = sizes self.max_size = max_size @@ -213,8 +213,8 @@ def __call__(self, img, target=None): return resize(img, target, size, self.max_size) -class RandomPad(object): - def __init__(self, max_pad): +class RandomPad: + def __init__(self, max_pad) -> None: self.max_pad = max_pad def __call__(self, img, target): @@ -223,13 +223,13 @@ def __call__(self, img, target): return pad(img, target, (pad_x, pad_y)) -class RandomSelect(object): +class RandomSelect: """ Randomly selects between transforms1 and transforms2, with probability p for transforms1 and (1 - p) for transforms2 """ - def __init__(self, transforms1, transforms2, p=0.5): + def __init__(self, transforms1, transforms2, p: float=0.5) -> None: self.transforms1 = transforms1 self.transforms2 = transforms2 self.p = p @@ -240,21 +240,21 @@ def __call__(self, img, target): return self.transforms2(img, target) -class ToTensor(object): +class ToTensor: def __call__(self, img, target): return F.to_tensor(img), target -class RandomErasing(object): - def __init__(self, *args, **kwargs): +class RandomErasing: + def __init__(self, *args, **kwargs) -> None: self.eraser = T.RandomErasing(*args, **kwargs) def __call__(self, img, target): return self.eraser(img), target -class Normalize(object): - def __init__(self, mean, std): +class Normalize: + def __init__(self, mean, std) -> None: self.mean = mean self.std = std @@ -272,8 +272,8 @@ def __call__(self, image, target=None): return image, target -class Compose(object): - def __init__(self, transforms): +class Compose: + def __init__(self, transforms) -> None: self.transforms = transforms def __call__(self, image, target): @@ -281,10 +281,10 @@ def __call__(self, image, target): image, target = t(image, target) return image, target - def __repr__(self): + def __repr__(self) -> str: format_string = self.__class__.__name__ + "(" for t in self.transforms: format_string += "\n" - format_string += " {0}".format(t) + format_string += f" {t}" format_string += "\n)" return format_string diff --git a/dimos/models/Detic/third_party/Deformable-DETR/engine.py b/dimos/models/Detic/third_party/Deformable-DETR/engine.py index f47471648c..9cee2a089b 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/engine.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/engine.py @@ -11,16 +11,16 @@ Train and eval functions used in main.py """ +from collections.abc import Iterable import math import os import sys -from typing import Iterable -import torch -import util.misc as utils from datasets.coco_eval import CocoEvaluator -from datasets.panoptic_eval import PanopticEvaluator from datasets.data_prefetcher import data_prefetcher +from datasets.panoptic_eval import PanopticEvaluator +import torch +import util.misc as utils def train_one_epoch( @@ -38,7 +38,7 @@ def train_one_epoch( metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}")) metric_logger.add_meter("class_error", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")) metric_logger.add_meter("grad_norm", utils.SmoothedValue(window_size=1, fmt="{value:.2f}")) - header = "Epoch: [{}]".format(epoch) + header = f"Epoch: [{epoch}]" print_freq = 10 prefetcher = data_prefetcher(data_loader, device, prefetch=True) @@ -62,7 +62,7 @@ def train_one_epoch( loss_value = losses_reduced_scaled.item() if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) + print(f"Loss is {loss_value}, stopping training") print(loss_dict_reduced) sys.exit(1) @@ -135,7 +135,7 @@ def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, out if "segm" in postprocessors.keys(): target_sizes = torch.stack([t["size"] for t in targets], dim=0) results = postprocessors["segm"](results, outputs, orig_target_sizes, target_sizes) - res = {target["image_id"].item(): output for target, output in zip(targets, results)} + res = {target["image_id"].item(): output for target, output in zip(targets, results, strict=False)} if coco_evaluator is not None: coco_evaluator.update(res) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/main.py b/dimos/models/Detic/third_party/Deformable-DETR/main.py index ff91fd52a5..187b93a868 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/main.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/main.py @@ -11,19 +11,19 @@ import argparse import datetime import json +from pathlib import Path import random import time -from pathlib import Path -import numpy as np -import torch -from torch.utils.data import DataLoader import datasets -import util.misc as utils -import datasets.samplers as samplers from datasets import build_dataset, get_coco_api_from_dataset +import datasets.samplers as samplers from engine import evaluate, train_one_epoch from models import build_model +import numpy as np +import torch +from torch.utils.data import DataLoader +import util.misc as utils def get_args_parser(): @@ -168,9 +168,9 @@ def get_args_parser(): return parser -def main(args): +def main(args) -> None: utils.init_distributed_mode(args) - print("git:\n {}\n".format(utils.get_sha())) + print(f"git:\n {utils.get_sha()}\n") if args.frozen_weights is not None: assert args.masks, "Frozen training is meant for segmentation only" @@ -235,7 +235,7 @@ def match_name_keywords(n, name_keywords): break return out - for n, p in model_without_ddp.named_parameters(): + for n, _p in model_without_ddp.named_parameters(): print(n) param_dicts = [ @@ -306,9 +306,9 @@ def match_name_keywords(n, name_keywords): if not (k.endswith("total_params") or k.endswith("total_ops")) ] if len(missing_keys) > 0: - print("Missing Keys: {}".format(missing_keys)) + print(f"Missing Keys: {missing_keys}") if len(unexpected_keys) > 0: - print("Unexpected Keys: {}".format(unexpected_keys)) + print(f"Unexpected Keys: {unexpected_keys}") if ( not args.eval and "optimizer" in checkpoint @@ -319,7 +319,7 @@ def match_name_keywords(n, name_keywords): p_groups = copy.deepcopy(optimizer.param_groups) optimizer.load_state_dict(checkpoint["optimizer"]) - for pg, pg_old in zip(optimizer.param_groups, p_groups): + for pg, pg_old in zip(optimizer.param_groups, p_groups, strict=False): pg["lr"] = pg_old["lr"] pg["initial_lr"] = pg_old["initial_lr"] print(optimizer.param_groups) @@ -405,7 +405,7 @@ def match_name_keywords(n, name_keywords): total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) - print("Training time {}".format(total_time_str)) + print(f"Training time {total_time_str}") if __name__ == "__main__": diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py b/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py index 341dac2bde..c2b7a526d1 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/backbone.py @@ -11,13 +11,12 @@ Backbone modules. """ + import torch +from torch import nn import torch.nn.functional as F import torchvision -from torch import nn from torchvision.models._utils import IntermediateLayerGetter -from typing import Dict, List - from util.misc import NestedTensor, is_main_process from .position_encoding import build_position_encoding @@ -32,8 +31,8 @@ class FrozenBatchNorm2d(torch.nn.Module): produce nans. """ - def __init__(self, n, eps=1e-5): - super(FrozenBatchNorm2d, self).__init__() + def __init__(self, n, eps: float=1e-5) -> None: + super().__init__() self.register_buffer("weight", torch.ones(n)) self.register_buffer("bias", torch.zeros(n)) self.register_buffer("running_mean", torch.zeros(n)) @@ -41,13 +40,13 @@ def __init__(self, n, eps=1e-5): self.eps = eps def _load_from_state_dict( - self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs - ): + self, state_dict, prefix: str, local_metadata, strict: bool, missing_keys, unexpected_keys, error_msgs + ) -> None: num_batches_tracked_key = prefix + "num_batches_tracked" if num_batches_tracked_key in state_dict: del state_dict[num_batches_tracked_key] - super(FrozenBatchNorm2d, self)._load_from_state_dict( + super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) @@ -65,14 +64,14 @@ def forward(self, x): class BackboneBase(nn.Module): - def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool): + def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_layers: bool) -> None: super().__init__() for name, parameter in backbone.named_parameters(): if ( not train_backbone - or "layer2" not in name + or ("layer2" not in name and "layer3" not in name - and "layer4" not in name + and "layer4" not in name) ): parameter.requires_grad_(False) if return_interm_layers: @@ -88,7 +87,7 @@ def __init__(self, backbone: nn.Module, train_backbone: bool, return_interm_laye def forward(self, tensor_list: NestedTensor): xs = self.body(tensor_list.tensors) - out: Dict[str, NestedTensor] = {} + out: dict[str, NestedTensor] = {} for name, x in xs.items(): m = tensor_list.mask assert m is not None @@ -100,7 +99,7 @@ def forward(self, tensor_list: NestedTensor): class Backbone(BackboneBase): """ResNet backbone with frozen BatchNorm.""" - def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool): + def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, dilation: bool) -> None: norm_layer = FrozenBatchNorm2d backbone = getattr(torchvision.models, name)( replace_stride_with_dilation=[False, False, dilation], @@ -114,16 +113,16 @@ def __init__(self, name: str, train_backbone: bool, return_interm_layers: bool, class Joiner(nn.Sequential): - def __init__(self, backbone, position_embedding): + def __init__(self, backbone, position_embedding) -> None: super().__init__(backbone, position_embedding) self.strides = backbone.strides self.num_channels = backbone.num_channels def forward(self, tensor_list: NestedTensor): xs = self[0](tensor_list) - out: List[NestedTensor] = [] + out: list[NestedTensor] = [] pos = [] - for name, x in sorted(xs.items()): + for _name, x in sorted(xs.items()): out.append(x) # position encoding diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py index cce6571795..79ec0020f0 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_detr.py @@ -11,23 +11,26 @@ Deformable DETR model and criterion classes. """ -import torch -import torch.nn.functional as F -from torch import nn +from collections.abc import Sequence +import copy import math +import torch +from torch import nn +import torch.nn.functional as F from util import box_ops from util.misc import ( NestedTensor, - nested_tensor_from_tensor_list, accuracy, get_world_size, interpolate, - is_dist_avail_and_initialized, inverse_sigmoid, + is_dist_avail_and_initialized, + nested_tensor_from_tensor_list, ) from .backbone import build_backbone +from .deformable_transformer import build_deforamble_transformer from .matcher import build_matcher from .segmentation import ( DETRsegm, @@ -36,8 +39,6 @@ dice_loss, sigmoid_focal_loss, ) -from .deformable_transformer import build_deforamble_transformer -import copy def _get_clones(module, N): @@ -51,13 +52,13 @@ def __init__( self, backbone, transformer, - num_classes, - num_queries, - num_feature_levels, - aux_loss=True, - with_box_refine=False, - two_stage=False, - ): + num_classes: int, + num_queries: int, + num_feature_levels: int, + aux_loss: bool=True, + with_box_refine: bool=False, + two_stage: bool=False, + ) -> None: """Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py @@ -226,7 +227,7 @@ def _set_aux_loss(self, outputs_class, outputs_coord): # as a dict having both a Tensor and a list. return [ {"pred_logits": a, "pred_boxes": b} - for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + for a, b in zip(outputs_class[:-1], outputs_coord[:-1], strict=False) ] @@ -237,7 +238,7 @@ class SetCriterion(nn.Module): 2) we supervise each pair of matched ground-truth / prediction (supervise class and box) """ - def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): + def __init__(self, num_classes: int, matcher, weight_dict, losses, focal_alpha: float=0.25) -> None: """Create the criterion. Parameters: num_classes: number of object categories, omitting the special no-object category @@ -253,7 +254,7 @@ def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25): self.losses = losses self.focal_alpha = focal_alpha - def loss_labels(self, outputs, targets, indices, num_boxes, log=True): + def loss_labels(self, outputs, targets, indices, num_boxes: int, log: bool=True): """Classification loss (NLL) targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes] """ @@ -261,7 +262,7 @@ def loss_labels(self, outputs, targets, indices, num_boxes, log=True): src_logits = outputs["pred_logits"] idx = self._get_src_permutation_idx(indices) - target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]) + target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices, strict=False)]) target_classes = torch.full( src_logits.shape[:2], self.num_classes, dtype=torch.int64, device=src_logits.device ) @@ -290,7 +291,7 @@ def loss_labels(self, outputs, targets, indices, num_boxes, log=True): return losses @torch.no_grad() - def loss_cardinality(self, outputs, targets, indices, num_boxes): + def loss_cardinality(self, outputs, targets, indices, num_boxes: int): """Compute the cardinality error, ie the absolute error in the number of predicted non-empty boxes This is not really a loss, it is intended for logging purposes only. It doesn't propagate gradients """ @@ -303,7 +304,7 @@ def loss_cardinality(self, outputs, targets, indices, num_boxes): losses = {"cardinality_error": card_err} return losses - def loss_boxes(self, outputs, targets, indices, num_boxes): + def loss_boxes(self, outputs, targets, indices, num_boxes: int): """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] The target boxes are expected in format (center_x, center_y, h, w), normalized by the image size. @@ -311,7 +312,7 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): assert "pred_boxes" in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices, strict=False)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") @@ -326,7 +327,7 @@ def loss_boxes(self, outputs, targets, indices, num_boxes): losses["loss_giou"] = loss_giou.sum() / num_boxes return losses - def loss_masks(self, outputs, targets, indices, num_boxes): + def loss_masks(self, outputs, targets, indices, num_boxes: int): """Compute the losses related to the masks: the focal loss and the dice loss. targets dicts must contain the key "masks" containing a tensor of dim [nb_target_boxes, h, w] """ @@ -338,7 +339,7 @@ def loss_masks(self, outputs, targets, indices, num_boxes): src_masks = outputs["pred_masks"] # TODO use valid to mask invalid areas due to padding in loss - target_masks, valid = nested_tensor_from_tensor_list( + target_masks, _valid = nested_tensor_from_tensor_list( [t["masks"] for t in targets] ).decompose() target_masks = target_masks.to(src_masks) @@ -370,7 +371,7 @@ def _get_tgt_permutation_idx(self, indices): tgt_idx = torch.cat([tgt for (_, tgt) in indices]) return batch_idx, tgt_idx - def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs): + def get_loss(self, loss, outputs, targets, indices, num_boxes: int, **kwargs): loss_map = { "labels": self.loss_labels, "cardinality": self.loss_cardinality, @@ -450,7 +451,7 @@ class PostProcess(nn.Module): """This module converts the model's output into the format expected by the coco api""" @torch.no_grad() - def forward(self, outputs, target_sizes): + def forward(self, outputs, target_sizes: Sequence[int]): """Perform the computation Parameters: outputs: raw outputs of the model @@ -476,7 +477,7 @@ def forward(self, outputs, target_sizes): scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) boxes = boxes * scale_fct[:, None, :] - results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes)] + results = [{"scores": s, "labels": l, "boxes": b} for s, l, b in zip(scores, labels, boxes, strict=False)] return results @@ -484,12 +485,12 @@ def forward(self, outputs, target_sizes): class MLP(nn.Module): """Very simple multi-layer perceptron (also called FFN)""" - def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + def __init__(self, input_dim, hidden_dim, output_dim, num_layers: int) -> None: super().__init__() self.num_layers = num_layers h = [hidden_dim] * (num_layers - 1) self.layers = nn.ModuleList( - nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + nn.Linear(n, k) for n, k in zip([input_dim, *h], [*h, output_dim], strict=False) ) def forward(self, x): diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py index 6e75127833..f3cde19e1b 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/deformable_transformer.py @@ -11,31 +11,31 @@ import math import torch -import torch.nn.functional as F from torch import nn -from torch.nn.init import xavier_uniform_, constant_, normal_ - +import torch.nn.functional as F +from torch.nn.init import constant_, normal_, xavier_uniform_ from util.misc import inverse_sigmoid + from models.ops.modules import MSDeformAttn class DeformableTransformer(nn.Module): def __init__( self, - d_model=256, - nhead=8, - num_encoder_layers=6, - num_decoder_layers=6, - dim_feedforward=1024, - dropout=0.1, - activation="relu", - return_intermediate_dec=False, - num_feature_levels=4, - dec_n_points=4, - enc_n_points=4, - two_stage=False, - two_stage_num_proposals=300, - ): + d_model: int=256, + nhead: int=8, + num_encoder_layers: int=6, + num_decoder_layers: int=6, + dim_feedforward: int=1024, + dropout: float=0.1, + activation: str="relu", + return_intermediate_dec: bool=False, + num_feature_levels: int=4, + dec_n_points: int=4, + enc_n_points: int=4, + two_stage: bool=False, + two_stage_num_proposals: int=300, + ) -> None: super().__init__() self.d_model = d_model @@ -67,7 +67,7 @@ def __init__( self._reset_parameters() - def _reset_parameters(self): + def _reset_parameters(self) -> None: for p in self.parameters(): if p.dim() > 1: nn.init.xavier_uniform_(p) @@ -96,7 +96,6 @@ def get_proposal_pos_embed(self, proposals): def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes): N_, S_, C_ = memory.shape - base_scale = 4.0 proposals = [] _cur = 0 for lvl, (H_, W_) in enumerate(spatial_shapes): @@ -149,7 +148,7 @@ def forward(self, srcs, masks, pos_embeds, query_embed=None): mask_flatten = [] lvl_pos_embed_flatten = [] spatial_shapes = [] - for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds)): + for lvl, (src, mask, pos_embed) in enumerate(zip(srcs, masks, pos_embeds, strict=False)): bs, c, h, w = src.shape spatial_shape = (h, w) spatial_shapes.append(spatial_shape) @@ -240,14 +239,14 @@ def forward(self, srcs, masks, pos_embeds, query_embed=None): class DeformableTransformerEncoderLayer(nn.Module): def __init__( self, - d_model=256, - d_ffn=1024, - dropout=0.1, - activation="relu", - n_levels=4, - n_heads=8, - n_points=4, - ): + d_model: int=256, + d_ffn: int=1024, + dropout: float=0.1, + activation: str="relu", + n_levels: int=4, + n_heads: int=8, + n_points: int=4, + ) -> None: super().__init__() # self attention @@ -295,7 +294,7 @@ def forward( class DeformableTransformerEncoder(nn.Module): - def __init__(self, encoder_layer, num_layers): + def __init__(self, encoder_layer, num_layers: int) -> None: super().__init__() self.layers = _get_clones(encoder_layer, num_layers) self.num_layers = num_layers @@ -334,14 +333,14 @@ def forward( class DeformableTransformerDecoderLayer(nn.Module): def __init__( self, - d_model=256, - d_ffn=1024, - dropout=0.1, - activation="relu", - n_levels=4, - n_heads=8, - n_points=4, - ): + d_model: int=256, + d_ffn: int=1024, + dropout: float=0.1, + activation: str="relu", + n_levels: int=4, + n_heads: int=8, + n_points: int=4, + ) -> None: super().__init__() # cross attention @@ -409,7 +408,7 @@ def forward( class DeformableTransformerDecoder(nn.Module): - def __init__(self, decoder_layer, num_layers, return_intermediate=False): + def __init__(self, decoder_layer, num_layers: int, return_intermediate: bool=False) -> None: super().__init__() self.layers = _get_clones(decoder_layer, num_layers) self.num_layers = num_layers diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py b/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py index 29838972ab..7cbcf4a82e 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/matcher.py @@ -11,10 +11,9 @@ Modules to compute the matching cost and solve the corresponding LSAP. """ -import torch from scipy.optimize import linear_sum_assignment +import torch from torch import nn - from util.box_ops import box_cxcywh_to_xyxy, generalized_box_iou @@ -26,7 +25,7 @@ class HungarianMatcher(nn.Module): while the others are un-matched (and thus treated as non-objects). """ - def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1): + def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1) -> None: """Creates the matcher Params: diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py index c18582590e..965811ed7f 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/functions/ms_deform_attn_func.py @@ -6,16 +6,12 @@ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division +import MultiScaleDeformableAttention as MSDA import torch -import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable - -import MultiScaleDeformableAttention as MSDA +import torch.nn.functional as F class MSDeformAttnFunction(Function): diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py index bc02668b96..1d70af7cc4 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/modules/ms_deform_attn.py @@ -6,29 +6,26 @@ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division -import warnings import math +import warnings import torch from torch import nn import torch.nn.functional as F -from torch.nn.init import xavier_uniform_, constant_ +from torch.nn.init import constant_, xavier_uniform_ from ..functions import MSDeformAttnFunction def _is_power_of_2(n): if (not isinstance(n, int)) or (n < 0): - raise ValueError("invalid input for _is_power_of_2: {} (type: {})".format(n, type(n))) + raise ValueError(f"invalid input for _is_power_of_2: {n} (type: {type(n)})") return (n & (n - 1) == 0) and n != 0 class MSDeformAttn(nn.Module): - def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): + def __init__(self, d_model: int=256, n_levels: int=4, n_heads: int=8, n_points: int=4) -> None: """ Multi-Scale Deformable Attention Module :param d_model hidden dimension @@ -39,14 +36,14 @@ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): super().__init__() if d_model % n_heads != 0: raise ValueError( - "d_model must be divisible by n_heads, but got {} and {}".format(d_model, n_heads) + f"d_model must be divisible by n_heads, but got {d_model} and {n_heads}" ) _d_per_head = d_model // n_heads # you'd better set _d_per_head to a power of 2 which is more efficient in our CUDA implementation if not _is_power_of_2(_d_per_head): warnings.warn( "You'd better set d_model in MSDeformAttn to make the dimension of each attention head a power of 2 " - "which is more efficient in our CUDA implementation." + "which is more efficient in our CUDA implementation.", stacklevel=2 ) self.im2col_step = 64 @@ -63,7 +60,7 @@ def __init__(self, d_model=256, n_levels=4, n_heads=8, n_points=4): self._reset_parameters() - def _reset_parameters(self): + def _reset_parameters(self) -> None: constant_(self.sampling_offsets.weight.data, 0.0) thetas = torch.arange(self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) @@ -92,7 +89,7 @@ def forward( input_level_start_index, input_padding_mask=None, ): - """ + r""" :param query (N, Length_{query}, C) :param reference_points (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes @@ -136,9 +133,7 @@ def forward( ) else: raise ValueError( - "Last dim of reference_points must be 2 or 4, but get {} instead.".format( - reference_points.shape[-1] - ) + f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead." ) output = MSDeformAttnFunction.apply( value, diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py index 7cf252f0cf..7a5560a83f 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/setup.py @@ -6,17 +6,12 @@ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ -import os import glob +import os +from setuptools import find_packages, setup import torch - -from torch.utils.cpp_extension import CUDA_HOME -from torch.utils.cpp_extension import CppExtension -from torch.utils.cpp_extension import CUDAExtension - -from setuptools import find_packages -from setuptools import setup +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension requirements = ["torch", "torchvision"] diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py index 3fa3c7da6d..720d6473b2 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/ops/test.py @@ -6,16 +6,11 @@ # Modified from https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/tree/pytorch_1.0.0 # ------------------------------------------------------------------------------------------------ -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division +from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch import torch from torch.autograd import gradcheck -from functions.ms_deform_attn_func import MSDeformAttnFunction, ms_deform_attn_core_pytorch - - N, M, D = 1, 2, 2 Lq, L, P = 2, 2, 2 shapes = torch.as_tensor([(6, 4), (3, 2)], dtype=torch.long).cuda() @@ -27,7 +22,7 @@ @torch.no_grad() -def check_forward_equal_with_pytorch_double(): +def check_forward_equal_with_pytorch_double() -> None: value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 @@ -62,7 +57,7 @@ def check_forward_equal_with_pytorch_double(): @torch.no_grad() -def check_forward_equal_with_pytorch_float(): +def check_forward_equal_with_pytorch_float() -> None: value = torch.rand(N, S, M, D).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 @@ -90,8 +85,8 @@ def check_forward_equal_with_pytorch_float(): def check_gradient_numerical( - channels=4, grad_value=True, grad_sampling_loc=True, grad_attn_weight=True -): + channels: int=4, grad_value: bool=True, grad_sampling_loc: bool=True, grad_attn_weight: bool=True +) -> None: value = torch.rand(N, S, M, channels).cuda() * 0.01 sampling_locations = torch.rand(N, Lq, M, L, P, 2).cuda() attention_weights = torch.rand(N, Lq, M, L, P).cuda() + 1e-5 diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py b/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py index c0ab1b34c3..2ce5038e5e 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/position_encoding.py @@ -12,9 +12,9 @@ """ import math + import torch from torch import nn - from util.misc import NestedTensor @@ -24,7 +24,7 @@ class PositionEmbeddingSine(nn.Module): used by the Attention is all you need paper, generalized to work on images. """ - def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + def __init__(self, num_pos_feats: int=64, temperature: int=10000, normalize: bool=False, scale=None) -> None: super().__init__() self.num_pos_feats = num_pos_feats self.temperature = temperature @@ -67,13 +67,13 @@ class PositionEmbeddingLearned(nn.Module): Absolute pos embedding, learned. """ - def __init__(self, num_pos_feats=256): + def __init__(self, num_pos_feats: int=256) -> None: super().__init__() self.row_embed = nn.Embedding(50, num_pos_feats) self.col_embed = nn.Embedding(50, num_pos_feats) self.reset_parameters() - def reset_parameters(self): + def reset_parameters(self) -> None: nn.init.uniform_(self.row_embed.weight) nn.init.uniform_(self.col_embed.weight) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py b/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py index edb3f0a3c4..af68f7c1c7 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/models/segmentation.py @@ -11,14 +11,14 @@ This file provides the definition of the convolutional heads used to predict masks, as well as the losses """ -import io from collections import defaultdict +from collections.abc import Sequence +import io +from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F -from PIL import Image - import util.box_ops as box_ops from util.misc import NestedTensor, interpolate, nested_tensor_from_tensor_list @@ -29,7 +29,7 @@ class DETRsegm(nn.Module): - def __init__(self, detr, freeze_detr=False): + def __init__(self, detr, freeze_detr: bool=False) -> None: super().__init__() self.detr = detr @@ -58,7 +58,7 @@ def forward(self, samples: NestedTensor): if self.detr.aux_loss: out["aux_outputs"] = [ {"pred_logits": a, "pred_boxes": b} - for a, b in zip(outputs_class[:-1], outputs_coord[:-1]) + for a, b in zip(outputs_class[:-1], outputs_coord[:-1], strict=False) ] # FIXME h_boxes takes the last one computed, keep this in mind @@ -81,7 +81,7 @@ class MaskHeadSmallConv(nn.Module): Upsampling is done using a FPN approach """ - def __init__(self, dim, fpn_dims, context_dim): + def __init__(self, dim: int, fpn_dims, context_dim) -> None: super().__init__() inter_dims = [ @@ -116,7 +116,7 @@ def __init__(self, dim, fpn_dims, context_dim): nn.init.constant_(m.bias, 0) def forward(self, x, bbox_mask, fpns): - def expand(tensor, length): + def expand(tensor, length: int): return tensor.unsqueeze(1).repeat(1, int(length), 1, 1, 1).flatten(0, 1) x = torch.cat([expand(x, bbox_mask.shape[1]), bbox_mask.flatten(0, 1)], 1) @@ -159,7 +159,7 @@ def expand(tensor, length): class MHAttentionMap(nn.Module): """This is a 2D attention module, which only returns the attention softmax (no multiplication by value)""" - def __init__(self, query_dim, hidden_dim, num_heads, dropout=0, bias=True): + def __init__(self, query_dim, hidden_dim, num_heads: int, dropout: int=0, bias: bool=True) -> None: super().__init__() self.num_heads = num_heads self.hidden_dim = hidden_dim @@ -190,7 +190,7 @@ def forward(self, q, k, mask=None): return weights -def dice_loss(inputs, targets, num_boxes): +def dice_loss(inputs, targets, num_boxes: int): """ Compute the DICE loss, similar to generalized IOU for masks Args: @@ -208,7 +208,7 @@ def dice_loss(inputs, targets, num_boxes): return loss.sum() / num_boxes -def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): +def sigmoid_focal_loss(inputs, targets, num_boxes: int, alpha: float = 0.25, gamma: float = 2): """ Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: @@ -237,12 +237,12 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f class PostProcessSegm(nn.Module): - def __init__(self, threshold=0.5): + def __init__(self, threshold: float=0.5) -> None: super().__init__() self.threshold = threshold @torch.no_grad() - def forward(self, results, outputs, orig_target_sizes, max_target_sizes): + def forward(self, results, outputs, orig_target_sizes: Sequence[int], max_target_sizes: Sequence[int]): assert len(orig_target_sizes) == len(max_target_sizes) max_h, max_w = max_target_sizes.max(0)[0].tolist() outputs_masks = outputs["pred_masks"].squeeze(2) @@ -252,7 +252,7 @@ def forward(self, results, outputs, orig_target_sizes, max_target_sizes): outputs_masks = (outputs_masks.sigmoid() > self.threshold).cpu() for i, (cur_mask, t, tt) in enumerate( - zip(outputs_masks, max_target_sizes, orig_target_sizes) + zip(outputs_masks, max_target_sizes, orig_target_sizes, strict=False) ): img_h, img_w = t[0], t[1] results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) @@ -267,7 +267,7 @@ class PostProcessPanoptic(nn.Module): """This class converts the output of the model to the final panoptic result, in the format expected by the coco panoptic API""" - def __init__(self, is_thing_map, threshold=0.85): + def __init__(self, is_thing_map: bool, threshold: float=0.85) -> None: """ Parameters: is_thing_map: This is a whose keys are the class ids, and the values a boolean indicating whether @@ -278,7 +278,7 @@ def __init__(self, is_thing_map, threshold=0.85): self.threshold = threshold self.is_thing_map = is_thing_map - def forward(self, outputs, processed_sizes, target_sizes=None): + def forward(self, outputs, processed_sizes: Sequence[int], target_sizes: Sequence[int] | None=None): """This function computes the panoptic prediction from the model's predictions. Parameters: outputs: This is a dict coming directly from the model. See the model doc for the content. @@ -304,7 +304,7 @@ def to_tuple(tup): return tuple(tup.cpu().tolist()) for cur_logits, cur_masks, cur_boxes, size, target_size in zip( - out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes + out_logits, raw_masks, raw_boxes, processed_sizes, target_sizes, strict=False ): # we filter empty queries and detection below threshold scores, labels = cur_logits.softmax(-1).max(-1) @@ -327,7 +327,7 @@ def to_tuple(tup): if not self.is_thing_map[label.item()]: stuff_equiv_classes[label.item()].append(k) - def get_ids_area(masks, scores, dedup=False): + def get_ids_area(masks, scores, dedup: bool=False): # This helper function creates the final panoptic segmentation image # It also returns the area of the masks that appears on the image diff --git a/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py b/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py index 9e9fdfea2c..1d60ae4994 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/tools/launch.py @@ -103,9 +103,9 @@ how things can go wrong if you don't do this correctly. """ -import subprocess +from argparse import REMAINDER, ArgumentParser import os -from argparse import ArgumentParser, REMAINDER +import subprocess def parse_args(): @@ -189,7 +189,7 @@ def main(): current_env["RANK"] = str(dist_rank) current_env["LOCAL_RANK"] = str(local_rank) - cmd = [args.training_script] + args.training_script_args + cmd = [args.training_script, *args.training_script_args] process = subprocess.Popen(cmd, env=current_env) processes.append(process) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py b/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py index 661807da15..2acf190530 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/misc.py @@ -13,26 +13,26 @@ Mostly copy-paste from torchvision references. """ -import os -import subprocess -import time from collections import defaultdict, deque import datetime +import os import pickle -from typing import Optional, List +import subprocess +import time import torch -import torch.distributed as dist from torch import Tensor +import torch.distributed as dist # needed due to empty tensor bug in pytorch and torchvision 0.5 import torchvision if float(torchvision.__version__[:3]) < 0.5: import math + from torchvision.ops.misc import _NewEmptyTensorOp - def _check_size_scale_factor(dim, size, scale_factor): + def _check_size_scale_factor(dim: int, size: int, scale_factor): # type: (int, Optional[List[int]], Optional[float]) -> None if size is None and scale_factor is None: raise ValueError("either size or scale_factor should be defined") @@ -40,12 +40,10 @@ def _check_size_scale_factor(dim, size, scale_factor): raise ValueError("only one of size or scale_factor should be defined") if not (scale_factor is not None and len(scale_factor) != dim): raise ValueError( - "scale_factor shape must match input shape. Input is {}D, scale_factor size is {}".format( - dim, len(scale_factor) - ) + f"scale_factor shape must match input shape. Input is {dim}D, scale_factor size is {len(scale_factor)}" ) - def _output_size(dim, input, size, scale_factor): + def _output_size(dim: int, input, size: int, scale_factor): # type: (int, Tensor, Optional[List[int]], Optional[float]) -> List[int] assert dim == 2 _check_size_scale_factor(dim, size, scale_factor) @@ -55,18 +53,18 @@ def _output_size(dim, input, size, scale_factor): assert scale_factor is not None and isinstance(scale_factor, (int, float)) scale_factors = [scale_factor, scale_factor] # math.floor might return float in py2.7 - return [int(math.floor(input.size(i + 2) * scale_factors[i])) for i in range(dim)] + return [math.floor(input.size(i + 2) * scale_factors[i]) for i in range(dim)] elif float(torchvision.__version__[:3]) < 0.7: from torchvision.ops import _new_empty_tensor from torchvision.ops.misc import _output_size -class SmoothedValue(object): +class SmoothedValue: """Track a series of values and provide access to smoothed values over a window or the global series average. """ - def __init__(self, window_size=20, fmt=None): + def __init__(self, window_size: int=20, fmt=None) -> None: if fmt is None: fmt = "{median:.4f} ({global_avg:.4f})" self.deque = deque(maxlen=window_size) @@ -74,12 +72,12 @@ def __init__(self, window_size=20, fmt=None): self.count = 0 self.fmt = fmt - def update(self, value, n=1): + def update(self, value, n: int=1) -> None: self.deque.append(value) self.count += n self.total += value * n - def synchronize_between_processes(self): + def synchronize_between_processes(self) -> None: """ Warning: does not synchronize the deque! """ @@ -114,7 +112,7 @@ def max(self): def value(self): return self.deque[-1] - def __str__(self): + def __str__(self) -> str: return self.fmt.format( median=self.median, avg=self.avg, @@ -160,14 +158,14 @@ def all_gather(data): dist.all_gather(tensor_list, tensor) data_list = [] - for size, tensor in zip(size_list, tensor_list): + for size, tensor in zip(size_list, tensor_list, strict=False): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list -def reduce_dict(input_dict, average=True): +def reduce_dict(input_dict, average: bool=True): """ Args: input_dict (dict): all the values will be reduced @@ -190,16 +188,16 @@ def reduce_dict(input_dict, average=True): dist.all_reduce(values) if average: values /= world_size - reduced_dict = {k: v for k, v in zip(names, values)} + reduced_dict = {k: v for k, v in zip(names, values, strict=False)} return reduced_dict -class MetricLogger(object): - def __init__(self, delimiter="\t"): +class MetricLogger: + def __init__(self, delimiter: str="\t") -> None: self.meters = defaultdict(SmoothedValue) self.delimiter = delimiter - def update(self, **kwargs): + def update(self, **kwargs) -> None: for k, v in kwargs.items(): if isinstance(v, torch.Tensor): v = v.item() @@ -211,19 +209,19 @@ def __getattr__(self, attr): return self.meters[attr] if attr in self.__dict__: return self.__dict__[attr] - raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") - def __str__(self): + def __str__(self) -> str: loss_str = [] for name, meter in self.meters.items(): - loss_str.append("{}: {}".format(name, str(meter))) + loss_str.append(f"{name}: {meter!s}") return self.delimiter.join(loss_str) - def synchronize_between_processes(self): + def synchronize_between_processes(self) -> None: for meter in self.meters.values(): meter.synchronize_between_processes() - def add_meter(self, name, meter): + def add_meter(self, name: str, meter) -> None: self.meters[name] = meter def log_every(self, iterable, print_freq, header=None): @@ -294,9 +292,7 @@ def log_every(self, iterable, print_freq, header=None): total_time = time.time() - start_time total_time_str = str(datetime.timedelta(seconds=int(total_time))) print( - "{} Total time: {} ({:.4f} s / it)".format( - header, total_time_str, total_time / len(iterable) - ) + f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)" ) @@ -322,7 +318,7 @@ def _run(command): def collate_fn(batch): - batch = list(zip(*batch)) + batch = list(zip(*batch, strict=False)) batch[0] = nested_tensor_from_tensor_list(batch[0]) return tuple(batch) @@ -336,19 +332,19 @@ def _max_by_axis(the_list): return maxes -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): +def nested_tensor_from_tensor_list(tensor_list: list[Tensor]): # TODO make this more general if tensor_list[0].ndim == 3: # TODO make it support different-sized images max_size = _max_by_axis([list(img.shape) for img in tensor_list]) # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape + batch_shape = [len(tensor_list), *max_size] + b, _c, h, w = batch_shape dtype = tensor_list[0].dtype device = tensor_list[0].device tensor = torch.zeros(batch_shape, dtype=dtype, device=device) mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): + for img, pad_img, m in zip(tensor_list, tensor, mask, strict=False): pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) m[: img.shape[1], : img.shape[2]] = False else: @@ -356,13 +352,13 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): return NestedTensor(tensor, mask) -class NestedTensor(object): - def __init__(self, tensors, mask: Optional[Tensor]): +class NestedTensor: + def __init__(self, tensors, mask: Tensor | None) -> None: self.tensors = tensors self.mask = mask - def to(self, device, non_blocking=False): - # type: (Device) -> NestedTensor # noqa + def to(self, device, non_blocking: bool=False): + # type: (Device) -> NestedTensor cast_tensor = self.tensors.to(device, non_blocking=non_blocking) mask = self.mask if mask is not None: @@ -372,7 +368,7 @@ def to(self, device, non_blocking=False): cast_mask = None return NestedTensor(cast_tensor, cast_mask) - def record_stream(self, *args, **kwargs): + def record_stream(self, *args, **kwargs) -> None: self.tensors.record_stream(*args, **kwargs) if self.mask is not None: self.mask.record_stream(*args, **kwargs) @@ -380,11 +376,11 @@ def record_stream(self, *args, **kwargs): def decompose(self): return self.tensors, self.mask - def __repr__(self): + def __repr__(self) -> str: return str(self.tensors) -def setup_for_distributed(is_master): +def setup_for_distributed(is_master: bool) -> None: """ This function disables printing when not in master process """ @@ -392,7 +388,7 @@ def setup_for_distributed(is_master): builtin_print = __builtin__.print - def print(*args, **kwargs): + def print(*args, **kwargs) -> None: force = kwargs.pop("force", False) if is_master or force: builtin_print(*args, **kwargs) @@ -400,7 +396,7 @@ def print(*args, **kwargs): __builtin__.print = print -def is_dist_avail_and_initialized(): +def is_dist_avail_and_initialized() -> bool: if not dist.is_available(): return False if not dist.is_initialized(): @@ -436,12 +432,12 @@ def is_main_process(): return get_rank() == 0 -def save_on_master(*args, **kwargs): +def save_on_master(*args, **kwargs) -> None: if is_main_process(): torch.save(*args, **kwargs) -def init_distributed_mode(args): +def init_distributed_mode(args) -> None: if "RANK" in os.environ and "WORLD_SIZE" in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ["WORLD_SIZE"]) @@ -453,7 +449,7 @@ def init_distributed_mode(args): ntasks = int(os.environ["SLURM_NTASKS"]) node_list = os.environ["SLURM_NODELIST"] num_gpus = torch.cuda.device_count() - addr = subprocess.getoutput("scontrol show hostname {} | head -n1".format(node_list)) + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") os.environ["MASTER_PORT"] = os.environ.get("MASTER_PORT", "29500") os.environ["MASTER_ADDR"] = addr os.environ["WORLD_SIZE"] = str(ntasks) @@ -473,7 +469,7 @@ def init_distributed_mode(args): torch.cuda.set_device(args.gpu) args.dist_backend = "nccl" - print("| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True) + print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True) torch.distributed.init_process_group( backend=args.dist_backend, init_method=args.dist_url, @@ -503,7 +499,7 @@ def accuracy(output, target, topk=(1,)): return res -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): +def interpolate(input, size: int | None=None, scale_factor=None, mode: str="nearest", align_corners=None): # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor """ Equivalent to nn.functional.interpolate, but with support for empty batch sizes. @@ -523,7 +519,7 @@ class can go away. return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners) -def get_total_grad_norm(parameters, norm_type=2): +def get_total_grad_norm(parameters, norm_type: int=2): parameters = list(filter(lambda p: p.grad is not None, parameters)) norm_type = float(norm_type) device = parameters[0].grad.device @@ -534,7 +530,7 @@ def get_total_grad_norm(parameters, norm_type=2): return total_norm -def inverse_sigmoid(x, eps=1e-5): +def inverse_sigmoid(x, eps: float=1e-5): x = x.clamp(min=0, max=1) x1 = x.clamp(min=eps) x2 = (1 - x).clamp(min=eps) diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py b/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py index 3bbb97b3d1..710420f410 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py @@ -11,16 +11,16 @@ Plotting utilities to visualize training logs. """ -import torch +from pathlib import Path, PurePath + +import matplotlib.pyplot as plt import pandas as pd import seaborn as sns -import matplotlib.pyplot as plt - -from pathlib import Path, PurePath +import torch def plot_logs( - logs, fields=("class_error", "loss_bbox_unscaled", "mAP"), ewm_col=0, log_name="log.txt" + logs, fields=("class_error", "loss_bbox_unscaled", "mAP"), ewm_col: int=0, log_name: str="log.txt" ): """ Function to plot specific fields from training log(s). Plots both training and test results. @@ -50,7 +50,7 @@ def plot_logs( ) # verify valid dir(s) and that every item in list is Path object - for i, dir in enumerate(logs): + for _i, dir in enumerate(logs): if not isinstance(dir, PurePath): raise ValueError( f"{func_name} - non-Path object in logs argument of {type(dir)}: \n{dir}" @@ -62,9 +62,9 @@ def plot_logs( # load log file(s) and plot dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] - fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + _fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) - for df, color in zip(dfs, sns.color_palette(n_colors=len(logs))): + for df, color in zip(dfs, sns.color_palette(n_colors=len(logs)), strict=False): for j, field in enumerate(fields): if field == "mAP": coco_eval = ( @@ -80,12 +80,12 @@ def plot_logs( color=[color] * 2, style=["-", "--"], ) - for ax, field in zip(axs, fields): + for ax, field in zip(axs, fields, strict=False): ax.legend([Path(p).name for p in logs]) ax.set_title(field) -def plot_precision_recall(files, naming_scheme="iter"): +def plot_precision_recall(files, naming_scheme: str="iter"): if naming_scheme == "exp_id": # name becomes exp_id names = [f.parts[-3] for f in files] @@ -94,7 +94,7 @@ def plot_precision_recall(files, naming_scheme="iter"): else: raise ValueError(f"not supported {naming_scheme}") fig, axs = plt.subplots(ncols=2, figsize=(16, 5)) - for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names): + for f, color, name in zip(files, sns.color_palette("Blues", n_colors=len(files)), names, strict=False): data = torch.load(f) # precision is n_iou, n_points, n_cat, n_area, max_det precision = data["precision"] diff --git a/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py b/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py index 6b24b5b260..567e71f7c4 100644 --- a/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py +++ b/dimos/models/Detic/tools/convert-thirdparty-pretrained-model-to-d2.py @@ -2,6 +2,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved import argparse import pickle + import torch """ diff --git a/dimos/models/Detic/tools/create_imagenetlvis_json.py b/dimos/models/Detic/tools/create_imagenetlvis_json.py index 54883d7337..4f53874421 100644 --- a/dimos/models/Detic/tools/create_imagenetlvis_json.py +++ b/dimos/models/Detic/tools/create_imagenetlvis_json.py @@ -2,8 +2,9 @@ import argparse import json import os -from nltk.corpus import wordnet + from detectron2.data.detection_utils import read_image +from nltk.corpus import wordnet if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -15,7 +16,7 @@ args = parser.parse_args() print("Loading LVIS meta") - data = json.load(open(args.lvis_meta_path, "r")) + data = json.load(open(args.lvis_meta_path)) print("Done") synset2cat = {x["synset"]: x for x in data["categories"]} count = 0 @@ -32,9 +33,9 @@ cat_images = [] for file in files: count = count + 1 - file_name = "{}/{}".format(folder, file) + file_name = f"{folder}/{file}" # img = cv2.imread('{}/{}'.format(args.imagenet_path, file_name)) - img = read_image("{}/{}".format(args.imagenet_path, file_name)) + img = read_image(f"{args.imagenet_path}/{file_name}") h, w = img.shape[:2] image = { "id": count, diff --git a/dimos/models/Detic/tools/create_lvis_21k.py b/dimos/models/Detic/tools/create_lvis_21k.py index 05e9530181..a1f24446ac 100644 --- a/dimos/models/Detic/tools/create_lvis_21k.py +++ b/dimos/models/Detic/tools/create_lvis_21k.py @@ -16,9 +16,9 @@ args = parser.parse_args() print("Loading", args.imagenet_path) - in_data = json.load(open(args.imagenet_path, "r")) + in_data = json.load(open(args.imagenet_path)) print("Loading", args.lvis_path) - lvis_data = json.load(open(args.lvis_path, "r")) + lvis_data = json.load(open(args.lvis_path)) categories = copy.deepcopy(lvis_data["categories"]) cat_count = max(x["id"] for x in categories) @@ -53,14 +53,14 @@ lvis_data["categories"] = categories if not args.not_save_imagenet: - in_out_path = args.imagenet_path[:-5] + "_{}.json".format(args.mark) + in_out_path = args.imagenet_path[:-5] + f"_{args.mark}.json" for k, v in in_data.items(): print("imagenet", k, len(v)) print("Saving Imagenet to", in_out_path) json.dump(in_data, open(in_out_path, "w")) if not args.not_save_lvis: - lvis_out_path = args.lvis_path[:-5] + "_{}.json".format(args.mark) + lvis_out_path = args.lvis_path[:-5] + f"_{args.mark}.json" for k, v in lvis_data.items(): print("lvis", k, len(v)) print("Saving LVIS to", lvis_out_path) @@ -72,4 +72,4 @@ if k in x: del x[k] CATEGORIES = repr(categories) + " # noqa" - open(args.save_categories, "wt").write(f"CATEGORIES = {CATEGORIES}") + open(args.save_categories, "w").write(f"CATEGORIES = {CATEGORIES}") diff --git a/dimos/models/Detic/tools/download_cc.py b/dimos/models/Detic/tools/download_cc.py index fb493c8edc..ef7b4b0f7d 100644 --- a/dimos/models/Detic/tools/download_cc.py +++ b/dimos/models/Detic/tools/download_cc.py @@ -1,9 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import os -import json import argparse -from PIL import Image +import json +import os + import numpy as np +from PIL import Image if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -13,7 +14,7 @@ parser.add_argument("--out_path", default="datasets/cc3m/train_image_info.json") parser.add_argument("--not_download_image", action="store_true") args = parser.parse_args() - categories = json.load(open(args.cat_info, "r"))["categories"] + categories = json.load(open(args.cat_info))["categories"] images = [] if not os.path.exists(args.save_image_path): os.makedirs(args.save_image_path) @@ -22,16 +23,16 @@ cap, path = line[:-1].split("\t") print(i, cap, path) if not args.not_download_image: - os.system("wget {} -O {}/{}.jpg".format(path, args.save_image_path, i + 1)) + os.system(f"wget {path} -O {args.save_image_path}/{i + 1}.jpg") try: - img = Image.open(open("{}/{}.jpg".format(args.save_image_path, i + 1), "rb")) + img = Image.open(open(f"{args.save_image_path}/{i + 1}.jpg", "rb")) img = np.asarray(img.convert("RGB")) h, w = img.shape[:2] except: continue image_info = { "id": i + 1, - "file_name": "{}.jpg".format(i + 1), + "file_name": f"{i + 1}.jpg", "height": h, "width": w, "captions": [cap], diff --git a/dimos/models/Detic/tools/dump_clip_features.py b/dimos/models/Detic/tools/dump_clip_features.py index 941fe221ed..31be161f6d 100644 --- a/dimos/models/Detic/tools/dump_clip_features.py +++ b/dimos/models/Detic/tools/dump_clip_features.py @@ -1,10 +1,11 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse -import json -import torch -import numpy as np import itertools +import json + from nltk.corpus import wordnet +import numpy as np +import torch if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -20,7 +21,7 @@ args = parser.parse_args() print("Loading", args.ann) - data = json.load(open(args.ann, "r")) + data = json.load(open(args.ann)) cat_names = [x["name"] for x in sorted(data["categories"], key=lambda x: x["id"])] if "synonyms" in data["categories"][0]: if args.use_wn_name: @@ -48,12 +49,12 @@ sentences = [x for x in cat_names] sentences_synonyms = [[xx for xx in x] for x in synonyms] elif args.prompt == "photo": - sentences = ["a photo of a {}".format(x) for x in cat_names] - sentences_synonyms = [["a photo of a {}".format(xx) for xx in x] for x in synonyms] + sentences = [f"a photo of a {x}" for x in cat_names] + sentences_synonyms = [[f"a photo of a {xx}" for xx in x] for x in synonyms] elif args.prompt == "scene": - sentences = ["a photo of a {} in the scene".format(x) for x in cat_names] + sentences = [f"a photo of a {x} in the scene" for x in cat_names] sentences_synonyms = [ - ["a photo of a {} in the scene".format(xx) for xx in x] for x in synonyms + [f"a photo of a {xx} in the scene" for xx in x] for x in synonyms ] print("sentences_synonyms", len(sentences_synonyms), sum(len(x) for x in sentences_synonyms)) @@ -86,7 +87,7 @@ print("after stack", text_features.shape) text_features = text_features.cpu().numpy() elif args.model in ["bert", "roberta"]: - from transformers import AutoTokenizer, AutoModel + from transformers import AutoModel, AutoTokenizer if args.model == "bert": model_name = "bert-large-uncased" diff --git a/dimos/models/Detic/tools/fix_o365_names.py b/dimos/models/Detic/tools/fix_o365_names.py index 7b2ffad365..5aee27a14f 100644 --- a/dimos/models/Detic/tools/fix_o365_names.py +++ b/dimos/models/Detic/tools/fix_o365_names.py @@ -1,7 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse -import json import copy +import json if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -11,12 +11,12 @@ new_names = {} old_names = {} - with open(args.fix_name_map, "r") as f: + with open(args.fix_name_map) as f: for line in f: tmp = line.strip().split(",") old_names[int(tmp[0])] = tmp[1] new_names[int(tmp[0])] = tmp[2] - data = json.load(open(args.ann, "r")) + data = json.load(open(args.ann)) cat_info = copy.deepcopy(data["categories"]) diff --git a/dimos/models/Detic/tools/fix_o365_path.py b/dimos/models/Detic/tools/fix_o365_path.py index 8e0b476323..c43358fff0 100644 --- a/dimos/models/Detic/tools/fix_o365_path.py +++ b/dimos/models/Detic/tools/fix_o365_path.py @@ -1,9 +1,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse import json -import path import os +import path + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( @@ -13,7 +14,7 @@ args = parser.parse_args() print("Loading", args.ann) - data = json.load(open(args.ann, "r")) + data = json.load(open(args.ann)) images = [] count = 0 for x in data["images"]: diff --git a/dimos/models/Detic/tools/get_cc_tags.py b/dimos/models/Detic/tools/get_cc_tags.py index 52aa05445c..0a5cdab8ec 100644 --- a/dimos/models/Detic/tools/get_cc_tags.py +++ b/dimos/models/Detic/tools/get_cc_tags.py @@ -1,7 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse -import json from collections import defaultdict +import json + from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES # This mapping is extracted from the official LVIS mapping: @@ -110,7 +111,7 @@ def map_name(x): args = parser.parse_args() # lvis_data = json.load(open(args.lvis_ann, 'r')) - cc_data = json.load(open(args.cc_ann, "r")) + cc_data = json.load(open(args.cc_ann)) if args.convert_caption: num_caps = 0 caps = defaultdict(list) diff --git a/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py b/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py index 874d378d48..688b0a92e5 100644 --- a/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py +++ b/dimos/models/Detic/tools/get_coco_zeroshot_oriorder.py @@ -10,10 +10,10 @@ parser.add_argument("--cat_path", default="datasets/coco/annotations/instances_val2017.json") args = parser.parse_args() print("Loading", args.cat_path) - cat = json.load(open(args.cat_path, "r"))["categories"] + cat = json.load(open(args.cat_path))["categories"] print("Loading", args.data_path) - data = json.load(open(args.data_path, "r")) + data = json.load(open(args.data_path)) data["categories"] = cat out_path = args.data_path[:-5] + "_oriorder.json" print("Saving to", out_path) diff --git a/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py b/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py index 2f19a6cf91..00502db11f 100644 --- a/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py +++ b/dimos/models/Detic/tools/get_imagenet_21k_full_tar_json.py @@ -1,13 +1,14 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse import json -import numpy as np +import operator import sys import time + from nltk.corpus import wordnet -from tqdm import tqdm -import operator +import numpy as np import torch +from tqdm import tqdm sys.path.insert(0, "third_party/CenterNet2/") sys.path.insert(0, "third_party/Deformable-DETR") diff --git a/dimos/models/Detic/tools/get_lvis_cat_info.py b/dimos/models/Detic/tools/get_lvis_cat_info.py index 79d025300c..414a615b8a 100644 --- a/dimos/models/Detic/tools/get_lvis_cat_info.py +++ b/dimos/models/Detic/tools/get_lvis_cat_info.py @@ -11,7 +11,7 @@ args = parser.parse_args() print("Loading", args.ann) - data = json.load(open(args.ann, "r")) + data = json.load(open(args.ann)) cats = data["categories"] image_count = {x["id"]: set() for x in cats} ann_count = {x["id"]: 0 for x in cats} diff --git a/dimos/models/Detic/tools/merge_lvis_coco.py b/dimos/models/Detic/tools/merge_lvis_coco.py index 5ef480d28e..1a76a02f0b 100644 --- a/dimos/models/Detic/tools/merge_lvis_coco.py +++ b/dimos/models/Detic/tools/merge_lvis_coco.py @@ -1,9 +1,9 @@ # Copyright (c) Facebook, Inc. and its affiliates. from collections import defaultdict -import torch import json from detectron2.structures import Boxes, pairwise_iou +import torch COCO_PATH = "datasets/coco/annotations/instances_train2017.json" IMG_PATH = "datasets/coco/train2017/" @@ -110,8 +110,8 @@ def get_bbox(ann): if __name__ == "__main__": file_name_key = "file_name" if "v0.5" in LVIS_PATH else "coco_url" - coco_data = json.load(open(COCO_PATH, "r")) - lvis_data = json.load(open(LVIS_PATH, "r")) + coco_data = json.load(open(COCO_PATH)) + lvis_data = json.load(open(LVIS_PATH)) coco_cats = coco_data["categories"] lvis_cats = lvis_data["categories"] diff --git a/dimos/models/Detic/tools/preprocess_imagenet22k.py b/dimos/models/Detic/tools/preprocess_imagenet22k.py index f4ea6fcbfe..c5a5ad0d31 100644 --- a/dimos/models/Detic/tools/preprocess_imagenet22k.py +++ b/dimos/models/Detic/tools/preprocess_imagenet22k.py @@ -2,26 +2,28 @@ # Copyright (c) Facebook, Inc. and its affiliates. import os -import numpy as np import sys +import numpy as np + sys.path.insert(0, "third_party/CenterNet2/") sys.path.insert(0, "third_party/Deformable-DETR") -from detic.data.tar_dataset import _TarDataset -import io import gzip +import io import time +from detic.data.tar_dataset import _TarDataset + -class _RawTarDataset(object): - def __init__(self, filename, indexname, preload=False): +class _RawTarDataset: + def __init__(self, filename, indexname: str, preload: bool=False) -> None: self.filename = filename self.names = [] self.offsets = [] for l in open(indexname): ll = l.split() - a, b, c = ll[:3] + _a, b, c = ll[:3] offset = int(b[:-1]) if l.endswith("** Block of NULs **\n"): self.offsets.append(offset) @@ -38,10 +40,10 @@ def __init__(self, filename, indexname, preload=False): else: self.data = None - def __len__(self): + def __len__(self) -> int: return len(self.names) - def __getitem__(self, idx): + def __getitem__(self, idx: int): if self.data is None: self.data = np.memmap(self.filename, mode="r", dtype="uint8") ofs = self.offsets[idx] * 512 @@ -64,7 +66,7 @@ def __getitem__(self, idx): return sdata -def preprocess(): +def preprocess() -> None: # Follow https://github.com/Alibaba-MIIL/ImageNet21K/blob/main/dataset_preprocessing/processing_script.sh # Expect 12358684 samples with 11221 classes # ImageNet folder has 21841 classes (synsets) @@ -79,7 +81,6 @@ def preprocess(): log_files = os.listdir(i22ktarlogs) log_files = [x for x in log_files if x.endswith(".tarlog")] log_files.sort() - chunk_datasets = [] dataset_lens = [] min_count = 0 create_npy_tarlogs = True diff --git a/dimos/models/Detic/tools/remove_lvis_rare.py b/dimos/models/Detic/tools/remove_lvis_rare.py index 2e1705d50c..423dd6e6e2 100644 --- a/dimos/models/Detic/tools/remove_lvis_rare.py +++ b/dimos/models/Detic/tools/remove_lvis_rare.py @@ -8,7 +8,7 @@ args = parser.parse_args() print("Loading", args.ann) - data = json.load(open(args.ann, "r")) + data = json.load(open(args.ann)) catid2freq = {x["id"]: x["frequency"] for x in data["categories"]} print("ori #anns", len(data["annotations"])) exclude = ["r"] diff --git a/dimos/models/Detic/tools/unzip_imagenet_lvis.py b/dimos/models/Detic/tools/unzip_imagenet_lvis.py index d550db9980..fd969c28bb 100644 --- a/dimos/models/Detic/tools/unzip_imagenet_lvis.py +++ b/dimos/models/Detic/tools/unzip_imagenet_lvis.py @@ -1,6 +1,6 @@ # Copyright (c) Facebook, Inc. and its affiliates. -import os import argparse +import os if __name__ == "__main__": parser = argparse.ArgumentParser() diff --git a/dimos/models/Detic/train_net.py b/dimos/models/Detic/train_net.py index 53699045bd..54ab6136f4 100644 --- a/dimos/models/Detic/train_net.py +++ b/dimos/models/Detic/train_net.py @@ -1,56 +1,54 @@ # Copyright (c) Facebook, Inc. and its affiliates. +from collections import OrderedDict +import datetime import logging import os import sys -from collections import OrderedDict -import torch -from torch.nn.parallel import DistributedDataParallel import time -import datetime -from fvcore.common.timer import Timer -import detectron2.utils.comm as comm from detectron2.checkpoint import DetectionCheckpointer, PeriodicCheckpointer from detectron2.config import get_cfg from detectron2.data import ( MetadataCatalog, build_detection_test_loader, ) +from detectron2.data.build import build_detection_train_loader +from detectron2.data.dataset_mapper import DatasetMapper from detectron2.engine import default_argument_parser, default_setup, launch - from detectron2.evaluation import ( + COCOEvaluator, + LVISEvaluator, inference_on_dataset, print_csv_format, - LVISEvaluator, - COCOEvaluator, ) from detectron2.modeling import build_model from detectron2.solver import build_lr_scheduler, build_optimizer +import detectron2.utils.comm as comm from detectron2.utils.events import ( CommonMetricPrinter, EventStorage, JSONWriter, TensorboardXWriter, ) -from detectron2.data.dataset_mapper import DatasetMapper -from detectron2.data.build import build_detection_train_loader from detectron2.utils.logger import setup_logger +from fvcore.common.timer import Timer +import torch from torch.cuda.amp import GradScaler +from torch.nn.parallel import DistributedDataParallel sys.path.insert(0, "third_party/CenterNet2/") from centernet.config import add_centernet_config sys.path.insert(0, "third_party/Deformable-DETR") from detic.config import add_detic_config +from detic.custom_solver import build_custom_optimizer from detic.data.custom_build_augmentation import build_custom_augmentation from detic.data.custom_dataset_dataloader import build_custom_train_loader from detic.data.custom_dataset_mapper import CustomDatasetMapper, DetrDatasetMapper -from detic.custom_solver import build_custom_optimizer -from detic.evaluation.oideval import OIDEvaluator from detic.evaluation.custom_coco_eval import CustomCOCOEvaluator +from detic.evaluation.oideval import OIDEvaluator from detic.modeling.utils import reset_cls_test - logger = logging.getLogger("detectron2") @@ -65,7 +63,7 @@ def do_test(cfg, model): else DatasetMapper(cfg, False, augmentations=build_custom_augmentation(cfg, False)) ) data_loader = build_detection_test_loader(cfg, dataset_name, mapper=mapper) - output_folder = os.path.join(cfg.OUTPUT_DIR, "inference_{}".format(dataset_name)) + output_folder = os.path.join(cfg.OUTPUT_DIR, f"inference_{dataset_name}") evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type if evaluator_type == "lvis" or cfg.GEN_PSEDO_LABELS: @@ -83,14 +81,14 @@ def do_test(cfg, model): results[dataset_name] = inference_on_dataset(model, data_loader, evaluator) if comm.is_main_process(): - logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + logger.info(f"Evaluation results for {dataset_name} in csv format:") print_csv_format(results[dataset_name]) if len(results) == 1: - results = list(results.values())[0] + results = next(iter(results.values())) return results -def do_train(cfg, model, resume=False): +def do_train(cfg, model, resume: bool=False) -> None: model.train() if cfg.SOLVER.USE_CUSTOM_SOLVER: optimizer = build_custom_optimizer(cfg, model) @@ -143,12 +141,12 @@ def do_train(cfg, model, resume=False): if cfg.FP16: scaler = GradScaler() - logger.info("Starting training from iteration {}".format(start_iter)) + logger.info(f"Starting training from iteration {start_iter}") with EventStorage(start_iter) as storage: step_timer = Timer() data_timer = Timer() start_time = time.perf_counter() - for data, iteration in zip(data_loader, range(start_iter, max_iter)): + for data, iteration in zip(data_loader, range(start_iter, max_iter), strict=False): data_time = data_timer.seconds() storage.put_scalars(data_time=data_time) step_timer.reset() @@ -195,7 +193,7 @@ def do_train(cfg, model, resume=False): total_time = time.perf_counter() - start_time logger.info( - "Total training time: {}".format(str(datetime.timedelta(seconds=int(total_time)))) + f"Total training time: {datetime.timedelta(seconds=int(total_time))!s}" ) @@ -210,8 +208,8 @@ def setup(args): cfg.merge_from_list(args.opts) if "/auto" in cfg.OUTPUT_DIR: file_name = os.path.basename(args.config_file)[:-5] - cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace("/auto", "/{}".format(file_name)) - logger.info("OUTPUT_DIR: {}".format(cfg.OUTPUT_DIR)) + cfg.OUTPUT_DIR = cfg.OUTPUT_DIR.replace("/auto", f"/{file_name}") + logger.info(f"OUTPUT_DIR: {cfg.OUTPUT_DIR}") cfg.freeze() default_setup(cfg, args) setup_logger(output=cfg.OUTPUT_DIR, distributed_rank=comm.get_rank(), name="detic") @@ -222,7 +220,7 @@ def main(args): cfg = setup(args) model = build_model(cfg) - logger.info("Model:\n{}".format(model)) + logger.info(f"Model:\n{model}") if args.eval_only: DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load( cfg.MODEL.WEIGHTS, resume=args.resume @@ -247,16 +245,16 @@ def main(args): args = default_argument_parser() args = args.parse_args() if args.num_machines == 1: - args.dist_url = "tcp://127.0.0.1:{}".format(torch.randint(11111, 60000, (1,))[0].item()) + args.dist_url = f"tcp://127.0.0.1:{torch.randint(11111, 60000, (1,))[0].item()}" else: if args.dist_url == "host": args.dist_url = "tcp://{}:12345".format(os.environ["SLURM_JOB_NODELIST"]) elif not args.dist_url.startswith("tcp"): tmp = os.popen( - "echo $(scontrol show job {} | grep BatchHost)".format(args.dist_url) + f"echo $(scontrol show job {args.dist_url} | grep BatchHost)" ).read() tmp = tmp[tmp.find("=") + 1 : -1] - args.dist_url = "tcp://{}:12345".format(tmp) + args.dist_url = f"tcp://{tmp}:12345" print("Command Line Args:", args) launch( main, diff --git a/dimos/models/depth/metric3d.py b/dimos/models/depth/metric3d.py index b4f00718bc..1cd4f7495b 100644 --- a/dimos/models/depth/metric3d.py +++ b/dimos/models/depth/metric3d.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch -from PIL import Image import cv2 -import numpy as np +import torch # May need to add this back for import to work # external_path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'external', 'Metric3D')) @@ -24,7 +22,7 @@ class Metric3D: - def __init__(self, camera_intrinsics=None, gt_depth_scale=256.0): + def __init__(self, camera_intrinsics=None, gt_depth_scale: float=256.0) -> None: # self.conf = get_config("zoedepth", "infer") # self.depth_model = build_model(self.conf) self.depth_model = torch.hub.load( @@ -56,7 +54,7 @@ def update_intrinsic(self, intrinsic): self.intrinsic = intrinsic print(f"Intrinsics updated to: {self.intrinsic}") - def infer_depth(self, img, debug=False): + def infer_depth(self, img, debug: bool=False): if debug: print(f"Input image: {img}") try: @@ -72,14 +70,14 @@ def infer_depth(self, img, debug=False): img = self.rescale_input(img, self.rgb_origin) with torch.no_grad(): - pred_depth, confidence, output_dict = self.depth_model.inference({"input": img}) + pred_depth, _confidence, _output_dict = self.depth_model.inference({"input": img}) # Convert to PIL format depth_image = self.unpad_transform_depth(pred_depth) return depth_image.cpu().numpy() - def save_depth(self, pred_depth): + def save_depth(self, pred_depth) -> None: # Save the depth map to a file pred_depth_np = pred_depth.cpu().numpy() output_depth_file = "output_depth_map.png" @@ -154,10 +152,10 @@ def unpad_transform_depth(self, pred_depth): """Set new intrinsic value.""" - def update_intrinsic(self, intrinsic): + def update_intrinsic(self, intrinsic) -> None: self.intrinsic = intrinsic - def eval_predicted_depth(self, depth_file, pred_depth): + def eval_predicted_depth(self, depth_file, pred_depth) -> None: if depth_file is not None: gt_depth = cv2.imread(depth_file, -1) gt_depth = gt_depth / self.gt_depth_scale diff --git a/dimos/models/embedding/base.py b/dimos/models/embedding/base.py index 7f2e1896b9..f7c790ffbf 100644 --- a/dimos/models/embedding/base.py +++ b/dimos/models/embedding/base.py @@ -14,16 +14,18 @@ from __future__ import annotations -import time from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar +import time +from typing import TYPE_CHECKING, Generic, TypeVar import numpy as np import torch -from dimos.msgs.sensor_msgs import Image from dimos.types.timestamped import Timestamped +if TYPE_CHECKING: + from dimos.msgs.sensor_msgs import Image + class Embedding(Timestamped): """Base class for embeddings with vector data. @@ -34,14 +36,14 @@ class Embedding(Timestamped): vector: torch.Tensor | np.ndarray - def __init__(self, vector: torch.Tensor | np.ndarray, timestamp: Optional[float] = None): + def __init__(self, vector: torch.Tensor | np.ndarray, timestamp: float | None = None) -> None: self.vector = vector if timestamp: self.timestamp = timestamp else: self.timestamp = time.time() - def __matmul__(self, other: "Embedding") -> float: + def __matmul__(self, other: Embedding) -> float: """Compute cosine similarity via @ operator.""" if isinstance(self.vector, torch.Tensor): other_tensor = other.to_torch(self.vector.device) @@ -65,7 +67,7 @@ def to_torch(self, device: str | torch.device | None = None) -> torch.Tensor: return self.vector.to(device) return self.vector - def to_cpu(self) -> "Embedding": + def to_cpu(self) -> Embedding: """Move embedding to CPU, returning self for chaining.""" if isinstance(self.vector, torch.Tensor): self.vector = self.vector.cpu() @@ -141,7 +143,7 @@ def query(self, query_emb: E, candidates: list[E], top_k: int = 5) -> list[tuple """ similarities = self.compare_one_to_many(query_emb, candidates) top_values, top_indices = similarities.topk(k=min(top_k, len(candidates))) - return [(idx.item(), val.item()) for idx, val in zip(top_indices, top_values)] + return [(idx.item(), val.item()) for idx, val in zip(top_indices, top_values, strict=False)] def warmup(self) -> None: """Optional warmup method to pre-load model.""" diff --git a/dimos/models/embedding/clip.py b/dimos/models/embedding/clip.py index e751e9ee33..23ab5e94f2 100644 --- a/dimos/models/embedding/clip.py +++ b/dimos/models/embedding/clip.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from PIL import Image as PILImage import torch import torch.nn.functional as F -from PIL import Image as PILImage -from transformers import CLIPModel as HFCLIPModel -from transformers import CLIPProcessor +from transformers import CLIPModel as HFCLIPModel, CLIPProcessor from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs import Image @@ -35,7 +34,7 @@ def __init__( model_name: str = "openai/clip-vit-base-patch32", device: str | None = None, normalize: bool = False, - ): + ) -> None: """ Initialize CLIP model. diff --git a/dimos/models/embedding/embedding_models_disabled_tests.py b/dimos/models/embedding/embedding_models_disabled_tests.py index 52e9fd08af..bb1f038410 100644 --- a/dimos/models/embedding/embedding_models_disabled_tests.py +++ b/dimos/models/embedding/embedding_models_disabled_tests.py @@ -49,7 +49,7 @@ def test_image(): @pytest.mark.heavy -def test_single_image_embedding(embedding_model, test_image): +def test_single_image_embedding(embedding_model, test_image) -> None: """Test embedding a single image.""" embedding = embedding_model.embed(test_image) @@ -74,7 +74,7 @@ def test_single_image_embedding(embedding_model, test_image): @pytest.mark.heavy -def test_batch_image_embedding(embedding_model, test_image): +def test_batch_image_embedding(embedding_model, test_image) -> None: """Test embedding multiple images at once.""" embeddings = embedding_model.embed(test_image, test_image, test_image) @@ -92,7 +92,7 @@ def test_batch_image_embedding(embedding_model, test_image): @pytest.mark.heavy -def test_single_text_embedding(embedding_model): +def test_single_text_embedding(embedding_model) -> None: """Test embedding a single text string.""" import torch @@ -117,7 +117,7 @@ def test_single_text_embedding(embedding_model): @pytest.mark.heavy -def test_batch_text_embedding(embedding_model): +def test_batch_text_embedding(embedding_model) -> None: """Test embedding multiple text strings at once.""" import torch @@ -137,7 +137,7 @@ def test_batch_text_embedding(embedding_model): @pytest.mark.heavy -def test_text_image_similarity(embedding_model, test_image): +def test_text_image_similarity(embedding_model, test_image) -> None: """Test cross-modal text-image similarity using @ operator.""" if not hasattr(embedding_model, "embed_text"): pytest.skip("Model does not support text embeddings") @@ -150,7 +150,7 @@ def test_text_image_similarity(embedding_model, test_image): # Compute similarities using @ operator similarities = {} - for query, text_emb in zip(queries, text_embeddings): + for query, text_emb in zip(queries, text_embeddings, strict=False): similarity = img_embedding @ text_emb similarities[query] = similarity print(f"\n'{query}': {similarity:.4f}") @@ -161,7 +161,7 @@ def test_text_image_similarity(embedding_model, test_image): @pytest.mark.heavy -def test_cosine_distance(embedding_model, test_image): +def test_cosine_distance(embedding_model, test_image) -> None: """Test cosine distance computation (1 - similarity).""" emb1 = embedding_model.embed(test_image) emb2 = embedding_model.embed(test_image) @@ -180,7 +180,7 @@ def test_cosine_distance(embedding_model, test_image): @pytest.mark.heavy -def test_query_functionality(embedding_model, test_image): +def test_query_functionality(embedding_model, test_image) -> None: """Test query method for top-k retrieval.""" if not hasattr(embedding_model, "embed_text"): pytest.skip("Model does not support text embeddings") @@ -206,7 +206,7 @@ def test_query_functionality(embedding_model, test_image): @pytest.mark.heavy -def test_embedding_operator(embedding_model, test_image): +def test_embedding_operator(embedding_model, test_image) -> None: """Test that @ operator works on embeddings.""" emb1 = embedding_model.embed(test_image) emb2 = embedding_model.embed(test_image) @@ -220,7 +220,7 @@ def test_embedding_operator(embedding_model, test_image): @pytest.mark.heavy -def test_warmup(embedding_model): +def test_warmup(embedding_model) -> None: """Test that warmup runs without error.""" # Warmup is already called in fixture, but test it explicitly embedding_model.warmup() @@ -229,7 +229,7 @@ def test_warmup(embedding_model): @pytest.mark.heavy -def test_compare_one_to_many(embedding_model, test_image): +def test_compare_one_to_many(embedding_model, test_image) -> None: """Test GPU-accelerated one-to-many comparison.""" import torch @@ -253,7 +253,7 @@ def test_compare_one_to_many(embedding_model, test_image): @pytest.mark.heavy -def test_compare_many_to_many(embedding_model): +def test_compare_many_to_many(embedding_model) -> None: """Test GPU-accelerated many-to-many comparison.""" import torch @@ -280,7 +280,7 @@ def test_compare_many_to_many(embedding_model): @pytest.mark.heavy -def test_gpu_query_performance(embedding_model, test_image): +def test_gpu_query_performance(embedding_model, test_image) -> None: """Test that query method uses GPU acceleration.""" # Create a larger gallery gallery_size = 20 @@ -303,7 +303,7 @@ def test_gpu_query_performance(embedding_model, test_image): @pytest.mark.heavy -def test_embedding_performance(embedding_model): +def test_embedding_performance(embedding_model) -> None: """Measure embedding performance over multiple real video frames.""" import time @@ -317,7 +317,7 @@ def test_embedding_performance(embedding_model): # Collect 10 real frames from the video test_images = [] - for ts, frame in video_replay.iterate_ts(duration=1.0): + for _ts, frame in video_replay.iterate_ts(duration=1.0): test_images.append(frame.to_rgb()) if len(test_images) >= 10: break @@ -391,7 +391,7 @@ def test_embedding_performance(embedding_model): text_embeddings = embedding_model.embed_text(*test_queries) similarities = [] - for query, text_emb in zip(test_queries, text_embeddings): + for query, text_emb in zip(test_queries, text_embeddings, strict=False): sim = first_frame_emb @ text_emb similarities.append((query, sim)) diff --git a/dimos/models/embedding/mobileclip.py b/dimos/models/embedding/mobileclip.py index c0295a78ef..8ddefd3c87 100644 --- a/dimos/models/embedding/mobileclip.py +++ b/dimos/models/embedding/mobileclip.py @@ -15,9 +15,9 @@ from pathlib import Path import open_clip +from PIL import Image as PILImage import torch import torch.nn.functional as F -from PIL import Image as PILImage from dimos.models.embedding.base import Embedding, EmbeddingModel from dimos.msgs.sensor_msgs import Image @@ -35,7 +35,7 @@ def __init__( model_path: Path | str | None = None, device: str | None = None, normalize: bool = True, - ): + ) -> None: """ Initialize MobileCLIP model. diff --git a/dimos/models/embedding/treid.py b/dimos/models/embedding/treid.py index bdd00627a0..b00ad11250 100644 --- a/dimos/models/embedding/treid.py +++ b/dimos/models/embedding/treid.py @@ -36,7 +36,7 @@ def __init__( model_path: Path | str | None = None, device: str | None = None, normalize: bool = False, - ): + ) -> None: """ Initialize TorchReID model. diff --git a/dimos/models/labels/llava-34b.py b/dimos/models/labels/llava-34b.py index c59a5c8aa9..52e28ac24e 100644 --- a/dimos/models/labels/llava-34b.py +++ b/dimos/models/labels/llava-34b.py @@ -18,17 +18,16 @@ # llava v1.6 from llama_cpp import Llama from llama_cpp.llama_chat_format import Llava15ChatHandler - from vqasynth.datasets.utils import image_to_base64_data_uri class Llava: def __init__( self, - mmproj=f"{os.getcwd()}/models/mmproj-model-f16.gguf", - model_path=f"{os.getcwd()}/models/llava-v1.6-34b.Q4_K_M.gguf", - gpu=True, - ): + mmproj: str=f"{os.getcwd()}/models/mmproj-model-f16.gguf", + model_path: str=f"{os.getcwd()}/models/llava-v1.6-34b.Q4_K_M.gguf", + gpu: bool=True, + ) -> None: chat_handler = Llava15ChatHandler(clip_model_path=mmproj, verbose=True) n_gpu_layers = 0 if gpu: @@ -41,7 +40,7 @@ def __init__( n_gpu_layers=n_gpu_layers, ) - def run_inference(self, image, prompt, return_json=True): + def run_inference(self, image, prompt: str, return_json: bool=True): data_uri = image_to_base64_data_uri(image) res = self.llm.create_chat_completion( messages=[ diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py index f09a4ee315..fe173dc017 100644 --- a/dimos/models/manipulation/contact_graspnet_pytorch/inference.py +++ b/dimos/models/manipulation/contact_graspnet_pytorch/inference.py @@ -1,34 +1,33 @@ +import argparse import glob import os -import argparse -import torch -import numpy as np -from contact_graspnet_pytorch.contact_grasp_estimator import GraspEstimator from contact_graspnet_pytorch import config_utils - -from contact_graspnet_pytorch.visualization_utils_o3d import visualize_grasps, show_image -from contact_graspnet_pytorch.checkpoints import CheckpointIO +from contact_graspnet_pytorch.checkpoints import CheckpointIO +from contact_graspnet_pytorch.contact_grasp_estimator import GraspEstimator from contact_graspnet_pytorch.data import load_available_input_data +import numpy as np + from dimos.utils.data import get_data -def inference(global_config, + +def inference(global_config, ckpt_dir, - input_paths, - local_regions=True, - filter_grasps=True, - skip_border_objects=False, - z_range = [0.2,1.8], - forward_passes=1, + input_paths, + local_regions: bool=True, + filter_grasps: bool=True, + skip_border_objects: bool=False, + z_range = None, + forward_passes: int=1, K=None,): """ Predict 6-DoF grasp distribution for given model and input data - + :param global_config: config.yaml from checkpoint directory :param checkpoint_dir: checkpoint directory :param input_paths: .png/.npz/.npy file paths that contain depth/pointcloud and optionally intrinsics/segmentation/rgb :param K: Camera Matrix with intrinsics to convert depth to point cloud - :param local_regions: Crop 3D local regions around given segments. + :param local_regions: Crop 3D local regions around given segments. :param skip_border_objects: When extracting local_regions, ignore segments at depth map boundary. :param filter_grasps: Filter and assign grasp contacts according to segmap. :param segmap_id: only return grasps from specified segmap_id. @@ -36,18 +35,19 @@ def inference(global_config, :param forward_passes: Number of forward passes to run on each point cloud. Default: 1 """ # Build the model + if z_range is None: + z_range = [0.2, 1.8] grasp_estimator = GraspEstimator(global_config) # Load the weights model_checkpoint_dir = get_data(ckpt_dir) checkpoint_io = CheckpointIO(checkpoint_dir=model_checkpoint_dir, model=grasp_estimator.model) try: - load_dict = checkpoint_io.load('model.pt') + checkpoint_io.load('model.pt') except FileExistsError: print('No model checkpoint found') - load_dict = {} - + os.makedirs('results', exist_ok=True) # Process example test scenes @@ -56,36 +56,36 @@ def inference(global_config, pc_segments = {} segmap, rgb, depth, cam_K, pc_full, pc_colors = load_available_input_data(p, K=K) - + if segmap is None and (local_regions or filter_grasps): raise ValueError('Need segmentation map to extract local regions or filter grasps') if pc_full is None: print('Converting depth to point cloud(s)...') pc_full, pc_segments, pc_colors = grasp_estimator.extract_point_clouds(depth, cam_K, segmap=segmap, rgb=rgb, - skip_border_objects=skip_border_objects, + skip_border_objects=skip_border_objects, z_range=z_range) - + print(pc_full.shape) print('Generating Grasps...') - pred_grasps_cam, scores, contact_pts, _ = grasp_estimator.predict_scene_grasps(pc_full, - pc_segments=pc_segments, - local_regions=local_regions, - filter_grasps=filter_grasps, - forward_passes=forward_passes) - + pred_grasps_cam, scores, contact_pts, _ = grasp_estimator.predict_scene_grasps(pc_full, + pc_segments=pc_segments, + local_regions=local_regions, + filter_grasps=filter_grasps, + forward_passes=forward_passes) + # Save results - np.savez('results/predictions_{}'.format(os.path.basename(p.replace('png','npz').replace('npy','npz'))), + np.savez('results/predictions_{}'.format(os.path.basename(p.replace('png','npz').replace('npy','npz'))), pc_full=pc_full, pred_grasps_cam=pred_grasps_cam, scores=scores, contact_pts=contact_pts, pc_colors=pc_colors) - # Visualize results + # Visualize results # show_image(rgb, segmap) # visualize_grasps(pc_full, pred_grasps_cam, scores, plot_opencv_cam=True, pc_colors=pc_colors) - + if not glob.glob(input_paths): print('No files found: ', input_paths) - + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -101,16 +101,16 @@ def inference(global_config, FLAGS = parser.parse_args() global_config = config_utils.load_config(FLAGS.ckpt_dir, batch_size=FLAGS.forward_passes, arg_configs=FLAGS.arg_configs) - + print(str(global_config)) - print('pid: %s'%(str(os.getpid()))) + print(f'pid: {os.getpid()!s}') - inference(global_config, + inference(global_config, FLAGS.ckpt_dir, - FLAGS.np_path, + FLAGS.np_path, local_regions=FLAGS.local_regions, filter_grasps=FLAGS.filter_grasps, skip_border_objects=FLAGS.skip_border_objects, z_range=eval(str(FLAGS.z_range)), forward_passes=FLAGS.forward_passes, - K=eval(str(FLAGS.K))) \ No newline at end of file + K=eval(str(FLAGS.K))) diff --git a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py index 84f0343779..7964a24954 100644 --- a/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py +++ b/dimos/models/manipulation/contact_graspnet_pytorch/test_contact_graspnet.py @@ -1,11 +1,11 @@ -import os -import sys import glob -import pytest -import importlib.util +import os + import numpy as np +import pytest + -def is_manipulation_installed(): +def is_manipulation_installed() -> bool: """Check if the manipulation extras are installed.""" try: import contact_graspnet_pytorch @@ -13,38 +13,39 @@ def is_manipulation_installed(): except ImportError: return False -@pytest.mark.skipif(not is_manipulation_installed(), +@pytest.mark.skipif(not is_manipulation_installed(), reason="This test requires 'pip install .[manipulation]' to be run") -def test_contact_graspnet_inference(): +def test_contact_graspnet_inference() -> None: """Test contact graspnet inference with local regions and filter grasps.""" # Skip test if manipulation dependencies not installed if not is_manipulation_installed(): pytest.skip("contact_graspnet_pytorch not installed. Run 'pip install .[manipulation]' first.") return - + try: - from dimos.utils.data import get_data from contact_graspnet_pytorch import config_utils + from dimos.models.manipulation.contact_graspnet_pytorch.inference import inference + from dimos.utils.data import get_data except ImportError: pytest.skip("Required modules could not be imported. Make sure you have run 'pip install .[manipulation]'.") return # Test data path - use the default test data path test_data_path = os.path.join(get_data("models_contact_graspnet"), "test_data/0.npy") - + # Check if test data exists test_files = glob.glob(test_data_path) if not test_files: pytest.fail(f"No test data found at {test_data_path}") - + # Load config with default values ckpt_dir = 'models_contact_graspnet' global_config = config_utils.load_config(ckpt_dir, batch_size=1) - + # Run inference function with the same params as the command line result_files_before = glob.glob('results/predictions_*.npz') - + inference( global_config=global_config, ckpt_dir=ckpt_dir, @@ -56,15 +57,15 @@ def test_contact_graspnet_inference(): forward_passes=1, K=None ) - + # Verify results were created result_files_after = glob.glob('results/predictions_*.npz') assert len(result_files_after) >= len(result_files_before), "No result files were generated" - + # Load at least one result file and verify it contains expected data if result_files_after: latest_result = sorted(result_files_after)[-1] result_data = np.load(latest_result, allow_pickle=True) expected_keys = ['pc_full', 'pred_grasps_cam', 'scores', 'contact_pts', 'pc_colors'] for key in expected_keys: - assert key in result_data.files, f"Expected key '{key}' not found in results" \ No newline at end of file + assert key in result_data.files, f"Expected key '{key}' not found in results" diff --git a/dimos/models/pointcloud/pointcloud_utils.py b/dimos/models/pointcloud/pointcloud_utils.py index c0951f44f2..33b4b59607 100644 --- a/dimos/models/pointcloud/pointcloud_utils.py +++ b/dimos/models/pointcloud/pointcloud_utils.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +import random + import numpy as np import open3d as o3d -import random -def save_pointcloud(pcd, file_path): +def save_pointcloud(pcd, file_path) -> None: """ Save a point cloud to a file using Open3D. """ @@ -52,7 +53,7 @@ def create_point_cloud_from_rgbd(rgb_image, depth_image, intrinsic_parameters): return pcd -def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): +def canonicalize_point_cloud(pcd, canonicalize_threshold: float=0.3): # Segment the largest plane, assumed to be the floor plane_model, inliers = pcd.segment_plane( distance_threshold=0.01, ransac_n=3, num_iterations=1000 @@ -95,7 +96,7 @@ def canonicalize_point_cloud(pcd, canonicalize_threshold=0.3): # Distance calculations -def human_like_distance(distance_meters): +def human_like_distance(distance_meters) -> str: # Define the choices with units included, focusing on the 0.1 to 10 meters range if distance_meters < 1: # For distances less than 1 meter choices = [ diff --git a/dimos/models/qwen/video_query.py b/dimos/models/qwen/video_query.py index c82ce0fc27..80bb078bac 100644 --- a/dimos/models/qwen/video_query.py +++ b/dimos/models/qwen/video_query.py @@ -1,8 +1,9 @@ """Utility functions for one-off video frame queries using Qwen model.""" +import json import os + import numpy as np -from typing import Optional, Tuple from openai import OpenAI from reactivex import Observable, operators as ops from reactivex.subject import Subject @@ -10,15 +11,14 @@ from dimos.agents.agent import OpenAIAgent from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer from dimos.utils.threadpool import get_scheduler -import json -BBox = Tuple[float, float, float, float] # (x1, y1, x2, y2) +BBox = tuple[float, float, float, float] # (x1, y1, x2, y2) def query_single_frame_observable( video_observable: Observable, query: str, - api_key: Optional[str] = None, + api_key: str | None = None, model_name: str = "qwen2.5-vl-72b-instruct", ) -> Observable: """Process a single frame from a video observable with Qwen model. @@ -89,7 +89,7 @@ def query_single_frame_observable( def query_single_frame( image: np.ndarray, query: str = "Return the center coordinates of the fridge handle as a tuple (x,y)", - api_key: Optional[str] = None, + api_key: str | None = None, model_name: str = "qwen2.5-vl-72b-instruct", ) -> str: """Process a single numpy image array with Qwen model. @@ -162,8 +162,8 @@ def query_single_frame( def get_bbox_from_qwen( - video_stream: Observable, object_name: Optional[str] = None -) -> Optional[Tuple[BBox, float]]: + video_stream: Observable, object_name: str | None = None +) -> tuple[BBox, float] | None: """Get bounding box coordinates from Qwen for a specific object or any object. Args: @@ -201,7 +201,7 @@ def get_bbox_from_qwen( return None -def get_bbox_from_qwen_frame(frame, object_name: Optional[str] = None) -> Optional[BBox]: +def get_bbox_from_qwen_frame(frame, object_name: str | None = None) -> BBox | None: """Get bounding box coordinates from Qwen for a specific object or any object using a single frame. Args: diff --git a/dimos/models/segmentation/clipseg.py b/dimos/models/segmentation/clipseg.py index 043cd194b0..ca8fbeb6fc 100644 --- a/dimos/models/segmentation/clipseg.py +++ b/dimos/models/segmentation/clipseg.py @@ -16,7 +16,7 @@ class CLIPSeg: - def __init__(self, model_name="CIDAS/clipseg-rd64-refined"): + def __init__(self, model_name: str="CIDAS/clipseg-rd64-refined") -> None: self.clipseg_processor = AutoProcessor.from_pretrained(model_name) self.clipseg_model = CLIPSegForImageSegmentation.from_pretrained(model_name) diff --git a/dimos/models/segmentation/sam.py b/dimos/models/segmentation/sam.py index 1efb07c484..96b23bf984 100644 --- a/dimos/models/segmentation/sam.py +++ b/dimos/models/segmentation/sam.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from transformers import SamModel, SamProcessor import torch +from transformers import SamModel, SamProcessor class SAM: - def __init__(self, model_name="facebook/sam-vit-huge", device="cuda"): + def __init__(self, model_name: str="facebook/sam-vit-huge", device: str="cuda") -> None: self.device = device self.sam_model = SamModel.from_pretrained(model_name).to(self.device) self.sam_processor = SamProcessor.from_pretrained(model_name) diff --git a/dimos/models/segmentation/segment_utils.py b/dimos/models/segmentation/segment_utils.py index 9808f5d4e4..59a805afaa 100644 --- a/dimos/models/segmentation/segment_utils.py +++ b/dimos/models/segmentation/segment_utils.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import torch import numpy as np +import torch -def find_medoid_and_closest_points(points, num_closest=5): +def find_medoid_and_closest_points(points, num_closest: int=5): """ Find the medoid from a collection of points and the closest points to the medoid. @@ -37,7 +37,7 @@ def find_medoid_and_closest_points(points, num_closest=5): return medoid, points[closest_indices] -def sample_points_from_heatmap(heatmap, original_size, num_points=5, percentile=0.95): +def sample_points_from_heatmap(heatmap, original_size: int, num_points: int=5, percentile: float=0.95): """ Sample points from the given heatmap, focusing on areas with higher values. """ @@ -53,7 +53,7 @@ def sample_points_from_heatmap(heatmap, original_size, num_points=5, percentile= ) sampled_coords = np.array(np.unravel_index(sampled_indices, attn.shape)).T - medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) + _medoid, sampled_coords = find_medoid_and_closest_points(sampled_coords) pts = [] for pt in sampled_coords.tolist(): x, y = pt diff --git a/dimos/models/vl/base.py b/dimos/models/vl/base.py index cde41bd8fc..7e162b3ccf 100644 --- a/dimos/models/vl/base.py +++ b/dimos/models/vl/base.py @@ -1,6 +1,6 @@ +from abc import ABC, abstractmethod import json import logging -from abc import ABC, abstractmethod from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import Detection2DBBox, ImageDetections2D diff --git a/dimos/models/vl/moondream.py b/dimos/models/vl/moondream.py index a3b9f5fcca..781f1adbf1 100644 --- a/dimos/models/vl/moondream.py +++ b/dimos/models/vl/moondream.py @@ -1,10 +1,9 @@ -import warnings from functools import cached_property -from typing import Optional +import warnings import numpy as np -import torch from PIL import Image as PILImage +import torch from transformers import AutoModelForCausalLM from dimos.models.vl.base import VlModel @@ -20,9 +19,9 @@ class MoondreamVlModel(VlModel): def __init__( self, model_name: str = "vikhyatk/moondream2", - device: Optional[str] = None, + device: str | None = None, dtype: torch.dtype = torch.bfloat16, - ): + ) -> None: self._model_name = model_name self._device = device or ("cuda" if torch.cuda.is_available() else "cpu") self._dtype = dtype diff --git a/dimos/models/vl/qwen.py b/dimos/models/vl/qwen.py index c34f6f7964..773fcc35ad 100644 --- a/dimos/models/vl/qwen.py +++ b/dimos/models/vl/qwen.py @@ -1,6 +1,5 @@ -import os from functools import cached_property -from typing import Optional +import os import numpy as np from openai import OpenAI @@ -11,9 +10,9 @@ class QwenVlModel(VlModel): _model_name: str - _api_key: Optional[str] + _api_key: str | None - def __init__(self, api_key: Optional[str] = None, model_name: str = "qwen2.5-vl-72b-instruct"): + def __init__(self, api_key: str | None = None, model_name: str = "qwen2.5-vl-72b-instruct") -> None: self._model_name = model_name self._api_key = api_key diff --git a/dimos/models/vl/test_base.py b/dimos/models/vl/test_base.py index 302a588721..3d8575fab3 100644 --- a/dimos/models/vl/test_base.py +++ b/dimos/models/vl/test_base.py @@ -26,7 +26,7 @@ """ -def test_query_detections_mocked(): +def test_query_detections_mocked() -> None: """Test query_detections with mocked API response (no API key required).""" # Load test image image = Image.from_file(get_data("cafe.jpg")) @@ -76,7 +76,7 @@ def test_query_detections_mocked(): @pytest.mark.tool @pytest.mark.skipif(not os.getenv("ALIBABA_API_KEY"), reason="ALIBABA_API_KEY not set") -def test_query_detections_real(): +def test_query_detections_real() -> None: """Test query_detections with real API calls (requires API key).""" # Load test image image = Image.from_file(get_data("cafe.jpg")) diff --git a/dimos/models/vl/test_models.py b/dimos/models/vl/test_models.py index adc49798e9..b33e0905e6 100644 --- a/dimos/models/vl/test_models.py +++ b/dimos/models/vl/test_models.py @@ -1,16 +1,19 @@ import time +from typing import TYPE_CHECKING -import pytest from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations +import pytest from dimos.core import LCMTransport -from dimos.models.vl.base import VlModel from dimos.models.vl.moondream import MoondreamVlModel from dimos.models.vl.qwen import QwenVlModel from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type import ImageDetections2D from dimos.utils.data import get_data +if TYPE_CHECKING: + from dimos.models.vl.base import VlModel + @pytest.mark.parametrize( "model_class,model_name", @@ -21,7 +24,7 @@ ids=["moondream", "qwen"], ) @pytest.mark.gpu -def test_vlm(model_class, model_name): +def test_vlm(model_class, model_name: str) -> None: image = Image.from_file(get_data("cafe.jpg")).to_rgb() print(f"Testing {model_name}") diff --git a/dimos/msgs/foxglove_msgs/Color.py b/dimos/msgs/foxglove_msgs/Color.py index 59d60ccc35..ed19911eb7 100644 --- a/dimos/msgs/foxglove_msgs/Color.py +++ b/dimos/msgs/foxglove_msgs/Color.py @@ -15,6 +15,7 @@ from __future__ import annotations import hashlib + from dimos_lcm.foxglove_msgs import Color as LCMColor diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 1cf6c95442..50072bdc70 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -16,13 +16,10 @@ from typing import TypeAlias -from dimos_lcm.geometry_msgs import Pose as LCMPose -from dimos_lcm.geometry_msgs import Transform as LCMTransform +from dimos_lcm.geometry_msgs import Pose as LCMPose, Transform as LCMTransform try: - from geometry_msgs.msg import Pose as ROSPose - from geometry_msgs.msg import Point as ROSPoint - from geometry_msgs.msg import Quaternion as ROSQuaternion + from geometry_msgs.msg import Point as ROSPoint, Pose as ROSPose, Quaternion as ROSQuaternion except ImportError: ROSPose = None ROSPoint = None @@ -78,10 +75,14 @@ def __init__( @dispatch def __init__( self, - position: VectorConvertable | Vector3 = [0, 0, 0], - orientation: QuaternionConvertable | Quaternion = [0, 0, 0, 1], + position: VectorConvertable | Vector3 = None, + orientation: QuaternionConvertable | Quaternion = None, ) -> None: """Initialize a pose with position and orientation.""" + if orientation is None: + orientation = [0, 0, 0, 1] + if position is None: + position = [0, 0, 0] self.position = Vector3(position) self.orientation = Quaternion(orientation) @@ -163,7 +164,7 @@ def __eq__(self, other) -> bool: def __matmul__(self, transform: LCMTransform | Transform) -> Pose: return self + transform - def __add__(self, other: "Pose" | PoseConvertable | LCMTransform | Transform) -> "Pose": + def __add__(self, other: Pose | PoseConvertable | LCMTransform | Transform) -> Pose: """Compose two poses or apply a transform (transform composition). The operation self + other represents applying transformation 'other' @@ -215,7 +216,7 @@ def __add__(self, other: "Pose" | PoseConvertable | LCMTransform | Transform) -> return Pose(new_position, new_orientation) @classmethod - def from_ros_msg(cls, ros_msg: ROSPose) -> "Pose": + def from_ros_msg(cls, ros_msg: ROSPose) -> Pose: """Create a Pose from a ROS geometry_msgs/Pose message. Args: @@ -253,7 +254,7 @@ def to_ros_msg(self) -> ROSPose: @dispatch -def to_pose(value: "Pose") -> "Pose": +def to_pose(value: Pose) -> Pose: """Pass through Pose objects.""" return value diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py index c44c9cd4ff..770f41b641 100644 --- a/dimos/msgs/geometry_msgs/PoseStamped.py +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -14,14 +14,10 @@ from __future__ import annotations -import struct import time -from io import BytesIO from typing import BinaryIO, TypeAlias from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime try: from geometry_msgs.msg import PoseStamped as ROSPoseStamped @@ -79,7 +75,7 @@ def lcm_decode(cls, data: bytes | BinaryIO) -> PoseStamped: lcm_msg.pose.orientation.y, lcm_msg.pose.orientation.z, lcm_msg.pose.orientation.w, - ], # noqa: E501, + ], ) def __str__(self) -> str: @@ -117,7 +113,7 @@ def find_transform(self, other: PoseStamped) -> Transform: ) @classmethod - def from_ros_msg(cls, ros_msg: ROSPoseStamped) -> "PoseStamped": + def from_ros_msg(cls, ros_msg: ROSPoseStamped) -> PoseStamped: """Create a PoseStamped from a ROS geometry_msgs/PoseStamped message. Args: diff --git a/dimos/msgs/geometry_msgs/PoseWithCovariance.py b/dimos/msgs/geometry_msgs/PoseWithCovariance.py index 3a49522653..ba2c360935 100644 --- a/dimos/msgs/geometry_msgs/PoseWithCovariance.py +++ b/dimos/msgs/geometry_msgs/PoseWithCovariance.py @@ -14,10 +14,10 @@ from __future__ import annotations -from typing import TypeAlias +from typing import TYPE_CHECKING, TypeAlias -import numpy as np from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance +import numpy as np from plum import dispatch try: @@ -26,8 +26,10 @@ ROSPoseWithCovariance = None from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + from dimos.msgs.geometry_msgs.Vector3 import Vector3 # Types that can be converted to/from PoseWithCovariance PoseWithCovarianceConvertable: TypeAlias = ( @@ -86,7 +88,7 @@ def __init__(self, pose_tuple: tuple[PoseConvertable, list[float] | np.ndarray]) self.pose = Pose(pose_tuple[0]) self.covariance = np.array(pose_tuple[1], dtype=float).reshape(36) - def __getattribute__(self, name): + def __getattribute__(self, name: str): """Override to ensure covariance is always returned as numpy array.""" if name == "covariance": cov = object.__getattribute__(self, "covariance") @@ -95,7 +97,7 @@ def __getattribute__(self, name): return cov return super().__getattribute__(name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value) -> None: """Override to ensure covariance is stored as numpy array.""" if name == "covariance": if not isinstance(value, np.ndarray): @@ -180,7 +182,7 @@ def lcm_encode(self) -> bytes: return lcm_msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes) -> "PoseWithCovariance": + def lcm_decode(cls, data: bytes) -> PoseWithCovariance: """Decode from LCM binary format.""" lcm_msg = LCMPoseWithCovariance.lcm_decode(data) pose = Pose( @@ -195,7 +197,7 @@ def lcm_decode(cls, data: bytes) -> "PoseWithCovariance": return cls(pose, lcm_msg.covariance) @classmethod - def from_ros_msg(cls, ros_msg: ROSPoseWithCovariance) -> "PoseWithCovariance": + def from_ros_msg(cls, ros_msg: ROSPoseWithCovariance) -> PoseWithCovariance: """Create a PoseWithCovariance from a ROS geometry_msgs/PoseWithCovariance message. Args: diff --git a/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py index 05e1847734..3683a15fbd 100644 --- a/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py +++ b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py @@ -17,8 +17,8 @@ import time from typing import TypeAlias -import numpy as np from dimos_lcm.geometry_msgs import PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped +import numpy as np from plum import dispatch try: @@ -113,7 +113,7 @@ def __str__(self) -> str: ) @classmethod - def from_ros_msg(cls, ros_msg: ROSPoseWithCovarianceStamped) -> "PoseWithCovarianceStamped": + def from_ros_msg(cls, ros_msg: ROSPoseWithCovarianceStamped) -> PoseWithCovarianceStamped: """Create a PoseWithCovarianceStamped from a ROS geometry_msgs/PoseWithCovarianceStamped message. Args: diff --git a/dimos/msgs/geometry_msgs/Quaternion.py b/dimos/msgs/geometry_msgs/Quaternion.py index 9b51339537..6ce8c3bf2d 100644 --- a/dimos/msgs/geometry_msgs/Quaternion.py +++ b/dimos/msgs/geometry_msgs/Quaternion.py @@ -14,13 +14,13 @@ from __future__ import annotations -import struct from collections.abc import Sequence from io import BytesIO +import struct from typing import BinaryIO, TypeAlias -import numpy as np from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion +import numpy as np from plum import dispatch from scipy.spatial.transform import Rotation as R @@ -74,7 +74,7 @@ def __init__(self, sequence: Sequence[int | float] | np.ndarray) -> None: self.w = sequence[3] @dispatch - def __init__(self, quaternion: "Quaternion") -> None: + def __init__(self, quaternion: Quaternion) -> None: """Initialize from another Quaternion (copy constructor).""" self.x, self.y, self.z, self.w = quaternion.x, quaternion.y, quaternion.z, quaternion.w @@ -113,7 +113,7 @@ def to_radians(self) -> Vector3: return self.to_euler() @classmethod - def from_euler(cls, vector: Vector3) -> "Quaternion": + def from_euler(cls, vector: Vector3) -> Quaternion: """Convert Euler angles (roll, pitch, yaw) in radians to quaternion. Args: @@ -175,7 +175,7 @@ def __eq__(self, other) -> bool: return False return self.x == other.x and self.y == other.y and self.z == other.z and self.w == other.w - def __mul__(self, other: "Quaternion") -> "Quaternion": + def __mul__(self, other: Quaternion) -> Quaternion: """Multiply two quaternions (Hamilton product). The result represents the composition of rotations: diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index 88ee8627ae..fc22a30bf1 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -17,14 +17,18 @@ import time from typing import BinaryIO -from dimos_lcm.geometry_msgs import Transform as LCMTransform -from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped +from dimos_lcm.geometry_msgs import ( + Transform as LCMTransform, + TransformStamped as LCMTransformStamped, +) try: - from geometry_msgs.msg import Quaternion as ROSQuaternion - from geometry_msgs.msg import Transform as ROSTransform - from geometry_msgs.msg import TransformStamped as ROSTransformStamped - from geometry_msgs.msg import Vector3 as ROSVector3 + from geometry_msgs.msg import ( + Quaternion as ROSQuaternion, + Transform as ROSTransform, + TransformStamped as ROSTransformStamped, + Vector3 as ROSVector3, + ) except ImportError: ROSTransformStamped = None ROSTransform = None @@ -60,7 +64,7 @@ def __init__( self.translation = translation if translation is not None else Vector3() self.rotation = rotation if rotation is not None else Quaternion() - def now(self) -> "Transform": + def now(self) -> Transform: """Return a copy of this Transform with the current timestamp.""" return Transform( translation=self.translation, @@ -97,10 +101,10 @@ def lcm_transform(self) -> LCMTransformStamped: ), ) - def apply(self, other: "Transform") -> "Transform": + def apply(self, other: Transform) -> Transform: return self.__add__(other) - def __add__(self, other: "Transform") -> "Transform": + def __add__(self, other: Transform) -> Transform: """Compose two transforms (transform composition). The operation self + other represents applying transformation 'other' @@ -137,7 +141,7 @@ def __add__(self, other: "Transform") -> "Transform": ts=self.ts, ) - def inverse(self) -> "Transform": + def inverse(self) -> Transform: """Compute the inverse transform. The inverse transform reverses the direction of the transformation. @@ -162,7 +166,7 @@ def inverse(self) -> "Transform": ) @classmethod - def from_ros_transform_stamped(cls, ros_msg: ROSTransformStamped) -> "Transform": + def from_ros_transform_stamped(cls, ros_msg: ROSTransformStamped) -> Transform: """Create a Transform from a ROS geometry_msgs/TransformStamped message. Args: @@ -225,12 +229,12 @@ def to_ros_transform_stamped(self) -> ROSTransformStamped: return ros_msg - def __neg__(self) -> "Transform": + def __neg__(self) -> Transform: """Unary minus operator returns the inverse transform.""" return self.inverse() @classmethod - def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": + def from_pose(cls, frame_id: str, pose: Pose | PoseStamped) -> Transform: """Create a Transform from a Pose or PoseStamped. Args: @@ -261,7 +265,7 @@ def from_pose(cls, frame_id: str, pose: "Pose | PoseStamped") -> "Transform": else: raise TypeError(f"Expected Pose or PoseStamped, got {type(pose).__name__}") - def to_pose(self, **kwargs) -> "PoseStamped": + def to_pose(self, **kwargs) -> PoseStamped: """Create a Transform from a Pose or PoseStamped. Args: @@ -283,7 +287,7 @@ def to_pose(self, **kwargs) -> "PoseStamped": **kwargs, ) - def to_matrix(self) -> "np.ndarray": + def to_matrix(self) -> np.ndarray: """Convert Transform to a 4x4 transformation matrix. Returns a homogeneous transformation matrix that represents both diff --git a/dimos/msgs/geometry_msgs/Twist.py b/dimos/msgs/geometry_msgs/Twist.py index 2b7b4206a3..a57f9bb3ff 100644 --- a/dimos/msgs/geometry_msgs/Twist.py +++ b/dimos/msgs/geometry_msgs/Twist.py @@ -14,23 +14,22 @@ from __future__ import annotations -import struct -from io import BytesIO -from typing import BinaryIO +from typing import TYPE_CHECKING from dimos_lcm.geometry_msgs import Twist as LCMTwist from plum import dispatch try: - from geometry_msgs.msg import Twist as ROSTwist - from geometry_msgs.msg import Vector3 as ROSVector3 + from geometry_msgs.msg import Twist as ROSTwist, Vector3 as ROSVector3 except ImportError: ROSTwist = None ROSVector3 = None -from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs.Quaternion import Quaternion + class Twist(LCMTwist): linear: Vector3 @@ -57,7 +56,7 @@ def __init__(self, linear: VectorLike, angular: Quaternion) -> None: self.angular = angular.to_euler() @dispatch - def __init__(self, twist: "Twist") -> None: + def __init__(self, twist: Twist) -> None: """Initialize from another Twist (copy constructor).""" self.linear = Vector3(twist.linear) self.angular = Vector3(twist.angular) @@ -69,7 +68,7 @@ def __init__(self, lcm_twist: LCMTwist) -> None: self.angular = Vector3(lcm_twist.angular) @dispatch - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: """Handle keyword arguments for LCM compatibility.""" linear = kwargs.get("linear", Vector3()) angular = kwargs.get("angular", Vector3()) @@ -109,7 +108,7 @@ def __bool__(self) -> bool: return not self.is_zero() @classmethod - def from_ros_msg(cls, ros_msg: ROSTwist) -> "Twist": + def from_ros_msg(cls, ros_msg: ROSTwist) -> Twist: """Create a Twist from a ROS geometry_msgs/Twist message. Args: diff --git a/dimos/msgs/geometry_msgs/TwistStamped.py b/dimos/msgs/geometry_msgs/TwistStamped.py index 5c464dfa17..1a14d8cb0d 100644 --- a/dimos/msgs/geometry_msgs/TwistStamped.py +++ b/dimos/msgs/geometry_msgs/TwistStamped.py @@ -14,14 +14,10 @@ from __future__ import annotations -import struct import time -from io import BytesIO from typing import BinaryIO, TypeAlias from dimos_lcm.geometry_msgs import TwistStamped as LCMTwistStamped -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime from plum import dispatch try: @@ -30,7 +26,7 @@ ROSTwistStamped = None from dimos.msgs.geometry_msgs.Twist import Twist -from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.msgs.geometry_msgs.Vector3 import VectorConvertable from dimos.types.timestamped import Timestamped # Types that can be converted to/from TwistStamped @@ -79,7 +75,7 @@ def __str__(self) -> str: ) @classmethod - def from_ros_msg(cls, ros_msg: ROSTwistStamped) -> "TwistStamped": + def from_ros_msg(cls, ros_msg: ROSTwistStamped) -> TwistStamped: """Create a TwistStamped from a ROS geometry_msgs/TwistStamped message. Args: diff --git a/dimos/msgs/geometry_msgs/TwistWithCovariance.py b/dimos/msgs/geometry_msgs/TwistWithCovariance.py index 18237cf7b9..53e77beaf7 100644 --- a/dimos/msgs/geometry_msgs/TwistWithCovariance.py +++ b/dimos/msgs/geometry_msgs/TwistWithCovariance.py @@ -16,8 +16,8 @@ from typing import TypeAlias -import numpy as np from dimos_lcm.geometry_msgs import TwistWithCovariance as LCMTwistWithCovariance +import numpy as np from plum import dispatch try: @@ -113,7 +113,7 @@ def __init__( self.twist = Twist(twist[0], twist[1]) self.covariance = np.array(twist_tuple[1], dtype=float).reshape(36) - def __getattribute__(self, name): + def __getattribute__(self, name: str): """Override to ensure covariance is always returned as numpy array.""" if name == "covariance": cov = object.__getattribute__(self, "covariance") @@ -122,7 +122,7 @@ def __getattribute__(self, name): return cov return super().__getattribute__(name) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value) -> None: """Override to ensure covariance is stored as numpy array.""" if name == "covariance": if not isinstance(value, np.ndarray): @@ -185,7 +185,7 @@ def lcm_encode(self) -> bytes: return lcm_msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes) -> "TwistWithCovariance": + def lcm_decode(cls, data: bytes) -> TwistWithCovariance: """Decode from LCM binary format.""" lcm_msg = LCMTwistWithCovariance.lcm_decode(data) twist = Twist( @@ -195,7 +195,7 @@ def lcm_decode(cls, data: bytes) -> "TwistWithCovariance": return cls(twist, lcm_msg.covariance) @classmethod - def from_ros_msg(cls, ros_msg: ROSTwistWithCovariance) -> "TwistWithCovariance": + def from_ros_msg(cls, ros_msg: ROSTwistWithCovariance) -> TwistWithCovariance: """Create a TwistWithCovariance from a ROS geometry_msgs/TwistWithCovariance message. Args: diff --git a/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py index 1cc4c010a5..20684d9375 100644 --- a/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py +++ b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py @@ -17,8 +17,8 @@ import time from typing import TypeAlias -import numpy as np from dimos_lcm.geometry_msgs import TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped +import numpy as np from plum import dispatch try: @@ -121,7 +121,7 @@ def __str__(self) -> str: ) @classmethod - def from_ros_msg(cls, ros_msg: ROSTwistWithCovarianceStamped) -> "TwistWithCovarianceStamped": + def from_ros_msg(cls, ros_msg: ROSTwistWithCovarianceStamped) -> TwistWithCovarianceStamped: """Create a TwistWithCovarianceStamped from a ROS geometry_msgs/TwistWithCovarianceStamped message. Args: diff --git a/dimos/msgs/geometry_msgs/Vector3.py b/dimos/msgs/geometry_msgs/Vector3.py index 2eb204693b..05d3340a42 100644 --- a/dimos/msgs/geometry_msgs/Vector3.py +++ b/dimos/msgs/geometry_msgs/Vector3.py @@ -14,13 +14,11 @@ from __future__ import annotations -import struct from collections.abc import Sequence -from io import BytesIO -from typing import BinaryIO, TypeAlias +from typing import TypeAlias -import numpy as np from dimos_lcm.geometry_msgs import Vector3 as LCMVector3 +import numpy as np from plum import dispatch # Types that can be converted to/from Vector @@ -92,7 +90,7 @@ def __init__(self, array: np.ndarray) -> None: self.z = float(data[2]) @dispatch - def __init__(self, vector: "Vector3") -> None: + def __init__(self, vector: Vector3) -> None: """Initialize from another Vector3 (copy constructor).""" self.x = vector.x self.y = vector.y @@ -126,7 +124,7 @@ def data(self) -> np.ndarray: """Get the underlying numpy array.""" return np.array([self.x, self.y, self.z], dtype=float) - def __getitem__(self, idx): + def __getitem__(self, idx: int): if idx == 0: return self.x elif idx == 1: @@ -386,7 +384,7 @@ def __bool__(self) -> bool: @dispatch -def to_numpy(value: "Vector3") -> np.ndarray: +def to_numpy(value: Vector3) -> np.ndarray: """Convert a Vector3 to a numpy array.""" return value.to_numpy() @@ -404,7 +402,7 @@ def to_numpy(value: Sequence[int | float]) -> np.ndarray: @dispatch -def to_vector(value: "Vector3") -> Vector3: +def to_vector(value: Vector3) -> Vector3: """Pass through Vector3 objects.""" return value diff --git a/dimos/msgs/geometry_msgs/test_Pose.py b/dimos/msgs/geometry_msgs/test_Pose.py index 6d9c10b1c2..e5c373e166 100644 --- a/dimos/msgs/geometry_msgs/test_Pose.py +++ b/dimos/msgs/geometry_msgs/test_Pose.py @@ -14,14 +14,12 @@ import pickle +from dimos_lcm.geometry_msgs import Pose as LCMPose import numpy as np import pytest -from dimos_lcm.geometry_msgs import Pose as LCMPose try: - from geometry_msgs.msg import Pose as ROSPose - from geometry_msgs.msg import Point as ROSPoint - from geometry_msgs.msg import Quaternion as ROSQuaternion + from geometry_msgs.msg import Point as ROSPoint, Pose as ROSPose, Quaternion as ROSQuaternion except ImportError: ROSPose = None ROSPoint = None @@ -32,7 +30,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 -def test_pose_default_init(): +def test_pose_default_init() -> None: """Test that default initialization creates a pose at origin with identity orientation.""" pose = Pose() @@ -53,7 +51,7 @@ def test_pose_default_init(): assert pose.z == 0.0 -def test_pose_pose_init(): +def test_pose_pose_init() -> None: """Test initialization with position coordinates only (identity orientation).""" pose_data = Pose(1.0, 2.0, 3.0) @@ -76,7 +74,7 @@ def test_pose_pose_init(): assert pose.z == 3.0 -def test_pose_position_init(): +def test_pose_position_init() -> None: """Test initialization with position coordinates only (identity orientation).""" pose = Pose(1.0, 2.0, 3.0) @@ -97,7 +95,7 @@ def test_pose_position_init(): assert pose.z == 3.0 -def test_pose_full_init(): +def test_pose_full_init() -> None: """Test initialization with position and orientation coordinates.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) @@ -118,7 +116,7 @@ def test_pose_full_init(): assert pose.z == 3.0 -def test_pose_vector_position_init(): +def test_pose_vector_position_init() -> None: """Test initialization with Vector3 position (identity orientation).""" position = Vector3(4.0, 5.0, 6.0) pose = Pose(position) @@ -135,7 +133,7 @@ def test_pose_vector_position_init(): assert pose.orientation.w == 1.0 -def test_pose_vector_quaternion_init(): +def test_pose_vector_quaternion_init() -> None: """Test initialization with Vector3 position and Quaternion orientation.""" position = Vector3(1.0, 2.0, 3.0) orientation = Quaternion(0.1, 0.2, 0.3, 0.9) @@ -153,7 +151,7 @@ def test_pose_vector_quaternion_init(): assert pose.orientation.w == 0.9 -def test_pose_list_init(): +def test_pose_list_init() -> None: """Test initialization with lists for position and orientation.""" position_list = [1.0, 2.0, 3.0] orientation_list = [0.1, 0.2, 0.3, 0.9] @@ -171,7 +169,7 @@ def test_pose_list_init(): assert pose.orientation.w == 0.9 -def test_pose_tuple_init(): +def test_pose_tuple_init() -> None: """Test initialization from a tuple of (position, orientation).""" position = [1.0, 2.0, 3.0] orientation = [0.1, 0.2, 0.3, 0.9] @@ -190,7 +188,7 @@ def test_pose_tuple_init(): assert pose.orientation.w == 0.9 -def test_pose_dict_init(): +def test_pose_dict_init() -> None: """Test initialization from a dictionary with 'position' and 'orientation' keys.""" pose_dict = {"position": [1.0, 2.0, 3.0], "orientation": [0.1, 0.2, 0.3, 0.9]} pose = Pose(pose_dict) @@ -207,7 +205,7 @@ def test_pose_dict_init(): assert pose.orientation.w == 0.9 -def test_pose_copy_init(): +def test_pose_copy_init() -> None: """Test initialization from another Pose (copy constructor).""" original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) copy = Pose(original) @@ -228,7 +226,7 @@ def test_pose_copy_init(): assert copy == original -def test_pose_lcm_init(): +def test_pose_lcm_init() -> None: """Test initialization from an LCM Pose.""" # Create LCM pose lcm_pose = LCMPose() @@ -254,7 +252,7 @@ def test_pose_lcm_init(): assert pose.orientation.w == 0.9 -def test_pose_properties(): +def test_pose_properties() -> None: """Test pose property access.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) @@ -270,7 +268,7 @@ def test_pose_properties(): assert pose.yaw == euler.z -def test_pose_euler_properties_identity(): +def test_pose_euler_properties_identity() -> None: """Test pose Euler angle properties with identity orientation.""" pose = Pose(1.0, 2.0, 3.0) # Identity orientation @@ -285,7 +283,7 @@ def test_pose_euler_properties_identity(): assert np.isclose(pose.orientation.euler.z, 0.0, atol=1e-10) -def test_pose_repr(): +def test_pose_repr() -> None: """Test pose string representation.""" pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) @@ -301,7 +299,7 @@ def test_pose_repr(): assert "2.567" in repr_str or "2.57" in repr_str -def test_pose_str(): +def test_pose_str() -> None: """Test pose string formatting.""" pose = Pose(1.234, 2.567, 3.891, 0.1, 0.2, 0.3, 0.9) @@ -319,7 +317,7 @@ def test_pose_str(): assert str_repr.count("Pose") == 1 -def test_pose_equality(): +def test_pose_equality() -> None: """Test pose equality comparison.""" pose1 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) pose2 = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) @@ -338,10 +336,10 @@ def test_pose_equality(): # Different types assert pose1 != "not a pose" assert pose1 != [1.0, 2.0, 3.0] - assert pose1 != None + assert pose1 is not None -def test_pose_with_numpy_arrays(): +def test_pose_with_numpy_arrays() -> None: """Test pose initialization with numpy arrays.""" position_array = np.array([1.0, 2.0, 3.0]) orientation_array = np.array([0.1, 0.2, 0.3, 0.9]) @@ -360,7 +358,7 @@ def test_pose_with_numpy_arrays(): assert pose.orientation.w == 0.9 -def test_pose_with_mixed_types(): +def test_pose_with_mixed_types() -> None: """Test pose initialization with mixed input types.""" # Position as tuple, orientation as list pose1 = Pose((1.0, 2.0, 3.0), [0.1, 0.2, 0.3, 0.9]) @@ -380,7 +378,7 @@ def test_pose_with_mixed_types(): assert pose1.orientation.w == pose2.orientation.w -def test_to_pose_passthrough(): +def test_to_pose_passthrough() -> None: """Test to_pose function with Pose input (passthrough).""" original = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) result = to_pose(original) @@ -389,7 +387,7 @@ def test_to_pose_passthrough(): assert result is original -def test_to_pose_conversion(): +def test_to_pose_conversion() -> None: """Test to_pose function with convertible inputs.""" # Note: The to_pose conversion function has type checking issues in the current implementation # Test direct construction instead to verify the intended functionality @@ -421,7 +419,7 @@ def test_to_pose_conversion(): assert result2.orientation.w == 0.9 -def test_pose_euler_roundtrip(): +def test_pose_euler_roundtrip() -> None: """Test conversion from Euler angles to quaternion and back.""" # Start with known Euler angles (small angles to avoid gimbal lock) roll = 0.1 @@ -444,7 +442,7 @@ def test_pose_euler_roundtrip(): assert np.isclose(result_euler.z, yaw, atol=1e-6) -def test_pose_zero_position(): +def test_pose_zero_position() -> None: """Test pose with zero position vector.""" # Use manual construction since Vector3.zeros has signature issues pose = Pose(0.0, 0.0, 0.0) # Position at origin with identity orientation @@ -457,7 +455,7 @@ def test_pose_zero_position(): assert np.isclose(pose.yaw, 0.0, atol=1e-10) -def test_pose_unit_vectors(): +def test_pose_unit_vectors() -> None: """Test pose with unit vector positions.""" # Test unit x vector position pose_x = Pose(Vector3.unit_x()) @@ -478,7 +476,7 @@ def test_pose_unit_vectors(): assert pose_z.z == 1.0 -def test_pose_negative_coordinates(): +def test_pose_negative_coordinates() -> None: """Test pose with negative coordinates.""" pose = Pose(-1.0, -2.0, -3.0, -0.1, -0.2, -0.3, 0.9) @@ -494,7 +492,7 @@ def test_pose_negative_coordinates(): assert pose.orientation.w == 0.9 -def test_pose_large_coordinates(): +def test_pose_large_coordinates() -> None: """Test pose with large coordinate values.""" large_value = 1000.0 pose = Pose(large_value, large_value, large_value) @@ -514,7 +512,7 @@ def test_pose_large_coordinates(): "x,y,z", [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (0.5, -0.5, 1.5), (100.0, -100.0, 0.0)], ) -def test_pose_parametrized_positions(x, y, z): +def test_pose_parametrized_positions(x, y, z) -> None: """Parametrized test for various position values.""" pose = Pose(x, y, z) @@ -539,7 +537,7 @@ def test_pose_parametrized_positions(x, y, z): (0.5, 0.5, 0.5, 0.5), # Equal components ], ) -def test_pose_parametrized_orientations(qx, qy, qz, qw): +def test_pose_parametrized_orientations(qx, qy, qz, qw) -> None: """Parametrized test for various orientation values.""" pose = Pose(0.0, 0.0, 0.0, qx, qy, qz, qw) @@ -555,10 +553,10 @@ def test_pose_parametrized_orientations(qx, qy, qz, qw): assert pose.orientation.w == qw -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test encoding and decoding of Pose to/from binary LCM format.""" - def encodepass(): + def encodepass() -> None: pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) binary_msg = pose_source.lcm_encode() pose_dest = Pose.lcm_decode(binary_msg) @@ -574,10 +572,10 @@ def encodepass(): print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") -def test_pickle_encode_decode(): +def test_pickle_encode_decode() -> None: """Test encoding and decoding of Pose to/from binary LCM format.""" - def encodepass(): + def encodepass() -> None: pose_source = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) binary_msg = pickle.dumps(pose_source) pose_dest = pickle.loads(binary_msg) @@ -590,7 +588,7 @@ def encodepass(): print(f"{timeit.timeit(encodepass, number=1000)} ms per cycle") -def test_pose_addition_translation_only(): +def test_pose_addition_translation_only() -> None: """Test pose addition with translation only (identity rotations).""" # Two poses with only translations pose1 = Pose(1.0, 2.0, 3.0) # First translation @@ -610,7 +608,7 @@ def test_pose_addition_translation_only(): assert result.orientation.w == 1.0 -def test_pose_addition_with_rotation(): +def test_pose_addition_with_rotation() -> None: """Test pose addition with rotation applied to translation.""" # First pose: at origin, rotated 90 degrees around Z (yaw) # 90 degree rotation quaternion around Z: (0, 0, sin(pi/4), cos(pi/4)) @@ -635,7 +633,7 @@ def test_pose_addition_with_rotation(): assert np.isclose(result.orientation.w, np.cos(angle / 2), atol=1e-10) -def test_pose_addition_rotation_composition(): +def test_pose_addition_rotation_composition() -> None: """Test that rotations are properly composed.""" # First pose: 45 degrees around Z angle1 = np.pi / 4 # 45 degrees @@ -657,7 +655,7 @@ def test_pose_addition_rotation_composition(): assert np.isclose(result.orientation.w, expected_qw, atol=1e-10) -def test_pose_addition_full_transform(): +def test_pose_addition_full_transform() -> None: """Test full pose composition with translation and rotation.""" # Robot pose: at (2, 1, 0), facing 90 degrees left (positive yaw) robot_yaw = np.pi / 2 # 90 degrees @@ -682,7 +680,7 @@ def test_pose_addition_full_transform(): assert np.isclose(object_in_world.yaw, robot_yaw, atol=1e-10) -def test_pose_addition_chain(): +def test_pose_addition_chain() -> None: """Test chaining multiple pose additions.""" # Create a chain of transformations pose1 = Pose(1.0, 0.0, 0.0) # Move 1 unit in X @@ -698,7 +696,7 @@ def test_pose_addition_chain(): assert result.position.z == 1.0 -def test_pose_addition_with_convertible(): +def test_pose_addition_with_convertible() -> None: """Test pose addition with convertible types.""" pose1 = Pose(1.0, 2.0, 3.0) @@ -717,7 +715,7 @@ def test_pose_addition_with_convertible(): assert result2.position.z == 3.0 -def test_pose_identity_addition(): +def test_pose_identity_addition() -> None: """Test that adding identity pose leaves pose unchanged.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) identity = Pose() # Identity pose at origin @@ -734,7 +732,7 @@ def test_pose_identity_addition(): assert result.orientation.w == pose.orientation.w -def test_pose_addition_3d_rotation(): +def test_pose_addition_3d_rotation() -> None: """Test pose addition with 3D rotations.""" # First pose: rotated around X axis (roll) roll = np.pi / 4 # 45 degrees @@ -759,7 +757,7 @@ def test_pose_addition_3d_rotation(): @pytest.mark.ros -def test_pose_from_ros_msg(): +def test_pose_from_ros_msg() -> None: """Test creating a Pose from a ROS Pose message.""" ros_msg = ROSPose() ros_msg.position = ROSPoint(x=1.0, y=2.0, z=3.0) @@ -777,7 +775,7 @@ def test_pose_from_ros_msg(): @pytest.mark.ros -def test_pose_to_ros_msg(): +def test_pose_to_ros_msg() -> None: """Test converting a Pose to a ROS Pose message.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) @@ -794,7 +792,7 @@ def test_pose_to_ros_msg(): @pytest.mark.ros -def test_pose_ros_roundtrip(): +def test_pose_ros_roundtrip() -> None: """Test round-trip conversion between Pose and ROS Pose.""" original = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) diff --git a/dimos/msgs/geometry_msgs/test_PoseStamped.py b/dimos/msgs/geometry_msgs/test_PoseStamped.py index cbc0c26876..6224b6548a 100644 --- a/dimos/msgs/geometry_msgs/test_PoseStamped.py +++ b/dimos/msgs/geometry_msgs/test_PoseStamped.py @@ -25,7 +25,7 @@ from dimos.msgs.geometry_msgs import PoseStamped -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test encoding and decoding of Pose to/from binary LCM format.""" pose_source = PoseStamped( @@ -47,7 +47,7 @@ def test_lcm_encode_decode(): assert pose_dest == pose_source -def test_pickle_encode_decode(): +def test_pickle_encode_decode() -> None: """Test encoding and decoding of PoseStamped to/from binary LCM format.""" pose_source = PoseStamped( @@ -63,7 +63,7 @@ def test_pickle_encode_decode(): @pytest.mark.ros -def test_pose_stamped_from_ros_msg(): +def test_pose_stamped_from_ros_msg() -> None: """Test creating a PoseStamped from a ROS PoseStamped message.""" ros_msg = ROSPoseStamped() ros_msg.header.frame_id = "world" @@ -91,7 +91,7 @@ def test_pose_stamped_from_ros_msg(): @pytest.mark.ros -def test_pose_stamped_to_ros_msg(): +def test_pose_stamped_to_ros_msg() -> None: """Test converting a PoseStamped to a ROS PoseStamped message.""" pose_stamped = PoseStamped( ts=123.456, @@ -116,7 +116,7 @@ def test_pose_stamped_to_ros_msg(): @pytest.mark.ros -def test_pose_stamped_ros_roundtrip(): +def test_pose_stamped_ros_roundtrip() -> None: """Test round-trip conversion between PoseStamped and ROS PoseStamped.""" original = PoseStamped( ts=123.789, diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py index dd254104a5..ea455ba488 100644 --- a/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py @@ -12,15 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance import numpy as np import pytest -from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance try: - from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance - from geometry_msgs.msg import Pose as ROSPose - from geometry_msgs.msg import Point as ROSPoint - from geometry_msgs.msg import Quaternion as ROSQuaternion + from geometry_msgs.msg import ( + Point as ROSPoint, + Pose as ROSPose, + PoseWithCovariance as ROSPoseWithCovariance, + Quaternion as ROSQuaternion, + ) except ImportError: ROSPoseWithCovariance = None ROSPose = None @@ -29,11 +31,9 @@ from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -def test_pose_with_covariance_default_init(): +def test_pose_with_covariance_default_init() -> None: """Test that default initialization creates a pose at origin with zero covariance.""" pose_cov = PoseWithCovariance() @@ -51,7 +51,7 @@ def test_pose_with_covariance_default_init(): assert pose_cov.covariance.shape == (36,) -def test_pose_with_covariance_pose_init(): +def test_pose_with_covariance_pose_init() -> None: """Test initialization with a Pose object.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) pose_cov = PoseWithCovariance(pose) @@ -69,7 +69,7 @@ def test_pose_with_covariance_pose_init(): assert np.all(pose_cov.covariance == 0.0) -def test_pose_with_covariance_pose_and_covariance_init(): +def test_pose_with_covariance_pose_and_covariance_init() -> None: """Test initialization with pose and covariance.""" pose = Pose(1.0, 2.0, 3.0) covariance = np.arange(36, dtype=float) @@ -84,7 +84,7 @@ def test_pose_with_covariance_pose_and_covariance_init(): assert np.array_equal(pose_cov.covariance, covariance) -def test_pose_with_covariance_list_covariance(): +def test_pose_with_covariance_list_covariance() -> None: """Test initialization with covariance as a list.""" pose = Pose(1.0, 2.0, 3.0) covariance_list = list(range(36)) @@ -95,7 +95,7 @@ def test_pose_with_covariance_list_covariance(): assert np.array_equal(pose_cov.covariance, np.array(covariance_list)) -def test_pose_with_covariance_copy_init(): +def test_pose_with_covariance_copy_init() -> None: """Test copy constructor.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) covariance = np.arange(36, dtype=float) @@ -113,7 +113,7 @@ def test_pose_with_covariance_copy_init(): assert copy.covariance[0] != 999.0 -def test_pose_with_covariance_lcm_init(): +def test_pose_with_covariance_lcm_init() -> None: """Test initialization from LCM message.""" lcm_msg = LCMPoseWithCovariance() lcm_msg.pose.position.x = 1.0 @@ -140,7 +140,7 @@ def test_pose_with_covariance_lcm_init(): assert np.array_equal(pose_cov.covariance, np.arange(36)) -def test_pose_with_covariance_dict_init(): +def test_pose_with_covariance_dict_init() -> None: """Test initialization from dictionary.""" pose_dict = {"pose": Pose(1.0, 2.0, 3.0), "covariance": list(range(36))} pose_cov = PoseWithCovariance(pose_dict) @@ -151,7 +151,7 @@ def test_pose_with_covariance_dict_init(): assert np.array_equal(pose_cov.covariance, np.arange(36)) -def test_pose_with_covariance_dict_init_no_covariance(): +def test_pose_with_covariance_dict_init_no_covariance() -> None: """Test initialization from dictionary without covariance.""" pose_dict = {"pose": Pose(1.0, 2.0, 3.0)} pose_cov = PoseWithCovariance(pose_dict) @@ -160,7 +160,7 @@ def test_pose_with_covariance_dict_init_no_covariance(): assert np.all(pose_cov.covariance == 0.0) -def test_pose_with_covariance_tuple_init(): +def test_pose_with_covariance_tuple_init() -> None: """Test initialization from tuple.""" pose = Pose(1.0, 2.0, 3.0) covariance = np.arange(36, dtype=float) @@ -173,7 +173,7 @@ def test_pose_with_covariance_tuple_init(): assert np.array_equal(pose_cov.covariance, covariance) -def test_pose_with_covariance_properties(): +def test_pose_with_covariance_properties() -> None: """Test convenience properties.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) pose_cov = PoseWithCovariance(pose) @@ -198,7 +198,7 @@ def test_pose_with_covariance_properties(): assert pose_cov.yaw == pose.yaw -def test_pose_with_covariance_matrix_property(): +def test_pose_with_covariance_matrix_property() -> None: """Test covariance matrix property.""" pose = Pose() covariance_array = np.arange(36, dtype=float) @@ -216,7 +216,7 @@ def test_pose_with_covariance_matrix_property(): assert np.array_equal(pose_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) -def test_pose_with_covariance_repr(): +def test_pose_with_covariance_repr() -> None: """Test string representation.""" pose = Pose(1.234, 2.567, 3.891) pose_cov = PoseWithCovariance(pose) @@ -228,7 +228,7 @@ def test_pose_with_covariance_repr(): assert "36 elements" in repr_str -def test_pose_with_covariance_str(): +def test_pose_with_covariance_str() -> None: """Test string formatting.""" pose = Pose(1.234, 2.567, 3.891) covariance = np.eye(6).flatten() @@ -243,7 +243,7 @@ def test_pose_with_covariance_str(): assert "6.000" in str_repr # Trace of identity matrix is 6 -def test_pose_with_covariance_equality(): +def test_pose_with_covariance_equality() -> None: """Test equality comparison.""" pose1 = Pose(1.0, 2.0, 3.0) cov1 = np.arange(36, dtype=float) @@ -268,10 +268,10 @@ def test_pose_with_covariance_equality(): # Different type assert pose_cov1 != "not a pose" - assert pose_cov1 != None + assert pose_cov1 is not None -def test_pose_with_covariance_lcm_encode_decode(): +def test_pose_with_covariance_lcm_encode_decode() -> None: """Test LCM encoding and decoding.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) covariance = np.arange(36, dtype=float) @@ -289,7 +289,7 @@ def test_pose_with_covariance_lcm_encode_decode(): @pytest.mark.ros -def test_pose_with_covariance_from_ros_msg(): +def test_pose_with_covariance_from_ros_msg() -> None: """Test creating from ROS message.""" ros_msg = ROSPoseWithCovariance() ros_msg.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) @@ -309,7 +309,7 @@ def test_pose_with_covariance_from_ros_msg(): @pytest.mark.ros -def test_pose_with_covariance_to_ros_msg(): +def test_pose_with_covariance_to_ros_msg() -> None: """Test converting to ROS message.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) covariance = np.arange(36, dtype=float) @@ -329,7 +329,7 @@ def test_pose_with_covariance_to_ros_msg(): @pytest.mark.ros -def test_pose_with_covariance_ros_roundtrip(): +def test_pose_with_covariance_ros_roundtrip() -> None: """Test round-trip conversion with ROS messages.""" pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) covariance = np.random.rand(36) @@ -341,7 +341,7 @@ def test_pose_with_covariance_ros_roundtrip(): assert restored == original -def test_pose_with_covariance_zero_covariance(): +def test_pose_with_covariance_zero_covariance() -> None: """Test with zero covariance matrix.""" pose = Pose(1.0, 2.0, 3.0) pose_cov = PoseWithCovariance(pose) @@ -350,7 +350,7 @@ def test_pose_with_covariance_zero_covariance(): assert np.trace(pose_cov.covariance_matrix) == 0.0 -def test_pose_with_covariance_diagonal_covariance(): +def test_pose_with_covariance_diagonal_covariance() -> None: """Test with diagonal covariance matrix.""" pose = Pose() covariance = np.zeros(36) @@ -378,7 +378,7 @@ def test_pose_with_covariance_diagonal_covariance(): "x,y,z", [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (100.0, -100.0, 0.0)], ) -def test_pose_with_covariance_parametrized_positions(x, y, z): +def test_pose_with_covariance_parametrized_positions(x, y, z) -> None: """Parametrized test for various position values.""" pose = Pose(x, y, z) pose_cov = PoseWithCovariance(pose) diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py index 139279add3..25a246495d 100644 --- a/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py @@ -18,13 +18,15 @@ import pytest try: - from geometry_msgs.msg import PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped - from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance - from geometry_msgs.msg import Pose as ROSPose - from geometry_msgs.msg import Point as ROSPoint - from geometry_msgs.msg import Quaternion as ROSQuaternion - from std_msgs.msg import Header as ROSHeader from builtin_interfaces.msg import Time as ROSTime + from geometry_msgs.msg import ( + Point as ROSPoint, + Pose as ROSPose, + PoseWithCovariance as ROSPoseWithCovariance, + PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped, + Quaternion as ROSQuaternion, + ) + from std_msgs.msg import Header as ROSHeader except ImportError: ROSHeader = None ROSPoseWithCovarianceStamped = None @@ -34,18 +36,13 @@ ROSTime = None ROSPoseWithCovariance = None -from dimos_lcm.geometry_msgs import PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped -from dimos.msgs.geometry_msgs.Quaternion import Quaternion -from dimos.msgs.geometry_msgs.Vector3 import Vector3 -def test_pose_with_covariance_stamped_default_init(): +def test_pose_with_covariance_stamped_default_init() -> None: """Test default initialization.""" if ROSPoseWithCovariance is None: pytest.skip("ROS not available") @@ -77,7 +74,7 @@ def test_pose_with_covariance_stamped_default_init(): assert np.all(pose_cov_stamped.covariance == 0.0) -def test_pose_with_covariance_stamped_with_timestamp(): +def test_pose_with_covariance_stamped_with_timestamp() -> None: """Test initialization with specific timestamp.""" ts = 1234567890.123456 frame_id = "base_link" @@ -87,7 +84,7 @@ def test_pose_with_covariance_stamped_with_timestamp(): assert pose_cov_stamped.frame_id == frame_id -def test_pose_with_covariance_stamped_with_pose(): +def test_pose_with_covariance_stamped_with_pose() -> None: """Test initialization with pose.""" ts = 1234567890.123456 frame_id = "map" @@ -106,7 +103,7 @@ def test_pose_with_covariance_stamped_with_pose(): assert np.array_equal(pose_cov_stamped.covariance, covariance) -def test_pose_with_covariance_stamped_properties(): +def test_pose_with_covariance_stamped_properties() -> None: """Test convenience properties.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) covariance = np.eye(6).flatten() @@ -136,7 +133,7 @@ def test_pose_with_covariance_stamped_properties(): assert np.trace(cov_matrix) == 6.0 -def test_pose_with_covariance_stamped_str(): +def test_pose_with_covariance_stamped_str() -> None: """Test string representation.""" pose = Pose(1.234, 2.567, 3.891) covariance = np.eye(6).flatten() * 2.0 @@ -153,7 +150,7 @@ def test_pose_with_covariance_stamped_str(): assert "12.000" in str_repr # Trace of 2*identity is 12 -def test_pose_with_covariance_stamped_lcm_encode_decode(): +def test_pose_with_covariance_stamped_lcm_encode_decode() -> None: """Test LCM encoding and decoding.""" ts = 1234567890.123456 frame_id = "camera_link" @@ -184,7 +181,7 @@ def test_pose_with_covariance_stamped_lcm_encode_decode(): @pytest.mark.ros -def test_pose_with_covariance_stamped_from_ros_msg(): +def test_pose_with_covariance_stamped_from_ros_msg() -> None: """Test creating from ROS message.""" ros_msg = ROSPoseWithCovarianceStamped() @@ -217,7 +214,7 @@ def test_pose_with_covariance_stamped_from_ros_msg(): @pytest.mark.ros -def test_pose_with_covariance_stamped_to_ros_msg(): +def test_pose_with_covariance_stamped_to_ros_msg() -> None: """Test converting to ROS message.""" ts = 1234567890.567890 frame_id = "imu" @@ -246,7 +243,7 @@ def test_pose_with_covariance_stamped_to_ros_msg(): @pytest.mark.ros -def test_pose_with_covariance_stamped_ros_roundtrip(): +def test_pose_with_covariance_stamped_ros_roundtrip() -> None: """Test round-trip conversion with ROS messages.""" ts = 2147483647.987654 # Max int32 value for ROS Time.sec frame_id = "robot_base" @@ -275,7 +272,7 @@ def test_pose_with_covariance_stamped_ros_roundtrip(): assert np.allclose(restored.covariance, original.covariance) -def test_pose_with_covariance_stamped_zero_timestamp(): +def test_pose_with_covariance_stamped_zero_timestamp() -> None: """Test that zero timestamp gets replaced with current time.""" pose_cov_stamped = PoseWithCovarianceStamped(ts=0.0) @@ -284,7 +281,7 @@ def test_pose_with_covariance_stamped_zero_timestamp(): assert pose_cov_stamped.ts <= time.time() -def test_pose_with_covariance_stamped_inheritance(): +def test_pose_with_covariance_stamped_inheritance() -> None: """Test that it properly inherits from PoseWithCovariance and Timestamped.""" pose = Pose(1.0, 2.0, 3.0) covariance = np.eye(6).flatten() @@ -304,7 +301,7 @@ def test_pose_with_covariance_stamped_inheritance(): assert hasattr(pose_cov_stamped, "covariance") -def test_pose_with_covariance_stamped_sec_nsec(): +def test_pose_with_covariance_stamped_sec_nsec() -> None: """Test the sec_nsec helper function.""" from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import sec_nsec @@ -338,7 +335,7 @@ def test_pose_with_covariance_stamped_sec_nsec(): "frame_id", ["", "map", "odom", "base_link", "camera_optical_frame", "sensor/lidar/front"], ) -def test_pose_with_covariance_stamped_frame_ids(frame_id): +def test_pose_with_covariance_stamped_frame_ids(frame_id) -> None: """Test various frame ID values.""" pose_cov_stamped = PoseWithCovarianceStamped(frame_id=frame_id) assert pose_cov_stamped.frame_id == frame_id @@ -351,7 +348,7 @@ def test_pose_with_covariance_stamped_frame_ids(frame_id): assert restored.frame_id == frame_id -def test_pose_with_covariance_stamped_different_covariances(): +def test_pose_with_covariance_stamped_different_covariances() -> None: """Test with different covariance patterns.""" pose = Pose(1.0, 2.0, 3.0) diff --git a/dimos/msgs/geometry_msgs/test_Quaternion.py b/dimos/msgs/geometry_msgs/test_Quaternion.py index 18f9e2c5ab..501f5a0271 100644 --- a/dimos/msgs/geometry_msgs/test_Quaternion.py +++ b/dimos/msgs/geometry_msgs/test_Quaternion.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion import numpy as np import pytest -from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion from dimos.msgs.geometry_msgs.Quaternion import Quaternion -def test_quaternion_default_init(): +def test_quaternion_default_init() -> None: """Test that default initialization creates an identity quaternion (w=1, x=y=z=0).""" q = Quaternion() assert q.x == 0.0 @@ -29,7 +29,7 @@ def test_quaternion_default_init(): assert q.to_tuple() == (0.0, 0.0, 0.0, 1.0) -def test_quaternion_component_init(): +def test_quaternion_component_init() -> None: """Test initialization with four float components (x, y, z, w).""" q = Quaternion(0.5, 0.5, 0.5, 0.5) assert q.x == 0.5 @@ -60,7 +60,7 @@ def test_quaternion_component_init(): assert isinstance(q4.x, float) -def test_quaternion_sequence_init(): +def test_quaternion_sequence_init() -> None: """Test initialization from sequence (list, tuple) of 4 numbers.""" # From list q1 = Quaternion([0.1, 0.2, 0.3, 0.4]) @@ -91,7 +91,7 @@ def test_quaternion_sequence_init(): Quaternion([1, 2, 3, 4, 5]) # Too many components -def test_quaternion_numpy_init(): +def test_quaternion_numpy_init() -> None: """Test initialization from numpy array.""" # From numpy array arr = np.array([0.1, 0.2, 0.3, 0.4]) @@ -117,7 +117,7 @@ def test_quaternion_numpy_init(): Quaternion(np.array([1, 2, 3, 4, 5])) # Too many elements -def test_quaternion_copy_init(): +def test_quaternion_copy_init() -> None: """Test initialization from another Quaternion (copy constructor).""" original = Quaternion(0.1, 0.2, 0.3, 0.4) copy = Quaternion(original) @@ -132,7 +132,7 @@ def test_quaternion_copy_init(): assert copy == original -def test_quaternion_lcm_init(): +def test_quaternion_lcm_init() -> None: """Test initialization from LCM Quaternion.""" lcm_quat = LCMQuaternion() lcm_quat.x = 0.1 @@ -147,7 +147,7 @@ def test_quaternion_lcm_init(): assert q.w == 0.4 -def test_quaternion_properties(): +def test_quaternion_properties() -> None: """Test quaternion component properties.""" q = Quaternion(1.0, 2.0, 3.0, 4.0) @@ -161,7 +161,7 @@ def test_quaternion_properties(): assert q.to_tuple() == (1.0, 2.0, 3.0, 4.0) -def test_quaternion_indexing(): +def test_quaternion_indexing() -> None: """Test quaternion indexing support.""" q = Quaternion(1.0, 2.0, 3.0, 4.0) @@ -172,7 +172,7 @@ def test_quaternion_indexing(): assert q[3] == 4.0 -def test_quaternion_euler(): +def test_quaternion_euler() -> None: """Test quaternion to Euler angles conversion.""" # Test identity quaternion (should give zero angles) @@ -197,7 +197,7 @@ def test_quaternion_euler(): assert np.isclose(angles_x90.z, 0.0, atol=1e-10) # yaw should be 0 -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test encoding and decoding of Quaternion to/from binary LCM format.""" q_source = Quaternion(1.0, 2.0, 3.0, 4.0) @@ -210,7 +210,7 @@ def test_lcm_encode_decode(): assert q_dest == q_source -def test_quaternion_multiplication(): +def test_quaternion_multiplication() -> None: """Test quaternion multiplication (Hamilton product).""" # Test identity multiplication q1 = Quaternion(0.5, 0.5, 0.5, 0.5) @@ -245,7 +245,7 @@ def test_quaternion_multiplication(): assert np.isclose(result.w, np.cos(expected_angle / 2), atol=1e-10) -def test_quaternion_conjugate(): +def test_quaternion_conjugate() -> None: """Test quaternion conjugate.""" q = Quaternion(0.1, 0.2, 0.3, 0.4) conj = q.conjugate() @@ -266,7 +266,7 @@ def test_quaternion_conjugate(): assert np.isclose(result.w, expected_w, atol=1e-10) -def test_quaternion_inverse(): +def test_quaternion_inverse() -> None: """Test quaternion inverse.""" # Test with unit quaternion q_unit = Quaternion(0, 0, 0, 1).normalize() # Already normalized but being explicit @@ -297,7 +297,7 @@ def test_quaternion_inverse(): assert np.isclose(result.w, 1, atol=1e-10) -def test_quaternion_normalize(): +def test_quaternion_normalize() -> None: """Test quaternion normalization.""" # Test non-unit quaternion q = Quaternion(1, 2, 3, 4) @@ -315,7 +315,7 @@ def test_quaternion_normalize(): assert np.isclose(q_norm.w, q.w / scale, atol=1e-10) -def test_quaternion_rotate_vector(): +def test_quaternion_rotate_vector() -> None: """Test rotating vectors with quaternions.""" from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -360,7 +360,7 @@ def test_quaternion_rotate_vector(): assert np.isclose(v_rotated.z, v.z, atol=1e-10) -def test_quaternion_inverse_zero(): +def test_quaternion_inverse_zero() -> None: """Test that inverting zero quaternion raises error.""" q_zero = Quaternion(0, 0, 0, 0) @@ -368,7 +368,7 @@ def test_quaternion_inverse_zero(): q_zero.inverse() -def test_quaternion_normalize_zero(): +def test_quaternion_normalize_zero() -> None: """Test that normalizing zero quaternion raises error.""" q_zero = Quaternion(0, 0, 0, 0) @@ -376,7 +376,7 @@ def test_quaternion_normalize_zero(): q_zero.normalize() -def test_quaternion_multiplication_type_error(): +def test_quaternion_multiplication_type_error() -> None: """Test that multiplying quaternion with non-quaternion raises error.""" q = Quaternion(1, 0, 0, 0) diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index f09f0c2966..b61e92ae01 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -23,13 +23,11 @@ except ImportError: ROSTransformStamped = None -from dimos_lcm.geometry_msgs import Transform as LCMTransform -from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 -def test_transform_initialization(): +def test_transform_initialization() -> None: # Test default initialization (identity transform) tf = Transform() assert tf.translation.x == 0.0 @@ -75,7 +73,7 @@ def test_transform_initialization(): assert tf9.rotation == Quaternion(0, 0, 1, 0) -def test_transform_identity(): +def test_transform_identity() -> None: # Test identity class method tf = Transform.identity() assert tf.translation.is_zero() @@ -88,7 +86,7 @@ def test_transform_identity(): assert tf == Transform() -def test_transform_equality(): +def test_transform_equality() -> None: tf1 = Transform(translation=Vector3(1, 2, 3), rotation=Quaternion(0, 0, 0, 1)) tf2 = Transform(translation=Vector3(1, 2, 3), rotation=Quaternion(0, 0, 0, 1)) tf3 = Transform(translation=Vector3(1, 2, 4), rotation=Quaternion(0, 0, 0, 1)) # Different z @@ -102,7 +100,7 @@ def test_transform_equality(): assert tf1 != "not a transform" -def test_transform_string_representations(): +def test_transform_string_representations() -> None: tf = Transform( translation=Vector3(1.5, -2.0, 3.14), rotation=Quaternion(0, 0, 0.707107, 0.707107) ) @@ -121,7 +119,7 @@ def test_transform_string_representations(): assert "Rotation:" in str_str -def test_pose_add_transform(): +def test_pose_add_transform() -> None: initial_pose = Pose(1.0, 0.0, 0.0) # 90 degree rotation around Z axis @@ -168,7 +166,7 @@ def test_pose_add_transform(): print(found_tf.rotation, found_tf.translation) -def test_pose_add_transform_with_rotation(): +def test_pose_add_transform_with_rotation() -> None: # Create a pose at (0, 0, 0) rotated 90 degrees around Z angle = np.pi / 2 initial_pose = Pose(0.0, 0.0, 0.0, 0.0, 0.0, np.sin(angle / 2), np.cos(angle / 2)) @@ -230,7 +228,7 @@ def test_pose_add_transform_with_rotation(): assert np.isclose(transformed_pose2.orientation.w, np.cos(total_angle2 / 2), atol=1e-10) -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: angle = np.pi / 2 transform = Transform( translation=Vector3(2.0, 1.0, 0.0), @@ -244,7 +242,7 @@ def test_lcm_encode_decode(): assert decoded_transform == transform -def test_transform_addition(): +def test_transform_addition() -> None: # Test 1: Simple translation addition (no rotation) t1 = Transform( translation=Vector3(1, 0, 0), @@ -320,7 +318,7 @@ def test_transform_addition(): t1 + "not a transform" -def test_transform_from_pose(): +def test_transform_from_pose() -> None: """Test converting Pose to Transform""" # Create a Pose with position and orientation pose = Pose( @@ -340,7 +338,7 @@ def test_transform_from_pose(): # validating results from example @ # https://foxglove.dev/blog/understanding-ros-transforms -def test_transform_from_ros(): +def test_transform_from_ros() -> None: """Test converting PoseStamped to Transform""" test_time = time.time() pose_stamped = PoseStamped( @@ -370,7 +368,7 @@ def test_transform_from_ros(): assert end_effector_global_pose.translation.y == pytest.approx(0.366, abs=1e-3) -def test_transform_from_pose_stamped(): +def test_transform_from_pose_stamped() -> None: """Test converting PoseStamped to Transform""" # Create a PoseStamped with position, orientation, timestamp and frame test_time = time.time() @@ -392,7 +390,7 @@ def test_transform_from_pose_stamped(): assert transform.child_frame_id == "robot_base" # passed as first argument -def test_transform_from_pose_variants(): +def test_transform_from_pose_variants() -> None: """Test from_pose with different Pose initialization methods""" # Test with Pose created from x,y,z pose1 = Pose(1.0, 2.0, 3.0) @@ -417,7 +415,7 @@ def test_transform_from_pose_variants(): assert transform3.translation.z == 12.0 -def test_transform_from_pose_invalid_type(): +def test_transform_from_pose_invalid_type() -> None: """Test that from_pose raises TypeError for invalid types""" with pytest.raises(TypeError): Transform.from_pose("not a pose") @@ -430,7 +428,7 @@ def test_transform_from_pose_invalid_type(): @pytest.mark.ros -def test_transform_from_ros_transform_stamped(): +def test_transform_from_ros_transform_stamped() -> None: """Test creating a Transform from a ROS TransformStamped message.""" ros_msg = ROSTransformStamped() ros_msg.header.frame_id = "world" @@ -460,7 +458,7 @@ def test_transform_from_ros_transform_stamped(): @pytest.mark.ros -def test_transform_to_ros_transform_stamped(): +def test_transform_to_ros_transform_stamped() -> None: """Test converting a Transform to a ROS TransformStamped message.""" transform = Transform( translation=Vector3(4.0, 5.0, 6.0), @@ -487,7 +485,7 @@ def test_transform_to_ros_transform_stamped(): @pytest.mark.ros -def test_transform_ros_roundtrip(): +def test_transform_ros_roundtrip() -> None: """Test round-trip conversion between Transform and ROS TransformStamped.""" original = Transform( translation=Vector3(7.5, 8.5, 9.5), diff --git a/dimos/msgs/geometry_msgs/test_Twist.py b/dimos/msgs/geometry_msgs/test_Twist.py index 5f463d0bac..49631a5372 100644 --- a/dimos/msgs/geometry_msgs/test_Twist.py +++ b/dimos/msgs/geometry_msgs/test_Twist.py @@ -16,8 +16,7 @@ import pytest try: - from geometry_msgs.msg import Twist as ROSTwist - from geometry_msgs.msg import Vector3 as ROSVector3 + from geometry_msgs.msg import Twist as ROSTwist, Vector3 as ROSVector3 except ImportError: ROSTwist = None ROSVector3 = None @@ -27,7 +26,7 @@ from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 -def test_twist_initialization(): +def test_twist_initialization() -> None: # Test default initialization (zero twist) tw = Twist() assert tw.linear.x == 0.0 @@ -104,7 +103,7 @@ def test_twist_initialization(): assert tw11.angular.is_zero() # Identity quaternion -> zero euler angles -def test_twist_zero(): +def test_twist_zero() -> None: # Test zero class method tw = Twist.zero() assert tw.linear.is_zero() @@ -115,7 +114,7 @@ def test_twist_zero(): assert tw == Twist() -def test_twist_equality(): +def test_twist_equality() -> None: tw1 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.3)) tw2 = Twist(Vector3(1, 2, 3), Vector3(0.1, 0.2, 0.3)) tw3 = Twist(Vector3(1, 2, 4), Vector3(0.1, 0.2, 0.3)) # Different linear z @@ -127,7 +126,7 @@ def test_twist_equality(): assert tw1 != "not a twist" -def test_twist_string_representations(): +def test_twist_string_representations() -> None: tw = Twist(Vector3(1.5, -2.0, 3.14), Vector3(0.1, -0.2, 0.3)) # Test repr @@ -145,7 +144,7 @@ def test_twist_string_representations(): assert "Angular:" in str_str -def test_twist_is_zero(): +def test_twist_is_zero() -> None: # Test zero twist tw1 = Twist() assert tw1.is_zero() @@ -163,7 +162,7 @@ def test_twist_is_zero(): assert not tw4.is_zero() -def test_twist_bool(): +def test_twist_bool() -> None: # Test zero twist is False tw1 = Twist() assert not tw1 @@ -179,7 +178,7 @@ def test_twist_bool(): assert tw4 -def test_twist_lcm_encoding(): +def test_twist_lcm_encoding() -> None: # Test encoding and decoding tw = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.1, 0.2, 0.3)) @@ -196,7 +195,7 @@ def test_twist_lcm_encoding(): assert decoded == tw -def test_twist_with_lists(): +def test_twist_with_lists() -> None: # Test initialization with lists instead of Vector3 tw1 = Twist(linear=[1, 2, 3], angular=[0.1, 0.2, 0.3]) assert tw1.linear == Vector3(1, 2, 3) @@ -209,7 +208,7 @@ def test_twist_with_lists(): @pytest.mark.ros -def test_twist_from_ros_msg(): +def test_twist_from_ros_msg() -> None: """Test Twist.from_ros_msg conversion.""" # Create ROS message ros_msg = ROSTwist() @@ -229,7 +228,7 @@ def test_twist_from_ros_msg(): @pytest.mark.ros -def test_twist_to_ros_msg(): +def test_twist_to_ros_msg() -> None: """Test Twist.to_ros_msg conversion.""" # Create LCM message lcm_msg = Twist(linear=Vector3(40.0, 50.0, 60.0), angular=Vector3(4.0, 5.0, 6.0)) @@ -247,7 +246,7 @@ def test_twist_to_ros_msg(): @pytest.mark.ros -def test_ros_zero_twist_conversion(): +def test_ros_zero_twist_conversion() -> None: """Test conversion of zero twist messages between ROS and LCM.""" # Test ROS to LCM with zero twist ros_zero = ROSTwist() @@ -266,7 +265,7 @@ def test_ros_zero_twist_conversion(): @pytest.mark.ros -def test_ros_negative_values_conversion(): +def test_ros_negative_values_conversion() -> None: """Test ROS conversion with negative values.""" # Create ROS message with negative values ros_msg = ROSTwist() @@ -286,7 +285,7 @@ def test_ros_negative_values_conversion(): @pytest.mark.ros -def test_ros_roundtrip_conversion(): +def test_ros_roundtrip_conversion() -> None: """Test round-trip conversion maintains data integrity.""" # LCM -> ROS -> LCM original_lcm = Twist(linear=Vector3(1.234, 5.678, 9.012), angular=Vector3(0.111, 0.222, 0.333)) diff --git a/dimos/msgs/geometry_msgs/test_TwistStamped.py b/dimos/msgs/geometry_msgs/test_TwistStamped.py index 8414d4480a..385523a284 100644 --- a/dimos/msgs/geometry_msgs/test_TwistStamped.py +++ b/dimos/msgs/geometry_msgs/test_TwistStamped.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import pickle import time +import pytest try: from geometry_msgs.msg import TwistStamped as ROSTwistStamped @@ -25,7 +25,7 @@ from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test encoding and decoding of TwistStamped to/from binary LCM format.""" twist_source = TwistStamped( ts=time.time(), @@ -46,7 +46,7 @@ def test_lcm_encode_decode(): assert twist_dest == twist_source -def test_pickle_encode_decode(): +def test_pickle_encode_decode() -> None: """Test encoding and decoding of TwistStamped to/from binary pickle format.""" twist_source = TwistStamped( @@ -62,7 +62,7 @@ def test_pickle_encode_decode(): @pytest.mark.ros -def test_twist_stamped_from_ros_msg(): +def test_twist_stamped_from_ros_msg() -> None: """Test creating a TwistStamped from a ROS TwistStamped message.""" ros_msg = ROSTwistStamped() ros_msg.header.frame_id = "world" @@ -88,7 +88,7 @@ def test_twist_stamped_from_ros_msg(): @pytest.mark.ros -def test_twist_stamped_to_ros_msg(): +def test_twist_stamped_to_ros_msg() -> None: """Test converting a TwistStamped to a ROS TwistStamped message.""" twist_stamped = TwistStamped( ts=123.456, @@ -112,7 +112,7 @@ def test_twist_stamped_to_ros_msg(): @pytest.mark.ros -def test_twist_stamped_ros_roundtrip(): +def test_twist_stamped_ros_roundtrip() -> None: """Test round-trip conversion between TwistStamped and ROS TwistStamped.""" original = TwistStamped( ts=123.789, diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py index d001482062..19b992baf4 100644 --- a/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py @@ -16,9 +16,11 @@ import pytest try: - from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance - from geometry_msgs.msg import Twist as ROSTwist - from geometry_msgs.msg import Vector3 as ROSVector3 + from geometry_msgs.msg import ( + Twist as ROSTwist, + TwistWithCovariance as ROSTwistWithCovariance, + Vector3 as ROSVector3, + ) except ImportError: ROSTwist = None ROSTwistWithCovariance = None @@ -31,7 +33,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 -def test_twist_with_covariance_default_init(): +def test_twist_with_covariance_default_init() -> None: """Test that default initialization creates a zero twist with zero covariance.""" if ROSVector3 is None: pytest.skip("ROS not available") @@ -52,7 +54,7 @@ def test_twist_with_covariance_default_init(): assert twist_cov.covariance.shape == (36,) -def test_twist_with_covariance_twist_init(): +def test_twist_with_covariance_twist_init() -> None: """Test initialization with a Twist object.""" linear = Vector3(1.0, 2.0, 3.0) angular = Vector3(0.1, 0.2, 0.3) @@ -71,7 +73,7 @@ def test_twist_with_covariance_twist_init(): assert np.all(twist_cov.covariance == 0.0) -def test_twist_with_covariance_twist_and_covariance_init(): +def test_twist_with_covariance_twist_and_covariance_init() -> None: """Test initialization with twist and covariance.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance = np.arange(36, dtype=float) @@ -86,7 +88,7 @@ def test_twist_with_covariance_twist_and_covariance_init(): assert np.array_equal(twist_cov.covariance, covariance) -def test_twist_with_covariance_tuple_init(): +def test_twist_with_covariance_tuple_init() -> None: """Test initialization with tuple of (linear, angular) velocities.""" linear = [1.0, 2.0, 3.0] angular = [0.1, 0.2, 0.3] @@ -105,7 +107,7 @@ def test_twist_with_covariance_tuple_init(): assert np.array_equal(twist_cov.covariance, covariance) -def test_twist_with_covariance_list_covariance(): +def test_twist_with_covariance_list_covariance() -> None: """Test initialization with covariance as a list.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance_list = list(range(36)) @@ -116,7 +118,7 @@ def test_twist_with_covariance_list_covariance(): assert np.array_equal(twist_cov.covariance, np.array(covariance_list)) -def test_twist_with_covariance_copy_init(): +def test_twist_with_covariance_copy_init() -> None: """Test copy constructor.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance = np.arange(36, dtype=float) @@ -134,7 +136,7 @@ def test_twist_with_covariance_copy_init(): assert copy.covariance[0] != 999.0 -def test_twist_with_covariance_lcm_init(): +def test_twist_with_covariance_lcm_init() -> None: """Test initialization from LCM message.""" lcm_msg = LCMTwistWithCovariance() lcm_msg.twist.linear.x = 1.0 @@ -159,7 +161,7 @@ def test_twist_with_covariance_lcm_init(): assert np.array_equal(twist_cov.covariance, np.arange(36)) -def test_twist_with_covariance_dict_init(): +def test_twist_with_covariance_dict_init() -> None: """Test initialization from dictionary.""" twist_dict = { "twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)), @@ -173,7 +175,7 @@ def test_twist_with_covariance_dict_init(): assert np.array_equal(twist_cov.covariance, np.arange(36)) -def test_twist_with_covariance_dict_init_no_covariance(): +def test_twist_with_covariance_dict_init_no_covariance() -> None: """Test initialization from dictionary without covariance.""" twist_dict = {"twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3))} twist_cov = TwistWithCovariance(twist_dict) @@ -182,7 +184,7 @@ def test_twist_with_covariance_dict_init_no_covariance(): assert np.all(twist_cov.covariance == 0.0) -def test_twist_with_covariance_tuple_of_tuple_init(): +def test_twist_with_covariance_tuple_of_tuple_init() -> None: """Test initialization from tuple of (twist_tuple, covariance).""" twist_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]) covariance = np.arange(36, dtype=float) @@ -197,7 +199,7 @@ def test_twist_with_covariance_tuple_of_tuple_init(): assert np.array_equal(twist_cov.covariance, covariance) -def test_twist_with_covariance_properties(): +def test_twist_with_covariance_properties() -> None: """Test convenience properties.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) twist_cov = TwistWithCovariance(twist) @@ -211,7 +213,7 @@ def test_twist_with_covariance_properties(): assert twist_cov.angular.z == 0.3 -def test_twist_with_covariance_matrix_property(): +def test_twist_with_covariance_matrix_property() -> None: """Test covariance matrix property.""" twist = Twist() covariance_array = np.arange(36, dtype=float) @@ -229,7 +231,7 @@ def test_twist_with_covariance_matrix_property(): assert np.array_equal(twist_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) -def test_twist_with_covariance_repr(): +def test_twist_with_covariance_repr() -> None: """Test string representation.""" twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) twist_cov = TwistWithCovariance(twist) @@ -241,7 +243,7 @@ def test_twist_with_covariance_repr(): assert "36 elements" in repr_str -def test_twist_with_covariance_str(): +def test_twist_with_covariance_str() -> None: """Test string formatting.""" twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) covariance = np.eye(6).flatten() @@ -256,7 +258,7 @@ def test_twist_with_covariance_str(): assert "6.000" in str_repr # Trace of identity matrix is 6 -def test_twist_with_covariance_equality(): +def test_twist_with_covariance_equality() -> None: """Test equality comparison.""" twist1 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) cov1 = np.arange(36, dtype=float) @@ -281,10 +283,10 @@ def test_twist_with_covariance_equality(): # Different type assert twist_cov1 != "not a twist" - assert twist_cov1 != None + assert twist_cov1 is not None -def test_twist_with_covariance_is_zero(): +def test_twist_with_covariance_is_zero() -> None: """Test is_zero method.""" # Zero twist twist_cov1 = TwistWithCovariance() @@ -298,7 +300,7 @@ def test_twist_with_covariance_is_zero(): assert twist_cov2 # Boolean conversion -def test_twist_with_covariance_lcm_encode_decode(): +def test_twist_with_covariance_lcm_encode_decode() -> None: """Test LCM encoding and decoding.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance = np.arange(36, dtype=float) @@ -316,7 +318,7 @@ def test_twist_with_covariance_lcm_encode_decode(): @pytest.mark.ros -def test_twist_with_covariance_from_ros_msg(): +def test_twist_with_covariance_from_ros_msg() -> None: """Test creating from ROS message.""" ros_msg = ROSTwistWithCovariance() ros_msg.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) @@ -335,7 +337,7 @@ def test_twist_with_covariance_from_ros_msg(): @pytest.mark.ros -def test_twist_with_covariance_to_ros_msg(): +def test_twist_with_covariance_to_ros_msg() -> None: """Test converting to ROS message.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance = np.arange(36, dtype=float) @@ -354,7 +356,7 @@ def test_twist_with_covariance_to_ros_msg(): @pytest.mark.ros -def test_twist_with_covariance_ros_roundtrip(): +def test_twist_with_covariance_ros_roundtrip() -> None: """Test round-trip conversion with ROS messages.""" twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) covariance = np.random.rand(36) @@ -366,7 +368,7 @@ def test_twist_with_covariance_ros_roundtrip(): assert restored == original -def test_twist_with_covariance_zero_covariance(): +def test_twist_with_covariance_zero_covariance() -> None: """Test with zero covariance matrix.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) twist_cov = TwistWithCovariance(twist) @@ -375,7 +377,7 @@ def test_twist_with_covariance_zero_covariance(): assert np.trace(twist_cov.covariance_matrix) == 0.0 -def test_twist_with_covariance_diagonal_covariance(): +def test_twist_with_covariance_diagonal_covariance() -> None: """Test with diagonal covariance matrix.""" twist = Twist() covariance = np.zeros(36) @@ -408,7 +410,7 @@ def test_twist_with_covariance_diagonal_covariance(): ([100.0, -100.0, 0.0], [3.14, -3.14, 0.0]), ], ) -def test_twist_with_covariance_parametrized_velocities(linear, angular): +def test_twist_with_covariance_parametrized_velocities(linear, angular) -> None: """Parametrized test for various velocity values.""" twist = Twist(linear, angular) twist_cov = TwistWithCovariance(twist) diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py index 4174814c78..93c7a7b23f 100644 --- a/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py @@ -18,12 +18,14 @@ import pytest try: - from geometry_msgs.msg import TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped - from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance - from geometry_msgs.msg import Twist as ROSTwist - from geometry_msgs.msg import Vector3 as ROSVector3 - from std_msgs.msg import Header as ROSHeader from builtin_interfaces.msg import Time as ROSTime + from geometry_msgs.msg import ( + Twist as ROSTwist, + TwistWithCovariance as ROSTwistWithCovariance, + TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped, + Vector3 as ROSVector3, + ) + from std_msgs.msg import Header as ROSHeader except ImportError: ROSTwistWithCovarianceStamped = None ROSTwist = None @@ -32,9 +34,6 @@ ROSTwistWithCovariance = None ROSVector3 = None -from dimos_lcm.geometry_msgs import TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance @@ -42,7 +41,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 -def test_twist_with_covariance_stamped_default_init(): +def test_twist_with_covariance_stamped_default_init() -> None: """Test default initialization.""" if ROSVector3 is None: pytest.skip("ROS not available") @@ -74,7 +73,7 @@ def test_twist_with_covariance_stamped_default_init(): assert np.all(twist_cov_stamped.covariance == 0.0) -def test_twist_with_covariance_stamped_with_timestamp(): +def test_twist_with_covariance_stamped_with_timestamp() -> None: """Test initialization with specific timestamp.""" ts = 1234567890.123456 frame_id = "base_link" @@ -84,7 +83,7 @@ def test_twist_with_covariance_stamped_with_timestamp(): assert twist_cov_stamped.frame_id == frame_id -def test_twist_with_covariance_stamped_with_twist(): +def test_twist_with_covariance_stamped_with_twist() -> None: """Test initialization with twist.""" ts = 1234567890.123456 frame_id = "odom" @@ -103,7 +102,7 @@ def test_twist_with_covariance_stamped_with_twist(): assert np.array_equal(twist_cov_stamped.covariance, covariance) -def test_twist_with_covariance_stamped_with_tuple(): +def test_twist_with_covariance_stamped_with_tuple() -> None: """Test initialization with tuple of velocities.""" ts = 1234567890.123456 frame_id = "robot_base" @@ -122,7 +121,7 @@ def test_twist_with_covariance_stamped_with_tuple(): assert np.array_equal(twist_cov_stamped.covariance, covariance) -def test_twist_with_covariance_stamped_properties(): +def test_twist_with_covariance_stamped_properties() -> None: """Test convenience properties.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance = np.eye(6).flatten() @@ -144,7 +143,7 @@ def test_twist_with_covariance_stamped_properties(): assert np.trace(cov_matrix) == 6.0 -def test_twist_with_covariance_stamped_str(): +def test_twist_with_covariance_stamped_str() -> None: """Test string representation.""" twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.111, 0.222, 0.333)) covariance = np.eye(6).flatten() * 2.0 @@ -161,7 +160,7 @@ def test_twist_with_covariance_stamped_str(): assert "12.000" in str_repr # Trace of 2*identity is 12 -def test_twist_with_covariance_stamped_lcm_encode_decode(): +def test_twist_with_covariance_stamped_lcm_encode_decode() -> None: """Test LCM encoding and decoding.""" ts = 1234567890.123456 frame_id = "camera_link" @@ -193,7 +192,7 @@ def test_twist_with_covariance_stamped_lcm_encode_decode(): @pytest.mark.ros -def test_twist_with_covariance_stamped_from_ros_msg(): +def test_twist_with_covariance_stamped_from_ros_msg() -> None: """Test creating from ROS message.""" ros_msg = ROSTwistWithCovarianceStamped() @@ -225,7 +224,7 @@ def test_twist_with_covariance_stamped_from_ros_msg(): @pytest.mark.ros -def test_twist_with_covariance_stamped_to_ros_msg(): +def test_twist_with_covariance_stamped_to_ros_msg() -> None: """Test converting to ROS message.""" ts = 1234567890.567890 frame_id = "imu" @@ -253,7 +252,7 @@ def test_twist_with_covariance_stamped_to_ros_msg(): @pytest.mark.ros -def test_twist_with_covariance_stamped_ros_roundtrip(): +def test_twist_with_covariance_stamped_ros_roundtrip() -> None: """Test round-trip conversion with ROS messages.""" ts = 2147483647.987654 # Max int32 value for ROS Time.sec frame_id = "robot_base" @@ -283,7 +282,7 @@ def test_twist_with_covariance_stamped_ros_roundtrip(): assert np.allclose(restored.covariance, original.covariance) -def test_twist_with_covariance_stamped_zero_timestamp(): +def test_twist_with_covariance_stamped_zero_timestamp() -> None: """Test that zero timestamp gets replaced with current time.""" twist_cov_stamped = TwistWithCovarianceStamped(ts=0.0) @@ -292,7 +291,7 @@ def test_twist_with_covariance_stamped_zero_timestamp(): assert twist_cov_stamped.ts <= time.time() -def test_twist_with_covariance_stamped_inheritance(): +def test_twist_with_covariance_stamped_inheritance() -> None: """Test that it properly inherits from TwistWithCovariance and Timestamped.""" twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) covariance = np.eye(6).flatten() @@ -312,7 +311,7 @@ def test_twist_with_covariance_stamped_inheritance(): assert hasattr(twist_cov_stamped, "covariance") -def test_twist_with_covariance_stamped_is_zero(): +def test_twist_with_covariance_stamped_is_zero() -> None: """Test is_zero method inheritance.""" # Zero twist twist_cov_stamped1 = TwistWithCovarianceStamped() @@ -326,7 +325,7 @@ def test_twist_with_covariance_stamped_is_zero(): assert twist_cov_stamped2 # Boolean conversion -def test_twist_with_covariance_stamped_sec_nsec(): +def test_twist_with_covariance_stamped_sec_nsec() -> None: """Test the sec_nsec helper function.""" from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import sec_nsec @@ -360,7 +359,7 @@ def test_twist_with_covariance_stamped_sec_nsec(): "frame_id", ["", "map", "odom", "base_link", "cmd_vel", "sensor/velocity/front"], ) -def test_twist_with_covariance_stamped_frame_ids(frame_id): +def test_twist_with_covariance_stamped_frame_ids(frame_id) -> None: """Test various frame ID values.""" twist_cov_stamped = TwistWithCovarianceStamped(frame_id=frame_id) assert twist_cov_stamped.frame_id == frame_id @@ -373,7 +372,7 @@ def test_twist_with_covariance_stamped_frame_ids(frame_id): assert restored.frame_id == frame_id -def test_twist_with_covariance_stamped_different_covariances(): +def test_twist_with_covariance_stamped_different_covariances() -> None: """Test with different covariance patterns.""" twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.5)) diff --git a/dimos/msgs/geometry_msgs/test_Vector3.py b/dimos/msgs/geometry_msgs/test_Vector3.py index 81325286f9..7ad4e67f16 100644 --- a/dimos/msgs/geometry_msgs/test_Vector3.py +++ b/dimos/msgs/geometry_msgs/test_Vector3.py @@ -18,7 +18,7 @@ from dimos.msgs.geometry_msgs.Vector3 import Vector3 -def test_vector_default_init(): +def test_vector_default_init() -> None: """Test that default initialization of Vector() has x,y,z components all zero.""" v = Vector3() assert v.x == 0.0 @@ -26,10 +26,10 @@ def test_vector_default_init(): assert v.z == 0.0 assert len(v.data) == 3 assert v.to_list() == [0.0, 0.0, 0.0] - assert v.is_zero() == True # Zero vector should be considered zero + assert v.is_zero() # Zero vector should be considered zero -def test_vector_specific_init(): +def test_vector_specific_init() -> None: """Test initialization with specific values and different input types.""" v1 = Vector3(1.0, 2.0) # 2D vector (now becomes 3D with z=0) @@ -67,7 +67,7 @@ def test_vector_specific_init(): assert v6 == original -def test_vector_addition(): +def test_vector_addition() -> None: """Test vector addition.""" v1 = Vector3(1.0, 2.0, 3.0) v2 = Vector3(4.0, 5.0, 6.0) @@ -78,7 +78,7 @@ def test_vector_addition(): assert v_add.z == 9.0 -def test_vector_subtraction(): +def test_vector_subtraction() -> None: """Test vector subtraction.""" v1 = Vector3(1.0, 2.0, 3.0) v2 = Vector3(4.0, 5.0, 6.0) @@ -89,7 +89,7 @@ def test_vector_subtraction(): assert v_sub.z == 3.0 -def test_vector_scalar_multiplication(): +def test_vector_scalar_multiplication() -> None: """Test vector multiplication by a scalar.""" v1 = Vector3(1.0, 2.0, 3.0) @@ -105,7 +105,7 @@ def test_vector_scalar_multiplication(): assert v_rmul.z == 6.0 -def test_vector_scalar_division(): +def test_vector_scalar_division() -> None: """Test vector division by a scalar.""" v2 = Vector3(4.0, 5.0, 6.0) @@ -115,7 +115,7 @@ def test_vector_scalar_division(): assert v_div.z == 3.0 -def test_vector_dot_product(): +def test_vector_dot_product() -> None: """Test vector dot product.""" v1 = Vector3(1.0, 2.0, 3.0) v2 = Vector3(4.0, 5.0, 6.0) @@ -124,7 +124,7 @@ def test_vector_dot_product(): assert dot == 32.0 -def test_vector_length(): +def test_vector_length() -> None: """Test vector length calculation.""" # 2D vector with length 5 (now 3D with z=0) v1 = Vector3(3.0, 4.0) @@ -139,10 +139,10 @@ def test_vector_length(): assert v2.length_squared() == 49.0 -def test_vector_normalize(): +def test_vector_normalize() -> None: """Test vector normalization.""" v = Vector3(2.0, 3.0, 6.0) - assert v.is_zero() == False + assert not v.is_zero() v_norm = v.normalize() length = v.length() @@ -154,19 +154,19 @@ def test_vector_normalize(): assert np.isclose(v_norm.y, expected_y) assert np.isclose(v_norm.z, expected_z) assert np.isclose(v_norm.length(), 1.0) - assert v_norm.is_zero() == False + assert not v_norm.is_zero() # Test normalizing a zero vector v_zero = Vector3(0.0, 0.0, 0.0) - assert v_zero.is_zero() == True + assert v_zero.is_zero() v_zero_norm = v_zero.normalize() assert v_zero_norm.x == 0.0 assert v_zero_norm.y == 0.0 assert v_zero_norm.z == 0.0 - assert v_zero_norm.is_zero() == True + assert v_zero_norm.is_zero() -def test_vector_to_2d(): +def test_vector_to_2d() -> None: """Test conversion to 2D vector.""" v = Vector3(2.0, 3.0, 6.0) @@ -183,7 +183,7 @@ def test_vector_to_2d(): assert v2_2d.z == 0.0 -def test_vector_distance(): +def test_vector_distance() -> None: """Test distance calculations between vectors.""" v1 = Vector3(1.0, 2.0, 3.0) v2 = Vector3(4.0, 6.0, 8.0) @@ -198,7 +198,7 @@ def test_vector_distance(): assert dist_sq == 50.0 # 9 + 16 + 25 -def test_vector_cross_product(): +def test_vector_cross_product() -> None: """Test vector cross product.""" v1 = Vector3(1.0, 0.0, 0.0) # Unit x vector v2 = Vector3(0.0, 1.0, 0.0) # Unit y vector @@ -230,17 +230,17 @@ def test_vector_cross_product(): assert cross_2d.z == -2.0 -def test_vector_zeros(): +def test_vector_zeros() -> None: """Test Vector3.zeros class method.""" # 3D zero vector v_zeros = Vector3.zeros() assert v_zeros.x == 0.0 assert v_zeros.y == 0.0 assert v_zeros.z == 0.0 - assert v_zeros.is_zero() == True + assert v_zeros.is_zero() -def test_vector_ones(): +def test_vector_ones() -> None: """Test Vector3.ones class method.""" # 3D ones vector v_ones = Vector3.ones() @@ -249,7 +249,7 @@ def test_vector_ones(): assert v_ones.z == 1.0 -def test_vector_conversion_methods(): +def test_vector_conversion_methods() -> None: """Test vector conversion methods (to_list, to_tuple, to_numpy).""" v = Vector3(1.0, 2.0, 3.0) @@ -265,7 +265,7 @@ def test_vector_conversion_methods(): assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) -def test_vector_equality(): +def test_vector_equality() -> None: """Test vector equality.""" v1 = Vector3(1, 2, 3) v2 = Vector3(1, 2, 3) @@ -278,75 +278,75 @@ def test_vector_equality(): assert v1 != [1, 2, 3] -def test_vector_is_zero(): +def test_vector_is_zero() -> None: """Test is_zero method for vectors.""" # Default zero vector v0 = Vector3() - assert v0.is_zero() == True + assert v0.is_zero() # Explicit zero vector v1 = Vector3(0.0, 0.0, 0.0) - assert v1.is_zero() == True + assert v1.is_zero() # Zero vector with different initialization (now always 3D) v2 = Vector3(0.0, 0.0) # Becomes (0, 0, 0) - assert v2.is_zero() == True + assert v2.is_zero() # Non-zero vectors v3 = Vector3(1.0, 0.0, 0.0) - assert v3.is_zero() == False + assert not v3.is_zero() v4 = Vector3(0.0, 2.0, 0.0) - assert v4.is_zero() == False + assert not v4.is_zero() v5 = Vector3(0.0, 0.0, 3.0) - assert v5.is_zero() == False + assert not v5.is_zero() # Almost zero (within tolerance) v6 = Vector3(1e-10, 1e-10, 1e-10) - assert v6.is_zero() == True + assert v6.is_zero() # Almost zero (outside tolerance) v7 = Vector3(1e-6, 1e-6, 1e-6) - assert v7.is_zero() == False + assert not v7.is_zero() def test_vector_bool_conversion(): """Test boolean conversion of vectors.""" # Zero vectors should be False v0 = Vector3() - assert bool(v0) == False + assert not bool(v0) v1 = Vector3(0.0, 0.0, 0.0) - assert bool(v1) == False + assert not bool(v1) # Almost zero vectors should be False v2 = Vector3(1e-10, 1e-10, 1e-10) - assert bool(v2) == False + assert not bool(v2) # Non-zero vectors should be True v3 = Vector3(1.0, 0.0, 0.0) - assert bool(v3) == True + assert bool(v3) v4 = Vector3(0.0, 2.0, 0.0) - assert bool(v4) == True + assert bool(v4) v5 = Vector3(0.0, 0.0, 3.0) - assert bool(v5) == True + assert bool(v5) # Direct use in if statements if v0: - assert False, "Zero vector should be False in boolean context" + raise AssertionError("Zero vector should be False in boolean context") else: pass # Expected path if v3: pass # Expected path else: - assert False, "Non-zero vector should be True in boolean context" + raise AssertionError("Non-zero vector should be True in boolean context") -def test_vector_add(): +def test_vector_add() -> None: """Test vector addition operator.""" v1 = Vector3(1.0, 2.0, 3.0) v2 = Vector3(4.0, 5.0, 6.0) @@ -368,7 +368,7 @@ def test_vector_add(): assert (v1 + v_zero) == v1 -def test_vector_add_dim_mismatch(): +def test_vector_add_dim_mismatch() -> None: """Test vector addition with different input dimensions (now all vectors are 3D).""" v1 = Vector3(1.0, 2.0) # Becomes (1, 2, 0) v2 = Vector3(4.0, 5.0, 6.0) # (4, 5, 6) @@ -380,7 +380,7 @@ def test_vector_add_dim_mismatch(): assert v_add_op.z == 6.0 # 0 + 6 -def test_yaw_pitch_roll_accessors(): +def test_yaw_pitch_roll_accessors() -> None: """Test yaw, pitch, and roll accessor properties.""" # Test with a 3D vector v = Vector3(1.0, 2.0, 3.0) @@ -412,7 +412,7 @@ def test_yaw_pitch_roll_accessors(): assert v_neg.yaw == -3.5 -def test_vector_to_quaternion(): +def test_vector_to_quaternion() -> None: """Test vector to quaternion conversion.""" # Test with zero Euler angles (should produce identity quaternion) v_zero = Vector3(0.0, 0.0, 0.0) @@ -450,7 +450,7 @@ def test_vector_to_quaternion(): assert np.isclose(q_x_90.w, expected, atol=1e-10) -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: v_source = Vector3(1.0, 2.0, 3.0) binary_msg = v_source.lcm_encode() diff --git a/dimos/msgs/geometry_msgs/test_publish.py b/dimos/msgs/geometry_msgs/test_publish.py index 4e364dc19a..50578346ae 100644 --- a/dimos/msgs/geometry_msgs/test_publish.py +++ b/dimos/msgs/geometry_msgs/test_publish.py @@ -21,7 +21,7 @@ @pytest.mark.tool -def test_runpublish(): +def test_runpublish() -> None: for i in range(10): msg = Vector3(-5 + i, -5 + i, i) lc = lcm.LCM() @@ -31,16 +31,16 @@ def test_runpublish(): @pytest.mark.tool -def test_receive(): +def test_receive() -> None: lc = lcm.LCM() - def receive(bla, msg): + def receive(bla, msg) -> None: # print("receive", bla, msg) print(Vector3.decode(msg)) lc.subscribe("thing1_vector3#geometry_msgs.Vector3", receive) - def _loop(): + def _loop() -> None: while True: """LCM message handling loop""" try: diff --git a/dimos/msgs/nav_msgs/OccupancyGrid.py b/dimos/msgs/nav_msgs/OccupancyGrid.py index 4bb7495e86..3e144de74f 100644 --- a/dimos/msgs/nav_msgs/OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/OccupancyGrid.py @@ -14,14 +14,13 @@ from __future__ import annotations -import time from enum import IntEnum -from typing import TYPE_CHECKING, BinaryIO, Optional +import time +from typing import TYPE_CHECKING, BinaryIO -import numpy as np -from dimos_lcm.nav_msgs import MapMetaData -from dimos_lcm.nav_msgs import OccupancyGrid as LCMOccupancyGrid +from dimos_lcm.nav_msgs import MapMetaData, OccupancyGrid as LCMOccupancyGrid from dimos_lcm.std_msgs import Time as LCMTime +import numpy as np from scipy import ndimage from dimos.msgs.geometry_msgs import Pose, Vector3, VectorLike @@ -61,11 +60,11 @@ class OccupancyGrid(Timestamped): def __init__( self, - grid: Optional[np.ndarray] = None, - width: Optional[int] = None, - height: Optional[int] = None, + grid: np.ndarray | None = None, + width: int | None = None, + height: int | None = None, resolution: float = 0.05, - origin: Optional[Pose] = None, + origin: Pose | None = None, frame_id: str = "world", ts: float = 0.0, ) -> None: @@ -173,7 +172,7 @@ def unknown_percent(self) -> float: """Percentage of cells that are unknown.""" return (self.unknown_cells / self.total_cells * 100) if self.total_cells > 0 else 0.0 - def inflate(self, radius: float) -> "OccupancyGrid": + def inflate(self, radius: float) -> OccupancyGrid: """Inflate obstacles by a given radius (binary inflation). Args: radius: Inflation radius in meters @@ -187,7 +186,7 @@ def inflate(self, radius: float) -> "OccupancyGrid": grid_array = self.grid # Create circular kernel for binary inflation - kernel_size = 2 * cell_radius + 1 + 2 * cell_radius + 1 y, x = np.ogrid[-cell_radius : cell_radius + 1, -cell_radius : cell_radius + 1] kernel = (x**2 + y**2 <= cell_radius**2).astype(np.uint8) @@ -300,7 +299,7 @@ def lcm_encode(self) -> bytes: return lcm_msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes | BinaryIO) -> "OccupancyGrid": + def lcm_decode(cls, data: bytes | BinaryIO) -> OccupancyGrid: """Decode LCM bytes to OccupancyGrid.""" lcm_msg = LCMOccupancyGrid.lcm_decode(data) @@ -330,13 +329,13 @@ def lcm_decode(cls, data: bytes | BinaryIO) -> "OccupancyGrid": @classmethod def from_pointcloud( cls, - cloud: "PointCloud2", + cloud: PointCloud2, resolution: float = 0.05, min_height: float = 0.1, max_height: float = 2.0, - frame_id: Optional[str] = None, + frame_id: str | None = None, mark_free_radius: float = 0.4, - ) -> "OccupancyGrid": + ) -> OccupancyGrid: """Create an OccupancyGrid from a PointCloud2 message. Args: @@ -462,7 +461,7 @@ def from_pointcloud( return occupancy_grid - def gradient(self, obstacle_threshold: int = 50, max_distance: float = 2.0) -> "OccupancyGrid": + def gradient(self, obstacle_threshold: int = 50, max_distance: float = 2.0) -> OccupancyGrid: """Create a gradient OccupancyGrid for path planning. Creates a gradient where free space has value 0 and values increase near obstacles. @@ -522,7 +521,7 @@ def gradient(self, obstacle_threshold: int = 50, max_distance: float = 2.0) -> " return gradient_grid - def filter_above(self, threshold: int) -> "OccupancyGrid": + def filter_above(self, threshold: int) -> OccupancyGrid: """Create a new OccupancyGrid with only values above threshold. Args: @@ -553,7 +552,7 @@ def filter_above(self, threshold: int) -> "OccupancyGrid": return filtered - def filter_below(self, threshold: int) -> "OccupancyGrid": + def filter_below(self, threshold: int) -> OccupancyGrid: """Create a new OccupancyGrid with only values below threshold. Args: @@ -584,7 +583,7 @@ def filter_below(self, threshold: int) -> "OccupancyGrid": return filtered - def max(self) -> "OccupancyGrid": + def max(self) -> OccupancyGrid: """Create a new OccupancyGrid with all non-unknown cells set to maximum value. Returns: diff --git a/dimos/msgs/nav_msgs/Odometry.py b/dimos/msgs/nav_msgs/Odometry.py index 6e8b6c27fc..3a640b242d 100644 --- a/dimos/msgs/nav_msgs/Odometry.py +++ b/dimos/msgs/nav_msgs/Odometry.py @@ -15,10 +15,10 @@ from __future__ import annotations import time -from typing import TypeAlias +from typing import TYPE_CHECKING, TypeAlias -import numpy as np from dimos_lcm.nav_msgs import Odometry as LCMOdometry +import numpy as np from plum import dispatch try: @@ -30,9 +30,11 @@ from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance -from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.types.timestamped import Timestamped +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs.Vector3 import Vector3 + # Types that can be converted to/from Odometry OdometryConvertable: TypeAlias = ( LCMOdometry | dict[str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist] @@ -281,7 +283,7 @@ def lcm_encode(self) -> bytes: return lcm_msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes) -> "Odometry": + def lcm_decode(cls, data: bytes) -> Odometry: """Decode from LCM binary format.""" lcm_msg = LCMOdometry.lcm_decode(data) @@ -328,7 +330,7 @@ def lcm_decode(cls, data: bytes) -> "Odometry": ) @classmethod - def from_ros_msg(cls, ros_msg: ROSOdometry) -> "Odometry": + def from_ros_msg(cls, ros_msg: ROSOdometry) -> Odometry: """Create an Odometry from a ROS nav_msgs/Odometry message. Args: diff --git a/dimos/msgs/nav_msgs/Path.py b/dimos/msgs/nav_msgs/Path.py index 18a2fb07ee..fa05ae4d6f 100644 --- a/dimos/msgs/nav_msgs/Path.py +++ b/dimos/msgs/nav_msgs/Path.py @@ -14,31 +14,29 @@ from __future__ import annotations -import struct import time -from io import BytesIO -from typing import BinaryIO, TypeAlias - -from dimos_lcm.geometry_msgs import Point as LCMPoint -from dimos_lcm.geometry_msgs import Pose as LCMPose -from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped -from dimos_lcm.geometry_msgs import Quaternion as LCMQuaternion +from typing import TYPE_CHECKING, BinaryIO + +from dimos_lcm.geometry_msgs import ( + Point as LCMPoint, + Pose as LCMPose, + PoseStamped as LCMPoseStamped, + Quaternion as LCMQuaternion, +) from dimos_lcm.nav_msgs import Path as LCMPath -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime +from dimos_lcm.std_msgs import Header as LCMHeader, Time as LCMTime try: from nav_msgs.msg import Path as ROSPath except ImportError: ROSPath = None -from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped -from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable -from dimos.msgs.geometry_msgs.Transform import Transform -from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable from dimos.types.timestamped import Timestamped +if TYPE_CHECKING: + from collections.abc import Iterator + def sec_nsec(ts): s = int(ts) @@ -78,13 +76,13 @@ def last(self) -> PoseStamped | None: """Return the last pose in the path, or None if empty.""" return self.poses[-1] if self.poses else None - def tail(self) -> "Path": + def tail(self) -> Path: """Return a new Path with all poses except the first.""" return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses[1:] if self.poses else []) - def push(self, pose: PoseStamped) -> "Path": + def push(self, pose: PoseStamped) -> Path: """Return a new Path with the pose appended (immutable).""" - return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses + [pose]) + return Path(ts=self.ts, frame_id=self.frame_id, poses=[*self.poses, pose]) def push_mut(self, pose: PoseStamped) -> None: """Append a pose to this path (mutable).""" @@ -130,7 +128,7 @@ def lcm_encode(self) -> bytes: return lcm_msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes | BinaryIO) -> "Path": + def lcm_decode(cls, data: bytes | BinaryIO) -> Path: """Decode LCM bytes to Path.""" lcm_msg = LCMPath.lcm_decode(data) @@ -169,23 +167,23 @@ def __getitem__(self, index: int | slice) -> PoseStamped | list[PoseStamped]: """Allow indexing and slicing of poses.""" return self.poses[index] - def __iter__(self): + def __iter__(self) -> Iterator: """Allow iteration over poses.""" return iter(self.poses) - def slice(self, start: int, end: int | None = None) -> "Path": + def slice(self, start: int, end: int | None = None) -> Path: """Return a new Path with a slice of poses.""" return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses[start:end]) - def extend(self, other: "Path") -> "Path": + def extend(self, other: Path) -> Path: """Return a new Path with poses from both paths (immutable).""" return Path(ts=self.ts, frame_id=self.frame_id, poses=self.poses + other.poses) - def extend_mut(self, other: "Path") -> None: + def extend_mut(self, other: Path) -> None: """Extend this path with poses from another path (mutable).""" self.poses.extend(other.poses) - def reverse(self) -> "Path": + def reverse(self) -> Path: """Return a new Path with poses in reverse order.""" return Path(ts=self.ts, frame_id=self.frame_id, poses=list(reversed(self.poses))) @@ -194,7 +192,7 @@ def clear(self) -> None: self.poses.clear() @classmethod - def from_ros_msg(cls, ros_msg: ROSPath) -> "Path": + def from_ros_msg(cls, ros_msg: ROSPath) -> Path: """Create a Path from a ROS nav_msgs/Path message. Args: diff --git a/dimos/msgs/nav_msgs/__init__.py b/dimos/msgs/nav_msgs/__init__.py index 9ea87f3f78..9df397c57c 100644 --- a/dimos/msgs/nav_msgs/__init__.py +++ b/dimos/msgs/nav_msgs/__init__.py @@ -1,5 +1,5 @@ from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, MapMetaData, OccupancyGrid -from dimos.msgs.nav_msgs.Path import Path from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.nav_msgs.Path import Path -__all__ = ["Path", "OccupancyGrid", "MapMetaData", "CostValues", "Odometry"] +__all__ = ["CostValues", "MapMetaData", "OccupancyGrid", "Odometry", "Path"] diff --git a/dimos/msgs/nav_msgs/test_OccupancyGrid.py b/dimos/msgs/nav_msgs/test_OccupancyGrid.py index 83277b54bc..a4cd36f9c0 100644 --- a/dimos/msgs/nav_msgs/test_OccupancyGrid.py +++ b/dimos/msgs/nav_msgs/test_OccupancyGrid.py @@ -27,7 +27,7 @@ from dimos.utils.testing import get_data -def test_empty_grid(): +def test_empty_grid() -> None: """Test creating an empty grid.""" grid = OccupancyGrid() assert grid.width == 0 @@ -37,7 +37,7 @@ def test_empty_grid(): assert grid.frame_id == "world" -def test_grid_with_dimensions(): +def test_grid_with_dimensions() -> None: """Test creating a grid with specified dimensions.""" grid = OccupancyGrid(width=10, height=10, resolution=0.1, frame_id="map") assert grid.width == 10 @@ -50,7 +50,7 @@ def test_grid_with_dimensions(): assert grid.unknown_percent == 100.0 -def test_grid_from_numpy_array(): +def test_grid_from_numpy_array() -> None: """Test creating a grid from a numpy array.""" data = np.zeros((20, 30), dtype=np.int8) data[5:10, 10:20] = 100 # Add some obstacles @@ -78,7 +78,7 @@ def test_grid_from_numpy_array(): assert abs(grid.unknown_percent - 1.5) < 0.1 -def test_world_grid_coordinate_conversion(): +def test_world_grid_coordinate_conversion() -> None: """Test converting between world and grid coordinates.""" data = np.zeros((20, 30), dtype=np.int8) origin = Pose(1.0, 2.0, 0.0) @@ -95,7 +95,7 @@ def test_world_grid_coordinate_conversion(): assert world_pos.y == 2.25 -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test LCM encoding and decoding.""" data = np.zeros((20, 30), dtype=np.int8) data[5:10, 10:20] = 100 # Add some obstacles @@ -129,7 +129,7 @@ def test_lcm_encode_decode(): assert decoded.grid[5, 10] == 50 # Value we set should be preserved in grid -def test_string_representation(): +def test_string_representation() -> None: """Test string representations.""" grid = OccupancyGrid(width=10, height=10, resolution=0.1, frame_id="map") @@ -148,7 +148,7 @@ def test_string_representation(): assert "resolution=0.1" in repr_str -def test_grid_property_sync(): +def test_grid_property_sync() -> None: """Test that the grid property works correctly.""" grid = OccupancyGrid(width=5, height=5, resolution=0.1, frame_id="map") @@ -161,14 +161,14 @@ def test_grid_property_sync(): assert grid.grid[0, 0] == 50 -def test_invalid_grid_dimensions(): +def test_invalid_grid_dimensions() -> None: """Test handling of invalid grid dimensions.""" # Test with non-2D array with pytest.raises(ValueError, match="Grid must be a 2D array"): OccupancyGrid(grid=np.zeros(10), resolution=0.1) -def test_from_pointcloud(): +def test_from_pointcloud() -> None: """Test creating OccupancyGrid from PointCloud2.""" file_path = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle" with open(file_path, "rb") as f: @@ -191,7 +191,7 @@ def test_from_pointcloud(): assert occupancygrid.occupied_cells > 0 # Should have some occupied cells -def test_gradient(): +def test_gradient() -> None: """Test converting occupancy grid to gradient field.""" # Create a small test grid with an obstacle in the middle data = np.zeros((10, 10), dtype=np.int8) @@ -241,7 +241,7 @@ def test_gradient(): assert gradient_with_unknown.unknown_cells == 8 # All unknowns preserved -def test_filter_above(): +def test_filter_above() -> None: """Test filtering cells above threshold.""" # Create test grid with various values data = np.array( @@ -280,7 +280,7 @@ def test_filter_above(): assert filtered.frame_id == grid.frame_id -def test_filter_below(): +def test_filter_below() -> None: """Test filtering cells below threshold.""" # Create test grid with various values data = np.array( @@ -321,7 +321,7 @@ def test_filter_below(): assert filtered.frame_id == grid.frame_id -def test_max(): +def test_max() -> None: """Test setting all non-unknown cells to maximum.""" # Create test grid with various values data = np.array( @@ -366,7 +366,7 @@ def test_max(): @pytest.mark.lcm -def test_lcm_broadcast(): +def test_lcm_broadcast() -> None: """Test broadcasting OccupancyGrid and gradient over LCM.""" file_path = get_data("lcm_msgs") / "sensor_msgs/PointCloud2.pickle" with open(file_path, "rb") as f: @@ -412,13 +412,13 @@ def test_lcm_broadcast(): print("\nNo occupied cells found for sampling") # Check statistics - print(f"\nOriginal grid stats:") + print("\nOriginal grid stats:") print(f" Occupied (100): {np.sum(occupancygrid.grid == 100)} cells") print(f" Inflated (99): {np.sum(occupancygrid.grid == 99)} cells") print(f" Free (0): {np.sum(occupancygrid.grid == 0)} cells") print(f" Unknown (-1): {np.sum(occupancygrid.grid == -1)} cells") - print(f"\nGradient grid stats:") + print("\nGradient grid stats:") print(f" Max gradient (100): {np.sum(gradient_grid.grid == 100)} cells") print( f" High gradient (80-99): {np.sum((gradient_grid.grid >= 80) & (gradient_grid.grid < 100))} cells" @@ -461,11 +461,11 @@ def test_lcm_broadcast(): lcm.publish(Topic("/global_costmap", OccupancyGrid), occupancygrid) lcm.publish(Topic("/global_gradient", OccupancyGrid), gradient_grid) - print(f"\nPublished to LCM:") + print("\nPublished to LCM:") print(f" /global_map: PointCloud2 with {len(pointcloud)} points") print(f" /global_costmap: {occupancygrid}") print(f" /global_gradient: {gradient_grid}") - print(f"\nGradient info:") - print(f" Values: 0 (free far from obstacles) -> 100 (at obstacles)") + print("\nGradient info:") + print(" Values: 0 (free far from obstacles) -> 100 (at obstacles)") print(f" Unknown cells: {gradient_grid.unknown_cells} (preserved as -1)") - print(f" Max distance for gradient: 5.0 meters") + print(" Max distance for gradient: 5.0 meters") diff --git a/dimos/msgs/nav_msgs/test_Odometry.py b/dimos/msgs/nav_msgs/test_Odometry.py index 2fee199b1b..e61bb8e8da 100644 --- a/dimos/msgs/nav_msgs/test_Odometry.py +++ b/dimos/msgs/nav_msgs/test_Odometry.py @@ -18,16 +18,18 @@ import pytest try: + from builtin_interfaces.msg import Time as ROSTime + from geometry_msgs.msg import ( + Point as ROSPoint, + Pose as ROSPose, + PoseWithCovariance as ROSPoseWithCovariance, + Quaternion as ROSQuaternion, + Twist as ROSTwist, + TwistWithCovariance as ROSTwistWithCovariance, + Vector3 as ROSVector3, + ) from nav_msgs.msg import Odometry as ROSOdometry - from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance - from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance - from geometry_msgs.msg import Pose as ROSPose - from geometry_msgs.msg import Twist as ROSTwist - from geometry_msgs.msg import Point as ROSPoint - from geometry_msgs.msg import Quaternion as ROSQuaternion - from geometry_msgs.msg import Vector3 as ROSVector3 from std_msgs.msg import Header as ROSHeader - from builtin_interfaces.msg import Time as ROSTime except ImportError: ROSTwist = None ROSHeader = None @@ -40,18 +42,16 @@ ROSTwistWithCovariance = None ROSVector3 = None -from dimos_lcm.nav_msgs import Odometry as LCMOdometry -from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance from dimos.msgs.geometry_msgs.Twist import Twist from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.nav_msgs.Odometry import Odometry -def test_odometry_default_init(): +def test_odometry_default_init() -> None: """Test default initialization.""" if ROSVector3 is None: pytest.skip("ROS not available") @@ -99,7 +99,7 @@ def test_odometry_default_init(): assert np.all(odom.twist.covariance == 0.0) -def test_odometry_with_frames(): +def test_odometry_with_frames() -> None: """Test initialization with frame IDs.""" ts = 1234567890.123456 frame_id = "odom" @@ -112,7 +112,7 @@ def test_odometry_with_frames(): assert odom.child_frame_id == child_frame_id -def test_odometry_with_pose_and_twist(): +def test_odometry_with_pose_and_twist() -> None: """Test initialization with pose and twist.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) @@ -126,7 +126,7 @@ def test_odometry_with_pose_and_twist(): assert odom.twist.twist.angular.z == 0.1 -def test_odometry_with_covariances(): +def test_odometry_with_covariances() -> None: """Test initialization with pose and twist with covariances.""" pose = Pose(1.0, 2.0, 3.0) pose_cov = np.arange(36, dtype=float) @@ -150,7 +150,7 @@ def test_odometry_with_covariances(): assert np.array_equal(odom.twist.covariance, twist_cov) -def test_odometry_copy_constructor(): +def test_odometry_copy_constructor() -> None: """Test copy constructor.""" original = Odometry( ts=1000.0, @@ -168,7 +168,7 @@ def test_odometry_copy_constructor(): assert copy.twist is not original.twist -def test_odometry_dict_init(): +def test_odometry_dict_init() -> None: """Test initialization from dictionary.""" odom_dict = { "ts": 1000.0, @@ -187,7 +187,7 @@ def test_odometry_dict_init(): assert odom.twist.linear.x == 0.5 -def test_odometry_properties(): +def test_odometry_properties() -> None: """Test convenience properties.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) @@ -230,7 +230,7 @@ def test_odometry_properties(): assert odom.yaw == pose.yaw -def test_odometry_str_repr(): +def test_odometry_str_repr() -> None: """Test string representations.""" odom = Odometry( ts=1234567890.123456, @@ -253,7 +253,7 @@ def test_odometry_str_repr(): assert "0.500" in str_repr -def test_odometry_equality(): +def test_odometry_equality() -> None: """Test equality comparison.""" odom1 = Odometry( ts=1000.0, @@ -284,7 +284,7 @@ def test_odometry_equality(): assert odom1 != "not an odometry" -def test_odometry_lcm_encode_decode(): +def test_odometry_lcm_encode_decode() -> None: """Test LCM encoding and decoding.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) pose_cov = np.arange(36, dtype=float) @@ -312,7 +312,7 @@ def test_odometry_lcm_encode_decode(): @pytest.mark.ros -def test_odometry_from_ros_msg(): +def test_odometry_from_ros_msg() -> None: """Test creating from ROS message.""" ros_msg = ROSOdometry() @@ -350,7 +350,7 @@ def test_odometry_from_ros_msg(): @pytest.mark.ros -def test_odometry_to_ros_msg(): +def test_odometry_to_ros_msg() -> None: """Test converting to ROS message.""" pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) pose_cov = np.arange(36, dtype=float) @@ -394,7 +394,7 @@ def test_odometry_to_ros_msg(): @pytest.mark.ros -def test_odometry_ros_roundtrip(): +def test_odometry_ros_roundtrip() -> None: """Test round-trip conversion with ROS messages.""" pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) pose_cov = np.random.rand(36) @@ -420,7 +420,7 @@ def test_odometry_ros_roundtrip(): assert restored.twist == original.twist -def test_odometry_zero_timestamp(): +def test_odometry_zero_timestamp() -> None: """Test that zero timestamp gets replaced with current time.""" odom = Odometry(ts=0.0) @@ -429,7 +429,7 @@ def test_odometry_zero_timestamp(): assert odom.ts <= time.time() -def test_odometry_with_just_pose(): +def test_odometry_with_just_pose() -> None: """Test initialization with just a Pose (no covariance).""" pose = Pose(1.0, 2.0, 3.0) @@ -442,7 +442,7 @@ def test_odometry_with_just_pose(): assert np.all(odom.twist.covariance == 0.0) # Twist should also be zero -def test_odometry_with_just_twist(): +def test_odometry_with_just_twist() -> None: """Test initialization with just a Twist (no covariance).""" twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) @@ -465,7 +465,7 @@ def test_odometry_with_just_twist(): ("", ""), # Empty frames ], ) -def test_odometry_frame_combinations(frame_id, child_frame_id): +def test_odometry_frame_combinations(frame_id, child_frame_id) -> None: """Test various frame ID combinations.""" odom = Odometry(frame_id=frame_id, child_frame_id=child_frame_id) @@ -482,7 +482,7 @@ def test_odometry_frame_combinations(frame_id, child_frame_id): assert restored.child_frame_id == child_frame_id -def test_odometry_typical_robot_scenario(): +def test_odometry_typical_robot_scenario() -> None: """Test a typical robot odometry scenario.""" # Robot moving forward at 0.5 m/s with slight rotation odom = Odometry( diff --git a/dimos/msgs/nav_msgs/test_Path.py b/dimos/msgs/nav_msgs/test_Path.py index 94028d7959..9f4c39b8a0 100644 --- a/dimos/msgs/nav_msgs/test_Path.py +++ b/dimos/msgs/nav_msgs/test_Path.py @@ -12,14 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import pytest - try: - from nav_msgs.msg import Path as ROSPath from geometry_msgs.msg import PoseStamped as ROSPoseStamped + from nav_msgs.msg import Path as ROSPath except ImportError: ROSPoseStamped = None ROSPath = None @@ -38,7 +36,7 @@ def create_test_pose(x: float, y: float, z: float, frame_id: str = "map") -> Pos ) -def test_init_empty(): +def test_init_empty() -> None: """Test creating an empty path.""" path = Path(frame_id="map") assert path.frame_id == "map" @@ -47,7 +45,7 @@ def test_init_empty(): assert path.poses == [] -def test_init_with_poses(): +def test_init_with_poses() -> None: """Test creating a path with initial poses.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(frame_id="map", poses=poses) @@ -56,7 +54,7 @@ def test_init_with_poses(): assert path.poses == poses -def test_head(): +def test_head() -> None: """Test getting the first pose.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(poses=poses) @@ -67,7 +65,7 @@ def test_head(): assert empty_path.head() is None -def test_last(): +def test_last() -> None: """Test getting the last pose.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(poses=poses) @@ -78,7 +76,7 @@ def test_last(): assert empty_path.last() is None -def test_tail(): +def test_tail() -> None: """Test getting all poses except the first.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(poses=poses) @@ -96,7 +94,7 @@ def test_tail(): assert len(empty_path.tail()) == 0 -def test_push_immutable(): +def test_push_immutable() -> None: """Test immutable push operation.""" path = Path(frame_id="map") pose1 = create_test_pose(1, 1, 0) @@ -115,7 +113,7 @@ def test_push_immutable(): assert path3.poses == [pose1, pose2] -def test_push_mutable(): +def test_push_mutable() -> None: """Test mutable push operation.""" path = Path(frame_id="map") pose1 = create_test_pose(1, 1, 0) @@ -131,7 +129,7 @@ def test_push_mutable(): assert path.poses == [pose1, pose2] -def test_indexing(): +def test_indexing() -> None: """Test indexing and slicing.""" poses = [create_test_pose(i, i, 0) for i in range(5)] path = Path(poses=poses) @@ -146,7 +144,7 @@ def test_indexing(): assert path[3:] == poses[3:] -def test_iteration(): +def test_iteration() -> None: """Test iterating over poses.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(poses=poses) @@ -157,7 +155,7 @@ def test_iteration(): assert collected == poses -def test_slice_method(): +def test_slice_method() -> None: """Test slice method.""" poses = [create_test_pose(i, i, 0) for i in range(5)] path = Path(frame_id="map", poses=poses) @@ -172,7 +170,7 @@ def test_slice_method(): assert sliced2.poses == poses[2:] -def test_extend_immutable(): +def test_extend_immutable() -> None: """Test immutable extend operation.""" poses1 = [create_test_pose(i, i, 0) for i in range(2)] poses2 = [create_test_pose(i + 2, i + 2, 0) for i in range(2)] @@ -187,7 +185,7 @@ def test_extend_immutable(): assert extended.frame_id == "map" # Keeps first path's frame -def test_extend_mutable(): +def test_extend_mutable() -> None: """Test mutable extend operation.""" poses1 = [create_test_pose(i, i, 0) for i in range(2)] poses2 = [create_test_pose(i + 2, i + 2, 0) for i in range(2)] @@ -198,13 +196,13 @@ def test_extend_mutable(): path1.extend_mut(path2) assert len(path1) == 4 # Check poses are the same as concatenation - for i, (p1, p2) in enumerate(zip(path1.poses, poses1 + poses2)): + for _i, (p1, p2) in enumerate(zip(path1.poses, poses1 + poses2, strict=False)): assert p1.x == p2.x assert p1.y == p2.y assert p1.z == p2.z -def test_reverse(): +def test_reverse() -> None: """Test reverse operation.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(poses=poses) @@ -214,7 +212,7 @@ def test_reverse(): assert reversed_path.poses == list(reversed(poses)) -def test_clear(): +def test_clear() -> None: """Test clear operation.""" poses = [create_test_pose(i, i, 0) for i in range(3)] path = Path(poses=poses) @@ -224,7 +222,7 @@ def test_clear(): assert path.poses == [] -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test encoding and decoding of Path to/from binary LCM format.""" # Create path with poses # Use timestamps that can be represented exactly in float64 @@ -258,7 +256,7 @@ def test_lcm_encode_decode(): # Check poses assert len(path_dest.poses) == len(path_source.poses) - for orig, decoded in zip(path_source.poses, path_dest.poses): + for orig, decoded in zip(path_source.poses, path_dest.poses, strict=False): # Check pose timestamps assert abs(decoded.ts - orig.ts) < 1e-6 # All poses should have the path's frame_id @@ -276,7 +274,7 @@ def test_lcm_encode_decode(): assert decoded.orientation.w == orig.orientation.w -def test_lcm_encode_decode_empty(): +def test_lcm_encode_decode_empty() -> None: """Test encoding and decoding of empty Path.""" path_source = Path(frame_id="base_link") @@ -288,7 +286,7 @@ def test_lcm_encode_decode_empty(): assert len(path_dest.poses) == 0 -def test_str_representation(): +def test_str_representation() -> None: """Test string representation.""" path = Path(frame_id="map") assert str(path) == "Path(frame_id='map', poses=0)" @@ -299,7 +297,7 @@ def test_str_representation(): @pytest.mark.ros -def test_path_from_ros_msg(): +def test_path_from_ros_msg() -> None: """Test creating a Path from a ROS Path message.""" ros_msg = ROSPath() ros_msg.header.frame_id = "map" @@ -335,7 +333,7 @@ def test_path_from_ros_msg(): @pytest.mark.ros -def test_path_to_ros_msg(): +def test_path_to_ros_msg() -> None: """Test converting a Path to a ROS Path message.""" poses = [ PoseStamped( @@ -362,7 +360,7 @@ def test_path_to_ros_msg(): @pytest.mark.ros -def test_path_ros_roundtrip(): +def test_path_ros_roundtrip() -> None: """Test round-trip conversion between Path and ROS Path.""" poses = [ PoseStamped( @@ -383,7 +381,7 @@ def test_path_ros_roundtrip(): assert restored.ts == original.ts assert len(restored.poses) == len(original.poses) - for orig_pose, rest_pose in zip(original.poses, restored.poses): + for orig_pose, rest_pose in zip(original.poses, restored.poses, strict=False): assert rest_pose.position.x == orig_pose.position.x assert rest_pose.position.y == orig_pose.position.y assert rest_pose.position.z == orig_pose.position.z diff --git a/dimos/msgs/sensor_msgs/CameraInfo.py b/dimos/msgs/sensor_msgs/CameraInfo.py index 5ce0f76353..3d2a118b0d 100644 --- a/dimos/msgs/sensor_msgs/CameraInfo.py +++ b/dimos/msgs/sensor_msgs/CameraInfo.py @@ -15,18 +15,15 @@ from __future__ import annotations import time -from typing import List, Optional - -import numpy as np # Import LCM types from dimos_lcm.sensor_msgs import CameraInfo as LCMCameraInfo from dimos_lcm.std_msgs.Header import Header +import numpy as np # Import ROS types try: - from sensor_msgs.msg import CameraInfo as ROSCameraInfo - from sensor_msgs.msg import RegionOfInterest as ROSRegionOfInterest + from sensor_msgs.msg import CameraInfo as ROSCameraInfo, RegionOfInterest as ROSRegionOfInterest from std_msgs.msg import Header as ROSHeader ROS_AVAILABLE = True @@ -46,15 +43,15 @@ def __init__( height: int = 0, width: int = 0, distortion_model: str = "", - D: Optional[List[float]] = None, - K: Optional[List[float]] = None, - R: Optional[List[float]] = None, - P: Optional[List[float]] = None, + D: list[float] | None = None, + K: list[float] | None = None, + R: list[float] | None = None, + P: list[float] | None = None, binning_x: int = 0, binning_y: int = 0, frame_id: str = "", - ts: Optional[float] = None, - ): + ts: float | None = None, + ) -> None: """Initialize CameraInfo. Args: @@ -110,7 +107,7 @@ def from_yaml(cls, yaml_file: str) -> CameraInfo: """ import yaml - with open(yaml_file, "r") as f: + with open(yaml_file) as f: data = yaml.safe_load(f) # Extract basic parameters @@ -177,7 +174,7 @@ def set_R_matrix(self, R: np.ndarray): raise ValueError(f"R matrix must be 3x3, got {R.shape}") self.R = R.flatten().tolist() - def set_D_coeffs(self, D: np.ndarray): + def set_D_coeffs(self, D: np.ndarray) -> None: """Set distortion coefficients from numpy array.""" self.D = D.flatten().tolist() @@ -222,7 +219,7 @@ def lcm_encode(self) -> bytes: return msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes) -> "CameraInfo": + def lcm_decode(cls, data: bytes) -> CameraInfo: """Decode from LCM CameraInfo bytes.""" msg = LCMCameraInfo.lcm_decode(data) @@ -254,7 +251,7 @@ def lcm_decode(cls, data: bytes) -> "CameraInfo": return camera_info @classmethod - def from_ros_msg(cls, ros_msg: "ROSCameraInfo") -> "CameraInfo": + def from_ros_msg(cls, ros_msg: ROSCameraInfo) -> CameraInfo: """Create CameraInfo from ROS sensor_msgs/CameraInfo message. Args: @@ -292,7 +289,7 @@ def from_ros_msg(cls, ros_msg: "ROSCameraInfo") -> "CameraInfo": return camera_info - def to_ros_msg(self) -> "ROSCameraInfo": + def to_ros_msg(self) -> ROSCameraInfo: """Convert to ROS sensor_msgs/CameraInfo message. Returns: @@ -376,7 +373,7 @@ def __eq__(self, other) -> bool: class CalibrationProvider: """Provides lazy-loaded access to camera calibration YAML files in a directory.""" - def __init__(self, calibration_dir): + def __init__(self, calibration_dir) -> None: """Initialize with a directory containing calibration YAML files. Args: diff --git a/dimos/msgs/sensor_msgs/Image.py b/dimos/msgs/sensor_msgs/Image.py index 36f6f1d545..eb1bf2d049 100644 --- a/dimos/msgs/sensor_msgs/Image.py +++ b/dimos/msgs/sensor_msgs/Image.py @@ -15,23 +15,20 @@ from __future__ import annotations import base64 -import functools import time -from typing import Literal, Optional, TypedDict +from typing import TYPE_CHECKING, Literal, TypedDict import cv2 -import numpy as np -import reactivex as rx from dimos_lcm.sensor_msgs.Image import Image as LCMImage from dimos_lcm.std_msgs.Header import Header +import numpy as np +import reactivex as rx from reactivex import operators as ops -from reactivex.observable import Observable from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( HAS_CUDA, HAS_NVIMGCODEC, NVIMGCODEC_LAST_USED, - AbstractImage, ImageFormat, ) from dimos.msgs.sensor_msgs.image_impls.CudaImage import CudaImage @@ -39,6 +36,13 @@ from dimos.types.timestamped import Timestamped, TimestampedBufferCollection, to_human_readable from dimos.utils.reactive import quality_barrier +if TYPE_CHECKING: + from reactivex.observable import Observable + + from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( + AbstractImage, + ) + try: import cupy as cp # type: ignore except Exception: @@ -70,7 +74,7 @@ def __init__( format: ImageFormat | None = None, frame_id: str | None = None, ts: float | None = None, - ): + ) -> None: """Construct an Image facade. Usage: @@ -120,7 +124,7 @@ def __str__(self) -> str: ) @classmethod - def from_impl(cls, impl: AbstractImage) -> "Image": + def from_impl(cls, impl: AbstractImage) -> Image: return cls(impl) @classmethod @@ -130,7 +134,7 @@ def from_numpy( format: ImageFormat = ImageFormat.BGR, to_cuda: bool = False, **kwargs, - ) -> "Image": + ) -> Image: if kwargs.pop("to_gpu", False): to_cuda = True if to_cuda and HAS_CUDA: @@ -154,7 +158,7 @@ def from_numpy( @classmethod def from_file( cls, filepath: str, format: ImageFormat = ImageFormat.RGB, to_cuda: bool = False, **kwargs - ) -> "Image": + ) -> Image: if kwargs.pop("to_gpu", False): to_cuda = True arr = cv2.imread(filepath, cv2.IMREAD_UNCHANGED) @@ -173,7 +177,7 @@ def from_file( @classmethod def from_opencv( cls, cv_image: np.ndarray, format: ImageFormat = ImageFormat.BGR, **kwargs - ) -> "Image": + ) -> Image: """Construct from an OpenCV image (NumPy array).""" return cls( NumpyImage(cv_image, format, kwargs.get("frame_id", ""), kwargs.get("ts", time.time())) @@ -181,8 +185,8 @@ def from_opencv( @classmethod def from_depth( - cls, depth_data, frame_id: str = "", ts: float = None, to_cuda: bool = False - ) -> "Image": + cls, depth_data, frame_id: str = "", ts: float | None = None, to_cuda: bool = False + ) -> Image: arr = np.asarray(depth_data) if arr.dtype != np.float32: arr = arr.astype(np.float32) @@ -266,10 +270,10 @@ def shape(self): def dtype(self): return self._impl.dtype - def copy(self) -> "Image": + def copy(self) -> Image: return Image(self._impl.copy()) - def to_cpu(self) -> "Image": + def to_cpu(self) -> Image: if isinstance(self._impl, NumpyImage): return self.copy() @@ -284,7 +288,7 @@ def to_cpu(self) -> "Image": ) ) - def to_cupy(self) -> "Image": + def to_cupy(self) -> Image: if isinstance(self._impl, CudaImage): return self.copy() return Image( @@ -296,19 +300,19 @@ def to_cupy(self) -> "Image": def to_opencv(self) -> np.ndarray: return self._impl.to_opencv() - def to_rgb(self) -> "Image": + def to_rgb(self) -> Image: return Image(self._impl.to_rgb()) - def to_bgr(self) -> "Image": + def to_bgr(self) -> Image: return Image(self._impl.to_bgr()) - def to_grayscale(self) -> "Image": + def to_grayscale(self) -> Image: return Image(self._impl.to_grayscale()) - def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> "Image": + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> Image: return Image(self._impl.resize(width, height, interpolation)) - def crop(self, x: int, y: int, width: int, height: int) -> "Image": + def crop(self, x: int, y: int, width: int, height: int) -> Image: return Image(self._impl.crop(x, y, width, height)) @property @@ -323,8 +327,8 @@ def to_base64( self, quality: int = 80, *, - max_width: Optional[int] = None, - max_height: Optional[int] = None, + max_width: int | None = None, + max_height: int | None = None, ) -> str: """Encode the image as a base64 JPEG string. @@ -346,8 +350,8 @@ def to_base64( scale = min(scale, max_height / height) if scale < 1.0: - new_width = max(1, int(round(width * scale))) - new_height = max(1, int(round(height * scale))) + new_width = max(1, round(width * scale)) + new_height = max(1, round(height * scale)) bgr_image = cv2.resize(bgr_image, (new_width, new_height), interpolation=cv2.INTER_AREA) encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), int(np.clip(quality, 0, 100))] @@ -366,7 +370,7 @@ def agent_encode(self) -> AgentImageMessage: ] # LCM encode/decode - def lcm_encode(self, frame_id: Optional[str] = None) -> bytes: + def lcm_encode(self, frame_id: str | None = None) -> bytes: """Convert to LCM Image message.""" msg = LCMImage() @@ -402,7 +406,7 @@ def lcm_encode(self, frame_id: Optional[str] = None) -> bytes: return msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes, **kwargs) -> "Image": + def lcm_decode(cls, data: bytes, **kwargs) -> Image: msg = LCMImage.lcm_decode(data) fmt, dtype, channels = _parse_lcm_encoding(msg.encoding) arr = np.frombuffer(msg.data, dtype=dtype) @@ -442,7 +446,7 @@ def csrt_update(self, *args, **kwargs): return self._impl.csrt_update(*args, **kwargs) # type: ignore @classmethod - def from_ros_msg(cls, ros_msg: ROSImage) -> "Image": + def from_ros_msg(cls, ros_msg: ROSImage) -> Image: """Create an Image from a ROS sensor_msgs/Image message. Args: @@ -546,7 +550,7 @@ def __len__(self) -> int: def __getstate__(self): return {"data": self.data, "format": self.format, "frame_id": self.frame_id, "ts": self.ts} - def __setstate__(self, state): + def __setstate__(self, state) -> None: self.__init__( data=state.get("data"), format=state.get("format"), @@ -562,11 +566,11 @@ def __setstate__(self, state): HAS_NVIMGCODEC = HAS_NVIMGCODEC __all__ = [ "HAS_CUDA", - "ImageFormat", - "NVIMGCODEC_LAST_USED", "HAS_NVIMGCODEC", - "sharpness_window", + "NVIMGCODEC_LAST_USED", + "ImageFormat", "sharpness_barrier", + "sharpness_window", ] diff --git a/dimos/msgs/sensor_msgs/Joy.py b/dimos/msgs/sensor_msgs/Joy.py index e528b304b6..aa8611655a 100644 --- a/dimos/msgs/sensor_msgs/Joy.py +++ b/dimos/msgs/sensor_msgs/Joy.py @@ -15,7 +15,7 @@ from __future__ import annotations import time -from typing import List, TypeAlias +from typing import TypeAlias from dimos_lcm.sensor_msgs import Joy as LCMJoy @@ -30,7 +30,7 @@ # Types that can be converted to/from Joy JoyConvertable: TypeAlias = ( - tuple[List[float], List[int]] | dict[str, List[float] | List[int]] | LCMJoy + tuple[list[float], list[int]] | dict[str, list[float] | list[int]] | LCMJoy ) @@ -43,16 +43,16 @@ class Joy(Timestamped): msg_name = "sensor_msgs.Joy" ts: float frame_id: str - axes: List[float] - buttons: List[int] + axes: list[float] + buttons: list[int] @dispatch def __init__( self, ts: float = 0.0, frame_id: str = "", - axes: List[float] | None = None, - buttons: List[int] | None = None, + axes: list[float] | None = None, + buttons: list[int] | None = None, ) -> None: """Initialize a Joy message. @@ -68,7 +68,7 @@ def __init__( self.buttons = buttons if buttons is not None else [] @dispatch - def __init__(self, joy_tuple: tuple[List[float], List[int]]) -> None: + def __init__(self, joy_tuple: tuple[list[float], list[int]]) -> None: """Initialize from a tuple of (axes, buttons).""" self.ts = time.time() self.frame_id = "" @@ -76,7 +76,7 @@ def __init__(self, joy_tuple: tuple[List[float], List[int]]) -> None: self.buttons = list(joy_tuple[1]) @dispatch - def __init__(self, joy_dict: dict[str, List[float] | List[int]]) -> None: + def __init__(self, joy_dict: dict[str, list[float] | list[int]]) -> None: """Initialize from a dictionary with 'axes' and 'buttons' keys.""" self.ts = joy_dict.get("ts", time.time()) self.frame_id = joy_dict.get("frame_id", "") @@ -142,7 +142,7 @@ def __eq__(self, other) -> bool: ) @classmethod - def from_ros_msg(cls, ros_msg: ROSJoy) -> "Joy": + def from_ros_msg(cls, ros_msg: ROSJoy) -> Joy: """Create a Joy from a ROS sensor_msgs/Joy message. Args: diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index d81c8d0198..b8de431fa0 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -16,11 +16,6 @@ import functools import struct -import time -from typing import Optional - -import numpy as np -import open3d as o3d # Import LCM types from dimos_lcm.sensor_msgs.PointCloud2 import ( @@ -28,13 +23,14 @@ ) from dimos_lcm.sensor_msgs.PointField import PointField from dimos_lcm.std_msgs.Header import Header +import numpy as np +import open3d as o3d from dimos.msgs.geometry_msgs import Vector3 # Import ROS types try: - from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 - from sensor_msgs.msg import PointField as ROSPointField + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2, PointField as ROSPointField from std_msgs.msg import Header as ROSHeader ROS_AVAILABLE = True @@ -52,15 +48,15 @@ def __init__( self, pointcloud: o3d.geometry.PointCloud = None, frame_id: str = "world", - ts: Optional[float] = None, - ): + ts: float | None = None, + ) -> None: self.ts = ts self.pointcloud = pointcloud if pointcloud is not None else o3d.geometry.PointCloud() self.frame_id = frame_id @classmethod def from_numpy( - cls, points: np.ndarray, frame_id: str = "world", timestamp: Optional[float] = None + cls, points: np.ndarray, frame_id: str = "world", timestamp: float | None = None ) -> PointCloud2: """Create PointCloud2 from numpy array of shape (N, 3). @@ -131,7 +127,7 @@ def get_bounding_box_dimensions(self) -> tuple[float, float, float]: extent = bbox.get_extent() return tuple(extent) - def bounding_box_intersects(self, other: "PointCloud2") -> bool: + def bounding_box_intersects(self, other: PointCloud2) -> bool: # Get axis-aligned bounding boxes bbox1 = self.get_axis_aligned_bounding_box() bbox2 = other.get_axis_aligned_bounding_box() @@ -153,7 +149,7 @@ def bounding_box_intersects(self, other: "PointCloud2") -> bool: and max1[2] >= min2[2] ) - def lcm_encode(self, frame_id: Optional[str] = None) -> bytes: + def lcm_encode(self, frame_id: str | None = None) -> bytes: """Convert to LCM PointCloud2 message.""" msg = LCMPointCloud2() @@ -211,7 +207,7 @@ def lcm_encode(self, frame_id: Optional[str] = None) -> bytes: return msg.lcm_encode() @classmethod - def lcm_decode(cls, data: bytes) -> "PointCloud2": + def lcm_decode(cls, data: bytes) -> PointCloud2: msg = LCMPointCloud2.lcm_decode(data) if msg.width == 0 or msg.height == 0: @@ -313,9 +309,9 @@ def __len__(self) -> int: def filter_by_height( self, - min_height: Optional[float] = None, - max_height: Optional[float] = None, - ) -> "PointCloud2": + min_height: float | None = None, + max_height: float | None = None, + ) -> PointCloud2: """Filter points based on their height (z-coordinate). This method creates a new PointCloud2 containing only points within the specified @@ -388,7 +384,7 @@ def __repr__(self) -> str: return f"PointCloud(points={len(self)}, frame_id='{self.frame_id}', ts={self.ts})" @classmethod - def from_ros_msg(cls, ros_msg: "ROSPointCloud2") -> "PointCloud2": + def from_ros_msg(cls, ros_msg: ROSPointCloud2) -> PointCloud2: """Convert from ROS sensor_msgs/PointCloud2 message. Args: @@ -499,7 +495,7 @@ def from_ros_msg(cls, ros_msg: "ROSPointCloud2") -> "PointCloud2": ts=ts, ) - def to_ros_msg(self) -> "ROSPointCloud2": + def to_ros_msg(self) -> ROSPointCloud2: """Convert to ROS sensor_msgs/PointCloud2 message. Returns: diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py index 9a8a7b54fe..56574e448d 100644 --- a/dimos/msgs/sensor_msgs/__init__.py +++ b/dimos/msgs/sensor_msgs/__init__.py @@ -1,4 +1,4 @@ -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat -from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.Joy import Joy +from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 diff --git a/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py b/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py index 2f7da1d0d9..9dd0c647d2 100644 --- a/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py +++ b/dimos/msgs/sensor_msgs/image_impls/AbstractImage.py @@ -14,10 +14,10 @@ from __future__ import annotations -import base64 -import os from abc import ABC, abstractmethod +import base64 from enum import Enum +import os from typing import Any import cv2 @@ -148,28 +148,28 @@ def to_opencv(self) -> np.ndarray: # pragma: no cover - abstract ... @abstractmethod - def to_rgb(self) -> "AbstractImage": # pragma: no cover - abstract + def to_rgb(self) -> AbstractImage: # pragma: no cover - abstract ... @abstractmethod - def to_bgr(self) -> "AbstractImage": # pragma: no cover - abstract + def to_bgr(self) -> AbstractImage: # pragma: no cover - abstract ... @abstractmethod - def to_grayscale(self) -> "AbstractImage": # pragma: no cover - abstract + def to_grayscale(self) -> AbstractImage: # pragma: no cover - abstract ... @abstractmethod def resize( self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR - ) -> "AbstractImage": # pragma: no cover - abstract + ) -> AbstractImage: # pragma: no cover - abstract ... @abstractmethod def sharpness(self) -> float: # pragma: no cover - abstract ... - def copy(self) -> "AbstractImage": + def copy(self) -> AbstractImage: return self.__class__( data=self.data.copy(), format=self.format, frame_id=self.frame_id, ts=self.ts ) # type: ignore diff --git a/dimos/msgs/sensor_msgs/image_impls/CudaImage.py b/dimos/msgs/sensor_msgs/image_impls/CudaImage.py index 58ebaf621d..3067138d36 100644 --- a/dimos/msgs/sensor_msgs/image_impls/CudaImage.py +++ b/dimos/msgs/sensor_msgs/image_impls/CudaImage.py @@ -14,27 +14,27 @@ from __future__ import annotations -import time from dataclasses import dataclass, field -from typing import Optional, Tuple +import time import cv2 import numpy as np from dimos.msgs.sensor_msgs.image_impls.AbstractImage import ( + HAS_CUDA, AbstractImage, ImageFormat, - HAS_CUDA, + _ascontig, _is_cu, _to_cpu, - _ascontig, ) -from dimos.msgs.sensor_msgs.image_impls.NumpyImage import NumpyImage try: import cupy as cp # type: ignore - from cupyx.scipy import ndimage as cndimage # type: ignore - from cupyx.scipy import signal as csignal # type: ignore + from cupyx.scipy import ( + ndimage as cndimage, # type: ignore + signal as csignal, # type: ignore + ) except Exception: # pragma: no cover cp = None # type: ignore cndimage = None # type: ignore @@ -267,7 +267,7 @@ _pnp_kernel = _mod.get_function("pnp_gn_batch") -def _solve_pnp_cuda_kernel(obj, img, K, iterations=15, damping=1e-6): +def _solve_pnp_cuda_kernel(obj, img, K, iterations: int = 15, damping: float = 1e-6): if cp is None: raise RuntimeError("CuPy/CUDA not available") @@ -466,7 +466,7 @@ def _skew(v): I = I[None, :, :] if n == 1 else xp.broadcast_to(I, (n, 3, 3)) KK = xp.matmul(K, K) out = I + A * K + B * KK - return out.reshape(orig_shape + (3, 3)) if orig_shape else out[0] + return out.reshape((*orig_shape, 3, 3)) if orig_shape else out[0] mat = arr if mat.shape[-2:] != (3, 3): @@ -503,13 +503,13 @@ def _skew(v): axis = axis / axis_norm[:, None] r_pi = theta[:, None] * axis r = xp.where(pi_mask[:, None], r_pi, r) - out = r.reshape(orig_shape + (3,)) if orig_shape else r[0] + out = r.reshape((*orig_shape, 3)) if orig_shape else r[0] return out def _undistort_points_cuda( - img_px: "cp.ndarray", K: "cp.ndarray", dist: "cp.ndarray", iterations: int = 8 -) -> "cp.ndarray": + img_px: cp.ndarray, K: cp.ndarray, dist: cp.ndarray, iterations: int = 8 +) -> cp.ndarray: """Iteratively undistort pixel coordinates on device (Brown–Conrady). Returns pixel coordinates after undistortion (fx*xu+cx, fy*yu+cy). @@ -570,7 +570,7 @@ def to_opencv(self) -> np.ndarray: return _to_cpu(self.to_bgr().data) return _to_cpu(self.data) - def to_rgb(self) -> "CudaImage": + def to_rgb(self) -> CudaImage: if self.format == ImageFormat.RGB: return self.copy() # type: ignore if self.format == ImageFormat.BGR: @@ -588,7 +588,7 @@ def to_rgb(self) -> "CudaImage": return CudaImage(_gray_to_rgb_cuda(gray8), ImageFormat.RGB, self.frame_id, self.ts) return self.copy() # type: ignore - def to_bgr(self) -> "CudaImage": + def to_bgr(self) -> CudaImage: if self.format == ImageFormat.BGR: return self.copy() # type: ignore if self.format == ImageFormat.RGB: @@ -613,7 +613,7 @@ def to_bgr(self) -> "CudaImage": ) return self.copy() # type: ignore - def to_grayscale(self) -> "CudaImage": + def to_grayscale(self) -> CudaImage: if self.format in (ImageFormat.GRAY, ImageFormat.GRAY16, ImageFormat.DEPTH): return self.copy() # type: ignore if self.format == ImageFormat.BGR: @@ -634,12 +634,12 @@ def to_grayscale(self) -> "CudaImage": return CudaImage(_rgb_to_gray_cuda(rgb), ImageFormat.GRAY, self.frame_id, self.ts) raise ValueError(f"Unsupported format: {self.format}") - def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> "CudaImage": + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> CudaImage: return CudaImage( _resize_bilinear_hwc_cuda(self.data, height, width), self.format, self.frame_id, self.ts ) - def crop(self, x: int, y: int, width: int, height: int) -> "CudaImage": + def crop(self, x: int, y: int, width: int, height: int) -> CudaImage: """Crop the image to the specified region. Args: @@ -710,7 +710,7 @@ def create_csrt_tracker(self, bbox: BBox): raise ValueError("Invalid bbox for CUDA tracker") return _CudaTemplateTracker(tmpl, x0=x, y0=y) - def csrt_update(self, tracker) -> Tuple[bool, Tuple[int, int, int, int]]: + def csrt_update(self, tracker) -> tuple[bool, tuple[int, int, int, int]]: if not isinstance(tracker, _CudaTemplateTracker): raise TypeError("Expected CUDA tracker instance") gray = self.to_grayscale().data.astype(cp.float32) @@ -723,9 +723,9 @@ def solve_pnp( object_points: np.ndarray, image_points: np.ndarray, camera_matrix: np.ndarray, - dist_coeffs: Optional[np.ndarray] = None, + dist_coeffs: np.ndarray | None = None, flags: int = cv2.SOLVEPNP_ITERATIVE, - ) -> Tuple[bool, np.ndarray, np.ndarray]: + ) -> tuple[bool, np.ndarray, np.ndarray]: if not HAS_CUDA or cp is None or (dist_coeffs is not None and np.any(dist_coeffs)): obj = np.asarray(object_points, dtype=np.float32).reshape(-1, 3) img = np.asarray(image_points, dtype=np.float32).reshape(-1, 2) @@ -743,10 +743,10 @@ def solve_pnp_batch( object_points_batch: np.ndarray, image_points_batch: np.ndarray, camera_matrix: np.ndarray, - dist_coeffs: Optional[np.ndarray] = None, + dist_coeffs: np.ndarray | None = None, iterations: int = 15, damping: float = 1e-6, - ) -> Tuple[np.ndarray, np.ndarray]: + ) -> tuple[np.ndarray, np.ndarray]: """Batched PnP (each block = one instance).""" if not HAS_CUDA or cp is None or (dist_coeffs is not None and np.any(dist_coeffs)): obj = np.asarray(object_points_batch, dtype=np.float32) @@ -792,12 +792,12 @@ def solve_pnp_ransac( object_points: np.ndarray, image_points: np.ndarray, camera_matrix: np.ndarray, - dist_coeffs: Optional[np.ndarray] = None, + dist_coeffs: np.ndarray | None = None, iterations_count: int = 100, reprojection_error: float = 3.0, confidence: float = 0.99, min_sample: int = 6, - ) -> Tuple[bool, np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[bool, np.ndarray, np.ndarray, np.ndarray]: """RANSAC with CUDA PnP solver.""" if not HAS_CUDA or cp is None or (dist_coeffs is not None and np.any(dist_coeffs)): obj = np.asarray(object_points, dtype=np.float32) @@ -829,7 +829,7 @@ def solve_pnp_ransac( N = obj.shape[0] rng = cp.random.RandomState(1234) best_inliers = -1 - best_r, best_t, best_mask = None, None, None + _best_r, _best_t, best_mask = None, None, None for _ in range(iterations_count): idx = rng.choice(N, size=min_sample, replace=False) @@ -843,7 +843,7 @@ def solve_pnp_ransac( mask = (err < reprojection_error).astype(cp.uint8) inliers = int(mask.sum()) if inliers > best_inliers: - best_inliers, best_r, best_t, best_mask = inliers, rvec, tvec, mask + best_inliers, _best_r, _best_t, best_mask = inliers, rvec, tvec, mask if inliers >= int(confidence * N): break @@ -857,13 +857,13 @@ def solve_pnp_ransac( class _CudaTemplateTracker: def __init__( self, - tmpl: "cp.ndarray", + tmpl: cp.ndarray, scale_step: float = 1.05, lr: float = 0.1, search_radius: int = 16, x0: int = 0, y0: int = 0, - ): + ) -> None: self.tmpl = tmpl.astype(cp.float32) self.h, self.w = int(tmpl.shape[0]), int(tmpl.shape[1]) self.scale_step = float(scale_step) @@ -877,7 +877,7 @@ def __init__( self.y = int(y0) self.x = int(x0) - def update(self, gray: "cp.ndarray"): + def update(self, gray: cp.ndarray): H, W = int(gray.shape[0]), int(gray.shape[1]) r = self.search_radius x0 = max(0, self.x - r) @@ -891,8 +891,8 @@ def update(self, gray: "cp.ndarray"): best = (self.x, self.y, self.w, self.h) best_score = -1e9 for s in (1.0 / self.scale_step, 1.0, self.scale_step): - th = max(1, int(round(self.h * s))) - tw = max(1, int(round(self.w * s))) + th = max(1, round(self.h * s)) + tw = max(1, round(self.w * s)) tmpl_s = _resize_bilinear_hwc_cuda(self.tmpl, th, tw) if tmpl_s.ndim == 3: tmpl_s = tmpl_s[..., 0] diff --git a/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py b/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py index 3431b11295..d75adc66ea 100644 --- a/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py +++ b/dimos/msgs/sensor_msgs/image_impls/NumpyImage.py @@ -14,9 +14,8 @@ from __future__ import annotations -import time from dataclasses import dataclass, field -from typing import Optional, Tuple +import time import cv2 import numpy as np @@ -61,7 +60,7 @@ def to_opencv(self) -> np.ndarray: return arr raise ValueError(f"Unsupported format: {self.format}") - def to_rgb(self) -> "NumpyImage": + def to_rgb(self) -> NumpyImage: if self.format == ImageFormat.RGB: return self.copy() # type: ignore arr = self.data @@ -80,7 +79,7 @@ def to_rgb(self) -> "NumpyImage": return NumpyImage(rgb, ImageFormat.RGB, self.frame_id, self.ts) return self.copy() # type: ignore - def to_bgr(self) -> "NumpyImage": + def to_bgr(self) -> NumpyImage: if self.format == ImageFormat.BGR: return self.copy() # type: ignore arr = self.data @@ -103,7 +102,7 @@ def to_bgr(self) -> "NumpyImage": ) return self.copy() # type: ignore - def to_grayscale(self) -> "NumpyImage": + def to_grayscale(self) -> NumpyImage: if self.format in (ImageFormat.GRAY, ImageFormat.GRAY16, ImageFormat.DEPTH): return self.copy() # type: ignore if self.format == ImageFormat.BGR: @@ -127,9 +126,7 @@ def to_grayscale(self) -> "NumpyImage": ) raise ValueError(f"Unsupported format: {self.format}") - def resize( - self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR - ) -> "NumpyImage": + def resize(self, width: int, height: int, interpolation: int = cv2.INTER_LINEAR) -> NumpyImage: return NumpyImage( cv2.resize(self.data, (width, height), interpolation=interpolation), self.format, @@ -137,7 +134,7 @@ def resize( self.ts, ) - def crop(self, x: int, y: int, width: int, height: int) -> "NumpyImage": + def crop(self, x: int, y: int, width: int, height: int) -> NumpyImage: """Crop the image to the specified region. Args: @@ -185,9 +182,9 @@ def solve_pnp( object_points: np.ndarray, image_points: np.ndarray, camera_matrix: np.ndarray, - dist_coeffs: Optional[np.ndarray] = None, + dist_coeffs: np.ndarray | None = None, flags: int = cv2.SOLVEPNP_ITERATIVE, - ) -> Tuple[bool, np.ndarray, np.ndarray]: + ) -> tuple[bool, np.ndarray, np.ndarray]: obj = np.asarray(object_points, dtype=np.float32).reshape(-1, 3) img = np.asarray(image_points, dtype=np.float32).reshape(-1, 2) K = np.asarray(camera_matrix, dtype=np.float64) @@ -195,7 +192,7 @@ def solve_pnp( ok, rvec, tvec = cv2.solvePnP(obj, img, K, dist, flags=flags) return bool(ok), rvec.astype(np.float64), tvec.astype(np.float64) - def create_csrt_tracker(self, bbox: Tuple[int, int, int, int]): + def create_csrt_tracker(self, bbox: tuple[int, int, int, int]): tracker = None if hasattr(cv2, "legacy") and hasattr(cv2.legacy, "TrackerCSRT_create"): tracker = cv2.legacy.TrackerCSRT_create() @@ -208,7 +205,7 @@ def create_csrt_tracker(self, bbox: Tuple[int, int, int, int]): raise RuntimeError("Failed to initialize CSRT tracker") return tracker - def csrt_update(self, tracker) -> Tuple[bool, Tuple[int, int, int, int]]: + def csrt_update(self, tracker) -> tuple[bool, tuple[int, int, int, int]]: ok, box = tracker.update(self.to_bgr().to_opencv()) if not ok: return False, (0, 0, 0, 0) @@ -220,12 +217,12 @@ def solve_pnp_ransac( object_points: np.ndarray, image_points: np.ndarray, camera_matrix: np.ndarray, - dist_coeffs: Optional[np.ndarray] = None, + dist_coeffs: np.ndarray | None = None, iterations_count: int = 100, reprojection_error: float = 3.0, confidence: float = 0.99, min_sample: int = 6, - ) -> Tuple[bool, np.ndarray, np.ndarray, np.ndarray]: + ) -> tuple[bool, np.ndarray, np.ndarray, np.ndarray]: obj = np.asarray(object_points, dtype=np.float32).reshape(-1, 3) img = np.asarray(image_points, dtype=np.float32).reshape(-1, 2) K = np.asarray(camera_matrix, dtype=np.float64) diff --git a/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py b/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py index 810cedf5f1..c226e36bf0 100644 --- a/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py +++ b/dimos/msgs/sensor_msgs/image_impls/test_image_backend_utils.py @@ -18,8 +18,6 @@ from dimos.msgs.sensor_msgs import Image, ImageFormat try: - import cupy as cp - HAS_CUDA = True print("Running image backend utils tests with CUDA/CuPy support (GPU mode)") except: @@ -27,13 +25,13 @@ print("Running image backend utils tests in CPU-only mode") from dimos.perception.common.utils import ( - rectify_image, - project_3d_points_to_2d, - project_2d_points_to_3d, colorize_depth, draw_bounding_box, - draw_segmentation_mask, draw_object_detection_visualization, + draw_segmentation_mask, + project_2d_points_to_3d, + project_3d_points_to_2d, + rectify_image, ) @@ -57,7 +55,7 @@ def _has_cupy() -> bool: @pytest.mark.parametrize( "shape,fmt", [((64, 64, 3), ImageFormat.BGR), ((64, 64), ImageFormat.GRAY)] ) -def test_rectify_image_cpu(shape, fmt): +def test_rectify_image_cpu(shape, fmt) -> None: arr = (np.random.rand(*shape) * (255 if fmt != ImageFormat.GRAY else 65535)).astype( np.uint8 if fmt != ImageFormat.GRAY else np.uint16 ) @@ -79,7 +77,7 @@ def test_rectify_image_cpu(shape, fmt): @pytest.mark.parametrize( "shape,fmt", [((32, 32, 3), ImageFormat.BGR), ((32, 32), ImageFormat.GRAY)] ) -def test_rectify_image_gpu_parity(shape, fmt): +def test_rectify_image_gpu_parity(shape, fmt) -> None: import cupy as cp # type: ignore arr_np = (np.random.rand(*shape) * (255 if fmt != ImageFormat.GRAY else 65535)).astype( @@ -102,7 +100,7 @@ def test_rectify_image_gpu_parity(shape, fmt): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_rectify_image_gpu_nonzero_dist_close(): +def test_rectify_image_gpu_nonzero_dist_close() -> None: import cupy as cp # type: ignore H, W = 64, 96 @@ -135,7 +133,7 @@ def test_rectify_image_gpu_nonzero_dist_close(): ) -def test_project_roundtrip_cpu(): +def test_project_roundtrip_cpu() -> None: pts3d = np.array([[0.1, 0.2, 1.0], [0.0, 0.0, 2.0], [0.5, -0.3, 3.0]], dtype=np.float32) fx, fy, cx, cy = 200.0, 220.0, 64.0, 48.0 K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1.0]], dtype=np.float64) @@ -149,7 +147,7 @@ def test_project_roundtrip_cpu(): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_project_parity_gpu_cpu(): +def test_project_parity_gpu_cpu() -> None: import cupy as cp # type: ignore pts3d_np = np.array([[0.1, 0.2, 1.0], [0.0, 0.0, 2.0], [0.5, -0.3, 3.0]], dtype=np.float32) @@ -168,7 +166,7 @@ def test_project_parity_gpu_cpu(): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_project_parity_gpu_cpu_random(): +def test_project_parity_gpu_cpu_random() -> None: import cupy as cp # type: ignore rng = np.random.RandomState(0) @@ -196,7 +194,7 @@ def test_project_parity_gpu_cpu_random(): assert pts3d_cpu.shape == cp.asnumpy(pts3d_gpu).shape -def test_colorize_depth_cpu(): +def test_colorize_depth_cpu() -> None: depth = np.zeros((32, 48), dtype=np.float32) depth[8:16, 12:24] = 1.5 out = colorize_depth(depth, max_depth=3.0, overlay_stats=False) @@ -206,7 +204,7 @@ def test_colorize_depth_cpu(): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_colorize_depth_gpu_parity(): +def test_colorize_depth_gpu_parity() -> None: import cupy as cp # type: ignore depth_np = np.zeros((16, 20), dtype=np.float32) @@ -216,7 +214,7 @@ def test_colorize_depth_gpu_parity(): np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) -def test_draw_bounding_box_cpu(): +def test_draw_bounding_box_cpu() -> None: img = np.zeros((20, 30, 3), dtype=np.uint8) out = draw_bounding_box(img, [2, 3, 10, 12], color=(255, 0, 0), thickness=1) assert isinstance(out, np.ndarray) @@ -225,7 +223,7 @@ def test_draw_bounding_box_cpu(): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_draw_bounding_box_gpu_parity(): +def test_draw_bounding_box_gpu_parity() -> None: import cupy as cp # type: ignore img_np = np.zeros((20, 30, 3), dtype=np.uint8) @@ -235,7 +233,7 @@ def test_draw_bounding_box_gpu_parity(): np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) -def test_draw_segmentation_mask_cpu(): +def test_draw_segmentation_mask_cpu() -> None: img = np.zeros((20, 30, 3), dtype=np.uint8) mask = np.zeros((20, 30), dtype=np.uint8) mask[5:10, 8:15] = 1 @@ -244,7 +242,7 @@ def test_draw_segmentation_mask_cpu(): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_draw_segmentation_mask_gpu_parity(): +def test_draw_segmentation_mask_gpu_parity() -> None: import cupy as cp # type: ignore img_np = np.zeros((20, 30, 3), dtype=np.uint8) @@ -257,7 +255,7 @@ def test_draw_segmentation_mask_gpu_parity(): np.testing.assert_array_equal(cp.asnumpy(out_gpu), out_cpu) -def test_draw_object_detection_visualization_cpu(): +def test_draw_object_detection_visualization_cpu() -> None: img = np.zeros((30, 40, 3), dtype=np.uint8) objects = [ { @@ -272,7 +270,7 @@ def test_draw_object_detection_visualization_cpu(): @pytest.mark.skipif(not _has_cupy(), reason="CuPy/CUDA not available") -def test_draw_object_detection_visualization_gpu_parity(): +def test_draw_object_detection_visualization_gpu_parity() -> None: import cupy as cp # type: ignore img_np = np.zeros((30, 40, 3), dtype=np.uint8) diff --git a/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py b/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py index 0e19a24167..d8012a8f53 100644 --- a/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py +++ b/dimos/msgs/sensor_msgs/image_impls/test_image_backends.py @@ -87,7 +87,7 @@ def alloc_timer(request): """Helper fixture for adaptive testing with optional GPU support.""" def _alloc( - arr: np.ndarray, fmt: ImageFormat, *, to_cuda: bool = None, label: str | None = None + arr: np.ndarray, fmt: ImageFormat, *, to_cuda: bool | None = None, label: str | None = None ): tag = label or request.node.name @@ -126,7 +126,7 @@ def _alloc( ((64, 64), ImageFormat.GRAY), ], ) -def test_color_conversions(shape, fmt, alloc_timer): +def test_color_conversions(shape, fmt, alloc_timer) -> None: """Test color conversions with NumpyImage always, add CudaImage parity when available.""" arr = _prepare_image(fmt, shape) cpu, gpu, _, _ = alloc_timer(arr, fmt) @@ -147,7 +147,7 @@ def test_color_conversions(shape, fmt, alloc_timer): assert np.array_equal(cpu_round, gpu_round) -def test_grayscale(alloc_timer): +def test_grayscale(alloc_timer) -> None: """Test grayscale conversion with NumpyImage always, add CudaImage parity when available.""" arr = _prepare_image(ImageFormat.BGR, (48, 32, 3)) cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR) @@ -168,7 +168,7 @@ def test_grayscale(alloc_timer): @pytest.mark.parametrize("fmt", [ImageFormat.BGR, ImageFormat.RGB, ImageFormat.BGRA]) -def test_resize(fmt, alloc_timer): +def test_resize(fmt, alloc_timer) -> None: """Test resize with NumpyImage always, add CudaImage parity when available.""" shape = (60, 80, 3) if fmt in (ImageFormat.BGR, ImageFormat.RGB) else (60, 80, 4) arr = _prepare_image(fmt, shape) @@ -192,7 +192,7 @@ def test_resize(fmt, alloc_timer): assert np.max(np.abs(cpu_res.astype(np.int16) - gpu_res.astype(np.int16))) <= 1 -def test_perf_alloc(alloc_timer): +def test_perf_alloc(alloc_timer) -> None: """Test allocation performance with NumpyImage always, add CudaImage when available.""" arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) alloc_timer(arr, ImageFormat.BGR, label="test_perf_alloc-setup") @@ -218,7 +218,7 @@ def test_perf_alloc(alloc_timer): print(f"alloc (avg per call) cpu={cpu_t:.6f}s") -def test_sharpness(alloc_timer): +def test_sharpness(alloc_timer) -> None: """Test sharpness computation with NumpyImage always, add CudaImage parity when available.""" arr = _prepare_image(ImageFormat.BGR, (64, 64, 3)) cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR) @@ -235,7 +235,7 @@ def test_sharpness(alloc_timer): assert abs(s_cpu - s_gpu) < 5e-2 -def test_to_opencv(alloc_timer): +def test_to_opencv(alloc_timer) -> None: """Test to_opencv conversion with NumpyImage always, add CudaImage parity when available.""" # BGRA should drop alpha and produce BGR arr = _prepare_image(ImageFormat.BGRA, (32, 32, 4)) @@ -254,7 +254,7 @@ def test_to_opencv(alloc_timer): assert np.array_equal(cpu_bgr, gpu_bgr) -def test_solve_pnp(alloc_timer): +def test_solve_pnp(alloc_timer) -> None: """Test solve_pnp with NumpyImage always, add CudaImage parity when available.""" # Synthetic camera and 3D points K = np.array([[400.0, 0.0, 32.0], [0.0, 400.0, 24.0], [0.0, 0.0, 1.0]], dtype=np.float64) @@ -304,7 +304,7 @@ def test_solve_pnp(alloc_timer): assert err_gpu.max() < 1e-2 -def test_perf_grayscale(alloc_timer): +def test_perf_grayscale(alloc_timer) -> None: """Test grayscale performance with NumpyImage always, add CudaImage when available.""" arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR, label="test_perf_grayscale-setup") @@ -330,7 +330,7 @@ def test_perf_grayscale(alloc_timer): print(f"grayscale (avg per call) cpu={cpu_t:.6f}s") -def test_perf_resize(alloc_timer): +def test_perf_resize(alloc_timer) -> None: """Test resize performance with NumpyImage always, add CudaImage when available.""" arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR, label="test_perf_resize-setup") @@ -356,7 +356,7 @@ def test_perf_resize(alloc_timer): print(f"resize (avg per call) cpu={cpu_t:.6f}s") -def test_perf_sharpness(alloc_timer): +def test_perf_sharpness(alloc_timer) -> None: """Test sharpness performance with NumpyImage always, add CudaImage when available.""" arr = _prepare_image(ImageFormat.BGR, (480, 640, 3)) cpu, gpu, _, _ = alloc_timer(arr, ImageFormat.BGR, label="test_perf_sharpness-setup") @@ -382,7 +382,7 @@ def test_perf_sharpness(alloc_timer): print(f"sharpness (avg per call) cpu={cpu_t:.6f}s") -def test_perf_solvepnp(alloc_timer): +def test_perf_solvepnp(alloc_timer) -> None: """Test solve_pnp performance with NumpyImage always, add CudaImage when available.""" K = np.array([[600.0, 0.0, 320.0], [0.0, 600.0, 240.0], [0.0, 0.0, 1.0]], dtype=np.float64) dist = None @@ -419,7 +419,7 @@ def test_perf_solvepnp(alloc_timer): # this test is failing with # raise RuntimeError("OpenCV CSRT tracker not available") @pytest.mark.skip -def test_perf_tracker(alloc_timer): +def test_perf_tracker(alloc_timer) -> None: """Test tracker performance with NumpyImage always, add CudaImage when available.""" # Don't check - just let it fail if CSRT isn't available @@ -467,7 +467,7 @@ def test_perf_tracker(alloc_timer): # this test is failing with # raise RuntimeError("OpenCV CSRT tracker not available") @pytest.mark.skip -def test_csrt_tracker(alloc_timer): +def test_csrt_tracker(alloc_timer) -> None: """Test CSRT tracker with NumpyImage always, add CudaImage parity when available.""" # Don't check - just let it fail if CSRT isn't available @@ -500,7 +500,7 @@ def test_csrt_tracker(alloc_timer): # Compare to ground-truth expected bbox expected = (x0 + dx, y0 + dy, w0, h0) - err_cpu = sum(abs(a - b) for a, b in zip(bbox_cpu, expected)) + err_cpu = sum(abs(a - b) for a, b in zip(bbox_cpu, expected, strict=False)) assert err_cpu <= 8 # Optionally test GPU parity when CUDA is available @@ -509,11 +509,11 @@ def test_csrt_tracker(alloc_timer): ok_gpu, bbox_gpu = gpu2.csrt_update(trk_gpu) assert ok_gpu - err_gpu = sum(abs(a - b) for a, b in zip(bbox_gpu, expected)) + err_gpu = sum(abs(a - b) for a, b in zip(bbox_gpu, expected, strict=False)) assert err_gpu <= 10 # allow some slack for scale/window effects -def test_solve_pnp_ransac(alloc_timer): +def test_solve_pnp_ransac(alloc_timer) -> None: """Test solve_pnp_ransac with NumpyImage always, add CudaImage when available.""" # Camera with distortion K = np.array([[500.0, 0.0, 320.0], [0.0, 500.0, 240.0], [0.0, 0.0, 1.0]], dtype=np.float64) @@ -568,7 +568,7 @@ def test_solve_pnp_ransac(alloc_timer): assert err_gpu.max() < 4.0 -def test_solve_pnp_batch(alloc_timer): +def test_solve_pnp_batch(alloc_timer) -> None: """Test solve_pnp batch processing with NumpyImage always, add CudaImage when available.""" # Note: Batch processing is primarily a GPU feature, but we can still test CPU loop # Generate batched problems @@ -625,7 +625,7 @@ def test_solve_pnp_batch(alloc_timer): print(f"solvePnP-batch (avg per pose) cpu={cpu_t:.6f}s (GPU batch not available)") -def test_nvimgcodec_flag_and_fallback(monkeypatch): +def test_nvimgcodec_flag_and_fallback(monkeypatch) -> None: # Test that to_base64() works with and without nvimgcodec by patching runtime flags import dimos.msgs.sensor_msgs.image_impls.AbstractImage as AbstractImageMod @@ -668,7 +668,7 @@ def test_nvimgcodec_flag_and_fallback(monkeypatch): @pytest.mark.skipif(not HAS_CUDA, reason="CuPy/CUDA not available") -def test_nvimgcodec_gpu_path(monkeypatch): +def test_nvimgcodec_gpu_path(monkeypatch) -> None: """Test nvimgcodec GPU encoding path when CUDA is available. This test specifically verifies that when nvimgcodec is available, @@ -681,7 +681,6 @@ def test_nvimgcodec_gpu_path(monkeypatch): pytest.skip("nvimgcodec library not available") # Save original nvimgcodec module reference - original_nvimgcodec = AbstractImageMod.nvimgcodec # Create a CUDA image and encode using the actual nvimgcodec if available arr = _prepare_image(ImageFormat.BGR, (32, 32, 3)) @@ -709,7 +708,7 @@ def test_nvimgcodec_gpu_path(monkeypatch): @pytest.mark.skipif(not HAS_CUDA, reason="CuPy/CUDA not available") -def test_to_cpu_format_preservation(): +def test_to_cpu_format_preservation() -> None: """Test that to_cpu() preserves image format correctly. This tests the fix for the bug where to_cpu() was using to_opencv() diff --git a/dimos/msgs/sensor_msgs/test_CameraInfo.py b/dimos/msgs/sensor_msgs/test_CameraInfo.py index fe4076a325..c35145255b 100644 --- a/dimos/msgs/sensor_msgs/test_CameraInfo.py +++ b/dimos/msgs/sensor_msgs/test_CameraInfo.py @@ -17,8 +17,7 @@ import pytest try: - from sensor_msgs.msg import CameraInfo as ROSCameraInfo - from sensor_msgs.msg import RegionOfInterest as ROSRegionOfInterest + from sensor_msgs.msg import CameraInfo as ROSCameraInfo, RegionOfInterest as ROSRegionOfInterest from std_msgs.msg import Header as ROSHeader except ImportError: ROSCameraInfo = None @@ -29,7 +28,7 @@ from dimos.utils.path_utils import get_project_root -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test LCM encode/decode preserves CameraInfo data.""" print("Testing CameraInfo LCM encode/decode...") @@ -150,7 +149,7 @@ def test_lcm_encode_decode(): print("✓ LCM encode/decode test passed - all properties preserved!") -def test_numpy_matrix_operations(): +def test_numpy_matrix_operations() -> None: """Test numpy matrix getter/setter operations.""" print("\nTesting numpy matrix operations...") @@ -188,7 +187,7 @@ def test_numpy_matrix_operations(): @pytest.mark.ros -def test_ros_conversion(): +def test_ros_conversion() -> None: """Test ROS message conversion preserves CameraInfo data.""" print("\nTesting ROS CameraInfo conversion...") @@ -336,7 +335,7 @@ def test_ros_conversion(): assert dimos_info.frame_id == "test_camera", ( f"Frame ID not preserved: expected 'test_camera', got '{dimos_info.frame_id}'" ) - assert dimos_info.distortion_model == "plumb_bob", f"Distortion model not preserved" + assert dimos_info.distortion_model == "plumb_bob", "Distortion model not preserved" assert len(dimos_info.D) == 5, ( f"Wrong number of distortion coefficients: expected 5, got {len(dimos_info.D)}" ) @@ -356,7 +355,7 @@ def test_ros_conversion(): print("\n✓ All ROS conversion tests passed!") -def test_equality(): +def test_equality() -> None: """Test CameraInfo equality comparison.""" print("\nTesting CameraInfo equality...") @@ -391,7 +390,7 @@ def test_equality(): print("✓ Equality comparison works correctly") -def test_camera_info_from_yaml(): +def test_camera_info_from_yaml() -> None: """Test loading CameraInfo from YAML file.""" # Get path to the single webcam YAML file @@ -427,7 +426,7 @@ def test_camera_info_from_yaml(): print("✓ CameraInfo loaded successfully from YAML file") -def test_calibration_provider(): +def test_calibration_provider() -> None: """Test CalibrationProvider lazy loading of YAML files.""" # Get the directory containing calibration files (not the file itself) calibration_dir = get_project_root() / "dimos" / "hardware" / "camera" / "zed" diff --git a/dimos/msgs/sensor_msgs/test_Joy.py b/dimos/msgs/sensor_msgs/test_Joy.py index 174b9d8908..ae1b4a6379 100644 --- a/dimos/msgs/sensor_msgs/test_Joy.py +++ b/dimos/msgs/sensor_msgs/test_Joy.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. + import pytest -import time try: from sensor_msgs.msg import Joy as ROSJoy @@ -29,7 +29,7 @@ from dimos.msgs.sensor_msgs.Joy import Joy -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test LCM encode/decode preserves Joy data.""" print("Testing Joy LCM encode/decode...") @@ -58,7 +58,7 @@ def test_lcm_encode_decode(): print("✓ Joy LCM encode/decode test passed") -def test_initialization_methods(): +def test_initialization_methods() -> None: """Test various initialization methods for Joy.""" print("Testing Joy initialization methods...") @@ -98,7 +98,7 @@ def test_initialization_methods(): print("✓ Joy initialization methods test passed") -def test_equality(): +def test_equality() -> None: """Test Joy equality comparison.""" print("Testing Joy equality...") @@ -136,7 +136,7 @@ def test_equality(): print("✓ Joy equality test passed") -def test_string_representation(): +def test_string_representation() -> None: """Test Joy string representations.""" print("Testing Joy string representations...") @@ -166,7 +166,7 @@ def test_string_representation(): @pytest.mark.ros -def test_ros_conversion(): +def test_ros_conversion() -> None: """Test conversion to/from ROS Joy messages.""" print("Testing Joy ROS conversion...") @@ -197,7 +197,7 @@ def test_ros_conversion(): print("✓ Joy ROS conversion test passed") -def test_edge_cases(): +def test_edge_cases() -> None: """Test Joy with edge cases.""" print("Testing Joy edge cases...") @@ -220,7 +220,7 @@ def test_edge_cases(): decoded = Joy.lcm_decode(encoded) # Check axes with floating point tolerance assert len(decoded.axes) == len(many_axes) - for i, (a, b) in enumerate(zip(decoded.axes, many_axes)): + for i, (a, b) in enumerate(zip(decoded.axes, many_axes, strict=False)): assert abs(a - b) < 1e-6, f"Axis {i}: {a} != {b}" assert decoded.buttons == many_buttons diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py index cb18d6fd9d..d51b827fa7 100644 --- a/dimos/msgs/sensor_msgs/test_PointCloud2.py +++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -13,14 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest -import numpy as np -import struct +import numpy as np +import pytest try: - from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 - from sensor_msgs.msg import PointField as ROSPointField + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2, PointField as ROSPointField from std_msgs.msg import Header as ROSHeader except ImportError: ROSPointCloud2 = None @@ -38,7 +36,7 @@ ROS_AVAILABLE = False -def test_lcm_encode_decode(): +def test_lcm_encode_decode() -> None: """Test LCM encode/decode preserves pointcloud data.""" replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) lidar_msg: LidarMessage = replay.load_one("lidar_data_021") @@ -100,7 +98,7 @@ def test_lcm_encode_decode(): @pytest.mark.ros -def test_ros_conversion(): +def test_ros_conversion() -> None: """Test ROS message conversion preserves pointcloud data.""" if not ROS_AVAILABLE: print("ROS packages not available - skipping ROS conversion test") @@ -234,37 +232,37 @@ def test_ros_conversion(): print("\n✓ All ROS conversion tests passed!") -def test_bounding_box_intersects(): +def test_bounding_box_intersects() -> None: """Test bounding_box_intersects method with various scenarios.""" # Test 1: Overlapping boxes pc1 = PointCloud2.from_numpy(np.array([[0, 0, 0], [2, 2, 2]])) pc2 = PointCloud2.from_numpy(np.array([[1, 1, 1], [3, 3, 3]])) - assert pc1.bounding_box_intersects(pc2) == True - assert pc2.bounding_box_intersects(pc1) == True # Should be symmetric + assert pc1.bounding_box_intersects(pc2) + assert pc2.bounding_box_intersects(pc1) # Should be symmetric # Test 2: Non-overlapping boxes pc3 = PointCloud2.from_numpy(np.array([[0, 0, 0], [1, 1, 1]])) pc4 = PointCloud2.from_numpy(np.array([[2, 2, 2], [3, 3, 3]])) - assert pc3.bounding_box_intersects(pc4) == False - assert pc4.bounding_box_intersects(pc3) == False + assert not pc3.bounding_box_intersects(pc4) + assert not pc4.bounding_box_intersects(pc3) # Test 3: Touching boxes (edge case - should be True) pc5 = PointCloud2.from_numpy(np.array([[0, 0, 0], [1, 1, 1]])) pc6 = PointCloud2.from_numpy(np.array([[1, 1, 1], [2, 2, 2]])) - assert pc5.bounding_box_intersects(pc6) == True - assert pc6.bounding_box_intersects(pc5) == True + assert pc5.bounding_box_intersects(pc6) + assert pc6.bounding_box_intersects(pc5) # Test 4: One box completely inside another pc7 = PointCloud2.from_numpy(np.array([[0, 0, 0], [3, 3, 3]])) pc8 = PointCloud2.from_numpy(np.array([[1, 1, 1], [2, 2, 2]])) - assert pc7.bounding_box_intersects(pc8) == True - assert pc8.bounding_box_intersects(pc7) == True + assert pc7.bounding_box_intersects(pc8) + assert pc8.bounding_box_intersects(pc7) # Test 5: Boxes overlapping only in 2 dimensions (not all 3) pc9 = PointCloud2.from_numpy(np.array([[0, 0, 0], [2, 2, 1]])) pc10 = PointCloud2.from_numpy(np.array([[1, 1, 2], [3, 3, 3]])) - assert pc9.bounding_box_intersects(pc10) == False - assert pc10.bounding_box_intersects(pc9) == False + assert not pc9.bounding_box_intersects(pc10) + assert not pc10.bounding_box_intersects(pc9) # Test 6: Real-world detection scenario with floating point coordinates detection1_points = np.array( @@ -277,7 +275,7 @@ def test_bounding_box_intersects(): ) pc_det2 = PointCloud2.from_numpy(detection2_points) - assert pc_det1.bounding_box_intersects(pc_det2) == True + assert pc_det1.bounding_box_intersects(pc_det2) # Test 7: Single point clouds pc_single1 = PointCloud2.from_numpy(np.array([[1.0, 1.0, 1.0]])) @@ -285,14 +283,14 @@ def test_bounding_box_intersects(): pc_single3 = PointCloud2.from_numpy(np.array([[2.0, 2.0, 2.0]])) # Same point should intersect - assert pc_single1.bounding_box_intersects(pc_single2) == True + assert pc_single1.bounding_box_intersects(pc_single2) # Different points should not intersect - assert pc_single1.bounding_box_intersects(pc_single3) == False + assert not pc_single1.bounding_box_intersects(pc_single3) # Test 8: Empty point clouds pc_empty1 = PointCloud2.from_numpy(np.array([]).reshape(0, 3)) pc_empty2 = PointCloud2.from_numpy(np.array([]).reshape(0, 3)) - pc_nonempty = PointCloud2.from_numpy(np.array([[1.0, 1.0, 1.0]])) + PointCloud2.from_numpy(np.array([[1.0, 1.0, 1.0]])) # Empty clouds should handle gracefully (Open3D returns inf bounds) # This might raise an exception or return False - we should handle gracefully diff --git a/dimos/msgs/sensor_msgs/test_image.py b/dimos/msgs/sensor_msgs/test_image.py index 6fa0b9d37b..65237e4a6c 100644 --- a/dimos/msgs/sensor_msgs/test_image.py +++ b/dimos/msgs/sensor_msgs/test_image.py @@ -16,7 +16,7 @@ import pytest from reactivex import operators as ops -from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier, sharpness_window +from dimos.msgs.sensor_msgs.Image import Image, ImageFormat, sharpness_barrier from dimos.utils.data import get_data from dimos.utils.testing import TimedSensorReplay @@ -27,7 +27,7 @@ def img(): return Image.from_file(str(image_file_path)) -def test_file_load(img: Image): +def test_file_load(img: Image) -> None: assert isinstance(img.data, np.ndarray) assert img.width == 1024 assert img.height == 771 @@ -41,7 +41,7 @@ def test_file_load(img: Image): assert img.data.flags["C_CONTIGUOUS"] -def test_lcm_encode_decode(img: Image): +def test_lcm_encode_decode(img: Image) -> None: binary_msg = img.lcm_encode() decoded_img = Image.lcm_decode(binary_msg) @@ -50,13 +50,13 @@ def test_lcm_encode_decode(img: Image): assert decoded_img == img -def test_rgb_bgr_conversion(img: Image): +def test_rgb_bgr_conversion(img: Image) -> None: rgb = img.to_rgb() assert not rgb == img assert rgb.to_bgr() == img -def test_opencv_conversion(img: Image): +def test_opencv_conversion(img: Image) -> None: ocv = img.to_opencv() decoded_img = Image.from_opencv(ocv) @@ -66,7 +66,7 @@ def test_opencv_conversion(img: Image): @pytest.mark.tool -def test_sharpness_stream(): +def test_sharpness_stream() -> None: get_data("unitree_office_walk") # Preload data for testing video_store = TimedSensorReplay( "unitree_office_walk/video", autocast=lambda x: Image.from_numpy(x).to_rgb() @@ -80,7 +80,7 @@ def test_sharpness_stream(): return -def test_sharpness_barrier(): +def test_sharpness_barrier() -> None: import time from unittest.mock import MagicMock @@ -110,7 +110,7 @@ def track_input(img): window_contents.append((relative_time, img)) return img - def track_output(img): + def track_output(img) -> None: """Track what sharpness_barrier emits""" emitted_images.append(img) @@ -118,8 +118,6 @@ def track_output(img): # Emit images at 100Hz to get ~5 per window from reactivex import from_iterable, interval - window_duration = 0.05 # 20Hz = 0.05s windows - source = from_iterable(mock_images).pipe( ops.zip(interval(0.01)), # 100Hz emission rate ops.map(lambda x: x[0]), # Extract just the image diff --git a/dimos/msgs/std_msgs/Bool.py b/dimos/msgs/std_msgs/Bool.py index 6af250277e..55751a41eb 100644 --- a/dimos/msgs/std_msgs/Bool.py +++ b/dimos/msgs/std_msgs/Bool.py @@ -15,8 +15,6 @@ """Bool message type.""" -from typing import ClassVar - from dimos_lcm.std_msgs import Bool as LCMBool try: @@ -30,7 +28,7 @@ class Bool(LCMBool): msg_name = "std_msgs.Bool" - def __init__(self, data: bool = False): + def __init__(self, data: bool = False) -> None: """Initialize Bool with data value.""" self.data = data diff --git a/dimos/msgs/std_msgs/Header.py b/dimos/msgs/std_msgs/Header.py index 7b48293a68..1d17913941 100644 --- a/dimos/msgs/std_msgs/Header.py +++ b/dimos/msgs/std_msgs/Header.py @@ -14,11 +14,10 @@ from __future__ import annotations -import time from datetime import datetime +import time -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime +from dimos_lcm.std_msgs import Header as LCMHeader, Time as LCMTime from plum import dispatch # Import the actual LCM header type that's returned from decoding diff --git a/dimos/msgs/std_msgs/Int32.py b/dimos/msgs/std_msgs/Int32.py index 910d7c375e..0ce2f03f60 100644 --- a/dimos/msgs/std_msgs/Int32.py +++ b/dimos/msgs/std_msgs/Int32.py @@ -18,6 +18,7 @@ """Int32 message type.""" from typing import ClassVar + from dimos_lcm.std_msgs import Int32 as LCMInt32 @@ -26,6 +27,6 @@ class Int32(LCMInt32): msg_name: ClassVar[str] = "std_msgs.Int32" - def __init__(self, data: int = 0): + def __init__(self, data: int = 0) -> None: """Initialize Int32 with data value.""" self.data = data diff --git a/dimos/msgs/std_msgs/Int8.py b/dimos/msgs/std_msgs/Int8.py index e4a4a24e17..d76b479d41 100644 --- a/dimos/msgs/std_msgs/Int8.py +++ b/dimos/msgs/std_msgs/Int8.py @@ -18,6 +18,7 @@ """Int32 message type.""" from typing import ClassVar + from dimos_lcm.std_msgs import Int8 as LCMInt8 try: @@ -31,7 +32,7 @@ class Int8(LCMInt8): msg_name: ClassVar[str] = "std_msgs.Int8" - def __init__(self, data: int = 0): + def __init__(self, data: int = 0) -> None: """Initialize Int8 with data value.""" self.data = data diff --git a/dimos/msgs/std_msgs/__init__.py b/dimos/msgs/std_msgs/__init__.py index d46e2ce9a3..e517ea1864 100644 --- a/dimos/msgs/std_msgs/__init__.py +++ b/dimos/msgs/std_msgs/__init__.py @@ -14,7 +14,7 @@ from .Bool import Bool from .Header import Header -from .Int32 import Int32 from .Int8 import Int8 +from .Int32 import Int32 -__all__ = ["Bool", "Header", "Int32", "Int8"] +__all__ = ["Bool", "Header", "Int8", "Int32"] diff --git a/dimos/msgs/std_msgs/test_header.py b/dimos/msgs/std_msgs/test_header.py index 85ffa0b8c6..314ee5cd37 100644 --- a/dimos/msgs/std_msgs/test_header.py +++ b/dimos/msgs/std_msgs/test_header.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time from datetime import datetime - -import pytest +import time from dimos.msgs.std_msgs import Header -def test_header_initialization_methods(): +def test_header_initialization_methods() -> None: """Test various ways to initialize a Header.""" # Method 1: With timestamp and frame_id @@ -69,7 +67,7 @@ def test_header_initialization_methods(): assert header7.frame_id == "lidar" -def test_header_properties(): +def test_header_properties() -> None: """Test Header property accessors.""" header = Header(1234567890.123456789, "test") @@ -82,7 +80,7 @@ def test_header_properties(): assert abs(dt.timestamp() - 1234567890.123456789) < 1e-6 -def test_header_string_representation(): +def test_header_string_representation() -> None: """Test Header string representations.""" header = Header(100.5, "map", seq=10) diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py index d2bb018c34..5aabfa4b23 100644 --- a/dimos/msgs/tf2_msgs/TFMessage.py +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -27,24 +27,23 @@ from __future__ import annotations -from typing import BinaryIO +from typing import TYPE_CHECKING, BinaryIO -from dimos_lcm.geometry_msgs import Transform as LCMTransform -from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped -from dimos_lcm.std_msgs import Header as LCMHeader -from dimos_lcm.std_msgs import Time as LCMTime from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage try: - from tf2_msgs.msg import TFMessage as ROSTFMessage from geometry_msgs.msg import TransformStamped as ROSTransformStamped + from tf2_msgs.msg import TFMessage as ROSTFMessage except ImportError: ROSTFMessage = None ROSTransformStamped = None +from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 -from dimos.msgs.geometry_msgs.Quaternion import Quaternion + +if TYPE_CHECKING: + from collections.abc import Iterator class TFMessage: @@ -114,7 +113,7 @@ def __getitem__(self, index: int) -> Transform: """Get transform by index.""" return self.transforms[index] - def __iter__(self): + def __iter__(self) -> Iterator: """Iterate over transforms.""" return iter(self.transforms) @@ -128,7 +127,7 @@ def __str__(self) -> str: return "\n".join(lines) @classmethod - def from_ros_msg(cls, ros_msg: ROSTFMessage) -> "TFMessage": + def from_ros_msg(cls, ros_msg: ROSTFMessage) -> TFMessage: """Create a TFMessage from a ROS tf2_msgs/TFMessage message. Args: diff --git a/dimos/msgs/tf2_msgs/test_TFMessage.py b/dimos/msgs/tf2_msgs/test_TFMessage.py index dfe3400e1c..26c0bac570 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage.py @@ -15,8 +15,8 @@ import pytest try: - from tf2_msgs.msg import TFMessage as ROSTFMessage from geometry_msgs.msg import TransformStamped as ROSTransformStamped + from tf2_msgs.msg import TFMessage as ROSTFMessage except ImportError: ROSTransformStamped = None ROSTFMessage = None @@ -27,7 +27,7 @@ from dimos.msgs.tf2_msgs import TFMessage -def test_tfmessage_initialization(): +def test_tfmessage_initialization() -> None: """Test TFMessage initialization with Transform objects.""" # Create some transforms tf1 = Transform( @@ -52,14 +52,14 @@ def test_tfmessage_initialization(): assert transforms == [tf1, tf2] -def test_tfmessage_empty(): +def test_tfmessage_empty() -> None: """Test empty TFMessage.""" msg = TFMessage() assert len(msg) == 0 assert list(msg) == [] -def test_tfmessage_add_transform(): +def test_tfmessage_add_transform() -> None: """Test adding transforms to TFMessage.""" msg = TFMessage() @@ -70,7 +70,7 @@ def test_tfmessage_add_transform(): assert msg[0] == tf -def test_tfmessage_lcm_encode_decode(): +def test_tfmessage_lcm_encode_decode() -> None: """Test encoding TFMessage to LCM bytes.""" # Create transforms tf1 = Transform( @@ -118,7 +118,7 @@ def test_tfmessage_lcm_encode_decode(): @pytest.mark.ros -def test_tfmessage_from_ros_msg(): +def test_tfmessage_from_ros_msg() -> None: """Test creating a TFMessage from a ROS TFMessage message.""" ros_msg = ROSTFMessage() @@ -179,7 +179,7 @@ def test_tfmessage_from_ros_msg(): @pytest.mark.ros -def test_tfmessage_to_ros_msg(): +def test_tfmessage_to_ros_msg() -> None: """Test converting a TFMessage to a ROS TFMessage message.""" # Create transforms tf1 = Transform( @@ -230,7 +230,7 @@ def test_tfmessage_to_ros_msg(): @pytest.mark.ros -def test_tfmessage_ros_roundtrip(): +def test_tfmessage_ros_roundtrip() -> None: """Test round-trip conversion between TFMessage and ROS TFMessage.""" # Create transforms with various properties tf1 = Transform( @@ -256,7 +256,7 @@ def test_tfmessage_ros_roundtrip(): assert len(restored) == len(original) - for orig_tf, rest_tf in zip(original, restored): + for orig_tf, rest_tf in zip(original, restored, strict=False): assert rest_tf.frame_id == orig_tf.frame_id assert rest_tf.child_frame_id == orig_tf.child_frame_id assert rest_tf.ts == orig_tf.ts diff --git a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py index bd8259997f..9471673821 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage_lcmpub.py @@ -13,9 +13,7 @@ # limitations under the License. import time -from dataclasses import dataclass -import numpy as np import pytest from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 @@ -26,8 +24,7 @@ # Publishes a series of transforms representing a robot kinematic chain # to actual LCM messages, foxglove running in parallel should render this @pytest.mark.skip -def test_publish_transforms(): - import tf_lcm_py +def test_publish_transforms() -> None: from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage lcm = LCM(autoconf=True) diff --git a/dimos/navigation/bbox_navigation.py b/dimos/navigation/bbox_navigation.py index f498f2ec3f..db66ab8349 100644 --- a/dimos/navigation/bbox_navigation.py +++ b/dimos/navigation/bbox_navigation.py @@ -12,14 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.core import Module, In, Out, rpc -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.msgs.geometry_msgs import PoseStamped, Vector3, Quaternion -from dimos_lcm.sensor_msgs import CameraInfo -from dimos.utils.logging_config import setup_logger import logging + +from dimos_lcm.sensor_msgs import CameraInfo from reactivex.disposable import Disposable +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.utils.logging_config import setup_logger + logger = setup_logger(__name__, level=logging.DEBUG) @@ -30,13 +32,13 @@ class BBoxNavigationModule(Module): camera_info: In[CameraInfo] = None goal_request: Out[PoseStamped] = None - def __init__(self, goal_distance: float = 1.0): + def __init__(self, goal_distance: float = 1.0) -> None: super().__init__() self.goal_distance = goal_distance self.camera_intrinsics = None @rpc - def start(self): + def start(self) -> None: unsub = self.camera_info.subscribe( lambda msg: setattr(self, "camera_intrinsics", [msg.K[0], msg.K[4], msg.K[2], msg.K[5]]) ) @@ -49,7 +51,7 @@ def start(self): def stop(self) -> None: super().stop() - def _on_detection(self, det: Detection2DArray): + def _on_detection(self, det: Detection2DArray) -> None: if det.detections_length == 0 or not self.camera_intrinsics: return fx, fy, cx, cy = self.camera_intrinsics diff --git a/dimos/navigation/bt_navigator/goal_validator.py b/dimos/navigation/bt_navigator/goal_validator.py index f43a45969c..f0c4a9ce37 100644 --- a/dimos/navigation/bt_navigator/goal_validator.py +++ b/dimos/navigation/bt_navigator/goal_validator.py @@ -13,10 +13,10 @@ # limitations under the License. from collections import deque -from typing import Optional, Tuple import numpy as np -from dimos.msgs.geometry_msgs import VectorLike, Vector3 + +from dimos.msgs.geometry_msgs import Vector3, VectorLike from dimos.msgs.nav_msgs import CostValues, OccupancyGrid @@ -28,7 +28,7 @@ def find_safe_goal( min_clearance: float = 0.3, max_search_distance: float = 5.0, connectivity_check_radius: int = 3, -) -> Optional[Vector3]: +) -> Vector3 | None: """ Find a safe goal position when the original goal is in collision or too close to obstacles. @@ -87,7 +87,7 @@ def _find_safe_goal_bfs( min_clearance: float, max_search_distance: float, connectivity_check_radius: int, -) -> Optional[Vector3]: +) -> Vector3 | None: """ BFS-based search for nearest safe goal position. This guarantees finding the closest valid position. @@ -151,7 +151,7 @@ def _find_safe_goal_spiral( min_clearance: float, max_search_distance: float, connectivity_check_radius: int, -) -> Optional[Vector3]: +) -> Vector3 | None: """ Spiral search pattern from goal outward. @@ -212,7 +212,7 @@ def _find_safe_goal_voronoi( cost_threshold: int, min_clearance: float, max_search_distance: float, -) -> Optional[Vector3]: +) -> Vector3 | None: """ Find safe position using Voronoi diagram (ridge points equidistant from obstacles). @@ -235,7 +235,6 @@ def _find_safe_goal_voronoi( gx, gy = int(goal_grid.x), int(goal_grid.y) # Create binary obstacle map - obstacle_map = costmap.grid >= cost_threshold free_map = (costmap.grid < cost_threshold) & (costmap.grid != CostValues.UNKNOWN) # Compute distance transform @@ -285,7 +284,7 @@ def _find_safe_goal_gradient( min_clearance: float, max_search_distance: float, connectivity_check_radius: int, -) -> Optional[Vector3]: +) -> Vector3 | None: """ Use gradient descent on the costmap to find a safe position. diff --git a/dimos/navigation/bt_navigator/navigator.py b/dimos/navigation/bt_navigator/navigator.py index df1d50cbf2..782e815bb3 100644 --- a/dimos/navigation/bt_navigator/navigator.py +++ b/dimos/navigation/bt_navigator/navigator.py @@ -18,22 +18,22 @@ Navigator module for coordinating global and local planning. """ +from collections.abc import Callable +from enum import Enum import threading import time -from enum import Enum -from typing import Callable, Optional -from dimos.core import Module, In, Out, rpc +from dimos_lcm.std_msgs import Bool, String +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc from dimos.core.rpc_client import RpcCall from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.nav_msgs import OccupancyGrid -from dimos_lcm.std_msgs import String from dimos.navigation.bt_navigator.goal_validator import find_safe_goal from dimos.navigation.bt_navigator.recovery_server import RecoveryServer -from reactivex.disposable import Disposable from dimos.protocol.tf import TF from dimos.utils.logging_config import setup_logger -from dimos_lcm.std_msgs import Bool from dimos.utils.transform_utils import apply_transform logger = setup_logger("dimos.navigation.bt_navigator") @@ -74,10 +74,10 @@ class BehaviorTreeNavigator(Module): def __init__( self, publishing_frequency: float = 1.0, - reset_local_planner: Callable[[], None] = None, - check_goal_reached: Callable[[], bool] = None, + reset_local_planner: Callable[[], None] | None = None, + check_goal_reached: Callable[[], bool] | None = None, **kwargs, - ): + ) -> None: """Initialize the Navigator. Args: @@ -95,19 +95,19 @@ def __init__( self.state_lock = threading.Lock() # Current goal - self.current_goal: Optional[PoseStamped] = None - self.original_goal: Optional[PoseStamped] = None + self.current_goal: PoseStamped | None = None + self.original_goal: PoseStamped | None = None self.goal_lock = threading.Lock() # Goal reached state self._goal_reached = False # Latest data - self.latest_odom: Optional[PoseStamped] = None - self.latest_costmap: Optional[OccupancyGrid] = None + self.latest_odom: PoseStamped | None = None + self.latest_costmap: OccupancyGrid | None = None # Control thread - self.control_thread: Optional[threading.Thread] = None + self.control_thread: threading.Thread | None = None self.stop_event = threading.Event() # TF listener @@ -133,7 +133,7 @@ def set_HolonomicLocalPlanner_is_goal_reached(self, callable: RpcCall) -> None: self.check_goal_reached.set_rpc(self.rpc) @rpc - def start(self): + def start(self) -> None: super().start() # Subscribe to inputs @@ -209,22 +209,22 @@ def get_state(self) -> NavigatorState: """Get the current state of the navigator.""" return self.state - def _on_odom(self, msg: PoseStamped): + def _on_odom(self, msg: PoseStamped) -> None: """Handle incoming odometry messages.""" self.latest_odom = msg if self.state == NavigatorState.FOLLOWING_PATH: self.recovery_server.update_odom(msg) - def _on_goal_request(self, msg: PoseStamped): + def _on_goal_request(self, msg: PoseStamped) -> None: """Handle incoming goal requests.""" self.set_goal(msg) - def _on_costmap(self, msg: OccupancyGrid): + def _on_costmap(self, msg: OccupancyGrid) -> None: """Handle incoming costmap messages.""" self.latest_costmap = msg - def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> Optional[PoseStamped]: + def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> PoseStamped | None: """Transform goal pose to the odometry frame.""" if not goal.frame_id: return goal @@ -270,7 +270,7 @@ def _transform_goal_to_odom_frame(self, goal: PoseStamped) -> Optional[PoseStamp logger.error(f"Failed to transform goal: {e}") return None - def _control_loop(self): + def _control_loop(self) -> None: """Main control loop running in separate thread.""" while not self.stop_event.is_set(): with self.state_lock: diff --git a/dimos/navigation/bt_navigator/recovery_server.py b/dimos/navigation/bt_navigator/recovery_server.py index a5afa3b090..5b05d35de5 100644 --- a/dimos/navigation/bt_navigator/recovery_server.py +++ b/dimos/navigation/bt_navigator/recovery_server.py @@ -18,8 +18,6 @@ Recovery server for handling stuck detection and recovery behaviors. """ -from collections import deque - from dimos.msgs.geometry_msgs import PoseStamped from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import get_distance @@ -39,7 +37,7 @@ def __init__( self, position_threshold: float = 0.2, stuck_duration: float = 3.0, - ): + ) -> None: """Initialize the recovery server. Args: diff --git a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py index 64d238602d..ed5f364a74 100644 --- a/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/test_wavefront_frontier_goal_selector.py @@ -15,10 +15,10 @@ import time import numpy as np +from PIL import ImageDraw import pytest -from PIL import Image, ImageDraw -from dimos.msgs.geometry_msgs import PoseStamped, Vector3 +from dimos.msgs.geometry_msgs import Vector3 from dimos.msgs.nav_msgs import CostValues, OccupancyGrid from dimos.navigation.frontier_exploration.utils import costmap_to_pil_image from dimos.navigation.frontier_exploration.wavefront_frontier_goal_selector import ( @@ -71,13 +71,13 @@ def quick_costmap(): ) class MockLidar: - def __init__(self): + def __init__(self) -> None: self.origin = Vector3(0.0, 0.0, 0.0) return occupancy_grid, MockLidar() -def create_test_costmap(width=40, height=40, resolution=0.1): +def create_test_costmap(width: int = 40, height: int = 40, resolution: float = 0.1): """Create a simple test costmap with free, occupied, and unknown regions. Default size reduced from 100x100 to 40x40 for faster tests. @@ -114,13 +114,13 @@ def create_test_costmap(width=40, height=40, resolution=0.1): # Create a mock lidar message with origin class MockLidar: - def __init__(self): + def __init__(self) -> None: self.origin = Vector3(0.0, 0.0, 0.0) return occupancy_grid, MockLidar() -def test_frontier_detection_with_office_lidar(explorer, quick_costmap): +def test_frontier_detection_with_office_lidar(explorer, quick_costmap) -> None: """Test frontier detection using a test costmap.""" # Get test costmap costmap, first_lidar = quick_costmap @@ -164,7 +164,7 @@ def test_frontier_detection_with_office_lidar(explorer, quick_costmap): explorer.stop() # TODO: this should be a in try-finally -def test_exploration_goal_selection(explorer): +def test_exploration_goal_selection(explorer) -> None: """Test the complete exploration goal selection pipeline.""" # Get test costmap - use regular size for more realistic test costmap, first_lidar = create_test_costmap() @@ -198,7 +198,7 @@ def test_exploration_goal_selection(explorer): explorer.stop() # TODO: this should be a in try-finally -def test_exploration_session_reset(explorer): +def test_exploration_session_reset(explorer) -> None: """Test exploration session reset functionality.""" # Get test costmap costmap, first_lidar = create_test_costmap() @@ -229,7 +229,7 @@ def test_exploration_session_reset(explorer): explorer.stop() # TODO: this should be a in try-finally -def test_frontier_ranking(explorer): +def test_frontier_ranking(explorer) -> None: """Test frontier ranking and scoring logic.""" # Get test costmap costmap, first_lidar = create_test_costmap() @@ -275,7 +275,7 @@ def test_frontier_ranking(explorer): explorer.stop() # TODO: this should be a in try-finally -def test_exploration_with_no_gain_detection(): +def test_exploration_with_no_gain_detection() -> None: """Test information gain detection and exploration termination.""" # Get initial costmap costmap1, first_lidar = create_test_costmap() @@ -313,7 +313,7 @@ def test_exploration_with_no_gain_detection(): @pytest.mark.vis -def test_frontier_detection_visualization(): +def test_frontier_detection_visualization() -> None: """Test frontier detection with visualization (marked with @pytest.mark.vis).""" # Get test costmap costmap, first_lidar = create_test_costmap() @@ -398,7 +398,7 @@ def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: explorer.stop() -def test_performance_timing(): +def test_performance_timing() -> None: """Test performance by timing frontier detection operations.""" import time @@ -427,7 +427,7 @@ def test_performance_timing(): # Time goal selection start = time.time() - goal = explorer.get_exploration_goal(robot_pose, costmap) + explorer.get_exploration_goal(robot_pose, costmap) goal_time = time.time() - start results.append( diff --git a/dimos/navigation/frontier_exploration/utils.py b/dimos/navigation/frontier_exploration/utils.py index 680af142fb..d307749531 100644 --- a/dimos/navigation/frontier_exploration/utils.py +++ b/dimos/navigation/frontier_exploration/utils.py @@ -18,12 +18,9 @@ import numpy as np from PIL import Image, ImageDraw -from typing import List, Tuple -from dimos.msgs.nav_msgs import OccupancyGrid, CostValues + from dimos.msgs.geometry_msgs import Vector3 -import os -import pickle -import cv2 +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid def costmap_to_pil_image(costmap: OccupancyGrid, scale_factor: int = 2) -> Image.Image: @@ -70,9 +67,9 @@ def costmap_to_pil_image(costmap: OccupancyGrid, scale_factor: int = 2) -> Image def draw_frontiers_on_image( image: Image.Image, costmap: OccupancyGrid, - frontiers: List[Vector3], + frontiers: list[Vector3], scale_factor: int = 2, - unfiltered_frontiers: List[Vector3] = None, + unfiltered_frontiers: list[Vector3] | None = None, ) -> Image.Image: """ Draw frontier points on the costmap image. @@ -90,7 +87,7 @@ def draw_frontiers_on_image( img_copy = image.copy() draw = ImageDraw.Draw(img_copy) - def world_to_image_coords(world_pos: Vector3) -> Tuple[int, int]: + def world_to_image_coords(world_pos: Vector3) -> tuple[int, int]: """Convert world coordinates to image pixel coordinates.""" grid_pos = costmap.world_to_grid(world_pos) # Flip Y coordinate and apply scaling diff --git a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py index a1ce4e8075..71677635f5 100644 --- a/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py +++ b/dimos/navigation/frontier_exploration/wavefront_frontier_goal_selector.py @@ -19,21 +19,20 @@ for autonomous navigation using the dimos Costmap and Vector types. """ -import threading from collections import deque from dataclasses import dataclass from enum import IntFlag -from typing import List, Optional, Tuple +import threading +from dimos_lcm.std_msgs import Bool import numpy as np +from reactivex.disposable import Disposable -from dimos.core import Module, In, Out, rpc +from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Vector3 -from dimos.msgs.nav_msgs import OccupancyGrid, CostValues +from dimos.msgs.nav_msgs import CostValues, OccupancyGrid from dimos.utils.logging_config import setup_logger -from dimos_lcm.std_msgs import Bool from dimos.utils.transform_utils import get_distance -from reactivex.disposable import Disposable logger = setup_logger("dimos.robot.unitree.frontier_exploration") @@ -60,7 +59,7 @@ class GridPoint: class FrontierCache: """Cache for grid points to avoid duplicate point creation.""" - def __init__(self): + def __init__(self) -> None: self.points = {} def get_point(self, x: int, y: int) -> GridPoint: @@ -70,7 +69,7 @@ def get_point(self, x: int, y: int) -> GridPoint: self.points[key] = GridPoint(x, y) return self.points[key] - def clear(self): + def clear(self) -> None: """Clear the point cache.""" self.points.clear() @@ -111,7 +110,7 @@ def __init__( num_no_gain_attempts: int = 2, goal_timeout: float = 15.0, **kwargs, - ): + ) -> None: """ Initialize the frontier explorer. @@ -138,21 +137,21 @@ def __init__( self.goal_timeout = goal_timeout # Latest data - self.latest_costmap: Optional[OccupancyGrid] = None - self.latest_odometry: Optional[PoseStamped] = None + self.latest_costmap: OccupancyGrid | None = None + self.latest_odometry: PoseStamped | None = None # Goal reached event self.goal_reached_event = threading.Event() # Exploration state self.exploration_active = False - self.exploration_thread: Optional[threading.Thread] = None + self.exploration_thread: threading.Thread | None = None self.stop_event = threading.Event() logger.info("WavefrontFrontierExplorer module initialized") @rpc - def start(self): + def start(self) -> None: super().start() unsub = self.global_costmap.subscribe(self._on_costmap) @@ -178,26 +177,26 @@ def stop(self) -> None: self.stop_exploration() super().stop() - def _on_costmap(self, msg: OccupancyGrid): + def _on_costmap(self, msg: OccupancyGrid) -> None: """Handle incoming costmap messages.""" self.latest_costmap = msg - def _on_odometry(self, msg: PoseStamped): + def _on_odometry(self, msg: PoseStamped) -> None: """Handle incoming odometry messages.""" self.latest_odometry = msg - def _on_goal_reached(self, msg: Bool): + def _on_goal_reached(self, msg: Bool) -> None: """Handle goal reached messages.""" if msg.data: self.goal_reached_event.set() - def _on_explore_cmd(self, msg: Bool): + def _on_explore_cmd(self, msg: Bool) -> None: """Handle exploration command messages.""" if msg.data: logger.info("Received exploration start command via LCM") self.explore() - def _on_stop_explore_cmd(self, msg: Bool): + def _on_stop_explore_cmd(self, msg: Bool) -> None: """Handle stop exploration command messages.""" if msg.data: logger.info("Received exploration stop command via LCM") @@ -217,7 +216,7 @@ def _count_costmap_information(self, costmap: OccupancyGrid) -> int: obstacle_count = np.sum(costmap.grid >= self.occupancy_threshold) return int(free_count + obstacle_count) - def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> List[GridPoint]: + def _get_neighbors(self, point: GridPoint, costmap: OccupancyGrid) -> list[GridPoint]: """Get valid neighboring points for a given grid point.""" neighbors = [] @@ -263,7 +262,7 @@ def _is_frontier_point(self, point: GridPoint, costmap: OccupancyGrid) -> bool: def _find_free_space( self, start_x: int, start_y: int, costmap: OccupancyGrid - ) -> Tuple[int, int]: + ) -> tuple[int, int]: """ Find the nearest free space point using BFS from the starting position. """ @@ -289,7 +288,7 @@ def _find_free_space( # If no free space found, return original position return (start_x, start_y) - def _compute_centroid(self, frontier_points: List[Vector3]) -> Vector3: + def _compute_centroid(self, frontier_points: list[Vector3]) -> Vector3: """Compute the centroid of a list of frontier points.""" if not frontier_points: return Vector3(0.0, 0.0, 0.0) @@ -300,7 +299,7 @@ def _compute_centroid(self, frontier_points: List[Vector3]) -> Vector3: return Vector3(centroid[0], centroid[1], 0.0) - def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> List[Vector3]: + def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> list[Vector3]: """ Main frontier detection algorithm using wavefront exploration. @@ -418,8 +417,8 @@ def detect_frontiers(self, robot_pose: Vector3, costmap: OccupancyGrid) -> List[ return ranked_frontiers def _update_exploration_direction( - self, robot_pose: Vector3, goal_pose: Optional[Vector3] = None - ): + self, robot_pose: Vector3, goal_pose: Vector3 | None = None + ) -> None: """Update the current exploration direction based on robot movement or selected goal.""" if goal_pose is not None: # Calculate direction from robot to goal @@ -567,11 +566,11 @@ def _compute_comprehensive_frontier_score( def _rank_frontiers( self, - frontier_centroids: List[Vector3], - frontier_sizes: List[int], + frontier_centroids: list[Vector3], + frontier_sizes: list[int], robot_pose: Vector3, costmap: OccupancyGrid, - ) -> List[Vector3]: + ) -> list[Vector3]: """ Find the single best frontier using comprehensive scoring and filtering. @@ -609,9 +608,7 @@ def _rank_frontiers( # Extract just the frontiers (remove scores) and return as list return [frontier for frontier, _ in valid_frontiers] - def get_exploration_goal( - self, robot_pose: Vector3, costmap: OccupancyGrid - ) -> Optional[Vector3]: + def get_exploration_goal(self, robot_pose: Vector3, costmap: OccupancyGrid) -> Vector3 | None: """ Get the single best exploration goal using comprehensive frontier scoring. @@ -675,11 +672,11 @@ def get_exploration_goal( self.last_costmap = costmap return None - def mark_explored_goal(self, goal: Vector3): + def mark_explored_goal(self, goal: Vector3) -> None: """Mark a goal as explored.""" self.explored_goals.append(goal) - def reset_exploration_session(self): + def reset_exploration_session(self) -> None: """ Reset all exploration state variables for a new exploration session. @@ -746,7 +743,7 @@ def stop_exploration(self) -> bool: def is_exploration_active(self) -> bool: return self.exploration_active - def _exploration_loop(self): + def _exploration_loop(self) -> None: """Main exploration loop running in separate thread.""" # Track number of goals published goals_published = 0 diff --git a/dimos/navigation/global_planner/__init__.py b/dimos/navigation/global_planner/__init__.py index 9aaf52e11e..275619659b 100644 --- a/dimos/navigation/global_planner/__init__.py +++ b/dimos/navigation/global_planner/__init__.py @@ -1,4 +1,4 @@ -from dimos.navigation.global_planner.planner import AstarPlanner, astar_planner from dimos.navigation.global_planner.algo import astar +from dimos.navigation.global_planner.planner import AstarPlanner, astar_planner -__all__ = ["AstarPlanner", "astar_planner", "astar"] +__all__ = ["AstarPlanner", "astar", "astar_planner"] diff --git a/dimos/navigation/global_planner/algo.py b/dimos/navigation/global_planner/algo.py index 08cae6d147..16f8dc3600 100644 --- a/dimos/navigation/global_planner/algo.py +++ b/dimos/navigation/global_planner/algo.py @@ -13,8 +13,6 @@ # limitations under the License. import heapq -import math -from typing import Optional from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, VectorLike from dimos.msgs.nav_msgs import CostValues, OccupancyGrid, Path @@ -29,7 +27,7 @@ def astar( start: VectorLike = (0.0, 0.0), cost_threshold: int = 90, unknown_penalty: float = 0.8, -) -> Optional[Path]: +) -> Path | None: """ A* path planning algorithm from start to goal position. @@ -99,7 +97,7 @@ def heuristic(x1, y1, x2, y2): while open_set: # Get the node with the lowest f_score - current_f, current = heapq.heappop(open_set) + _current_f, current = heapq.heappop(open_set) current_x, current_y = current # Remove from open set hash diff --git a/dimos/navigation/global_planner/planner.py b/dimos/navigation/global_planner/planner.py index f9df988cfe..89ac134b08 100644 --- a/dimos/navigation/global_planner/planner.py +++ b/dimos/navigation/global_planner/planner.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional + +from reactivex.disposable import Disposable from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import Pose, PoseStamped @@ -20,11 +21,11 @@ from dimos.navigation.global_planner.algo import astar from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import euler_to_quaternion -from reactivex.disposable import Disposable logger = setup_logger(__file__) import math + from dimos.msgs.geometry_msgs import Quaternion, Vector3 @@ -148,15 +149,15 @@ class AstarPlanner(Module): # LCM outputs path: Out[Path] = None - def __init__(self): + def __init__(self) -> None: super().__init__() # Latest data - self.latest_costmap: Optional[OccupancyGrid] = None - self.latest_odom: Optional[PoseStamped] = None + self.latest_costmap: OccupancyGrid | None = None + self.latest_odom: PoseStamped | None = None @rpc - def start(self): + def start(self) -> None: super().start() unsub = self.target.subscribe(self._on_target) @@ -174,15 +175,15 @@ def start(self): def stop(self) -> None: super().stop() - def _on_costmap(self, msg: OccupancyGrid): + def _on_costmap(self, msg: OccupancyGrid) -> None: """Handle incoming costmap messages.""" self.latest_costmap = msg - def _on_odom(self, msg: PoseStamped): + def _on_odom(self, msg: PoseStamped) -> None: """Handle incoming odometry messages.""" self.latest_odom = msg - def _on_target(self, msg: PoseStamped): + def _on_target(self, msg: PoseStamped) -> None: """Handle incoming target messages and trigger planning.""" if self.latest_costmap is None or self.latest_odom is None: logger.warning("Cannot plan: missing costmap or odometry data") @@ -194,7 +195,7 @@ def _on_target(self, msg: PoseStamped): path = add_orientations_to_path(path, msg.orientation) self.path.publish(path) - def plan(self, goal: Pose) -> Optional[Path]: + def plan(self, goal: Pose) -> Path | None: """Plan a path from current position to goal.""" if self.latest_costmap is None or self.latest_odom is None: logger.warning("Cannot plan: missing costmap or odometry data") diff --git a/dimos/navigation/local_planner/__init__.py b/dimos/navigation/local_planner/__init__.py index f6b97d6762..9e0f62931a 100644 --- a/dimos/navigation/local_planner/__init__.py +++ b/dimos/navigation/local_planner/__init__.py @@ -1,2 +1,2 @@ -from dimos.navigation.local_planner.local_planner import BaseLocalPlanner from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner +from dimos.navigation.local_planner.local_planner import BaseLocalPlanner diff --git a/dimos/navigation/local_planner/holonomic_local_planner.py b/dimos/navigation/local_planner/holonomic_local_planner.py index 94624fc65e..bd41fe2a8d 100644 --- a/dimos/navigation/local_planner/holonomic_local_planner.py +++ b/dimos/navigation/local_planner/holonomic_local_planner.py @@ -18,13 +18,11 @@ Gradient-Augmented Look-Ahead Pursuit (GLAP) holonomic local planner. """ -from typing import Optional, Tuple - import numpy as np from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.navigation.local_planner import BaseLocalPlanner -from dimos.utils.transform_utils import quaternion_to_euler, normalize_angle, get_distance +from dimos.utils.transform_utils import get_distance, normalize_angle, quaternion_to_euler class HolonomicLocalPlanner(BaseLocalPlanner): @@ -54,7 +52,7 @@ def __init__( orientation_tolerance: float = 0.2, control_frequency: float = 10.0, **kwargs, - ): + ) -> None: """Initialize the GLAP planner with specified parameters.""" super().__init__( goal_tolerance=goal_tolerance, @@ -73,7 +71,7 @@ def __init__( # Previous velocity for filtering (vx, vy, vtheta) self.v_prev = np.array([0.0, 0.0, 0.0]) - def compute_velocity(self) -> Optional[Twist]: + def compute_velocity(self) -> Twist | None: """ Compute velocity commands using GLAP algorithm. @@ -216,7 +214,7 @@ def _compute_obstacle_repulsion(self, pose: np.ndarray, costmap: np.ndarray) -> def _find_closest_point_on_path( self, pose: np.ndarray, path: np.ndarray - ) -> Tuple[int, np.ndarray]: + ) -> tuple[int, np.ndarray]: """ Find the closest point on the path to current pose. diff --git a/dimos/navigation/local_planner/local_planner.py b/dimos/navigation/local_planner/local_planner.py index ac1a6ea744..0a569f00ed 100644 --- a/dimos/navigation/local_planner/local_planner.py +++ b/dimos/navigation/local_planner/local_planner.py @@ -19,17 +19,17 @@ Subscribes to local costmap, odometry, and path, publishes movement commands. """ +from abc import abstractmethod import threading import time -from abc import abstractmethod -from typing import Optional -from dimos.core import Module, In, Out, rpc -from dimos.msgs.geometry_msgs import Twist, PoseStamped +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped, Twist from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger -from dimos.utils.transform_utils import get_distance, quaternion_to_euler, normalize_angle -from reactivex.disposable import Disposable +from dimos.utils.transform_utils import get_distance, normalize_angle, quaternion_to_euler logger = setup_logger(__file__) @@ -61,7 +61,7 @@ def __init__( orientation_tolerance: float = 0.2, control_frequency: float = 10.0, **kwargs, - ): + ) -> None: """Initialize the local planner module. Args: @@ -78,18 +78,18 @@ def __init__( self.control_period = 1.0 / control_frequency # Latest data - self.latest_costmap: Optional[OccupancyGrid] = None - self.latest_odom: Optional[PoseStamped] = None - self.latest_path: Optional[Path] = None + self.latest_costmap: OccupancyGrid | None = None + self.latest_odom: PoseStamped | None = None + self.latest_path: Path | None = None # Control thread - self.planning_thread: Optional[threading.Thread] = None + self.planning_thread: threading.Thread | None = None self.stop_planning = threading.Event() logger.info("Local planner module initialized") @rpc - def start(self): + def start(self) -> None: super().start() unsub = self.local_costmap.subscribe(self._on_costmap) @@ -106,27 +106,27 @@ def stop(self) -> None: self.cancel_planning() super().stop() - def _on_costmap(self, msg: OccupancyGrid): + def _on_costmap(self, msg: OccupancyGrid) -> None: self.latest_costmap = msg - def _on_odom(self, msg: PoseStamped): + def _on_odom(self, msg: PoseStamped) -> None: self.latest_odom = msg - def _on_path(self, msg: Path): + def _on_path(self, msg: Path) -> None: self.latest_path = msg if msg and len(msg.poses) > 0: if self.planning_thread is None or not self.planning_thread.is_alive(): self._start_planning_thread() - def _start_planning_thread(self): + def _start_planning_thread(self) -> None: """Start the planning thread.""" self.stop_planning.clear() self.planning_thread = threading.Thread(target=self._follow_path_loop, daemon=True) self.planning_thread.start() logger.debug("Started follow path thread") - def _follow_path_loop(self): + def _follow_path_loop(self) -> None: """Main planning loop that runs in a separate thread.""" while not self.stop_planning.is_set(): if self.is_goal_reached(): @@ -140,7 +140,7 @@ def _follow_path_loop(self): time.sleep(self.control_period) - def _plan(self): + def _plan(self) -> None: """Compute and publish velocity command.""" cmd_vel = self.compute_velocity() @@ -148,7 +148,7 @@ def _plan(self): self.cmd_vel.publish(cmd_vel) @abstractmethod - def compute_velocity(self) -> Optional[Twist]: + def compute_velocity(self) -> Twist | None: """ Compute velocity commands based on current costmap, odometry, and path. Must be implemented by derived classes. @@ -189,7 +189,7 @@ def is_goal_reached(self) -> bool: return abs(yaw_error) < self.orientation_tolerance @rpc - def reset(self): + def reset(self) -> None: """Reset the local planner state, clearing the current path.""" # Clear the latest path self.latest_path = None diff --git a/dimos/navigation/local_planner/test_base_local_planner.py b/dimos/navigation/local_planner/test_base_local_planner.py index dc76bca83a..8786b1a925 100644 --- a/dimos/navigation/local_planner/test_base_local_planner.py +++ b/dimos/navigation/local_planner/test_base_local_planner.py @@ -22,7 +22,7 @@ import pytest from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion -from dimos.msgs.nav_msgs import Path, OccupancyGrid +from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner @@ -55,7 +55,7 @@ def empty_costmap(self): costmap.origin.position.y = -5.0 return costmap - def test_straight_path_no_obstacles(self, planner, empty_costmap): + def test_straight_path_no_obstacles(self, planner, empty_costmap) -> None: """Test that planner follows straight path with no obstacles.""" # Set current position at origin planner.latest_odom = PoseStamped() @@ -84,7 +84,7 @@ def test_straight_path_no_obstacles(self, planner, empty_costmap): assert abs(vel.linear.y) < 0.1 # Near zero assert abs(vel.angular.z) < 0.1 # Small angular velocity when aligned with path - def test_obstacle_gradient_repulsion(self, planner): + def test_obstacle_gradient_repulsion(self, planner) -> None: """Test that obstacle gradients create repulsive forces.""" # Set position at origin planner.latest_odom = PoseStamped() @@ -116,7 +116,7 @@ def test_obstacle_gradient_repulsion(self, planner): assert vel is not None assert vel.linear.y > 0.1 # Repulsion pushes north - def test_lowpass_filter(self): + def test_lowpass_filter(self) -> None: """Test that low-pass filter smooths velocity commands.""" # Create planner with alpha=0.5 for filtering planner = HolonomicLocalPlanner( @@ -164,7 +164,7 @@ def test_lowpass_filter(self): assert 0 < vel2.linear.x <= planner.v_max # Should still be positive and within limits planner._close_module() - def test_no_path(self, planner, empty_costmap): + def test_no_path(self, planner, empty_costmap) -> None: """Test that planner returns None when no path is available.""" planner.latest_odom = PoseStamped() planner.latest_costmap = empty_costmap @@ -173,7 +173,7 @@ def test_no_path(self, planner, empty_costmap): vel = planner.compute_velocity() assert vel is None - def test_no_odometry(self, planner, empty_costmap): + def test_no_odometry(self, planner, empty_costmap) -> None: """Test that planner returns None when no odometry is available.""" planner.latest_odom = None planner.latest_costmap = empty_costmap @@ -188,7 +188,7 @@ def test_no_odometry(self, planner, empty_costmap): vel = planner.compute_velocity() assert vel is None - def test_no_costmap(self, planner): + def test_no_costmap(self, planner) -> None: """Test that planner returns None when no costmap is available.""" planner.latest_odom = PoseStamped() planner.latest_costmap = None @@ -203,7 +203,7 @@ def test_no_costmap(self, planner): vel = planner.compute_velocity() assert vel is None - def test_goal_reached(self, planner, empty_costmap): + def test_goal_reached(self, planner, empty_costmap) -> None: """Test velocity when robot is at goal.""" # Set robot at goal position planner.latest_odom = PoseStamped() @@ -229,7 +229,7 @@ def test_goal_reached(self, planner, empty_costmap): assert abs(vel.linear.x) < 0.1 assert abs(vel.linear.y) < 0.1 - def test_velocity_saturation(self, planner, empty_costmap): + def test_velocity_saturation(self, planner, empty_costmap) -> None: """Test that velocities are capped at v_max.""" # Set robot far from goal to maximize commanded velocity planner.latest_odom = PoseStamped() @@ -256,7 +256,7 @@ def test_velocity_saturation(self, planner, empty_costmap): assert abs(vel.linear.y) <= planner.v_max + 0.01 assert abs(vel.angular.z) <= planner.v_max + 0.01 - def test_lookahead_interpolation(self, planner, empty_costmap): + def test_lookahead_interpolation(self, planner, empty_costmap) -> None: """Test that lookahead point is correctly interpolated on path.""" # Set robot at origin planner.latest_odom = PoseStamped() @@ -283,7 +283,7 @@ def test_lookahead_interpolation(self, planner, empty_costmap): assert vel.linear.x > 0.5 # Moving forward assert abs(vel.linear.y) < 0.1 # Staying on path - def test_curved_path_following(self, planner, empty_costmap): + def test_curved_path_following(self, planner, empty_costmap) -> None: """Test following a curved path.""" # Set robot at origin planner.latest_odom = PoseStamped() @@ -315,7 +315,7 @@ def test_curved_path_following(self, planner, empty_costmap): total_linear = np.sqrt(vel.linear.x**2 + vel.linear.y**2) assert total_linear > 0.1 # Some reasonable movement - def test_robot_frame_transformation(self, empty_costmap): + def test_robot_frame_transformation(self, empty_costmap) -> None: """Test that velocities are correctly transformed to robot frame.""" # Create planner with no filtering for deterministic test planner = HolonomicLocalPlanner( @@ -359,7 +359,7 @@ def test_robot_frame_transformation(self, empty_costmap): assert abs(vel.linear.x) < abs(vel.linear.y) # Lateral movement dominates planner._close_module() - def test_angular_velocity_computation(self, empty_costmap): + def test_angular_velocity_computation(self, empty_costmap) -> None: """Test that angular velocity is computed to align with path.""" planner = HolonomicLocalPlanner( lookahead_dist=2.0, diff --git a/dimos/navigation/rosnav.py b/dimos/navigation/rosnav.py index d74da612d8..f0d04926d3 100644 --- a/dimos/navigation/rosnav.py +++ b/dimos/navigation/rosnav.py @@ -18,26 +18,25 @@ Encapsulates ROS bridge and topic remapping for Unitree robots. """ +from collections.abc import Generator +from dataclasses import dataclass import logging import threading import time -from dataclasses import dataclass -from typing import Generator, Optional - -import rclpy -from geometry_msgs.msg import PointStamped as ROSPointStamped -from geometry_msgs.msg import PoseStamped as ROSPoseStamped # ROS2 message imports -from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from geometry_msgs.msg import ( + PointStamped as ROSPointStamped, + PoseStamped as ROSPoseStamped, + TwistStamped as ROSTwistStamped, +) from nav_msgs.msg import Path as ROSPath +import rclpy from rclpy.node import Node from reactivex import operators as ops from reactivex.subject import Subject -from sensor_msgs.msg import Joy as ROSJoy -from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 -from std_msgs.msg import Bool as ROSBool -from std_msgs.msg import Int8 as ROSInt8 +from sensor_msgs.msg import Joy as ROSJoy, PointCloud2 as ROSPointCloud2 +from std_msgs.msg import Bool as ROSBool, Int8 as ROSInt8 from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos import spec @@ -88,10 +87,10 @@ class ROSNav(Module, spec.Nav, spec.Global3DMap, spec.Pointcloud, spec.LocalPlan _global_pointcloud_subject: Subject _current_position_running: bool = False - _spin_thread: Optional[threading.Thread] = None - _goal_reach: Optional[bool] = None + _spin_thread: threading.Thread | None = None + _goal_reach: bool | None = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) # Initialize RxPY Subjects for streaming data @@ -132,7 +131,7 @@ def __init__(self, *args, **kwargs): logger.info("NavigationModule initialized with ROS2 node") @rpc - def start(self): + def start(self) -> None: self._running = True self._disposables.add( @@ -164,7 +163,7 @@ def start(self): self.goal_req.subscribe(self._on_goal_pose) logger.info("NavigationModule started with ROS2 spinning and RxPY streams") - def _spin_node(self): + def _spin_node(self) -> None: while self._running and rclpy.ok(): try: rclpy.spin_once(self._node, timeout_sec=0.1) @@ -172,10 +171,10 @@ def _spin_node(self): if self._running: logger.error(f"ROS2 spin error: {e}") - def _on_ros_goal_reached(self, msg: ROSBool): + def _on_ros_goal_reached(self, msg: ROSBool) -> None: self._goal_reach = msg.data - def _on_ros_goal_waypoint(self, msg: ROSPointStamped): + def _on_ros_goal_waypoint(self, msg: ROSPointStamped) -> None: dimos_pose = PoseStamped( ts=time.time(), frame_id=msg.header.frame_id, @@ -184,21 +183,21 @@ def _on_ros_goal_waypoint(self, msg: ROSPointStamped): ) self.goal_active.publish(dimos_pose) - def _on_ros_cmd_vel(self, msg: ROSTwistStamped): + def _on_ros_cmd_vel(self, msg: ROSTwistStamped) -> None: self.cmd_vel.publish(TwistStamped.from_ros_msg(msg)) - def _on_ros_registered_scan(self, msg: ROSPointCloud2): + def _on_ros_registered_scan(self, msg: ROSPointCloud2) -> None: self._local_pointcloud_subject.on_next(msg) - def _on_ros_global_pointcloud(self, msg: ROSPointCloud2): + def _on_ros_global_pointcloud(self, msg: ROSPointCloud2) -> None: self._global_pointcloud_subject.on_next(msg) - def _on_ros_path(self, msg: ROSPath): + def _on_ros_path(self, msg: ROSPath) -> None: dimos_path = Path.from_ros_msg(msg) dimos_path.frame_id = "base_link" self.path_active.publish(dimos_path) - def _on_ros_tf(self, msg: ROSTFMessage): + def _on_ros_tf(self, msg: ROSTFMessage) -> None: ros_tf = TFMessage.from_ros_msg(msg) map_to_world_tf = Transform( @@ -215,14 +214,14 @@ def _on_ros_tf(self, msg: ROSTFMessage): *ros_tf.transforms, ) - def _on_goal_pose(self, msg: PoseStamped): + def _on_goal_pose(self, msg: PoseStamped) -> None: self.navigate_to(msg) - def _on_cancel_goal(self, msg: Bool): + def _on_cancel_goal(self, msg: Bool) -> None: if msg.data: self.stop() - def _set_autonomy_mode(self): + def _set_autonomy_mode(self) -> None: joy_msg = ROSJoy() joy_msg.axes = [ 0.0, # axis 0 @@ -363,7 +362,7 @@ def stop_navigation(self) -> bool: return True @rpc - def stop(self): + def stop(self) -> None: """Stop the navigation module and clean up resources.""" self.stop_navigation() try: diff --git a/dimos/navigation/visual/query.py b/dimos/navigation/visual/query.py index 7f54664c31..45a0ede40d 100644 --- a/dimos/navigation/visual/query.py +++ b/dimos/navigation/visual/query.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from dimos.models.qwen.video_query import BBox from dimos.models.vl.base import VlModel @@ -22,7 +21,7 @@ def get_object_bbox_from_image( vl_model: VlModel, image: Image, object_description: str -) -> Optional[BBox]: +) -> BBox | None: prompt = ( f"Look at this image and find the '{object_description}'. " "Return ONLY a JSON object with format: {'name': 'object_name', 'bbox': [x1, y1, x2, y2]} " diff --git a/dimos/perception/common/__init__.py b/dimos/perception/common/__init__.py index e658a8734c..67481bc449 100644 --- a/dimos/perception/common/__init__.py +++ b/dimos/perception/common/__init__.py @@ -1,3 +1,3 @@ -from .detection2d_tracker import target2dTracker, get_tracked_results +from .detection2d_tracker import get_tracked_results, target2dTracker from .ibvs import * from .utils import * diff --git a/dimos/perception/common/detection2d_tracker.py b/dimos/perception/common/detection2d_tracker.py index 2e4582cc00..7645acd380 100644 --- a/dimos/perception/common/detection2d_tracker.py +++ b/dimos/perception/common/detection2d_tracker.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np from collections import deque +from collections.abc import Sequence + +import numpy as np def compute_iou(bbox1, bbox2): @@ -80,12 +82,12 @@ def __init__( initial_mask, initial_bbox, track_id, - prob, - name, + prob: float, + name: str, texture_value, target_id, - history_size=10, - ): + history_size: int = 10, + ) -> None: """ Args: initial_mask (torch.Tensor): Latest segmentation mask. @@ -111,7 +113,7 @@ def __init__( self.missed_frames = 0 # Consecutive frames when no detection was assigned. self.history_size = history_size - def update(self, mask, bbox, track_id, prob, name, texture_value): + def update(self, mask, bbox, track_id, prob: float, name: str, texture_value) -> None: """ Update the target with a new detection. """ @@ -126,7 +128,7 @@ def update(self, mask, bbox, track_id, prob, name, texture_value): self.frame_count.append(1) self.missed_frames = 0 - def mark_missed(self): + def mark_missed(self) -> None: """ Increment the count of consecutive frames where this target was not updated. """ @@ -139,7 +141,7 @@ def compute_score( min_area_ratio, max_area_ratio, texture_range=(0.0, 1.0), - border_safe_distance=50, + border_safe_distance: int = 50, weights=None, ): """ @@ -249,17 +251,17 @@ class target2dTracker: def __init__( self, - history_size=10, - score_threshold_start=0.5, - score_threshold_stop=0.3, - min_frame_count=10, - max_missed_frames=3, - min_area_ratio=0.001, - max_area_ratio=0.1, + history_size: int = 10, + score_threshold_start: float = 0.5, + score_threshold_stop: float = 0.3, + min_frame_count: int = 10, + max_missed_frames: int = 3, + min_area_ratio: float = 0.001, + max_area_ratio: float = 0.1, texture_range=(0.0, 1.0), - border_safe_distance=50, + border_safe_distance: int = 50, weights=None, - ): + ) -> None: """ Args: history_size (int): Maximum history length (number of frames) per target. @@ -291,7 +293,16 @@ def __init__( self.targets = {} # Dictionary mapping target_id -> target2d instance. self.next_target_id = 0 - def update(self, frame, masks, bboxes, track_ids, probs, names, texture_values): + def update( + self, + frame, + masks, + bboxes, + track_ids, + probs: Sequence[float], + names: Sequence[str], + texture_values, + ): """ Update the tracker with new detections from the current frame. @@ -313,7 +324,7 @@ def update(self, frame, masks, bboxes, track_ids, probs, names, texture_values): # For each detection, try to match with an existing target. for mask, bbox, det_tid, prob, name, texture in zip( - masks, bboxes, track_ids, probs, names, texture_values + masks, bboxes, track_ids, probs, names, texture_values, strict=False ): matched_target = None diff --git a/dimos/perception/common/export_tensorrt.py b/dimos/perception/common/export_tensorrt.py index 9c021eb0a0..9d73b4ae3f 100644 --- a/dimos/perception/common/export_tensorrt.py +++ b/dimos/perception/common/export_tensorrt.py @@ -13,6 +13,7 @@ # limitations under the License. import argparse + from ultralytics import YOLO, FastSAM @@ -39,7 +40,7 @@ def parse_args(): return parser.parse_args() -def main(): +def main() -> None: args = parse_args() half = args.precision == "fp16" int8 = args.precision == "int8" diff --git a/dimos/perception/common/ibvs.py b/dimos/perception/common/ibvs.py index d580c71b23..2978aff84f 100644 --- a/dimos/perception/common/ibvs.py +++ b/dimos/perception/common/ibvs.py @@ -16,7 +16,7 @@ class PersonDistanceEstimator: - def __init__(self, K, camera_pitch, camera_height): + def __init__(self, K, camera_pitch, camera_height) -> None: """ Initialize the distance estimator using ground plane constraint. @@ -49,7 +49,7 @@ def __init__(self, K, camera_pitch, camera_height): self.fx = K[0, 0] self.cx = K[0, 2] - def estimate_distance_angle(self, bbox: tuple, robot_pitch: float = None): + def estimate_distance_angle(self, bbox: tuple, robot_pitch: float | None = None): """ Estimate distance and angle to person using ground plane constraint. @@ -123,7 +123,7 @@ class ObjectDistanceEstimator: camera's intrinsic parameters to estimate the distance to a detected object. """ - def __init__(self, K, camera_pitch, camera_height): + def __init__(self, K, camera_pitch, camera_height) -> None: """ Initialize the distance estimator using ground plane constraint. @@ -170,7 +170,7 @@ def estimate_object_size(self, bbox: tuple, distance: float): Returns: estimated_size: Estimated physical height of the object (in meters) """ - x_min, y_min, x_max, y_max = bbox + _x_min, y_min, _x_max, y_max = bbox # Calculate object height in pixels object_height_px = y_max - y_min @@ -181,7 +181,7 @@ def estimate_object_size(self, bbox: tuple, distance: float): return estimated_size - def set_estimated_object_size(self, size: float): + def set_estimated_object_size(self, size: float) -> None: """ Set the estimated object size for future distance calculations. diff --git a/dimos/perception/common/utils.py b/dimos/perception/common/utils.py index 1ce3931c2f..de3da4c171 100644 --- a/dimos/perception/common/utils.py +++ b/dimos/perception/common/utils.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Union + import cv2 +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.vision_msgs import BoundingBox2D, Detection2D, Detection3D import numpy as np -from typing import List, Tuple, Optional, Any, Union +import torch +import yaml + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 +from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header from dimos.types.manipulation import ObjectData from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger -from dimos_lcm.vision_msgs import Detection3D, Detection2D, BoundingBox2D -from dimos_lcm.sensor_msgs import CameraInfo -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 -from dimos.msgs.std_msgs import Header -from dimos.msgs.sensor_msgs import Image -import torch -import yaml logger = setup_logger("dimos.perception.common.utils") @@ -66,7 +68,7 @@ def load_camera_info(yaml_path: str, frame_id: str = "camera_link") -> CameraInf Returns: CameraInfo: LCM CameraInfo message with all calibration data """ - with open(yaml_path, "r") as f: + with open(yaml_path) as f: camera_info_data = yaml.safe_load(f) # Extract image dimensions @@ -112,7 +114,7 @@ def load_camera_info(yaml_path: str, frame_id: str = "camera_link") -> CameraInf ) -def load_camera_info_opencv(yaml_path: str) -> Tuple[np.ndarray, np.ndarray]: +def load_camera_info_opencv(yaml_path: str) -> tuple[np.ndarray, np.ndarray]: """ Load ROS-style camera_info YAML file and convert to OpenCV camera matrix and distortion coefficients. @@ -123,7 +125,7 @@ def load_camera_info_opencv(yaml_path: str) -> Tuple[np.ndarray, np.ndarray]: K: 3x3 camera intrinsic matrix dist: 1xN distortion coefficients array (for plumb_bob model) """ - with open(yaml_path, "r") as f: + with open(yaml_path) as f: camera_info = yaml.safe_load(f) # Extract camera matrix (K) @@ -288,7 +290,7 @@ def rectify_image(image: Image, camera_matrix: np.ndarray, dist_coeffs: np.ndarr def project_3d_points_to_2d_cuda( - points_3d: "cp.ndarray", camera_intrinsics: Union[List[float], "cp.ndarray"] + points_3d: "cp.ndarray", camera_intrinsics: Union[list[float], "cp.ndarray"] ) -> "cp.ndarray": xp = cp # type: ignore pts = points_3d.astype(xp.float64, copy=False) @@ -307,7 +309,7 @@ def project_3d_points_to_2d_cuda( def project_3d_points_to_2d_cpu( - points_3d: np.ndarray, camera_intrinsics: Union[List[float], np.ndarray] + points_3d: np.ndarray, camera_intrinsics: list[float] | np.ndarray ) -> np.ndarray: pts = np.asarray(points_3d, dtype=np.float64) valid_mask = pts[:, 2] > 0 @@ -326,7 +328,7 @@ def project_3d_points_to_2d_cpu( def project_3d_points_to_2d( points_3d: Union[np.ndarray, "cp.ndarray"], - camera_intrinsics: Union[List[float], np.ndarray, "cp.ndarray"], + camera_intrinsics: Union[list[float], np.ndarray, "cp.ndarray"], ) -> Union[np.ndarray, "cp.ndarray"]: """ Project 3D points to 2D image coordinates using camera intrinsics. @@ -357,7 +359,7 @@ def project_3d_points_to_2d( def project_2d_points_to_3d_cuda( points_2d: "cp.ndarray", depth_values: "cp.ndarray", - camera_intrinsics: Union[List[float], "cp.ndarray"], + camera_intrinsics: Union[list[float], "cp.ndarray"], ) -> "cp.ndarray": xp = cp # type: ignore pts = points_2d.astype(xp.float64, copy=False) @@ -380,7 +382,7 @@ def project_2d_points_to_3d_cuda( def project_2d_points_to_3d_cpu( points_2d: np.ndarray, depth_values: np.ndarray, - camera_intrinsics: Union[List[float], np.ndarray], + camera_intrinsics: list[float] | np.ndarray, ) -> np.ndarray: pts = np.asarray(points_2d, dtype=np.float64) depths = np.asarray(depth_values, dtype=np.float64) @@ -406,7 +408,7 @@ def project_2d_points_to_3d_cpu( def project_2d_points_to_3d( points_2d: Union[np.ndarray, "cp.ndarray"], depth_values: Union[np.ndarray, "cp.ndarray"], - camera_intrinsics: Union[List[float], np.ndarray, "cp.ndarray"], + camera_intrinsics: Union[list[float], np.ndarray, "cp.ndarray"], ) -> Union[np.ndarray, "cp.ndarray"]: """ Project 2D image points to 3D coordinates using depth values and camera intrinsics. @@ -440,7 +442,7 @@ def project_2d_points_to_3d( def colorize_depth( depth_img: Union[np.ndarray, "cp.ndarray"], max_depth: float = 5.0, overlay_stats: bool = True -) -> Optional[Union[np.ndarray, "cp.ndarray"]]: +) -> Union[np.ndarray, "cp.ndarray"] | None: """ Normalize and colorize depth image using COLORMAP_JET with optional statistics overlay. @@ -579,12 +581,12 @@ def colorize_depth( def draw_bounding_box( image: Union[np.ndarray, "cp.ndarray"], - bbox: List[float], - color: Tuple[int, int, int] = (0, 255, 0), + bbox: list[float], + color: tuple[int, int, int] = (0, 255, 0), thickness: int = 2, - label: Optional[str] = None, - confidence: Optional[float] = None, - object_id: Optional[int] = None, + label: str | None = None, + confidence: float | None = None, + object_id: int | None = None, font_scale: float = 0.6, ) -> Union[np.ndarray, "cp.ndarray"]: """ @@ -647,7 +649,7 @@ def draw_bounding_box( def draw_segmentation_mask( image: Union[np.ndarray, "cp.ndarray"], mask: Union[np.ndarray, "cp.ndarray"], - color: Tuple[int, int, int] = (0, 200, 200), + color: tuple[int, int, int] = (0, 200, 200), alpha: float = 0.5, draw_contours: bool = True, contour_thickness: int = 2, @@ -692,10 +694,10 @@ def draw_segmentation_mask( def draw_object_detection_visualization( image: Union[np.ndarray, "cp.ndarray"], - objects: List[ObjectData], + objects: list[ObjectData], draw_masks: bool = False, - bbox_color: Tuple[int, int, int] = (0, 255, 0), - mask_color: Tuple[int, int, int] = (0, 200, 200), + bbox_color: tuple[int, int, int] = (0, 255, 0), + mask_color: tuple[int, int, int] = (0, 200, 200), font_scale: float = 0.6, ) -> Union[np.ndarray, "cp.ndarray"]: """ @@ -751,14 +753,14 @@ def draw_object_detection_visualization( def detection_results_to_object_data( - bboxes: List[List[float]], - track_ids: List[int], - class_ids: List[int], - confidences: List[float], - names: List[str], - masks: Optional[List[np.ndarray]] = None, + bboxes: list[list[float]], + track_ids: list[int], + class_ids: list[int], + confidences: list[float], + names: list[str], + masks: list[np.ndarray] | None = None, source: str = "detection", -) -> List[ObjectData]: +) -> list[ObjectData]: """ Convert detection/segmentation results to ObjectData format. @@ -781,8 +783,8 @@ def detection_results_to_object_data( bbox = bboxes[i] width = bbox[2] - bbox[0] height = bbox[3] - bbox[1] - center_x = bbox[0] + width / 2 - center_y = bbox[1] + height / 2 + bbox[0] + width / 2 + bbox[1] + height / 2 # Create ObjectData object_data: ObjectData = { @@ -813,8 +815,8 @@ def detection_results_to_object_data( def combine_object_data( - list1: List[ObjectData], list2: List[ObjectData], overlap_threshold: float = 0.8 -) -> List[ObjectData]: + list1: list[ObjectData], list2: list[ObjectData], overlap_threshold: float = 0.8 +) -> list[ObjectData]: """ Combine two ObjectData lists, removing duplicates based on segmentation mask overlap. """ @@ -858,7 +860,7 @@ def combine_object_data( return combined -def point_in_bbox(point: Tuple[int, int], bbox: List[float]) -> bool: +def point_in_bbox(point: tuple[int, int], bbox: list[float]) -> bool: """ Check if a point is inside a bounding box. @@ -874,7 +876,7 @@ def point_in_bbox(point: Tuple[int, int], bbox: List[float]) -> bool: return x1 <= x <= x2 and y1 <= y <= y2 -def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, float]: +def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> tuple[float, float, float, float]: """ Convert BoundingBox2D from center format to corner format. @@ -898,8 +900,8 @@ def bbox2d_to_corners(bbox_2d: BoundingBox2D) -> Tuple[float, float, float, floa def find_clicked_detection( - click_pos: Tuple[int, int], detections_2d: List[Detection2D], detections_3d: List[Detection3D] -) -> Optional[Detection3D]: + click_pos: tuple[int, int], detections_2d: list[Detection2D], detections_3d: list[Detection3D] +) -> Detection3D | None: """ Find which detection was clicked based on 2D bounding boxes. diff --git a/dimos/perception/detection/conftest.py b/dimos/perception/detection/conftest.py index 69481c2fb0..c6994382a2 100644 --- a/dimos/perception/detection/conftest.py +++ b/dimos/perception/detection/conftest.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable, Generator import functools -from typing import Callable, Generator, Optional, TypedDict +from typing import TypedDict -import pytest from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos_lcm.foxglove_msgs.SceneUpdate import SceneUpdate from dimos_lcm.visualization_msgs.MarkerArray import MarkerArray +import pytest from dimos.core import LCMTransport from dimos.msgs.geometry_msgs import Transform @@ -48,10 +49,10 @@ class Moment(TypedDict, total=False): camera_info: CameraInfo transforms: list[Transform] tf: TF - annotations: Optional[ImageAnnotations] - detections: Optional[ImageDetections3DPC] - markers: Optional[MarkerArray] - scene_update: Optional[SceneUpdate] + annotations: ImageAnnotations | None + detections: ImageDetections3DPC | None + markers: MarkerArray | None + scene_update: SceneUpdate | None class Moment2D(Moment): @@ -106,7 +107,7 @@ def moment_provider(**kwargs) -> Moment: camera_info_out = go2.camera_info from typing import cast - camera_info = cast(CameraInfo, camera_info_out) + camera_info = cast("CameraInfo", camera_info_out) return { "odom_frame": odom_frame, "lidar_frame": lidar_frame, @@ -121,7 +122,7 @@ def moment_provider(**kwargs) -> Moment: @pytest.fixture(scope="session") def publish_moment(): - def publisher(moment: Moment | Moment2D | Moment3D): + def publisher(moment: Moment | Moment2D | Moment3D) -> None: detections2d_val = moment.get("detections2d") if detections2d_val: # 2d annotations @@ -225,7 +226,7 @@ def moment_provider(**kwargs) -> Moment2D: @pytest.fixture(scope="session") def get_moment_3dpc(get_moment_2d) -> Generator[Callable[[], Moment3D], None, None]: - module: Optional[Detection3DModule] = None + module: Detection3DModule | None = None @functools.lru_cache(maxsize=1) def moment_provider(**kwargs) -> Moment3D: diff --git a/dimos/perception/detection/detectors/detic.py b/dimos/perception/detection/detectors/detic.py index db2d8bb634..4432988f28 100644 --- a/dimos/perception/detection/detectors/detic.py +++ b/dimos/perception/detection/detectors/detic.py @@ -12,18 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Sequence import os import sys import numpy as np +# Add Detic to Python path +from dimos.constants import DIMOS_PROJECT_ROOT from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.detectors.types import Detector from dimos.perception.detection2d.utils import plot_results -# Add Detic to Python path -from dimos.constants import DIMOS_PROJECT_ROOT - detic_path = DIMOS_PROJECT_ROOT / "dimos/models/Detic" if str(detic_path) not in sys.path: sys.path.append(str(detic_path)) @@ -44,7 +44,7 @@ class SimpleTracker: """Simple IOU-based tracker implementation without external dependencies""" - def __init__(self, iou_threshold=0.3, max_age=5): + def __init__(self, iou_threshold: float = 0.3, max_age: int = 5) -> None: self.iou_threshold = iou_threshold self.max_age = max_age self.next_id = 1 @@ -161,7 +161,9 @@ def update(self, detections, masks): class Detic2DDetector(Detector): - def __init__(self, model_path=None, device="cuda", vocabulary=None, threshold=0.5): + def __init__( + self, model_path=None, device: str = "cuda", vocabulary=None, threshold: float = 0.5 + ) -> None: """ Initialize the Detic detector with open vocabulary support. @@ -278,7 +280,7 @@ def setup_vocabulary(self, vocabulary): if isinstance(vocabulary, str): # If it's a string but not a built-in dataset, treat as a file try: - with open(vocabulary, "r") as f: + with open(vocabulary) as f: class_names = [line.strip() for line in f if line.strip()] except: # Default to LVIS if there's an issue @@ -301,7 +303,7 @@ def setup_vocabulary(self, vocabulary): self.reset_cls_test(self.predictor.model, classifier, num_classes) return self.class_names - def _get_clip_embeddings(self, vocabulary, prompt="a "): + def _get_clip_embeddings(self, vocabulary, prompt: str = "a "): """ Generate CLIP embeddings for a vocabulary list. @@ -354,7 +356,7 @@ def process_image(self, image: Image): bboxes.append([x1, y1, x2, y2]) # Get class names - names = [self.class_names[class_id] for class_id in class_ids] + [self.class_names[class_id] for class_id in class_ids] # Apply tracking detections = [] @@ -362,7 +364,7 @@ def process_image(self, image: Image): for i, bbox in enumerate(bboxes): if scores[i] >= self.threshold: # Format for tracker: [x1, y1, x2, y2, score, class_id] - detections.append(bbox + [scores[i], class_ids[i]]) + detections.append([*bbox, scores[i], class_ids[i]]) filtered_masks.append(masks[i]) if not detections: @@ -396,7 +398,9 @@ def process_image(self, image: Image): # tracked_masks, ) - def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, names): + def visualize_results( + self, image, bboxes, track_ids, class_ids, confidences, names: Sequence[str] + ): """ Generate visualization of detection results. @@ -414,7 +418,7 @@ def visualize_results(self, image, bboxes, track_ids, class_ids, confidences, na return plot_results(image, bboxes, track_ids, class_ids, confidences, names) - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" # Nothing specific to clean up for Detic pass diff --git a/dimos/perception/detection/detectors/person/test_person_detectors.py b/dimos/perception/detection/detectors/person/test_person_detectors.py index bca39acbcd..d912bec3a0 100644 --- a/dimos/perception/detection/detectors/person/test_person_detectors.py +++ b/dimos/perception/detection/detectors/person/test_person_detectors.py @@ -27,7 +27,7 @@ def person(people): return people[0] -def test_person_detection(people): +def test_person_detection(people) -> None: """Test that we can detect people with pose keypoints.""" assert len(people) > 0 @@ -40,7 +40,7 @@ def test_person_detection(people): assert person.keypoint_scores.shape == (17,) -def test_person_properties(people): +def test_person_properties(people) -> None: """Test Detection2DPerson object properties and methods.""" person = people[0] @@ -62,7 +62,7 @@ def test_person_properties(people): assert all(0 <= conf <= 1 for _, _, conf in visible) -def test_person_normalized_coords(people): +def test_person_normalized_coords(people) -> None: """Test normalized coordinates if available.""" person = people[0] @@ -78,7 +78,7 @@ def test_person_normalized_coords(people): assert (person.bbox_normalized <= 1).all() -def test_multiple_people(people): +def test_multiple_people(people) -> None: """Test that multiple people can be detected.""" print(f"\nDetected {len(people)} people in test image") @@ -93,14 +93,14 @@ def test_multiple_people(people): print(f" {name}: ({xy[0]:.1f}, {xy[1]:.1f}) conf={conf:.3f}") -def test_image_detections2d_structure(people): +def test_image_detections2d_structure(people) -> None: """Test that process_image returns ImageDetections2D.""" assert isinstance(people, ImageDetections2D) assert len(people.detections) > 0 assert all(isinstance(d, Detection2DPerson) for d in people.detections) -def test_invalid_keypoint(test_image): +def test_invalid_keypoint(test_image) -> None: """Test error handling for invalid keypoint names.""" # Create a dummy Detection2DPerson import numpy as np @@ -123,7 +123,7 @@ def test_invalid_keypoint(test_image): person.get_keypoint("invalid_keypoint") -def test_person_annotations(person): +def test_person_annotations(person) -> None: # Test text annotations text_anns = person.to_text_annotation() print(f"\nText annotations: {len(text_anns)}") @@ -156,5 +156,5 @@ def test_person_annotations(person): assert img_anns.texts_length == len(text_anns) assert img_anns.points_length == len(points_anns) - print(f"\n✓ Person annotations working correctly!") + print("\n✓ Person annotations working correctly!") print(f" - {len(person.get_visible_keypoints(0.5))}/17 visible keypoints") diff --git a/dimos/perception/detection/detectors/person/yolo.py b/dimos/perception/detection/detectors/person/yolo.py index 05e79fa22f..6421ab7d1d 100644 --- a/dimos/perception/detection/detectors/person/yolo.py +++ b/dimos/perception/detection/detectors/person/yolo.py @@ -25,7 +25,12 @@ class YoloPersonDetector(Detector): - def __init__(self, model_path="models_yolo", model_name="yolo11n-pose.pt", device: str = None): + def __init__( + self, + model_path: str = "models_yolo", + model_name: str = "yolo11n-pose.pt", + device: str | None = None, + ) -> None: self.model = YOLO(get_data(model_path) / model_name, task="track") self.tracker = get_data(model_path) / "botsort.yaml" @@ -60,7 +65,7 @@ def process_image(self, image: Image) -> ImageDetections2D: ) return ImageDetections2D.from_ultralytics_result(image, results) - def stop(self): + def stop(self) -> None: """ Clean up resources used by the detector, including tracker threads. """ diff --git a/dimos/perception/detection/detectors/test_bbox_detectors.py b/dimos/perception/detection/detectors/test_bbox_detectors.py index d246ded8a3..a86690279f 100644 --- a/dimos/perception/detection/detectors/test_bbox_detectors.py +++ b/dimos/perception/detection/detectors/test_bbox_detectors.py @@ -29,7 +29,7 @@ def detections(detector, test_image): return detector.process_image(test_image) -def test_detection_basic(detections): +def test_detection_basic(detections) -> None: """Test that we can detect objects with all detectors.""" assert len(detections.detections) > 0 @@ -42,7 +42,7 @@ def test_detection_basic(detections): assert detection.name is not None -def test_detection_bbox_properties(detections): +def test_detection_bbox_properties(detections) -> None: """Test Detection2D bbox properties work for all detectors.""" detection = detections.detections[0] @@ -66,7 +66,7 @@ def test_detection_bbox_properties(detections): assert height == y2 - y1 -def test_detection_cropped_image(detections, test_image): +def test_detection_cropped_image(detections, test_image) -> None: """Test cropping image to detection bbox.""" detection = detections.detections[0] @@ -80,7 +80,7 @@ def test_detection_cropped_image(detections, test_image): assert cropped.shape[1] <= test_image.shape[1] -def test_detection_annotations(detections): +def test_detection_annotations(detections) -> None: """Test annotation generation for detections.""" detection = detections.detections[0] @@ -98,7 +98,7 @@ def test_detection_annotations(detections): assert annotations.points_length >= 1 -def test_detection_ros_conversion(detections): +def test_detection_ros_conversion(detections) -> None: """Test conversion to ROS Detection2D message.""" detection = detections.detections[0] @@ -117,7 +117,7 @@ def test_detection_ros_conversion(detections): assert ros_det.results[0].hypothesis.class_id == detection.class_id -def test_detection_is_valid(detections): +def test_detection_is_valid(detections) -> None: """Test bbox validation.""" detection = detections.detections[0] @@ -125,14 +125,14 @@ def test_detection_is_valid(detections): assert detection.is_valid() -def test_image_detections2d_structure(detections): +def test_image_detections2d_structure(detections) -> None: """Test that process_image returns ImageDetections2D.""" assert isinstance(detections, ImageDetections2D) assert len(detections.detections) > 0 assert all(isinstance(d, Detection2D) for d in detections.detections) -def test_multiple_detections(detections): +def test_multiple_detections(detections) -> None: """Test that multiple objects can be detected.""" print(f"\nDetected {len(detections.detections)} objects in test image") @@ -146,7 +146,7 @@ def test_multiple_detections(detections): print(f" Track ID: {detection.track_id}") -def test_detection_string_representation(detections): +def test_detection_string_representation(detections) -> None: """Test string representation of detections.""" detection = detections.detections[0] str_repr = str(detection) diff --git a/dimos/perception/detection/detectors/yolo.py b/dimos/perception/detection/detectors/yolo.py index a338d3c8de..64e56ad456 100644 --- a/dimos/perception/detection/detectors/yolo.py +++ b/dimos/perception/detection/detectors/yolo.py @@ -25,7 +25,12 @@ class Yolo2DDetector(Detector): - def __init__(self, model_path="models_yolo", model_name="yolo11n.pt", device: str = None): + def __init__( + self, + model_path: str = "models_yolo", + model_name: str = "yolo11n.pt", + device: str | None = None, + ) -> None: self.model = YOLO( get_data(model_path) / model_name, task="detect", @@ -63,7 +68,7 @@ def process_image(self, image: Image) -> ImageDetections2D: return ImageDetections2D.from_ultralytics_result(image, results) - def stop(self): + def stop(self) -> None: """ Clean up resources used by the detector, including tracker threads. """ diff --git a/dimos/perception/detection/module2D.py b/dimos/perception/detection/module2D.py index cc2790e0df..4b0ccc11aa 100644 --- a/dimos/perception/detection/module2D.py +++ b/dimos/perception/detection/module2D.py @@ -11,8 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any from dimos_lcm.foxglove_msgs.ImageAnnotations import ( ImageAnnotations, @@ -38,12 +39,12 @@ @dataclass class Config(ModuleConfig): max_freq: float = 10 - detector: Optional[Callable[[Any], Detector]] = Yolo2DDetector + detector: Callable[[Any], Detector] | None = Yolo2DDetector publish_detection_images: bool = True camera_info: CameraInfo = None # type: ignore filter: list[Filter2D] | Filter2D | None = None - def __post_init__(self): + def __post_init__(self) -> None: if self.filter is None: self.filter = [] elif not isinstance(self.filter, list): @@ -66,7 +67,7 @@ class Detection2DModule(Module): cnt: int = 0 - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.detector = self.config.detector() self.vlm_detections_subject = Subject() @@ -90,7 +91,7 @@ def sharp_image_stream(self) -> Observable[Image]: def detection_stream_2d(self) -> Observable[ImageDetections2D]: return backpressure(self.sharp_image_stream().pipe(ops.map(self.process_image_frame))) - def track(self, detections: ImageDetections2D): + def track(self, detections: ImageDetections2D) -> None: sensor_frame = self.tf.get("sensor", "camera_optical", detections.image.ts, 5.0) if not sensor_frame: @@ -130,7 +131,7 @@ def track(self, detections: ImageDetections2D): self.tf.publish(*transforms) @rpc - def start(self): + def start(self) -> None: # self.detection_stream_2d().subscribe(self.track) self.detection_stream_2d().subscribe( @@ -141,7 +142,7 @@ def start(self): lambda det: self.annotations.publish(det.to_foxglove_annotations()) ) - def publish_cropped_images(detections: ImageDetections2D): + def publish_cropped_images(detections: ImageDetections2D) -> None: for index, detection in enumerate(detections[:3]): image_topic = getattr(self, "detected_image_" + str(index)) image_topic.publish(detection.cropped_image()) diff --git a/dimos/perception/detection/module3D.py b/dimos/perception/detection/module3D.py index c218704600..c457229066 100644 --- a/dimos/perception/detection/module3D.py +++ b/dimos/perception/detection/module3D.py @@ -13,8 +13,6 @@ # limitations under the License. -from typing import Optional, Tuple - from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from lcm_msgs.foxglove_msgs import SceneUpdate from reactivex import operators as ops @@ -55,7 +53,7 @@ class Detection3DModule(Detection2DModule): detected_image_1: Out[Image] = None # type: ignore detected_image_2: Out[Image] = None # type: ignore - detection_3d_stream: Optional[Observable[ImageDetections3DPC]] = None + detection_3d_stream: Observable[ImageDetections3DPC] | None = None def process_frame( self, @@ -81,7 +79,7 @@ def process_frame( def pixel_to_3d( self, - pixel: Tuple[int, int], + pixel: tuple[int, int], assumed_depth: float = 1.0, ) -> Vector3: """Unproject 2D pixel coordinates to 3D position in camera optical frame. @@ -163,7 +161,7 @@ def nav_vlm(self, question: str) -> str: ) @rpc - def start(self): + def start(self) -> None: super().start() def detection2d_to_3d(args): @@ -184,7 +182,7 @@ def detection2d_to_3d(args): def stop(self) -> None: super().stop() - def _publish_detections(self, detections: ImageDetections3DPC): + def _publish_detections(self, detections: ImageDetections3DPC) -> None: if not detections: return diff --git a/dimos/perception/detection/moduleDB.py b/dimos/perception/detection/moduleDB.py index 620d15cec3..d3c9f51c23 100644 --- a/dimos/perception/detection/moduleDB.py +++ b/dimos/perception/detection/moduleDB.py @@ -11,10 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable +from copy import copy import threading import time -from copy import copy -from typing import Any, Callable, Dict, List, Optional +from typing import Any from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from lcm_msgs.foxglove_msgs import SceneUpdate @@ -23,7 +24,7 @@ from dimos import spec from dimos.core import DimosCluster, In, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Vector3 -from dimos.msgs.sensor_msgs import CameraInfo, Image, PointCloud2 +from dimos.msgs.sensor_msgs import Image, PointCloud2 from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.module3D import Detection3DModule from dimos.perception.detection.type import ImageDetections3DPC, TableStr @@ -32,12 +33,12 @@ # Represents an object in space, as collection of 3d detections over time class Object3D(Detection3DPC): - best_detection: Optional[Detection3DPC] = None # type: ignore - center: Optional[Vector3] = None # type: ignore - track_id: Optional[str] = None # type: ignore + best_detection: Detection3DPC | None = None # type: ignore + center: Vector3 | None = None # type: ignore + track_id: str | None = None # type: ignore detections: int = 0 - def to_repr_dict(self) -> Dict[str, Any]: + def to_repr_dict(self) -> dict[str, Any]: if self.center is None: center_str = "None" else: @@ -50,7 +51,9 @@ def to_repr_dict(self) -> Dict[str, Any]: "center": center_str, } - def __init__(self, track_id: str, detection: Optional[Detection3DPC] = None, *args, **kwargs): + def __init__( + self, track_id: str, detection: Detection3DPC | None = None, *args, **kwargs + ) -> None: if detection is None: return self.ts = detection.ts @@ -89,7 +92,7 @@ def __add__(self, detection: Detection3DPC) -> "Object3D": return new_object - def get_image(self) -> Optional[Image]: + def get_image(self) -> Image | None: return self.best_detection.image if self.best_detection else None def scene_entity_label(self) -> str: @@ -100,7 +103,7 @@ def agent_encode(self): "id": self.track_id, "name": self.name, "detections": self.detections, - "last_seen": f"{round((time.time() - self.ts))}s ago", + "last_seen": f"{round(time.time() - self.ts)}s ago", # "position": self.to_pose().position.agent_encode(), } @@ -134,9 +137,9 @@ def to_pose(self) -> PoseStamped: class ObjectDBModule(Detection3DModule, TableStr): cnt: int = 0 objects: dict[str, Object3D] - object_stream: Optional[Observable[Object3D]] = None + object_stream: Observable[Object3D] | None = None - goto: Optional[Callable[[PoseStamped], Any]] = None + goto: Callable[[PoseStamped], Any] | None = None image: In[Image] = None # type: ignore pointcloud: In[PointCloud2] = None # type: ignore @@ -156,17 +159,17 @@ class ObjectDBModule(Detection3DModule, TableStr): target: Out[PoseStamped] = None # type: ignore - remembered_locations: Dict[str, PoseStamped] + remembered_locations: dict[str, PoseStamped] @rpc - def start(self): + def start(self) -> None: Detection3DModule.start(self) - def update_objects(imageDetections: ImageDetections3DPC): + def update_objects(imageDetections: ImageDetections3DPC) -> None: for detection in imageDetections.detections: self.add_detection(detection) - def scene_thread(): + def scene_thread() -> None: while True: scene_update = self.to_foxglove_scene_update() self.scene_update.publish(scene_update) @@ -176,13 +179,13 @@ def scene_thread(): self.detection_stream_3d.subscribe(update_objects) - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.goto = None self.objects = {} self.remembered_locations = {} - def closest_object(self, detection: Detection3DPC) -> Optional[Object3D]: + def closest_object(self, detection: Detection3DPC) -> Object3D | None: # Filter objects to only those with matching names matching_objects = [obj for obj in self.objects.values() if obj.name == detection.name] @@ -194,7 +197,7 @@ def closest_object(self, detection: Detection3DPC) -> Optional[Object3D]: return distances[0] - def add_detections(self, detections: List[Detection3DPC]) -> List[Object3D]: + def add_detections(self, detections: list[Detection3DPC]) -> list[Object3D]: return [ detection for detection in map(self.add_detection, detections) if detection is not None ] @@ -263,7 +266,7 @@ def agent_encode(self) -> str: # return ret[0] if ret else None - def lookup(self, label: str) -> List[Detection3DPC]: + def lookup(self, label: str) -> list[Detection3DPC]: """Look up a detection by label.""" return [] @@ -271,7 +274,7 @@ def lookup(self, label: str) -> List[Detection3DPC]: def stop(self): return super().stop() - def goto_object(self, object_id: str) -> Optional[Object3D]: + def goto_object(self, object_id: str) -> Object3D | None: """Go to object by id.""" return self.objects.get(object_id, None) @@ -293,13 +296,13 @@ def to_foxglove_scene_update(self) -> "SceneUpdate": scene_update.entities.append( obj.to_foxglove_scene_entity(entity_id=f"{obj.name}_{obj.track_id}") ) - except Exception as e: + except Exception: pass scene_update.entities_length = len(scene_update.entities) return scene_update - def __len__(self): + def __len__(self) -> int: return len(self.objects.values()) diff --git a/dimos/perception/detection/person_tracker.py b/dimos/perception/detection/person_tracker.py index fe69fbc15e..568214d972 100644 --- a/dimos/perception/detection/person_tracker.py +++ b/dimos/perception/detection/person_tracker.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple from reactivex import operators as ops from reactivex.observable import Observable @@ -33,13 +32,13 @@ class PersonTracker(Module): camera_info: CameraInfo - def __init__(self, cameraInfo: CameraInfo, **kwargs): + def __init__(self, cameraInfo: CameraInfo, **kwargs) -> None: super().__init__(**kwargs) self.camera_info = cameraInfo def center_to_3d( self, - pixel: Tuple[int, int], + pixel: tuple[int, int], camera_info: CameraInfo, assumed_depth: float = 1.0, ) -> Vector3: @@ -85,14 +84,14 @@ def detections_stream(self) -> Observable[ImageDetections2D]: ) @rpc - def start(self): + def start(self) -> None: self.detections_stream().subscribe(self.track) @rpc - def stop(self): + def stop(self) -> None: super().stop() - def track(self, detections2D: ImageDetections2D): + def track(self, detections2D: ImageDetections2D) -> None: if len(detections2D) == 0: return diff --git a/dimos/perception/detection/reid/__init__.py b/dimos/perception/detection/reid/__init__.py index b76741a7eb..31d50a894b 100644 --- a/dimos/perception/detection/reid/__init__.py +++ b/dimos/perception/detection/reid/__init__.py @@ -1,13 +1,13 @@ -from dimos.perception.detection.reid.module import Config, ReidModule from dimos.perception.detection.reid.embedding_id_system import EmbeddingIDSystem +from dimos.perception.detection.reid.module import Config, ReidModule from dimos.perception.detection.reid.type import IDSystem, PassthroughIDSystem __all__ = [ + "Config", + "EmbeddingIDSystem", # ID Systems "IDSystem", "PassthroughIDSystem", - "EmbeddingIDSystem", # Module "ReidModule", - "Config", ] diff --git a/dimos/perception/detection/reid/embedding_id_system.py b/dimos/perception/detection/reid/embedding_id_system.py index 7fb0a2ba40..c1c406fe56 100644 --- a/dimos/perception/detection/reid/embedding_id_system.py +++ b/dimos/perception/detection/reid/embedding_id_system.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Dict, List, Literal, Set +from collections.abc import Callable +from typing import Literal import numpy as np @@ -39,7 +40,7 @@ def __init__( top_k: int = 30, max_embeddings_per_track: int = 500, min_embeddings_for_matching: int = 10, - ): + ) -> None: """Initialize track associator. Args: @@ -69,17 +70,17 @@ def __init__( self.min_embeddings_for_matching = min_embeddings_for_matching # Track embeddings (list of all embeddings as numpy arrays) - self.track_embeddings: Dict[int, List[np.ndarray]] = {} + self.track_embeddings: dict[int, list[np.ndarray]] = {} # Negative constraints (track_ids that co-occurred = different objects) - self.negative_pairs: Dict[int, Set[int]] = {} + self.negative_pairs: dict[int, set[int]] = {} # Track ID to long-term unique ID mapping - self.track_to_long_term: Dict[int, int] = {} + self.track_to_long_term: dict[int, int] = {} self.long_term_counter: int = 0 # Similarity history for optional adaptive thresholding - self.similarity_history: List[float] = [] + self.similarity_history: list[float] = [] def register_detection(self, detection: Detection2DBBox) -> int: """ @@ -128,7 +129,7 @@ def update_embedding(self, track_id: int, new_embedding: Embedding) -> None: embeddings.pop(0) # Remove oldest def _compute_group_similarity( - self, query_embeddings: List[np.ndarray], candidate_embeddings: List[np.ndarray] + self, query_embeddings: list[np.ndarray], candidate_embeddings: list[np.ndarray] ) -> float: """Compute similarity between two groups of embeddings. @@ -164,7 +165,7 @@ def _compute_group_similarity( else: raise ValueError(f"Unknown comparison mode: {self.comparison_mode}") - def add_negative_constraints(self, track_ids: List[int]) -> None: + def add_negative_constraints(self, track_ids: list[int]) -> None: """Record that these track_ids co-occurred in same frame (different objects). Args: diff --git a/dimos/perception/detection/reid/module.py b/dimos/perception/detection/reid/module.py index 64769b1038..3cef9f2ff2 100644 --- a/dimos/perception/detection/reid/module.py +++ b/dimos/perception/detection/reid/module.py @@ -42,7 +42,7 @@ class ReidModule(Module): image: In[Image] = None # type: ignore annotations: Out[ImageAnnotations] = None # type: ignore - def __init__(self, idsystem: IDSystem | None = None, **kwargs): + def __init__(self, idsystem: IDSystem | None = None, **kwargs) -> None: super().__init__(**kwargs) if idsystem is None: try: @@ -69,14 +69,14 @@ def detections_stream(self) -> Observable[ImageDetections2D]: ) @rpc - def start(self): + def start(self) -> None: self.detections_stream().subscribe(self.ingress) @rpc - def stop(self): + def stop(self) -> None: super().stop() - def ingress(self, imageDetections: ImageDetections2D): + def ingress(self, imageDetections: ImageDetections2D) -> None: text_annotations = [] for detection in imageDetections: diff --git a/dimos/perception/detection/reid/test_embedding_id_system.py b/dimos/perception/detection/reid/test_embedding_id_system.py index b2bc84bc55..840ecb2fb8 100644 --- a/dimos/perception/detection/reid/test_embedding_id_system.py +++ b/dimos/perception/detection/reid/test_embedding_id_system.py @@ -44,7 +44,7 @@ def test_image(): @pytest.mark.gpu -def test_update_embedding_single(track_associator, mobileclip_model, test_image): +def test_update_embedding_single(track_associator, mobileclip_model, test_image) -> None: """Test updating embedding for a single track.""" embedding = mobileclip_model.embed(test_image) @@ -63,7 +63,7 @@ def test_update_embedding_single(track_associator, mobileclip_model, test_image) @pytest.mark.gpu -def test_update_embedding_running_average(track_associator, mobileclip_model, test_image): +def test_update_embedding_running_average(track_associator, mobileclip_model, test_image) -> None: """Test running average of embeddings.""" embedding1 = mobileclip_model.embed(test_image) embedding2 = mobileclip_model.embed(test_image) @@ -88,7 +88,7 @@ def test_update_embedding_running_average(track_associator, mobileclip_model, te @pytest.mark.gpu -def test_negative_constraints(track_associator): +def test_negative_constraints(track_associator) -> None: """Test negative constraint recording.""" # Simulate frame with 3 tracks track_ids = [1, 2, 3] @@ -104,7 +104,7 @@ def test_negative_constraints(track_associator): @pytest.mark.gpu -def test_associate_new_track(track_associator, mobileclip_model, test_image): +def test_associate_new_track(track_associator, mobileclip_model, test_image) -> None: """Test associating a new track creates new long_term_id.""" embedding = mobileclip_model.embed(test_image) track_associator.update_embedding(track_id=1, new_embedding=embedding) @@ -118,7 +118,7 @@ def test_associate_new_track(track_associator, mobileclip_model, test_image): @pytest.mark.gpu -def test_associate_similar_tracks(track_associator, mobileclip_model, test_image): +def test_associate_similar_tracks(track_associator, mobileclip_model, test_image) -> None: """Test associating similar tracks to same long_term_id.""" # Create embeddings from same image (should be very similar) embedding1 = mobileclip_model.embed(test_image) @@ -138,7 +138,7 @@ def test_associate_similar_tracks(track_associator, mobileclip_model, test_image @pytest.mark.gpu -def test_associate_with_negative_constraint(track_associator, mobileclip_model, test_image): +def test_associate_with_negative_constraint(track_associator, mobileclip_model, test_image) -> None: """Test that negative constraints prevent association.""" # Create similar embeddings embedding1 = mobileclip_model.embed(test_image) @@ -163,7 +163,7 @@ def test_associate_with_negative_constraint(track_associator, mobileclip_model, @pytest.mark.gpu -def test_associate_different_objects(track_associator, mobileclip_model, test_image): +def test_associate_different_objects(track_associator, mobileclip_model, test_image) -> None: """Test that dissimilar embeddings get different long_term_ids.""" # Create embeddings for image and text (very different) image_emb = mobileclip_model.embed(test_image) @@ -183,7 +183,7 @@ def test_associate_different_objects(track_associator, mobileclip_model, test_im @pytest.mark.gpu -def test_associate_returns_cached(track_associator, mobileclip_model, test_image): +def test_associate_returns_cached(track_associator, mobileclip_model, test_image) -> None: """Test that repeated calls return same long_term_id.""" embedding = mobileclip_model.embed(test_image) track_associator.update_embedding(track_id=1, new_embedding=embedding) @@ -199,14 +199,14 @@ def test_associate_returns_cached(track_associator, mobileclip_model, test_image @pytest.mark.gpu -def test_associate_not_ready(track_associator): +def test_associate_not_ready(track_associator) -> None: """Test that associate returns -1 for track without embedding.""" long_term_id = track_associator.associate(track_id=999) assert long_term_id == -1, "Should return -1 for track without embedding" @pytest.mark.gpu -def test_gpu_performance(track_associator, mobileclip_model, test_image): +def test_gpu_performance(track_associator, mobileclip_model, test_image) -> None: """Test that embeddings stay on GPU for performance.""" embedding = mobileclip_model.embed(test_image) track_associator.update_embedding(track_id=1, new_embedding=embedding) @@ -227,7 +227,7 @@ def test_gpu_performance(track_associator, mobileclip_model, test_image): @pytest.mark.gpu -def test_similarity_threshold_configurable(mobileclip_model): +def test_similarity_threshold_configurable(mobileclip_model) -> None: """Test that similarity threshold is configurable.""" associator_strict = EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.95) associator_loose = EmbeddingIDSystem(model=lambda: mobileclip_model, similarity_threshold=0.50) @@ -237,7 +237,7 @@ def test_similarity_threshold_configurable(mobileclip_model): @pytest.mark.gpu -def test_multi_track_scenario(track_associator, mobileclip_model, test_image): +def test_multi_track_scenario(track_associator, mobileclip_model, test_image) -> None: """Test realistic scenario with multiple tracks across frames.""" # Frame 1: Track 1 appears emb1 = mobileclip_model.embed(test_image) diff --git a/dimos/perception/detection/reid/test_module.py b/dimos/perception/detection/reid/test_module.py index 6c977e13a5..cd580a1111 100644 --- a/dimos/perception/detection/reid/test_module.py +++ b/dimos/perception/detection/reid/test_module.py @@ -21,7 +21,7 @@ @pytest.mark.tool -def test_reid_ingress(imageDetections2d): +def test_reid_ingress(imageDetections2d) -> None: try: from dimos.models.embedding import TorchReIDModel except Exception: diff --git a/dimos/perception/detection/test_moduleDB.py b/dimos/perception/detection/test_moduleDB.py index 97598b6ee2..4a801598b0 100644 --- a/dimos/perception/detection/test_moduleDB.py +++ b/dimos/perception/detection/test_moduleDB.py @@ -13,8 +13,8 @@ # limitations under the License. import time -import pytest from lcm_msgs.foxglove_msgs import SceneUpdate +import pytest from dimos.core import LCMTransport from dimos.msgs.foxglove_msgs import ImageAnnotations @@ -26,7 +26,7 @@ @pytest.mark.module -def test_moduleDB(dimos_cluster): +def test_moduleDB(dimos_cluster) -> None: connection = go2.deploy(dimos_cluster, "fake") moduleDB = dimos_cluster.deploy( diff --git a/dimos/perception/detection/type/__init__.py b/dimos/perception/detection/type/__init__.py index bc44d984fd..624784776f 100644 --- a/dimos/perception/detection/type/__init__.py +++ b/dimos/perception/detection/type/__init__.py @@ -22,22 +22,22 @@ __all__ = [ # 2D Detection types "Detection2D", - "Filter2D", "Detection2DBBox", "Detection2DPerson", - "ImageDetections2D", # 3D Detection types "Detection3D", "Detection3DBBox", "Detection3DPC", + "Filter2D", + # Base types + "ImageDetections", + "ImageDetections2D", "ImageDetections3DPC", # Point cloud filters "PointCloudFilter", + "TableStr", "height_filter", "radius_outlier", "raycast", "statistical", - # Base types - "ImageDetections", - "TableStr", ] diff --git a/dimos/perception/detection/type/detection2d/__init__.py b/dimos/perception/detection/type/detection2d/__init__.py index 197c7a55e2..a0e22546b0 100644 --- a/dimos/perception/detection/type/detection2d/__init__.py +++ b/dimos/perception/detection/type/detection2d/__init__.py @@ -20,6 +20,6 @@ __all__ = [ "Detection2D", "Detection2DBBox", - "ImageDetections2D", "Detection2DPerson", + "ImageDetections2D", ] diff --git a/dimos/perception/detection/type/detection2d/base.py b/dimos/perception/detection/type/detection2d/base.py index ea57acb911..11a4d729f6 100644 --- a/dimos/perception/detection/type/detection2d/base.py +++ b/dimos/perception/detection/type/detection2d/base.py @@ -13,7 +13,7 @@ # limitations under the License. from abc import abstractmethod -from typing import Callable, List +from collections.abc import Callable from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation, TextAnnotation from dimos_lcm.vision_msgs import Detection2D as ROSDetection2D @@ -37,12 +37,12 @@ def to_image_annotations(self) -> ImageAnnotations: ... @abstractmethod - def to_text_annotation(self) -> List[TextAnnotation]: + def to_text_annotation(self) -> list[TextAnnotation]: """Return text annotations for visualization.""" ... @abstractmethod - def to_points_annotation(self) -> List[PointsAnnotation]: + def to_points_annotation(self) -> list[PointsAnnotation]: """Return points/shape annotations for visualization.""" ... diff --git a/dimos/perception/detection/type/detection2d/bbox.py b/dimos/perception/detection/type/detection2d/bbox.py index 223e1bc018..46e8fe2cc7 100644 --- a/dimos/perception/detection/type/detection2d/bbox.py +++ b/dimos/perception/detection/type/detection2d/bbox.py @@ -14,12 +14,14 @@ from __future__ import annotations -import hashlib from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Sequence, Tuple, Union +import hashlib +from typing import TYPE_CHECKING, Any if TYPE_CHECKING: - from dimos.perception.detection.type.detection2d.person import Detection2DPerson + from ultralytics.engine.results import Results + + from dimos.msgs.sensor_msgs import Image from dimos_lcm.foxglove_msgs.ImageAnnotations import ( PointsAnnotation, @@ -28,28 +30,24 @@ from dimos_lcm.foxglove_msgs.Point2 import Point2 from dimos_lcm.vision_msgs import ( BoundingBox2D, + Detection2D as ROSDetection2D, ObjectHypothesis, ObjectHypothesisWithPose, Point2D, Pose2D, ) -from dimos_lcm.vision_msgs import ( - Detection2D as ROSDetection2D, -) from rich.console import Console from rich.text import Text -from ultralytics.engine.results import Results from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.foxglove_msgs.Color import Color -from dimos.msgs.sensor_msgs import Image from dimos.msgs.std_msgs import Header from dimos.perception.detection.type.detection2d.base import Detection2D from dimos.types.timestamped import to_ros_stamp, to_timestamp from dimos.utils.decorators.decorators import simple_mcache -Bbox = Tuple[float, float, float, float] -CenteredBbox = Tuple[float, float, float, float] +Bbox = tuple[float, float, float, float] +CenteredBbox = tuple[float, float, float, float] def _hash_to_color(name: str) -> str: @@ -88,7 +86,7 @@ class Detection2DBBox(Detection2D): ts: float image: Image - def to_repr_dict(self) -> Dict[str, Any]: + def to_repr_dict(self) -> dict[str, Any]: """Return a dictionary representation of the detection for display purposes.""" x1, y1, x2, y2 = self.bbox return { @@ -101,7 +99,7 @@ def to_repr_dict(self) -> Dict[str, Any]: def center_to_3d( self, - pixel: Tuple[int, int], + pixel: tuple[int, int], camera_info: CameraInfo, assumed_depth: float = 1.0, ) -> PoseStamped: @@ -141,7 +139,7 @@ def cropped_image(self, padding: int = 20) -> Image: x1 - padding, y1 - padding, x2 - x1 + 2 * padding, y2 - y1 + 2 * padding ) - def __str__(self): + def __str__(self) -> str: console = Console(force_terminal=True, legacy_windows=False) d = self.to_repr_dict() @@ -166,7 +164,7 @@ def __str__(self): return capture.get().strip() @property - def center_bbox(self) -> Tuple[float, float]: + def center_bbox(self) -> tuple[float, float]: """Get center point of bounding box.""" x1, y1, x2, y2 = self.bbox return ((x1 + x2) / 2, (y1 + y2) / 2) @@ -203,7 +201,7 @@ def is_valid(self) -> bool: return True @classmethod - def from_ultralytics_result(cls, result: Results, idx: int, image: Image) -> "Detection2DBBox": + def from_ultralytics_result(cls, result: Results, idx: int, image: Image) -> Detection2DBBox: """Create Detection2DBBox from ultralytics Results object. Args: @@ -274,8 +272,8 @@ def to_ros_bbox(self) -> BoundingBox2D: def lcm_encode(self): return self.to_image_annotations().lcm_encode() - def to_text_annotation(self) -> List[TextAnnotation]: - x1, y1, x2, y2 = self.bbox + def to_text_annotation(self) -> list[TextAnnotation]: + x1, y1, _x2, y2 = self.bbox font_size = self.image.width / 80 @@ -311,7 +309,7 @@ def to_text_annotation(self) -> List[TextAnnotation]: return annotations - def to_points_annotation(self) -> List[PointsAnnotation]: + def to_points_annotation(self) -> list[PointsAnnotation]: x1, y1, x2, y2 = self.bbox thickness = 1 @@ -351,7 +349,7 @@ def to_image_annotations(self) -> ImageAnnotations: ) @classmethod - def from_ros_detection2d(cls, ros_det: ROSDetection2D, **kwargs) -> "Detection2D": + def from_ros_detection2d(cls, ros_det: ROSDetection2D, **kwargs) -> Detection2D: """Convert from ROS Detection2D message to Detection2D object.""" # Extract bbox from ROS format center_x = ros_det.bbox.center.position.x diff --git a/dimos/perception/detection/type/detection2d/imageDetections2D.py b/dimos/perception/detection/type/detection2d/imageDetections2D.py index 74854dae47..0c505ae2b5 100644 --- a/dimos/perception/detection/type/detection2d/imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/imageDetections2D.py @@ -14,24 +14,26 @@ from __future__ import annotations -from typing import List +from typing import TYPE_CHECKING -from dimos_lcm.vision_msgs import Detection2DArray -from ultralytics.engine.results import Results - -from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type.detection2d.base import Detection2D from dimos.perception.detection.type.detection2d.bbox import Detection2DBBox from dimos.perception.detection.type.imageDetections import ImageDetections +if TYPE_CHECKING: + from dimos_lcm.vision_msgs import Detection2DArray + from ultralytics.engine.results import Results + + from dimos.msgs.sensor_msgs import Image + class ImageDetections2D(ImageDetections[Detection2D]): @classmethod def from_ros_detection2d_array( cls, image: Image, ros_detections: Detection2DArray, **kwargs - ) -> "ImageDetections2D": + ) -> ImageDetections2D: """Convert from ROS Detection2DArray message to ImageDetections2D object.""" - detections: List[Detection2D] = [] + detections: list[Detection2D] = [] for ros_det in ros_detections.detections: detection = Detection2DBBox.from_ros_detection2d(ros_det, image=image, **kwargs) if detection.is_valid(): # type: ignore[attr-defined] @@ -41,8 +43,8 @@ def from_ros_detection2d_array( @classmethod def from_ultralytics_result( - cls, image: Image, results: List[Results], **kwargs - ) -> "ImageDetections2D": + cls, image: Image, results: list[Results], **kwargs + ) -> ImageDetections2D: """Create ImageDetections2D from ultralytics Results. Dispatches to appropriate Detection2D subclass based on result type: @@ -59,7 +61,7 @@ def from_ultralytics_result( """ from dimos.perception.detection.type.detection2d.person import Detection2DPerson - detections: List[Detection2D] = [] + detections: list[Detection2D] = [] for result in results: if result.boxes is None: continue diff --git a/dimos/perception/detection/type/detection2d/person.py b/dimos/perception/detection/type/detection2d/person.py index 1c6fee5cae..1d84613051 100644 --- a/dimos/perception/detection/type/detection2d/person.py +++ b/dimos/perception/detection/type/detection2d/person.py @@ -15,11 +15,11 @@ from dataclasses import dataclass # Import for type checking only to avoid circular imports -from typing import TYPE_CHECKING, List, Optional, Tuple +from typing import TYPE_CHECKING -import numpy as np from dimos_lcm.foxglove_msgs.ImageAnnotations import PointsAnnotation, TextAnnotation from dimos_lcm.foxglove_msgs.Point2 import Point2 +import numpy as np from dimos.msgs.foxglove_msgs.Color import Color from dimos.msgs.sensor_msgs import Image @@ -40,12 +40,12 @@ class Detection2DPerson(Detection2DBBox): keypoint_scores: np.ndarray # [17] - confidence scores # Optional normalized coordinates - bbox_normalized: Optional[np.ndarray] = None # [x1, y1, x2, y2] in 0-1 range - keypoints_normalized: Optional[np.ndarray] = None # [17, 2] in 0-1 range + bbox_normalized: np.ndarray | None = None # [x1, y1, x2, y2] in 0-1 range + keypoints_normalized: np.ndarray | None = None # [17, 2] in 0-1 range # Image dimensions for context - image_width: Optional[int] = None - image_height: Optional[int] = None + image_width: int | None = None + image_height: int | None = None # Keypoint names (class attribute) KEYPOINT_NAMES = [ @@ -88,9 +88,9 @@ def from_ultralytics_result( # Validate that this is a pose detection result if not hasattr(result, "keypoints") or result.keypoints is None: raise ValueError( - f"Cannot create Detection2DPerson from result without keypoints. " - f"This appears to be a regular detection result, not a pose detection. " - f"Use Detection2DBBox.from_ultralytics_result() instead." + "Cannot create Detection2DPerson from result without keypoints. " + "This appears to be a regular detection result, not a pose detection. " + "Use Detection2DBBox.from_ultralytics_result() instead." ) if not hasattr(result, "boxes") or result.boxes is None: @@ -191,7 +191,7 @@ def from_ros_detection2d(cls, *args, **kwargs) -> "Detection2DPerson": "message format that includes pose keypoints." ) - def get_keypoint(self, name: str) -> Tuple[np.ndarray, float]: + def get_keypoint(self, name: str) -> tuple[np.ndarray, float]: """Get specific keypoint by name. Returns: Tuple of (xy_coordinates, confidence_score) @@ -202,13 +202,15 @@ def get_keypoint(self, name: str) -> Tuple[np.ndarray, float]: idx = self.KEYPOINT_NAMES.index(name) return self.keypoints[idx], self.keypoint_scores[idx] - def get_visible_keypoints(self, threshold: float = 0.5) -> List[Tuple[str, np.ndarray, float]]: + def get_visible_keypoints(self, threshold: float = 0.5) -> list[tuple[str, np.ndarray, float]]: """Get all keypoints above confidence threshold. Returns: List of tuples: (keypoint_name, xy_coordinates, confidence) """ visible = [] - for i, (name, score) in enumerate(zip(self.KEYPOINT_NAMES, self.keypoint_scores)): + for i, (name, score) in enumerate( + zip(self.KEYPOINT_NAMES, self.keypoint_scores, strict=False) + ): if score > threshold: visible.append((name, self.keypoints[i], score)) return visible @@ -231,12 +233,12 @@ def height(self) -> float: return y2 - y1 @property - def center(self) -> Tuple[float, float]: + def center(self) -> tuple[float, float]: """Get center point of bounding box.""" x1, y1, x2, y2 = self.bbox return ((x1 + x2) / 2, (y1 + y2) / 2) - def to_points_annotation(self) -> List[PointsAnnotation]: + def to_points_annotation(self) -> list[PointsAnnotation]: """Override to include keypoint visualizations along with bounding box.""" annotations = [] @@ -249,7 +251,7 @@ def to_points_annotation(self) -> List[PointsAnnotation]: # Create points for visible keypoints if visible_keypoints: keypoint_points = [] - for name, xy, conf in visible_keypoints: + for _name, xy, _conf in visible_keypoints: keypoint_points.append(Point2(float(xy[0]), float(xy[1]))) # Add keypoints as circles @@ -317,14 +319,14 @@ def to_points_annotation(self) -> List[PointsAnnotation]: return annotations - def to_text_annotation(self) -> List[TextAnnotation]: + def to_text_annotation(self) -> list[TextAnnotation]: """Override to include pose information in text annotations.""" # Get base annotations from parent annotations = super().to_text_annotation() # Add pose-specific info visible_count = len(self.get_visible_keypoints(threshold=0.5)) - x1, y1, x2, y2 = self.bbox + x1, _y1, _x2, y2 = self.bbox annotations.append( TextAnnotation( diff --git a/dimos/perception/detection/type/detection2d/test_bbox.py b/dimos/perception/detection/type/detection2d/test_bbox.py index 3bf37c0fb6..a12e4e0d76 100644 --- a/dimos/perception/detection/type/detection2d/test_bbox.py +++ b/dimos/perception/detection/type/detection2d/test_bbox.py @@ -14,7 +14,7 @@ import pytest -def test_detection2d(detection2d): +def test_detection2d(detection2d) -> None: # def test_detection_basic_properties(detection2d): """Test basic detection properties.""" assert detection2d.track_id >= 0 diff --git a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py index 6731b7b0c7..120072cfb6 100644 --- a/dimos/perception/detection/type/detection2d/test_imageDetections2D.py +++ b/dimos/perception/detection/type/detection2d/test_imageDetections2D.py @@ -16,7 +16,7 @@ from dimos.perception.detection.type import ImageDetections2D -def test_from_ros_detection2d_array(get_moment_2d): +def test_from_ros_detection2d_array(get_moment_2d) -> None: moment = get_moment_2d() detections2d = moment["detections2d"] @@ -37,7 +37,7 @@ def test_from_ros_detection2d_array(get_moment_2d): recovered_det = recovered.detections[0] # Check bbox is approximately the same (allow 1 pixel tolerance due to float conversion) - for orig_val, rec_val in zip(original_det.bbox, recovered_det.bbox): + for orig_val, rec_val in zip(original_det.bbox, recovered_det.bbox, strict=False): assert orig_val == pytest.approx(rec_val, abs=1.0) # Check other properties @@ -45,7 +45,7 @@ def test_from_ros_detection2d_array(get_moment_2d): assert recovered_det.class_id == original_det.class_id assert recovered_det.confidence == pytest.approx(original_det.confidence, abs=0.01) - print(f"\nSuccessfully round-tripped detection through ROS format:") + print("\nSuccessfully round-tripped detection through ROS format:") print(f" Original bbox: {original_det.bbox}") print(f" Recovered bbox: {recovered_det.bbox}") print(f" Track ID: {recovered_det.track_id}") diff --git a/dimos/perception/detection/type/detection2d/test_person.py b/dimos/perception/detection/type/detection2d/test_person.py index ba930fd299..2ff1e81237 100644 --- a/dimos/perception/detection/type/detection2d/test_person.py +++ b/dimos/perception/detection/type/detection2d/test_person.py @@ -14,7 +14,7 @@ import pytest -def test_person_ros_confidence(): +def test_person_ros_confidence() -> None: """Test that Detection2DPerson preserves confidence when converting to ROS format.""" from dimos.msgs.sensor_msgs import Image @@ -58,7 +58,7 @@ def test_person_ros_confidence(): print(f" Visible keypoints: {len(person_det.get_visible_keypoints(threshold=0.3))}/17") -def test_person_from_ros_raises(): +def test_person_from_ros_raises() -> None: """Test that Detection2DPerson.from_ros_detection2d() raises NotImplementedError.""" from dimos.perception.detection.type.detection2d.person import Detection2DPerson diff --git a/dimos/perception/detection/type/detection3d/__init__.py b/dimos/perception/detection/type/detection3d/__init__.py index a8d11ca87f..0e765b175f 100644 --- a/dimos/perception/detection/type/detection3d/__init__.py +++ b/dimos/perception/detection/type/detection3d/__init__.py @@ -31,7 +31,7 @@ "ImageDetections3DPC", "PointCloudFilter", "height_filter", - "raycast", "radius_outlier", + "raycast", "statistical", ] diff --git a/dimos/perception/detection/type/detection3d/base.py b/dimos/perception/detection/type/detection3d/base.py index a82a50d474..7988c19a47 100644 --- a/dimos/perception/detection/type/detection3d/base.py +++ b/dimos/perception/detection/type/detection3d/base.py @@ -16,13 +16,15 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import TYPE_CHECKING -from dimos_lcm.sensor_msgs import CameraInfo - -from dimos.msgs.geometry_msgs import Transform from dimos.perception.detection.type.detection2d import Detection2DBBox +if TYPE_CHECKING: + from dimos_lcm.sensor_msgs import CameraInfo + + from dimos.msgs.geometry_msgs import Transform + @dataclass class Detection3D(Detection2DBBox): @@ -39,6 +41,6 @@ def from_2d( distance: float, camera_info: CameraInfo, world_to_optical_transform: Transform, - ) -> Optional["Detection3D"]: + ) -> Detection3D | None: """Create a 3D detection from a 2D detection.""" ... diff --git a/dimos/perception/detection/type/detection3d/bbox.py b/dimos/perception/detection/type/detection3d/bbox.py index 2bc0c1c541..30ca882d16 100644 --- a/dimos/perception/detection/type/detection3d/bbox.py +++ b/dimos/perception/detection/type/detection3d/bbox.py @@ -14,24 +14,12 @@ from __future__ import annotations -import functools from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, TypeVar - -import numpy as np -from dimos_lcm.sensor_msgs import CameraInfo -from lcm_msgs.builtin_interfaces import Duration -from lcm_msgs.foxglove_msgs import CubePrimitive, SceneEntity, SceneUpdate, TextPrimitive -from lcm_msgs.geometry_msgs import Point, Pose, Quaternion -from lcm_msgs.geometry_msgs import Vector3 as LCMVector3 +import functools +from typing import Any -from dimos.msgs.foxglove_msgs.Color import Color from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.perception.detection.type.detection2d import Detection2D, Detection2DBBox -from dimos.perception.detection.type.detection3d.base import Detection3D -from dimos.perception.detection.type.imageDetections import ImageDetections -from dimos.types.timestamped import to_ros_stamp +from dimos.perception.detection.type.detection2d import Detection2DBBox @dataclass @@ -60,7 +48,7 @@ def pose(self) -> PoseStamped: orientation=self.orientation, ) - def to_repr_dict(self) -> Dict[str, Any]: + def to_repr_dict(self) -> dict[str, Any]: # Calculate distance from camera camera_pos = self.transform.translation distance = (self.center - camera_pos).magnitude() diff --git a/dimos/perception/detection/type/detection3d/imageDetections3DPC.py b/dimos/perception/detection/type/detection3d/imageDetections3DPC.py index efad114a2c..f843fb96fd 100644 --- a/dimos/perception/detection/type/detection3d/imageDetections3DPC.py +++ b/dimos/perception/detection/type/detection3d/imageDetections3DPC.py @@ -23,7 +23,7 @@ class ImageDetections3DPC(ImageDetections[Detection3DPC]): """Specialized class for 3D detections in an image.""" - def to_foxglove_scene_update(self) -> "SceneUpdate": + def to_foxglove_scene_update(self) -> SceneUpdate: """Convert all detections to a Foxglove SceneUpdate message. Returns: diff --git a/dimos/perception/detection/type/detection3d/pointcloud.py b/dimos/perception/detection/type/detection3d/pointcloud.py index e5fb82549c..56423d2f29 100644 --- a/dimos/perception/detection/type/detection3d/pointcloud.py +++ b/dimos/perception/detection/type/detection3d/pointcloud.py @@ -14,21 +14,18 @@ from __future__ import annotations -import functools from dataclasses import dataclass -from typing import Any, Dict, Optional +import functools +from typing import TYPE_CHECKING, Any -import numpy as np -from dimos_lcm.sensor_msgs import CameraInfo from lcm_msgs.builtin_interfaces import Duration -from lcm_msgs.foxglove_msgs import CubePrimitive, SceneEntity, SceneUpdate, TextPrimitive -from lcm_msgs.geometry_msgs import Point, Pose, Quaternion -from lcm_msgs.geometry_msgs import Vector3 as LCMVector3 +from lcm_msgs.foxglove_msgs import CubePrimitive, SceneEntity, TextPrimitive +from lcm_msgs.geometry_msgs import Point, Pose, Quaternion, Vector3 as LCMVector3 +import numpy as np from dimos.msgs.foxglove_msgs.Color import Color from dimos.msgs.geometry_msgs import PoseStamped, Transform, Vector3 from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.perception.detection.type.detection2d import Detection2DBBox from dimos.perception.detection.type.detection3d.base import Detection3D from dimos.perception.detection.type.detection3d.pointcloud_filters import ( PointCloudFilter, @@ -38,6 +35,11 @@ ) from dimos.types.timestamped import to_ros_stamp +if TYPE_CHECKING: + from dimos_lcm.sensor_msgs import CameraInfo + + from dimos.perception.detection.type.detection2d import Detection2DBBox + @dataclass class Detection3DPC(Detection3D): @@ -73,11 +75,11 @@ def get_bounding_box_dimensions(self) -> tuple[float, float, float]: """Get dimensions (width, height, depth) of the detection's bounding box.""" return self.pointcloud.get_bounding_box_dimensions() - def bounding_box_intersects(self, other: "Detection3DPC") -> bool: + def bounding_box_intersects(self, other: Detection3DPC) -> bool: """Check if this detection's bounding box intersects with another's.""" return self.pointcloud.bounding_box_intersects(other.pointcloud) - def to_repr_dict(self) -> Dict[str, Any]: + def to_repr_dict(self) -> dict[str, Any]: # Calculate distance from camera # The pointcloud is in world frame, and transform gives camera position in world center_world = self.center @@ -96,7 +98,7 @@ def to_repr_dict(self) -> Dict[str, Any]: "points": str(len(self.pointcloud)), } - def to_foxglove_scene_entity(self, entity_id: Optional[str] = None) -> "SceneEntity": + def to_foxglove_scene_entity(self, entity_id: str | None = None) -> SceneEntity: """Convert detection to a Foxglove SceneEntity with cube primitive and text label. Args: @@ -204,8 +206,8 @@ def from_2d( # type: ignore[override] world_to_optical_transform: Transform, # filters are to be adjusted based on the sensor noise characteristics if feeding # sensor data directly - filters: Optional[list[PointCloudFilter]] = None, - ) -> Optional["Detection3DPC"]: + filters: list[PointCloudFilter] | None = None, + ) -> Detection3DPC | None: """Create a Detection3D from a 2D detection by projecting world pointcloud. This method handles: diff --git a/dimos/perception/detection/type/detection3d/pointcloud_filters.py b/dimos/perception/detection/type/detection3d/pointcloud_filters.py index 51cf3d7f33..1c6085b690 100644 --- a/dimos/perception/detection/type/detection3d/pointcloud_filters.py +++ b/dimos/perception/detection/type/detection3d/pointcloud_filters.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import Callable, Optional +from collections.abc import Callable from dimos_lcm.sensor_msgs import CameraInfo @@ -24,24 +24,24 @@ # Filters take Detection2DBBox, PointCloud2, CameraInfo, Transform and return filtered PointCloud2 or None PointCloudFilter = Callable[ - [Detection2DBBox, PointCloud2, CameraInfo, Transform], Optional[PointCloud2] + [Detection2DBBox, PointCloud2, CameraInfo, Transform], PointCloud2 | None ] -def height_filter(height=0.1) -> PointCloudFilter: +def height_filter(height: float = 0.1) -> PointCloudFilter: return lambda det, pc, ci, tf: pc.filter_by_height(height) -def statistical(nb_neighbors=40, std_ratio=0.5) -> PointCloudFilter: +def statistical(nb_neighbors: int = 40, std_ratio: float = 0.5) -> PointCloudFilter: def filter_func( det: Detection2DBBox, pc: PointCloud2, ci: CameraInfo, tf: Transform - ) -> Optional[PointCloud2]: + ) -> PointCloud2 | None: try: - statistical, removed = pc.pointcloud.remove_statistical_outlier( + statistical, _removed = pc.pointcloud.remove_statistical_outlier( nb_neighbors=nb_neighbors, std_ratio=std_ratio ) return PointCloud2(statistical, pc.frame_id, pc.ts) - except Exception as e: + except Exception: # print("statistical filter failed:", e) return None @@ -51,14 +51,14 @@ def filter_func( def raycast() -> PointCloudFilter: def filter_func( det: Detection2DBBox, pc: PointCloud2, ci: CameraInfo, tf: Transform - ) -> Optional[PointCloud2]: + ) -> PointCloud2 | None: try: camera_pos = tf.inverse().translation camera_pos_np = camera_pos.to_numpy() _, visible_indices = pc.pointcloud.hidden_point_removal(camera_pos_np, radius=100.0) visible_pcd = pc.pointcloud.select_by_index(visible_indices) return PointCloud2(visible_pcd, pc.frame_id, pc.ts) - except Exception as e: + except Exception: # print("raycast filter failed:", e) return None @@ -73,8 +73,8 @@ def radius_outlier(min_neighbors: int = 20, radius: float = 0.3) -> PointCloudFi def filter_func( det: Detection2DBBox, pc: PointCloud2, ci: CameraInfo, tf: Transform - ) -> Optional[PointCloud2]: - filtered_pcd, removed = pc.pointcloud.remove_radius_outlier( + ) -> PointCloud2 | None: + filtered_pcd, _removed = pc.pointcloud.remove_radius_outlier( nb_points=min_neighbors, radius=radius ) return PointCloud2(filtered_pcd, pc.frame_id, pc.ts) diff --git a/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py b/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py index 31e44dad91..4ad2660738 100644 --- a/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py +++ b/dimos/perception/detection/type/detection3d/test_imageDetections3DPC.py @@ -16,7 +16,7 @@ @pytest.mark.skip -def test_to_foxglove_scene_update(detections3dpc): +def test_to_foxglove_scene_update(detections3dpc) -> None: # Convert to scene update scene_update = detections3dpc.to_foxglove_scene_update() @@ -28,7 +28,9 @@ def test_to_foxglove_scene_update(detections3dpc): assert len(scene_update.entities) == len(detections3dpc.detections) # Verify each entity corresponds to a detection - for i, (entity, detection) in enumerate(zip(scene_update.entities, detections3dpc.detections)): + for _i, (entity, detection) in enumerate( + zip(scene_update.entities, detections3dpc.detections, strict=False) + ): assert entity.id == str(detection.track_id) assert entity.frame_id == detection.frame_id assert entity.cubes_length == 1 diff --git a/dimos/perception/detection/type/detection3d/test_pointcloud.py b/dimos/perception/detection/type/detection3d/test_pointcloud.py index edeeaacb4b..f616fe7f33 100644 --- a/dimos/perception/detection/type/detection3d/test_pointcloud.py +++ b/dimos/perception/detection/type/detection3d/test_pointcloud.py @@ -16,7 +16,7 @@ import pytest -def test_detection3dpc(detection3dpc): +def test_detection3dpc(detection3dpc) -> None: # def test_oriented_bounding_box(detection3dpc): """Test oriented bounding box calculation and values.""" obb = detection3dpc.get_oriented_bounding_box() diff --git a/dimos/perception/detection/type/imageDetections.py b/dimos/perception/detection/type/imageDetections.py index 0a1ce8cf56..5ea2b61e45 100644 --- a/dimos/perception/detection/type/imageDetections.py +++ b/dimos/perception/detection/type/imageDetections.py @@ -14,16 +14,18 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Callable, Generic, List, Optional, TypeVar +from typing import TYPE_CHECKING, Generic, TypeVar from dimos_lcm.vision_msgs import Detection2DArray from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.sensor_msgs import Image from dimos.msgs.std_msgs import Header from dimos.perception.detection.type.utils import TableStr if TYPE_CHECKING: + from collections.abc import Callable, Iterator + + from dimos.msgs.sensor_msgs import Image from dimos.perception.detection.type.detection2d.base import Detection2D T = TypeVar("T", bound=Detection2D) @@ -35,23 +37,23 @@ class ImageDetections(Generic[T], TableStr): image: Image - detections: List[T] + detections: list[T] @property def ts(self) -> float: return self.image.ts - def __init__(self, image: Image, detections: Optional[List[T]] = None): + def __init__(self, image: Image, detections: list[T] | None = None) -> None: self.image = image self.detections = detections or [] for det in self.detections: if not det.ts: det.ts = image.ts - def __len__(self): + def __len__(self) -> int: return len(self.detections) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.detections) def __getitem__(self, index): diff --git a/dimos/perception/detection/type/test_detection3d.py b/dimos/perception/detection/type/test_detection3d.py index 44413df1fe..031623afe3 100644 --- a/dimos/perception/detection/type/test_detection3d.py +++ b/dimos/perception/detection/type/test_detection3d.py @@ -14,19 +14,17 @@ import time -from dimos.perception.detection.type.detection3d import Detection3D - -def test_guess_projection(get_moment_2d, publish_moment): +def test_guess_projection(get_moment_2d, publish_moment) -> None: moment = get_moment_2d() for key, value in moment.items(): print(key, "====================================") print(value) - camera_info = moment.get("camera_info") + moment.get("camera_info") detection2d = moment.get("detections2d")[0] tf = moment.get("tf") - transform = tf.get("camera_optical", "world", detection2d.ts, 5.0) + tf.get("camera_optical", "world", detection2d.ts, 5.0) # for stash # detection3d = Detection3D.from_2d(detection2d, 1.5, camera_info, transform) diff --git a/dimos/perception/detection/type/test_object3d.py b/dimos/perception/detection/type/test_object3d.py index 1dc3cb6bd0..4acd2f1afa 100644 --- a/dimos/perception/detection/type/test_object3d.py +++ b/dimos/perception/detection/type/test_object3d.py @@ -14,14 +14,11 @@ import pytest -from dimos.perception.detection.module2D import Detection2DModule -from dimos.perception.detection.module3D import Detection3DModule -from dimos.perception.detection.moduleDB import Object3D, ObjectDBModule +from dimos.perception.detection.moduleDB import Object3D from dimos.perception.detection.type.detection3d import ImageDetections3DPC -from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule -def test_first_object(first_object): +def test_first_object(first_object) -> None: # def test_object3d_properties(first_object): """Test basic properties of an Object3D.""" assert first_object.track_id is not None @@ -46,7 +43,7 @@ def test_first_object(first_object): assert -10 < first_object.center.z < 10 -def test_object3d_repr_dict(first_object): +def test_object3d_repr_dict(first_object) -> None: """Test to_repr_dict method.""" repr_dict = first_object.to_repr_dict() @@ -91,7 +88,7 @@ def test_object3d_repr_dict(first_object): assert first_object.get_image() is first_object.best_detection.image -def test_all_objeects(all_objects): +def test_all_objeects(all_objects) -> None: # def test_object3d_multiple_detections(all_objects): """Test objects that have been built from multiple detections.""" # Find objects with multiple detections @@ -121,7 +118,7 @@ def test_all_objeects(all_objects): assert obj.detections >= 1 -def test_objectdb_module(object_db_module): +def test_objectdb_module(object_db_module) -> None: # def test_object_db_module_populated(object_db_module): """Test that ObjectDBModule is properly populated.""" assert len(object_db_module.objects) > 0, "Database should contain objects" diff --git a/dimos/perception/detection/type/utils.py b/dimos/perception/detection/type/utils.py index f1e2187015..89cf41b404 100644 --- a/dimos/perception/detection/type/utils.py +++ b/dimos/perception/detection/type/utils.py @@ -50,7 +50,7 @@ def _hash_to_color(name: str) -> str: class TableStr: """Mixin class that provides table-based string representation for detection collections.""" - def __str__(self): + def __str__(self) -> str: console = Console(force_terminal=True, legacy_windows=False) # Create a table for detections diff --git a/dimos/perception/detection2d/utils.py b/dimos/perception/detection2d/utils.py index 73e0eb5671..c44a013325 100644 --- a/dimos/perception/detection2d/utils.py +++ b/dimos/perception/detection2d/utils.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +from collections.abc import Sequence + import cv2 -from dimos.types.vector import Vector +import numpy as np def filter_detections( @@ -22,7 +23,7 @@ def filter_detections( track_ids, class_ids, confidences, - names, + names: Sequence[str], class_filter=None, name_filter=None, track_id_filter=None, @@ -61,7 +62,7 @@ def filter_detections( # Filter detections for bbox, track_id, class_id, conf, name in zip( - bboxes, track_ids, class_ids, confidences, names + bboxes, track_ids, class_ids, confidences, names, strict=False ): # Check if detection passes all specified filters keep = True @@ -154,7 +155,9 @@ def extract_detection_results(result, class_filter=None, name_filter=None, track return bboxes, track_ids, class_ids, confidences, names -def plot_results(image, bboxes, track_ids, class_ids, confidences, names, alpha=0.5): +def plot_results( + image, bboxes, track_ids, class_ids, confidences, names: Sequence[str], alpha: float = 0.5 +): """ Draw bounding boxes and labels on the image. @@ -172,7 +175,7 @@ def plot_results(image, bboxes, track_ids, class_ids, confidences, names, alpha= """ vis_img = image.copy() - for bbox, track_id, conf, name in zip(bboxes, track_ids, confidences, names): + for bbox, track_id, conf, name in zip(bboxes, track_ids, confidences, names, strict=False): # Generate consistent color based on track_id or class name if track_id != -1: np.random.seed(track_id) @@ -242,7 +245,7 @@ def calculate_depth_from_bbox(depth_map, bbox): return None -def calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics): +def calculate_distance_angle_from_bbox(bbox, depth: int, camera_intrinsics): """ Calculate distance and angle to object center based on bbox and depth. @@ -258,12 +261,12 @@ def calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics): raise ValueError("Camera intrinsics required for distance calculation") # Extract camera parameters - fx, fy, cx, cy = camera_intrinsics + fx, _fy, cx, _cy = camera_intrinsics # Calculate center of bounding box in pixels x1, y1, x2, y2 = bbox center_x = (x1 + x2) / 2 - center_y = (y1 + y2) / 2 + (y1 + y2) / 2 # Calculate normalized image coordinates x_norm = (center_x - cx) / fx @@ -277,7 +280,7 @@ def calculate_distance_angle_from_bbox(bbox, depth, camera_intrinsics): return distance, angle -def calculate_object_size_from_bbox(bbox, depth, camera_intrinsics): +def calculate_object_size_from_bbox(bbox, depth: int, camera_intrinsics): """ Estimate physical width and height of object in meters. diff --git a/dimos/perception/grasp_generation/grasp_generation.py b/dimos/perception/grasp_generation/grasp_generation.py index 89e7a0036c..730ccd1aa2 100644 --- a/dimos/perception/grasp_generation/grasp_generation.py +++ b/dimos/perception/grasp_generation/grasp_generation.py @@ -17,13 +17,13 @@ """ import asyncio + import numpy as np import open3d as o3d -from typing import Dict, List, Optional +from dimos.perception.grasp_generation.utils import parse_grasp_results from dimos.types.manipulation import ObjectData from dimos.utils.logging_config import setup_logger -from dimos.perception.grasp_generation.utils import parse_grasp_results logger = setup_logger("dimos.perception.grasp_generation") @@ -33,7 +33,7 @@ class HostedGraspGenerator: Dimensional-hosted grasp generator using WebSocket communication. """ - def __init__(self, server_url: str): + def __init__(self, server_url: str) -> None: """ Initialize Dimensional-hosted grasp generator. @@ -44,8 +44,8 @@ def __init__(self, server_url: str): logger.info(f"Initialized grasp generator with server: {server_url}") def generate_grasps_from_objects( - self, objects: List[ObjectData], full_pcd: o3d.geometry.PointCloud - ) -> List[Dict]: + self, objects: list[ObjectData], full_pcd: o3d.geometry.PointCloud + ) -> list[dict]: """ Generate grasps from ObjectData objects using grasp generator. @@ -112,8 +112,8 @@ def generate_grasps_from_objects( return [] def _send_grasp_request_sync( - self, points: np.ndarray, colors: Optional[np.ndarray] - ) -> Optional[List[Dict]]: + self, points: np.ndarray, colors: np.ndarray | None + ) -> list[dict] | None: """Send synchronous grasp request to grasp server.""" try: @@ -149,9 +149,10 @@ def _send_grasp_request_sync( async def _async_grasp_request( self, points: np.ndarray, colors: np.ndarray - ) -> Optional[List[Dict]]: + ) -> list[dict] | None: """Async grasp request helper.""" import json + import websockets try: @@ -183,7 +184,7 @@ async def _async_grasp_request( logger.error(f"Async grasp request failed: {e}") return None - def _convert_grasp_format(self, grasps: List[dict]) -> List[dict]: + def _convert_grasp_format(self, grasps: list[dict]) -> list[dict]: """Convert Dimensional Grasp format to visualization format.""" converted = [] @@ -206,7 +207,7 @@ def _convert_grasp_format(self, grasps: List[dict]) -> List[dict]: converted.sort(key=lambda x: x["score"], reverse=True) return converted - def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, float]: + def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> dict[str, float]: """Convert rotation matrix to Euler angles (in radians).""" sy = np.sqrt(rotation_matrix[0, 0] ** 2 + rotation_matrix[1, 0] ** 2) @@ -223,6 +224,6 @@ def _rotation_matrix_to_euler(self, rotation_matrix: np.ndarray) -> Dict[str, fl return {"roll": x, "pitch": y, "yaw": z} - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" logger.info("Grasp generator cleaned up") diff --git a/dimos/perception/grasp_generation/utils.py b/dimos/perception/grasp_generation/utils.py index ab0cfd0d15..d83d02e596 100644 --- a/dimos/perception/grasp_generation/utils.py +++ b/dimos/perception/grasp_generation/utils.py @@ -14,18 +14,18 @@ """Utilities for grasp generation and visualization.""" +import cv2 import numpy as np import open3d as o3d -import cv2 -from typing import List, Dict, Tuple, Optional, Union -from dimos.perception.common.utils import project_3d_points_to_2d, project_2d_points_to_3d + +from dimos.perception.common.utils import project_3d_points_to_2d def create_gripper_geometry( grasp_data: dict, finger_length: float = 0.08, finger_thickness: float = 0.004, -) -> List[o3d.geometry.TriangleMesh]: +) -> list[o3d.geometry.TriangleMesh]: """ Create a simple fork-like gripper geometry from grasp data. @@ -146,8 +146,8 @@ def create_gripper_geometry( def create_all_gripper_geometries( - grasp_list: List[dict], max_grasps: int = -1 -) -> List[o3d.geometry.TriangleMesh]: + grasp_list: list[dict], max_grasps: int = -1 +) -> list[o3d.geometry.TriangleMesh]: """ Create gripper geometries for multiple grasps. @@ -171,8 +171,8 @@ def create_all_gripper_geometries( def draw_grasps_on_image( image: np.ndarray, - grasp_data: Union[dict, Dict[Union[int, str], List[dict]], List[dict]], - camera_intrinsics: Union[List[float], np.ndarray], # [fx, fy, cx, cy] or 3x3 matrix + grasp_data: dict | dict[int | str, list[dict]] | list[dict], + camera_intrinsics: list[float] | np.ndarray, # [fx, fy, cx, cy] or 3x3 matrix max_grasps: int = -1, # -1 means show all grasps finger_length: float = 0.08, # Match 3D gripper finger_thickness: float = 0.004, # Match 3D gripper @@ -215,7 +215,7 @@ def draw_grasps_on_image( else: # Dictionary of grasps by object ID grasps_to_draw = [] - for obj_id, grasps in grasp_data.items(): + for _obj_id, grasps in grasp_data.items(): for i, grasp in enumerate(grasps): grasps_to_draw.append((grasp, i)) @@ -393,7 +393,7 @@ def transform_points(points): center_2d = project_3d_points_to_2d(translation.reshape(1, -1), camera_matrix)[0] cv2.circle(result, tuple(center_2d.astype(int)), 3, color, -1) - except Exception as e: + except Exception: # Skip this grasp if there's an error continue @@ -426,9 +426,9 @@ def get_standard_coordinate_transform(): def visualize_grasps_3d( point_cloud: o3d.geometry.PointCloud, - grasp_list: List[dict], + grasp_list: list[dict], max_grasps: int = -1, -): +) -> None: """ Visualize grasps in 3D with point cloud. @@ -459,7 +459,7 @@ def visualize_grasps_3d( o3d.visualization.draw_geometries(geometries, window_name="3D Grasp Visualization") -def parse_grasp_results(grasps: List[Dict]) -> List[Dict]: +def parse_grasp_results(grasps: list[dict]) -> list[dict]: """ Parse grasp results into visualization format. @@ -500,8 +500,8 @@ def parse_grasp_results(grasps: List[Dict]) -> List[Dict]: def create_grasp_overlay( rgb_image: np.ndarray, - grasps: List[Dict], - camera_intrinsics: Union[List[float], np.ndarray], + grasps: list[dict], + camera_intrinsics: list[float] | np.ndarray, ) -> np.ndarray: """ Create grasp visualization overlay on RGB image. @@ -524,5 +524,5 @@ def create_grasp_overlay( max_grasps=-1, ) return cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB) - except Exception as e: + except Exception: return rgb_image.copy() diff --git a/dimos/perception/object_detection_stream.py b/dimos/perception/object_detection_stream.py index 4fb8fc2691..a82cbe9db5 100644 --- a/dimos/perception/object_detection_stream.py +++ b/dimos/perception/object_detection_stream.py @@ -12,11 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import time + import numpy as np -from reactivex import Observable -from reactivex import operators as ops +from reactivex import Observable, operators as ops from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector @@ -27,19 +25,22 @@ except (ModuleNotFoundError, ImportError): DETIC_AVAILABLE = False Detic2DDetector = None +from collections.abc import Callable +from typing import TYPE_CHECKING + from dimos.models.depth.metric3d import Metric3D +from dimos.perception.common.utils import draw_object_detection_visualization from dimos.perception.detection2d.utils import ( calculate_depth_from_bbox, calculate_object_size_from_bbox, calculate_position_rotation_from_bbox, ) -from dimos.perception.common.utils import draw_object_detection_visualization from dimos.types.vector import Vector -from typing import Optional, Union, Callable -from dimos.types.manipulation import ObjectData +from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import transform_robot_to_map -from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from dimos.types.manipulation import ObjectData # Initialize logger for the ObjectDetectionStream logger = setup_logger("dimos.perception.object_detection_stream") @@ -60,16 +61,16 @@ class ObjectDetectionStream: def __init__( self, camera_intrinsics=None, # [fx, fy, cx, cy] - device="cuda", - gt_depth_scale=1000.0, - min_confidence=0.7, + device: str = "cuda", + gt_depth_scale: float = 1000.0, + min_confidence: float = 0.7, class_filter=None, # Optional list of class names to filter (e.g., ["person", "car"]) - get_pose: Callable = None, # Optional function to transform coordinates to map frame - detector: Optional[Union[Detic2DDetector, Yolo2DDetector]] = None, + get_pose: Callable | None = None, # Optional function to transform coordinates to map frame + detector: Detic2DDetector | Yolo2DDetector | None = None, video_stream: Observable = None, disable_depth: bool = False, # Flag to disable monocular Metric3D depth estimation draw_masks: bool = False, # Flag to enable drawing segmentation masks - ): + ) -> None: """ Initialize the ObjectDetectionStream. @@ -155,7 +156,8 @@ def create_stream(self, video_stream: Observable) -> Observable: def process_frame(frame): # TODO: More modular detector output interface bboxes, track_ids, class_ids, confidences, names, *mask_data = ( - self.detector.process_image(frame) + ([],) + *self.detector.process_image(frame), + [], ) masks = ( @@ -311,6 +313,6 @@ def format_detection_data(result): # Return a new stream with the formatter applied return self.stream.pipe(ops.map(format_detection_data)) - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" pass diff --git a/dimos/perception/object_tracker.py b/dimos/perception/object_tracker.py index 497b6933b3..f5fa48581a 100644 --- a/dimos/perception/object_tracker.py +++ b/dimos/perception/object_tracker.py @@ -12,20 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np -import time import threading -from typing import Dict, List, Optional +import time -from dimos.core import In, Out, Module, rpc -from dimos.msgs.std_msgs import Header -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray -from reactivex.disposable import Disposable -from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose -from dimos.protocol.tf import TF -from dimos.utils.logging_config import setup_logger +import cv2 +from dimos_lcm.sensor_msgs import CameraInfo # Import LCM messages from dimos_lcm.vision_msgs import ( @@ -33,14 +24,23 @@ Detection3D, ObjectHypothesisWithPose, ) -from dimos_lcm.sensor_msgs import CameraInfo +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.manipulation.visual_servoing.utils import visualize_detections_3d +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray +from dimos.protocol.tf import TF +from dimos.types.timestamped import align_timestamped +from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( - yaw_towards_point, - optical_to_robot_frame, euler_to_quaternion, + optical_to_robot_frame, + yaw_towards_point, ) -from dimos.manipulation.visual_servoing.utils import visualize_detections_3d -from dimos.types.timestamped import align_timestamped logger = setup_logger("dimos.perception.object_tracker") @@ -63,7 +63,7 @@ def __init__( reid_threshold: int = 10, reid_fail_tolerance: int = 5, frame_id: str = "camera_link", - ): + ) -> None: """ Initialize an object tracking module using OpenCV's CSRT tracker with ORB re-ID. @@ -99,12 +99,12 @@ def __init__( self.reid_warmup_frames = 3 # Number of frames before REID starts self._frame_lock = threading.Lock() - self._latest_rgb_frame: Optional[np.ndarray] = None - self._latest_depth_frame: Optional[np.ndarray] = None - self._latest_camera_info: Optional[CameraInfo] = None + self._latest_rgb_frame: np.ndarray | None = None + self._latest_depth_frame: np.ndarray | None = None + self._latest_camera_info: CameraInfo | None = None # Tracking thread control - self.tracking_thread: Optional[threading.Thread] = None + self.tracking_thread: threading.Thread | None = None self.stop_tracking = threading.Event() self.tracking_rate = 30.0 # Hz self.tracking_period = 1.0 / self.tracking_rate @@ -113,16 +113,16 @@ def __init__( self.tf = TF() # Store latest detections for RPC access - self._latest_detection2d: Optional[Detection2DArray] = None - self._latest_detection3d: Optional[Detection3DArray] = None + self._latest_detection2d: Detection2DArray | None = None + self._latest_detection3d: Detection3DArray | None = None self._detection_event = threading.Event() @rpc - def start(self): + def start(self) -> None: super().start() # Subscribe to aligned rgb and depth streams - def on_aligned_frames(frames_tuple): + def on_aligned_frames(frames_tuple) -> None: rgb_msg, depth_msg = frames_tuple with self._frame_lock: self._latest_rgb_frame = rgb_msg.data @@ -144,7 +144,7 @@ def on_aligned_frames(frames_tuple): self._disposables.add(unsub) # Subscribe to camera info stream separately (doesn't need alignment) - def on_camera_info(camera_info_msg: CameraInfo): + def on_camera_info(camera_info_msg: CameraInfo) -> None: self._latest_camera_info = camera_info_msg # Extract intrinsics from camera info K matrix # K is a 3x3 matrix in row-major order: [fx, 0, cx, 0, fy, cy, 0, 0, 1] @@ -172,8 +172,8 @@ def stop(self) -> None: @rpc def track( self, - bbox: List[float], - ) -> Dict: + bbox: list[float], + ) -> dict: """ Initialize tracking with a bounding box and process current frame. @@ -269,14 +269,14 @@ def reid(self, frame, current_bbox) -> bool: return good_matches >= self.reid_threshold - def _start_tracking_thread(self): + def _start_tracking_thread(self) -> None: """Start the tracking thread.""" self.stop_tracking.clear() self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) self.tracking_thread.start() logger.info("Started tracking thread") - def _tracking_loop(self): + def _tracking_loop(self) -> None: """Main tracking loop that runs in a separate thread.""" while not self.stop_tracking.is_set() and self.tracking_initialized: # Process tracking for current frame @@ -287,7 +287,7 @@ def _tracking_loop(self): logger.info("Tracking loop ended") - def _reset_tracking_state(self): + def _reset_tracking_state(self) -> None: """Reset tracking state without stopping the thread.""" self.tracker = None self.tracking_bbox = None @@ -346,7 +346,7 @@ def is_tracking(self) -> bool: """ return self.tracking_initialized and self.reid_confirmed - def _process_tracking(self): + def _process_tracking(self) -> None: """Process current frame for tracking and publish detections.""" if self.tracker is None or not self.tracking_initialized: return @@ -495,7 +495,7 @@ def _process_tracking(self): translation=robot_pose.position, rotation=robot_pose.orientation, frame_id=self.frame_id, # Use configured camera frame - child_frame_id=f"tracked_object", + child_frame_id="tracked_object", ts=header.ts, ) self.tf.publish(tracked_object_tf) @@ -550,7 +550,7 @@ def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray: """Draw REID feature matches on the image.""" viz_image = image.copy() - x1, y1, x2, y2 = self.last_roi_bbox + x1, y1, _x2, _y2 = self.last_roi_bbox # Draw keypoints from current ROI in green for kp in self.last_roi_kps: @@ -590,7 +590,7 @@ def _draw_reid_matches(self, image: np.ndarray) -> np.ndarray: return viz_image - def _get_depth_from_bbox(self, bbox: List[int], depth_frame: np.ndarray) -> Optional[float]: + def _get_depth_from_bbox(self, bbox: list[int], depth_frame: np.ndarray) -> float | None: """Calculate depth from bbox using the 25th percentile of closest points. Args: diff --git a/dimos/perception/object_tracker_2d.py b/dimos/perception/object_tracker_2d.py index 84b823ce5e..0256b7beb9 100644 --- a/dimos/perception/object_tracker_2d.py +++ b/dimos/perception/object_tracker_2d.py @@ -12,19 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np -import time -import threading -from typing import Dict, List, Optional import logging +import threading +import time -from dimos.core import In, Out, Module, rpc -from dimos.msgs.std_msgs import Header -from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos.msgs.vision_msgs import Detection2DArray -from dimos.utils.logging_config import setup_logger -from reactivex.disposable import Disposable +import cv2 # Import LCM messages from dimos_lcm.vision_msgs import ( @@ -35,6 +27,14 @@ Point2D, Pose2D, ) +import numpy as np +from reactivex.disposable import Disposable + +from dimos.core import In, Module, Out, rpc +from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header +from dimos.msgs.vision_msgs import Detection2DArray +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.perception.object_tracker_2d", level=logging.INFO) @@ -50,7 +50,7 @@ class ObjectTracker2D(Module): def __init__( self, frame_id: str = "camera_link", - ): + ) -> None: """ Initialize 2D object tracking module using OpenCV's CSRT tracker. @@ -73,23 +73,23 @@ def __init__( # Frame management self._frame_lock = threading.Lock() - self._latest_rgb_frame: Optional[np.ndarray] = None - self._frame_arrival_time: Optional[float] = None + self._latest_rgb_frame: np.ndarray | None = None + self._frame_arrival_time: float | None = None # Tracking thread control - self.tracking_thread: Optional[threading.Thread] = None + self.tracking_thread: threading.Thread | None = None self.stop_tracking_event = threading.Event() self.tracking_rate = 5.0 # Hz self.tracking_period = 1.0 / self.tracking_rate # Store latest detection for RPC access - self._latest_detection2d: Optional[Detection2DArray] = None + self._latest_detection2d: Detection2DArray | None = None @rpc - def start(self): + def start(self) -> None: super().start() - def on_frame(frame_msg: Image): + def on_frame(frame_msg: Image) -> None: arrival_time = time.perf_counter() with self._frame_lock: self._latest_rgb_frame = frame_msg.data @@ -109,7 +109,7 @@ def stop(self) -> None: super().stop() @rpc - def track(self, bbox: List[float]) -> Dict: + def track(self, bbox: list[float]) -> dict: """ Initialize tracking with a bounding box. @@ -151,21 +151,21 @@ def track(self, bbox: List[float]) -> Dict: return {"status": "tracking_started", "bbox": self.tracking_bbox} - def _start_tracking_thread(self): + def _start_tracking_thread(self) -> None: """Start the tracking thread.""" self.stop_tracking_event.clear() self.tracking_thread = threading.Thread(target=self._tracking_loop, daemon=True) self.tracking_thread.start() logger.info("Started tracking thread") - def _tracking_loop(self): + def _tracking_loop(self) -> None: """Main tracking loop that runs in a separate thread.""" while not self.stop_tracking_event.is_set() and self.tracking_initialized: self._process_tracking() time.sleep(self.tracking_period) logger.info("Tracking loop ended") - def _reset_tracking_state(self): + def _reset_tracking_state(self) -> None: """Reset tracking state without stopping the thread.""" self.tracker = None self.tracking_bbox = None @@ -212,7 +212,7 @@ def is_tracking(self) -> bool: """ return self.tracking_initialized - def _process_tracking(self): + def _process_tracking(self) -> None: """Process current frame for tracking and publish 2D detections.""" if self.tracker is None or not self.tracking_initialized: return @@ -290,7 +290,7 @@ def _process_tracking(self): viz_msg = Image.from_numpy(viz_copy, format=ImageFormat.RGB) self.tracked_overlay.publish(viz_msg) - def _draw_visualization(self, image: np.ndarray, bbox: List[int]) -> np.ndarray: + def _draw_visualization(self, image: np.ndarray, bbox: list[int]) -> np.ndarray: """Draw tracking visualization.""" viz_image = image.copy() x1, y1, x2, y2 = bbox diff --git a/dimos/perception/object_tracker_3d.py b/dimos/perception/object_tracker_3d.py index 20b5705c05..231ae26748 100644 --- a/dimos/perception/object_tracker_3d.py +++ b/dimos/perception/object_tracker_3d.py @@ -12,28 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. + +# Import LCM messages +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.vision_msgs import Detection3D, ObjectHypothesisWithPose import numpy as np -from typing import List, Optional from dimos.core import In, Out, rpc -from dimos.msgs.std_msgs import Header +from dimos.manipulation.visual_servoing.utils import visualize_detections_3d +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 from dimos.msgs.sensor_msgs import Image, ImageFormat +from dimos.msgs.std_msgs import Header from dimos.msgs.vision_msgs import Detection2DArray, Detection3DArray -from dimos.msgs.geometry_msgs import Vector3, Quaternion, Transform, Pose from dimos.perception.object_tracker_2d import ObjectTracker2D from dimos.protocol.tf import TF from dimos.types.timestamped import align_timestamped from dimos.utils.logging_config import setup_logger from dimos.utils.transform_utils import ( - yaw_towards_point, - optical_to_robot_frame, euler_to_quaternion, + optical_to_robot_frame, + yaw_towards_point, ) -from dimos.manipulation.visual_servoing.utils import visualize_detections_3d - -# Import LCM messages -from dimos_lcm.sensor_msgs import CameraInfo -from dimos_lcm.vision_msgs import Detection3D, ObjectHypothesisWithPose logger = setup_logger("dimos.perception.object_tracker_3d") @@ -48,7 +47,7 @@ class ObjectTracker3D(ObjectTracker2D): # Additional outputs (2D tracker already has detection2darray and tracked_overlay) detection3darray: Out[Detection3DArray] = None - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: """ Initialize 3D object tracking module. @@ -59,21 +58,21 @@ def __init__(self, **kwargs): # Additional state for 3D tracking self.camera_intrinsics = None - self._latest_depth_frame: Optional[np.ndarray] = None - self._latest_camera_info: Optional[CameraInfo] = None + self._latest_depth_frame: np.ndarray | None = None + self._latest_camera_info: CameraInfo | None = None # TF publisher for tracked object self.tf = TF() # Store latest 3D detection - self._latest_detection3d: Optional[Detection3DArray] = None + self._latest_detection3d: Detection3DArray | None = None @rpc - def start(self): + def start(self) -> None: super().start() # Subscribe to aligned RGB and depth streams - def on_aligned_frames(frames_tuple): + def on_aligned_frames(frames_tuple) -> None: rgb_msg, depth_msg = frames_tuple with self._frame_lock: self._latest_rgb_frame = rgb_msg.data @@ -95,7 +94,7 @@ def on_aligned_frames(frames_tuple): self._disposables.add(unsub) # Subscribe to camera info - def on_camera_info(camera_info_msg: CameraInfo): + def on_camera_info(camera_info_msg: CameraInfo) -> None: self._latest_camera_info = camera_info_msg # Extract intrinsics: K is [fx, 0, cx, 0, fy, cy, 0, 0, 1] self.camera_intrinsics = [ @@ -113,7 +112,7 @@ def on_camera_info(camera_info_msg: CameraInfo): def stop(self) -> None: super().stop() - def _process_tracking(self): + def _process_tracking(self) -> None: """Override to add 3D detection creation after 2D tracking.""" # Call parent 2D tracking super()._process_tracking() @@ -155,9 +154,7 @@ def _process_tracking(self): viz_msg = Image.from_numpy(viz_image) self.tracked_overlay.publish(viz_msg) - def _create_detection3d_from_2d( - self, detection2d: Detection2DArray - ) -> Optional[Detection3DArray]: + def _create_detection3d_from_2d(self, detection2d: Detection2DArray) -> Detection3DArray | None: """Create 3D detection from 2D detection using depth.""" if detection2d.detections_length == 0: return None @@ -243,7 +240,7 @@ def _create_detection3d_from_2d( return detection3darray - def _get_depth_from_bbox(self, bbox: List[int], depth_frame: np.ndarray) -> Optional[float]: + def _get_depth_from_bbox(self, bbox: list[int], depth_frame: np.ndarray) -> float | None: """ Calculate depth from bbox using the 25th percentile of closest points. diff --git a/dimos/perception/person_tracker.py b/dimos/perception/person_tracker.py index d5d3e2be09..915c241196 100644 --- a/dimos/perception/person_tracker.py +++ b/dimos/perception/person_tracker.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector -from dimos.perception.detection2d.utils import filter_detections -from dimos.perception.common.ibvs import PersonDistanceEstimator -from reactivex import Observable, interval -from reactivex.disposable import Disposable -from reactivex import operators as ops -import numpy as np + import cv2 -from typing import Dict, Optional +import numpy as np +from reactivex import Observable, interval, operators as ops +from reactivex.disposable import Disposable -from dimos.core import In, Out, Module, rpc +from dimos.core import In, Module, Out, rpc from dimos.msgs.sensor_msgs import Image +from dimos.perception.common.ibvs import PersonDistanceEstimator +from dimos.perception.detection2d.utils import filter_detections +from dimos.perception.detection2d.yolo_2d_det import Yolo2DDetector from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.perception.person_tracker") @@ -36,14 +35,14 @@ class PersonTrackingStream(Module): video: In[Image] = None # LCM outputs - tracking_data: Out[Dict] = None + tracking_data: Out[dict] = None def __init__( self, camera_intrinsics=None, - camera_pitch=0.0, - camera_height=1.0, - ): + camera_pitch: float = 0.0, + camera_height: float = 1.0, + ) -> None: """ Initialize a person tracking stream using Yolo2DDetector and PersonDistanceEstimator. @@ -85,20 +84,20 @@ def __init__( ) # For tracking latest frame data - self._latest_frame: Optional[np.ndarray] = None + self._latest_frame: np.ndarray | None = None self._process_interval = 0.1 # Process at 10Hz # Tracking state - starts disabled self._tracking_enabled = False @rpc - def start(self): + def start(self) -> None: """Start the person tracking module and subscribe to LCM streams.""" super().start() # Subscribe to video stream - def set_video(image_msg: Image): + def set_video(image_msg: Image) -> None: if hasattr(image_msg, "data"): self._latest_frame = image_msg.data else: @@ -117,7 +116,7 @@ def set_video(image_msg: Image): def stop(self) -> None: super().stop() - def _process_frame(self): + def _process_frame(self) -> None: """Process the latest frame if available.""" if self._latest_frame is None: return @@ -179,7 +178,7 @@ def _process_tracking(self, frame): target_data["angle"] = angle # Add text to visualization - x1, y1, x2, y2 = map(int, bbox) + _x1, y1, x2, _y2 = map(int, bbox) dist_text = f"{distance:.2f}m, {np.rad2deg(angle):.1f} deg" # Add black background for better visibility @@ -237,7 +236,7 @@ def is_tracking_enabled(self) -> bool: return self._tracking_enabled @rpc - def get_tracking_data(self) -> Dict: + def get_tracking_data(self) -> dict: """Get the latest tracking data. Returns: diff --git a/dimos/perception/pointcloud/__init__.py b/dimos/perception/pointcloud/__init__.py index 1f282bb738..a380e2aadf 100644 --- a/dimos/perception/pointcloud/__init__.py +++ b/dimos/perception/pointcloud/__init__.py @@ -1,3 +1,3 @@ -from .utils import * from .cuboid_fit import * from .pointcloud_filtering import * +from .utils import * diff --git a/dimos/perception/pointcloud/cuboid_fit.py b/dimos/perception/pointcloud/cuboid_fit.py index d567f40395..376ae08da0 100644 --- a/dimos/perception/pointcloud/cuboid_fit.py +++ b/dimos/perception/pointcloud/cuboid_fit.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. + +import cv2 import numpy as np import open3d as o3d -import cv2 -from typing import Dict, Optional, Union, Tuple def fit_cuboid( - points: Union[np.ndarray, o3d.geometry.PointCloud], method: str = "minimal" -) -> Optional[Dict]: + points: np.ndarray | o3d.geometry.PointCloud, method: str = "minimal" +) -> dict | None: """ Fit a cuboid to a point cloud using Open3D's built-in methods. @@ -103,7 +103,7 @@ def fit_cuboid( return None -def fit_cuboid_simple(points: Union[np.ndarray, o3d.geometry.PointCloud]) -> Optional[Dict]: +def fit_cuboid_simple(points: np.ndarray | o3d.geometry.PointCloud) -> dict | None: """ Simple wrapper for minimal oriented bounding box fitting. @@ -190,11 +190,11 @@ def get_cuboid_corners( def visualize_cuboid_on_image( image: np.ndarray, - cuboid_params: Dict, + cuboid_params: dict, camera_matrix: np.ndarray, - extrinsic_rotation: Optional[np.ndarray] = None, - extrinsic_translation: Optional[np.ndarray] = None, - color: Tuple[int, int, int] = (0, 255, 0), + extrinsic_rotation: np.ndarray | None = None, + extrinsic_translation: np.ndarray | None = None, + color: tuple[int, int, int] = (0, 255, 0), thickness: int = 2, show_dimensions: bool = True, ) -> np.ndarray: @@ -320,7 +320,7 @@ def visualize_cuboid_on_image( return vis_img -def compute_cuboid_volume(cuboid_params: Dict) -> float: +def compute_cuboid_volume(cuboid_params: dict) -> float: """ Compute the volume of a cuboid. @@ -337,7 +337,7 @@ def compute_cuboid_volume(cuboid_params: Dict) -> float: return float(np.prod(dims)) -def compute_cuboid_surface_area(cuboid_params: Dict) -> float: +def compute_cuboid_surface_area(cuboid_params: dict) -> float: """ Compute the surface area of a cuboid. @@ -354,7 +354,7 @@ def compute_cuboid_surface_area(cuboid_params: Dict) -> float: return 2.0 * (dims[0] * dims[1] + dims[1] * dims[2] + dims[2] * dims[0]) -def check_cuboid_quality(cuboid_params: Dict, points: np.ndarray) -> Dict: +def check_cuboid_quality(cuboid_params: dict, points: np.ndarray) -> dict: """ Assess the quality of a cuboid fit. diff --git a/dimos/perception/pointcloud/pointcloud_filtering.py b/dimos/perception/pointcloud/pointcloud_filtering.py index 3de2f3ae6a..4ca8a0c84b 100644 --- a/dimos/perception/pointcloud/pointcloud_filtering.py +++ b/dimos/perception/pointcloud/pointcloud_filtering.py @@ -12,22 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np + import cv2 -import os -import torch +import numpy as np import open3d as o3d -import argparse -import pickle -from typing import Dict, List, Optional, Union -import time -from dimos.types.manipulation import ObjectData -from dimos.types.vector import Vector +import torch + +from dimos.perception.pointcloud.cuboid_fit import fit_cuboid from dimos.perception.pointcloud.utils import ( - load_camera_matrix_from_yaml, create_point_cloud_and_extract_masks, + load_camera_matrix_from_yaml, ) -from dimos.perception.pointcloud.cuboid_fit import fit_cuboid +from dimos.types.manipulation import ObjectData +from dimos.types.vector import Vector class PointcloudFiltering: @@ -40,8 +37,8 @@ class PointcloudFiltering: def __init__( self, - color_intrinsics: Optional[Union[str, List[float], np.ndarray]] = None, - depth_intrinsics: Optional[Union[str, List[float], np.ndarray]] = None, + color_intrinsics: str | list[float] | np.ndarray | None = None, + depth_intrinsics: str | list[float] | np.ndarray | None = None, color_weight: float = 0.3, enable_statistical_filtering: bool = True, statistical_neighbors: int = 20, @@ -55,7 +52,7 @@ def __init__( min_points_for_cuboid: int = 10, cuboid_method: str = "oriented", max_bbox_size_percent: float = 30.0, - ): + ) -> None: """ Initialize the point cloud filtering pipeline. @@ -117,13 +114,13 @@ def generate_color_from_id(self, object_id: int) -> np.ndarray: return color def _validate_inputs( - self, color_img: np.ndarray, depth_img: np.ndarray, objects: List[ObjectData] + self, color_img: np.ndarray, depth_img: np.ndarray, objects: list[ObjectData] ): """Validate input parameters.""" if color_img.shape[:2] != depth_img.shape: raise ValueError("Color and depth image dimensions don't match") - def _prepare_masks(self, masks: List[np.ndarray], target_shape: tuple) -> List[np.ndarray]: + def _prepare_masks(self, masks: list[np.ndarray], target_shape: tuple) -> list[np.ndarray]: """Prepare and validate masks to match target shape.""" processed_masks = [] for mask in masks: @@ -187,7 +184,7 @@ def _apply_subsampling(self, pcd: o3d.geometry.PointCloud) -> o3d.geometry.Point return pcd.voxel_down_sample(self.voxel_size) return pcd - def _extract_masks_from_objects(self, objects: List[ObjectData]) -> List[np.ndarray]: + def _extract_masks_from_objects(self, objects: list[ObjectData]) -> list[np.ndarray]: """Extract segmentation masks from ObjectData objects.""" return [obj["segmentation_mask"] for obj in objects] @@ -196,8 +193,8 @@ def get_full_point_cloud(self) -> o3d.geometry.PointCloud: return self._apply_subsampling(self.full_pcd) def process_images( - self, color_img: np.ndarray, depth_img: np.ndarray, objects: List[ObjectData] - ) -> List[ObjectData]: + self, color_img: np.ndarray, depth_img: np.ndarray, objects: list[ObjectData] + ) -> list[ObjectData]: """ Process color and depth images with object detection results to create filtered point clouds. @@ -276,7 +273,9 @@ def process_images( # Process each object and update ObjectData updated_objects = [] - for i, (obj, mask, pcd) in enumerate(zip(objects, processed_masks, masked_pcds)): + for i, (obj, _mask, pcd) in enumerate( + zip(objects, processed_masks, masked_pcds, strict=False) + ): # Skip empty point clouds if len(np.asarray(pcd.points)) == 0: continue @@ -353,7 +352,7 @@ def process_images( return updated_objects - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" if torch.cuda.is_available(): torch.cuda.empty_cache() diff --git a/dimos/perception/pointcloud/test_pointcloud_filtering.py b/dimos/perception/pointcloud/test_pointcloud_filtering.py index 4b4e5c7c4f..719feeb984 100644 --- a/dimos/perception/pointcloud/test_pointcloud_filtering.py +++ b/dimos/perception/pointcloud/test_pointcloud_filtering.py @@ -13,30 +13,34 @@ # limitations under the License. import os +from typing import TYPE_CHECKING + import cv2 import numpy as np -import pytest import open3d as o3d +import pytest from dimos.perception.pointcloud.pointcloud_filtering import PointcloudFiltering from dimos.perception.pointcloud.utils import load_camera_matrix_from_yaml -from dimos.types.manipulation import ObjectData + +if TYPE_CHECKING: + from dimos.types.manipulation import ObjectData class TestPointcloudFiltering: - def test_pointcloud_filtering_initialization(self): + def test_pointcloud_filtering_initialization(self) -> None: """Test PointcloudFiltering initializes correctly with default parameters.""" try: filtering = PointcloudFiltering() assert filtering is not None assert filtering.color_weight == 0.3 - assert filtering.enable_statistical_filtering == True - assert filtering.enable_radius_filtering == True - assert filtering.enable_subsampling == True + assert filtering.enable_statistical_filtering + assert filtering.enable_radius_filtering + assert filtering.enable_subsampling except Exception as e: pytest.skip(f"Skipping test due to initialization error: {e}") - def test_pointcloud_filtering_with_custom_params(self): + def test_pointcloud_filtering_with_custom_params(self) -> None: """Test PointcloudFiltering with custom parameters.""" try: filtering = PointcloudFiltering( @@ -47,14 +51,14 @@ def test_pointcloud_filtering_with_custom_params(self): max_num_objects=5, ) assert filtering.color_weight == 0.5 - assert filtering.enable_statistical_filtering == False - assert filtering.enable_radius_filtering == False + assert not filtering.enable_statistical_filtering + assert not filtering.enable_radius_filtering assert filtering.voxel_size == 0.01 assert filtering.max_num_objects == 5 except Exception as e: pytest.skip(f"Skipping test due to initialization error: {e}") - def test_pointcloud_filtering_process_images(self): + def test_pointcloud_filtering_process_images(self) -> None: """Test PointcloudFiltering can process RGB-D images and return filtered point clouds.""" try: # Import data inside method to avoid pytest fixture confusion @@ -204,7 +208,7 @@ def test_pointcloud_filtering_process_images(self): except Exception as e: pytest.skip(f"Skipping test due to error: {e}") - def test_pointcloud_filtering_empty_objects(self): + def test_pointcloud_filtering_empty_objects(self) -> None: """Test PointcloudFiltering with empty object list.""" try: from dimos.utils.data import get_data @@ -234,7 +238,7 @@ def test_pointcloud_filtering_empty_objects(self): except Exception as e: pytest.skip(f"Skipping test due to error: {e}") - def test_color_generation_consistency(self): + def test_color_generation_consistency(self) -> None: """Test that color generation is consistent for the same object ID.""" try: filtering = PointcloudFiltering() diff --git a/dimos/perception/pointcloud/utils.py b/dimos/perception/pointcloud/utils.py index b3c395bfa3..d3fcb19ca6 100644 --- a/dimos/perception/pointcloud/utils.py +++ b/dimos/perception/pointcloud/utils.py @@ -19,19 +19,21 @@ from RGBD images using Open3D. """ -import numpy as np -import yaml import os +from typing import Any + import cv2 +import numpy as np import open3d as o3d -from typing import List, Optional, Tuple, Union, Dict, Any from scipy.spatial import cKDTree +import yaml + from dimos.perception.common.utils import project_3d_points_to_2d def load_camera_matrix_from_yaml( - camera_info: Optional[Union[str, List[float], np.ndarray, dict]], -) -> Optional[np.ndarray]: + camera_info: str | list[float] | np.ndarray | dict | None, +) -> np.ndarray | None: """ Load camera intrinsic matrix from various input formats. @@ -72,7 +74,7 @@ def load_camera_matrix_from_yaml( raise FileNotFoundError(f"Camera info file not found: {camera_info}") try: - with open(camera_info, "r") as f: + with open(camera_info) as f: data = yaml.safe_load(f) return _extract_matrix_from_dict(data) except Exception as e: @@ -199,11 +201,11 @@ def create_o3d_point_cloud_from_rgbd( def create_point_cloud_and_extract_masks( color_img: np.ndarray, depth_img: np.ndarray, - masks: List[np.ndarray], + masks: list[np.ndarray], intrinsic: np.ndarray, depth_scale: float = 1.0, depth_trunc: float = 3.0, -) -> Tuple[o3d.geometry.PointCloud, List[o3d.geometry.PointCloud]]: +) -> tuple[o3d.geometry.PointCloud, list[o3d.geometry.PointCloud]]: """ Efficiently create a point cloud once and extract multiple masked regions. @@ -267,7 +269,7 @@ def create_point_cloud_and_extract_masks( def filter_point_cloud_statistical( pcd: o3d.geometry.PointCloud, nb_neighbors: int = 20, std_ratio: float = 2.0 -) -> Tuple[o3d.geometry.PointCloud, np.ndarray]: +) -> tuple[o3d.geometry.PointCloud, np.ndarray]: """ Apply statistical outlier filtering to point cloud. @@ -287,7 +289,7 @@ def filter_point_cloud_statistical( def filter_point_cloud_radius( pcd: o3d.geometry.PointCloud, nb_points: int = 16, radius: float = 0.05 -) -> Tuple[o3d.geometry.PointCloud, np.ndarray]: +) -> tuple[o3d.geometry.PointCloud, np.ndarray]: """ Apply radius-based outlier filtering to point cloud. @@ -307,9 +309,9 @@ def filter_point_cloud_radius( def overlay_point_clouds_on_image( base_image: np.ndarray, - point_clouds: List[o3d.geometry.PointCloud], - camera_intrinsics: Union[List[float], np.ndarray], - colors: List[Tuple[int, int, int]], + point_clouds: list[o3d.geometry.PointCloud], + camera_intrinsics: list[float] | np.ndarray, + colors: list[tuple[int, int, int]], point_size: int = 2, alpha: float = 0.7, ) -> np.ndarray: @@ -384,7 +386,7 @@ def overlay_point_clouds_on_image( def create_point_cloud_overlay_visualization( base_image: np.ndarray, - objects: List[dict], + objects: list[dict], intrinsics: np.ndarray, ) -> np.ndarray: """ @@ -455,7 +457,7 @@ def create_point_cloud_overlay_visualization( return result -def create_3d_bounding_box_corners(position, rotation, size): +def create_3d_bounding_box_corners(position, rotation, size: int): """ Create 8 corners of a 3D bounding box from position, rotation, and size. @@ -526,7 +528,7 @@ def create_3d_bounding_box_corners(position, rotation, size): return rotated_corners -def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness=2): +def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness: int = 2) -> None: """ Draw a 3D bounding box on an image using projected 2D corners. @@ -561,12 +563,12 @@ def draw_3d_bounding_box_on_image(image, corners_2d, color, thickness=2): def extract_and_cluster_misc_points( full_pcd: o3d.geometry.PointCloud, - all_objects: List[dict], + all_objects: list[dict], eps: float = 0.03, min_points: int = 100, enable_filtering: bool = True, voxel_size: float = 0.02, -) -> Tuple[List[o3d.geometry.PointCloud], o3d.geometry.VoxelGrid]: +) -> tuple[list[o3d.geometry.PointCloud], o3d.geometry.VoxelGrid]: """ Extract miscellaneous/background points and cluster them using DBSCAN. @@ -726,7 +728,7 @@ def _create_voxel_grid_from_point_cloud( def _create_voxel_grid_from_clusters( - clusters: List[o3d.geometry.PointCloud], voxel_size: float = 0.02 + clusters: list[o3d.geometry.PointCloud], voxel_size: float = 0.02 ) -> o3d.geometry.VoxelGrid: """ Create a voxel grid from multiple clustered point clouds. @@ -761,7 +763,7 @@ def _create_voxel_grid_from_clusters( def _cluster_point_cloud_dbscan( pcd: o3d.geometry.PointCloud, eps: float = 0.05, min_points: int = 50 -) -> List[o3d.geometry.PointCloud]: +) -> list[o3d.geometry.PointCloud]: """ Cluster a point cloud using DBSCAN and return list of clustered point clouds. @@ -836,7 +838,7 @@ def get_standard_coordinate_transform(): def visualize_clustered_point_clouds( - clustered_pcds: List[o3d.geometry.PointCloud], + clustered_pcds: list[o3d.geometry.PointCloud], window_name: str = "Clustered Point Clouds", point_size: float = 2.0, show_coordinate_frame: bool = True, @@ -1000,8 +1002,8 @@ def visualize_voxel_grid( def combine_object_pointclouds( - point_clouds: Union[List[np.ndarray], List[o3d.geometry.PointCloud]], - colors: Optional[List[np.ndarray]] = None, + point_clouds: list[np.ndarray] | list[o3d.geometry.PointCloud], + colors: list[np.ndarray] | None = None, ) -> o3d.geometry.PointCloud: """ Combine multiple point clouds into a single Open3D point cloud. @@ -1044,9 +1046,9 @@ def combine_object_pointclouds( def extract_centroids_from_masks( rgb_image: np.ndarray, depth_image: np.ndarray, - masks: List[np.ndarray], - camera_intrinsics: Union[List[float], np.ndarray], -) -> List[Dict[str, Any]]: + masks: list[np.ndarray], + camera_intrinsics: list[float] | np.ndarray, +) -> list[dict[str, Any]]: """ Extract 3D centroids and orientations from segmentation masks. diff --git a/dimos/perception/segmentation/__init__.py b/dimos/perception/segmentation/__init__.py index a8f9a291ce..a48a76d6a4 100644 --- a/dimos/perception/segmentation/__init__.py +++ b/dimos/perception/segmentation/__init__.py @@ -1,2 +1,2 @@ -from .utils import * from .sam_2d_seg import * +from .utils import * diff --git a/dimos/perception/segmentation/image_analyzer.py b/dimos/perception/segmentation/image_analyzer.py index 1260e41fe7..074ee7d605 100644 --- a/dimos/perception/segmentation/image_analyzer.py +++ b/dimos/perception/segmentation/image_analyzer.py @@ -13,10 +13,11 @@ # limitations under the License. import base64 -from openai import OpenAI -import cv2 import os +import cv2 +from openai import OpenAI + NORMAL_PROMPT = "What are in these images? Give a short word answer with at most two words, \ if not sure, give a description of its shape or color like 'small tube', 'blue item'. \" \ if does not look like an object, say 'unknown'. Export objects as a list of strings \ @@ -33,7 +34,7 @@ class ImageAnalyzer: - def __init__(self): + def __init__(self) -> None: """ Initializes the ImageAnalyzer with OpenAI API credentials. """ @@ -52,7 +53,7 @@ def encode_image(self, image): _, buffer = cv2.imencode(".jpg", image) return base64.b64encode(buffer).decode("utf-8") - def analyze_images(self, images, detail="auto", prompt_type="normal"): + def analyze_images(self, images, detail: str = "auto", prompt_type: str = "normal"): """ Takes a list of cropped images and returns descriptions from OpenAI's Vision model. @@ -87,7 +88,7 @@ def analyze_images(self, images, detail="auto", prompt_type="normal"): messages=[ { "role": "user", - "content": [{"type": "text", "text": prompt}] + image_data, + "content": [{"type": "text", "text": prompt}, *image_data], } ], max_tokens=300, @@ -95,10 +96,10 @@ def analyze_images(self, images, detail="auto", prompt_type="normal"): ) # Accessing the content of the response using dot notation - return [choice.message.content for choice in response.choices][0] + return next(choice.message.content for choice in response.choices) -def main(): +def main() -> None: # Define the directory containing cropped images cropped_images_dir = "cropped_images" if not os.path.exists(cropped_images_dir): @@ -130,7 +131,7 @@ def main(): object_list = [item.strip()[2:] for item in results.split("\n")] # Overlay text on images and display them - for i, (img, obj) in enumerate(zip(images, object_list)): + for i, (img, obj) in enumerate(zip(images, object_list, strict=False)): if obj: # Only process non-empty lines # Add text to image font = cv2.FONT_HERSHEY_SIMPLEX diff --git a/dimos/perception/segmentation/sam_2d_seg.py b/dimos/perception/segmentation/sam_2d_seg.py index cb2acaf076..b13ebc4c65 100644 --- a/dimos/perception/segmentation/sam_2d_seg.py +++ b/dimos/perception/segmentation/sam_2d_seg.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import time from collections import deque +from collections.abc import Sequence from concurrent.futures import ThreadPoolExecutor +import os +import time import cv2 import onnxruntime @@ -32,7 +33,6 @@ from dimos.utils.data import get_data from dimos.utils.gpu_utils import is_cuda_available from dimos.utils.logging_config import setup_logger -from dimos.utils.path_utils import get_project_root logger = setup_logger("dimos.perception.segmentation.sam_2d_seg") @@ -40,14 +40,14 @@ class Sam2DSegmenter: def __init__( self, - model_path="models_fastsam", - model_name="FastSAM-s.onnx", - min_analysis_interval=5.0, - use_tracker=True, - use_analyzer=True, - use_rich_labeling=False, - use_filtering=True, - ): + model_path: str = "models_fastsam", + model_name: str = "FastSAM-s.onnx", + min_analysis_interval: float = 5.0, + use_tracker: bool = True, + use_analyzer: bool = True, + use_rich_labeling: bool = False, + use_filtering: bool = True, + ) -> None: if is_cuda_available(): logger.info("Using CUDA for SAM 2d segmenter") if hasattr(onnxruntime, "preload_dlls"): # Handles CUDA 11 / onnxruntime-gpu<=1.18 @@ -225,7 +225,7 @@ def check_analysis_status(self, tracked_target_ids): if results is not None: # Map results to track IDs object_list = eval(results) - for track_id, result in zip(self.current_queue_ids, object_list): + for track_id, result in zip(self.current_queue_ids, object_list, strict=False): self.object_names[track_id] = result except Exception as e: print(f"Queue analysis failed: {e}") @@ -255,7 +255,7 @@ def check_analysis_status(self, tracked_target_ids): return queue_indices, queue_ids return None, None - def run_analysis(self, frame, tracked_bboxes, tracked_target_ids): + def run_analysis(self, frame, tracked_bboxes, tracked_target_ids) -> None: """Run queue image analysis in background.""" if not self.use_analyzer: return @@ -278,27 +278,29 @@ def run_analysis(self, frame, tracked_bboxes, tracked_target_ids): self.image_analyzer.analyze_images, cropped_images, prompt_type=prompt_type ) - def get_object_names(self, track_ids, tracked_names): + def get_object_names(self, track_ids, tracked_names: Sequence[str]): """Get object names for the given track IDs, falling back to tracked names.""" if not self.use_analyzer: return tracked_names return [ self.object_names.get(track_id, tracked_name) - for track_id, tracked_name in zip(track_ids, tracked_names) + for track_id, tracked_name in zip(track_ids, tracked_names, strict=False) ] - def visualize_results(self, image, masks, bboxes, track_ids, probs, names): + def visualize_results( + self, image, masks, bboxes, track_ids, probs: Sequence[float], names: Sequence[str] + ): """Generate an overlay visualization with segmentation results and object names.""" return plot_results(image, masks, bboxes, track_ids, probs, names) - def cleanup(self): + def cleanup(self) -> None: """Cleanup resources.""" if self.use_analyzer: self.analysis_executor.shutdown() -def main(): +def main() -> None: # Example usage with different configurations cap = cv2.VideoCapture(0) @@ -328,7 +330,7 @@ def main(): if not ret: break - start_time = time.time() + time.time() # Process image and get results masks, bboxes, target_ids, probs, names = segmenter.process_image(frame) diff --git a/dimos/perception/segmentation/test_sam_2d_seg.py b/dimos/perception/segmentation/test_sam_2d_seg.py index 297b265415..23eaf02fa3 100644 --- a/dimos/perception/segmentation/test_sam_2d_seg.py +++ b/dimos/perception/segmentation/test_sam_2d_seg.py @@ -15,21 +15,18 @@ import os import time -import cv2 import numpy as np import pytest -import reactivex as rx from reactivex import operators as ops from dimos.perception.segmentation.sam_2d_seg import Sam2DSegmenter from dimos.perception.segmentation.utils import extract_masks_bboxes_probs_names -from dimos.stream import video_provider from dimos.stream.video_provider import VideoProvider @pytest.mark.heavy class TestSam2DSegmenter: - def test_sam_segmenter_initialization(self): + def test_sam_segmenter_initialization(self) -> None: """Test FastSAM segmenter initializes correctly with default model path.""" try: # Try to initialize with the default model path and existing device setting @@ -40,7 +37,7 @@ def test_sam_segmenter_initialization(self): # If the model file doesn't exist, the test should still pass with a warning pytest.skip(f"Skipping test due to model initialization error: {e}") - def test_sam_segmenter_process_image(self): + def test_sam_segmenter_process_image(self) -> None: """Test FastSAM segmenter can process video frames and return segmentation masks.""" # Import get data inside method to avoid pytest fixture confusion from dimos.utils.data import get_data @@ -53,7 +50,6 @@ def test_sam_segmenter_process_image(self): # Note: conf and iou are parameters for process_image, not constructor # We'll monkey patch the process_image method to use lower thresholds - original_process_image = segmenter.process_image def patched_process_image(image): results = segmenter.model.track( @@ -70,7 +66,7 @@ def patched_process_image(image): ) if len(results) > 0: - masks, bboxes, track_ids, probs, names, areas = ( + masks, bboxes, track_ids, probs, names, _areas = ( extract_masks_bboxes_probs_names(results[0]) ) return masks, bboxes, track_ids, probs, names @@ -114,7 +110,7 @@ def process_frame(frame): frames_processed = 0 target_frames = 5 - def on_next(result): + def on_next(result) -> None: nonlocal frames_processed, results if not result: return @@ -126,10 +122,10 @@ def on_next(result): if frames_processed >= target_frames: subscription.dispose() - def on_error(error): + def on_error(error) -> None: pytest.fail(f"Error in segmentation stream: {error}") - def on_completed(): + def on_completed() -> None: pass # Subscribe and wait for results diff --git a/dimos/perception/segmentation/utils.py b/dimos/perception/segmentation/utils.py index 4101edfa40..24d6ce4bf2 100644 --- a/dimos/perception/segmentation/utils.py +++ b/dimos/perception/segmentation/utils.py @@ -12,13 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +from collections.abc import Sequence + import cv2 +import numpy as np import torch class SimpleTracker: - def __init__(self, history_size=100, min_count=10, count_window=20): + def __init__( + self, history_size: int = 100, min_count: int = 10, count_window: int = 20 + ) -> None: """ Simple temporal tracker that counts appearances in a fixed window. :param history_size: Number of past frames to remember @@ -43,12 +47,12 @@ def update(self, track_ids): # Compute occurrences efficiently using numpy unique_ids, counts = np.unique(all_tracks, return_counts=True) - id_counts = dict(zip(unique_ids, counts)) + id_counts = dict(zip(unique_ids, counts, strict=False)) # Update total counts but ensure it only contains IDs within the history size total_tracked_ids = np.concatenate(self.history) if self.history else np.array([]) unique_total_ids, total_counts = np.unique(total_tracked_ids, return_counts=True) - self.total_counts = dict(zip(unique_total_ids, total_counts)) + self.total_counts = dict(zip(unique_total_ids, total_counts, strict=False)) # Return IDs that appear often enough return [track_id for track_id, count in id_counts.items() if count >= self.min_count] @@ -58,7 +62,7 @@ def get_total_counts(self): return self.total_counts -def extract_masks_bboxes_probs_names(result, max_size=0.7): +def extract_masks_bboxes_probs_names(result, max_size: float = 0.7): """ Extracts masks, bounding boxes, probabilities, and class names from one Ultralytics result object. @@ -81,7 +85,7 @@ def extract_masks_bboxes_probs_names(result, max_size=0.7): total_area = result.masks.orig_shape[0] * result.masks.orig_shape[1] - for box, mask_data in zip(result.boxes, result.masks.data): + for box, mask_data in zip(result.boxes, result.masks.data, strict=False): mask_numpy = mask_data # Extract bounding box @@ -110,7 +114,7 @@ def extract_masks_bboxes_probs_names(result, max_size=0.7): return masks, bboxes, track_ids, probs, names, areas -def compute_texture_map(frame, blur_size=3): +def compute_texture_map(frame, blur_size: int = 3): """ Compute texture map using gradient statistics. Returns high values for textured regions and low values for smooth regions. @@ -149,7 +153,15 @@ def compute_texture_map(frame, blur_size=3): def filter_segmentation_results( - frame, masks, bboxes, track_ids, probs, names, areas, texture_threshold=0.07, size_filter=800 + frame, + masks, + bboxes, + track_ids, + probs: Sequence[float], + names: Sequence[str], + areas, + texture_threshold: float = 0.07, + size_filter: int = 800, ): """ Filters segmentation results using both overlap and saliency detection. @@ -228,7 +240,15 @@ def filter_segmentation_results( ) -def plot_results(image, masks, bboxes, track_ids, probs, names, alpha=0.5): +def plot_results( + image, + masks, + bboxes, + track_ids, + probs: Sequence[float], + names: Sequence[str], + alpha: float = 0.5, +): """ Draws bounding boxes, masks, and labels on the given image with enhanced visualization. Includes object names in the overlay and improved text visibility. @@ -236,7 +256,9 @@ def plot_results(image, masks, bboxes, track_ids, probs, names, alpha=0.5): h, w = image.shape[:2] overlay = image.copy() - for mask, bbox, track_id, prob, name in zip(masks, bboxes, track_ids, probs, names): + for mask, bbox, track_id, prob, name in zip( + masks, bboxes, track_ids, probs, names, strict=False + ): # Convert mask tensor to numpy if needed if isinstance(mask, torch.Tensor): mask = mask.cpu().numpy() @@ -291,7 +313,7 @@ def plot_results(image, masks, bboxes, track_ids, probs, names, alpha=0.5): return result -def crop_images_from_bboxes(image, bboxes, buffer=0): +def crop_images_from_bboxes(image, bboxes, buffer: int = 0): """ Crops regions from an image based on bounding boxes with an optional buffer. diff --git a/dimos/perception/spatial_perception.py b/dimos/perception/spatial_perception.py index 45faad5b12..7a96939431 100644 --- a/dimos/perception/spatial_perception.py +++ b/dimos/perception/spatial_perception.py @@ -16,16 +16,15 @@ Spatial Memory module for creating a semantic map of the environment. """ +from datetime import datetime import os import time +from typing import TYPE_CHECKING, Any, Optional import uuid -from datetime import datetime -from typing import Any, Dict, List, Optional import cv2 import numpy as np -from reactivex import Observable, disposable, interval -from reactivex import operators as ops +from reactivex import Observable, disposable, interval, operators as ops from reactivex.disposable import Disposable from dimos import spec @@ -34,12 +33,13 @@ from dimos.agents.memory.visual_memory import VisualMemory from dimos.constants import DIMOS_PROJECT_ROOT from dimos.core import DimosCluster, In, Module, rpc -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Vector3 from dimos.msgs.sensor_msgs import Image from dimos.types.robot_location import RobotLocation -from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from dimos.msgs.geometry_msgs import Vector3 + _OUTPUT_DIR = DIMOS_PROJECT_ROOT / "assets" / "output" _MEMORY_DIR = _OUTPUT_DIR / "memory" _SPATIAL_MEMORY_DIR = _MEMORY_DIR / "spatial_memory" @@ -70,19 +70,19 @@ def __init__( embedding_dimensions: int = 512, min_distance_threshold: float = 0.01, # Min distance in meters to store a new frame min_time_threshold: float = 1.0, # Min time in seconds to record a new frame - db_path: Optional[str] = str(_DB_PATH), # Path for ChromaDB persistence - visual_memory_path: Optional[str] = str( + db_path: str | None = str(_DB_PATH), # Path for ChromaDB persistence + visual_memory_path: str | None = str( _VISUAL_MEMORY_PATH ), # Path for saving/loading visual memory new_memory: bool = True, # Whether to create a new memory from scratch - output_dir: Optional[str] = str( + output_dir: str | None = str( _SPATIAL_MEMORY_DIR ), # Directory for storing visual memory data chroma_client: Any = None, # Optional ChromaDB client for persistence visual_memory: Optional[ "VisualMemory" ] = None, # Optional VisualMemory instance for storing images - ): + ) -> None: """ Initialize the spatial perception system. @@ -167,8 +167,8 @@ def __init__( embedding_provider=self.embedding_provider, ) - self.last_position: Optional[Vector3] = None - self.last_record_time: Optional[float] = None + self.last_position: Vector3 | None = None + self.last_record_time: float | None = None self.frame_count: int = 0 self.stored_frame_count: int = 0 @@ -177,20 +177,20 @@ def __init__( self._subscription = None # List to store robot locations - self.robot_locations: List[RobotLocation] = [] + self.robot_locations: list[RobotLocation] = [] # Track latest data for processing - self._latest_video_frame: Optional[np.ndarray] = None + self._latest_video_frame: np.ndarray | None = None self._process_interval = 1 logger.info(f"SpatialMemory initialized with model {embedding_model}") @rpc - def start(self): + def start(self) -> None: super().start() # Subscribe to LCM streams - def set_video(image_msg: Image): + def set_video(image_msg: Image) -> None: # Convert Image message to numpy array if hasattr(image_msg, "data"): frame = image_msg.data @@ -207,7 +207,7 @@ def set_video(image_msg: Image): self._disposables.add(Disposable(unsub)) @rpc - def stop(self): + def stop(self) -> None: self.stop_continuous_processing() # Save data before shutdown @@ -218,7 +218,7 @@ def stop(self): super().stop() - def _process_frame(self): + def _process_frame(self) -> None: """Process the latest frame with pose data if available.""" tf = self.tf.get("map", "base_link") if self._latest_video_frame is None or tf is None: @@ -309,7 +309,7 @@ def _process_frame(self): @rpc def query_by_location( self, x: float, y: float, radius: float = 2.0, limit: int = 5 - ) -> List[Dict]: + ) -> list[dict]: """ Query the vector database for images near the specified location. @@ -374,7 +374,7 @@ def stop_continuous_processing(self) -> None: except Exception as e: logger.error(f"Error stopping spatial memory processing: {e}") - def _on_frame_processed(self, result: Dict[str, Any]) -> None: + def _on_frame_processed(self, result: dict[str, Any]) -> None: """ Handle updates from the spatial memory processing stream. """ @@ -501,7 +501,7 @@ def process_combined_data(data): ) @rpc - def query_by_image(self, image: np.ndarray, limit: int = 5) -> List[Dict]: + def query_by_image(self, image: np.ndarray, limit: int = 5) -> list[dict]: """ Query the vector database for images similar to the provided image. @@ -516,7 +516,7 @@ def query_by_image(self, image: np.ndarray, limit: int = 5) -> List[Dict]: return self.vector_db.query_by_embedding(embedding, limit) @rpc - def query_by_text(self, text: str, limit: int = 5) -> List[Dict]: + def query_by_text(self, text: str, limit: int = 5) -> list[dict]: """ Query the vector database for images matching the provided text description. @@ -558,9 +558,9 @@ def add_robot_location(self, location: RobotLocation) -> bool: def add_named_location( self, name: str, - position: Optional[List[float]] = None, - rotation: Optional[List[float]] = None, - description: Optional[str] = None, + position: list[float] | None = None, + rotation: list[float] | None = None, + description: str | None = None, ) -> bool: """ Add a named robot location to spatial memory using current or specified position. @@ -591,7 +591,7 @@ def add_named_location( return self.add_robot_location(location) @rpc - def get_robot_locations(self) -> List[RobotLocation]: + def get_robot_locations(self) -> list[RobotLocation]: """ Get all stored robot locations. @@ -601,7 +601,7 @@ def get_robot_locations(self) -> List[RobotLocation]: return self.robot_locations @rpc - def find_robot_location(self, name: str) -> Optional[RobotLocation]: + def find_robot_location(self, name: str) -> RobotLocation | None: """ Find a robot location by name. @@ -619,7 +619,7 @@ def find_robot_location(self, name: str) -> Optional[RobotLocation]: return None @rpc - def get_stats(self) -> Dict[str, int]: + def get_stats(self) -> dict[str, int]: """Get statistics about the spatial memory module. Returns: @@ -638,7 +638,7 @@ def tag_location(self, robot_location: RobotLocation) -> bool: return True @rpc - def query_tagged_location(self, query: str) -> Optional[RobotLocation]: + def query_tagged_location(self, query: str) -> RobotLocation | None: location, semantic_distance = self.vector_db.query_tagged_location(query) if semantic_distance < 0.3: return location @@ -657,4 +657,4 @@ def deploy( spatial_memory = SpatialMemory.blueprint -__all__ = ["SpatialMemory", "spatial_memory", "deploy"] +__all__ = ["SpatialMemory", "deploy", "spatial_memory"] diff --git a/dimos/perception/test_spatial_memory.py b/dimos/perception/test_spatial_memory.py index cde2b7d45c..f42638df73 100644 --- a/dimos/perception/test_spatial_memory.py +++ b/dimos/perception/test_spatial_memory.py @@ -17,13 +17,9 @@ import tempfile import time -import cv2 import numpy as np import pytest -import reactivex as rx -from reactivex import Observable from reactivex import operators as ops -from reactivex.subject import Subject from dimos.msgs.geometry_msgs import Pose from dimos.perception.spatial_perception import SpatialMemory @@ -57,14 +53,14 @@ def spatial_memory(self, temp_dir): # Clean up memory.stop() - def test_spatial_memory_initialization(self, spatial_memory): + def test_spatial_memory_initialization(self, spatial_memory) -> None: """Test SpatialMemory initializes correctly with CLIP model.""" # Use the shared spatial_memory fixture assert spatial_memory is not None assert spatial_memory.embedding_model == "clip" assert spatial_memory.embedding_provider is not None - def test_image_embedding(self, spatial_memory): + def test_image_embedding(self, spatial_memory) -> None: """Test generating image embeddings using CLIP.""" # Use the shared spatial_memory fixture # Create a test image - use a simple colored square @@ -89,7 +85,7 @@ def test_image_embedding(self, spatial_memory): assert text_embedding.shape[0] == spatial_memory.embedding_dimensions assert np.isclose(np.linalg.norm(text_embedding), 1.0, atol=1e-5) - def test_spatial_memory_processing(self, spatial_memory, temp_dir): + def test_spatial_memory_processing(self, spatial_memory, temp_dir) -> None: """Test processing video frames and building spatial memory with CLIP embeddings.""" try: # Use the shared spatial_memory fixture @@ -136,7 +132,7 @@ def process_frame(frame): frames_processed = 0 target_frames = 100 # Process more frames for thorough testing - def on_next(result): + def on_next(result) -> None: nonlocal results, frames_processed if not result: # Skip None results return @@ -148,10 +144,10 @@ def on_next(result): if frames_processed >= target_frames: subscription.dispose() - def on_error(error): + def on_error(error) -> None: pytest.fail(f"Error in spatial stream: {error}") - def on_completed(): + def on_completed() -> None: pass # Subscribe and wait for results diff --git a/dimos/perception/test_spatial_memory_module.py b/dimos/perception/test_spatial_memory_module.py index 5166ef2443..f89c975d89 100644 --- a/dimos/perception/test_spatial_memory_module.py +++ b/dimos/perception/test_spatial_memory_module.py @@ -14,26 +14,21 @@ import asyncio import os -import shutil import tempfile import time -from typing import Dict, List -import numpy as np import pytest from reactivex import operators as ops from dimos import core -from dimos.core import Module, In, Out, rpc +from dimos.core import Module, Out, rpc from dimos.msgs.sensor_msgs import Image -from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub +from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.data import get_data -from dimos.utils.testing import TimedSensorReplay from dimos.utils.logging_config import setup_logger -from unittest.mock import patch, MagicMock -import warnings +from dimos.utils.testing import TimedSensorReplay logger = setup_logger("test_spatial_memory_module") @@ -45,13 +40,13 @@ class VideoReplayModule(Module): video_out: Out[Image] = None - def __init__(self, video_path: str): + def __init__(self, video_path: str) -> None: super().__init__() self.video_path = video_path self._subscription = None @rpc - def start(self): + def start(self) -> None: """Start replaying video data.""" # Use TimedSensorReplay to replay video frames video_replay = TimedSensorReplay(self.video_path, autocast=Image.from_numpy) @@ -69,7 +64,7 @@ def start(self): logger.info("VideoReplayModule started") @rpc - def stop(self): + def stop(self) -> None: """Stop replaying video data.""" if self._subscription: self._subscription.dispose() @@ -82,13 +77,13 @@ class OdometryReplayModule(Module): odom_out: Out[Odometry] = None - def __init__(self, odom_path: str): + def __init__(self, odom_path: str) -> None: super().__init__() self.odom_path = odom_path self._subscription = None @rpc - def start(self): + def start(self) -> None: """Start replaying odometry data.""" # Use TimedSensorReplay to replay odometry odom_replay = TimedSensorReplay(self.odom_path, autocast=Odometry.from_msg) @@ -106,7 +101,7 @@ def start(self): logger.info("OdometryReplayModule started") @rpc - def stop(self): + def stop(self) -> None: """Stop replaying odometry data.""" if self._subscription: self._subscription.dispose() @@ -189,7 +184,7 @@ async def test_spatial_memory_module_with_replay(self, temp_dir): logger.error( f"Timeout after {timeout}s - Frame count: {stats['frame_count']}, Stored: {stats['stored_frame_count']}" ) - assert False, f"No frames processed within {timeout} seconds" + raise AssertionError(f"No frames processed within {timeout} seconds") await asyncio.sleep(2) diff --git a/dimos/protocol/encode/__init__.py b/dimos/protocol/encode/__init__.py index cce141527f..66bbbbb21c 100644 --- a/dimos/protocol/encode/__init__.py +++ b/dimos/protocol/encode/__init__.py @@ -1,5 +1,5 @@ -import json from abc import ABC, abstractmethod +import json from typing import Generic, Protocol, TypeVar MsgT = TypeVar("MsgT") @@ -67,7 +67,7 @@ def decode(data: bytes) -> LCMMsgT: class LCMTypedEncoder(LCM, Generic[LCMMsgT]): """Typed LCM encoder for specific message types.""" - def __init__(self, message_type: type[LCMMsgT]): + def __init__(self, message_type: type[LCMMsgT]) -> None: self.message_type = message_type @staticmethod diff --git a/dimos/protocol/pubsub/lcmpubsub.py b/dimos/protocol/pubsub/lcmpubsub.py index 5fda6dbb83..70a4034a1e 100644 --- a/dimos/protocol/pubsub/lcmpubsub.py +++ b/dimos/protocol/pubsub/lcmpubsub.py @@ -14,21 +14,16 @@ from __future__ import annotations -import pickle -import subprocess -import sys -import threading -import traceback from dataclasses import dataclass -from typing import Any, Callable, Optional, Protocol, runtime_checkable - -import lcm +from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin -from dimos.protocol.service.lcmservice import LCMConfig, LCMService, autoconf, check_system -from dimos.utils.deprecation import deprecated +from dimos.protocol.service.lcmservice import LCMConfig, LCMService from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from collections.abc import Callable + import threading logger = setup_logger(__name__) @@ -38,7 +33,7 @@ class LCMMsg(Protocol): msg_name: str @classmethod - def lcm_decode(cls, data: bytes) -> "LCMMsg": + def lcm_decode(cls, data: bytes) -> LCMMsg: """Decode bytes into an LCM message instance.""" ... @@ -50,7 +45,7 @@ def lcm_encode(self) -> bytes: @dataclass class Topic: topic: str = "" - lcm_type: Optional[type[LCMMsg]] = None + lcm_type: type[LCMMsg] | None = None def __str__(self) -> str: if self.lcm_type is None: @@ -61,7 +56,7 @@ def __str__(self) -> str: class LCMPubSubBase(LCMService, PubSub[Topic, Any]): default_config = LCMConfig _stop_event: threading.Event - _thread: Optional[threading.Thread] + _thread: threading.Thread | None _callbacks: dict[str, list[Callable[[Any], None]]] def __init__(self, **kwargs) -> None: @@ -69,7 +64,7 @@ def __init__(self, **kwargs) -> None: super().__init__(**kwargs) self._callbacks = {} - def publish(self, topic: Topic, message: bytes): + def publish(self, topic: Topic, message: bytes) -> None: """Publish a message to the specified channel.""" if self.l is None: logger.error("Tried to publish after LCM was closed") @@ -82,14 +77,14 @@ def subscribe( if self.l is None: logger.error("Tried to subscribe after LCM was closed") - def noop(): + def noop() -> None: pass return noop lcm_subscription = self.l.subscribe(str(topic), lambda _, msg: callback(msg, topic)) - def unsubscribe(): + def unsubscribe() -> None: if self.l is None: return self.l.unsubscribe(lcm_subscription) diff --git a/dimos/protocol/pubsub/memory.py b/dimos/protocol/pubsub/memory.py index 35e93b0754..513dfd32cd 100644 --- a/dimos/protocol/pubsub/memory.py +++ b/dimos/protocol/pubsub/memory.py @@ -13,7 +13,8 @@ # limitations under the License. from collections import defaultdict -from typing import Any, Callable, DefaultDict, List +from collections.abc import Callable +from typing import Any from dimos.protocol import encode from dimos.protocol.pubsub.spec import PubSub, PubSubEncoderMixin @@ -21,7 +22,7 @@ class Memory(PubSub[str, Any]): def __init__(self) -> None: - self._map: DefaultDict[str, List[Callable[[Any, str], None]]] = defaultdict(list) + self._map: defaultdict[str, list[Callable[[Any, str], None]]] = defaultdict(list) def publish(self, topic: str, message: Any) -> None: for cb in self._map[topic]: @@ -30,7 +31,7 @@ def publish(self, topic: str, message: Any) -> None: def subscribe(self, topic: str, callback: Callable[[Any, str], None]) -> Callable[[], None]: self._map[topic].append(callback) - def unsubscribe(): + def unsubscribe() -> None: try: self._map[topic].remove(callback) if not self._map[topic]: diff --git a/dimos/protocol/pubsub/redispubsub.py b/dimos/protocol/pubsub/redispubsub.py index 42128e0d0c..7d6c798f2c 100644 --- a/dimos/protocol/pubsub/redispubsub.py +++ b/dimos/protocol/pubsub/redispubsub.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass, field import json import threading import time -from collections import defaultdict -from dataclasses import dataclass, field -from typing import Any, Callable, Dict, List +from types import TracebackType +from typing import Any import redis @@ -30,7 +32,7 @@ class RedisConfig: host: str = "localhost" port: int = 6379 db: int = 0 - kwargs: Dict[str, Any] = field(default_factory=dict) + kwargs: dict[str, Any] = field(default_factory=dict) class Redis(PubSub[str, Any], Service[RedisConfig]): @@ -46,7 +48,7 @@ def __init__(self, **kwargs) -> None: self._pubsub = None # Subscription management - self._callbacks: Dict[str, List[Callable[[Any, str], None]]] = defaultdict(list) + self._callbacks: dict[str, list[Callable[[Any, str], None]]] = defaultdict(list) self._listener_thread = None self._running = False @@ -85,7 +87,7 @@ def _connect(self): f"Failed to connect to Redis at {self.config.host}:{self.config.port}: {e}" ) - def _listen_loop(self): + def _listen_loop(self) -> None: """Listen for messages from Redis and dispatch to callbacks.""" while self._running: try: @@ -141,7 +143,7 @@ def subscribe(self, topic: str, callback: Callable[[Any, str], None]) -> Callabl self._callbacks[topic].append(callback) # Return unsubscribe function - def unsubscribe(): + def unsubscribe() -> None: self.unsubscribe(topic, callback) return unsubscribe @@ -161,7 +163,7 @@ def unsubscribe(self, topic: str, callback: Callable[[Any, str], None]) -> None: except ValueError: pass # Callback wasn't in the list - def close(self): + def close(self) -> None: """Close Redis connections and stop listener thread.""" self._running = False @@ -187,5 +189,10 @@ def close(self): def __enter__(self): return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: self.close() diff --git a/dimos/protocol/pubsub/shm/ipc_factory.py b/dimos/protocol/pubsub/shm/ipc_factory.py index 3d6dbc17e3..9aedbfa1c4 100644 --- a/dimos/protocol/pubsub/shm/ipc_factory.py +++ b/dimos/protocol/pubsub/shm/ipc_factory.py @@ -14,15 +14,12 @@ # frame_ipc.py # Python 3.9+ -import base64 -import time from abc import ABC, abstractmethod +from multiprocessing.shared_memory import SharedMemory import os -from typing import Optional, Tuple +import time import numpy as np -from multiprocessing.shared_memory import SharedMemory -from multiprocessing.managers import SharedMemoryManager _UNLINK_ON_GC = os.getenv("DIMOS_IPC_UNLINK_ON_GC", "0").lower() not in ("0", "false", "no") @@ -99,10 +96,11 @@ def close(self) -> None: from multiprocessing.shared_memory import SharedMemory -import weakref, os +import os +import weakref -def _safe_unlink(name): +def _safe_unlink(name: str) -> None: try: shm = SharedMemory(name=name) shm.unlink() @@ -118,12 +116,19 @@ def _safe_unlink(name): class CpuShmChannel(FrameChannel): - def __init__(self, shape, dtype=np.uint8, *, data_name=None, ctrl_name=None): + def __init__( + self, + shape, + dtype=np.uint8, + *, + data_name: str | None = None, + ctrl_name: str | None = None, + ) -> None: self._shape = tuple(shape) self._dtype = np.dtype(dtype) self._nbytes = int(self._dtype.itemsize * np.prod(self._shape)) - def _create_or_open(name, size): + def _create_or_open(name: str, size: int): try: shm = SharedMemory(create=True, size=size, name=name) owner = True @@ -169,7 +174,7 @@ def descriptor(self): } @property - def device(self): + def device(self) -> str: return "cpu" @property @@ -180,7 +185,7 @@ def shape(self): def dtype(self): return self._dtype - def publish(self, frame): + def publish(self, frame) -> None: assert isinstance(frame, np.ndarray) assert frame.shape == self._shape and frame.dtype == self._dtype active = int(self._ctrl[2]) @@ -198,7 +203,7 @@ def publish(self, frame): self._ctrl[2] = inactive self._ctrl[0] += 1 - def read(self, last_seq: int = -1, require_new=True): + def read(self, last_seq: int = -1, require_new: bool = True): for _ in range(3): seq1 = int(self._ctrl[0]) idx = int(self._ctrl[2]) @@ -223,7 +228,7 @@ def descriptor(self): } @classmethod - def attach(cls, desc): + def attach(cls, desc: str): obj = object.__new__(cls) obj._shape = tuple(desc["shape"]) obj._dtype = np.dtype(desc["dtype"]) @@ -244,7 +249,7 @@ def attach(cls, desc): obj._finalizer_data = obj._finalizer_ctrl = None return obj - def close(self): + def close(self) -> None: if getattr(self, "_is_owner", False): try: self._shm_ctrl.close() diff --git a/dimos/protocol/pubsub/shmpubsub.py b/dimos/protocol/pubsub/shmpubsub.py index 2d643a32d8..ef67ffb885 100644 --- a/dimos/protocol/pubsub/shmpubsub.py +++ b/dimos/protocol/pubsub/shmpubsub.py @@ -19,22 +19,25 @@ from __future__ import annotations +from collections import defaultdict +from dataclasses import dataclass import hashlib import os import struct import threading import time +from typing import TYPE_CHECKING, Any import uuid -from collections import defaultdict -from dataclasses import dataclass -from typing import Any, Callable, Dict, Optional, Tuple import numpy as np -from dimos.protocol.pubsub.spec import PubSub, PubSubEncoderMixin, PickleEncoderMixin -from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel, CPU_IPC_Factory +from dimos.protocol.pubsub.shm.ipc_factory import CpuShmChannel +from dimos.protocol.pubsub.spec import PickleEncoderMixin, PubSub, PubSubEncoderMixin from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from collections.abc import Callable + logger = setup_logger("dimos.protocol.pubsub.sharedmemory") @@ -72,32 +75,32 @@ class SharedMemoryPubSubBase(PubSub[str, Any]): # TODO: implement "is_cuda" below capacity, above cp class _TopicState: __slots__ = ( - "channel", - "subs", - "stop", - "thread", - "last_seq", - "shape", - "dtype", "capacity", + "channel", "cp", + "dtype", "last_local_payload", + "last_seq", + "shape", + "stop", + "subs", "suppress_counts", + "thread", ) - def __init__(self, channel, capacity: int, cp_mod): + def __init__(self, channel, capacity: int, cp_mod) -> None: self.channel = channel self.capacity = int(capacity) self.shape = (self.capacity + 20,) # +20 for header: length(4) + uuid(16) self.dtype = np.uint8 self.subs: list[Callable[[bytes, str], None]] = [] self.stop = threading.Event() - self.thread: Optional[threading.Thread] = None + self.thread: threading.Thread | None = None self.last_seq = 0 # start at 0 to avoid b"" on first poll # TODO: implement an initializer variable for is_cuda once CUDA IPC is in self.cp = cp_mod - self.last_local_payload: Optional[bytes] = None - self.suppress_counts: Dict[bytes, int] = defaultdict(int) # UUID bytes as key + self.last_local_payload: bytes | None = None + self.suppress_counts: dict[bytes, int] = defaultdict(int) # UUID bytes as key # ----- init / lifecycle ------------------------------------------------- @@ -115,7 +118,7 @@ def __init__( default_capacity=default_capacity, close_channels_on_stop=close_channels_on_stop, ) - self._topics: Dict[str, SharedMemoryPubSubBase._TopicState] = {} + self._topics: dict[str, SharedMemoryPubSubBase._TopicState] = {} self._lock = threading.Lock() def start(self) -> None: @@ -126,7 +129,7 @@ def start(self) -> None: def stop(self) -> None: with self._lock: - for topic, st in list(self._topics.items()): + for _topic, st in list(self._topics.items()): # stop fanout try: if st.thread: @@ -193,7 +196,7 @@ def subscribe(self, topic: str, callback: Callable[[bytes, str], Any]) -> Callab st.thread = threading.Thread(target=self._fanout_loop, args=(topic, st), daemon=True) st.thread.start() - def _unsub(): + def _unsub() -> None: try: st.subs.remove(callback) except ValueError: @@ -242,7 +245,7 @@ def _names_for_topic(topic: str, capacity: int) -> tuple[str, str]: def _fanout_loop(self, topic: str, st: _TopicState) -> None: while not st.stop.is_set(): - seq, ts_ns, view = st.channel.read(last_seq=st.last_seq, require_new=True) + seq, _ts_ns, view = st.channel.read(last_seq=st.last_seq, require_new=True) if view is None: time.sleep(0.001) continue diff --git a/dimos/protocol/pubsub/spec.py b/dimos/protocol/pubsub/spec.py index b6ce6695da..ef5a4f450f 100644 --- a/dimos/protocol/pubsub/spec.py +++ b/dimos/protocol/pubsub/spec.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import pickle from abc import ABC, abstractmethod -from collections.abc import AsyncIterator +import asyncio +from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import dataclass -from typing import Any, Callable, Generic, TypeVar +import pickle +from typing import Any, Generic, TypeVar + from dimos.utils.logging_config import setup_logger MsgT = TypeVar("MsgT") @@ -57,7 +58,7 @@ def unsubscribe(self) -> None: def __enter__(self): return self - def __exit__(self, *exc): + def __exit__(self, *exc) -> None: self.unsubscribe() # public helper: returns disposable object @@ -69,7 +70,7 @@ def sub(self, topic: TopicT, cb: Callable[[MsgT, TopicT], None]) -> "_Subscripti async def aiter(self, topic: TopicT, *, max_pending: int | None = None) -> AsyncIterator[MsgT]: q: asyncio.Queue[MsgT] = asyncio.Queue(maxsize=max_pending or 0) - def _cb(msg: MsgT, topic: TopicT): + def _cb(msg: MsgT, topic: TopicT) -> None: q.put_nowait(msg) unsubscribe_fn = self.subscribe(topic, _cb) @@ -85,7 +86,7 @@ def _cb(msg: MsgT, topic: TopicT): async def queue(self, topic: TopicT, *, max_pending: int | None = None): q: asyncio.Queue[MsgT] = asyncio.Queue(maxsize=max_pending or 0) - def _queue_cb(msg: MsgT, topic: TopicT): + def _queue_cb(msg: MsgT, topic: TopicT) -> None: q.put_nowait(msg) unsubscribe_fn = self.subscribe(topic, _queue_cb) @@ -113,7 +114,7 @@ def encode(self, msg: MsgT, topic: TopicT) -> bytes: ... @abstractmethod def decode(self, msg: bytes, topic: TopicT) -> MsgT: ... - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self._encode_callback_map: dict = {} @@ -131,7 +132,7 @@ def subscribe( ) -> Callable[[], None]: """Subscribe with automatic decoding.""" - def wrapper_cb(encoded_data: bytes, topic: TopicT): + def wrapper_cb(encoded_data: bytes, topic: TopicT) -> None: decoded_message = self.decode(encoded_data, topic) callback(decoded_message, topic) diff --git a/dimos/protocol/pubsub/test_encoder.py b/dimos/protocol/pubsub/test_encoder.py index 4f2d23d7d2..9a47c14105 100644 --- a/dimos/protocol/pubsub/test_encoder.py +++ b/dimos/protocol/pubsub/test_encoder.py @@ -19,12 +19,12 @@ from dimos.protocol.pubsub.memory import Memory, MemoryWithJSONEncoder -def test_json_encoded_pubsub(): +def test_json_encoded_pubsub() -> None: """Test memory pubsub with JSON encoding.""" pubsub = MemoryWithJSONEncoder() received_messages = [] - def callback(message, topic): + def callback(message, topic) -> None: received_messages.append(message) # Subscribe to a topic @@ -47,16 +47,16 @@ def callback(message, topic): # Verify all messages were received and properly decoded assert len(received_messages) == len(test_messages) - for original, received in zip(test_messages, received_messages): + for original, received in zip(test_messages, received_messages, strict=False): assert original == received -def test_json_encoding_edge_cases(): +def test_json_encoding_edge_cases() -> None: """Test edge cases for JSON encoding.""" pubsub = MemoryWithJSONEncoder() received_messages = [] - def callback(message, topic): + def callback(message, topic) -> None: received_messages.append(message) pubsub.subscribe("edge_cases", callback) @@ -78,16 +78,16 @@ def callback(message, topic): assert received_messages == edge_cases -def test_multiple_subscribers_with_encoding(): +def test_multiple_subscribers_with_encoding() -> None: """Test that multiple subscribers work with encoding.""" pubsub = MemoryWithJSONEncoder() received_messages_1 = [] received_messages_2 = [] - def callback_1(message, topic): + def callback_1(message, topic) -> None: received_messages_1.append(message) - def callback_2(message, topic): + def callback_2(message, topic) -> None: received_messages_2.append(f"callback_2: {message}") pubsub.subscribe("json_topic", callback_1) @@ -123,16 +123,16 @@ def callback_2(message, topic): # assert received_messages_2 == ["only callback_2 should get this"] -def test_data_actually_encoded_in_transit(): +def test_data_actually_encoded_in_transit() -> None: """Validate that data is actually encoded in transit by intercepting raw bytes.""" # Create a spy memory that captures what actually gets published class SpyMemory(Memory): - def __init__(self): + def __init__(self) -> None: super().__init__() self.raw_messages_received = [] - def publish(self, topic: str, message): + def publish(self, topic: str, message) -> None: # Capture what actually gets published self.raw_messages_received.append((topic, message, type(message))) super().publish(topic, message) @@ -144,7 +144,7 @@ class SpyMemoryWithJSON(MemoryWithJSONEncoder, SpyMemory): pubsub = SpyMemoryWithJSON() received_decoded = [] - def callback(message, topic): + def callback(message, topic) -> None: received_decoded.append(message) pubsub.subscribe("test_topic", callback) diff --git a/dimos/protocol/pubsub/test_lcmpubsub.py b/dimos/protocol/pubsub/test_lcmpubsub.py index d8a39248bb..b089483164 100644 --- a/dimos/protocol/pubsub/test_lcmpubsub.py +++ b/dimos/protocol/pubsub/test_lcmpubsub.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading import time import pytest @@ -55,7 +54,7 @@ class MockLCMMessage: msg_name = "geometry_msgs.Mock" - def __init__(self, data): + def __init__(self, data) -> None: self.data = data def lcm_encode(self) -> bytes: @@ -69,7 +68,7 @@ def __eq__(self, other): return isinstance(other, MockLCMMessage) and self.data == other.data -def test_LCMPubSubBase_pubsub(lcm_pub_sub_base): +def test_LCMPubSubBase_pubsub(lcm_pub_sub_base) -> None: lcm = lcm_pub_sub_base received_messages = [] @@ -77,7 +76,7 @@ def test_LCMPubSubBase_pubsub(lcm_pub_sub_base): topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg, topic): + def callback(msg, topic) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -98,13 +97,13 @@ def callback(msg, topic): assert received_topic == topic -def test_lcm_autodecoder_pubsub(lcm): +def test_lcm_autodecoder_pubsub(lcm) -> None: received_messages = [] topic = Topic(topic="/test_topic", lcm_type=MockLCMMessage) test_message = MockLCMMessage("test_data") - def callback(msg, topic): + def callback(msg, topic) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -134,12 +133,12 @@ def callback(msg, topic): # passes some geometry types through LCM @pytest.mark.parametrize("test_message", test_msgs) -def test_lcm_geometry_msgs_pubsub(test_message, lcm): +def test_lcm_geometry_msgs_pubsub(test_message, lcm) -> None: received_messages = [] topic = Topic(topic="/test_topic", lcm_type=test_message.__class__) - def callback(msg, topic): + def callback(msg, topic) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) @@ -165,13 +164,13 @@ def callback(msg, topic): # passes some geometry types through pickle LCM @pytest.mark.parametrize("test_message", test_msgs) -def test_lcm_geometry_msgs_autopickle_pubsub(test_message, pickle_lcm): +def test_lcm_geometry_msgs_autopickle_pubsub(test_message, pickle_lcm) -> None: lcm = pickle_lcm received_messages = [] topic = Topic(topic="/test_topic") - def callback(msg, topic): + def callback(msg, topic) -> None: received_messages.append((msg, topic)) lcm.subscribe(topic, callback) diff --git a/dimos/protocol/pubsub/test_spec.py b/dimos/protocol/pubsub/test_spec.py index 0f9486ec09..2bc8ae3ea1 100644 --- a/dimos/protocol/pubsub/test_spec.py +++ b/dimos/protocol/pubsub/test_spec.py @@ -15,10 +15,10 @@ # limitations under the License. import asyncio -import time -import traceback +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, List, Tuple +import time +from typing import Any import pytest @@ -38,7 +38,7 @@ def memory_context(): # Use Any for context manager type to accommodate both Memory and Redis -testdata: List[Tuple[Callable[[], Any], Any, List[Any]]] = [ +testdata: list[tuple[Callable[[], Any], Any, list[Any]]] = [ (memory_context, "topic", ["value1", "value2", "value3"]), ] @@ -84,7 +84,7 @@ def lcm_context(): print("LCM not available") -from dimos.protocol.pubsub.shmpubsub import SharedMemory, PickleSharedMemory +from dimos.protocol.pubsub.shmpubsub import PickleSharedMemory @contextmanager @@ -105,13 +105,13 @@ def shared_memory_cpu_context(): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_store(pubsub_context, topic, values): +def test_store(pubsub_context, topic, values) -> None: with pubsub_context() as x: # Create a list to capture received messages received_messages = [] # Define callback function that stores received messages - def callback(message, _): + def callback(message, _) -> None: received_messages.append(message) # Subscribe to the topic with our callback @@ -130,7 +130,7 @@ def callback(message, _): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_multiple_subscribers(pubsub_context, topic, values): +def test_multiple_subscribers(pubsub_context, topic, values) -> None: """Test that multiple subscribers receive the same message.""" with pubsub_context() as x: # Create lists to capture received messages for each subscriber @@ -138,10 +138,10 @@ def test_multiple_subscribers(pubsub_context, topic, values): received_messages_2 = [] # Define callback functions - def callback_1(message, topic): + def callback_1(message, topic) -> None: received_messages_1.append(message) - def callback_2(message, topic): + def callback_2(message, topic) -> None: received_messages_2.append(message) # Subscribe both callbacks to the same topic @@ -162,14 +162,14 @@ def callback_2(message, topic): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_unsubscribe(pubsub_context, topic, values): +def test_unsubscribe(pubsub_context, topic, values) -> None: """Test that unsubscribed callbacks don't receive messages.""" with pubsub_context() as x: # Create a list to capture received messages received_messages = [] # Define callback function - def callback(message, topic): + def callback(message, topic) -> None: received_messages.append(message) # Subscribe and get unsubscribe function @@ -189,14 +189,14 @@ def callback(message, topic): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) -def test_multiple_messages(pubsub_context, topic, values): +def test_multiple_messages(pubsub_context, topic, values) -> None: """Test that subscribers receive multiple messages in order.""" with pubsub_context() as x: # Create a list to capture received messages received_messages = [] # Define callback function - def callback(message, topic): + def callback(message, topic) -> None: received_messages.append(message) # Subscribe to the topic @@ -217,7 +217,7 @@ def callback(message, topic): @pytest.mark.parametrize("pubsub_context, topic, values", testdata) @pytest.mark.asyncio -async def test_async_iterator(pubsub_context, topic, values): +async def test_async_iterator(pubsub_context, topic, values) -> None: """Test that async iterator receives messages correctly.""" with pubsub_context() as x: # Get the messages to send (using the rest of the values) @@ -228,7 +228,7 @@ async def test_async_iterator(pubsub_context, topic, values): async_iter = x.aiter(topic) # Create a task to consume messages from the async iterator - async def consume_messages(): + async def consume_messages() -> None: try: async for message in async_iter: received_messages.append(message) diff --git a/dimos/protocol/rpc/off_test_pubsubrpc.py b/dimos/protocol/rpc/off_test_pubsubrpc.py index 33d149ee11..940baad2f7 100644 --- a/dimos/protocol/rpc/off_test_pubsubrpc.py +++ b/dimos/protocol/rpc/off_test_pubsubrpc.py @@ -12,19 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import asyncio -import time +from collections.abc import Callable from contextlib import contextmanager -from typing import Any, Callable, List, Tuple +import time import pytest -from dimos.core import Module, rpc, start, stop +from dimos.core import Module, rpc, start from dimos.protocol.rpc.lcmrpc import LCMRPC -from dimos.protocol.rpc.spec import RPCClient, RPCServer from dimos.protocol.service.lcmservice import autoconf -testgrid: List[Callable] = [] +testgrid: list[Callable] = [] # test module we'll use for binding RPC methods @@ -84,7 +82,7 @@ def redis_rpc_context(): @pytest.mark.parametrize("rpc_context", testgrid) -def test_basics(rpc_context): +def test_basics(rpc_context) -> None: with rpc_context() as (server, client): def remote_function(a: int, b: int): @@ -99,7 +97,7 @@ def remote_function(a: int, b: int): msgs = [] - def receive_msg(response): + def receive_msg(response) -> None: msgs.append(response) print(f"Received response: {response}") @@ -110,7 +108,7 @@ def receive_msg(response): @pytest.mark.parametrize("rpc_context", testgrid) -def test_module_autobind(rpc_context): +def test_module_autobind(rpc_context) -> None: with rpc_context() as (server, client): module = MyModule() print("\n") @@ -132,7 +130,7 @@ def test_module_autobind(rpc_context): msgs = [] - def receive_msg(msg): + def receive_msg(msg) -> None: msgs.append(msg) client.call("MyModule/add", ([1, 2], {}), receive_msg) @@ -148,7 +146,7 @@ def receive_msg(msg): # # can do blocking calls @pytest.mark.parametrize("rpc_context", testgrid) -def test_sync(rpc_context): +def test_sync(rpc_context) -> None: with rpc_context() as (server, client): module = MyModule() print("\n") @@ -162,7 +160,7 @@ def test_sync(rpc_context): # # can do blocking calls @pytest.mark.parametrize("rpc_context", testgrid) -def test_kwargs(rpc_context): +def test_kwargs(rpc_context) -> None: with rpc_context() as (server, client): module = MyModule() print("\n") @@ -175,7 +173,7 @@ def test_kwargs(rpc_context): # or async calls as well @pytest.mark.parametrize("rpc_context", testgrid) @pytest.mark.asyncio -async def test_async(rpc_context): +async def test_async(rpc_context) -> None: with rpc_context() as (server, client): module = MyModule() print("\n") @@ -185,14 +183,14 @@ async def test_async(rpc_context): # or async calls as well @pytest.mark.module -def test_rpc_full_deploy(): +def test_rpc_full_deploy() -> None: autoconf() # test module we'll use for binding RPC methods class CallerModule(Module): remote: Callable[[int, int], int] - def __init__(self, remote: Callable[[int, int], int]): + def __init__(self, remote: Callable[[int, int], int]) -> None: self.remote = remote super().__init__() diff --git a/dimos/protocol/rpc/pubsubrpc.py b/dimos/protocol/rpc/pubsubrpc.py index ef4fb25aa4..033cb7a5e2 100644 --- a/dimos/protocol/rpc/pubsubrpc.py +++ b/dimos/protocol/rpc/pubsubrpc.py @@ -14,14 +14,13 @@ from __future__ import annotations -import time from abc import abstractmethod -from types import FunctionType +from collections.abc import Callable +import time from typing import ( + TYPE_CHECKING, Any, - Callable, Generic, - Optional, TypedDict, TypeVar, ) @@ -30,6 +29,8 @@ from dimos.protocol.rpc.spec import Args, RPCSpec from dimos.utils.logging_config import setup_logger +if TYPE_CHECKING: + from types import FunctionType logger = setup_logger(__file__) @@ -68,7 +69,7 @@ def _encodeRPCReq(self, res: RPCReq) -> MsgT: ... @abstractmethod def _encodeRPCRes(self, res: RPCRes) -> MsgT: ... - def call(self, name: str, arguments: Args, cb: Optional[Callable]): + def call(self, name: str, arguments: Args, cb: Callable | None): if cb is None: return self.call_nowait(name, arguments) @@ -81,7 +82,7 @@ def call_cb(self, name: str, arguments: Args, cb: Callable) -> Any: req: RPCReq = {"name": name, "args": arguments, "id": msg_id} - def receive_response(msg: MsgT, _: TopicT): + def receive_response(msg: MsgT, _: TopicT) -> None: res = self._decodeRPCRes(msg) if res.get("id") != msg_id: return @@ -100,7 +101,7 @@ def call_nowait(self, name: str, arguments: Args) -> None: req: RPCReq = {"name": name, "args": arguments, "id": None} self.publish(topic_req, self._encodeRPCReq(req)) - def serve_rpc(self, f: FunctionType, name: Optional[str] = None): + def serve_rpc(self, f: FunctionType, name: str | None = None): if not name: name = f.__name__ @@ -118,7 +119,7 @@ def receive_call(msg: MsgT, _: TopicT) -> None: # Execute RPC handler in a separate thread to avoid deadlock when # the handler makes nested RPC calls. - def execute_and_respond(): + def execute_and_respond() -> None: try: response = f(*args[0], **args[1]) req_id = req.get("id") diff --git a/dimos/protocol/rpc/spec.py b/dimos/protocol/rpc/spec.py index 461d60f8ae..283b84f1dd 100644 --- a/dimos/protocol/rpc/spec.py +++ b/dimos/protocol/rpc/spec.py @@ -13,15 +13,15 @@ # limitations under the License. import asyncio +from collections.abc import Callable import threading -import time -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, overload +from typing import Any, Protocol, overload class Empty: ... -Args = Tuple[List, Dict[str, Any]] +Args = tuple[list, dict[str, Any]] # module that we can inspect for RPCs @@ -39,18 +39,16 @@ def call(self, name: str, arguments: Args, cb: None) -> None: ... @overload def call(self, name: str, arguments: Args, cb: Callable[[Any], None]) -> Callable[[], Any]: ... - def call( - self, name: str, arguments: Args, cb: Optional[Callable] - ) -> Optional[Callable[[], Any]]: ... + def call(self, name: str, arguments: Args, cb: Callable | None) -> Callable[[], Any] | None: ... # we expect to crash if we don't get a return value after 10 seconds # but callers can override this timeout for extra long functions def call_sync( - self, name: str, arguments: Args, rpc_timeout: Optional[float] = 30.0 - ) -> Tuple[Any, Callable[[], None]]: + self, name: str, arguments: Args, rpc_timeout: float | None = 30.0 + ) -> tuple[Any, Callable[[], None]]: event = threading.Event() - def receive_value(val): + def receive_value(val) -> None: event.result = val # attach to event event.set() @@ -63,7 +61,7 @@ async def call_async(self, name: str, arguments: Args) -> Any: loop = asyncio.get_event_loop() future = loop.create_future() - def receive_value(val): + def receive_value(val) -> None: try: loop.call_soon_threadsafe(future.set_result, val) except Exception as e: @@ -77,7 +75,7 @@ def receive_value(val): class RPCServer(Protocol): def serve_rpc(self, f: Callable, name: str) -> Callable[[], None]: ... - def serve_module_rpc(self, module: RPCInspectable, name: Optional[str] = None): + def serve_module_rpc(self, module: RPCInspectable, name: str | None = None) -> None: for fname in module.rpcs.keys(): if not name: name = module.__class__.__name__ @@ -86,7 +84,7 @@ def override_f(*args, fname=fname, **kwargs): return getattr(module, fname)(*args, **kwargs) topic = name + "/" + fname - unsub_fn = self.serve_rpc(override_f, topic) + self.serve_rpc(override_f, topic) class RPCSpec(RPCServer, RPCClient): ... diff --git a/dimos/protocol/rpc/test_lcmrpc.py b/dimos/protocol/rpc/test_lcmrpc.py index 02fe0a2d3a..6ee00b23e0 100644 --- a/dimos/protocol/rpc/test_lcmrpc.py +++ b/dimos/protocol/rpc/test_lcmrpc.py @@ -12,8 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest from collections.abc import Generator + +import pytest + from dimos.constants import LCM_MAX_CHANNEL_NAME_LENGTH from dimos.protocol.rpc.lcmrpc import LCMRPC diff --git a/dimos/protocol/rpc/test_lcmrpc_timeout.py b/dimos/protocol/rpc/test_lcmrpc_timeout.py index 88b5436269..74cf4963c7 100644 --- a/dimos/protocol/rpc/test_lcmrpc_timeout.py +++ b/dimos/protocol/rpc/test_lcmrpc_timeout.py @@ -50,7 +50,7 @@ def lcm_client(): client.stop() -def test_lcmrpc_timeout_no_reply(lcm_server, lcm_client): +def test_lcmrpc_timeout_no_reply(lcm_server, lcm_client) -> None: """Test that RPC calls timeout when no reply is received""" server = lcm_server client = lcm_client @@ -84,7 +84,7 @@ def never_responds(a: int, b: int): assert function_called.wait(0.5), "Server function was never called" -def test_lcmrpc_timeout_nonexistent_service(lcm_client): +def test_lcmrpc_timeout_nonexistent_service(lcm_client) -> None: """Test that RPC calls timeout when calling a non-existent service""" client = lcm_client @@ -103,7 +103,7 @@ def test_lcmrpc_timeout_nonexistent_service(lcm_client): assert elapsed < 0.3, f"Timeout took too long: {elapsed}s" -def test_lcmrpc_callback_with_timeout(lcm_server, lcm_client): +def test_lcmrpc_callback_with_timeout(lcm_server, lcm_client) -> None: """Test that callback-based RPC calls handle timeouts properly""" server = lcm_server client = lcm_client @@ -121,7 +121,7 @@ def never_responds(a: int, b: int): callback_called = threading.Event() received_value = [] - def callback(value): + def callback(value) -> None: received_value.append(value) callback_called.set() @@ -141,7 +141,7 @@ def callback(value): unsub() -def test_lcmrpc_normal_operation(lcm_server, lcm_client): +def test_lcmrpc_normal_operation(lcm_server, lcm_client) -> None: """Sanity check that normal RPC calls still work""" server = lcm_server client = lcm_client diff --git a/dimos/protocol/service/lcmservice.py b/dimos/protocol/service/lcmservice.py index f1cabbba3d..1b19a5cfeb 100644 --- a/dimos/protocol/service/lcmservice.py +++ b/dimos/protocol/service/lcmservice.py @@ -15,14 +15,14 @@ from __future__ import annotations from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass +from functools import cache import os import subprocess import sys import threading import traceback -from dataclasses import dataclass -from functools import cache -from typing import Optional, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable import lcm @@ -69,7 +69,7 @@ def check_multicast() -> list[str]: return commands_needed -def check_buffers() -> tuple[list[str], Optional[int]]: +def check_buffers() -> tuple[list[str], int | None]: """Check if buffer configuration is needed and return required commands and current size. Returns: @@ -192,7 +192,7 @@ class LCMConfig: ttl: int = 0 url: str | None = None autoconf: bool = True - lcm: Optional[lcm.LCM] = None + lcm: lcm.LCM | None = None @runtime_checkable @@ -200,7 +200,7 @@ class LCMMsg(Protocol): msg_name: str @classmethod - def lcm_decode(cls, data: bytes) -> "LCMMsg": + def lcm_decode(cls, data: bytes) -> LCMMsg: """Decode bytes into an LCM message instance.""" ... @@ -212,7 +212,7 @@ def lcm_encode(self) -> bytes: @dataclass class Topic: topic: str = "" - lcm_type: Optional[type[LCMMsg]] = None + lcm_type: type[LCMMsg] | None = None def __str__(self) -> str: if self.lcm_type is None: @@ -222,10 +222,10 @@ def __str__(self) -> str: class LCMService(Service[LCMConfig]): default_config = LCMConfig - l: Optional[lcm.LCM] + l: lcm.LCM | None _stop_event: threading.Event _l_lock: threading.Lock - _thread: Optional[threading.Thread] + _thread: threading.Thread | None _call_thread_pool: ThreadPoolExecutor | None = None _call_thread_pool_lock: threading.RLock = threading.RLock() @@ -256,7 +256,7 @@ def __getstate__(self): state.pop("_call_thread_pool_lock", None) return state - def __setstate__(self, state): + def __setstate__(self, state) -> None: """Restore object from pickled state.""" self.__dict__.update(state) # Reinitialize runtime attributes @@ -267,7 +267,7 @@ def __setstate__(self, state): self._call_thread_pool = None self._call_thread_pool_lock = threading.RLock() - def start(self): + def start(self) -> None: # Reinitialize LCM if it's None (e.g., after unpickling) if self.l is None: if self.config.lcm: @@ -300,7 +300,7 @@ def _lcm_loop(self) -> None: stack_trace = traceback.format_exc() print(f"Error in LCM handling: {e}\n{stack_trace}") - def stop(self): + def stop(self) -> None: """Stop the LCM loop.""" self._stop_event.set() if self._thread is not None: diff --git a/dimos/protocol/service/spec.py b/dimos/protocol/service/spec.py index 5406e2151f..d55c1bfacf 100644 --- a/dimos/protocol/service/spec.py +++ b/dimos/protocol/service/spec.py @@ -13,14 +13,14 @@ # limitations under the License. from abc import ABC -from typing import Generic, Type, TypeVar +from typing import Generic, TypeVar # Generic type for service configuration ConfigT = TypeVar("ConfigT") class Configurable(Generic[ConfigT]): - default_config: Type[ConfigT] + default_config: type[ConfigT] def __init__(self, **kwargs) -> None: self.config: ConfigT = self.default_config(**kwargs) diff --git a/dimos/protocol/service/test_lcmservice.py b/dimos/protocol/service/test_lcmservice.py index 7065029b91..619513d46f 100644 --- a/dimos/protocol/service/test_lcmservice.py +++ b/dimos/protocol/service/test_lcmservice.py @@ -14,12 +14,10 @@ import os import subprocess -import time from unittest.mock import patch import pytest -from dimos.msgs.geometry_msgs import Pose, Quaternion, Vector3 from dimos.protocol.service.lcmservice import ( autoconf, check_buffers, @@ -33,7 +31,7 @@ def get_sudo_prefix() -> str: return "" if check_root() else "sudo " -def test_check_multicast_all_configured(): +def test_check_multicast_all_configured() -> None: """Test check_multicast when system is properly configured.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock successful checks with realistic output format @@ -53,7 +51,7 @@ def test_check_multicast_all_configured(): assert result == [] -def test_check_multicast_missing_multicast_flag(): +def test_check_multicast_missing_multicast_flag() -> None: """Test check_multicast when loopback interface lacks multicast.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock interface without MULTICAST flag (realistic current system state) @@ -74,7 +72,7 @@ def test_check_multicast_missing_multicast_flag(): assert result == [f"{sudo}ifconfig lo multicast"] -def test_check_multicast_missing_route(): +def test_check_multicast_missing_route() -> None: """Test check_multicast when multicast route is missing.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock missing route - interface has multicast but no route @@ -95,7 +93,7 @@ def test_check_multicast_missing_route(): assert result == [f"{sudo}route add -net 224.0.0.0 netmask 240.0.0.0 dev lo"] -def test_check_multicast_all_missing(): +def test_check_multicast_all_missing() -> None: """Test check_multicast when both multicast flag and route are missing (current system state).""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock both missing - matches actual current system state @@ -120,7 +118,7 @@ def test_check_multicast_all_missing(): assert result == expected -def test_check_multicast_subprocess_exception(): +def test_check_multicast_subprocess_exception() -> None: """Test check_multicast when subprocess calls fail.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock subprocess exceptions @@ -135,7 +133,7 @@ def test_check_multicast_subprocess_exception(): assert result == expected -def test_check_buffers_all_configured(): +def test_check_buffers_all_configured() -> None: """Test check_buffers when system is properly configured.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock sufficient buffer sizes @@ -151,7 +149,7 @@ def test_check_buffers_all_configured(): assert buffer_size == 2097152 -def test_check_buffers_low_max_buffer(): +def test_check_buffers_low_max_buffer() -> None: """Test check_buffers when rmem_max is too low.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock low rmem_max @@ -168,7 +166,7 @@ def test_check_buffers_low_max_buffer(): assert buffer_size == 1048576 -def test_check_buffers_low_default_buffer(): +def test_check_buffers_low_default_buffer() -> None: """Test check_buffers when rmem_default is too low.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock low rmem_default @@ -185,7 +183,7 @@ def test_check_buffers_low_default_buffer(): assert buffer_size == 2097152 -def test_check_buffers_both_low(): +def test_check_buffers_both_low() -> None: """Test check_buffers when both buffer sizes are too low.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock both low @@ -206,7 +204,7 @@ def test_check_buffers_both_low(): assert buffer_size == 1048576 -def test_check_buffers_subprocess_exception(): +def test_check_buffers_subprocess_exception() -> None: """Test check_buffers when subprocess calls fail.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock subprocess exceptions @@ -222,7 +220,7 @@ def test_check_buffers_subprocess_exception(): assert buffer_size is None -def test_check_buffers_parsing_error(): +def test_check_buffers_parsing_error() -> None: """Test check_buffers when output parsing fails.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock malformed output @@ -241,7 +239,7 @@ def test_check_buffers_parsing_error(): assert buffer_size is None -def test_check_buffers_dev_container(): +def test_check_buffers_dev_container() -> None: """Test check_buffers in dev container where sysctl fails.""" with patch("dimos.protocol.pubsub.lcmpubsub.subprocess.run") as mock_run: # Mock dev container behavior - sysctl returns non-zero @@ -274,7 +272,7 @@ def test_check_buffers_dev_container(): assert buffer_size is None -def test_autoconf_no_config_needed(): +def test_autoconf_no_config_needed() -> None: """Test autoconf when no configuration is needed.""" # Clear CI environment variable for this test with patch.dict(os.environ, {"CI": ""}, clear=False): @@ -310,7 +308,7 @@ def test_autoconf_no_config_needed(): mock_logger.warning.assert_not_called() -def test_autoconf_with_config_needed_success(): +def test_autoconf_with_config_needed_success() -> None: """Test autoconf when configuration is needed and commands succeed.""" # Clear CI environment variable for this test with patch.dict(os.environ, {"CI": ""}, clear=False): @@ -365,7 +363,7 @@ def test_autoconf_with_config_needed_success(): mock_logger.info.assert_has_calls(expected_info_calls) -def test_autoconf_with_command_failures(): +def test_autoconf_with_command_failures() -> None: """Test autoconf when some commands fail.""" # Clear CI environment variable for this test with patch.dict(os.environ, {"CI": ""}, clear=False): @@ -392,8 +390,17 @@ def test_autoconf_with_command_failures(): )(), # ifconfig lo multicast subprocess.CalledProcessError( 1, - get_sudo_prefix().split() - + ["route", "add", "-net", "224.0.0.0", "netmask", "240.0.0.0", "dev", "lo"], + [ + *get_sudo_prefix().split(), + "route", + "add", + "-net", + "224.0.0.0", + "netmask", + "240.0.0.0", + "dev", + "lo", + ], "Permission denied", "Operation not permitted", ), diff --git a/dimos/protocol/service/test_spec.py b/dimos/protocol/service/test_spec.py index 0706af5112..9842f9c49f 100644 --- a/dimos/protocol/service/test_spec.py +++ b/dimos/protocol/service/test_spec.py @@ -16,8 +16,6 @@ from dataclasses import dataclass -from typing_extensions import TypedDict - from dimos.protocol.service.spec import Service @@ -38,7 +36,7 @@ def start(self) -> None: ... def stop(self) -> None: ... -def test_default_configuration(): +def test_default_configuration() -> None: """Test that default configuration is applied correctly.""" service = DatabaseService() @@ -51,7 +49,7 @@ def test_default_configuration(): assert service.config.ssl_enabled is False -def test_partial_configuration_override(): +def test_partial_configuration_override() -> None: """Test that partial configuration correctly overrides defaults.""" service = DatabaseService(host="production-db", port=3306, ssl_enabled=True) @@ -66,7 +64,7 @@ def test_partial_configuration_override(): assert service.config.max_connections == 10 -def test_complete_configuration_override(): +def test_complete_configuration_override() -> None: """Test that all configuration values can be overridden.""" service = DatabaseService( host="custom-host", @@ -86,7 +84,7 @@ def test_complete_configuration_override(): assert service.config.ssl_enabled is True -def test_service_subclassing(): +def test_service_subclassing() -> None: @dataclass class ExtraConfig(DatabaseConfig): extra_param: str = "default_value" @@ -94,7 +92,7 @@ class ExtraConfig(DatabaseConfig): class ExtraDatabaseService(DatabaseService): default_config = ExtraConfig - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) bla = ExtraDatabaseService(host="custom-host2", extra_param="extra_value") diff --git a/dimos/protocol/skill/comms.py b/dimos/protocol/skill/comms.py index 09273c36c0..b0adecf5c5 100644 --- a/dimos/protocol/skill/comms.py +++ b/dimos/protocol/skill/comms.py @@ -15,13 +15,17 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Callable, Generic, Optional, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypeVar from dimos.protocol.pubsub.lcmpubsub import PickleLCM -from dimos.protocol.pubsub.spec import PubSub from dimos.protocol.service import Service from dimos.protocol.skill.type import SkillMsg +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.protocol.pubsub.spec import PubSub + # defines a protocol for communication between skills and agents # it has simple requirements of pub/sub semantics capable of sending and receiving SkillMsg objects @@ -46,8 +50,8 @@ def stop(self) -> None: ... @dataclass class PubSubCommsConfig(Generic[TopicT, MsgT]): - topic: Optional[TopicT] = None - pubsub: Union[type[PubSub[TopicT, MsgT]], PubSub[TopicT, MsgT], None] = None + topic: TopicT | None = None + pubsub: type[PubSub[TopicT, MsgT]] | PubSub[TopicT, MsgT] | None = None autostart: bool = True @@ -72,7 +76,7 @@ def __init__(self, **kwargs) -> None: def start(self) -> None: self.pubsub.start() - def stop(self): + def stop(self) -> None: self.pubsub.stop() def publish(self, msg: SkillMsg) -> None: @@ -85,7 +89,7 @@ def subscribe(self, cb: Callable[[SkillMsg], None]) -> None: @dataclass class LCMCommsConfig(PubSubCommsConfig[str, SkillMsg]): topic: str = "/skill" - pubsub: Union[type[PubSub], PubSub, None] = PickleLCM + pubsub: type[PubSub] | PubSub | None = PickleLCM # lcm needs to be started only if receiving # skill comms are broadcast only in modules so we don't autostart autostart: bool = False diff --git a/dimos/protocol/skill/coordinator.py b/dimos/protocol/skill/coordinator.py index e9c8680864..a672ceacee 100644 --- a/dimos/protocol/skill/coordinator.py +++ b/dimos/protocol/skill/coordinator.py @@ -13,13 +13,13 @@ # limitations under the License. import asyncio -import json -import threading -import time from copy import copy from dataclasses import dataclass from enum import Enum -from typing import Any, List, Literal, Optional, Union +import json +import threading +import time +from typing import Any, Literal from langchain_core.messages import ToolMessage from langchain_core.tools import tool as langchain_tool @@ -28,14 +28,12 @@ from rich.text import Text from dimos.core import rpc -from dimos.core.module import get_loop +from dimos.core.module import Module, get_loop from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec from dimos.protocol.skill.skill import SkillConfig, SkillContainer from dimos.protocol.skill.type import MsgType, Output, Reducer, Return, SkillMsg, Stream from dimos.protocol.skill.utils import interpret_tool_call_args from dimos.utils.logging_config import setup_logger -from dimos.core.module import Module - logger = setup_logger(__file__) @@ -76,9 +74,9 @@ class SkillState: end_msg: SkillMsg[Literal[MsgType.ret]] = None error_msg: SkillMsg[Literal[MsgType.error]] = None ret_msg: SkillMsg[Literal[MsgType.ret]] = None - reduced_stream_msg: List[SkillMsg[Literal[MsgType.reduced_stream]]] = None + reduced_stream_msg: list[SkillMsg[Literal[MsgType.reduced_stream]]] = None - def __init__(self, call_id: str, name: str, skill_config: Optional[SkillConfig] = None) -> None: + def __init__(self, call_id: str, name: str, skill_config: SkillConfig | None = None) -> None: super().__init__() self.skill_config = skill_config or SkillConfig( @@ -120,7 +118,7 @@ def content(self) -> dict[str, Any] | str | int | float | None: else: return self.error_msg.content - def agent_encode(self) -> Union[ToolMessage, str]: + def agent_encode(self) -> ToolMessage | str: # tool call can emit a single ToolMessage # subsequent messages are considered SituationalAwarenessMessages, # those are collapsed into a HumanMessage, that's artificially prepended to history @@ -249,7 +247,7 @@ def table(self) -> Table: states_table.add_row("", "[dim]No active skills[/dim]", "", "", "") return states_table - def __str__(self): + def __str__(self) -> str: console = Console(force_terminal=True, legacy_windows=False) # Render to string with title above @@ -272,10 +270,10 @@ class SkillCoordinator(Module): _dynamic_containers: list[SkillContainer] _skill_state: SkillStateDict # key is call_id, not skill_name _skills: dict[str, SkillConfig] - _updates_available: Optional[asyncio.Event] - _loop: Optional[asyncio.AbstractEventLoop] - _loop_thread: Optional[threading.Thread] - _agent_loop: Optional[asyncio.AbstractEventLoop] + _updates_available: asyncio.Event | None + _loop: asyncio.AbstractEventLoop | None + _loop_thread: threading.Thread | None + _agent_loop: asyncio.AbstractEventLoop | None def __init__(self) -> None: # TODO: Why isn't this super().__init__() ? @@ -357,7 +355,7 @@ def get_tools(self) -> list[dict]: # internal skill call def call_skill( - self, call_id: Union[str | Literal[False]], skill_name: str, args: dict[str, Any] + self, call_id: str | Literal[False], skill_name: str, args: dict[str, Any] ) -> None: if not call_id: call_id = str(time.time()) @@ -413,7 +411,7 @@ def handle_message(self, msg: SkillMsg) -> None: if should_notify: updates_available = self._ensure_updates_available() if updates_available is None: - print(f"[DEBUG] Event not created yet, deferring notification") + print("[DEBUG] Event not created yet, deferring notification") return try: @@ -462,7 +460,7 @@ def has_passive_skills(self) -> bool: return False return True - async def wait_for_updates(self, timeout: Optional[float] = None) -> True: + async def wait_for_updates(self, timeout: float | None = None) -> True: """Wait for skill updates to become available. This method should be called by the agent when it's ready to receive updates. @@ -503,17 +501,17 @@ async def wait_for_updates(self, timeout: Optional[float] = None) -> True: # print(f"[DEBUG] Waiting for event with timeout {timeout}") await asyncio.wait_for(updates_available.wait(), timeout=timeout) else: - print(f"[DEBUG] Waiting for event without timeout") + print("[DEBUG] Waiting for event without timeout") await updates_available.wait() - print(f"[DEBUG] Event was set! Returning True") + print("[DEBUG] Event was set! Returning True") return True except asyncio.TimeoutError: - print(f"[DEBUG] Timeout occurred while waiting for event") + print("[DEBUG] Timeout occurred while waiting for event") return False except RuntimeError as e: if "bound to a different event loop" in str(e): print( - f"[DEBUG] Event loop binding error detected, recreating event and returning False to retry" + "[DEBUG] Event loop binding error detected, recreating event and returning False to retry" ) # Recreate the event in the current loop current_loop = asyncio.get_running_loop() @@ -570,7 +568,7 @@ def generate_snapshot(self, clear: bool = True) -> SkillStateDict: return ret - def __str__(self): + def __str__(self) -> str: console = Console(force_terminal=True, legacy_windows=False) # Create main table without any header @@ -614,7 +612,7 @@ def __str__(self): # # Dynamic containers will be queried at runtime via # .skills() method - def register_skills(self, container: SkillContainer): + def register_skills(self, container: SkillContainer) -> None: self.empty = False if not container.dynamic_skills(): logger.info(f"Registering static skill container, {container}") @@ -625,7 +623,7 @@ def register_skills(self, container: SkillContainer): logger.info(f"Registering dynamic skill container, {container}") self._dynamic_containers.append(container) - def get_skill_config(self, skill_name: str) -> Optional[SkillConfig]: + def get_skill_config(self, skill_name: str) -> SkillConfig | None: skill_config = self._skills.get(skill_name) if not skill_config: skill_config = self.skills().get(skill_name) diff --git a/dimos/protocol/skill/schema.py b/dimos/protocol/skill/schema.py index 37a6e6fac1..49dc1caa37 100644 --- a/dimos/protocol/skill/schema.py +++ b/dimos/protocol/skill/schema.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Dict, List, Union, get_args, get_origin +from typing import Union, get_args, get_origin def python_type_to_json_schema(python_type) -> dict: @@ -37,14 +37,14 @@ def python_type_to_json_schema(python_type) -> dict: return {"anyOf": [python_type_to_json_schema(arg) for arg in args]} # Handle List/list types - if origin in (list, List): + if origin in (list, list): args = get_args(python_type) if args: return {"type": "array", "items": python_type_to_json_schema(args[0])} return {"type": "array"} # Handle Dict/dict types - if origin in (dict, Dict): + if origin in (dict, dict): return {"type": "object"} # Handle basic types @@ -65,7 +65,7 @@ def function_to_schema(func) -> dict: try: signature = inspect.signature(func) except ValueError as e: - raise ValueError(f"Failed to get signature for function {func.__name__}: {str(e)}") + raise ValueError(f"Failed to get signature for function {func.__name__}: {e!s}") properties = {} required = [] diff --git a/dimos/protocol/skill/skill.py b/dimos/protocol/skill/skill.py index 5008232554..7ad260eaa5 100644 --- a/dimos/protocol/skill/skill.py +++ b/dimos/protocol/skill/skill.py @@ -13,10 +13,10 @@ # limitations under the License. import asyncio -import threading +from collections.abc import Callable from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass -from typing import Any, Callable, Optional +from typing import Any # from dimos.core.core import rpc from dimos.protocol.skill.comms import LCMSkillComms, SkillCommsSpec @@ -148,18 +148,18 @@ def wrapper(self, *args, **kwargs): class SkillContainer: skill_transport_class: type[SkillCommsSpec] = LCMSkillComms - _skill_thread_pool: Optional[ThreadPoolExecutor] = None - _skill_transport: Optional[SkillCommsSpec] = None + _skill_thread_pool: ThreadPoolExecutor | None = None + _skill_transport: SkillCommsSpec | None = None @rpc - def dynamic_skills(self): + def dynamic_skills(self) -> bool: return False def __str__(self) -> str: return f"SkillContainer({self.__class__.__name__})" @rpc - def stop(self): + def stop(self) -> None: if self._skill_transport: self._skill_transport.stop() self._skill_transport = None diff --git a/dimos/protocol/skill/test_coordinator.py b/dimos/protocol/skill/test_coordinator.py index 65b45c50fa..e8d8c45a0c 100644 --- a/dimos/protocol/skill/test_coordinator.py +++ b/dimos/protocol/skill/test_coordinator.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +from collections.abc import Generator import datetime import time -from typing import Generator, Optional import pytest @@ -28,11 +28,11 @@ class SkillContainerTest(Module): @rpc - def start(self): + def start(self) -> None: super().start() @rpc - def stop(self): + def stop(self) -> None: super().stop() @skill() @@ -48,7 +48,7 @@ def delayadd(self, x: int, y: int) -> int: return x + y @skill(stream=Stream.call_agent, reducer=Reducer.all) - def counter(self, count_to: int, delay: Optional[float] = 0.05) -> Generator[int, None, None]: + def counter(self, count_to: int, delay: float | None = 0.05) -> Generator[int, None, None]: """Counts from 1 to count_to, with an optional delay between counts.""" for i in range(1, count_to + 1): if delay > 0: @@ -57,7 +57,7 @@ def counter(self, count_to: int, delay: Optional[float] = 0.05) -> Generator[int @skill(stream=Stream.passive, reducer=Reducer.sum) def counter_passive_sum( - self, count_to: int, delay: Optional[float] = 0.05 + self, count_to: int, delay: float | None = 0.05 ) -> Generator[int, None, None]: """Counts from 1 to count_to, with an optional delay between counts.""" for i in range(1, count_to + 1): @@ -66,14 +66,14 @@ def counter_passive_sum( yield i @skill(stream=Stream.passive, reducer=Reducer.latest) - def current_time(self, frequency: Optional[float] = 10) -> Generator[str, None, None]: + def current_time(self, frequency: float | None = 10) -> Generator[str, None, None]: """Provides current time.""" while True: yield str(datetime.datetime.now()) time.sleep(1 / frequency) @skill(stream=Stream.passive, reducer=Reducer.latest) - def uptime_seconds(self, frequency: Optional[float] = 10) -> Generator[float, None, None]: + def uptime_seconds(self, frequency: float | None = 10) -> Generator[float, None, None]: """Provides current uptime.""" start_time = datetime.datetime.now() while True: @@ -81,7 +81,7 @@ def uptime_seconds(self, frequency: Optional[float] = 10) -> Generator[float, No time.sleep(1 / frequency) @skill() - def current_date(self, frequency: Optional[float] = 10) -> str: + def current_date(self, frequency: float | None = 10) -> str: """Provides current date.""" return datetime.datetime.now() @@ -95,7 +95,7 @@ def take_photo(self) -> str: @pytest.mark.asyncio -async def test_coordinator_parallel_calls(): +async def test_coordinator_parallel_calls() -> None: skillCoordinator = SkillCoordinator() skillCoordinator.register_skills(SkillContainerTest()) @@ -133,7 +133,7 @@ async def test_coordinator_parallel_calls(): @pytest.mark.asyncio -async def test_coordinator_generator(): +async def test_coordinator_generator() -> None: container = SkillContainerTest() skillCoordinator = SkillCoordinator() skillCoordinator.register_skills(container) diff --git a/dimos/protocol/skill/test_utils.py b/dimos/protocol/skill/test_utils.py index 57c16579f5..db332357fe 100644 --- a/dimos/protocol/skill/test_utils.py +++ b/dimos/protocol/skill/test_utils.py @@ -15,73 +15,73 @@ from dimos.protocol.skill.utils import interpret_tool_call_args -def test_list(): +def test_list() -> None: args, kwargs = interpret_tool_call_args([1, 2, 3]) assert args == [1, 2, 3] assert kwargs == {} -def test_none(): +def test_none() -> None: args, kwargs = interpret_tool_call_args(None) assert args == [] assert kwargs == {} -def test_none_nested(): +def test_none_nested() -> None: args, kwargs = interpret_tool_call_args({"args": None}) assert args == [] assert kwargs == {} -def test_non_dict(): +def test_non_dict() -> None: args, kwargs = interpret_tool_call_args("test") assert args == ["test"] assert kwargs == {} -def test_dict_with_args_and_kwargs(): +def test_dict_with_args_and_kwargs() -> None: args, kwargs = interpret_tool_call_args({"args": [1, 2], "kwargs": {"key": "value"}}) assert args == [1, 2] assert kwargs == {"key": "value"} -def test_dict_with_only_kwargs(): +def test_dict_with_only_kwargs() -> None: args, kwargs = interpret_tool_call_args({"kwargs": {"a": 1, "b": 2}}) assert args == [] assert kwargs == {"a": 1, "b": 2} -def test_dict_as_kwargs(): +def test_dict_as_kwargs() -> None: args, kwargs = interpret_tool_call_args({"x": 10, "y": 20}) assert args == [] assert kwargs == {"x": 10, "y": 20} -def test_dict_with_only_args_first_pass(): +def test_dict_with_only_args_first_pass() -> None: args, kwargs = interpret_tool_call_args({"args": [5, 6, 7]}) assert args == [5, 6, 7] assert kwargs == {} -def test_dict_with_only_args_nested(): +def test_dict_with_only_args_nested() -> None: args, kwargs = interpret_tool_call_args({"args": {"inner": "value"}}) assert args == [] assert kwargs == {"inner": "value"} -def test_empty_list(): +def test_empty_list() -> None: args, kwargs = interpret_tool_call_args([]) assert args == [] assert kwargs == {} -def test_empty_dict(): +def test_empty_dict() -> None: args, kwargs = interpret_tool_call_args({}) assert args == [] assert kwargs == {} -def test_integer(): +def test_integer() -> None: args, kwargs = interpret_tool_call_args(42) assert args == [42] assert kwargs == {} diff --git a/dimos/protocol/skill/type.py b/dimos/protocol/skill/type.py index 7ffbe13798..9b1c4ce5f5 100644 --- a/dimos/protocol/skill/type.py +++ b/dimos/protocol/skill/type.py @@ -13,11 +13,11 @@ # limitations under the License. from __future__ import annotations -import time -import os +from collections.abc import Callable from dataclasses import dataclass from enum import Enum -from typing import Any, Callable, Generic, Literal, Optional, TypeVar +import time +from typing import Any, Generic, Literal, TypeVar from dimos.types.timestamped import Timestamped from dimos.utils.generic import truncate_display_string @@ -54,7 +54,7 @@ class Return(Enum): @dataclass class SkillConfig: name: str - reducer: "ReducerF" + reducer: ReducerF stream: Stream ret: Return output: Output @@ -63,7 +63,7 @@ class SkillConfig: autostart: bool = False hide_skill: bool = False - def bind(self, f: Callable) -> "SkillConfig": + def bind(self, f: Callable) -> SkillConfig: self.f = f return self @@ -75,7 +75,7 @@ def call(self, call_id, *args, **kwargs) -> Any: return self.f(*args, **kwargs, call_id=call_id) - def __str__(self): + def __str__(self) -> str: parts = [f"name={self.name}"] # Only show reducer if stream is not none (streaming is happening) @@ -136,7 +136,7 @@ def end(self) -> bool: def start(self) -> bool: return self.type == MsgType.start - def __str__(self): + def __str__(self) -> str: time_ago = time.time() - self.ts if self.type == MsgType.start: @@ -156,7 +156,7 @@ def __str__(self): # typing looks complex but it's a standard reducer function signature, using SkillMsgs # (Optional[accumulator], msg) -> accumulator ReducerF = Callable[ - [Optional[SkillMsg[Literal[MsgType.reduced_stream]]], SkillMsg[Literal[MsgType.stream]]], + [SkillMsg[Literal[MsgType.reduced_stream]] | None, SkillMsg[Literal[MsgType.stream]]], SkillMsg[Literal[MsgType.reduced_stream]], ] @@ -164,7 +164,7 @@ def __str__(self): C = TypeVar("C") # content type A = TypeVar("A") # accumulator type # define a naive reducer function type that's generic in terms of the accumulator type -SimpleReducerF = Callable[[Optional[A], C], A] +SimpleReducerF = Callable[[A | None, C], A] def make_reducer(simple_reducer: SimpleReducerF) -> ReducerF: @@ -175,7 +175,7 @@ def make_reducer(simple_reducer: SimpleReducerF) -> ReducerF: """ def reducer( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: # Extract the content from the accumulator if it exists @@ -209,7 +209,7 @@ def _make_skill_msg( def sum_reducer( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: """Sum reducer that adds values together.""" @@ -219,7 +219,7 @@ def sum_reducer( def latest_reducer( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: """Latest reducer that keeps only the most recent value.""" @@ -227,17 +227,17 @@ def latest_reducer( def all_reducer( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: """All reducer that collects all values into a list.""" acc_value = accumulator.content if accumulator else None - new_value = acc_value + [msg.content] if acc_value else [msg.content] + new_value = [*acc_value, msg.content] if acc_value else [msg.content] return _make_skill_msg(msg, new_value) def accumulate_list( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: """All reducer that collects all values into a list.""" @@ -246,7 +246,7 @@ def accumulate_list( def accumulate_dict( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: """All reducer that collects all values into a list.""" @@ -255,7 +255,7 @@ def accumulate_dict( def accumulate_string( - accumulator: Optional[SkillMsg[Literal[MsgType.reduced_stream]]], + accumulator: SkillMsg[Literal[MsgType.reduced_stream]] | None, msg: SkillMsg[Literal[MsgType.stream]], ) -> SkillMsg[Literal[MsgType.reduced_stream]]: """All reducer that collects all values into a list.""" diff --git a/dimos/protocol/tf/__init__.py b/dimos/protocol/tf/__init__.py index 518a9b97f0..96cdbcf285 100644 --- a/dimos/protocol/tf/__init__.py +++ b/dimos/protocol/tf/__init__.py @@ -12,6 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.protocol.tf.tf import TF, LCMTF, PubSubTF, TFSpec, TFConfig, TBuffer, MultiTBuffer +from dimos.protocol.tf.tf import LCMTF, TF, MultiTBuffer, PubSubTF, TBuffer, TFConfig, TFSpec -__all__ = ["TF", "LCMTF", "PubSubTF", "TFSpec", "TFConfig", "TBuffer", "MultiTBuffer"] +__all__ = ["LCMTF", "TF", "MultiTBuffer", "PubSubTF", "TBuffer", "TFConfig", "TFSpec"] diff --git a/dimos/protocol/tf/test_tf.py b/dimos/protocol/tf/test_tf.py index 4d39e8764e..c25e1014f9 100644 --- a/dimos/protocol/tf/test_tf.py +++ b/dimos/protocol/tf/test_tf.py @@ -25,7 +25,7 @@ # from https://foxglove.dev/blog/understanding-ros-transforms -def test_tf_ros_example(): +def test_tf_ros_example() -> None: tf = TF() base_link_to_arm = Transform( @@ -55,7 +55,7 @@ def test_tf_ros_example(): tf.stop() -def test_tf_main(): +def test_tf_main() -> None: """Test TF broadcasting and querying between two TF instances. If you run foxglove-bridge this will show up in the UI""" @@ -184,7 +184,7 @@ def test_tf_main(): class TestTBuffer: - def test_add_transform(self): + def test_add_transform(self) -> None: buffer = TBuffer(buffer_size=10.0) transform = Transform( translation=Vector3(1.0, 2.0, 3.0), @@ -198,7 +198,7 @@ def test_add_transform(self): assert len(buffer) == 1 assert buffer[0] == transform - def test_get(self): + def test_get(self) -> None: buffer = TBuffer() base_time = time.time() @@ -226,7 +226,7 @@ def test_get(self): result = buffer.get(time_point=base_time + 10.0, time_tolerance=0.1) assert result is None # Outside tolerance - def test_buffer_pruning(self): + def test_buffer_pruning(self) -> None: buffer = TBuffer(buffer_size=1.0) # 1 second buffer # Add old transform @@ -254,7 +254,7 @@ def test_buffer_pruning(self): class TestMultiTBuffer: - def test_multiple_frame_pairs(self): + def test_multiple_frame_pairs(self) -> None: ttbuffer = MultiTBuffer(buffer_size=10.0) # Add transforms for different frame pairs @@ -279,7 +279,7 @@ def test_multiple_frame_pairs(self): assert ("world", "robot1") in ttbuffer.buffers assert ("world", "robot2") in ttbuffer.buffers - def test_graph(self): + def test_graph(self) -> None: ttbuffer = MultiTBuffer(buffer_size=10.0) # Add transforms for different frame pairs @@ -301,7 +301,7 @@ def test_graph(self): print(ttbuffer.graph()) - def test_get_latest_transform(self): + def test_get_latest_transform(self) -> None: ttbuffer = MultiTBuffer() # Add multiple transforms @@ -320,7 +320,7 @@ def test_get_latest_transform(self): assert latest is not None assert latest.translation.x == 2.0 - def test_get_transform_at_time(self): + def test_get_transform_at_time(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -342,7 +342,7 @@ def test_get_transform_at_time(self): # The implementation picks the later one when equidistant assert result.translation.x == 3.0 - def test_time_tolerance(self): + def test_time_tolerance(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -363,14 +363,14 @@ def test_time_tolerance(self): result = ttbuffer.get("world", "robot", time_point=base_time + 0.5, time_tolerance=0.1) assert result is None - def test_nonexistent_frame_pair(self): + def test_nonexistent_frame_pair(self) -> None: ttbuffer = MultiTBuffer() # Try to get transform for non-existent frame pair result = ttbuffer.get("foo", "bar") assert result is None - def test_get_transform_search_direct(self): + def test_get_transform_search_direct(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -389,7 +389,7 @@ def test_get_transform_search_direct(self): assert len(result) == 1 assert result[0].translation.x == 1.0 - def test_get_transform_search_chain(self): + def test_get_transform_search_chain(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -415,7 +415,7 @@ def test_get_transform_search_chain(self): assert result[0].translation.x == 1.0 # world -> robot assert result[1].translation.y == 2.0 # robot -> sensor - def test_get_transform_search_complex_chain(self): + def test_get_transform_search_complex_chain(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -466,7 +466,7 @@ def test_get_transform_search_complex_chain(self): assert result[1].child_frame_id == "arm" assert result[2].child_frame_id == "hand" - def test_get_transform_search_no_path(self): + def test_get_transform_search_no_path(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -479,7 +479,7 @@ def test_get_transform_search_no_path(self): result = ttbuffer.get_transform_search("world", "sensor") assert result is None - def test_get_transform_search_with_time(self): + def test_get_transform_search_with_time(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -509,7 +509,7 @@ def test_get_transform_search_with_time(self): ) assert result is None # Outside tolerance - def test_get_transform_search_shortest_path(self): + def test_get_transform_search_shortest_path(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -532,7 +532,7 @@ def test_get_transform_search_shortest_path(self): assert len(result) == 1 # Direct path, not the 3-hop path assert result[0].child_frame_id == "target" - def test_string_representations(self): + def test_string_representations(self) -> None: # Test empty buffers empty_buffer = TBuffer() assert str(empty_buffer) == "TBuffer(empty)" @@ -577,7 +577,7 @@ def test_string_representations(self): assert "TBuffer(world -> robot2, 1 msgs" in ttbuffer_str assert "TBuffer(robot1 -> sensor, 1 msgs" in ttbuffer_str - def test_get_with_transform_chain_composition(self): + def test_get_with_transform_chain_composition(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() @@ -627,7 +627,7 @@ def test_get_with_transform_chain_composition(self): assert result.frame_id == "world" assert result.child_frame_id == "sensor" - def test_get_with_longer_transform_chain(self): + def test_get_with_longer_transform_chain(self) -> None: ttbuffer = MultiTBuffer() base_time = time.time() diff --git a/dimos/protocol/tf/tf.py b/dimos/protocol/tf/tf.py index 0052ef4758..f60e216176 100644 --- a/dimos/protocol/tf/tf.py +++ b/dimos/protocol/tf/tf.py @@ -14,12 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time from abc import abstractmethod from collections import deque from dataclasses import dataclass, field from functools import reduce -from typing import Optional, TypeVar, Union +from typing import TypeVar from dimos.msgs.geometry_msgs import Transform from dimos.msgs.tf2_msgs import TFMessage @@ -40,7 +39,7 @@ class TFConfig: # generic specification for transform service class TFSpec(Service[TFConfig]): - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) @abstractmethod @@ -57,8 +56,8 @@ def get( self, parent_frame: str, child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, + time_point: float | None = None, + time_tolerance: float | None = None, ): ... def receive_transform(self, *args: Transform) -> None: ... @@ -74,7 +73,7 @@ def receive_tfmessage(self, msg: TFMessage) -> None: # stores a single transform class TBuffer(TimestampedCollection[Transform]): - def __init__(self, buffer_size: float = 10.0): + def __init__(self, buffer_size: float = 10.0) -> None: super().__init__() self.buffer_size = buffer_size @@ -91,9 +90,7 @@ def _prune_old_transforms(self, current_time) -> None: while self._items and self._items[0].ts < cutoff_time: self._items.pop(0) - def get( - self, time_point: Optional[float] = None, time_tolerance: float = 1.0 - ) -> Optional[Transform]: + def get(self, time_point: float | None = None, time_tolerance: float = 1.0) -> Transform | None: """Get transform at specified time or latest if no time given.""" if time_point is None: # Return the latest transform @@ -137,7 +134,7 @@ def __str__(self) -> str: # stores multiple transform buffers # creates a new buffer on demand when new transform is detected class MultiTBuffer: - def __init__(self, buffer_size: float = 10.0): + def __init__(self, buffer_size: float = 10.0) -> None: self.buffers: dict[tuple[str, str], TBuffer] = {} self.buffer_size = buffer_size @@ -169,9 +166,9 @@ def get_transform( self, parent_frame: str, child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ) -> Optional[Transform]: + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> Transform | None: # Check forward direction key = (parent_frame, child_frame) if key in self.buffers: @@ -185,7 +182,7 @@ def get_transform( return None - def get(self, *args, **kwargs) -> Optional[Transform]: + def get(self, *args, **kwargs) -> Transform | None: simple = self.get_transform(*args, **kwargs) if simple is not None: return simple @@ -201,9 +198,9 @@ def get_transform_search( self, parent_frame: str, child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ) -> Optional[list[Transform]]: + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> list[Transform] | None: """Search for shortest transform chain between parent and child frames using BFS.""" # Check if direct transform exists (already checked in get_transform, but for clarity) direct = self.get_transform(parent_frame, child_frame, time_point, time_tolerance) @@ -232,14 +229,14 @@ def get_transform_search( current_frame, next_frame, time_point, time_tolerance ) if transform: - queue.append((next_frame, path + [transform])) + queue.append((next_frame, [*path, transform])) return None def graph(self) -> str: import subprocess - def connection_str(connection: tuple[str, str]): + def connection_str(connection: tuple[str, str]) -> str: (frame_from, frame_to) = connection return f"{frame_from} -> {frame_to}" @@ -269,8 +266,8 @@ def __str__(self) -> str: @dataclass class PubSubTFConfig(TFConfig): - topic: Optional[Topic] = None # Required field but needs default for dataclass inheritance - pubsub: Union[type[PubSub], PubSub, None] = None + topic: Topic | None = None # Required field but needs default for dataclass inheritance + pubsub: type[PubSub] | PubSub | None = None autostart: bool = True @@ -293,14 +290,14 @@ def __init__(self, **kwargs) -> None: if self.config.autostart: self.start() - def start(self, sub=True) -> None: + def start(self, sub: bool = True) -> None: self.pubsub.start() if sub: topic = getattr(self.config, "topic", None) if topic: self.pubsub.subscribe(topic, self.receive_msg) - def stop(self): + def stop(self) -> None: self.pubsub.stop() def publish(self, *args: Transform) -> None: @@ -332,9 +329,9 @@ def get( self, parent_frame: str, child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, - ) -> Optional[Transform]: + time_point: float | None = None, + time_tolerance: float | None = None, + ) -> Transform | None: return super().get(parent_frame, child_frame, time_point, time_tolerance) def receive_msg(self, msg: TFMessage, topic: Topic) -> None: @@ -344,7 +341,7 @@ def receive_msg(self, msg: TFMessage, topic: Topic) -> None: @dataclass class LCMPubsubConfig(PubSubTFConfig): topic: Topic = field(default_factory=lambda: Topic("/tf", TFMessage)) - pubsub: Union[type[PubSub], PubSub, None] = LCM + pubsub: type[PubSub] | PubSub | None = LCM autostart: bool = True diff --git a/dimos/protocol/tf/tflcmcpp.py b/dimos/protocol/tf/tflcmcpp.py index e12877bdec..0d5b31b9b6 100644 --- a/dimos/protocol/tf/tflcmcpp.py +++ b/dimos/protocol/tf/tflcmcpp.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union from datetime import datetime -from dimos_lcm import tf +from typing import Union + +from dimos.msgs.geometry_msgs import Transform from dimos.protocol.service.lcmservice import LCMConfig, LCMService -from dimos.protocol.tf.tf import TFSpec, TFConfig -from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 +from dimos.protocol.tf.tf import TFConfig, TFSpec # this doesn't work due to tf_lcm_py package @@ -62,8 +62,8 @@ def lookup( self, parent_frame: str, child_frame: str, - time_point: Optional[float] = None, - time_tolerance: Optional[float] = None, + time_point: float | None = None, + time_tolerance: float | None = None, ): return self.buffer.lookup_transform( parent_frame, @@ -73,7 +73,7 @@ def lookup( ) def can_transform( - self, parent_frame: str, child_frame: str, time_point: Optional[float | datetime] = None + self, parent_frame: str, child_frame: str, time_point: float | datetime | None = None ) -> bool: if not time_point: time_point = datetime.now() @@ -86,8 +86,8 @@ def can_transform( def get_frames(self) -> set[str]: return set(self.buffer.get_all_frame_names()) - def start(self): + def start(self) -> None: super().start() ... - def stop(self): ... + def stop(self) -> None: ... diff --git a/dimos/robot/agilex/piper_arm.py b/dimos/robot/agilex/piper_arm.py index 7dbb2fcbfc..642d39c7cb 100644 --- a/dimos/robot/agilex/piper_arm.py +++ b/dimos/robot/agilex/piper_arm.py @@ -13,22 +13,20 @@ # limitations under the License. import asyncio -import logging -from typing import Optional, List + +# Import LCM message types +from dimos_lcm.sensor_msgs import CameraInfo from dimos import core from dimos.hardware.camera.zed import ZEDModule from dimos.manipulation.visual_servoing.manipulation_module import ManipulationModule from dimos.msgs.sensor_msgs import Image from dimos.protocol import pubsub +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.robot.robot import Robot from dimos.skills.skills import SkillLibrary from dimos.types.robot_capabilities import RobotCapability -from dimos.robot.foxglove_bridge import FoxgloveBridge from dimos.utils.logging_config import setup_logger -from dimos.robot.robot import Robot - -# Import LCM message types -from dimos_lcm.sensor_msgs import CameraInfo logger = setup_logger("dimos.robot.agilex.piper_arm") @@ -36,7 +34,7 @@ class PiperArmRobot(Robot): """Piper Arm robot with ZED camera and manipulation capabilities.""" - def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): + def __init__(self, robot_capabilities: list[RobotCapability] | None = None) -> None: super().__init__() self.dimos = None self.stereo_camera = None @@ -49,7 +47,7 @@ def __init__(self, robot_capabilities: Optional[List[RobotCapability]] = None): RobotCapability.MANIPULATION, ] - async def start(self): + async def start(self) -> None: """Start the robot modules.""" # Start Dimos self.dimos = core.start(2) # Need 2 workers for ZED and manipulation modules @@ -109,7 +107,7 @@ async def start(self): logger.info("PiperArmRobot initialized and started") def pick_and_place( - self, pick_x: int, pick_y: int, place_x: Optional[int] = None, place_y: Optional[int] = None + self, pick_x: int, pick_y: int, place_x: int | None = None, place_y: int | None = None ): """Execute pick and place task. @@ -143,7 +141,7 @@ def handle_keyboard_command(self, key: str): logger.error("Manipulation module not initialized") return None - def stop(self): + def stop(self) -> None: """Stop all modules and clean up.""" logger.info("Stopping PiperArmRobot...") @@ -163,7 +161,7 @@ def stop(self): logger.info("PiperArmRobot stopped") -async def run_piper_arm(): +async def run_piper_arm() -> None: """Run the Piper Arm robot.""" robot = PiperArmRobot() diff --git a/dimos/robot/agilex/run.py b/dimos/robot/agilex/run.py index a2db03c898..90258e5d82 100644 --- a/dimos/robot/agilex/run.py +++ b/dimos/robot/agilex/run.py @@ -21,19 +21,18 @@ import asyncio import os import sys -import time -from dotenv import load_dotenv +from dotenv import load_dotenv import reactivex as rx import reactivex.operators as ops -from dimos.robot.agilex.piper_arm import PiperArmRobot from dimos.agents.claude_agent import ClaudeAgent -from dimos.skills.manipulation.pick_and_place import PickAndPlace +from dimos.robot.agilex.piper_arm import PiperArmRobot from dimos.skills.kill_skill import KillSkill -from dimos.web.robot_web_interface import RobotWebInterface +from dimos.skills.manipulation.pick_and_place import PickAndPlace from dimos.stream.audio.pipelines import stt, tts from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface logger = setup_logger("dimos.robot.agilex.run") @@ -64,7 +63,7 @@ - User: "Pick up the coffee mug" You: "I'll pick up the coffee mug for you." [Execute PickAndPlace with object_query="coffee mug"] -- User: "Put the toy on the table" +- User: "Put the toy on the table" You: "I'll place the toy on the table." [Execute PickAndPlace with object_query="toy", target_query="on the table"] - User: "What do you see?" @@ -161,7 +160,7 @@ def main(): logger.info("=" * 60) logger.info("Piper Arm Agent Ready!") - logger.info(f"Web interface available at: http://localhost:5555") + logger.info("Web interface available at: http://localhost:5555") logger.info("Foxglove visualization available at: ws://localhost:8765") logger.info("You can:") logger.info(" - Type commands in the web interface") diff --git a/dimos/robot/all_blueprints.py b/dimos/robot/all_blueprints.py index 2eef48855f..4d1a6e28bf 100644 --- a/dimos/robot/all_blueprints.py +++ b/dimos/robot/all_blueprints.py @@ -14,7 +14,6 @@ from dimos.core.blueprints import ModuleBlueprintSet - # The blueprints are defined as import strings so as not to trigger unnecessary imports. all_blueprints = { "unitree-go2": "dimos.robot.unitree_webrtc.unitree_go2_blueprints:standard", diff --git a/dimos/robot/cli/dimos_robot.py b/dimos/robot/cli/dimos_robot.py index 5b589b3d69..e3736f4665 100644 --- a/dimos/robot/cli/dimos_robot.py +++ b/dimos/robot/cli/dimos_robot.py @@ -12,17 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -import inspect from enum import Enum +import inspect from typing import Optional, get_args, get_origin import typer from dimos.core.blueprints import autoconnect from dimos.core.global_config import GlobalConfig -from dimos.robot.all_blueprints import all_blueprints, get_blueprint_by_name, get_module_by_name from dimos.protocol import pubsub - +from dimos.robot.all_blueprints import all_blueprints, get_blueprint_by_name, get_module_by_name RobotType = Enum("RobotType", {key.replace("-", "_").upper(): key for key in all_blueprints.keys()}) @@ -82,7 +81,7 @@ def create_dynamic_callback(): ) params.append(param) - def callback(**kwargs): + def callback(**kwargs) -> None: ctx = kwargs.pop("ctx") overrides = {k: v for k, v in kwargs.items() if v is not None} ctx.obj = GlobalConfig().model_copy(update=overrides) @@ -102,7 +101,7 @@ def run( extra_modules: list[str] = typer.Option( [], "--extra-module", help="Extra modules to add to the blueprint" ), -): +) -> None: """Run the robot with the specified configuration.""" config: GlobalConfig = ctx.obj pubsub.lcm.autoconf() @@ -117,7 +116,7 @@ def run( @main.command() -def show_config(ctx: typer.Context): +def show_config(ctx: typer.Context) -> None: """Show current configuration status.""" config: GlobalConfig = ctx.obj diff --git a/dimos/robot/connection_interface.py b/dimos/robot/connection_interface.py index 1f327a7939..6480827214 100644 --- a/dimos/robot/connection_interface.py +++ b/dimos/robot/connection_interface.py @@ -13,8 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional + from reactivex.observable import Observable + from dimos.types.vector import Vector __all__ = ["ConnectionInterface"] @@ -44,7 +45,7 @@ def move(self, velocity: Vector, duration: float = 0.0) -> bool: pass @abstractmethod - def get_video_stream(self, fps: int = 30) -> Optional[Observable]: + def get_video_stream(self, fps: int = 30) -> Observable | None: """Get the video stream from the robot's camera. Args: diff --git a/dimos/robot/foxglove_bridge.py b/dimos/robot/foxglove_bridge.py index fa87653624..7a5ef5e33d 100644 --- a/dimos/robot/foxglove_bridge.py +++ b/dimos/robot/foxglove_bridge.py @@ -15,7 +15,6 @@ import asyncio import logging import threading -from typing import List, Optional # this is missing, I'm just trying to import lcm_foxglove_bridge.py from dimos_lcm from dimos_lcm.foxglove_bridge import FoxgloveBridge as LCMFoxgloveBridge @@ -30,15 +29,15 @@ class FoxgloveBridge(Module): _thread: threading.Thread _loop: asyncio.AbstractEventLoop - def __init__(self, *args, shm_channels=None, **kwargs): + def __init__(self, *args, shm_channels=None, **kwargs) -> None: super().__init__(*args, **kwargs) self.shm_channels = shm_channels or [] @rpc - def start(self): + def start(self) -> None: super().start() - def run_bridge(): + def run_bridge() -> None: self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) try: @@ -63,7 +62,7 @@ def run_bridge(): self._thread.start() @rpc - def stop(self): + def stop(self) -> None: if self._loop and self._loop.is_running(): self._loop.call_soon_threadsafe(self._loop.stop) self._thread.join(timeout=2) @@ -73,12 +72,14 @@ def stop(self): def deploy( dimos: DimosCluster, - shm_channels: Optional[List[str]] = [ - "/image#sensor_msgs.Image", - "/lidar#sensor_msgs.PointCloud2", - "/map#sensor_msgs.PointCloud2", - ], + shm_channels: list[str] | None = None, ) -> FoxgloveBridge: + if shm_channels is None: + shm_channels = [ + "/image#sensor_msgs.Image", + "/lidar#sensor_msgs.PointCloud2", + "/map#sensor_msgs.PointCloud2", + ] foxglove_bridge = dimos.deploy( FoxgloveBridge, shm_channels=shm_channels, @@ -90,4 +91,4 @@ def deploy( foxglove_bridge = FoxgloveBridge.blueprint -__all__ = ["FoxgloveBridge", "foxglove_bridge", "deploy"] +__all__ = ["FoxgloveBridge", "deploy", "foxglove_bridge"] diff --git a/dimos/robot/position_stream.py b/dimos/robot/position_stream.py index 05d80b8bcf..8cb5966b24 100644 --- a/dimos/robot/position_stream.py +++ b/dimos/robot/position_stream.py @@ -19,13 +19,12 @@ """ import logging -from typing import Tuple, Optional import time -from reactivex import Subject, Observable -from reactivex import operators as ops -from rclpy.node import Node + from geometry_msgs.msg import PoseStamped from nav_msgs.msg import Odometry +from rclpy.node import Node +from reactivex import Observable, Subject, operators as ops from dimos.utils.logging_config import setup_logger @@ -44,9 +43,9 @@ def __init__( self, ros_node: Node, odometry_topic: str = "/odom", - pose_topic: Optional[str] = None, + pose_topic: str | None = None, use_odometry: bool = True, - ): + ) -> None: """ Initialize the position stream provider. @@ -90,7 +89,7 @@ def _create_subscription(self): ) logger.info(f"Subscribed to pose topic: {self.pose_topic}") - def _odometry_callback(self, msg: Odometry): + def _odometry_callback(self, msg: Odometry) -> None: """ Process odometry messages and extract position. @@ -102,7 +101,7 @@ def _odometry_callback(self, msg: Odometry): self._update_position(x, y) - def _pose_callback(self, msg: PoseStamped): + def _pose_callback(self, msg: PoseStamped) -> None: """ Process pose messages and extract position. @@ -114,7 +113,7 @@ def _pose_callback(self, msg: PoseStamped): self._update_position(x, y) - def _update_position(self, x: float, y: float): + def _update_position(self, x: float, y: float) -> None: """ Update the current position and emit to subscribers. @@ -146,7 +145,7 @@ def get_position_stream(self) -> Observable: ops.share() # Share the stream among multiple subscribers ) - def get_current_position(self) -> Optional[Tuple[float, float]]: + def get_current_position(self) -> tuple[float, float] | None: """ Get the most recent position. @@ -155,7 +154,7 @@ def get_current_position(self) -> Optional[Tuple[float, float]]: """ return self.last_position - def cleanup(self): + def cleanup(self) -> None: """Clean up resources.""" if hasattr(self, "subscription") and self.subscription: self.ros_node.destroy_subscription(self.subscription) diff --git a/dimos/robot/recorder.py b/dimos/robot/recorder.py index 56b6cea888..acc9c0140e 100644 --- a/dimos/robot/recorder.py +++ b/dimos/robot/recorder.py @@ -14,10 +14,12 @@ # UNDER DEVELOPMENT 🚧🚧🚧, NEEDS TESTING +from collections.abc import Callable +from queue import Queue import threading import time -from queue import Queue -from typing import Callable, Literal +from types import TracebackType +from typing import Literal # from dimos.data.recording import Recorder @@ -41,7 +43,7 @@ def __init__( get_observation: Callable, prepare_action: Callable, frequency_hz: int = 5, - recorder_kwargs: dict = None, + recorder_kwargs: dict | None = None, on_static: Literal["record", "omit"] = "omit", ) -> None: """Initializes the RobotRecorder. @@ -78,11 +80,16 @@ def __init__( self._worker_thread = threading.Thread(target=self._process_queue, daemon=True) self._worker_thread.start() - def __enter__(self): + def __enter__(self) -> None: """Enter the context manager, starting the recording.""" self.start_recording(self.task) - def __exit__(self, exc_type, exc_value, traceback) -> None: + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: """Exit the context manager, stopping the recording.""" self.stop_recording() diff --git a/dimos/robot/robot.py b/dimos/robot/robot.py index 7a0bd27867..002dcb4710 100644 --- a/dimos/robot/robot.py +++ b/dimos/robot/robot.py @@ -15,7 +15,6 @@ """Minimal robot interface for DIMOS robots.""" from abc import ABC, abstractmethod -from typing import List, Optional from reactivex import Observable @@ -33,9 +32,9 @@ class Robot(ABC): can share, with no required methods - just common properties and helpers. """ - def __init__(self): + def __init__(self) -> None: """Initialize the robot with basic properties.""" - self.capabilities: List[RobotCapability] = [] + self.capabilities: list[RobotCapability] = [] self.skill_library = None def has_capability(self, capability: RobotCapability) -> bool: @@ -57,7 +56,7 @@ def get_skills(self): """ return self.skill_library - def cleanup(self): + def cleanup(self) -> None: """Clean up robot resources. Override this method to provide cleanup logic. @@ -81,7 +80,7 @@ def is_exploration_active(self) -> bool: ... @property @abstractmethod - def spatial_memory(self) -> Optional[SpatialMemory]: ... + def spatial_memory(self) -> SpatialMemory | None: ... # TODO: Delete diff --git a/dimos/robot/ros_bridge.py b/dimos/robot/ros_bridge.py index d77d5eb1fb..b067f88a22 100644 --- a/dimos/robot/ros_bridge.py +++ b/dimos/robot/ros_bridge.py @@ -12,16 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum import logging import threading -from typing import Dict, Any, Type, Optional -from enum import Enum +from typing import Any try: import rclpy from rclpy.executors import SingleThreadedExecutor from rclpy.node import Node - from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy, QoSDurabilityPolicy + from rclpy.qos import QoSDurabilityPolicy, QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy except ImportError: rclpy = None SingleThreadedExecutor = None @@ -48,7 +48,7 @@ class BridgeDirection(Enum): class ROSBridge(Resource): """Unidirectional bridge between ROS and DIMOS for message passing.""" - def __init__(self, node_name: str = "dimos_ros_bridge"): + def __init__(self, node_name: str = "dimos_ros_bridge") -> None: """Initialize the ROS-DIMOS bridge. Args: @@ -67,7 +67,7 @@ def __init__(self, node_name: str = "dimos_ros_bridge"): self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) self._spin_thread.start() # TODO: don't forget to shut it down - self._bridges: Dict[str, Dict[str, Any]] = {} + self._bridges: dict[str, dict[str, Any]] = {} self._qos = QoSProfile( reliability=QoSReliabilityPolicy.RELIABLE, @@ -91,7 +91,7 @@ def stop(self) -> None: logger.info("ROSBridge shutdown complete") - def _ros_spin(self): + def _ros_spin(self) -> None: """Background thread for spinning ROS executor.""" try: self._executor.spin() @@ -101,10 +101,10 @@ def _ros_spin(self): def add_topic( self, topic_name: str, - dimos_type: Type, - ros_type: Type, + dimos_type: type, + ros_type: type, direction: BridgeDirection, - remap_topic: Optional[str] = None, + remap_topic: str | None = None, ) -> None: """Add unidirectional bridging for a topic. @@ -138,7 +138,7 @@ def add_topic( if direction == BridgeDirection.ROS_TO_DIMOS: - def ros_callback(msg): + def ros_callback(msg) -> None: self._ros_to_dimos(msg, dimos_topic, dimos_type, topic_name) ros_subscription = self.node.create_subscription( @@ -149,7 +149,7 @@ def ros_callback(msg): elif direction == BridgeDirection.DIMOS_TO_ROS: ros_publisher = self.node.create_publisher(ros_type, ros_topic_name, self._qos) - def dimos_callback(msg, _topic): + def dimos_callback(msg, _topic) -> None: self._dimos_to_ros(msg, ros_publisher, topic_name) dimos_subscription = self.lcm.subscribe(dimos_topic, dimos_callback) @@ -180,7 +180,7 @@ def dimos_callback(msg, _topic): logger.info(f" DIMOS type: {dimos_type.__name__}, ROS type: {ros_type.__name__}") def _ros_to_dimos( - self, ros_msg: Any, dimos_topic: Topic, dimos_type: Type, _topic_name: str + self, ros_msg: Any, dimos_topic: Topic, dimos_type: type, _topic_name: str ) -> None: """Convert ROS message to DIMOS and publish. diff --git a/dimos/robot/ros_command_queue.py b/dimos/robot/ros_command_queue.py index fc48ce5cde..770f44e1a6 100644 --- a/dimos/robot/ros_command_queue.py +++ b/dimos/robot/ros_command_queue.py @@ -20,12 +20,14 @@ Commands are processed sequentially and only when the robot is in IDLE state. """ +from collections.abc import Callable +from enum import Enum, auto +from queue import Empty, PriorityQueue import threading import time +from typing import Any, NamedTuple import uuid -from enum import Enum, auto -from queue import PriorityQueue, Empty -from typing import Callable, Optional, NamedTuple, Dict, Any + from dimos.utils.logging_config import setup_logger # Initialize logger for the ros command queue module @@ -56,7 +58,7 @@ class ROSCommand(NamedTuple): id: str # Unique ID for tracking cmd_type: CommandType # Type of command execute_func: Callable # Function to execute the command - params: Dict[str, Any] # Parameters for the command (for debugging/logging) + params: dict[str, Any] # Parameters for the command (for debugging/logging) priority: int # Priority level (lower is higher priority) timeout: float # How long to wait for this command to complete @@ -72,10 +74,10 @@ class ROSCommandQueue: def __init__( self, webrtc_func: Callable, - is_ready_func: Callable[[], bool] = None, - is_busy_func: Optional[Callable[[], bool]] = None, + is_ready_func: Callable[[], bool] | None = None, + is_busy_func: Callable[[], bool] | None = None, debug: bool = True, - ): + ) -> None: """ Initialize the ROSCommandQueue. @@ -116,7 +118,7 @@ def __init__( logger.info("ROSCommandQueue initialized") - def start(self): + def start(self) -> None: """Start the queue processing thread""" if self._queue_thread is not None and self._queue_thread.is_alive(): logger.warning("Queue processing thread already running") @@ -127,7 +129,7 @@ def start(self): self._queue_thread.start() logger.info("Queue processing thread started") - def stop(self, timeout=2.0): + def stop(self, timeout: float = 2.0) -> None: """ Stop the queue processing thread @@ -151,10 +153,10 @@ def stop(self, timeout=2.0): def queue_webrtc_request( self, api_id: int, - topic: str = None, + topic: str | None = None, parameter: str = "", - request_id: str = None, - data: Dict[str, Any] = None, + request_id: str | None = None, + data: dict[str, Any] | None = None, priority: int = 0, timeout: float = 30.0, ) -> str: @@ -176,7 +178,7 @@ def queue_webrtc_request( request_id = request_id or str(uuid.uuid4()) # Create a function that will execute this WebRTC request - def execute_webrtc(): + def execute_webrtc() -> bool: try: logger.info(f"Executing WebRTC request: {api_id} (ID: {request_id})") if self._debug: @@ -297,7 +299,7 @@ def queue_action_client_request( return request_id - def _process_queue(self): + def _process_queue(self) -> None: """Process commands in the queue""" logger.info("Starting queue processing") logger.info("[WebRTC Queue] Processing thread started") @@ -426,7 +428,7 @@ def _process_queue(self): logger.info("Queue processing stopped") - def _print_queue_status(self): + def _print_queue_status(self) -> None: """Print the current queue status""" current_time = time.time() @@ -435,7 +437,7 @@ def _print_queue_status(self): return is_ready = self._is_ready_func() - is_busy = self._is_busy_func() if self._is_busy_func else False + self._is_busy_func() if self._is_busy_func else False queue_size = self.queue_size # Get information about the current command @@ -466,6 +468,6 @@ def queue_size(self) -> int: return self._queue.qsize() @property - def current_command(self) -> Optional[ROSCommand]: + def current_command(self) -> ROSCommand | None: """Get the current command being processed""" return self._current_command diff --git a/dimos/robot/ros_control.py b/dimos/robot/ros_control.py index 6aa51fc3a8..2e9eb95204 100644 --- a/dimos/robot/ros_control.py +++ b/dimos/robot/ros_control.py @@ -12,42 +12,38 @@ # See the License for the specific language governing permissions and # limitations under the License. -import rclpy -from rclpy.node import Node -from rclpy.executors import MultiThreadedExecutor -from rclpy.action import ActionClient -from geometry_msgs.msg import Twist -from nav2_msgs.action import Spin - -from sensor_msgs.msg import Image, CompressedImage -from cv_bridge import CvBridge +from abc import ABC, abstractmethod from enum import Enum, auto +import math import threading import time -from typing import Optional, Dict, Any, Type -from abc import ABC, abstractmethod +from typing import Any + +from builtin_interfaces.msg import Duration +from cv_bridge import CvBridge +from geometry_msgs.msg import Point, Twist, Vector3 +from nav2_msgs.action import Spin +from nav_msgs.msg import OccupancyGrid, Odometry +import rclpy +from rclpy.action import ActionClient +from rclpy.executors import MultiThreadedExecutor +from rclpy.node import Node from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, QoSProfile, QoSReliabilityPolicy, - QoSHistoryPolicy, - QoSDurabilityPolicy, ) -from dimos.stream.ros_video_provider import ROSVideoProvider -import math -from builtin_interfaces.msg import Duration -from geometry_msgs.msg import Point, Vector3 -from dimos.robot.ros_command_queue import ROSCommandQueue -from dimos.utils.logging_config import setup_logger - -from nav_msgs.msg import OccupancyGrid - +from sensor_msgs.msg import CompressedImage, Image import tf2_ros -from dimos.robot.ros_transform import ROSTransformAbility -from dimos.robot.ros_observable_topic import ROSObservableTopicAbility + from dimos.robot.connection_interface import ConnectionInterface +from dimos.robot.ros_command_queue import ROSCommandQueue +from dimos.robot.ros_observable_topic import ROSObservableTopicAbility +from dimos.robot.ros_transform import ROSTransformAbility +from dimos.stream.ros_video_provider import ROSVideoProvider from dimos.types.vector import Vector - -from nav_msgs.msg import Odometry +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.robot.ros_control") @@ -70,24 +66,24 @@ class ROSControl(ROSTransformAbility, ROSObservableTopicAbility, ConnectionInter def __init__( self, node_name: str, - camera_topics: Dict[str, str] = None, + camera_topics: dict[str, str] | None = None, max_linear_velocity: float = 1.0, mock_connection: bool = False, max_angular_velocity: float = 2.0, - state_topic: str = None, - imu_topic: str = None, - state_msg_type: Type = None, - imu_msg_type: Type = None, - webrtc_topic: str = None, - webrtc_api_topic: str = None, - webrtc_msg_type: Type = None, - move_vel_topic: str = None, - pose_topic: str = None, + state_topic: str | None = None, + imu_topic: str | None = None, + state_msg_type: type | None = None, + imu_msg_type: type | None = None, + webrtc_topic: str | None = None, + webrtc_api_topic: str | None = None, + webrtc_msg_type: type | None = None, + move_vel_topic: str | None = None, + pose_topic: str | None = None, odom_topic: str = "/odom", global_costmap_topic: str = "map", costmap_topic: str = "/local_costmap/costmap", debug: bool = False, - ): + ) -> None: """ Initialize base ROS control interface Args: @@ -269,7 +265,7 @@ def __init__( logger.info(f"{node_name} initialized with multi-threaded executor") print(f"{node_name} initialized with multi-threaded executor") - def get_global_costmap(self) -> Optional[OccupancyGrid]: + def get_global_costmap(self) -> OccupancyGrid | None: """ Get current global_costmap data @@ -287,25 +283,25 @@ def get_global_costmap(self) -> Optional[OccupancyGrid]: else: return None - def _global_costmap_callback(self, msg): + def _global_costmap_callback(self, msg) -> None: """Callback for costmap data""" self._global_costmap_data = msg - def _imu_callback(self, msg): + def _imu_callback(self, msg) -> None: """Callback for IMU data""" self._imu_state = msg # Log IMU state (very verbose) # logger.debug(f"IMU state updated: {self._imu_state}") - def _odom_callback(self, msg): + def _odom_callback(self, msg) -> None: """Callback for odometry data""" self._odom_data = msg - def _costmap_callback(self, msg): + def _costmap_callback(self, msg) -> None: """Callback for costmap data""" self._costmap_data = msg - def _state_callback(self, msg): + def _state_callback(self, msg) -> None: """Callback for state messages to track mode and progress""" # Call the abstract method to update RobotMode enum based on the received state @@ -315,11 +311,11 @@ def _state_callback(self, msg): # logger.debug(f"Robot state updated: {self._robot_state}") @property - def robot_state(self) -> Optional[Any]: + def robot_state(self) -> Any | None: """Get the full robot state message""" return self._robot_state - def _ros_spin(self): + def _ros_spin(self) -> None: """Background thread for spinning the multi-threaded executor.""" self._executor.add_node(self._node) try: @@ -336,7 +332,7 @@ def _update_mode(self, *args, **kwargs): """Update robot mode based on state - to be implemented by child classes""" pass - def get_state(self) -> Optional[Any]: + def get_state(self) -> Any | None: """ Get current robot state @@ -352,7 +348,7 @@ def get_state(self) -> Optional[Any]: return self._robot_state - def get_imu_state(self) -> Optional[Any]: + def get_imu_state(self) -> Any | None: """ Get current IMU state @@ -367,7 +363,7 @@ def get_imu_state(self) -> Optional[Any]: return None return self._imu_state - def get_odometry(self) -> Optional[Odometry]: + def get_odometry(self) -> Odometry | None: """ Get current odometry data @@ -381,7 +377,7 @@ def get_odometry(self) -> Optional[Odometry]: return None return self._odom_data - def get_costmap(self) -> Optional[OccupancyGrid]: + def get_costmap(self) -> OccupancyGrid | None: """ Get current costmap data @@ -393,7 +389,7 @@ def get_costmap(self) -> Optional[OccupancyGrid]: return None return self._costmap_data - def _image_callback(self, msg): + def _image_callback(self, msg) -> None: """Convert ROS image to numpy array and push to data stream""" if self._video_provider and self._bridge: try: @@ -407,14 +403,14 @@ def _image_callback(self, msg): self._video_provider.push_data(frame) except Exception as e: logger.error(f"Error converting image: {e}") - print(f"Full conversion error: {str(e)}") + print(f"Full conversion error: {e!s}") @property - def video_provider(self) -> Optional[ROSVideoProvider]: + def video_provider(self) -> ROSVideoProvider | None: """Data provider property for streaming data""" return self._video_provider - def get_video_stream(self, fps: int = 30) -> Optional[Observable]: + def get_video_stream(self, fps: int = 30) -> Observable | None: """Get the video stream from the robot's camera. Args: @@ -428,7 +424,9 @@ def get_video_stream(self, fps: int = 30) -> Optional[Observable]: return self.video_provider.get_stream(fps=fps) - def _send_action_client_goal(self, client, goal_msg, description=None, time_allowance=20.0): + def _send_action_client_goal( + self, client, goal_msg, description: str | None = None, time_allowance: float = 20.0 + ) -> bool: """ Generic function to send any action client goal and wait for completion. @@ -656,7 +654,7 @@ def stop(self) -> bool: logger.error(f"Failed to stop movement: {e}") return False - def cleanup(self): + def cleanup(self) -> None: """Cleanup the executor, ROS node, and stop robot.""" self.stop() @@ -679,10 +677,10 @@ def disconnect(self) -> None: def webrtc_req( self, api_id: int, - topic: str = None, + topic: str | None = None, parameter: str = "", priority: int = 0, - request_id: str = None, + request_id: str | None = None, data=None, ) -> bool: """ @@ -725,7 +723,7 @@ def get_robot_mode(self) -> RobotMode: """ return self._mode - def print_robot_mode(self): + def print_robot_mode(self) -> None: """Print the current robot mode to the console""" mode = self.get_robot_mode() print(f"Current RobotMode: {mode.name}") @@ -734,11 +732,11 @@ def print_robot_mode(self): def queue_webrtc_req( self, api_id: int, - topic: str = None, + topic: str | None = None, parameter: str = "", priority: int = 0, timeout: float = 90.0, - request_id: str = None, + request_id: str | None = None, data=None, ) -> str: """ @@ -840,7 +838,7 @@ def get_position_stream(self): return position_provider.get_position_stream() - def _goal_response_callback(self, future): + def _goal_response_callback(self, future) -> None: """Handle the goal response.""" goal_handle = future.result() if not goal_handle.accepted: @@ -854,7 +852,7 @@ def _goal_response_callback(self, future): result_future = goal_handle.get_result_async() result_future.add_done_callback(self._goal_result_callback) - def _goal_result_callback(self, future): + def _goal_result_callback(self, future) -> None: """Handle the goal result.""" try: result = future.result().result diff --git a/dimos/robot/ros_observable_topic.py b/dimos/robot/ros_observable_topic.py index ef99ceadee..7cfc70fd8b 100644 --- a/dimos/robot/ros_observable_topic.py +++ b/dimos/robot/ros_observable_topic.py @@ -12,30 +12,30 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import functools +from collections.abc import Callable import enum +import functools +from typing import Any, Union + +from nav_msgs import msg +from rclpy.qos import ( + QoSDurabilityPolicy, + QoSHistoryPolicy, + QoSProfile, + QoSReliabilityPolicy, +) import reactivex as rx from reactivex import operators as ops from reactivex.disposable import Disposable from reactivex.scheduler import ThreadPoolScheduler from rxpy_backpressure import BackPressure -from nav_msgs import msg +from dimos.msgs.nav_msgs import OccupancyGrid +from dimos.types.vector import Vector from dimos.utils.logging_config import setup_logger from dimos.utils.threadpool import get_scheduler -from dimos.types.vector import Vector -from dimos.msgs.nav_msgs import OccupancyGrid - -from typing import Union, Callable, Any - -from rclpy.qos import ( - QoSProfile, - QoSReliabilityPolicy, - QoSHistoryPolicy, - QoSDurabilityPolicy, -) -__all__ = ["ROSObservableTopicAbility", "QOS"] +__all__ = ["QOS", "ROSObservableTopicAbility"] TopicType = Union[OccupancyGrid, msg.OccupancyGrid, msg.Odometry] @@ -97,7 +97,7 @@ def _sub_msg_type(self, msg_type): return msg_type - @functools.lru_cache(maxsize=None) + @functools.cache def topic( self, topic_name: str, @@ -219,7 +219,7 @@ async def topic_latest_async( core = self.topic(topic_name, msg_type, qos=qos) # single ROS callback - def _on_next(v): + def _on_next(v) -> None: cache["val"] = v if not first.done(): loop.call_soon_threadsafe(first.set_result, v) diff --git a/dimos/robot/ros_transform.py b/dimos/robot/ros_transform.py index b0c46fd275..d54eb8cd15 100644 --- a/dimos/robot/ros_transform.py +++ b/dimos/robot/ros_transform.py @@ -12,16 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import rclpy -from typing import Optional + from geometry_msgs.msg import TransformStamped -from tf2_ros import Buffer -import tf2_ros +import rclpy +from scipy.spatial.transform import Rotation as R from tf2_geometry_msgs import PointStamped -from dimos.utils.logging_config import setup_logger -from dimos.types.vector import Vector +import tf2_ros +from tf2_ros import Buffer + from dimos.types.path import Path -from scipy.spatial.transform import Rotation as R +from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.robot.ros_transform") @@ -70,7 +71,7 @@ def transform_euler(self, source_frame: str, target_frame: str = "map", timeout: def transform( self, source_frame: str, target_frame: str = "map", timeout: float = 1.0 - ) -> Optional[TransformStamped]: + ) -> TransformStamped | None: try: transform = self.tf_buffer.lookup_transform( target_frame, diff --git a/dimos/robot/test_ros_bridge.py b/dimos/robot/test_ros_bridge.py index a4c0c16ed7..435766b938 100644 --- a/dimos/robot/test_ros_bridge.py +++ b/dimos/robot/test_ros_bridge.py @@ -12,21 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import threading +import time import unittest -import numpy as np +import numpy as np import pytest try: + from geometry_msgs.msg import TransformStamped, TwistStamped as ROSTwistStamped import rclpy from rclpy.node import Node - from geometry_msgs.msg import TwistStamped as ROSTwistStamped - from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 - from sensor_msgs.msg import PointField + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2, PointField from tf2_msgs.msg import TFMessage as ROSTFMessage - from geometry_msgs.msg import TransformStamped except ImportError: rclpy = None Node = None @@ -36,18 +34,18 @@ ROSTFMessage = None TransformStamped = None -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic from dimos.msgs.geometry_msgs import TwistStamped from dimos.msgs.sensor_msgs import PointCloud2 from dimos.msgs.tf2_msgs import TFMessage -from dimos.robot.ros_bridge import ROSBridge, BridgeDirection +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.ros_bridge import BridgeDirection, ROSBridge @pytest.mark.ros class TestROSBridge(unittest.TestCase): """Test suite for ROS-DIMOS bridge.""" - def setUp(self): + def setUp(self) -> None: """Set up test fixtures.""" # Skip if ROS is not available if rclpy is None: @@ -68,14 +66,14 @@ def setUp(self): self.dimos_messages = [] self.message_timestamps = {"ros": [], "dimos": []} - def tearDown(self): + def tearDown(self) -> None: """Clean up test fixtures.""" self.test_node.destroy_node() self.bridge.stop() if rclpy.ok(): rclpy.try_shutdown() - def test_ros_to_dimos_twist(self): + def test_ros_to_dimos_twist(self) -> None: """Test ROS TwistStamped to DIMOS conversion and transmission.""" # Set up bridge self.bridge.add_topic( @@ -87,7 +85,7 @@ def test_ros_to_dimos_twist(self): lcm.start() topic = Topic("/test_twist", TwistStamped) - def dimos_callback(msg, _topic): + def dimos_callback(msg, _topic) -> None: self.dimos_messages.append(msg) self.message_timestamps["dimos"].append(time.time()) @@ -122,7 +120,7 @@ def dimos_callback(msg, _topic): self.assertAlmostEqual(msg.linear.y, float(i * 2), places=5) self.assertAlmostEqual(msg.angular.z, float(i * 0.1), places=5) - def test_dimos_to_ros_twist(self): + def test_dimos_to_ros_twist(self) -> None: """Test DIMOS TwistStamped to ROS conversion and transmission.""" # Set up bridge self.bridge.add_topic( @@ -130,7 +128,7 @@ def test_dimos_to_ros_twist(self): ) # Subscribe to ROS side - def ros_callback(msg): + def ros_callback(msg) -> None: self.ros_messages.append(msg) self.message_timestamps["ros"].append(time.time()) @@ -164,7 +162,7 @@ def ros_callback(msg): self.assertAlmostEqual(msg.twist.linear.y, float(i * 4), places=5) self.assertAlmostEqual(msg.twist.angular.z, float(i * 0.2), places=5) - def test_frequency_preservation(self): + def test_frequency_preservation(self) -> None: """Test that message frequencies are preserved through the bridge.""" # Set up bridge self.bridge.add_topic( @@ -178,7 +176,7 @@ def test_frequency_preservation(self): receive_times = [] - def dimos_callback(_msg, _topic): + def dimos_callback(_msg, _topic) -> None: receive_times.append(time.time()) lcm.subscribe(topic, dimos_callback) @@ -229,7 +227,7 @@ def dimos_callback(_msg, _topic): msg=f"Frequency not preserved for {target_freq}Hz: sent={send_freq:.1f}Hz, received={receive_freq:.1f}Hz", ) - def test_pointcloud_conversion(self): + def test_pointcloud_conversion(self) -> None: """Test PointCloud2 message conversion with numpy optimization.""" # Set up bridge self.bridge.add_topic( @@ -243,7 +241,7 @@ def test_pointcloud_conversion(self): received_cloud = [] - def dimos_callback(msg, _topic): + def dimos_callback(msg, _topic) -> None: received_cloud.append(msg) lcm.subscribe(topic, dimos_callback) @@ -286,7 +284,7 @@ def dimos_callback(msg, _topic): self.assertEqual(received_points.shape, points.shape) np.testing.assert_array_almost_equal(received_points, points, decimal=5) - def test_tf_high_frequency(self): + def test_tf_high_frequency(self) -> None: """Test TF message handling at high frequency.""" # Set up bridge self.bridge.add_topic("/test_tf", TFMessage, ROSTFMessage, BridgeDirection.ROS_TO_DIMOS) @@ -299,7 +297,7 @@ def test_tf_high_frequency(self): received_tfs = [] receive_times = [] - def dimos_callback(msg, _topic): + def dimos_callback(msg, _topic) -> None: received_tfs.append(msg) receive_times.append(time.time()) @@ -351,7 +349,7 @@ def dimos_callback(msg, _topic): msg=f"High frequency TF not preserved: expected={target_freq}Hz, got={receive_freq:.1f}Hz", ) - def test_bidirectional_bridge(self): + def test_bidirectional_bridge(self) -> None: """Test simultaneous bidirectional message flow.""" # Set up bidirectional bridges for same topic type self.bridge.add_topic( @@ -382,7 +380,7 @@ def test_bidirectional_bridge(self): stop_spinning = threading.Event() # Spin the test node in background to receive messages - def spin_test_node(): + def spin_test_node() -> None: while not stop_spinning.is_set(): rclpy.spin_once(self.test_node, timeout_sec=0.01) @@ -390,7 +388,7 @@ def spin_test_node(): spin_thread.start() # Send messages in both directions simultaneously - def send_ros_messages(): + def send_ros_messages() -> None: for i in range(50): msg = ROSTwistStamped() msg.header.stamp = self.test_node.get_clock().now().to_msg() @@ -398,7 +396,7 @@ def send_ros_messages(): ros_pub.publish(msg) time.sleep(0.02) # 50Hz - def send_dimos_messages(): + def send_dimos_messages() -> None: for i in range(50): msg = TwistStamped(ts=time.time()) msg.linear.y = float(i * 2) diff --git a/dimos/robot/test_ros_observable_topic.py b/dimos/robot/test_ros_observable_topic.py index 71a1484de3..0ffed24d35 100644 --- a/dimos/robot/test_ros_observable_topic.py +++ b/dimos/robot/test_ros_observable_topic.py @@ -13,16 +13,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import asyncio import threading import time + import pytest -from dimos.utils.logging_config import setup_logger + from dimos.types.vector import Vector -import asyncio +from dimos.utils.logging_config import setup_logger class MockROSNode: - def __init__(self): + def __init__(self) -> None: self.logger = setup_logger("ROS") self.sub_id_cnt = 0 @@ -33,7 +35,7 @@ def _get_sub_id(self): self.sub_id_cnt += 1 return sub_id - def create_subscription(self, msg_type, topic_name, callback, qos): + def create_subscription(self, msg_type, topic_name: str, callback, qos): # Mock implementation of ROS subscription sub_id = self._get_sub_id() @@ -42,7 +44,7 @@ def create_subscription(self, msg_type, topic_name, callback, qos): self.logger.info(f"Subscribed {topic_name} subid {sub_id}") # Create message simulation thread - def simulate_messages(): + def simulate_messages() -> None: message_count = 0 while not stop_event.is_set(): message_count += 1 @@ -58,7 +60,7 @@ def simulate_messages(): thread.start() return sub_id - def destroy_subscription(self, subscription): + def destroy_subscription(self, subscription) -> None: if subscription in self.subs: self.subs[subscription].set() self.logger.info(f"Destroyed subscription: {subscription}") @@ -72,7 +74,7 @@ def robot(): from dimos.robot.ros_observable_topic import ROSObservableTopicAbility class MockRobot(ROSObservableTopicAbility): - def __init__(self): + def __init__(self) -> None: self.logger = setup_logger("ROBOT") # Initialize the mock ROS node self._node = MockROSNode() @@ -88,7 +90,7 @@ def __init__(self): # 4. that the system replays the last message to new observers, # before the new ROS sub starts producing @pytest.mark.ros -def test_parallel_and_cleanup(robot): +def test_parallel_and_cleanup(robot) -> None: from nav_msgs import msg received_messages = [] @@ -160,7 +162,7 @@ def test_parallel_and_cleanup(robot): # ├──► observe_on(pool) ─► backpressure.latest ─► sub2 (slow) # └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) @pytest.mark.ros -def test_parallel_and_hog(robot): +def test_parallel_and_hog(robot) -> None: from nav_msgs import msg obs1 = robot.topic("/odom", msg.Odometry) @@ -200,7 +202,7 @@ def test_parallel_and_hog(robot): @pytest.mark.asyncio @pytest.mark.ros -async def test_topic_latest_async(robot): +async def test_topic_latest_async(robot) -> None: from nav_msgs import msg odom = await robot.topic_latest_async("/odom", msg.Odometry) @@ -213,14 +215,14 @@ async def test_topic_latest_async(robot): @pytest.mark.ros -def test_topic_auto_conversion(robot): +def test_topic_auto_conversion(robot) -> None: odom = robot.topic("/vector", Vector).subscribe(lambda x: print(x)) time.sleep(0.5) odom.dispose() @pytest.mark.ros -def test_topic_latest_sync(robot): +def test_topic_latest_sync(robot) -> None: from nav_msgs import msg odom = robot.topic_latest("/odom", msg.Odometry) @@ -233,13 +235,13 @@ def test_topic_latest_sync(robot): @pytest.mark.ros -def test_topic_latest_sync_benchmark(robot): +def test_topic_latest_sync_benchmark(robot) -> None: from nav_msgs import msg odom = robot.topic_latest("/odom", msg.Odometry) start_time = time.time() - for i in range(100): + for _i in range(100): odom() end_time = time.time() elapsed = end_time - start_time diff --git a/dimos/robot/unitree/connection/connection.py b/dimos/robot/unitree/connection/connection.py index fc9714c8ba..0d904df7c4 100644 --- a/dimos/robot/unitree/connection/connection.py +++ b/dimos/robot/unitree/connection/connection.py @@ -13,19 +13,19 @@ # limitations under the License. import asyncio +from dataclasses import dataclass import functools import threading import time -from dataclasses import dataclass -from typing import Optional, TypeAlias +from typing import TypeAlias -import numpy as np from aiortc import MediaStreamTrack from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] Go2WebRTCConnection, WebRTCConnectionMethod, ) +import numpy as np from numpy.typing import NDArray from reactivex import operators as ops from reactivex.observable import Observable @@ -49,12 +49,12 @@ class SerializableVideoFrame: """Pickleable wrapper for av.VideoFrame with all metadata""" data: np.ndarray - pts: Optional[int] = None - time: Optional[float] = None - dts: Optional[int] = None - width: Optional[int] = None - height: Optional[int] = None - format: Optional[str] = None + pts: int | None = None + time: float | None = None + dts: int | None = None + width: int | None = None + height: int | None = None + format: str | None = None @classmethod def from_av_frame(cls, frame): @@ -73,21 +73,21 @@ def to_ndarray(self, format=None): class UnitreeWebRTCConnection(Resource): - def __init__(self, ip: str, mode: str = "ai"): + def __init__(self, ip: str, mode: str = "ai") -> None: self.ip = ip self.mode = mode - self.stop_timer: Optional[threading.Timer] = None + self.stop_timer: threading.Timer | None = None self.cmd_vel_timeout = 0.2 self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) self.connect() - def connect(self): + def connect(self) -> None: self.loop = asyncio.new_event_loop() self.task = None self.connected_event = asyncio.Event() self.connection_ready = threading.Event() - async def async_connect(): + async def async_connect() -> None: await self.conn.connect() await self.conn.datachannel.disableTrafficSaving(True) @@ -103,7 +103,7 @@ async def async_connect(): while True: await asyncio.sleep(1) - def start_background_loop(): + def start_background_loop() -> None: asyncio.set_event_loop(self.loop) self.task = self.loop.create_task(async_connect()) self.loop.run_forever() @@ -156,13 +156,13 @@ def move(self, twist: TwistStamped, duration: float = 0.0) -> bool: # x - Positive right, negative left # y - positive forward, negative backwards # yaw - Positive rotate right, negative rotate left - async def async_move(): + async def async_move() -> None: self.conn.datachannel.pub_sub.publish_without_callback( RTC_TOPIC["WIRELESS_CONTROLLER"], data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, ) - async def async_move_duration(): + async def async_move_duration() -> None: """Send movement commands continuously for the specified duration.""" start_time = time.time() sleep_time = 0.01 @@ -198,17 +198,17 @@ async def async_move_duration(): # Generic conversion of unitree subscription to Subject (used for all subs) def unitree_sub_stream(self, topic_name: str): - def subscribe_in_thread(cb): + def subscribe_in_thread(cb) -> None: # Run the subscription in the background thread that has the event loop - def run_subscription(): + def run_subscription() -> None: self.conn.datachannel.pub_sub.subscribe(topic_name, cb) # Use call_soon_threadsafe to run in the background thread self.loop.call_soon_threadsafe(run_subscription) - def unsubscribe_in_thread(cb): + def unsubscribe_in_thread(cb) -> None: # Run the unsubscription in the background thread that has the event loop - def run_unsubscription(): + def run_unsubscription() -> None: self.conn.datachannel.pub_sub.unsubscribe(topic_name) # Use call_soon_threadsafe to run in the background thread @@ -273,7 +273,7 @@ def lowstate_stream(self) -> Observable[LowStateMsg]: def standup_ai(self): return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) - def standup_normal(self): + def standup_normal(self) -> bool: self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) time.sleep(0.5) self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) @@ -325,17 +325,17 @@ async def accept_track(track: MediaStreamTrack) -> None: self.conn.video.add_track_callback(accept_track) # Run the video channel switching in the background thread - def switch_video_channel(): + def switch_video_channel() -> None: self.conn.video.switchVideoChannel(True) self.loop.call_soon_threadsafe(switch_video_channel) - def stop(): + def stop() -> None: stop_event.set() # Signal the loop to stop self.conn.video.track_callbacks.remove(accept_track) # Run the video channel switching off in the background thread - def switch_video_channel_off(): + def switch_video_channel_off() -> None: self.conn.video.switchVideoChannel(False) self.loop.call_soon_threadsafe(switch_video_channel_off) @@ -381,7 +381,7 @@ def disconnect(self) -> None: self.task.cancel() if hasattr(self, "conn"): - async def async_disconnect(): + async def async_disconnect() -> None: try: await self.conn.disconnect() except: diff --git a/dimos/robot/unitree/connection/g1.py b/dimos/robot/unitree/connection/g1.py index 88386a59ed..8e63cbb40a 100644 --- a/dimos/robot/unitree/connection/g1.py +++ b/dimos/robot/unitree/connection/g1.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, cast from dimos import spec from dimos.core import DimosCluster, In, Module, rpc @@ -25,11 +24,11 @@ class G1Connection(Module): cmd_vel: In[TwistStamped] = None # type: ignore - ip: Optional[str] + ip: str | None connection: UnitreeWebRTCConnection - def __init__(self, ip: Optional[str] = None, **kwargs): + def __init__(self, ip: str | None = None, **kwargs) -> None: super().__init__(**kwargs) if ip is None: @@ -37,7 +36,7 @@ def __init__(self, ip: Optional[str] = None, **kwargs): self.connection = UnitreeWebRTCConnection(ip) @rpc - def start(self): + def start(self) -> None: super().start() self.connection.start() self._disposables.add( @@ -50,7 +49,7 @@ def stop(self) -> None: super().stop() @rpc - def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: """Send movement command to robot.""" twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) self.connection.move(twist, duration) diff --git a/dimos/robot/unitree/connection/go2.py b/dimos/robot/unitree/connection/go2.py index a55d8a8bdd..3dcda0f7d7 100644 --- a/dimos/robot/unitree/connection/go2.py +++ b/dimos/robot/unitree/connection/go2.py @@ -13,9 +13,9 @@ # limitations under the License. import logging -import time from threading import Thread -from typing import List, Optional, Protocol +import time +from typing import Protocol from dimos_lcm.sensor_msgs import CameraInfo from reactivex.observable import Observable @@ -96,7 +96,7 @@ class ReplayConnection(UnitreeWebRTCConnection): def __init__( self, **kwargs, - ): + ) -> None: get_data(self.dir_name) self.replay_config = { "loop": kwargs.get("loop"), @@ -104,16 +104,16 @@ def __init__( "duration": kwargs.get("duration"), } - def connect(self): + def connect(self) -> None: pass - def start(self): + def start(self) -> None: pass - def standup(self): + def standup(self) -> None: print("standup suppressed") - def liedown(self): + def liedown(self) -> None: print("liedown suppressed") @simple_mcache @@ -136,7 +136,7 @@ def video_stream(self): return video_store.stream(**self.replay_config) - def move(self, twist: TwistStamped, duration: float = 0.0): + def move(self, twist: TwistStamped, duration: float = 0.0) -> None: pass def publish_request(self, topic: str, data: dict): @@ -153,16 +153,16 @@ class GO2Connection(Module, spec.Camera, spec.Pointcloud): connection: Go2ConnectionProtocol - ip: Optional[str] + ip: str | None camera_info: CameraInfo = camera_info def __init__( self, - ip: Optional[str] = None, + ip: str | None = None, *args, **kwargs, - ): + ) -> None: match ip: case None | "fake" | "mock" | "replay": self.connection = ReplayConnection() @@ -214,7 +214,7 @@ def stop(self) -> None: super().stop() @classmethod - def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: + def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: camera_link = Transform( translation=Vector3(0.3, 0.0, 0.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), @@ -246,16 +246,16 @@ def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: sensor, ] - def _publish_tf(self, msg): + def _publish_tf(self, msg) -> None: self.tf.publish(*self._odom_to_tf(msg)) - def publish_camera_info(self): + def publish_camera_info(self) -> None: while True: self.camera_info_stream.publish(camera_info) time.sleep(1.0) @rpc - def move(self, twist: TwistStamped, duration: float = 0.0): + def move(self, twist: TwistStamped, duration: float = 0.0) -> None: """Send movement command to robot.""" self.connection.move(twist, duration) @@ -281,7 +281,7 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) -def deploy(dimos: DimosCluster, ip: str, prefix="") -> GO2Connection: +def deploy(dimos: DimosCluster, ip: str, prefix: str = "") -> GO2Connection: from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE connection = dimos.deploy(GO2Connection, ip) diff --git a/dimos/robot/unitree/g1/g1zed.py b/dimos/robot/unitree/g1/g1zed.py index c4fb9d90c1..607ae3acb6 100644 --- a/dimos/robot/unitree/g1/g1zed.py +++ b/dimos/robot/unitree/g1/g1zed.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, TypedDict, cast +from typing import TypedDict, cast from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE -from dimos.core import DimosCluster, LCMTransport, RPCClient, pSHMTransport +from dimos.core import DimosCluster, LCMTransport, pSHMTransport from dimos.hardware.camera import zed from dimos.hardware.camera.module import CameraModule from dimos.hardware.camera.webcam import Webcam @@ -44,7 +44,7 @@ class G1ZedDeployResult(TypedDict): def deploy_g1_monozed(dimos: DimosCluster) -> CameraModule: camera = cast( - CameraModule, + "CameraModule", dimos.deploy( CameraModule, frequency=4.0, diff --git a/dimos/robot/unitree/go2/go2.py b/dimos/robot/unitree/go2/go2.py index 05c05e7a8e..0e78485adc 100644 --- a/dimos/robot/unitree/go2/go2.py +++ b/dimos/robot/unitree/go2/go2.py @@ -14,9 +14,7 @@ import logging -from dimos import agents2 from dimos.core import DimosCluster -from dimos.perception.detection import moduleDB from dimos.robot import foxglove_bridge from dimos.robot.unitree.connection import go2 from dimos.utils.logging_config import setup_logger diff --git a/dimos/robot/unitree/run.py b/dimos/robot/unitree/run.py index 17f1226fd8..43338c9353 100644 --- a/dimos/robot/unitree/run.py +++ b/dimos/robot/unitree/run.py @@ -27,14 +27,13 @@ import importlib import os import sys -from pathlib import Path from dotenv import load_dotenv from dimos.core import start, wait_exit -def main(): +def main() -> None: load_dotenv() parser = argparse.ArgumentParser(description="Unitree Robot Modular Deployment Runner") @@ -78,7 +77,7 @@ def main(): full_module_path = f"dimos.robot.unitree.{module_path}" print(f"Importing module: {full_module_path}") module = importlib.import_module(full_module_path) - except ImportError as e: + except ImportError: # Try as a relative import from the unitree package try: module = importlib.import_module(f".{module_path}", package="dimos.robot.unitree") @@ -88,10 +87,10 @@ def main(): traceback.print_exc() print(f"\nERROR: Could not import module '{args.module}'") - print(f"Tried importing as:") + print("Tried importing as:") print(f" 1. {full_module_path}") - print(f" 2. Relative import from dimos.robot.unitree") - print(f"Make sure the module exists in dimos/robot/unitree/") + print(" 2. Relative import from dimos.robot.unitree") + print("Make sure the module exists in dimos/robot/unitree/") print(f"Import error: {e2}") sys.exit(1) diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 8ddc77ac63..4aee995c02 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -13,28 +13,27 @@ # limitations under the License. import asyncio +from dataclasses import dataclass import functools import threading import time -from dataclasses import dataclass -from typing import Literal, Optional, TypeAlias +from typing import Literal, TypeAlias -import numpy as np from aiortc import MediaStreamTrack from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] Go2WebRTCConnection, WebRTCConnectionMethod, ) +import numpy as np from reactivex import operators as ops from reactivex.observable import Observable from reactivex.subject import Subject -from dimos.core import In, Module, Out, rpc +from dimos.core import rpc from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import Pose, Transform, Twist, Vector3 +from dimos.msgs.geometry_msgs import Pose, Transform, Twist from dimos.msgs.sensor_msgs import Image -from dimos.robot.connection_interface import ConnectionInterface from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -49,12 +48,12 @@ class SerializableVideoFrame: """Pickleable wrapper for av.VideoFrame with all metadata""" data: np.ndarray - pts: Optional[int] = None - time: Optional[float] = None - dts: Optional[int] = None - width: Optional[int] = None - height: Optional[int] = None - format: Optional[str] = None + pts: int | None = None + time: float | None = None + dts: int | None = None + width: int | None = None + height: int | None = None + format: str | None = None @classmethod def from_av_frame(cls, frame): @@ -73,7 +72,7 @@ def to_ndarray(self, format=None): class UnitreeWebRTCConnection(Resource): - def __init__(self, ip: str, mode: str = "ai"): + def __init__(self, ip: str, mode: str = "ai") -> None: self.ip = ip self.mode = mode self.stop_timer = None @@ -81,13 +80,13 @@ def __init__(self, ip: str, mode: str = "ai"): self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) self.connect() - def connect(self): + def connect(self) -> None: self.loop = asyncio.new_event_loop() self.task = None self.connected_event = asyncio.Event() self.connection_ready = threading.Event() - async def async_connect(): + async def async_connect() -> None: await self.conn.connect() await self.conn.datachannel.disableTrafficSaving(True) @@ -103,7 +102,7 @@ async def async_connect(): while True: await asyncio.sleep(1) - def start_background_loop(): + def start_background_loop() -> None: asyncio.set_event_loop(self.loop) self.task = self.loop.create_task(async_connect()) self.loop.run_forever() @@ -155,13 +154,13 @@ def move(self, twist: Twist, duration: float = 0.0) -> bool: # x - Positive right, negative left # y - positive forward, negative backwards # yaw - Positive rotate right, negative rotate left - async def async_move(): + async def async_move() -> None: self.conn.datachannel.pub_sub.publish_without_callback( RTC_TOPIC["WIRELESS_CONTROLLER"], data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, ) - async def async_move_duration(): + async def async_move_duration() -> None: """Send movement commands continuously for the specified duration.""" start_time = time.time() sleep_time = 0.01 @@ -197,17 +196,17 @@ async def async_move_duration(): # Generic conversion of unitree subscription to Subject (used for all subs) def unitree_sub_stream(self, topic_name: str): - def subscribe_in_thread(cb): + def subscribe_in_thread(cb) -> None: # Run the subscription in the background thread that has the event loop - def run_subscription(): + def run_subscription() -> None: self.conn.datachannel.pub_sub.subscribe(topic_name, cb) # Use call_soon_threadsafe to run in the background thread self.loop.call_soon_threadsafe(run_subscription) - def unsubscribe_in_thread(cb): + def unsubscribe_in_thread(cb) -> None: # Run the unsubscription in the background thread that has the event loop - def run_unsubscription(): + def run_unsubscription() -> None: self.conn.datachannel.pub_sub.unsubscribe(topic_name) # Use call_soon_threadsafe to run in the background thread @@ -272,7 +271,7 @@ def lowstate_stream(self) -> Subject[LowStateMsg]: def standup_ai(self): return self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["BalanceStand"]}) - def standup_normal(self): + def standup_normal(self) -> bool: self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["StandUp"]}) time.sleep(0.5) self.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": SPORT_CMD["RecoveryStand"]}) @@ -324,17 +323,17 @@ async def accept_track(track: MediaStreamTrack) -> VideoMessage: self.conn.video.add_track_callback(accept_track) # Run the video channel switching in the background thread - def switch_video_channel(): + def switch_video_channel() -> None: self.conn.video.switchVideoChannel(True) self.loop.call_soon_threadsafe(switch_video_channel) - def stop(): + def stop() -> None: stop_event.set() # Signal the loop to stop self.conn.video.track_callbacks.remove(accept_track) # Run the video channel switching off in the background thread - def switch_video_channel_off(): + def switch_video_channel_off() -> None: self.conn.video.switchVideoChannel(False) self.loop.call_soon_threadsafe(switch_video_channel_off) @@ -388,7 +387,7 @@ def disconnect(self) -> None: self.task.cancel() if hasattr(self, "conn"): - async def async_disconnect(): + async def async_disconnect() -> None: try: await self.conn.disconnect() except: diff --git a/dimos/robot/unitree_webrtc/depth_module.py b/dimos/robot/unitree_webrtc/depth_module.py index 2e0bd77ee2..9e9b57b24b 100644 --- a/dimos/robot/unitree_webrtc/depth_module.py +++ b/dimos/robot/unitree_webrtc/depth_module.py @@ -14,16 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time import threading -from typing import Optional +import time +from dimos_lcm.sensor_msgs import CameraInfo import numpy as np -from dimos.core import Module, In, Out, rpc +from dimos.core import In, Module, Out, rpc from dimos.core.global_config import GlobalConfig from dimos.msgs.sensor_msgs import Image, ImageFormat -from dimos_lcm.sensor_msgs import CameraInfo from dimos.utils.logging_config import setup_logger logger = setup_logger(__name__) @@ -53,7 +52,7 @@ def __init__( gt_depth_scale: float = 0.5, global_config: GlobalConfig | None = None, **kwargs, - ): + ) -> None: """ Initialize Depth Module. @@ -76,7 +75,7 @@ def __init__( self._cannot_process_depth = False # Threading - self._processing_thread: Optional[threading.Thread] = None + self._processing_thread: threading.Thread | None = None self._stop_processing = threading.Event() if global_config: @@ -84,7 +83,7 @@ def __init__( self.gt_depth_scale = 1.0 @rpc - def start(self): + def start(self) -> None: super().start() if self._running: @@ -104,7 +103,7 @@ def start(self): logger.info("Depth module started") @rpc - def stop(self): + def stop(self) -> None: if not self._running: return @@ -117,7 +116,7 @@ def stop(self): super().stop() - def _on_camera_info(self, msg: CameraInfo): + def _on_camera_info(self, msg: CameraInfo) -> None: """Process camera info to extract intrinsics.""" if self.metric3d is not None: return # Already initialized @@ -145,7 +144,7 @@ def _on_camera_info(self, msg: CameraInfo): except Exception as e: logger.error(f"Error processing camera info: {e}") - def _on_video(self, msg: Image): + def _on_video(self, msg: Image) -> None: """Store latest video frame for processing.""" if not self._running: return @@ -156,14 +155,14 @@ def _on_video(self, msg: Image): f"Received video frame: format={msg.format}, shape={msg.data.shape if hasattr(msg.data, 'shape') else 'unknown'}" ) - def _start_processing_thread(self): + def _start_processing_thread(self) -> None: """Start the processing thread.""" self._stop_processing.clear() self._processing_thread = threading.Thread(target=self._main_processing_loop, daemon=True) self._processing_thread.start() logger.info("Started depth processing thread") - def _main_processing_loop(self): + def _main_processing_loop(self) -> None: """Main processing loop that continuously processes latest frames.""" logger.info("Starting main processing loop") @@ -187,7 +186,7 @@ def _main_processing_loop(self): logger.info("Main processing loop stopped") - def _process_depth(self, img_array: np.ndarray): + def _process_depth(self, img_array: np.ndarray) -> None: """Process depth estimation using Metric3D.""" if self._cannot_process_depth: self._last_depth = None @@ -213,7 +212,7 @@ def _process_depth(self, img_array: np.ndarray): logger.error(f"Error processing depth: {e}") self._cannot_process_depth = True - def _publish_depth(self): + def _publish_depth(self) -> None: """Publish depth image.""" if not self._running: return diff --git a/dimos/robot/unitree_webrtc/g1_joystick_module.py b/dimos/robot/unitree_webrtc/g1_joystick_module.py index 156a0891a2..2c6a5e64e5 100644 --- a/dimos/robot/unitree_webrtc/g1_joystick_module.py +++ b/dimos/robot/unitree_webrtc/g1_joystick_module.py @@ -34,13 +34,13 @@ class G1JoystickModule(Module): twist_out: Out[Twist] = None # Standard velocity commands - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: Module.__init__(self, *args, **kwargs) self.pygame_ready = False self.running = False @rpc - def start(self): + def start(self) -> bool: """Initialize pygame and start control loop.""" super().start() @@ -75,7 +75,7 @@ def stop(self) -> None: self.twist_out.publish(stop_twist) - def _pygame_loop(self): + def _pygame_loop(self) -> None: """Main pygame event loop - ALL pygame operations happen here.""" import pygame @@ -142,7 +142,7 @@ def _pygame_loop(self): pygame.quit() print("G1 JoystickModule stopped") - def _update_display(self, twist): + def _update_display(self, twist) -> None: """Update pygame window with current status.""" import pygame diff --git a/dimos/robot/unitree_webrtc/g1_run.py b/dimos/robot/unitree_webrtc/g1_run.py index 1ac0914470..b8c0bc77c7 100644 --- a/dimos/robot/unitree_webrtc/g1_run.py +++ b/dimos/robot/unitree_webrtc/g1_run.py @@ -18,22 +18,22 @@ Provides interaction capabilities with natural language interface and ZED vision. """ +import argparse import os import sys import time -import argparse -from dotenv import load_dotenv +from dotenv import load_dotenv import reactivex as rx import reactivex.operators as ops +from dimos.agents.claude_agent import ClaudeAgent from dimos.robot.unitree_webrtc.unitree_g1 import UnitreeG1 from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills -from dimos.agents.claude_agent import ClaudeAgent from dimos.skills.kill_skill import KillSkill from dimos.skills.navigation import GetPose -from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import setup_logger +from dimos.web.robot_web_interface import RobotWebInterface logger = setup_logger("dimos.robot.unitree_webrtc.g1_run") @@ -87,7 +87,7 @@ def main(): # Load system prompt try: - with open(SYSTEM_PROMPT_PATH, "r") as f: + with open(SYSTEM_PROMPT_PATH) as f: system_prompt = f.read() except FileNotFoundError: logger.error(f"System prompt file not found at {SYSTEM_PROMPT_PATH}") @@ -154,7 +154,7 @@ def main(): logger.info("=" * 60) logger.info("Unitree G1 Agent Ready!") - logger.info(f"Web interface available at: http://localhost:5555") + logger.info("Web interface available at: http://localhost:5555") logger.info("You can:") logger.info(" - Type commands in the web interface") logger.info(" - Use voice commands") diff --git a/dimos/robot/unitree_webrtc/modular/connection_module.py b/dimos/robot/unitree_webrtc/modular/connection_module.py index 5950282f0b..bad9af22a1 100644 --- a/dimos/robot/unitree_webrtc/modular/connection_module.py +++ b/dimos/robot/unitree_webrtc/modular/connection_module.py @@ -15,26 +15,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass import functools import logging import os import queue -import time import warnings -from dataclasses import dataclass -from typing import List, Optional -import reactivex as rx from dimos_lcm.sensor_msgs import CameraInfo +import reactivex as rx from reactivex import operators as ops from reactivex.observable import Observable -from dimos.agents2 import Agent, Output, Reducer, Stream, skill +from dimos.agents2 import Output, Reducer, Stream, skill from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core import DimosCluster, In, LCMTransport, Module, ModuleConfig, Out, pSHMTransport, rpc -from dimos.msgs.foxglove_msgs import ImageAnnotations from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 -from dimos.msgs.sensor_msgs.Image import Image, sharpness_window +from dimos.msgs.sensor_msgs.Image import Image from dimos.msgs.std_msgs import Header from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -68,7 +65,7 @@ class FakeRTC(UnitreeWebRTCConnection): def __init__( self, **kwargs, - ): + ) -> None: get_data(self.dir_name) self.replay_config = { "loop": kwargs.get("loop"), @@ -76,16 +73,16 @@ def __init__( "duration": kwargs.get("duration"), } - def connect(self): + def connect(self) -> None: pass - def start(self): + def start(self) -> None: pass - def standup(self): + def standup(self) -> None: print("standup suppressed") - def liedown(self): + def liedown(self) -> None: print("liedown suppressed") @functools.cache @@ -108,7 +105,7 @@ def video_stream(self): return video_store.stream(**self.replay_config) - def move(self, vector: Twist, duration: float = 0.0): + def move(self, vector: Twist, duration: float = 0.0) -> None: pass def publish_request(self, topic: str, data: dict): @@ -118,7 +115,7 @@ def publish_request(self, topic: str, data: dict): @dataclass class ConnectionModuleConfig(ModuleConfig): - ip: Optional[str] = None + ip: str | None = None connection_type: str = "fake" # or "fake" or "mujoco" loop: bool = False # For fake connection speed: float = 1.0 # For fake connection @@ -139,7 +136,7 @@ class ConnectionModule(Module): # parallel calls video_running: bool = False - def __init__(self, connection_type: str = "webrtc", *args, **kwargs): + def __init__(self, connection_type: str = "webrtc", *args, **kwargs) -> None: self.connection_config = kwargs self.connection_type = connection_type Module.__init__(self, *args, **kwargs) @@ -153,11 +150,10 @@ def video_stream_tool(self) -> Image: _queue = queue.Queue(maxsize=1) self.connection.video_stream().subscribe(_queue.put) - for image in iter(_queue.get, None): - yield image + yield from iter(_queue.get, None) @rpc - def record(self, recording_name: str): + def record(self, recording_name: str) -> None: lidar_store: TimedSensorStorage = TimedSensorStorage(f"{recording_name}/lidar") lidar_store.save_stream(self.connection.lidar_stream()).subscribe(lambda x: x) @@ -219,7 +215,7 @@ def stop(self) -> None: super().stop() @classmethod - def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: + def _odom_to_tf(cls, odom: PoseStamped) -> list[Transform]: camera_link = Transform( translation=Vector3(0.3, 0.0, 0.0), rotation=Quaternion(0.0, 0.0, 0.0, 1.0), @@ -251,7 +247,7 @@ def _odom_to_tf(self, odom: PoseStamped) -> List[Transform]: sensor, ] - def _publish_tf(self, msg): + def _publish_tf(self, msg) -> None: self.odom.publish(msg) self.tf.publish(*self._odom_to_tf(msg)) @@ -267,7 +263,7 @@ def publish_request(self, topic: str, data: dict): return self.connection.publish_request(topic, data) @classmethod - def _camera_info(self) -> Out[CameraInfo]: + def _camera_info(cls) -> Out[CameraInfo]: fx, fy, cx, cy = list( map( lambda x: int(x / image_resize_factor), diff --git a/dimos/robot/unitree_webrtc/modular/detect.py b/dimos/robot/unitree_webrtc/modular/detect.py index 3f6c2c04b2..46f561b109 100644 --- a/dimos/robot/unitree_webrtc/modular/detect.py +++ b/dimos/robot/unitree_webrtc/modular/detect.py @@ -107,7 +107,7 @@ def broadcast( odom_frame: Odometry, detections, annotations, -): +) -> None: from dimos_lcm.foxglove_msgs.ImageAnnotations import ImageAnnotations from dimos.core import LCMTransport @@ -167,7 +167,7 @@ def attach_frame_id(image: Image) -> Image: return data -def main(): +def main() -> None: try: with open("filename.pkl", "rb") as file: data = pickle.load(file) diff --git a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py index 948dccaa16..e7a2bcabc8 100644 --- a/dimos/robot/unitree_webrtc/modular/ivan_unitree.py +++ b/dimos/robot/unitree_webrtc/modular/ivan_unitree.py @@ -15,34 +15,29 @@ import logging import time -from dimos_lcm.foxglove_msgs import SceneUpdate - from dimos.agents2.spec import Model, Provider from dimos.core import LCMTransport, start # from dimos.msgs.detection2d import Detection2DArray from dimos.msgs.foxglove_msgs import ImageAnnotations -from dimos.msgs.geometry_msgs import PoseStamped -from dimos.msgs.sensor_msgs import Image, PointCloud2 +from dimos.msgs.sensor_msgs import Image from dimos.msgs.vision_msgs import Detection2DArray from dimos.perception.detection.module2D import Detection2DModule -from dimos.perception.detection.module3D import Detection3DModule -from dimos.perception.detection.person_tracker import PersonTracker from dimos.perception.detection.reid import ReidModule from dimos.protocol.pubsub import lcm from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.robot.unitree_webrtc.modular import deploy_connection, deploy_navigation +from dimos.robot.unitree_webrtc.modular import deploy_connection from dimos.robot.unitree_webrtc.modular.connection_module import ConnectionModule from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.robot.unitree_webrtc.unitree_go2", level=logging.INFO) -def detection_unitree(): +def detection_unitree() -> None: dimos = start(8) connection = deploy_connection(dimos) - def goto(pose): + def goto(pose) -> bool: print("NAVIGATION REQUESTED:", pose) return True @@ -92,7 +87,7 @@ def goto(pose): connection.start() reid.start() - from dimos.agents2 import Agent, Output, Reducer, Stream, skill + from dimos.agents2 import Agent from dimos.agents2.cli.human import HumanInput agent = Agent( @@ -130,7 +125,7 @@ def goto(pose): logger.info("Shutting down...") -def main(): +def main() -> None: lcm.autoconf() detection_unitree() diff --git a/dimos/robot/unitree_webrtc/modular/navigation.py b/dimos/robot/unitree_webrtc/modular/navigation.py index f16fd29816..9aa03d104e 100644 --- a/dimos/robot/unitree_webrtc/modular/navigation.py +++ b/dimos/robot/unitree_webrtc/modular/navigation.py @@ -15,7 +15,7 @@ from dimos_lcm.std_msgs import Bool, String from dimos.core import LCMTransport -from dimos.msgs.geometry_msgs import PoseStamped, Twist, Vector3 +from dimos.msgs.geometry_msgs import PoseStamped, Twist from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer diff --git a/dimos/robot/unitree_webrtc/mujoco_connection.py b/dimos/robot/unitree_webrtc/mujoco_connection.py index 64bfaf2b8e..b68097ea33 100644 --- a/dimos/robot/unitree_webrtc/mujoco_connection.py +++ b/dimos/robot/unitree_webrtc/mujoco_connection.py @@ -20,7 +20,6 @@ import logging import threading import time -from typing import List from reactivex import Observable @@ -29,7 +28,6 @@ from dimos.msgs.sensor_msgs import Image from dimos.utils.data import get_data - LIDAR_FREQUENCY = 10 ODOM_FREQUENCY = 50 VIDEO_FREQUENCY = 30 @@ -38,15 +36,15 @@ class MujocoConnection: - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: try: from dimos.simulation.mujoco.mujoco import MujocoThread except ImportError: raise ImportError("'mujoco' is not installed. Use `pip install -e .[sim]`") get_data("mujoco_sim") self.mujoco_thread = MujocoThread() - self._stream_threads: List[threading.Thread] = [] - self._stop_events: List[threading.Event] = [] + self._stream_threads: list[threading.Thread] = [] + self._stop_events: list[threading.Event] = [] self._is_cleaned_up = False # Register cleanup on exit @@ -89,10 +87,10 @@ def stop(self) -> None: if hasattr(self, "video_stream"): self.video_stream.cache_clear() - def standup(self): + def standup(self) -> None: print("standup supressed") - def liedown(self): + def liedown(self) -> None: print("liedown supressed") @functools.cache @@ -105,7 +103,7 @@ def on_subscribe(observer, scheduler): stop_event = threading.Event() self._stop_events.append(stop_event) - def run(): + def run() -> None: try: while not stop_event.is_set() and not self._is_cleaned_up: lidar_to_publish = self.mujoco_thread.get_lidar_message() @@ -123,7 +121,7 @@ def run(): self._stream_threads.append(thread) thread.start() - def dispose(): + def dispose() -> None: stop_event.set() return dispose @@ -140,7 +138,7 @@ def on_subscribe(observer, scheduler): stop_event = threading.Event() self._stop_events.append(stop_event) - def run(): + def run() -> None: try: while not stop_event.is_set() and not self._is_cleaned_up: odom_to_publish = self.mujoco_thread.get_odom_message() @@ -157,7 +155,7 @@ def run(): self._stream_threads.append(thread) thread.start() - def dispose(): + def dispose() -> None: stop_event.set() return dispose @@ -174,7 +172,7 @@ def on_subscribe(observer, scheduler): stop_event = threading.Event() self._stop_events.append(stop_event) - def run(): + def run() -> None: lat = 37.78092426217621 lon = -122.40682866540769 try: @@ -189,7 +187,7 @@ def run(): self._stream_threads.append(thread) thread.start() - def dispose(): + def dispose() -> None: stop_event.set() return dispose @@ -206,7 +204,7 @@ def on_subscribe(observer, scheduler): stop_event = threading.Event() self._stop_events.append(stop_event) - def run(): + def run() -> None: try: while not stop_event.is_set() and not self._is_cleaned_up: with self.mujoco_thread.pixels_lock: @@ -223,16 +221,16 @@ def run(): self._stream_threads.append(thread) thread.start() - def dispose(): + def dispose() -> None: stop_event.set() return dispose return Observable(on_subscribe) - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: Twist, duration: float = 0.0) -> None: if not self._is_cleaned_up: self.mujoco_thread.move(twist, duration) - def publish_request(self, topic: str, data: dict): + def publish_request(self, topic: str, data: dict) -> None: pass diff --git a/dimos/robot/unitree_webrtc/rosnav.py b/dimos/robot/unitree_webrtc/rosnav.py index 969ddad950..bd91fafb90 100644 --- a/dimos/robot/unitree_webrtc/rosnav.py +++ b/dimos/robot/unitree_webrtc/rosnav.py @@ -16,25 +16,11 @@ import logging import time -from dimos import core -from dimos.core import Module, In, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped, Transform, Vector3 -from dimos.msgs.nav_msgs import Odometry -from dimos.msgs.sensor_msgs import PointCloud2, Joy +from dimos.core import In, Module, Out, rpc +from dimos.msgs.geometry_msgs import PoseStamped +from dimos.msgs.sensor_msgs import Joy from dimos.msgs.std_msgs.Bool import Bool -from dimos.msgs.std_msgs.Header import Header -from dimos.msgs.tf2_msgs.TFMessage import TFMessage -from dimos.protocol.tf import TF -from dimos.robot.ros_bridge import ROSBridge, BridgeDirection -from dimos.utils.transform_utils import euler_to_quaternion -from geometry_msgs.msg import TwistStamped as ROSTwistStamped -from geometry_msgs.msg import PoseStamped as ROSPoseStamped -from nav_msgs.msg import Odometry as ROSOdometry -from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 -from std_msgs.msg import Bool as ROSBool -from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos.utils.logging_config import setup_logger -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic logger = setup_logger("dimos.robot.unitree_webrtc.nav_bot", level=logging.INFO) @@ -45,23 +31,23 @@ class NavigationModule(Module): cancel_goal: Out[Bool] = None joy: Out[Joy] = None - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: """Initialize NavigationModule.""" Module.__init__(self, *args, **kwargs) self.goal_reach = None @rpc - def start(self): + def start(self) -> None: """Start the navigation module.""" if self.goal_reached: self.goal_reached.subscribe(self._on_goal_reached) logger.info("NavigationModule started") - def _on_goal_reached(self, msg: Bool): + def _on_goal_reached(self, msg: Bool) -> None: """Handle goal reached status messages.""" self.goal_reach = msg.data - def _set_autonomy_mode(self): + def _set_autonomy_mode(self) -> None: """ Set autonomy mode by publishing Joy message. """ @@ -95,7 +81,7 @@ def _set_autonomy_mode(self): if self.joy: self.joy.publish(joy_msg) - logger.info(f"Setting autonomy mode via Joy message") + logger.info("Setting autonomy mode via Joy message") @rpc def go_to(self, pose: PoseStamped, timeout: float = 60.0) -> bool: diff --git a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py index 20871be4ce..7acdfc1980 100644 --- a/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py +++ b/dimos/robot/unitree_webrtc/test_unitree_go2_integration.py @@ -18,17 +18,17 @@ from dimos import core from dimos.core import Module, Out, rpc -from dimos.msgs.geometry_msgs import PoseStamped, Twist, Vector3, Quaternion -from dimos.msgs.sensor_msgs import Image +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Twist, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.protocol import pubsub +from dimos.msgs.sensor_msgs import Image +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer from dimos.navigation.global_planner import AstarPlanner from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator -from dimos.robot.unitree_webrtc.unitree_go2 import ConnectionModule +from dimos.protocol import pubsub from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map +from dimos.robot.unitree_webrtc.unitree_go2 import ConnectionModule from dimos.utils.logging_config import setup_logger logger = setup_logger("test_unitree_go2_integration") @@ -41,12 +41,12 @@ class MovementControlModule(Module): movecmd: Out[Twist] = None - def __init__(self): + def __init__(self) -> None: super().__init__() self.commands_sent = [] @rpc - def send_move_command(self, x: float, y: float, yaw: float): + def send_move_command(self, x: float, y: float, yaw: float) -> None: """Send a movement command.""" cmd = Twist(linear=Vector3(x, y, 0.0), angular=Vector3(0.0, 0.0, yaw)) self.movecmd.publish(cmd) @@ -62,7 +62,7 @@ def get_command_count(self) -> int: @pytest.mark.module class TestUnitreeGo2CoreModules: @pytest.mark.asyncio - async def test_unitree_go2_navigation_stack(self): + async def test_unitree_go2_navigation_stack(self) -> None: """Test UnitreeGo2 core navigation modules without perception/visualization.""" # Start Dask @@ -94,9 +94,10 @@ async def test_unitree_go2_navigation_stack(self): navigator = dimos.deploy(BehaviorTreeNavigator, local_planner=local_planner) # Set up transports first - from dimos.msgs.nav_msgs import Path from dimos_lcm.std_msgs import Bool + from dimos.msgs.nav_msgs import Path + navigator.goal.transport = core.LCMTransport("/navigation_goal", PoseStamped) navigator.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) navigator.goal_reached.transport = core.LCMTransport("/goal_reached", Bool) diff --git a/dimos/robot/unitree_webrtc/testing/helpers.py b/dimos/robot/unitree_webrtc/testing/helpers.py index 8d01cb76cc..5159deab4c 100644 --- a/dimos/robot/unitree_webrtc/testing/helpers.py +++ b/dimos/robot/unitree_webrtc/testing/helpers.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable, Iterable import time +from typing import Any, Protocol + import open3d as o3d -from typing import Callable, Union, Any, Protocol, Iterable from reactivex.observable import Observable color1 = [1, 0.706, 0] @@ -28,7 +30,7 @@ # # (in case there is some preparation within the fuction and this time needs to be subtracted # from the benchmark target) -def benchmark(calls: int, targetf: Callable[[], Union[int, None]]) -> float: +def benchmark(calls: int, targetf: Callable[[], int | None]) -> float: start = time.time() timemod = 0 for _ in range(calls): @@ -89,8 +91,8 @@ def show3d_stream( Subsequent geometries update the visualizer. If no new geometry, just poll events. geometry_observable: Observable of objects with .o3d_geometry or Open3D geometry """ - import threading import queue + import threading import time from typing import Any diff --git a/dimos/robot/unitree_webrtc/testing/mock.py b/dimos/robot/unitree_webrtc/testing/mock.py index f929d33c5c..20eb357cc0 100644 --- a/dimos/robot/unitree_webrtc/testing/mock.py +++ b/dimos/robot/unitree_webrtc/testing/mock.py @@ -12,35 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator +import glob import os import pickle -import glob -from typing import Union, Iterator, cast, overload -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg +from typing import cast, overload -from reactivex import operators as ops -from reactivex import interval, from_iterable +from reactivex import from_iterable, interval, operators as ops from reactivex.observable import Observable +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage, RawLidarMsg + class Mock: - def __init__(self, root="office", autocast: bool = True): + def __init__(self, root: str = "office", autocast: bool = True) -> None: current_dir = os.path.dirname(os.path.abspath(__file__)) self.root = os.path.join(current_dir, f"mockdata/{root}") self.autocast = autocast self.cnt = 0 @overload - def load(self, name: Union[int, str], /) -> LidarMessage: ... + def load(self, name: int | str, /) -> LidarMessage: ... @overload - def load(self, *names: Union[int, str]) -> list[LidarMessage]: ... + def load(self, *names: int | str) -> list[LidarMessage]: ... - def load(self, *names: Union[int, str]) -> Union[LidarMessage, list[LidarMessage]]: + def load(self, *names: int | str) -> LidarMessage | list[LidarMessage]: if len(names) == 1: return self.load_one(names[0]) return list(map(lambda name: self.load_one(name), names)) - def load_one(self, name: Union[int, str]) -> LidarMessage: + def load_one(self, name: int | str) -> LidarMessage: if isinstance(name, int): file_name = f"/lidar_data_{name:03d}.pickle" else: @@ -48,7 +49,7 @@ def load_one(self, name: Union[int, str]) -> LidarMessage: full_path = self.root + file_name with open(full_path, "rb") as f: - return LidarMessage.from_msg(cast(RawLidarMsg, pickle.load(f))) + return LidarMessage.from_msg(cast("RawLidarMsg", pickle.load(f))) def iterate(self) -> Iterator[LidarMessage]: pattern = os.path.join(self.root, "lidar_data_*.pickle") @@ -58,7 +59,7 @@ def iterate(self) -> Iterator[LidarMessage]: filename = os.path.splitext(basename)[0] yield self.load_one(filename) - def stream(self, rate_hz=10.0): + def stream(self, rate_hz: float = 10.0): sleep_time = 1.0 / rate_hz return from_iterable(self.iterate()).pipe( diff --git a/dimos/robot/unitree_webrtc/testing/multimock.py b/dimos/robot/unitree_webrtc/testing/multimock.py index cfc2688129..eab10e14bb 100644 --- a/dimos/robot/unitree_webrtc/testing/multimock.py +++ b/dimos/robot/unitree_webrtc/testing/multimock.py @@ -33,13 +33,19 @@ import os import pickle import time -from typing import Any, Generic, Iterator, List, Tuple, TypeVar, Union, Optional -from reactivex.scheduler import ThreadPoolScheduler +from typing import TYPE_CHECKING, Any, Generic, TypeVar from reactivex import from_iterable, interval, operators as ops -from reactivex.observable import Observable -from dimos.utils.threadpool import get_scheduler + from dimos.robot.unitree_webrtc.type.timeseries import TEvent, Timeseries +from dimos.utils.threadpool import get_scheduler + +if TYPE_CHECKING: + import builtins + from collections.abc import Iterator + + from reactivex.observable import Observable + from reactivex.scheduler import ThreadPoolScheduler T = TypeVar("T") @@ -80,11 +86,11 @@ def save_one(self, frame: Any) -> int: return self.cnt - def load(self, *names: Union[int, str]) -> List[Tuple[float, T]]: + def load(self, *names: int | str) -> builtins.list[tuple[float, T]]: """Load multiple items by name or index.""" return list(map(self.load_one, names)) - def load_one(self, name: Union[int, str]) -> TEvent[T]: + def load_one(self, name: int | str) -> TEvent[T]: """Load a single item by name or index.""" if isinstance(name, int): file_name = f"/{self.file_prefix}_{name:03d}.pickle" @@ -106,7 +112,7 @@ def iterate(self) -> Iterator[TEvent[T]]: timestamp, data = pickle.load(f) yield TEvent(timestamp, data) - def list(self) -> List[TEvent[T]]: + def list(self) -> builtins.list[TEvent[T]]: return list(self.iterate()) def interval_stream(self, rate_hz: float = 10.0) -> Observable[T]: @@ -120,7 +126,7 @@ def interval_stream(self, rate_hz: float = 10.0) -> Observable[T]: def stream( self, replay_speed: float = 1.0, - scheduler: Optional[ThreadPoolScheduler] = None, + scheduler: ThreadPoolScheduler | None = None, ) -> Observable[T]: def _generator(): prev_ts: float | None = None diff --git a/dimos/robot/unitree_webrtc/testing/test_actors.py b/dimos/robot/unitree_webrtc/testing/test_actors.py index 1b42412249..4612f45a79 100644 --- a/dimos/robot/unitree_webrtc/testing/test_actors.py +++ b/dimos/robot/unitree_webrtc/testing/test_actors.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio +from collections.abc import Callable import time -from typing import Callable import pytest @@ -36,12 +36,12 @@ def client(): class Consumer: testf: Callable[[int], int] - def __init__(self, counter=None): + def __init__(self, counter=None) -> None: self.testf = counter print("consumer init with", counter) async def waitcall(self, n: int): - async def task(): + async def task() -> None: await asyncio.sleep(n) print("sleep finished, calling") @@ -60,7 +60,7 @@ def addten(self, x: int): @pytest.mark.tool -def test_wait(client): +def test_wait(client) -> None: counter = client.submit(Counter, actor=True).result() async def addten(n): @@ -74,7 +74,7 @@ async def addten(n): @pytest.mark.tool -def test_basic(dimos): +def test_basic(dimos) -> None: counter = dimos.deploy(Counter) consumer = dimos.deploy( Consumer, @@ -93,7 +93,7 @@ def test_basic(dimos): @pytest.mark.tool -def test_mapper_start(dimos): +def test_mapper_start(dimos) -> None: mapper = dimos.deploy(Mapper) mapper.lidar.transport = core.LCMTransport("/lidar", LidarMessage) print("start res", mapper.start().result()) @@ -106,6 +106,6 @@ def test_mapper_start(dimos): @pytest.mark.tool -def test_counter(dimos): +def test_counter(dimos) -> None: counter = dimos.deploy(Counter) assert counter.addten(10) == 20 diff --git a/dimos/robot/unitree_webrtc/testing/test_mock.py b/dimos/robot/unitree_webrtc/testing/test_mock.py index 4852392943..73eeef05ba 100644 --- a/dimos/robot/unitree_webrtc/testing/test_mock.py +++ b/dimos/robot/unitree_webrtc/testing/test_mock.py @@ -14,13 +14,15 @@ # limitations under the License. import time + import pytest -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage + from dimos.robot.unitree_webrtc.testing.mock import Mock +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @pytest.mark.needsdata -def test_mock_load_cast(): +def test_mock_load_cast() -> None: mock = Mock("test") # Load a frame with type casting @@ -39,7 +41,7 @@ def test_mock_load_cast(): @pytest.mark.needsdata -def test_mock_iterate(): +def test_mock_iterate() -> None: """Test the iterate method of the Mock class.""" mock = Mock("office") @@ -52,7 +54,7 @@ def test_mock_iterate(): @pytest.mark.needsdata -def test_mock_stream(): +def test_mock_stream() -> None: frames = [] sub1 = Mock("office").stream(rate_hz=30.0).subscribe(on_next=frames.append) time.sleep(0.1) diff --git a/dimos/robot/unitree_webrtc/testing/test_tooling.py b/dimos/robot/unitree_webrtc/testing/test_tooling.py index b68bed2f86..38a3dba593 100644 --- a/dimos/robot/unitree_webrtc/testing/test_tooling.py +++ b/dimos/robot/unitree_webrtc/testing/test_tooling.py @@ -16,8 +16,8 @@ import sys import time -import pytest from dotenv import load_dotenv +import pytest from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.odometry import Odometry @@ -26,7 +26,7 @@ @pytest.mark.tool -def test_record_all(): +def test_record_all() -> None: from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 load_dotenv() @@ -58,7 +58,7 @@ def test_record_all(): @pytest.mark.tool -def test_replay_all(): +def test_replay_all() -> None: lidar_store = TimedSensorReplay("unitree/lidar", autocast=LidarMessage.from_msg) odom_store = TimedSensorReplay("unitree/odom", autocast=Odometry.from_msg) video_store = TimedSensorReplay("unitree/video") diff --git a/dimos/robot/unitree_webrtc/type/lidar.py b/dimos/robot/unitree_webrtc/type/lidar.py index e21c7ddd00..a6595790ad 100644 --- a/dimos/robot/unitree_webrtc/type/lidar.py +++ b/dimos/robot/unitree_webrtc/type/lidar.py @@ -13,8 +13,7 @@ # limitations under the License. import time -from copy import copy -from typing import List, Optional, TypedDict +from typing import TypedDict import numpy as np import open3d as o3d @@ -32,11 +31,11 @@ class RawLidarData(TypedDict): """Data portion of the LIDAR message""" frame_id: str - origin: List[float] + origin: list[float] resolution: float src_size: int stamp: float - width: List[int] + width: list[int] data: RawLidarPoints @@ -51,10 +50,10 @@ class RawLidarMsg(TypedDict): class LidarMessage(PointCloud2): resolution: float # we lose resolution when encoding PointCloud2 origin: Vector3 - raw_msg: Optional[RawLidarMsg] + raw_msg: RawLidarMsg | None # _costmap: Optional[Costmap] = None # TODO: Fix after costmap migration - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__( pointcloud=kwargs.get("pointcloud"), ts=kwargs.get("ts"), @@ -87,7 +86,7 @@ def from_msg(cls: type["LidarMessage"], raw_message: RawLidarMsg, **kwargs) -> " } return cls(**cls_data) - def __repr__(self): + def __repr__(self) -> str: return f"LidarMessage(ts={to_human_readable(self.ts)}, origin={self.origin}, resolution={self.resolution}, {self.pointcloud})" def __iadd__(self, other: "LidarMessage") -> "LidarMessage": diff --git a/dimos/robot/unitree_webrtc/type/lowstate.py b/dimos/robot/unitree_webrtc/type/lowstate.py index 9c4d8edee5..c50504135c 100644 --- a/dimos/robot/unitree_webrtc/type/lowstate.py +++ b/dimos/robot/unitree_webrtc/type/lowstate.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypedDict, List, Literal +from typing import Literal, TypedDict raw_odom_msg_sample = { "type": "msg", @@ -61,11 +61,11 @@ class MotorState(TypedDict): q: float temperature: int lost: int - reserve: List[int] + reserve: list[int] class ImuState(TypedDict): - rpy: List[float] + rpy: list[float] class BmsState(TypedDict): @@ -74,15 +74,15 @@ class BmsState(TypedDict): soc: int current: int cycle: int - bq_ntc: List[int] - mcu_ntc: List[int] + bq_ntc: list[int] + mcu_ntc: list[int] class LowStateData(TypedDict): imu_state: ImuState - motor_state: List[MotorState] + motor_state: list[MotorState] bms_state: BmsState - foot_force: List[int] + foot_force: list[int] temperature_ntc1: int power_v: float diff --git a/dimos/robot/unitree_webrtc/type/map.py b/dimos/robot/unitree_webrtc/type/map.py index 61eaa83d0f..452bcaf17c 100644 --- a/dimos/robot/unitree_webrtc/type/map.py +++ b/dimos/robot/unitree_webrtc/type/map.py @@ -13,7 +13,6 @@ # limitations under the License. import time -from typing import Optional import numpy as np import open3d as o3d @@ -40,12 +39,12 @@ def __init__( self, voxel_size: float = 0.05, cost_resolution: float = 0.05, - global_publish_interval: Optional[float] = None, + global_publish_interval: float | None = None, min_height: float = 0.15, max_height: float = 0.6, global_config: GlobalConfig | None = None, **kwargs, - ): + ) -> None: self.voxel_size = voxel_size self.cost_resolution = cost_resolution self.global_publish_interval = global_publish_interval @@ -59,13 +58,13 @@ def __init__( super().__init__(**kwargs) @rpc - def start(self): + def start(self) -> None: super().start() unsub = self.lidar.subscribe(self.add_frame) self._disposables.add(Disposable(unsub)) - def publish(_): + def publish(_) -> None: self.global_map.publish(self.to_lidar_message()) # temporary, not sure if it belogs in mapper diff --git a/dimos/robot/unitree_webrtc/type/odometry.py b/dimos/robot/unitree_webrtc/type/odometry.py index c307929a00..52a8544fbc 100644 --- a/dimos/robot/unitree_webrtc/type/odometry.py +++ b/dimos/robot/unitree_webrtc/type/odometry.py @@ -14,13 +14,10 @@ import time from typing import Literal, TypedDict -from scipy.spatial.transform import Rotation as R - from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Vector3 from dimos.robot.unitree_webrtc.type.timeseries import ( Timestamped, ) -from dimos.types.timestamped import to_human_readable, to_timestamp raw_odometry_msg_sample = { "type": "msg", diff --git a/dimos/robot/unitree_webrtc/type/test_lidar.py b/dimos/robot/unitree_webrtc/type/test_lidar.py index 75ceec88f8..93435e8e4b 100644 --- a/dimos/robot/unitree_webrtc/type/test_lidar.py +++ b/dimos/robot/unitree_webrtc/type/test_lidar.py @@ -14,17 +14,12 @@ # limitations under the License. import itertools -import time -import pytest - -from dimos.msgs.sensor_msgs import PointCloud2 -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.testing import SensorReplay -def test_init(): +def test_init() -> None: lidar = SensorReplay("office_lidar") for raw_frame in itertools.islice(lidar.iterate(), 5): diff --git a/dimos/robot/unitree_webrtc/type/test_map.py b/dimos/robot/unitree_webrtc/type/test_map.py index ef2418c7f4..12ee8f832d 100644 --- a/dimos/robot/unitree_webrtc/type/test_map.py +++ b/dimos/robot/unitree_webrtc/type/test_map.py @@ -22,7 +22,7 @@ @pytest.mark.vis -def test_costmap_vis(): +def test_costmap_vis() -> None: map = Map() map.start() mock = Mock("office") @@ -39,7 +39,7 @@ def test_costmap_vis(): @pytest.mark.vis -def test_reconstruction_with_realtime_vis(): +def test_reconstruction_with_realtime_vis() -> None: map = Map() map.start() mock = Mock("office") @@ -52,7 +52,7 @@ def test_reconstruction_with_realtime_vis(): @pytest.mark.vis -def test_splice_vis(): +def test_splice_vis() -> None: mock = Mock("test") target = mock.load("a") insert = mock.load("b") @@ -60,7 +60,7 @@ def test_splice_vis(): @pytest.mark.vis -def test_robot_vis(): +def test_robot_vis() -> None: map = Map() map.start() mock = Mock("office") @@ -72,13 +72,13 @@ def test_robot_vis(): show3d(map.pointcloud, title="global dynamic map test").run() -def test_robot_mapping(): +def test_robot_mapping() -> None: lidar_replay = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) map = Map(voxel_size=0.5) # Mock the output streams to avoid publishing errors class MockStream: - def publish(self, msg): + def publish(self, msg) -> None: pass # Do nothing map.local_costmap = MockStream() diff --git a/dimos/robot/unitree_webrtc/type/test_odometry.py b/dimos/robot/unitree_webrtc/type/test_odometry.py index 0bd76f1900..b1a251b254 100644 --- a/dimos/robot/unitree_webrtc/type/test_odometry.py +++ b/dimos/robot/unitree_webrtc/type/test_odometry.py @@ -14,14 +14,13 @@ from __future__ import annotations +from operator import add, sub import os import threading -from operator import add, sub -from typing import Optional +from dotenv import load_dotenv import pytest import reactivex.operators as ops -from dotenv import load_dotenv from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.utils.testing import SensorReplay, SensorStorage @@ -57,7 +56,7 @@ def test_last_yaw_value() -> None: def test_total_rotation_travel_iterate() -> None: total_rad = 0.0 - prev_yaw: Optional[float] = None + prev_yaw: float | None = None for odom in SensorReplay(name="raw_odometry_rotate_walk", autocast=Odometry.from_msg).iterate(): yaw = odom.orientation.radians.z diff --git a/dimos/robot/unitree_webrtc/type/test_timeseries.py b/dimos/robot/unitree_webrtc/type/test_timeseries.py index fe96d75eaf..b7c955933d 100644 --- a/dimos/robot/unitree_webrtc/type/test_timeseries.py +++ b/dimos/robot/unitree_webrtc/type/test_timeseries.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from datetime import timedelta, datetime -from dimos.robot.unitree_webrtc.type.timeseries import TEvent, TList +from datetime import datetime, timedelta +from dimos.robot.unitree_webrtc.type.timeseries import TEvent, TList fixed_date = datetime(2025, 5, 13, 15, 2, 5).astimezone() start_event = TEvent(fixed_date, 1) @@ -23,22 +23,22 @@ sample_list = TList([start_event, TEvent(fixed_date + timedelta(seconds=2), 5), end_event]) -def test_repr(): +def test_repr() -> None: assert ( str(sample_list) == "Timeseries(date=2025-05-13, start=15:02:05, end=15:02:15, duration=0:00:10, events=3, freq=0.30Hz)" ) -def test_equals(): +def test_equals() -> None: assert start_event == TEvent(start_event.ts, 1) assert start_event != TEvent(start_event.ts, 2) assert start_event != TEvent(start_event.ts + timedelta(seconds=1), 1) -def test_range(): +def test_range() -> None: assert sample_list.time_range() == (start_event.ts, end_event.ts) -def test_duration(): +def test_duration() -> None: assert sample_list.duration() == timedelta(seconds=10) diff --git a/dimos/robot/unitree_webrtc/type/timeseries.py b/dimos/robot/unitree_webrtc/type/timeseries.py index 48dfddcac5..baf683c019 100644 --- a/dimos/robot/unitree_webrtc/type/timeseries.py +++ b/dimos/robot/unitree_webrtc/type/timeseries.py @@ -16,7 +16,10 @@ from abc import ABC, abstractmethod from datetime import datetime, timedelta, timezone -from typing import Generic, Iterable, Tuple, TypedDict, TypeVar, Union +from typing import TYPE_CHECKING, Generic, TypedDict, TypeVar, Union + +if TYPE_CHECKING: + from collections.abc import Iterable PAYLOAD = TypeVar("PAYLOAD") @@ -29,7 +32,7 @@ class RosStamp(TypedDict): EpochLike = Union[int, float, datetime, RosStamp] -def from_ros_stamp(stamp: dict[str, int], tz: timezone = None) -> datetime: +def from_ros_stamp(stamp: dict[str, int], tz: timezone | None = None) -> datetime: """Convert ROS-style timestamp {'sec': int, 'nanosec': int} to datetime.""" return datetime.fromtimestamp(stamp["sec"] + stamp["nanosec"] / 1e9, tz=tz) @@ -39,7 +42,7 @@ def to_human_readable(ts: EpochLike) -> str: return dt.strftime("%Y-%m-%d %H:%M:%S") -def to_datetime(ts: EpochLike, tz: timezone = None) -> datetime: +def to_datetime(ts: EpochLike, tz: timezone | None = None) -> datetime: if isinstance(ts, datetime): # if ts.tzinfo is None: # ts = ts.astimezone(tz) @@ -56,14 +59,14 @@ class Timestamped(ABC): ts: datetime - def __init__(self, ts: EpochLike): + def __init__(self, ts: EpochLike) -> None: self.ts = to_datetime(ts) class TEvent(Timestamped, Generic[PAYLOAD]): """Concrete class for an event with a timestamp and data.""" - def __init__(self, timestamp: EpochLike, data: PAYLOAD): + def __init__(self, timestamp: EpochLike, data: PAYLOAD) -> None: super().__init__(timestamp) self.data = data @@ -100,7 +103,7 @@ def frequency(self) -> float: """Calculate the frequency of events in Hz.""" return len(list(self)) / (self.duration().total_seconds() or 1) - def time_range(self) -> Tuple[datetime, datetime]: + def time_range(self) -> tuple[datetime, datetime]: """Return (earliest_ts, latest_ts). Empty input ⇒ ValueError.""" return self.start_time, self.end_time diff --git a/dimos/robot/unitree_webrtc/type/vector.py b/dimos/robot/unitree_webrtc/type/vector.py index 22b00a753d..be00e3403c 100644 --- a/dimos/robot/unitree_webrtc/type/vector.py +++ b/dimos/robot/unitree_webrtc/type/vector.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import numpy as np +import builtins +from collections.abc import Iterable from typing import ( - Tuple, - List, - TypeVar, - Protocol, - runtime_checkable, Any, - Iterable, + Protocol, + TypeVar, Union, + runtime_checkable, ) + +import numpy as np from numpy.typing import NDArray T = TypeVar("T", bound="Vector") @@ -53,7 +53,7 @@ def yaw(self) -> float: return self.x @property - def tuple(self) -> Tuple[float, ...]: + def tuple(self) -> tuple[float, ...]: """Tuple representation of the vector.""" return tuple(self._data) @@ -269,11 +269,11 @@ def unit_z(cls: type[T], dim: int = 3) -> T: v[2] = 1.0 return cls(v) - def to_list(self) -> List[float]: + def to_list(self) -> list[float]: """Convert the vector to a list.""" return [float(x) for x in self._data] - def to_tuple(self) -> Tuple[float, ...]: + def to_tuple(self) -> builtins.tuple[float, ...]: """Convert the vector to a tuple.""" return tuple(self._data) @@ -324,7 +324,7 @@ def to_vector(value: VectorLike) -> Vector: return Vector(value) -def to_tuple(value: VectorLike) -> Tuple[float, ...]: +def to_tuple(value: VectorLike) -> tuple[float, ...]: """Convert a vector-compatible value to a tuple. Args: @@ -345,7 +345,7 @@ def to_tuple(value: VectorLike) -> Tuple[float, ...]: return tuple(float(x) for x in data) -def to_list(value: VectorLike) -> List[float]: +def to_list(value: VectorLike) -> list[float]: """Convert a vector-compatible value to a list. Args: diff --git a/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py b/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py index ab547dade2..82545fa2c6 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py @@ -17,10 +17,10 @@ """Internal B1 command structure for UDP communication.""" -from pydantic import BaseModel, Field -from typing import Optional import struct +from pydantic import BaseModel, Field + class B1Command(BaseModel): """Internal B1 robot command matching UDP packet structure. diff --git a/dimos/robot/unitree_webrtc/unitree_b1/connection.py b/dimos/robot/unitree_webrtc/unitree_b1/connection.py index a458858040..73285b4d76 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/connection.py @@ -22,6 +22,8 @@ import threading import time +from reactivex.disposable import Disposable + from dimos.core import In, Module, Out, rpc from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry @@ -29,7 +31,6 @@ from dimos.utils.logging_config import setup_logger from .b1_command import B1Command -from reactivex.disposable import Disposable # Setup logger with DEBUG level for troubleshooting logger = setup_logger("dimos.robot.unitree_webrtc.unitree_b1.connection", level=logging.DEBUG) @@ -59,7 +60,7 @@ class B1ConnectionModule(Module): def __init__( self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs - ): + ) -> None: """Initialize B1 connection module. Args: @@ -87,7 +88,7 @@ def __init__( self.timeout_active = False @rpc - def start(self): + def start(self) -> None: """Start the connection and subscribe to command streams.""" super().start() @@ -123,7 +124,7 @@ def start(self): self.watchdog_thread.start() @rpc - def stop(self): + def stop(self) -> None: """Stop the connection and send stop commands.""" self.set_mode(RobotMode.IDLE) # IDLE @@ -152,7 +153,7 @@ def stop(self): super().stop() - def handle_twist_stamped(self, twist_stamped: TwistStamped): + def handle_twist_stamped(self, twist_stamped: TwistStamped) -> None: """Handle timestamped Twist message and convert to B1Command. This is called automatically when messages arrive on cmd_vel input. @@ -197,7 +198,7 @@ def handle_twist_stamped(self, twist_stamped: TwistStamped): self.last_command_time = time.time() self.timeout_active = False # Reset timeout state since we got a new command - def handle_mode(self, mode_msg: Int32): + def handle_mode(self, mode_msg: Int32) -> None: """Handle mode change message. This is called automatically when messages arrive on mode_cmd input. @@ -208,7 +209,7 @@ def handle_mode(self, mode_msg: Int32): self.set_mode(mode_msg.data) @rpc - def set_mode(self, mode: int): + def set_mode(self, mode: int) -> bool: """Set robot mode (0=idle, 1=stand, 2=walk, 6=recovery).""" self.current_mode = mode with self.cmd_lock: @@ -233,7 +234,7 @@ def set_mode(self, mode: int): return True - def _send_loop(self): + def _send_loop(self) -> None: """Continuously send current command at 50Hz. The watchdog thread handles timeout and zeroing commands, so this loop @@ -269,7 +270,7 @@ def _send_loop(self): if self.running: logger.error(f"Send error: {e}") - def _publish_odom_pose(self, msg: Odometry): + def _publish_odom_pose(self, msg: Odometry) -> None: """Convert and publish odometry as PoseStamped. This matches G1's approach of receiving external odometry. @@ -283,7 +284,7 @@ def _publish_odom_pose(self, msg: Odometry): ) self.odom_pose.publish(pose_stamped) - def _watchdog_loop(self): + def _watchdog_loop(self) -> None: """Single watchdog thread that monitors command freshness.""" while self.watchdog_running: try: @@ -320,31 +321,31 @@ def _watchdog_loop(self): logger.error(f"Watchdog error: {e}") @rpc - def idle(self): + def idle(self) -> bool: """Set robot to idle mode.""" self.set_mode(RobotMode.IDLE) return True @rpc - def pose(self): + def pose(self) -> bool: """Set robot to stand/pose mode for reaching ground objects with manipulator.""" self.set_mode(RobotMode.STAND) return True @rpc - def walk(self): + def walk(self) -> bool: """Set robot to walk mode.""" self.set_mode(RobotMode.WALK) return True @rpc - def recovery(self): + def recovery(self) -> bool: """Set robot to recovery mode.""" self.set_mode(RobotMode.RECOVERY) return True @rpc - def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> bool: """Direct RPC method for sending TwistStamped commands. Args: @@ -358,11 +359,11 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0): class MockB1ConnectionModule(B1ConnectionModule): """Test connection module that prints commands instead of sending UDP.""" - def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs): + def __init__(self, ip: str = "127.0.0.1", port: int = 9090, *args, **kwargs) -> None: """Initialize test connection without creating socket.""" super().__init__(ip, port, test_mode=True, *args, **kwargs) - def _send_loop(self): + def _send_loop(self) -> None: """Override to provide better test output with timeout detection.""" timeout_warned = False diff --git a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py index 9edc27f3c3..9c3c09861c 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py @@ -24,6 +24,7 @@ os.environ["SDL_VIDEODRIVER"] = "x11" import time + from dimos.core import Module, Out, rpc from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 from dimos.msgs.std_msgs import Int32 @@ -39,14 +40,14 @@ class JoystickModule(Module): twist_out: Out[TwistStamped] = None # Timestamped velocity commands mode_out: Out[Int32] = None # Mode changes - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: Module.__init__(self, *args, **kwargs) self.pygame_ready = False self.running = False self.current_mode = 0 # Start in IDLE mode for safety @rpc - def start(self): + def start(self) -> bool: """Initialize pygame and start control loop.""" super().start() @@ -88,7 +89,7 @@ def stop(self) -> None: super().stop() - def _pygame_loop(self): + def _pygame_loop(self) -> None: """Main pygame event loop - ALL pygame operations happen here.""" import pygame @@ -223,7 +224,7 @@ def _pygame_loop(self): pygame.quit() print("JoystickModule stopped") - def _update_display(self, twist): + def _update_display(self, twist) -> None: """Update pygame window with current status.""" import pygame diff --git a/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py index 57227e6e23..49421c85e0 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py @@ -34,7 +34,7 @@ class TestB1Connection: """Test suite for B1 connection module with Timer implementation.""" - def test_watchdog_actually_zeros_commands(self): + def test_watchdog_actually_zeros_commands(self) -> None: """Test that watchdog thread zeros commands after timeout.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -75,7 +75,7 @@ def test_watchdog_actually_zeros_commands(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_watchdog_resets_on_new_command(self): + def test_watchdog_resets_on_new_command(self) -> None: """Test that watchdog timeout resets when new command arrives.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -123,7 +123,7 @@ def test_watchdog_resets_on_new_command(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_watchdog_thread_efficiency(self): + def test_watchdog_thread_efficiency(self) -> None: """Test that watchdog uses only one thread regardless of command rate.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -157,7 +157,7 @@ def test_watchdog_thread_efficiency(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_watchdog_with_send_loop_blocking(self): + def test_watchdog_with_send_loop_blocking(self) -> None: """Test that watchdog still works if send loop blocks.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) @@ -165,7 +165,7 @@ def test_watchdog_with_send_loop_blocking(self): original_send_loop = conn._send_loop block_event = threading.Event() - def blocking_send_loop(): + def blocking_send_loop() -> None: # Block immediately block_event.wait() # Then run normally @@ -204,7 +204,7 @@ def blocking_send_loop(): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_continuous_commands_prevent_timeout(self): + def test_continuous_commands_prevent_timeout(self) -> None: """Test that continuous commands prevent watchdog timeout.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -239,7 +239,7 @@ def test_continuous_commands_prevent_timeout(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_watchdog_timing_accuracy(self): + def test_watchdog_timing_accuracy(self) -> None: """Test that watchdog zeros commands at approximately 200ms.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -280,7 +280,7 @@ def test_watchdog_timing_accuracy(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_mode_changes_with_watchdog(self): + def test_mode_changes_with_watchdog(self) -> None: """Test that mode changes work correctly with watchdog.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -323,7 +323,7 @@ def test_mode_changes_with_watchdog(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_watchdog_stops_movement_when_commands_stop(self): + def test_watchdog_stops_movement_when_commands_stop(self) -> None: """Verify watchdog zeros commands when packets stop being sent.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -334,7 +334,7 @@ def test_watchdog_stops_movement_when_commands_stop(self): conn.watchdog_thread.start() # Simulate sending movement commands for a while - for i in range(5): + for _i in range(5): twist = TwistStamped( ts=time.time(), frame_id="base_link", @@ -381,7 +381,7 @@ def test_watchdog_stops_movement_when_commands_stop(self): conn.watchdog_thread.join(timeout=0.5) conn._close_module() - def test_rapid_command_thread_safety(self): + def test_rapid_command_thread_safety(self) -> None: """Test thread safety with rapid commands from multiple threads.""" conn = MockB1ConnectionModule(ip="127.0.0.1", port=9090) conn.running = True @@ -395,8 +395,8 @@ def test_rapid_command_thread_safety(self): initial_threads = threading.active_count() # Send commands from multiple threads rapidly - def send_commands(thread_id): - for i in range(10): + def send_commands(thread_id) -> None: + for _i in range(10): twist = TwistStamped( ts=time.time(), frame_id="base_link", diff --git a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py index 5501557820..04390c2e9e 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py @@ -22,16 +22,14 @@ import logging import os -from typing import Optional from dimos import core from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.resource import Resource -from dimos.msgs.geometry_msgs import TwistStamped, PoseStamped +from dimos.msgs.geometry_msgs import PoseStamped, TwistStamped from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.std_msgs import Int32 from dimos.msgs.tf2_msgs.TFMessage import TFMessage -from dimos.protocol.pubsub.lcmpubsub import LCM from dimos.robot.robot import Robot from dimos.robot.ros_bridge import BridgeDirection, ROSBridge from dimos.robot.unitree_webrtc.unitree_b1.connection import ( @@ -71,12 +69,12 @@ def __init__( self, ip: str = "192.168.123.14", port: int = 9090, - output_dir: str = None, - skill_library: Optional[SkillLibrary] = None, + output_dir: str | None = None, + skill_library: SkillLibrary | None = None, enable_joystick: bool = False, enable_ros_bridge: bool = True, test_mode: bool = False, - ): + ) -> None: """Initialize the B1 robot. Args: @@ -104,7 +102,7 @@ def __init__( os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") - def start(self): + def start(self) -> None: """Start the B1 robot - initialize DimOS, deploy modules, and start them.""" logger.info("Initializing DimOS...") @@ -151,7 +149,7 @@ def stop(self) -> None: if self.ros_bridge: self.ros_bridge.stop() - def _deploy_ros_bridge(self): + def _deploy_ros_bridge(self) -> None: """Deploy and configure ROS bridge (matching G1 implementation).""" self.ros_bridge = ROSBridge("b1_ros_bridge") @@ -175,7 +173,7 @@ def _deploy_ros_bridge(self): logger.info("ROS bridge deployed: /cmd_vel, /state_estimation, /tf (ROS → DIMOS)") # Robot control methods (standard interface) - def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: """Send movement command to robot using timestamped Twist. Args: @@ -185,26 +183,26 @@ def move(self, twist_stamped: TwistStamped, duration: float = 0.0): if self.connection: self.connection.move(twist_stamped, duration) - def stand(self): + def stand(self) -> None: """Put robot in stand mode.""" if self.connection: self.connection.stand() logger.info("B1 switched to STAND mode") - def walk(self): + def walk(self) -> None: """Put robot in walk mode.""" if self.connection: self.connection.walk() logger.info("B1 switched to WALK mode") - def idle(self): + def idle(self) -> None: """Put robot in idle mode.""" if self.connection: self.connection.idle() logger.info("B1 switched to IDLE mode") -def main(): +def main() -> None: """Main entry point for testing B1 robot.""" import argparse diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py index d8f6975d27..fc148c54c3 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ b/dimos/robot/unitree_webrtc/unitree_g1.py @@ -21,15 +21,12 @@ import logging import os import time -from typing import Optional from dimos_lcm.foxglove_msgs import SceneUpdate -from geometry_msgs.msg import PoseStamped as ROSPoseStamped -from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from geometry_msgs.msg import PoseStamped as ROSPoseStamped, TwistStamped as ROSTwistStamped from nav_msgs.msg import Odometry as ROSOdometry from reactivex.disposable import Disposable -from sensor_msgs.msg import Joy as ROSJoy -from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 +from sensor_msgs.msg import Joy as ROSJoy, PointCloud2 as ROSPointCloud2 from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos import core @@ -93,14 +90,16 @@ class G1ConnectionModule(Module): ip: str connection_type: str = "webrtc" - def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): + def __init__( + self, ip: str | None = None, connection_type: str = "webrtc", *args, **kwargs + ) -> None: self.ip = ip self.connection_type = connection_type self.connection = None Module.__init__(self, *args, **kwargs) @rpc - def start(self): + def start(self) -> None: """Start the connection and subscribe to sensor streams.""" super().start() @@ -118,7 +117,7 @@ def stop(self) -> None: self.connection.stop() super().stop() - def _publish_odom_pose(self, msg: Odometry): + def _publish_odom_pose(self, msg: Odometry) -> None: self.odom_pose.publish( PoseStamped( ts=msg.ts, @@ -129,7 +128,7 @@ def _publish_odom_pose(self, msg: Odometry): ) @rpc - def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: """Send movement command to robot.""" twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) self.connection.move(twist, duration) @@ -146,17 +145,17 @@ class UnitreeG1(Robot, Resource): def __init__( self, ip: str, - output_dir: str = None, + output_dir: str | None = None, websocket_port: int = 7779, - skill_library: Optional[SkillLibrary] = None, - recording_path: str = None, - replay_path: str = None, + skill_library: SkillLibrary | None = None, + recording_path: str | None = None, + replay_path: str | None = None, enable_joystick: bool = False, enable_connection: bool = True, enable_ros_bridge: bool = True, enable_perception: bool = False, enable_camera: bool = False, - ): + ) -> None: """Initialize the G1 robot. Args: @@ -206,7 +205,7 @@ def __init__( self._ros_nav = None self._setup_directories() - def _setup_directories(self): + def _setup_directories(self) -> None: """Setup directories for spatial memory storage.""" os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") @@ -225,7 +224,7 @@ def _setup_directories(self): os.makedirs(self.spatial_memory_dir, exist_ok=True) os.makedirs(self.db_path, exist_ok=True) - def _deploy_detection(self, goto): + def _deploy_detection(self, goto) -> None: detection = self._dimos.deploy( ObjectDBModule, goto=goto, camera_info=zed.CameraInfo.SingleWebcam ) @@ -254,7 +253,7 @@ def _deploy_detection(self, goto): self.detection = detection - def start(self): + def start(self) -> None: self.lcm.start() self._dimos.start() @@ -327,7 +326,7 @@ def stop(self) -> None: self._ros_nav.stop() self.lcm.stop() - def _deploy_connection(self): + def _deploy_connection(self) -> None: """Deploy and configure the connection module.""" self.connection = self._dimos.deploy(G1ConnectionModule, self.ip) @@ -336,7 +335,7 @@ def _deploy_connection(self): self.connection.odom_in.transport = core.LCMTransport("/state_estimation", Odometry) self.connection.odom_pose.transport = core.LCMTransport("/odom", PoseStamped) - def _deploy_camera(self): + def _deploy_camera(self) -> None: """Deploy and configure a standard webcam module.""" logger.info("Deploying standard webcam module...") @@ -360,7 +359,7 @@ def _deploy_camera(self): self.camera.camera_info.transport = core.LCMTransport("/camera_info", CameraInfo) logger.info("Webcam module configured") - def _deploy_visualization(self): + def _deploy_visualization(self) -> None: """Deploy and configure visualization modules.""" # Deploy WebSocket visualization module self.websocket_vis = self._dimos.deploy(WebsocketVisModule, port=self.websocket_port) @@ -377,7 +376,7 @@ def _deploy_visualization(self): ) self.foxglove_bridge.start() - def _deploy_perception(self): + def _deploy_perception(self) -> None: self.spatial_memory_module = self._dimos.deploy( SpatialMemory, collection_name=self.spatial_memory_collection, @@ -391,7 +390,7 @@ def _deploy_perception(self): logger.info("Spatial memory module deployed and connected") - def _deploy_joystick(self): + def _deploy_joystick(self) -> None: """Deploy joystick control module.""" from dimos.robot.unitree_webrtc.g1_joystick_module import G1JoystickModule @@ -400,7 +399,7 @@ def _deploy_joystick(self): self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", Twist) logger.info("Joystick module deployed - pygame window will open") - def _deploy_ros_bridge(self): + def _deploy_ros_bridge(self) -> None: """Deploy and configure ROS bridge.""" self.ros_bridge = ROSBridge("g1_ros_bridge") @@ -450,7 +449,7 @@ def _deploy_ros_bridge(self): "ROS bridge deployed: /cmd_vel, /state_estimation, /tf, /registered_scan (ROS → DIMOS)" ) - def _start_modules(self): + def _start_modules(self) -> None: """Start all deployed modules.""" self._dimos.start_all_modules() @@ -464,7 +463,7 @@ def _start_modules(self): self.skill_library.init() self.skill_library.initialize_skills() - def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0) -> None: """Send movement command to robot.""" self.connection.move(twist_stamped, duration) @@ -474,11 +473,11 @@ def get_odom(self) -> PoseStamped: return None @property - def spatial_memory(self) -> Optional[SpatialMemory]: + def spatial_memory(self) -> SpatialMemory | None: return self.spatial_memory_module -def main(): +def main() -> None: """Main entry point for testing.""" import argparse import os diff --git a/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py b/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py index 2ca937dde3..170b577c21 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py +++ b/dimos/robot/unitree_webrtc/unitree_g1_skill_container.py @@ -19,14 +19,11 @@ from __future__ import annotations -import datetime -import time -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING from dimos.core.core import rpc -from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 +from dimos.msgs.geometry_msgs import TwistStamped, Vector3 from dimos.protocol.skill.skill import skill -from dimos.protocol.skill.type import Reducer, Stream from dimos.robot.unitree_webrtc.unitree_skill_container import UnitreeSkillContainer from dimos.utils.logging_config import setup_logger @@ -68,7 +65,7 @@ class UnitreeG1SkillContainer(UnitreeSkillContainer): Inherits all Go2 skills and adds G1-specific arm controls and movement modes. """ - def __init__(self, robot: Optional[Union[UnitreeG1, UnitreeGo2]] = None): + def __init__(self, robot: UnitreeG1 | UnitreeGo2 | None = None) -> None: """Initialize the skill container with robot reference. Args: @@ -89,7 +86,7 @@ def start(self) -> None: def stop(self) -> None: super().stop() - def _generate_arm_skills(self): + def _generate_arm_skills(self) -> None: """Dynamically generate arm control skills from G1_ARM_CONTROLS list.""" logger.info(f"Generating {len(G1_ARM_CONTROLS)} G1 arm control skills") @@ -97,7 +94,7 @@ def _generate_arm_skills(self): skill_name = self._convert_to_snake_case(name) self._create_arm_skill(skill_name, data_value, description, name) - def _generate_mode_skills(self): + def _generate_mode_skills(self) -> None: """Dynamically generate movement mode skills from G1_MODE_CONTROLS list.""" logger.info(f"Generating {len(G1_MODE_CONTROLS)} G1 movement mode skills") @@ -107,7 +104,7 @@ def _generate_mode_skills(self): def _create_arm_skill( self, skill_name: str, data_value: int, description: str, original_name: str - ): + ) -> None: """Create a dynamic arm control skill method with the @skill decorator. Args: @@ -138,7 +135,7 @@ def dynamic_skill_func(self) -> str: def _create_mode_skill( self, skill_name: str, data_value: int, description: str, original_name: str - ): + ) -> None: """Create a dynamic movement mode skill method with the @skill decorator. Args: @@ -200,7 +197,7 @@ def _execute_arm_command(self, data_value: int, name: str) -> str: return f"Error: Robot not connected (cannot execute {name})" try: - result = self._robot.connection.publish_request( + self._robot.connection.publish_request( "rt/api/arm/request", {"api_id": 7106, "parameter": {"data": data_value}} ) message = f"G1 arm action {name} executed successfully (data={data_value})" @@ -222,7 +219,7 @@ def _execute_mode_command(self, data_value: int, name: str) -> str: return f"Error: Robot not connected (cannot execute {name})" try: - result = self._robot.connection.publish_request( + self._robot.connection.publish_request( "rt/api/sport/request", {"api_id": 7101, "parameter": {"data": data_value}} ) message = f"G1 mode {name} activated successfully (data={data_value})" diff --git a/dimos/robot/unitree_webrtc/unitree_go2.py b/dimos/robot/unitree_webrtc/unitree_go2.py index 7bb544f52c..b91433ead8 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2.py +++ b/dimos/robot/unitree_webrtc/unitree_go2.py @@ -20,56 +20,53 @@ import os import time import warnings -from typing import Optional +from dimos_lcm.sensor_msgs import CameraInfo +from dimos_lcm.std_msgs import Bool, String from reactivex import Observable from reactivex.disposable import CompositeDisposable from dimos import core from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE from dimos.core import In, Module, Out, rpc -from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.global_config import GlobalConfig +from dimos.core.module_coordinator import ModuleCoordinator from dimos.core.resource import Resource from dimos.mapping.types import LatLon -from dimos.msgs.std_msgs import Header -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion +from dimos.msgs.geometry_msgs import PoseStamped, Quaternion, Transform, Twist, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.msgs.sensor_msgs import Image +from dimos.msgs.std_msgs import Header from dimos.msgs.vision_msgs import Detection2DArray -from dimos_lcm.std_msgs import String -from dimos_lcm.sensor_msgs import CameraInfo -from dimos.perception.spatial_perception import SpatialMemory +from dimos.navigation.bbox_navigation import BBoxNavigationModule +from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState +from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.navigation.global_planner import AstarPlanner +from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner from dimos.perception.common.utils import ( load_camera_info, load_camera_info_opencv, rectify_image, ) +from dimos.perception.object_tracker_2d import ObjectTracker2D +from dimos.perception.spatial_perception import SpatialMemory from dimos.protocol import pubsub -from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.protocol.pubsub.lcmpubsub import LCM from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge -from dimos.utils.monitoring import UtilizationModule -from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule -from dimos.navigation.global_planner import AstarPlanner -from dimos.navigation.local_planner.holonomic_local_planner import HolonomicLocalPlanner -from dimos.navigation.bt_navigator.navigator import BehaviorTreeNavigator, NavigatorState -from dimos.navigation.frontier_exploration import WavefrontFrontierExplorer +from dimos.robot.robot import UnitreeRobot from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.robot.unitree_webrtc.type.map import Map from dimos.robot.unitree_webrtc.type.odometry import Odometry from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills from dimos.skills.skills import AbstractRobotSkill, SkillLibrary +from dimos.types.robot_capabilities import RobotCapability from dimos.utils.data import get_data from dimos.utils.logging_config import setup_logger +from dimos.utils.monitoring import UtilizationModule from dimos.utils.testing import TimedSensorReplay -from dimos.perception.object_tracker_2d import ObjectTracker2D -from dimos.navigation.bbox_navigation import BBoxNavigationModule -from dimos_lcm.std_msgs import Bool -from dimos.robot.robot import UnitreeRobot -from dimos.types.robot_capabilities import RobotCapability - +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule logger = setup_logger(__file__, level=logging.INFO) @@ -89,7 +86,7 @@ class ReplayRTC(Resource): """Replay WebRTC connection for testing with recorded data.""" - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: get_data("unitree_office_walk") # Preload data for testing def start(self) -> None: @@ -98,10 +95,10 @@ def start(self) -> None: def stop(self) -> None: pass - def standup(self): + def standup(self) -> None: print("standup suppressed") - def liedown(self): + def liedown(self) -> None: print("liedown suppressed") @functools.cache @@ -124,7 +121,7 @@ def video_stream(self): ) return video_store.stream() - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: Twist, duration: float = 0.0) -> None: pass def publish_request(self, topic: str, data: dict): @@ -157,7 +154,7 @@ def __init__( global_config: GlobalConfig | None = None, *args, **kwargs, - ): + ) -> None: cfg = global_config or GlobalConfig() self.ip = ip if ip is not None else cfg.robot_ip self.connection_type = connection_type or cfg.unitree_connection_type @@ -230,7 +227,7 @@ def stop(self) -> None: self.connection.stop() super().stop() - def _on_video(self, msg: Image): + def _on_video(self, msg: Image) -> None: """Handle incoming video frames and publish synchronized camera data.""" # Apply rectification if enabled if self.rectify_image: @@ -246,10 +243,10 @@ def _on_video(self, msg: Image): self._publish_camera_info(timestamp) self._publish_camera_pose(timestamp) - def _publish_gps_location(self, msg: LatLon): + def _publish_gps_location(self, msg: LatLon) -> None: self.gps_location.publish(msg) - def _publish_tf(self, msg): + def _publish_tf(self, msg) -> None: self._odom = msg self.odom.publish(msg) self.tf.publish(Transform.from_pose("base_link", msg)) @@ -262,12 +259,12 @@ def _publish_tf(self, msg): ) self.tf.publish(camera_link) - def _publish_camera_info(self, timestamp: float): + def _publish_camera_info(self, timestamp: float) -> None: header = Header(timestamp, "camera_link") self.lcm_camera_info.header = header self.camera_info.publish(self.lcm_camera_info) - def _publish_camera_pose(self, timestamp: float): + def _publish_camera_pose(self, timestamp: float) -> None: """Publish camera pose from TF lookup.""" try: # Look up transform from world to camera_link @@ -293,7 +290,7 @@ def _publish_camera_pose(self, timestamp: float): logger.error(f"Error publishing camera pose: {e}") @rpc - def get_odom(self) -> Optional[PoseStamped]: + def get_odom(self) -> PoseStamped | None: """Get the robot's odometry. Returns: @@ -302,7 +299,7 @@ def get_odom(self) -> Optional[PoseStamped]: return self._odom @rpc - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: Twist, duration: float = 0.0) -> None: """Send movement command to robot.""" self.connection.move(twist, duration) @@ -339,12 +336,12 @@ class UnitreeGo2(UnitreeRobot, Resource): def __init__( self, - ip: Optional[str], - output_dir: str = None, + ip: str | None, + output_dir: str | None = None, websocket_port: int = 7779, - skill_library: Optional[SkillLibrary] = None, - connection_type: Optional[str] = "webrtc", - ): + skill_library: SkillLibrary | None = None, + connection_type: str | None = "webrtc", + ) -> None: """Initialize the robot system. Args: @@ -386,7 +383,7 @@ def __init__( self._setup_directories() - def _setup_directories(self): + def _setup_directories(self) -> None: """Setup directories for spatial memory storage.""" os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") @@ -405,7 +402,7 @@ def _setup_directories(self): os.makedirs(self.spatial_memory_dir, exist_ok=True) os.makedirs(self.db_path, exist_ok=True) - def start(self): + def start(self) -> None: self.lcm.start() self._dimos.start() @@ -427,7 +424,7 @@ def stop(self) -> None: self._dimos.stop() self.lcm.stop() - def _deploy_connection(self): + def _deploy_connection(self) -> None: """Deploy and configure the connection module.""" self.connection = self._dimos.deploy( ConnectionModule, self.ip, connection_type=self.connection_type @@ -443,7 +440,7 @@ def _deploy_connection(self): self.connection.camera_info.transport = core.LCMTransport("/go2/camera_info", CameraInfo) self.connection.camera_pose.transport = core.LCMTransport("/go2/camera_pose", PoseStamped) - def _deploy_mapping(self): + def _deploy_mapping(self) -> None: """Deploy and configure the mapping module.""" min_height = 0.3 if self.connection_type == "mujoco" else 0.15 self.mapper = self._dimos.deploy( @@ -456,7 +453,7 @@ def _deploy_mapping(self): self.mapper.lidar.connect(self.connection.lidar) - def _deploy_navigation(self): + def _deploy_navigation(self) -> None: """Deploy and configure navigation modules.""" self.global_planner = self._dimos.deploy(AstarPlanner) self.local_planner = self._dimos.deploy(HolonomicLocalPlanner) @@ -501,7 +498,7 @@ def _deploy_navigation(self): self.frontier_explorer.global_costmap.connect(self.mapper.global_costmap) self.frontier_explorer.odom.connect(self.connection.odom) - def _deploy_visualization(self): + def _deploy_visualization(self) -> None: """Deploy and configure visualization modules.""" self.websocket_vis = self._dimos.deploy(WebsocketVisModule, port=self.websocket_port) self.websocket_vis.goal_request.transport = core.LCMTransport("/goal_request", PoseStamped) @@ -515,7 +512,7 @@ def _deploy_visualization(self): self.websocket_vis.path.connect(self.global_planner.path) self.websocket_vis.global_costmap.connect(self.mapper.global_costmap) - def _deploy_foxglove_bridge(self): + def _deploy_foxglove_bridge(self) -> None: self.foxglove_bridge = FoxgloveBridge( shm_channels=[ "/go2/color_image#sensor_msgs.Image", @@ -524,7 +521,7 @@ def _deploy_foxglove_bridge(self): ) self.foxglove_bridge.start() - def _deploy_perception(self): + def _deploy_perception(self) -> None: """Deploy and configure perception modules.""" # Deploy spatial memory self.spatial_memory_module = self._dimos.deploy( @@ -568,7 +565,7 @@ def _deploy_perception(self): logger.info("Object tracker and bbox navigator modules deployed") - def _deploy_camera(self): + def _deploy_camera(self) -> None: """Deploy and configure the camera module.""" # Connect object tracker inputs if self.object_tracker: @@ -582,7 +579,7 @@ def _deploy_camera(self): self.bbox_navigator.goal_request.connect(self.navigator.goal_request) logger.info("BBox navigator connected") - def _start_modules(self): + def _start_modules(self) -> None: """Start all deployed modules in the correct order.""" self._dimos.start_all_modules() @@ -596,7 +593,7 @@ def _start_modules(self): self.skill_library.init() self.skill_library.initialize_skills() - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: Twist, duration: float = 0.0) -> None: """Send movement command to robot.""" self.connection.move(twist, duration) @@ -608,7 +605,7 @@ def explore(self) -> bool: """ return self.frontier_explorer.explore() - def navigate_to(self, pose: PoseStamped, blocking: bool = True): + def navigate_to(self, pose: PoseStamped, blocking: bool = True) -> bool: """Navigate to a target pose. Args: @@ -661,7 +658,7 @@ def cancel_navigation(self) -> bool: return self.navigator.cancel_goal() @property - def spatial_memory(self) -> Optional[SpatialMemory]: + def spatial_memory(self) -> SpatialMemory | None: """Get the robot's spatial memory module. Returns: @@ -682,7 +679,7 @@ def get_odom(self) -> PoseStamped: return self.connection.get_odom() -def main(): +def main() -> None: """Main entry point.""" ip = os.getenv("ROBOT_IP") connection_type = os.getenv("CONNECTION_TYPE", "webrtc") @@ -705,4 +702,4 @@ def main(): main() -__all__ = ["ConnectionModule", "connection", "UnitreeGo2", "ReplayRTC"] +__all__ = ["ConnectionModule", "ReplayRTC", "UnitreeGo2", "connection"] diff --git a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py index af13dc20bc..4cdc8438f1 100644 --- a/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py +++ b/dimos/robot/unitree_webrtc/unitree_go2_blueprints.py @@ -14,34 +14,34 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos_lcm.sensor_msgs import CameraInfo + +from dimos.agents2.agent import llm_agent +from dimos.agents2.cli.human import human_input +from dimos.agents2.skills.navigation import navigation_skill from dimos.constants import DEFAULT_CAPACITY_COLOR_IMAGE, DEFAULT_CAPACITY_DEPTH_IMAGE from dimos.core.blueprints import autoconnect from dimos.core.transport import LCMTransport, pSHMTransport from dimos.msgs.geometry_msgs import PoseStamped from dimos.msgs.sensor_msgs import Image -from dimos_lcm.sensor_msgs import CameraInfo -from dimos.perception.spatial_perception import spatial_memory -from dimos.robot.foxglove_bridge import foxglove_bridge -from dimos.robot.unitree_webrtc.unitree_go2 import connection -from dimos.utils.monitoring import utilization -from dimos.web.websocket_vis.websocket_vis_module import websocket_vis -from dimos.navigation.global_planner import astar_planner -from dimos.navigation.local_planner.holonomic_local_planner import ( - holonomic_local_planner, -) from dimos.navigation.bt_navigator.navigator import ( behavior_tree_navigator, ) from dimos.navigation.frontier_exploration import ( wavefront_frontier_explorer, ) -from dimos.robot.unitree_webrtc.type.map import mapper -from dimos.robot.unitree_webrtc.depth_module import depth_module +from dimos.navigation.global_planner import astar_planner +from dimos.navigation.local_planner.holonomic_local_planner import ( + holonomic_local_planner, +) from dimos.perception.object_tracker import object_tracking -from dimos.agents2.agent import llm_agent -from dimos.agents2.cli.human import human_input -from dimos.agents2.skills.navigation import navigation_skill - +from dimos.perception.spatial_perception import spatial_memory +from dimos.robot.foxglove_bridge import foxglove_bridge +from dimos.robot.unitree_webrtc.depth_module import depth_module +from dimos.robot.unitree_webrtc.type.map import mapper +from dimos.robot.unitree_webrtc.unitree_go2 import connection +from dimos.utils.monitoring import utilization +from dimos.web.websocket_vis.websocket_vis_module import websocket_vis basic = ( autoconnect( diff --git a/dimos/robot/unitree_webrtc/unitree_skill_container.py b/dimos/robot/unitree_webrtc/unitree_skill_container.py index 61df7be2d7..e6179adcbb 100644 --- a/dimos/robot/unitree_webrtc/unitree_skill_container.py +++ b/dimos/robot/unitree_webrtc/unitree_skill_container.py @@ -21,16 +21,17 @@ import datetime import time -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING + +from go2_webrtc_driver.constants import RTC_TOPIC from dimos.core import Module from dimos.core.core import rpc from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.protocol.skill.skill import skill from dimos.protocol.skill.type import Reducer, Stream -from dimos.utils.logging_config import setup_logger from dimos.robot.unitree_webrtc.unitree_skills import UNITREE_WEBRTC_CONTROLS -from go2_webrtc_driver.constants import RTC_TOPIC +from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 @@ -41,7 +42,7 @@ class UnitreeSkillContainer(Module): """Container for Unitree Go2 robot skills using the new framework.""" - def __init__(self, robot: Optional[UnitreeGo2] = None): + def __init__(self, robot: UnitreeGo2 | None = None) -> None: """Initialize the skill container with robot reference. Args: @@ -62,7 +63,7 @@ def stop(self) -> None: # TODO: Do I need to clean up dynamic skills? super().stop() - def _generate_unitree_skills(self): + def _generate_unitree_skills(self) -> None: """Dynamically generate skills from the UNITREE_WEBRTC_CONTROLS list.""" logger.info(f"Generating {len(UNITREE_WEBRTC_CONTROLS)} dynamic Unitree skills") @@ -89,7 +90,7 @@ def _convert_to_snake_case(self, name: str) -> str: def _create_dynamic_skill( self, skill_name: str, api_id: int, description: str, original_name: str - ): + ) -> None: """Create a dynamic skill method with the @skill decorator. Args: @@ -161,7 +162,7 @@ def current_time(self): time.sleep(1) @skill() - def speak(self, text: str): + def speak(self, text: str) -> str: """Speak text out loud through the robot's speakers.""" return f"This is being said aloud: {text}" @@ -178,9 +179,7 @@ def _execute_sport_command(self, api_id: int, name: str) -> str: return f"Error: Robot not connected (cannot execute {name})" try: - result = self._robot.connection.publish_request( - RTC_TOPIC["SPORT_MOD"], {"api_id": api_id} - ) + self._robot.connection.publish_request(RTC_TOPIC["SPORT_MOD"], {"api_id": api_id}) message = f"{name} command executed successfully (id={api_id})" logger.info(message) return message diff --git a/dimos/robot/unitree_webrtc/unitree_skills.py b/dimos/robot/unitree_webrtc/unitree_skills.py index cb01426325..2bba4caa53 100644 --- a/dimos/robot/unitree_webrtc/unitree_skills.py +++ b/dimos/robot/unitree_webrtc/unitree_skills.py @@ -14,23 +14,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, Tuple, Union import time +from typing import TYPE_CHECKING + from pydantic import Field if TYPE_CHECKING: - from dimos.robot.robot import Robot, MockRobot + from dimos.robot.robot import MockRobot, Robot else: Robot = "Robot" MockRobot = "MockRobot" +from go2_webrtc_driver.constants import RTC_TOPIC + +from dimos.msgs.geometry_msgs import Twist, Vector3 from dimos.skills.skills import AbstractRobotSkill, AbstractSkill, SkillLibrary from dimos.types.constants import Colors -from dimos.msgs.geometry_msgs import Twist, Vector3 -from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD # Module-level constant for Unitree Go2 WebRTC control definitions -UNITREE_WEBRTC_CONTROLS: List[Tuple[str, int, str]] = [ +UNITREE_WEBRTC_CONTROLS: list[tuple[str, int, str]] = [ ("Damp", 1001, "Lowers the robot to the ground fully."), ( "BalanceStand", @@ -180,7 +182,7 @@ # Module-level constants for Unitree G1 WebRTC control definitions # G1 Arm Actions - all use api_id 7106 on topic "rt/api/arm/request" -G1_ARM_CONTROLS: List[Tuple[str, int, str]] = [ +G1_ARM_CONTROLS: list[tuple[str, int, str]] = [ ("Handshake", 27, "Perform a handshake gesture with the right hand."), ("HighFive", 18, "Give a high five with the right hand."), ("Hug", 19, "Perform a hugging gesture with both arms."), @@ -198,7 +200,7 @@ ] # G1 Movement Modes - all use api_id 7101 on topic "rt/api/sport/request" -G1_MODE_CONTROLS: List[Tuple[str, int, str]] = [ +G1_MODE_CONTROLS: list[tuple[str, int, str]] = [ ("WalkMode", 500, "Switch to normal walking mode."), ("WalkControlWaist", 501, "Switch to walking mode with waist control."), ("RunMode", 801, "Switch to running mode."), @@ -210,7 +212,7 @@ class MyUnitreeSkills(SkillLibrary): """My Unitree Skills for WebRTC interface.""" - def __init__(self, robot: Optional[Robot] = None, robot_type: str = "go2"): + def __init__(self, robot: Robot | None = None, robot_type: str = "go2") -> None: """Initialize Unitree skills library. Args: @@ -229,7 +231,7 @@ def __init__(self, robot: Optional[Robot] = None, robot_type: str = "go2"): self.register_skills(dynamic_skills) @classmethod - def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSkill"]]): + def register_skills(cls, skill_classes: AbstractSkill | list[AbstractSkill]) -> None: """Add multiple skill classes as class attributes. Args: @@ -242,28 +244,28 @@ def register_skills(cls, skill_classes: Union["AbstractSkill", list["AbstractSki # Add to the class as a skill setattr(cls, skill_class.__name__, skill_class) - def initialize_skills(self): + def initialize_skills(self) -> None: for skill_class in self.get_class_skills(): self.create_instance(skill_class.__name__, robot=self._robot) # Refresh the class skills self.refresh_class_skills() - def create_skills_live(self) -> List[AbstractRobotSkill]: + def create_skills_live(self) -> list[AbstractRobotSkill]: # ================================================ # Procedurally created skills # ================================================ class BaseUnitreeSkill(AbstractRobotSkill): """Base skill for dynamic skill creation.""" - def __call__(self): + def __call__(self) -> str: super().__call__() # For Go2: Simple api_id based call if hasattr(self, "_app_id"): string = f"{Colors.GREEN_PRINT_COLOR}Executing Go2 skill: {self.__class__.__name__} with api_id={self._app_id}{Colors.RESET_COLOR}" print(string) - result = self._robot.connection.publish_request( + self._robot.connection.publish_request( RTC_TOPIC["SPORT_MOD"], {"api_id": self._app_id} ) return f"{self.__class__.__name__} executed successfully" @@ -272,7 +274,7 @@ def __call__(self): elif hasattr(self, "_data_value"): string = f"{Colors.GREEN_PRINT_COLOR}Executing G1 skill: {self.__class__.__name__} with data={self._data_value}{Colors.RESET_COLOR}" print(string) - result = self._robot.connection.publish_request( + self._robot.connection.publish_request( self._topic, {"api_id": self._api_id, "parameter": {"data": self._data_value}}, ) @@ -333,7 +335,7 @@ class Move(AbstractRobotSkill): yaw: float = Field(default=0.0, description="Rotational velocity (rad/s)") duration: float = Field(default=0.0, description="How long to move (seconds).") - def __call__(self): + def __call__(self) -> str: self._robot.move( Twist(linear=Vector3(self.x, self.y, 0.0), angular=Vector3(0.0, 0.0, self.yaw)), duration=self.duration, @@ -345,7 +347,7 @@ class Wait(AbstractSkill): seconds: float = Field(..., description="Seconds to wait") - def __call__(self): + def __call__(self) -> str: time.sleep(self.seconds) return f"Wait completed with length={self.seconds}s" diff --git a/dimos/robot/utils/robot_debugger.py b/dimos/robot/utils/robot_debugger.py index 74c174f9cd..b3cfb195ce 100644 --- a/dimos/robot/utils/robot_debugger.py +++ b/dimos/robot/utils/robot_debugger.py @@ -21,7 +21,7 @@ class RobotDebugger(Resource): - def __init__(self, robot): + def __init__(self, robot) -> None: self._robot = robot self._threaded_server = None diff --git a/dimos/simulation/__init__.py b/dimos/simulation/__init__.py index 3d25363b30..2b77f47097 100644 --- a/dimos/simulation/__init__.py +++ b/dimos/simulation/__init__.py @@ -12,4 +12,4 @@ GenesisSimulator = None # type: ignore GenesisStream = None # type: ignore -__all__ = ["IsaacSimulator", "IsaacStream", "GenesisSimulator", "GenesisStream"] +__all__ = ["GenesisSimulator", "GenesisStream", "IsaacSimulator", "IsaacStream"] diff --git a/dimos/simulation/base/simulator_base.py b/dimos/simulation/base/simulator_base.py index 91633bb53a..777893d74c 100644 --- a/dimos/simulation/base/simulator_base.py +++ b/dimos/simulation/base/simulator_base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, List, Dict from abc import ABC, abstractmethod @@ -23,9 +22,9 @@ class SimulatorBase(ABC): def __init__( self, headless: bool = True, - open_usd: Optional[str] = None, # Keep for Isaac compatibility - entities: Optional[List[Dict[str, Union[str, dict]]]] = None, # Add for Genesis - ): + open_usd: str | None = None, # Keep for Isaac compatibility + entities: list[dict[str, str | dict]] | None = None, # Add for Genesis + ) -> None: """Initialize the simulator. Args: diff --git a/dimos/simulation/base/stream_base.py b/dimos/simulation/base/stream_base.py index d20af296e2..1fb0e86add 100644 --- a/dimos/simulation/base/stream_base.py +++ b/dimos/simulation/base/stream_base.py @@ -13,9 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Literal, Optional, Union from pathlib import Path import subprocess +from typing import Literal AnnotatorType = Literal["rgb", "normals", "bounding_box_3d", "motion_vectors"] TransportType = Literal["tcp", "udp"] @@ -35,8 +35,8 @@ def __init__( annotator_type: AnnotatorType = "rgb", transport: TransportType = "tcp", rtsp_url: str = "rtsp://mediamtx:8554/stream", - usd_path: Optional[Union[str, Path]] = None, - ): + usd_path: str | Path | None = None, + ) -> None: """Initialize the stream. Args: @@ -61,7 +61,7 @@ def __init__( self.proc = None @abstractmethod - def _load_stage(self, usd_path: Union[str, Path]): + def _load_stage(self, usd_path: str | Path): """Load stage from file.""" pass @@ -70,7 +70,7 @@ def _setup_camera(self): """Setup and validate camera.""" pass - def _setup_ffmpeg(self): + def _setup_ffmpeg(self) -> None: """Setup FFmpeg process for streaming.""" command = [ "ffmpeg", diff --git a/dimos/simulation/genesis/simulator.py b/dimos/simulation/genesis/simulator.py index e531e6b422..f3a73be08b 100644 --- a/dimos/simulation/genesis/simulator.py +++ b/dimos/simulation/genesis/simulator.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union, List, Dict + import genesis as gs # type: ignore + from ..base.simulator_base import SimulatorBase @@ -23,9 +24,9 @@ class GenesisSimulator(SimulatorBase): def __init__( self, headless: bool = True, - open_usd: Optional[str] = None, # Keep for compatibility - entities: Optional[List[Dict[str, Union[str, dict]]]] = None, - ): + open_usd: str | None = None, # Keep for compatibility + entities: list[dict[str, str | dict]] | None = None, + ) -> None: """Initialize the Genesis simulation. Args: @@ -73,7 +74,7 @@ def __init__( # Don't build scene yet - let stream add camera first self.is_built = False - def _load_entities(self, entities: List[Dict[str, Union[str, dict]]]): + def _load_entities(self, entities: list[dict[str, str | dict]]): """Load multiple entities into the scene.""" for entity in entities: entity_type = entity.get("type", "").lower() @@ -130,9 +131,9 @@ def _load_entities(self, entities: List[Dict[str, Union[str, dict]]]): raise ValueError(f"Unsupported entity type: {entity_type}") except Exception as e: - print(f"[Warning] Failed to load entity {entity}: {str(e)}") + print(f"[Warning] Failed to load entity {entity}: {e!s}") - def add_entity(self, entity_type: str, path: str = "", **params): + def add_entity(self, entity_type: str, path: str = "", **params) -> None: """Add a single entity to the scene. Args: @@ -146,13 +147,13 @@ def get_stage(self): """Get the current stage/scene.""" return self.scene - def build(self): + def build(self) -> None: """Build the scene if not already built.""" if not self.is_built: self.scene.build() self.is_built = True - def close(self): + def close(self) -> None: """Close the simulation.""" # Genesis handles cleanup automatically pass diff --git a/dimos/simulation/genesis/stream.py b/dimos/simulation/genesis/stream.py index fbb70fea13..d24b254b38 100644 --- a/dimos/simulation/genesis/stream.py +++ b/dimos/simulation/genesis/stream.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pathlib import Path +import time + import cv2 import numpy as np -import time -from typing import Optional, Union -from pathlib import Path -from ..base.stream_base import StreamBase, AnnotatorType, TransportType + +from ..base.stream_base import AnnotatorType, StreamBase, TransportType class GenesisStream(StreamBase): @@ -33,8 +34,8 @@ def __init__( annotator_type: AnnotatorType = "rgb", transport: TransportType = "tcp", rtsp_url: str = "rtsp://mediamtx:8554/stream", - usd_path: Optional[Union[str, Path]] = None, - ): + usd_path: str | Path | None = None, + ) -> None: """Initialize the Genesis stream.""" super().__init__( simulator=simulator, @@ -60,12 +61,12 @@ def __init__( # Build scene after camera is set up simulator.build() - def _load_stage(self, usd_path: Union[str, Path]): + def _load_stage(self, usd_path: str | Path) -> None: """Load stage from file.""" # Genesis handles stage loading through simulator pass - def _setup_camera(self): + def _setup_camera(self) -> None: """Setup and validate camera.""" self.camera = self.scene.add_camera( res=(self.width, self.height), @@ -75,12 +76,12 @@ def _setup_camera(self): GUI=False, ) - def _setup_annotator(self): + def _setup_annotator(self) -> None: """Setup the specified annotator.""" # Genesis handles different render types through camera.render() pass - def stream(self): + def stream(self) -> None: """Start the streaming loop.""" try: print("[Stream] Starting Genesis camera stream...") @@ -129,7 +130,7 @@ def stream(self): finally: self.cleanup() - def cleanup(self): + def cleanup(self) -> None: """Cleanup resources.""" print("[Cleanup] Stopping FFmpeg process...") if hasattr(self, "proc"): diff --git a/dimos/simulation/isaac/simulator.py b/dimos/simulation/isaac/simulator.py index ba6fe319b4..0d49b9145e 100644 --- a/dimos/simulation/isaac/simulator.py +++ b/dimos/simulation/isaac/simulator.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, List, Dict, Union + from isaacsim import SimulationApp + from ..base.simulator_base import SimulatorBase @@ -23,9 +24,9 @@ class IsaacSimulator(SimulatorBase): def __init__( self, headless: bool = True, - open_usd: Optional[str] = None, - entities: Optional[List[Dict[str, Union[str, dict]]]] = None, # Add but ignore - ): + open_usd: str | None = None, + entities: list[dict[str, str | dict]] | None = None, # Add but ignore + ) -> None: """Initialize the Isaac Sim simulation.""" super().__init__(headless, open_usd) self.app = SimulationApp({"headless": headless, "open_usd": open_usd}) @@ -37,7 +38,7 @@ def get_stage(self): self.stage = omni.usd.get_context().get_stage() return self.stage - def close(self): + def close(self) -> None: """Close the simulation.""" if hasattr(self, "app"): self.app.close() diff --git a/dimos/simulation/isaac/stream.py b/dimos/simulation/isaac/stream.py index 44560783bd..eb85ba8815 100644 --- a/dimos/simulation/isaac/stream.py +++ b/dimos/simulation/isaac/stream.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import time -from typing import Optional, Union from pathlib import Path -from ..base.stream_base import StreamBase, AnnotatorType, TransportType +import time + +import cv2 + +from ..base.stream_base import AnnotatorType, StreamBase, TransportType class IsaacStream(StreamBase): @@ -32,8 +33,8 @@ def __init__( annotator_type: AnnotatorType = "rgb", transport: TransportType = "tcp", rtsp_url: str = "rtsp://mediamtx:8554/stream", - usd_path: Optional[Union[str, Path]] = None, - ): + usd_path: str | Path | None = None, + ) -> None: """Initialize the Isaac Sim stream.""" super().__init__( simulator=simulator, @@ -59,7 +60,7 @@ def __init__( self._setup_ffmpeg() self._setup_annotator() - def _load_stage(self, usd_path: Union[str, Path]): + def _load_stage(self, usd_path: str | Path): """Load USD stage from file.""" import omni.usd @@ -80,12 +81,12 @@ def _setup_camera(self): self.camera_path, resolution=(self.width, self.height) ) - def _setup_annotator(self): + def _setup_annotator(self) -> None: """Setup the specified annotator.""" self.annotator = self.rep.AnnotatorRegistry.get_annotator(self.annotator_type) self.annotator.attach(self.render_product) - def stream(self): + def stream(self) -> None: """Start the streaming loop.""" try: print("[Stream] Starting camera stream loop...") @@ -125,7 +126,7 @@ def stream(self): finally: self.cleanup() - def cleanup(self): + def cleanup(self) -> None: """Cleanup resources.""" print("[Cleanup] Stopping FFmpeg process...") if hasattr(self, "proc"): diff --git a/dimos/simulation/mujoco/depth_camera.py b/dimos/simulation/mujoco/depth_camera.py index 3778d6f900..bb7cc34047 100644 --- a/dimos/simulation/mujoco/depth_camera.py +++ b/dimos/simulation/mujoco/depth_camera.py @@ -15,6 +15,7 @@ # limitations under the License. import math + import numpy as np import open3d as o3d diff --git a/dimos/simulation/mujoco/model.py b/dimos/simulation/mujoco/model.py index 1543a80364..12d97181b2 100644 --- a/dimos/simulation/mujoco/model.py +++ b/dimos/simulation/mujoco/model.py @@ -15,11 +15,10 @@ # limitations under the License. -import mujoco -import numpy as np from etils import epath +import mujoco from mujoco_playground._src import mjx_env - +import numpy as np from dimos.simulation.mujoco.policy import OnnxController from dimos.simulation.mujoco.types import InputController @@ -53,7 +52,7 @@ def load_model(input_device: InputController, model=None, data=None): ctrl_dt = 0.02 sim_dt = 0.01 - n_substeps = int(round(ctrl_dt / sim_dt)) + n_substeps = round(ctrl_dt / sim_dt) model.opt.timestep = sim_dt policy = OnnxController( diff --git a/dimos/simulation/mujoco/mujoco.py b/dimos/simulation/mujoco/mujoco.py index bf52277002..5e867a26d1 100644 --- a/dimos/simulation/mujoco/mujoco.py +++ b/dimos/simulation/mujoco/mujoco.py @@ -21,10 +21,9 @@ import time import mujoco +from mujoco import viewer import numpy as np import open3d as o3d -from mujoco import viewer - from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -42,7 +41,7 @@ class MujocoThread(threading.Thread): - def __init__(self): + def __init__(self) -> None: super().__init__(daemon=True) self.shared_pixels = None self.pixels_lock = threading.RLock() @@ -71,7 +70,7 @@ def __init__(self): # Register cleanup on exit atexit.register(self.cleanup) - def run(self): + def run(self) -> None: try: self.run_simulation() except Exception as e: @@ -79,7 +78,7 @@ def run(self): finally: self._cleanup_resources() - def run_simulation(self): + def run_simulation(self) -> None: self.model, self.data = load_model(self) camera_id = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_CAMERA, "head_camera") @@ -272,12 +271,12 @@ def get_odom_message(self) -> Odometry | None: ) return odom_to_publish - def _stop_move(self): + def _stop_move(self) -> None: with self._command_lock: self._command = np.zeros(3, dtype=np.float32) self._stop_timer = None - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist: Twist, duration: float = 0.0) -> None: if self._stop_timer: self._stop_timer.cancel() @@ -297,7 +296,7 @@ def get_command(self) -> np.ndarray: with self._command_lock: return self._command.copy() - def stop(self): + def stop(self) -> None: """Stop the simulation thread gracefully.""" self._is_running = False @@ -312,7 +311,7 @@ def stop(self): if self.is_alive(): logger.warning("MuJoCo thread did not stop gracefully within timeout") - def cleanup(self): + def cleanup(self) -> None: """Clean up all resources. Can be called multiple times safely.""" if self._cleanup_registered: return @@ -322,7 +321,7 @@ def cleanup(self): self.stop() self._cleanup_resources() - def _cleanup_resources(self): + def _cleanup_resources(self) -> None: """Internal method to clean up MuJoCo-specific resources.""" try: # Cancel any timers @@ -392,7 +391,7 @@ def _cleanup_resources(self): except Exception as e: logger.error(f"Error during resource cleanup: {e}") - def __del__(self): + def __del__(self) -> None: """Destructor to ensure cleanup on object deletion.""" try: self.cleanup() diff --git a/dimos/simulation/mujoco/policy.py b/dimos/simulation/mujoco/policy.py index 2ab78f6c4c..2ea974f0be 100644 --- a/dimos/simulation/mujoco/policy.py +++ b/dimos/simulation/mujoco/policy.py @@ -32,7 +32,7 @@ def __init__( n_substeps: int, action_scale: float, input_controller: InputController, - ): + ) -> None: self._output_names = ["continuous_actions"] self._policy = rt.InferenceSession(policy_path, providers=["CPUExecutionProvider"]) diff --git a/dimos/skills/kill_skill.py b/dimos/skills/kill_skill.py index f7eb63e807..b9d02729f5 100644 --- a/dimos/skills/kill_skill.py +++ b/dimos/skills/kill_skill.py @@ -19,7 +19,6 @@ particularly those running in separate threads like the monitor skill. """ -from typing import Optional from pydantic import Field from dimos.skills.skills import AbstractSkill, SkillLibrary @@ -39,7 +38,7 @@ class KillSkill(AbstractSkill): skill_name: str = Field(..., description="Name of the skill to terminate") - def __init__(self, skill_library: Optional[SkillLibrary] = None, **data): + def __init__(self, skill_library: SkillLibrary | None = None, **data) -> None: """ Initialize the kill skill. diff --git a/dimos/skills/manipulation/abstract_manipulation_skill.py b/dimos/skills/manipulation/abstract_manipulation_skill.py index 8881548540..e3f6e719fa 100644 --- a/dimos/skills/manipulation/abstract_manipulation_skill.py +++ b/dimos/skills/manipulation/abstract_manipulation_skill.py @@ -14,11 +14,9 @@ """Abstract base class for manipulation skills.""" -from typing import Optional - -from dimos.skills.skills import AbstractRobotSkill, Colors -from dimos.robot.robot import Robot from dimos.manipulation.manipulation_interface import ManipulationInterface +from dimos.robot.robot import Robot +from dimos.skills.skills import AbstractRobotSkill from dimos.types.robot_capabilities import RobotCapability @@ -28,7 +26,7 @@ class AbstractManipulationSkill(AbstractRobotSkill): This abstract class provides access to the robot's manipulation memory system. """ - def __init__(self, *args, robot: Optional[Robot] = None, **kwargs): + def __init__(self, *args, robot: Robot | None = None, **kwargs) -> None: """Initialize the manipulation skill. Args: @@ -42,7 +40,7 @@ def __init__(self, *args, robot: Optional[Robot] = None, **kwargs): ) @property - def manipulation_interface(self) -> Optional[ManipulationInterface]: + def manipulation_interface(self) -> ManipulationInterface | None: """Get the robot's manipulation interface. Returns: diff --git a/dimos/skills/manipulation/force_constraint_skill.py b/dimos/skills/manipulation/force_constraint_skill.py index d7a97287b2..72616c32a3 100644 --- a/dimos/skills/manipulation/force_constraint_skill.py +++ b/dimos/skills/manipulation/force_constraint_skill.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, List, Tuple + from pydantic import Field from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill -from dimos.skills.skills import AbstractRobotSkill from dimos.types.manipulation import ForceConstraint, Vector from dimos.utils.logging_config import setup_logger @@ -37,7 +36,7 @@ class ForceConstraintSkill(AbstractManipulationSkill): max_force: float = Field(100.0, description="Maximum force magnitude in Newtons to apply") # Force direction as (x,y) tuple - force_direction: Optional[Tuple[float, float]] = Field( + force_direction: tuple[float, float] | None = Field( None, description="Force direction vector (x,y)" ) diff --git a/dimos/skills/manipulation/manipulate_skill.py b/dimos/skills/manipulation/manipulate_skill.py index efd923f8c6..7905d4f76c 100644 --- a/dimos/skills/manipulation/manipulate_skill.py +++ b/dimos/skills/manipulation/manipulate_skill.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Dict, Any, Optional, Union import time +from typing import Any import uuid from pydantic import Field @@ -21,12 +21,9 @@ from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill from dimos.types.manipulation import ( AbstractConstraint, - TranslationConstraint, - RotationConstraint, - ForceConstraint, - ManipulationTaskConstraint, - ManipulationTask, ManipulationMetadata, + ManipulationTask, + ManipulationTaskConstraint, ) from dimos.utils.logging_config import setup_logger @@ -52,18 +49,18 @@ class Manipulate(AbstractManipulationSkill): ) # Constraints - can be set directly - constraints: List[str] = Field( + constraints: list[str] = Field( [], description="List of AbstractConstraint constraint IDs from AgentMemory to apply to the manipulation task", ) # Object movement tolerances - object_tolerances: Dict[str, float] = Field( + object_tolerances: dict[str, float] = Field( {}, # Empty dict as default description="Dictionary mapping object IDs to movement tolerances (0.0 = immovable, 1.0 = freely movable)", ) - def __call__(self) -> Dict[str, Any]: + def __call__(self) -> dict[str, Any]: """ Execute a manipulation task with the given constraints. @@ -122,7 +119,7 @@ def _build_manipulation_metadata(self) -> ManipulationMetadata: objects_by_id[obj_id] = dict(obj) # Make a copy to avoid modifying original # Create objects_data dictionary with tolerances applied - objects_data: Dict[str, Any] = {} + objects_data: dict[str, Any] = {} # First, apply all specified tolerances for object_id, tolerance in self.object_tolerances.items(): @@ -163,7 +160,7 @@ def _build_manipulation_constraint(self) -> ManipulationTaskConstraint: return constraint # TODO: Implement - def _execute_manipulation(self, task: ManipulationTask) -> Dict[str, Any]: + def _execute_manipulation(self, task: ManipulationTask) -> dict[str, Any]: """ Execute the manipulation with the given constraint. diff --git a/dimos/skills/manipulation/pick_and_place.py b/dimos/skills/manipulation/pick_and_place.py index 15570d5373..bb9cc32607 100644 --- a/dimos/skills/manipulation/pick_and_place.py +++ b/dimos/skills/manipulation/pick_and_place.py @@ -20,20 +20,21 @@ """ import json -import cv2 import os -from typing import Optional, Tuple, Dict, Any +from typing import Any + +import cv2 import numpy as np from pydantic import Field -from dimos.skills.skills import AbstractRobotSkill from dimos.models.qwen.video_query import query_single_frame +from dimos.skills.skills import AbstractRobotSkill from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.skills.manipulation.pick_and_place") -def parse_qwen_points_response(response: str) -> Optional[Tuple[Tuple[int, int], Tuple[int, int]]]: +def parse_qwen_points_response(response: str) -> tuple[tuple[int, int], tuple[int, int]] | None: """ Parse Qwen's response containing two points. @@ -75,8 +76,8 @@ def parse_qwen_points_response(response: str) -> Optional[Tuple[Tuple[int, int], def save_debug_image_with_points( image: np.ndarray, - pick_point: Optional[Tuple[int, int]] = None, - place_point: Optional[Tuple[int, int]] = None, + pick_point: tuple[int, int] | None = None, + place_point: tuple[int, int] | None = None, filename_prefix: str = "qwen_debug", ) -> str: """ @@ -133,7 +134,7 @@ def save_debug_image_with_points( return filepath -def parse_qwen_single_point_response(response: str) -> Optional[Tuple[int, int]]: +def parse_qwen_single_point_response(response: str) -> tuple[int, int] | None: """ Parse Qwen's response containing a single point. @@ -195,7 +196,7 @@ class PickAndPlace(AbstractRobotSkill): description="Natural language description of the object to pick (e.g., 'red mug', 'small box')", ) - target_query: Optional[str] = Field( + target_query: str | None = Field( None, description="Natural language description of where to place the object (e.g., 'on the table', 'in the basket'). If not provided, only pick operation will be performed.", ) @@ -204,7 +205,7 @@ class PickAndPlace(AbstractRobotSkill): "qwen2.5-vl-72b-instruct", description="Qwen model to use for visual queries" ) - def __init__(self, robot=None, **data): + def __init__(self, robot=None, **data) -> None: """ Initialize the PickAndPlace skill. @@ -214,7 +215,7 @@ def __init__(self, robot=None, **data): """ super().__init__(robot=robot, **data) - def _get_camera_frame(self) -> Optional[np.ndarray]: + def _get_camera_frame(self) -> np.ndarray | None: """ Get a single RGB frame from the robot's camera. @@ -237,7 +238,7 @@ def _get_camera_frame(self) -> Optional[np.ndarray]: def _query_pick_and_place_points( self, frame: np.ndarray - ) -> Optional[Tuple[Tuple[int, int], Tuple[int, int]]]: + ) -> tuple[tuple[int, int], tuple[int, int]] | None: """ Query Qwen to get both pick and place points in a single query. @@ -270,7 +271,7 @@ def _query_pick_and_place_points( def _query_single_point( self, frame: np.ndarray, query: str, point_type: str - ) -> Optional[Tuple[int, int]]: + ) -> tuple[int, int] | None: """ Query Qwen to get a single point location. @@ -315,7 +316,7 @@ def _query_single_point( logger.error(f"Error querying Qwen for {point_type} point: {e}") return None - def __call__(self) -> Dict[str, Any]: + def __call__(self) -> dict[str, Any]: """ Execute the pick and place operation. @@ -417,7 +418,7 @@ def __call__(self) -> Dict[str, Any]: logger.error(f"Error executing pick and place: {e}") return { "success": False, - "error": f"Execution error: {str(e)}", + "error": f"Execution error: {e!s}", "pick_point": pick_point, "place_point": place_point, } diff --git a/dimos/skills/manipulation/rotation_constraint_skill.py b/dimos/skills/manipulation/rotation_constraint_skill.py index a4973bf64d..ae1bdbb57d 100644 --- a/dimos/skills/manipulation/rotation_constraint_skill.py +++ b/dimos/skills/manipulation/rotation_constraint_skill.py @@ -12,13 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Dict, Any, Optional, Tuple, Literal +from typing import Literal + from pydantic import Field from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill from dimos.types.manipulation import RotationConstraint -from dimos.utils.logging_config import setup_logger from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger # Initialize logger logger = setup_logger("dimos.skills.rotation_constraint_skill") @@ -39,16 +40,16 @@ class RotationConstraintSkill(AbstractManipulationSkill): ) # Simple angle values for rotation (in degrees) - start_angle: Optional[float] = Field(None, description="Starting angle in degrees") - end_angle: Optional[float] = Field(None, description="Ending angle in degrees") + start_angle: float | None = Field(None, description="Starting angle in degrees") + end_angle: float | None = Field(None, description="Ending angle in degrees") # Pivot points as (x,y) tuples - pivot_point: Optional[Tuple[float, float]] = Field( + pivot_point: tuple[float, float] | None = Field( None, description="Pivot point (x,y) for rotation" ) # TODO: Secondary pivot point for more complex rotations - secondary_pivot_point: Optional[Tuple[float, float]] = Field( + secondary_pivot_point: tuple[float, float] | None = Field( None, description="Secondary pivot point (x,y) for double-pivot rotation" ) diff --git a/dimos/skills/manipulation/translation_constraint_skill.py b/dimos/skills/manipulation/translation_constraint_skill.py index 69c9f128e0..6e1808744f 100644 --- a/dimos/skills/manipulation/translation_constraint_skill.py +++ b/dimos/skills/manipulation/translation_constraint_skill.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, List, Tuple, Literal +from typing import Literal + from pydantic import Field from dimos.skills.manipulation.abstract_manipulation_skill import AbstractManipulationSkill -from dimos.skills.skills import AbstractRobotSkill from dimos.types.manipulation import TranslationConstraint, Vector from dimos.utils.logging_config import setup_logger @@ -37,19 +37,19 @@ class TranslationConstraintSkill(AbstractManipulationSkill): "x", description="Axis to translate along: 'x', 'y', or 'z'" ) - reference_point: Optional[Tuple[float, float]] = Field( + reference_point: tuple[float, float] | None = Field( None, description="Reference point (x,y) on the target object for translation constraining" ) - bounds_min: Optional[Tuple[float, float]] = Field( + bounds_min: tuple[float, float] | None = Field( None, description="Minimum bounds (x,y) for bounded translation" ) - bounds_max: Optional[Tuple[float, float]] = Field( + bounds_max: tuple[float, float] | None = Field( None, description="Maximum bounds (x,y) for bounded translation" ) - target_point: Optional[Tuple[float, float]] = Field( + target_point: tuple[float, float] | None = Field( None, description="Final target position (x,y) for translation constraining" ) diff --git a/dimos/skills/rest/rest.py b/dimos/skills/rest/rest.py index 3e7c7426cc..a8b5adfeb9 100644 --- a/dimos/skills/rest/rest.py +++ b/dimos/skills/rest/rest.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + +from pydantic import Field import requests + from dimos.skills.skills import AbstractSkill -from pydantic import Field -import logging logger = logging.getLogger(__name__) diff --git a/dimos/skills/skills.py b/dimos/skills/skills.py index cb9f979281..6eabab5dad 100644 --- a/dimos/skills/skills.py +++ b/dimos/skills/skills.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator import logging -from typing import Any, Optional -from pydantic import BaseModel +from typing import Any + from openai import pydantic_function_tool +from pydantic import BaseModel from dimos.types.constants import Colors @@ -30,14 +32,14 @@ class SkillLibrary: # ==== Flat Skill Library ==== - def __init__(self): - self.registered_skills: list["AbstractSkill"] = [] - self.class_skills: list["AbstractSkill"] = [] + def __init__(self) -> None: + self.registered_skills: list[AbstractSkill] = [] + self.class_skills: list[AbstractSkill] = [] self._running_skills = {} # {skill_name: (instance, subscription)} self.init() - def init(self): + def init(self) -> None: # Collect all skills from the parent class and update self.skills self.refresh_class_skills() @@ -74,7 +76,7 @@ def get_class_skills(self) -> list["AbstractSkill"]: return skills - def refresh_class_skills(self): + def refresh_class_skills(self) -> None: self.class_skills = self.get_class_skills() def add(self, skill: "AbstractSkill") -> None: @@ -93,7 +95,7 @@ def remove(self, skill: "AbstractSkill") -> None: def clear(self) -> None: self.registered_skills.clear() - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self.registered_skills) def __len__(self) -> int: @@ -109,7 +111,7 @@ def __getitem__(self, index): _instances: dict[str, dict] = {} - def create_instance(self, name, **kwargs): + def create_instance(self, name: str, **kwargs) -> None: # Key based only on the name key = name @@ -117,7 +119,7 @@ def create_instance(self, name, **kwargs): # Instead of creating an instance, store the args for later use self._instances[key] = kwargs - def call(self, name, **args): + def call(self, name: str, **args): try: # Get the stored args if available; otherwise, use an empty dict stored_args = self._instances.get(name, {}) @@ -144,7 +146,7 @@ def call(self, name, **args): # Call the instance directly return instance() except Exception as e: - error_msg = f"Error executing skill '{name}': {str(e)}" + error_msg = f"Error executing skill '{name}': {e!s}" logger.error(error_msg) return error_msg @@ -158,7 +160,7 @@ def get_tools(self) -> Any: def get_list_of_skills_as_json(self, list_of_skills: list["AbstractSkill"]) -> list[str]: return list(map(pydantic_function_tool, list_of_skills)) - def register_running_skill(self, name: str, instance: Any, subscription=None): + def register_running_skill(self, name: str, instance: Any, subscription=None) -> None: """ Register a running skill with its subscription. @@ -171,7 +173,7 @@ def register_running_skill(self, name: str, instance: Any, subscription=None): self._running_skills[name] = (instance, subscription) logger.info(f"Registered running skill: {name}") - def unregister_running_skill(self, name: str): + def unregister_running_skill(self, name: str) -> bool: """ Unregister a running skill. @@ -214,7 +216,7 @@ def terminate_skill(self, name: str): try: # Call the stop method if it exists if hasattr(instance, "stop") and callable(instance.stop): - result = instance.stop() + instance.stop() logger.info(f"Stopped skill: {name}") else: logger.warning(f"Skill {name} does not have a stop method") @@ -250,7 +252,7 @@ def terminate_skill(self, name: str): class AbstractSkill(BaseModel): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: print("Initializing AbstractSkill Class") super().__init__(*args, **kwargs) self._instances = {} @@ -260,7 +262,9 @@ def __init__(self, *args, **kwargs): def clone(self) -> "AbstractSkill": return AbstractSkill() - def register_as_running(self, name: str, skill_library: SkillLibrary, subscription=None): + def register_as_running( + self, name: str, skill_library: SkillLibrary, subscription=None + ) -> None: """ Register this skill as running in the skill library. @@ -271,7 +275,7 @@ def register_as_running(self, name: str, skill_library: SkillLibrary, subscripti """ skill_library.register_running_skill(name, self, subscription) - def unregister_as_running(self, name: str, skill_library: SkillLibrary): + def unregister_as_running(self, name: str, skill_library: SkillLibrary) -> None: """ Unregister this skill from the skill library. @@ -306,7 +310,7 @@ def get_list_of_skills_as_json(self, list_of_skills: list["AbstractSkill"]) -> l class AbstractRobotSkill(AbstractSkill): _robot: Robot = None - def __init__(self, *args, robot: Optional[Robot] = None, **kwargs): + def __init__(self, *args, robot: Robot | None = None, **kwargs) -> None: super().__init__(*args, **kwargs) self._robot = robot print( diff --git a/dimos/skills/speak.py b/dimos/skills/speak.py index e73b9e792a..a1e3abb078 100644 --- a/dimos/skills/speak.py +++ b/dimos/skills/speak.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.skills.skills import AbstractSkill +import queue +import threading +import time +from typing import Any + from pydantic import Field from reactivex import Subject -from typing import Optional, Any, List -import time -import threading -import queue + +from dimos.skills.skills import AbstractSkill from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.skills.speak") @@ -32,7 +34,7 @@ _queue_running = False -def _process_audio_queue(): +def _process_audio_queue() -> None: """Background thread to process audio requests sequentially""" global _queue_running @@ -55,7 +57,7 @@ def _process_audio_queue(): # Continue processing other tasks -def start_audio_queue_processor(): +def start_audio_queue_processor() -> None: """Start the background thread for processing audio requests""" global _queue_processor_thread, _queue_running @@ -77,12 +79,12 @@ class Speak(AbstractSkill): text: str = Field(..., description="Text to speak") - def __init__(self, tts_node: Optional[Any] = None, **data): + def __init__(self, tts_node: Any | None = None, **data) -> None: super().__init__(**data) self._tts_node = tts_node self._audio_complete = threading.Event() self._subscription = None - self._subscriptions: List = [] # Track all subscriptions + self._subscriptions: list = [] # Track all subscriptions def __call__(self): if not self._tts_node: @@ -93,7 +95,7 @@ def __call__(self): result_queue = queue.Queue(1) # Define the speech task to run in the audio queue - def speak_task(): + def speak_task() -> None: try: # Using a lock to ensure exclusive access to audio device with _audio_device_lock: @@ -102,12 +104,12 @@ def speak_task(): self._subscriptions = [] # This function will be called when audio processing is complete - def on_complete(): + def on_complete() -> None: logger.info(f"TTS audio playback completed for: {self.text}") self._audio_complete.set() # This function will be called if there's an error - def on_error(error): + def on_error(error) -> None: logger.error(f"Error in TTS processing: {error}") self._audio_complete.set() @@ -147,7 +149,7 @@ def on_error(error): result_queue.put(f"Spoke: {self.text} successfully") except Exception as e: logger.error(f"Error in speak task: {e}") - result_queue.put(f"Error speaking text: {str(e)}") + result_queue.put(f"Error speaking text: {e!s}") # Add our speech task to the global queue for sequential processing display_text = self.text[:50] + "..." if len(self.text) > 50 else self.text diff --git a/dimos/skills/unitree/unitree_speak.py b/dimos/skills/unitree/unitree_speak.py index f06666c30a..539ca0cd29 100644 --- a/dimos/skills/unitree/unitree_speak.py +++ b/dimos/skills/unitree/unitree_speak.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.skills.skills import AbstractRobotSkill -from pydantic import Field -import time -import tempfile -import os -import json import base64 import hashlib -import soundfile as sf +import json +import os +import tempfile +import time + +from go2_webrtc_driver.constants import RTC_TOPIC import numpy as np from openai import OpenAI +from pydantic import Field +import soundfile as sf + +from dimos.skills.skills import AbstractRobotSkill from dimos.utils.logging_config import setup_logger -from go2_webrtc_driver.constants import RTC_TOPIC logger = setup_logger("dimos.skills.unitree.unitree_speak") @@ -56,7 +58,7 @@ class UnitreeSpeak(AbstractRobotSkill): default=False, description="Use megaphone mode for lower latency (experimental)" ) - def __init__(self, **data): + def __init__(self, **data) -> None: super().__init__(**data) self._openai_client = None @@ -76,7 +78,7 @@ def _generate_audio(self, text: str) -> bytes: logger.error(f"Error generating audio: {e}") raise - def _webrtc_request(self, api_id: int, parameter: dict = None): + def _webrtc_request(self, api_id: int, parameter: dict | None = None): if parameter is None: parameter = {} @@ -109,7 +111,7 @@ def _upload_audio_to_robot(self, audio_data: bytes, filename: str) -> str: } logger.debug(f"Sending chunk {i}/{total_chunks}") - response = self._webrtc_request(AUDIO_API["UPLOAD_AUDIO_FILE"], parameter) + self._webrtc_request(AUDIO_API["UPLOAD_AUDIO_FILE"], parameter) logger.info(f"Audio upload completed for '{filename}'") @@ -146,7 +148,7 @@ def _play_audio_on_robot(self, uuid: str): logger.error(f"Error playing audio on robot: {e}") raise - def _stop_audio_playback(self): + def _stop_audio_playback(self) -> None: try: logger.debug("Stopping audio playback") self._webrtc_request(AUDIO_API["PAUSE"], {}) @@ -201,7 +203,7 @@ def _upload_and_play_megaphone(self, audio_data: bytes, duration: float): except Exception as e: logger.warning(f"Error exiting megaphone mode: {e}") - def __call__(self): + def __call__(self) -> str: super().__call__() if not self._robot: @@ -275,4 +277,4 @@ def __call__(self): except Exception as e: logger.error(f"Error in speak skill: {e}") - return f"Error speaking text: {str(e)}" + return f"Error speaking text: {e!s}" diff --git a/dimos/skills/visual_navigation_skills.py b/dimos/skills/visual_navigation_skills.py index 96e21eb92d..8064f28cc9 100644 --- a/dimos/skills/visual_navigation_skills.py +++ b/dimos/skills/visual_navigation_skills.py @@ -19,16 +19,16 @@ and navigating to specific objects using computer vision. """ -import time import logging import threading -from typing import Optional, Tuple +import time -from dimos.skills.skills import AbstractRobotSkill -from dimos.utils.logging_config import setup_logger -from dimos.perception.visual_servoing import VisualServoing from pydantic import Field + +from dimos.perception.visual_servoing import VisualServoing +from dimos.skills.skills import AbstractRobotSkill from dimos.types.vector import Vector +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.skills.visual_navigation", level=logging.DEBUG) @@ -47,11 +47,11 @@ class FollowHuman(AbstractRobotSkill): 1.5, description="Desired distance to maintain from the person in meters" ) timeout: float = Field(20.0, description="Maximum time to follow the person in seconds") - point: Optional[Tuple[int, int]] = Field( + point: tuple[int, int] | None = Field( None, description="Optional point to start tracking (x,y pixel coordinates)" ) - def __init__(self, robot=None, **data): + def __init__(self, robot=None, **data) -> None: super().__init__(robot=robot, **data) self._stop_event = threading.Event() self._visual_servoing = None @@ -129,7 +129,7 @@ def __call__(self): self._visual_servoing.stop_tracking() self._visual_servoing = None - def stop(self): + def stop(self) -> bool: """ Stop the human following process. diff --git a/dimos/spec/__init__.py b/dimos/spec/__init__.py index 06b9b2243a..03c1024d12 100644 --- a/dimos/spec/__init__.py +++ b/dimos/spec/__init__.py @@ -4,12 +4,12 @@ from dimos.spec.perception import Camera, Image, Pointcloud __all__ = [ - "Image", "Camera", - "Pointcloud", "Global3DMap", - "GlobalMap", "GlobalCostmap", + "GlobalMap", + "Image", "LocalPlanner", "Nav", + "Pointcloud", ] diff --git a/dimos/spec/perception.py b/dimos/spec/perception.py index 774492106b..1d38285d3f 100644 --- a/dimos/spec/perception.py +++ b/dimos/spec/perception.py @@ -15,8 +15,7 @@ from typing import Protocol from dimos.core import Out -from dimos.msgs.sensor_msgs import CameraInfo, PointCloud2 -from dimos.msgs.sensor_msgs import Image as ImageMsg +from dimos.msgs.sensor_msgs import CameraInfo, Image as ImageMsg, PointCloud2 class Image(Protocol): diff --git a/dimos/stream/audio/base.py b/dimos/stream/audio/base.py index a22e6606d6..43c3c13dec 100644 --- a/dimos/stream/audio/base.py +++ b/dimos/stream/audio/base.py @@ -13,8 +13,9 @@ # limitations under the License. from abc import ABC, abstractmethod -from reactivex import Observable + import numpy as np +from reactivex import Observable class AbstractAudioEmitter(ABC): @@ -58,7 +59,9 @@ class AbstractAudioTransform(AbstractAudioConsumer, AbstractAudioEmitter): class AudioEvent: """Class to represent an audio frame event with metadata.""" - def __init__(self, data: np.ndarray, sample_rate: int, timestamp: float, channels: int = 1): + def __init__( + self, data: np.ndarray, sample_rate: int, timestamp: float, channels: int = 1 + ) -> None: """ Initialize an AudioEvent. diff --git a/dimos/stream/audio/node_key_recorder.py b/dimos/stream/audio/node_key_recorder.py index 6494dcbef9..5e918bae5c 100644 --- a/dimos/stream/audio/node_key_recorder.py +++ b/dimos/stream/audio/node_key_recorder.py @@ -13,17 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List -import numpy as np -import time -import threading -import sys import select +import sys +import threading +import time + +import numpy as np from reactivex import Observable -from reactivex.subject import Subject, ReplaySubject +from reactivex.subject import ReplaySubject, Subject from dimos.stream.audio.base import AbstractAudioTransform, AudioEvent - from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.audio.key_recorder") @@ -39,7 +38,7 @@ def __init__( self, max_recording_time: float = 120.0, always_subscribe: bool = False, - ): + ) -> None: """ Initialize KeyRecorder. @@ -113,7 +112,7 @@ def emit_recording(self) -> Observable: """ return self._recording_subject - def stop(self): + def stop(self) -> None: """Stop recording and clean up resources.""" logger.info("Stopping audio recorder") @@ -131,7 +130,7 @@ def stop(self): if self._input_thread.is_alive(): self._input_thread.join(1.0) - def _input_monitor(self): + def _input_monitor(self) -> None: """Monitor for key presses to toggle recording.""" logger.info("Press Enter to start/stop recording...") @@ -148,7 +147,7 @@ def _input_monitor(self): # Sleep a bit to reduce CPU usage time.sleep(0.1) - def _start_recording(self): + def _start_recording(self) -> None: """Start recording audio and subscribe to the audio source if not always subscribed.""" if not self._audio_observable: logger.error("Cannot start recording: No audio source has been set") @@ -168,7 +167,7 @@ def _start_recording(self): self._audio_buffer = [] logger.info("Recording... (press Enter to stop)") - def _stop_recording(self): + def _stop_recording(self) -> None: """Stop recording, unsubscribe from audio source if not always subscribed, and emit the combined audio event.""" self._is_recording = False recording_duration = time.time() - self._recording_start_time @@ -188,7 +187,7 @@ def _stop_recording(self): else: logger.warning("No audio was recorded") - def _process_audio_event(self, audio_event): + def _process_audio_event(self, audio_event) -> None: """Process incoming audio events.""" # Only buffer if recording @@ -212,7 +211,7 @@ def _process_audio_event(self, audio_event): logger.warning(f"Max recording time ({self.max_recording_time}s) reached") self._stop_recording() - def _combine_audio_events(self, audio_events: List[AudioEvent]) -> AudioEvent: + def _combine_audio_events(self, audio_events: list[AudioEvent]) -> AudioEvent: """Combine multiple audio events into a single event.""" if not audio_events: logger.warning("Attempted to combine empty audio events list") @@ -287,11 +286,11 @@ def _combine_audio_events(self, audio_events: List[AudioEvent]) -> AudioEvent: logger.warning("Failed to create valid combined audio event") return None - def _handle_error(self, error): + def _handle_error(self, error) -> None: """Handle errors from the observable.""" logger.error(f"Error in audio observable: {error}") - def _handle_completion(self): + def _handle_completion(self) -> None: """Handle completion of the observable.""" logger.info("Audio observable completed") self.stop() @@ -301,9 +300,9 @@ def _handle_completion(self): from dimos.stream.audio.node_microphone import ( SounddeviceAudioSource, ) + from dimos.stream.audio.node_normalizer import AudioNormalizer from dimos.stream.audio.node_output import SounddeviceAudioOutput from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.node_normalizer import AudioNormalizer from dimos.stream.audio.utils import keepalive # Create microphone source, recorder, and audio output diff --git a/dimos/stream/audio/node_microphone.py b/dimos/stream/audio/node_microphone.py index bdb9b32180..1f4bf13499 100644 --- a/dimos/stream/audio/node_microphone.py +++ b/dimos/stream/audio/node_microphone.py @@ -13,17 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.stream.audio.base import ( - AbstractAudioEmitter, - AudioEvent, -) +import time +from typing import Any import numpy as np -from typing import Optional, List, Dict, Any from reactivex import Observable, create, disposable -import time import sounddevice as sd +from dimos.stream.audio.base import ( + AbstractAudioEmitter, + AudioEvent, +) from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.audio.node_microphone") @@ -34,12 +34,12 @@ class SounddeviceAudioSource(AbstractAudioEmitter): def __init__( self, - device_index: Optional[int] = None, + device_index: int | None = None, sample_rate: int = 16000, channels: int = 1, block_size: int = 1024, dtype: np.dtype = np.float32, - ): + ) -> None: """ Initialize SounddeviceAudioSource. @@ -69,7 +69,7 @@ def emit_audio(self) -> Observable: def on_subscribe(observer, scheduler): # Callback function to process audio data - def audio_callback(indata, frames, time_info, status): + def audio_callback(indata, frames, time_info, status) -> None: if status: logger.warning(f"Audio callback status: {status}") @@ -106,7 +106,7 @@ def audio_callback(indata, frames, time_info, status): observer.on_error(e) # Return a disposable to clean up resources - def dispose(): + def dispose() -> None: logger.info("Stopping audio capture") self._running = False if self._stream: @@ -118,7 +118,7 @@ def dispose(): return create(on_subscribe) - def get_available_devices(self) -> List[Dict[str, Any]]: + def get_available_devices(self) -> list[dict[str, Any]]: """Get a list of available audio input devices.""" return sd.query_devices() diff --git a/dimos/stream/audio/node_normalizer.py b/dimos/stream/audio/node_normalizer.py index db9557a5b1..064fc3cf6c 100644 --- a/dimos/stream/audio/node_normalizer.py +++ b/dimos/stream/audio/node_normalizer.py @@ -13,21 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from collections.abc import Callable import numpy as np from reactivex import Observable, create, disposable -from dimos.utils.logging_config import setup_logger -from dimos.stream.audio.volume import ( - calculate_rms_volume, - calculate_peak_volume, -) from dimos.stream.audio.base import ( AbstractAudioTransform, AudioEvent, ) - +from dimos.stream.audio.volume import ( + calculate_peak_volume, + calculate_rms_volume, +) +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.stream.audio.node_normalizer") @@ -48,7 +47,7 @@ def __init__( decay_factor: float = 0.999, adapt_speed: float = 0.05, volume_func: Callable[[np.ndarray], float] = calculate_peak_volume, - ): + ) -> None: """ Initialize AudioNormalizer. @@ -156,7 +155,7 @@ def on_subscribe(observer, scheduler): ) # Return a disposable to clean up resources - def dispose(): + def dispose() -> None: logger.info("Stopping audio normalizer") audio_subscription.dispose() @@ -167,12 +166,13 @@ def dispose(): if __name__ == "__main__": import sys + from dimos.stream.audio.node_microphone import ( SounddeviceAudioSource, ) + from dimos.stream.audio.node_output import SounddeviceAudioOutput from dimos.stream.audio.node_simulated import SimulatedAudioSource from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.node_output import SounddeviceAudioOutput from dimos.stream.audio.utils import keepalive # Parse command line arguments diff --git a/dimos/stream/audio/node_output.py b/dimos/stream/audio/node_output.py index ee2e2c5ec2..3dc93d3757 100644 --- a/dimos/stream/audio/node_output.py +++ b/dimos/stream/audio/node_output.py @@ -13,15 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, List, Dict, Any +from typing import Any + import numpy as np -import sounddevice as sd from reactivex import Observable +import sounddevice as sd -from dimos.utils.logging_config import setup_logger from dimos.stream.audio.base import ( AbstractAudioTransform, ) +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.stream.audio.node_output") @@ -37,12 +38,12 @@ class SounddeviceAudioOutput(AbstractAudioTransform): def __init__( self, - device_index: Optional[int] = None, + device_index: int | None = None, sample_rate: int = 16000, channels: int = 1, block_size: int = 1024, dtype: np.dtype = np.float32, - ): + ) -> None: """ Initialize SounddeviceAudioOutput. @@ -118,7 +119,7 @@ def emit_audio(self) -> Observable: return self.audio_observable - def stop(self): + def stop(self) -> None: """Stop audio output and clean up resources.""" logger.info("Stopping audio output") self._running = False @@ -132,7 +133,7 @@ def stop(self): self._stream.close() self._stream = None - def _play_audio_event(self, audio_event): + def _play_audio_event(self, audio_event) -> None: """Play audio from an AudioEvent.""" if not self._running or not self._stream: return @@ -150,11 +151,11 @@ def _play_audio_event(self, audio_event): except Exception as e: logger.error(f"Error playing audio: {e}") - def _handle_error(self, error): + def _handle_error(self, error) -> None: """Handle errors from the observable.""" logger.error(f"Error in audio observable: {error}") - def _handle_completion(self): + def _handle_completion(self) -> None: """Handle completion of the observable.""" logger.info("Audio observable completed") self._running = False @@ -163,7 +164,7 @@ def _handle_completion(self): self._stream.close() self._stream = None - def get_available_devices(self) -> List[Dict[str, Any]]: + def get_available_devices(self) -> list[dict[str, Any]]: """Get a list of available audio output devices.""" return sd.query_devices() diff --git a/dimos/stream/audio/node_simulated.py b/dimos/stream/audio/node_simulated.py index c9aff9a32d..82de718ced 100644 --- a/dimos/stream/audio/node_simulated.py +++ b/dimos/stream/audio/node_simulated.py @@ -12,15 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import threading +import time + +import numpy as np +from reactivex import Observable, create, disposable + from dimos.stream.audio.abstract import ( AbstractAudioEmitter, AudioEvent, ) -import numpy as np -from reactivex import Observable, create, disposable -import threading -import time - from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.stream.audio.node_simulated") @@ -40,7 +41,7 @@ def __init__( modulation_rate: float = 0.5, # Modulation rate in Hz volume_oscillation: bool = True, # Enable sinusoidal volume changes volume_oscillation_rate: float = 0.2, # Volume oscillation rate in Hz - ): + ) -> None: """ Initialize SimulatedAudioSource. @@ -132,7 +133,7 @@ def _generate_sine_wave(self, time_points: np.ndarray) -> np.ndarray: return wave - def _audio_thread(self, observer, interval: float): + def _audio_thread(self, observer, interval: float) -> None: """Thread function for simulated audio generation.""" try: sample_index = 0 @@ -197,7 +198,7 @@ def on_subscribe(observer, scheduler): ) # Return a disposable to clean up - def dispose(): + def dispose() -> None: logger.info("Stopping simulated audio") self._running = False if self._thread and self._thread.is_alive(): @@ -209,9 +210,9 @@ def dispose(): if __name__ == "__main__": - from dimos.stream.audio.utils import keepalive - from dimos.stream.audio.node_volume_monitor import monitor from dimos.stream.audio.node_output import SounddeviceAudioOutput + from dimos.stream.audio.node_volume_monitor import monitor + from dimos.stream.audio.utils import keepalive source = SimulatedAudioSource() speaker = SounddeviceAudioOutput() diff --git a/dimos/stream/audio/node_volume_monitor.py b/dimos/stream/audio/node_volume_monitor.py index 6510667307..e1c5b226a4 100644 --- a/dimos/stream/audio/node_volume_monitor.py +++ b/dimos/stream/audio/node_volume_monitor.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable +from collections.abc import Callable + from reactivex import Observable, create, disposable -from dimos.stream.audio.base import AudioEvent, AbstractAudioConsumer +from dimos.stream.audio.base import AbstractAudioConsumer, AudioEvent from dimos.stream.audio.text.base import AbstractTextEmitter from dimos.stream.audio.text.node_stdout import TextPrinterNode from dimos.stream.audio.volume import calculate_peak_volume @@ -35,7 +36,7 @@ def __init__( threshold: float = 0.01, bar_length: int = 50, volume_func: Callable = calculate_peak_volume, - ): + ) -> None: """ Initialize VolumeMonitorNode. @@ -101,7 +102,7 @@ def on_subscribe(observer, scheduler): logger.info(f"Starting volume monitor (method: {self.func_name})") # Subscribe to the audio source - def on_audio_event(event: AudioEvent): + def on_audio_event(event: AudioEvent) -> None: try: # Calculate volume volume = self.volume_func(event.data) @@ -123,7 +124,7 @@ def on_audio_event(event: AudioEvent): ) # Return a disposable to clean up resources - def dispose(): + def dispose() -> None: logger.info("Stopping volume monitor") subscription.dispose() @@ -167,8 +168,8 @@ def monitor( if __name__ == "__main__": - from utils import keepalive from audio.node_simulated import SimulatedAudioSource + from utils import keepalive # Use the monitor function to create and connect the nodes volume_monitor = monitor(SimulatedAudioSource().emit_audio()) diff --git a/dimos/stream/audio/pipelines.py b/dimos/stream/audio/pipelines.py index ee2ae43316..ceaeb80fac 100644 --- a/dimos/stream/audio/pipelines.py +++ b/dimos/stream/audio/pipelines.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos.stream.audio.node_key_recorder import KeyRecorder from dimos.stream.audio.node_microphone import SounddeviceAudioSource from dimos.stream.audio.node_normalizer import AudioNormalizer -from dimos.stream.audio.node_volume_monitor import monitor -from dimos.stream.audio.node_key_recorder import KeyRecorder from dimos.stream.audio.node_output import SounddeviceAudioOutput +from dimos.stream.audio.node_volume_monitor import monitor from dimos.stream.audio.stt.node_whisper import WhisperNode -from dimos.stream.audio.tts.node_openai import OpenAITTSNode, Voice from dimos.stream.audio.text.node_stdout import TextPrinterNode +from dimos.stream.audio.tts.node_openai import OpenAITTSNode, Voice def stt(): diff --git a/dimos/stream/audio/stt/node_whisper.py b/dimos/stream/audio/stt/node_whisper.py index b5d8cc8a7b..05ec5274c8 100644 --- a/dimos/stream/audio/stt/node_whisper.py +++ b/dimos/stream/audio/stt/node_whisper.py @@ -13,13 +13,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any + from reactivex import Observable, create, disposable import whisper from dimos.stream.audio.base import ( - AudioEvent, AbstractAudioConsumer, + AudioEvent, ) from dimos.stream.audio.text.base import AbstractTextEmitter from dimos.utils.logging_config import setup_logger @@ -35,8 +36,10 @@ class WhisperNode(AbstractAudioConsumer, AbstractTextEmitter): def __init__( self, model: str = "base", - modelopts: Dict[str, Any] = {"language": "en", "fp16": False}, - ): + modelopts: dict[str, Any] | None = None, + ) -> None: + if modelopts is None: + modelopts = {"language": "en", "fp16": False} self.audio_observable = None self.modelopts = modelopts self.model = whisper.load_model(model) @@ -68,7 +71,7 @@ def on_subscribe(observer, scheduler): logger.info("Starting Whisper transcription service") # Subscribe to the audio source - def on_audio_event(event: AudioEvent): + def on_audio_event(event: AudioEvent) -> None: try: result = self.model.transcribe(event.data.flatten(), **self.modelopts) observer.on_next(result["text"].strip()) @@ -84,7 +87,7 @@ def on_audio_event(event: AudioEvent): ) # Return a disposable to clean up resources - def dispose(): + def dispose() -> None: subscription.dispose() return disposable.Disposable(dispose) @@ -93,13 +96,13 @@ def dispose(): if __name__ == "__main__": + from dimos.stream.audio.node_key_recorder import KeyRecorder from dimos.stream.audio.node_microphone import ( SounddeviceAudioSource, ) + from dimos.stream.audio.node_normalizer import AudioNormalizer from dimos.stream.audio.node_output import SounddeviceAudioOutput from dimos.stream.audio.node_volume_monitor import monitor - from dimos.stream.audio.node_normalizer import AudioNormalizer - from dimos.stream.audio.node_key_recorder import KeyRecorder from dimos.stream.audio.text.node_stdout import TextPrinterNode from dimos.stream.audio.tts.node_openai import OpenAITTSNode from dimos.stream.audio.utils import keepalive diff --git a/dimos/stream/audio/text/base.py b/dimos/stream/audio/text/base.py index fc27bfa901..b7305c0bcc 100644 --- a/dimos/stream/audio/text/base.py +++ b/dimos/stream/audio/text/base.py @@ -13,6 +13,7 @@ # limitations under the License. from abc import ABC, abstractmethod + from reactivex import Observable diff --git a/dimos/stream/audio/text/node_stdout.py b/dimos/stream/audio/text/node_stdout.py index dea454d294..b0a5fd4ac8 100644 --- a/dimos/stream/audio/text/node_stdout.py +++ b/dimos/stream/audio/text/node_stdout.py @@ -14,6 +14,7 @@ # limitations under the License. from reactivex import Observable + from dimos.stream.audio.text.base import AbstractTextConsumer from dimos.utils.logging_config import setup_logger @@ -25,7 +26,7 @@ class TextPrinterNode(AbstractTextConsumer): A node that subscribes to a text observable and prints the text. """ - def __init__(self, prefix: str = "", suffix: str = "", end: str = "\n"): + def __init__(self, prefix: str = "", suffix: str = "", end: str = "\n") -> None: """ Initialize TextPrinterNode. @@ -72,6 +73,7 @@ def consume_text(self, text_observable: Observable) -> "AbstractTextConsumer": if __name__ == "__main__": import time + from reactivex import Subject # Create a simple text subject that we can push values to diff --git a/dimos/stream/audio/tts/node_openai.py b/dimos/stream/audio/tts/node_openai.py index f65e0d50e2..211b2b0246 100644 --- a/dimos/stream/audio/tts/node_openai.py +++ b/dimos/stream/audio/tts/node_openai.py @@ -13,21 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from enum import Enum +import io import threading import time -from enum import Enum -from typing import Optional + +from openai import OpenAI from reactivex import Observable, Subject -import io import soundfile as sf -from openai import OpenAI -from dimos.stream.audio.text.base import AbstractTextConsumer, AbstractTextEmitter from dimos.stream.audio.base import ( AbstractAudioEmitter, AudioEvent, ) - +from dimos.stream.audio.text.base import AbstractTextConsumer, AbstractTextEmitter from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.stream.audio.tts.openai") @@ -55,12 +54,12 @@ class OpenAITTSNode(AbstractTextConsumer, AbstractAudioEmitter, AbstractTextEmit def __init__( self, - api_key: Optional[str] = None, + api_key: str | None = None, voice: Voice = Voice.ECHO, model: str = "tts-1", buffer_size: int = 1024, speed: float = 1.0, - ): + ) -> None: """ Initialize OpenAITTSNode. @@ -219,10 +218,12 @@ def dispose(self) -> None: if __name__ == "__main__": import time - from dimos.stream.audio.utils import keepalive + from reactivex import Subject + from dimos.stream.audio.node_output import SounddeviceAudioOutput from dimos.stream.audio.text.node_stdout import TextPrinterNode + from dimos.stream.audio.utils import keepalive # Create a simple text subject that we can push values to text_subject = Subject() @@ -247,7 +248,7 @@ def dispose(self) -> None: print("Starting OpenAI TTS test...") print("-" * 60) - for i, message in enumerate(test_messages): + for _i, message in enumerate(test_messages): text_subject.on_next(message) keepalive() diff --git a/dimos/stream/audio/tts/node_pytts.py b/dimos/stream/audio/tts/node_pytts.py index 818371a0f1..f1543331ef 100644 --- a/dimos/stream/audio/tts/node_pytts.py +++ b/dimos/stream/audio/tts/node_pytts.py @@ -13,11 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from reactivex import Observable, Subject import pyttsx3 +from reactivex import Observable, Subject from dimos.stream.audio.text.abstract import AbstractTextTransform - from dimos.utils.logging_config import setup_logger logger = setup_logger(__name__) @@ -31,7 +30,7 @@ class PyTTSNode(AbstractTextTransform): text observables, allowing it to be inserted into a text processing pipeline. """ - def __init__(self, rate: int = 200, volume: float = 1.0): + def __init__(self, rate: int = 200, volume: float = 1.0) -> None: """ Initialize PyTTSNode. diff --git a/dimos/stream/audio/utils.py b/dimos/stream/audio/utils.py index 712086ffd6..1a2991467c 100644 --- a/dimos/stream/audio/utils.py +++ b/dimos/stream/audio/utils.py @@ -15,7 +15,7 @@ import time -def keepalive(): +def keepalive() -> None: try: # Keep the program running print("Press Ctrl+C to exit") diff --git a/dimos/stream/audio/volume.py b/dimos/stream/audio/volume.py index f2e50ab72c..bd137172b3 100644 --- a/dimos/stream/audio/volume.py +++ b/dimos/stream/audio/volume.py @@ -69,6 +69,7 @@ def calculate_peak_volume(audio_data: np.ndarray) -> float: if __name__ == "__main__": # Example usage import time + from .node_simulated import SimulatedAudioSource # Create a simulated audio source @@ -77,7 +78,7 @@ def calculate_peak_volume(audio_data: np.ndarray) -> float: # Create observable and subscribe to get a single frame audio_observable = audio_source.capture_audio_as_observable() - def process_frame(frame): + def process_frame(frame) -> None: # Calculate and print both RMS and peak volumes rms_vol = calculate_rms_volume(frame.data) peak_vol = calculate_peak_volume(frame.data) @@ -89,7 +90,7 @@ def process_frame(frame): # Set a flag to track when processing is complete processed = {"done": False} - def process_frame_wrapper(frame): + def process_frame_wrapper(frame) -> None: # Process the frame process_frame(frame) # Mark as processed diff --git a/dimos/stream/data_provider.py b/dimos/stream/data_provider.py index 73e1ba0f20..f931857fda 100644 --- a/dimos/stream/data_provider.py +++ b/dimos/stream/data_provider.py @@ -13,14 +13,13 @@ # limitations under the License. from abc import ABC -from reactivex import Subject, Observable -from reactivex.subject import Subject -from reactivex.scheduler import ThreadPoolScheduler -import multiprocessing import logging +import multiprocessing import reactivex as rx -from reactivex import operators as ops +from reactivex import Observable, Subject, operators as ops +from reactivex.scheduler import ThreadPoolScheduler +from reactivex.subject import Subject logging.basicConfig(level=logging.INFO) @@ -31,7 +30,7 @@ class AbstractDataProvider(ABC): """Abstract base class for data providers using ReactiveX.""" - def __init__(self, dev_name: str = "NA"): + def __init__(self, dev_name: str = "NA") -> None: self.dev_name = dev_name self._data_subject = Subject() # Regular Subject, no initial None value @@ -40,11 +39,11 @@ def data_stream(self) -> Observable: """Get the data stream observable.""" return self._data_subject - def push_data(self, data): + def push_data(self, data) -> None: """Push new data to the stream.""" self._data_subject.on_next(data) - def dispose(self): + def dispose(self) -> None: """Cleanup resources.""" self._data_subject.dispose() @@ -52,17 +51,17 @@ def dispose(self): class ROSDataProvider(AbstractDataProvider): """ReactiveX data provider for ROS topics.""" - def __init__(self, dev_name: str = "ros_provider"): + def __init__(self, dev_name: str = "ros_provider") -> None: super().__init__(dev_name) self.logger = logging.getLogger(dev_name) - def push_data(self, data): + def push_data(self, data) -> None: """Push new data to the stream.""" print(f"ROSDataProvider pushing data of type: {type(data)}") super().push_data(data) print("Data pushed to subject") - def capture_data_as_observable(self, fps: int = None) -> Observable: + def capture_data_as_observable(self, fps: int | None = None) -> Observable: """Get the data stream as an observable. Args: @@ -115,7 +114,7 @@ class QueryDataProvider(AbstractDataProvider): logger (logging.Logger): Logger instance for logging messages. """ - def __init__(self, dev_name: str = "query_provider"): + def __init__(self, dev_name: str = "query_provider") -> None: """ Initializes the QueryDataProvider. @@ -127,7 +126,7 @@ def __init__(self, dev_name: str = "query_provider"): def start_query_stream( self, - query_template: str = None, + query_template: str | None = None, frequency: float = 3.0, start_count: int = 0, end_count: int = 5000, diff --git a/dimos/stream/frame_processor.py b/dimos/stream/frame_processor.py index b07a09118b..fda13ece61 100644 --- a/dimos/stream/frame_processor.py +++ b/dimos/stream/frame_processor.py @@ -12,17 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os + import cv2 import numpy as np -import os -from reactivex import Observable -from reactivex import operators as ops -from typing import Tuple, Optional +from reactivex import Observable, operators as ops # TODO: Reorganize, filenaming - Consider merger with VideoOperators class class FrameProcessor: - def __init__(self, output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=False): + def __init__( + self, output_dir: str = f"{os.getcwd()}/assets/output/frames", delete_on_init: bool = False + ) -> None: """Initializes the FrameProcessor. Sets up the output directory for frame storage and optionally cleans up @@ -65,10 +66,10 @@ def to_grayscale(self, frame): def edge_detection(self, frame): return cv2.Canny(frame, 100, 200) - def resize(self, frame, scale=0.5): + def resize(self, frame, scale: float = 0.5): return cv2.resize(frame, None, fx=scale, fy=scale, interpolation=cv2.INTER_AREA) - def export_to_jpeg(self, frame, save_limit=100, loop=False, suffix=""): + def export_to_jpeg(self, frame, save_limit: int = 100, loop: bool = False, suffix: str = ""): if frame is None: print("Error: Attempted to save a None image.") return None @@ -92,10 +93,10 @@ def export_to_jpeg(self, frame, save_limit=100, loop=False, suffix=""): def compute_optical_flow( self, - acc: Tuple[np.ndarray, np.ndarray, Optional[float]], + acc: tuple[np.ndarray, np.ndarray, float | None], current_frame: np.ndarray, compute_relevancy: bool = True, - ) -> Tuple[np.ndarray, np.ndarray, Optional[float]]: + ) -> tuple[np.ndarray, np.ndarray, float | None]: """Computes optical flow between consecutive frames. Uses the Farneback algorithm to compute dense optical flow between the @@ -121,7 +122,7 @@ def compute_optical_flow( ValueError: If input frames have invalid dimensions or types. TypeError: If acc is not a tuple of correct types. """ - prev_frame, prev_flow, prev_relevancy = acc + prev_frame, _prev_flow, _prev_relevancy = acc if prev_frame is None: return (current_frame, None, None) diff --git a/dimos/stream/ros_video_provider.py b/dimos/stream/ros_video_provider.py index 7ca6fa4aa7..5182ca79f8 100644 --- a/dimos/stream/ros_video_provider.py +++ b/dimos/stream/ros_video_provider.py @@ -18,13 +18,12 @@ and makes them available as an Observable stream. """ -from reactivex import Subject, Observable -from reactivex import operators as ops -from reactivex.scheduler import ThreadPoolScheduler import logging import time -from typing import Optional + import numpy as np +from reactivex import Observable, Subject, operators as ops +from reactivex.scheduler import ThreadPoolScheduler from dimos.stream.video_provider import AbstractVideoProvider @@ -44,8 +43,8 @@ class ROSVideoProvider(AbstractVideoProvider): """ def __init__( - self, dev_name: str = "ros_video", pool_scheduler: Optional[ThreadPoolScheduler] = None - ): + self, dev_name: str = "ros_video", pool_scheduler: ThreadPoolScheduler | None = None + ) -> None: """Initialize the ROS video provider. Args: diff --git a/dimos/stream/rtsp_video_provider.py b/dimos/stream/rtsp_video_provider.py index 5926c4f676..3aeb651a4d 100644 --- a/dimos/stream/rtsp_video_provider.py +++ b/dimos/stream/rtsp_video_provider.py @@ -17,7 +17,6 @@ import subprocess import threading import time -from typing import Optional import ffmpeg # ffmpeg-python wrapper import numpy as np @@ -44,7 +43,7 @@ class RtspVideoProvider(AbstractVideoProvider): """ def __init__( - self, dev_name: str, rtsp_url: str, pool_scheduler: Optional[ThreadPoolScheduler] = None + self, dev_name: str, rtsp_url: str, pool_scheduler: ThreadPoolScheduler | None = None ) -> None: """Initializes the RTSP video provider. @@ -56,7 +55,7 @@ def __init__( super().__init__(dev_name, pool_scheduler) self.rtsp_url = rtsp_url # Holds the currently active ffmpeg process Popen object - self._ffmpeg_process: Optional[subprocess.Popen] = None + self._ffmpeg_process: subprocess.Popen | None = None # Lock to protect access to the ffmpeg process object self._lock = threading.Lock() @@ -170,11 +169,11 @@ def capture_video_as_observable(self, fps: int = 0) -> Observable: def emit_frames(observer, scheduler): """Function executed by rx.create to emit frames.""" - process: Optional[subprocess.Popen] = None + process: subprocess.Popen | None = None # Event to signal the processing loop should stop (e.g., on dispose) should_stop = threading.Event() - def cleanup_process(): + def cleanup_process() -> None: """Safely terminate the ffmpeg process if it's running.""" nonlocal process logger.debug(f"({self.dev_name}) Cleanup requested.") diff --git a/dimos/stream/stream_merger.py b/dimos/stream/stream_merger.py index 6f854b2d80..b59c78fa96 100644 --- a/dimos/stream/stream_merger.py +++ b/dimos/stream/stream_merger.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, TypeVar, Tuple -from reactivex import Observable -from reactivex import operators as ops +from typing import TypeVar + +from reactivex import Observable, operators as ops T = TypeVar("T") Q = TypeVar("Q") @@ -22,7 +22,7 @@ def create_stream_merger( data_input_stream: Observable[T], text_query_stream: Observable[Q] -) -> Observable[Tuple[Q, List[T]]]: +) -> Observable[tuple[Q, list[T]]]: """ Creates a merged stream that combines the latest value from data_input_stream with each value from text_query_stream. diff --git a/dimos/stream/video_operators.py b/dimos/stream/video_operators.py index 78ba7518a1..d7299f3dce 100644 --- a/dimos/stream/video_operators.py +++ b/dimos/stream/video_operators.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import base64 +from collections.abc import Callable from datetime import datetime, timedelta +from enum import Enum +from typing import TYPE_CHECKING, Any + import cv2 import numpy as np -from reactivex import Observable, Observer, create -from reactivex import operators as ops -from typing import Any, Callable, Tuple, Optional - +from reactivex import Observable, Observer, create, operators as ops import zmq -import base64 -from enum import Enum -from dimos.stream.frame_processor import FrameProcessor +if TYPE_CHECKING: + from dimos.stream.frame_processor import FrameProcessor class VideoOperators: @@ -31,7 +32,7 @@ class VideoOperators: @staticmethod def with_fps_sampling( - fps: int = 25, *, sample_interval: Optional[timedelta] = None, use_latest: bool = True + fps: int = 25, *, sample_interval: timedelta | None = None, use_latest: bool = True ) -> Callable[[Observable], Observable]: """Creates an operator that samples frames at a specified rate. @@ -214,9 +215,9 @@ def with_optical_flow( @staticmethod def with_zmq_socket( - socket: zmq.Socket, scheduler: Optional[Any] = None + socket: zmq.Socket, scheduler: Any | None = None ) -> Callable[[Observable], Observable]: - def send_frame(frame, socket): + def send_frame(frame, socket) -> None: _, img_encoded = cv2.imencode(".jpg", frame) socket.send(img_encoded.tobytes()) # print(f"Frame received: {frame.shape}") @@ -243,7 +244,7 @@ def encode_image() -> Callable[[Observable], Observable]: """ def _operator(source: Observable) -> Observable: - def _encode_image(image: np.ndarray) -> Tuple[str, Tuple[int, int]]: + def _encode_image(image: np.ndarray) -> tuple[str, tuple[int, int]]: try: width, height = image.shape[:2] _, buffer = cv2.imencode(".jpg", image) @@ -259,10 +260,11 @@ def _encode_image(image: np.ndarray) -> Tuple[str, Tuple[int, int]]: return _operator -from reactivex.disposable import Disposable -from reactivex import Observable from threading import Lock +from reactivex import Observable +from reactivex.disposable import Disposable + class Operators: @staticmethod @@ -282,13 +284,13 @@ def _subscribe(observer, scheduler=None): upstream_disp = None active_inner_disp = None - def dispose_all(): + def dispose_all() -> None: if upstream_disp: upstream_disp.dispose() if active_inner_disp: active_inner_disp.dispose() - def on_next(value): + def on_next(value) -> None: nonlocal in_flight, active_inner_disp lock.acquire() try: @@ -308,16 +310,16 @@ def on_next(value): observer.on_error(ex) return - def inner_on_next(ivalue): + def inner_on_next(ivalue) -> None: observer.on_next(ivalue) - def inner_on_error(err): + def inner_on_error(err) -> None: nonlocal in_flight with lock: in_flight = False observer.on_error(err) - def inner_on_completed(): + def inner_on_completed() -> None: nonlocal in_flight with lock: in_flight = False @@ -333,11 +335,11 @@ def inner_on_completed(): scheduler=scheduler, ) - def on_error(err): + def on_error(err) -> None: dispose_all() observer.on_error(err) - def on_completed(): + def on_completed() -> None: nonlocal upstream_done with lock: upstream_done = True @@ -370,13 +372,13 @@ def _subscribe(observer, scheduler=None): upstream_disp = None active_inner_disp = None - def dispose_all(): + def dispose_all() -> None: if upstream_disp: upstream_disp.dispose() if active_inner_disp: active_inner_disp.dispose() - def on_next(value): + def on_next(value) -> None: nonlocal in_flight, active_inner_disp with lock: # If not busy, claim the slot @@ -395,17 +397,17 @@ def on_next(value): observer.on_error(ex) return - def inner_on_next(ivalue): + def inner_on_next(ivalue) -> None: observer.on_next(ivalue) - def inner_on_error(err): + def inner_on_error(err) -> None: nonlocal in_flight with lock: in_flight = False print("\033[34mError in inner on error.\033[0m") observer.on_error(err) - def inner_on_completed(): + def inner_on_completed() -> None: nonlocal in_flight with lock: in_flight = False @@ -422,11 +424,11 @@ def inner_on_completed(): scheduler=scheduler, ) - def on_error(e): + def on_error(e) -> None: dispose_all() observer.on_error(e) - def on_completed(): + def on_completed() -> None: nonlocal upstream_done with lock: upstream_done = True @@ -453,7 +455,7 @@ def _exhaust_map(source: Observable): def subscribe(observer, scheduler=None): is_processing = False - def on_next(item): + def on_next(item) -> None: nonlocal is_processing if not is_processing: is_processing = True @@ -471,7 +473,7 @@ def on_next(item): else: print("\033[35mSkipping item, already processing.\033[0m") - def set_not_processing(): + def set_not_processing() -> None: nonlocal is_processing is_processing = False print("\033[35mItem processed.\033[0m") @@ -491,7 +493,7 @@ def set_not_processing(): def with_lock(lock: Lock): def operator(source: Observable): def subscribe(observer, scheduler=None): - def on_next(item): + def on_next(item) -> None: if not lock.locked(): # Check if the lock is free if lock.acquire(blocking=False): # Non-blocking acquire try: @@ -504,10 +506,10 @@ def on_next(item): else: print("\033[34mLock busy, skipping item.\033[0m") - def on_error(error): + def on_error(error) -> None: observer.on_error(error) - def on_completed(): + def on_completed() -> None: observer.on_completed() return source.subscribe( @@ -525,7 +527,7 @@ def on_completed(): def with_lock_check(lock: Lock): # Renamed for clarity def operator(source: Observable): def subscribe(observer, scheduler=None): - def on_next(item): + def on_next(item) -> None: if not lock.locked(): # Check if the lock is held WITHOUT acquiring print(f"\033[32mLock is free, processing item: {item}\033[0m") observer.on_next(item) @@ -533,10 +535,10 @@ def on_next(item): print(f"\033[34mLock is busy, skipping item: {item}\033[0m") # observer.on_completed() - def on_error(error): + def on_error(error) -> None: observer.on_error(error) - def on_completed(): + def on_completed() -> None: observer.on_completed() return source.subscribe( @@ -565,7 +567,7 @@ class PrintColor(Enum): def print_emission( id: str, dev_name: str = "NA", - counts: dict = None, + counts: dict | None = None, color: "Operators.PrintColor" = None, enabled: bool = True, ): @@ -591,7 +593,7 @@ def print_emission( def _operator(source: Observable) -> Observable: def _subscribe(observer: Observer, scheduler=None): - def on_next(value): + def on_next(value) -> None: if counts is not None: # Initialize count if necessary if id not in counts: diff --git a/dimos/stream/video_provider.py b/dimos/stream/video_provider.py index 050905a024..0b7e815ae2 100644 --- a/dimos/stream/video_provider.py +++ b/dimos/stream/video_provider.py @@ -20,12 +20,11 @@ """ # Standard library imports +from abc import ABC, abstractmethod import logging import os -import time -from abc import ABC, abstractmethod from threading import Lock -from typing import Optional +import time # Third-party imports import cv2 @@ -60,7 +59,7 @@ class AbstractVideoProvider(ABC): """Abstract base class for video providers managing video capture resources.""" def __init__( - self, dev_name: str = "NA", pool_scheduler: Optional[ThreadPoolScheduler] = None + self, dev_name: str = "NA", pool_scheduler: ThreadPoolScheduler | None = None ) -> None: """Initializes the video provider with a device name. @@ -108,7 +107,7 @@ def __init__( self, dev_name: str, video_source: str = f"{os.getcwd()}/assets/video-f30-480p.mp4", - pool_scheduler: Optional[ThreadPoolScheduler] = None, + pool_scheduler: ThreadPoolScheduler | None = None, ) -> None: """Initializes the video provider with a device name and video source. @@ -163,7 +162,7 @@ def capture_video_as_observable(self, realtime: bool = True, fps: int = 30) -> O VideoFrameError: If frames cannot be read properly. """ - def emit_frames(observer, scheduler): + def emit_frames(observer, scheduler) -> None: try: self._initialize_capture() diff --git a/dimos/stream/video_providers/unitree.py b/dimos/stream/video_providers/unitree.py index e1a7587146..ba28cb1d6f 100644 --- a/dimos/stream/video_providers/unitree.py +++ b/dimos/stream/video_providers/unitree.py @@ -12,26 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.stream.video_provider import AbstractVideoProvider - -from queue import Queue -from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod -from aiortc import MediaStreamTrack import asyncio -from reactivex import Observable, create, operators as ops import logging +from queue import Queue import threading import time +from aiortc import MediaStreamTrack +from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod +from reactivex import Observable, create, operators as ops + +from dimos.stream.video_provider import AbstractVideoProvider + class UnitreeVideoProvider(AbstractVideoProvider): def __init__( self, dev_name: str = "UnitreeGo2", connection_method: WebRTCConnectionMethod = WebRTCConnectionMethod.LocalSTA, - serial_number: str = None, - ip: str = None, - ): + serial_number: str | None = None, + ip: str | None = None, + ) -> None: """Initialize the Unitree video stream with WebRTC connection. Args: @@ -60,7 +61,7 @@ def __init__( else: raise ValueError("Unsupported connection method") - async def _recv_camera_stream(self, track: MediaStreamTrack): + async def _recv_camera_stream(self, track: MediaStreamTrack) -> None: """Receive video frames from WebRTC and put them in the queue.""" while True: frame = await track.recv() @@ -68,7 +69,7 @@ async def _recv_camera_stream(self, track: MediaStreamTrack): img = frame.to_ndarray(format="bgr24") self.frame_queue.put(img) - def _run_asyncio_loop(self, loop): + def _run_asyncio_loop(self, loop) -> None: """Run the asyncio event loop in a separate thread.""" asyncio.set_event_loop(loop) @@ -115,7 +116,7 @@ def capture_video_as_observable(self, fps: int = 30) -> Observable: """ frame_interval = 1.0 / fps - def emit_frames(observer, scheduler): + def emit_frames(observer, scheduler) -> None: try: # Start asyncio loop if not already running if not self.loop: @@ -158,7 +159,7 @@ def emit_frames(observer, scheduler): ops.share() # Share the stream among multiple subscribers ) - def dispose_all(self): + def dispose_all(self) -> None: """Clean up resources.""" if self.loop: self.loop.call_soon_threadsafe(self.loop.stop) diff --git a/dimos/stream/videostream.py b/dimos/stream/videostream.py index ee63261ae6..9c99ddea3a 100644 --- a/dimos/stream/videostream.py +++ b/dimos/stream/videostream.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Iterator + import cv2 class VideoStream: - def __init__(self, source=0): + def __init__(self, source: int = 0) -> None: """ Initialize the video stream from a camera source. @@ -27,7 +29,7 @@ def __init__(self, source=0): if not self.capture.isOpened(): raise ValueError(f"Unable to open video source {source}") - def __iter__(self): + def __iter__(self) -> Iterator: return self def __next__(self): @@ -37,5 +39,5 @@ def __next__(self): raise StopIteration return frame - def release(self): + def release(self) -> None: self.capture.release() diff --git a/dimos/types/label.py b/dimos/types/label.py index ce037aed7a..83b91c8152 100644 --- a/dimos/types/label.py +++ b/dimos/types/label.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, Any +from typing import Any class LabelType: - def __init__(self, labels: Dict[str, Any], metadata: Any = None): + def __init__(self, labels: dict[str, Any], metadata: Any = None) -> None: """ Initializes a standardized label type. @@ -31,7 +31,7 @@ def get_label_descriptions(self): """Return a list of label descriptions.""" return [desc["description"] for desc in self.labels.values()] - def save_to_json(self, filepath: str): + def save_to_json(self, filepath: str) -> None: """Save the labels to a JSON file.""" import json diff --git a/dimos/types/manipulation.py b/dimos/types/manipulation.py index fee4e69ebb..0df62362a4 100644 --- a/dimos/types/manipulation.py +++ b/dimos/types/manipulation.py @@ -12,13 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from abc import ABC +from dataclasses import dataclass, field from enum import Enum -from typing import Dict, List, Optional, Any, Union, TypedDict, Tuple, Literal, TYPE_CHECKING -from dataclasses import dataclass, field, fields -from abc import ABC, abstractmethod +import time +from typing import TYPE_CHECKING, Any, Literal, TypedDict import uuid + import numpy as np -import time + from dimos.types.vector import Vector if TYPE_CHECKING: @@ -46,10 +48,10 @@ class TranslationConstraint(AbstractConstraint): """Constraint parameters for translational movement along a single axis.""" translation_axis: Literal["x", "y", "z"] = None # Axis to translate along - reference_point: Optional[Vector] = None - bounds_min: Optional[Vector] = None # For bounded translation - bounds_max: Optional[Vector] = None # For bounded translation - target_point: Optional[Vector] = None # For relative positioning + reference_point: Vector | None = None + bounds_min: Vector | None = None # For bounded translation + bounds_max: Vector | None = None # For bounded translation + target_point: Vector | None = None # For relative positioning @dataclass @@ -57,10 +59,10 @@ class RotationConstraint(AbstractConstraint): """Constraint parameters for rotational movement around a single axis.""" rotation_axis: Literal["roll", "pitch", "yaw"] = None # Axis to rotate around - start_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis - end_angle: Optional[Vector] = None # Angle values applied to the specified rotation axis - pivot_point: Optional[Vector] = None # Point of rotation - secondary_pivot_point: Optional[Vector] = None # For double point rotations + start_angle: Vector | None = None # Angle values applied to the specified rotation axis + end_angle: Vector | None = None # Angle values applied to the specified rotation axis + pivot_point: Vector | None = None # Point of rotation + secondary_pivot_point: Vector | None = None # For double point rotations @dataclass @@ -69,7 +71,7 @@ class ForceConstraint(AbstractConstraint): max_force: float = 0.0 # Maximum force in newtons min_force: float = 0.0 # Minimum force in newtons - force_direction: Optional[Vector] = None # Direction of force application + force_direction: Vector | None = None # Direction of force application class ObjectData(TypedDict, total=False): @@ -77,7 +79,7 @@ class ObjectData(TypedDict, total=False): # Basic detection information object_id: int # Unique ID for the object - bbox: List[float] # Bounding box [x1, y1, x2, y2] + bbox: list[float] # Bounding box [x1, y1, x2, y2] depth: float # Depth in meters from Metric3d confidence: float # Detection confidence class_id: int # Class ID from the detector @@ -86,9 +88,9 @@ class ObjectData(TypedDict, total=False): segmentation_mask: np.ndarray # Binary mask of the object's pixels # 3D pose and dimensions - position: Union[Dict[str, float], Vector] # 3D position {x, y, z} or Vector - rotation: Union[Dict[str, float], Vector] # 3D rotation {roll, pitch, yaw} or Vector - size: Dict[str, float] # Object dimensions {width, height, depth} + position: dict[str, float] | Vector # 3D position {x, y, z} or Vector + rotation: dict[str, float] | Vector # 3D rotation {roll, pitch, yaw} or Vector + size: dict[str, float] # Object dimensions {width, height, depth} # Point cloud data point_cloud: "o3d.geometry.PointCloud" # Open3D point cloud object @@ -100,21 +102,21 @@ class ManipulationMetadata(TypedDict, total=False): """Typed metadata for manipulation constraints.""" timestamp: float - objects: Dict[str, ObjectData] + objects: dict[str, ObjectData] @dataclass class ManipulationTaskConstraint: """Set of constraints for a specific manipulation action.""" - constraints: List[AbstractConstraint] = field(default_factory=list) + constraints: list[AbstractConstraint] = field(default_factory=list) - def add_constraint(self, constraint: AbstractConstraint): + def add_constraint(self, constraint: AbstractConstraint) -> None: """Add a constraint to this set.""" if constraint not in self.constraints: self.constraints.append(constraint) - def get_constraints(self) -> List[AbstractConstraint]: + def get_constraints(self) -> list[AbstractConstraint]: """Get all constraints in this set.""" return self.constraints @@ -125,18 +127,18 @@ class ManipulationTask: description: str target_object: str # Semantic label of target object - target_point: Optional[Tuple[float, float]] = ( + target_point: tuple[float, float] | None = ( None # (X,Y) point in pixel-space of the point to manipulate on target object ) metadata: ManipulationMetadata = field(default_factory=dict) timestamp: float = field(default_factory=time.time) task_id: str = "" - result: Optional[Dict[str, Any]] = None # Any result data from the task execution - constraints: Union[List[AbstractConstraint], ManipulationTaskConstraint, AbstractConstraint] = ( - field(default_factory=list) + result: dict[str, Any] | None = None # Any result data from the task execution + constraints: list[AbstractConstraint] | ManipulationTaskConstraint | AbstractConstraint = field( + default_factory=list ) - def add_constraint(self, constraint: AbstractConstraint): + def add_constraint(self, constraint: AbstractConstraint) -> None: """Add a constraint to this manipulation task.""" # If constraints is a ManipulationTaskConstraint object if isinstance(self.constraints, ManipulationTaskConstraint): @@ -152,7 +154,7 @@ def add_constraint(self, constraint: AbstractConstraint): # This will also handle empty lists (the default case) self.constraints.append(constraint) - def get_constraints(self) -> List[AbstractConstraint]: + def get_constraints(self) -> list[AbstractConstraint]: """Get all constraints in this manipulation task.""" # If constraints is a ManipulationTaskConstraint object if isinstance(self.constraints, ManipulationTaskConstraint): diff --git a/dimos/types/robot_location.py b/dimos/types/robot_location.py index 54211b72f4..59a780daf5 100644 --- a/dimos/types/robot_location.py +++ b/dimos/types/robot_location.py @@ -17,8 +17,8 @@ """ from dataclasses import dataclass, field -from typing import Dict, Any, Optional, Tuple import time +from typing import Any import uuid @@ -41,14 +41,14 @@ class RobotLocation: """ name: str - position: Tuple[float, float, float] - rotation: Tuple[float, float, float] - frame_id: Optional[str] = None + position: tuple[float, float, float] + rotation: tuple[float, float, float] + frame_id: str | None = None timestamp: float = field(default_factory=time.time) location_id: str = field(default_factory=lambda: f"loc_{uuid.uuid4().hex[:8]}") - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) - def __post_init__(self): + def __post_init__(self) -> None: """Validate and normalize the position and rotation tuples.""" # Ensure position is a tuple of 3 floats if len(self.position) == 2: @@ -62,7 +62,7 @@ def __post_init__(self): else: self.rotation = tuple(float(x) for x in self.rotation) - def to_vector_metadata(self) -> Dict[str, Any]: + def to_vector_metadata(self) -> dict[str, Any]: """ Convert the location to metadata format for storing in a vector database. @@ -89,7 +89,7 @@ def to_vector_metadata(self) -> Dict[str, Any]: return metadata @classmethod - def from_vector_metadata(cls, metadata: Dict[str, Any]) -> "RobotLocation": + def from_vector_metadata(cls, metadata: dict[str, Any]) -> "RobotLocation": """ Create a RobotLocation object from vector database metadata. @@ -134,5 +134,5 @@ def from_vector_metadata(cls, metadata: Dict[str, Any]) -> "RobotLocation": }, ) - def __str__(self): + def __str__(self) -> str: return f"[RobotPosition name:{self.name} pos:{self.position} rot:{self.rotation})]" diff --git a/dimos/types/ros_polyfill.py b/dimos/types/ros_polyfill.py index 1bb4ece7fb..fde0a832cb 100644 --- a/dimos/types/ros_polyfill.py +++ b/dimos/types/ros_polyfill.py @@ -15,13 +15,11 @@ try: from geometry_msgs.msg import Vector3 except ImportError: - from dimos.msgs.geometry_msgs import Vector3 # type: ignore[import] + pass # type: ignore[import] try: from geometry_msgs.msg import Point, Pose, Quaternion, Twist from nav_msgs.msg import OccupancyGrid, Odometry from std_msgs.msg import Header except ImportError: - from dimos_lcm.geometry_msgs import Point, Pose, Quaternion, Twist - from dimos_lcm.nav_msgs import OccupancyGrid, Odometry - from dimos_lcm.std_msgs import Header + pass diff --git a/dimos/types/sample.py b/dimos/types/sample.py index 5665f7a640..6d84942c55 100644 --- a/dimos/types/sample.py +++ b/dimos/types/sample.py @@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import logging +import builtins from collections import OrderedDict +from collections.abc import Sequence from enum import Enum +import json +import logging from pathlib import Path -from typing import Any, Dict, List, Literal, Sequence, Union, get_origin +from typing import Annotated, Any, Literal, Union, get_origin -import numpy as np from datasets import Dataset from gymnasium import spaces from jsonref import replace_refs +from mbodied.data.utils import to_features +from mbodied.utils.import_utils import smart_import +import numpy as np from pydantic import BaseModel, ConfigDict, ValidationError from pydantic.fields import FieldInfo from pydantic_core import from_json -from typing_extensions import Annotated - -from mbodied.data.utils import to_features -from mbodied.utils.import_utils import smart_import Flattenable = Annotated[Literal["dict", "np", "pt", "list"], "Numpy, PyTorch, list, or dict"] @@ -81,7 +81,7 @@ class Sample(BaseModel): arbitrary_types_allowed=True, ) - def __init__(self, datum=None, **data): + def __init__(self, datum=None, **data) -> None: """Accepts an arbitrary datum as well as keyword arguments.""" if datum is not None: if isinstance(datum, Sample): @@ -100,7 +100,7 @@ def __str__(self) -> str: """Return a string representation of the Sample instance.""" return f"{self.__class__.__name__}({', '.join([f'{k}={v}' for k, v in self.dict().items() if v is not None])})" - def dict(self, exclude_none=True, exclude: set[str] = None) -> Dict[str, Any]: + def dict(self, exclude_none: bool = True, exclude: set[str] | None = None) -> dict[str, Any]: """Return the Sample object as a dictionary with None values excluded. Args: @@ -142,7 +142,7 @@ def unflatten(cls, one_d_array_or_dict, schema=None) -> "Sample": else: flat_data = list(one_d_array_or_dict) - def unflatten_recursive(schema_part, index=0): + def unflatten_recursive(schema_part, index: int = 0): if schema_part["type"] == "object": result = {} for prop, prop_schema in schema_part["properties"].items(): @@ -165,10 +165,10 @@ def flatten( self, output_type: Flattenable = "dict", non_numerical: Literal["ignore", "forbid", "allow"] = "allow", - ) -> Dict[str, Any] | np.ndarray | "torch.Tensor" | List: + ) -> builtins.dict[str, Any] | np.ndarray | "torch.Tensor" | list: accumulator = {} if output_type == "dict" else [] - def flatten_recursive(obj, path=""): + def flatten_recursive(obj, path: str = "") -> None: if isinstance(obj, Sample): for k, v in obj.dict().items(): flatten_recursive(v, path + k + "/") @@ -208,7 +208,7 @@ def flatten_recursive(obj, path=""): return accumulator @staticmethod - def obj_to_schema(value: Any) -> Dict: + def obj_to_schema(value: Any) -> builtins.dict: """Generates a simplified JSON schema from a dictionary. Args: @@ -236,7 +236,9 @@ def obj_to_schema(value: Any) -> Dict: return {"type": "boolean"} return {} - def schema(self, resolve_refs: bool = True, include_descriptions=False) -> Dict: + def schema( + self, resolve_refs: bool = True, include_descriptions: bool = False + ) -> builtins.dict: """Returns a simplified json schema. Removing additionalProperties, @@ -406,10 +408,10 @@ def space_for( raise ValueError(f"Unsupported object {value} of type: {type(value)} for space generation") @classmethod - def init_from(cls, d: Any, pack=False) -> "Sample": + def init_from(cls, d: Any, pack: bool = False) -> "Sample": if isinstance(d, spaces.Space): return cls.from_space(d) - if isinstance(d, Union[Sequence, np.ndarray]): # noqa: UP007 + if isinstance(d, Union[Sequence, np.ndarray]): if pack: return cls.pack_from(d) return cls.unflatten(d) @@ -427,7 +429,9 @@ def init_from(cls, d: Any, pack=False) -> "Sample": return cls(d) @classmethod - def from_flat_dict(cls, flat_dict: Dict[str, Any], schema: Dict = None) -> "Sample": + def from_flat_dict( + cls, flat_dict: builtins.dict[str, Any], schema: builtins.dict | None = None + ) -> "Sample": """Initialize a Sample instance from a flattened dictionary.""" """ Reconstructs the original JSON object from a flattened dictionary using the provided schema. @@ -466,7 +470,7 @@ def from_space(cls, space: spaces.Space) -> "Sample": return cls(sampled) @classmethod - def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": + def pack_from(cls, samples: list[Union["Sample", builtins.dict]]) -> "Sample": """Pack a list of samples into a single sample with lists for attributes. Args: @@ -496,7 +500,7 @@ def pack_from(cls, samples: List[Union["Sample", Dict]]) -> "Sample": aggregated[attr].append(getattr(sample, attr, None)) return cls(**aggregated) - def unpack(self, to_dicts=False) -> List[Union["Sample", Dict]]: + def unpack(self, to_dicts: bool = False) -> list[Union["Sample", builtins.dict]]: """Unpack the packed Sample object into a list of Sample objects or dictionaries.""" attributes = list(self.model_extra.keys()) + list(self.model_fields.keys()) attributes = [attr for attr in attributes if getattr(self, attr) is not None] @@ -525,7 +529,9 @@ def default_space(cls) -> spaces.Dict: return cls().space() @classmethod - def default_sample(cls, output_type="Sample") -> Union["Sample", Dict[str, Any]]: + def default_sample( + cls, output_type: str = "Sample" + ) -> Union["Sample", builtins.dict[str, Any]]: """Generate a default Sample instance from its class attributes. Useful for padding. This is the "no-op" instance and should be overriden as needed. @@ -554,7 +560,7 @@ def space(self) -> spaces.Dict: for key, value in self.dict().items(): logging.debug("Generating space for key: '%s', value: %s", key, value) info = self.model_field_info(key) - value = getattr(self, key) if hasattr(self, key) else value # noqa: PLW2901 + value = getattr(self, key) if hasattr(self, key) else value space_dict[key] = ( value.space() if isinstance(value, Sample) else self.space_for(value, info=info) ) diff --git a/dimos/types/segmentation.py b/dimos/types/segmentation.py index 5995f302f9..1f3c2a0773 100644 --- a/dimos/types/segmentation.py +++ b/dimos/types/segmentation.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Any +from typing import Any + import numpy as np class SegmentationType: - def __init__(self, masks: List[np.ndarray], metadata: Any = None): + def __init__(self, masks: list[np.ndarray], metadata: Any = None) -> None: """ Initializes a standardized segmentation type. @@ -35,7 +36,7 @@ def combine_masks(self): combined_mask = np.logical_or(combined_mask, mask) return combined_mask - def save_masks(self, directory: str): + def save_masks(self, directory: str) -> None: """Save each mask to a separate file.""" import os diff --git a/dimos/types/test_timestamped.py b/dimos/types/test_timestamped.py index e197f971a0..7eae7a8ad3 100644 --- a/dimos/types/test_timestamped.py +++ b/dimos/types/test_timestamped.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time from datetime import datetime, timezone +import time import pytest from reactivex import operators as ops @@ -33,7 +33,7 @@ from dimos.utils.reactive import backpressure -def test_timestamped_dt_method(): +def test_timestamped_dt_method() -> None: ts = 1751075203.4120464 timestamped = Timestamped(ts) dt = timestamped.dt() @@ -42,7 +42,7 @@ def test_timestamped_dt_method(): assert dt.tzinfo is not None, "datetime should be timezone-aware" -def test_to_ros_stamp(): +def test_to_ros_stamp() -> None: """Test the to_ros_stamp function with different input types.""" # Test with float timestamp @@ -65,7 +65,7 @@ def test_to_ros_stamp(): assert abs(result.nanosec - 123456000) < 1000 # Allow small rounding error -def test_to_datetime(): +def test_to_datetime() -> None: """Test the to_datetime function with different input types.""" # Test with float timestamp @@ -108,7 +108,7 @@ def test_to_datetime(): class SimpleTimestamped(Timestamped): - def __init__(self, ts: float, data: str): + def __init__(self, ts: float, data: str) -> None: super().__init__(ts) self.data = data @@ -138,7 +138,7 @@ def collection(sample_items): return TimestampedCollection(sample_items) -def test_empty_collection(): +def test_empty_collection() -> None: collection = TimestampedCollection() assert len(collection) == 0 assert collection.duration() == 0.0 @@ -146,7 +146,7 @@ def test_empty_collection(): assert collection.find_closest(1.0) is None -def test_add_items(): +def test_add_items() -> None: collection = TimestampedCollection() item1 = SimpleTimestamped(2.0, "two") item2 = SimpleTimestamped(1.0, "one") @@ -159,7 +159,7 @@ def test_add_items(): assert collection[1].data == "two" -def test_find_closest(collection): +def test_find_closest(collection) -> None: # Exact match assert collection.find_closest(3.0).data == "third" @@ -184,7 +184,7 @@ def test_find_closest(collection): assert collection.find_closest(10.0, tolerance=2.0) is None -def test_find_before_after(collection): +def test_find_before_after(collection) -> None: # Find before assert collection.find_before(2.0).data == "first" assert collection.find_before(5.5).data == "fifth" @@ -196,7 +196,7 @@ def test_find_before_after(collection): assert collection.find_after(7.0) is None # Nothing after last item -def test_merge_collections(): +def test_merge_collections() -> None: collection1 = TimestampedCollection( [ SimpleTimestamped(1.0, "a"), @@ -216,12 +216,12 @@ def test_merge_collections(): assert [item.data for item in merged] == ["a", "b", "c", "d"] -def test_duration_and_range(collection): +def test_duration_and_range(collection) -> None: assert collection.duration() == 6.0 # 7.0 - 1.0 assert collection.time_range() == (1.0, 7.0) -def test_slice_by_time(collection): +def test_slice_by_time(collection) -> None: # Slice inclusive of boundaries sliced = collection.slice_by_time(2.0, 6.0) assert len(sliced) == 2 @@ -237,19 +237,19 @@ def test_slice_by_time(collection): assert len(all_slice) == 4 -def test_iteration(collection): +def test_iteration(collection) -> None: items = list(collection) assert len(items) == 4 assert [item.ts for item in items] == [1.0, 3.0, 5.0, 7.0] -def test_single_item_collection(): +def test_single_item_collection() -> None: single = TimestampedCollection([SimpleTimestamped(5.0, "only")]) assert single.duration() == 0.0 assert single.time_range() == (5.0, 5.0) -def test_time_window_collection(): +def test_time_window_collection() -> None: # Create a collection with a 2-second window window = TimestampedBufferCollection[SimpleTimestamped](window_duration=2.0) @@ -278,7 +278,7 @@ def test_time_window_collection(): assert window.end_ts == 5.5 -def test_timestamp_alignment(test_scheduler): +def test_timestamp_alignment(test_scheduler) -> None: speed = 5.0 # ensure that lfs package is downloaded @@ -333,7 +333,7 @@ def process_video_frame(frame): assert len(aligned_frames) > 2 -def test_timestamp_alignment_primary_first(): +def test_timestamp_alignment_primary_first() -> None: """Test alignment when primary messages arrive before secondary messages.""" from reactivex import Subject @@ -394,7 +394,7 @@ def test_timestamp_alignment_primary_first(): secondary_subject.on_completed() -def test_timestamp_alignment_multiple_secondaries(): +def test_timestamp_alignment_multiple_secondaries() -> None: """Test alignment with multiple secondary observables.""" from reactivex import Subject @@ -464,7 +464,7 @@ def test_timestamp_alignment_multiple_secondaries(): secondary2_subject.on_completed() -def test_timestamp_alignment_delayed_secondary(): +def test_timestamp_alignment_delayed_secondary() -> None: """Test alignment when secondary messages arrive late but still within tolerance.""" from reactivex import Subject @@ -524,7 +524,7 @@ def test_timestamp_alignment_delayed_secondary(): secondary_subject.on_completed() -def test_timestamp_alignment_buffer_cleanup(): +def test_timestamp_alignment_buffer_cleanup() -> None: """Test that old buffered primaries are cleaned up.""" import time as time_module diff --git a/dimos/types/test_vector.py b/dimos/types/test_vector.py index 6a93d37afd..5462fda9a4 100644 --- a/dimos/types/test_vector.py +++ b/dimos/types/test_vector.py @@ -17,7 +17,7 @@ from dimos.types.vector import Vector -def test_vector_default_init(): +def test_vector_default_init() -> None: """Test that default initialization of Vector() has x,y,z components all zero.""" v = Vector() assert v.x == 0.0 @@ -26,10 +26,10 @@ def test_vector_default_init(): assert v.dim == 0 assert len(v.data) == 0 assert v.to_list() == [] - assert v.is_zero() == True # Empty vector should be considered zero + assert v.is_zero() # Empty vector should be considered zero -def test_vector_specific_init(): +def test_vector_specific_init() -> None: """Test initialization with specific values.""" # 2D vector v1 = Vector(1.0, 2.0) @@ -60,7 +60,7 @@ def test_vector_specific_init(): assert v4.dim == 3 -def test_vector_addition(): +def test_vector_addition() -> None: """Test vector addition.""" v1 = Vector(1.0, 2.0, 3.0) v2 = Vector(4.0, 5.0, 6.0) @@ -71,7 +71,7 @@ def test_vector_addition(): assert v_add.z == 9.0 -def test_vector_subtraction(): +def test_vector_subtraction() -> None: """Test vector subtraction.""" v1 = Vector(1.0, 2.0, 3.0) v2 = Vector(4.0, 5.0, 6.0) @@ -82,7 +82,7 @@ def test_vector_subtraction(): assert v_sub.z == 3.0 -def test_vector_scalar_multiplication(): +def test_vector_scalar_multiplication() -> None: """Test vector multiplication by a scalar.""" v1 = Vector(1.0, 2.0, 3.0) @@ -98,7 +98,7 @@ def test_vector_scalar_multiplication(): assert v_rmul.z == 6.0 -def test_vector_scalar_division(): +def test_vector_scalar_division() -> None: """Test vector division by a scalar.""" v2 = Vector(4.0, 5.0, 6.0) @@ -108,7 +108,7 @@ def test_vector_scalar_division(): assert v_div.z == 3.0 -def test_vector_dot_product(): +def test_vector_dot_product() -> None: """Test vector dot product.""" v1 = Vector(1.0, 2.0, 3.0) v2 = Vector(4.0, 5.0, 6.0) @@ -117,7 +117,7 @@ def test_vector_dot_product(): assert dot == 32.0 -def test_vector_length(): +def test_vector_length() -> None: """Test vector length calculation.""" # 2D vector with length 5 v1 = Vector(3.0, 4.0) @@ -132,10 +132,10 @@ def test_vector_length(): assert v2.length_squared() == 49.0 -def test_vector_normalize(): +def test_vector_normalize() -> None: """Test vector normalization.""" v = Vector(2.0, 3.0, 6.0) - assert v.is_zero() == False + assert not v.is_zero() v_norm = v.normalize() length = v.length() @@ -147,19 +147,19 @@ def test_vector_normalize(): assert np.isclose(v_norm.y, expected_y) assert np.isclose(v_norm.z, expected_z) assert np.isclose(v_norm.length(), 1.0) - assert v_norm.is_zero() == False + assert not v_norm.is_zero() # Test normalizing a zero vector v_zero = Vector(0.0, 0.0, 0.0) - assert v_zero.is_zero() == True + assert v_zero.is_zero() v_zero_norm = v_zero.normalize() assert v_zero_norm.x == 0.0 assert v_zero_norm.y == 0.0 assert v_zero_norm.z == 0.0 - assert v_zero_norm.is_zero() == True + assert v_zero_norm.is_zero() -def test_vector_to_2d(): +def test_vector_to_2d() -> None: """Test conversion to 2D vector.""" v = Vector(2.0, 3.0, 6.0) @@ -177,7 +177,7 @@ def test_vector_to_2d(): assert v2_2d.dim == 2 -def test_vector_distance(): +def test_vector_distance() -> None: """Test distance calculations between vectors.""" v1 = Vector(1.0, 2.0, 3.0) v2 = Vector(4.0, 6.0, 8.0) @@ -192,7 +192,7 @@ def test_vector_distance(): assert dist_sq == 50.0 # 9 + 16 + 25 -def test_vector_cross_product(): +def test_vector_cross_product() -> None: """Test vector cross product.""" v1 = Vector(1.0, 0.0, 0.0) # Unit x vector v2 = Vector(0.0, 1.0, 0.0) # Unit y vector @@ -220,7 +220,7 @@ def test_vector_cross_product(): v_2d.cross(v2) -def test_vector_zeros(): +def test_vector_zeros() -> None: """Test Vector.zeros class method.""" # 3D zero vector v_zeros = Vector.zeros(3) @@ -228,7 +228,7 @@ def test_vector_zeros(): assert v_zeros.y == 0.0 assert v_zeros.z == 0.0 assert v_zeros.dim == 3 - assert v_zeros.is_zero() == True + assert v_zeros.is_zero() # 2D zero vector v_zeros_2d = Vector.zeros(2) @@ -236,10 +236,10 @@ def test_vector_zeros(): assert v_zeros_2d.y == 0.0 assert v_zeros_2d.z == 0.0 assert v_zeros_2d.dim == 2 - assert v_zeros_2d.is_zero() == True + assert v_zeros_2d.is_zero() -def test_vector_ones(): +def test_vector_ones() -> None: """Test Vector.ones class method.""" # 3D ones vector v_ones = Vector.ones(3) @@ -256,7 +256,7 @@ def test_vector_ones(): assert v_ones_2d.dim == 2 -def test_vector_conversion_methods(): +def test_vector_conversion_methods() -> None: """Test vector conversion methods (to_list, to_tuple, to_numpy).""" v = Vector(1.0, 2.0, 3.0) @@ -272,7 +272,7 @@ def test_vector_conversion_methods(): assert np.array_equal(np_array, np.array([1.0, 2.0, 3.0])) -def test_vector_equality(): +def test_vector_equality() -> None: """Test vector equality.""" v1 = Vector(1, 2, 3) v2 = Vector(1, 2, 3) @@ -285,75 +285,75 @@ def test_vector_equality(): assert v1 != [1, 2, 3] -def test_vector_is_zero(): +def test_vector_is_zero() -> None: """Test is_zero method for vectors.""" # Default empty vector v0 = Vector() - assert v0.is_zero() == True + assert v0.is_zero() # Explicit zero vector v1 = Vector(0.0, 0.0, 0.0) - assert v1.is_zero() == True + assert v1.is_zero() # Zero vector with different dimensions v2 = Vector(0.0, 0.0) - assert v2.is_zero() == True + assert v2.is_zero() # Non-zero vectors v3 = Vector(1.0, 0.0, 0.0) - assert v3.is_zero() == False + assert not v3.is_zero() v4 = Vector(0.0, 2.0, 0.0) - assert v4.is_zero() == False + assert not v4.is_zero() v5 = Vector(0.0, 0.0, 3.0) - assert v5.is_zero() == False + assert not v5.is_zero() # Almost zero (within tolerance) v6 = Vector(1e-10, 1e-10, 1e-10) - assert v6.is_zero() == True + assert v6.is_zero() # Almost zero (outside tolerance) v7 = Vector(1e-6, 1e-6, 1e-6) - assert v7.is_zero() == False + assert not v7.is_zero() def test_vector_bool_conversion(): """Test boolean conversion of vectors.""" # Zero vectors should be False v0 = Vector() - assert bool(v0) == False + assert not bool(v0) v1 = Vector(0.0, 0.0, 0.0) - assert bool(v1) == False + assert not bool(v1) # Almost zero vectors should be False v2 = Vector(1e-10, 1e-10, 1e-10) - assert bool(v2) == False + assert not bool(v2) # Non-zero vectors should be True v3 = Vector(1.0, 0.0, 0.0) - assert bool(v3) == True + assert bool(v3) v4 = Vector(0.0, 2.0, 0.0) - assert bool(v4) == True + assert bool(v4) v5 = Vector(0.0, 0.0, 3.0) - assert bool(v5) == True + assert bool(v5) # Direct use in if statements if v0: - assert False, "Zero vector should be False in boolean context" + raise AssertionError("Zero vector should be False in boolean context") else: pass # Expected path if v3: pass # Expected path else: - assert False, "Non-zero vector should be True in boolean context" + raise AssertionError("Non-zero vector should be True in boolean context") -def test_vector_add(): +def test_vector_add() -> None: """Test vector addition operator.""" v1 = Vector(1.0, 2.0, 3.0) v2 = Vector(4.0, 5.0, 6.0) @@ -375,10 +375,10 @@ def test_vector_add(): assert (v1 + v_zero) == v1 -def test_vector_add_dim_mismatch(): +def test_vector_add_dim_mismatch() -> None: """Test vector addition operator.""" v1 = Vector(1.0, 2.0) v2 = Vector(4.0, 5.0, 6.0) # Using + operator - v_add_op = v1 + v2 + v1 + v2 diff --git a/dimos/types/test_weaklist.py b/dimos/types/test_weaklist.py index c4dfe27616..a37d893de9 100644 --- a/dimos/types/test_weaklist.py +++ b/dimos/types/test_weaklist.py @@ -24,14 +24,14 @@ class SampleObject: """Simple test object.""" - def __init__(self, value): + def __init__(self, value) -> None: self.value = value - def __repr__(self): + def __repr__(self) -> str: return f"SampleObject({self.value})" -def test_weaklist_basic_operations(): +def test_weaklist_basic_operations() -> None: """Test basic append, iterate, and length operations.""" wl = WeakList() @@ -54,7 +54,7 @@ def test_weaklist_basic_operations(): assert SampleObject(4) not in wl -def test_weaklist_auto_removal(): +def test_weaklist_auto_removal() -> None: """Test that objects are automatically removed when garbage collected.""" wl = WeakList() @@ -77,7 +77,7 @@ def test_weaklist_auto_removal(): assert list(wl) == [obj1, obj3] -def test_weaklist_explicit_remove(): +def test_weaklist_explicit_remove() -> None: """Test explicit removal of objects.""" wl = WeakList() @@ -98,7 +98,7 @@ def test_weaklist_explicit_remove(): wl.remove(SampleObject(3)) -def test_weaklist_indexing(): +def test_weaklist_indexing() -> None: """Test index access.""" wl = WeakList() @@ -119,7 +119,7 @@ def test_weaklist_indexing(): _ = wl[3] -def test_weaklist_clear(): +def test_weaklist_clear() -> None: """Test clearing the list.""" wl = WeakList() @@ -136,7 +136,7 @@ def test_weaklist_clear(): assert obj1 not in wl -def test_weaklist_iteration_during_modification(): +def test_weaklist_iteration_during_modification() -> None: """Test that iteration works even if objects are deleted during iteration.""" wl = WeakList() diff --git a/dimos/types/timestamped.py b/dimos/types/timestamped.py index 412ba08c03..2d3d6b6a20 100644 --- a/dimos/types/timestamped.py +++ b/dimos/types/timestamped.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections import defaultdict +from collections.abc import Iterable, Iterator from datetime import datetime, timezone -from typing import Generic, Iterable, List, Optional, Tuple, TypeVar, Union +from typing import Generic, TypeVar, Union from dimos_lcm.builtin_interfaces import Time as ROSTime from reactivex import create @@ -100,7 +101,7 @@ def to_datetime(ts: TimeLike, tz=None) -> datetime: class Timestamped: ts: float - def __init__(self, ts: float): + def __init__(self, ts: float) -> None: self.ts = ts def dt(self) -> datetime: @@ -119,14 +120,14 @@ def ros_timestamp(self) -> list[int]: class TimestampedCollection(Generic[T]): """A collection of timestamped objects with efficient time-based operations.""" - def __init__(self, items: Optional[Iterable[T]] = None): + def __init__(self, items: Iterable[T] | None = None) -> None: self._items = SortedKeyList(items or [], key=lambda x: x.ts) def add(self, item: T) -> None: """Add a timestamped item to the collection.""" self._items.add(item) - def find_closest(self, timestamp: float, tolerance: Optional[float] = None) -> Optional[T]: + def find_closest(self, timestamp: float, tolerance: float | None = None) -> T | None: """Find the timestamped object closest to the given timestamp.""" if not self._items: return None @@ -162,12 +163,12 @@ def find_closest(self, timestamp: float, tolerance: Optional[float] = None) -> O return self._items[closest_idx] - def find_before(self, timestamp: float) -> Optional[T]: + def find_before(self, timestamp: float) -> T | None: """Find the last item before the given timestamp.""" idx = self._items.bisect_key_left(timestamp) return self._items[idx - 1] if idx > 0 else None - def find_after(self, timestamp: float) -> Optional[T]: + def find_after(self, timestamp: float) -> T | None: """Find the first item after the given timestamp.""" idx = self._items.bisect_key_right(timestamp) return self._items[idx] if idx < len(self._items) else None @@ -184,7 +185,7 @@ def duration(self) -> float: return 0.0 return self._items[-1].ts - self._items[0].ts - def time_range(self) -> Optional[Tuple[float, float]]: + def time_range(self) -> tuple[float, float] | None: """Get the time range (start, end) of the collection.""" if not self._items: return None @@ -197,19 +198,19 @@ def slice_by_time(self, start: float, end: float) -> "TimestampedCollection[T]": return TimestampedCollection(self._items[start_idx:end_idx]) @property - def start_ts(self) -> Optional[float]: + def start_ts(self) -> float | None: """Get the start timestamp of the collection.""" return self._items[0].ts if self._items else None @property - def end_ts(self) -> Optional[float]: + def end_ts(self) -> float | None: """Get the end timestamp of the collection.""" return self._items[-1].ts if self._items else None def __len__(self) -> int: return len(self._items) - def __iter__(self): + def __iter__(self) -> Iterator: return iter(self._items) def __getitem__(self, idx: int) -> T: @@ -223,7 +224,7 @@ def __getitem__(self, idx: int) -> T: class TimestampedBufferCollection(TimestampedCollection[T]): """A timestamped collection that maintains a sliding time window, dropping old messages.""" - def __init__(self, window_duration: float, items: Optional[Iterable[T]] = None): + def __init__(self, window_duration: float, items: Iterable[T] | None = None) -> None: """ Initialize with a time window duration in seconds. @@ -270,12 +271,12 @@ class MatchContainer(Timestamped, Generic[PRIMARY, SECONDARY]): tracking which secondaries are still missing to avoid redundant searches. """ - def __init__(self, primary: PRIMARY, matches: List[Optional[SECONDARY]]): + def __init__(self, primary: PRIMARY, matches: list[SECONDARY | None]) -> None: super().__init__(primary.ts) self.primary = primary self.matches = matches # Direct list with None for missing matches - def message_received(self, secondary_idx: int, secondary_item: SECONDARY): + def message_received(self, secondary_idx: int, secondary_item: SECONDARY) -> None: """Process a secondary message and check if it matches this primary.""" if self.matches[secondary_idx] is None: self.matches[secondary_idx] = secondary_item @@ -284,7 +285,7 @@ def is_complete(self) -> bool: """Check if all secondary matches have been found.""" return all(match is not None for match in self.matches) - def get_tuple(self) -> Tuple[PRIMARY, ...]: + def get_tuple(self) -> tuple[PRIMARY, ...]: """Get the result tuple for emission.""" return (self.primary, *self.matches) @@ -294,7 +295,7 @@ def align_timestamped( *secondary_observables: Observable[SECONDARY], buffer_size: float = 1.0, # seconds match_tolerance: float = 0.1, # seconds -) -> Observable[Tuple[PRIMARY, ...]]: +) -> Observable[tuple[PRIMARY, ...]]: """Align a primary observable with one or more secondary observables. Args: @@ -312,7 +313,7 @@ def align_timestamped( def subscribe(observer, scheduler=None): # Create a timed buffer collection for each secondary observable - secondary_collections: List[TimestampedBufferCollection[SECONDARY]] = [ + secondary_collections: list[TimestampedBufferCollection[SECONDARY]] = [ TimestampedBufferCollection(buffer_size) for _ in secondary_observables ] @@ -331,13 +332,13 @@ def has_secondary_progressed_past(secondary_ts: float, primary_ts: float) -> boo """Check if secondary stream has progressed past the primary + tolerance.""" return secondary_ts > primary_ts + match_tolerance - def remove_stakeholder(stakeholder: MatchContainer): + def remove_stakeholder(stakeholder: MatchContainer) -> None: """Remove a stakeholder from all tracking structures.""" primary_buffer.remove(stakeholder) for weak_list in secondary_stakeholders.values(): weak_list.discard(stakeholder) - def on_secondary(i: int, secondary_item: SECONDARY): + def on_secondary(i: int, secondary_item: SECONDARY) -> None: # Add the secondary item to its collection secondary_collections[i].add(secondary_item) @@ -368,7 +369,7 @@ def on_secondary(i: int, secondary_item: SECONDARY): ) ) - def on_primary(primary_item: PRIMARY): + def on_primary(primary_item: PRIMARY) -> None: # Try to find matches in existing secondary collections matches = [None] * len(secondary_observables) diff --git a/dimos/types/vector.py b/dimos/types/vector.py index d980e28105..161084fc2c 100644 --- a/dimos/types/vector.py +++ b/dimos/types/vector.py @@ -12,21 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Tuple, TypeVar, Union, Sequence +import builtins +from collections.abc import Sequence +from typing import TypeVar, Union import numpy as np + from dimos.types.ros_polyfill import Vector3 T = TypeVar("T", bound="Vector") # Vector-like types that can be converted to/from Vector -VectorLike = Union[Sequence[Union[int, float]], Vector3, "Vector", np.ndarray] +VectorLike = Union[Sequence[int | float], Vector3, "Vector", np.ndarray] class Vector: """A wrapper around numpy arrays for vector operations with intuitive syntax.""" - def __init__(self, *args: VectorLike): + def __init__(self, *args: VectorLike) -> None: """Initialize a vector from components or another iterable. Examples: @@ -49,7 +52,7 @@ def yaw(self) -> float: return self.x @property - def tuple(self) -> Tuple[float, ...]: + def tuple(self) -> tuple[float, ...]: """Tuple representation of the vector.""" return tuple(self._data) @@ -78,7 +81,7 @@ def data(self) -> np.ndarray: """Get the underlying numpy array.""" return self._data - def __getitem__(self, idx): + def __getitem__(self, idx: int): return self._data[idx] def __repr__(self) -> str: @@ -103,7 +106,7 @@ def getArrow(): return f"{getArrow()} Vector {self.__repr__()}" - def serialize(self) -> Tuple: + def serialize(self) -> builtins.tuple: """Serialize the vector to a tuple.""" return {"type": "vector", "c": self._data.tolist()} @@ -261,11 +264,11 @@ def unit_z(cls: type[T], dim: int = 3) -> T: v[2] = 1.0 return cls(v) - def to_list(self) -> List[float]: + def to_list(self) -> list[float]: """Convert the vector to a list.""" return self._data.tolist() - def to_tuple(self) -> Tuple[float, ...]: + def to_tuple(self) -> builtins.tuple[float, ...]: """Convert the vector to a tuple.""" return tuple(self._data) @@ -327,7 +330,7 @@ def to_vector(value: VectorLike) -> Vector: return Vector(value) -def to_tuple(value: VectorLike) -> Tuple[float, ...]: +def to_tuple(value: VectorLike) -> tuple[float, ...]: """Convert a vector-compatible value to a tuple. Args: @@ -348,7 +351,7 @@ def to_tuple(value: VectorLike) -> Tuple[float, ...]: return tuple(value) -def to_list(value: VectorLike) -> List[float]: +def to_list(value: VectorLike) -> list[float]: """Convert a vector-compatible value to a list. Args: diff --git a/dimos/types/weaklist.py b/dimos/types/weaklist.py index 8722455c66..e09b36157c 100644 --- a/dimos/types/weaklist.py +++ b/dimos/types/weaklist.py @@ -14,8 +14,9 @@ """Weak reference list implementation that automatically removes dead references.""" +from collections.abc import Iterator +from typing import Any import weakref -from typing import Any, Iterator, Optional class WeakList: @@ -25,13 +26,13 @@ class WeakList: Supports iteration, append, remove, and length operations. """ - def __init__(self): + def __init__(self) -> None: self._refs = [] def append(self, obj: Any) -> None: """Add an object to the list (stored as weak reference).""" - def _cleanup(ref): + def _cleanup(ref) -> None: try: self._refs.remove(ref) except ValueError: diff --git a/dimos/utils/actor_registry.py b/dimos/utils/actor_registry.py index 3f1133fa4d..9cd589bed2 100644 --- a/dimos/utils/actor_registry.py +++ b/dimos/utils/actor_registry.py @@ -16,7 +16,6 @@ import json from multiprocessing import shared_memory -from typing import Dict class ActorRegistry: @@ -26,7 +25,7 @@ class ActorRegistry: SHM_SIZE = 65536 # 64KB should be enough for most deployments @staticmethod - def update(actor_name: str, worker_id: str): + def update(actor_name: str, worker_id: str) -> None: """Update registry with new actor deployment.""" try: shm = shared_memory.SharedMemory(name=ActorRegistry.SHM_NAME) @@ -46,7 +45,7 @@ def update(actor_name: str, worker_id: str): shm.close() @staticmethod - def get_all() -> Dict[str, str]: + def get_all() -> dict[str, str]: """Get all actor->worker mappings.""" try: shm = shared_memory.SharedMemory(name=ActorRegistry.SHM_NAME) @@ -57,7 +56,7 @@ def get_all() -> Dict[str, str]: return {} @staticmethod - def clear(): + def clear() -> None: """Clear the registry and free shared memory.""" try: shm = shared_memory.SharedMemory(name=ActorRegistry.SHM_NAME) @@ -68,7 +67,7 @@ def clear(): pass @staticmethod - def _read_from_shm(shm) -> Dict[str, str]: + def _read_from_shm(shm) -> dict[str, str]: """Read JSON data from shared memory.""" raw = bytes(shm.buf[:]).rstrip(b"\x00") if not raw: @@ -76,7 +75,7 @@ def _read_from_shm(shm) -> Dict[str, str]: return json.loads(raw.decode("utf-8")) @staticmethod - def _write_to_shm(shm, data: Dict[str, str]): + def _write_to_shm(shm, data: dict[str, str]): """Write JSON data to shared memory.""" json_bytes = json.dumps(data).encode("utf-8") if len(json_bytes) > ActorRegistry.SHM_SIZE: diff --git a/dimos/utils/cli/agentspy/agentspy.py b/dimos/utils/cli/agentspy/agentspy.py index 2d69e3537f..8aea6a1542 100644 --- a/dimos/utils/cli/agentspy/agentspy.py +++ b/dimos/utils/cli/agentspy/agentspy.py @@ -14,10 +14,10 @@ from __future__ import annotations -import time from collections import deque from dataclasses import dataclass -from typing import Any, Deque, List, Optional, Union +import time +from typing import Any, Union from langchain_core.messages import ( AIMessage, @@ -43,7 +43,7 @@ class MessageEntry: timestamp: float message: AnyMessage - def __post_init__(self): + def __post_init__(self) -> None: """Initialize timestamp if not provided.""" if self.timestamp is None: self.timestamp = time.time() @@ -52,25 +52,25 @@ def __post_init__(self): class AgentMessageMonitor: """Monitor agent messages published via LCM.""" - def __init__(self, topic: str = "/agent", max_messages: int = 1000): + def __init__(self, topic: str = "/agent", max_messages: int = 1000) -> None: self.topic = topic self.max_messages = max_messages - self.messages: Deque[MessageEntry] = deque(maxlen=max_messages) + self.messages: deque[MessageEntry] = deque(maxlen=max_messages) self.transport = PickleLCM() self.transport.start() - self.callbacks: List[callable] = [] + self.callbacks: list[callable] = [] pass - def start(self): + def start(self) -> None: """Start monitoring messages.""" self.transport.subscribe(self.topic, self._handle_message) - def stop(self): + def stop(self) -> None: """Stop monitoring.""" # PickleLCM doesn't have explicit stop method pass - def _handle_message(self, msg: Any, topic: str): + def _handle_message(self, msg: Any, topic: str) -> None: """Handle incoming messages.""" # Check if it's one of the message types we care about if isinstance(msg, (SystemMessage, ToolMessage, AIMessage, HumanMessage)): @@ -83,11 +83,11 @@ def _handle_message(self, msg: Any, topic: str): else: pass - def subscribe(self, callback: callable): + def subscribe(self, callback: callable) -> None: """Subscribe to new messages.""" self.callbacks.append(callback) - def get_messages(self) -> List[MessageEntry]: + def get_messages(self) -> list[MessageEntry]: """Get all stored messages.""" return list(self.messages) @@ -165,10 +165,10 @@ class AgentSpyApp(App): Binding("ctrl+c", "quit", show=False), ] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.monitor = AgentMessageMonitor() - self.message_log: Optional[RichLog] = None + self.message_log: RichLog | None = None def compose(self) -> ComposeResult: """Compose the UI.""" @@ -176,7 +176,7 @@ def compose(self) -> ComposeResult: yield self.message_log yield Footer() - def on_mount(self): + def on_mount(self) -> None: """Start monitoring when app mounts.""" self.theme = "flexoki" @@ -188,11 +188,11 @@ def on_mount(self): for entry in self.monitor.get_messages(): self.on_new_message(entry) - def on_unmount(self): + def on_unmount(self) -> None: """Stop monitoring when app unmounts.""" self.monitor.stop() - def on_new_message(self, entry: MessageEntry): + def on_new_message(self, entry: MessageEntry) -> None: """Handle new messages.""" if self.message_log: msg = entry.message @@ -207,18 +207,18 @@ def on_new_message(self, entry: MessageEntry): f"[{style}]{content}[/{style}]" ) - def refresh_display(self): + def refresh_display(self) -> None: """Refresh the message display.""" # Not needed anymore as messages are written directly to the log - def action_clear(self): + def action_clear(self) -> None: """Clear message history.""" self.monitor.messages.clear() if self.message_log: self.message_log.clear() -def main(): +def main() -> None: """Main entry point for agentspy.""" import sys diff --git a/dimos/utils/cli/agentspy/demo_agentspy.py b/dimos/utils/cli/agentspy/demo_agentspy.py index 1e3a0d4f3b..100f22522d 100755 --- a/dimos/utils/cli/agentspy/demo_agentspy.py +++ b/dimos/utils/cli/agentspy/demo_agentspy.py @@ -16,17 +16,19 @@ """Demo script to test agent message publishing and agentspy reception.""" import time + from langchain_core.messages import ( AIMessage, HumanMessage, SystemMessage, ToolMessage, ) -from dimos.protocol.pubsub.lcmpubsub import PickleLCM + from dimos.protocol.pubsub import lcm +from dimos.protocol.pubsub.lcmpubsub import PickleLCM -def test_publish_messages(): +def test_publish_messages() -> None: """Publish test messages to verify agentspy is working.""" print("Starting agent message publisher demo...") diff --git a/dimos/utils/cli/boxglove/boxglove.py b/dimos/utils/cli/boxglove/boxglove.py index eabd13800b..1e0e09a277 100644 --- a/dimos/utils/cli/boxglove/boxglove.py +++ b/dimos/utils/cli/boxglove/boxglove.py @@ -14,29 +14,25 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import numpy as np import reactivex.operators as ops from rich.text import Text from textual.app import App, ComposeResult -from textual.color import Color from textual.containers import Container from textual.reactive import reactive -from textual.widgets import Footer, Header, Label, Static +from textual.widgets import Footer, Static from dimos import core -from dimos.msgs.geometry_msgs import Pose, PoseStamped, Transform, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid -from dimos.msgs.sensor_msgs import Image, PointCloud2 -from dimos.robot.unitree_webrtc.multiprocess.unitree_go2_navonly import ConnectionModule from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.utils.cli.boxglove.connection import Connection if TYPE_CHECKING: from reactivex.disposable import Disposable from dimos.msgs.nav_msgs import OccupancyGrid + from dimos.utils.cli.boxglove.connection import Connection blocks = "█▗▖▝▘" @@ -64,7 +60,7 @@ class OccupancyGridApp(App): layout: vertical; overflow: hidden; } - + #grid-container { width: 100%; height: 1fr; @@ -72,14 +68,14 @@ class OccupancyGridApp(App): margin: 0; padding: 0; } - + #grid-display { width: 100%; height: 100%; margin: 0; padding: 0; } - + Footer { dock: bottom; height: 1; @@ -87,19 +83,19 @@ class OccupancyGridApp(App): """ # Reactive properties - grid_data: reactive[Optional["OccupancyGrid"]] = reactive(None) + grid_data: reactive[OccupancyGrid | None] = reactive(None) BINDINGS = [ ("q", "quit", "Quit"), ("ctrl+c", "quit", "Quit"), ] - def __init__(self, connection: Connection, *args, **kwargs): + def __init__(self, connection: Connection, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.connection = connection - self.subscription: Optional[Disposable] = None - self.grid_display: Optional[Static] = None - self.cached_grid: Optional["OccupancyGrid"] = None + self.subscription: Disposable | None = None + self.grid_display: Static | None = None + self.cached_grid: OccupancyGrid | None = None def compose(self) -> ComposeResult: """Create the app layout.""" @@ -115,7 +111,7 @@ def on_mount(self) -> None: self.theme = "flexoki" # Subscribe to the OccupancyGrid stream - def on_grid(grid: "OccupancyGrid") -> None: + def on_grid(grid: OccupancyGrid) -> None: self.grid_data = grid def on_error(error: Exception) -> None: @@ -128,7 +124,7 @@ async def on_unmount(self) -> None: if self.subscription: self.subscription.dispose() - def watch_grid_data(self, grid: Optional["OccupancyGrid"]) -> None: + def watch_grid_data(self, grid: OccupancyGrid | None) -> None: """Update display when new grid data arrives.""" if grid is None: return @@ -147,7 +143,7 @@ def on_resize(self, event) -> None: grid_text = self.render_grid(self.cached_grid) self.grid_display.update(grid_text) - def render_grid(self, grid: "OccupancyGrid") -> Text: + def render_grid(self, grid: OccupancyGrid) -> Text: """Render the OccupancyGrid as colored ASCII art, scaled to fit terminal.""" text = Text() @@ -177,7 +173,7 @@ def render_grid(self, grid: "OccupancyGrid") -> Text: render_height = min(int(grid.height / scale_float), terminal_height) # Store both integer and float scale for different uses - scale = int(np.ceil(scale_float)) # For legacy compatibility + int(np.ceil(scale_float)) # For legacy compatibility # Adjust render dimensions to use all available space # This reduces jumping by allowing fractional cell sizes @@ -276,7 +272,7 @@ def get_cell_char_and_style(grid_data: np.ndarray, x: int, y: int) -> tuple[str, return text -def main(): +def main() -> None: """Run the OccupancyGrid visualizer with a connection.""" # app = OccupancyGridApp(core.LCMTransport("/global_costmap", OccupancyGrid).observable) diff --git a/dimos/utils/cli/boxglove/connection.py b/dimos/utils/cli/boxglove/connection.py index 2c1f91469c..5d3b3f8806 100644 --- a/dimos/utils/cli/boxglove/connection.py +++ b/dimos/utils/cli/boxglove/connection.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable import pickle -import time -from typing import Callable import reactivex as rx from reactivex import operators as ops @@ -38,13 +37,13 @@ def subscribe(observer, scheduler=None): lcm.autoconf() l = lcm.LCM() - def on_message(grid: OccupancyGrid, _): + def on_message(grid: OccupancyGrid, _) -> None: observer.on_next(grid) l.subscribe(lcm.Topic("/global_costmap", OccupancyGrid), on_message) l.start() - def dispose(): + def dispose() -> None: l.stop() return Disposable(dispose) diff --git a/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py b/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py index a0cf07ffb6..8244d16d39 100644 --- a/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py +++ b/dimos/utils/cli/foxglove_bridge/run_foxglove_bridge.py @@ -28,10 +28,10 @@ print(f"Using dimos_lcm from: {dimos_lcm_path}") -def run_bridge_example(): +def run_bridge_example() -> None: """Example of running the bridge in a separate thread""" - def bridge_thread(): + def bridge_thread() -> None: """Thread function to run the bridge""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) @@ -58,7 +58,7 @@ def bridge_thread(): print("Shutting down...") -def main(): +def main() -> None: run_bridge_example() diff --git a/dimos/utils/cli/human/humancli.py b/dimos/utils/cli/human/humancli.py index fb0ebc5fe2..4c474b88d2 100644 --- a/dimos/utils/cli/human/humancli.py +++ b/dimos/utils/cli/human/humancli.py @@ -14,10 +14,10 @@ from __future__ import annotations +from datetime import datetime import textwrap import threading -from datetime import datetime -from typing import Optional +from typing import TYPE_CHECKING from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolCall, ToolMessage from rich.highlighter import JSONHighlighter @@ -25,13 +25,14 @@ from textual.app import App, ComposeResult from textual.binding import Binding from textual.containers import Container -from textual.events import Key from textual.widgets import Input, RichLog from dimos.core import pLCMTransport from dimos.utils.cli import theme from dimos.utils.generic import truncate_display_string +if TYPE_CHECKING: + from textual.events import Key # Custom theme for JSON highlighting JSON_THEME = Theme( @@ -76,13 +77,13 @@ class HumanCLIApp(App): Binding("ctrl+l", "clear", "Clear chat"), ] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.human_transport = pLCMTransport("/human_input") self.agent_transport = pLCMTransport("/agent") - self.chat_log: Optional[RichLog] = None - self.input_widget: Optional[Input] = None - self._subscription_thread: Optional[threading.Thread] = None + self.chat_log: RichLog | None = None + self.input_widget: Input | None = None + self._subscription_thread: threading.Thread | None = None self._running = False def compose(self) -> ComposeResult: @@ -132,7 +133,7 @@ def on_unmount(self) -> None: def _subscribe_to_agent(self) -> None: """Subscribe to agent messages in a separate thread.""" - def receive_msg(msg): + def receive_msg(msg) -> None: if not self._running: return @@ -275,7 +276,7 @@ def on_input_submitted(self, event: Input.Submitted) -> None: /help - Show this help message /exit - Exit the application /quit - Exit the application - + Tool calls are displayed in cyan with ▶ prefix""" self._add_system_message(help_text) return @@ -293,7 +294,7 @@ def action_quit(self) -> None: self.exit() -def main(): +def main() -> None: """Main entry point for the human CLI.""" import sys diff --git a/dimos/utils/cli/lcmspy/lcmspy.py b/dimos/utils/cli/lcmspy/lcmspy.py index 134051302c..42f811ffbc 100755 --- a/dimos/utils/cli/lcmspy/lcmspy.py +++ b/dimos/utils/cli/lcmspy/lcmspy.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading -import time from collections import deque from dataclasses import dataclass from enum import Enum +import threading +import time import lcm @@ -45,7 +45,7 @@ def human_readable_bytes(bytes_value: float, round_to: int = 2) -> tuple[float, class Topic: history_window: float = 60.0 - def __init__(self, name: str, history_window: float = 60.0): + def __init__(self, name: str, history_window: float = 60.0) -> None: self.name = name # Store (timestamp, data_size) tuples for statistics self.message_history = deque() @@ -53,14 +53,14 @@ def __init__(self, name: str, history_window: float = 60.0): # Total traffic accumulator (doesn't get cleaned up) self.total_traffic_bytes = 0 - def msg(self, data: bytes): + def msg(self, data: bytes) -> None: # print(f"> msg {self.__str__()} {len(data)} bytes") datalen = len(data) self.message_history.append((time.time(), datalen)) self.total_traffic_bytes += datalen self._cleanup_old_messages() - def _cleanup_old_messages(self, max_age: float = None): + def _cleanup_old_messages(self, max_age: float | None = None) -> None: """Remove messages older than max_age seconds""" current_time = time.time() while self.message_history and current_time - self.message_history[0][0] > ( @@ -114,7 +114,7 @@ def total_traffic_hr(self) -> tuple[float, BandwidthUnit]: total_bytes = self.total_traffic() return human_readable_bytes(total_bytes) - def __str__(self): + def __str__(self) -> str: return f"topic({self.name})" @@ -129,21 +129,21 @@ class LCMSpy(LCMService, Topic): graph_log_window: float = 1.0 topic_class: type[Topic] = Topic - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) Topic.__init__(self, name="total", history_window=self.config.topic_history_window) self.topic = {} self.l = lcm.LCM(self.config.url) if self.config.url else lcm.LCM() - def start(self): + def start(self) -> None: super().start() self.l.subscribe(".*", self.msg) - def stop(self): + def stop(self) -> None: """Stop the LCM spy and clean up resources""" super().stop() - def msg(self, topic, data): + def msg(self, topic, data) -> None: Topic.msg(self, data) if topic not in self.topic: @@ -155,12 +155,12 @@ def msg(self, topic, data): class GraphTopic(Topic): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.freq_history = deque(maxlen=20) self.bandwidth_history = deque(maxlen=20) - def update_graphs(self, step_window: float = 1.0): + def update_graphs(self, step_window: float = 1.0) -> None: """Update historical data for graphing""" freq = self.freq(step_window) kbps = self.kbps(step_window) @@ -180,23 +180,23 @@ class GraphLCMSpy(LCMSpy, GraphTopic): graph_log_stop_event: threading.Event = threading.Event() topic_class: type[Topic] = GraphTopic - def __init__(self, **kwargs): + def __init__(self, **kwargs) -> None: super().__init__(**kwargs) GraphTopic.__init__(self, name="total", history_window=self.config.topic_history_window) - def start(self): + def start(self) -> None: super().start() self.graph_log_thread = threading.Thread(target=self.graph_log, daemon=True) self.graph_log_thread.start() - def graph_log(self): + def graph_log(self) -> None: while not self.graph_log_stop_event.is_set(): self.update_graphs(self.config.graph_log_window) # Update global history for topic in self.topic.values(): topic.update_graphs(self.config.graph_log_window) time.sleep(self.config.graph_log_window) - def stop(self): + def stop(self) -> None: """Stop the graph logging and LCM spy""" self.graph_log_stop_event.set() if self.graph_log_thread and self.graph_log_thread.is_alive(): diff --git a/dimos/utils/cli/lcmspy/run_lcmspy.py b/dimos/utils/cli/lcmspy/run_lcmspy.py index 4faef02892..2e96156852 100644 --- a/dimos/utils/cli/lcmspy/run_lcmspy.py +++ b/dimos/utils/cli/lcmspy/run_lcmspy.py @@ -14,23 +14,13 @@ from __future__ import annotations -import math -import random -import threading -from typing import List - from rich.text import Text from textual.app import App, ComposeResult -from textual.binding import Binding from textual.color import Color -from textual.containers import Container -from textual.reactive import reactive -from textual.renderables.sparkline import Sparkline as SparklineRenderable -from textual.widgets import DataTable, Header, Label, Sparkline +from textual.widgets import DataTable from dimos.utils.cli import theme -from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy -from dimos.utils.cli.lcmspy.lcmspy import GraphTopic as SpyTopic +from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy, GraphTopic as SpyTopic def gradient(max_value: float, value: float) -> str: @@ -88,7 +78,7 @@ class LCMSpyApp(App): ("ctrl+c", "quit"), ] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.spy = GraphLCMSpy(autoconf=True, graph_log_window=0.5) self.table: DataTable | None = None @@ -101,15 +91,15 @@ def compose(self) -> ComposeResult: self.table.add_column("Total Traffic") yield self.table - def on_mount(self): + def on_mount(self) -> None: self.spy.start() self.set_interval(self.refresh_interval, self.refresh_table) - async def on_unmount(self): + async def on_unmount(self) -> None: self.spy.stop() - def refresh_table(self): - topics: List[SpyTopic] = list(self.spy.topic.values()) + def refresh_table(self) -> None: + topics: list[SpyTopic] = list(self.spy.topic.values()) topics.sort(key=lambda t: t.total_traffic(), reverse=True) self.table.clear(columns=False) @@ -127,7 +117,7 @@ def refresh_table(self): ) -def main(): +def main() -> None: import sys if len(sys.argv) > 1 and sys.argv[1] == "web": diff --git a/dimos/utils/cli/lcmspy/test_lcmspy.py b/dimos/utils/cli/lcmspy/test_lcmspy.py index f72175ea10..56e8e72c3b 100644 --- a/dimos/utils/cli/lcmspy/test_lcmspy.py +++ b/dimos/utils/cli/lcmspy/test_lcmspy.py @@ -17,13 +17,11 @@ import pytest from dimos.protocol.pubsub.lcmpubsub import PickleLCM, Topic -from dimos.protocol.service.lcmservice import autoconf -from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy, GraphTopic, LCMSpy -from dimos.utils.cli.lcmspy.lcmspy import Topic as TopicSpy +from dimos.utils.cli.lcmspy.lcmspy import GraphLCMSpy, GraphTopic, LCMSpy, Topic as TopicSpy @pytest.mark.lcm -def test_spy_basic(): +def test_spy_basic() -> None: lcm = PickleLCM(autoconf=True) lcm.start() @@ -82,7 +80,7 @@ def test_spy_basic(): @pytest.mark.lcm -def test_topic_statistics_direct(): +def test_topic_statistics_direct() -> None: """Test Topic statistics directly without LCM""" topic = TopicSpy("/test") @@ -90,7 +88,7 @@ def test_topic_statistics_direct(): # Add some test messages test_data = [b"small", b"medium sized message", b"very long message for testing purposes"] - for i, data in enumerate(test_data): + for _i, data in enumerate(test_data): topic.msg(data) time.sleep(0.1) # Simulate time passing @@ -108,7 +106,7 @@ def test_topic_statistics_direct(): print(f"Direct test - Avg size: {avg_size:.2f} bytes") -def test_topic_cleanup(): +def test_topic_cleanup() -> None: """Test that old messages are properly cleaned up""" topic = TopicSpy("/test") @@ -131,7 +129,7 @@ def test_topic_cleanup(): @pytest.mark.lcm -def test_graph_topic_basic(): +def test_graph_topic_basic() -> None: """Test GraphTopic basic functionality""" topic = GraphTopic("/test_graph") @@ -147,7 +145,7 @@ def test_graph_topic_basic(): @pytest.mark.lcm -def test_graph_lcmspy_basic(): +def test_graph_lcmspy_basic() -> None: """Test GraphLCMSpy basic functionality""" spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) spy.start() @@ -167,7 +165,7 @@ def test_graph_lcmspy_basic(): @pytest.mark.lcm -def test_lcmspy_global_totals(): +def test_lcmspy_global_totals() -> None: """Test that LCMSpy tracks global totals as a Topic itself""" spy = LCMSpy(autoconf=True) spy.start() @@ -197,7 +195,7 @@ def test_lcmspy_global_totals(): @pytest.mark.lcm -def test_graph_lcmspy_global_totals(): +def test_graph_lcmspy_global_totals() -> None: """Test that GraphLCMSpy tracks global totals with history""" spy = GraphLCMSpy(autoconf=True, graph_log_window=0.1) spy.start() diff --git a/dimos/utils/cli/skillspy/demo_skillspy.py b/dimos/utils/cli/skillspy/demo_skillspy.py index 20c5417a2e..f7d4875e01 100644 --- a/dimos/utils/cli/skillspy/demo_skillspy.py +++ b/dimos/utils/cli/skillspy/demo_skillspy.py @@ -15,8 +15,9 @@ """Demo script that runs skills in the background while agentspy monitors them.""" -import time import threading +import time + from dimos.protocol.skill.coordinator import SkillCoordinator from dimos.protocol.skill.skill import SkillContainer, skill @@ -25,7 +26,7 @@ class DemoSkills(SkillContainer): @skill() def count_to(self, n: int) -> str: """Count to n with delays.""" - for i in range(n): + for _i in range(n): time.sleep(0.5) return f"Counted to {n}" @@ -53,7 +54,7 @@ def quick_task(self, name: str) -> str: return f"Quick task '{name}' done!" -def run_demo_skills(): +def run_demo_skills() -> None: """Run demo skills in background.""" # Create and start agent interface agent_interface = SkillCoordinator() @@ -64,7 +65,7 @@ def run_demo_skills(): agent_interface.register_skills(demo_skills) # Run various skills periodically - def skill_runner(): + def skill_runner() -> None: counter = 0 while True: time.sleep(2) diff --git a/dimos/utils/cli/skillspy/skillspy.py b/dimos/utils/cli/skillspy/skillspy.py index bfb0a7edc8..769478b00e 100644 --- a/dimos/utils/cli/skillspy/skillspy.py +++ b/dimos/utils/cli/skillspy/skillspy.py @@ -14,32 +14,35 @@ from __future__ import annotations -import logging import threading import time -from typing import Callable, Dict, Optional +from typing import TYPE_CHECKING from rich.text import Text from textual.app import App, ComposeResult from textual.binding import Binding from textual.widgets import DataTable, Footer -from dimos.protocol.skill.comms import SkillMsg from dimos.protocol.skill.coordinator import SkillCoordinator, SkillState, SkillStateEnum from dimos.utils.cli import theme +if TYPE_CHECKING: + from collections.abc import Callable + + from dimos.protocol.skill.comms import SkillMsg + class AgentSpy: """Spy on agent skill executions via LCM messages.""" - def __init__(self): + def __init__(self) -> None: self.agent_interface = SkillCoordinator() - self.message_callbacks: list[Callable[[Dict[str, SkillState]], None]] = [] + self.message_callbacks: list[Callable[[dict[str, SkillState]], None]] = [] self._lock = threading.Lock() - self._latest_state: Dict[str, SkillState] = {} + self._latest_state: dict[str, SkillState] = {} self._running = False - def start(self): + def start(self) -> None: """Start spying on agent messages.""" self._running = True # Start the agent interface @@ -48,20 +51,20 @@ def start(self): # Subscribe to the agent interface's comms self.agent_interface.skill_transport.subscribe(self._handle_message) - def stop(self): + def stop(self) -> None: """Stop spying.""" self._running = False # Give threads a moment to finish processing time.sleep(0.2) self.agent_interface.stop() - def _handle_message(self, msg: SkillMsg): + def _handle_message(self, msg: SkillMsg) -> None: """Handle incoming skill messages.""" if not self._running: return # Small delay to ensure agent_interface has processed the message - def delayed_update(): + def delayed_update() -> None: time.sleep(0.1) if not self._running: return @@ -73,11 +76,11 @@ def delayed_update(): # Run in separate thread to not block LCM threading.Thread(target=delayed_update, daemon=True).start() - def subscribe(self, callback: Callable[[Dict[str, SkillState]], None]): + def subscribe(self, callback: Callable[[dict[str, SkillState]], None]) -> None: """Subscribe to state updates.""" self.message_callbacks.append(callback) - def get_state(self) -> Dict[str, SkillState]: + def get_state(self) -> dict[str, SkillState]: """Get current state snapshot.""" with self._lock: return self._latest_state.copy() @@ -137,10 +140,10 @@ class AgentSpyApp(App): Binding("ctrl+c", "quit", "Quit", show=False), ] - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) self.spy = AgentSpy() - self.table: Optional[DataTable] = None + self.table: DataTable | None = None self.skill_history: list[tuple[str, SkillState, float]] = [] # (call_id, state, start_time) def compose(self) -> ComposeResult: @@ -155,7 +158,7 @@ def compose(self) -> ComposeResult: yield self.table yield Footer() - def on_mount(self): + def on_mount(self) -> None: """Start the spy when app mounts.""" self.spy.subscribe(self.update_state) self.spy.start() @@ -163,11 +166,11 @@ def on_mount(self): # Set up periodic refresh to update durations self.set_interval(1.0, self.refresh_table) - def on_unmount(self): + def on_unmount(self) -> None: """Stop the spy when app unmounts.""" self.spy.stop() - def update_state(self, state: Dict[str, SkillState]): + def update_state(self, state: dict[str, SkillState]) -> None: """Update state from spy callback. State dict is keyed by call_id.""" # Update history with current state current_time = time.time() @@ -176,7 +179,7 @@ def update_state(self, state: Dict[str, SkillState]): for call_id, skill_state in state.items(): # Find if this call_id already in history found = False - for i, (existing_call_id, old_state, start_time) in enumerate(self.skill_history): + for i, (existing_call_id, _old_state, start_time) in enumerate(self.skill_history): if existing_call_id == call_id: # Update existing entry self.skill_history[i] = (call_id, skill_state, start_time) @@ -194,7 +197,7 @@ def update_state(self, state: Dict[str, SkillState]): # Schedule UI update self.call_from_thread(self.refresh_table) - def refresh_table(self): + def refresh_table(self) -> None: """Refresh the table display.""" if not self.table: return @@ -251,13 +254,13 @@ def refresh_table(self): Text(details, style=theme.FOREGROUND), ) - def action_clear(self): + def action_clear(self) -> None: """Clear the skill history.""" self.skill_history.clear() self.refresh_table() -def main(): +def main() -> None: """Main entry point for agentspy CLI.""" import sys diff --git a/dimos/utils/cli/theme.py b/dimos/utils/cli/theme.py index aa061bc43a..e3d98b07de 100644 --- a/dimos/utils/cli/theme.py +++ b/dimos/utils/cli/theme.py @@ -16,8 +16,8 @@ from __future__ import annotations -import re from pathlib import Path +import re def parse_tcss_colors(tcss_path: str | Path) -> dict[str, str]: diff --git a/dimos/utils/data.py b/dimos/utils/data.py index 0a2656ca82..8b70c2ad27 100644 --- a/dimos/utils/data.py +++ b/dimos/utils/data.py @@ -12,11 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import subprocess -import tarfile from functools import cache from pathlib import Path -from typing import Optional, Union +import subprocess +import tarfile @cache @@ -31,7 +30,7 @@ def _get_repo_root() -> Path: @cache -def _get_data_dir(extra_path: Optional[str] = None) -> Path: +def _get_data_dir(extra_path: str | None = None) -> Path: if extra_path: return _get_repo_root() / "data" / extra_path return _get_repo_root() / "data" @@ -59,7 +58,7 @@ def _is_lfs_pointer_file(file_path: Path) -> bool: if file_path.stat().st_size > 1024: # LFS pointers are much smaller return False - with open(file_path, "r", encoding="utf-8") as f: + with open(file_path, encoding="utf-8") as f: first_line = f.readline().strip() return first_line.startswith("version https://git-lfs.github.com/spec/") @@ -83,7 +82,7 @@ def _lfs_pull(file_path: Path, repo_root: Path) -> None: return None -def _decompress_archive(filename: Union[str, Path]) -> Path: +def _decompress_archive(filename: str | Path) -> Path: target_dir = _get_data_dir() filename_path = Path(filename) with tarfile.open(filename_path, "r:gz") as tar: @@ -91,7 +90,7 @@ def _decompress_archive(filename: Union[str, Path]) -> Path: return target_dir / filename_path.name.replace(".tar.gz", "") -def _pull_lfs_archive(filename: Union[str, Path]) -> Path: +def _pull_lfs_archive(filename: str | Path) -> Path: # Check Git LFS availability first _check_git_lfs_available() @@ -121,7 +120,7 @@ def _pull_lfs_archive(filename: Union[str, Path]) -> Path: return file_path -def get_data(filename: Union[str, Path]) -> Path: +def get_data(filename: str | Path) -> Path: """ Get the path to a test data, downloading from LFS if needed. diff --git a/dimos/utils/decorators/accumulators.py b/dimos/utils/decorators/accumulators.py index 4c57293b9f..7672ff7033 100644 --- a/dimos/utils/decorators/accumulators.py +++ b/dimos/utils/decorators/accumulators.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading from abc import ABC, abstractmethod -from typing import Any, Generic, Optional, TypeVar +import threading +from typing import Generic, TypeVar T = TypeVar("T") @@ -28,7 +28,7 @@ def add(self, *args, **kwargs) -> None: pass @abstractmethod - def get(self) -> Optional[tuple[tuple, dict]]: + def get(self) -> tuple[tuple, dict] | None: """Get the accumulated args and kwargs and reset the accumulator.""" pass @@ -41,15 +41,15 @@ def __len__(self) -> int: class LatestAccumulator(Accumulator[T]): """Simple accumulator that remembers only the latest args and kwargs.""" - def __init__(self): - self._latest: Optional[tuple[tuple, dict]] = None + def __init__(self) -> None: + self._latest: tuple[tuple, dict] | None = None self._lock = threading.Lock() def add(self, *args, **kwargs) -> None: with self._lock: self._latest = (args, kwargs) - def get(self) -> Optional[tuple[tuple, dict]]: + def get(self) -> tuple[tuple, dict] | None: with self._lock: result = self._latest self._latest = None @@ -67,7 +67,7 @@ class RollingAverageAccumulator(Accumulator[T]): a rolling average without storing individual values. """ - def __init__(self): + def __init__(self) -> None: self._sum: float = 0.0 self._count: int = 0 self._latest_kwargs: dict = {} @@ -86,7 +86,7 @@ def add(self, *args, **kwargs) -> None: except (TypeError, ValueError): raise TypeError(f"First argument must be numeric, got {type(args[0])}") - def get(self) -> Optional[tuple[tuple, dict]]: + def get(self) -> tuple[tuple, dict] | None: with self._lock: if self._count == 0: return None diff --git a/dimos/utils/decorators/decorators.py b/dimos/utils/decorators/decorators.py index 067251e5c6..4511aea309 100644 --- a/dimos/utils/decorators/decorators.py +++ b/dimos/utils/decorators/decorators.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable +from functools import wraps import threading import time -from functools import wraps -from typing import Callable, Optional, Type from .accumulators import Accumulator, LatestAccumulator -def limit(max_freq: float, accumulator: Optional[Accumulator] = None): +def limit(max_freq: float, accumulator: Accumulator | None = None): """ Decorator that limits function call frequency. @@ -46,9 +46,9 @@ def limit(max_freq: float, accumulator: Optional[Accumulator] = None): def decorator(func: Callable) -> Callable: last_call_time = 0.0 lock = threading.Lock() - timer: Optional[threading.Timer] = None + timer: threading.Timer | None = None - def execute_accumulated(): + def execute_accumulated() -> None: nonlocal last_call_time, timer with lock: if len(accumulator): @@ -145,7 +145,7 @@ def getter(self): return getter -def retry(max_retries: int = 3, on_exception: Type[Exception] = Exception, delay: float = 0.0): +def retry(max_retries: int = 3, on_exception: type[Exception] = Exception, delay: float = 0.0): """ Decorator that retries a function call if it raises an exception. diff --git a/dimos/utils/decorators/test_decorators.py b/dimos/utils/decorators/test_decorators.py index 133fab97c2..fdad670454 100644 --- a/dimos/utils/decorators/test_decorators.py +++ b/dimos/utils/decorators/test_decorators.py @@ -16,15 +16,15 @@ import pytest -from dimos.utils.decorators import LatestAccumulator, RollingAverageAccumulator, limit, retry +from dimos.utils.decorators import RollingAverageAccumulator, limit, retry -def test_limit(): +def test_limit() -> None: """Test limit decorator with keyword arguments.""" calls = [] @limit(20) # 20 Hz - def process(msg: str, keyword: int = 0): + def process(msg: str, keyword: int = 0) -> str: calls.append((msg, keyword)) return f"{msg}:{keyword}" @@ -49,14 +49,14 @@ def process(msg: str, keyword: int = 0): assert calls == [("first", 1), ("third", 3), ("fourth", 0)] -def test_latest_rolling_average(): +def test_latest_rolling_average() -> None: """Test RollingAverageAccumulator with limit decorator.""" calls = [] accumulator = RollingAverageAccumulator() @limit(20, accumulator=accumulator) # 20 Hz - def process(value: float, label: str = ""): + def process(value: float, label: str = "") -> str: calls.append((value, label)) return f"{value}:{label}" @@ -79,12 +79,12 @@ def process(value: float, label: str = ""): assert calls == [(10.0, "first"), (25.0, "third")] # (20+30)/2 = 25 -def test_retry_success_after_failures(): +def test_retry_success_after_failures() -> None: """Test that retry decorator retries on failure and eventually succeeds.""" attempts = [] @retry(max_retries=3) - def flaky_function(fail_times=2): + def flaky_function(fail_times: int = 2) -> str: attempts.append(len(attempts)) if len(attempts) <= fail_times: raise ValueError(f"Attempt {len(attempts)} failed") @@ -95,7 +95,7 @@ def flaky_function(fail_times=2): assert len(attempts) == 3 # Failed twice, succeeded on third attempt -def test_retry_exhausted(): +def test_retry_exhausted() -> None: """Test that retry decorator raises exception when retries are exhausted.""" attempts = [] @@ -111,12 +111,12 @@ def always_fails(): assert len(attempts) == 3 # Initial attempt + 2 retries -def test_retry_specific_exception(): +def test_retry_specific_exception() -> None: """Test that retry only catches specified exception types.""" attempts = [] @retry(max_retries=3, on_exception=ValueError) - def raises_different_exceptions(): + def raises_different_exceptions() -> str: attempts.append(len(attempts)) if len(attempts) == 1: raise ValueError("First attempt") @@ -132,12 +132,12 @@ def raises_different_exceptions(): assert len(attempts) == 2 # First attempt with ValueError, second with TypeError -def test_retry_no_failures(): +def test_retry_no_failures() -> None: """Test that retry decorator works when function succeeds immediately.""" attempts = [] @retry(max_retries=5) - def always_succeeds(): + def always_succeeds() -> str: attempts.append(len(attempts)) return "immediate success" @@ -146,13 +146,13 @@ def always_succeeds(): assert len(attempts) == 1 # Only one attempt needed -def test_retry_with_delay(): +def test_retry_with_delay() -> None: """Test that retry decorator applies delay between attempts.""" attempts = [] times = [] @retry(max_retries=2, delay=0.1) - def delayed_failures(): + def delayed_failures() -> str: times.append(time.time()) attempts.append(len(attempts)) if len(attempts) < 2: @@ -172,7 +172,7 @@ def delayed_failures(): assert times[1] - times[0] >= 0.1 -def test_retry_zero_retries(): +def test_retry_zero_retries() -> None: """Test retry with max_retries=0 (no retries, just one attempt).""" attempts = [] @@ -187,31 +187,31 @@ def single_attempt(): assert len(attempts) == 1 # Only the initial attempt -def test_retry_invalid_parameters(): +def test_retry_invalid_parameters() -> None: """Test that retry decorator validates parameters.""" with pytest.raises(ValueError): @retry(max_retries=-1) - def invalid_retries(): + def invalid_retries() -> None: pass with pytest.raises(ValueError): @retry(delay=-0.5) - def invalid_delay(): + def invalid_delay() -> None: pass -def test_retry_with_methods(): +def test_retry_with_methods() -> None: """Test that retry decorator works with class methods, instance methods, and static methods.""" class TestClass: - def __init__(self): + def __init__(self) -> None: self.instance_attempts = [] self.instance_value = 42 @retry(max_retries=3) - def instance_method(self, fail_times=2): + def instance_method(self, fail_times: int = 2) -> str: """Test retry on instance method.""" self.instance_attempts.append(len(self.instance_attempts)) if len(self.instance_attempts) <= fail_times: @@ -220,7 +220,7 @@ def instance_method(self, fail_times=2): @classmethod @retry(max_retries=2) - def class_method(cls, attempts_list, fail_times=1): + def class_method(cls, attempts_list, fail_times: int = 1) -> str: """Test retry on class method.""" attempts_list.append(len(attempts_list)) if len(attempts_list) <= fail_times: @@ -229,7 +229,7 @@ def class_method(cls, attempts_list, fail_times=1): @staticmethod @retry(max_retries=2) - def static_method(attempts_list, fail_times=1): + def static_method(attempts_list, fail_times: int = 1) -> str: """Test retry on static method.""" attempts_list.append(len(attempts_list)) if len(attempts_list) <= fail_times: diff --git a/dimos/utils/deprecation.py b/dimos/utils/deprecation.py index dca63d853f..3c4dd5929e 100644 --- a/dimos/utils/deprecation.py +++ b/dimos/utils/deprecation.py @@ -12,8 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import warnings import functools +import warnings def deprecated(reason: str): diff --git a/dimos/utils/extract_frames.py b/dimos/utils/extract_frames.py index ddff12f189..d57b0641cd 100644 --- a/dimos/utils/extract_frames.py +++ b/dimos/utils/extract_frames.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 import argparse from pathlib import Path +import cv2 + -def extract_frames(video_path, output_dir, frame_rate): +def extract_frames(video_path, output_dir, frame_rate) -> None: """ Extract frames from a video file at a specified frame rate. @@ -40,7 +41,7 @@ def extract_frames(video_path, output_dir, frame_rate): return # Calculate the interval between frames to capture - frame_interval = int(round(original_frame_rate / frame_rate)) + frame_interval = round(original_frame_rate / frame_rate) if frame_interval == 0: frame_interval = 1 diff --git a/dimos/utils/generic.py b/dimos/utils/generic.py index 7c776e984e..adbb18988f 100644 --- a/dimos/utils/generic.py +++ b/dimos/utils/generic.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os +import hashlib import json -import uuid +import os import string -import hashlib -from typing import Any, Optional +from typing import Any +import uuid -def truncate_display_string(arg: Any, max: Optional[int] = None) -> str: +def truncate_display_string(arg: Any, max: int | None = None) -> str: """ If we print strings that are too long that potentially obscures more important logs. diff --git a/dimos/utils/generic_subscriber.py b/dimos/utils/generic_subscriber.py index 17e619c28c..5f687c494a 100644 --- a/dimos/utils/generic_subscriber.py +++ b/dimos/utils/generic_subscriber.py @@ -14,11 +14,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import threading import logging -from typing import Optional, Any +import threading +from typing import TYPE_CHECKING, Any + from reactivex import Observable -from reactivex.disposable import Disposable + +if TYPE_CHECKING: + from reactivex.disposable import Disposable logger = logging.getLogger(__name__) @@ -26,17 +29,17 @@ class GenericSubscriber: """Subscribes to an RxPy Observable stream and stores the latest message.""" - def __init__(self, stream: Observable): + def __init__(self, stream: Observable) -> None: """Initialize the subscriber and subscribe to the stream. Args: stream: The RxPy Observable stream to subscribe to. """ - self.latest_message: Optional[Any] = None + self.latest_message: Any | None = None self._lock = threading.Lock() - self._subscription: Optional[Disposable] = None + self._subscription: Disposable | None = None self._stream_completed = threading.Event() - self._stream_error: Optional[Exception] = None + self._stream_error: Exception | None = None if stream is not None: try: @@ -50,25 +53,25 @@ def __init__(self, stream: Observable): else: logger.warning("Initialized GenericSubscriber with a None stream.") - def _on_next(self, message: Any): + def _on_next(self, message: Any) -> None: """Callback for receiving a new message.""" with self._lock: self.latest_message = message # logger.debug("Received new message") # Can be noisy - def _on_error(self, error: Exception): + def _on_error(self, error: Exception) -> None: """Callback for stream error.""" logger.error(f"Stream error: {error}") with self._lock: self._stream_error = error self._stream_completed.set() # Signal completion/error - def _on_completed(self): + def _on_completed(self) -> None: """Callback for stream completion.""" logger.info("Stream completed.") self._stream_completed.set() - def get_data(self) -> Optional[Any]: + def get_data(self) -> Any | None: """Get the latest message received from the stream. Returns: @@ -89,7 +92,7 @@ def is_completed(self) -> bool: """Check if the stream has completed or encountered an error.""" return self._stream_completed.is_set() - def dispose(self): + def dispose(self) -> None: """Dispose of the subscription to stop receiving messages.""" if self._subscription is not None: try: @@ -100,6 +103,6 @@ def dispose(self): logger.error(f"Error disposing subscription: {e}") self._stream_completed.set() # Ensure completed flag is set on manual dispose - def __del__(self): + def __del__(self) -> None: """Ensure cleanup on object deletion.""" self.dispose() diff --git a/dimos/utils/gpu_utils.py b/dimos/utils/gpu_utils.py index e40516deec..e0a1a23734 100644 --- a/dimos/utils/gpu_utils.py +++ b/dimos/utils/gpu_utils.py @@ -16,7 +16,6 @@ def is_cuda_available(): try: import pycuda.driver as cuda - import pycuda.autoinit # implicitly initializes the CUDA driver cuda.init() return cuda.Device.count() > 0 diff --git a/dimos/utils/llm_utils.py b/dimos/utils/llm_utils.py index 05cc44ad24..124169e794 100644 --- a/dimos/utils/llm_utils.py +++ b/dimos/utils/llm_utils.py @@ -14,10 +14,9 @@ import json import re -from typing import Union -def extract_json(response: str) -> Union[dict, list]: +def extract_json(response: str) -> dict | list: """Extract JSON from potentially messy LLM response. Tries multiple strategies: diff --git a/dimos/utils/logging_config.py b/dimos/utils/logging_config.py index a0a6a5fc4a..e12b1e4828 100644 --- a/dimos/utils/logging_config.py +++ b/dimos/utils/logging_config.py @@ -19,7 +19,6 @@ import logging import os -from typing import Optional import colorlog @@ -33,7 +32,7 @@ def setup_logger( - name: str, level: Optional[int] = None, log_format: Optional[str] = None + name: str, level: int | None = None, log_format: str | None = None ) -> logging.Logger: """Set up a logger with color output. diff --git a/dimos/utils/monitoring.py b/dimos/utils/monitoring.py index abadbe591c..17415781b5 100644 --- a/dimos/utils/monitoring.py +++ b/dimos/utils/monitoring.py @@ -18,25 +18,24 @@ echo 0 | sudo tee /proc/sys/kernel/yama/ptrace_scope """ -import subprocess -import threading -import re +from functools import cache import os +import re import shutil -from functools import lru_cache, partial -from typing import Optional -from distributed.client import Client +import subprocess +import threading from distributed import get_client +from distributed.client import Client + from dimos.core import Module, rpc from dimos.utils.actor_registry import ActorRegistry from dimos.utils.logging_config import setup_logger - logger = setup_logger(__file__) -def print_data_table(data): +def print_data_table(data) -> None: headers = [ "cpu_percent", "active_percent", @@ -88,13 +87,13 @@ class UtilizationThread(threading.Thread): _stop_event: threading.Event _monitors: dict - def __init__(self, module): + def __init__(self, module) -> None: super().__init__(daemon=True) self._module = module self._stop_event = threading.Event() self._monitors = {} - def run(self): + def run(self) -> None: while not self._stop_event.is_set(): workers = self._module.client.scheduler_info()["workers"] pids = {pid: None for pid in get_worker_pids()} @@ -124,13 +123,13 @@ def run(self): print_data_table(data) self._stop_event.wait(1) - def stop(self): + def stop(self) -> None: self._stop_event.set() for monitor in self._monitors.values(): monitor.stop() monitor.join(timeout=2) - def _fix_missing_ids(self, data): + def _fix_missing_ids(self, data) -> None: """ Some worker IDs are None. But if we order the workers by PID and all non-None ids are in order, then we can deduce that the None ones are the @@ -142,10 +141,10 @@ def _fix_missing_ids(self, data): class UtilizationModule(Module): - client: Optional[Client] - _utilization_thread: Optional[UtilizationThread] + client: Client | None + _utilization_thread: UtilizationThread | None - def __init__(self): + def __init__(self) -> None: super().__init__() self.client = None self._utilization_thread = None @@ -171,14 +170,14 @@ def __init__(self): self._utilization_thread = UtilizationThread(self) @rpc - def start(self): + def start(self) -> None: super().start() if self._utilization_thread: self._utilization_thread.start() @rpc - def stop(self): + def stop(self) -> None: if self._utilization_thread: self._utilization_thread.stop() self._utilization_thread.join(timeout=2) @@ -201,7 +200,7 @@ def _can_use_py_spy(): return False -@lru_cache(maxsize=None) +@cache def get_pid_by_port(port: int) -> int | None: try: result = subprocess.run( @@ -219,7 +218,7 @@ def get_worker_pids(): if not pid.isdigit(): continue try: - with open(f"/proc/{pid}/cmdline", "r") as f: + with open(f"/proc/{pid}/cmdline") as f: cmdline = f.read().replace("\x00", " ") if "spawn_main" in cmdline: pids.append(int(pid)) @@ -234,7 +233,7 @@ class GilMonitorThread(threading.Thread): _stop_event: threading.Event _lock: threading.Lock - def __init__(self, pid): + def __init__(self, pid: int) -> None: super().__init__(daemon=True) self.pid = pid self._latest_values = (-1.0, -1.0, -1.0, -1) @@ -279,7 +278,7 @@ def run(self): active_percent, num_threads, ) - except (ValueError, IndexError) as e: + except (ValueError, IndexError): pass except Exception as e: logger.error(f"An error occurred in GilMonitorThread for PID {self.pid}: {e}") @@ -294,7 +293,7 @@ def get_values(self): with self._lock: return self._latest_values - def stop(self): + def stop(self) -> None: self._stop_event.set() diff --git a/dimos/utils/reactive.py b/dimos/utils/reactive.py index 74c7044648..f7885d3129 100644 --- a/dimos/utils/reactive.py +++ b/dimos/utils/reactive.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable import threading -from typing import Any, Callable, Generic, Optional, TypeVar +from typing import Any, Generic, TypeVar import reactivex as rx from reactivex import operators as ops @@ -32,7 +33,7 @@ # └──► observe_on(pool) ─► backpressure.latest ─► sub3 (slower) def backpressure( observable: Observable[T], - scheduler: Optional[ThreadPoolScheduler] = None, + scheduler: ThreadPoolScheduler | None = None, drop_unprocessed: bool = True, ) -> Observable[T]: if scheduler is None: @@ -65,7 +66,7 @@ def _subscribe(observer, sch=None): class LatestReader(Generic[T]): """A callable object that returns the latest value from an observable.""" - def __init__(self, initial_value: T, subscription, connection=None): + def __init__(self, initial_value: T, subscription, connection=None) -> None: self._value = initial_value self._subscription = subscription self._connection = connection @@ -81,21 +82,21 @@ def dispose(self) -> None: self._connection.dispose() -def getter_ondemand(observable: Observable[T], timeout: Optional[float] = 30.0) -> T: +def getter_ondemand(observable: Observable[T], timeout: float | None = 30.0) -> T: def getter(): result = [] error = [] event = threading.Event() - def on_next(value): + def on_next(value) -> None: result.append(value) event.set() - def on_error(e): + def on_error(e) -> None: error.append(e) event.set() - def on_completed(): + def on_completed() -> None: event.set() # Subscribe and wait for first value @@ -128,7 +129,7 @@ def on_completed(): def getter_streaming( source: Observable[T], - timeout: Optional[float] = 30.0, + timeout: float | None = 30.0, *, nonblocking: bool = False, ) -> LatestReader[T]: @@ -182,7 +183,7 @@ def callback_to_observable( stop: Callable[[CB[T]], Any], ) -> Observable[T]: def _subscribe(observer, _scheduler=None): - def _on_msg(value: T): + def _on_msg(value: T) -> None: observer.on_next(value) start(_on_msg) diff --git a/dimos/utils/s3_utils.py b/dimos/utils/s3_utils.py index b8f2c32b86..f4c3227a71 100644 --- a/dimos/utils/s3_utils.py +++ b/dimos/utils/s3_utils.py @@ -12,9 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import boto3 import os +import boto3 + try: import open3d as o3d except Exception as e: @@ -22,25 +23,25 @@ class S3Utils: - def __init__(self, bucket_name): + def __init__(self, bucket_name: str) -> None: self.s3 = boto3.client("s3") self.bucket_name = bucket_name - def download_file(self, s3_key, local_path): + def download_file(self, s3_key, local_path) -> None: try: self.s3.download_file(self.bucket_name, s3_key, local_path) print(f"Downloaded {s3_key} to {local_path}") except Exception as e: print(f"Error downloading {s3_key}: {e}") - def upload_file(self, local_path, s3_key): + def upload_file(self, local_path, s3_key) -> None: try: self.s3.upload_file(local_path, self.bucket_name, s3_key) print(f"Uploaded {local_path} to {s3_key}") except Exception as e: print(f"Error uploading {local_path}: {e}") - def save_pointcloud_to_s3(self, inlier_cloud, s3_key): + def save_pointcloud_to_s3(self, inlier_cloud, s3_key) -> None: try: temp_pcd_file = "/tmp/temp_pointcloud.pcd" o3d.io.write_point_cloud(temp_pcd_file, inlier_cloud) @@ -74,10 +75,10 @@ def restore_pointcloud_from_s3(self, pointcloud_paths): return restored_pointclouds @staticmethod - def upload_text_file(bucket_name, local_path, s3_key): + def upload_text_file(bucket_name: str, local_path, s3_key) -> None: s3 = boto3.client("s3") try: - with open(local_path, "r") as file: + with open(local_path) as file: content = file.read() # Ensure the s3_key includes the file name diff --git a/dimos/utils/simple_controller.py b/dimos/utils/simple_controller.py index 99260fa8b2..dd92ae0c55 100644 --- a/dimos/utils/simple_controller.py +++ b/dimos/utils/simple_controller.py @@ -15,7 +15,7 @@ import math -def normalize_angle(angle): +def normalize_angle(angle: float): """Normalize angle to the range [-pi, pi].""" return math.atan2(math.sin(angle), math.cos(angle)) @@ -27,14 +27,14 @@ class PIDController: def __init__( self, kp, - ki=0.0, - kd=0.0, + ki: float = 0.0, + kd: float = 0.0, output_limits=(None, None), integral_limit=None, - deadband=0.0, - output_deadband=0.0, - inverse_output=False, - ): + deadband: float = 0.0, + output_deadband: float = 0.0, + inverse_output: bool = False, + ) -> None: """ Initialize the PID controller. @@ -124,7 +124,7 @@ def _apply_deadband_compensation(self, error): # Visual Servoing Controller Class # ---------------------------- class VisualServoingController: - def __init__(self, distance_pid_params, angle_pid_params): + def __init__(self, distance_pid_params, angle_pid_params) -> None: """ Initialize the visual servoing controller using enhanced PID controllers. diff --git a/dimos/utils/test_data.py b/dimos/utils/test_data.py index c584e0cdcc..b6df8e1a12 100644 --- a/dimos/utils/test_data.py +++ b/dimos/utils/test_data.py @@ -18,12 +18,11 @@ import pytest -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils import data @pytest.mark.heavy -def test_pull_file(): +def test_pull_file() -> None: repo_root = data._get_repo_root() test_file_name = "cafe.jpg" test_file_compressed = data._get_lfs_dir() / (test_file_name + ".tar.gz") @@ -79,7 +78,7 @@ def test_pull_file(): @pytest.mark.heavy -def test_pull_dir(): +def test_pull_dir() -> None: repo_root = data._get_repo_root() test_dir_name = "ab_lidar_frames" test_dir_compressed = data._get_lfs_dir() / (test_dir_name + ".tar.gz") @@ -124,6 +123,7 @@ def test_pull_dir(): "6c3aaa9a79853ea4a7453c7db22820980ceb55035777f7460d05a0fa77b3b1b3", "456cc2c23f4ffa713b4e0c0d97143c27e48bbe6ef44341197b31ce84b3650e74", ], + strict=False, ): with file.open("rb") as f: sha256 = hashlib.sha256(f.read()).hexdigest() diff --git a/dimos/utils/test_foxglove_bridge.py b/dimos/utils/test_foxglove_bridge.py index b845622d88..ad597c8720 100644 --- a/dimos/utils/test_foxglove_bridge.py +++ b/dimos/utils/test_foxglove_bridge.py @@ -17,10 +17,7 @@ Test for foxglove bridge import and basic functionality """ -import threading -import time import warnings -from unittest.mock import MagicMock, patch import pytest @@ -28,7 +25,7 @@ warnings.filterwarnings("ignore", category=DeprecationWarning, module="websockets.legacy") -def test_foxglove_bridge_import(): +def test_foxglove_bridge_import() -> None: """Test that the foxglove bridge can be imported successfully.""" try: from dimos_lcm.foxglove_bridge import FoxgloveBridge @@ -36,7 +33,7 @@ def test_foxglove_bridge_import(): pytest.fail(f"Failed to import foxglove bridge: {e}") -def test_foxglove_bridge_runner_init(): +def test_foxglove_bridge_runner_init() -> None: """Test that LcmFoxgloveBridge can be initialized with default parameters.""" try: from dimos_lcm.foxglove_bridge import FoxgloveBridge @@ -50,7 +47,7 @@ def test_foxglove_bridge_runner_init(): pytest.fail(f"Failed to initialize LcmFoxgloveBridge: {e}") -def test_foxglove_bridge_runner_params(): +def test_foxglove_bridge_runner_params() -> None: """Test that LcmFoxgloveBridge accepts various parameter configurations.""" try: from dimos_lcm.foxglove_bridge import FoxgloveBridge @@ -69,7 +66,7 @@ def test_foxglove_bridge_runner_params(): pytest.fail(f"Failed to create runner with different configs: {e}") -def test_bridge_runner_has_run_method(): +def test_bridge_runner_has_run_method() -> None: """Test that the bridge runner has a run method that can be called.""" try: from dimos_lcm.foxglove_bridge import FoxgloveBridge @@ -78,7 +75,7 @@ def test_bridge_runner_has_run_method(): # Check that the run method exists assert hasattr(runner, "run") - assert callable(getattr(runner, "run")) + assert callable(runner.run) except Exception as e: pytest.fail(f"Failed to verify run method: {e}") diff --git a/dimos/utils/test_generic.py b/dimos/utils/test_generic.py index f85201d9bf..51e7a2007a 100644 --- a/dimos/utils/test_generic.py +++ b/dimos/utils/test_generic.py @@ -13,6 +13,7 @@ # limitations under the License. from uuid import UUID + from dimos.utils.generic import short_id diff --git a/dimos/utils/test_llm_utils.py b/dimos/utils/test_llm_utils.py index 4073fd8af2..2eb2da9867 100644 --- a/dimos/utils/test_llm_utils.py +++ b/dimos/utils/test_llm_utils.py @@ -21,14 +21,14 @@ from dimos.utils.llm_utils import extract_json -def test_extract_json_clean_response(): +def test_extract_json_clean_response() -> None: """Test extract_json with clean JSON response.""" clean_json = '[["object", 1, 2, 3, 4]]' result = extract_json(clean_json) assert result == [["object", 1, 2, 3, 4]] -def test_extract_json_with_text_before_after(): +def test_extract_json_with_text_before_after() -> None: """Test extract_json with text before and after JSON.""" messy = """Here's what I found: [ @@ -40,7 +40,7 @@ def test_extract_json_with_text_before_after(): assert result == [["person", 10, 20, 30, 40], ["car", 50, 60, 70, 80]] -def test_extract_json_with_emojis(): +def test_extract_json_with_emojis() -> None: """Test extract_json with emojis and markdown code blocks.""" messy = """Sure! 😊 Here are the detections: @@ -53,7 +53,7 @@ def test_extract_json_with_emojis(): assert result == [["human", 100, 200, 300, 400]] -def test_extract_json_multiple_json_blocks(): +def test_extract_json_multiple_json_blocks() -> None: """Test extract_json when there are multiple JSON blocks.""" messy = """First attempt (wrong format): {"error": "not what we want"} @@ -70,14 +70,14 @@ def test_extract_json_multiple_json_blocks(): assert result == [["cat", 10, 10, 50, 50], ["dog", 60, 60, 100, 100]] -def test_extract_json_object(): +def test_extract_json_object() -> None: """Test extract_json with JSON object instead of array.""" response = 'The result is: {"status": "success", "count": 5}' result = extract_json(response) assert result == {"status": "success", "count": 5} -def test_extract_json_nested_structures(): +def test_extract_json_nested_structures() -> None: """Test extract_json with nested arrays and objects.""" response = """Processing complete: [ @@ -91,7 +91,7 @@ def test_extract_json_nested_structures(): assert result[2] == ["label2", 5, 6, 7, 8] -def test_extract_json_invalid(): +def test_extract_json_invalid() -> None: """Test extract_json raises error when no valid JSON found.""" response = "This response has no valid JSON at all!" with pytest.raises(json.JSONDecodeError) as exc_info: @@ -114,7 +114,7 @@ def test_extract_json_invalid(): Hope this helps!😀😊 :)""" -def test_extract_json_with_real_llm_response(): +def test_extract_json_with_real_llm_response() -> None: """Test extract_json with actual messy LLM response.""" result = extract_json(MOCK_LLM_RESPONSE) assert isinstance(result, list) diff --git a/dimos/utils/test_reactive.py b/dimos/utils/test_reactive.py index 8c6d868e97..8fae6de0db 100644 --- a/dimos/utils/test_reactive.py +++ b/dimos/utils/test_reactive.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable import time -from typing import Any, Callable, TypeVar +from typing import Any, TypeVar import numpy as np import pytest @@ -46,13 +47,15 @@ def assert_time( return result -def min_time(func: Callable[[], Any], min_t: int, assert_fail_msg="Function returned too fast"): +def min_time( + func: Callable[[], Any], min_t: int, assert_fail_msg: str = "Function returned too fast" +): return assert_time( func, (lambda t: t >= min_t * 0.98), assert_fail_msg + f", min: {min_t} seconds" ) -def max_time(func: Callable[[], Any], max_t: int, assert_fail_msg="Function took too long"): +def max_time(func: Callable[[], Any], max_t: int, assert_fail_msg: str = "Function took too long"): return assert_time(func, (lambda t: t < max_t), assert_fail_msg + f", max: {max_t} seconds") @@ -66,7 +69,7 @@ def factory(observer, scheduler=None): state["active"] += 1 upstream = source.subscribe(observer, scheduler=scheduler) - def _dispose(): + def _dispose() -> None: upstream.dispose() state["active"] -= 1 @@ -78,7 +81,7 @@ def _dispose(): return proxy -def test_backpressure_handling(): +def test_backpressure_handling() -> None: # Create a dedicated scheduler for this test to avoid thread leaks test_scheduler = ThreadPoolScheduler(max_workers=8) try: @@ -137,7 +140,7 @@ def test_backpressure_handling(): test_scheduler.executor.shutdown(wait=True) -def test_getter_streaming_blocking(): +def test_getter_streaming_blocking() -> None: source = dispose_spy( rx.interval(0.2).pipe(ops.map(lambda i: np.array([i, i + 1, i + 2])), ops.take(50)) ) @@ -162,7 +165,7 @@ def test_getter_streaming_blocking(): assert source.is_disposed(), "Observable should be disposed" -def test_getter_streaming_blocking_timeout(): +def test_getter_streaming_blocking_timeout() -> None: source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) with pytest.raises(Exception): getter = getter_streaming(source, timeout=0.1) @@ -171,7 +174,7 @@ def test_getter_streaming_blocking_timeout(): assert source.is_disposed() -def test_getter_streaming_nonblocking(): +def test_getter_streaming_nonblocking() -> None: source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) getter = max_time( @@ -196,7 +199,7 @@ def test_getter_streaming_nonblocking(): assert source.is_disposed(), "Observable should be disposed" -def test_getter_streaming_nonblocking_timeout(): +def test_getter_streaming_nonblocking_timeout() -> None: source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) getter = getter_streaming(source, timeout=0.1, nonblocking=True) with pytest.raises(Exception): @@ -210,7 +213,7 @@ def test_getter_streaming_nonblocking_timeout(): assert source.is_disposed(), "Observable should be disposed after cleanup" -def test_getter_ondemand(): +def test_getter_ondemand() -> None: # Create a controlled scheduler to avoid thread leaks from rx.interval test_scheduler = ThreadPoolScheduler(max_workers=4) try: @@ -232,7 +235,7 @@ def test_getter_ondemand(): test_scheduler.executor.shutdown(wait=True) -def test_getter_ondemand_timeout(): +def test_getter_ondemand_timeout() -> None: source = dispose_spy(rx.interval(0.2).pipe(ops.take(50))) getter = getter_ondemand(source, timeout=0.1) with pytest.raises(Exception): @@ -242,13 +245,13 @@ def test_getter_ondemand_timeout(): time.sleep(0.3) -def test_callback_to_observable(): +def test_callback_to_observable() -> None: # Test converting a callback-based API to an Observable received = [] callback = None # Mock start function that captures the callback - def start_fn(cb): + def start_fn(cb) -> str: nonlocal callback callback = cb return "start_result" @@ -256,7 +259,7 @@ def start_fn(cb): # Mock stop function stop_called = False - def stop_fn(cb): + def stop_fn(cb) -> None: nonlocal stop_called stop_called = True diff --git a/dimos/utils/test_testing.py b/dimos/utils/test_testing.py index 017b267c1b..3684031170 100644 --- a/dimos/utils/test_testing.py +++ b/dimos/utils/test_testing.py @@ -12,12 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import hashlib -import os import re -import subprocess -import reactivex as rx from reactivex import operators as ops from dimos.robot.unitree_webrtc.type.lidar import LidarMessage @@ -26,7 +22,7 @@ from dimos.utils.data import get_data -def test_sensor_replay(): +def test_sensor_replay() -> None: counter = 0 for message in testing.SensorReplay(name="office_lidar").iterate(): counter += 1 @@ -34,7 +30,7 @@ def test_sensor_replay(): assert counter == 500 -def test_sensor_replay_cast(): +def test_sensor_replay_cast() -> None: counter = 0 for message in testing.SensorReplay( name="office_lidar", autocast=LidarMessage.from_msg @@ -44,7 +40,7 @@ def test_sensor_replay_cast(): assert counter == 500 -def test_timed_sensor_replay(): +def test_timed_sensor_replay() -> None: get_data("unitree_office_walk") odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) @@ -70,7 +66,7 @@ def test_timed_sensor_replay(): assert itermsgs[i] == timed_msgs[i] -def test_iterate_ts_no_seek(): +def test_iterate_ts_no_seek() -> None: """Test iterate_ts without seek (start_timestamp=None)""" odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) @@ -88,7 +84,7 @@ def test_iterate_ts_no_seek(): assert isinstance(msg, Odometry) -def test_iterate_ts_with_from_timestamp(): +def test_iterate_ts_with_from_timestamp() -> None: """Test iterate_ts with from_timestamp (absolute timestamp)""" odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) @@ -116,7 +112,7 @@ def test_iterate_ts_with_from_timestamp(): assert seeked_msgs[0][1] == all_msgs[4][1] -def test_iterate_ts_with_relative_seek(): +def test_iterate_ts_with_relative_seek() -> None: """Test iterate_ts with seek (relative seconds after first timestamp)""" odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) @@ -145,7 +141,7 @@ def test_iterate_ts_with_relative_seek(): assert seeked_msgs[0][0] > first_ts -def test_stream_with_seek(): +def test_stream_with_seek() -> None: """Test stream method with seek parameters""" odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) @@ -171,7 +167,7 @@ def test_stream_with_seek(): msgs_with_timestamp.append(msg) -def test_duration_with_loop(): +def test_duration_with_loop() -> None: """Test duration parameter with looping in TimedSensorReplay""" odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) @@ -180,7 +176,7 @@ def test_duration_with_loop(): duration = 0.3 # 300ms window # First pass: collect timestamps in the duration window - for ts, msg in odom_store.iterate_ts(duration=duration): + for ts, _msg in odom_store.iterate_ts(duration=duration): collected_ts.append(ts) if len(collected_ts) >= 100: # Safety limit break @@ -193,7 +189,7 @@ def test_duration_with_loop(): loop_count = 0 prev_ts = None - for ts, msg in odom_store.iterate_ts(duration=duration, loop=True): + for ts, _msg in odom_store.iterate_ts(duration=duration, loop=True): if prev_ts is not None and ts < prev_ts: # We've looped back to the beginning loop_count += 1 @@ -204,7 +200,7 @@ def test_duration_with_loop(): assert loop_count >= 2 # Verify we actually looped -def test_first_methods(): +def test_first_methods() -> None: """Test first() and first_timestamp() methods""" # Test SensorReplay.first() @@ -243,13 +239,13 @@ def test_first_methods(): assert isinstance(first_data, Odometry) -def test_find_closest(): +def test_find_closest() -> None: """Test find_closest method in TimedSensorReplay""" odom_store = testing.TimedSensorReplay("unitree_office_walk/odom", autocast=Odometry.from_msg) # Get some reference timestamps timestamps = [] - for ts, msg in odom_store.iterate_ts(): + for ts, _msg in odom_store.iterate_ts(): timestamps.append(ts) if len(timestamps) >= 10: break diff --git a/dimos/utils/test_transform_utils.py b/dimos/utils/test_transform_utils.py index 85128ac09c..8054971d3f 100644 --- a/dimos/utils/test_transform_utils.py +++ b/dimos/utils/test_transform_utils.py @@ -12,36 +12,36 @@ # See the License for the specific language governing permissions and # limitations under the License. -import pytest import numpy as np +import pytest from scipy.spatial.transform import Rotation as R +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 from dimos.utils import transform_utils -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion, Transform class TestNormalizeAngle: - def test_normalize_angle_zero(self): + def test_normalize_angle_zero(self) -> None: assert transform_utils.normalize_angle(0) == 0 - def test_normalize_angle_pi(self): + def test_normalize_angle_pi(self) -> None: assert np.isclose(transform_utils.normalize_angle(np.pi), np.pi) - def test_normalize_angle_negative_pi(self): + def test_normalize_angle_negative_pi(self) -> None: assert np.isclose(transform_utils.normalize_angle(-np.pi), -np.pi) - def test_normalize_angle_two_pi(self): + def test_normalize_angle_two_pi(self) -> None: # 2*pi should normalize to 0 assert np.isclose(transform_utils.normalize_angle(2 * np.pi), 0, atol=1e-10) - def test_normalize_angle_large_positive(self): + def test_normalize_angle_large_positive(self) -> None: # Large positive angle should wrap to [-pi, pi] angle = 5 * np.pi normalized = transform_utils.normalize_angle(angle) assert -np.pi <= normalized <= np.pi assert np.isclose(normalized, np.pi) - def test_normalize_angle_large_negative(self): + def test_normalize_angle_large_negative(self) -> None: # Large negative angle should wrap to [-pi, pi] angle = -5 * np.pi normalized = transform_utils.normalize_angle(angle) @@ -54,19 +54,19 @@ def test_normalize_angle_large_negative(self): class TestPoseToMatrix: - def test_identity_pose(self): + def test_identity_pose(self) -> None: pose = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) T = transform_utils.pose_to_matrix(pose) assert np.allclose(T, np.eye(4)) - def test_translation_only(self): + def test_translation_only(self) -> None: pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) T = transform_utils.pose_to_matrix(pose) expected = np.eye(4) expected[:3, 3] = [1, 2, 3] assert np.allclose(T, expected) - def test_rotation_only_90_degrees_z(self): + def test_rotation_only_90_degrees_z(self) -> None: # 90 degree rotation around z-axis quat = R.from_euler("z", np.pi / 2).as_quat() pose = Pose(Vector3(0, 0, 0), Quaternion(quat[0], quat[1], quat[2], quat[3])) @@ -79,7 +79,7 @@ def test_rotation_only_90_degrees_z(self): # Check translation is zero assert np.allclose(T[:3, 3], [0, 0, 0]) - def test_translation_and_rotation(self): + def test_translation_and_rotation(self) -> None: quat = R.from_euler("xyz", [np.pi / 4, np.pi / 6, np.pi / 3]).as_quat() pose = Pose(Vector3(5, -3, 2), Quaternion(quat[0], quat[1], quat[2], quat[3])) T = transform_utils.pose_to_matrix(pose) @@ -94,7 +94,7 @@ def test_translation_and_rotation(self): # Check bottom row assert np.allclose(T[3, :], [0, 0, 0, 1]) - def test_zero_norm_quaternion(self): + def test_zero_norm_quaternion(self) -> None: # Test handling of zero norm quaternion pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 0)) T = transform_utils.pose_to_matrix(pose) @@ -106,7 +106,7 @@ def test_zero_norm_quaternion(self): class TestMatrixToPose: - def test_identity_matrix(self): + def test_identity_matrix(self) -> None: T = np.eye(4) pose = transform_utils.matrix_to_pose(T) assert pose.position.x == 0 @@ -117,7 +117,7 @@ def test_identity_matrix(self): assert np.isclose(pose.orientation.y, 0) assert np.isclose(pose.orientation.z, 0) - def test_translation_only(self): + def test_translation_only(self) -> None: T = np.eye(4) T[:3, 3] = [1, 2, 3] pose = transform_utils.matrix_to_pose(T) @@ -126,7 +126,7 @@ def test_translation_only(self): assert pose.position.z == 3 assert np.isclose(pose.orientation.w, 1) - def test_rotation_only(self): + def test_rotation_only(self) -> None: T = np.eye(4) T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() pose = transform_utils.matrix_to_pose(T) @@ -141,7 +141,7 @@ def test_rotation_only(self): recovered_rot = R.from_quat(quat).as_matrix() assert np.allclose(recovered_rot, T[:3, :3]) - def test_round_trip_conversion(self): + def test_round_trip_conversion(self) -> None: # Test that pose -> matrix -> pose gives same result # Use a properly normalized quaternion quat = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_quat() @@ -161,7 +161,7 @@ def test_round_trip_conversion(self): class TestApplyTransform: - def test_identity_transform(self): + def test_identity_transform(self) -> None: pose = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) T_identity = np.eye(4) result = transform_utils.apply_transform(pose, T_identity) @@ -170,7 +170,7 @@ def test_identity_transform(self): assert np.isclose(result.position.y, pose.position.y) assert np.isclose(result.position.z, pose.position.z) - def test_translation_transform(self): + def test_translation_transform(self) -> None: pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) T = np.eye(4) T[:3, 3] = [2, 3, 4] @@ -180,7 +180,7 @@ def test_translation_transform(self): assert np.isclose(result.position.y, 3) # 3 + 0 assert np.isclose(result.position.z, 4) # 4 + 0 - def test_rotation_transform(self): + def test_rotation_transform(self) -> None: pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) T = np.eye(4) T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() # 90 degree rotation @@ -191,7 +191,7 @@ def test_rotation_transform(self): assert np.isclose(result.position.y, 1) assert np.isclose(result.position.z, 0) - def test_transform_with_transform_object(self): + def test_transform_with_transform_object(self) -> None: pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) pose.frame_id = "base" @@ -206,7 +206,7 @@ def test_transform_with_transform_object(self): assert np.isclose(result.position.y, 3) assert np.isclose(result.position.z, 4) - def test_transform_frame_mismatch_raises(self): + def test_transform_frame_mismatch_raises(self) -> None: pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) pose.frame_id = "base" @@ -221,14 +221,14 @@ def test_transform_frame_mismatch_raises(self): class TestOpticalToRobotFrame: - def test_identity_at_origin(self): + def test_identity_at_origin(self) -> None: pose = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) result = transform_utils.optical_to_robot_frame(pose) assert result.position.x == 0 assert result.position.y == 0 assert result.position.z == 0 - def test_position_transformation(self): + def test_position_transformation(self) -> None: # Optical: X=right(1), Y=down(0), Z=forward(0) pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) result = transform_utils.optical_to_robot_frame(pose) @@ -238,7 +238,7 @@ def test_position_transformation(self): assert np.isclose(result.position.y, -1) # Left = -Camera X assert np.isclose(result.position.z, 0) # Up = -Camera Y - def test_forward_position(self): + def test_forward_position(self) -> None: # Optical: X=right(0), Y=down(0), Z=forward(2) pose = Pose(Vector3(0, 0, 2), Quaternion(0, 0, 0, 1)) result = transform_utils.optical_to_robot_frame(pose) @@ -248,7 +248,7 @@ def test_forward_position(self): assert np.isclose(result.position.y, 0) assert np.isclose(result.position.z, 0) - def test_down_position(self): + def test_down_position(self) -> None: # Optical: X=right(0), Y=down(3), Z=forward(0) pose = Pose(Vector3(0, 3, 0), Quaternion(0, 0, 0, 1)) result = transform_utils.optical_to_robot_frame(pose) @@ -258,7 +258,7 @@ def test_down_position(self): assert np.isclose(result.position.y, 0) assert np.isclose(result.position.z, -3) - def test_round_trip_optical_robot(self): + def test_round_trip_optical_robot(self) -> None: original_pose = Pose(Vector3(1, 2, 3), Quaternion(0.1, 0.2, 0.3, 0.9165151389911680)) robot_pose = transform_utils.optical_to_robot_frame(original_pose) recovered_pose = transform_utils.robot_to_optical_frame(robot_pose) @@ -269,7 +269,7 @@ def test_round_trip_optical_robot(self): class TestRobotToOpticalFrame: - def test_position_transformation(self): + def test_position_transformation(self) -> None: # Robot: X=forward(1), Y=left(0), Z=up(0) pose = Pose(Vector3(1, 0, 0), Quaternion(0, 0, 0, 1)) result = transform_utils.robot_to_optical_frame(pose) @@ -279,7 +279,7 @@ def test_position_transformation(self): assert np.isclose(result.position.y, 0) assert np.isclose(result.position.z, 1) - def test_left_position(self): + def test_left_position(self) -> None: # Robot: X=forward(0), Y=left(2), Z=up(0) pose = Pose(Vector3(0, 2, 0), Quaternion(0, 0, 0, 1)) result = transform_utils.robot_to_optical_frame(pose) @@ -289,7 +289,7 @@ def test_left_position(self): assert np.isclose(result.position.y, 0) assert np.isclose(result.position.z, 0) - def test_up_position(self): + def test_up_position(self) -> None: # Robot: X=forward(0), Y=left(0), Z=up(3) pose = Pose(Vector3(0, 0, 3), Quaternion(0, 0, 0, 1)) result = transform_utils.robot_to_optical_frame(pose) @@ -301,31 +301,31 @@ def test_up_position(self): class TestYawTowardsPoint: - def test_yaw_from_origin(self): + def test_yaw_from_origin(self) -> None: # Point at (1, 0) from origin should have yaw = 0 position = Vector3(1, 0, 0) yaw = transform_utils.yaw_towards_point(position) assert np.isclose(yaw, 0) - def test_yaw_ninety_degrees(self): + def test_yaw_ninety_degrees(self) -> None: # Point at (0, 1) from origin should have yaw = pi/2 position = Vector3(0, 1, 0) yaw = transform_utils.yaw_towards_point(position) assert np.isclose(yaw, np.pi / 2) - def test_yaw_negative_ninety_degrees(self): + def test_yaw_negative_ninety_degrees(self) -> None: # Point at (0, -1) from origin should have yaw = -pi/2 position = Vector3(0, -1, 0) yaw = transform_utils.yaw_towards_point(position) assert np.isclose(yaw, -np.pi / 2) - def test_yaw_forty_five_degrees(self): + def test_yaw_forty_five_degrees(self) -> None: # Point at (1, 1) from origin should have yaw = pi/4 position = Vector3(1, 1, 0) yaw = transform_utils.yaw_towards_point(position) assert np.isclose(yaw, np.pi / 4) - def test_yaw_with_custom_target(self): + def test_yaw_with_custom_target(self) -> None: # Point at (3, 2) from target (1, 1) position = Vector3(3, 2, 0) target = Vector3(1, 1, 0) @@ -339,13 +339,13 @@ def test_yaw_with_custom_target(self): class TestCreateTransformFrom6DOF: - def test_identity_transform(self): + def test_identity_transform(self) -> None: trans = Vector3(0, 0, 0) euler = Vector3(0, 0, 0) T = transform_utils.create_transform_from_6dof(trans, euler) assert np.allclose(T, np.eye(4)) - def test_translation_only(self): + def test_translation_only(self) -> None: trans = Vector3(1, 2, 3) euler = Vector3(0, 0, 0) T = transform_utils.create_transform_from_6dof(trans, euler) @@ -354,7 +354,7 @@ def test_translation_only(self): expected[:3, 3] = [1, 2, 3] assert np.allclose(T, expected) - def test_rotation_only(self): + def test_rotation_only(self) -> None: trans = Vector3(0, 0, 0) euler = Vector3(np.pi / 4, np.pi / 6, np.pi / 3) T = transform_utils.create_transform_from_6dof(trans, euler) @@ -364,7 +364,7 @@ def test_rotation_only(self): assert np.allclose(T[:3, 3], [0, 0, 0]) assert np.allclose(T[3, :], [0, 0, 0, 1]) - def test_translation_and_rotation(self): + def test_translation_and_rotation(self) -> None: trans = Vector3(5, -3, 2) euler = Vector3(0.1, 0.2, 0.3) T = transform_utils.create_transform_from_6dof(trans, euler) @@ -373,7 +373,7 @@ def test_translation_and_rotation(self): assert np.allclose(T[:3, :3], expected_rot) assert np.allclose(T[:3, 3], [5, -3, 2]) - def test_small_angles_threshold(self): + def test_small_angles_threshold(self) -> None: trans = Vector3(1, 2, 3) euler = Vector3(1e-7, 1e-8, 1e-9) # Very small angles T = transform_utils.create_transform_from_6dof(trans, euler) @@ -385,12 +385,12 @@ def test_small_angles_threshold(self): class TestInvertTransform: - def test_identity_inverse(self): + def test_identity_inverse(self) -> None: T = np.eye(4) T_inv = transform_utils.invert_transform(T) assert np.allclose(T_inv, np.eye(4)) - def test_translation_inverse(self): + def test_translation_inverse(self) -> None: T = np.eye(4) T[:3, 3] = [1, 2, 3] T_inv = transform_utils.invert_transform(T) @@ -400,7 +400,7 @@ def test_translation_inverse(self): expected[:3, 3] = [-1, -2, -3] assert np.allclose(T_inv, expected) - def test_rotation_inverse(self): + def test_rotation_inverse(self) -> None: T = np.eye(4) T[:3, :3] = R.from_euler("z", np.pi / 2).as_matrix() T_inv = transform_utils.invert_transform(T) @@ -410,7 +410,7 @@ def test_rotation_inverse(self): expected[:3, :3] = R.from_euler("z", -np.pi / 2).as_matrix() assert np.allclose(T_inv, expected) - def test_general_transform_inverse(self): + def test_general_transform_inverse(self) -> None: T = np.eye(4) T[:3, :3] = R.from_euler("xyz", [0.1, 0.2, 0.3]).as_matrix() T[:3, 3] = [1, 2, 3] @@ -427,17 +427,17 @@ def test_general_transform_inverse(self): class TestComposeTransforms: - def test_no_transforms(self): + def test_no_transforms(self) -> None: result = transform_utils.compose_transforms() assert np.allclose(result, np.eye(4)) - def test_single_transform(self): + def test_single_transform(self) -> None: T = np.eye(4) T[:3, 3] = [1, 2, 3] result = transform_utils.compose_transforms(T) assert np.allclose(result, T) - def test_two_translations(self): + def test_two_translations(self) -> None: T1 = np.eye(4) T1[:3, 3] = [1, 0, 0] @@ -450,7 +450,7 @@ def test_two_translations(self): expected[:3, 3] = [1, 2, 0] assert np.allclose(result, expected) - def test_three_transforms(self): + def test_three_transforms(self) -> None: T1 = np.eye(4) T1[:3, 3] = [1, 0, 0] @@ -466,7 +466,7 @@ def test_three_transforms(self): class TestEulerToQuaternion: - def test_zero_euler(self): + def test_zero_euler(self) -> None: euler = Vector3(0, 0, 0) quat = transform_utils.euler_to_quaternion(euler) assert np.isclose(quat.w, 1) @@ -474,7 +474,7 @@ def test_zero_euler(self): assert np.isclose(quat.y, 0) assert np.isclose(quat.z, 0) - def test_roll_only(self): + def test_roll_only(self) -> None: euler = Vector3(np.pi / 2, 0, 0) quat = transform_utils.euler_to_quaternion(euler) @@ -484,7 +484,7 @@ def test_roll_only(self): assert np.isclose(recovered[1], 0) assert np.isclose(recovered[2], 0) - def test_pitch_only(self): + def test_pitch_only(self) -> None: euler = Vector3(0, np.pi / 3, 0) quat = transform_utils.euler_to_quaternion(euler) @@ -493,7 +493,7 @@ def test_pitch_only(self): assert np.isclose(recovered[1], np.pi / 3) assert np.isclose(recovered[2], 0) - def test_yaw_only(self): + def test_yaw_only(self) -> None: euler = Vector3(0, 0, np.pi / 4) quat = transform_utils.euler_to_quaternion(euler) @@ -502,7 +502,7 @@ def test_yaw_only(self): assert np.isclose(recovered[1], 0) assert np.isclose(recovered[2], np.pi / 4) - def test_degrees_mode(self): + def test_degrees_mode(self) -> None: euler = Vector3(45, 30, 60) # degrees quat = transform_utils.euler_to_quaternion(euler, degrees=True) @@ -513,14 +513,14 @@ def test_degrees_mode(self): class TestQuaternionToEuler: - def test_identity_quaternion(self): + def test_identity_quaternion(self) -> None: quat = Quaternion(0, 0, 0, 1) euler = transform_utils.quaternion_to_euler(quat) assert np.isclose(euler.x, 0) assert np.isclose(euler.y, 0) assert np.isclose(euler.z, 0) - def test_90_degree_yaw(self): + def test_90_degree_yaw(self) -> None: # Create quaternion for 90 degree yaw rotation r = R.from_euler("z", np.pi / 2) q = r.as_quat() @@ -531,7 +531,7 @@ def test_90_degree_yaw(self): assert np.isclose(euler.y, 0) assert np.isclose(euler.z, np.pi / 2) - def test_round_trip_euler_quaternion(self): + def test_round_trip_euler_quaternion(self) -> None: original_euler = Vector3(0.3, 0.5, 0.7) quat = transform_utils.euler_to_quaternion(original_euler) recovered_euler = transform_utils.quaternion_to_euler(quat) @@ -540,7 +540,7 @@ def test_round_trip_euler_quaternion(self): assert np.isclose(recovered_euler.y, original_euler.y, atol=1e-10) assert np.isclose(recovered_euler.z, original_euler.z, atol=1e-10) - def test_degrees_mode(self): + def test_degrees_mode(self) -> None: # Create quaternion for 45 degree yaw rotation r = R.from_euler("z", 45, degrees=True) q = r.as_quat() @@ -551,7 +551,7 @@ def test_degrees_mode(self): assert np.isclose(euler.y, 0) assert np.isclose(euler.z, 45) - def test_angle_normalization(self): + def test_angle_normalization(self) -> None: # Test that angles are normalized to [-pi, pi] r = R.from_euler("xyz", [3 * np.pi, -3 * np.pi, 2 * np.pi]) q = r.as_quat() @@ -564,43 +564,43 @@ def test_angle_normalization(self): class TestGetDistance: - def test_same_pose(self): + def test_same_pose(self) -> None: pose1 = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(1, 2, 3), Quaternion(0.1, 0.2, 0.3, 0.9)) distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, 0) - def test_vector_distance(self): + def test_vector_distance(self) -> None: pose1 = Vector3(1, 2, 3) pose2 = Vector3(4, 5, 6) distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, np.sqrt(3**2 + 3**2 + 3**2)) - def test_distance_x_axis(self): + def test_distance_x_axis(self) -> None: pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(5, 0, 0), Quaternion(0, 0, 0, 1)) distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, 5) - def test_distance_y_axis(self): + def test_distance_y_axis(self) -> None: pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(0, 3, 0), Quaternion(0, 0, 0, 1)) distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, 3) - def test_distance_z_axis(self): + def test_distance_z_axis(self) -> None: pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(0, 0, 4), Quaternion(0, 0, 0, 1)) distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, 4) - def test_3d_distance(self): + def test_3d_distance(self) -> None: pose1 = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(3, 4, 0), Quaternion(0, 0, 0, 1)) distance = transform_utils.get_distance(pose1, pose2) assert np.isclose(distance, 5) # 3-4-5 triangle - def test_negative_coordinates(self): + def test_negative_coordinates(self) -> None: pose1 = Pose(Vector3(-1, -2, -3), Quaternion(0, 0, 0, 1)) pose2 = Pose(Vector3(1, 2, 3), Quaternion(0, 0, 0, 1)) distance = transform_utils.get_distance(pose1, pose2) @@ -609,7 +609,7 @@ def test_negative_coordinates(self): class TestRetractDistance: - def test_retract_along_negative_z(self): + def test_retract_along_negative_z(self) -> None: # Default case: gripper approaches along -z axis # Positive distance moves away from the surface (opposite to approach direction) target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) @@ -627,7 +627,7 @@ def test_retract_along_negative_z(self): assert retracted.orientation.z == target_pose.orientation.z assert retracted.orientation.w == target_pose.orientation.w - def test_retract_with_rotation(self): + def test_retract_with_rotation(self) -> None: # Test with a rotated pose (90 degrees around x-axis) r = R.from_euler("x", np.pi / 2) q = r.as_quat() @@ -640,7 +640,7 @@ def test_retract_with_rotation(self): assert np.isclose(retracted.position.y, 0.5) # Move along +y assert np.isclose(retracted.position.z, 1) - def test_retract_negative_distance(self): + def test_retract_negative_distance(self) -> None: # Negative distance should move forward (toward the approach direction) target_pose = Pose(Vector3(0, 0, 1), Quaternion(0, 0, 0, 1)) retracted = transform_utils.offset_distance(target_pose, -0.3) @@ -650,7 +650,7 @@ def test_retract_negative_distance(self): assert np.isclose(retracted.position.y, 0) assert np.isclose(retracted.position.z, 1.3) # 1 + (-0.3) * (-1) = 1.3 - def test_retract_arbitrary_pose(self): + def test_retract_arbitrary_pose(self) -> None: # Test with arbitrary position and rotation r = R.from_euler("xyz", [0.1, 0.2, 0.3]) q = r.as_quat() diff --git a/dimos/utils/testing.py b/dimos/utils/testing.py index c5984cf3fd..5e3725bc81 100644 --- a/dimos/utils/testing.py +++ b/dimos/utils/testing.py @@ -11,22 +11,21 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections.abc import Callable, Iterator import functools import glob -import logging import os +from pathlib import Path import pickle import re -import shutil import time -from pathlib import Path -from typing import Any, Callable, Generic, Iterator, Optional, Tuple, TypeVar, Union +from typing import Any, Generic, TypeVar from reactivex import ( from_iterable, interval, + operators as ops, ) -from reactivex import operators as ops from reactivex.observable import Observable from reactivex.scheduler import TimeoutScheduler @@ -44,16 +43,16 @@ class SensorReplay(Generic[T]): For example: lambda data: LidarMessage.from_msg(data) """ - def __init__(self, name: str, autocast: Optional[Callable[[Any], T]] = None): + def __init__(self, name: str, autocast: Callable[[Any], T] | None = None) -> None: self.root_dir = get_data(name) self.autocast = autocast - def load(self, *names: Union[int, str]) -> Union[T, Any, list[T], list[Any]]: + def load(self, *names: int | str) -> T | Any | list[T] | list[Any]: if len(names) == 1: return self.load_one(names[0]) return list(map(lambda name: self.load_one(name), names)) - def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: + def load_one(self, name: int | str | Path) -> T | Any: if isinstance(name, int): full_path = self.root_dir / f"/{name:03d}.pickle" elif isinstance(name, Path): @@ -67,7 +66,7 @@ def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: return self.autocast(data) return data - def first(self) -> Optional[Union[T, Any]]: + def first(self) -> T | Any | None: try: return next(self.iterate()) except StopIteration: @@ -86,16 +85,14 @@ def extract_number(filepath): key=extract_number, ) - def iterate(self, loop: bool = False) -> Iterator[Union[T, Any]]: + def iterate(self, loop: bool = False) -> Iterator[T | Any]: while True: for file_path in self.files: yield self.load_one(Path(file_path)) if not loop: break - def stream( - self, rate_hz: Optional[float] = None, loop: bool = False - ) -> Observable[Union[T, Any]]: + def stream(self, rate_hz: float | None = None, loop: bool = False) -> Observable[T | Any]: if rate_hz is None: return from_iterable(self.iterate(loop=loop)) @@ -117,7 +114,7 @@ class SensorStorage(Generic[T]): autocast: Optional function that takes data and returns a processed result before storage. """ - def __init__(self, name: str, autocast: Optional[Callable[[T], Any]] = None): + def __init__(self, name: str, autocast: Callable[[T], Any] | None = None) -> None: self.name = name self.autocast = autocast self.cnt = 0 @@ -137,11 +134,11 @@ def __init__(self, name: str, autocast: Optional[Callable[[T], Any]] = None): # Create the directory self.root_dir.mkdir(parents=True, exist_ok=True) - def consume_stream(self, observable: Observable[Union[T, Any]]) -> None: + def consume_stream(self, observable: Observable[T | Any]) -> None: """Consume an observable stream of sensor data without saving.""" return observable.subscribe(self.save_one) - def save_stream(self, observable: Observable[Union[T, Any]]) -> Observable[int]: + def save_stream(self, observable: Observable[T | Any]) -> Observable[int]: """Save an observable stream of sensor data to pickle files.""" return observable.pipe(ops.map(lambda frame: self.save_one(frame))) @@ -180,7 +177,7 @@ def save_one(self, frame: T) -> int: class TimedSensorReplay(SensorReplay[T]): - def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: + def load_one(self, name: int | str | Path) -> T | Any: if isinstance(name, int): full_path = self.root_dir / f"/{name:03d}.pickle" elif isinstance(name, Path): @@ -194,9 +191,7 @@ def load_one(self, name: Union[int, str, Path]) -> Union[T, Any]: return (data[0], self.autocast(data[1])) return data - def find_closest( - self, timestamp: float, tolerance: Optional[float] = None - ) -> Optional[Union[T, Any]]: + def find_closest(self, timestamp: float, tolerance: float | None = None) -> T | Any | None: """Find the frame closest to the given timestamp. Args: @@ -226,8 +221,8 @@ def find_closest( return closest_data def find_closest_seek( - self, relative_seconds: float, tolerance: Optional[float] = None - ) -> Optional[Union[T, Any]]: + self, relative_seconds: float, tolerance: float | None = None + ) -> T | Any | None: """Find the frame closest to a time relative to the start. Args: @@ -246,7 +241,7 @@ def find_closest_seek( target_timestamp = first_ts + relative_seconds return self.find_closest(target_timestamp, tolerance) - def first_timestamp(self) -> Optional[float]: + def first_timestamp(self) -> float | None: """Get the timestamp of the first item in the dataset. Returns: @@ -258,16 +253,16 @@ def first_timestamp(self) -> Optional[float]: except StopIteration: return None - def iterate(self, loop: bool = False) -> Iterator[Union[T, Any]]: + def iterate(self, loop: bool = False) -> Iterator[T | Any]: return (x[1] for x in super().iterate(loop=loop)) def iterate_ts( self, - seek: Optional[float] = None, - duration: Optional[float] = None, - from_timestamp: Optional[float] = None, + seek: float | None = None, + duration: float | None = None, + from_timestamp: float | None = None, loop: bool = False, - ) -> Iterator[Union[Tuple[float, T], Any]]: + ) -> Iterator[tuple[float, T] | Any]: first_ts = None if (seek is not None) or (duration is not None): first_ts = self.first_timestamp() @@ -292,12 +287,12 @@ def iterate_ts( def stream( self, - speed=1.0, - seek: Optional[float] = None, - duration: Optional[float] = None, - from_timestamp: Optional[float] = None, + speed: float = 1.0, + seek: float | None = None, + duration: float | None = None, + from_timestamp: float | None = None, loop: bool = False, - ) -> Observable[Union[T, Any]]: + ) -> Observable[T | Any]: def _subscribe(observer, scheduler=None): from reactivex.disposable import CompositeDisposable, Disposable @@ -330,7 +325,7 @@ def _subscribe(observer, scheduler=None): observer.on_completed() return disp - def schedule_emission(message): + def schedule_emission(message) -> None: nonlocal next_message, is_disposed if is_disposed: @@ -348,7 +343,7 @@ def schedule_emission(message): target_time = start_local_time + (ts - start_replay_time) / speed delay = max(0.0, target_time - time.time()) - def emit(): + def emit() -> None: if is_disposed: return observer.on_next(data) @@ -365,7 +360,7 @@ def emit(): schedule_emission(next_message) # Create a custom disposable that properly cleans up - def dispose(): + def dispose() -> None: nonlocal is_disposed is_disposed = True disp.dispose() diff --git a/dimos/utils/transform_utils.py b/dimos/utils/transform_utils.py index 5b49d285cc..21421b4390 100644 --- a/dimos/utils/transform_utils.py +++ b/dimos/utils/transform_utils.py @@ -12,10 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. + import numpy as np -from typing import Tuple from scipy.spatial.transform import Rotation as R -from dimos.msgs.geometry_msgs import Pose, Vector3, Quaternion, Transform + +from dimos.msgs.geometry_msgs import Pose, Quaternion, Transform, Vector3 def normalize_angle(angle: float) -> float: diff --git a/dimos/web/dimos_interface/api/server.py b/dimos/web/dimos_interface/api/server.py index bcc590ab46..4f9979c085 100644 --- a/dimos/web/dimos_interface/api/server.py +++ b/dimos/web/dimos_interface/api/server.py @@ -25,30 +25,31 @@ # browser like Safari. # Fast Api & Uvicorn -import cv2 -from dimos.web.edge_io import EdgeIO -from fastapi import FastAPI, Request, Form, HTTPException, UploadFile, File -from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse -from sse_starlette.sse import EventSourceResponse -from fastapi.templating import Jinja2Templates -import uvicorn -from threading import Lock -from pathlib import Path -from queue import Queue, Empty import asyncio -from reactivex.disposable import SingleAssignmentDisposable -from reactivex import operators as ops -import reactivex as rx -from fastapi.middleware.cors import CORSMiddleware - # For audio processing import io +from pathlib import Path +from queue import Empty, Queue +from threading import Lock import time -import numpy as np + +import cv2 +from fastapi import FastAPI, File, Form, HTTPException, Request, UploadFile +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse +from fastapi.templating import Jinja2Templates import ffmpeg +import numpy as np +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import SingleAssignmentDisposable import soundfile as sf +from sse_starlette.sse import EventSourceResponse +import uvicorn + from dimos.stream.audio.base import AudioEvent +from dimos.web.edge_io import EdgeIO # TODO: Resolve threading, start/stop stream functionality. @@ -56,14 +57,14 @@ class FastAPIServer(EdgeIO): def __init__( self, - dev_name="FastAPI Server", - edge_type="Bidirectional", - host="0.0.0.0", - port=5555, + dev_name: str = "FastAPI Server", + edge_type: str = "Bidirectional", + host: str = "0.0.0.0", + port: int = 5555, text_streams=None, audio_subject=None, **streams, - ): + ) -> None: print("Starting FastAPIServer initialization...") # Debug print super().__init__(dev_name, edge_type) self.app = FastAPI() @@ -235,7 +236,7 @@ def _decode_audio(raw: bytes) -> tuple[np.ndarray, int]: print(f"ffmpeg decoding failed: {exc}") return None, None - def setup_routes(self): + def setup_routes(self) -> None: """Set up FastAPI routes.""" @self.app.get("/streams") @@ -275,7 +276,7 @@ async def submit_query(query: str = Form(...)): # Ensure we always return valid JSON even on error return JSONResponse( status_code=500, - content={"success": False, "message": f"Server error: {str(e)}"}, + content={"success": False, "message": f"Server error: {e!s}"}, ) @self.app.post("/upload_audio") @@ -335,10 +336,10 @@ async def unitree_command(request: Request): return JSONResponse(response) except Exception as e: - print(f"Error processing command: {str(e)}") + print(f"Error processing command: {e!s}") return JSONResponse( status_code=500, - content={"success": False, "message": f"Error processing command: {str(e)}"}, + content={"success": False, "message": f"Error processing command: {e!s}"}, ) @self.app.get("/text_stream/{key}") @@ -350,7 +351,7 @@ async def text_stream(key: str): for key in self.streams: self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) - def run(self): + def run(self) -> None: """Run the FastAPI server.""" uvicorn.run( self.app, host=self.host, port=self.port diff --git a/dimos/web/edge_io.py b/dimos/web/edge_io.py index 8511df2ce3..ad15614623 100644 --- a/dimos/web/edge_io.py +++ b/dimos/web/edge_io.py @@ -16,11 +16,11 @@ class EdgeIO: - def __init__(self, dev_name: str = "NA", edge_type: str = "Base"): + def __init__(self, dev_name: str = "NA", edge_type: str = "Base") -> None: self.dev_name = dev_name self.edge_type = edge_type self.disposables = CompositeDisposable() - def dispose_all(self): + def dispose_all(self) -> None: """Disposes of all active subscriptions managed by this agent.""" self.disposables.dispose() diff --git a/dimos/web/fastapi_server.py b/dimos/web/fastapi_server.py index 7dcd0f6d73..6c8a85344a 100644 --- a/dimos/web/fastapi_server.py +++ b/dimos/web/fastapi_server.py @@ -23,21 +23,22 @@ # browser like Safari. # Fast Api & Uvicorn +import asyncio +from pathlib import Path +from queue import Empty, Queue +from threading import Lock + import cv2 -from dimos.web.edge_io import EdgeIO -from fastapi import FastAPI, Request, Form, HTTPException -from fastapi.responses import HTMLResponse, StreamingResponse, JSONResponse -from sse_starlette.sse import EventSourceResponse +from fastapi import FastAPI, Form, HTTPException, Request +from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse from fastapi.templating import Jinja2Templates +import reactivex as rx +from reactivex import operators as ops +from reactivex.disposable import SingleAssignmentDisposable +from sse_starlette.sse import EventSourceResponse import uvicorn -from threading import Lock -from pathlib import Path -from queue import Queue, Empty -import asyncio -from reactivex.disposable import SingleAssignmentDisposable -from reactivex import operators as ops -import reactivex as rx +from dimos.web.edge_io import EdgeIO # TODO: Resolve threading, start/stop stream functionality. @@ -45,13 +46,13 @@ class FastAPIServer(EdgeIO): def __init__( self, - dev_name="FastAPI Server", - edge_type="Bidirectional", - host="0.0.0.0", - port=5555, + dev_name: str = "FastAPI Server", + edge_type: str = "Bidirectional", + host: str = "0.0.0.0", + port: int = 5555, text_streams=None, **streams, - ): + ) -> None: super().__init__(dev_name, edge_type) self.app = FastAPI() self.port = port @@ -176,7 +177,7 @@ async def text_stream_generator(self, key): finally: self.text_clients.remove(client_id) - def setup_routes(self): + def setup_routes(self) -> None: """Set up FastAPI routes.""" @self.app.get("/", response_class=HTMLResponse) @@ -205,7 +206,7 @@ async def submit_query(query: str = Form(...)): # Ensure we always return valid JSON even on error return JSONResponse( status_code=500, - content={"success": False, "message": f"Server error: {str(e)}"}, + content={"success": False, "message": f"Server error: {e!s}"}, ) @self.app.get("/text_stream/{key}") @@ -217,7 +218,7 @@ async def text_stream(key: str): for key in self.streams: self.app.get(f"/video_feed/{key}")(self.create_video_feed_route(key)) - def run(self): + def run(self) -> None: """Run the FastAPI server.""" uvicorn.run( self.app, host=self.host, port=self.port diff --git a/dimos/web/flask_server.py b/dimos/web/flask_server.py index 01d79f63cd..b0cf6fc143 100644 --- a/dimos/web/flask_server.py +++ b/dimos/web/flask_server.py @@ -12,17 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from flask import Flask, Response, render_template +from queue import Queue + import cv2 +from flask import Flask, Response, render_template from reactivex import operators as ops from reactivex.disposable import SingleAssignmentDisposable -from queue import Queue from dimos.web.edge_io import EdgeIO class FlaskServer(EdgeIO): - def __init__(self, dev_name="Flask Server", edge_type="Bidirectional", port=5555, **streams): + def __init__( + self, + dev_name: str = "Flask Server", + edge_type: str = "Bidirectional", + port: int = 5555, + **streams, + ) -> None: super().__init__(dev_name, edge_type) self.app = Flask(__name__) self.port = port @@ -44,7 +51,7 @@ def process_frame_flask(self, frame): _, buffer = cv2.imencode(".jpg", frame) return buffer.tobytes() - def setup_routes(self): + def setup_routes(self) -> None: @self.app.route("/") def index(): stream_keys = list(self.streams.keys()) # Get the keys from the streams dictionary @@ -90,6 +97,6 @@ def response_generator(): f"/video_feed/{key}", endpoint, view_func=make_response_generator(key) ) - def run(self, host="0.0.0.0", port=5555, threaded=True): + def run(self, host: str = "0.0.0.0", port: int = 5555, threaded: bool = True) -> None: self.port = port self.app.run(host=host, port=self.port, debug=False, threaded=threaded) diff --git a/dimos/web/robot_web_interface.py b/dimos/web/robot_web_interface.py index 33847c0056..0dc7636ac9 100644 --- a/dimos/web/robot_web_interface.py +++ b/dimos/web/robot_web_interface.py @@ -23,7 +23,7 @@ class RobotWebInterface(FastAPIServer): """Wrapper class for the dimos-interface FastAPI server.""" - def __init__(self, port=5555, text_streams=None, audio_subject=None, **streams): + def __init__(self, port: int = 5555, text_streams=None, audio_subject=None, **streams) -> None: super().__init__( dev_name="Robot Web Interface", edge_type="Bidirectional", diff --git a/dimos/web/websocket_vis/costmap_viz.py b/dimos/web/websocket_vis/costmap_viz.py index a1c6944d2b..ec2088b3b8 100644 --- a/dimos/web/websocket_vis/costmap_viz.py +++ b/dimos/web/websocket_vis/costmap_viz.py @@ -18,19 +18,19 @@ """ import numpy as np -from typing import Optional + from dimos.msgs.nav_msgs import OccupancyGrid class CostmapViz: """A wrapper around OccupancyGrid for visualization compatibility.""" - def __init__(self, occupancy_grid: Optional[OccupancyGrid] = None): + def __init__(self, occupancy_grid: OccupancyGrid | None = None) -> None: """Initialize from an OccupancyGrid.""" self.occupancy_grid = occupancy_grid @property - def data(self) -> Optional[np.ndarray]: + def data(self) -> np.ndarray | None: """Get the costmap data as a numpy array.""" if self.occupancy_grid: return self.occupancy_grid.grid diff --git a/dimos/web/websocket_vis/optimized_costmap.py b/dimos/web/websocket_vis/optimized_costmap.py index 30a226c66f..03307ff2c0 100644 --- a/dimos/web/websocket_vis/optimized_costmap.py +++ b/dimos/web/websocket_vis/optimized_costmap.py @@ -19,22 +19,23 @@ import base64 import hashlib import time -from typing import Dict, Any, Optional, Tuple -import numpy as np +from typing import Any import zlib +import numpy as np + class OptimizedCostmapEncoder: """Handles optimized encoding of costmaps with delta compression.""" - def __init__(self, chunk_size: int = 64): + def __init__(self, chunk_size: int = 64) -> None: self.chunk_size = chunk_size - self.last_full_grid: Optional[np.ndarray] = None + self.last_full_grid: np.ndarray | None = None self.last_full_sent_time: float = 0 # Track when last full update was sent - self.chunk_hashes: Dict[Tuple[int, int], str] = {} + self.chunk_hashes: dict[tuple[int, int], str] = {} self.full_update_interval = 3.0 # Send full update every 3 seconds - def encode_costmap(self, grid: np.ndarray, force_full: bool = False) -> Dict[str, Any]: + def encode_costmap(self, grid: np.ndarray, force_full: bool = False) -> dict[str, Any]: """Encode a costmap grid with optimizations. Args: @@ -59,7 +60,7 @@ def encode_costmap(self, grid: np.ndarray, force_full: bool = False) -> Dict[str else: return self._encode_delta(grid, current_time) - def _encode_full(self, grid: np.ndarray, current_time: float) -> Dict[str, Any]: + def _encode_full(self, grid: np.ndarray, current_time: float) -> dict[str, Any]: height, width = grid.shape # Convert to uint8 for better compression (costmap values are -1 to 100) @@ -88,7 +89,7 @@ def _encode_full(self, grid: np.ndarray, current_time: float) -> Dict[str, Any]: "data": encoded, } - def _encode_delta(self, grid: np.ndarray, current_time: float) -> Dict[str, Any]: + def _encode_delta(self, grid: np.ndarray, current_time: float) -> dict[str, Any]: height, width = grid.shape changed_chunks = [] @@ -145,7 +146,7 @@ def _encode_delta(self, grid: np.ndarray, current_time: float) -> Dict[str, Any] "chunks": changed_chunks, } - def _update_chunk_hashes(self, grid: np.ndarray): + def _update_chunk_hashes(self, grid: np.ndarray) -> None: """Update all chunk hashes for the grid.""" self.chunk_hashes.clear() height, width = grid.shape diff --git a/dimos/web/websocket_vis/path_history.py b/dimos/web/websocket_vis/path_history.py index 2bfa66a956..f60031bc51 100644 --- a/dimos/web/websocket_vis/path_history.py +++ b/dimos/web/websocket_vis/path_history.py @@ -17,16 +17,15 @@ This is a minimal implementation to support websocket visualization. """ -from typing import List, Optional, Union from dimos.msgs.geometry_msgs import Vector3 class PathHistory: """A simple container for storing a history of positions for visualization.""" - def __init__(self, points: Optional[List[Union[Vector3, tuple, list]]] = None): + def __init__(self, points: list[Vector3 | tuple | list] | None = None) -> None: """Initialize with optional list of points.""" - self.points: List[Vector3] = [] + self.points: list[Vector3] = [] if points: for p in points: if isinstance(p, Vector3): @@ -34,7 +33,7 @@ def __init__(self, points: Optional[List[Union[Vector3, tuple, list]]] = None): else: self.points.append(Vector3(*p)) - def ipush(self, point: Union[Vector3, tuple, list]) -> "PathHistory": + def ipush(self, point: Vector3 | tuple | list) -> "PathHistory": """Add a point to the history (in-place) and return self.""" if isinstance(point, Vector3): self.points.append(point) @@ -48,7 +47,7 @@ def iclip_tail(self, max_length: int) -> "PathHistory": self.points = self.points[-max_length:] return self - def last(self) -> Optional[Vector3]: + def last(self) -> Vector3 | None: """Return the last point in the history, or None if empty.""" return self.points[-1] if self.points else None diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index af1cb3bdd5..91e0428f33 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -19,19 +19,17 @@ """ import asyncio -import base64 import threading import time -from typing import Any, Dict, Optional +from typing import Any -import numpy as np -import socketio -import uvicorn from dimos_lcm.std_msgs import Bool from reactivex.disposable import Disposable +import socketio from starlette.applications import Starlette from starlette.responses import HTMLResponse from starlette.routing import Route +import uvicorn from dimos.core import In, Module, Out, rpc from dimos.mapping.types import LatLon @@ -77,7 +75,7 @@ class WebsocketVisModule(Module): cmd_vel: Out[Twist] = None movecmd_stamped: Out[TwistStamped] = None - def __init__(self, port: int = 7779, **kwargs): + def __init__(self, port: int = 7779, **kwargs) -> None: """Initialize the WebSocket visualization module. Args: @@ -86,12 +84,12 @@ def __init__(self, port: int = 7779, **kwargs): super().__init__(**kwargs) self.port = port - self._uvicorn_server_thread: Optional[threading.Thread] = None - self.sio: Optional[socketio.AsyncServer] = None + self._uvicorn_server_thread: threading.Thread | None = None + self.sio: socketio.AsyncServer | None = None self.app = None self._broadcast_loop = None self._broadcast_thread = None - self._uvicorn_server: Optional[uvicorn.Server] = None + self._uvicorn_server: uvicorn.Server | None = None self.vis_state = {} self.state_lock = threading.Lock() @@ -115,7 +113,7 @@ def websocket_vis_loop() -> None: self._broadcast_thread.start() @rpc - def start(self): + def start(self) -> None: super().start() self._create_server() @@ -128,32 +126,32 @@ def start(self): try: unsub = self.odom.subscribe(self._on_robot_pose) self._disposables.add(Disposable(unsub)) - except Exception as e: + except Exception: ... try: unsub = self.gps_location.subscribe(self._on_gps_location) self._disposables.add(Disposable(unsub)) - except Exception as e: + except Exception: ... try: unsub = self.path.subscribe(self._on_path) self._disposables.add(Disposable(unsub)) - except Exception as e: + except Exception: ... unsub = self.global_costmap.subscribe(self._on_global_costmap) self._disposables.add(Disposable(unsub)) @rpc - def stop(self): + def stop(self) -> None: if self._uvicorn_server: self._uvicorn_server.should_exit = True if self.sio and self._broadcast_loop and not self._broadcast_loop.is_closed(): - async def _disconnect_all(): + async def _disconnect_all() -> None: await self.sio.disconnect() asyncio.run_coroutine_threadsafe(_disconnect_all(), self._broadcast_loop) @@ -175,7 +173,7 @@ def set_gps_travel_goal_points(self, points: list[LatLon]) -> None: self.vis_state["gps_travel_goal_points"] = json_points self._emit("gps_travel_goal_points", json_points) - def _create_server(self): + def _create_server(self) -> None: # Create SocketIO server self.sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*") @@ -189,7 +187,7 @@ async def serve_index(request): # Register SocketIO event handlers @self.sio.event - async def connect(sid, environ): + async def connect(sid, environ) -> None: with self.state_lock: current_state = dict(self.vis_state) @@ -199,7 +197,7 @@ async def connect(sid, environ): await self.sio.emit("full_state", current_state, room=sid) @self.sio.event - async def click(sid, position): + async def click(sid, position) -> None: goal = PoseStamped( position=(position[0], position[1], 0), orientation=(0, 0, 0, 1), # Default orientation @@ -209,22 +207,22 @@ async def click(sid, position): logger.info(f"Click goal published: ({goal.position.x:.2f}, {goal.position.y:.2f})") @self.sio.event - async def gps_goal(sid, goal): + async def gps_goal(sid, goal) -> None: logger.info(f"Set GPS goal: {goal}") self.gps_goal.publish(LatLon(lat=goal["lat"], lon=goal["lon"])) @self.sio.event - async def start_explore(sid): + async def start_explore(sid) -> None: logger.info("Starting exploration") self.explore_cmd.publish(Bool(data=True)) @self.sio.event - async def stop_explore(sid): + async def stop_explore(sid) -> None: logger.info("Stopping exploration") self.stop_explore_cmd.publish(Bool(data=True)) @self.sio.event - async def move_command(sid, data): + async def move_command(sid, data) -> None: # Publish Twist if transport is configured if self.cmd_vel and self.cmd_vel.transport: twist = Twist( @@ -257,28 +255,28 @@ def _run_uvicorn_server(self) -> None: self._uvicorn_server = uvicorn.Server(config) self._uvicorn_server.run() - def _on_robot_pose(self, msg: PoseStamped): + def _on_robot_pose(self, msg: PoseStamped) -> None: pose_data = {"type": "vector", "c": [msg.position.x, msg.position.y, msg.position.z]} self.vis_state["robot_pose"] = pose_data self._emit("robot_pose", pose_data) - def _on_gps_location(self, msg: LatLon): + def _on_gps_location(self, msg: LatLon) -> None: pose_data = {"lat": msg.lat, "lon": msg.lon} self.vis_state["gps_location"] = pose_data self._emit("gps_location", pose_data) - def _on_path(self, msg: Path): + def _on_path(self, msg: Path) -> None: points = [[pose.position.x, pose.position.y] for pose in msg.poses] path_data = {"type": "path", "points": points} self.vis_state["path"] = path_data self._emit("path", path_data) - def _on_global_costmap(self, msg: OccupancyGrid): + def _on_global_costmap(self, msg: OccupancyGrid) -> None: costmap_data = self._process_costmap(msg) self.vis_state["costmap"] = costmap_data self._emit("costmap", costmap_data) - def _process_costmap(self, costmap: OccupancyGrid) -> Dict[str, Any]: + def _process_costmap(self, costmap: OccupancyGrid) -> dict[str, Any]: """Convert OccupancyGrid to visualization format.""" costmap = costmap.inflate(0.1).gradient(max_distance=1.0) grid_data = self.costmap_encoder.encode_costmap(costmap.grid) @@ -294,7 +292,7 @@ def _process_costmap(self, costmap: OccupancyGrid) -> Dict[str, Any]: "origin_theta": 0, # Assuming no rotation for now } - def _emit(self, event: str, data: Any): + def _emit(self, event: str, data: Any) -> None: if self._broadcast_loop and not self._broadcast_loop.is_closed(): asyncio.run_coroutine_threadsafe(self.sio.emit(event, data), self._broadcast_loop) diff --git a/pyproject.toml b/pyproject.toml index 2d3804c1fc..e4fae8dc12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,16 @@ exclude = [ "src" ] +[tool.ruff.lint] +extend-select = ["E", "W", "F", "B", "UP", "N", "I", "C90", "A", "RUF", "TCH"] +# TODO: All of these should be fixed, but it's easier commit autofixes first +ignore = ["A001", "A002", "A004", "B008", "B017", "B018", "B019", "B023", "B024", "B026", "B027", "B904", "C901", "E402", "E501", "E721", "E722", "E741", "F401", "F403", "F405", "F811", "F821", "F821", "F821", "N801", "N802", "N803", "N806", "N812", "N813", "N813", "N816", "N817", "N999", "RUF001", "RUF002", "RUF003", "RUF006", "RUF009", "RUF012", "RUF034", "RUF043", "RUF059", "TC010", "UP007", "UP035"] + +[tool.ruff.lint.isort] +known-first-party = ["dimos"] +combine-as-imports = true +force-sort-within-sections = true + [tool.mypy] # mypy doesn't understand plum @dispatch decorator # so we gave up on this check globally diff --git a/setup.py b/setup.py index 0a77274dca..15fa5aa750 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from setuptools import setup, find_packages +from setuptools import find_packages, setup setup( packages=find_packages(), diff --git a/tests/agent_manip_flow_fastapi_test.py b/tests/agent_manip_flow_fastapi_test.py index c7dec66f74..f8b6df4244 100644 --- a/tests/agent_manip_flow_fastapi_test.py +++ b/tests/agent_manip_flow_fastapi_test.py @@ -17,23 +17,19 @@ It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. """ -import tests.test_header -import os - # ----- - # Standard library imports import multiprocessing +import os + from dotenv import load_dotenv # Third-party imports -from fastapi import FastAPI from reactivex import operators as ops from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler +from reactivex.scheduler import ThreadPoolScheduler # Local application imports -from dimos.agents.agent import OpenAIAgent from dimos.stream.frame_processor import FrameProcessor from dimos.stream.video_operators import VideoOperators as vops from dimos.stream.video_provider import VideoProvider @@ -55,7 +51,7 @@ def main(): Raises: RuntimeError: If video sources are unavailable or processing fails. """ - disposables = CompositeDisposable() + CompositeDisposable() processor = FrameProcessor( output_dir=f"{os.getcwd()}/assets/output/frames", delete_on_init=True @@ -112,7 +108,7 @@ def main(): optical_flow_stream_obs = optical_flow_relevancy_stream_obs.pipe( ops.do_action(lambda result: print(f"Optical Flow Relevancy Score: {result[1]}")), vops.with_optical_flow_filtering(threshold=2.0), - ops.do_action(lambda _: print(f"Optical Flow Passed Threshold.")), + ops.do_action(lambda _: print("Optical Flow Passed Threshold.")), vops.with_jpeg_export(processor, suffix="optical"), ) diff --git a/tests/agent_manip_flow_flask_test.py b/tests/agent_manip_flow_flask_test.py index 2356eb74ae..e96c6f2d20 100644 --- a/tests/agent_manip_flow_flask_test.py +++ b/tests/agent_manip_flow_flask_test.py @@ -17,24 +17,21 @@ It handles video capture, frame processing, and exposes the processed video streams via HTTP endpoints. """ -import tests.test_header -import os - # ----- - # Standard library imports import multiprocessing +import os + from dotenv import load_dotenv # Third-party imports from flask import Flask -from reactivex import operators as ops -from reactivex import of, interval, zip +from reactivex import interval, operators as ops, zip from reactivex.disposable import CompositeDisposable -from reactivex.scheduler import ThreadPoolScheduler, CurrentThreadScheduler, ImmediateScheduler +from reactivex.scheduler import ThreadPoolScheduler # Local application imports -from dimos.agents.agent import PromptBuilder, OpenAIAgent +from dimos.agents.agent import OpenAIAgent from dimos.stream.frame_processor import FrameProcessor from dimos.stream.video_operators import VideoOperators as vops from dimos.stream.video_provider import VideoProvider @@ -92,7 +89,7 @@ def main(): # vops.with_jpeg_export(processor, suffix="raw_slowed"), ) - edge_detection_stream_obs = processor.process_stream_edge_detection(video_stream_obs).pipe( + processor.process_stream_edge_detection(video_stream_obs).pipe( # vops.with_jpeg_export(processor, suffix="edge"), ) diff --git a/tests/agent_memory_test.py b/tests/agent_memory_test.py index b662af18bd..c2c41ad502 100644 --- a/tests/agent_memory_test.py +++ b/tests/agent_memory_test.py @@ -12,13 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os # ----- - from dotenv import load_dotenv -import os load_dotenv() diff --git a/tests/genesissim/stream_camera.py b/tests/genesissim/stream_camera.py index 56ad5c4286..9346f58595 100644 --- a/tests/genesissim/stream_camera.py +++ b/tests/genesissim/stream_camera.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os + from dimos.simulation.genesis import GenesisSimulator, GenesisStream diff --git a/tests/isaacsim/stream_camera.py b/tests/isaacsim/stream_camera.py index b641b3cbe3..7aa25e7e38 100644 --- a/tests/isaacsim/stream_camera.py +++ b/tests/isaacsim/stream_camera.py @@ -13,8 +13,8 @@ # limitations under the License. import os -from dimos.simulation.isaac import IsaacSimulator -from dimos.simulation.isaac import IsaacStream + +from dimos.simulation.isaac import IsaacSimulator, IsaacStream def main(): diff --git a/tests/run.py b/tests/run.py index 9ae6f81398..d64bbb11c0 100644 --- a/tests/run.py +++ b/tests/run.py @@ -12,41 +12,34 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header +import asyncio +import atexit +import logging import os - +import signal +import threading import time +import warnings + from dotenv import load_dotenv -from dimos.agents.cerebras_agent import CerebrasAgent +import reactivex as rx +import reactivex.operators as ops + from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.perception.object_detection_stream import ObjectDetectionStream # from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.observe import Observe +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal, Explore -from dimos.skills.visual_navigation_skills import FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -import threading -import json -from dimos.types.vector import Vector +from dimos.skills.navigation import Explore, GetPose, NavigateToGoal, NavigateWithText +from dimos.skills.observe import Observe +from dimos.skills.observe_stream import ObserveStream from dimos.skills.unitree.unitree_speak import UnitreeSpeak - -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.stream.audio.pipelines import stt +from dimos.types.vector import Vector from dimos.utils.reactive import backpressure -import asyncio -import atexit -import signal -import sys -import warnings -import logging +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.server import WebsocketVis # Filter out known WebRTC warnings that don't affect functionality warnings.filterwarnings("ignore", message="coroutine.*was never awaited") @@ -289,9 +282,7 @@ def combine_with_locations(object_detections): stt_node.consume_audio(audio_subject.pipe(ops.share())) # Read system query from prompt.txt file -with open( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets/agent/prompt.txt"), "r" -) as f: +with open(os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets/agent/prompt.txt")) as f: system_query = f.read() # Create a ClaudeAgent instance diff --git a/tests/run_go2_ros.py b/tests/run_go2_ros.py index 6bba1c1797..bc083a3a57 100644 --- a/tests/run_go2_ros.py +++ b/tests/run_go2_ros.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - import os import time @@ -48,7 +46,7 @@ def get_env_var(var_name, default=None, required=False): connection_method = getattr(WebRTCConnectionMethod, connection_method) print("Initializing UnitreeGo2...") - print(f"Configuration:") + print("Configuration:") print(f" IP: {robot_ip}") print(f" Connection Method: {connection_method}") print(f" Serial Number: {serial_number if serial_number else 'Not provided'}") diff --git a/tests/run_navigation_only.py b/tests/run_navigation_only.py index 2995750e2b..947da9c3a2 100644 --- a/tests/run_navigation_only.py +++ b/tests/run_navigation_only.py @@ -11,22 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os -from dotenv import load_dotenv -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree_webrtc.testing.helpers import show3d_stream -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.types.vector import Vector -import reactivex.operators as ops -import time -import threading import asyncio import atexit +import logging +import os import signal -import sys +import threading +import time import warnings -import logging + +from dotenv import load_dotenv +import reactivex.operators as ops + +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +from dimos.types.vector import Vector +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.server import WebsocketVis + # logging.basicConfig(level=logging.DEBUG) # Filter out known WebRTC warnings that don't affect functionality diff --git a/tests/simple_agent_test.py b/tests/simple_agent_test.py index 2534eac31b..f2cf8493d4 100644 --- a/tests/simple_agent_test.py +++ b/tests/simple_agent_test.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header +import os +from dimos.agents.agent import OpenAIAgent from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.agents.agent import OpenAIAgent -import os +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills # Initialize robot robot = UnitreeGo2( diff --git a/tests/test_agent.py b/tests/test_agent.py index e2c8f89f8e..e91345ff6a 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -12,12 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import os -import tests.test_header # ----- - from dotenv import load_dotenv diff --git a/tests/test_agent_alibaba.py b/tests/test_agent_alibaba.py index 9519387b7b..fa4dfe80bf 100644 --- a/tests/test_agent_alibaba.py +++ b/tests/test_agent_alibaba.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - import os -from dimos.agents.agent import OpenAIAgent + from openai import OpenAI -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler + +from dimos.agents.agent import OpenAIAgent from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler # Initialize video stream video_stream = VideoProvider( diff --git a/tests/test_agent_ctransformers_gguf.py b/tests/test_agent_ctransformers_gguf.py index 6cd3405239..389a9c74c5 100644 --- a/tests/test_agent_ctransformers_gguf.py +++ b/tests/test_agent_ctransformers_gguf.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - from dimos.agents.agent_ctransformers_gguf import CTransformersGGUFAgent system_query = "You are a robot with the following functions. Move(), Reverse(), Left(), Right(), Stop(). Given the following user comands return the correct function." diff --git a/tests/test_agent_huggingface_local.py b/tests/test_agent_huggingface_local.py index 4c4536a197..eb88dd9847 100644 --- a/tests/test_agent_huggingface_local.py +++ b/tests/test_agent_huggingface_local.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.stream.data_provider import QueryDataProvider -import tests.test_header - import os -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer + from dimos.agents.agent_huggingface_local import HuggingFaceLocalAgent from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.data_provider import QueryDataProvider +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler # Initialize video stream video_stream = VideoProvider( diff --git a/tests/test_agent_huggingface_local_jetson.py b/tests/test_agent_huggingface_local_jetson.py index 6d29b3903f..883a05be54 100644 --- a/tests/test_agent_huggingface_local_jetson.py +++ b/tests/test_agent_huggingface_local_jetson.py @@ -12,15 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.stream.data_provider import QueryDataProvider -import tests.test_header - import os -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer + from dimos.agents.agent_huggingface_local import HuggingFaceLocalAgent from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.data_provider import QueryDataProvider +from dimos.stream.video_provider import VideoProvider +from dimos.utils.threadpool import get_scheduler # Initialize video stream video_stream = VideoProvider( diff --git a/tests/test_agent_huggingface_remote.py b/tests/test_agent_huggingface_remote.py index 7129523bf0..ed99faa8a4 100644 --- a/tests/test_agent_huggingface_remote.py +++ b/tests/test_agent_huggingface_remote.py @@ -12,15 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.stream.data_provider import QueryDataProvider -import tests.test_header -import os -from dimos.stream.video_provider import VideoProvider -from dimos.utils.threadpool import get_scheduler -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer from dimos.agents.agent_huggingface_remote import HuggingFaceRemoteAgent -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer +from dimos.stream.data_provider import QueryDataProvider # Initialize video stream # video_stream = VideoProvider( diff --git a/tests/test_audio_agent.py b/tests/test_audio_agent.py index 6caf24b9eb..d79d2040c2 100644 --- a/tests/test_audio_agent.py +++ b/tests/test_audio_agent.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dimos.agents.agent import OpenAIAgent +from dimos.stream.audio.pipelines import stt, tts from dimos.stream.audio.utils import keepalive -from dimos.stream.audio.pipelines import tts, stt from dimos.utils.threadpool import get_scheduler -from dimos.agents.agent import OpenAIAgent def main(): diff --git a/tests/test_audio_robot_agent.py b/tests/test_audio_robot_agent.py index 411e4a56c1..27340fcd80 100644 --- a/tests/test_audio_robot_agent.py +++ b/tests/test_audio_robot_agent.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.utils.threadpool import get_scheduler import os + +from dimos.agents.agent import OpenAIAgent from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.agents.agent import OpenAIAgent -from dimos.stream.audio.pipelines import tts, stt +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills +from dimos.stream.audio.pipelines import stt, tts from dimos.stream.audio.utils import keepalive +from dimos.utils.threadpool import get_scheduler def main(): diff --git a/tests/test_cerebras_unitree_ros.py b/tests/test_cerebras_unitree_ros.py index cbb7c130db..60890a3d5c 100644 --- a/tests/test_cerebras_unitree_ros.py +++ b/tests/test_cerebras_unitree_ros.py @@ -12,29 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. -import sys import os -from dimos.robot.robot import MockRobot -import tests.test_header -import time from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + from dimos.agents.cerebras_agent import CerebrasAgent from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -from dimos.web.websocket_vis.server import WebsocketVis -import threading -from dimos.types.vector import Vector +from dimos.skills.navigation import GetPose, NavigateToGoal, NavigateWithText +from dimos.skills.observe_stream import ObserveStream from dimos.skills.speak import Speak +from dimos.skills.visual_navigation_skills import FollowHuman +from dimos.stream.audio.pipelines import stt, tts +from dimos.web.robot_web_interface import RobotWebInterface # Load API key from environment load_dotenv() diff --git a/tests/test_claude_agent_query.py b/tests/test_claude_agent_query.py index aabd85bc12..05893a6b9d 100644 --- a/tests/test_claude_agent_query.py +++ b/tests/test_claude_agent_query.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - from dotenv import load_dotenv + from dimos.agents.claude_agent import ClaudeAgent # Load API key from environment diff --git a/tests/test_claude_agent_skills_query.py b/tests/test_claude_agent_skills_query.py index 1aaeb795f1..bb5753d2db 100644 --- a/tests/test_claude_agent_skills_query.py +++ b/tests/test_claude_agent_skills_query.py @@ -12,27 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header import os +import threading -import time from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + from dimos.agents.claude_agent import ClaudeAgent from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import Navigate, BuildSemanticMap, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import NavigateToObject, FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -from dimos.web.websocket_vis.server import WebsocketVis -import threading -from dimos.types.vector import Vector +from dimos.skills.navigation import BuildSemanticMap, GetPose, Navigate, NavigateToGoal +from dimos.skills.observe_stream import ObserveStream from dimos.skills.speak import Speak +from dimos.skills.visual_navigation_skills import FollowHuman, NavigateToObject +from dimos.stream.audio.pipelines import stt, tts +from dimos.types.vector import Vector +from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.server import WebsocketVis # Load API key from environment load_dotenv() diff --git a/tests/test_command_pose_unitree.py b/tests/test_command_pose_unitree.py index 22cf0e82ed..f67b8c969f 100644 --- a/tests/test_command_pose_unitree.py +++ b/tests/test_command_pose_unitree.py @@ -18,12 +18,12 @@ # Add the parent directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl import os import time -import math + +from dimos.robot.unitree.unitree_go2 import UnitreeGo2 +from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills # Initialize robot robot = UnitreeGo2( diff --git a/tests/test_header.py b/tests/test_header.py index 48ea6dd509..05e6c3e21c 100644 --- a/tests/test_header.py +++ b/tests/test_header.py @@ -19,9 +19,9 @@ tests to import from the main application. """ -import sys -import os import inspect +import os +import sys # Add the parent directory of 'tests' to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/tests/test_huggingface_llm_agent.py b/tests/test_huggingface_llm_agent.py index e5914f1311..5d3c1f39a5 100644 --- a/tests/test_huggingface_llm_agent.py +++ b/tests/test_huggingface_llm_agent.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - import os import time diff --git a/tests/test_manipulation_agent.py b/tests/test_manipulation_agent.py index 5062fd8446..bd09b23b5e 100644 --- a/tests/test_manipulation_agent.py +++ b/tests/test_manipulation_agent.py @@ -12,45 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.skills.skills import SkillLibrary -import tests.test_header +import datetime import os -import time +import cv2 from dotenv import load_dotenv -from dimos.agents.claude_agent import ClaudeAgent -from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream -from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import FollowHuman +from openai import OpenAI import reactivex as rx import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -import threading -import json -import cv2 -import numpy as np -import os -import datetime -from dimos.types.vector import Vector -from dimos.skills.speak import Speak -from dimos.perception.object_detection_stream import ObjectDetectionStream +from reactivex.subject import BehaviorSubject + +from dimos.agents.claude_agent import ClaudeAgent from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.agents.agent import OpenAIAgent -from dimos.agents.tokenizer.huggingface_tokenizer import HuggingFaceTokenizer -from openai import OpenAI -from dimos.utils.reactive import backpressure -from dimos.stream.video_provider import VideoProvider -from reactivex.subject import Subject, BehaviorSubject -from dimos.utils.logging_config import setup_logger -from dimos.skills.manipulation.translation_constraint_skill import TranslationConstraintSkill -from dimos.skills.manipulation.rotation_constraint_skill import RotationConstraintSkill -from dimos.skills.manipulation.manipulate_skill import Manipulate +from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.robot.robot import MockManipulationRobot +from dimos.skills.manipulation.manipulate_skill import Manipulate +from dimos.skills.manipulation.rotation_constraint_skill import RotationConstraintSkill +from dimos.skills.manipulation.translation_constraint_skill import TranslationConstraintSkill +from dimos.skills.skills import SkillLibrary +from dimos.utils.logging_config import setup_logger +from dimos.utils.reactive import backpressure +from dimos.web.robot_web_interface import RobotWebInterface # Initialize logger for the agent module logger = setup_logger("dimos.tests.test_manipulation_agent") @@ -207,7 +189,7 @@ def combine_with_locations(object_detections): # Read system query from prompt.txt file with open( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt"), "r" + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt") ) as f: system_query = f.read() diff --git a/tests/test_manipulation_perception_pipeline.py b/tests/test_manipulation_perception_pipeline.py index 227f991650..6f8755d3da 100644 --- a/tests/test_manipulation_perception_pipeline.py +++ b/tests/test_manipulation_perception_pipeline.py @@ -26,17 +26,15 @@ # limitations under the License. import sys -import time import threading -from reactivex import operators as ops - -import tests.test_header +import time from pyzed import sl +from reactivex import operators as ops + +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline def monitor_grasps(pipeline): @@ -138,10 +136,10 @@ def main(): ) grasp_monitor_thread.start() - print(f"\n Point Cloud + Grasp Generation Test Running:") + print("\n Point Cloud + Grasp Generation Test Running:") print(f" Web Interface: http://localhost:{web_port}") - print(f" Object Detection View: RGB with bounding boxes") - print(f" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(" Object Detection View: RGB with bounding boxes") + print(" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") print(f" Confidence threshold: {min_confidence}") print(f" Grasp server: {grasp_server_url}") print(f" Available streams: {list(streams.keys())}") diff --git a/tests/test_manipulation_perception_pipeline.py.py b/tests/test_manipulation_perception_pipeline.py.py index 227f991650..6f8755d3da 100644 --- a/tests/test_manipulation_perception_pipeline.py.py +++ b/tests/test_manipulation_perception_pipeline.py.py @@ -26,17 +26,15 @@ # limitations under the License. import sys -import time import threading -from reactivex import operators as ops - -import tests.test_header +import time from pyzed import sl +from reactivex import operators as ops + +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline def monitor_grasps(pipeline): @@ -138,10 +136,10 @@ def main(): ) grasp_monitor_thread.start() - print(f"\n Point Cloud + Grasp Generation Test Running:") + print("\n Point Cloud + Grasp Generation Test Running:") print(f" Web Interface: http://localhost:{web_port}") - print(f" Object Detection View: RGB with bounding boxes") - print(f" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print(" Object Detection View: RGB with bounding boxes") + print(" Point Cloud View: Depth with colored point clouds and 3D bounding boxes") print(f" Confidence threshold: {min_confidence}") print(f" Grasp server: {grasp_server_url}") print(f" Available streams: {list(streams.keys())}") diff --git a/tests/test_manipulation_pipeline_single_frame.py b/tests/test_manipulation_pipeline_single_frame.py index 629ba4dbee..c29b2b2607 100644 --- a/tests/test_manipulation_pipeline_single_frame.py +++ b/tests/test_manipulation_pipeline_single_frame.py @@ -14,12 +14,13 @@ """Test manipulation processor with direct visualization and grasp data output.""" +import argparse import os + import cv2 -import numpy as np -import argparse import matplotlib -import tests.test_header +import numpy as np + from dimos.utils.data import get_data # Try to use TkAgg backend for live display, fallback to Agg if not available @@ -33,17 +34,17 @@ import matplotlib.pyplot as plt import open3d as o3d -from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid from dimos.manipulation.manip_aio_processer import ManipulationProcessor +from dimos.perception.grasp_generation.utils import create_grasp_overlay, visualize_grasps_3d from dimos.perception.pointcloud.utils import ( + combine_object_pointclouds, load_camera_matrix_from_yaml, + visualize_clustered_point_clouds, visualize_pcd, - combine_object_pointclouds, + visualize_voxel_grid, ) from dimos.utils.logging_config import setup_logger -from dimos.perception.grasp_generation.utils import visualize_grasps_3d, create_grasp_overlay - logger = setup_logger("test_pipeline_viz") @@ -161,7 +162,7 @@ def main(): else: rows = 2 cols = (num_plots + 1) // 2 - fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) + _fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) if num_plots == 1: axes = [axes] diff --git a/tests/test_manipulation_pipeline_single_frame_lcm.py b/tests/test_manipulation_pipeline_single_frame_lcm.py index 7b57887ddc..0c2f2bc591 100644 --- a/tests/test_manipulation_pipeline_single_frame_lcm.py +++ b/tests/test_manipulation_pipeline_single_frame_lcm.py @@ -14,15 +14,13 @@ """Test manipulation processor with LCM topic subscription.""" -import os -import sys -import cv2 -import numpy as np import argparse -import threading import pickle +import threading + +import cv2 import matplotlib -import tests.test_header +import numpy as np # Try to use TkAgg backend for live display, fallback to Agg if not available try: @@ -32,19 +30,13 @@ matplotlib.use("Qt5Agg") except: matplotlib.use("Agg") # Fallback to non-interactive -import matplotlib.pyplot as plt -import open3d as o3d -from typing import Dict, List, Optional # LCM imports import lcm -from lcm_msgs.sensor_msgs import Image as LCMImage -from lcm_msgs.sensor_msgs import CameraInfo as LCMCameraInfo +from lcm_msgs.sensor_msgs import CameraInfo as LCMCameraInfo, Image as LCMImage +import open3d as o3d -from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid from dimos.manipulation.manip_aio_processer import ManipulationProcessor -from dimos.perception.grasp_generation.utils import visualize_grasps_3d -from dimos.perception.pointcloud.utils import visualize_pcd from dimos.utils.logging_config import setup_logger logger = setup_logger("test_pipeline_lcm") @@ -57,9 +49,9 @@ def __init__(self, lcm_url: str = "udpm://239.255.76.67:7667?ttl=1"): self.lcm = lcm.LCM(lcm_url) # Data storage - self.rgb_data: Optional[np.ndarray] = None - self.depth_data: Optional[np.ndarray] = None - self.camera_intrinsics: Optional[List[float]] = None + self.rgb_data: np.ndarray | None = None + self.depth_data: np.ndarray | None = None + self.camera_intrinsics: list[float] | None = None # Synchronization self.data_lock = threading.Lock() @@ -278,14 +270,14 @@ def main(): results = run_processor(color_img, depth_img, intrinsics) # Debug: Print what we received - print(f"\n✅ Processor Results:") + print("\n✅ Processor Results:") print(f" Available results: {list(results.keys())}") print(f" Processing time: {results.get('processing_time', 0):.3f}s") # Show timing breakdown if available if "timing_breakdown" in results: breakdown = results["timing_breakdown"] - print(f" Timing breakdown:") + print(" Timing breakdown:") print(f" - Detection: {breakdown.get('detection', 0):.3f}s") print(f" - Segmentation: {breakdown.get('segmentation', 0):.3f}s") print(f" - Point cloud: {breakdown.get('pointcloud', 0):.3f}s") @@ -299,17 +291,17 @@ def main(): print(f" All objects processed: {all_count}") # Print misc clusters information - if "misc_clusters" in results and results["misc_clusters"]: + if results.get("misc_clusters"): cluster_count = len(results["misc_clusters"]) total_misc_points = sum( len(np.asarray(cluster.points)) for cluster in results["misc_clusters"] ) print(f" Misc clusters: {cluster_count} clusters with {total_misc_points} total points") else: - print(f" Misc clusters: None") + print(" Misc clusters: None") # Print grasp summary - if "grasps" in results and results["grasps"]: + if results.get("grasps"): total_grasps = 0 best_score = 0 for grasp in results["grasps"]: @@ -414,7 +406,7 @@ def serialize_voxel_grid(voxel_grid): with open(pickle_path, "wb") as f: pickle.dump(pickle_data, f) - print(f"Results saved successfully with all 3D data serialized!") + print("Results saved successfully with all 3D data serialized!") print(f"Pickled data keys: {list(pickle_data['results'].keys())}") # Visualization code has been moved to visualization_script.py diff --git a/tests/test_move_vel_unitree.py b/tests/test_move_vel_unitree.py index fe4d09a8e1..4700c056aa 100644 --- a/tests/test_move_vel_unitree.py +++ b/tests/test_move_vel_unitree.py @@ -12,13 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header +import os +import time from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -import os -import time +from dimos.robot.unitree.unitree_skills import MyUnitreeSkills # Initialize robot robot = UnitreeGo2( diff --git a/tests/test_navigate_to_object_robot.py b/tests/test_navigate_to_object_robot.py index eb2767d6ca..ecf4fd4956 100644 --- a/tests/test_navigate_to_object_robot.py +++ b/tests/test_navigate_to_object_robot.py @@ -12,20 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -import time -import sys import argparse +import os import threading -from reactivex import Subject, operators as RxOps +import time + +from reactivex import operators as RxOps from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger from dimos.skills.navigation import Navigate -import tests.test_header +from dimos.utils.logging_config import logger +from dimos.web.robot_web_interface import RobotWebInterface def parse_args(): diff --git a/tests/test_navigation_skills.py b/tests/test_navigation_skills.py index 9a91d1aba5..93497de691 100644 --- a/tests/test_navigation_skills.py +++ b/tests/test_navigation_skills.py @@ -25,16 +25,12 @@ python simple_navigation_test.py --skip-build --query "kitchen" """ -import os -import sys -import time -import logging import argparse -import threading -from reactivex import Subject, operators as RxOps import os +import threading +import time -import tests.test_header +from reactivex import operators as RxOps from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl diff --git a/tests/test_object_detection_agent_data_query_stream.py b/tests/test_object_detection_agent_data_query_stream.py index 00e5625119..ca5671f78e 100644 --- a/tests/test_object_detection_agent_data_query_stream.py +++ b/tests/test_object_detection_agent_data_query_stream.py @@ -12,27 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import time import sys -import argparse import threading -from typing import List, Dict, Any -from reactivex import Subject, operators as ops +from dotenv import load_dotenv +from reactivex import operators as ops + +from dimos.agents.claude_agent import ClaudeAgent +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.types.vector import Vector from dimos.utils.reactive import backpressure -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector -from dimos.agents.claude_agent import ClaudeAgent - -from dotenv import load_dotenv +from dimos.web.robot_web_interface import RobotWebInterface def parse_args(): diff --git a/tests/test_object_detection_stream.py b/tests/test_object_detection_stream.py index 1cf8aeab01..2d45c261d5 100644 --- a/tests/test_object_detection_stream.py +++ b/tests/test_object_detection_stream.py @@ -12,22 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os -import time import sys -import argparse import threading -from typing import List, Dict, Any -from reactivex import Subject, operators as ops +import time +from typing import Any + +from dotenv import load_dotenv +from reactivex import operators as ops +from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.types.vector import Vector from dimos.utils.reactive import backpressure -from dotenv import load_dotenv +from dimos.web.robot_web_interface import RobotWebInterface def parse_args(): @@ -58,7 +57,7 @@ def __init__(self, print_interval: float = 1.0): self.print_interval = print_interval self.last_print_time = 0 - def print_results(self, objects: List[Dict[str, Any]]): + def print_results(self, objects: list[dict[str, Any]]): """Print object detection results to console with rate limiting.""" current_time = time.time() diff --git a/tests/test_object_tracking_module.py b/tests/test_object_tracking_module.py index 0b4b1f1364..4fc1adac83 100755 --- a/tests/test_object_tracking_module.py +++ b/tests/test_object_tracking_module.py @@ -16,20 +16,21 @@ """Test script for Object Tracking module with ZED camera.""" import asyncio + import cv2 +from dimos_lcm.sensor_msgs import CameraInfo from dimos import core from dimos.hardware.zed_camera import ZEDModule -from dimos.perception.object_tracker import ObjectTracking -from dimos.protocol import pubsub -from dimos.utils.logging_config import setup_logger -from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.msgs.geometry_msgs import PoseStamped # Import message types from dimos.msgs.sensor_msgs import Image -from dimos_lcm.sensor_msgs import CameraInfo -from dimos.msgs.geometry_msgs import PoseStamped +from dimos.perception.object_tracker import ObjectTracking +from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.utils.logging_config import setup_logger logger = setup_logger("test_object_tracking_module") diff --git a/tests/test_object_tracking_webcam.py b/tests/test_object_tracking_webcam.py index a9d792d51b..8fcfe7bacd 100644 --- a/tests/test_object_tracking_webcam.py +++ b/tests/test_object_tracking_webcam.py @@ -12,16 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np -import os -import sys import queue import threading -import tests.test_header -from dimos.stream.video_provider import VideoProvider +import cv2 + from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.stream.video_provider import VideoProvider # Global variables for bounding box selection selecting_bbox = False diff --git a/tests/test_object_tracking_with_qwen.py b/tests/test_object_tracking_with_qwen.py index 959565ae55..e8fcd86a2b 100644 --- a/tests/test_object_tracking_with_qwen.py +++ b/tests/test_object_tracking_with_qwen.py @@ -13,21 +13,14 @@ # limitations under the License. import os -import sys -import time -import cv2 -import numpy as np import queue import threading -import json -from reactivex import Subject, operators as RxOps -from openai import OpenAI -import tests.test_header -from dimos.stream.video_provider import VideoProvider -from dimos.perception.object_tracker import ObjectTrackingStream +import cv2 + from dimos.models.qwen.video_query import get_bbox_from_qwen -from dimos.utils.logging_config import logger +from dimos.perception.object_tracker import ObjectTrackingStream +from dimos.stream.video_provider import VideoProvider # Global variables for tracking control object_size = 0.30 # Hardcoded object size in meters (adjust based on your tracking target) diff --git a/tests/test_person_following_robot.py b/tests/test_person_following_robot.py index 46f91cc7a3..f7ee6eaf0d 100644 --- a/tests/test_person_following_robot.py +++ b/tests/test_person_following_robot.py @@ -14,16 +14,15 @@ import os import time -import sys + from reactivex import operators as RxOps -import tests.test_header +from dimos.models.qwen.video_query import query_single_frame_observable from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.logging_config import logger -from dimos.models.qwen.video_query import query_single_frame_observable +from dimos.web.robot_web_interface import RobotWebInterface def main(): diff --git a/tests/test_person_following_webcam.py b/tests/test_person_following_webcam.py index 2108c4cf95..20a6a7ca4d 100644 --- a/tests/test_person_following_webcam.py +++ b/tests/test_person_following_webcam.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np -import os -import sys import queue import threading -import tests.test_header +import cv2 +import numpy as np -from dimos.stream.video_provider import VideoProvider from dimos.perception.person_tracker import PersonTrackingStream from dimos.perception.visual_servoing import VisualServoing +from dimos.stream.video_provider import VideoProvider def main(): diff --git a/tests/test_pick_and_place_module.py b/tests/test_pick_and_place_module.py index 6a8470863e..1bce414a6e 100644 --- a/tests/test_pick_and_place_module.py +++ b/tests/test_pick_and_place_module.py @@ -18,13 +18,13 @@ Subscribes to visualization images and handles mouse/keyboard input. """ -import cv2 -import sys import asyncio +import sys import threading import time + +import cv2 import numpy as np -from typing import Optional try: import pyzed.sl as sl @@ -32,12 +32,12 @@ print("Error: ZED SDK not installed.") sys.exit(1) -from dimos.robot.agilex.piper_arm import PiperArmRobot -from dimos.utils.logging_config import setup_logger - # Import LCM message types from dimos_lcm.sensor_msgs import Image + from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.agilex.piper_arm import PiperArmRobot +from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.tests.test_pick_and_place_module") diff --git a/tests/test_pick_and_place_skill.py b/tests/test_pick_and_place_skill.py index 40cf2c23b0..78eeb761fb 100644 --- a/tests/test_pick_and_place_skill.py +++ b/tests/test_pick_and_place_skill.py @@ -18,8 +18,8 @@ Uses hardcoded points and the PickAndPlace skill. """ -import sys import asyncio +import sys try: import pyzed.sl as sl # Required for ZED camera diff --git a/tests/test_planning_agent_web_interface.py b/tests/test_planning_agent_web_interface.py index 1d1e3fcd87..6c88919110 100644 --- a/tests/test_planning_agent_web_interface.py +++ b/tests/test_planning_agent_web_interface.py @@ -23,15 +23,13 @@ ROS_OUTPUT_DIR: Optional. Directory for ROS output files. """ -import tests.test_header import os import sys # ----- - from textwrap import dedent -import threading import time + import reactivex as rx import reactivex.operators as ops @@ -41,10 +39,10 @@ from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.utils.logging_config import logger +from dimos.utils.threadpool import make_single_thread_scheduler # from dimos.web.fastapi_server import FastAPIServer from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.threadpool import make_single_thread_scheduler def main(): diff --git a/tests/test_planning_robot_agent.py b/tests/test_planning_robot_agent.py index 6e55e5de71..aa16a7cac7 100644 --- a/tests/test_planning_robot_agent.py +++ b/tests/test_planning_robot_agent.py @@ -24,14 +24,11 @@ USE_TERMINAL: Optional. If set to "true", use terminal interface instead of web. """ -import tests.test_header import os import sys # ----- - from textwrap import dedent -import threading import time # Local application imports @@ -40,8 +37,8 @@ from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.utils.logging_config import logger -from dimos.web.robot_web_interface import RobotWebInterface from dimos.utils.threadpool import make_single_thread_scheduler +from dimos.web.robot_web_interface import RobotWebInterface def main(): @@ -110,7 +107,7 @@ def main(): system_query = dedent( """ You are a robot execution agent that can execute tasks on a virtual - robot. You are given a task to execute and a list of skills that + robot. You are given a task to execute and a list of skills that you can use to execute the task. ONLY OUTPUT THE SKILLS TO EXECUTE, NOTHING ELSE. """ diff --git a/tests/test_pointcloud_filtering.py b/tests/test_pointcloud_filtering.py index 57a1cb5b00..8a9eb8665f 100644 --- a/tests/test_pointcloud_filtering.py +++ b/tests/test_pointcloud_filtering.py @@ -13,17 +13,13 @@ # limitations under the License. import sys -import time -import threading -from reactivex import operators as ops - -import tests.test_header from pyzed import sl +from reactivex import operators as ops + +from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline from dimos.stream.stereo_camera_streams.zed import ZEDCameraStream from dimos.web.robot_web_interface import RobotWebInterface -from dimos.utils.logging_config import logger -from dimos.manipulation.manip_aio_pipeline import ManipulationPipeline def main(): @@ -78,10 +74,10 @@ def main(): pointcloud_stream=pointcloud_viz_stream, ) - print(f"\nPoint Cloud Filtering Test Running:") + print("\nPoint Cloud Filtering Test Running:") print(f"Web Interface: http://localhost:{web_port}") - print(f"Object Detection View: RGB with bounding boxes") - print(f"Point Cloud View: Depth with colored point clouds and 3D bounding boxes") + print("Object Detection View: RGB with bounding boxes") + print("Point Cloud View: Depth with colored point clouds and 3D bounding boxes") print(f"Confidence threshold: {min_confidence}") print("\nPress Ctrl+C to stop the test\n") diff --git a/tests/test_qwen_image_query.py b/tests/test_qwen_image_query.py index 634f9f6563..6a3aa9d8c6 100644 --- a/tests/test_qwen_image_query.py +++ b/tests/test_qwen_image_query.py @@ -15,9 +15,11 @@ """Test the Qwen image query functionality.""" import os + import cv2 import numpy as np from PIL import Image + from dimos.models.qwen.video_query import query_single_frame diff --git a/tests/test_robot.py b/tests/test_robot.py index 76289273f7..63439ce3d9 100644 --- a/tests/test_robot.py +++ b/tests/test_robot.py @@ -13,13 +13,14 @@ # limitations under the License. import os -import time import threading -from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 +import time + +from reactivex import operators as RxOps + from dimos.robot.local_planner.local_planner import navigate_to_goal_local +from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 from dimos.web.robot_web_interface import RobotWebInterface -from reactivex import operators as RxOps -import tests.test_header def main(): diff --git a/tests/test_rtsp_video_provider.py b/tests/test_rtsp_video_provider.py index e3824740a6..fb0f075750 100644 --- a/tests/test_rtsp_video_provider.py +++ b/tests/test_rtsp_video_provider.py @@ -12,11 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dimos.stream.rtsp_video_provider import RtspVideoProvider -from dimos.web.robot_web_interface import RobotWebInterface -import tests.test_header - -import logging import time import numpy as np @@ -24,15 +19,16 @@ from reactivex import operators as ops from dimos.stream.frame_processor import FrameProcessor +from dimos.stream.rtsp_video_provider import RtspVideoProvider from dimos.stream.video_operators import VideoOperators as vops from dimos.stream.video_provider import get_scheduler from dimos.utils.logging_config import setup_logger - +from dimos.web.robot_web_interface import RobotWebInterface logger = setup_logger("tests.test_rtsp_video_provider") -import sys import os +import sys # Load environment variables from .env file from dotenv import load_dotenv @@ -51,7 +47,7 @@ print("Example: python -m dimos.stream.rtsp_video_provider rtsp://...") sys.exit(1) -logger.info(f"Attempting to connect to provided RTSP URL.") +logger.info("Attempting to connect to provided RTSP URL.") provider = RtspVideoProvider(dev_name="TestRtspCam", rtsp_url=RTSP_URL) logger.info("Creating observable...") diff --git a/tests/test_semantic_seg_robot.py b/tests/test_semantic_seg_robot.py index eb5beb88e2..0a78bc371b 100644 --- a/tests/test_semantic_seg_robot.py +++ b/tests/test_semantic_seg_robot.py @@ -12,25 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np import os -import sys import queue +import sys import threading +import cv2 +import numpy as np + # Add the parent directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from dimos.stream.video_provider import VideoProvider +from reactivex import operators as RxOps + from dimos.perception.semantic_seg import SemanticSegmentationStream from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.stream.video_operators import VideoOperators as MyVideoOps, Operators as MyOps from dimos.stream.frame_processor import FrameProcessor -from reactivex import operators as RxOps +from dimos.stream.video_operators import Operators as MyOps +from dimos.web.robot_web_interface import RobotWebInterface def main(): @@ -111,7 +111,7 @@ def on_completed(): "counts": {}, } - frame_processor = FrameProcessor(delete_on_init=True) + FrameProcessor(delete_on_init=True) subscription = segmentation_stream.pipe( MyOps.print_emission(id="A", **print_emission_args), RxOps.share(), diff --git a/tests/test_semantic_seg_robot_agent.py b/tests/test_semantic_seg_robot_agent.py index 8007e700a0..f35fdb53d4 100644 --- a/tests/test_semantic_seg_robot_agent.py +++ b/tests/test_semantic_seg_robot_agent.py @@ -12,22 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np import os -import sys -from dimos.stream.video_provider import VideoProvider +import cv2 +from reactivex import Subject, operators as RxOps + +from dimos.agents.agent import OpenAIAgent from dimos.perception.semantic_seg import SemanticSegmentationStream from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.stream.video_operators import VideoOperators as MyVideoOps, Operators as MyOps from dimos.stream.frame_processor import FrameProcessor -from reactivex import Subject, operators as RxOps -from dimos.agents.agent import OpenAIAgent +from dimos.stream.video_operators import VideoOperators as MyVideoOps from dimos.utils.threadpool import get_scheduler +from dimos.web.robot_web_interface import RobotWebInterface def main(): @@ -54,7 +52,7 @@ def main(): # Throttling to slowdown SegmentationAgent calls # TODO: add Agent parameter to handle this called api_call_interval - frame_processor = FrameProcessor(delete_on_init=True) + FrameProcessor(delete_on_init=True) seg_stream = segmentation_stream.pipe( RxOps.share(), RxOps.map(lambda x: x.metadata["viz_frame"] if x is not None else None), diff --git a/tests/test_semantic_seg_webcam.py b/tests/test_semantic_seg_webcam.py index 083d1a0090..b7fc57073b 100644 --- a/tests/test_semantic_seg_webcam.py +++ b/tests/test_semantic_seg_webcam.py @@ -12,18 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cv2 -import numpy as np import os -import sys import queue +import sys import threading +import cv2 +import numpy as np + # Add the parent directory to the Python path sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from dimos.stream.video_provider import VideoProvider from dimos.perception.semantic_seg import SemanticSegmentationStream +from dimos.stream.video_provider import VideoProvider def main(): diff --git a/tests/test_skills.py b/tests/test_skills.py index 0d4b7f2ff8..139a4efe59 100644 --- a/tests/test_skills.py +++ b/tests/test_skills.py @@ -17,13 +17,10 @@ import unittest from unittest import mock -import tests.test_header - -from dimos.skills.skills import AbstractSkill, SkillLibrary +from dimos.agents.agent import OpenAIAgent from dimos.robot.robot import MockRobot from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.types.constants import Colors -from dimos.agents.agent import OpenAIAgent +from dimos.skills.skills import AbstractSkill class TestSkill(AbstractSkill): diff --git a/tests/test_skills_rest.py b/tests/test_skills_rest.py index 70a15fcfd5..a9493e3c79 100644 --- a/tests/test_skills_rest.py +++ b/tests/test_skills_rest.py @@ -12,18 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - from textwrap import dedent -from dimos.skills.skills import SkillLibrary from dotenv import load_dotenv -from dimos.agents.claude_agent import ClaudeAgent -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.rest.rest import GenericRestSkill import reactivex as rx import reactivex.operators as ops +from dimos.agents.claude_agent import ClaudeAgent +from dimos.skills.rest.rest import GenericRestSkill +from dimos.skills.skills import SkillLibrary +from dimos.web.robot_web_interface import RobotWebInterface + # Load API key from environment load_dotenv() @@ -48,9 +47,9 @@ skills=skills, system_query=dedent( """ - You are a virtual agent. When given a query, respond by using + You are a virtual agent. When given a query, respond by using the appropriate tool calls if needed to execute commands on the robot. - + IMPORTANT: Only return the response directly asked of the user. E.G. if the user asks for the time, only return the time. If the user asks for the weather, only return the weather. diff --git a/tests/test_spatial_memory.py b/tests/test_spatial_memory.py index 16b1449509..8d1d88b468 100644 --- a/tests/test_spatial_memory.py +++ b/tests/test_spatial_memory.py @@ -13,25 +13,20 @@ # limitations under the License. import os -import sys import time -import pickle -import numpy as np + +import chromadb import cv2 -import matplotlib.pyplot as plt from matplotlib.patches import Circle +import matplotlib.pyplot as plt import reactivex from reactivex import operators as ops -import chromadb from dimos.agents.memory.visual_memory import VisualMemory - -import tests.test_header +from dimos.msgs.geometry_msgs import Quaternion, Vector3 # from dimos.robot.unitree_webrtc.unitree_go2 import UnitreeGo2 # Uncomment when properly configured from dimos.perception.spatial_perception import SpatialMemory -from dimos.types.vector import Vector -from dimos.msgs.geometry_msgs import Vector3, Quaternion def extract_pose_data(transform): @@ -146,7 +141,7 @@ def main(): def on_stored_frame(result): nonlocal stored_count # Only count actually stored frames (not debug frames) - if not result.get("stored", True) == False: + if not not result.get("stored", True): stored_count += 1 pos = result["position"] if isinstance(pos, tuple): diff --git a/tests/test_spatial_memory_query.py b/tests/test_spatial_memory_query.py index a0e77e9444..539f5f5eb0 100644 --- a/tests/test_spatial_memory_query.py +++ b/tests/test_spatial_memory_query.py @@ -20,18 +20,16 @@ python test_spatial_memory_query.py --query "robot" --limit 3 --save-one """ -import os -import sys import argparse -import numpy as np +from datetime import datetime +import os + +import chromadb import cv2 import matplotlib.pyplot as plt -import chromadb -from datetime import datetime -import tests.test_header -from dimos.perception.spatial_perception import SpatialMemory from dimos.agents.memory.visual_memory import VisualMemory +from dimos.perception.spatial_perception import SpatialMemory def setup_persistent_chroma_db(db_path): @@ -225,7 +223,7 @@ def visualize_spatial_memory_with_objects( x_coords = [loc[0] for loc in locations] y_coords = [loc[1] for loc in locations] else: - x_coords, y_coords = zip(*locations) + x_coords, y_coords = zip(*locations, strict=False) # Create figure plt.figure(figsize=(12, 10)) diff --git a/tests/test_standalone_chromadb.py b/tests/test_standalone_chromadb.py index a5dc0e9b73..d6e59e5237 100644 --- a/tests/test_standalone_chromadb.py +++ b/tests/test_standalone_chromadb.py @@ -12,14 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header import os # ----- - -import chromadb -from langchain_openai import OpenAIEmbeddings from langchain_chroma import Chroma +from langchain_openai import OpenAIEmbeddings OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: diff --git a/tests/test_standalone_fastapi.py b/tests/test_standalone_fastapi.py index 6fac013546..eb7a9a060a 100644 --- a/tests/test_standalone_fastapi.py +++ b/tests/test_standalone_fastapi.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os - import logging +import os logging.basicConfig(level=logging.DEBUG) -from fastapi import FastAPI, Response import cv2 -import uvicorn +from fastapi import FastAPI from starlette.responses import StreamingResponse +import uvicorn app = FastAPI() diff --git a/tests/test_standalone_hugging_face.py b/tests/test_standalone_hugging_face.py index d0b2e68e61..ad5f02d510 100644 --- a/tests/test_standalone_hugging_face.py +++ b/tests/test_standalone_hugging_face.py @@ -12,19 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - # from transformers import AutoModelForCausalLM, AutoTokenizer - # model_name = "Qwen/QwQ-32B" - # model = AutoModelForCausalLM.from_pretrained( # model_name, # torch_dtype="auto", # device_map="auto" # ) # tokenizer = AutoTokenizer.from_pretrained(model_name) - # prompt = "How many r's are in the word \"strawberry\"" # messages = [ # {"role": "user", "content": prompt} @@ -34,9 +29,7 @@ # tokenize=False, # add_generation_prompt=True # ) - # model_inputs = tokenizer([text], return_tensors="pt").to(model.device) - # generated_ids = model.generate( # **model_inputs, # max_new_tokens=32768 @@ -44,31 +37,23 @@ # generated_ids = [ # output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) # ] - # response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] # print(response) - # ----------------------------------------------------------------------------- - # import requests # import json - # API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" # api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') - # HEADERS = {"Authorization": f"Bearer {api_key}"} - # prompt = "How many r's are in the word \"strawberry\"" # messages = [ # {"role": "user", "content": prompt} # ] - # # Format the prompt in the desired chat format # chat_template = ( # f"{messages[0]['content']}\n" # "Assistant:" # ) - # payload = { # "inputs": chat_template, # "parameters": { @@ -76,28 +61,21 @@ # "temperature": 0.7 # } # } - # # API request # response = requests.post(API_URL, headers=HEADERS, json=payload) - # # Handle response # if response.status_code == 200: # output = response.json()[0]['generated_text'] # print(output.strip()) # else: # print(f"Error {response.status_code}: {response.text}") - # ----------------------------------------------------------------------------- - # import os # import requests # import time - # API_URL = "https://api-inference.huggingface.co/models/Qwen/QwQ-32B" # api_key = os.getenv('HUGGINGFACE_ACCESS_TOKEN') - # HEADERS = {"Authorization": f"Bearer {api_key}"} - # def query_with_retries(payload, max_retries=5, delay=15): # for attempt in range(max_retries): # response = requests.post(API_URL, headers=HEADERS, json=payload) @@ -110,22 +88,18 @@ # print(f"Error {response.status_code}: {response.text}") # break # return "Failed after multiple retries." - # prompt = "How many r's are in the word \"strawberry\"" # messages = [{"role": "user", "content": prompt}] # chat_template = f"{messages[0]['content']}\nAssistant:" - # payload = { # "inputs": chat_template, # "parameters": {"max_new_tokens": 32768, "temperature": 0.7} # } - # output = query_with_retries(payload) # print(output.strip()) - # ----------------------------------------------------------------------------- - import os + from huggingface_hub import InferenceClient # Use environment variable for API key diff --git a/tests/test_standalone_openai_json.py b/tests/test_standalone_openai_json.py index ef839ae85b..fe1a67ad78 100644 --- a/tests/test_standalone_openai_json.py +++ b/tests/test_standalone_openai_json.py @@ -12,17 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os # ----- - import dotenv dotenv.load_dotenv() import json from textwrap import dedent + from openai import OpenAI from pydantic import BaseModel diff --git a/tests/test_standalone_openai_json_struct.py b/tests/test_standalone_openai_json_struct.py index 1b49aed8a7..b22f064e35 100644 --- a/tests/test_standalone_openai_json_struct.py +++ b/tests/test_standalone_openai_json_struct.py @@ -12,18 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os # ----- -from typing import List, Union, Dict - import dotenv dotenv.load_dotenv() from textwrap import dedent + from openai import OpenAI from pydantic import BaseModel @@ -79,7 +76,7 @@ def get_math_solution(question: str): # If we were able to successfully parse the response back parsed_solution = solution.parsed if not parsed_solution: - print(f"Unable to Parse Solution") + print("Unable to Parse Solution") exit() # Print solution from class definitions diff --git a/tests/test_standalone_openai_json_struct_func.py b/tests/test_standalone_openai_json_struct_func.py index dcea40ffff..36f158cd20 100644 --- a/tests/test_standalone_openai_json_struct_func.py +++ b/tests/test_standalone_openai_json_struct_func.py @@ -12,22 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os # ----- -from typing import List, Union, Dict - import dotenv dotenv.load_dotenv() import json -import requests from textwrap import dedent + from openai import OpenAI, pydantic_function_tool from pydantic import BaseModel, Field +import requests MODEL = "gpt-4o-2024-08-06" @@ -163,7 +160,7 @@ def get_math_solution(question: str): # If we were able to successfully parse the response back parsed_solution = solution.parsed if not parsed_solution: - print(f"Unable to Parse Solution") + print("Unable to Parse Solution") print(f"Solution: {solution}") break diff --git a/tests/test_standalone_openai_json_struct_func_playground.py b/tests/test_standalone_openai_json_struct_func_playground.py index f4554de6be..8dd687148d 100644 --- a/tests/test_standalone_openai_json_struct_func_playground.py +++ b/tests/test_standalone_openai_json_struct_func_playground.py @@ -12,61 +12,46 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os - # ----- # # Milestone 1 - - # from typing import List, Dict, Optional # import requests # import json # from pydantic import BaseModel, Field # from openai import OpenAI, pydantic_function_tool - # # Environment setup # import dotenv # dotenv.load_dotenv() - # # Constants and prompts # MODEL = "gpt-4o-2024-08-06" # GENERAL_PROMPT = ''' # Follow the instructions. Output a step by step solution, along with a final answer. # Use the explanation field to detail the reasoning. # ''' - # # Initialize OpenAI client # client = OpenAI() - # # Models and functions # class Step(BaseModel): # explanation: str # output: str - # class MathReasoning(BaseModel): # steps: List[Step] # final_answer: str - # class GetWeather(BaseModel): # latitude: str = Field(..., description="Latitude e.g., Bogotá, Colombia") # longitude: str = Field(..., description="Longitude e.g., Bogotá, Colombia") - # def fetch_weather(latitude: str, longitude: str) -> Dict: # url = f"https://api.open-meteo.com/v1/forecast?latitude={latitude}&longitude={longitude}¤t=temperature_2m,wind_speed_10m&hourly=temperature_2m,relative_humidity_2m,wind_speed_10m&temperature_unit=fahrenheit" # response = requests.get(url) # return response.json().get('current', {}) - # # Tool management # def get_tools() -> List[BaseModel]: # return [pydantic_function_tool(GetWeather)] - # def handle_function_call(tool_call: Dict) -> Optional[str]: # if tool_call['name'] == "get_weather": # result = fetch_weather(**tool_call['args']) # return f"Temperature is {result['temperature_2m']}°F" # return None - # # Communication and processing with OpenAI # def process_message_with_openai(question: str) -> MathReasoning: # messages = [ @@ -80,11 +65,9 @@ # tools=get_tools() # ) # return response.choices[0].message - # def get_math_solution(question: str) -> MathReasoning: # solution = process_message_with_openai(question) # return solution - # # Example usage # def main(): # problems = [ @@ -93,32 +76,24 @@ # ] # problem = problems[1] # print(f"Problem: {problem}") - # solution = get_math_solution(problem) # if not solution: # print("Failed to get a solution.") # return - # if not solution.parsed: # print("Failed to get a parsed solution.") # print(f"Solution: {solution}") # return - # print(f"Steps: {solution.parsed.steps}") # print(f"Final Answer: {solution.parsed.final_answer}") - # if __name__ == "__main__": # main() - - # # Milestone 1 - # Milestone 2 import json -import os -import requests from dotenv import load_dotenv +import requests load_dotenv() diff --git a/tests/test_standalone_project_out.py b/tests/test_standalone_project_out.py index 22aec63bae..8fe99f0704 100644 --- a/tests/test_standalone_project_out.py +++ b/tests/test_standalone_project_out.py @@ -12,20 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import sys -import os - # ----- - import ast import inspect -import types import sys def extract_function_info(filename): - with open(filename, "r") as f: + with open(filename) as f: source = f.read() tree = ast.parse(source, filename=filename) diff --git a/tests/test_standalone_rxpy_01.py b/tests/test_standalone_rxpy_01.py index 733930d430..9be48f3eab 100644 --- a/tests/test_standalone_rxpy_01.py +++ b/tests/test_standalone_rxpy_01.py @@ -12,35 +12,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header -import os +import multiprocessing +from threading import Event # ----- - import reactivex from reactivex import operators as ops from reactivex.scheduler import ThreadPoolScheduler -import multiprocessing -from threading import Event which_test = 2 if which_test == 1: """ Test 1: Periodic Emission Test - This test creates a ThreadPoolScheduler that leverages as many threads as there are CPU + This test creates a ThreadPoolScheduler that leverages as many threads as there are CPU cores available, optimizing the execution across multiple threads. The core functionality - revolves around an observable, secondly_emission, which emits a value every second. - Each emission is an incrementing integer, which is then mapped to a message indicating - the number of seconds since the test began. The sequence is limited to 30 emissions, - each logged as it occurs, and accompanied by an additional message via the - emission_process function to indicate the value's emission. The test subscribes to the - observable to print each emitted value, handle any potential errors, and confirm + revolves around an observable, secondly_emission, which emits a value every second. + Each emission is an incrementing integer, which is then mapped to a message indicating + the number of seconds since the test began. The sequence is limited to 30 emissions, + each logged as it occurs, and accompanied by an additional message via the + emission_process function to indicate the value's emission. The test subscribes to the + observable to print each emitted value, handle any potential errors, and confirm completion of the emissions after 30 seconds. Key Components: • ThreadPoolScheduler: Manages concurrency with multiple threads. - • Observable Sequence: Emits every second, indicating progression with a specific + • Observable Sequence: Emits every second, indicating progression with a specific message format. • Subscription: Monitors and logs emissions, errors, and the completion event. """ @@ -73,14 +70,14 @@ def emission_process(value): In this test, a similar ThreadPoolScheduler setup is used to handle tasks across multiple CPU cores efficiently. This setup includes two observables. The first, secondly_emission, - emits an incrementing integer every second, indicating the passage of time. The second - observable, immediate_emission, emits a predefined sequence of characters (['a', 'b', - 'c', 'd', 'e']) repeatedly and immediately. These two streams are combined using the zip - operator, which synchronizes their emissions into pairs. Each combined pair is formatted - and logged, indicating both the time elapsed and the immediate value emitted at that + emits an incrementing integer every second, indicating the passage of time. The second + observable, immediate_emission, emits a predefined sequence of characters (['a', 'b', + 'c', 'd', 'e']) repeatedly and immediately. These two streams are combined using the zip + operator, which synchronizes their emissions into pairs. Each combined pair is formatted + and logged, indicating both the time elapsed and the immediate value emitted at that second. - A synchronization mechanism via an Event (completed_event) ensures that the main program + A synchronization mechanism via an Event (completed_event) ensures that the main program thread waits until all planned emissions are completed before exiting. This test not only checks the functionality of zipping different rhythmic emissions but also demonstrates handling of asynchronous task completion in Python using event-driven programming. @@ -88,9 +85,9 @@ def emission_process(value): Key Components: • Combined Observable Emissions: Synchronizes periodic and immediate emissions into a single stream. - • Event Synchronization: Uses a threading event to manage program lifecycle and + • Event Synchronization: Uses a threading event to manage program lifecycle and ensure that all emissions are processed before shutdown. - • Complex Subscription Management: Handles errors and completion, including + • Complex Subscription Management: Handles errors and completion, including setting an event to signal the end of task processing. """ diff --git a/tests/test_unitree_agent.py b/tests/test_unitree_agent.py index 34c5aa335d..5c4b6acb7b 100644 --- a/tests/test_unitree_agent.py +++ b/tests/test_unitree_agent.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header import os import time @@ -304,7 +303,7 @@ def stop(self): elif test_to_run == 4: myUnitreeAgentDemo.run_with_queries_and_fast_api() elif test_to_run < 0 or test_to_run >= 5: - assert False, f"Invalid test number: {test_to_run}" + raise AssertionError(f"Invalid test number: {test_to_run}") # Keep the program running to allow the Unitree Agent Demo to operate continuously try: diff --git a/tests/test_unitree_agent_queries_fastapi.py b/tests/test_unitree_agent_queries_fastapi.py index be95ea5de6..0671a53135 100644 --- a/tests/test_unitree_agent_queries_fastapi.py +++ b/tests/test_unitree_agent_queries_fastapi.py @@ -23,9 +23,9 @@ ROS_OUTPUT_DIR: Optional. Directory for ROS output files. """ -import tests.test_header import os import sys + import reactivex as rx import reactivex.operators as ops @@ -34,7 +34,6 @@ from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_skills import MyUnitreeSkills from dimos.utils.logging_config import logger -from dimos.web.robot_web_interface import RobotWebInterface from dimos.web.fastapi_server import FastAPIServer diff --git a/tests/test_unitree_ros_v0.0.4.py b/tests/test_unitree_ros_v0.0.4.py index e4086074cc..efb39be2bf 100644 --- a/tests/test_unitree_ros_v0.0.4.py +++ b/tests/test_unitree_ros_v0.0.4.py @@ -12,30 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header import os -import time from dotenv import load_dotenv +import reactivex as rx +import reactivex.operators as ops + from dimos.agents.claude_agent import ClaudeAgent +from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.perception.object_detection_stream import ObjectDetectionStream from dimos.robot.unitree.unitree_go2 import UnitreeGo2 -from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl from dimos.robot.unitree.unitree_skills import MyUnitreeSkills -from dimos.web.robot_web_interface import RobotWebInterface -from dimos.skills.observe_stream import ObserveStream from dimos.skills.kill_skill import KillSkill -from dimos.skills.navigation import NavigateWithText, GetPose, NavigateToGoal -from dimos.skills.visual_navigation_skills import FollowHuman -import reactivex as rx -import reactivex.operators as ops -from dimos.stream.audio.pipelines import tts, stt -import threading -import json -from dimos.types.vector import Vector +from dimos.skills.navigation import GetPose, NavigateWithText +from dimos.skills.observe_stream import ObserveStream from dimos.skills.speak import Speak -from dimos.perception.object_detection_stream import ObjectDetectionStream -from dimos.perception.detection2d.detic_2d_det import Detic2DDetector +from dimos.skills.visual_navigation_skills import FollowHuman +from dimos.stream.audio.pipelines import stt, tts from dimos.utils.reactive import backpressure +from dimos.web.robot_web_interface import RobotWebInterface # Load API key from environment load_dotenv() @@ -142,7 +137,7 @@ def combine_with_locations(object_detections): # Read system query from prompt.txt file with open( - os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt"), "r" + os.path.join(os.path.dirname(os.path.dirname(__file__)), "assets", "agent", "prompt.txt") ) as f: system_query = f.read() diff --git a/tests/test_webrtc_queue.py b/tests/test_webrtc_queue.py index 11408df145..5e09ec1f9d 100644 --- a/tests/test_webrtc_queue.py +++ b/tests/test_webrtc_queue.py @@ -14,11 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tests.test_header - +import os import time + from dimos.robot.unitree.unitree_go2 import UnitreeGo2, WebRTCConnectionMethod -import os from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl diff --git a/tests/test_websocketvis.py b/tests/test_websocketvis.py index a400bd9d14..262555ce50 100644 --- a/tests/test_websocketvis.py +++ b/tests/test_websocketvis.py @@ -12,22 +12,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import math import os -import time +import pickle import threading +import time + +from reactivex import operators as ops + +from dimos.robot.global_planner.planner import AstarPlanner from dimos.robot.unitree.unitree_go2 import UnitreeGo2 from dimos.robot.unitree.unitree_ros_control import UnitreeROSControl -from dimos.web.websocket_vis.server import WebsocketVis -from dimos.web.websocket_vis.helpers import vector_stream -from dimos.robot.global_planner.planner import AstarPlanner from dimos.types.costmap import Costmap from dimos.types.vector import Vector -from reactivex import operators as ops -import argparse -import pickle -import reactivex as rx from dimos.web.robot_web_interface import RobotWebInterface +from dimos.web.websocket_vis.helpers import vector_stream +from dimos.web.websocket_vis.server import WebsocketVis def parse_args(): diff --git a/tests/test_zed_module.py b/tests/test_zed_module.py index a8c5691b59..03a21ac65d 100644 --- a/tests/test_zed_module.py +++ b/tests/test_zed_module.py @@ -18,21 +18,20 @@ import asyncio import threading import time -from typing import Optional -import numpy as np + import cv2 +from dimos_lcm.geometry_msgs import PoseStamped + +# Import LCM message types +from dimos_lcm.sensor_msgs import CameraInfo, Image as LCMImage +import numpy as np from dimos import core from dimos.hardware.zed_camera import ZEDModule -from dimos.protocol import pubsub -from dimos.utils.logging_config import setup_logger from dimos.perception.common.utils import colorize_depth - -# Import LCM message types -from dimos_lcm.sensor_msgs import Image as LCMImage -from dimos_lcm.sensor_msgs import CameraInfo -from dimos_lcm.geometry_msgs import PoseStamped +from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.utils.logging_config import setup_logger logger = setup_logger("test_zed_module") diff --git a/tests/test_zed_setup.py b/tests/test_zed_setup.py index ca50bb63fb..33aefb65eb 100755 --- a/tests/test_zed_setup.py +++ b/tests/test_zed_setup.py @@ -17,8 +17,8 @@ Simple test script to verify ZED camera setup and basic functionality. """ -import sys from pathlib import Path +import sys def test_imports(): @@ -108,11 +108,12 @@ def test_basic_functionality(): try: import pyzed.sl as sl + from dimos.hardware.zed_camera import ZEDCamera from dimos.perception.zed_visualizer import ZEDVisualizer # Test camera initialization (without opening) - camera = ZEDCamera( + ZEDCamera( camera_id=0, resolution=sl.RESOLUTION.HD720, depth_mode=sl.DEPTH_MODE.NEURAL, @@ -127,7 +128,7 @@ def test_basic_functionality(): dummy_rgb = np.zeros((480, 640, 3), dtype=np.uint8) dummy_depth = np.ones((480, 640), dtype=np.float32) * 2.0 - vis = visualizer.create_side_by_side_image(dummy_rgb, dummy_depth) + visualizer.create_side_by_side_image(dummy_rgb, dummy_depth) print("✓ Dummy visualization created successfully") return True diff --git a/tests/visualization_script.py b/tests/visualization_script.py index d0c4c6af84..a42b4bf06c 100644 --- a/tests/visualization_script.py +++ b/tests/visualization_script.py @@ -16,11 +16,11 @@ """Visualize pickled manipulation pipeline results.""" import os -import sys import pickle -import numpy as np -import json +import sys + import matplotlib +import numpy as np # Try to use TkAgg backend for live display, fallback to Agg if not available try: @@ -35,77 +35,42 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from dimos.perception.pointcloud.utils import visualize_clustered_point_clouds, visualize_voxel_grid -from dimos.perception.grasp_generation.utils import visualize_grasps_3d -from dimos.perception.pointcloud.utils import visualize_pcd -from dimos.utils.logging_config import setup_logger -import trimesh - -import tf_lcm_py -import cv2 -from contextlib import contextmanager -import lcm_msgs -from lcm_msgs.sensor_msgs import JointState, PointCloud2, CameraInfo, PointCloud2, PointField -from lcm_msgs.std_msgs import Header -from typing import List, Tuple, Optional import atexit +from contextlib import contextmanager from datetime import datetime import time +import lcm_msgs from pydrake.all import ( AddMultibodyPlantSceneGraph, - CoulombFriction, - Diagram, DiagramBuilder, - InverseKinematics, + JointIndex, MeshcatVisualizer, MeshcatVisualizerParams, - MultibodyPlant, Parser, RigidTransform, RollPitchYaw, RotationMatrix, - JointIndex, - Solve, StartMeshcat, ) +from pydrake.common import MemoryFile from pydrake.geometry import ( + Box, CollisionFilterDeclaration, + InMemoryMesh, Mesh, ProximityProperties, - InMemoryMesh, - Box, - Cylinder, ) from pydrake.math import RigidTransform as DrakeRigidTransform -from pydrake.common import MemoryFile +import tf_lcm_py +import trimesh -from pydrake.all import ( - MinimumDistanceLowerBoundConstraint, - MultibodyPlant, - Parser, - DiagramBuilder, - AddMultibodyPlantSceneGraph, - MeshcatVisualizer, - StartMeshcat, - RigidTransform, - Role, - RollPitchYaw, - RotationMatrix, - Solve, - InverseKinematics, - MeshcatVisualizerParams, - MinimumDistanceLowerBoundConstraint, - DoDifferentialInverseKinematics, - DifferentialInverseKinematicsStatus, - DifferentialInverseKinematicsParameters, - DepthImageToPointCloud, -) -from manipulation.scenarios import AddMultibodyTriad -from manipulation.meshcat_utils import ( # TODO(russt): switch to pydrake version - _MeshcatPoseSliders, +from dimos.perception.pointcloud.utils import ( + visualize_clustered_point_clouds, + visualize_pcd, + visualize_voxel_grid, ) -from manipulation.scenarios import AddIiwa, AddShape, AddWsg +from dimos.utils.logging_config import setup_logger logger = setup_logger("visualization_script") @@ -132,9 +97,9 @@ def deserialize_point_cloud(data): return None pcd = o3d.geometry.PointCloud() - if "points" in data and data["points"]: + if data.get("points"): pcd.points = o3d.utility.Vector3dVector(np.array(data["points"])) - if "colors" in data and data["colors"]: + if data.get("colors"): pcd.colors = o3d.utility.Vector3dVector(np.array(data["colors"])) return pcd @@ -177,9 +142,9 @@ def visualize_results(pickle_path="manipulation_results.pkl"): data = pickle.load(f) results = data["results"] - color_img = data["color_img"] - depth_img = data["depth_img"] - intrinsics = data["intrinsics"] + data["color_img"] + data["depth_img"] + data["intrinsics"] print(f"Loaded results with keys: {list(results.keys())}") @@ -229,7 +194,7 @@ def visualize_results(pickle_path="manipulation_results.pkl"): else: rows = 2 cols = (num_plots + 1) // 2 - fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) + _fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 5 * rows)) # Ensure axes is always a list for consistent indexing if num_plots == 1: @@ -284,7 +249,7 @@ def visualize_results(pickle_path="manipulation_results.pkl"): print("No full point cloud available for visualization") # Reconstruct misc clusters if available - if "misc_clusters" in results and results["misc_clusters"]: + if results.get("misc_clusters"): misc_clusters = [deserialize_point_cloud(cluster) for cluster in results["misc_clusters"]] cluster_count = len(misc_clusters) total_misc_points = sum(len(np.asarray(cluster.points)) for cluster in misc_clusters) @@ -333,8 +298,8 @@ class DrakeKinematicsEnv: def __init__( self, urdf_path: str, - kinematic_chain_joints: List[str], - links_to_ignore: Optional[List[str]] = None, + kinematic_chain_joints: list[str], + links_to_ignore: list[str] | None = None, ): self._resources_to_cleanup = [] @@ -544,7 +509,7 @@ def process_and_add_object_class(self, object_key: str, results: dict): def _create_clustered_convex_hulls( self, points: np.ndarray, object_id: int - ) -> List[o3d.geometry.TriangleMesh]: + ) -> list[o3d.geometry.TriangleMesh]: """ Create convex hulls from DBSCAN clusters of point cloud data. Fast approach: cluster points, then convex hull each cluster. @@ -579,7 +544,7 @@ def _create_clustered_convex_hulls( # Compute nearest neighbor distances for better eps estimation distances = pcd.compute_nearest_neighbor_distance() avg_nn_distance = np.mean(distances) - std_nn_distance = np.std(distances) + np.std(distances) print( f"Object {object_id}: {len(points)} points, avg_nn_dist={avg_nn_distance:.4f}" @@ -740,7 +705,7 @@ def set_joint_positions(self, joint_positions): print(f"Updated joint positions: {joint_positions}") def register_convex_hulls_as_collision( - self, meshes: List[o3d.geometry.TriangleMesh], hull_type: str + self, meshes: list[o3d.geometry.TriangleMesh], hull_type: str ): """Register convex hulls as collision and visual geometry""" if not meshes: @@ -864,7 +829,7 @@ def get_transform(self, target_frame, source_frame): timestamp = self.buffer.get_most_recent_timestamp() if attempts % 10 == 0: print(f"Using timestamp from buffer: {timestamp}") - except Exception as e: + except Exception: # Fall back to current time if get_most_recent_timestamp fails timestamp = datetime.now() if not hasattr(timestamp, "timestamp"): diff --git a/tests/zed_neural_depth_demo.py b/tests/zed_neural_depth_demo.py index 5edce9633f..86daf4107d 100755 --- a/tests/zed_neural_depth_demo.py +++ b/tests/zed_neural_depth_demo.py @@ -21,17 +21,17 @@ Press ESC or 'q' to quit. """ -import os -import sys -import time import argparse +from datetime import datetime import logging from pathlib import Path -import numpy as np +import sys +import time + import cv2 -import yaml -from datetime import datetime +import numpy as np import open3d as o3d +import yaml # Add the project root to Python path sys.path.append(str(Path(__file__).parent.parent)) @@ -44,7 +44,7 @@ sys.exit(1) from dimos.hardware.zed_camera import ZEDCamera -from dimos.perception.pointcloud.utils import visualize_pcd, visualize_clustered_point_clouds +from dimos.perception.pointcloud.utils import visualize_pcd # Configure logging logging.basicConfig( @@ -227,7 +227,7 @@ def visualize_captured_pointclouds(self): def update_display(self): """Update the live display with new frames.""" # Capture frame - left_img, right_img, depth_map = self.camera.capture_frame() + left_img, _right_img, depth_map = self.camera.capture_frame() if left_img is None or depth_map is None: return False, None, None From ae01fbbeac9bd04e780b27efe0a1df0b6142f851 Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Wed, 29 Oct 2025 01:37:21 +0200 Subject: [PATCH 39/40] fix missing imports --- dimos/agents2/agent.py | 2 +- dimos/skills/skills.py | 6 +++++- dimos/types/ros_polyfill.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/dimos/agents2/agent.py b/dimos/agents2/agent.py index 7980cdee7d..04c08b0434 100644 --- a/dimos/agents2/agent.py +++ b/dimos/agents2/agent.py @@ -30,7 +30,7 @@ from dimos.agents2.spec import AgentSpec, Model, Provider from dimos.agents2.system_prompt import get_system_prompt -from dimos.core import rpc +from dimos.core import DimosCluster, rpc from dimos.protocol.skill.coordinator import ( SkillContainer, SkillCoordinator, diff --git a/dimos/skills/skills.py b/dimos/skills/skills.py index 27d07bf7fe..196fcf07b5 100644 --- a/dimos/skills/skills.py +++ b/dimos/skills/skills.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -from collections.abc import Iterator +from __future__ import annotations + import logging from typing import TYPE_CHECKING, Any @@ -21,6 +22,9 @@ from dimos.types.constants import Colors +if TYPE_CHECKING: + from collections.abc import Iterator + # Configure logging for the module logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) diff --git a/dimos/types/ros_polyfill.py b/dimos/types/ros_polyfill.py index 8e8c1e76be..c8919caec3 100644 --- a/dimos/types/ros_polyfill.py +++ b/dimos/types/ros_polyfill.py @@ -15,7 +15,7 @@ try: from geometry_msgs.msg import Vector3 except ImportError: - pass # type: ignore[import] + from dimos.msgs.geometry_msgs import Vector3 # type: ignore[import] try: from geometry_msgs.msg import Point, Pose, Quaternion, Twist From e0e631eb93ee47c1d2742c84f8e46ac973146808 Mon Sep 17 00:00:00 2001 From: lesh Date: Wed, 29 Oct 2025 17:45:16 +0100 Subject: [PATCH 40/40] detic ruff undo --- dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py | 1 - .../Detic/third_party/CenterNet2/centernet/modeling/debug.py | 2 -- .../CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py | 2 -- dimos/models/Detic/third_party/Deformable-DETR/engine.py | 2 +- .../models/Detic/third_party/Deformable-DETR/util/plot_utils.py | 2 +- dimos/models/Detic/tools/preprocess_imagenet22k.py | 2 +- 6 files changed, 3 insertions(+), 8 deletions(-) diff --git a/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py b/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py index 5b54974a9e..aaa7ca233e 100644 --- a/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py +++ b/dimos/models/Detic/detic/modeling/roi_heads/detic_fast_rcnn.py @@ -1,5 +1,4 @@ # Copyright (c) Facebook, Inc. and its affiliates. -from collections.abc import Sequence import math from detectron2.config import configurable diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py index a9f9b4dc73..63186b05c5 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/debug.py @@ -1,5 +1,3 @@ -from collections.abc import Sequence - import cv2 import numpy as np import torch diff --git a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py index c9c5c9a2b6..50ccf371c9 100644 --- a/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py +++ b/dimos/models/Detic/third_party/CenterNet2/centernet/modeling/layers/heatmap_focal_loss.py @@ -1,5 +1,3 @@ -from collections.abc import Sequence - import torch from typing import Sequence diff --git a/dimos/models/Detic/third_party/Deformable-DETR/engine.py b/dimos/models/Detic/third_party/Deformable-DETR/engine.py index 9cee2a089b..7e6e7c2c20 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/engine.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/engine.py @@ -11,10 +11,10 @@ Train and eval functions used in main.py """ -from collections.abc import Iterable import math import os import sys +from typing import Iterable from datasets.coco_eval import CocoEvaluator from datasets.data_prefetcher import data_prefetcher diff --git a/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py b/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py index 710420f410..0af3b9e5e6 100644 --- a/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py +++ b/dimos/models/Detic/third_party/Deformable-DETR/util/plot_utils.py @@ -62,7 +62,7 @@ def plot_logs( # load log file(s) and plot dfs = [pd.read_json(Path(p) / log_name, lines=True) for p in logs] - _fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) + fig, axs = plt.subplots(ncols=len(fields), figsize=(16, 5)) for df, color in zip(dfs, sns.color_palette(n_colors=len(logs)), strict=False): for j, field in enumerate(fields): diff --git a/dimos/models/Detic/tools/preprocess_imagenet22k.py b/dimos/models/Detic/tools/preprocess_imagenet22k.py index c5a5ad0d31..edf2d2bbf7 100644 --- a/dimos/models/Detic/tools/preprocess_imagenet22k.py +++ b/dimos/models/Detic/tools/preprocess_imagenet22k.py @@ -23,7 +23,7 @@ def __init__(self, filename, indexname: str, preload: bool=False) -> None: for l in open(indexname): ll = l.split() - _a, b, c = ll[:3] + a, b, c = ll[:3] offset = int(b[:-1]) if l.endswith("** Block of NULs **\n"): self.offsets.append(offset)