-
Notifications
You must be signed in to change notification settings - Fork 8
fix opw solver #152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix opw solver #152
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,6 +23,45 @@ | |
| from embodichain.lab.sim.robots import CobotMagicCfg | ||
|
|
||
|
|
||
| def grid_sample_qpos_from_limits( | ||
| qpos_limits: torch.Tensor, | ||
| steps_per_joint: int = 4, | ||
| device=None, | ||
| max_samples: int = 4096, | ||
| ) -> torch.Tensor: | ||
| """Generate grid samples for qpos from qpos_limits. | ||
|
|
||
| Args: | ||
| qpos_limits: tensor of shape (1, n, 2) or (n, 2) where each row is [low, high]. | ||
| steps_per_joint: number of values per joint (defaults to 2: low and high). | ||
| device: torch device to place the samples on. | ||
| max_samples: cap the number of returned samples (take first N if grid is larger). | ||
|
|
||
| Returns: | ||
| Tensor of shape (N, n) where N <= max_samples. | ||
| """ | ||
| if device is None: | ||
| device = qpos_limits.device | ||
|
|
||
| limits = qpos_limits.squeeze(0) if qpos_limits.dim() == 3 else qpos_limits | ||
| lows = limits[:, 0].to(device) | ||
| highs = limits[:, 1].to(device) | ||
|
|
||
| # create per-joint linspaces | ||
| grids = [ | ||
| torch.linspace(l.item(), h.item(), steps_per_joint, device=device) | ||
| for l, h in zip(lows, highs) | ||
| ] | ||
|
|
||
| # meshgrid and stack | ||
| mesh = torch.meshgrid(*grids, indexing="ij") | ||
| stacked = torch.stack([m.reshape(-1) for m in mesh], dim=1) | ||
|
|
||
| if stacked.shape[0] > max_samples: | ||
| return stacked[:max_samples] | ||
| return stacked | ||
|
|
||
|
|
||
| # Base test class for OPWSolver | ||
| class BaseSolverTest: | ||
| sim = None # Define as a class attribute | ||
|
|
@@ -75,42 +114,38 @@ def setup_simulation(self, sim_device): | |
| def test_ik(self, arm_name: str): | ||
| # Test inverse kinematics (IK) with a 1x4x4 homogeneous matrix pose and a joint_seed | ||
|
|
||
| test_qpos = torch.tensor( | ||
| [[0.0, np.pi / 4, -np.pi / 4, 0.0, np.pi / 4, 0.0]], | ||
| dtype=torch.float32, | ||
| device=self.robot.device, | ||
| qpos_limit = self.robot.get_qpos_limits(name=arm_name) | ||
| # generate a small grid of qpos samples from the joint limits (low/high) | ||
| sample_qpos = grid_sample_qpos_from_limits( | ||
| qpos_limit, steps_per_joint=8, device=self.robot.device, max_samples=65536 | ||
| ) | ||
| sample_qpos = sample_qpos[None, :, :] | ||
|
|
||
| fk_xpos = self.robot.compute_fk(qpos=test_qpos, name=arm_name, to_matrix=True) | ||
| fk_xpos_xyzquat = self.robot.compute_fk( | ||
| qpos=test_qpos, name=arm_name, to_matrix=False | ||
| fk_xpos = self.robot.compute_batch_fk( | ||
| qpos=sample_qpos, name=arm_name, to_matrix=True | ||
| ) | ||
| fk_xpos_xyzquat = self.robot.compute_batch_fk( | ||
| qpos=sample_qpos, name=arm_name, to_matrix=False | ||
| ) | ||
|
|
||
| res, ik_qpos = self.robot.compute_ik( | ||
| pose=fk_xpos, joint_seed=test_qpos, name=arm_name | ||
| res, ik_qpos = self.robot.compute_batch_ik( | ||
| pose=fk_xpos, joint_seed=sample_qpos, name=arm_name | ||
| ) | ||
|
Comment on lines
+131
to
133
|
||
|
|
||
| res, ik_qpos_xyzquat = self.robot.compute_ik( | ||
| pose=fk_xpos_xyzquat, joint_seed=test_qpos, name=arm_name | ||
| res, ik_qpos_xyzquat = self.robot.compute_batch_ik( | ||
| pose=fk_xpos_xyzquat, joint_seed=sample_qpos, name=arm_name | ||
| ) | ||
|
|
||
| assert torch.allclose( | ||
| ik_qpos, ik_qpos_xyzquat, atol=1e-4, rtol=1e-4 | ||
| ), "IK results do not match for different pose formats" | ||
|
|
||
| res, ik_qpos = self.robot.compute_ik(pose=fk_xpos, name=arm_name) | ||
|
|
||
| if ik_qpos_xyzquat.dim() == 3: | ||
| ik_xpos = self.robot.compute_fk( | ||
| qpos=ik_qpos_xyzquat[0][0], name=arm_name, to_matrix=True | ||
| ) | ||
| else: | ||
| ik_xpos = self.robot.compute_fk( | ||
| qpos=ik_qpos_xyzquat, name=arm_name, to_matrix=True | ||
| ) | ||
| ik_xpos = self.robot.compute_batch_fk( | ||
| qpos=ik_qpos_xyzquat, name=arm_name, to_matrix=True | ||
| ) | ||
|
|
||
| assert torch.allclose( | ||
| test_qpos, ik_qpos, atol=5e-3, rtol=5e-3 | ||
| sample_qpos, ik_qpos, atol=5e-3, rtol=5e-3 | ||
| ), f"FK and IK qpos do not match for {arm_name}" | ||
|
|
||
| assert torch.allclose( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test generates up to 65,536 IK/FK samples per arm (
steps_per_joint=8on a 6-DOF grid, then capped). That is a large workload for a unit test and can significantly slow CI or lead to timeouts. Consider loweringsteps_per_joint/max_samples, or sampling a smaller randomized subset that still covers the edge cases you care about.