diff --git a/dimos/msgs/geometry_msgs/Pose.py b/dimos/msgs/geometry_msgs/Pose.py index 0706a144f6..1cf6c95442 100644 --- a/dimos/msgs/geometry_msgs/Pose.py +++ b/dimos/msgs/geometry_msgs/Pose.py @@ -14,13 +14,20 @@ from __future__ import annotations -import struct -import traceback -from io import BytesIO -from typing import BinaryIO, TypeAlias +from typing import TypeAlias from dimos_lcm.geometry_msgs import Pose as LCMPose from dimos_lcm.geometry_msgs import Transform as LCMTransform + +try: + from geometry_msgs.msg import Pose as ROSPose + from geometry_msgs.msg import Point as ROSPoint + from geometry_msgs.msg import Quaternion as ROSQuaternion +except ImportError: + ROSPose = None + ROSPoint = None + ROSQuaternion = None + from plum import dispatch from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable @@ -207,6 +214,43 @@ def __add__(self, other: "Pose" | PoseConvertable | LCMTransform | Transform) -> return Pose(new_position, new_orientation) + @classmethod + def from_ros_msg(cls, ros_msg: ROSPose) -> "Pose": + """Create a Pose from a ROS geometry_msgs/Pose message. + + Args: + ros_msg: ROS Pose message + + Returns: + Pose instance + """ + position = Vector3(ros_msg.position.x, ros_msg.position.y, ros_msg.position.z) + orientation = Quaternion( + ros_msg.orientation.x, + ros_msg.orientation.y, + ros_msg.orientation.z, + ros_msg.orientation.w, + ) + return cls(position, orientation) + + def to_ros_msg(self) -> ROSPose: + """Convert to a ROS geometry_msgs/Pose message. + + Returns: + ROS Pose message + """ + ros_msg = ROSPose() + ros_msg.position = ROSPoint( + x=float(self.position.x), y=float(self.position.y), z=float(self.position.z) + ) + ros_msg.orientation = ROSQuaternion( + x=float(self.orientation.x), + y=float(self.orientation.y), + z=float(self.orientation.z), + w=float(self.orientation.w), + ) + return ros_msg + @dispatch def to_pose(value: "Pose") -> "Pose": diff --git a/dimos/msgs/geometry_msgs/PoseStamped.py b/dimos/msgs/geometry_msgs/PoseStamped.py index ea1198818d..c44c9cd4ff 100644 --- a/dimos/msgs/geometry_msgs/PoseStamped.py +++ b/dimos/msgs/geometry_msgs/PoseStamped.py @@ -22,6 +22,12 @@ from dimos_lcm.geometry_msgs import PoseStamped as LCMPoseStamped from dimos_lcm.std_msgs import Header as LCMHeader from dimos_lcm.std_msgs import Time as LCMTime + +try: + from geometry_msgs.msg import PoseStamped as ROSPoseStamped +except ImportError: + ROSPoseStamped = None + from plum import dispatch from dimos.msgs.geometry_msgs.Pose import Pose @@ -109,3 +115,44 @@ def find_transform(self, other: PoseStamped) -> Transform: translation=local_translation, rotation=relative_rotation, ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseStamped) -> "PoseStamped": + """Create a PoseStamped from a ROS geometry_msgs/PoseStamped message. + + Args: + ros_msg: ROS PoseStamped message + + Returns: + PoseStamped instance + """ + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose + pose = Pose.from_ros_msg(ros_msg.pose) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + position=pose.position, + orientation=pose.orientation, + ) + + def to_ros_msg(self) -> ROSPoseStamped: + """Convert to a ROS geometry_msgs/PoseStamped message. + + Returns: + ROS PoseStamped message + """ + ros_msg = ROSPoseStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set pose + ros_msg.pose = Pose.to_ros_msg(self) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/PoseWithCovariance.py b/dimos/msgs/geometry_msgs/PoseWithCovariance.py new file mode 100644 index 0000000000..3a49522653 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseWithCovariance.py @@ -0,0 +1,225 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance +from plum import dispatch + +try: + from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance +except ImportError: + ROSPoseWithCovariance = None + +from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + +# Types that can be converted to/from PoseWithCovariance +PoseWithCovarianceConvertable: TypeAlias = ( + tuple[PoseConvertable, list[float] | np.ndarray] + | LCMPoseWithCovariance + | dict[str, PoseConvertable | list[float] | np.ndarray] +) + + +class PoseWithCovariance(LCMPoseWithCovariance): + pose: Pose + msg_name = "geometry_msgs.PoseWithCovariance" + + @dispatch + def __init__(self) -> None: + """Initialize with default pose and zero covariance.""" + self.pose = Pose() + self.covariance = np.zeros(36) + + @dispatch + def __init__( + self, pose: Pose | PoseConvertable, covariance: list[float] | np.ndarray | None = None + ) -> None: + """Initialize with pose and optional covariance.""" + self.pose = Pose(pose) if not isinstance(pose, Pose) else pose + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__(self, pose_with_cov: PoseWithCovariance) -> None: + """Initialize from another PoseWithCovariance (copy constructor).""" + self.pose = Pose(pose_with_cov.pose) + self.covariance = np.array(pose_with_cov.covariance).copy() + + @dispatch + def __init__(self, lcm_pose_with_cov: LCMPoseWithCovariance) -> None: + """Initialize from an LCM PoseWithCovariance.""" + self.pose = Pose(lcm_pose_with_cov.pose) + self.covariance = np.array(lcm_pose_with_cov.covariance) + + @dispatch + def __init__(self, pose_dict: dict[str, PoseConvertable | list[float] | np.ndarray]) -> None: + """Initialize from a dictionary with 'pose' and 'covariance' keys.""" + self.pose = Pose(pose_dict["pose"]) + covariance = pose_dict.get("covariance") + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__(self, pose_tuple: tuple[PoseConvertable, list[float] | np.ndarray]) -> None: + """Initialize from a tuple of (pose, covariance).""" + self.pose = Pose(pose_tuple[0]) + self.covariance = np.array(pose_tuple[1], dtype=float).reshape(36) + + def __getattribute__(self, name): + """Override to ensure covariance is always returned as numpy array.""" + if name == "covariance": + cov = object.__getattribute__(self, "covariance") + if not isinstance(cov, np.ndarray): + return np.array(cov, dtype=float) + return cov + return super().__getattribute__(name) + + def __setattr__(self, name, value): + """Override to ensure covariance is stored as numpy array.""" + if name == "covariance": + if not isinstance(value, np.ndarray): + value = np.array(value, dtype=float).reshape(36) + super().__setattr__(name, value) + + @property + def x(self) -> float: + """X coordinate of position.""" + return self.pose.x + + @property + def y(self) -> float: + """Y coordinate of position.""" + return self.pose.y + + @property + def z(self) -> float: + """Z coordinate of position.""" + return self.pose.z + + @property + def position(self) -> Vector3: + """Position vector.""" + return self.pose.position + + @property + def orientation(self) -> Quaternion: + """Orientation quaternion.""" + return self.pose.orientation + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.pose.roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.pose.pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.pose.yaw + + @property + def covariance_matrix(self) -> np.ndarray: + """Get covariance as 6x6 matrix.""" + return self.covariance.reshape(6, 6) + + @covariance_matrix.setter + def covariance_matrix(self, value: np.ndarray) -> None: + """Set covariance from 6x6 matrix.""" + self.covariance = np.array(value).reshape(36) + + def __repr__(self) -> str: + return f"PoseWithCovariance(pose={self.pose!r}, covariance=<{self.covariance.shape[0] if isinstance(self.covariance, np.ndarray) else len(self.covariance)} elements>)" + + def __str__(self) -> str: + return ( + f"PoseWithCovariance(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + def __eq__(self, other) -> bool: + """Check if two PoseWithCovariance are equal.""" + if not isinstance(other, PoseWithCovariance): + return False + return self.pose == other.pose and np.allclose(self.covariance, other.covariance) + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMPoseWithCovariance() + lcm_msg.pose = self.pose + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.covariance = self.covariance.tolist() + else: + lcm_msg.covariance = list(self.covariance) + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "PoseWithCovariance": + """Decode from LCM binary format.""" + lcm_msg = LCMPoseWithCovariance.lcm_decode(data) + pose = Pose( + position=[lcm_msg.pose.position.x, lcm_msg.pose.position.y, lcm_msg.pose.position.z], + orientation=[ + lcm_msg.pose.orientation.x, + lcm_msg.pose.orientation.y, + lcm_msg.pose.orientation.z, + lcm_msg.pose.orientation.w, + ], + ) + return cls(pose, lcm_msg.covariance) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseWithCovariance) -> "PoseWithCovariance": + """Create a PoseWithCovariance from a ROS geometry_msgs/PoseWithCovariance message. + + Args: + ros_msg: ROS PoseWithCovariance message + + Returns: + PoseWithCovariance instance + """ + + pose = Pose.from_ros_msg(ros_msg.pose) + return cls(pose, list(ros_msg.covariance)) + + def to_ros_msg(self) -> ROSPoseWithCovariance: + """Convert to a ROS geometry_msgs/PoseWithCovariance message. + + Returns: + ROS PoseWithCovariance message + """ + + ros_msg = ROSPoseWithCovariance() + ros_msg.pose = self.pose.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.covariance = self.covariance.tolist() + else: + ros_msg.covariance = list(self.covariance) + return ros_msg diff --git a/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py new file mode 100644 index 0000000000..05e1847734 --- /dev/null +++ b/dimos/msgs/geometry_msgs/PoseWithCovarianceStamped.py @@ -0,0 +1,161 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped +from plum import dispatch + +try: + from geometry_msgs.msg import PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped +except ImportError: + ROSPoseWithCovarianceStamped = None + +from dimos.msgs.geometry_msgs.Pose import Pose, PoseConvertable +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from PoseWithCovarianceStamped +PoseWithCovarianceStampedConvertable: TypeAlias = ( + tuple[PoseConvertable, list[float] | np.ndarray] + | LCMPoseWithCovarianceStamped + | dict[str, PoseConvertable | list[float] | np.ndarray | float | str] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class PoseWithCovarianceStamped(PoseWithCovariance, Timestamped): + msg_name = "geometry_msgs.PoseWithCovarianceStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + """Initialize with timestamp and frame_id.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + pose: Pose | PoseConvertable | None = None, + covariance: list[float] | np.ndarray | None = None, + ) -> None: + """Initialize with timestamp, frame_id, pose and covariance.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + if pose is None: + super().__init__() + else: + super().__init__(pose, covariance) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMPoseWithCovarianceStamped() + lcm_msg.pose.pose = self.pose + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.pose.covariance = self.covariance.tolist() + else: + lcm_msg.pose.covariance = list(self.covariance) + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> PoseWithCovarianceStamped: + lcm_msg = LCMPoseWithCovarianceStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + pose=Pose( + position=[ + lcm_msg.pose.pose.position.x, + lcm_msg.pose.pose.position.y, + lcm_msg.pose.pose.position.z, + ], + orientation=[ + lcm_msg.pose.pose.orientation.x, + lcm_msg.pose.pose.orientation.y, + lcm_msg.pose.pose.orientation.z, + lcm_msg.pose.pose.orientation.w, + ], + ), + covariance=lcm_msg.pose.covariance, + ) + + def __str__(self) -> str: + return ( + f"PoseWithCovarianceStamped(pos=[{self.x:.3f}, {self.y:.3f}, {self.z:.3f}], " + f"euler=[{self.roll:.3f}, {self.pitch:.3f}, {self.yaw:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPoseWithCovarianceStamped) -> "PoseWithCovarianceStamped": + """Create a PoseWithCovarianceStamped from a ROS geometry_msgs/PoseWithCovarianceStamped message. + + Args: + ros_msg: ROS PoseWithCovarianceStamped message + + Returns: + PoseWithCovarianceStamped instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose with covariance + pose_with_cov = PoseWithCovariance.from_ros_msg(ros_msg.pose) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + pose=pose_with_cov.pose, + covariance=pose_with_cov.covariance, + ) + + def to_ros_msg(self) -> ROSPoseWithCovarianceStamped: + """Convert to a ROS geometry_msgs/PoseWithCovarianceStamped message. + + Returns: + ROS PoseWithCovarianceStamped message + """ + + ros_msg = ROSPoseWithCovarianceStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set pose with covariance + ros_msg.pose.pose = self.pose.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.pose.covariance = self.covariance.tolist() + else: + ros_msg.pose.covariance = list(self.covariance) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/Transform.py b/dimos/msgs/geometry_msgs/Transform.py index a47c58337c..61951d34b5 100644 --- a/dimos/msgs/geometry_msgs/Transform.py +++ b/dimos/msgs/geometry_msgs/Transform.py @@ -20,6 +20,17 @@ from dimos_lcm.geometry_msgs import Transform as LCMTransform from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped +try: + from geometry_msgs.msg import TransformStamped as ROSTransformStamped + from geometry_msgs.msg import Transform as ROSTransform + from geometry_msgs.msg import Vector3 as ROSVector3 + from geometry_msgs.msg import Quaternion as ROSQuaternion +except ImportError: + ROSTransformStamped = None + ROSTransform = None + ROSVector3 = None + ROSQuaternion = None + from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.std_msgs import Header @@ -137,6 +148,70 @@ def inverse(self) -> "Transform": ts=self.ts, ) + @classmethod + def from_ros_transform_stamped(cls, ros_msg: ROSTransformStamped) -> "Transform": + """Create a Transform from a ROS geometry_msgs/TransformStamped message. + + Args: + ros_msg: ROS TransformStamped message + + Returns: + Transform instance + """ + + # Convert timestamp + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert translation + translation = Vector3( + ros_msg.transform.translation.x, + ros_msg.transform.translation.y, + ros_msg.transform.translation.z, + ) + + # Convert rotation + rotation = Quaternion( + ros_msg.transform.rotation.x, + ros_msg.transform.rotation.y, + ros_msg.transform.rotation.z, + ros_msg.transform.rotation.w, + ) + + return cls( + translation=translation, + rotation=rotation, + frame_id=ros_msg.header.frame_id, + child_frame_id=ros_msg.child_frame_id, + ts=ts, + ) + + def to_ros_transform_stamped(self) -> ROSTransformStamped: + """Convert to a ROS geometry_msgs/TransformStamped message. + + Returns: + ROS TransformStamped message + """ + + ros_msg = ROSTransformStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set child frame + ros_msg.child_frame_id = self.child_frame_id + + # Set transform + ros_msg.transform.translation = ROSVector3( + x=self.translation.x, y=self.translation.y, z=self.translation.z + ) + ros_msg.transform.rotation = ROSQuaternion( + x=self.rotation.x, y=self.rotation.y, z=self.rotation.z, w=self.rotation.w + ) + + return ros_msg + def __neg__(self) -> "Transform": """Unary minus operator returns the inverse transform.""" return self.inverse() diff --git a/dimos/msgs/geometry_msgs/Twist.py b/dimos/msgs/geometry_msgs/Twist.py index 10e07eaeb5..2b7b4206a3 100644 --- a/dimos/msgs/geometry_msgs/Twist.py +++ b/dimos/msgs/geometry_msgs/Twist.py @@ -21,6 +21,13 @@ from dimos_lcm.geometry_msgs import Twist as LCMTwist from plum import dispatch +try: + from geometry_msgs.msg import Twist as ROSTwist + from geometry_msgs.msg import Vector3 as ROSVector3 +except ImportError: + ROSTwist = None + ROSVector3 = None + from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike @@ -100,3 +107,30 @@ def __bool__(self) -> bool: False if twist is zero, True otherwise """ return not self.is_zero() + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwist) -> "Twist": + """Create a Twist from a ROS geometry_msgs/Twist message. + + Args: + ros_msg: ROS Twist message + + Returns: + Twist instance + """ + + linear = Vector3(ros_msg.linear.x, ros_msg.linear.y, ros_msg.linear.z) + angular = Vector3(ros_msg.angular.x, ros_msg.angular.y, ros_msg.angular.z) + return cls(linear, angular) + + def to_ros_msg(self) -> ROSTwist: + """Convert to a ROS geometry_msgs/Twist message. + + Returns: + ROS Twist message + """ + + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=self.linear.x, y=self.linear.y, z=self.linear.z) + ros_msg.angular = ROSVector3(x=self.angular.x, y=self.angular.y, z=self.angular.z) + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistStamped.py b/dimos/msgs/geometry_msgs/TwistStamped.py new file mode 100644 index 0000000000..5c464dfa17 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistStamped.py @@ -0,0 +1,122 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import struct +import time +from io import BytesIO +from typing import BinaryIO, TypeAlias + +from dimos_lcm.geometry_msgs import TwistStamped as LCMTwistStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime +from plum import dispatch + +try: + from geometry_msgs.msg import TwistStamped as ROSTwistStamped +except ImportError: + ROSTwistStamped = None + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from TwistStamped +TwistConvertable: TypeAlias = ( + tuple[VectorConvertable, VectorConvertable] | LCMTwistStamped | dict[str, VectorConvertable] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class TwistStamped(Twist, Timestamped): + msg_name = "geometry_msgs.TwistStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMTwistStamped() + lcm_msg.twist = self + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes | BinaryIO) -> TwistStamped: + lcm_msg = LCMTwistStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + linear=[lcm_msg.twist.linear.x, lcm_msg.twist.linear.y, lcm_msg.twist.linear.z], + angular=[lcm_msg.twist.angular.x, lcm_msg.twist.angular.y, lcm_msg.twist.angular.z], + ) + + def __str__(self) -> str: + return ( + f"TwistStamped(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}])" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistStamped) -> "TwistStamped": + """Create a TwistStamped from a ROS geometry_msgs/TwistStamped message. + + Args: + ros_msg: ROS TwistStamped message + + Returns: + TwistStamped instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert twist + twist = Twist.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + linear=twist.linear, + angular=twist.angular, + ) + + def to_ros_msg(self) -> ROSTwistStamped: + """Convert to a ROS geometry_msgs/TwistStamped message. + + Returns: + ROS TwistStamped message + """ + + ros_msg = ROSTwistStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set twist + ros_msg.twist = Twist.to_ros_msg(self) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistWithCovariance.py b/dimos/msgs/geometry_msgs/TwistWithCovariance.py new file mode 100644 index 0000000000..18237cf7b9 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistWithCovariance.py @@ -0,0 +1,225 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import TwistWithCovariance as LCMTwistWithCovariance +from plum import dispatch + +try: + from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance +except ImportError: + ROSTwistWithCovariance = None + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorConvertable + +# Types that can be converted to/from TwistWithCovariance +TwistWithCovarianceConvertable: TypeAlias = ( + tuple[Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray] + | LCMTwistWithCovariance + | dict[str, Twist | tuple[VectorConvertable, VectorConvertable] | list[float] | np.ndarray] +) + + +class TwistWithCovariance(LCMTwistWithCovariance): + twist: Twist + msg_name = "geometry_msgs.TwistWithCovariance" + + @dispatch + def __init__(self) -> None: + """Initialize with default twist and zero covariance.""" + self.twist = Twist() + self.covariance = np.zeros(36) + + @dispatch + def __init__( + self, + twist: Twist | tuple[VectorConvertable, VectorConvertable], + covariance: list[float] | np.ndarray | None = None, + ) -> None: + """Initialize with twist and optional covariance.""" + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__(self, twist_with_cov: TwistWithCovariance) -> None: + """Initialize from another TwistWithCovariance (copy constructor).""" + self.twist = Twist(twist_with_cov.twist) + self.covariance = np.array(twist_with_cov.covariance).copy() + + @dispatch + def __init__(self, lcm_twist_with_cov: LCMTwistWithCovariance) -> None: + """Initialize from an LCM TwistWithCovariance.""" + self.twist = Twist(lcm_twist_with_cov.twist) + self.covariance = np.array(lcm_twist_with_cov.covariance) + + @dispatch + def __init__( + self, + twist_dict: dict[ + str, Twist | tuple[VectorConvertable, VectorConvertable] | list[float] | np.ndarray + ], + ) -> None: + """Initialize from a dictionary with 'twist' and 'covariance' keys.""" + twist = twist_dict["twist"] + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + + covariance = twist_dict.get("covariance") + if covariance is None: + self.covariance = np.zeros(36) + else: + self.covariance = np.array(covariance, dtype=float).reshape(36) + + @dispatch + def __init__( + self, + twist_tuple: tuple[ + Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray + ], + ) -> None: + """Initialize from a tuple of (twist, covariance).""" + twist = twist_tuple[0] + if isinstance(twist, Twist): + self.twist = twist + else: + # Assume it's a tuple of (linear, angular) + self.twist = Twist(twist[0], twist[1]) + self.covariance = np.array(twist_tuple[1], dtype=float).reshape(36) + + def __getattribute__(self, name): + """Override to ensure covariance is always returned as numpy array.""" + if name == "covariance": + cov = object.__getattribute__(self, "covariance") + if not isinstance(cov, np.ndarray): + return np.array(cov, dtype=float) + return cov + return super().__getattribute__(name) + + def __setattr__(self, name, value): + """Override to ensure covariance is stored as numpy array.""" + if name == "covariance": + if not isinstance(value, np.ndarray): + value = np.array(value, dtype=float).reshape(36) + super().__setattr__(name, value) + + @property + def linear(self) -> Vector3: + """Linear velocity vector.""" + return self.twist.linear + + @property + def angular(self) -> Vector3: + """Angular velocity vector.""" + return self.twist.angular + + @property + def covariance_matrix(self) -> np.ndarray: + """Get covariance as 6x6 matrix.""" + return self.covariance.reshape(6, 6) + + @covariance_matrix.setter + def covariance_matrix(self, value: np.ndarray) -> None: + """Set covariance from 6x6 matrix.""" + self.covariance = np.array(value).reshape(36) + + def __repr__(self) -> str: + return f"TwistWithCovariance(twist={self.twist!r}, covariance=<{self.covariance.shape[0] if isinstance(self.covariance, np.ndarray) else len(self.covariance)} elements>)" + + def __str__(self) -> str: + return ( + f"TwistWithCovariance(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + def __eq__(self, other) -> bool: + """Check if two TwistWithCovariance are equal.""" + if not isinstance(other, TwistWithCovariance): + return False + return self.twist == other.twist and np.allclose(self.covariance, other.covariance) + + def is_zero(self) -> bool: + """Check if this is a zero twist (no linear or angular velocity).""" + return self.twist.is_zero() + + def __bool__(self) -> bool: + """Boolean conversion - False if zero twist, True otherwise.""" + return not self.is_zero() + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMTwistWithCovariance() + lcm_msg.twist = self.twist + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.covariance = self.covariance.tolist() + else: + lcm_msg.covariance = list(self.covariance) + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "TwistWithCovariance": + """Decode from LCM binary format.""" + lcm_msg = LCMTwistWithCovariance.lcm_decode(data) + twist = Twist( + linear=[lcm_msg.twist.linear.x, lcm_msg.twist.linear.y, lcm_msg.twist.linear.z], + angular=[lcm_msg.twist.angular.x, lcm_msg.twist.angular.y, lcm_msg.twist.angular.z], + ) + return cls(twist, lcm_msg.covariance) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistWithCovariance) -> "TwistWithCovariance": + """Create a TwistWithCovariance from a ROS geometry_msgs/TwistWithCovariance message. + + Args: + ros_msg: ROS TwistWithCovariance message + + Returns: + TwistWithCovariance instance + """ + + twist = Twist.from_ros_msg(ros_msg.twist) + return cls(twist, list(ros_msg.covariance)) + + def to_ros_msg(self) -> ROSTwistWithCovariance: + """Convert to a ROS geometry_msgs/TwistWithCovariance message. + + Returns: + ROS TwistWithCovariance message + """ + + ros_msg = ROSTwistWithCovariance() + ros_msg.twist = self.twist.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.covariance = self.covariance.tolist() + else: + ros_msg.covariance = list(self.covariance) + return ros_msg diff --git a/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py new file mode 100644 index 0000000000..1cc4c010a5 --- /dev/null +++ b/dimos/msgs/geometry_msgs/TwistWithCovarianceStamped.py @@ -0,0 +1,169 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +import numpy as np +from dimos_lcm.geometry_msgs import TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped +from plum import dispatch + +try: + from geometry_msgs.msg import TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped +except ImportError: + ROSTwistWithCovarianceStamped = None + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import VectorConvertable +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from TwistWithCovarianceStamped +TwistWithCovarianceStampedConvertable: TypeAlias = ( + tuple[Twist | tuple[VectorConvertable, VectorConvertable], list[float] | np.ndarray] + | LCMTwistWithCovarianceStamped + | dict[ + str, + Twist + | tuple[VectorConvertable, VectorConvertable] + | list[float] + | np.ndarray + | float + | str, + ] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class TwistWithCovarianceStamped(TwistWithCovariance, Timestamped): + msg_name = "geometry_msgs.TwistWithCovarianceStamped" + ts: float + frame_id: str + + @dispatch + def __init__(self, ts: float = 0.0, frame_id: str = "", **kwargs) -> None: + """Initialize with timestamp and frame_id.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + super().__init__(**kwargs) + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + twist: Twist | tuple[VectorConvertable, VectorConvertable] | None = None, + covariance: list[float] | np.ndarray | None = None, + ) -> None: + """Initialize with timestamp, frame_id, twist and covariance.""" + self.frame_id = frame_id + self.ts = ts if ts != 0 else time.time() + if twist is None: + super().__init__() + else: + super().__init__(twist, covariance) + + def lcm_encode(self) -> bytes: + lcm_msg = LCMTwistWithCovarianceStamped() + lcm_msg.twist.twist = self.twist + # LCM expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + lcm_msg.twist.covariance = self.covariance.tolist() + else: + lcm_msg.twist.covariance = list(self.covariance) + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> TwistWithCovarianceStamped: + lcm_msg = LCMTwistWithCovarianceStamped.lcm_decode(data) + return cls( + ts=lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000), + frame_id=lcm_msg.header.frame_id, + twist=Twist( + linear=[ + lcm_msg.twist.twist.linear.x, + lcm_msg.twist.twist.linear.y, + lcm_msg.twist.twist.linear.z, + ], + angular=[ + lcm_msg.twist.twist.angular.x, + lcm_msg.twist.twist.angular.y, + lcm_msg.twist.twist.angular.z, + ], + ), + covariance=lcm_msg.twist.covariance, + ) + + def __str__(self) -> str: + return ( + f"TwistWithCovarianceStamped(linear=[{self.linear.x:.3f}, {self.linear.y:.3f}, {self.linear.z:.3f}], " + f"angular=[{self.angular.x:.3f}, {self.angular.y:.3f}, {self.angular.z:.3f}], " + f"cov_trace={np.trace(self.covariance_matrix):.3f})" + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTwistWithCovarianceStamped) -> "TwistWithCovarianceStamped": + """Create a TwistWithCovarianceStamped from a ROS geometry_msgs/TwistWithCovarianceStamped message. + + Args: + ros_msg: ROS TwistWithCovarianceStamped message + + Returns: + TwistWithCovarianceStamped instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert twist with covariance + twist_with_cov = TwistWithCovariance.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + twist=twist_with_cov.twist, + covariance=twist_with_cov.covariance, + ) + + def to_ros_msg(self) -> ROSTwistWithCovarianceStamped: + """Convert to a ROS geometry_msgs/TwistWithCovarianceStamped message. + + Returns: + ROS TwistWithCovarianceStamped message + """ + + ros_msg = ROSTwistWithCovarianceStamped() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set twist with covariance + ros_msg.twist.twist = self.twist.to_ros_msg() + # ROS expects list, not numpy array + if isinstance(self.covariance, np.ndarray): + ros_msg.twist.covariance = self.covariance.tolist() + else: + ros_msg.twist.covariance = list(self.covariance) + + return ros_msg diff --git a/dimos/msgs/geometry_msgs/__init__.py b/dimos/msgs/geometry_msgs/__init__.py index 86d25bb843..de46a0a079 100644 --- a/dimos/msgs/geometry_msgs/__init__.py +++ b/dimos/msgs/geometry_msgs/__init__.py @@ -1,6 +1,11 @@ from dimos.msgs.geometry_msgs.Pose import Pose, PoseLike, to_pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped from dimos.msgs.geometry_msgs.Vector3 import Vector3, VectorLike diff --git a/dimos/msgs/geometry_msgs/test_Pose.py b/dimos/msgs/geometry_msgs/test_Pose.py index 2524f12faa..6d9c10b1c2 100644 --- a/dimos/msgs/geometry_msgs/test_Pose.py +++ b/dimos/msgs/geometry_msgs/test_Pose.py @@ -18,6 +18,15 @@ import pytest from dimos_lcm.geometry_msgs import Pose as LCMPose +try: + from geometry_msgs.msg import Pose as ROSPose + from geometry_msgs.msg import Point as ROSPoint + from geometry_msgs.msg import Quaternion as ROSQuaternion +except ImportError: + ROSPose = None + ROSPoint = None + ROSQuaternion = None + from dimos.msgs.geometry_msgs.Pose import Pose, to_pose from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.geometry_msgs.Vector3 import Vector3 @@ -747,3 +756,55 @@ def test_pose_addition_3d_rotation(): assert np.isclose(result.position.x, 1.0, atol=1e-10) # X unchanged assert np.isclose(result.position.y, cos45 - sin45, atol=1e-10) assert np.isclose(result.position.z, sin45 + cos45, atol=1e-10) + + +@pytest.mark.ros +def test_pose_from_ros_msg(): + """Test creating a Pose from a ROS Pose message.""" + ros_msg = ROSPose() + ros_msg.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + + pose = Pose.from_ros_msg(ros_msg) + + assert pose.position.x == 1.0 + assert pose.position.y == 2.0 + assert pose.position.z == 3.0 + assert pose.orientation.x == 0.1 + assert pose.orientation.y == 0.2 + assert pose.orientation.z == 0.3 + assert pose.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_to_ros_msg(): + """Test converting a Pose to a ROS Pose message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + + ros_msg = pose.to_ros_msg() + + assert isinstance(ros_msg, ROSPose) + assert ros_msg.position.x == 1.0 + assert ros_msg.position.y == 2.0 + assert ros_msg.position.z == 3.0 + assert ros_msg.orientation.x == 0.1 + assert ros_msg.orientation.y == 0.2 + assert ros_msg.orientation.z == 0.3 + assert ros_msg.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_ros_roundtrip(): + """Test round-trip conversion between Pose and ROS Pose.""" + original = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + + ros_msg = original.to_ros_msg() + restored = Pose.from_ros_msg(ros_msg) + + assert restored.position.x == original.position.x + assert restored.position.y == original.position.y + assert restored.position.z == original.position.z + assert restored.orientation.x == original.orientation.x + assert restored.orientation.y == original.orientation.y + assert restored.orientation.z == original.orientation.z + assert restored.orientation.w == original.orientation.w diff --git a/dimos/msgs/geometry_msgs/test_PoseStamped.py b/dimos/msgs/geometry_msgs/test_PoseStamped.py index 86dbf72bdc..cbc0c26876 100644 --- a/dimos/msgs/geometry_msgs/test_PoseStamped.py +++ b/dimos/msgs/geometry_msgs/test_PoseStamped.py @@ -15,6 +15,13 @@ import pickle import time +import pytest + +try: + from geometry_msgs.msg import PoseStamped as ROSPoseStamped +except ImportError: + ROSPoseStamped = None + from dimos.msgs.geometry_msgs import PoseStamped @@ -53,3 +60,80 @@ def test_pickle_encode_decode(): assert isinstance(pose_dest, PoseStamped) assert pose_dest is not pose_source assert pose_dest == pose_source + + +@pytest.mark.ros +def test_pose_stamped_from_ros_msg(): + """Test creating a PoseStamped from a ROS PoseStamped message.""" + ros_msg = ROSPoseStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.pose.position.x = 1.0 + ros_msg.pose.position.y = 2.0 + ros_msg.pose.position.z = 3.0 + ros_msg.pose.orientation.x = 0.1 + ros_msg.pose.orientation.y = 0.2 + ros_msg.pose.orientation.z = 0.3 + ros_msg.pose.orientation.w = 0.9 + + pose_stamped = PoseStamped.from_ros_msg(ros_msg) + + assert pose_stamped.frame_id == "world" + assert pose_stamped.ts == 123.456 + assert pose_stamped.position.x == 1.0 + assert pose_stamped.position.y == 2.0 + assert pose_stamped.position.z == 3.0 + assert pose_stamped.orientation.x == 0.1 + assert pose_stamped.orientation.y == 0.2 + assert pose_stamped.orientation.z == 0.3 + assert pose_stamped.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_stamped_to_ros_msg(): + """Test converting a PoseStamped to a ROS PoseStamped message.""" + pose_stamped = PoseStamped( + ts=123.456, + frame_id="base_link", + position=(1.0, 2.0, 3.0), + orientation=(0.1, 0.2, 0.3, 0.9), + ) + + ros_msg = pose_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert ros_msg.pose.position.x == 1.0 + assert ros_msg.pose.position.y == 2.0 + assert ros_msg.pose.position.z == 3.0 + assert ros_msg.pose.orientation.x == 0.1 + assert ros_msg.pose.orientation.y == 0.2 + assert ros_msg.pose.orientation.z == 0.3 + assert ros_msg.pose.orientation.w == 0.9 + + +@pytest.mark.ros +def test_pose_stamped_ros_roundtrip(): + """Test round-trip conversion between PoseStamped and ROS PoseStamped.""" + original = PoseStamped( + ts=123.789, + frame_id="odom", + position=(1.5, 2.5, 3.5), + orientation=(0.15, 0.25, 0.35, 0.85), + ) + + ros_msg = original.to_ros_msg() + restored = PoseStamped.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert restored.position.x == original.position.x + assert restored.position.y == original.position.y + assert restored.position.z == original.position.z + assert restored.orientation.x == original.orientation.x + assert restored.orientation.y == original.orientation.y + assert restored.orientation.z == original.orientation.z + assert restored.orientation.w == original.orientation.w diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py new file mode 100644 index 0000000000..dd254104a5 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovariance.py @@ -0,0 +1,388 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest +from dimos_lcm.geometry_msgs import PoseWithCovariance as LCMPoseWithCovariance + +try: + from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance + from geometry_msgs.msg import Pose as ROSPose + from geometry_msgs.msg import Point as ROSPoint + from geometry_msgs.msg import Quaternion as ROSQuaternion +except ImportError: + ROSPoseWithCovariance = None + ROSPose = None + ROSPoint = None + ROSQuaternion = None + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_pose_with_covariance_default_init(): + """Test that default initialization creates a pose at origin with zero covariance.""" + pose_cov = PoseWithCovariance() + + # Pose should be at origin with identity orientation + assert pose_cov.pose.position.x == 0.0 + assert pose_cov.pose.position.y == 0.0 + assert pose_cov.pose.position.z == 0.0 + assert pose_cov.pose.orientation.x == 0.0 + assert pose_cov.pose.orientation.y == 0.0 + assert pose_cov.pose.orientation.z == 0.0 + assert pose_cov.pose.orientation.w == 1.0 + + # Covariance should be all zeros + assert np.all(pose_cov.covariance == 0.0) + assert pose_cov.covariance.shape == (36,) + + +def test_pose_with_covariance_pose_init(): + """Test initialization with a Pose object.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = PoseWithCovariance(pose) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + + # Covariance should be zeros by default + assert np.all(pose_cov.covariance == 0.0) + + +def test_pose_with_covariance_pose_and_covariance_init(): + """Test initialization with pose and covariance.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + + # Covariance should match + assert np.array_equal(pose_cov.covariance, covariance) + + +def test_pose_with_covariance_list_covariance(): + """Test initialization with covariance as a list.""" + pose = Pose(1.0, 2.0, 3.0) + covariance_list = list(range(36)) + pose_cov = PoseWithCovariance(pose, covariance_list) + + # Covariance should be converted to numpy array + assert isinstance(pose_cov.covariance, np.ndarray) + assert np.array_equal(pose_cov.covariance, np.array(covariance_list)) + + +def test_pose_with_covariance_copy_init(): + """Test copy constructor.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + original = PoseWithCovariance(pose, covariance) + copy = PoseWithCovariance(original) + + # Should be equal but not the same object + assert copy == original + assert copy is not original + assert copy.pose is not original.pose + assert copy.covariance is not original.covariance + + # Modify original to ensure they're independent + original.covariance[0] = 999.0 + assert copy.covariance[0] != 999.0 + + +def test_pose_with_covariance_lcm_init(): + """Test initialization from LCM message.""" + lcm_msg = LCMPoseWithCovariance() + lcm_msg.pose.position.x = 1.0 + lcm_msg.pose.position.y = 2.0 + lcm_msg.pose.position.z = 3.0 + lcm_msg.pose.orientation.x = 0.1 + lcm_msg.pose.orientation.y = 0.2 + lcm_msg.pose.orientation.z = 0.3 + lcm_msg.pose.orientation.w = 0.9 + lcm_msg.covariance = list(range(36)) + + pose_cov = PoseWithCovariance(lcm_msg) + + # Pose should match + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + + # Covariance should match + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_dict_init(): + """Test initialization from dictionary.""" + pose_dict = {"pose": Pose(1.0, 2.0, 3.0), "covariance": list(range(36))} + pose_cov = PoseWithCovariance(pose_dict) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +def test_pose_with_covariance_dict_init_no_covariance(): + """Test initialization from dictionary without covariance.""" + pose_dict = {"pose": Pose(1.0, 2.0, 3.0)} + pose_cov = PoseWithCovariance(pose_dict) + + assert pose_cov.pose.position.x == 1.0 + assert np.all(pose_cov.covariance == 0.0) + + +def test_pose_with_covariance_tuple_init(): + """Test initialization from tuple.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.arange(36, dtype=float) + pose_tuple = (pose, covariance) + pose_cov = PoseWithCovariance(pose_tuple) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert np.array_equal(pose_cov.covariance, covariance) + + +def test_pose_with_covariance_properties(): + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = PoseWithCovariance(pose) + + # Position properties + assert pose_cov.x == 1.0 + assert pose_cov.y == 2.0 + assert pose_cov.z == 3.0 + assert pose_cov.position.x == 1.0 + assert pose_cov.position.y == 2.0 + assert pose_cov.position.z == 3.0 + + # Orientation properties + assert pose_cov.orientation.x == 0.1 + assert pose_cov.orientation.y == 0.2 + assert pose_cov.orientation.z == 0.3 + assert pose_cov.orientation.w == 0.9 + + # Euler angle properties + assert pose_cov.roll == pose.roll + assert pose_cov.pitch == pose.pitch + assert pose_cov.yaw == pose.yaw + + +def test_pose_with_covariance_matrix_property(): + """Test covariance matrix property.""" + pose = Pose() + covariance_array = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance_array) + + # Get as matrix + cov_matrix = pose_cov.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert cov_matrix[0, 0] == 0.0 + assert cov_matrix[5, 5] == 35.0 + + # Set from matrix + new_matrix = np.eye(6) * 2.0 + pose_cov.covariance_matrix = new_matrix + assert np.array_equal(pose_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + +def test_pose_with_covariance_repr(): + """Test string representation.""" + pose = Pose(1.234, 2.567, 3.891) + pose_cov = PoseWithCovariance(pose) + + repr_str = repr(pose_cov) + assert "PoseWithCovariance" in repr_str + assert "pose=" in repr_str + assert "covariance=" in repr_str + assert "36 elements" in repr_str + + +def test_pose_with_covariance_str(): + """Test string formatting.""" + pose = Pose(1.234, 2.567, 3.891) + covariance = np.eye(6).flatten() + pose_cov = PoseWithCovariance(pose, covariance) + + str_repr = str(pose_cov) + assert "PoseWithCovariance" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "6.000" in str_repr # Trace of identity matrix is 6 + + +def test_pose_with_covariance_equality(): + """Test equality comparison.""" + pose1 = Pose(1.0, 2.0, 3.0) + cov1 = np.arange(36, dtype=float) + pose_cov1 = PoseWithCovariance(pose1, cov1) + + pose2 = Pose(1.0, 2.0, 3.0) + cov2 = np.arange(36, dtype=float) + pose_cov2 = PoseWithCovariance(pose2, cov2) + + # Equal + assert pose_cov1 == pose_cov2 + + # Different pose + pose3 = Pose(1.1, 2.0, 3.0) + pose_cov3 = PoseWithCovariance(pose3, cov1) + assert pose_cov1 != pose_cov3 + + # Different covariance + cov3 = np.arange(36, dtype=float) + 1 + pose_cov4 = PoseWithCovariance(pose1, cov3) + assert pose_cov1 != pose_cov4 + + # Different type + assert pose_cov1 != "not a pose" + assert pose_cov1 != None + + +def test_pose_with_covariance_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + source = PoseWithCovariance(pose, covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = PoseWithCovariance.lcm_decode(binary_msg) + + # Should be equal + assert decoded == source + assert isinstance(decoded, PoseWithCovariance) + assert isinstance(decoded.pose, Pose) + assert isinstance(decoded.covariance, np.ndarray) + + +@pytest.mark.ros +def test_pose_with_covariance_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSPoseWithCovariance() + ros_msg.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.covariance = [float(i) for i in range(36)] + + pose_cov = PoseWithCovariance.from_ros_msg(ros_msg) + + assert pose_cov.pose.position.x == 1.0 + assert pose_cov.pose.position.y == 2.0 + assert pose_cov.pose.position.z == 3.0 + assert pose_cov.pose.orientation.x == 0.1 + assert pose_cov.pose.orientation.y == 0.2 + assert pose_cov.pose.orientation.z == 0.3 + assert pose_cov.pose.orientation.w == 0.9 + assert np.array_equal(pose_cov.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_to_ros_msg(): + """Test converting to ROS message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + pose_cov = PoseWithCovariance(pose, covariance) + + ros_msg = pose_cov.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseWithCovariance) + assert ros_msg.pose.position.x == 1.0 + assert ros_msg.pose.position.y == 2.0 + assert ros_msg.pose.position.z == 3.0 + assert ros_msg.pose.orientation.x == 0.1 + assert ros_msg.pose.orientation.y == 0.2 + assert ros_msg.pose.orientation.z == 0.3 + assert ros_msg.pose.orientation.w == 0.9 + assert list(ros_msg.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + covariance = np.random.rand(36) + original = PoseWithCovariance(pose, covariance) + + ros_msg = original.to_ros_msg() + restored = PoseWithCovariance.from_ros_msg(ros_msg) + + assert restored == original + + +def test_pose_with_covariance_zero_covariance(): + """Test with zero covariance matrix.""" + pose = Pose(1.0, 2.0, 3.0) + pose_cov = PoseWithCovariance(pose) + + assert np.all(pose_cov.covariance == 0.0) + assert np.trace(pose_cov.covariance_matrix) == 0.0 + + +def test_pose_with_covariance_diagonal_covariance(): + """Test with diagonal covariance matrix.""" + pose = Pose() + covariance = np.zeros(36) + # Set diagonal elements + for i in range(6): + covariance[i * 6 + i] = i + 1 + + pose_cov = PoseWithCovariance(pose, covariance) + + cov_matrix = pose_cov.covariance_matrix + assert np.trace(cov_matrix) == sum(range(1, 7)) # 1+2+3+4+5+6 = 21 + + # Check diagonal elements + for i in range(6): + assert cov_matrix[i, i] == i + 1 + + # Check off-diagonal elements are zero + for i in range(6): + for j in range(6): + if i != j: + assert cov_matrix[i, j] == 0.0 + + +@pytest.mark.parametrize( + "x,y,z", + [(0.0, 0.0, 0.0), (1.0, 2.0, 3.0), (-1.0, -2.0, -3.0), (100.0, -100.0, 0.0)], +) +def test_pose_with_covariance_parametrized_positions(x, y, z): + """Parametrized test for various position values.""" + pose = Pose(x, y, z) + pose_cov = PoseWithCovariance(pose) + + assert pose_cov.x == x + assert pose_cov.y == y + assert pose_cov.z == z diff --git a/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py new file mode 100644 index 0000000000..139279add3 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_PoseWithCovarianceStamped.py @@ -0,0 +1,371 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import PoseWithCovarianceStamped as ROSPoseWithCovarianceStamped + from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance + from geometry_msgs.msg import Pose as ROSPose + from geometry_msgs.msg import Point as ROSPoint + from geometry_msgs.msg import Quaternion as ROSQuaternion + from std_msgs.msg import Header as ROSHeader + from builtin_interfaces.msg import Time as ROSTime +except ImportError: + ROSHeader = None + ROSPoseWithCovarianceStamped = None + ROSPose = None + ROSQuaternion = None + ROSPoint = None + ROSTime = None + ROSPoseWithCovariance = None + +from dimos_lcm.geometry_msgs import PoseWithCovarianceStamped as LCMPoseWithCovarianceStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import PoseWithCovarianceStamped +from dimos.msgs.geometry_msgs.Quaternion import Quaternion +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_pose_with_covariance_stamped_default_init(): + """Test default initialization.""" + if ROSPoseWithCovariance is None: + pytest.skip("ROS not available") + if ROSTime is None: + pytest.skip("ROS not available") + if ROSPoint is None: + pytest.skip("ROS not available") + if ROSQuaternion is None: + pytest.skip("ROS not available") + if ROSPose is None: + pytest.skip("ROS not available") + if ROSPoseWithCovarianceStamped is None: + pytest.skip("ROS not available") + if ROSHeader is None: + pytest.skip("ROS not available") + pose_cov_stamped = PoseWithCovarianceStamped() + + # Should have current timestamp + assert pose_cov_stamped.ts > 0 + assert pose_cov_stamped.frame_id == "" + + # Pose should be at origin with identity orientation + assert pose_cov_stamped.pose.position.x == 0.0 + assert pose_cov_stamped.pose.position.y == 0.0 + assert pose_cov_stamped.pose.position.z == 0.0 + assert pose_cov_stamped.pose.orientation.w == 1.0 + + # Covariance should be all zeros + assert np.all(pose_cov_stamped.covariance == 0.0) + + +def test_pose_with_covariance_stamped_with_timestamp(): + """Test initialization with specific timestamp.""" + ts = 1234567890.123456 + frame_id = "base_link" + pose_cov_stamped = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id) + + assert pose_cov_stamped.ts == ts + assert pose_cov_stamped.frame_id == frame_id + + +def test_pose_with_covariance_stamped_with_pose(): + """Test initialization with pose.""" + ts = 1234567890.123456 + frame_id = "map" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + pose_cov_stamped = PoseWithCovarianceStamped( + ts=ts, frame_id=frame_id, pose=pose, covariance=covariance + ) + + assert pose_cov_stamped.ts == ts + assert pose_cov_stamped.frame_id == frame_id + assert pose_cov_stamped.pose.position.x == 1.0 + assert pose_cov_stamped.pose.position.y == 2.0 + assert pose_cov_stamped.pose.position.z == 3.0 + assert np.array_equal(pose_cov_stamped.covariance, covariance) + + +def test_pose_with_covariance_stamped_properties(): + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.eye(6).flatten() + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="odom", pose=pose, covariance=covariance + ) + + # Position properties + assert pose_cov_stamped.x == 1.0 + assert pose_cov_stamped.y == 2.0 + assert pose_cov_stamped.z == 3.0 + + # Orientation properties + assert pose_cov_stamped.orientation.x == 0.1 + assert pose_cov_stamped.orientation.y == 0.2 + assert pose_cov_stamped.orientation.z == 0.3 + assert pose_cov_stamped.orientation.w == 0.9 + + # Euler angles + assert pose_cov_stamped.roll == pose.roll + assert pose_cov_stamped.pitch == pose.pitch + assert pose_cov_stamped.yaw == pose.yaw + + # Covariance matrix + cov_matrix = pose_cov_stamped.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert np.trace(cov_matrix) == 6.0 + + +def test_pose_with_covariance_stamped_str(): + """Test string representation.""" + pose = Pose(1.234, 2.567, 3.891) + covariance = np.eye(6).flatten() * 2.0 + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="world", pose=pose, covariance=covariance + ) + + str_repr = str(pose_cov_stamped) + assert "PoseWithCovarianceStamped" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "12.000" in str_repr # Trace of 2*identity is 12 + + +def test_pose_with_covariance_stamped_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + ts = 1234567890.123456 + frame_id = "camera_link" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + source = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id, pose=pose, covariance=covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = PoseWithCovarianceStamped.lcm_decode(binary_msg) + + # Check timestamp (may lose some precision) + assert abs(decoded.ts - ts) < 1e-6 + assert decoded.frame_id == frame_id + + # Check pose + assert decoded.pose.position.x == 1.0 + assert decoded.pose.position.y == 2.0 + assert decoded.pose.position.z == 3.0 + assert decoded.pose.orientation.x == 0.1 + assert decoded.pose.orientation.y == 0.2 + assert decoded.pose.orientation.z == 0.3 + assert decoded.pose.orientation.w == 0.9 + + # Check covariance + assert np.array_equal(decoded.covariance, covariance) + + +@pytest.mark.ros +def test_pose_with_covariance_stamped_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSPoseWithCovarianceStamped() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "laser" + + # Set pose with covariance + ros_msg.pose = ROSPoseWithCovariance() + ros_msg.pose.pose = ROSPose() + ros_msg.pose.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.pose.covariance = [float(i) for i in range(36)] + + pose_cov_stamped = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + + assert pose_cov_stamped.ts == 1234567890.123456 + assert pose_cov_stamped.frame_id == "laser" + assert pose_cov_stamped.pose.position.x == 1.0 + assert pose_cov_stamped.pose.position.y == 2.0 + assert pose_cov_stamped.pose.position.z == 3.0 + assert pose_cov_stamped.pose.orientation.x == 0.1 + assert pose_cov_stamped.pose.orientation.y == 0.2 + assert pose_cov_stamped.pose.orientation.z == 0.3 + assert pose_cov_stamped.pose.orientation.w == 0.9 + assert np.array_equal(pose_cov_stamped.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_stamped_to_ros_msg(): + """Test converting to ROS message.""" + ts = 1234567890.567890 + frame_id = "imu" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + covariance = np.arange(36, dtype=float) + + pose_cov_stamped = PoseWithCovarianceStamped( + ts=ts, frame_id=frame_id, pose=pose, covariance=covariance + ) + + ros_msg = pose_cov_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSPoseWithCovarianceStamped) + assert ros_msg.header.frame_id == frame_id + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + + assert ros_msg.pose.pose.position.x == 1.0 + assert ros_msg.pose.pose.position.y == 2.0 + assert ros_msg.pose.pose.position.z == 3.0 + assert ros_msg.pose.pose.orientation.x == 0.1 + assert ros_msg.pose.pose.orientation.y == 0.2 + assert ros_msg.pose.pose.orientation.z == 0.3 + assert ros_msg.pose.pose.orientation.w == 0.9 + assert list(ros_msg.pose.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_pose_with_covariance_stamped_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + ts = 2147483647.987654 # Max int32 value for ROS Time.sec + frame_id = "robot_base" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + covariance = np.random.rand(36) + + original = PoseWithCovarianceStamped(ts=ts, frame_id=frame_id, pose=pose, covariance=covariance) + + ros_msg = original.to_ros_msg() + restored = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + + # Check timestamp (loses some precision in conversion) + assert abs(restored.ts - ts) < 1e-6 + assert restored.frame_id == frame_id + + # Check pose + assert restored.pose.position.x == original.pose.position.x + assert restored.pose.position.y == original.pose.position.y + assert restored.pose.position.z == original.pose.position.z + assert restored.pose.orientation.x == original.pose.orientation.x + assert restored.pose.orientation.y == original.pose.orientation.y + assert restored.pose.orientation.z == original.pose.orientation.z + assert restored.pose.orientation.w == original.pose.orientation.w + + # Check covariance + assert np.allclose(restored.covariance, original.covariance) + + +def test_pose_with_covariance_stamped_zero_timestamp(): + """Test that zero timestamp gets replaced with current time.""" + pose_cov_stamped = PoseWithCovarianceStamped(ts=0.0) + + # Should have been replaced with current time + assert pose_cov_stamped.ts > 0 + assert pose_cov_stamped.ts <= time.time() + + +def test_pose_with_covariance_stamped_inheritance(): + """Test that it properly inherits from PoseWithCovariance and Timestamped.""" + pose = Pose(1.0, 2.0, 3.0) + covariance = np.eye(6).flatten() + pose_cov_stamped = PoseWithCovarianceStamped( + ts=1234567890.0, frame_id="test", pose=pose, covariance=covariance + ) + + # Should be instance of parent classes + assert isinstance(pose_cov_stamped, PoseWithCovariance) + + # Should have Timestamped attributes + assert hasattr(pose_cov_stamped, "ts") + assert hasattr(pose_cov_stamped, "frame_id") + + # Should have PoseWithCovariance attributes + assert hasattr(pose_cov_stamped, "pose") + assert hasattr(pose_cov_stamped, "covariance") + + +def test_pose_with_covariance_stamped_sec_nsec(): + """Test the sec_nsec helper function.""" + from dimos.msgs.geometry_msgs.PoseWithCovarianceStamped import sec_nsec + + # Test integer seconds + s, ns = sec_nsec(1234567890.0) + assert s == 1234567890 + assert ns == 0 + + # Test fractional seconds + s, ns = sec_nsec(1234567890.123456789) + assert s == 1234567890 + assert abs(ns - 123456789) < 100 # Allow small rounding error + + # Test small fractional seconds + s, ns = sec_nsec(0.000000001) + assert s == 0 + assert ns == 1 + + # Test large timestamp + s, ns = sec_nsec(9999999999.999999999) + # Due to floating point precision, this might round to 10000000000 + assert s in [9999999999, 10000000000] + if s == 9999999999: + assert abs(ns - 999999999) < 10 + else: + assert ns == 0 + + +@pytest.mark.ros +@pytest.mark.parametrize( + "frame_id", + ["", "map", "odom", "base_link", "camera_optical_frame", "sensor/lidar/front"], +) +def test_pose_with_covariance_stamped_frame_ids(frame_id): + """Test various frame ID values.""" + pose_cov_stamped = PoseWithCovarianceStamped(frame_id=frame_id) + assert pose_cov_stamped.frame_id == frame_id + + # Test roundtrip through ROS + ros_msg = pose_cov_stamped.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + + restored = PoseWithCovarianceStamped.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + + +def test_pose_with_covariance_stamped_different_covariances(): + """Test with different covariance patterns.""" + pose = Pose(1.0, 2.0, 3.0) + + # Zero covariance + zero_cov = np.zeros(36) + pose_cov1 = PoseWithCovarianceStamped(pose=pose, covariance=zero_cov) + assert np.all(pose_cov1.covariance == 0.0) + + # Identity covariance + identity_cov = np.eye(6).flatten() + pose_cov2 = PoseWithCovarianceStamped(pose=pose, covariance=identity_cov) + assert np.trace(pose_cov2.covariance_matrix) == 6.0 + + # Full covariance + full_cov = np.random.rand(36) + pose_cov3 = PoseWithCovarianceStamped(pose=pose, covariance=full_cov) + assert np.array_equal(pose_cov3.covariance, full_cov) diff --git a/dimos/msgs/geometry_msgs/test_Transform.py b/dimos/msgs/geometry_msgs/test_Transform.py index e069305bca..f09f0c2966 100644 --- a/dimos/msgs/geometry_msgs/test_Transform.py +++ b/dimos/msgs/geometry_msgs/test_Transform.py @@ -17,6 +17,12 @@ import numpy as np import pytest + +try: + from geometry_msgs.msg import TransformStamped as ROSTransformStamped +except ImportError: + ROSTransformStamped = None + from dimos_lcm.geometry_msgs import Transform as LCMTransform from dimos_lcm.geometry_msgs import TransformStamped as LCMTransformStamped @@ -421,3 +427,86 @@ def test_transform_from_pose_invalid_type(): with pytest.raises(TypeError): Transform.from_pose(None) + + +@pytest.mark.ros +def test_transform_from_ros_transform_stamped(): + """Test creating a Transform from a ROS TransformStamped message.""" + ros_msg = ROSTransformStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.child_frame_id = "robot" + ros_msg.transform.translation.x = 1.0 + ros_msg.transform.translation.y = 2.0 + ros_msg.transform.translation.z = 3.0 + ros_msg.transform.rotation.x = 0.1 + ros_msg.transform.rotation.y = 0.2 + ros_msg.transform.rotation.z = 0.3 + ros_msg.transform.rotation.w = 0.9 + + transform = Transform.from_ros_transform_stamped(ros_msg) + + assert transform.frame_id == "world" + assert transform.child_frame_id == "robot" + assert transform.ts == 123.456 + assert transform.translation.x == 1.0 + assert transform.translation.y == 2.0 + assert transform.translation.z == 3.0 + assert transform.rotation.x == 0.1 + assert transform.rotation.y == 0.2 + assert transform.rotation.z == 0.3 + assert transform.rotation.w == 0.9 + + +@pytest.mark.ros +def test_transform_to_ros_transform_stamped(): + """Test converting a Transform to a ROS TransformStamped message.""" + transform = Transform( + translation=Vector3(4.0, 5.0, 6.0), + rotation=Quaternion(0.15, 0.25, 0.35, 0.85), + frame_id="base_link", + child_frame_id="sensor", + ts=124.789, + ) + + ros_msg = transform.to_ros_transform_stamped() + + assert isinstance(ros_msg, ROSTransformStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.child_frame_id == "sensor" + assert ros_msg.header.stamp.sec == 124 + assert ros_msg.header.stamp.nanosec == 789000000 + assert ros_msg.transform.translation.x == 4.0 + assert ros_msg.transform.translation.y == 5.0 + assert ros_msg.transform.translation.z == 6.0 + assert ros_msg.transform.rotation.x == 0.15 + assert ros_msg.transform.rotation.y == 0.25 + assert ros_msg.transform.rotation.z == 0.35 + assert ros_msg.transform.rotation.w == 0.85 + + +@pytest.mark.ros +def test_transform_ros_roundtrip(): + """Test round-trip conversion between Transform and ROS TransformStamped.""" + original = Transform( + translation=Vector3(7.5, 8.5, 9.5), + rotation=Quaternion(0.0, 0.0, 0.383, 0.924), # ~45 degrees around Z + frame_id="odom", + child_frame_id="base_footprint", + ts=99.123, + ) + + ros_msg = original.to_ros_transform_stamped() + restored = Transform.from_ros_transform_stamped(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.child_frame_id == original.child_frame_id + assert restored.ts == original.ts + assert restored.translation.x == original.translation.x + assert restored.translation.y == original.translation.y + assert restored.translation.z == original.translation.z + assert restored.rotation.x == original.rotation.x + assert restored.rotation.y == original.rotation.y + assert restored.rotation.z == original.rotation.z + assert restored.rotation.w == original.rotation.w diff --git a/dimos/msgs/geometry_msgs/test_Twist.py b/dimos/msgs/geometry_msgs/test_Twist.py index 2e57523826..5f463d0bac 100644 --- a/dimos/msgs/geometry_msgs/test_Twist.py +++ b/dimos/msgs/geometry_msgs/test_Twist.py @@ -14,6 +14,14 @@ import numpy as np import pytest + +try: + from geometry_msgs.msg import Twist as ROSTwist + from geometry_msgs.msg import Vector3 as ROSVector3 +except ImportError: + ROSTwist = None + ROSVector3 = None + from dimos_lcm.geometry_msgs import Twist as LCMTwist from dimos.msgs.geometry_msgs import Quaternion, Twist, Vector3 @@ -198,3 +206,97 @@ def test_twist_with_lists(): tw2 = Twist(linear=np.array([4, 5, 6]), angular=np.array([0.4, 0.5, 0.6])) assert tw2.linear == Vector3(4, 5, 6) assert tw2.angular == Vector3(0.4, 0.5, 0.6) + + +@pytest.mark.ros +def test_twist_from_ros_msg(): + """Test Twist.from_ros_msg conversion.""" + # Create ROS message + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=10.0, y=20.0, z=30.0) + ros_msg.angular = ROSVector3(x=1.0, y=2.0, z=3.0) + + # Convert to LCM + lcm_msg = Twist.from_ros_msg(ros_msg) + + assert isinstance(lcm_msg, Twist) + assert lcm_msg.linear.x == 10.0 + assert lcm_msg.linear.y == 20.0 + assert lcm_msg.linear.z == 30.0 + assert lcm_msg.angular.x == 1.0 + assert lcm_msg.angular.y == 2.0 + assert lcm_msg.angular.z == 3.0 + + +@pytest.mark.ros +def test_twist_to_ros_msg(): + """Test Twist.to_ros_msg conversion.""" + # Create LCM message + lcm_msg = Twist(linear=Vector3(40.0, 50.0, 60.0), angular=Vector3(4.0, 5.0, 6.0)) + + # Convert to ROS + ros_msg = lcm_msg.to_ros_msg() + + assert isinstance(ros_msg, ROSTwist) + assert ros_msg.linear.x == 40.0 + assert ros_msg.linear.y == 50.0 + assert ros_msg.linear.z == 60.0 + assert ros_msg.angular.x == 4.0 + assert ros_msg.angular.y == 5.0 + assert ros_msg.angular.z == 6.0 + + +@pytest.mark.ros +def test_ros_zero_twist_conversion(): + """Test conversion of zero twist messages between ROS and LCM.""" + # Test ROS to LCM with zero twist + ros_zero = ROSTwist() + lcm_zero = Twist.from_ros_msg(ros_zero) + assert lcm_zero.is_zero() + + # Test LCM to ROS with zero twist + lcm_zero2 = Twist.zero() + ros_zero2 = lcm_zero2.to_ros_msg() + assert ros_zero2.linear.x == 0.0 + assert ros_zero2.linear.y == 0.0 + assert ros_zero2.linear.z == 0.0 + assert ros_zero2.angular.x == 0.0 + assert ros_zero2.angular.y == 0.0 + assert ros_zero2.angular.z == 0.0 + + +@pytest.mark.ros +def test_ros_negative_values_conversion(): + """Test ROS conversion with negative values.""" + # Create ROS message with negative values + ros_msg = ROSTwist() + ros_msg.linear = ROSVector3(x=-1.5, y=-2.5, z=-3.5) + ros_msg.angular = ROSVector3(x=-0.1, y=-0.2, z=-0.3) + + # Convert to LCM and back + lcm_msg = Twist.from_ros_msg(ros_msg) + ros_msg2 = lcm_msg.to_ros_msg() + + assert ros_msg2.linear.x == -1.5 + assert ros_msg2.linear.y == -2.5 + assert ros_msg2.linear.z == -3.5 + assert ros_msg2.angular.x == -0.1 + assert ros_msg2.angular.y == -0.2 + assert ros_msg2.angular.z == -0.3 + + +@pytest.mark.ros +def test_ros_roundtrip_conversion(): + """Test round-trip conversion maintains data integrity.""" + # LCM -> ROS -> LCM + original_lcm = Twist(linear=Vector3(1.234, 5.678, 9.012), angular=Vector3(0.111, 0.222, 0.333)) + ros_intermediate = original_lcm.to_ros_msg() + final_lcm = Twist.from_ros_msg(ros_intermediate) + + assert final_lcm == original_lcm + assert final_lcm.linear.x == 1.234 + assert final_lcm.linear.y == 5.678 + assert final_lcm.linear.z == 9.012 + assert final_lcm.angular.x == 0.111 + assert final_lcm.angular.y == 0.222 + assert final_lcm.angular.z == 0.333 diff --git a/dimos/msgs/geometry_msgs/test_TwistStamped.py b/dimos/msgs/geometry_msgs/test_TwistStamped.py new file mode 100644 index 0000000000..8414d4480a --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistStamped.py @@ -0,0 +1,158 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pickle +import time + + +try: + from geometry_msgs.msg import TwistStamped as ROSTwistStamped +except ImportError: + ROSTwistStamped = None + +from dimos.msgs.geometry_msgs.TwistStamped import TwistStamped + + +def test_lcm_encode_decode(): + """Test encoding and decoding of TwistStamped to/from binary LCM format.""" + twist_source = TwistStamped( + ts=time.time(), + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + binary_msg = twist_source.lcm_encode() + twist_dest = TwistStamped.lcm_decode(binary_msg) + + assert isinstance(twist_dest, TwistStamped) + assert twist_dest is not twist_source + + print(twist_source.linear) + print(twist_source.angular) + + print(twist_dest.linear) + print(twist_dest.angular) + assert twist_dest == twist_source + + +def test_pickle_encode_decode(): + """Test encoding and decoding of TwistStamped to/from binary pickle format.""" + + twist_source = TwistStamped( + ts=time.time(), + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + binary_msg = pickle.dumps(twist_source) + twist_dest = pickle.loads(binary_msg) + assert isinstance(twist_dest, TwistStamped) + assert twist_dest is not twist_source + assert twist_dest == twist_source + + +@pytest.mark.ros +def test_twist_stamped_from_ros_msg(): + """Test creating a TwistStamped from a ROS TwistStamped message.""" + ros_msg = ROSTwistStamped() + ros_msg.header.frame_id = "world" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + ros_msg.twist.linear.x = 1.0 + ros_msg.twist.linear.y = 2.0 + ros_msg.twist.linear.z = 3.0 + ros_msg.twist.angular.x = 0.1 + ros_msg.twist.angular.y = 0.2 + ros_msg.twist.angular.z = 0.3 + + twist_stamped = TwistStamped.from_ros_msg(ros_msg) + + assert twist_stamped.frame_id == "world" + assert twist_stamped.ts == 123.456 + assert twist_stamped.linear.x == 1.0 + assert twist_stamped.linear.y == 2.0 + assert twist_stamped.linear.z == 3.0 + assert twist_stamped.angular.x == 0.1 + assert twist_stamped.angular.y == 0.2 + assert twist_stamped.angular.z == 0.3 + + +@pytest.mark.ros +def test_twist_stamped_to_ros_msg(): + """Test converting a TwistStamped to a ROS TwistStamped message.""" + twist_stamped = TwistStamped( + ts=123.456, + frame_id="base_link", + linear=(1.0, 2.0, 3.0), + angular=(0.1, 0.2, 0.3), + ) + + ros_msg = twist_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistStamped) + assert ros_msg.header.frame_id == "base_link" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert ros_msg.twist.linear.x == 1.0 + assert ros_msg.twist.linear.y == 2.0 + assert ros_msg.twist.linear.z == 3.0 + assert ros_msg.twist.angular.x == 0.1 + assert ros_msg.twist.angular.y == 0.2 + assert ros_msg.twist.angular.z == 0.3 + + +@pytest.mark.ros +def test_twist_stamped_ros_roundtrip(): + """Test round-trip conversion between TwistStamped and ROS TwistStamped.""" + original = TwistStamped( + ts=123.789, + frame_id="odom", + linear=(1.5, 2.5, 3.5), + angular=(0.15, 0.25, 0.35), + ) + + ros_msg = original.to_ros_msg() + restored = TwistStamped.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert restored.linear.x == original.linear.x + assert restored.linear.y == original.linear.y + assert restored.linear.z == original.linear.z + assert restored.angular.x == original.angular.x + assert restored.angular.y == original.angular.y + assert restored.angular.z == original.angular.z + + +if __name__ == "__main__": + print("Running test_lcm_encode_decode...") + test_lcm_encode_decode() + print("✓ test_lcm_encode_decode passed") + + print("Running test_pickle_encode_decode...") + test_pickle_encode_decode() + print("✓ test_pickle_encode_decode passed") + + print("Running test_twist_stamped_from_ros_msg...") + test_twist_stamped_from_ros_msg() + print("✓ test_twist_stamped_from_ros_msg passed") + + print("Running test_twist_stamped_to_ros_msg...") + test_twist_stamped_to_ros_msg() + print("✓ test_twist_stamped_to_ros_msg passed") + + print("Running test_twist_stamped_ros_roundtrip...") + test_twist_stamped_ros_roundtrip() + print("✓ test_twist_stamped_ros_roundtrip passed") + + print("\nAll tests passed!") diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py new file mode 100644 index 0000000000..d001482062 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovariance.py @@ -0,0 +1,421 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance + from geometry_msgs.msg import Twist as ROSTwist + from geometry_msgs.msg import Vector3 as ROSVector3 +except ImportError: + ROSTwist = None + ROSTwistWithCovariance = None + ROSVector3 = None + +from dimos_lcm.geometry_msgs import TwistWithCovariance as LCMTwistWithCovariance + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_twist_with_covariance_default_init(): + """Test that default initialization creates a zero twist with zero covariance.""" + if ROSVector3 is None: + pytest.skip("ROS not available") + if ROSTwistWithCovariance is None: + pytest.skip("ROS not available") + twist_cov = TwistWithCovariance() + + # Twist should be zero + assert twist_cov.twist.linear.x == 0.0 + assert twist_cov.twist.linear.y == 0.0 + assert twist_cov.twist.linear.z == 0.0 + assert twist_cov.twist.angular.x == 0.0 + assert twist_cov.twist.angular.y == 0.0 + assert twist_cov.twist.angular.z == 0.0 + + # Covariance should be all zeros + assert np.all(twist_cov.covariance == 0.0) + assert twist_cov.covariance.shape == (36,) + + +def test_twist_with_covariance_twist_init(): + """Test initialization with a Twist object.""" + linear = Vector3(1.0, 2.0, 3.0) + angular = Vector3(0.1, 0.2, 0.3) + twist = Twist(linear, angular) + twist_cov = TwistWithCovariance(twist) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should be zeros by default + assert np.all(twist_cov.covariance == 0.0) + + +def test_twist_with_covariance_twist_and_covariance_init(): + """Test initialization with twist and covariance.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_tuple_init(): + """Test initialization with tuple of (linear, angular) velocities.""" + linear = [1.0, 2.0, 3.0] + angular = [0.1, 0.2, 0.3] + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance((linear, angular), covariance) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_list_covariance(): + """Test initialization with covariance as a list.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance_list = list(range(36)) + twist_cov = TwistWithCovariance(twist, covariance_list) + + # Covariance should be converted to numpy array + assert isinstance(twist_cov.covariance, np.ndarray) + assert np.array_equal(twist_cov.covariance, np.array(covariance_list)) + + +def test_twist_with_covariance_copy_init(): + """Test copy constructor.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + original = TwistWithCovariance(twist, covariance) + copy = TwistWithCovariance(original) + + # Should be equal but not the same object + assert copy == original + assert copy is not original + assert copy.twist is not original.twist + assert copy.covariance is not original.covariance + + # Modify original to ensure they're independent + original.covariance[0] = 999.0 + assert copy.covariance[0] != 999.0 + + +def test_twist_with_covariance_lcm_init(): + """Test initialization from LCM message.""" + lcm_msg = LCMTwistWithCovariance() + lcm_msg.twist.linear.x = 1.0 + lcm_msg.twist.linear.y = 2.0 + lcm_msg.twist.linear.z = 3.0 + lcm_msg.twist.angular.x = 0.1 + lcm_msg.twist.angular.y = 0.2 + lcm_msg.twist.angular.z = 0.3 + lcm_msg.covariance = list(range(36)) + + twist_cov = TwistWithCovariance(lcm_msg) + + # Twist should match + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + + # Covariance should match + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_dict_init(): + """Test initialization from dictionary.""" + twist_dict = { + "twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)), + "covariance": list(range(36)), + } + twist_cov = TwistWithCovariance(twist_dict) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +def test_twist_with_covariance_dict_init_no_covariance(): + """Test initialization from dictionary without covariance.""" + twist_dict = {"twist": Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3))} + twist_cov = TwistWithCovariance(twist_dict) + + assert twist_cov.twist.linear.x == 1.0 + assert np.all(twist_cov.covariance == 0.0) + + +def test_twist_with_covariance_tuple_of_tuple_init(): + """Test initialization from tuple of (twist_tuple, covariance).""" + twist_tuple = ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance((twist_tuple, covariance)) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + assert np.array_equal(twist_cov.covariance, covariance) + + +def test_twist_with_covariance_properties(): + """Test convenience properties.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + # Linear and angular properties + assert twist_cov.linear.x == 1.0 + assert twist_cov.linear.y == 2.0 + assert twist_cov.linear.z == 3.0 + assert twist_cov.angular.x == 0.1 + assert twist_cov.angular.y == 0.2 + assert twist_cov.angular.z == 0.3 + + +def test_twist_with_covariance_matrix_property(): + """Test covariance matrix property.""" + twist = Twist() + covariance_array = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance_array) + + # Get as matrix + cov_matrix = twist_cov.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert cov_matrix[0, 0] == 0.0 + assert cov_matrix[5, 5] == 35.0 + + # Set from matrix + new_matrix = np.eye(6) * 2.0 + twist_cov.covariance_matrix = new_matrix + assert np.array_equal(twist_cov.covariance[:6], [2.0, 0.0, 0.0, 0.0, 0.0, 0.0]) + + +def test_twist_with_covariance_repr(): + """Test string representation.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + repr_str = repr(twist_cov) + assert "TwistWithCovariance" in repr_str + assert "twist=" in repr_str + assert "covariance=" in repr_str + assert "36 elements" in repr_str + + +def test_twist_with_covariance_str(): + """Test string formatting.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov = TwistWithCovariance(twist, covariance) + + str_repr = str(twist_cov) + assert "TwistWithCovariance" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "6.000" in str_repr # Trace of identity matrix is 6 + + +def test_twist_with_covariance_equality(): + """Test equality comparison.""" + twist1 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + cov1 = np.arange(36, dtype=float) + twist_cov1 = TwistWithCovariance(twist1, cov1) + + twist2 = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + cov2 = np.arange(36, dtype=float) + twist_cov2 = TwistWithCovariance(twist2, cov2) + + # Equal + assert twist_cov1 == twist_cov2 + + # Different twist + twist3 = Twist(Vector3(1.1, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov3 = TwistWithCovariance(twist3, cov1) + assert twist_cov1 != twist_cov3 + + # Different covariance + cov3 = np.arange(36, dtype=float) + 1 + twist_cov4 = TwistWithCovariance(twist1, cov3) + assert twist_cov1 != twist_cov4 + + # Different type + assert twist_cov1 != "not a twist" + assert twist_cov1 != None + + +def test_twist_with_covariance_is_zero(): + """Test is_zero method.""" + # Zero twist + twist_cov1 = TwistWithCovariance() + assert twist_cov1.is_zero() + assert not twist_cov1 # Boolean conversion + + # Non-zero twist + twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.0)) + twist_cov2 = TwistWithCovariance(twist) + assert not twist_cov2.is_zero() + assert twist_cov2 # Boolean conversion + + +def test_twist_with_covariance_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + source = TwistWithCovariance(twist, covariance) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = TwistWithCovariance.lcm_decode(binary_msg) + + # Should be equal + assert decoded == source + assert isinstance(decoded, TwistWithCovariance) + assert isinstance(decoded.twist, Twist) + assert isinstance(decoded.covariance, np.ndarray) + + +@pytest.mark.ros +def test_twist_with_covariance_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSTwistWithCovariance() + ros_msg.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.covariance = [float(i) for i in range(36)] + + twist_cov = TwistWithCovariance.from_ros_msg(ros_msg) + + assert twist_cov.twist.linear.x == 1.0 + assert twist_cov.twist.linear.y == 2.0 + assert twist_cov.twist.linear.z == 3.0 + assert twist_cov.twist.angular.x == 0.1 + assert twist_cov.twist.angular.y == 0.2 + assert twist_cov.twist.angular.z == 0.3 + assert np.array_equal(twist_cov.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_to_ros_msg(): + """Test converting to ROS message.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + twist_cov = TwistWithCovariance(twist, covariance) + + ros_msg = twist_cov.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistWithCovariance) + assert ros_msg.twist.linear.x == 1.0 + assert ros_msg.twist.linear.y == 2.0 + assert ros_msg.twist.linear.z == 3.0 + assert ros_msg.twist.angular.x == 0.1 + assert ros_msg.twist.angular.y == 0.2 + assert ros_msg.twist.angular.z == 0.3 + assert list(ros_msg.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) + covariance = np.random.rand(36) + original = TwistWithCovariance(twist, covariance) + + ros_msg = original.to_ros_msg() + restored = TwistWithCovariance.from_ros_msg(ros_msg) + + assert restored == original + + +def test_twist_with_covariance_zero_covariance(): + """Test with zero covariance matrix.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + twist_cov = TwistWithCovariance(twist) + + assert np.all(twist_cov.covariance == 0.0) + assert np.trace(twist_cov.covariance_matrix) == 0.0 + + +def test_twist_with_covariance_diagonal_covariance(): + """Test with diagonal covariance matrix.""" + twist = Twist() + covariance = np.zeros(36) + # Set diagonal elements + for i in range(6): + covariance[i * 6 + i] = i + 1 + + twist_cov = TwistWithCovariance(twist, covariance) + + cov_matrix = twist_cov.covariance_matrix + assert np.trace(cov_matrix) == sum(range(1, 7)) # 1+2+3+4+5+6 = 21 + + # Check diagonal elements + for i in range(6): + assert cov_matrix[i, i] == i + 1 + + # Check off-diagonal elements are zero + for i in range(6): + for j in range(6): + if i != j: + assert cov_matrix[i, j] == 0.0 + + +@pytest.mark.parametrize( + "linear,angular", + [ + ([0.0, 0.0, 0.0], [0.0, 0.0, 0.0]), + ([1.0, 2.0, 3.0], [0.1, 0.2, 0.3]), + ([-1.0, -2.0, -3.0], [-0.1, -0.2, -0.3]), + ([100.0, -100.0, 0.0], [3.14, -3.14, 0.0]), + ], +) +def test_twist_with_covariance_parametrized_velocities(linear, angular): + """Parametrized test for various velocity values.""" + twist = Twist(linear, angular) + twist_cov = TwistWithCovariance(twist) + + assert twist_cov.linear.x == linear[0] + assert twist_cov.linear.y == linear[1] + assert twist_cov.linear.z == linear[2] + assert twist_cov.angular.x == angular[0] + assert twist_cov.angular.y == angular[1] + assert twist_cov.angular.z == angular[2] diff --git a/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py new file mode 100644 index 0000000000..4174814c78 --- /dev/null +++ b/dimos/msgs/geometry_msgs/test_TwistWithCovarianceStamped.py @@ -0,0 +1,393 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest + +try: + from geometry_msgs.msg import TwistWithCovarianceStamped as ROSTwistWithCovarianceStamped + from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance + from geometry_msgs.msg import Twist as ROSTwist + from geometry_msgs.msg import Vector3 as ROSVector3 + from std_msgs.msg import Header as ROSHeader + from builtin_interfaces.msg import Time as ROSTime +except ImportError: + ROSTwistWithCovarianceStamped = None + ROSTwist = None + ROSHeader = None + ROSTime = None + ROSTwistWithCovariance = None + ROSVector3 = None + +from dimos_lcm.geometry_msgs import TwistWithCovarianceStamped as LCMTwistWithCovarianceStamped +from dimos_lcm.std_msgs import Header as LCMHeader +from dimos_lcm.std_msgs import Time as LCMTime + +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import TwistWithCovarianceStamped +from dimos.msgs.geometry_msgs.Vector3 import Vector3 + + +def test_twist_with_covariance_stamped_default_init(): + """Test default initialization.""" + if ROSVector3 is None: + pytest.skip("ROS not available") + if ROSTwistWithCovariance is None: + pytest.skip("ROS not available") + if ROSTime is None: + pytest.skip("ROS not available") + if ROSHeader is None: + pytest.skip("ROS not available") + if ROSTwist is None: + pytest.skip("ROS not available") + if ROSTwistWithCovarianceStamped is None: + pytest.skip("ROS not available") + twist_cov_stamped = TwistWithCovarianceStamped() + + # Should have current timestamp + assert twist_cov_stamped.ts > 0 + assert twist_cov_stamped.frame_id == "" + + # Twist should be zero + assert twist_cov_stamped.twist.linear.x == 0.0 + assert twist_cov_stamped.twist.linear.y == 0.0 + assert twist_cov_stamped.twist.linear.z == 0.0 + assert twist_cov_stamped.twist.angular.x == 0.0 + assert twist_cov_stamped.twist.angular.y == 0.0 + assert twist_cov_stamped.twist.angular.z == 0.0 + + # Covariance should be all zeros + assert np.all(twist_cov_stamped.covariance == 0.0) + + +def test_twist_with_covariance_stamped_with_timestamp(): + """Test initialization with specific timestamp.""" + ts = 1234567890.123456 + frame_id = "base_link" + twist_cov_stamped = TwistWithCovarianceStamped(ts=ts, frame_id=frame_id) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + + +def test_twist_with_covariance_stamped_with_twist(): + """Test initialization with twist.""" + ts = 1234567890.123456 + frame_id = "odom" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.linear.y == 2.0 + assert twist_cov_stamped.twist.linear.z == 3.0 + assert np.array_equal(twist_cov_stamped.covariance, covariance) + + +def test_twist_with_covariance_stamped_with_tuple(): + """Test initialization with tuple of velocities.""" + ts = 1234567890.123456 + frame_id = "robot_base" + linear = [1.0, 2.0, 3.0] + angular = [0.1, 0.2, 0.3] + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=(linear, angular), covariance=covariance + ) + + assert twist_cov_stamped.ts == ts + assert twist_cov_stamped.frame_id == frame_id + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.angular.x == 0.1 + assert np.array_equal(twist_cov_stamped.covariance, covariance) + + +def test_twist_with_covariance_stamped_properties(): + """Test convenience properties.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="cmd_vel", twist=twist, covariance=covariance + ) + + # Linear and angular properties + assert twist_cov_stamped.linear.x == 1.0 + assert twist_cov_stamped.linear.y == 2.0 + assert twist_cov_stamped.linear.z == 3.0 + assert twist_cov_stamped.angular.x == 0.1 + assert twist_cov_stamped.angular.y == 0.2 + assert twist_cov_stamped.angular.z == 0.3 + + # Covariance matrix + cov_matrix = twist_cov_stamped.covariance_matrix + assert cov_matrix.shape == (6, 6) + assert np.trace(cov_matrix) == 6.0 + + +def test_twist_with_covariance_stamped_str(): + """Test string representation.""" + twist = Twist(Vector3(1.234, 2.567, 3.891), Vector3(0.111, 0.222, 0.333)) + covariance = np.eye(6).flatten() * 2.0 + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="world", twist=twist, covariance=covariance + ) + + str_repr = str(twist_cov_stamped) + assert "TwistWithCovarianceStamped" in str_repr + assert "1.234" in str_repr + assert "2.567" in str_repr + assert "3.891" in str_repr + assert "cov_trace" in str_repr + assert "12.000" in str_repr # Trace of 2*identity is 12 + + +def test_twist_with_covariance_stamped_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + ts = 1234567890.123456 + frame_id = "camera_link" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + source = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = TwistWithCovarianceStamped.lcm_decode(binary_msg) + + # Check timestamp (may lose some precision) + assert abs(decoded.ts - ts) < 1e-6 + assert decoded.frame_id == frame_id + + # Check twist + assert decoded.twist.linear.x == 1.0 + assert decoded.twist.linear.y == 2.0 + assert decoded.twist.linear.z == 3.0 + assert decoded.twist.angular.x == 0.1 + assert decoded.twist.angular.y == 0.2 + assert decoded.twist.angular.z == 0.3 + + # Check covariance + assert np.array_equal(decoded.covariance, covariance) + + +@pytest.mark.ros +def test_twist_with_covariance_stamped_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSTwistWithCovarianceStamped() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "laser" + + # Set twist with covariance + ros_msg.twist = ROSTwistWithCovariance() + ros_msg.twist.twist = ROSTwist() + ros_msg.twist.twist.linear = ROSVector3(x=1.0, y=2.0, z=3.0) + ros_msg.twist.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.twist.covariance = [float(i) for i in range(36)] + + twist_cov_stamped = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + + assert twist_cov_stamped.ts == 1234567890.123456 + assert twist_cov_stamped.frame_id == "laser" + assert twist_cov_stamped.twist.linear.x == 1.0 + assert twist_cov_stamped.twist.linear.y == 2.0 + assert twist_cov_stamped.twist.linear.z == 3.0 + assert twist_cov_stamped.twist.angular.x == 0.1 + assert twist_cov_stamped.twist.angular.y == 0.2 + assert twist_cov_stamped.twist.angular.z == 0.3 + assert np.array_equal(twist_cov_stamped.covariance, np.arange(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_stamped_to_ros_msg(): + """Test converting to ROS message.""" + ts = 1234567890.567890 + frame_id = "imu" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.arange(36, dtype=float) + + twist_cov_stamped = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + ros_msg = twist_cov_stamped.to_ros_msg() + + assert isinstance(ros_msg, ROSTwistWithCovarianceStamped) + assert ros_msg.header.frame_id == frame_id + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + + assert ros_msg.twist.twist.linear.x == 1.0 + assert ros_msg.twist.twist.linear.y == 2.0 + assert ros_msg.twist.twist.linear.z == 3.0 + assert ros_msg.twist.twist.angular.x == 0.1 + assert ros_msg.twist.twist.angular.y == 0.2 + assert ros_msg.twist.twist.angular.z == 0.3 + assert list(ros_msg.twist.covariance) == list(range(36)) + + +@pytest.mark.ros +def test_twist_with_covariance_stamped_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + ts = 2147483647.987654 # Max int32 value for ROS Time.sec + frame_id = "robot_base" + twist = Twist(Vector3(1.5, 2.5, 3.5), Vector3(0.15, 0.25, 0.35)) + covariance = np.random.rand(36) + + original = TwistWithCovarianceStamped( + ts=ts, frame_id=frame_id, twist=twist, covariance=covariance + ) + + ros_msg = original.to_ros_msg() + restored = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + + # Check timestamp (loses some precision in conversion) + assert abs(restored.ts - ts) < 1e-6 + assert restored.frame_id == frame_id + + # Check twist + assert restored.twist.linear.x == original.twist.linear.x + assert restored.twist.linear.y == original.twist.linear.y + assert restored.twist.linear.z == original.twist.linear.z + assert restored.twist.angular.x == original.twist.angular.x + assert restored.twist.angular.y == original.twist.angular.y + assert restored.twist.angular.z == original.twist.angular.z + + # Check covariance + assert np.allclose(restored.covariance, original.covariance) + + +def test_twist_with_covariance_stamped_zero_timestamp(): + """Test that zero timestamp gets replaced with current time.""" + twist_cov_stamped = TwistWithCovarianceStamped(ts=0.0) + + # Should have been replaced with current time + assert twist_cov_stamped.ts > 0 + assert twist_cov_stamped.ts <= time.time() + + +def test_twist_with_covariance_stamped_inheritance(): + """Test that it properly inherits from TwistWithCovariance and Timestamped.""" + twist = Twist(Vector3(1.0, 2.0, 3.0), Vector3(0.1, 0.2, 0.3)) + covariance = np.eye(6).flatten() + twist_cov_stamped = TwistWithCovarianceStamped( + ts=1234567890.0, frame_id="test", twist=twist, covariance=covariance + ) + + # Should be instance of parent classes + assert isinstance(twist_cov_stamped, TwistWithCovariance) + + # Should have Timestamped attributes + assert hasattr(twist_cov_stamped, "ts") + assert hasattr(twist_cov_stamped, "frame_id") + + # Should have TwistWithCovariance attributes + assert hasattr(twist_cov_stamped, "twist") + assert hasattr(twist_cov_stamped, "covariance") + + +def test_twist_with_covariance_stamped_is_zero(): + """Test is_zero method inheritance.""" + # Zero twist + twist_cov_stamped1 = TwistWithCovarianceStamped() + assert twist_cov_stamped1.is_zero() + assert not twist_cov_stamped1 # Boolean conversion + + # Non-zero twist + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.0)) + twist_cov_stamped2 = TwistWithCovarianceStamped(twist=twist) + assert not twist_cov_stamped2.is_zero() + assert twist_cov_stamped2 # Boolean conversion + + +def test_twist_with_covariance_stamped_sec_nsec(): + """Test the sec_nsec helper function.""" + from dimos.msgs.geometry_msgs.TwistWithCovarianceStamped import sec_nsec + + # Test integer seconds + s, ns = sec_nsec(1234567890.0) + assert s == 1234567890 + assert ns == 0 + + # Test fractional seconds + s, ns = sec_nsec(1234567890.123456789) + assert s == 1234567890 + assert abs(ns - 123456789) < 100 # Allow small rounding error + + # Test small fractional seconds + s, ns = sec_nsec(0.000000001) + assert s == 0 + assert ns == 1 + + # Test large timestamp + s, ns = sec_nsec(9999999999.999999999) + # Due to floating point precision, this might round to 10000000000 + assert s in [9999999999, 10000000000] + if s == 9999999999: + assert abs(ns - 999999999) < 10 + else: + assert ns == 0 + + +@pytest.mark.ros +@pytest.mark.parametrize( + "frame_id", + ["", "map", "odom", "base_link", "cmd_vel", "sensor/velocity/front"], +) +def test_twist_with_covariance_stamped_frame_ids(frame_id): + """Test various frame ID values.""" + twist_cov_stamped = TwistWithCovarianceStamped(frame_id=frame_id) + assert twist_cov_stamped.frame_id == frame_id + + # Test roundtrip through ROS + ros_msg = twist_cov_stamped.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + + restored = TwistWithCovarianceStamped.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + + +def test_twist_with_covariance_stamped_different_covariances(): + """Test with different covariance patterns.""" + twist = Twist(Vector3(1.0, 0.0, 0.0), Vector3(0.0, 0.0, 0.5)) + + # Zero covariance + zero_cov = np.zeros(36) + twist_cov1 = TwistWithCovarianceStamped(twist=twist, covariance=zero_cov) + assert np.all(twist_cov1.covariance == 0.0) + + # Identity covariance + identity_cov = np.eye(6).flatten() + twist_cov2 = TwistWithCovarianceStamped(twist=twist, covariance=identity_cov) + assert np.trace(twist_cov2.covariance_matrix) == 6.0 + + # Full covariance + full_cov = np.random.rand(36) + twist_cov3 = TwistWithCovarianceStamped(twist=twist, covariance=full_cov) + assert np.array_equal(twist_cov3.covariance, full_cov) diff --git a/dimos/msgs/nav_msgs/Odometry.py b/dimos/msgs/nav_msgs/Odometry.py new file mode 100644 index 0000000000..6e8b6c27fc --- /dev/null +++ b/dimos/msgs/nav_msgs/Odometry.py @@ -0,0 +1,379 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import TypeAlias + +import numpy as np +from dimos_lcm.nav_msgs import Odometry as LCMOdometry +from plum import dispatch + +try: + from nav_msgs.msg import Odometry as ROSOdometry +except ImportError: + ROSOdometry = None + +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.types.timestamped import Timestamped + +# Types that can be converted to/from Odometry +OdometryConvertable: TypeAlias = ( + LCMOdometry | dict[str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist] +) + + +def sec_nsec(ts): + s = int(ts) + return [s, int((ts - s) * 1_000_000_000)] + + +class Odometry(LCMOdometry, Timestamped): + pose: PoseWithCovariance + twist: TwistWithCovariance + msg_name = "nav_msgs.Odometry" + ts: float + frame_id: str + child_frame_id: str + + @dispatch + def __init__( + self, + ts: float = 0.0, + frame_id: str = "", + child_frame_id: str = "", + pose: PoseWithCovariance | Pose | None = None, + twist: TwistWithCovariance | Twist | None = None, + ) -> None: + """Initialize with timestamp, frame IDs, pose and twist. + + Args: + ts: Timestamp in seconds (defaults to current time if 0) + frame_id: Reference frame ID (e.g., "odom", "map") + child_frame_id: Child frame ID (e.g., "base_link", "base_footprint") + pose: Pose with covariance (or just Pose, covariance will be zero) + twist: Twist with covariance (or just Twist, covariance will be zero) + """ + self.ts = ts if ts != 0 else time.time() + self.frame_id = frame_id + self.child_frame_id = child_frame_id + + # Handle pose + if pose is None: + self.pose = PoseWithCovariance() + elif isinstance(pose, PoseWithCovariance): + self.pose = pose + elif isinstance(pose, Pose): + self.pose = PoseWithCovariance(pose) + else: + self.pose = PoseWithCovariance(Pose(pose)) + + # Handle twist + if twist is None: + self.twist = TwistWithCovariance() + elif isinstance(twist, TwistWithCovariance): + self.twist = twist + elif isinstance(twist, Twist): + self.twist = TwistWithCovariance(twist) + else: + self.twist = TwistWithCovariance(Twist(twist)) + + @dispatch + def __init__(self, odometry: Odometry) -> None: + """Initialize from another Odometry (copy constructor).""" + self.ts = odometry.ts + self.frame_id = odometry.frame_id + self.child_frame_id = odometry.child_frame_id + self.pose = PoseWithCovariance(odometry.pose) + self.twist = TwistWithCovariance(odometry.twist) + + @dispatch + def __init__(self, lcm_odometry: LCMOdometry) -> None: + """Initialize from an LCM Odometry.""" + self.ts = lcm_odometry.header.stamp.sec + (lcm_odometry.header.stamp.nsec / 1_000_000_000) + self.frame_id = lcm_odometry.header.frame_id + self.child_frame_id = lcm_odometry.child_frame_id + self.pose = PoseWithCovariance(lcm_odometry.pose) + self.twist = TwistWithCovariance(lcm_odometry.twist) + + @dispatch + def __init__( + self, + odometry_dict: dict[ + str, float | str | PoseWithCovariance | TwistWithCovariance | Pose | Twist + ], + ) -> None: + """Initialize from a dictionary.""" + self.ts = odometry_dict.get("ts", odometry_dict.get("timestamp", time.time())) + self.frame_id = odometry_dict.get("frame_id", "") + self.child_frame_id = odometry_dict.get("child_frame_id", "") + + # Handle pose + pose = odometry_dict.get("pose") + if pose is None: + self.pose = PoseWithCovariance() + elif isinstance(pose, PoseWithCovariance): + self.pose = pose + elif isinstance(pose, Pose): + self.pose = PoseWithCovariance(pose) + else: + self.pose = PoseWithCovariance(Pose(pose)) + + # Handle twist + twist = odometry_dict.get("twist") + if twist is None: + self.twist = TwistWithCovariance() + elif isinstance(twist, TwistWithCovariance): + self.twist = twist + elif isinstance(twist, Twist): + self.twist = TwistWithCovariance(twist) + else: + self.twist = TwistWithCovariance(Twist(twist)) + + @property + def position(self) -> Vector3: + """Get position from pose.""" + return self.pose.position + + @property + def orientation(self): + """Get orientation from pose.""" + return self.pose.orientation + + @property + def linear_velocity(self) -> Vector3: + """Get linear velocity from twist.""" + return self.twist.linear + + @property + def angular_velocity(self) -> Vector3: + """Get angular velocity from twist.""" + return self.twist.angular + + @property + def x(self) -> float: + """X position.""" + return self.pose.x + + @property + def y(self) -> float: + """Y position.""" + return self.pose.y + + @property + def z(self) -> float: + """Z position.""" + return self.pose.z + + @property + def vx(self) -> float: + """Linear velocity in X.""" + return self.twist.linear.x + + @property + def vy(self) -> float: + """Linear velocity in Y.""" + return self.twist.linear.y + + @property + def vz(self) -> float: + """Linear velocity in Z.""" + return self.twist.linear.z + + @property + def wx(self) -> float: + """Angular velocity around X (roll rate).""" + return self.twist.angular.x + + @property + def wy(self) -> float: + """Angular velocity around Y (pitch rate).""" + return self.twist.angular.y + + @property + def wz(self) -> float: + """Angular velocity around Z (yaw rate).""" + return self.twist.angular.z + + @property + def roll(self) -> float: + """Roll angle in radians.""" + return self.pose.roll + + @property + def pitch(self) -> float: + """Pitch angle in radians.""" + return self.pose.pitch + + @property + def yaw(self) -> float: + """Yaw angle in radians.""" + return self.pose.yaw + + def __repr__(self) -> str: + return ( + f"Odometry(ts={self.ts:.6f}, frame_id='{self.frame_id}', " + f"child_frame_id='{self.child_frame_id}', pose={self.pose!r}, twist={self.twist!r})" + ) + + def __str__(self) -> str: + return ( + f"Odometry:\n" + f" Timestamp: {self.ts:.6f}\n" + f" Frame: {self.frame_id} -> {self.child_frame_id}\n" + f" Position: [{self.x:.3f}, {self.y:.3f}, {self.z:.3f}]\n" + f" Orientation: [roll={self.roll:.3f}, pitch={self.pitch:.3f}, yaw={self.yaw:.3f}]\n" + f" Linear Velocity: [{self.vx:.3f}, {self.vy:.3f}, {self.vz:.3f}]\n" + f" Angular Velocity: [{self.wx:.3f}, {self.wy:.3f}, {self.wz:.3f}]" + ) + + def __eq__(self, other) -> bool: + """Check if two Odometry messages are equal.""" + if not isinstance(other, Odometry): + return False + return ( + abs(self.ts - other.ts) < 1e-6 + and self.frame_id == other.frame_id + and self.child_frame_id == other.child_frame_id + and self.pose == other.pose + and self.twist == other.twist + ) + + def lcm_encode(self) -> bytes: + """Encode to LCM binary format.""" + lcm_msg = LCMOdometry() + + # Set header + [lcm_msg.header.stamp.sec, lcm_msg.header.stamp.nsec] = sec_nsec(self.ts) + lcm_msg.header.frame_id = self.frame_id + lcm_msg.child_frame_id = self.child_frame_id + + # Set pose with covariance + lcm_msg.pose.pose = self.pose.pose + if isinstance(self.pose.covariance, np.ndarray): + lcm_msg.pose.covariance = self.pose.covariance.tolist() + else: + lcm_msg.pose.covariance = list(self.pose.covariance) + + # Set twist with covariance + lcm_msg.twist.twist = self.twist.twist + if isinstance(self.twist.covariance, np.ndarray): + lcm_msg.twist.covariance = self.twist.covariance.tolist() + else: + lcm_msg.twist.covariance = list(self.twist.covariance) + + return lcm_msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "Odometry": + """Decode from LCM binary format.""" + lcm_msg = LCMOdometry.lcm_decode(data) + + # Extract timestamp + ts = lcm_msg.header.stamp.sec + (lcm_msg.header.stamp.nsec / 1_000_000_000) + + # Create pose with covariance + pose = Pose( + position=[ + lcm_msg.pose.pose.position.x, + lcm_msg.pose.pose.position.y, + lcm_msg.pose.pose.position.z, + ], + orientation=[ + lcm_msg.pose.pose.orientation.x, + lcm_msg.pose.pose.orientation.y, + lcm_msg.pose.pose.orientation.z, + lcm_msg.pose.pose.orientation.w, + ], + ) + pose_with_cov = PoseWithCovariance(pose, lcm_msg.pose.covariance) + + # Create twist with covariance + twist = Twist( + linear=[ + lcm_msg.twist.twist.linear.x, + lcm_msg.twist.twist.linear.y, + lcm_msg.twist.twist.linear.z, + ], + angular=[ + lcm_msg.twist.twist.angular.x, + lcm_msg.twist.twist.angular.y, + lcm_msg.twist.twist.angular.z, + ], + ) + twist_with_cov = TwistWithCovariance(twist, lcm_msg.twist.covariance) + + return cls( + ts=ts, + frame_id=lcm_msg.header.frame_id, + child_frame_id=lcm_msg.child_frame_id, + pose=pose_with_cov, + twist=twist_with_cov, + ) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSOdometry) -> "Odometry": + """Create an Odometry from a ROS nav_msgs/Odometry message. + + Args: + ros_msg: ROS Odometry message + + Returns: + Odometry instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert pose and twist with covariance + pose_with_cov = PoseWithCovariance.from_ros_msg(ros_msg.pose) + twist_with_cov = TwistWithCovariance.from_ros_msg(ros_msg.twist) + + return cls( + ts=ts, + frame_id=ros_msg.header.frame_id, + child_frame_id=ros_msg.child_frame_id, + pose=pose_with_cov, + twist=twist_with_cov, + ) + + def to_ros_msg(self) -> ROSOdometry: + """Convert to a ROS nav_msgs/Odometry message. + + Returns: + ROS Odometry message + """ + + ros_msg = ROSOdometry() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Set child frame ID + ros_msg.child_frame_id = self.child_frame_id + + # Set pose with covariance + ros_msg.pose = self.pose.to_ros_msg() + + # Set twist with covariance + ros_msg.twist = self.twist.to_ros_msg() + + return ros_msg diff --git a/dimos/msgs/nav_msgs/Path.py b/dimos/msgs/nav_msgs/Path.py index 8fca1cf25f..18a2fb07ee 100644 --- a/dimos/msgs/nav_msgs/Path.py +++ b/dimos/msgs/nav_msgs/Path.py @@ -27,6 +27,11 @@ from dimos_lcm.std_msgs import Header as LCMHeader from dimos_lcm.std_msgs import Time as LCMTime +try: + from nav_msgs.msg import Path as ROSPath +except ImportError: + ROSPath = None + from dimos.msgs.geometry_msgs.Pose import Pose from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion, QuaternionConvertable @@ -187,3 +192,44 @@ def reverse(self) -> "Path": def clear(self) -> None: """Clear all poses from this path (mutable).""" self.poses.clear() + + @classmethod + def from_ros_msg(cls, ros_msg: ROSPath) -> "Path": + """Create a Path from a ROS nav_msgs/Path message. + + Args: + ros_msg: ROS Path message + + Returns: + Path instance + """ + + # Convert timestamp from ROS header + ts = ros_msg.header.stamp.sec + (ros_msg.header.stamp.nanosec / 1_000_000_000) + + # Convert poses + poses = [] + for ros_pose_stamped in ros_msg.poses: + poses.append(PoseStamped.from_ros_msg(ros_pose_stamped)) + + return cls(ts=ts, frame_id=ros_msg.header.frame_id, poses=poses) + + def to_ros_msg(self) -> ROSPath: + """Convert to a ROS nav_msgs/Path message. + + Returns: + ROS Path message + """ + + ros_msg = ROSPath() + + # Set header + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1_000_000_000) + + # Convert poses + for pose in self.poses: + ros_msg.poses.append(pose.to_ros_msg()) + + return ros_msg diff --git a/dimos/msgs/nav_msgs/__init__.py b/dimos/msgs/nav_msgs/__init__.py index 3e4241daa0..9ea87f3f78 100644 --- a/dimos/msgs/nav_msgs/__init__.py +++ b/dimos/msgs/nav_msgs/__init__.py @@ -1,4 +1,5 @@ from dimos.msgs.nav_msgs.OccupancyGrid import CostValues, MapMetaData, OccupancyGrid from dimos.msgs.nav_msgs.Path import Path +from dimos.msgs.nav_msgs.Odometry import Odometry -__all__ = ["Path", "OccupancyGrid", "MapMetaData", "CostValues"] +__all__ = ["Path", "OccupancyGrid", "MapMetaData", "CostValues", "Odometry"] diff --git a/dimos/msgs/nav_msgs/test_Odometry.py b/dimos/msgs/nav_msgs/test_Odometry.py new file mode 100644 index 0000000000..2fee199b1b --- /dev/null +++ b/dimos/msgs/nav_msgs/test_Odometry.py @@ -0,0 +1,504 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time + +import numpy as np +import pytest + +try: + from nav_msgs.msg import Odometry as ROSOdometry + from geometry_msgs.msg import PoseWithCovariance as ROSPoseWithCovariance + from geometry_msgs.msg import TwistWithCovariance as ROSTwistWithCovariance + from geometry_msgs.msg import Pose as ROSPose + from geometry_msgs.msg import Twist as ROSTwist + from geometry_msgs.msg import Point as ROSPoint + from geometry_msgs.msg import Quaternion as ROSQuaternion + from geometry_msgs.msg import Vector3 as ROSVector3 + from std_msgs.msg import Header as ROSHeader + from builtin_interfaces.msg import Time as ROSTime +except ImportError: + ROSTwist = None + ROSHeader = None + ROSPose = None + ROSPoseWithCovariance = None + ROSQuaternion = None + ROSOdometry = None + ROSPoint = None + ROSTime = None + ROSTwistWithCovariance = None + ROSVector3 = None + +from dimos_lcm.nav_msgs import Odometry as LCMOdometry + +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.geometry_msgs.Pose import Pose +from dimos.msgs.geometry_msgs.PoseWithCovariance import PoseWithCovariance +from dimos.msgs.geometry_msgs.Twist import Twist +from dimos.msgs.geometry_msgs.TwistWithCovariance import TwistWithCovariance +from dimos.msgs.geometry_msgs.Vector3 import Vector3 +from dimos.msgs.geometry_msgs.Quaternion import Quaternion + + +def test_odometry_default_init(): + """Test default initialization.""" + if ROSVector3 is None: + pytest.skip("ROS not available") + if ROSTwistWithCovariance is None: + pytest.skip("ROS not available") + if ROSTime is None: + pytest.skip("ROS not available") + if ROSPoint is None: + pytest.skip("ROS not available") + if ROSOdometry is None: + pytest.skip("ROS not available") + if ROSQuaternion is None: + pytest.skip("ROS not available") + if ROSPoseWithCovariance is None: + pytest.skip("ROS not available") + if ROSPose is None: + pytest.skip("ROS not available") + if ROSHeader is None: + pytest.skip("ROS not available") + if ROSTwist is None: + pytest.skip("ROS not available") + odom = Odometry() + + # Should have current timestamp + assert odom.ts > 0 + assert odom.frame_id == "" + assert odom.child_frame_id == "" + + # Pose should be at origin with identity orientation + assert odom.pose.position.x == 0.0 + assert odom.pose.position.y == 0.0 + assert odom.pose.position.z == 0.0 + assert odom.pose.orientation.w == 1.0 + + # Twist should be zero + assert odom.twist.linear.x == 0.0 + assert odom.twist.linear.y == 0.0 + assert odom.twist.linear.z == 0.0 + assert odom.twist.angular.x == 0.0 + assert odom.twist.angular.y == 0.0 + assert odom.twist.angular.z == 0.0 + + # Covariances should be zero + assert np.all(odom.pose.covariance == 0.0) + assert np.all(odom.twist.covariance == 0.0) + + +def test_odometry_with_frames(): + """Test initialization with frame IDs.""" + ts = 1234567890.123456 + frame_id = "odom" + child_frame_id = "base_link" + + odom = Odometry(ts=ts, frame_id=frame_id, child_frame_id=child_frame_id) + + assert odom.ts == ts + assert odom.frame_id == frame_id + assert odom.child_frame_id == child_frame_id + + +def test_odometry_with_pose_and_twist(): + """Test initialization with pose and twist.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + + odom = Odometry(ts=1000.0, frame_id="odom", child_frame_id="base_link", pose=pose, twist=twist) + + assert odom.pose.pose.position.x == 1.0 + assert odom.pose.pose.position.y == 2.0 + assert odom.pose.pose.position.z == 3.0 + assert odom.twist.twist.linear.x == 0.5 + assert odom.twist.twist.angular.z == 0.1 + + +def test_odometry_with_covariances(): + """Test initialization with pose and twist with covariances.""" + pose = Pose(1.0, 2.0, 3.0) + pose_cov = np.arange(36, dtype=float) + pose_with_cov = PoseWithCovariance(pose, pose_cov) + + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + twist_cov = np.arange(36, 72, dtype=float) + twist_with_cov = TwistWithCovariance(twist, twist_cov) + + odom = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=pose_with_cov, + twist=twist_with_cov, + ) + + assert odom.pose.position.x == 1.0 + assert np.array_equal(odom.pose.covariance, pose_cov) + assert odom.twist.linear.x == 0.5 + assert np.array_equal(odom.twist.covariance, twist_cov) + + +def test_odometry_copy_constructor(): + """Test copy constructor.""" + original = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + copy = Odometry(original) + + assert copy == original + assert copy is not original + assert copy.pose is not original.pose + assert copy.twist is not original.twist + + +def test_odometry_dict_init(): + """Test initialization from dictionary.""" + odom_dict = { + "ts": 1000.0, + "frame_id": "odom", + "child_frame_id": "base_link", + "pose": Pose(1.0, 2.0, 3.0), + "twist": Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + } + + odom = Odometry(odom_dict) + + assert odom.ts == 1000.0 + assert odom.frame_id == "odom" + assert odom.child_frame_id == "base_link" + assert odom.pose.position.x == 1.0 + assert odom.twist.linear.x == 0.5 + + +def test_odometry_properties(): + """Test convenience properties.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + + odom = Odometry(ts=1000.0, frame_id="odom", child_frame_id="base_link", pose=pose, twist=twist) + + # Position properties + assert odom.x == 1.0 + assert odom.y == 2.0 + assert odom.z == 3.0 + assert odom.position.x == 1.0 + assert odom.position.y == 2.0 + assert odom.position.z == 3.0 + + # Orientation properties + assert odom.orientation.x == 0.1 + assert odom.orientation.y == 0.2 + assert odom.orientation.z == 0.3 + assert odom.orientation.w == 0.9 + + # Velocity properties + assert odom.vx == 0.5 + assert odom.vy == 0.6 + assert odom.vz == 0.7 + assert odom.linear_velocity.x == 0.5 + assert odom.linear_velocity.y == 0.6 + assert odom.linear_velocity.z == 0.7 + + # Angular velocity properties + assert odom.wx == 0.1 + assert odom.wy == 0.2 + assert odom.wz == 0.3 + assert odom.angular_velocity.x == 0.1 + assert odom.angular_velocity.y == 0.2 + assert odom.angular_velocity.z == 0.3 + + # Euler angles + assert odom.roll == pose.roll + assert odom.pitch == pose.pitch + assert odom.yaw == pose.yaw + + +def test_odometry_str_repr(): + """Test string representations.""" + odom = Odometry( + ts=1234567890.123456, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.234, 2.567, 3.891), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + repr_str = repr(odom) + assert "Odometry" in repr_str + assert "1234567890.123456" in repr_str + assert "odom" in repr_str + assert "base_link" in repr_str + + str_repr = str(odom) + assert "Odometry" in str_repr + assert "odom -> base_link" in str_repr + assert "1.234" in str_repr + assert "0.500" in str_repr + + +def test_odometry_equality(): + """Test equality comparison.""" + odom1 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + odom2 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.0, 2.0, 3.0), + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + odom3 = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_link", + pose=Pose(1.1, 2.0, 3.0), # Different position + twist=Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)), + ) + + assert odom1 == odom2 + assert odom1 != odom3 + assert odom1 != "not an odometry" + + +def test_odometry_lcm_encode_decode(): + """Test LCM encoding and decoding.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = np.arange(36, dtype=float) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + twist_cov = np.arange(36, 72, dtype=float) + + source = Odometry( + ts=1234567890.123456, + frame_id="odom", + child_frame_id="base_link", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + # Encode and decode + binary_msg = source.lcm_encode() + decoded = Odometry.lcm_decode(binary_msg) + + # Check values (allowing for timestamp precision loss) + assert abs(decoded.ts - source.ts) < 1e-6 + assert decoded.frame_id == source.frame_id + assert decoded.child_frame_id == source.child_frame_id + assert decoded.pose == source.pose + assert decoded.twist == source.twist + + +@pytest.mark.ros +def test_odometry_from_ros_msg(): + """Test creating from ROS message.""" + ros_msg = ROSOdometry() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.stamp = ROSTime() + ros_msg.header.stamp.sec = 1234567890 + ros_msg.header.stamp.nanosec = 123456000 + ros_msg.header.frame_id = "odom" + ros_msg.child_frame_id = "base_link" + + # Set pose with covariance + ros_msg.pose = ROSPoseWithCovariance() + ros_msg.pose.pose = ROSPose() + ros_msg.pose.pose.position = ROSPoint(x=1.0, y=2.0, z=3.0) + ros_msg.pose.pose.orientation = ROSQuaternion(x=0.1, y=0.2, z=0.3, w=0.9) + ros_msg.pose.covariance = [float(i) for i in range(36)] + + # Set twist with covariance + ros_msg.twist = ROSTwistWithCovariance() + ros_msg.twist.twist = ROSTwist() + ros_msg.twist.twist.linear = ROSVector3(x=0.5, y=0.6, z=0.7) + ros_msg.twist.twist.angular = ROSVector3(x=0.1, y=0.2, z=0.3) + ros_msg.twist.covariance = [float(i) for i in range(36, 72)] + + odom = Odometry.from_ros_msg(ros_msg) + + assert odom.ts == 1234567890.123456 + assert odom.frame_id == "odom" + assert odom.child_frame_id == "base_link" + assert odom.pose.position.x == 1.0 + assert odom.twist.linear.x == 0.5 + assert np.array_equal(odom.pose.covariance, np.arange(36)) + assert np.array_equal(odom.twist.covariance, np.arange(36, 72)) + + +@pytest.mark.ros +def test_odometry_to_ros_msg(): + """Test converting to ROS message.""" + pose = Pose(1.0, 2.0, 3.0, 0.1, 0.2, 0.3, 0.9) + pose_cov = np.arange(36, dtype=float) + twist = Twist(Vector3(0.5, 0.6, 0.7), Vector3(0.1, 0.2, 0.3)) + twist_cov = np.arange(36, 72, dtype=float) + + odom = Odometry( + ts=1234567890.567890, + frame_id="odom", + child_frame_id="base_link", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + ros_msg = odom.to_ros_msg() + + assert isinstance(ros_msg, ROSOdometry) + assert ros_msg.header.frame_id == "odom" + assert ros_msg.header.stamp.sec == 1234567890 + assert abs(ros_msg.header.stamp.nanosec - 567890000) < 100 # Allow small rounding error + assert ros_msg.child_frame_id == "base_link" + + # Check pose + assert ros_msg.pose.pose.position.x == 1.0 + assert ros_msg.pose.pose.position.y == 2.0 + assert ros_msg.pose.pose.position.z == 3.0 + assert ros_msg.pose.pose.orientation.x == 0.1 + assert ros_msg.pose.pose.orientation.y == 0.2 + assert ros_msg.pose.pose.orientation.z == 0.3 + assert ros_msg.pose.pose.orientation.w == 0.9 + assert list(ros_msg.pose.covariance) == list(range(36)) + + # Check twist + assert ros_msg.twist.twist.linear.x == 0.5 + assert ros_msg.twist.twist.linear.y == 0.6 + assert ros_msg.twist.twist.linear.z == 0.7 + assert ros_msg.twist.twist.angular.x == 0.1 + assert ros_msg.twist.twist.angular.y == 0.2 + assert ros_msg.twist.twist.angular.z == 0.3 + assert list(ros_msg.twist.covariance) == list(range(36, 72)) + + +@pytest.mark.ros +def test_odometry_ros_roundtrip(): + """Test round-trip conversion with ROS messages.""" + pose = Pose(1.5, 2.5, 3.5, 0.15, 0.25, 0.35, 0.85) + pose_cov = np.random.rand(36) + twist = Twist(Vector3(0.55, 0.65, 0.75), Vector3(0.15, 0.25, 0.35)) + twist_cov = np.random.rand(36) + + original = Odometry( + ts=2147483647.987654, # Max int32 value for ROS Time.sec + frame_id="world", + child_frame_id="robot", + pose=PoseWithCovariance(pose, pose_cov), + twist=TwistWithCovariance(twist, twist_cov), + ) + + ros_msg = original.to_ros_msg() + restored = Odometry.from_ros_msg(ros_msg) + + # Check values (allowing for timestamp precision loss) + assert abs(restored.ts - original.ts) < 1e-6 + assert restored.frame_id == original.frame_id + assert restored.child_frame_id == original.child_frame_id + assert restored.pose == original.pose + assert restored.twist == original.twist + + +def test_odometry_zero_timestamp(): + """Test that zero timestamp gets replaced with current time.""" + odom = Odometry(ts=0.0) + + # Should have been replaced with current time + assert odom.ts > 0 + assert odom.ts <= time.time() + + +def test_odometry_with_just_pose(): + """Test initialization with just a Pose (no covariance).""" + pose = Pose(1.0, 2.0, 3.0) + + odom = Odometry(pose=pose) + + assert odom.pose.position.x == 1.0 + assert odom.pose.position.y == 2.0 + assert odom.pose.position.z == 3.0 + assert np.all(odom.pose.covariance == 0.0) # Should have zero covariance + assert np.all(odom.twist.covariance == 0.0) # Twist should also be zero + + +def test_odometry_with_just_twist(): + """Test initialization with just a Twist (no covariance).""" + twist = Twist(Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.1)) + + odom = Odometry(twist=twist) + + assert odom.twist.linear.x == 0.5 + assert odom.twist.angular.z == 0.1 + assert np.all(odom.twist.covariance == 0.0) # Should have zero covariance + assert np.all(odom.pose.covariance == 0.0) # Pose should also be zero + + +@pytest.mark.ros +@pytest.mark.parametrize( + "frame_id,child_frame_id", + [ + ("odom", "base_link"), + ("map", "odom"), + ("world", "robot"), + ("base_link", "camera_link"), + ("", ""), # Empty frames + ], +) +def test_odometry_frame_combinations(frame_id, child_frame_id): + """Test various frame ID combinations.""" + odom = Odometry(frame_id=frame_id, child_frame_id=child_frame_id) + + assert odom.frame_id == frame_id + assert odom.child_frame_id == child_frame_id + + # Test roundtrip through ROS + ros_msg = odom.to_ros_msg() + assert ros_msg.header.frame_id == frame_id + assert ros_msg.child_frame_id == child_frame_id + + restored = Odometry.from_ros_msg(ros_msg) + assert restored.frame_id == frame_id + assert restored.child_frame_id == child_frame_id + + +def test_odometry_typical_robot_scenario(): + """Test a typical robot odometry scenario.""" + # Robot moving forward at 0.5 m/s with slight rotation + odom = Odometry( + ts=1000.0, + frame_id="odom", + child_frame_id="base_footprint", + pose=Pose(10.0, 5.0, 0.0, 0.0, 0.0, np.sin(0.1), np.cos(0.1)), # 0.2 rad yaw + twist=Twist( + Vector3(0.5, 0.0, 0.0), Vector3(0.0, 0.0, 0.05) + ), # Moving forward, turning slightly + ) + + # Check we can access all the typical properties + assert odom.x == 10.0 + assert odom.y == 5.0 + assert odom.z == 0.0 + assert abs(odom.yaw - 0.2) < 0.01 # Approximately 0.2 radians + assert odom.vx == 0.5 # Forward velocity + assert odom.wz == 0.05 # Yaw rate diff --git a/dimos/msgs/nav_msgs/test_Path.py b/dimos/msgs/nav_msgs/test_Path.py index 0a34245448..94028d7959 100644 --- a/dimos/msgs/nav_msgs/test_Path.py +++ b/dimos/msgs/nav_msgs/test_Path.py @@ -16,6 +16,14 @@ import pytest + +try: + from nav_msgs.msg import Path as ROSPath + from geometry_msgs.msg import PoseStamped as ROSPoseStamped +except ImportError: + ROSPoseStamped = None + ROSPath = None + from dimos.msgs.geometry_msgs.PoseStamped import PoseStamped from dimos.msgs.geometry_msgs.Quaternion import Quaternion from dimos.msgs.nav_msgs.Path import Path @@ -288,3 +296,98 @@ def test_str_representation(): path.push_mut(create_test_pose(1, 1, 0)) path.push_mut(create_test_pose(2, 2, 0)) assert str(path) == "Path(frame_id='map', poses=2)" + + +@pytest.mark.ros +def test_path_from_ros_msg(): + """Test creating a Path from a ROS Path message.""" + ros_msg = ROSPath() + ros_msg.header.frame_id = "map" + ros_msg.header.stamp.sec = 123 + ros_msg.header.stamp.nanosec = 456000000 + + # Add some poses + for i in range(3): + ros_pose = ROSPoseStamped() + ros_pose.header.frame_id = "map" + ros_pose.header.stamp.sec = 123 + i + ros_pose.header.stamp.nanosec = 0 + ros_pose.pose.position.x = float(i) + ros_pose.pose.position.y = float(i * 2) + ros_pose.pose.position.z = float(i * 3) + ros_pose.pose.orientation.x = 0.0 + ros_pose.pose.orientation.y = 0.0 + ros_pose.pose.orientation.z = 0.0 + ros_pose.pose.orientation.w = 1.0 + ros_msg.poses.append(ros_pose) + + path = Path.from_ros_msg(ros_msg) + + assert path.frame_id == "map" + assert path.ts == 123.456 + assert len(path.poses) == 3 + + for i, pose in enumerate(path.poses): + assert pose.position.x == float(i) + assert pose.position.y == float(i * 2) + assert pose.position.z == float(i * 3) + assert pose.orientation.w == 1.0 + + +@pytest.mark.ros +def test_path_to_ros_msg(): + """Test converting a Path to a ROS Path message.""" + poses = [ + PoseStamped( + ts=124.0 + i, frame_id="odom", position=[i, i * 2, i * 3], orientation=[0, 0, 0, 1] + ) + for i in range(3) + ] + + path = Path(ts=123.456, frame_id="odom", poses=poses) + + ros_msg = path.to_ros_msg() + + assert isinstance(ros_msg, ROSPath) + assert ros_msg.header.frame_id == "odom" + assert ros_msg.header.stamp.sec == 123 + assert ros_msg.header.stamp.nanosec == 456000000 + assert len(ros_msg.poses) == 3 + + for i, ros_pose in enumerate(ros_msg.poses): + assert ros_pose.pose.position.x == float(i) + assert ros_pose.pose.position.y == float(i * 2) + assert ros_pose.pose.position.z == float(i * 3) + assert ros_pose.pose.orientation.w == 1.0 + + +@pytest.mark.ros +def test_path_ros_roundtrip(): + """Test round-trip conversion between Path and ROS Path.""" + poses = [ + PoseStamped( + ts=100.0 + i * 0.1, + frame_id="world", + position=[i * 1.5, i * 2.5, i * 3.5], + orientation=[0.1, 0.2, 0.3, 0.9], + ) + for i in range(3) + ] + + original = Path(ts=99.789, frame_id="world", poses=poses) + + ros_msg = original.to_ros_msg() + restored = Path.from_ros_msg(ros_msg) + + assert restored.frame_id == original.frame_id + assert restored.ts == original.ts + assert len(restored.poses) == len(original.poses) + + for orig_pose, rest_pose in zip(original.poses, restored.poses): + assert rest_pose.position.x == orig_pose.position.x + assert rest_pose.position.y == orig_pose.position.y + assert rest_pose.position.z == orig_pose.position.z + assert rest_pose.orientation.x == orig_pose.orientation.x + assert rest_pose.orientation.y == orig_pose.orientation.y + assert rest_pose.orientation.z == orig_pose.orientation.z + assert rest_pose.orientation.w == orig_pose.orientation.w diff --git a/dimos/msgs/sensor_msgs/CameraInfo.py b/dimos/msgs/sensor_msgs/CameraInfo.py new file mode 100644 index 0000000000..70a99e35ec --- /dev/null +++ b/dimos/msgs/sensor_msgs/CameraInfo.py @@ -0,0 +1,328 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import time +from typing import List, Optional + +import numpy as np + +# Import LCM types +from dimos_lcm.sensor_msgs import CameraInfo as LCMCameraInfo +from dimos_lcm.std_msgs.Header import Header + +# Import ROS types +try: + from sensor_msgs.msg import CameraInfo as ROSCameraInfo + from std_msgs.msg import Header as ROSHeader + from sensor_msgs.msg import RegionOfInterest as ROSRegionOfInterest + + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + +from dimos.types.timestamped import Timestamped + + +class CameraInfo(Timestamped): + """Camera calibration information message.""" + + msg_name = "sensor_msgs.CameraInfo" + + def __init__( + self, + height: int = 0, + width: int = 0, + distortion_model: str = "", + D: Optional[List[float]] = None, + K: Optional[List[float]] = None, + R: Optional[List[float]] = None, + P: Optional[List[float]] = None, + binning_x: int = 0, + binning_y: int = 0, + frame_id: str = "", + ts: Optional[float] = None, + ): + """Initialize CameraInfo. + + Args: + height: Image height + width: Image width + distortion_model: Name of distortion model (e.g., "plumb_bob") + D: Distortion coefficients + K: 3x3 intrinsic camera matrix + R: 3x3 rectification matrix + P: 3x4 projection matrix + binning_x: Horizontal binning + binning_y: Vertical binning + frame_id: Frame ID + ts: Timestamp + """ + self.ts = ts if ts is not None else time.time() + self.frame_id = frame_id + self.height = height + self.width = width + self.distortion_model = distortion_model + + # Initialize distortion coefficients + self.D = D if D is not None else [] + + # Initialize 3x3 intrinsic camera matrix (row-major) + self.K = K if K is not None else [0.0] * 9 + + # Initialize 3x3 rectification matrix (row-major) + self.R = R if R is not None else [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0] + + # Initialize 3x4 projection matrix (row-major) + self.P = P if P is not None else [0.0] * 12 + + self.binning_x = binning_x + self.binning_y = binning_y + + # Region of interest (not used in basic implementation) + self.roi_x_offset = 0 + self.roi_y_offset = 0 + self.roi_height = 0 + self.roi_width = 0 + self.roi_do_rectify = False + + def get_K_matrix(self) -> np.ndarray: + """Get intrinsic matrix as numpy array.""" + return np.array(self.K, dtype=np.float64).reshape(3, 3) + + def get_P_matrix(self) -> np.ndarray: + """Get projection matrix as numpy array.""" + return np.array(self.P, dtype=np.float64).reshape(3, 4) + + def get_R_matrix(self) -> np.ndarray: + """Get rectification matrix as numpy array.""" + return np.array(self.R, dtype=np.float64).reshape(3, 3) + + def get_D_coeffs(self) -> np.ndarray: + """Get distortion coefficients as numpy array.""" + return np.array(self.D, dtype=np.float64) + + def set_K_matrix(self, K: np.ndarray): + """Set intrinsic matrix from numpy array.""" + if K.shape != (3, 3): + raise ValueError(f"K matrix must be 3x3, got {K.shape}") + self.K = K.flatten().tolist() + + def set_P_matrix(self, P: np.ndarray): + """Set projection matrix from numpy array.""" + if P.shape != (3, 4): + raise ValueError(f"P matrix must be 3x4, got {P.shape}") + self.P = P.flatten().tolist() + + def set_R_matrix(self, R: np.ndarray): + """Set rectification matrix from numpy array.""" + if R.shape != (3, 3): + raise ValueError(f"R matrix must be 3x3, got {R.shape}") + self.R = R.flatten().tolist() + + def set_D_coeffs(self, D: np.ndarray): + """Set distortion coefficients from numpy array.""" + self.D = D.flatten().tolist() + + def lcm_encode(self) -> bytes: + """Convert to LCM CameraInfo message.""" + msg = LCMCameraInfo() + + # Header + msg.header = Header() + msg.header.seq = 0 + msg.header.frame_id = self.frame_id + msg.header.stamp.sec = int(self.ts) + msg.header.stamp.nsec = int((self.ts - int(self.ts)) * 1e9) + + # Image dimensions + msg.height = self.height + msg.width = self.width + + # Distortion model + msg.distortion_model = self.distortion_model + + # Distortion coefficients + msg.D_length = len(self.D) + msg.D = self.D + + # Camera matrices (all stored as row-major) + msg.K = self.K + msg.R = self.R + msg.P = self.P + + # Binning + msg.binning_x = self.binning_x + msg.binning_y = self.binning_y + + # ROI + msg.roi.x_offset = self.roi_x_offset + msg.roi.y_offset = self.roi_y_offset + msg.roi.height = self.roi_height + msg.roi.width = self.roi_width + msg.roi.do_rectify = self.roi_do_rectify + + return msg.lcm_encode() + + @classmethod + def lcm_decode(cls, data: bytes) -> "CameraInfo": + """Decode from LCM CameraInfo bytes.""" + msg = LCMCameraInfo.lcm_decode(data) + + # Extract timestamp + ts = msg.header.stamp.sec + msg.header.stamp.nsec / 1e9 if hasattr(msg, "header") else None + + camera_info = cls( + height=msg.height, + width=msg.width, + distortion_model=msg.distortion_model, + D=list(msg.D) if msg.D_length > 0 else [], + K=list(msg.K), + R=list(msg.R), + P=list(msg.P), + binning_x=msg.binning_x, + binning_y=msg.binning_y, + frame_id=msg.header.frame_id if hasattr(msg, "header") else "", + ts=ts, + ) + + # Set ROI if present + if hasattr(msg, "roi"): + camera_info.roi_x_offset = msg.roi.x_offset + camera_info.roi_y_offset = msg.roi.y_offset + camera_info.roi_height = msg.roi.height + camera_info.roi_width = msg.roi.width + camera_info.roi_do_rectify = msg.roi.do_rectify + + return camera_info + + @classmethod + def from_ros_msg(cls, ros_msg: "ROSCameraInfo") -> "CameraInfo": + """Create CameraInfo from ROS sensor_msgs/CameraInfo message. + + Args: + ros_msg: ROS CameraInfo message + + Returns: + CameraInfo instance + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert from ROS message.") + + # Extract timestamp + ts = ros_msg.header.stamp.sec + ros_msg.header.stamp.nanosec / 1e9 + + camera_info = cls( + height=ros_msg.height, + width=ros_msg.width, + distortion_model=ros_msg.distortion_model, + D=list(ros_msg.d), + K=list(ros_msg.k), + R=list(ros_msg.r), + P=list(ros_msg.p), + binning_x=ros_msg.binning_x, + binning_y=ros_msg.binning_y, + frame_id=ros_msg.header.frame_id, + ts=ts, + ) + + # Set ROI + camera_info.roi_x_offset = ros_msg.roi.x_offset + camera_info.roi_y_offset = ros_msg.roi.y_offset + camera_info.roi_height = ros_msg.roi.height + camera_info.roi_width = ros_msg.roi.width + camera_info.roi_do_rectify = ros_msg.roi.do_rectify + + return camera_info + + def to_ros_msg(self) -> "ROSCameraInfo": + """Convert to ROS sensor_msgs/CameraInfo message. + + Returns: + ROS CameraInfo message + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert to ROS message.") + + ros_msg = ROSCameraInfo() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1e9) + + # Image dimensions + ros_msg.height = self.height + ros_msg.width = self.width + + # Distortion model and coefficients + ros_msg.distortion_model = self.distortion_model + ros_msg.d = self.D + + # Camera matrices (all row-major) + ros_msg.k = self.K + ros_msg.r = self.R + ros_msg.p = self.P + + # Binning + ros_msg.binning_x = self.binning_x + ros_msg.binning_y = self.binning_y + + # ROI + ros_msg.roi = ROSRegionOfInterest() + ros_msg.roi.x_offset = self.roi_x_offset + ros_msg.roi.y_offset = self.roi_y_offset + ros_msg.roi.height = self.roi_height + ros_msg.roi.width = self.roi_width + ros_msg.roi.do_rectify = self.roi_do_rectify + + return ros_msg + + def __repr__(self) -> str: + """String representation.""" + return ( + f"CameraInfo(height={self.height}, width={self.width}, " + f"distortion_model='{self.distortion_model}', " + f"frame_id='{self.frame_id}', ts={self.ts})" + ) + + def __str__(self) -> str: + """Human-readable string.""" + return ( + f"CameraInfo:\n" + f" Resolution: {self.width}x{self.height}\n" + f" Distortion model: {self.distortion_model}\n" + f" Frame ID: {self.frame_id}\n" + f" Binning: {self.binning_x}x{self.binning_y}" + ) + + def __eq__(self, other) -> bool: + """Check if two CameraInfo messages are equal.""" + if not isinstance(other, CameraInfo): + return False + + return ( + self.height == other.height + and self.width == other.width + and self.distortion_model == other.distortion_model + and self.D == other.D + and self.K == other.K + and self.R == other.R + and self.P == other.P + and self.binning_x == other.binning_x + and self.binning_y == other.binning_y + and self.frame_id == other.frame_id + ) diff --git a/dimos/msgs/sensor_msgs/PointCloud2.py b/dimos/msgs/sensor_msgs/PointCloud2.py index 2238b31025..b5352ed6cb 100644 --- a/dimos/msgs/sensor_msgs/PointCloud2.py +++ b/dimos/msgs/sensor_msgs/PointCloud2.py @@ -28,6 +28,16 @@ from dimos_lcm.sensor_msgs.PointField import PointField from dimos_lcm.std_msgs.Header import Header +# Import ROS types +try: + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 + from sensor_msgs.msg import PointField as ROSPointField + from std_msgs.msg import Header as ROSHeader + + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + from dimos.types.timestamped import Timestamped @@ -211,3 +221,170 @@ def __len__(self) -> int: def __repr__(self) -> str: """String representation.""" return f"PointCloud(points={len(self)}, frame_id='{self.frame_id}', ts={self.ts})" + + @classmethod + def from_ros_msg(cls, ros_msg: "ROSPointCloud2") -> "PointCloud2": + """Convert from ROS sensor_msgs/PointCloud2 message. + + Args: + ros_msg: ROS PointCloud2 message + + Returns: + PointCloud2 instance + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert from ROS message.") + + # Handle empty point cloud + if ros_msg.width == 0 or ros_msg.height == 0: + pc = o3d.geometry.PointCloud() + return cls( + pointcloud=pc, + frame_id=ros_msg.header.frame_id, + ts=ros_msg.header.stamp.sec + ros_msg.header.stamp.nanosec / 1e9, + ) + + # Parse field information to find X, Y, Z offsets + x_offset = y_offset = z_offset = None + for field in ros_msg.fields: + if field.name == "x": + x_offset = field.offset + elif field.name == "y": + y_offset = field.offset + elif field.name == "z": + z_offset = field.offset + + if any(offset is None for offset in [x_offset, y_offset, z_offset]): + raise ValueError("PointCloud2 message missing X, Y, or Z fields") + + # Extract points from binary data using numpy for bulk conversion + num_points = ros_msg.width * ros_msg.height + data = ros_msg.data + point_step = ros_msg.point_step + + # Determine byte order + byte_order = ">" if ros_msg.is_bigendian else "<" + + # Check if we can use fast numpy path (common case: sequential float32 x,y,z) + if ( + x_offset == 0 + and y_offset == 4 + and z_offset == 8 + and point_step >= 12 + and not ros_msg.is_bigendian + ): + # Fast path: direct numpy reshape for tightly packed float32 x,y,z + # This is the most common case for point clouds + if point_step == 12: + # Perfectly packed x,y,z with no padding + points = np.frombuffer(data, dtype=np.float32).reshape(-1, 3) + else: + # Has additional fields after x,y,z, need to extract with stride + dt = np.dtype( + [("x", " 0: + dt_fields.append(("_pad_x", f"V{x_offset}")) + dt_fields.append(("x", f"{byte_order}f4")) + + # Add padding between x and y if needed + gap_xy = y_offset - x_offset - 4 + if gap_xy > 0: + dt_fields.append(("_pad_xy", f"V{gap_xy}")) + dt_fields.append(("y", f"{byte_order}f4")) + + # Add padding between y and z if needed + gap_yz = z_offset - y_offset - 4 + if gap_yz > 0: + dt_fields.append(("_pad_yz", f"V{gap_yz}")) + dt_fields.append(("z", f"{byte_order}f4")) + + # Add padding at the end to match point_step + remaining = point_step - z_offset - 4 + if remaining > 0: + dt_fields.append(("_pad_end", f"V{remaining}")) + + dt = np.dtype(dt_fields) + structured = np.frombuffer(data, dtype=dt, count=num_points) + points = np.column_stack((structured["x"], structured["y"], structured["z"])) + + # Filter out NaN and Inf values if not dense + if not ros_msg.is_dense: + mask = np.isfinite(points).all(axis=1) + points = points[mask] + + # Create Open3D point cloud + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points) + + # Extract timestamp + ts = ros_msg.header.stamp.sec + ros_msg.header.stamp.nanosec / 1e9 + + return cls( + pointcloud=pc, + frame_id=ros_msg.header.frame_id, + ts=ts, + ) + + def to_ros_msg(self) -> "ROSPointCloud2": + """Convert to ROS sensor_msgs/PointCloud2 message. + + Returns: + ROS PointCloud2 message + """ + if not ROS_AVAILABLE: + raise ImportError("ROS packages not available. Cannot convert to ROS message.") + + ros_msg = ROSPointCloud2() + + # Set header + ros_msg.header = ROSHeader() + ros_msg.header.frame_id = self.frame_id + ros_msg.header.stamp.sec = int(self.ts) + ros_msg.header.stamp.nanosec = int((self.ts - int(self.ts)) * 1e9) + + points = self.as_numpy() + + if len(points) == 0: + # Empty point cloud + ros_msg.height = 0 + ros_msg.width = 0 + ros_msg.fields = [] + ros_msg.is_bigendian = False + ros_msg.point_step = 0 + ros_msg.row_step = 0 + ros_msg.data = b"" + ros_msg.is_dense = True + return ros_msg + + # Set dimensions + ros_msg.height = 1 # Unorganized point cloud + ros_msg.width = len(points) + + # Define fields (X, Y, Z as float32) + ros_msg.fields = [ + ROSPointField(name="x", offset=0, datatype=ROSPointField.FLOAT32, count=1), + ROSPointField(name="y", offset=4, datatype=ROSPointField.FLOAT32, count=1), + ROSPointField(name="z", offset=8, datatype=ROSPointField.FLOAT32, count=1), + ] + + # Set point step and row step + ros_msg.point_step = 12 # 3 floats * 4 bytes each + ros_msg.row_step = ros_msg.point_step * ros_msg.width + + # Convert points to bytes (little endian float32) + ros_msg.data = points.astype(np.float32).tobytes() + + # Set properties + ros_msg.is_bigendian = False # Little endian + ros_msg.is_dense = True # No invalid points + + return ros_msg diff --git a/dimos/msgs/sensor_msgs/__init__.py b/dimos/msgs/sensor_msgs/__init__.py index 434ec75afb..a7afafe2f2 100644 --- a/dimos/msgs/sensor_msgs/__init__.py +++ b/dimos/msgs/sensor_msgs/__init__.py @@ -1,2 +1,3 @@ from dimos.msgs.sensor_msgs.Image import Image, ImageFormat from dimos.msgs.sensor_msgs.PointCloud2 import PointCloud2 +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo diff --git a/dimos/msgs/sensor_msgs/test_CameraInfo.py b/dimos/msgs/sensor_msgs/test_CameraInfo.py new file mode 100644 index 0000000000..0c755f74f5 --- /dev/null +++ b/dimos/msgs/sensor_msgs/test_CameraInfo.py @@ -0,0 +1,409 @@ +#!/usr/bin/env python3 +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import numpy as np + + +try: + from sensor_msgs.msg import CameraInfo as ROSCameraInfo + from sensor_msgs.msg import RegionOfInterest as ROSRegionOfInterest + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSCameraInfo = None + ROSRegionOfInterest = None + ROSHeader = None + +from dimos.msgs.sensor_msgs.CameraInfo import CameraInfo + +# Try to import ROS types for testing +try: + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + + +def test_lcm_encode_decode(): + """Test LCM encode/decode preserves CameraInfo data.""" + print("Testing CameraInfo LCM encode/decode...") + + # Create test camera info with sample calibration data + original = CameraInfo( + height=480, + width=640, + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.001, -0.002, 0.0], # 5 distortion coefficients + K=[ + 500.0, + 0.0, + 320.0, # fx, 0, cx + 0.0, + 500.0, + 240.0, # 0, fy, cy + 0.0, + 0.0, + 1.0, + ], # 0, 0, 1 + R=[1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0], + P=[ + 500.0, + 0.0, + 320.0, + 0.0, # fx, 0, cx, Tx + 0.0, + 500.0, + 240.0, + 0.0, # 0, fy, cy, Ty + 0.0, + 0.0, + 1.0, + 0.0, + ], # 0, 0, 1, 0 + binning_x=2, + binning_y=2, + frame_id="camera_optical_frame", + ts=1234567890.123456, + ) + + # Set ROI + original.roi_x_offset = 100 + original.roi_y_offset = 50 + original.roi_height = 200 + original.roi_width = 300 + original.roi_do_rectify = True + + # Encode and decode + binary_msg = original.lcm_encode() + decoded = CameraInfo.lcm_decode(binary_msg) + + # Check basic properties + assert original.height == decoded.height, ( + f"Height mismatch: {original.height} vs {decoded.height}" + ) + assert original.width == decoded.width, f"Width mismatch: {original.width} vs {decoded.width}" + print(f"✓ Image dimensions preserved: {decoded.width}x{decoded.height}") + + assert original.distortion_model == decoded.distortion_model, ( + f"Distortion model mismatch: '{original.distortion_model}' vs '{decoded.distortion_model}'" + ) + print(f"✓ Distortion model preserved: '{decoded.distortion_model}'") + + # Check distortion coefficients + assert len(original.D) == len(decoded.D), ( + f"D length mismatch: {len(original.D)} vs {len(decoded.D)}" + ) + np.testing.assert_allclose( + original.D, decoded.D, rtol=1e-9, atol=1e-9, err_msg="Distortion coefficients don't match" + ) + print(f"✓ Distortion coefficients preserved: {len(decoded.D)} coefficients") + + # Check camera matrices + np.testing.assert_allclose( + original.K, decoded.K, rtol=1e-9, atol=1e-9, err_msg="K matrix doesn't match" + ) + print("✓ Intrinsic matrix K preserved") + + np.testing.assert_allclose( + original.R, decoded.R, rtol=1e-9, atol=1e-9, err_msg="R matrix doesn't match" + ) + print("✓ Rectification matrix R preserved") + + np.testing.assert_allclose( + original.P, decoded.P, rtol=1e-9, atol=1e-9, err_msg="P matrix doesn't match" + ) + print("✓ Projection matrix P preserved") + + # Check binning + assert original.binning_x == decoded.binning_x, ( + f"Binning X mismatch: {original.binning_x} vs {decoded.binning_x}" + ) + assert original.binning_y == decoded.binning_y, ( + f"Binning Y mismatch: {original.binning_y} vs {decoded.binning_y}" + ) + print(f"✓ Binning preserved: {decoded.binning_x}x{decoded.binning_y}") + + # Check ROI + assert original.roi_x_offset == decoded.roi_x_offset, "ROI x_offset mismatch" + assert original.roi_y_offset == decoded.roi_y_offset, "ROI y_offset mismatch" + assert original.roi_height == decoded.roi_height, "ROI height mismatch" + assert original.roi_width == decoded.roi_width, "ROI width mismatch" + assert original.roi_do_rectify == decoded.roi_do_rectify, "ROI do_rectify mismatch" + print("✓ ROI preserved") + + # Check metadata + assert original.frame_id == decoded.frame_id, ( + f"Frame ID mismatch: '{original.frame_id}' vs '{decoded.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{decoded.frame_id}'") + + assert abs(original.ts - decoded.ts) < 1e-6, ( + f"Timestamp mismatch: {original.ts} vs {decoded.ts}" + ) + print(f"✓ Timestamp preserved: {decoded.ts}") + + print("✓ LCM encode/decode test passed - all properties preserved!") + + +def test_numpy_matrix_operations(): + """Test numpy matrix getter/setter operations.""" + print("\nTesting numpy matrix operations...") + + camera_info = CameraInfo() + + # Test K matrix + K = np.array([[525.0, 0.0, 319.5], [0.0, 525.0, 239.5], [0.0, 0.0, 1.0]]) + camera_info.set_K_matrix(K) + K_retrieved = camera_info.get_K_matrix() + np.testing.assert_allclose(K, K_retrieved, rtol=1e-9, atol=1e-9) + print("✓ K matrix setter/getter works") + + # Test P matrix + P = np.array([[525.0, 0.0, 319.5, 0.0], [0.0, 525.0, 239.5, 0.0], [0.0, 0.0, 1.0, 0.0]]) + camera_info.set_P_matrix(P) + P_retrieved = camera_info.get_P_matrix() + np.testing.assert_allclose(P, P_retrieved, rtol=1e-9, atol=1e-9) + print("✓ P matrix setter/getter works") + + # Test R matrix + R = np.eye(3) + camera_info.set_R_matrix(R) + R_retrieved = camera_info.get_R_matrix() + np.testing.assert_allclose(R, R_retrieved, rtol=1e-9, atol=1e-9) + print("✓ R matrix setter/getter works") + + # Test D coefficients + D = np.array([-0.2, 0.1, 0.001, -0.002, 0.05]) + camera_info.set_D_coeffs(D) + D_retrieved = camera_info.get_D_coeffs() + np.testing.assert_allclose(D, D_retrieved, rtol=1e-9, atol=1e-9) + print("✓ D coefficients setter/getter works") + + print("✓ All numpy matrix operations passed!") + + +@pytest.mark.ros +def test_ros_conversion(): + """Test ROS message conversion preserves CameraInfo data.""" + if not ROS_AVAILABLE: + print("\nROS packages not available - skipping ROS conversion test") + return + + print("\nTesting ROS CameraInfo conversion...") + + # Create test camera info + original = CameraInfo( + height=720, + width=1280, + distortion_model="rational_polynomial", + D=[0.1, -0.2, 0.001, 0.002, -0.05, 0.01, -0.02, 0.003], # 8 coefficients + K=[600.0, 0.0, 640.0, 0.0, 600.0, 360.0, 0.0, 0.0, 1.0], + R=[0.999, -0.01, 0.02, 0.01, 0.999, -0.01, -0.02, 0.01, 0.999], + P=[ + 600.0, + 0.0, + 640.0, + -60.0, # Stereo baseline of 0.1m + 0.0, + 600.0, + 360.0, + 0.0, + 0.0, + 0.0, + 1.0, + 0.0, + ], + binning_x=1, + binning_y=1, + frame_id="left_camera_optical", + ts=1234567890.987654, + ) + + # Set ROI + original.roi_x_offset = 200 + original.roi_y_offset = 100 + original.roi_height = 400 + original.roi_width = 800 + original.roi_do_rectify = False + + # Test 1: Convert to ROS and back + ros_msg = original.to_ros_msg() + converted = CameraInfo.from_ros_msg(ros_msg) + + # Check all properties + assert original.height == converted.height, ( + f"Height mismatch: {original.height} vs {converted.height}" + ) + assert original.width == converted.width, ( + f"Width mismatch: {original.width} vs {converted.width}" + ) + print(f"✓ Dimensions preserved: {converted.width}x{converted.height}") + + assert original.distortion_model == converted.distortion_model, ( + f"Distortion model mismatch: '{original.distortion_model}' vs '{converted.distortion_model}'" + ) + print(f"✓ Distortion model preserved: '{converted.distortion_model}'") + + np.testing.assert_allclose( + original.D, + converted.D, + rtol=1e-9, + atol=1e-9, + err_msg="D coefficients don't match after ROS conversion", + ) + print(f"✓ Distortion coefficients preserved: {len(converted.D)} coefficients") + + np.testing.assert_allclose( + original.K, + converted.K, + rtol=1e-9, + atol=1e-9, + err_msg="K matrix doesn't match after ROS conversion", + ) + print("✓ K matrix preserved") + + np.testing.assert_allclose( + original.R, + converted.R, + rtol=1e-9, + atol=1e-9, + err_msg="R matrix doesn't match after ROS conversion", + ) + print("✓ R matrix preserved") + + np.testing.assert_allclose( + original.P, + converted.P, + rtol=1e-9, + atol=1e-9, + err_msg="P matrix doesn't match after ROS conversion", + ) + print("✓ P matrix preserved") + + assert original.binning_x == converted.binning_x, "Binning X mismatch" + assert original.binning_y == converted.binning_y, "Binning Y mismatch" + print(f"✓ Binning preserved: {converted.binning_x}x{converted.binning_y}") + + assert original.roi_x_offset == converted.roi_x_offset, "ROI x_offset mismatch" + assert original.roi_y_offset == converted.roi_y_offset, "ROI y_offset mismatch" + assert original.roi_height == converted.roi_height, "ROI height mismatch" + assert original.roi_width == converted.roi_width, "ROI width mismatch" + assert original.roi_do_rectify == converted.roi_do_rectify, "ROI do_rectify mismatch" + print("✓ ROI preserved") + + assert original.frame_id == converted.frame_id, ( + f"Frame ID mismatch: '{original.frame_id}' vs '{converted.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{converted.frame_id}'") + + assert abs(original.ts - converted.ts) < 1e-6, ( + f"Timestamp mismatch: {original.ts} vs {converted.ts}" + ) + print(f"✓ Timestamp preserved: {converted.ts}") + + # Test 2: Create ROS message directly and convert to DIMOS + ros_msg2 = ROSCameraInfo() + ros_msg2.header = ROSHeader() + ros_msg2.header.frame_id = "test_camera" + ros_msg2.header.stamp.sec = 1234567890 + ros_msg2.header.stamp.nanosec = 500000000 + + ros_msg2.height = 1080 + ros_msg2.width = 1920 + ros_msg2.distortion_model = "plumb_bob" + ros_msg2.d = [-0.3, 0.15, 0.0, 0.0, 0.0] + ros_msg2.k = [1000.0, 0.0, 960.0, 0.0, 1000.0, 540.0, 0.0, 0.0, 1.0] + ros_msg2.r = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0] + ros_msg2.p = [1000.0, 0.0, 960.0, 0.0, 0.0, 1000.0, 540.0, 0.0, 0.0, 0.0, 1.0, 0.0] + ros_msg2.binning_x = 4 + ros_msg2.binning_y = 4 + + ros_msg2.roi = ROSRegionOfInterest() + ros_msg2.roi.x_offset = 10 + ros_msg2.roi.y_offset = 20 + ros_msg2.roi.height = 100 + ros_msg2.roi.width = 200 + ros_msg2.roi.do_rectify = True + + # Convert to DIMOS + dimos_info = CameraInfo.from_ros_msg(ros_msg2) + + assert dimos_info.height == 1080, ( + f"Height not preserved: expected 1080, got {dimos_info.height}" + ) + assert dimos_info.width == 1920, f"Width not preserved: expected 1920, got {dimos_info.width}" + assert dimos_info.frame_id == "test_camera", ( + f"Frame ID not preserved: expected 'test_camera', got '{dimos_info.frame_id}'" + ) + assert dimos_info.distortion_model == "plumb_bob", f"Distortion model not preserved" + assert len(dimos_info.D) == 5, ( + f"Wrong number of distortion coefficients: expected 5, got {len(dimos_info.D)}" + ) + print("✓ ROS to DIMOS conversion works correctly") + + # Test 3: Empty/minimal CameraInfo + minimal = CameraInfo(frame_id="minimal_camera", ts=1234567890.0) + minimal_ros = minimal.to_ros_msg() + minimal_converted = CameraInfo.from_ros_msg(minimal_ros) + + assert minimal.frame_id == minimal_converted.frame_id, ( + "Minimal CameraInfo frame_id not preserved" + ) + assert len(minimal_converted.D) == 0, "Minimal CameraInfo should have empty D" + print("✓ Minimal CameraInfo handling works") + + print("\n✓ All ROS conversion tests passed!") + + +def test_equality(): + """Test CameraInfo equality comparison.""" + print("\nTesting CameraInfo equality...") + + info1 = CameraInfo( + height=480, + width=640, + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.0, 0.0, 0.0], + frame_id="camera1", + ) + + info2 = CameraInfo( + height=480, + width=640, + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.0, 0.0, 0.0], + frame_id="camera1", + ) + + info3 = CameraInfo( + height=720, + width=1280, # Different resolution + distortion_model="plumb_bob", + D=[-0.1, 0.05, 0.0, 0.0, 0.0], + frame_id="camera1", + ) + + assert info1 == info2, "Identical CameraInfo objects should be equal" + assert info1 != info3, "Different CameraInfo objects should not be equal" + assert info1 != "not_camera_info", "CameraInfo should not equal non-CameraInfo object" + + print("✓ Equality comparison works correctly") + + +if __name__ == "__main__": + test_lcm_encode_decode() + test_numpy_matrix_operations() + test_ros_conversion() + test_equality() + print("\n✓✓✓ All CameraInfo tests passed! ✓✓✓") diff --git a/dimos/msgs/sensor_msgs/test_PointCloud2.py b/dimos/msgs/sensor_msgs/test_PointCloud2.py index eee1778680..ac115462be 100644 --- a/dimos/msgs/sensor_msgs/test_PointCloud2.py +++ b/dimos/msgs/sensor_msgs/test_PointCloud2.py @@ -13,12 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import numpy as np +import struct + + +try: + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 + from sensor_msgs.msg import PointField as ROSPointField + from std_msgs.msg import Header as ROSHeader +except ImportError: + ROSPointCloud2 = None + ROSPointField = None + ROSHeader = None from dimos.msgs.sensor_msgs import PointCloud2 from dimos.robot.unitree_webrtc.type.lidar import LidarMessage from dimos.utils.testing import SensorReplay +# Try to import ROS types for testing +try: + ROS_AVAILABLE = True +except ImportError: + ROS_AVAILABLE = False + def test_lcm_encode_decode(): """Test LCM encode/decode preserves pointcloud data.""" @@ -79,3 +97,143 @@ def test_lcm_encode_decode(): print(f" - Mean: {decoded_points.mean(axis=0)}") print("✓ LCM encode/decode test passed - all properties preserved!") + + +@pytest.mark.ros +def test_ros_conversion(): + """Test ROS message conversion preserves pointcloud data.""" + if not ROS_AVAILABLE: + print("ROS packages not available - skipping ROS conversion test") + return + + print("\nTesting ROS PointCloud2 conversion...") + + # Create a simple test point cloud + import open3d as o3d + + points = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [-1.0, -2.0, -3.0], + [0.5, 0.5, 0.5], + ], + dtype=np.float32, + ) + + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points) + + # Create DIMOS PointCloud2 + original = PointCloud2( + pointcloud=pc, + frame_id="test_frame", + ts=1234567890.123456, + ) + + # Test 1: Convert to ROS and back + ros_msg = original.to_ros_msg() + converted = PointCloud2.from_ros_msg(ros_msg) + + # Check points are preserved + original_points = original.as_numpy() + converted_points = converted.as_numpy() + + assert len(original_points) == len(converted_points), ( + f"Point count mismatch: {len(original_points)} vs {len(converted_points)}" + ) + + np.testing.assert_allclose( + original_points, + converted_points, + rtol=1e-6, + atol=1e-6, + err_msg="Points don't match after ROS conversion", + ) + print(f"✓ Points preserved: {len(converted_points)} points match") + + # Check metadata + assert original.frame_id == converted.frame_id, ( + f"Frame ID mismatch: '{original.frame_id}' vs '{converted.frame_id}'" + ) + print(f"✓ Frame ID preserved: '{converted.frame_id}'") + + assert abs(original.ts - converted.ts) < 1e-6, ( + f"Timestamp mismatch: {original.ts} vs {converted.ts}" + ) + print(f"✓ Timestamp preserved: {converted.ts}") + + # Test 2: Create ROS message directly and convert to DIMOS + ros_msg2 = ROSPointCloud2() + ros_msg2.header = ROSHeader() + ros_msg2.header.frame_id = "ros_test_frame" + ros_msg2.header.stamp.sec = 1234567890 + ros_msg2.header.stamp.nanosec = 123456000 + + # Set up point cloud data + ros_msg2.height = 1 + ros_msg2.width = 3 + ros_msg2.fields = [ + ROSPointField(name="x", offset=0, datatype=ROSPointField.FLOAT32, count=1), + ROSPointField(name="y", offset=4, datatype=ROSPointField.FLOAT32, count=1), + ROSPointField(name="z", offset=8, datatype=ROSPointField.FLOAT32, count=1), + ] + ros_msg2.is_bigendian = False + ros_msg2.point_step = 12 + ros_msg2.row_step = 36 + + # Pack test points + test_points = np.array( + [ + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ], + dtype=np.float32, + ) + ros_msg2.data = test_points.tobytes() + ros_msg2.is_dense = True + + # Convert to DIMOS + dimos_pc = PointCloud2.from_ros_msg(ros_msg2) + + assert dimos_pc.frame_id == "ros_test_frame", ( + f"Frame ID not preserved: expected 'ros_test_frame', got '{dimos_pc.frame_id}'" + ) + + decoded_points = dimos_pc.as_numpy() + assert len(decoded_points) == 3, ( + f"Wrong number of points: expected 3, got {len(decoded_points)}" + ) + + np.testing.assert_allclose( + test_points, + decoded_points, + rtol=1e-6, + atol=1e-6, + err_msg="Points from ROS message don't match", + ) + print("✓ ROS to DIMOS conversion works correctly") + + # Test 3: Empty point cloud + empty_pc = PointCloud2( + pointcloud=o3d.geometry.PointCloud(), + frame_id="empty_frame", + ts=1234567890.0, + ) + + empty_ros = empty_pc.to_ros_msg() + assert empty_ros.width == 0, "Empty cloud should have width 0" + assert empty_ros.height == 0, "Empty cloud should have height 0" + assert len(empty_ros.data) == 0, "Empty cloud should have no data" + + empty_converted = PointCloud2.from_ros_msg(empty_ros) + assert len(empty_converted) == 0, "Empty cloud conversion failed" + print("✓ Empty point cloud handling works") + + print("\n✓ All ROS conversion tests passed!") + + +if __name__ == "__main__": + test_lcm_encode_decode() + test_ros_conversion() diff --git a/dimos/msgs/tf2_msgs/TFMessage.py b/dimos/msgs/tf2_msgs/TFMessage.py index 9ccba615b2..d2bb018c34 100644 --- a/dimos/msgs/tf2_msgs/TFMessage.py +++ b/dimos/msgs/tf2_msgs/TFMessage.py @@ -35,6 +35,13 @@ from dimos_lcm.std_msgs import Time as LCMTime from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage +try: + from tf2_msgs.msg import TFMessage as ROSTFMessage + from geometry_msgs.msg import TransformStamped as ROSTransformStamped +except ImportError: + ROSTFMessage = None + ROSTransformStamped = None + from dimos.msgs.geometry_msgs.Transform import Transform from dimos.msgs.geometry_msgs.Vector3 import Vector3 from dimos.msgs.geometry_msgs.Quaternion import Quaternion @@ -119,3 +126,35 @@ def __str__(self) -> str: for i, transform in enumerate(self.transforms): lines.append(f" [{i}] {transform.frame_id} @ {transform.ts:.3f}") return "\n".join(lines) + + @classmethod + def from_ros_msg(cls, ros_msg: ROSTFMessage) -> "TFMessage": + """Create a TFMessage from a ROS tf2_msgs/TFMessage message. + + Args: + ros_msg: ROS TFMessage message + + Returns: + TFMessage instance + """ + transforms = [] + for ros_transform_stamped in ros_msg.transforms: + # Convert from ROS TransformStamped to our Transform + transform = Transform.from_ros_transform_stamped(ros_transform_stamped) + transforms.append(transform) + + return cls(*transforms) + + def to_ros_msg(self) -> ROSTFMessage: + """Convert to a ROS tf2_msgs/TFMessage message. + + Returns: + ROS TFMessage message + """ + ros_msg = ROSTFMessage() + + # Convert each Transform to ROS TransformStamped + for transform in self.transforms: + ros_msg.transforms.append(transform.to_ros_transform_stamped()) + + return ros_msg diff --git a/dimos/msgs/tf2_msgs/test_TFMessage.py b/dimos/msgs/tf2_msgs/test_TFMessage.py index 6eec4cbdcc..dfe3400e1c 100644 --- a/dimos/msgs/tf2_msgs/test_TFMessage.py +++ b/dimos/msgs/tf2_msgs/test_TFMessage.py @@ -13,6 +13,14 @@ # limitations under the License. import pytest + +try: + from tf2_msgs.msg import TFMessage as ROSTFMessage + from geometry_msgs.msg import TransformStamped as ROSTransformStamped +except ImportError: + ROSTransformStamped = None + ROSTFMessage = None + from dimos_lcm.tf2_msgs import TFMessage as LCMTFMessage from dimos.msgs.geometry_msgs import Quaternion, Transform, Vector3 @@ -107,3 +115,155 @@ def test_tfmessage_lcm_encode_decode(): assert ts2.child_frame_id == "target" assert ts2.transform.rotation.z == 0.707 assert ts2.transform.rotation.w == 0.707 + + +@pytest.mark.ros +def test_tfmessage_from_ros_msg(): + """Test creating a TFMessage from a ROS TFMessage message.""" + + ros_msg = ROSTFMessage() + + # Add first transform + tf1 = ROSTransformStamped() + tf1.header.frame_id = "world" + tf1.header.stamp.sec = 123 + tf1.header.stamp.nanosec = 456000000 + tf1.child_frame_id = "robot" + tf1.transform.translation.x = 1.0 + tf1.transform.translation.y = 2.0 + tf1.transform.translation.z = 3.0 + tf1.transform.rotation.x = 0.0 + tf1.transform.rotation.y = 0.0 + tf1.transform.rotation.z = 0.0 + tf1.transform.rotation.w = 1.0 + ros_msg.transforms.append(tf1) + + # Add second transform + tf2 = ROSTransformStamped() + tf2.header.frame_id = "robot" + tf2.header.stamp.sec = 124 + tf2.header.stamp.nanosec = 567000000 + tf2.child_frame_id = "sensor" + tf2.transform.translation.x = 4.0 + tf2.transform.translation.y = 5.0 + tf2.transform.translation.z = 6.0 + tf2.transform.rotation.x = 0.0 + tf2.transform.rotation.y = 0.0 + tf2.transform.rotation.z = 0.707 + tf2.transform.rotation.w = 0.707 + ros_msg.transforms.append(tf2) + + # Convert to TFMessage + tfmsg = TFMessage.from_ros_msg(ros_msg) + + assert len(tfmsg) == 2 + + # Check first transform + assert tfmsg[0].frame_id == "world" + assert tfmsg[0].child_frame_id == "robot" + assert tfmsg[0].ts == 123.456 + assert tfmsg[0].translation.x == 1.0 + assert tfmsg[0].translation.y == 2.0 + assert tfmsg[0].translation.z == 3.0 + assert tfmsg[0].rotation.w == 1.0 + + # Check second transform + assert tfmsg[1].frame_id == "robot" + assert tfmsg[1].child_frame_id == "sensor" + assert tfmsg[1].ts == 124.567 + assert tfmsg[1].translation.x == 4.0 + assert tfmsg[1].translation.y == 5.0 + assert tfmsg[1].translation.z == 6.0 + assert tfmsg[1].rotation.z == 0.707 + assert tfmsg[1].rotation.w == 0.707 + + +@pytest.mark.ros +def test_tfmessage_to_ros_msg(): + """Test converting a TFMessage to a ROS TFMessage message.""" + # Create transforms + tf1 = Transform( + translation=Vector3(1.0, 2.0, 3.0), + rotation=Quaternion(0.0, 0.0, 0.0, 1.0), + frame_id="map", + child_frame_id="base_link", + ts=123.456, + ) + tf2 = Transform( + translation=Vector3(7.0, 8.0, 9.0), + rotation=Quaternion(0.1, 0.2, 0.3, 0.9), + frame_id="base_link", + child_frame_id="lidar", + ts=125.789, + ) + + tfmsg = TFMessage(tf1, tf2) + + # Convert to ROS message + ros_msg = tfmsg.to_ros_msg() + + assert isinstance(ros_msg, ROSTFMessage) + assert len(ros_msg.transforms) == 2 + + # Check first transform + assert ros_msg.transforms[0].header.frame_id == "map" + assert ros_msg.transforms[0].child_frame_id == "base_link" + assert ros_msg.transforms[0].header.stamp.sec == 123 + assert ros_msg.transforms[0].header.stamp.nanosec == 456000000 + assert ros_msg.transforms[0].transform.translation.x == 1.0 + assert ros_msg.transforms[0].transform.translation.y == 2.0 + assert ros_msg.transforms[0].transform.translation.z == 3.0 + assert ros_msg.transforms[0].transform.rotation.w == 1.0 + + # Check second transform + assert ros_msg.transforms[1].header.frame_id == "base_link" + assert ros_msg.transforms[1].child_frame_id == "lidar" + assert ros_msg.transforms[1].header.stamp.sec == 125 + assert ros_msg.transforms[1].header.stamp.nanosec == 789000000 + assert ros_msg.transforms[1].transform.translation.x == 7.0 + assert ros_msg.transforms[1].transform.translation.y == 8.0 + assert ros_msg.transforms[1].transform.translation.z == 9.0 + assert ros_msg.transforms[1].transform.rotation.x == 0.1 + assert ros_msg.transforms[1].transform.rotation.y == 0.2 + assert ros_msg.transforms[1].transform.rotation.z == 0.3 + assert ros_msg.transforms[1].transform.rotation.w == 0.9 + + +@pytest.mark.ros +def test_tfmessage_ros_roundtrip(): + """Test round-trip conversion between TFMessage and ROS TFMessage.""" + # Create transforms with various properties + tf1 = Transform( + translation=Vector3(1.5, 2.5, 3.5), + rotation=Quaternion(0.15, 0.25, 0.35, 0.85), + frame_id="odom", + child_frame_id="base_footprint", + ts=100.123, + ) + tf2 = Transform( + translation=Vector3(0.1, 0.2, 0.3), + rotation=Quaternion(0.0, 0.0, 0.383, 0.924), + frame_id="base_footprint", + child_frame_id="camera", + ts=100.456, + ) + + original = TFMessage(tf1, tf2) + + # Convert to ROS and back + ros_msg = original.to_ros_msg() + restored = TFMessage.from_ros_msg(ros_msg) + + assert len(restored) == len(original) + + for orig_tf, rest_tf in zip(original, restored): + assert rest_tf.frame_id == orig_tf.frame_id + assert rest_tf.child_frame_id == orig_tf.child_frame_id + assert rest_tf.ts == orig_tf.ts + assert rest_tf.translation.x == orig_tf.translation.x + assert rest_tf.translation.y == orig_tf.translation.y + assert rest_tf.translation.z == orig_tf.translation.z + assert rest_tf.rotation.x == orig_tf.rotation.x + assert rest_tf.rotation.y == orig_tf.rotation.y + assert rest_tf.rotation.z == orig_tf.rotation.z + assert rest_tf.rotation.w == orig_tf.rotation.w diff --git a/dimos/robot/ros_bridge.py b/dimos/robot/ros_bridge.py new file mode 100644 index 0000000000..7e845e08d0 --- /dev/null +++ b/dimos/robot/ros_bridge.py @@ -0,0 +1,202 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +import time +from typing import Dict, Any, Type, Literal, Optional +from enum import Enum + +try: + import rclpy + from rclpy.executors import SingleThreadedExecutor + from rclpy.node import Node + from rclpy.qos import QoSProfile, QoSReliabilityPolicy, QoSHistoryPolicy, QoSDurabilityPolicy +except ImportError: + rclpy = None + SingleThreadedExecutor = None + Node = None + QoSProfile = None + QoSReliabilityPolicy = None + QoSHistoryPolicy = None + QoSDurabilityPolicy = None + +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.utils.logging_config import setup_logger + +logger = setup_logger("dimos.robot.ros_bridge", level=logging.INFO) + + +class BridgeDirection(Enum): + """Direction of message bridging.""" + + ROS_TO_DIMOS = "ros_to_dimos" + DIMOS_TO_ROS = "dimos_to_ros" + + +class ROSBridge: + """Unidirectional bridge between ROS and DIMOS for message passing.""" + + def __init__(self, node_name: str = "dimos_ros_bridge"): + """Initialize the ROS-DIMOS bridge. + + Args: + node_name: Name for the ROS node (default: "dimos_ros_bridge") + """ + if not rclpy.ok(): + rclpy.init() + + self.node = Node(node_name) + self.lcm = LCM() + self.lcm.start() + + self._executor = SingleThreadedExecutor() + self._executor.add_node(self.node) + + self._spin_thread = threading.Thread(target=self._ros_spin, daemon=True) + self._spin_thread.start() + + self._bridges: Dict[str, Dict[str, Any]] = {} + + self._qos = QoSProfile( + reliability=QoSReliabilityPolicy.RELIABLE, + history=QoSHistoryPolicy.KEEP_LAST, + durability=QoSDurabilityPolicy.VOLATILE, + depth=10, + ) + + logger.info(f"ROSBridge initialized with node name: {node_name}") + + def _ros_spin(self): + """Background thread for spinning ROS executor.""" + try: + self._executor.spin() + finally: + self._executor.shutdown() + + def add_topic( + self, + topic_name: str, + dimos_type: Type, + ros_type: Type, + direction: BridgeDirection, + remap_topic: Optional[str] = None, + ) -> None: + """Add unidirectional bridging for a topic. + + Args: + topic_name: Name of the topic (e.g., "/cmd_vel") + dimos_type: DIMOS message type (e.g., dimos.msgs.geometry_msgs.Twist) + ros_type: ROS message type (e.g., geometry_msgs.msg.Twist) + direction: Direction of bridging (ROS_TO_DIMOS or DIMOS_TO_ROS) + remap_topic: Optional remapped topic name for the other side + """ + if topic_name in self._bridges: + logger.warning(f"Topic {topic_name} already bridged") + return + + # Determine actual topic names for each side + ros_topic_name = topic_name + dimos_topic_name = topic_name + + if remap_topic: + if direction == BridgeDirection.ROS_TO_DIMOS: + dimos_topic_name = remap_topic + else: # DIMOS_TO_ROS + ros_topic_name = remap_topic + + # Create DIMOS/LCM topic + dimos_topic = Topic(dimos_topic_name, dimos_type) + + ros_subscription = None + ros_publisher = None + dimos_subscription = None + + if direction == BridgeDirection.ROS_TO_DIMOS: + + def ros_callback(msg): + self._ros_to_dimos(msg, dimos_topic, dimos_type, topic_name) + + ros_subscription = self.node.create_subscription( + ros_type, ros_topic_name, ros_callback, self._qos + ) + logger.info(f" ROS → DIMOS: Subscribing to ROS topic {ros_topic_name}") + + elif direction == BridgeDirection.DIMOS_TO_ROS: + ros_publisher = self.node.create_publisher(ros_type, ros_topic_name, self._qos) + + def dimos_callback(msg, _topic): + self._dimos_to_ros(msg, ros_publisher, topic_name) + + dimos_subscription = self.lcm.subscribe(dimos_topic, dimos_callback) + logger.info(f" DIMOS → ROS: Subscribing to DIMOS topic {dimos_topic_name}") + else: + raise ValueError(f"Invalid bridge direction: {direction}") + + self._bridges[topic_name] = { + "dimos_topic": dimos_topic, + "dimos_type": dimos_type, + "ros_type": ros_type, + "ros_subscription": ros_subscription, + "ros_publisher": ros_publisher, + "dimos_subscription": dimos_subscription, + "direction": direction, + "ros_topic_name": ros_topic_name, + "dimos_topic_name": dimos_topic_name, + } + + direction_str = { + BridgeDirection.ROS_TO_DIMOS: "ROS → DIMOS", + BridgeDirection.DIMOS_TO_ROS: "DIMOS → ROS", + }[direction] + + logger.info(f"Bridged topic: {topic_name} ({direction_str})") + if remap_topic: + logger.info(f" Remapped: ROS '{ros_topic_name}' ↔ DIMOS '{dimos_topic_name}'") + logger.info(f" DIMOS type: {dimos_type.__name__}, ROS type: {ros_type.__name__}") + + def _ros_to_dimos( + self, ros_msg: Any, dimos_topic: Topic, dimos_type: Type, _topic_name: str + ) -> None: + """Convert ROS message to DIMOS and publish. + + Args: + ros_msg: ROS message + dimos_topic: DIMOS topic to publish to + dimos_type: DIMOS message type + topic_name: Name of the topic for tracking + """ + dimos_msg = dimos_type.from_ros_msg(ros_msg) + self.lcm.publish(dimos_topic, dimos_msg) + + def _dimos_to_ros(self, dimos_msg: Any, ros_publisher, _topic_name: str) -> None: + """Convert DIMOS message to ROS and publish. + + Args: + dimos_msg: DIMOS message + ros_publisher: ROS publisher to use + _topic_name: Name of the topic (unused, kept for consistency) + """ + ros_msg = dimos_msg.to_ros_msg() + ros_publisher.publish(ros_msg) + + def shutdown(self): + """Shutdown the bridge and clean up resources.""" + self._executor.shutdown() + self.node.destroy_node() + + if rclpy.ok(): + rclpy.shutdown() + + logger.info("ROSBridge shutdown complete") diff --git a/dimos/robot/test_ros_bridge.py b/dimos/robot/test_ros_bridge.py new file mode 100644 index 0000000000..9f83913768 --- /dev/null +++ b/dimos/robot/test_ros_bridge.py @@ -0,0 +1,436 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import threading +import unittest +import numpy as np + +import pytest + +try: + import rclpy + from rclpy.node import Node + from geometry_msgs.msg import TwistStamped as ROSTwistStamped + from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 + from sensor_msgs.msg import PointField + from tf2_msgs.msg import TFMessage as ROSTFMessage + from geometry_msgs.msg import TransformStamped +except ImportError: + rclpy = None + Node = None + ROSTwistStamped = None + ROSPointCloud2 = None + PointField = None + ROSTFMessage = None + TransformStamped = None + +from dimos.protocol.pubsub.lcmpubsub import LCM, Topic +from dimos.msgs.geometry_msgs import TwistStamped +from dimos.msgs.sensor_msgs import PointCloud2 +from dimos.msgs.tf2_msgs import TFMessage +from dimos.robot.ros_bridge import ROSBridge, BridgeDirection + + +@pytest.mark.ros +class TestROSBridge(unittest.TestCase): + """Test suite for ROS-DIMOS bridge.""" + + def setUp(self): + """Set up test fixtures.""" + # Skip if ROS is not available + if rclpy is None: + self.skipTest("ROS not available") + + # Initialize ROS if not already done + if not rclpy.ok(): + rclpy.init() + + # Create test bridge + self.bridge = ROSBridge("test_ros_bridge") + + # Create test node for publishing/subscribing + self.test_node = Node("test_node") + + # Track received messages + self.ros_messages = [] + self.dimos_messages = [] + self.message_timestamps = {"ros": [], "dimos": []} + + def tearDown(self): + """Clean up test fixtures.""" + self.test_node.destroy_node() + self.bridge.shutdown() + if rclpy.ok(): + rclpy.try_shutdown() + + def test_ros_to_dimos_twist(self): + """Test ROS TwistStamped to DIMOS conversion and transmission.""" + # Set up bridge + self.bridge.add_topic( + "/test_twist", TwistStamped, ROSTwistStamped, BridgeDirection.ROS_TO_DIMOS + ) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_twist", TwistStamped) + + def dimos_callback(msg, _topic): + self.dimos_messages.append(msg) + self.message_timestamps["dimos"].append(time.time()) + + lcm.subscribe(topic, dimos_callback) + + # Publish from ROS side + ros_pub = self.test_node.create_publisher(ROSTwistStamped, "/test_twist", 10) + + # Send test messages + for i in range(10): + msg = ROSTwistStamped() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.header.frame_id = f"frame_{i}" + msg.twist.linear.x = float(i) + msg.twist.linear.y = float(i * 2) + msg.twist.angular.z = float(i * 0.1) + + ros_pub.publish(msg) + self.message_timestamps["ros"].append(time.time()) + time.sleep(0.01) # 100Hz + + # Allow time for processing + time.sleep(0.5) + + # Verify messages received + self.assertEqual(len(self.dimos_messages), 10, "Should receive all 10 messages") + + # Verify message content + for i, msg in enumerate(self.dimos_messages): + self.assertEqual(msg.frame_id, f"frame_{i}") + self.assertAlmostEqual(msg.linear.x, float(i), places=5) + self.assertAlmostEqual(msg.linear.y, float(i * 2), places=5) + self.assertAlmostEqual(msg.angular.z, float(i * 0.1), places=5) + + def test_dimos_to_ros_twist(self): + """Test DIMOS TwistStamped to ROS conversion and transmission.""" + # Set up bridge + self.bridge.add_topic( + "/test_twist_reverse", TwistStamped, ROSTwistStamped, BridgeDirection.DIMOS_TO_ROS + ) + + # Subscribe to ROS side + def ros_callback(msg): + self.ros_messages.append(msg) + self.message_timestamps["ros"].append(time.time()) + + self.test_node.create_subscription(ROSTwistStamped, "/test_twist_reverse", ros_callback, 10) + + # Use the bridge's LCM instance for publishing + topic = Topic("/test_twist_reverse", TwistStamped) + + # Send test messages + for i in range(10): + msg = TwistStamped(ts=time.time(), frame_id=f"dimos_frame_{i}") + msg.linear.x = float(i * 3) + msg.linear.y = float(i * 4) + msg.angular.z = float(i * 0.2) + + self.bridge.lcm.publish(topic, msg) + self.message_timestamps["dimos"].append(time.time()) + time.sleep(0.01) # 100Hz + + # Allow time for processing and spin the test node + for _ in range(50): # Spin for 0.5 seconds + rclpy.spin_once(self.test_node, timeout_sec=0.01) + + # Verify messages received + self.assertEqual(len(self.ros_messages), 10, "Should receive all 10 messages") + + # Verify message content + for i, msg in enumerate(self.ros_messages): + self.assertEqual(msg.header.frame_id, f"dimos_frame_{i}") + self.assertAlmostEqual(msg.twist.linear.x, float(i * 3), places=5) + self.assertAlmostEqual(msg.twist.linear.y, float(i * 4), places=5) + self.assertAlmostEqual(msg.twist.angular.z, float(i * 0.2), places=5) + + def test_frequency_preservation(self): + """Test that message frequencies are preserved through the bridge.""" + # Set up bridge + self.bridge.add_topic( + "/test_freq", TwistStamped, ROSTwistStamped, BridgeDirection.ROS_TO_DIMOS + ) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_freq", TwistStamped) + + receive_times = [] + + def dimos_callback(_msg, _topic): + receive_times.append(time.time()) + + lcm.subscribe(topic, dimos_callback) + + # Publish from ROS at specific frequencies + ros_pub = self.test_node.create_publisher(ROSTwistStamped, "/test_freq", 10) + + # Test different frequencies + test_frequencies = [10, 50, 100] # Hz + + for target_freq in test_frequencies: + receive_times.clear() + send_times = [] + period = 1.0 / target_freq + + # Send messages at target frequency + start_time = time.time() + while time.time() - start_time < 1.0: # Run for 1 second + msg = ROSTwistStamped() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.twist.linear.x = 1.0 + + ros_pub.publish(msg) + send_times.append(time.time()) + time.sleep(period) + + # Allow processing time + time.sleep(0.2) + + # Calculate actual frequencies + if len(send_times) > 1: + send_intervals = np.diff(send_times) + send_freq = 1.0 / np.mean(send_intervals) + else: + send_freq = 0 + + if len(receive_times) > 1: + receive_intervals = np.diff(receive_times) + receive_freq = 1.0 / np.mean(receive_intervals) + else: + receive_freq = 0 + + # Verify frequency preservation (within 10% tolerance) + self.assertAlmostEqual( + receive_freq, + send_freq, + delta=send_freq * 0.1, + msg=f"Frequency not preserved for {target_freq}Hz: sent={send_freq:.1f}Hz, received={receive_freq:.1f}Hz", + ) + + def test_pointcloud_conversion(self): + """Test PointCloud2 message conversion with numpy optimization.""" + # Set up bridge + self.bridge.add_topic( + "/test_cloud", PointCloud2, ROSPointCloud2, BridgeDirection.ROS_TO_DIMOS + ) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_cloud", PointCloud2) + + received_cloud = [] + + def dimos_callback(msg, _topic): + received_cloud.append(msg) + + lcm.subscribe(topic, dimos_callback) + + # Create test point cloud + ros_pub = self.test_node.create_publisher(ROSPointCloud2, "/test_cloud", 10) + + # Generate test points + num_points = 1000 + points = np.random.randn(num_points, 3).astype(np.float32) + + # Create ROS PointCloud2 message + msg = ROSPointCloud2() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.header.frame_id = "test_frame" + msg.height = 1 + msg.width = num_points + msg.fields = [ + PointField(name="x", offset=0, datatype=PointField.FLOAT32, count=1), + PointField(name="y", offset=4, datatype=PointField.FLOAT32, count=1), + PointField(name="z", offset=8, datatype=PointField.FLOAT32, count=1), + ] + msg.is_bigendian = False + msg.point_step = 12 + msg.row_step = msg.point_step * msg.width + msg.data = points.tobytes() + msg.is_dense = True + + # Send point cloud + ros_pub.publish(msg) + + # Allow processing time + time.sleep(0.5) + + # Verify reception + self.assertEqual(len(received_cloud), 1, "Should receive point cloud") + + # Verify point data + received_points = received_cloud[0].as_numpy() + self.assertEqual(received_points.shape, points.shape) + np.testing.assert_array_almost_equal(received_points, points, decimal=5) + + def test_tf_high_frequency(self): + """Test TF message handling at high frequency.""" + # Set up bridge + self.bridge.add_topic("/test_tf", TFMessage, ROSTFMessage, BridgeDirection.ROS_TO_DIMOS) + + # Subscribe to DIMOS side + lcm = LCM() + lcm.start() + topic = Topic("/test_tf", TFMessage) + + received_tfs = [] + receive_times = [] + + def dimos_callback(msg, _topic): + received_tfs.append(msg) + receive_times.append(time.time()) + + lcm.subscribe(topic, dimos_callback) + + # Publish TF at high frequency (200Hz) + ros_pub = self.test_node.create_publisher(ROSTFMessage, "/test_tf", 100) + + target_freq = 200 # Hz + period = 1.0 / target_freq + num_messages = 200 # 1 second worth + + send_times = [] + for i in range(num_messages): + msg = ROSTFMessage() + transform = TransformStamped() + transform.header.stamp = self.test_node.get_clock().now().to_msg() + transform.header.frame_id = "world" + transform.child_frame_id = f"link_{i}" + transform.transform.translation.x = float(i) + transform.transform.rotation.w = 1.0 + msg.transforms = [transform] + + ros_pub.publish(msg) + send_times.append(time.time()) + time.sleep(period) + + # Allow processing time + time.sleep(0.5) + + # Check message count (allow 5% loss tolerance) + min_expected = int(num_messages * 0.95) + self.assertGreaterEqual( + len(received_tfs), + min_expected, + f"Should receive at least {min_expected} of {num_messages} TF messages", + ) + + # Check frequency preservation + if len(receive_times) > 1: + receive_intervals = np.diff(receive_times) + receive_freq = 1.0 / np.mean(receive_intervals) + + # For high frequency, allow 20% tolerance + self.assertAlmostEqual( + receive_freq, + target_freq, + delta=target_freq * 0.2, + msg=f"High frequency TF not preserved: expected={target_freq}Hz, got={receive_freq:.1f}Hz", + ) + + def test_bidirectional_bridge(self): + """Test simultaneous bidirectional message flow.""" + # Set up bidirectional bridges for same topic type + self.bridge.add_topic( + "/ros_to_dimos", TwistStamped, ROSTwistStamped, BridgeDirection.ROS_TO_DIMOS + ) + + self.bridge.add_topic( + "/dimos_to_ros", TwistStamped, ROSTwistStamped, BridgeDirection.DIMOS_TO_ROS + ) + + dimos_received = [] + ros_received = [] + + # DIMOS subscriber - use bridge's LCM + topic_r2d = Topic("/ros_to_dimos", TwistStamped) + self.bridge.lcm.subscribe(topic_r2d, lambda msg, _: dimos_received.append(msg)) + + # ROS subscriber + self.test_node.create_subscription( + ROSTwistStamped, "/dimos_to_ros", lambda msg: ros_received.append(msg), 10 + ) + + # Set up publishers + ros_pub = self.test_node.create_publisher(ROSTwistStamped, "/ros_to_dimos", 10) + topic_d2r = Topic("/dimos_to_ros", TwistStamped) + + # Keep track of whether threads should continue + stop_spinning = threading.Event() + + # Spin the test node in background to receive messages + def spin_test_node(): + while not stop_spinning.is_set(): + rclpy.spin_once(self.test_node, timeout_sec=0.01) + + spin_thread = threading.Thread(target=spin_test_node, daemon=True) + spin_thread.start() + + # Send messages in both directions simultaneously + def send_ros_messages(): + for i in range(50): + msg = ROSTwistStamped() + msg.header.stamp = self.test_node.get_clock().now().to_msg() + msg.twist.linear.x = float(i) + ros_pub.publish(msg) + time.sleep(0.02) # 50Hz + + def send_dimos_messages(): + for i in range(50): + msg = TwistStamped(ts=time.time()) + msg.linear.y = float(i * 2) + self.bridge.lcm.publish(topic_d2r, msg) + time.sleep(0.02) # 50Hz + + # Run both senders in parallel + ros_thread = threading.Thread(target=send_ros_messages) + dimos_thread = threading.Thread(target=send_dimos_messages) + + ros_thread.start() + dimos_thread.start() + + ros_thread.join() + dimos_thread.join() + + # Allow processing time + time.sleep(0.5) + stop_spinning.set() + spin_thread.join(timeout=1.0) + + # Verify both directions worked + self.assertGreaterEqual(len(dimos_received), 45, "Should receive most ROS->DIMOS messages") + self.assertGreaterEqual(len(ros_received), 45, "Should receive most DIMOS->ROS messages") + + # Verify message integrity + for i, msg in enumerate(dimos_received[:45]): + self.assertAlmostEqual(msg.linear.x, float(i), places=5) + + for i, msg in enumerate(ros_received[:45]): + self.assertAlmostEqual(msg.twist.linear.y, float(i * 2), places=5) + + +if __name__ == "__main__": + unittest.main() diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index faad61833d..d05676d3e4 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -73,6 +73,8 @@ class UnitreeWebRTCConnection: def __init__(self, ip: str, mode: str = "ai"): self.ip = ip self.mode = mode + self.stop_timer = None + self.cmd_vel_timeout = 0.2 self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) self.connect() @@ -127,7 +129,7 @@ def move(self, twist: Twist, duration: float = 0.0) -> bool: async def async_move(): self.conn.datachannel.pub_sub.publish_without_callback( RTC_TOPIC["WIRELESS_CONTROLLER"], - data={"lx": y, "ly": x, "rx": -yaw, "ry": 0}, + data={"lx": -y, "ly": x, "rx": -yaw, "ry": 0}, ) async def async_move_duration(): @@ -139,6 +141,15 @@ async def async_move_duration(): await async_move() await asyncio.sleep(sleep_time) + # Cancel existing timer and start a new one + if self.stop_timer: + self.stop_timer.cancel() + + # Auto-stop after 0.5 seconds if no new commands + self.stop_timer = threading.Timer(self.cmd_vel_timeout, self.stop) + self.stop_timer.daemon = True + self.stop_timer.start() + try: if duration > 0: # Send continuous move commands for the duration @@ -323,10 +334,20 @@ def stop(self) -> bool: Returns: bool: True if stop command was sent successfully """ + # Cancel timer since we're explicitly stopping + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + return self.move(Twist()) def disconnect(self) -> None: """Disconnect from the robot and clean up resources.""" + # Cancel timer + if self.stop_timer: + self.stop_timer.cancel() + self.stop_timer = None + if hasattr(self, "task") and self.task: self.task.cancel() if hasattr(self, "conn"): diff --git a/dimos/robot/unitree_webrtc/unitree_g1.py b/dimos/robot/unitree_webrtc/unitree_g1.py index 2049d41c7c..2963c2cfc6 100644 --- a/dimos/robot/unitree_webrtc/unitree_g1.py +++ b/dimos/robot/unitree_webrtc/unitree_g1.py @@ -14,8 +14,8 @@ # limitations under the License. """ -Unitree G1 humanoid robot with ZED camera integration. -Minimal implementation using WebRTC connection for robot control and ZED for vision. +Unitree G1 humanoid robot. +Minimal implementation using WebRTC connection for robot control. """ import os @@ -25,23 +25,35 @@ from dimos import core from dimos.core import Module, In, Out, rpc -from dimos.hardware.zed_camera import ZEDModule -from dimos.msgs.geometry_msgs import PoseStamped, Transform, Twist, Vector3, Quaternion -from dimos.msgs.sensor_msgs import Image -from dimos_lcm.sensor_msgs import CameraInfo +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped +from dimos.msgs.nav_msgs.Odometry import Odometry +from dimos.msgs.sensor_msgs import Image, CameraInfo, PointCloud2 +from dimos.msgs.tf2_msgs.TFMessage import TFMessage from dimos.protocol import pubsub from dimos.protocol.pubsub.lcmpubsub import LCM -from dimos.protocol.tf import TF from dimos.robot.foxglove_bridge import FoxgloveBridge +from dimos.web.websocket_vis.websocket_vis_module import WebsocketVisModule from dimos.robot.unitree_webrtc.connection import UnitreeWebRTCConnection from dimos.robot.unitree_webrtc.unitree_skills import MyUnitreeSkills +from dimos.robot.ros_bridge import ROSBridge, BridgeDirection +from geometry_msgs.msg import TwistStamped as ROSTwistStamped +from nav_msgs.msg import Odometry as ROSOdometry +from sensor_msgs.msg import PointCloud2 as ROSPointCloud2 +from tf2_msgs.msg import TFMessage as ROSTFMessage from dimos.skills.skills import SkillLibrary from dimos.robot.robot import Robot + from dimos.types.robot_capabilities import RobotCapability from dimos.utils.logging_config import setup_logger logger = setup_logger("dimos.robot.unitree_webrtc.unitree_g1", level=logging.INFO) +# try: +# from dimos.hardware.zed_camera import ZEDModule +# except ImportError: +# logger.warning("ZEDModule not found. Please install pyzed to use ZED camera functionality.") +# ZEDModule = None + # Suppress verbose loggers logging.getLogger("aiortc.codecs.h264").setLevel(logging.ERROR) logging.getLogger("lcm_foxglove_bridge").setLevel(logging.ERROR) @@ -51,19 +63,18 @@ class G1ConnectionModule(Module): - """Simplified connection module for G1 - uses WebRTC for control, no video.""" + """Simplified connection module for G1 - uses WebRTC for control.""" + + movecmd: In[TwistStamped] = None + odom_in: In[Odometry] = None - movecmd: In[Twist] = None - odom: Out[PoseStamped] = None + odom_pose: Out[PoseStamped] = None ip: str connection_type: str = "webrtc" - _odom: PoseStamped = None - def __init__(self, ip: str = None, connection_type: str = "webrtc", *args, **kwargs): self.ip = ip self.connection_type = connection_type - self.tf = TF() self.connection = None Module.__init__(self, *args, **kwargs) @@ -72,47 +83,25 @@ def start(self): """Start the connection and subscribe to sensor streams.""" # Use the exact same UnitreeWebRTCConnection as Go2 self.connection = UnitreeWebRTCConnection(self.ip) - - # Subscribe only to odometry (no video/lidar for G1) - self.connection.odom_stream().subscribe(self._publish_tf) self.movecmd.subscribe(self.move) - - def _publish_tf(self, msg): - """Publish odometry and TF transforms.""" - self._odom = msg - self.odom.publish(msg) - self.tf.publish(Transform.from_pose("base_link", msg)) - - # Publish ZED camera transform relative to robot base - zed_transform = Transform( - translation=Vector3(0.0, 0.0, 1.5), # ZED mounted at ~1.5m height on G1 - rotation=Quaternion(0.0, 0.0, 0.0, 1.0), - frame_id="base_link", - child_frame_id="zed_camera", - ts=time.time(), + self.odom_in.subscribe(self._publish_odom_pose) + + def _publish_odom_pose(self, msg: Odometry): + self.odom_pose.publish( + PoseStamped( + ts=msg.ts, + frame_id=msg.frame_id, + position=msg.pose.pose.position, + orientation=msg.pose.orientation, + ) ) - self.tf.publish(zed_transform) - - @rpc - def get_odom(self) -> Optional[PoseStamped]: - """Get the robot's odometry.""" - return self._odom @rpc - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): """Send movement command to robot.""" + twist = Twist(linear=twist_stamped.linear, angular=twist_stamped.angular) self.connection.move(twist, duration) - @rpc - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() - - @rpc - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() - @rpc def publish_request(self, topic: str, data: dict): """Forward WebRTC publish requests to connection.""" @@ -120,26 +109,34 @@ def publish_request(self, topic: str, data: dict): class UnitreeG1(Robot): - """Unitree G1 humanoid robot with ZED camera for vision.""" + """Unitree G1 humanoid robot.""" def __init__( self, ip: str, output_dir: str = None, + websocket_port: int = 7779, skill_library: Optional[SkillLibrary] = None, recording_path: str = None, replay_path: str = None, enable_joystick: bool = False, + enable_connection: bool = True, + enable_ros_bridge: bool = True, + enable_camera: bool = False, ): """Initialize the G1 robot. Args: ip: Robot IP address output_dir: Directory for saving outputs + websocket_port: Port for web visualization skill_library: Skill library instance recording_path: Path to save recordings (if recording) replay_path: Path to replay recordings from (if replaying) enable_joystick: Enable pygame joystick control + enable_connection: Enable robot connection module + enable_ros_bridge: Enable ROS bridge + enable_camera: Enable ZED camera module """ super().__init__() self.ip = ip @@ -147,6 +144,10 @@ def __init__( self.recording_path = recording_path self.replay_path = replay_path self.enable_joystick = enable_joystick + self.enable_connection = enable_connection + self.enable_ros_bridge = enable_ros_bridge + self.enable_camera = enable_camera + self.websocket_port = websocket_port self.lcm = LCM() # Initialize skill library with G1 robot type @@ -157,14 +158,16 @@ def __init__( self.skill_library = skill_library # Set robot capabilities - self.capabilities = [RobotCapability.LOCOMOTION, RobotCapability.VISION] + self.capabilities = [RobotCapability.LOCOMOTION] # Module references self.dimos = None self.connection = None - self.zed_camera = None + self.websocket_vis = None self.foxglove_bridge = None self.joystick = None + self.ros_bridge = None + self.zed_camera = None self._setup_directories() @@ -175,32 +178,37 @@ def _setup_directories(self): def start(self): """Start the robot system with all modules.""" - self.dimos = core.start( - 3 if self.enable_joystick else 2 - ) # Extra worker for joystick if enabled + self.dimos = core.start(4) # 2 workers for connection and visualization + + if self.enable_connection: + self._deploy_connection() - self._deploy_connection() - self._deploy_camera() self._deploy_visualization() + if self.enable_camera: + self._deploy_camera() + if self.enable_joystick: self._deploy_joystick() + if self.enable_ros_bridge: + self._deploy_ros_bridge() + self._start_modules() self.lcm.start() logger.info("UnitreeG1 initialized and started") - logger.info("ZED camera module deployed for vision") + logger.info(f"WebSocket visualization available at http://localhost:{self.websocket_port}") def _deploy_connection(self): """Deploy and configure the connection module.""" self.connection = self.dimos.deploy(G1ConnectionModule, self.ip) # Configure LCM transports - self.connection.odom.transport = core.LCMTransport("/g1/odom", PoseStamped) - # Use standard /cmd_vel topic for compatibility with joystick and navigation - self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", Twist) + self.connection.movecmd.transport = core.LCMTransport("/cmd_vel", TwistStamped) + self.connection.odom_in.transport = core.LCMTransport("/state_estimation", Odometry) + self.connection.odom_pose.transport = core.LCMTransport("/odom", PoseStamped) def _deploy_camera(self): """Deploy and configure the ZED camera module (real or fake based on replay_path).""" @@ -241,7 +249,14 @@ def _deploy_camera(self): logger.info("ZED camera module configured") def _deploy_visualization(self): - """Deploy visualization tools.""" + """Deploy and configure visualization modules.""" + # Deploy WebSocket visualization module + self.websocket_vis = self.dimos.deploy(WebsocketVisModule, port=self.websocket_port) + self.websocket_vis.movecmd_stamped.transport = core.LCMTransport("/cmd_vel", TwistStamped) + + # Note: robot_pose connection removed since odom was removed from G1ConnectionModule + + # Deploy Foxglove bridge self.foxglove_bridge = FoxgloveBridge() def _deploy_joystick(self): @@ -253,10 +268,39 @@ def _deploy_joystick(self): self.joystick.twist_out.transport = core.LCMTransport("/cmd_vel", Twist) logger.info("Joystick module deployed - pygame window will open") + def _deploy_ros_bridge(self): + """Deploy and configure ROS bridge.""" + self.ros_bridge = ROSBridge("g1_ros_bridge") + + # Add /cmd_vel topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/cmd_vel", TwistStamped, ROSTwistStamped, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /state_estimation topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/state_estimation", Odometry, ROSOdometry, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /tf topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/tf", TFMessage, ROSTFMessage, direction=BridgeDirection.ROS_TO_DIMOS + ) + + # Add /registered_scan topic from ROS to DIMOS + self.ros_bridge.add_topic( + "/registered_scan", PointCloud2, ROSPointCloud2, direction=BridgeDirection.ROS_TO_DIMOS + ) + + logger.info( + "ROS bridge deployed: /cmd_vel, /state_estimation, /tf, /registered_scan (ROS → DIMOS)" + ) + def _start_modules(self): """Start all deployed modules.""" - self.connection.start() - self.zed_camera.start() + if self.connection: + self.connection.start() + self.websocket_vis.start() self.foxglove_bridge.start() if self.joystick: @@ -272,28 +316,35 @@ def _start_modules(self): self.skill_library.init() self.skill_library.initialize_skills() - def get_single_rgb_frame(self, timeout: float = 2.0) -> Image: - """Get a single RGB frame from the ZED camera.""" - from dimos.protocol.pubsub.lcmpubsub import Topic - - topic = Topic("/zed/color_image", Image) - return self.lcm.wait_for_message(topic, timeout=timeout) - - def move(self, twist: Twist, duration: float = 0.0): + def move(self, twist_stamped: TwistStamped, duration: float = 0.0): """Send movement command to robot.""" - self.connection.move(twist, duration) + self.connection.move(twist_stamped, duration) def get_odom(self) -> PoseStamped: """Get the robot's odometry.""" - return self.connection.get_odom() + # Note: odom functionality removed from G1ConnectionModule + return None + + def shutdown(self): + """Shutdown the robot and clean up resources.""" + logger.info("Shutting down UnitreeG1...") - def standup(self): - """Make the robot stand up.""" - return self.connection.standup() + # Shutdown ROS bridge if it exists + if self.ros_bridge is not None: + try: + self.ros_bridge.shutdown() + logger.info("ROS bridge shut down successfully") + except Exception as e: + logger.error(f"Error shutting down ROS bridge: {e}") - def liedown(self): - """Make the robot lie down.""" - return self.connection.liedown() + # Stop other modules if needed + if self.websocket_vis: + try: + self.websocket_vis.stop() + except Exception as e: + logger.error(f"Error stopping websocket vis: {e}") + + logger.info("UnitreeG1 shutdown complete") def main(): @@ -307,16 +358,13 @@ def main(): parser = argparse.ArgumentParser(description="Unitree G1 Humanoid Robot Control") parser.add_argument("--ip", default=os.getenv("ROBOT_IP"), help="Robot IP address") parser.add_argument("--joystick", action="store_true", help="Enable pygame joystick control") + parser.add_argument("--camera", action="store_true", help="Enable ZED camera module") parser.add_argument("--output-dir", help="Output directory for logs/data") parser.add_argument("--record", help="Path to save recording") parser.add_argument("--replay", help="Path to replay recording from") args = parser.parse_args() - if not args.ip: - logger.error("Robot IP not set. Use --ip or set ROBOT_IP environment variable") - return - pubsub.lcm.autoconf() robot = UnitreeG1( @@ -325,6 +373,9 @@ def main(): recording_path=args.record, replay_path=args.replay, enable_joystick=args.joystick, + enable_camera=args.camera, + enable_connection=os.getenv("ROBOT_IP") is not None, + enable_ros_bridge=True, ) robot.start() @@ -346,6 +397,7 @@ def main(): time.sleep(1) except KeyboardInterrupt: logger.info("Shutting down...") + robot.shutdown() if __name__ == "__main__": diff --git a/dimos/web/websocket_vis/websocket_vis_module.py b/dimos/web/websocket_vis/websocket_vis_module.py index 25141a85ec..760edc60e6 100644 --- a/dimos/web/websocket_vis/websocket_vis_module.py +++ b/dimos/web/websocket_vis/websocket_vis_module.py @@ -20,6 +20,7 @@ import asyncio import threading +import time from typing import Any, Dict, Optional import base64 import numpy as np @@ -32,7 +33,7 @@ from dimos.core import Module, In, Out, rpc from dimos_lcm.std_msgs import Bool -from dimos.msgs.geometry_msgs import PoseStamped, Twist, Vector3 +from dimos.msgs.geometry_msgs import PoseStamped, Twist, TwistStamped, Vector3 from dimos.msgs.nav_msgs import OccupancyGrid, Path from dimos.utils.logging_config import setup_logger @@ -68,6 +69,7 @@ class WebsocketVisModule(Module): explore_cmd: Out[Bool] = None stop_explore_cmd: Out[Bool] = None movecmd: Out[Twist] = None + movecmd_stamped: Out[TwistStamped] = None def __init__(self, port: int = 7779, **kwargs): """Initialize the WebSocket visualization module. @@ -111,9 +113,13 @@ def start(self): self.server_thread = threading.Thread(target=self._run_server, daemon=True) self.server_thread.start() - self.robot_pose.subscribe(self._on_robot_pose) - self.path.subscribe(self._on_path) - self.global_costmap.subscribe(self._on_global_costmap) + # Only subscribe to connected topics + if self.robot_pose.connection is not None: + self.robot_pose.subscribe(self._on_robot_pose) + if self.path.connection is not None: + self.path.subscribe(self._on_path) + if self.global_costmap.connection is not None: + self.global_costmap.subscribe(self._on_global_costmap) logger.info(f"WebSocket server started on http://localhost:{self.port}") @@ -167,11 +173,27 @@ async def stop_explore(sid): @self.sio.event async def move_command(sid, data): - twist = Twist( - linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), - angular=Vector3(data["angular"]["x"], data["angular"]["y"], data["angular"]["z"]), - ) - self.movecmd.publish(twist) + # Publish Twist if transport is configured + if self.movecmd and self.movecmd.transport: + twist = Twist( + linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), + angular=Vector3( + data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] + ), + ) + self.movecmd.publish(twist) + + # Publish TwistStamped if transport is configured + if self.movecmd_stamped and self.movecmd_stamped.transport: + twist_stamped = TwistStamped( + ts=time.time(), + frame_id="base_link", + linear=Vector3(data["linear"]["x"], data["linear"]["y"], data["linear"]["z"]), + angular=Vector3( + data["angular"]["x"], data["angular"]["y"], data["angular"]["z"] + ), + ) + self.movecmd_stamped.publish(twist_stamped) def _run_server(self): uvicorn.run(