diff --git a/embodichain/lab/sim/objects/robot.py b/embodichain/lab/sim/objects/robot.py index 49c330fd..22332c66 100644 --- a/embodichain/lab/sim/objects/robot.py +++ b/embodichain/lab/sim/objects/robot.py @@ -25,7 +25,7 @@ from embodichain.lab.sim.solvers import SolverCfg, BaseSolver from embodichain.lab.sim.objects import Articulation from embodichain.lab.sim.utility.tensor import to_tensor -from embodichain.utils.math import quat_from_matrix +from embodichain.utils.math import quat_from_matrix, matrix_from_quat from embodichain.utils.string import ( is_regular_expression, resolve_matching_names_values, @@ -561,22 +561,14 @@ def compute_ik( if pose.shape[-1] == 7 and pose.dim() == 2: # Convert pose from (batch, 7) to (batch, 4, 4) - pose = torch.cat( - ( - pose[:, :3].unsqueeze(-1), # Position - quat_from_matrix(pose[:, 3:]).unsqueeze(-1), # Quaternion - ), - dim=-1, - ) - pose = torch.cat( - ( - pose, - torch.tensor([[0, 0, 0, 1]], device=pose.device).expand( - pose.shape[0], -1, -1 - ), - ), - dim=1, - ) + pos = pose[:, :3] + quat = pose[:, 3:] + # Convert quaternion to rotation matrix + rot = matrix_from_quat(quat) + # Build homogeneous transformation matrix efficiently + pose = torch.eye(4, device=pose.device).repeat(pose.shape[0], 1, 1) + pose[:, :3, :3] = rot + pose[:, :3, 3] = pos base_pose = self.get_link_pose( link_name=solver.root_link_name, env_ids=local_env_ids, to_matrix=True @@ -721,23 +713,17 @@ def compute_batch_ik( if pose.shape[-1] == 7 and pose.dim() == 3: # Convert pose from (n_envs, n_batch, 7) to (n_envs * n_batch, 4, 4) - pose_batch = torch.reshape(-1, 7) - pose_batch = torch.cat( - ( - pose_batch[:, :3].unsqueeze(-1), # Position - quat_from_matrix(pose_batch[:, 3:]).unsqueeze(-1), # Quaternion - ), - dim=-1, - ) - pose_batch = torch.cat( - ( - pose_batch, - torch.tensor([[0, 0, 0, 1]], device=pose_batch.device).expand( - pose_batch.shape[0], -1, -1 - ), - ), - dim=1, + pose_batch = pose.reshape(-1, 7) + pos = pose_batch[:, :3] + quat = pose_batch[:, 3:] + # Convert quaternion to rotation matrix + rot = matrix_from_quat(quat) + # Build homogeneous transformation matrix efficiently + pose_batch = torch.eye(4, device=pose.device).repeat( + pose_batch.shape[0], 1, 1 ) + pose_batch[:, :3, :3] = rot + pose_batch[:, :3, 3] = pos else: # Convert pose from (n_envs, n_batch, 4, 4) to (n_envs * n_batch, 4, 4) pose_batch = pose.reshape(-1, 4, 4) diff --git a/tests/sim/solvers/test_opw_solver.py b/tests/sim/solvers/test_opw_solver.py index 938cb907..ee892884 100644 --- a/tests/sim/solvers/test_opw_solver.py +++ b/tests/sim/solvers/test_opw_solver.py @@ -80,11 +80,22 @@ def test_ik(self, arm_name: str): ) fk_xpos = self.robot.compute_fk(qpos=qpos_fk, name=arm_name, to_matrix=True) + fk_xpos_xyzquat = self.robot.compute_fk( + qpos=qpos_fk, name=arm_name, to_matrix=False + ) res, ik_qpos = self.robot.compute_ik( pose=fk_xpos, joint_seed=qpos_fk, name=arm_name ) + res, ik_qpos_xyzquat = self.robot.compute_ik( + pose=fk_xpos_xyzquat, joint_seed=qpos_fk, name=arm_name + ) + + assert torch.allclose( + ik_qpos, ik_qpos_xyzquat, atol=1e-4, rtol=1e-4 + ), "IK results do not match for different pose formats" + res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name) if ik_qpos.dim() == 3: