diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 0706a144f6..55ecaeca9e 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -14,13 +14,13 @@ from __future__ import annotations -import struct -import traceback -from io import BytesIO -from typing import BinaryIO, TypeAlias +from typing import TypeAlias from dimos_lcm.geometry_msgs import Pose as LCMPose from dimos_lcm.geometry_msgs import Transform as LCMTransform +from geometry_msgs.msg import Pose as ROSPose +from geometry_msgs.msg import Point as ROSPoint +from geometry_msgs.msg import Quaternion as ROSQuaternion from plum import dispatch from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable @@ -207,6 +207,43 @@ def __add__(self, other: "Pose" | PoseConvertable | LCMTransform | Transform) -> return Pose(new_position, new_orientation) + @classmethod + def from_ros_msg(cls, ros_msg: ROSPose) -> "Pose": + """Create a Pose from a ROS geometry_msgs/Pose message. + + Args: + ros_msg: ROS Pose message + + Returns: + Pose instance + """ + position = Vector3(ros_msg.position.x, ros_msg.position.y, ros_msg.position.z) + orientation = Quaternion( + ros_msg.orientation.x, + ros_msg.orientation.y, + ros_msg.orientation.z, + ros_msg.orientation.w, + ) + return cls(position, orientation) + + def to_ros_msg(self) -> ROSPose: + """Convert to a ROS geometry_msgs/Pose message. + + Returns: + ROS Pose message + """ + ros_msg = ROSPose() + ros_msg.position = ROSPoint( + x=float(self.position.x), y=float(self.position.y), z=float(self.position.z) + ) + ros_msg.orientation = ROSQuaternion( + x=float(self.orientation.x), + y=float(self.orientation.y), + z=float(self.orientation.z), + w=float(self.orientation.w), + ) + return ros_msg + @dispatch def to_pose(value: "Pose") -> "Pose": diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py index ea1198818d..2927247d89 100644 --- a/dimos/msgs/geometry_msgs/PoseStamped.py +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -22,6 +22,7 @@ from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped from dimos_lcm.std_msgs import Header as LCMHeader from dimos_lcm.std_msgs import Time as LCMTime +from geometry_msgs.msg import PoseStamped as ROSPoseStamped from plum import dispatch from dimos.msgs.geometry_msgs.Pose import Pose @@ -109,3 +110,44 @@ def find_transform(self, other: PoseStamped) -> Transform: translation=local_translation, rotation=relative_rotation, ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseStamped) -> "PoseStamped": + """Create a PoseStamped from a ROS geometry_msgs/PoseStamped message. + + Args: + ros_msg: ROS PoseStamped message + + Returns: + PoseStamped instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose + pose = Pose.from_ros_msg(ros_msg.pose) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + position=pose.position, + orientation=pose.orientation, + ) + + def to_ros_msg(self) -> ROSPoseStamped: + """Convert to a ROS geometry_msgs/PoseStamped message. + + Returns: + ROS PoseStamped message + """ + ros_msg = ROSPoseStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set pose + ros_msg.pose = Pose.to_ros_msg(self) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/PoseWithCovariance.py b/dimos/msgs/geometry_msgs/PoseWithCovariance.py new file mode 100644 index 0000000000..e4c93cede9 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseWithCovariance.py @@ -0,0 +1,219 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance +from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# Types that can be converted to/from PoseWithCovariance +PoseWithCovarianceConvertable: TypeAlias = ( + tuple[PoseConvertable, list[float] | np.ndarray] + | LCMPoseWithCovariance + | dict[str, PoseConvertable | list[float] | np.ndarray] +) + + +class PoseWithCovariance(LCMPoseWithCovariance): + pose: Pose + msg_name = "geometry_msgs.PoseWithCovariance" + + @dispatch + def __init__(self) -> None: + """Initialize with default pose and zero covariance.""" + self.pose = Pose() + self.covariance = np.zeros(36) + + @dispatch + def __init__( + self, pose: Pose | PoseConvertable, covariance: list[float] | np.ndarray | None = None + ) -> None: + """Initialize with pose and optional covariance.""" + self.pose = Pose(pose) if not isinstance(pose, Pose) else pose + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__(self, pose_with_cov: PoseWithCovariance) -> None: + """Initialize from another PoseWithCovariance (copy constructor).""" + self.pose = Pose(pose_with_cov.pose) + self.covariance = np.array(pose_with_cov.covariance).copy() + + @dispatch + def __init__(self, lcm_pose_with_cov: LCMPoseWithCovariance) -> None: + """Initialize from an LCM PoseWithCovariance.""" + self.pose = Pose(lcm_pose_with_cov.pose) + self.covariance = np.array(lcm_pose_with_cov.covariance) + + @dispatch + def __init__(self, pose_dict: dict[str, PoseConvertable | list[float] | np.ndarray]) -> None: + """Initialize from a dictionary with 'pose' and 'covariance' keys.""" + self.pose = Pose(pose_dict["pose"]) + covariance = pose_dict.get("covariance") + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__(self, pose_tuple: tuple[PoseConvertable, list[float] | np.ndarray]) -> None: + """Initialize from a tuple of (pose, covariance).""" + self.pose = Pose(pose_tuple[0]) + self.covariance = np.array(pose_tuple[1], dtype=float).reshape(36) + + def __getattribute__(self, name): + """Override to ensure covariance is always returned as numpy array.""" + if name == "covariance": + cov = object.__getattribute__(self, "covariance") + if not isinstance(cov, np.ndarray): + return np.array(cov, dtype=float) + return cov + return super().__getattribute__(name) + + def __setattr__(self, name, value): + """Override to ensure covariance is stored as numpy array.""" + if name == "covariance": + if not isinstance(value, np.ndarray): + value = np.array(value, dtype=float).reshape(36) + super().__setattr__(name, value) + + @property + def x(self) -> float: + """X coordinate of position.""" + return self.pose.x + + @property + def y(self) -> float: + """Y coordinate of position.""" + return self.pose.y + + @property + def z(self) -> float: + """Z coordinate of position.""" + return self.pose.z + + @property + def position(self) -> Vector3: + """Position vector.""" + return self.pose.position + + @property + def orientation(self) -> Quaternion: + """Orientation quaternion.""" + return self.pose.orientation + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.pose.roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.pose.pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.pose.yaw + + @property + def covariance_matrix(self) -> np.ndarray: + """Get covariance as 6x6 matrix.""" + return self.covariance.reshape(6, 6) + + @covariance_matrix.setter + def covariance_matrix(self, value: np.ndarray) -> None: + """Set covariance from 6x6 matrix.""" + self.covariance = np.array(value).reshape(36) + + def __repr__(self) -> str: + return f"PoseWithCovariance(pose={self.pose!r}, covariance=<{self.covariance.shape[0] if isinstance(self.covariance, np.ndarray) else len(self.covariance)} elements>)" + + def __str__(self) -> str: + return ( + f"PoseWithCovariance(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + def __eq__(self, other) -> bool: + """Check if two PoseWithCovariance are equal.""" + if not isinstance(other, PoseWithCovariance): + return False + return self.pose == other.pose and np.allclose(self.covariance, other.covariance) + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMPoseWithCovariance() + lcm_msg.pose = self.pose + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.covariance = self.covariance.tolist() + else: + lcm_msg.covariance = list(self.covariance) + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "PoseWithCovariance": + """Decode from LCM binary format.""" + lcm_msg = LCMPoseWithCovariance.lcm_decode(data) + pose = Pose( + position=[lcm_msg.pose.position.x, lcm_msg.pose.position.y, lcm_msg.pose.position.z], + orientation=[ + lcm_msg.pose.orientation.x, + lcm_msg.pose.orientation.y, + lcm_msg.pose.orientation.z, + lcm_msg.pose.orientation.w, + ], + ) + return cls(pose, lcm_msg.covariance) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseWithCovariance) -> "PoseWithCovariance": + """Create a PoseWithCovariance from a ROS geometry_msgs/PoseWithCovariance message. + + Args: + ros_msg: ROS PoseWithCovariance message + + Returns: + PoseWithCovariance instance + """ + pose = Pose.from_ros_msg(ros_msg.pose) + return cls(pose, list(ros_msg.covariance)) + + def to_ros_msg(self) -> ROSPoseWithCovariance: + """Convert to a ROS geometry_msgs/PoseWithCovariance message. + + Returns: + ROS PoseWithCovariance message + """ + ros_msg = ROSPoseWithCovariance() + ros_msg.pose = self.pose.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.covariance = self.covariance.tolist() + else: + ros_msg.covariance = list(self.covariance) + return ros_msg diff --git a/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py new file mode 100644 index 0000000000..9f48d8e2dc --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py @@ -0,0 +1,155 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped +from geometry_msgs.msg import PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from PoseWithCovarianceStamped +PoseWithCovarianceStampedConvertable: TypeAlias = ( + tuple[PoseConvertable, list[float] | np.ndarray] + | LCMPoseWithCovarianceStamped + | dict[str, PoseConvertable | list[float] | np.ndarray | float | str] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseWithCovarianceStamped(PoseWithCovariance, Timestamped): + msg_name = "geometry_msgs.PoseWithCovarianceStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + """Initialize with timestamp and frame_id.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + pose: Pose | PoseConvertable | None = None, + covariance: list[float] | np.ndarray | None = None, + ) -> None: + """Initialize with timestamp, frame_id, pose and covariance.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + if pose is None: + super().__init__() + else: + super().__init__(pose, covariance) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMPoseWithCovarianceStamped() + lcm_msg.pose.pose = self.pose + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.pose.covariance = self.covariance.tolist() + else: + lcm_msg.pose.covariance = list(self.covariance) + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> PoseWithCovarianceStamped: + lcm_msg = LCMPoseWithCovarianceStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + pose=Pose( + position=[ + lcm_msg.pose.pose.position.x, + lcm_msg.pose.pose.position.y, + lcm_msg.pose.pose.position.z, + ], + orientation=[ + lcm_msg.pose.pose.orientation.x, + lcm_msg.pose.pose.orientation.y, + lcm_msg.pose.pose.orientation.z, + lcm_msg.pose.pose.orientation.w, + ], + ), + covariance=lcm_msg.pose.covariance, + ) + + def __str__(self) -> str: + return ( + f"PoseWithCovarianceStamped(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseWithCovarianceStamped) -> "PoseWithCovarianceStamped": + """Create a PoseWithCovarianceStamped from a ROS geometry_msgs/PoseWithCovarianceStamped message. + + Args: + ros_msg: ROS PoseWithCovarianceStamped message + + Returns: + PoseWithCovarianceStamped instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose with covariance + pose_with_cov = PoseWithCovariance.from_ros_msg(ros_msg.pose) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + pose=pose_with_cov.pose, + covariance=pose_with_cov.covariance, + ) + + def to_ros_msg(self) -> ROSPoseWithCovarianceStamped: + """Convert to a ROS geometry_msgs/PoseWithCovarianceStamped message. + + Returns: + ROS PoseWithCovarianceStamped message + """ + ros_msg = ROSPoseWithCovarianceStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set pose with covariance + ros_msg.pose.pose = self.pose.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.pose.covariance = self.covariance.tolist() + else: + ros_msg.pose.covariance = list(self.covariance) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index a47c58337c..1f6121f6cf 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -19,6 +19,10 @@ from dimos_lcm.geometry_msgs import Transform as LCMTransform from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped +from geometry_msgs.msg import TransformStamped as ROSTransformStamped +from geometry_msgs.msg import Transform as ROSTransform +from geometry_msgs.msg import Vector3 as ROSVector3 +from geometry_msgs.msg import Quaternion as ROSQuaternion from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -137,6 +141,68 @@ def inverse(self) -> "Transform": ts=self.ts, ) + @classmethod + def from_ros_transform_stamped(cls, ros_msg: ROSTransformStamped) -> "Transform": + """Create a Transform from a ROS geometry_msgs/TransformStamped message. + + Args: + ros_msg: ROS TransformStamped message + + Returns: + Transform instance + """ + # Convert timestamp + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert translation + translation = Vector3( + ros_msg.transform.translation.x, + ros_msg.transform.translation.y, + ros_msg.transform.translation.z, + ) + + # Convert rotation + rotation = Quaternion( + ros_msg.transform.rotation.x, + ros_msg.transform.rotation.y, + ros_msg.transform.rotation.z, + ros_msg.transform.rotation.w, + ) + + return cls( + translation=translation, + rotation=rotation, + frame_id=ros_msg.header.frame_id, + child_frame_id=ros_msg.child_frame_id, + ts=ts, + ) + + def to_ros_transform_stamped(self) -> ROSTransformStamped: + """Convert to a ROS geometry_msgs/TransformStamped message. + + Returns: + ROS TransformStamped message + """ + ros_msg = ROSTransformStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set child frame + ros_msg.child_frame_id = self.child_frame_id + + # Set transform + ros_msg.transform.translation = ROSVector3( + x=self.translation.x, y=self.translation.y, z=self.translation.z + ) + ros_msg.transform.rotation = ROSQuaternion( + x=self.rotation.x, y=self.rotation.y, z=self.rotation.z, w=self.rotation.w + ) + + return ros_msg + def __neg__(self) -> "Transform": """Unary minus operator returns the inverse transform.""" return self.inverse() diff --git a/dimos/msgs/geometry_msgs/Twist.py b/dimos/msgs/geometry_msgs/Twist.py index 10e07eaeb5..fe951bff09 100644 --- a/dimos/msgs/geometry_msgs/Twist.py +++ b/dimos/msgs/geometry_msgs/Twist.py @@ -19,6 +19,8 @@ from typing import BinaryIO from dimos_lcm.geometry_msgs import Twist as LCMTwist +from geometry_msgs.msg import Twist as ROSTwist +from geometry_msgs.msg import Vector3 as ROSVector3 from plum import dispatch from dimos.msgs.geometry_msgs.Quaternion import Quaternion @@ -100,3 +102,28 @@ def __bool__(self) -> bool: False if twist is zero, True otherwise """ return not self.is_zero() + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwist) -> "Twist": + """Create a Twist from a ROS geometry_msgs/Twist message. + + Args: + ros_msg: ROS Twist message + + Returns: + Twist instance + """ + linear = Vector3(ros_msg.linear.x, ros_msg.linear.y, ros_msg.linear.z) + angular = Vector3(ros_msg.angular.x, ros_msg.angular.y, ros_msg.angular.z) + return cls(linear, angular) + + def to_ros_msg(self) -> ROSTwist: + """Convert to a ROS geometry_msgs/Twist message. + + Returns: + ROS Twist message + """ + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=self.linear.x, y=self.linear.y, z=self.linear.z) + ros_msg.angular = ROSVector3(x=self.angular.x, y=self.angular.y, z=self.angular.z) + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistStamped.py b/dimos/msgs/geometry_msgs/TwistStamped.py new file mode 100644 index 0000000000..0afa7e662f --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistStamped.py @@ -0,0 +1,116 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +import time +from io import BytesIO +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import TwistStamped as LCMTwistStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime +from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from plum import dispatch + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from TwistStamped +TwistConvertable: TypeAlias = ( + tuple[VectorConvertable, VectorConvertable] | LCMTwistStamped | dict[str, VectorConvertable] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class TwistStamped(Twist, Timestamped): + msg_name = "geometry_msgs.TwistStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMTwistStamped() + lcm_msg.twist = self + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> TwistStamped: + lcm_msg = LCMTwistStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + linear=[lcm_msg.twist.linear.x, lcm_msg.twist.linear.y, lcm_msg.twist.linear.z], + angular=[lcm_msg.twist.angular.x, lcm_msg.twist.angular.y, lcm_msg.twist.angular.z], + ) + + def __str__(self) -> str: + return ( + f"TwistStamped(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}])" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistStamped) -> "TwistStamped": + """Create a TwistStamped from a ROS geometry_msgs/TwistStamped message. + + Args: + ros_msg: ROS TwistStamped message + + Returns: + TwistStamped instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert twist + twist = Twist.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + linear=twist.linear, + angular=twist.angular, + ) + + def to_ros_msg(self) -> ROSTwistStamped: + """Convert to a ROS geometry_msgs/TwistStamped message. + + Returns: + ROS TwistStamped message + """ + ros_msg = ROSTwistStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set twist + ros_msg.twist = Twist.to_ros_msg(self) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistWithCovariance.py b/dimos/msgs/geometry_msgs/TwistWithCovariance.py new file mode 100644 index 0000000000..81be1c3874 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistWithCovariance.py @@ -0,0 +1,219 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import TwistWithCovariance as LCMTwistWithCovariance +from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance +from plum import dispatch + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable + +# Types that can be converted to/from TwistWithCovariance +TwistWithCovarianceConvertable: TypeAlias = ( + tuple[Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray] + | LCMTwistWithCovariance + | dict[str, Twist | tuple[VectorConvertable, VectorConvertable] | list[float] | np.ndarray] +) + + +class TwistWithCovariance(LCMTwistWithCovariance): + twist: Twist + msg_name = "geometry_msgs.TwistWithCovariance" + + @dispatch + def __init__(self) -> None: + """Initialize with default twist and zero covariance.""" + self.twist = Twist() + self.covariance = np.zeros(36) + + @dispatch + def __init__( + self, + twist: Twist | tuple[VectorConvertable, VectorConvertable], + covariance: list[float] | np.ndarray | None = None, + ) -> None: + """Initialize with twist and optional covariance.""" + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__(self, twist_with_cov: TwistWithCovariance) -> None: + """Initialize from another TwistWithCovariance (copy constructor).""" + self.twist = Twist(twist_with_cov.twist) + self.covariance = np.array(twist_with_cov.covariance).copy() + + @dispatch + def __init__(self, lcm_twist_with_cov: LCMTwistWithCovariance) -> None: + """Initialize from an LCM TwistWithCovariance.""" + self.twist = Twist(lcm_twist_with_cov.twist) + self.covariance = np.array(lcm_twist_with_cov.covariance) + + @dispatch + def __init__( + self, + twist_dict: dict[ + str, Twist | tuple[VectorConvertable, VectorConvertable] | list[float] | np.ndarray + ], + ) -> None: + """Initialize from a dictionary with 'twist' and 'covariance' keys.""" + twist = twist_dict["twist"] + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + + covariance = twist_dict.get("covariance") + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__( + self, + twist_tuple: tuple[ + Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray + ], + ) -> None: + """Initialize from a tuple of (twist, covariance).""" + twist = twist_tuple[0] + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + self.covariance = np.array(twist_tuple[1], dtype=float).reshape(36) + + def __getattribute__(self, name): + """Override to ensure covariance is always returned as numpy array.""" + if name == "covariance": + cov = object.__getattribute__(self, "covariance") + if not isinstance(cov, np.ndarray): + return np.array(cov, dtype=float) + return cov + return super().__getattribute__(name) + + def __setattr__(self, name, value): + """Override to ensure covariance is stored as numpy array.""" + if name == "covariance": + if not isinstance(value, np.ndarray): + value = np.array(value, dtype=float).reshape(36) + super().__setattr__(name, value) + + @property + def linear(self) -> Vector3: + """Linear velocity vector.""" + return self.twist.linear + + @property + def angular(self) -> Vector3: + """Angular velocity vector.""" + return self.twist.angular + + @property + def covariance_matrix(self) -> np.ndarray: + """Get covariance as 6x6 matrix.""" + return self.covariance.reshape(6, 6) + + @covariance_matrix.setter + def covariance_matrix(self, value: np.ndarray) -> None: + """Set covariance from 6x6 matrix.""" + self.covariance = np.array(value).reshape(36) + + def __repr__(self) -> str: + return f"TwistWithCovariance(twist={self.twist!r}, covariance=<{self.covariance.shape[0] if isinstance(self.covariance, np.ndarray) else len(self.covariance)} elements>)" + + def __str__(self) -> str: + return ( + f"TwistWithCovariance(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + def __eq__(self, other) -> bool: + """Check if two TwistWithCovariance are equal.""" + if not isinstance(other, TwistWithCovariance): + return False + return self.twist == other.twist and np.allclose(self.covariance, other.covariance) + + def is_zero(self) -> bool: + """Check if this is a zero twist (no linear or angular velocity).""" + return self.twist.is_zero() + + def __bool__(self) -> bool: + """Boolean conversion - False if zero twist, True otherwise.""" + return not self.is_zero() + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMTwistWithCovariance() + lcm_msg.twist = self.twist + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.covariance = self.covariance.tolist() + else: + lcm_msg.covariance = list(self.covariance) + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "TwistWithCovariance": + """Decode from LCM binary format.""" + lcm_msg = LCMTwistWithCovariance.lcm_decode(data) + twist = Twist( + linear=[lcm_msg.twist.linear.x, lcm_msg.twist.linear.y, lcm_msg.twist.linear.z], + angular=[lcm_msg.twist.angular.x, lcm_msg.twist.angular.y, lcm_msg.twist.angular.z], + ) + return cls(twist, lcm_msg.covariance) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistWithCovariance) -> "TwistWithCovariance": + """Create a TwistWithCovariance from a ROS geometry_msgs/TwistWithCovariance message. + + Args: + ros_msg: ROS TwistWithCovariance message + + Returns: + TwistWithCovariance instance + """ + twist = Twist.from_ros_msg(ros_msg.twist) + return cls(twist, list(ros_msg.covariance)) + + def to_ros_msg(self) -> ROSTwistWithCovariance: + """Convert to a ROS geometry_msgs/TwistWithCovariance message. + + Returns: + ROS TwistWithCovariance message + """ + ros_msg = ROSTwistWithCovariance() + ros_msg.twist = self.twist.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.covariance = self.covariance.tolist() + else: + ros_msg.covariance = list(self.covariance) + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py new file mode 100644 index 0000000000..f199eb60c9 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py @@ -0,0 +1,163 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped +from geometry_msgs.msg import TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped +from plum import dispatch + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from TwistWithCovarianceStamped +TwistWithCovarianceStampedConvertable: TypeAlias = ( + tuple[Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray] + | LCMTwistWithCovarianceStamped + | dict[ + str, + Twist + | tuple[VectorConvertable, VectorConvertable] + | list[float] + | np.ndarray + | float + | str, + ] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class TwistWithCovarianceStamped(TwistWithCovariance, Timestamped): + msg_name = "geometry_msgs.TwistWithCovarianceStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + """Initialize with timestamp and frame_id.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + twist: Twist | tuple[VectorConvertable, VectorConvertable] | None = None, + covariance: list[float] | np.ndarray | None = None, + ) -> None: + """Initialize with timestamp, frame_id, twist and covariance.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + if twist is None: + super().__init__() + else: + super().__init__(twist, covariance) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMTwistWithCovarianceStamped() + lcm_msg.twist.twist = self.twist + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.twist.covariance = self.covariance.tolist() + else: + lcm_msg.twist.covariance = list(self.covariance) + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> TwistWithCovarianceStamped: + lcm_msg = LCMTwistWithCovarianceStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + twist=Twist( + linear=[ + lcm_msg.twist.twist.linear.x, + lcm_msg.twist.twist.linear.y, + lcm_msg.twist.twist.linear.z, + ], + angular=[ + lcm_msg.twist.twist.angular.x, + lcm_msg.twist.twist.angular.y, + lcm_msg.twist.twist.angular.z, + ], + ), + covariance=lcm_msg.twist.covariance, + ) + + def __str__(self) -> str: + return ( + f"TwistWithCovarianceStamped(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistWithCovarianceStamped) -> "TwistWithCovarianceStamped": + """Create a TwistWithCovarianceStamped from a ROS geometry_msgs/TwistWithCovarianceStamped message. + + Args: + ros_msg: ROS TwistWithCovarianceStamped message + + Returns: + TwistWithCovarianceStamped instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert twist with covariance + twist_with_cov = TwistWithCovariance.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + twist=twist_with_cov.twist, + covariance=twist_with_cov.covariance, + ) + + def to_ros_msg(self) -> ROSTwistWithCovarianceStamped: + """Convert to a ROS geometry_msgs/TwistWithCovarianceStamped message. + + Returns: + ROS TwistWithCovarianceStamped message + """ + ros_msg = ROSTwistWithCovarianceStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set twist with covariance + ros_msg.twist.twist = self.twist.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.twist.covariance = self.covariance.tolist() + else: + ros_msg.twist.covariance = list(self.covariance) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py index 86d25bb843..137b113a4d 100644 --- a/dimos/msgs/geometry_msgs/__init__.py +++ b/dimos/msgs/geometry_msgs/__init__.py @@ -3,4 +3,5 @@ from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike diff --git a/dimos/msgs/geometry_msgs/test_Pose.py b/dimos/msgs/geometry_msgs/test_Pose.py index 2524f12faa..31f05934c9 100644 --- a/dimos/msgs/geometry_msgs/test_Pose.py +++ b/dimos/msgs/geometry_msgs/test_Pose.py @@ -17,6 +17,9 @@ import numpy as np import pytest from dimos_lcm.geometry_msgs import Pose as LCMPose +from geometry_msgs.msg import Pose as ROSPose +from geometry_msgs.msg import Point as ROSPoint +from geometry_msgs.msg import Quaternion as ROSQuaternion from dimos.msgs.geometry_msgs.Pose import Pose, to_pose from dimos.msgs.geometry_msgs.Quaternion import Quaternion @@ -747,3 +750,52 @@ def test_pose_addition_3d_rotation(): assert np.isclose(result.position.x, 1.0, atol=1e-10) # X unchanged assert np.isclose(result.position.y, cos45 - sin45, atol=1e-10) assert np.isclose(result.position.z, sin45 + cos45, atol=1e-10) + + +def test_pose_from_ros_msg(): + """Test creating a Pose from a ROS Pose message.""" + ros_msg = ROSPose() + ros_msg.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + + pose = Pose.from_ros_msg(ros_msg) + + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +def test_pose_to_ros_msg(): + """Test converting a Pose to a ROS Pose message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + ros_msg = pose.to_ros_msg() + + assert isinstance(ros_msg, ROSPose) + assert ros_msg.position.x == 1.0 + assert ros_msg.position.y == 2.0 + assert ros_msg.position.z == 3.0 + assert ros_msg.orientation.x == 0.1 + assert ros_msg.orientation.y == 0.2 + assert ros_msg.orientation.z == 0.3 + assert ros_msg.orientation.w == 0.9 + + +def test_pose_ros_roundtrip(): + """Test round-trip conversion between Pose and ROS Pose.""" + original = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + + ros_msg = original.to_ros_msg() + restored = Pose.from_ros_msg(ros_msg) + + assert restored.position.x == original.position.x + assert restored.position.y == original.position.y + assert restored.position.z == original.position.z + assert restored.orientation.x == original.orientation.x + assert restored.orientation.y == original.orientation.y + assert restored.orientation.z == original.orientation.z + assert restored.orientation.w == original.orientation.w diff --git a/dimos/msgs/geometry_msgs/test_PoseStamped.py b/dimos/msgs/geometry_msgs/test_PoseStamped.py index 86dbf72bdc..33ddee1fc3 100644 --- a/dimos/msgs/geometry_msgs/test_PoseStamped.py +++ b/dimos/msgs/geometry_msgs/test_PoseStamped.py @@ -15,6 +15,8 @@ import pickle import time +from geometry_msgs.msg import PoseStamped as ROSPoseStamped + from dimos.msgs.geometry_msgs import PoseStamped @@ -53,3 +55,77 @@ def test_pickle_encode_decode(): assert isinstance(pose_dest, PoseStamped) assert pose_dest is not pose_source assert pose_dest == pose_source + + +def test_pose_stamped_from_ros_msg(): + """Test creating a PoseStamped from a ROS PoseStamped message.""" + ros_msg = ROSPoseStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.pose.position.x = 1.0 + ros_msg.pose.position.y = 2.0 + ros_msg.pose.position.z = 3.0 + ros_msg.pose.orientation.x = 0.1 + ros_msg.pose.orientation.y = 0.2 + ros_msg.pose.orientation.z = 0.3 + ros_msg.pose.orientation.w = 0.9 + + pose_stamped = PoseStamped.from_ros_msg(ros_msg) + + assert pose_stamped.frame_id == "world" + assert pose_stamped.ts == 123.456 + assert pose_stamped.position.x == 1.0 + assert pose_stamped.position.y == 2.0 + assert pose_stamped.position.z == 3.0 + assert pose_stamped.orientation.x == 0.1 + assert pose_stamped.orientation.y == 0.2 + assert pose_stamped.orientation.z == 0.3 + assert pose_stamped.orientation.w == 0.9 + + +def test_pose_stamped_to_ros_msg(): + """Test converting a PoseStamped to a ROS PoseStamped message.""" + pose_stamped = PoseStamped( + ts=123.456, + frame_id="base_link", + position=(1.0, 2.0, 3.0), + orientation=(0.1, 0.2, 0.3, 0.9), + ) + + ros_msg = pose_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert ros_msg.pose.position.x == 1.0 + assert ros_msg.pose.position.y == 2.0 + assert ros_msg.pose.position.z == 3.0 + assert ros_msg.pose.orientation.x == 0.1 + assert ros_msg.pose.orientation.y == 0.2 + assert ros_msg.pose.orientation.z == 0.3 + assert ros_msg.pose.orientation.w == 0.9 + + +def test_pose_stamped_ros_roundtrip(): + """Test round-trip conversion between PoseStamped and ROS PoseStamped.""" + original = PoseStamped( + ts=123.789, + frame_id="odom", + position=(1.5, 2.5, 3.5), + orientation=(0.15, 0.25, 0.35, 0.85), + ) + + ros_msg = original.to_ros_msg() + restored = PoseStamped.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert restored.position.x == original.position.x + assert restored.position.y == original.position.y + assert restored.position.z == original.position.z + assert restored.orientation.x == original.orientation.x + assert restored.orientation.y == original.orientation.y + assert restored.orientation.z == original.orientation.z + assert restored.orientation.w == original.orientation.w diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py new file mode 100644 index 0000000000..d35946cc4a --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py @@ -0,0 +1,378 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance +from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance +from geometry_msgs.msg import Pose as ROSPose +from geometry_msgs.msg import Point as ROSPoint +from geometry_msgs.msg import Quaternion as ROSQuaternion + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_pose_with_covariance_default_init(): + """Test that default initialization creates a pose at origin with zero covariance.""" + pose_cov = PoseWithCovariance() + + # Pose should be at origin with identity orientation + assert pose_cov.pose.position.x == 0.0 + assert pose_cov.pose.position.y == 0.0 + assert pose_cov.pose.position.z == 0.0 + assert pose_cov.pose.orientation.x == 0.0 + assert pose_cov.pose.orientation.y == 0.0 + assert pose_cov.pose.orientation.z == 0.0 + assert pose_cov.pose.orientation.w == 1.0 + + # Covariance should be all zeros + assert np.all(pose_cov.covariance == 0.0) + assert pose_cov.covariance.shape == (36,) + + +def test_pose_with_covariance_pose_init(): + """Test initialization with a Pose object.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = PoseWithCovariance(pose) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + + # Covariance should be zeros by default + assert np.all(pose_cov.covariance == 0.0) + + +def test_pose_with_covariance_pose_and_covariance_init(): + """Test initialization with pose and covariance.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + + # Covariance should match + assert np.array_equal(pose_cov.covariance, covariance) + + +def test_pose_with_covariance_list_covariance(): + """Test initialization with covariance as a list.""" + pose = Pose(1.0, 2.0, 3.0) + covariance_list = list(range(36)) + pose_cov = PoseWithCovariance(pose, covariance_list) + + # Covariance should be converted to numpy array + assert isinstance(pose_cov.covariance, np.ndarray) + assert np.array_equal(pose_cov.covariance, np.array(covariance_list)) + + +def test_pose_with_covariance_copy_init(): + """Test copy constructor.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + original = PoseWithCovariance(pose, covariance) + copy = PoseWithCovariance(original) + + # Should be equal but not the same object + assert copy == original + assert copy is not original + assert copy.pose is not original.pose + assert copy.covariance is not original.covariance + + # Modify original to ensure they're independent + original.covariance[0] = 999.0 + assert copy.covariance[0] != 999.0 + + +def test_pose_with_covariance_lcm_init(): + """Test initialization from LCM message.""" + lcm_msg = LCMPoseWithCovariance() + lcm_msg.pose.position.x = 1.0 + lcm_msg.pose.position.y = 2.0 + lcm_msg.pose.position.z = 3.0 + lcm_msg.pose.orientation.x = 0.1 + lcm_msg.pose.orientation.y = 0.2 + lcm_msg.pose.orientation.z = 0.3 + lcm_msg.pose.orientation.w = 0.9 + lcm_msg.covariance = list(range(36)) + + pose_cov = PoseWithCovariance(lcm_msg) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + + # Covariance should match + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_dict_init(): + """Test initialization from dictionary.""" + pose_dict = {"pose": Pose(1.0, 2.0, 3.0), "covariance": list(range(36))} + pose_cov = PoseWithCovariance(pose_dict) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_dict_init_no_covariance(): + """Test initialization from dictionary without covariance.""" + pose_dict = {"pose": Pose(1.0, 2.0, 3.0)} + pose_cov = PoseWithCovariance(pose_dict) + + assert pose_cov.pose.position.x == 1.0 + assert np.all(pose_cov.covariance == 0.0) + + +def test_pose_with_covariance_tuple_init(): + """Test initialization from tuple.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.arange(36, dtype=float) + pose_tuple = (pose, covariance) + pose_cov = PoseWithCovariance(pose_tuple) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert np.array_equal(pose_cov.covariance, covariance) + + +def test_pose_with_covariance_properties(): + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = PoseWithCovariance(pose) + + # Position properties + assert pose_cov.x == 1.0 + assert pose_cov.y == 2.0 + assert pose_cov.z == 3.0 + assert pose_cov.position.x == 1.0 + assert pose_cov.position.y == 2.0 + assert pose_cov.position.z == 3.0 + + # Orientation properties + assert pose_cov.orientation.x == 0.1 + assert pose_cov.orientation.y == 0.2 + assert pose_cov.orientation.z == 0.3 + assert pose_cov.orientation.w == 0.9 + + # Euler angle properties + assert pose_cov.roll == pose.roll + assert pose_cov.pitch == pose.pitch + assert pose_cov.yaw == pose.yaw + + +def test_pose_with_covariance_matrix_property(): + """Test covariance matrix property.""" + pose = Pose() + covariance_array = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance_array) + + # Get as matrix + cov_matrix = pose_cov.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert cov_matrix[0, 0] == 0.0 + assert cov_matrix[5, 5] == 35.0 + + # Set from matrix + new_matrix = np.eye(6) * 2.0 + pose_cov.covariance_matrix = new_matrix + assert np.array_equal(pose_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + +def test_pose_with_covariance_repr(): + """Test string representation.""" + pose = Pose(1.234, 2.567, 3.891) + pose_cov = PoseWithCovariance(pose) + + repr_str = repr(pose_cov) + assert "PoseWithCovariance" in repr_str + assert "pose=" in repr_str + assert "covariance=" in repr_str + assert "36 elements" in repr_str + + +def test_pose_with_covariance_str(): + """Test string formatting.""" + pose = Pose(1.234, 2.567, 3.891) + covariance = np.eye(6).flatten() + pose_cov = PoseWithCovariance(pose, covariance) + + str_repr = str(pose_cov) + assert "PoseWithCovariance" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "6.000" in str_repr # Trace of identity matrix is 6 + + +def test_pose_with_covariance_equality(): + """Test equality comparison.""" + pose1 = Pose(1.0, 2.0, 3.0) + cov1 = np.arange(36, dtype=float) + pose_cov1 = PoseWithCovariance(pose1, cov1) + + pose2 = Pose(1.0, 2.0, 3.0) + cov2 = np.arange(36, dtype=float) + pose_cov2 = PoseWithCovariance(pose2, cov2) + + # Equal + assert pose_cov1 == pose_cov2 + + # Different pose + pose3 = Pose(1.1, 2.0, 3.0) + pose_cov3 = PoseWithCovariance(pose3, cov1) + assert pose_cov1 != pose_cov3 + + # Different covariance + cov3 = np.arange(36, dtype=float) + 1 + pose_cov4 = PoseWithCovariance(pose1, cov3) + assert pose_cov1 != pose_cov4 + + # Different type + assert pose_cov1 != "not a pose" + assert pose_cov1 != None + + +def test_pose_with_covariance_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + source = PoseWithCovariance(pose, covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = PoseWithCovariance.lcm_decode(binary_msg) + + # Should be equal + assert decoded == source + assert isinstance(decoded, PoseWithCovariance) + assert isinstance(decoded.pose, Pose) + assert isinstance(decoded.covariance, np.ndarray) + + +def test_pose_with_covariance_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSPoseWithCovariance() + ros_msg.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.covariance = [float(i) for i in range(36)] + + pose_cov = PoseWithCovariance.from_ros_msg(ros_msg) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_to_ros_msg(): + """Test converting to ROS message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance) + + ros_msg = pose_cov.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseWithCovariance) + assert ros_msg.pose.position.x == 1.0 + assert ros_msg.pose.position.y == 2.0 + assert ros_msg.pose.position.z == 3.0 + assert ros_msg.pose.orientation.x == 0.1 + assert ros_msg.pose.orientation.y == 0.2 + assert ros_msg.pose.orientation.z == 0.3 + assert ros_msg.pose.orientation.w == 0.9 + assert list(ros_msg.covariance) == list(range(36)) + + +def test_pose_with_covariance_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + covariance = np.random.rand(36) + original = PoseWithCovariance(pose, covariance) + + ros_msg = original.to_ros_msg() + restored = PoseWithCovariance.from_ros_msg(ros_msg) + + assert restored == original + + +def test_pose_with_covariance_zero_covariance(): + """Test with zero covariance matrix.""" + pose = Pose(1.0, 2.0, 3.0) + pose_cov = PoseWithCovariance(pose) + + assert np.all(pose_cov.covariance == 0.0) + assert np.trace(pose_cov.covariance_matrix) == 0.0 + + +def test_pose_with_covariance_diagonal_covariance(): + """Test with diagonal covariance matrix.""" + pose = Pose() + covariance = np.zeros(36) + # Set diagonal elements + for i in range(6): + covariance[i * 6 + i] = i + 1 + + pose_cov = PoseWithCovariance(pose, covariance) + + cov_matrix = pose_cov.covariance_matrix + assert np.trace(cov_matrix) == sum(range(1, 7)) # 1+2+3+4+5+6 = 21 + + # Check diagonal elements + for i in range(6): + assert cov_matrix[i, i] == i + 1 + + # Check off-diagonal elements are zero + for i in range(6): + for j in range(6): + if i != j: + assert cov_matrix[i, j] == 0.0 + + +@pytest.mark.parametrize( + "x,y,z", + [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (100.0, -100.0, 0.0)], +) +def test_pose_with_covariance_parametrized_positions(x, y, z): + """Parametrized test for various position values.""" + pose = Pose(x, y, z) + pose_cov = PoseWithCovariance(pose) + + assert pose_cov.x == x + assert pose_cov.y == y + assert pose_cov.z == z diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py new file mode 100644 index 0000000000..f6b7560e16 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py @@ -0,0 +1,342 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime +from geometry_msgs.msg import PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped +from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance +from geometry_msgs.msg import Pose as ROSPose +from geometry_msgs.msg import Point as ROSPoint +from geometry_msgs.msg import Quaternion as ROSQuaternion +from std_msgs.msg import Header as ROSHeader +from builtin_interfaces.msg import Time as ROSTime + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_pose_with_covariance_stamped_default_init(): + """Test default initialization.""" + pose_cov_stamped = PoseWithCovarianceStamped() + + # Should have current timestamp + assert pose_cov_stamped.ts > 0 + assert pose_cov_stamped.frame_id == "" + + # Pose should be at origin with identity orientation + assert pose_cov_stamped.pose.position.x == 0.0 + assert pose_cov_stamped.pose.position.y == 0.0 + assert pose_cov_stamped.pose.position.z == 0.0 + assert pose_cov_stamped.pose.orientation.w == 1.0 + + # Covariance should be all zeros + assert np.all(pose_cov_stamped.covariance == 0.0) + + +def test_pose_with_covariance_stamped_with_timestamp(): + """Test initialization with specific timestamp.""" + ts = 1234567890.123456 + frame_id = "base_link" + pose_cov_stamped = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id) + + assert pose_cov_stamped.ts == ts + assert pose_cov_stamped.frame_id == frame_id + + +def test_pose_with_covariance_stamped_with_pose(): + """Test initialization with pose.""" + ts = 1234567890.123456 + frame_id = "map" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + pose_cov_stamped = PoseWithCovarianceStamped( + ts=ts, frame_id=frame_id, pose=pose, covariance=covariance + ) + + assert pose_cov_stamped.ts == ts + assert pose_cov_stamped.frame_id == frame_id + assert pose_cov_stamped.pose.position.x == 1.0 + assert pose_cov_stamped.pose.position.y == 2.0 + assert pose_cov_stamped.pose.position.z == 3.0 + assert np.array_equal(pose_cov_stamped.covariance, covariance) + + +def test_pose_with_covariance_stamped_properties(): + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.eye(6).flatten() + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="odom", pose=pose, covariance=covariance + ) + + # Position properties + assert pose_cov_stamped.x == 1.0 + assert pose_cov_stamped.y == 2.0 + assert pose_cov_stamped.z == 3.0 + + # Orientation properties + assert pose_cov_stamped.orientation.x == 0.1 + assert pose_cov_stamped.orientation.y == 0.2 + assert pose_cov_stamped.orientation.z == 0.3 + assert pose_cov_stamped.orientation.w == 0.9 + + # Euler angles + assert pose_cov_stamped.roll == pose.roll + assert pose_cov_stamped.pitch == pose.pitch + assert pose_cov_stamped.yaw == pose.yaw + + # Covariance matrix + cov_matrix = pose_cov_stamped.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert np.trace(cov_matrix) == 6.0 + + +def test_pose_with_covariance_stamped_str(): + """Test string representation.""" + pose = Pose(1.234, 2.567, 3.891) + covariance = np.eye(6).flatten() * 2.0 + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="world", pose=pose, covariance=covariance + ) + + str_repr = str(pose_cov_stamped) + assert "PoseWithCovarianceStamped" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "12.000" in str_repr # Trace of 2*identity is 12 + + +def test_pose_with_covariance_stamped_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + ts = 1234567890.123456 + frame_id = "camera_link" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + source = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id, pose=pose, covariance=covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = PoseWithCovarianceStamped.lcm_decode(binary_msg) + + # Check timestamp (may lose some precision) + assert abs(decoded.ts - ts) < 1e-6 + assert decoded.frame_id == frame_id + + # Check pose + assert decoded.pose.position.x == 1.0 + assert decoded.pose.position.y == 2.0 + assert decoded.pose.position.z == 3.0 + assert decoded.pose.orientation.x == 0.1 + assert decoded.pose.orientation.y == 0.2 + assert decoded.pose.orientation.z == 0.3 + assert decoded.pose.orientation.w == 0.9 + + # Check covariance + assert np.array_equal(decoded.covariance, covariance) + + +def test_pose_with_covariance_stamped_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSPoseWithCovarianceStamped() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "laser" + + # Set pose with covariance + ros_msg.pose = ROSPoseWithCovariance() + ros_msg.pose.pose = ROSPose() + ros_msg.pose.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.pose.covariance = [float(i) for i in range(36)] + + pose_cov_stamped = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + + assert pose_cov_stamped.ts == 1234567890.123456 + assert pose_cov_stamped.frame_id == "laser" + assert pose_cov_stamped.pose.position.x == 1.0 + assert pose_cov_stamped.pose.position.y == 2.0 + assert pose_cov_stamped.pose.position.z == 3.0 + assert pose_cov_stamped.pose.orientation.x == 0.1 + assert pose_cov_stamped.pose.orientation.y == 0.2 + assert pose_cov_stamped.pose.orientation.z == 0.3 + assert pose_cov_stamped.pose.orientation.w == 0.9 + assert np.array_equal(pose_cov_stamped.covariance, np.arange(36)) + + +def test_pose_with_covariance_stamped_to_ros_msg(): + """Test converting to ROS message.""" + ts = 1234567890.567890 + frame_id = "imu" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + pose_cov_stamped = PoseWithCovarianceStamped( + ts=ts, frame_id=frame_id, pose=pose, covariance=covariance + ) + + ros_msg = pose_cov_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseWithCovarianceStamped) + assert ros_msg.header.frame_id == frame_id + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + + assert ros_msg.pose.pose.position.x == 1.0 + assert ros_msg.pose.pose.position.y == 2.0 + assert ros_msg.pose.pose.position.z == 3.0 + assert ros_msg.pose.pose.orientation.x == 0.1 + assert ros_msg.pose.pose.orientation.y == 0.2 + assert ros_msg.pose.pose.orientation.z == 0.3 + assert ros_msg.pose.pose.orientation.w == 0.9 + assert list(ros_msg.pose.covariance) == list(range(36)) + + +def test_pose_with_covariance_stamped_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + ts = 2147483647.987654 # Max int32 value for ROS Time.sec + frame_id = "robot_base" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + covariance = np.random.rand(36) + + original = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id, pose=pose, covariance=covariance) + + ros_msg = original.to_ros_msg() + restored = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + + # Check timestamp (loses some precision in conversion) + assert abs(restored.ts - ts) < 1e-6 + assert restored.frame_id == frame_id + + # Check pose + assert restored.pose.position.x == original.pose.position.x + assert restored.pose.position.y == original.pose.position.y + assert restored.pose.position.z == original.pose.position.z + assert restored.pose.orientation.x == original.pose.orientation.x + assert restored.pose.orientation.y == original.pose.orientation.y + assert restored.pose.orientation.z == original.pose.orientation.z + assert restored.pose.orientation.w == original.pose.orientation.w + + # Check covariance + assert np.allclose(restored.covariance, original.covariance) + + +def test_pose_with_covariance_stamped_zero_timestamp(): + """Test that zero timestamp gets replaced with current time.""" + pose_cov_stamped = PoseWithCovarianceStamped(ts=0.0) + + # Should have been replaced with current time + assert pose_cov_stamped.ts > 0 + assert pose_cov_stamped.ts <= time.time() + + +def test_pose_with_covariance_stamped_inheritance(): + """Test that it properly inherits from PoseWithCovariance and Timestamped.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.eye(6).flatten() + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="test", pose=pose, covariance=covariance + ) + + # Should be instance of parent classes + assert isinstance(pose_cov_stamped, PoseWithCovariance) + + # Should have Timestamped attributes + assert hasattr(pose_cov_stamped, "ts") + assert hasattr(pose_cov_stamped, "frame_id") + + # Should have PoseWithCovariance attributes + assert hasattr(pose_cov_stamped, "pose") + assert hasattr(pose_cov_stamped, "covariance") + + +def test_pose_with_covariance_stamped_sec_nsec(): + """Test the sec_nsec helper function.""" + from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import sec_nsec + + # Test integer seconds + s, ns = sec_nsec(1234567890.0) + assert s == 1234567890 + assert ns == 0 + + # Test fractional seconds + s, ns = sec_nsec(1234567890.123456789) + assert s == 1234567890 + assert abs(ns - 123456789) < 100 # Allow small rounding error + + # Test small fractional seconds + s, ns = sec_nsec(0.000000001) + assert s == 0 + assert ns == 1 + + # Test large timestamp + s, ns = sec_nsec(9999999999.999999999) + # Due to floating point precision, this might round to 10000000000 + assert s in [9999999999, 10000000000] + if s == 9999999999: + assert abs(ns - 999999999) < 10 + else: + assert ns == 0 + + +@pytest.mark.parametrize( + "frame_id", + ["", "map", "odom", "base_link", "camera_optical_frame", "sensor/lidar/front"], +) +def test_pose_with_covariance_stamped_frame_ids(frame_id): + """Test various frame ID values.""" + pose_cov_stamped = PoseWithCovarianceStamped(frame_id=frame_id) + assert pose_cov_stamped.frame_id == frame_id + + # Test roundtrip through ROS + ros_msg = pose_cov_stamped.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + + restored = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + + +def test_pose_with_covariance_stamped_different_covariances(): + """Test with different covariance patterns.""" + pose = Pose(1.0, 2.0, 3.0) + + # Zero covariance + zero_cov = np.zeros(36) + pose_cov1 = PoseWithCovarianceStamped(pose=pose, covariance=zero_cov) + assert np.all(pose_cov1.covariance == 0.0) + + # Identity covariance + identity_cov = np.eye(6).flatten() + pose_cov2 = PoseWithCovarianceStamped(pose=pose, covariance=identity_cov) + assert np.trace(pose_cov2.covariance_matrix) == 6.0 + + # Full covariance + full_cov = np.random.rand(36) + pose_cov3 = PoseWithCovarianceStamped(pose=pose, covariance=full_cov) + assert np.array_equal(pose_cov3.covariance, full_cov) diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index e069305bca..d6c695c858 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -19,6 +19,7 @@ import pytest from dimos_lcm.geometry_msgs import Transform as LCMTransform from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped +from geometry_msgs.msg import TransformStamped as ROSTransformStamped from dimos.msgs.geometry_msgs import Pose, PoseStamped, Quaternion, Transform, Vector3 @@ -421,3 +422,83 @@ def test_transform_from_pose_invalid_type(): with pytest.raises(TypeError): Transform.from_pose(None) + + +def test_transform_from_ros_transform_stamped(): + """Test creating a Transform from a ROS TransformStamped message.""" + ros_msg = ROSTransformStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.child_frame_id = "robot" + ros_msg.transform.translation.x = 1.0 + ros_msg.transform.translation.y = 2.0 + ros_msg.transform.translation.z = 3.0 + ros_msg.transform.rotation.x = 0.1 + ros_msg.transform.rotation.y = 0.2 + ros_msg.transform.rotation.z = 0.3 + ros_msg.transform.rotation.w = 0.9 + + transform = Transform.from_ros_transform_stamped(ros_msg) + + assert transform.frame_id == "world" + assert transform.child_frame_id == "robot" + assert transform.ts == 123.456 + assert transform.translation.x == 1.0 + assert transform.translation.y == 2.0 + assert transform.translation.z == 3.0 + assert transform.rotation.x == 0.1 + assert transform.rotation.y == 0.2 + assert transform.rotation.z == 0.3 + assert transform.rotation.w == 0.9 + + +def test_transform_to_ros_transform_stamped(): + """Test converting a Transform to a ROS TransformStamped message.""" + transform = Transform( + translation=Vector3(4.0, 5.0, 6.0), + rotation=Quaternion(0.15, 0.25, 0.35, 0.85), + frame_id="base_link", + child_frame_id="sensor", + ts=124.789, + ) + + ros_msg = transform.to_ros_transform_stamped() + + assert isinstance(ros_msg, ROSTransformStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.child_frame_id == "sensor" + assert ros_msg.header.stamp.sec == 124 + assert ros_msg.header.stamp.nanosec == 789000000 + assert ros_msg.transform.translation.x == 4.0 + assert ros_msg.transform.translation.y == 5.0 + assert ros_msg.transform.translation.z == 6.0 + assert ros_msg.transform.rotation.x == 0.15 + assert ros_msg.transform.rotation.y == 0.25 + assert ros_msg.transform.rotation.z == 0.35 + assert ros_msg.transform.rotation.w == 0.85 + + +def test_transform_ros_roundtrip(): + """Test round-trip conversion between Transform and ROS TransformStamped.""" + original = Transform( + translation=Vector3(7.5, 8.5, 9.5), + rotation=Quaternion(0.0, 0.0, 0.383, 0.924), # ~45 degrees around Z + frame_id="odom", + child_frame_id="base_footprint", + ts=99.123, + ) + + ros_msg = original.to_ros_transform_stamped() + restored = Transform.from_ros_transform_stamped(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.child_frame_id == original.child_frame_id + assert restored.ts == original.ts + assert restored.translation.x == original.translation.x + assert restored.translation.y == original.translation.y + assert restored.translation.z == original.translation.z + assert restored.rotation.x == original.rotation.x + assert restored.rotation.y == original.rotation.y + assert restored.rotation.z == original.rotation.z + assert restored.rotation.w == original.rotation.w diff --git a/dimos/msgs/geometry_msgs/test_Twist.py b/dimos/msgs/geometry_msgs/test_Twist.py index 2e57523826..6cf3fe0f03 100644 --- a/dimos/msgs/geometry_msgs/test_Twist.py +++ b/dimos/msgs/geometry_msgs/test_Twist.py @@ -15,6 +15,8 @@ import numpy as np import pytest from dimos_lcm.geometry_msgs import Twist as LCMTwist +from geometry_msgs.msg import Twist as ROSTwist +from geometry_msgs.msg import Vector3 as ROSVector3 from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 @@ -198,3 +200,92 @@ def test_twist_with_lists(): tw2 = Twist(linear=np.array([4, 5, 6]), angular=np.array([0.4, 0.5, 0.6])) assert tw2.linear == Vector3(4, 5, 6) assert tw2.angular == Vector3(0.4, 0.5, 0.6) + + +def test_twist_from_ros_msg(): + """Test Twist.from_ros_msg conversion.""" + # Create ROS message + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=10.0, y=20.0, z=30.0) + ros_msg.angular = ROSVector3(x=1.0, y=2.0, z=3.0) + + # Convert to LCM + lcm_msg = Twist.from_ros_msg(ros_msg) + + assert isinstance(lcm_msg, Twist) + assert lcm_msg.linear.x == 10.0 + assert lcm_msg.linear.y == 20.0 + assert lcm_msg.linear.z == 30.0 + assert lcm_msg.angular.x == 1.0 + assert lcm_msg.angular.y == 2.0 + assert lcm_msg.angular.z == 3.0 + + +def test_twist_to_ros_msg(): + """Test Twist.to_ros_msg conversion.""" + # Create LCM message + lcm_msg = Twist(linear=Vector3(40.0, 50.0, 60.0), angular=Vector3(4.0, 5.0, 6.0)) + + # Convert to ROS + ros_msg = lcm_msg.to_ros_msg() + + assert isinstance(ros_msg, ROSTwist) + assert ros_msg.linear.x == 40.0 + assert ros_msg.linear.y == 50.0 + assert ros_msg.linear.z == 60.0 + assert ros_msg.angular.x == 4.0 + assert ros_msg.angular.y == 5.0 + assert ros_msg.angular.z == 6.0 + + +def test_ros_zero_twist_conversion(): + """Test conversion of zero twist messages between ROS and LCM.""" + # Test ROS to LCM with zero twist + ros_zero = ROSTwist() + lcm_zero = Twist.from_ros_msg(ros_zero) + assert lcm_zero.is_zero() + + # Test LCM to ROS with zero twist + lcm_zero2 = Twist.zero() + ros_zero2 = lcm_zero2.to_ros_msg() + assert ros_zero2.linear.x == 0.0 + assert ros_zero2.linear.y == 0.0 + assert ros_zero2.linear.z == 0.0 + assert ros_zero2.angular.x == 0.0 + assert ros_zero2.angular.y == 0.0 + assert ros_zero2.angular.z == 0.0 + + +def test_ros_negative_values_conversion(): + """Test ROS conversion with negative values.""" + # Create ROS message with negative values + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=-1.5, y=-2.5, z=-3.5) + ros_msg.angular = ROSVector3(x=-0.1, y=-0.2, z=-0.3) + + # Convert to LCM and back + lcm_msg = Twist.from_ros_msg(ros_msg) + ros_msg2 = lcm_msg.to_ros_msg() + + assert ros_msg2.linear.x == -1.5 + assert ros_msg2.linear.y == -2.5 + assert ros_msg2.linear.z == -3.5 + assert ros_msg2.angular.x == -0.1 + assert ros_msg2.angular.y == -0.2 + assert ros_msg2.angular.z == -0.3 + + +def test_ros_roundtrip_conversion(): + """Test round-trip conversion maintains data integrity.""" + # LCM -> ROS -> LCM + original_lcm = Twist(linear=Vector3(1.234, 5.678, 9.012), angular=Vector3(0.111, 0.222, 0.333)) + ros_intermediate = original_lcm.to_ros_msg() + final_lcm = Twist.from_ros_msg(ros_intermediate) + + assert final_lcm == original_lcm + assert final_lcm.linear.x == 1.234 + assert final_lcm.linear.y == 5.678 + assert final_lcm.linear.z == 9.012 + assert final_lcm.angular.x == 0.111 + assert final_lcm.angular.y == 0.222 + assert final_lcm.angular.z == 0.333 diff --git a/dimos/msgs/geometry_msgs/test_TwistStamped.py b/dimos/msgs/geometry_msgs/test_TwistStamped.py new file mode 100644 index 0000000000..c84cba8cf2 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistStamped.py @@ -0,0 +1,151 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle +import time + +from geometry_msgs.msg import TwistStamped as ROSTwistStamped + +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped + + +def test_lcm_encode_decode(): + """Test encoding and decoding of TwistStamped to/from binary LCM format.""" + + twist_source = TwistStamped( + ts=time.time(), + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + binary_msg = twist_source.lcm_encode() + twist_dest = TwistStamped.lcm_decode(binary_msg) + + assert isinstance(twist_dest, TwistStamped) + assert twist_dest is not twist_source + + print(twist_source.linear) + print(twist_source.angular) + + print(twist_dest.linear) + print(twist_dest.angular) + assert twist_dest == twist_source + + +def test_pickle_encode_decode(): + """Test encoding and decoding of TwistStamped to/from binary pickle format.""" + + twist_source = TwistStamped( + ts=time.time(), + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + binary_msg = pickle.dumps(twist_source) + twist_dest = pickle.loads(binary_msg) + assert isinstance(twist_dest, TwistStamped) + assert twist_dest is not twist_source + assert twist_dest == twist_source + + +def test_twist_stamped_from_ros_msg(): + """Test creating a TwistStamped from a ROS TwistStamped message.""" + ros_msg = ROSTwistStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.twist.linear.x = 1.0 + ros_msg.twist.linear.y = 2.0 + ros_msg.twist.linear.z = 3.0 + ros_msg.twist.angular.x = 0.1 + ros_msg.twist.angular.y = 0.2 + ros_msg.twist.angular.z = 0.3 + + twist_stamped = TwistStamped.from_ros_msg(ros_msg) + + assert twist_stamped.frame_id == "world" + assert twist_stamped.ts == 123.456 + assert twist_stamped.linear.x == 1.0 + assert twist_stamped.linear.y == 2.0 + assert twist_stamped.linear.z == 3.0 + assert twist_stamped.angular.x == 0.1 + assert twist_stamped.angular.y == 0.2 + assert twist_stamped.angular.z == 0.3 + + +def test_twist_stamped_to_ros_msg(): + """Test converting a TwistStamped to a ROS TwistStamped message.""" + twist_stamped = TwistStamped( + ts=123.456, + frame_id="base_link", + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + + ros_msg = twist_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert ros_msg.twist.linear.x == 1.0 + assert ros_msg.twist.linear.y == 2.0 + assert ros_msg.twist.linear.z == 3.0 + assert ros_msg.twist.angular.x == 0.1 + assert ros_msg.twist.angular.y == 0.2 + assert ros_msg.twist.angular.z == 0.3 + + +def test_twist_stamped_ros_roundtrip(): + """Test round-trip conversion between TwistStamped and ROS TwistStamped.""" + original = TwistStamped( + ts=123.789, + frame_id="odom", + linear=(1.5, 2.5, 3.5), + angular=(0.15, 0.25, 0.35), + ) + + ros_msg = original.to_ros_msg() + restored = TwistStamped.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert restored.linear.x == original.linear.x + assert restored.linear.y == original.linear.y + assert restored.linear.z == original.linear.z + assert restored.angular.x == original.angular.x + assert restored.angular.y == original.angular.y + assert restored.angular.z == original.angular.z + + +if __name__ == "__main__": + print("Running test_lcm_encode_decode...") + test_lcm_encode_decode() + print("✓ test_lcm_encode_decode passed") + + print("Running test_pickle_encode_decode...") + test_pickle_encode_decode() + print("✓ test_pickle_encode_decode passed") + + print("Running test_twist_stamped_from_ros_msg...") + test_twist_stamped_from_ros_msg() + print("✓ test_twist_stamped_from_ros_msg passed") + + print("Running test_twist_stamped_to_ros_msg...") + test_twist_stamped_to_ros_msg() + print("✓ test_twist_stamped_to_ros_msg passed") + + print("Running test_twist_stamped_ros_roundtrip...") + test_twist_stamped_ros_roundtrip() + print("✓ test_twist_stamped_ros_roundtrip passed") + + print("\nAll tests passed!") diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py new file mode 100644 index 0000000000..0a2dbdff06 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py @@ -0,0 +1,407 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import TwistWithCovariance as LCMTwistWithCovariance +from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance +from geometry_msgs.msg import Twist as ROSTwist +from geometry_msgs.msg import Vector3 as ROSVector3 + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_twist_with_covariance_default_init(): + """Test that default initialization creates a zero twist with zero covariance.""" + twist_cov = TwistWithCovariance() + + # Twist should be zero + assert twist_cov.twist.linear.x == 0.0 + assert twist_cov.twist.linear.y == 0.0 + assert twist_cov.twist.linear.z == 0.0 + assert twist_cov.twist.angular.x == 0.0 + assert twist_cov.twist.angular.y == 0.0 + assert twist_cov.twist.angular.z == 0.0 + + # Covariance should be all zeros + assert np.all(twist_cov.covariance == 0.0) + assert twist_cov.covariance.shape == (36,) + + +def test_twist_with_covariance_twist_init(): + """Test initialization with a Twist object.""" + linear = Vector3(1.0, 2.0, 3.0) + angular = Vector3(0.1, 0.2, 0.3) + twist = Twist(linear, angular) + twist_cov = TwistWithCovariance(twist) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should be zeros by default + assert np.all(twist_cov.covariance == 0.0) + + +def test_twist_with_covariance_twist_and_covariance_init(): + """Test initialization with twist and covariance.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_tuple_init(): + """Test initialization with tuple of (linear, angular) velocities.""" + linear = [1.0, 2.0, 3.0] + angular = [0.1, 0.2, 0.3] + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance((linear, angular), covariance) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_list_covariance(): + """Test initialization with covariance as a list.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance_list = list(range(36)) + twist_cov = TwistWithCovariance(twist, covariance_list) + + # Covariance should be converted to numpy array + assert isinstance(twist_cov.covariance, np.ndarray) + assert np.array_equal(twist_cov.covariance, np.array(covariance_list)) + + +def test_twist_with_covariance_copy_init(): + """Test copy constructor.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + original = TwistWithCovariance(twist, covariance) + copy = TwistWithCovariance(original) + + # Should be equal but not the same object + assert copy == original + assert copy is not original + assert copy.twist is not original.twist + assert copy.covariance is not original.covariance + + # Modify original to ensure they're independent + original.covariance[0] = 999.0 + assert copy.covariance[0] != 999.0 + + +def test_twist_with_covariance_lcm_init(): + """Test initialization from LCM message.""" + lcm_msg = LCMTwistWithCovariance() + lcm_msg.twist.linear.x = 1.0 + lcm_msg.twist.linear.y = 2.0 + lcm_msg.twist.linear.z = 3.0 + lcm_msg.twist.angular.x = 0.1 + lcm_msg.twist.angular.y = 0.2 + lcm_msg.twist.angular.z = 0.3 + lcm_msg.covariance = list(range(36)) + + twist_cov = TwistWithCovariance(lcm_msg) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_dict_init(): + """Test initialization from dictionary.""" + twist_dict = { + "twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)), + "covariance": list(range(36)), + } + twist_cov = TwistWithCovariance(twist_dict) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_dict_init_no_covariance(): + """Test initialization from dictionary without covariance.""" + twist_dict = {"twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3))} + twist_cov = TwistWithCovariance(twist_dict) + + assert twist_cov.twist.linear.x == 1.0 + assert np.all(twist_cov.covariance == 0.0) + + +def test_twist_with_covariance_tuple_of_tuple_init(): + """Test initialization from tuple of (twist_tuple, covariance).""" + twist_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance((twist_tuple, covariance)) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_properties(): + """Test convenience properties.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + # Linear and angular properties + assert twist_cov.linear.x == 1.0 + assert twist_cov.linear.y == 2.0 + assert twist_cov.linear.z == 3.0 + assert twist_cov.angular.x == 0.1 + assert twist_cov.angular.y == 0.2 + assert twist_cov.angular.z == 0.3 + + +def test_twist_with_covariance_matrix_property(): + """Test covariance matrix property.""" + twist = Twist() + covariance_array = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance_array) + + # Get as matrix + cov_matrix = twist_cov.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert cov_matrix[0, 0] == 0.0 + assert cov_matrix[5, 5] == 35.0 + + # Set from matrix + new_matrix = np.eye(6) * 2.0 + twist_cov.covariance_matrix = new_matrix + assert np.array_equal(twist_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + +def test_twist_with_covariance_repr(): + """Test string representation.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + repr_str = repr(twist_cov) + assert "TwistWithCovariance" in repr_str + assert "twist=" in repr_str + assert "covariance=" in repr_str + assert "36 elements" in repr_str + + +def test_twist_with_covariance_str(): + """Test string formatting.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov = TwistWithCovariance(twist, covariance) + + str_repr = str(twist_cov) + assert "TwistWithCovariance" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "6.000" in str_repr # Trace of identity matrix is 6 + + +def test_twist_with_covariance_equality(): + """Test equality comparison.""" + twist1 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + cov1 = np.arange(36, dtype=float) + twist_cov1 = TwistWithCovariance(twist1, cov1) + + twist2 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + cov2 = np.arange(36, dtype=float) + twist_cov2 = TwistWithCovariance(twist2, cov2) + + # Equal + assert twist_cov1 == twist_cov2 + + # Different twist + twist3 = Twist(Vector3(1.1, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov3 = TwistWithCovariance(twist3, cov1) + assert twist_cov1 != twist_cov3 + + # Different covariance + cov3 = np.arange(36, dtype=float) + 1 + twist_cov4 = TwistWithCovariance(twist1, cov3) + assert twist_cov1 != twist_cov4 + + # Different type + assert twist_cov1 != "not a twist" + assert twist_cov1 != None + + +def test_twist_with_covariance_is_zero(): + """Test is_zero method.""" + # Zero twist + twist_cov1 = TwistWithCovariance() + assert twist_cov1.is_zero() + assert not twist_cov1 # Boolean conversion + + # Non-zero twist + twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.0)) + twist_cov2 = TwistWithCovariance(twist) + assert not twist_cov2.is_zero() + assert twist_cov2 # Boolean conversion + + +def test_twist_with_covariance_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + source = TwistWithCovariance(twist, covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = TwistWithCovariance.lcm_decode(binary_msg) + + # Should be equal + assert decoded == source + assert isinstance(decoded, TwistWithCovariance) + assert isinstance(decoded.twist, Twist) + assert isinstance(decoded.covariance, np.ndarray) + + +def test_twist_with_covariance_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSTwistWithCovariance() + ros_msg.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.covariance = [float(i) for i in range(36)] + + twist_cov = TwistWithCovariance.from_ros_msg(ros_msg) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_to_ros_msg(): + """Test converting to ROS message.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance) + + ros_msg = twist_cov.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistWithCovariance) + assert ros_msg.twist.linear.x == 1.0 + assert ros_msg.twist.linear.y == 2.0 + assert ros_msg.twist.linear.z == 3.0 + assert ros_msg.twist.angular.x == 0.1 + assert ros_msg.twist.angular.y == 0.2 + assert ros_msg.twist.angular.z == 0.3 + assert list(ros_msg.covariance) == list(range(36)) + + +def test_twist_with_covariance_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) + covariance = np.random.rand(36) + original = TwistWithCovariance(twist, covariance) + + ros_msg = original.to_ros_msg() + restored = TwistWithCovariance.from_ros_msg(ros_msg) + + assert restored == original + + +def test_twist_with_covariance_zero_covariance(): + """Test with zero covariance matrix.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + assert np.all(twist_cov.covariance == 0.0) + assert np.trace(twist_cov.covariance_matrix) == 0.0 + + +def test_twist_with_covariance_diagonal_covariance(): + """Test with diagonal covariance matrix.""" + twist = Twist() + covariance = np.zeros(36) + # Set diagonal elements + for i in range(6): + covariance[i * 6 + i] = i + 1 + + twist_cov = TwistWithCovariance(twist, covariance) + + cov_matrix = twist_cov.covariance_matrix + assert np.trace(cov_matrix) == sum(range(1, 7)) # 1+2+3+4+5+6 = 21 + + # Check diagonal elements + for i in range(6): + assert cov_matrix[i, i] == i + 1 + + # Check off-diagonal elements are zero + for i in range(6): + for j in range(6): + if i != j: + assert cov_matrix[i, j] == 0.0 + + +@pytest.mark.parametrize( + "linear,angular", + [ + ([0.0, 0.0, 0.0], [0.0, 0.0, 0.0]), + ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]), + ([-1.0, -2.0, -3.0], [-0.1, -0.2, -0.3]), + ([100.0, -100.0, 0.0], [3.14, -3.14, 0.0]), + ], +) +def test_twist_with_covariance_parametrized_velocities(linear, angular): + """Parametrized test for various velocity values.""" + twist = Twist(linear, angular) + twist_cov = TwistWithCovariance(twist) + + assert twist_cov.linear.x == linear[0] + assert twist_cov.linear.y == linear[1] + assert twist_cov.linear.z == linear[2] + assert twist_cov.angular.x == angular[0] + assert twist_cov.angular.y == angular[1] + assert twist_cov.angular.z == angular[2] diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py new file mode 100644 index 0000000000..7abc3f689b --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py @@ -0,0 +1,367 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime +from geometry_msgs.msg import TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped +from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance +from geometry_msgs.msg import Twist as ROSTwist +from geometry_msgs.msg import Vector3 as ROSVector3 +from std_msgs.msg import Header as ROSHeader +from builtin_interfaces.msg import Time as ROSTime + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_twist_with_covariance_stamped_default_init(): + """Test default initialization.""" + twist_cov_stamped = TwistWithCovarianceStamped() + + # Should have current timestamp + assert twist_cov_stamped.ts > 0 + assert twist_cov_stamped.frame_id == "" + + # Twist should be zero + assert twist_cov_stamped.twist.linear.x == 0.0 + assert twist_cov_stamped.twist.linear.y == 0.0 + assert twist_cov_stamped.twist.linear.z == 0.0 + assert twist_cov_stamped.twist.angular.x == 0.0 + assert twist_cov_stamped.twist.angular.y == 0.0 + assert twist_cov_stamped.twist.angular.z == 0.0 + + # Covariance should be all zeros + assert np.all(twist_cov_stamped.covariance == 0.0) + + +def test_twist_with_covariance_stamped_with_timestamp(): + """Test initialization with specific timestamp.""" + ts = 1234567890.123456 + frame_id = "base_link" + twist_cov_stamped = TwistWithCovarianceStamped(ts=ts, frame_id=frame_id) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + + +def test_twist_with_covariance_stamped_with_twist(): + """Test initialization with twist.""" + ts = 1234567890.123456 + frame_id = "odom" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.linear.y == 2.0 + assert twist_cov_stamped.twist.linear.z == 3.0 + assert np.array_equal(twist_cov_stamped.covariance, covariance) + + +def test_twist_with_covariance_stamped_with_tuple(): + """Test initialization with tuple of velocities.""" + ts = 1234567890.123456 + frame_id = "robot_base" + linear = [1.0, 2.0, 3.0] + angular = [0.1, 0.2, 0.3] + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=(linear, angular), covariance=covariance + ) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.angular.x == 0.1 + assert np.array_equal(twist_cov_stamped.covariance, covariance) + + +def test_twist_with_covariance_stamped_properties(): + """Test convenience properties.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="cmd_vel", twist=twist, covariance=covariance + ) + + # Linear and angular properties + assert twist_cov_stamped.linear.x == 1.0 + assert twist_cov_stamped.linear.y == 2.0 + assert twist_cov_stamped.linear.z == 3.0 + assert twist_cov_stamped.angular.x == 0.1 + assert twist_cov_stamped.angular.y == 0.2 + assert twist_cov_stamped.angular.z == 0.3 + + # Covariance matrix + cov_matrix = twist_cov_stamped.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert np.trace(cov_matrix) == 6.0 + + +def test_twist_with_covariance_stamped_str(): + """Test string representation.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.111, 0.222, 0.333)) + covariance = np.eye(6).flatten() * 2.0 + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="world", twist=twist, covariance=covariance + ) + + str_repr = str(twist_cov_stamped) + assert "TwistWithCovarianceStamped" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "12.000" in str_repr # Trace of 2*identity is 12 + + +def test_twist_with_covariance_stamped_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + ts = 1234567890.123456 + frame_id = "camera_link" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + source = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = TwistWithCovarianceStamped.lcm_decode(binary_msg) + + # Check timestamp (may lose some precision) + assert abs(decoded.ts - ts) < 1e-6 + assert decoded.frame_id == frame_id + + # Check twist + assert decoded.twist.linear.x == 1.0 + assert decoded.twist.linear.y == 2.0 + assert decoded.twist.linear.z == 3.0 + assert decoded.twist.angular.x == 0.1 + assert decoded.twist.angular.y == 0.2 + assert decoded.twist.angular.z == 0.3 + + # Check covariance + assert np.array_equal(decoded.covariance, covariance) + + +def test_twist_with_covariance_stamped_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSTwistWithCovarianceStamped() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "laser" + + # Set twist with covariance + ros_msg.twist = ROSTwistWithCovariance() + ros_msg.twist.twist = ROSTwist() + ros_msg.twist.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.twist.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.twist.covariance = [float(i) for i in range(36)] + + twist_cov_stamped = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + + assert twist_cov_stamped.ts == 1234567890.123456 + assert twist_cov_stamped.frame_id == "laser" + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.linear.y == 2.0 + assert twist_cov_stamped.twist.linear.z == 3.0 + assert twist_cov_stamped.twist.angular.x == 0.1 + assert twist_cov_stamped.twist.angular.y == 0.2 + assert twist_cov_stamped.twist.angular.z == 0.3 + assert np.array_equal(twist_cov_stamped.covariance, np.arange(36)) + + +def test_twist_with_covariance_stamped_to_ros_msg(): + """Test converting to ROS message.""" + ts = 1234567890.567890 + frame_id = "imu" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + ros_msg = twist_cov_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistWithCovarianceStamped) + assert ros_msg.header.frame_id == frame_id + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + + assert ros_msg.twist.twist.linear.x == 1.0 + assert ros_msg.twist.twist.linear.y == 2.0 + assert ros_msg.twist.twist.linear.z == 3.0 + assert ros_msg.twist.twist.angular.x == 0.1 + assert ros_msg.twist.twist.angular.y == 0.2 + assert ros_msg.twist.twist.angular.z == 0.3 + assert list(ros_msg.twist.covariance) == list(range(36)) + + +def test_twist_with_covariance_stamped_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + ts = 2147483647.987654 # Max int32 value for ROS Time.sec + frame_id = "robot_base" + twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) + covariance = np.random.rand(36) + + original = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + ros_msg = original.to_ros_msg() + restored = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + + # Check timestamp (loses some precision in conversion) + assert abs(restored.ts - ts) < 1e-6 + assert restored.frame_id == frame_id + + # Check twist + assert restored.twist.linear.x == original.twist.linear.x + assert restored.twist.linear.y == original.twist.linear.y + assert restored.twist.linear.z == original.twist.linear.z + assert restored.twist.angular.x == original.twist.angular.x + assert restored.twist.angular.y == original.twist.angular.y + assert restored.twist.angular.z == original.twist.angular.z + + # Check covariance + assert np.allclose(restored.covariance, original.covariance) + + +def test_twist_with_covariance_stamped_zero_timestamp(): + """Test that zero timestamp gets replaced with current time.""" + twist_cov_stamped = TwistWithCovarianceStamped(ts=0.0) + + # Should have been replaced with current time + assert twist_cov_stamped.ts > 0 + assert twist_cov_stamped.ts <= time.time() + + +def test_twist_with_covariance_stamped_inheritance(): + """Test that it properly inherits from TwistWithCovariance and Timestamped.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="test", twist=twist, covariance=covariance + ) + + # Should be instance of parent classes + assert isinstance(twist_cov_stamped, TwistWithCovariance) + + # Should have Timestamped attributes + assert hasattr(twist_cov_stamped, "ts") + assert hasattr(twist_cov_stamped, "frame_id") + + # Should have TwistWithCovariance attributes + assert hasattr(twist_cov_stamped, "twist") + assert hasattr(twist_cov_stamped, "covariance") + + +def test_twist_with_covariance_stamped_is_zero(): + """Test is_zero method inheritance.""" + # Zero twist + twist_cov_stamped1 = TwistWithCovarianceStamped() + assert twist_cov_stamped1.is_zero() + assert not twist_cov_stamped1 # Boolean conversion + + # Non-zero twist + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.0)) + twist_cov_stamped2 = TwistWithCovarianceStamped(twist=twist) + assert not twist_cov_stamped2.is_zero() + assert twist_cov_stamped2 # Boolean conversion + + +def test_twist_with_covariance_stamped_sec_nsec(): + """Test the sec_nsec helper function.""" + from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import sec_nsec + + # Test integer seconds + s, ns = sec_nsec(1234567890.0) + assert s == 1234567890 + assert ns == 0 + + # Test fractional seconds + s, ns = sec_nsec(1234567890.123456789) + assert s == 1234567890 + assert abs(ns - 123456789) < 100 # Allow small rounding error + + # Test small fractional seconds + s, ns = sec_nsec(0.000000001) + assert s == 0 + assert ns == 1 + + # Test large timestamp + s, ns = sec_nsec(9999999999.999999999) + # Due to floating point precision, this might round to 10000000000 + assert s in [9999999999, 10000000000] + if s == 9999999999: + assert abs(ns - 999999999) < 10 + else: + assert ns == 0 + + +@pytest.mark.parametrize( + "frame_id", + ["", "map", "odom", "base_link", "cmd_vel", "sensor/velocity/front"], +) +def test_twist_with_covariance_stamped_frame_ids(frame_id): + """Test various frame ID values.""" + twist_cov_stamped = TwistWithCovarianceStamped(frame_id=frame_id) + assert twist_cov_stamped.frame_id == frame_id + + # Test roundtrip through ROS + ros_msg = twist_cov_stamped.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + + restored = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + + +def test_twist_with_covariance_stamped_different_covariances(): + """Test with different covariance patterns.""" + twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.5)) + + # Zero covariance + zero_cov = np.zeros(36) + twist_cov1 = TwistWithCovarianceStamped(twist=twist, covariance=zero_cov) + assert np.all(twist_cov1.covariance == 0.0) + + # Identity covariance + identity_cov = np.eye(6).flatten() + twist_cov2 = TwistWithCovarianceStamped(twist=twist, covariance=identity_cov) + assert np.trace(twist_cov2.covariance_matrix) == 6.0 + + # Full covariance + full_cov = np.random.rand(36) + twist_cov3 = TwistWithCovarianceStamped(twist=twist, covariance=full_cov) + assert np.array_equal(twist_cov3.covariance, full_cov) diff --git a/dimos/msgs/nav_msgs/Odometry.py b/dimos/msgs/nav_msgs/Odometry.py new file mode 100644 index 0000000000..d5c875db20 --- /dev/null +++ b/dimos/msgs/nav_msgs/Odometry.py @@ -0,0 +1,373 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +import numpy as np +from dimos_lcm.nav_msgs import Odometry as LCMOdometry +from nav_msgs.msg import Odometry as ROSOdometry +from plum import dispatch + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Odometry +OdometryConvertable: TypeAlias = ( + LCMOdometry | dict[str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class Odometry(LCMOdometry, Timestamped): + pose: PoseWithCovariance + twist: TwistWithCovariance + msg_name = "nav_msgs.Odometry" + ts: float + frame_id: str + child_frame_id: str + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + child_frame_id: str = "", + pose: PoseWithCovariance | Pose | None = None, + twist: TwistWithCovariance | Twist | None = None, + ) -> None: + """Initialize with timestamp, frame IDs, pose and twist. + + Args: + ts: Timestamp in seconds (defaults to current time if 0) + frame_id: Reference frame ID (e.g., "odom", "map") + child_frame_id: Child frame ID (e.g., "base_link", "base_footprint") + pose: Pose with covariance (or just Pose, covariance will be zero) + twist: Twist with covariance (or just Twist, covariance will be zero) + """ + self.ts = ts if ts != 0 else time.time() + self.frame_id = frame_id + self.child_frame_id = child_frame_id + + # Handle pose + if pose is None: + self.pose = PoseWithCovariance() + elif isinstance(pose, PoseWithCovariance): + self.pose = pose + elif isinstance(pose, Pose): + self.pose = PoseWithCovariance(pose) + else: + self.pose = PoseWithCovariance(Pose(pose)) + + # Handle twist + if twist is None: + self.twist = TwistWithCovariance() + elif isinstance(twist, TwistWithCovariance): + self.twist = twist + elif isinstance(twist, Twist): + self.twist = TwistWithCovariance(twist) + else: + self.twist = TwistWithCovariance(Twist(twist)) + + @dispatch + def __init__(self, odometry: Odometry) -> None: + """Initialize from another Odometry (copy constructor).""" + self.ts = odometry.ts + self.frame_id = odometry.frame_id + self.child_frame_id = odometry.child_frame_id + self.pose = PoseWithCovariance(odometry.pose) + self.twist = TwistWithCovariance(odometry.twist) + + @dispatch + def __init__(self, lcm_odometry: LCMOdometry) -> None: + """Initialize from an LCM Odometry.""" + self.ts = lcm_odometry.header.stamp.sec + (lcm_odometry.header.stamp.nsec / 1_000_000_000) + self.frame_id = lcm_odometry.header.frame_id + self.child_frame_id = lcm_odometry.child_frame_id + self.pose = PoseWithCovariance(lcm_odometry.pose) + self.twist = TwistWithCovariance(lcm_odometry.twist) + + @dispatch + def __init__( + self, + odometry_dict: dict[ + str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist + ], + ) -> None: + """Initialize from a dictionary.""" + self.ts = odometry_dict.get("ts", odometry_dict.get("timestamp", time.time())) + self.frame_id = odometry_dict.get("frame_id", "") + self.child_frame_id = odometry_dict.get("child_frame_id", "") + + # Handle pose + pose = odometry_dict.get("pose") + if pose is None: + self.pose = PoseWithCovariance() + elif isinstance(pose, PoseWithCovariance): + self.pose = pose + elif isinstance(pose, Pose): + self.pose = PoseWithCovariance(pose) + else: + self.pose = PoseWithCovariance(Pose(pose)) + + # Handle twist + twist = odometry_dict.get("twist") + if twist is None: + self.twist = TwistWithCovariance() + elif isinstance(twist, TwistWithCovariance): + self.twist = twist + elif isinstance(twist, Twist): + self.twist = TwistWithCovariance(twist) + else: + self.twist = TwistWithCovariance(Twist(twist)) + + @property + def position(self) -> Vector3: + """Get position from pose.""" + return self.pose.position + + @property + def orientation(self): + """Get orientation from pose.""" + return self.pose.orientation + + @property + def linear_velocity(self) -> Vector3: + """Get linear velocity from twist.""" + return self.twist.linear + + @property + def angular_velocity(self) -> Vector3: + """Get angular velocity from twist.""" + return self.twist.angular + + @property + def x(self) -> float: + """X position.""" + return self.pose.x + + @property + def y(self) -> float: + """Y position.""" + return self.pose.y + + @property + def z(self) -> float: + """Z position.""" + return self.pose.z + + @property + def vx(self) -> float: + """Linear velocity in X.""" + return self.twist.linear.x + + @property + def vy(self) -> float: + """Linear velocity in Y.""" + return self.twist.linear.y + + @property + def vz(self) -> float: + """Linear velocity in Z.""" + return self.twist.linear.z + + @property + def wx(self) -> float: + """Angular velocity around X (roll rate).""" + return self.twist.angular.x + + @property + def wy(self) -> float: + """Angular velocity around Y (pitch rate).""" + return self.twist.angular.y + + @property + def wz(self) -> float: + """Angular velocity around Z (yaw rate).""" + return self.twist.angular.z + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.pose.roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.pose.pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.pose.yaw + + def __repr__(self) -> str: + return ( + f"Odometry(ts={self.ts:.6f}, frame_id='{self.frame_id}', " + f"child_frame_id='{self.child_frame_id}', pose={self.pose!r}, twist={self.twist!r})" + ) + + def __str__(self) -> str: + return ( + f"Odometry:\n" + f" Timestamp: {self.ts:.6f}\n" + f" Frame: {self.frame_id} -> {self.child_frame_id}\n" + f" Position: [{self.x:.3f}, {self.y:.3f}, {self.z:.3f}]\n" + f" Orientation: [roll={self.roll:.3f}, pitch={self.pitch:.3f}, yaw={self.yaw:.3f}]\n" + f" Linear Velocity: [{self.vx:.3f}, {self.vy:.3f}, {self.vz:.3f}]\n" + f" Angular Velocity: [{self.wx:.3f}, {self.wy:.3f}, {self.wz:.3f}]" + ) + + def __eq__(self, other) -> bool: + """Check if two Odometry messages are equal.""" + if not isinstance(other, Odometry): + return False + return ( + abs(self.ts - other.ts) < 1e-6 + and self.frame_id == other.frame_id + and self.child_frame_id == other.child_frame_id + and self.pose == other.pose + and self.twist == other.twist + ) + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMOdometry() + + # Set header + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + lcm_msg.child_frame_id = self.child_frame_id + + # Set pose with covariance + lcm_msg.pose.pose = self.pose.pose + if isinstance(self.pose.covariance, np.ndarray): + lcm_msg.pose.covariance = self.pose.covariance.tolist() + else: + lcm_msg.pose.covariance = list(self.pose.covariance) + + # Set twist with covariance + lcm_msg.twist.twist = self.twist.twist + if isinstance(self.twist.covariance, np.ndarray): + lcm_msg.twist.covariance = self.twist.covariance.tolist() + else: + lcm_msg.twist.covariance = list(self.twist.covariance) + + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "Odometry": + """Decode from LCM binary format.""" + lcm_msg = LCMOdometry.lcm_decode(data) + + # Extract timestamp + ts = lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000) + + # Create pose with covariance + pose = Pose( + position=[ + lcm_msg.pose.pose.position.x, + lcm_msg.pose.pose.position.y, + lcm_msg.pose.pose.position.z, + ], + orientation=[ + lcm_msg.pose.pose.orientation.x, + lcm_msg.pose.pose.orientation.y, + lcm_msg.pose.pose.orientation.z, + lcm_msg.pose.pose.orientation.w, + ], + ) + pose_with_cov = PoseWithCovariance(pose, lcm_msg.pose.covariance) + + # Create twist with covariance + twist = Twist( + linear=[ + lcm_msg.twist.twist.linear.x, + lcm_msg.twist.twist.linear.y, + lcm_msg.twist.twist.linear.z, + ], + angular=[ + lcm_msg.twist.twist.angular.x, + lcm_msg.twist.twist.angular.y, + lcm_msg.twist.twist.angular.z, + ], + ) + twist_with_cov = TwistWithCovariance(twist, lcm_msg.twist.covariance) + + return cls( + ts=ts, + frame_id=lcm_msg.header.frame_id, + child_frame_id=lcm_msg.child_frame_id, + pose=pose_with_cov, + twist=twist_with_cov, + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSOdometry) -> "Odometry": + """Create an Odometry from a ROS nav_msgs/Odometry message. + + Args: + ros_msg: ROS Odometry message + + Returns: + Odometry instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose and twist with covariance + pose_with_cov = PoseWithCovariance.from_ros_msg(ros_msg.pose) + twist_with_cov = TwistWithCovariance.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + child_frame_id=ros_msg.child_frame_id, + pose=pose_with_cov, + twist=twist_with_cov, + ) + + def to_ros_msg(self) -> ROSOdometry: + """Convert to a ROS nav_msgs/Odometry message. + + Returns: + ROS Odometry message + """ + ros_msg = ROSOdometry() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set child frame ID + ros_msg.child_frame_id = self.child_frame_id + + # Set pose with covariance + ros_msg.pose = self.pose.to_ros_msg() + + # Set twist with covariance + ros_msg.twist = self.twist.to_ros_msg() + + return ros_msg diff --git a/dimos/msgs/nav_msgs/Path.py b/dimos/msgs/nav_msgs/Path.py index 8fca1cf25f..bb0b509369 100644 --- a/dimos/msgs/nav_msgs/Path.py +++ b/dimos/msgs/nav_msgs/Path.py @@ -26,6 +26,7 @@ from dimos_lcm.nav_msgs import Path as LCMPath from dimos_lcm.std_msgs import Header as LCMHeader from dimos_lcm.std_msgs import Time as LCMTime +from nav_msgs.msg import Path as ROSPath from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped @@ -187,3 +188,42 @@ def reverse(self) -> "Path": def clear(self) -> None: """Clear all poses from this path (mutable).""" self.poses.clear() + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPath) -> "Path": + """Create a Path from a ROS nav_msgs/Path message. + + Args: + ros_msg: ROS Path message + + Returns: + Path instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert poses + poses = [] + for ros_pose_stamped in ros_msg.poses: + poses.append(PoseStamped.from_ros_msg(ros_pose_stamped)) + + return cls(ts=ts, frame_id=ros_msg.header.frame_id, poses=poses) + + def to_ros_msg(self) -> ROSPath: + """Convert to a ROS nav_msgs/Path message. + + Returns: + ROS Path message + """ + ros_msg = ROSPath() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Convert poses + for pose in self.poses: + ros_msg.poses.append(pose.to_ros_msg()) + + return ros_msg diff --git a/dimos/msgs/nav_msgs/__init__.py b/dimos/msgs/nav_msgs/__init__.py index 3e4241daa0..9ea87f3f78 100644 --- a/dimos/msgs/nav_msgs/__init__.py +++ b/dimos/msgs/nav_msgs/__init__.py @@ -1,4 +1,5 @@ from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, MapMetaData, OccupancyGrid from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.nav_msgs.Odometry import Odometry -__all__ = ["Path", "OccupancyGrid", "MapMetaData", "CostValues"] +__all__ = ["Path", "OccupancyGrid", "MapMetaData", "CostValues", "Odometry"] diff --git a/dimos/msgs/nav_msgs/test_Odometry.py b/dimos/msgs/nav_msgs/test_Odometry.py new file mode 100644 index 0000000000..6961cc2bed --- /dev/null +++ b/dimos/msgs/nav_msgs/test_Odometry.py @@ -0,0 +1,466 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest +from dimos_lcm.nav_msgs import Odometry as LCMOdometry +from nav_msgs.msg import Odometry as ROSOdometry +from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance +from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance +from geometry_msgs.msg import Pose as ROSPose +from geometry_msgs.msg import Twist as ROSTwist +from geometry_msgs.msg import Point as ROSPoint +from geometry_msgs.msg import Quaternion as ROSQuaternion +from geometry_msgs.msg import Vector3 as ROSVector3 +from std_msgs.msg import Header as ROSHeader +from builtin_interfaces.msg import Time as ROSTime + +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion + + +def test_odometry_default_init(): + """Test default initialization.""" + odom = Odometry() + + # Should have current timestamp + assert odom.ts > 0 + assert odom.frame_id == "" + assert odom.child_frame_id == "" + + # Pose should be at origin with identity orientation + assert odom.pose.position.x == 0.0 + assert odom.pose.position.y == 0.0 + assert odom.pose.position.z == 0.0 + assert odom.pose.orientation.w == 1.0 + + # Twist should be zero + assert odom.twist.linear.x == 0.0 + assert odom.twist.linear.y == 0.0 + assert odom.twist.linear.z == 0.0 + assert odom.twist.angular.x == 0.0 + assert odom.twist.angular.y == 0.0 + assert odom.twist.angular.z == 0.0 + + # Covariances should be zero + assert np.all(odom.pose.covariance == 0.0) + assert np.all(odom.twist.covariance == 0.0) + + +def test_odometry_with_frames(): + """Test initialization with frame IDs.""" + ts = 1234567890.123456 + frame_id = "odom" + child_frame_id = "base_link" + + odom = Odometry(ts=ts, frame_id=frame_id, child_frame_id=child_frame_id) + + assert odom.ts == ts + assert odom.frame_id == frame_id + assert odom.child_frame_id == child_frame_id + + +def test_odometry_with_pose_and_twist(): + """Test initialization with pose and twist.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + + odom = Odometry(ts=1000.0, frame_id="odom", child_frame_id="base_link", pose=pose, twist=twist) + + assert odom.pose.pose.position.x == 1.0 + assert odom.pose.pose.position.y == 2.0 + assert odom.pose.pose.position.z == 3.0 + assert odom.twist.twist.linear.x == 0.5 + assert odom.twist.twist.angular.z == 0.1 + + +def test_odometry_with_covariances(): + """Test initialization with pose and twist with covariances.""" + pose = Pose(1.0, 2.0, 3.0) + pose_cov = np.arange(36, dtype=float) + pose_with_cov = PoseWithCovariance(pose, pose_cov) + + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + twist_cov = np.arange(36, 72, dtype=float) + twist_with_cov = TwistWithCovariance(twist, twist_cov) + + odom = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=pose_with_cov, + twist=twist_with_cov, + ) + + assert odom.pose.position.x == 1.0 + assert np.array_equal(odom.pose.covariance, pose_cov) + assert odom.twist.linear.x == 0.5 + assert np.array_equal(odom.twist.covariance, twist_cov) + + +def test_odometry_copy_constructor(): + """Test copy constructor.""" + original = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + copy = Odometry(original) + + assert copy == original + assert copy is not original + assert copy.pose is not original.pose + assert copy.twist is not original.twist + + +def test_odometry_dict_init(): + """Test initialization from dictionary.""" + odom_dict = { + "ts": 1000.0, + "frame_id": "odom", + "child_frame_id": "base_link", + "pose": Pose(1.0, 2.0, 3.0), + "twist": Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + } + + odom = Odometry(odom_dict) + + assert odom.ts == 1000.0 + assert odom.frame_id == "odom" + assert odom.child_frame_id == "base_link" + assert odom.pose.position.x == 1.0 + assert odom.twist.linear.x == 0.5 + + +def test_odometry_properties(): + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + + odom = Odometry(ts=1000.0, frame_id="odom", child_frame_id="base_link", pose=pose, twist=twist) + + # Position properties + assert odom.x == 1.0 + assert odom.y == 2.0 + assert odom.z == 3.0 + assert odom.position.x == 1.0 + assert odom.position.y == 2.0 + assert odom.position.z == 3.0 + + # Orientation properties + assert odom.orientation.x == 0.1 + assert odom.orientation.y == 0.2 + assert odom.orientation.z == 0.3 + assert odom.orientation.w == 0.9 + + # Velocity properties + assert odom.vx == 0.5 + assert odom.vy == 0.6 + assert odom.vz == 0.7 + assert odom.linear_velocity.x == 0.5 + assert odom.linear_velocity.y == 0.6 + assert odom.linear_velocity.z == 0.7 + + # Angular velocity properties + assert odom.wx == 0.1 + assert odom.wy == 0.2 + assert odom.wz == 0.3 + assert odom.angular_velocity.x == 0.1 + assert odom.angular_velocity.y == 0.2 + assert odom.angular_velocity.z == 0.3 + + # Euler angles + assert odom.roll == pose.roll + assert odom.pitch == pose.pitch + assert odom.yaw == pose.yaw + + +def test_odometry_str_repr(): + """Test string representations.""" + odom = Odometry( + ts=1234567890.123456, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.234, 2.567, 3.891), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + repr_str = repr(odom) + assert "Odometry" in repr_str + assert "1234567890.123456" in repr_str + assert "odom" in repr_str + assert "base_link" in repr_str + + str_repr = str(odom) + assert "Odometry" in str_repr + assert "odom -> base_link" in str_repr + assert "1.234" in str_repr + assert "0.500" in str_repr + + +def test_odometry_equality(): + """Test equality comparison.""" + odom1 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + odom2 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + odom3 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.1, 2.0, 3.0), # Different position + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + assert odom1 == odom2 + assert odom1 != odom3 + assert odom1 != "not an odometry" + + +def test_odometry_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = np.arange(36, dtype=float) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + twist_cov = np.arange(36, 72, dtype=float) + + source = Odometry( + ts=1234567890.123456, + frame_id="odom", + child_frame_id="base_link", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = Odometry.lcm_decode(binary_msg) + + # Check values (allowing for timestamp precision loss) + assert abs(decoded.ts - source.ts) < 1e-6 + assert decoded.frame_id == source.frame_id + assert decoded.child_frame_id == source.child_frame_id + assert decoded.pose == source.pose + assert decoded.twist == source.twist + + +def test_odometry_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSOdometry() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "odom" + ros_msg.child_frame_id = "base_link" + + # Set pose with covariance + ros_msg.pose = ROSPoseWithCovariance() + ros_msg.pose.pose = ROSPose() + ros_msg.pose.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.pose.covariance = [float(i) for i in range(36)] + + # Set twist with covariance + ros_msg.twist = ROSTwistWithCovariance() + ros_msg.twist.twist = ROSTwist() + ros_msg.twist.twist.linear = ROSVector3(x=0.5, y=0.6, z=0.7) + ros_msg.twist.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.twist.covariance = [float(i) for i in range(36, 72)] + + odom = Odometry.from_ros_msg(ros_msg) + + assert odom.ts == 1234567890.123456 + assert odom.frame_id == "odom" + assert odom.child_frame_id == "base_link" + assert odom.pose.position.x == 1.0 + assert odom.twist.linear.x == 0.5 + assert np.array_equal(odom.pose.covariance, np.arange(36)) + assert np.array_equal(odom.twist.covariance, np.arange(36, 72)) + + +def test_odometry_to_ros_msg(): + """Test converting to ROS message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = np.arange(36, dtype=float) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + twist_cov = np.arange(36, 72, dtype=float) + + odom = Odometry( + ts=1234567890.567890, + frame_id="odom", + child_frame_id="base_link", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + ros_msg = odom.to_ros_msg() + + assert isinstance(ros_msg, ROSOdometry) + assert ros_msg.header.frame_id == "odom" + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + assert ros_msg.child_frame_id == "base_link" + + # Check pose + assert ros_msg.pose.pose.position.x == 1.0 + assert ros_msg.pose.pose.position.y == 2.0 + assert ros_msg.pose.pose.position.z == 3.0 + assert ros_msg.pose.pose.orientation.x == 0.1 + assert ros_msg.pose.pose.orientation.y == 0.2 + assert ros_msg.pose.pose.orientation.z == 0.3 + assert ros_msg.pose.pose.orientation.w == 0.9 + assert list(ros_msg.pose.covariance) == list(range(36)) + + # Check twist + assert ros_msg.twist.twist.linear.x == 0.5 + assert ros_msg.twist.twist.linear.y == 0.6 + assert ros_msg.twist.twist.linear.z == 0.7 + assert ros_msg.twist.twist.angular.x == 0.1 + assert ros_msg.twist.twist.angular.y == 0.2 + assert ros_msg.twist.twist.angular.z == 0.3 + assert list(ros_msg.twist.covariance) == list(range(36, 72)) + + +def test_odometry_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + pose_cov = np.random.rand(36) + twist = Twist(Vector3(0.55, 0.65, 0.75), Vector3(0.15, 0.25, 0.35)) + twist_cov = np.random.rand(36) + + original = Odometry( + ts=2147483647.987654, # Max int32 value for ROS Time.sec + frame_id="world", + child_frame_id="robot", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + ros_msg = original.to_ros_msg() + restored = Odometry.from_ros_msg(ros_msg) + + # Check values (allowing for timestamp precision loss) + assert abs(restored.ts - original.ts) < 1e-6 + assert restored.frame_id == original.frame_id + assert restored.child_frame_id == original.child_frame_id + assert restored.pose == original.pose + assert restored.twist == original.twist + + +def test_odometry_zero_timestamp(): + """Test that zero timestamp gets replaced with current time.""" + odom = Odometry(ts=0.0) + + # Should have been replaced with current time + assert odom.ts > 0 + assert odom.ts <= time.time() + + +def test_odometry_with_just_pose(): + """Test initialization with just a Pose (no covariance).""" + pose = Pose(1.0, 2.0, 3.0) + + odom = Odometry(pose=pose) + + assert odom.pose.position.x == 1.0 + assert odom.pose.position.y == 2.0 + assert odom.pose.position.z == 3.0 + assert np.all(odom.pose.covariance == 0.0) # Should have zero covariance + assert np.all(odom.twist.covariance == 0.0) # Twist should also be zero + + +def test_odometry_with_just_twist(): + """Test initialization with just a Twist (no covariance).""" + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + + odom = Odometry(twist=twist) + + assert odom.twist.linear.x == 0.5 + assert odom.twist.angular.z == 0.1 + assert np.all(odom.twist.covariance == 0.0) # Should have zero covariance + assert np.all(odom.pose.covariance == 0.0) # Pose should also be zero + + +@pytest.mark.parametrize( + "frame_id,child_frame_id", + [ + ("odom", "base_link"), + ("map", "odom"), + ("world", "robot"), + ("base_link", "camera_link"), + ("", ""), # Empty frames + ], +) +def test_odometry_frame_combinations(frame_id, child_frame_id): + """Test various frame ID combinations.""" + odom = Odometry(frame_id=frame_id, child_frame_id=child_frame_id) + + assert odom.frame_id == frame_id + assert odom.child_frame_id == child_frame_id + + # Test roundtrip through ROS + ros_msg = odom.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + assert ros_msg.child_frame_id == child_frame_id + + restored = Odometry.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + assert restored.child_frame_id == child_frame_id + + +def test_odometry_typical_robot_scenario(): + """Test a typical robot odometry scenario.""" + # Robot moving forward at 0.5 m/s with slight rotation + odom = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_footprint", + pose=Pose(10.0, 5.0, 0.0, 0.0, 0.0, np.sin(0.1), np.cos(0.1)), # 0.2 rad yaw + twist=Twist( + Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.05) + ), # Moving forward, turning slightly + ) + + # Check we can access all the typical properties + assert odom.x == 10.0 + assert odom.y == 5.0 + assert odom.z == 0.0 + assert abs(odom.yaw - 0.2) < 0.01 # Approximately 0.2 radians + assert odom.vx == 0.5 # Forward velocity + assert odom.wz == 0.05 # Yaw rate diff --git a/dimos/msgs/nav_msgs/test_Path.py b/dimos/msgs/nav_msgs/test_Path.py index 0a34245448..e7156d48c0 100644 --- a/dimos/msgs/nav_msgs/test_Path.py +++ b/dimos/msgs/nav_msgs/test_Path.py @@ -15,6 +15,7 @@ import time import pytest +from nav_msgs.msg import Path as ROSPath from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion @@ -288,3 +289,97 @@ def test_str_representation(): path.push_mut(create_test_pose(1, 1, 0)) path.push_mut(create_test_pose(2, 2, 0)) assert str(path) == "Path(frame_id='map', poses=2)" + + +def test_path_from_ros_msg(): + """Test creating a Path from a ROS Path message.""" + ros_msg = ROSPath() + ros_msg.header.frame_id = "map" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + + # Add some poses + for i in range(3): + from geometry_msgs.msg import PoseStamped as ROSPoseStamped + + ros_pose = ROSPoseStamped() + ros_pose.header.frame_id = "map" + ros_pose.header.stamp.sec = 123 + i + ros_pose.header.stamp.nanosec = 0 + ros_pose.pose.position.x = float(i) + ros_pose.pose.position.y = float(i * 2) + ros_pose.pose.position.z = float(i * 3) + ros_pose.pose.orientation.x = 0.0 + ros_pose.pose.orientation.y = 0.0 + ros_pose.pose.orientation.z = 0.0 + ros_pose.pose.orientation.w = 1.0 + ros_msg.poses.append(ros_pose) + + path = Path.from_ros_msg(ros_msg) + + assert path.frame_id == "map" + assert path.ts == 123.456 + assert len(path.poses) == 3 + + for i, pose in enumerate(path.poses): + assert pose.position.x == float(i) + assert pose.position.y == float(i * 2) + assert pose.position.z == float(i * 3) + assert pose.orientation.w == 1.0 + + +def test_path_to_ros_msg(): + """Test converting a Path to a ROS Path message.""" + poses = [ + PoseStamped( + ts=124.0 + i, frame_id="odom", position=[i, i * 2, i * 3], orientation=[0, 0, 0, 1] + ) + for i in range(3) + ] + + path = Path(ts=123.456, frame_id="odom", poses=poses) + + ros_msg = path.to_ros_msg() + + assert isinstance(ros_msg, ROSPath) + assert ros_msg.header.frame_id == "odom" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert len(ros_msg.poses) == 3 + + for i, ros_pose in enumerate(ros_msg.poses): + assert ros_pose.pose.position.x == float(i) + assert ros_pose.pose.position.y == float(i * 2) + assert ros_pose.pose.position.z == float(i * 3) + assert ros_pose.pose.orientation.w == 1.0 + + +def test_path_ros_roundtrip(): + """Test round-trip conversion between Path and ROS Path.""" + poses = [ + PoseStamped( + ts=100.0 + i * 0.1, + frame_id="world", + position=[i * 1.5, i * 2.5, i * 3.5], + orientation=[0.1, 0.2, 0.3, 0.9], + ) + for i in range(3) + ] + + original = Path(ts=99.789, frame_id="world", poses=poses) + + ros_msg = original.to_ros_msg() + restored = Path.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert len(restored.poses) == len(original.poses) + + for orig_pose, rest_pose in zip(original.poses, restored.poses): + assert rest_pose.position.x == orig_pose.position.x + assert rest_pose.position.y == orig_pose.position.y + assert rest_pose.position.z == orig_pose.position.z + assert rest_pose.orientation.x == orig_pose.orientation.x + assert rest_pose.orientation.y == orig_pose.orientation.y + assert rest_pose.orientation.z == orig_pose.orientation.z + assert rest_pose.orientation.w == orig_pose.orientation.w diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py index 9ccba615b2..3d61c37a16 100644 --- a/dimos/msgs/tf2_msgs/TFMessage.py +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -34,6 +34,8 @@ from dimos_lcm.std_msgs import Header as LCMHeader from dimos_lcm.std_msgs import Time as LCMTime from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage +from tf2_msgs.msg import TFMessage as ROSTFMessage +from geometry_msgs.msg import TransformStamped as ROSTransformStamped from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -119,3 +121,35 @@ def __str__(self) -> str: for i, transform in enumerate(self.transforms): lines.append(f" [{i}] {transform.frame_id} @ {transform.ts:.3f}") return "\n".join(lines) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTFMessage) -> "TFMessage": + """Create a TFMessage from a ROS tf2_msgs/TFMessage message. + + Args: + ros_msg: ROS TFMessage message + + Returns: + TFMessage instance + """ + transforms = [] + for ros_transform_stamped in ros_msg.transforms: + # Convert from ROS TransformStamped to our Transform + transform = Transform.from_ros_transform_stamped(ros_transform_stamped) + transforms.append(transform) + + return cls(*transforms) + + def to_ros_msg(self) -> ROSTFMessage: + """Convert to a ROS tf2_msgs/TFMessage message. + + Returns: + ROS TFMessage message + """ + ros_msg = ROSTFMessage() + + # Convert each Transform to ROS TransformStamped + for transform in self.transforms: + ros_msg.transforms.append(transform.to_ros_transform_stamped()) + + return ros_msg diff --git a/dimos/msgs/tf2_msgs/test_TFMessage.py b/dimos/msgs/tf2_msgs/test_TFMessage.py index 6eec4cbdcc..4bb5b2f3b0 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage.py @@ -14,6 +14,7 @@ import pytest from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage +from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 from dimos.msgs.tf2_msgs import TFMessage @@ -107,3 +108,153 @@ def test_tfmessage_lcm_encode_decode(): assert ts2.child_frame_id == "target" assert ts2.transform.rotation.z == 0.707 assert ts2.transform.rotation.w == 0.707 + + +def test_tfmessage_from_ros_msg(): + """Test creating a TFMessage from a ROS TFMessage message.""" + from geometry_msgs.msg import TransformStamped as ROSTransformStamped + + ros_msg = ROSTFMessage() + + # Add first transform + tf1 = ROSTransformStamped() + tf1.header.frame_id = "world" + tf1.header.stamp.sec = 123 + tf1.header.stamp.nanosec = 456000000 + tf1.child_frame_id = "robot" + tf1.transform.translation.x = 1.0 + tf1.transform.translation.y = 2.0 + tf1.transform.translation.z = 3.0 + tf1.transform.rotation.x = 0.0 + tf1.transform.rotation.y = 0.0 + tf1.transform.rotation.z = 0.0 + tf1.transform.rotation.w = 1.0 + ros_msg.transforms.append(tf1) + + # Add second transform + tf2 = ROSTransformStamped() + tf2.header.frame_id = "robot" + tf2.header.stamp.sec = 124 + tf2.header.stamp.nanosec = 567000000 + tf2.child_frame_id = "sensor" + tf2.transform.translation.x = 4.0 + tf2.transform.translation.y = 5.0 + tf2.transform.translation.z = 6.0 + tf2.transform.rotation.x = 0.0 + tf2.transform.rotation.y = 0.0 + tf2.transform.rotation.z = 0.707 + tf2.transform.rotation.w = 0.707 + ros_msg.transforms.append(tf2) + + # Convert to TFMessage + tfmsg = TFMessage.from_ros_msg(ros_msg) + + assert len(tfmsg) == 2 + + # Check first transform + assert tfmsg[0].frame_id == "world" + assert tfmsg[0].child_frame_id == "robot" + assert tfmsg[0].ts == 123.456 + assert tfmsg[0].translation.x == 1.0 + assert tfmsg[0].translation.y == 2.0 + assert tfmsg[0].translation.z == 3.0 + assert tfmsg[0].rotation.w == 1.0 + + # Check second transform + assert tfmsg[1].frame_id == "robot" + assert tfmsg[1].child_frame_id == "sensor" + assert tfmsg[1].ts == 124.567 + assert tfmsg[1].translation.x == 4.0 + assert tfmsg[1].translation.y == 5.0 + assert tfmsg[1].translation.z == 6.0 + assert tfmsg[1].rotation.z == 0.707 + assert tfmsg[1].rotation.w == 0.707 + + +def test_tfmessage_to_ros_msg(): + """Test converting a TFMessage to a ROS TFMessage message.""" + # Create transforms + tf1 = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="base_link", + ts=123.456, + ) + tf2 = Transform( + translation=Vector3(7.0, 8.0, 9.0), + rotation=Quaternion(0.1, 0.2, 0.3, 0.9), + frame_id="base_link", + child_frame_id="lidar", + ts=125.789, + ) + + tfmsg = TFMessage(tf1, tf2) + + # Convert to ROS message + ros_msg = tfmsg.to_ros_msg() + + assert isinstance(ros_msg, ROSTFMessage) + assert len(ros_msg.transforms) == 2 + + # Check first transform + assert ros_msg.transforms[0].header.frame_id == "map" + assert ros_msg.transforms[0].child_frame_id == "base_link" + assert ros_msg.transforms[0].header.stamp.sec == 123 + assert ros_msg.transforms[0].header.stamp.nanosec == 456000000 + assert ros_msg.transforms[0].transform.translation.x == 1.0 + assert ros_msg.transforms[0].transform.translation.y == 2.0 + assert ros_msg.transforms[0].transform.translation.z == 3.0 + assert ros_msg.transforms[0].transform.rotation.w == 1.0 + + # Check second transform + assert ros_msg.transforms[1].header.frame_id == "base_link" + assert ros_msg.transforms[1].child_frame_id == "lidar" + assert ros_msg.transforms[1].header.stamp.sec == 125 + assert ros_msg.transforms[1].header.stamp.nanosec == 789000000 + assert ros_msg.transforms[1].transform.translation.x == 7.0 + assert ros_msg.transforms[1].transform.translation.y == 8.0 + assert ros_msg.transforms[1].transform.translation.z == 9.0 + assert ros_msg.transforms[1].transform.rotation.x == 0.1 + assert ros_msg.transforms[1].transform.rotation.y == 0.2 + assert ros_msg.transforms[1].transform.rotation.z == 0.3 + assert ros_msg.transforms[1].transform.rotation.w == 0.9 + + +def test_tfmessage_ros_roundtrip(): + """Test round-trip conversion between TFMessage and ROS TFMessage.""" + # Create transforms with various properties + tf1 = Transform( + translation=Vector3(1.5, 2.5, 3.5), + rotation=Quaternion(0.15, 0.25, 0.35, 0.85), + frame_id="odom", + child_frame_id="base_footprint", + ts=100.123, + ) + tf2 = Transform( + translation=Vector3(0.1, 0.2, 0.3), + rotation=Quaternion(0.0, 0.0, 0.383, 0.924), + frame_id="base_footprint", + child_frame_id="camera", + ts=100.456, + ) + + original = TFMessage(tf1, tf2) + + # Convert to ROS and back + ros_msg = original.to_ros_msg() + restored = TFMessage.from_ros_msg(ros_msg) + + assert len(restored) == len(original) + + for orig_tf, rest_tf in zip(original, restored): + assert rest_tf.frame_id == orig_tf.frame_id + assert rest_tf.child_frame_id == orig_tf.child_frame_id + assert rest_tf.ts == orig_tf.ts + assert rest_tf.translation.x == orig_tf.translation.x + assert rest_tf.translation.y == orig_tf.translation.y + assert rest_tf.translation.z == orig_tf.translation.z + assert rest_tf.rotation.x == orig_tf.rotation.x + assert rest_tf.rotation.y == orig_tf.rotation.y + assert rest_tf.rotation.z == orig_tf.rotation.z + assert rest_tf.rotation.w == orig_tf.rotation.w diff --git a/dimos/robot/ros_bridge.py b/dimos/robot/ros_bridge.py new file mode 100644 index 0000000000..55dea29442 --- /dev/null +++ b/dimos/robot/ros_bridge.py @@ -0,0 +1,193 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time +from typing import Dict, Any, Type, Literal, Optional +from enum import Enum + +import rclpy +from rclpy.executors import MultiThreadedExecutor +from rclpy.node import Node +from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy, QoSDurabilityPolicy + +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.ros_bridge", level=logging.INFO) + + +class BridgeDirection(Enum): + """Direction of message bridging.""" + + ROS_TO_DIMOS = "ros_to_dimos" + DIMOS_TO_ROS = "dimos_to_ros" + + +class ROSBridge: + """Unidirectional bridge between ROS and DIMOS for message passing.""" + + def __init__(self, node_name: str = "dimos_ros_bridge"): + """Initialize the ROS-DIMOS bridge. + + Args: + node_name: Name for the ROS node (default: "dimos_ros_bridge") + """ + if not rclpy.ok(): + rclpy.init() + + self.node = Node(node_name) + self.lcm = LCM() + self.lcm.start() + + self._executor = MultiThreadedExecutor() + self._executor.add_node(self.node) + + self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) + self._spin_thread.start() + + self._bridges: Dict[str, Dict[str, Any]] = {} + + self._qos = QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=10, + ) + + logger.info(f"ROSBridge initialized with node name: {node_name}") + + def _ros_spin(self): + """Background thread for spinning ROS executor.""" + try: + self._executor.spin() + finally: + self._executor.shutdown() + + def add_topic( + self, + topic_name: str, + dimos_type: Type, + ros_type: Type, + direction: BridgeDirection, + remap_topic: Optional[str] = None, + ) -> None: + """Add unidirectional bridging for a topic. + + Args: + topic_name: Name of the topic (e.g., "/cmd_vel") + dimos_type: DIMOS message type (e.g., dimos.msgs.geometry_msgs.Twist) + ros_type: ROS message type (e.g., geometry_msgs.msg.Twist) + direction: Direction of bridging (ROS_TO_DIMOS or DIMOS_TO_ROS) + remap_topic: Optional remapped topic name for the other side + """ + if topic_name in self._bridges: + logger.warning(f"Topic {topic_name} already bridged") + return + + # Determine actual topic names for each side + ros_topic_name = topic_name + dimos_topic_name = topic_name + + if remap_topic: + if direction == BridgeDirection.ROS_TO_DIMOS: + dimos_topic_name = remap_topic + else: # DIMOS_TO_ROS + ros_topic_name = remap_topic + + # Create DIMOS/LCM topic + dimos_topic = Topic(dimos_topic_name, dimos_type) + + ros_subscription = None + ros_publisher = None + dimos_subscription = None + + if direction == BridgeDirection.ROS_TO_DIMOS: + + def ros_callback(msg): + self._ros_to_dimos(msg, dimos_topic, dimos_type, topic_name) + + ros_subscription = self.node.create_subscription( + ros_type, ros_topic_name, ros_callback, self._qos + ) + logger.info(f" ROS → DIMOS: Subscribing to ROS topic {ros_topic_name}") + + elif direction == BridgeDirection.DIMOS_TO_ROS: + ros_publisher = self.node.create_publisher(ros_type, ros_topic_name, self._qos) + + def dimos_callback(msg, _topic): + self._dimos_to_ros(msg, ros_publisher, topic_name) + + dimos_subscription = self.lcm.subscribe(dimos_topic, dimos_callback) + logger.info(f" DIMOS → ROS: Subscribing to DIMOS topic {dimos_topic_name}") + else: + raise ValueError(f"Invalid bridge direction: {direction}") + + self._bridges[topic_name] = { + "dimos_topic": dimos_topic, + "dimos_type": dimos_type, + "ros_type": ros_type, + "ros_subscription": ros_subscription, + "ros_publisher": ros_publisher, + "dimos_subscription": dimos_subscription, + "direction": direction, + "ros_topic_name": ros_topic_name, + "dimos_topic_name": dimos_topic_name, + } + + direction_str = { + BridgeDirection.ROS_TO_DIMOS: "ROS → DIMOS", + BridgeDirection.DIMOS_TO_ROS: "DIMOS → ROS", + }[direction] + + logger.info(f"Bridged topic: {topic_name} ({direction_str})") + if remap_topic: + logger.info(f" Remapped: ROS '{ros_topic_name}' ↔ DIMOS '{dimos_topic_name}'") + logger.info(f" DIMOS type: {dimos_type.__name__}, ROS type: {ros_type.__name__}") + + def _ros_to_dimos( + self, ros_msg: Any, dimos_topic: Topic, dimos_type: Type, _topic_name: str + ) -> None: + """Convert ROS message to DIMOS and publish. + + Args: + ros_msg: ROS message + dimos_topic: DIMOS topic to publish to + dimos_type: DIMOS message type + topic_name: Name of the topic for tracking + """ + dimos_msg = dimos_type.from_ros_msg(ros_msg) + self.lcm.publish(dimos_topic, dimos_msg) + + def _dimos_to_ros(self, dimos_msg: Any, ros_publisher, _topic_name: str) -> None: + """Convert DIMOS message to ROS and publish. + + Args: + dimos_msg: DIMOS message + ros_publisher: ROS publisher to use + _topic_name: Name of the topic (unused, kept for consistency) + """ + ros_msg = dimos_msg.to_ros_msg() + ros_publisher.publish(ros_msg) + + def shutdown(self): + """Shutdown the bridge and clean up resources.""" + self._executor.shutdown() + self.node.destroy_node() + + if rclpy.ok(): + rclpy.shutdown() + + logger.info("ROSBridge shutdown complete") diff --git a/dimos/robot/test_ros_bridge.py b/dimos/robot/test_ros_bridge.py new file mode 100644 index 0000000000..a02967e62a --- /dev/null +++ b/dimos/robot/test_ros_bridge.py @@ -0,0 +1,909 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, patch + +import pytest +from geometry_msgs.msg import Twist as ROSTwist +from geometry_msgs.msg import Vector3 as ROSVector3 +from geometry_msgs.msg import PoseStamped as ROSPoseStamped +from geometry_msgs.msg import Pose as ROSPose +from geometry_msgs.msg import Point as ROSPoint +from geometry_msgs.msg import Quaternion as ROSQuaternion +from nav_msgs.msg import Path as ROSPath +from sensor_msgs.msg import LaserScan as ROSLaserScan +from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 +from std_msgs.msg import String as ROSString +from std_msgs.msg import Header as ROSHeader + +from dimos.msgs.geometry_msgs import Twist, Vector3, PoseStamped, Pose, Quaternion +from dimos.msgs.nav_msgs import Path +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.robot.ros_bridge import ROSBridge, BridgeDirection + + +@pytest.fixture +def bridge(): + """Create a ROSBridge instance with mocked internals.""" + with ( + patch("dimos.robot.ros_bridge.rclpy") as mock_rclpy, + patch("dimos.robot.ros_bridge.Node") as mock_node_class, + patch("dimos.robot.ros_bridge.LCM") as mock_lcm_class, + patch("dimos.robot.ros_bridge.MultiThreadedExecutor") as mock_executor_class, + ): + mock_rclpy.ok.return_value = False + mock_node = MagicMock() + mock_node.create_subscription = MagicMock(return_value=MagicMock()) + mock_node.create_publisher = MagicMock(return_value=MagicMock()) + mock_node_class.return_value = mock_node + + mock_lcm = MagicMock() + mock_lcm.subscribe = MagicMock(return_value=MagicMock()) + mock_lcm.publish = MagicMock() + mock_lcm_class.return_value = mock_lcm + + mock_executor = MagicMock() + mock_executor_class.return_value = mock_executor + + bridge = ROSBridge("test_bridge") + + bridge._mock_rclpy = mock_rclpy + bridge._mock_node_class = mock_node_class + bridge._mock_lcm_class = mock_lcm_class + + return bridge + + +def test_bridge_initialization(): + """Test that the bridge initializes correctly with its own instances.""" + with ( + patch("dimos.robot.ros_bridge.rclpy") as mock_rclpy, + patch("dimos.robot.ros_bridge.Node") as mock_node_class, + patch("dimos.robot.ros_bridge.LCM") as mock_lcm_class, + patch("dimos.robot.ros_bridge.MultiThreadedExecutor"), + ): + mock_rclpy.ok.return_value = False + + bridge = ROSBridge("test_bridge") + + mock_rclpy.init.assert_called_once() + mock_node_class.assert_called_once_with("test_bridge") + mock_lcm_class.assert_called_once() + bridge.lcm.start.assert_called_once() + + assert bridge._bridges == {} + assert bridge._qos is not None + + +def test_add_topic_ros_to_dimos(bridge): + """Test that add_topic creates ROS subscription for ROS->DIMOS direction.""" + topic_name = "/cmd_vel" + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + + bridge.node.create_subscription.assert_called_once() + call_args = bridge.node.create_subscription.call_args + assert call_args[0][0] == ROSTwist + assert call_args[0][1] == topic_name + + bridge.node.create_publisher.assert_not_called() + bridge.lcm.subscribe.assert_not_called() + + assert topic_name in bridge._bridges + assert "dimos_topic" in bridge._bridges[topic_name] + assert "dimos_type" in bridge._bridges[topic_name] + assert "ros_type" in bridge._bridges[topic_name] + assert bridge._bridges[topic_name]["dimos_type"] == Twist + assert bridge._bridges[topic_name]["ros_type"] == ROSTwist + assert bridge._bridges[topic_name]["direction"] == BridgeDirection.ROS_TO_DIMOS + + +def test_add_topic_dimos_to_ros(bridge): + """Test that add_topic creates ROS publisher for DIMOS->ROS direction.""" + topic_name = "/cmd_vel" + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.DIMOS_TO_ROS) + + bridge.node.create_subscription.assert_not_called() + bridge.node.create_publisher.assert_called_once_with(ROSTwist, topic_name, bridge._qos) + bridge.lcm.subscribe.assert_called_once() + + assert topic_name in bridge._bridges + assert "dimos_topic" in bridge._bridges[topic_name] + assert "dimos_type" in bridge._bridges[topic_name] + assert "ros_type" in bridge._bridges[topic_name] + assert bridge._bridges[topic_name]["dimos_type"] == Twist + assert bridge._bridges[topic_name]["ros_type"] == ROSTwist + assert bridge._bridges[topic_name]["direction"] == BridgeDirection.DIMOS_TO_ROS + + +def test_ros_to_dimos_conversion(bridge): + """Test ROS to DIMOS message conversion and publishing.""" + # Create a ROS Twist message + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + + # Create DIMOS topic + dimos_topic = Topic("/test", Twist) + + # Call the conversion method with type + bridge._ros_to_dimos(ros_msg, dimos_topic, Twist, "/test") + + # Verify DIMOS publish was called + bridge.lcm.publish.assert_called_once() + + # Get the published message + published_topic, published_msg = bridge.lcm.publish.call_args[0] + + assert published_topic == dimos_topic + assert isinstance(published_msg, Twist) + assert published_msg.linear.x == 1.0 + assert published_msg.linear.y == 2.0 + assert published_msg.linear.z == 3.0 + assert published_msg.angular.x == 0.1 + assert published_msg.angular.y == 0.2 + assert published_msg.angular.z == 0.3 + + +def test_dimos_to_ros_conversion(bridge): + """Test DIMOS to ROS message conversion and publishing.""" + # Create a DIMOS Twist message + dimos_msg = Twist(linear=Vector3(4.0, 5.0, 6.0), angular=Vector3(0.4, 0.5, 0.6)) + + # Create mock ROS publisher + ros_pub = MagicMock() + + # Call the conversion method + bridge._dimos_to_ros(dimos_msg, ros_pub, "/test") + + # Verify ROS publish was called + ros_pub.publish.assert_called_once() + + # Get the published message + published_msg = ros_pub.publish.call_args[0][0] + + assert isinstance(published_msg, ROSTwist) + assert published_msg.linear.x == 4.0 + assert published_msg.linear.y == 5.0 + assert published_msg.linear.z == 6.0 + assert published_msg.angular.x == 0.4 + assert published_msg.angular.y == 0.5 + assert published_msg.angular.z == 0.6 + + +def test_unidirectional_flow_ros_to_dimos(bridge): + """Test that messages flow from ROS to DIMOS when configured.""" + topic_name = "/cmd_vel" + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + + ros_callback = bridge.node.create_subscription.call_args[0][2] + + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=1.5, y=2.5, z=3.5) + ros_msg.angular = ROSVector3(x=0.15, y=0.25, z=0.35) + + ros_callback(ros_msg) + + bridge.lcm.publish.assert_called_once() + _, published_msg = bridge.lcm.publish.call_args[0] + assert isinstance(published_msg, Twist) + assert published_msg.linear.x == 1.5 + + +def test_unidirectional_flow_dimos_to_ros(bridge): + """Test that messages flow from DIMOS to ROS when configured.""" + topic_name = "/cmd_vel" + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.DIMOS_TO_ROS) + + dimos_callback = bridge.lcm.subscribe.call_args[0][1] + + dimos_msg = Twist(linear=Vector3(7.0, 8.0, 9.0), angular=Vector3(0.7, 0.8, 0.9)) + + ros_publisher = bridge.node.create_publisher.return_value + + dimos_callback(dimos_msg, None) + + ros_publisher.publish.assert_called_once() + published_ros_msg = ros_publisher.publish.call_args[0][0] + assert isinstance(published_ros_msg, ROSTwist) + assert published_ros_msg.linear.x == 7.0 + + +def test_multiple_topics(bridge): + """Test that multiple topics can be bridged simultaneously.""" + topics = [ + ("/cmd_vel", BridgeDirection.ROS_TO_DIMOS), + ("/teleop", BridgeDirection.DIMOS_TO_ROS), + ("/nav_cmd", BridgeDirection.ROS_TO_DIMOS), + ] + + for topic, direction in topics: + bridge.add_topic(topic, Twist, ROSTwist, direction=direction) + + assert len(bridge._bridges) == 3 + for topic, _ in topics: + assert topic in bridge._bridges + + assert bridge.node.create_subscription.call_count == 2 + assert bridge.node.create_publisher.call_count == 1 + assert bridge.lcm.subscribe.call_count == 1 + + +def test_stress_ros_to_dimos_100_messages(bridge): + """Test publishing 100 ROS messages and verify DIMOS receives them all.""" + topic_name = "/stress_test" + num_messages = 100 + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + + ros_callback = bridge.node.create_subscription.call_args[0][2] + + for i in range(num_messages): + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=float(i), y=float(i * 2), z=float(i * 3)) + ros_msg.angular = ROSVector3(x=float(i * 0.1), y=float(i * 0.2), z=float(i * 0.3)) + + ros_callback(ros_msg) + + assert bridge.lcm.publish.call_count == num_messages + + last_call = bridge.lcm.publish.call_args_list[-1] + _, last_msg = last_call[0] + assert isinstance(last_msg, Twist) + assert last_msg.linear.x == 99.0 + assert last_msg.linear.y == 198.0 + assert last_msg.linear.z == 297.0 + assert abs(last_msg.angular.x - 9.9) < 0.01 + assert abs(last_msg.angular.y - 19.8) < 0.01 + assert abs(last_msg.angular.z - 29.7) < 0.01 + + +def test_stress_dimos_to_ros_100_messages(bridge): + """Test publishing 100 DIMOS messages and verify ROS receives them all.""" + topic_name = "/stress_test_reverse" + num_messages = 100 + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.DIMOS_TO_ROS) + + dimos_callback = bridge.lcm.subscribe.call_args[0][1] + ros_publisher = bridge.node.create_publisher.return_value + + for i in range(num_messages): + dimos_msg = Twist( + linear=Vector3(float(i * 10), float(i * 20), float(i * 30)), + angular=Vector3(float(i * 0.01), float(i * 0.02), float(i * 0.03)), + ) + + dimos_callback(dimos_msg, None) + + assert ros_publisher.publish.call_count == num_messages + + last_call = ros_publisher.publish.call_args_list[-1] + last_ros_msg = last_call[0][0] + assert isinstance(last_ros_msg, ROSTwist) + assert last_ros_msg.linear.x == 990.0 + assert last_ros_msg.linear.y == 1980.0 + assert last_ros_msg.linear.z == 2970.0 + assert abs(last_ros_msg.angular.x - 0.99) < 0.001 + assert abs(last_ros_msg.angular.y - 1.98) < 0.001 + assert abs(last_ros_msg.angular.z - 2.97) < 0.001 + + +def test_two_topics_different_directions(bridge): + """Test two topics with different directions handling messages.""" + topic_r2d = "/ros_to_dimos" + topic_d2r = "/dimos_to_ros" + + bridge.add_topic(topic_r2d, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + bridge.add_topic(topic_d2r, Twist, ROSTwist, direction=BridgeDirection.DIMOS_TO_ROS) + + ros_callback = bridge.node.create_subscription.call_args[0][2] + dimos_callback = bridge.lcm.subscribe.call_args[0][1] + ros_publisher = bridge.node.create_publisher.return_value + + for i in range(50): + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=float(i), y=0.0, z=0.0) + ros_msg.angular = ROSVector3(x=0.0, y=0.0, z=float(i * 0.1)) + ros_callback(ros_msg) + + dimos_msg = Twist( + linear=Vector3(0.0, float(i), 0.0), angular=Vector3(0.0, 0.0, float(i * 0.2)) + ) + dimos_callback(dimos_msg, None) + + assert bridge.lcm.publish.call_count == 50 + assert ros_publisher.publish.call_count == 50 + + last_dimos_call = bridge.lcm.publish.call_args_list[-1] + _, last_dimos_msg = last_dimos_call[0] + assert last_dimos_msg.linear.x == 49.0 + + last_ros_call = ros_publisher.publish.call_args_list[-1] + last_ros_msg = last_ros_call[0][0] + assert last_ros_msg.linear.y == 49.0 + + +def test_high_frequency_burst(bridge): + """Test handling a burst of 1000 messages to ensure no drops.""" + topic_name = "/burst_test" + burst_size = 1000 + + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + + ros_callback = bridge.node.create_subscription.call_args[0][2] + + messages_sent = [] + for i in range(burst_size): + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=float(i), y=float(i), z=float(i)) + ros_msg.angular = ROSVector3(x=0.0, y=0.0, z=0.0) + messages_sent.append(i) + ros_callback(ros_msg) + + assert bridge.lcm.publish.call_count == burst_size + + for idx, call in enumerate(bridge.lcm.publish.call_args_list): + _, msg = call[0] + assert msg.linear.x == float(idx) + + +def test_multiple_topics_with_different_rates(bridge): + """Test multiple topics receiving messages at different rates.""" + topics = { + "/fast_topic": 100, # 100 messages + "/medium_topic": 50, # 50 messages + "/slow_topic": 10, # 10 messages + } + + for topic_name in topics: + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + + callbacks = [] + for i in range(3): + callbacks.append(bridge.node.create_subscription.call_args_list[i][0][2]) + + bridge.lcm.publish.reset_mock() + + for topic_idx, (topic_name, msg_count) in enumerate(topics.items()): + for i in range(msg_count): + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=float(topic_idx), y=float(i), z=0.0) + callbacks[topic_idx](ros_msg) + + total_expected = sum(topics.values()) + assert bridge.lcm.publish.call_count == total_expected + + +def test_pose_stamped_bridging(bridge): + """Test bridging PoseStamped messages.""" + topic_name = "/robot_pose" + + # Test ROS to DIMOS + bridge.add_topic( + topic_name, PoseStamped, ROSPoseStamped, direction=BridgeDirection.ROS_TO_DIMOS + ) + + ros_callback = bridge.node.create_subscription.call_args[0][2] + + ros_msg = ROSPoseStamped() + ros_msg.header.frame_id = "map" + ros_msg.header.stamp.sec = 100 + ros_msg.header.stamp.nanosec = 500000000 + ros_msg.pose.position.x = 10.0 + ros_msg.pose.position.y = 20.0 + ros_msg.pose.position.z = 30.0 + ros_msg.pose.orientation.x = 0.0 + ros_msg.pose.orientation.y = 0.0 + ros_msg.pose.orientation.z = 0.707 + ros_msg.pose.orientation.w = 0.707 + + ros_callback(ros_msg) + + bridge.lcm.publish.assert_called_once() + _, published_msg = bridge.lcm.publish.call_args[0] + assert hasattr(published_msg, "frame_id") + assert hasattr(published_msg, "position") + assert hasattr(published_msg, "orientation") + + +def test_path_bridging(bridge): + """Test bridging Path messages.""" + topic_name = "/planned_path" + + # Test DIMOS to ROS + bridge.add_topic(topic_name, Path, ROSPath, direction=BridgeDirection.DIMOS_TO_ROS) + + dimos_callback = bridge.lcm.subscribe.call_args[0][1] + ros_publisher = bridge.node.create_publisher.return_value + + # Create a DIMOS Path with multiple poses + poses = [] + for i in range(5): + pose = PoseStamped( + ts=100.0 + i, + frame_id="map", + position=Vector3(float(i), float(i * 2), 0.0), + orientation=Quaternion(0, 0, 0, 1), + ) + poses.append(pose) + + dimos_path = Path(frame_id="map", poses=poses) + + dimos_callback(dimos_path, None) + + ros_publisher.publish.assert_called_once() + published_ros_msg = ros_publisher.publish.call_args[0][0] + assert isinstance(published_ros_msg, ROSPath) + + +def test_multiple_message_types(bridge): + """Test bridging multiple different message types simultaneously.""" + topics = [ + ("/cmd_vel", Twist, ROSTwist, BridgeDirection.ROS_TO_DIMOS), + ("/robot_pose", PoseStamped, ROSPoseStamped, BridgeDirection.DIMOS_TO_ROS), + ("/global_path", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/local_path", Path, ROSPath, BridgeDirection.ROS_TO_DIMOS), + ("/teleop_twist", Twist, ROSTwist, BridgeDirection.ROS_TO_DIMOS), + ] + + for topic_name, dimos_type, ros_type, direction in topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + assert len(bridge._bridges) == 5 + + # Count subscriptions and publishers + ros_to_dimos_count = sum(1 for _, _, _, d in topics if d == BridgeDirection.ROS_TO_DIMOS) + dimos_to_ros_count = sum(1 for _, _, _, d in topics if d == BridgeDirection.DIMOS_TO_ROS) + + assert bridge.node.create_subscription.call_count == ros_to_dimos_count + assert bridge.node.create_publisher.call_count == dimos_to_ros_count + assert bridge.lcm.subscribe.call_count == dimos_to_ros_count + + +def test_topic_remapping_ros_to_dimos(bridge): + """Test remapping topic names for ROS to DIMOS direction.""" + ros_topic = "/cmd_vel" + dimos_remapped = "/robot/velocity_command" + + bridge.add_topic( + ros_topic, + Twist, + ROSTwist, + direction=BridgeDirection.ROS_TO_DIMOS, + remap_topic=dimos_remapped, + ) + + # Verify ROS subscribes to original topic + bridge.node.create_subscription.assert_called_once() + call_args = bridge.node.create_subscription.call_args + assert call_args[0][1] == ros_topic # ROS side uses original topic + + # Verify bridge metadata contains both names + assert ros_topic in bridge._bridges + bridge_info = bridge._bridges[ros_topic] + assert bridge_info["ros_topic_name"] == ros_topic + assert bridge_info["dimos_topic_name"] == dimos_remapped + assert bridge_info["dimos_topic"].topic == dimos_remapped + + +def test_topic_remapping_dimos_to_ros(bridge): + """Test remapping topic names for DIMOS to ROS direction.""" + dimos_topic = "/velocity_command" + ros_remapped = "/mobile_base/cmd_vel" + + bridge.add_topic( + dimos_topic, + Twist, + ROSTwist, + direction=BridgeDirection.DIMOS_TO_ROS, + remap_topic=ros_remapped, + ) + + # Verify ROS publishes to remapped topic + bridge.node.create_publisher.assert_called_once_with(ROSTwist, ros_remapped, bridge._qos) + + # Verify DIMOS subscribes to original topic + bridge.lcm.subscribe.assert_called_once() + dimos_topic_arg = bridge.lcm.subscribe.call_args[0][0] + assert dimos_topic_arg.topic == dimos_topic + + # Verify bridge metadata + assert dimos_topic in bridge._bridges + bridge_info = bridge._bridges[dimos_topic] + assert bridge_info["ros_topic_name"] == ros_remapped + assert bridge_info["dimos_topic_name"] == dimos_topic + + +def test_remapped_message_flow_ros_to_dimos(bridge): + """Test message flow with remapped topics from ROS to DIMOS.""" + ros_topic = "/ros/cmd_vel" + dimos_remapped = "/dimos/velocity" + + bridge.add_topic( + ros_topic, + Twist, + ROSTwist, + direction=BridgeDirection.ROS_TO_DIMOS, + remap_topic=dimos_remapped, + ) + + # Get the ROS callback + ros_callback = bridge.node.create_subscription.call_args[0][2] + + # Send a ROS message + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=2.0, y=3.0, z=4.0) + ros_msg.angular = ROSVector3(x=0.2, y=0.3, z=0.4) + + ros_callback(ros_msg) + + # Verify DIMOS publishes to remapped topic + bridge.lcm.publish.assert_called_once() + published_topic, published_msg = bridge.lcm.publish.call_args[0] + assert published_topic.topic == dimos_remapped + assert isinstance(published_msg, Twist) + assert published_msg.linear.x == 2.0 + + +def test_remapped_message_flow_dimos_to_ros(bridge): + """Test message flow with remapped topics from DIMOS to ROS.""" + dimos_topic = "/dimos/velocity" + ros_remapped = "/ros/cmd_vel" + + bridge.add_topic( + dimos_topic, + Twist, + ROSTwist, + direction=BridgeDirection.DIMOS_TO_ROS, + remap_topic=ros_remapped, + ) + + # Get the DIMOS callback + dimos_callback = bridge.lcm.subscribe.call_args[0][1] + + # Send a DIMOS message + dimos_msg = Twist(linear=Vector3(5.0, 6.0, 7.0), angular=Vector3(0.5, 0.6, 0.7)) + + ros_publisher = bridge.node.create_publisher.return_value + dimos_callback(dimos_msg, None) + + # Verify ROS publishes to remapped topic + ros_publisher.publish.assert_called_once() + published_msg = ros_publisher.publish.call_args[0][0] + assert isinstance(published_msg, ROSTwist) + assert published_msg.linear.x == 5.0 + + +def test_multiple_remapped_topics(bridge): + """Test multiple topics with remapping.""" + topic_configs = [ + # (original_name, dimos_type, ros_type, direction, remap_name) + ("/cmd_vel", Twist, ROSTwist, BridgeDirection.ROS_TO_DIMOS, "/robot/velocity"), + ("/odom", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS, "/robot/odometry"), + ("/path", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS, "/navigation/global_path"), + ( + "/goal", + PoseStamped, + ROSPoseStamped, + BridgeDirection.DIMOS_TO_ROS, + "/navigation/goal_pose", + ), + ] + + for topic, dimos_type, ros_type, direction, remap in topic_configs: + bridge.add_topic(topic, dimos_type, ros_type, direction=direction, remap_topic=remap) + + assert len(bridge._bridges) == 4 + + # Verify ROS to DIMOS remapping + assert bridge._bridges["/cmd_vel"]["dimos_topic_name"] == "/robot/velocity" + assert bridge._bridges["/odom"]["dimos_topic_name"] == "/robot/odometry" + + # Verify DIMOS to ROS remapping + assert bridge._bridges["/path"]["ros_topic_name"] == "/navigation/global_path" + assert bridge._bridges["/goal"]["ros_topic_name"] == "/navigation/goal_pose" + + +def test_no_remapping_when_none(bridge): + """Test that topics work normally when remap_topic is None.""" + topic = "/cmd_vel" + + bridge.add_topic( + topic, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS, remap_topic=None + ) + + bridge_info = bridge._bridges[topic] + assert bridge_info["ros_topic_name"] == topic + assert bridge_info["dimos_topic_name"] == topic + + +def test_stress_remapped_topics(bridge): + """Test stress scenario with remapped topics.""" + num_messages = 100 + ros_topic = "/ros/high_freq" + dimos_remapped = "/dimos/data_stream" + + bridge.add_topic( + ros_topic, + Twist, + ROSTwist, + direction=BridgeDirection.ROS_TO_DIMOS, + remap_topic=dimos_remapped, + ) + + ros_callback = bridge.node.create_subscription.call_args[0][2] + + for i in range(num_messages): + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=float(i), y=float(i * 2), z=float(i * 3)) + ros_callback(ros_msg) + + assert bridge.lcm.publish.call_count == num_messages + + # Verify all published to remapped topic + for call in bridge.lcm.publish.call_args_list: + topic, _ = call[0] + assert topic.topic == dimos_remapped + + +def test_navigation_stack_topics(bridge): + """Test common navigation stack topics.""" + nav_topics = [ + ("/move_base/goal", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/move_base/global_plan", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/move_base/local_plan", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/cmd_vel", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/odom", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/robot_pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ] + + for topic_name, dimos_type, ros_type, direction in nav_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + assert len(bridge._bridges) == len(nav_topics) + + # Verify each topic is configured correctly + for topic_name, dimos_type, ros_type, direction in nav_topics: + assert topic_name in bridge._bridges + assert bridge._bridges[topic_name]["dimos_type"] == dimos_type + assert bridge._bridges[topic_name]["ros_type"] == ros_type + assert bridge._bridges[topic_name]["direction"] == direction + + +def test_control_topics(bridge): + """Test control system topics.""" + control_topics = [ + ("/joint_commands", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/joint_states", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/trajectory", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/feedback", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ] + + for topic_name, dimos_type, ros_type, direction in control_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + assert len(bridge._bridges) == len(control_topics) + + +def test_perception_topics(bridge): + """Test perception system topics.""" + perception_topics = [ + ("/detected_pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/tracked_path", Path, ROSPath, BridgeDirection.ROS_TO_DIMOS), + ("/vision_pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ] + + for topic_name, dimos_type, ros_type, direction in perception_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + # All perception topics are ROS to DIMOS + assert bridge.node.create_subscription.call_count == len(perception_topics) + assert bridge.node.create_publisher.call_count == 0 + + +def test_mixed_frequency_topics(bridge): + """Test topics with different expected frequencies.""" + # High frequency (100Hz+) + high_freq_topics = [ + ("/imu/data", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/joint_states", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ] + + # Medium frequency (10-50Hz) + medium_freq_topics = [ + ("/cmd_vel", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/odom", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ] + + # Low frequency (1-5Hz) + low_freq_topics = [ + ("/global_path", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/goal", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ] + + all_topics = high_freq_topics + medium_freq_topics + low_freq_topics + + for topic_name, dimos_type, ros_type, direction in all_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + assert len(bridge._bridges) == len(all_topics) + + # Test high frequency message handling + for topic_name, _, _, direction in high_freq_topics: + if direction == BridgeDirection.ROS_TO_DIMOS: + # Find the callback for this topic + for i, call in enumerate(bridge.node.create_subscription.call_args_list): + if call[0][1] == topic_name: + callback = call[0][2] + # Send 100 messages rapidly + for j in range(100): + ros_msg = ROSPoseStamped() + ros_msg.header.stamp.sec = j + callback(ros_msg) + break + + +def test_bidirectional_prevention(bridge): + """Test that the same topic cannot be added in both directions.""" + topic_name = "/cmd_vel" + + # Add topic in one direction + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.ROS_TO_DIMOS) + + # Try to add the same topic in opposite direction should not create duplicate + # The bridge should handle this gracefully + initial_bridges = len(bridge._bridges) + bridge.add_topic(topic_name, Twist, ROSTwist, direction=BridgeDirection.DIMOS_TO_ROS) + + # Should still have the same number of bridges (topic gets reconfigured, not duplicated) + assert len(bridge._bridges) == initial_bridges + + +def test_robot_arm_topics(bridge): + """Test robot arm control topics.""" + arm_topics = [ + ("/arm/joint_trajectory", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/arm/joint_states", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/arm/end_effector_pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/arm/gripper_cmd", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/arm/cartesian_trajectory", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ] + + for topic_name, dimos_type, ros_type, direction in arm_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + assert len(bridge._bridges) == len(arm_topics) + + # Check that arm control commands go from DIMOS to ROS + dimos_to_ros = [t for t in arm_topics if t[3] == BridgeDirection.DIMOS_TO_ROS] + ros_to_dimos = [t for t in arm_topics if t[3] == BridgeDirection.ROS_TO_DIMOS] + + assert bridge.node.create_publisher.call_count == len(dimos_to_ros) + assert bridge.node.create_subscription.call_count == len(ros_to_dimos) + + +def test_mobile_base_topics(bridge): + """Test mobile robot base topics.""" + base_topics = [ + ("/base/cmd_vel", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/base/odom", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/base/global_pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/base/path", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/base/local_plan", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ] + + for topic_name, dimos_type, ros_type, direction in base_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + # Verify topics are properly categorized + for topic_name, dimos_type, ros_type, direction in base_topics: + bridge_info = bridge._bridges[topic_name] + assert bridge_info["direction"] == direction + assert bridge_info["dimos_type"] == dimos_type + assert bridge_info["ros_type"] == ros_type + + +def test_autonomous_vehicle_topics(bridge): + """Test autonomous vehicle topics.""" + av_topics = [ + ("/vehicle/steering_cmd", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/vehicle/throttle_cmd", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/vehicle/brake_cmd", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS), + ("/vehicle/pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS), + ("/vehicle/planned_trajectory", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS), + ("/vehicle/current_path", Path, ROSPath, BridgeDirection.ROS_TO_DIMOS), + ] + + for topic_name, dimos_type, ros_type, direction in av_topics: + bridge.add_topic(topic_name, dimos_type, ros_type, direction=direction) + + assert len(bridge._bridges) == len(av_topics) + + # Count control vs feedback topics + control_topics = [t for t in av_topics if "cmd" in t[0] or "planned" in t[0]] + feedback_topics = [t for t in av_topics if "pose" in t[0] or "current" in t[0]] + + assert len(control_topics) == 4 # steering, throttle, brake, planned_trajectory + assert len(feedback_topics) == 2 # pose, current_path + + +def test_remapping_with_navigation_stack(bridge): + """Test remapping with common navigation stack patterns.""" + # Map ROS2 Nav2 topics to custom DIMOS topics + nav_remapping = [ + ("/cmd_vel", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS, "/nav2/cmd_vel"), + ("/odom", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS, "/robot/odometry"), + ("/global_plan", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS, "/nav2/plan"), + ("/local_plan", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS, "/nav2/local_plan"), + ("/goal_pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS, "/robot/goal"), + ] + + for topic, dimos_type, ros_type, direction, remap in nav_remapping: + bridge.add_topic(topic, dimos_type, ros_type, direction=direction, remap_topic=remap) + + # Verify DIMOS to ROS remappings + assert bridge._bridges["/cmd_vel"]["ros_topic_name"] == "/nav2/cmd_vel" + assert bridge._bridges["/global_plan"]["ros_topic_name"] == "/nav2/plan" + assert bridge._bridges["/local_plan"]["ros_topic_name"] == "/nav2/local_plan" + + # Verify ROS to DIMOS remappings + assert bridge._bridges["/odom"]["dimos_topic_name"] == "/robot/odometry" + assert bridge._bridges["/goal_pose"]["dimos_topic_name"] == "/robot/goal" + + +def test_remapping_with_robot_namespace(bridge): + """Test remapping for multi-robot systems with namespaces.""" + robot_id = "robot1" + + # Remap topics to include robot namespace + topics_with_namespace = [ + ("/cmd_vel", Twist, ROSTwist, BridgeDirection.DIMOS_TO_ROS, f"/{robot_id}/cmd_vel"), + ("/pose", PoseStamped, ROSPoseStamped, BridgeDirection.ROS_TO_DIMOS, f"/{robot_id}/pose"), + ("/path", Path, ROSPath, BridgeDirection.DIMOS_TO_ROS, f"/{robot_id}/path"), + ] + + for topic, dimos_type, ros_type, direction, remap in topics_with_namespace: + bridge.add_topic(topic, dimos_type, ros_type, direction=direction, remap_topic=remap) + + # Verify all topics are properly namespaced + assert bridge._bridges["/cmd_vel"]["ros_topic_name"] == "/robot1/cmd_vel" + assert bridge._bridges["/pose"]["dimos_topic_name"] == "/robot1/pose" + assert bridge._bridges["/path"]["ros_topic_name"] == "/robot1/path" + + +def test_remapping_preserves_original_key(bridge): + """Test that remapping preserves the original topic name as the key.""" + original_topic = "/original_topic" + remapped_name = "/remapped_topic" + + bridge.add_topic( + original_topic, + Twist, + ROSTwist, + direction=BridgeDirection.ROS_TO_DIMOS, + remap_topic=remapped_name, + ) + + # Original topic name should be the key + assert original_topic in bridge._bridges + assert remapped_name not in bridge._bridges + + # Bridge info should contain both names + bridge_info = bridge._bridges[original_topic] + assert bridge_info["ros_topic_name"] == original_topic + assert bridge_info["dimos_topic_name"] == remapped_name diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index faad61833d..d05676d3e4 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -73,6 +73,8 @@ class UnitreeWebRTCConnection: def __init__(self, ip: str, mode: str = "ai"): self.ip = ip self.mode = mode + self.stop_timer = None + self.cmd_vel_timeout = 0.2 self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) self.connect() @@ -127,7 +129,7 @@ def move(self, twist: Twist, duration: float = 0.0) -> bool: async def async_move(): self.conn.datachannel.pub_sub.publish_without_callback( RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": y, "ly": x, "rx": -yaw, "ry": 0}, + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, ) async def async_move_duration(): @@ -139,6 +141,15 @@ async def async_move_duration(): await async_move() await asyncio.sleep(sleep_time) + # Cancel existing timer and start a new one + if self.stop_timer: + self.stop_timer.cancel() + + # Auto-stop after 0.5 seconds if no new commands + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer.daemon = True + self.stop_timer.start() + try: if duration > 0: # Send continuous move commands for the duration @@ -323,10 +334,20 @@ def stop(self) -> bool: Returns: bool: True if stop command was sent successfully """ + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + return self.move(Twist()) def disconnect(self) -> None: """Disconnect from the robot and clean up resources.""" + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + if hasattr(self, "task") and self.task: self.task.cancel() if hasattr(self, "conn"): diff --git a/dimos/robot/unitree_webrtc/unitree_b1/README.md b/dimos/robot/unitree_webrtc/unitree_b1/README.md index 040a0a6da9..8616fc286a 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/README.md +++ b/dimos/robot/unitree_webrtc/unitree_b1/README.md @@ -167,6 +167,25 @@ External Machine (Client) B1 Robot (Server) └─────────────────────┘ └──────────────────┘ ``` +## Setting up ROS Navigation stack with Unitree B1 + +### Setup external Wireless USB Adapter on onboard hardware +This is because the onboard hardware (mini PC, jetson, etc.) needs to connect to both the B1 wifi AP network to send cmd_vel messages over UDP, as well as the network running dimensional + + +Plug in wireless adapter +```bash +nmcli device status +nmcli device wifi list ifname *DEVICE_NAME* +# Connect to b1 network +nmcli device wifi connect "Unitree_B1-251" password "00000000" ifname *DEVICE_NAME* +# Verify connection +nmcli connection show --active +``` + +### *TODO: add more docs* + + ## Troubleshooting ### Cannot connect to B1 diff --git a/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py b/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py index ac56978e72..ab547dade2 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/b1_command.py @@ -51,22 +51,27 @@ def from_twist(cls, twist, mode: int = 2): Returns: B1Command configured for the given Twist """ + # Max velocities from ROS needed to clamp to joystick ranges properly + MAX_LINEAR_VEL = 1.0 # m/s + MAX_ANGULAR_VEL = 2.0 # rad/s + if mode == 2: # WALK mode - velocity control return cls( - lx=-twist.angular.z, # ROS yaw → B1 turn (negated for correct direction) - ly=twist.linear.x, # ROS forward → B1 forward - rx=-twist.linear.y, # ROS lateral → B1 strafe (negated for correct direction) + # Scale and clamp to joystick range [-1, 1] + lx=max(-1.0, min(1.0, -twist.angular.z / MAX_ANGULAR_VEL)), + ly=max(-1.0, min(1.0, twist.linear.x / MAX_LINEAR_VEL)), + rx=max(-1.0, min(1.0, -twist.linear.y / MAX_LINEAR_VEL)), ry=0.0, # No pitch control in walk mode mode=mode, ) elif mode == 1: # STAND mode - body pose control # Map Twist pose controls to B1 joystick axes - # G1 cpp server maps: ly→bodyHeight, lx→euler[2], rx→euler[0], ry→euler[1] + # Already in normalized units, just clamp to [-1, 1] return cls( - lx=-twist.angular.z, # ROS yaw → B1 yaw (euler[2]) - ly=twist.linear.z, # ROS height → B1 bodyHeight - rx=-twist.angular.x, # ROS roll → B1 roll (euler[0]) - ry=twist.angular.y, # ROS pitch → B1 pitch (euler[1]) + lx=max(-1.0, min(1.0, -twist.angular.z)), # ROS yaw → B1 yaw + ly=max(-1.0, min(1.0, twist.linear.z)), # ROS height → B1 bodyHeight + rx=max(-1.0, min(1.0, -twist.angular.x)), # ROS roll → B1 roll + ry=max(-1.0, min(1.0, twist.angular.y)), # ROS pitch → B1 pitch mode=mode, ) else: diff --git a/dimos/robot/unitree_webrtc/unitree_b1/connection.py b/dimos/robot/unitree_webrtc/unitree_b1/connection.py index 9cbcd2f7d7..e06ddf501b 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/connection.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/connection.py @@ -17,16 +17,22 @@ """B1 Connection Module that accepts standard Twist commands and converts to UDP packets.""" +import logging import socket import threading import time from typing import Optional -from dimos.core import In, Module, rpc -from dimos.msgs.geometry_msgs import Twist +from dimos.core import In, Out, Module, rpc +from dimos.msgs.geometry_msgs import Twist, TwistStamped, PoseStamped +from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.std_msgs import Int32 +from dimos.utils.logging_config import setup_logger from .b1_command import B1Command +# Setup logger with DEBUG level for troubleshooting +logger = setup_logger("dimos.robot.unitree_webrtc.unitree_b1.connection", level=logging.DEBUG) + class B1ConnectionModule(Module): """UDP connection module for B1 robot with standard Twist interface. @@ -35,9 +41,11 @@ class B1ConnectionModule(Module): internally converts to B1Command format, and sends UDP packets at 50Hz. """ - # Module inputs - cmd_vel: In[Twist] = None # Standard velocity commands + cmd_vel: In[TwistStamped] = None # Timestamped velocity commands from ROS mode_cmd: In[Int32] = None # Mode changes + odom_in: In[Odometry] = None # External odometry from ROS SLAM/lidar + + odom_pose: Out[PoseStamped] = None # Converted pose for internal use def __init__( self, ip: str = "192.168.12.1", port: int = 9090, test_mode: bool = False, *args, **kwargs @@ -63,7 +71,10 @@ def __init__( self.socket = None self.packet_count = 0 self.last_command_time = time.time() - self.command_timeout = 0.1 # 100ms timeout matching C++ server + self.command_timeout = 0.2 # 200ms safety timeout + self.watchdog_thread = None + self.watchdog_running = False + self.timeout_active = False @rpc def start(self): @@ -72,26 +83,40 @@ def start(self): # Setup UDP socket (unless in test mode) if not self.test_mode: self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - print(f"B1 Connection started - UDP to {self.ip}:{self.port} at 50Hz") + logger.info(f"B1 Connection started - UDP to {self.ip}:{self.port} at 50Hz") else: - print(f"[TEST MODE] B1 Connection started - would send to {self.ip}:{self.port}") + logger.info(f"[TEST MODE] B1 Connection started - would send to {self.ip}:{self.port}") # Subscribe to input streams if self.cmd_vel: - self.cmd_vel.subscribe(self.handle_twist) + self.cmd_vel.subscribe(self.handle_twist_stamped) if self.mode_cmd: self.mode_cmd.subscribe(self.handle_mode) + if self.odom_in: + self.odom_in.subscribe(self._publish_odom_pose) - # Start 50Hz sending thread + # Start threads self.running = True + self.watchdog_running = True + + # Start 50Hz sending thread self.send_thread = threading.Thread(target=self._send_loop, daemon=True) self.send_thread.start() + # Start watchdog thread + self.watchdog_thread = threading.Thread(target=self._watchdog_loop, daemon=True) + self.watchdog_thread.start() + return True @rpc def stop(self): """Stop the connection and send stop commands.""" + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + self.set_mode(0) # IDLE self._current_cmd = B1Command(mode=0) # Zero all velocities @@ -104,8 +129,12 @@ def stop(self): time.sleep(0.02) self.running = False + self.watchdog_running = False + if self.send_thread: self.send_thread.join(timeout=0.5) + if self.watchdog_thread: + self.watchdog_thread.join(timeout=0.5) if self.socket: self.socket.close() @@ -113,26 +142,58 @@ def stop(self): return True - def handle_twist(self, twist: Twist): - """Handle standard Twist message and convert to B1Command. + def handle_twist_stamped(self, twist_stamped: TwistStamped): + """Handle timestamped Twist message and convert to B1Command. This is called automatically when messages arrive on cmd_vel input. """ + # Extract Twist from TwistStamped + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) + + logger.debug( + f"Received cmd_vel: linear=({twist.linear.x:.3f}, {twist.linear.y:.3f}, {twist.linear.z:.3f}), angular=({twist.angular.x:.3f}, {twist.angular.y:.3f}, {twist.angular.z:.3f})" + ) + + # In STAND mode (1), all twist values control body pose, not movement + # W/S: height (linear.z), A/D: yaw (angular.z), J/L: roll (angular.x), I/K: pitch (angular.y) + if self.current_mode == 1: + # In STAND mode, don't auto-switch since all inputs are valid body pose controls + has_movement = False + else: + # In other modes, consider linear x/y and angular.z as movement + has_movement = ( + abs(twist.linear.x) > 0.01 + or abs(twist.linear.y) > 0.01 + or abs(twist.angular.z) > 0.01 + ) + + if has_movement and self.current_mode not in (1, 2): + logger.info("Auto-switching to WALK mode for ROS control") + self.set_mode(2) + elif not has_movement and self.current_mode == 2: + logger.info("Auto-switching to IDLE mode (zero velocities)") + self.set_mode(0) + if self.test_mode: - print( - f"[TEST] Received Twist: linear=({twist.linear.x:.2f}, {twist.linear.y:.2f}), angular.z={twist.angular.z:.2f}" + logger.info( + f"[TEST] Received TwistStamped: linear=({twist.linear.x:.2f}, {twist.linear.y:.2f}), angular.z={twist.angular.z:.2f}" ) - # Convert Twist to B1Command + self._current_cmd = B1Command.from_twist(twist, self.current_mode) + + logger.debug(f"Converted to B1Command: {self._current_cmd}") + self.last_command_time = time.time() + self.timeout_active = False # Reset timeout state since we got a new command def handle_mode(self, mode_msg: Int32): """Handle mode change message. This is called automatically when messages arrive on mode_cmd input. """ + logger.debug(f"Received mode change: {mode_msg.data}") if self.test_mode: - print(f"[TEST] Received mode change: {mode_msg.data}") + logger.info(f"[TEST] Received mode change: {mode_msg.data}") self.set_mode(mode_msg.data) @rpc @@ -148,44 +209,34 @@ def set_mode(self, mode: int): self._current_cmd.rx = 0.0 self._current_cmd.ry = 0.0 + mode_names = {0: "IDLE", 1: "STAND", 2: "WALK", 6: "RECOVERY"} + logger.info(f"Mode changed to: {mode_names.get(mode, mode)}") if self.test_mode: - mode_names = {0: "IDLE", 1: "STAND", 2: "WALK", 6: "RECOVERY"} - print(f"[TEST] Mode changed to: {mode_names.get(mode, mode)}") + logger.info(f"[TEST] Mode changed to: {mode_names.get(mode, mode)}") return True def _send_loop(self): - """Continuously send current command at 50Hz with safety timeout.""" - timeout_warned = False + """Continuously send current command at 50Hz. + The watchdog thread handles timeout and zeroing commands, so this loop + just sends whatever is in self._current_cmd at 50Hz. + """ while self.running: try: - # Safety check: If no command received recently, send zeros - time_since_last_cmd = time.time() - self.last_command_time - - if time_since_last_cmd > self.command_timeout: - # Command is stale - send zero velocities for safety - if not timeout_warned: - if self.test_mode: - print( - f"[TEST] Command timeout ({time_since_last_cmd:.1f}s) - sending zeros" - ) - timeout_warned = True - - # Create safe idle command - safe_cmd = B1Command(mode=self.current_mode) - safe_cmd.lx = 0.0 - safe_cmd.ly = 0.0 - safe_cmd.rx = 0.0 - safe_cmd.ry = 0.0 - cmd_to_send = safe_cmd - else: - # Send command if fresh - if timeout_warned: - if self.test_mode: - print("[TEST] Commands resumed - control restored") - timeout_warned = False - cmd_to_send = self._current_cmd + # Watchdog handles timeout, we just send current command + cmd_to_send = self._current_cmd + + # Log status every second (50 packets) + if self.packet_count % 50 == 0: + logger.info( + f"Sending B1 commands at 50Hz | Mode: {self.current_mode} | Count: {self.packet_count}" + ) + if not self.test_mode: + logger.debug(f"Current B1Command: {self._current_cmd}") + data = cmd_to_send.to_bytes() + hex_str = " ".join(f"{b:02x}" for b in data) + logger.debug(f"UDP packet ({len(data)} bytes): {hex_str}") if self.socket: data = cmd_to_send.to_bytes() @@ -193,12 +244,67 @@ def _send_loop(self): self.packet_count += 1 - # Maintain 50Hz rate (20ms between packets) + # 50Hz rate (20ms between packets) time.sleep(0.020) except Exception as e: if self.running: - print(f"Send error: {e}") + logger.error(f"Send error: {e}") + + def _publish_odom_pose(self, msg: Odometry): + """Convert and publish odometry as PoseStamped. + + This matches G1's approach of receiving external odometry. + """ + if self.odom_pose: + pose_stamped = PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=msg.pose.pose.position, + orientation=msg.pose.pose.orientation, + ) + self.odom_pose.publish(pose_stamped) + + def _watchdog_loop(self): + """Single watchdog thread that monitors command freshness. + + This is more efficient than creating Timer threads for every command. + Checks every 50ms if commands are stale and zeros them if needed. + """ + while self.watchdog_running: + try: + time_since_last_cmd = time.time() - self.last_command_time + + if time_since_last_cmd > self.command_timeout: + if not self.timeout_active: + # First time detecting timeout + logger.warning( + f"Watchdog timeout ({time_since_last_cmd:.1f}s) - zeroing commands" + ) + if self.test_mode: + logger.info(f"[TEST] Watchdog timeout - zeroing commands") + + # Zero velocities but maintain mode + self._current_cmd.lx = 0.0 + self._current_cmd.ly = 0.0 + self._current_cmd.rx = 0.0 + self._current_cmd.ry = 0.0 + + self.timeout_active = True + else: + if self.timeout_active: + # Commands resumed + logger.info("Watchdog: Commands resumed - control restored") + if self.test_mode: + logger.info("[TEST] Watchdog: Commands resumed") + self.timeout_active = False + + # Check every 50ms + time.sleep(0.05) + + except Exception as e: + if self.watchdog_running: + logger.error(f"Watchdog error: {e}") @rpc def idle(self): @@ -225,18 +331,21 @@ def recovery(self): return True @rpc - def move(self, twist: Twist, duration: float = 0.0): - """Direct RPC method for sending Twist commands. + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + """Direct RPC method for sending TwistStamped commands. Args: - twist: Velocity command + twist_stamped: Timestamped velocity command duration: Not used, kept for compatibility """ - self.handle_twist(twist) + self.handle_twist_stamped(twist_stamped) return True def cleanup(self): """Clean up resources when module is destroyed.""" + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None self.stop() @@ -257,18 +366,20 @@ def _send_loop(self): # Show timeout transitions if is_timeout and not timeout_warned: - print(f"[TEST] Command timeout! Sending zeros after {time_since_last_cmd:.1f}s") + logger.info( + f"[TEST] Command timeout! Sending zeros after {time_since_last_cmd:.1f}s" + ) timeout_warned = True elif not is_timeout and timeout_warned: - print("[TEST] Commands resumed - control restored") + logger.info("[TEST] Commands resumed - control restored") timeout_warned = False # Print current state every 0.5 seconds if self.packet_count % 25 == 0: if is_timeout: - print(f"[TEST] B1Cmd[ZEROS] (timeout) | Count: {self.packet_count}") + logger.info(f"[TEST] B1Cmd[ZEROS] (timeout) | Count: {self.packet_count}") else: - print(f"[TEST] {self._current_cmd} | Count: {self.packet_count}") + logger.info(f"[TEST] {self._current_cmd} | Count: {self.packet_count}") self.packet_count += 1 time.sleep(0.020) diff --git a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py index e8857e31e2..34fb5d79c8 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/joystick_module.py @@ -23,19 +23,20 @@ # Force X11 driver to avoid OpenGL threading issues os.environ["SDL_VIDEODRIVER"] = "x11" +import time from dimos.core import Module, Out, rpc -from dimos.msgs.geometry_msgs import Twist, Vector3 +from dimos.msgs.geometry_msgs import Twist, TwistStamped, Vector3 from dimos.msgs.std_msgs import Int32 class JoystickModule(Module): """Pygame-based joystick control module for B1 testing. - Outputs standard Twist messages on /cmd_vel and mode changes on /b1/mode. + Outputs timestamped Twist messages on /cmd_vel and mode changes on /b1/mode. This allows testing the same interface that navigation will use. """ - twist_out: Out[Twist] = None # Standard velocity commands + twist_out: Out[TwistStamped] = None # Timestamped velocity commands mode_out: Out[Int32] = None # Mode changes def __init__(self, *args, **kwargs): @@ -119,7 +120,13 @@ def _pygame_loop(self): stop_twist = Twist() stop_twist.linear = Vector3(0, 0, 0) stop_twist.angular = Vector3(0, 0, 0) - self.twist_out.publish(stop_twist) + stop_twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=stop_twist.linear, + angular=stop_twist.angular, + ) + self.twist_out.publish(stop_twist_stamped) print("EMERGENCY STOP!") elif event.key == pygame.K_ESCAPE: # ESC still quits for development convenience @@ -178,8 +185,10 @@ def _pygame_loop(self): if pygame.K_k in self.keys_held: twist.angular.y = -1.0 # Pitch backward - # Always publish twist at 50Hz (matching working client behavior) - self.twist_out.publish(twist) + twist_stamped = TwistStamped( + ts=time.time(), frame_id="base_link", linear=twist.linear, angular=twist.angular + ) + self.twist_out.publish(twist_stamped) # Update pygame display self._update_display(twist) @@ -253,7 +262,13 @@ def stop(self): self.running = False # Send stop command stop_twist = Twist() - self.twist_out.publish(stop_twist) + stop_twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=stop_twist.linear, + angular=stop_twist.angular, + ) + self.twist_out.publish(stop_twist_stamped) return True def cleanup(self): diff --git a/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py new file mode 100644 index 0000000000..a5cfe7976c --- /dev/null +++ b/dimos/robot/unitree_webrtc/unitree_b1/test_connection.py @@ -0,0 +1,413 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2025 Dimensional Inc. + +"""Comprehensive tests for Unitree B1 connection module Timer implementation.""" + +import time +import threading + +from .connection import TestB1ConnectionModule +from dimos.msgs.geometry_msgs import TwistStamped, Vector3 +from dimos.msgs.std_msgs.Int32 import Int32 + + +class TestB1Connection: + """Test suite for B1 connection module with Timer implementation.""" + + def test_watchdog_actually_zeros_commands(self): + """Test that watchdog thread zeros commands after timeout.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send a forward command + twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist_stamped) + + # Verify command is set + assert conn._current_cmd.ly == 1.0 + assert conn._current_cmd.mode == 2 + assert not conn.timeout_active + + # Wait for watchdog timeout (200ms + buffer) + time.sleep(0.3) + + # Verify commands were zeroed by watchdog + assert conn._current_cmd.ly == 0.0 + assert conn._current_cmd.lx == 0.0 + assert conn._current_cmd.rx == 0.0 + assert conn._current_cmd.ry == 0.0 + assert conn._current_cmd.mode == 2 # Mode maintained + assert conn.timeout_active + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_watchdog_resets_on_new_command(self): + """Test that watchdog timeout resets when new command arrives.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send first command + twist1 = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist1) + assert conn._current_cmd.ly == 1.0 + + # Wait 150ms (not enough to trigger timeout) + time.sleep(0.15) + + # Send second command before timeout + twist2 = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist2) + + # Command should be updated and no timeout + assert conn._current_cmd.ly == 0.5 + assert not conn.timeout_active + + # Wait another 150ms (total 300ms from second command) + time.sleep(0.15) + # Should still not timeout since we reset the timer + assert not conn.timeout_active + assert conn._current_cmd.ly == 0.5 + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_watchdog_thread_efficiency(self): + """Test that watchdog uses only one thread regardless of command rate.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Count threads before sending commands + initial_thread_count = threading.active_count() + + # Send many commands rapidly (would create many Timer threads in old implementation) + for i in range(50): + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(i * 0.01, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + time.sleep(0.01) # 100Hz command rate + + # Thread count should be same (no new threads created) + final_thread_count = threading.active_count() + assert final_thread_count == initial_thread_count, "No new threads should be created" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_watchdog_with_send_loop_blocking(self): + """Test that watchdog still works if send loop blocks.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + + # Mock the send loop to simulate blocking + original_send_loop = conn._send_loop + block_event = threading.Event() + + def blocking_send_loop(): + # Block immediately + block_event.wait() + # Then run normally + original_send_loop() + + conn._send_loop = blocking_send_loop + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + assert conn._current_cmd.ly == 1.0 + + # Wait for watchdog timeout + time.sleep(0.3) + + # Watchdog should have zeroed commands despite blocked send loop + assert conn._current_cmd.ly == 0.0 + assert conn.timeout_active + + # Unblock send loop + block_event.set() + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_continuous_commands_prevent_timeout(self): + """Test that continuous commands prevent watchdog timeout.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send commands continuously for 500ms (should prevent timeout) + start = time.time() + commands_sent = 0 + while time.time() - start < 0.5: + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + commands_sent += 1 + time.sleep(0.05) # 50ms between commands (well under 200ms timeout) + + # Should never timeout + assert not conn.timeout_active, "Should not timeout with continuous commands" + assert conn._current_cmd.ly == 0.5, "Commands should still be active" + assert commands_sent >= 9, f"Should send at least 9 commands in 500ms, sent {commands_sent}" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_watchdog_timing_accuracy(self): + """Test that watchdog zeros commands at approximately 200ms.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send command and record time + start_time = time.time() + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + + # Wait for timeout checking periodically + timeout_time = None + while time.time() - start_time < 0.5: + if conn.timeout_active: + timeout_time = time.time() + break + time.sleep(0.01) + + assert timeout_time is not None, "Watchdog should timeout within 500ms" + + # Check timing (should be close to 200ms + up to 50ms watchdog interval) + elapsed = timeout_time - start_time + print(f"\nWatchdog timeout occurred at exactly {elapsed:.3f} seconds") + assert 0.19 <= elapsed <= 0.26, f"Watchdog timed out at {elapsed:.3f}s, expected ~0.2-0.25s" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_mode_changes_with_watchdog(self): + """Test that mode changes work correctly with watchdog.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Send walk command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + assert conn.current_mode == 2 + assert conn._current_cmd.ly == 1.0 + + # Wait for timeout first + time.sleep(0.25) + assert conn.timeout_active + assert conn._current_cmd.ly == 0.0 # Watchdog zeroed it + + # Now change mode to STAND + mode_msg = Int32() + mode_msg.data = 1 # STAND + conn.handle_mode(mode_msg) + assert conn.current_mode == 1 + assert conn._current_cmd.mode == 1 + # timeout_active stays true since we didn't send new movement commands + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_watchdog_stops_movement_when_commands_stop(self): + """Verify watchdog zeros commands when packets stop being sent.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Simulate sending movement commands for a while + for i in range(5): + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(1.0, 0, 0), + angular=Vector3(0, 0, 0.5), # Forward and turning + ) + conn.handle_twist_stamped(twist) + time.sleep(0.05) # Send at 20Hz + + # Verify robot is moving + assert conn._current_cmd.ly == 1.0 + assert conn._current_cmd.lx == -0.25 # angular.z * 0.5 -> lx (for turning) + assert conn.current_mode == 2 # WALK mode + assert not conn.timeout_active + + # Wait for watchdog to detect timeout (200ms + buffer) + time.sleep(0.3) + + assert conn.timeout_active, "Watchdog should have detected timeout" + assert conn._current_cmd.ly == 0.0, "Forward velocity should be zeroed" + assert conn._current_cmd.lx == 0.0, "Lateral velocity should be zeroed" + assert conn._current_cmd.rx == 0.0, "Rotation X should be zeroed" + assert conn._current_cmd.ry == 0.0, "Rotation Y should be zeroed" + assert conn.current_mode == 2, "Mode should stay as WALK" + + # Verify recovery works - send new command + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(0.5, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + + # Give watchdog time to detect recovery + time.sleep(0.1) + + assert not conn.timeout_active, "Should recover from timeout" + assert conn._current_cmd.ly == 0.5, "Should accept new commands" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) + + def test_rapid_command_thread_safety(self): + """Test thread safety with rapid commands from multiple threads.""" + conn = TestB1ConnectionModule(ip="127.0.0.1", port=9090) + conn.running = True + conn.watchdog_running = True + conn.send_thread = threading.Thread(target=conn._send_loop, daemon=True) + conn.send_thread.start() + conn.watchdog_thread = threading.Thread(target=conn._watchdog_loop, daemon=True) + conn.watchdog_thread.start() + + # Count initial threads + initial_threads = threading.active_count() + + # Send commands from multiple threads rapidly + def send_commands(thread_id): + for i in range(10): + twist = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(thread_id * 0.1, 0, 0), + angular=Vector3(0, 0, 0), + ) + conn.handle_twist_stamped(twist) + time.sleep(0.01) + + threads = [] + for i in range(3): + t = threading.Thread(target=send_commands, args=(i,)) + threads.append(t) + t.start() + + for t in threads: + t.join() + + # Thread count should only increase by the 3 sender threads we created + # No additional Timer threads should be created + final_threads = threading.active_count() + assert final_threads <= initial_threads, "No extra threads should be created by watchdog" + + # Commands should still work correctly + assert conn._current_cmd.ly >= 0, "Last command should be set" + assert not conn.timeout_active, "Should not be in timeout with recent commands" + + conn.running = False + conn.watchdog_running = False + conn.send_thread.join(timeout=0.5) + conn.watchdog_thread.join(timeout=0.5) diff --git a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py index a81d75cc2f..2f93d6f973 100644 --- a/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py +++ b/dimos/robot/unitree_webrtc/unitree_b1/unitree_b1.py @@ -25,10 +25,13 @@ from typing import Optional from dimos import core -from dimos.msgs.geometry_msgs import Twist +from dimos.msgs.geometry_msgs import Twist, TwistStamped, PoseStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.msgs.std_msgs import Int32 from dimos.protocol.pubsub.lcmpubsub import LCM from dimos.robot.robot import Robot +from dimos.robot.ros_bridge import ROSBridge, BridgeDirection from dimos.robot.unitree_webrtc.unitree_b1.connection import ( B1ConnectionModule, TestB1ConnectionModule, @@ -36,6 +39,9 @@ from dimos.skills.skills import SkillLibrary from dimos.types.robot_capabilities import RobotCapability from dimos.utils.logging_config import setup_logger +from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from nav_msgs.msg import Odometry as ROSOdometry +from tf2_msgs.msg import TFMessage as ROSTFMessage logger = setup_logger("dimos.robot.unitree_webrtc.unitree_b1", level=logging.INFO) @@ -56,6 +62,7 @@ def __init__( output_dir: str = None, skill_library: Optional[SkillLibrary] = None, enable_joystick: bool = False, + enable_ros_bridge: bool = True, test_mode: bool = False, ): """Initialize the B1 robot. @@ -66,6 +73,7 @@ def __init__( output_dir: Directory for saving outputs skill_library: Skill library instance (optional) enable_joystick: Enable pygame joystick control module + enable_ros_bridge: Enable ROS bridge for external control test_mode: Test mode - print commands instead of sending UDP """ super().__init__() @@ -73,10 +81,12 @@ def __init__( self.port = port self.output_dir = output_dir or os.path.join(os.getcwd(), "assets", "output") self.enable_joystick = enable_joystick + self.enable_ros_bridge = enable_ros_bridge self.test_mode = test_mode self.capabilities = [RobotCapability.LOCOMOTION] self.connection = None self.joystick = None + self.ros_bridge = None os.makedirs(self.output_dir, exist_ok=True) logger.info(f"Robot outputs will be saved to: {self.output_dir}") @@ -93,16 +103,18 @@ def start(self): else: self.connection = self.dimos.deploy(B1ConnectionModule, self.ip, self.port) - # Configure LCM transports for connection - self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", Twist) + # Configure LCM transports for connection (matching G1 pattern) + self.connection.cmd_vel.transport = core.LCMTransport("/cmd_vel", TwistStamped) self.connection.mode_cmd.transport = core.LCMTransport("/b1/mode", Int32) + self.connection.odom_in.transport = core.LCMTransport("/state_estimation", Odometry) + self.connection.odom_pose.transport = core.LCMTransport("/odom", PoseStamped) # Deploy joystick move_vel control if self.enable_joystick: from dimos.robot.unitree_webrtc.unitree_b1.joystick_module import JoystickModule self.joystick = self.dimos.deploy(JoystickModule) - self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", Twist) + self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", TwistStamped) self.joystick.mode_out.transport = core.LCMTransport("/b1/mode", Int32) logger.info("Joystick module deployed - pygame window will open") @@ -113,20 +125,47 @@ def start(self): if self.joystick: self.joystick.start() + # Deploy ROS bridge if enabled (matching G1 pattern) + if self.enable_ros_bridge: + self._deploy_ros_bridge() + logger.info(f"UnitreeB1 initialized - UDP control to {self.ip}:{self.port}") if self.enable_joystick: logger.info("Pygame joystick module enabled for testing") + if self.enable_ros_bridge: + logger.info("ROS bridge enabled for external control") + + def _deploy_ros_bridge(self): + """Deploy and configure ROS bridge (matching G1 implementation).""" + self.ros_bridge = ROSBridge("b1_ros_bridge") + + # Add /cmd_vel topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/cmd_vel", TwistStamped, ROSTwistStamped, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /state_estimation topic from ROS to DIMOS (external odometry) + self.ros_bridge.add_topic( + "/state_estimation", Odometry, ROSOdometry, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /tf topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/tf", TFMessage, ROSTFMessage, direction=BridgeDirection.ROS_TO_DIMOS + ) + + logger.info("ROS bridge deployed: /cmd_vel, /state_estimation, /tf (ROS → DIMOS)") # Robot control methods (standard interface) - def move(self, twist: Twist, duration: float = 0.0): - """Send movement command to robot using standard Twist. + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): + """Send movement command to robot using timestamped Twist. Args: - twist: Twist message with linear and angular velocities + twist_stamped: TwistStamped message with linear and angular velocities duration: How long to move (not used for B1) """ if self.connection: - self.connection.move(twist, duration) + self.connection.move(twist_stamped, duration) def stop(self): """Stop all robot movement.""" @@ -151,22 +190,34 @@ def idle(self): self.connection.idle() logger.info("B1 switched to IDLE mode") - def cleanup(self): - """Clean up robot resources.""" - logger.info("Cleaning up B1 robot...") + def shutdown(self): + """Shutdown the robot and clean up resources.""" + logger.info("Shutting down UnitreeB1...") # Stop robot movement self.stop() - # Clean up modules + # Shutdown ROS bridge if it exists + if self.ros_bridge is not None: + try: + self.ros_bridge.shutdown() + logger.info("ROS bridge shut down successfully") + except Exception as e: + logger.error(f"Error shutting down ROS bridge: {e}") + + # Clean up connection module if self.connection: self.connection.cleanup() - logger.info("B1 cleanup complete") + logger.info("UnitreeB1 shutdown complete") + + def cleanup(self): + """Clean up robot resources (calls shutdown for consistency).""" + self.shutdown() def __del__(self): """Destructor to ensure cleanup.""" - self.cleanup() + self.shutdown() def main(): @@ -177,6 +228,10 @@ def main(): parser.add_argument("--ip", default="192.168.12.1", help="Robot IP address") parser.add_argument("--port", type=int, default=9090, help="UDP port") parser.add_argument("--joystick", action="store_true", help="Enable pygame joystick control") + parser.add_argument("--ros-bridge", action="store_true", default=True, help="Enable ROS bridge") + parser.add_argument( + "--no-ros-bridge", dest="ros_bridge", action="store_false", help="Disable ROS bridge" + ) parser.add_argument("--output-dir", help="Output directory for logs/data") parser.add_argument( "--test", action="store_true", help="Test mode - print commands instead of UDP" @@ -189,6 +244,7 @@ def main(): port=args.port, output_dir=args.output_dir, enable_joystick=args.joystick, + enable_ros_bridge=args.ros_bridge, test_mode=args.test, ) @@ -216,7 +272,10 @@ def main(): # Manual control example print("\nB1 Robot ready for commands") print("Use robot.idle(), robot.stand(), robot.walk() to change modes") - print("Use robot.move(Twist(...)) to send velocity commands") + if args.ros_bridge: + print("ROS bridge active - listening for /cmd_vel and /state_estimation") + else: + print("Use robot.move(TwistStamped(...)) to send velocity commands") print("Press Ctrl+C to exit\n") import time diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py index 2049d41c7c..1ef2709dc5 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ b/dimos/robot/unitree_webrtc/unitree_g1.py @@ -14,8 +14,8 @@ # limitations under the License. """ -Unitree G1 humanoid robot with ZED camera integration. -Minimal implementation using WebRTC connection for robot control and ZED for vision. +Unitree G1 humanoid robot. +Minimal implementation using WebRTC connection for robot control. """ import os @@ -25,18 +25,25 @@ from dimos import core from dimos.core import Module, In, Out, rpc -from dimos.hardware.zed_camera import ZEDModule -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped +from dimos.msgs.nav_msgs.Odometry import Odometry from dimos.msgs.sensor_msgs import Image from dimos_lcm.sensor_msgs import CameraInfo +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM -from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.robot.ros_bridge import ROSBridge, BridgeDirection +from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from nav_msgs.msg import Odometry as ROSOdometry +from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos.skills.skills import SkillLibrary from dimos.robot.robot import Robot + +# from dimos.hardware.zed_camera import ZEDModule from dimos.types.robot_capabilities import RobotCapability from dimos.utils.logging_config import setup_logger @@ -51,19 +58,18 @@ class G1ConnectionModule(Module): - """Simplified connection module for G1 - uses WebRTC for control, no video.""" + """Simplified connection module for G1 - uses WebRTC for control.""" + + movecmd: In[TwistStamped] = None + odom_in: In[Odometry] = None - movecmd: In[Twist] = None - odom: Out[PoseStamped] = None + odom_pose: Out[PoseStamped] = None ip: str connection_type: str = "webrtc" - _odom: PoseStamped = None - def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): self.ip = ip self.connection_type = connection_type - self.tf = TF() self.connection = None Module.__init__(self, *args, **kwargs) @@ -72,47 +78,25 @@ def start(self): """Start the connection and subscribe to sensor streams.""" # Use the exact same UnitreeWebRTCConnection as Go2 self.connection = UnitreeWebRTCConnection(self.ip) - - # Subscribe only to odometry (no video/lidar for G1) - self.connection.odom_stream().subscribe(self._publish_tf) self.movecmd.subscribe(self.move) - - def _publish_tf(self, msg): - """Publish odometry and TF transforms.""" - self._odom = msg - self.odom.publish(msg) - self.tf.publish(Transform.from_pose("base_link", msg)) - - # Publish ZED camera transform relative to robot base - zed_transform = Transform( - translation=Vector3(0.0, 0.0, 1.5), # ZED mounted at ~1.5m height on G1 - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="zed_camera", - ts=time.time(), + self.odom_in.subscribe(self._publish_odom_pose) + + def _publish_odom_pose(self, msg: Odometry): + self.odom_pose.publish( + PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=msg.pose.pose.position, + orientation=msg.pose.orientation, + ) ) - self.tf.publish(zed_transform) - - @rpc - def get_odom(self) -> Optional[PoseStamped]: - """Get the robot's odometry.""" - return self._odom @rpc - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): """Send movement command to robot.""" + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) self.connection.move(twist, duration) - @rpc - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() - - @rpc - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() - @rpc def publish_request(self, topic: str, data: dict): """Forward WebRTC publish requests to connection.""" @@ -120,26 +104,34 @@ def publish_request(self, topic: str, data: dict): class UnitreeG1(Robot): - """Unitree G1 humanoid robot with ZED camera for vision.""" + """Unitree G1 humanoid robot.""" def __init__( self, ip: str, output_dir: str = None, + websocket_port: int = 7779, skill_library: Optional[SkillLibrary] = None, recording_path: str = None, replay_path: str = None, enable_joystick: bool = False, + enable_connection: bool = True, + enable_ros_bridge: bool = True, + enable_camera: bool = False, ): """Initialize the G1 robot. Args: ip: Robot IP address output_dir: Directory for saving outputs + websocket_port: Port for web visualization skill_library: Skill library instance recording_path: Path to save recordings (if recording) replay_path: Path to replay recordings from (if replaying) enable_joystick: Enable pygame joystick control + enable_connection: Enable robot connection module + enable_ros_bridge: Enable ROS bridge + enable_camera: Enable ZED camera module """ super().__init__() self.ip = ip @@ -147,6 +139,10 @@ def __init__( self.recording_path = recording_path self.replay_path = replay_path self.enable_joystick = enable_joystick + self.enable_connection = enable_connection + self.enable_ros_bridge = enable_ros_bridge + self.enable_camera = enable_camera + self.websocket_port = websocket_port self.lcm = LCM() # Initialize skill library with G1 robot type @@ -157,14 +153,16 @@ def __init__( self.skill_library = skill_library # Set robot capabilities - self.capabilities = [RobotCapability.LOCOMOTION, RobotCapability.VISION] + self.capabilities = [RobotCapability.LOCOMOTION] # Module references self.dimos = None self.connection = None - self.zed_camera = None + self.websocket_vis = None self.foxglove_bridge = None self.joystick = None + self.ros_bridge = None + self.zed_camera = None self._setup_directories() @@ -175,32 +173,37 @@ def _setup_directories(self): def start(self): """Start the robot system with all modules.""" - self.dimos = core.start( - 3 if self.enable_joystick else 2 - ) # Extra worker for joystick if enabled + self.dimos = core.start(4) # 2 workers for connection and visualization + + if self.enable_connection: + self._deploy_connection() - self._deploy_connection() - self._deploy_camera() self._deploy_visualization() + if self.enable_camera: + self._deploy_camera() + if self.enable_joystick: self._deploy_joystick() + if self.enable_ros_bridge: + self._deploy_ros_bridge() + self._start_modules() self.lcm.start() logger.info("UnitreeG1 initialized and started") - logger.info("ZED camera module deployed for vision") + logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") def _deploy_connection(self): """Deploy and configure the connection module.""" self.connection = self.dimos.deploy(G1ConnectionModule, self.ip) # Configure LCM transports - self.connection.odom.transport = core.LCMTransport("/g1/odom", PoseStamped) - # Use standard /cmd_vel topic for compatibility with joystick and navigation - self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Twist) + self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", TwistStamped) + self.connection.odom_in.transport = core.LCMTransport("/state_estimation", Odometry) + self.connection.odom_pose.transport = core.LCMTransport("/odom", PoseStamped) def _deploy_camera(self): """Deploy and configure the ZED camera module (real or fake based on replay_path).""" @@ -241,7 +244,14 @@ def _deploy_camera(self): logger.info("ZED camera module configured") def _deploy_visualization(self): - """Deploy visualization tools.""" + """Deploy and configure visualization modules.""" + # Deploy WebSocket visualization module + self.websocket_vis = self.dimos.deploy(WebsocketVisModule, port=self.websocket_port) + self.websocket_vis.movecmd_stamped.transport = core.LCMTransport("/cmd_vel", TwistStamped) + + # Note: robot_pose connection removed since odom was removed from G1ConnectionModule + + # Deploy Foxglove bridge self.foxglove_bridge = FoxgloveBridge() def _deploy_joystick(self): @@ -253,10 +263,32 @@ def _deploy_joystick(self): self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", Twist) logger.info("Joystick module deployed - pygame window will open") + def _deploy_ros_bridge(self): + """Deploy and configure ROS bridge.""" + self.ros_bridge = ROSBridge("g1_ros_bridge") + + # Add /cmd_vel topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/cmd_vel", TwistStamped, ROSTwistStamped, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /state_estimation topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/state_estimation", Odometry, ROSOdometry, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /tf topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/tf", TFMessage, ROSTFMessage, direction=BridgeDirection.ROS_TO_DIMOS + ) + + logger.info("ROS bridge deployed: /cmd_vel, /state_estimation, /tf (ROS → DIMOS)") + def _start_modules(self): """Start all deployed modules.""" - self.connection.start() - self.zed_camera.start() + if self.connection: + self.connection.start() + self.websocket_vis.start() self.foxglove_bridge.start() if self.joystick: @@ -272,28 +304,35 @@ def _start_modules(self): self.skill_library.init() self.skill_library.initialize_skills() - def get_single_rgb_frame(self, timeout: float = 2.0) -> Image: - """Get a single RGB frame from the ZED camera.""" - from dimos.protocol.pubsub.lcmpubsub import Topic - - topic = Topic("/zed/color_image", Image) - return self.lcm.wait_for_message(topic, timeout=timeout) - - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): """Send movement command to robot.""" - self.connection.move(twist, duration) + self.connection.move(twist_stamped, duration) def get_odom(self) -> PoseStamped: """Get the robot's odometry.""" - return self.connection.get_odom() + # Note: odom functionality removed from G1ConnectionModule + return None + + def shutdown(self): + """Shutdown the robot and clean up resources.""" + logger.info("Shutting down UnitreeG1...") + + # Shutdown ROS bridge if it exists + if self.ros_bridge is not None: + try: + self.ros_bridge.shutdown() + logger.info("ROS bridge shut down successfully") + except Exception as e: + logger.error(f"Error shutting down ROS bridge: {e}") - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() + # Stop other modules if needed + if self.websocket_vis: + try: + self.websocket_vis.stop() + except Exception as e: + logger.error(f"Error stopping websocket vis: {e}") - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() + logger.info("UnitreeG1 shutdown complete") def main(): @@ -307,6 +346,7 @@ def main(): parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") parser.add_argument("--joystick", action="store_true", help="Enable pygame joystick control") + parser.add_argument("--camera", action="store_true", help="Enable ZED camera module") parser.add_argument("--output-dir", help="Output directory for logs/data") parser.add_argument("--record", help="Path to save recording") parser.add_argument("--replay", help="Path to replay recording from") @@ -325,6 +365,9 @@ def main(): recording_path=args.record, replay_path=args.replay, enable_joystick=args.joystick, + enable_camera=args.camera, + enable_connection=False, + enable_ros_bridge=True, ) robot.start() @@ -346,6 +389,7 @@ def main(): time.sleep(1) except KeyboardInterrupt: logger.info("Shutting down...") + robot.shutdown() if __name__ == "__main__": diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 25141a85ec..760edc60e6 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -20,6 +20,7 @@ import asyncio import threading +import time from typing import Any, Dict, Optional import base64 import numpy as np @@ -32,7 +33,7 @@ from dimos.core import Module, In, Out, rpc from dimos_lcm.std_msgs import Bool -from dimos.msgs.geometry_msgs import PoseStamped, Twist, Vector3 +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger @@ -68,6 +69,7 @@ class WebsocketVisModule(Module): explore_cmd: Out[Bool] = None stop_explore_cmd: Out[Bool] = None movecmd: Out[Twist] = None + movecmd_stamped: Out[TwistStamped] = None def __init__(self, port: int = 7779, **kwargs): """Initialize the WebSocket visualization module. @@ -111,9 +113,13 @@ def start(self): self.server_thread = threading.Thread(target=self._run_server, daemon=True) self.server_thread.start() - self.robot_pose.subscribe(self._on_robot_pose) - self.path.subscribe(self._on_path) - self.global_costmap.subscribe(self._on_global_costmap) + # Only subscribe to connected topics + if self.robot_pose.connection is not None: + self.robot_pose.subscribe(self._on_robot_pose) + if self.path.connection is not None: + self.path.subscribe(self._on_path) + if self.global_costmap.connection is not None: + self.global_costmap.subscribe(self._on_global_costmap) logger.info(f"WebSocket server started on http://localhost:{self.port}") @@ -167,11 +173,27 @@ async def stop_explore(sid): @self.sio.event async def move_command(sid, data): - twist = Twist( - linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), - angular=Vector3(data["angular"]["x"], data["angular"]["y"], data["angular"]["z"]), - ) - self.movecmd.publish(twist) + # Publish Twist if transport is configured + if self.movecmd and self.movecmd.transport: + twist = Twist( + linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), + angular=Vector3( + data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] + ), + ) + self.movecmd.publish(twist) + + # Publish TwistStamped if transport is configured + if self.movecmd_stamped and self.movecmd_stamped.transport: + twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), + angular=Vector3( + data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] + ), + ) + self.movecmd_stamped.publish(twist_stamped) def _run_server(self): uvicorn.run(