diff --git a/embodichain/lab/sim/utility/action_utils.py b/embodichain/lab/sim/utility/action_utils.py index feb7311f..060201fa 100644 --- a/embodichain/lab/sim/utility/action_utils.py +++ b/embodichain/lab/sim/utility/action_utils.py @@ -245,7 +245,7 @@ def get_trajectory_object_offset_qpos( return is_success, key_qpos_offset -def interpolate_with_distance_warp( +def interpolate_with_distance( trajectory: torch.Tensor, # expected shape [B, N, M], float or convertible to float interp_num: int, # T device=torch.device("cuda"), @@ -258,7 +258,7 @@ def interpolate_with_distance_warp( Args: trajectory: Torch.Tensor of shape [B, N, M]. interp_num: Target number of samples T. - device: Warp device string ('cpu', 'cuda', 'cuda:0', ...). + device: Torch device string ('cpu', 'cuda', 'cuda:0', ...). dtype: Working dtype (wp.float32 or wp.float64). Defaults to wp.float32. Returns: @@ -335,3 +335,75 @@ def interpolate_with_distance_warp( # wp.synchronize_device(device) interp_trajectory = wp.to_torch(out).view(B, T, M) return interp_trajectory + + +def interpolate_with_nums( + trajectory: torch.Tensor, # expected shape [B, N, M], float or convertible to float + interp_nums: torch.Tensor, # expected shape [N - 1], interp_num in each segment + device=torch.device("cuda"), +) -> torch.Tensor: + """ + Each entry ``interp_nums[i] = k`` controls segment ``i`` between + ``trajectory[:, i, :]`` and ``trajectory[:, i + 1, :]``. For that segment, + ``k`` samples are generated with interpolation factors + ``alpha = 0, 1/k, 2/k, ..., (k-1)/k`` (i.e., including the segment start + and excluding the segment end). The final endpoint + ``trajectory[:, -1, :]`` is appended once at the end of the result, so + intermediate segment endpoints are not duplicated. + + Args: + trajectory: Torch.Tensor of shape [B, N, M]. + interp_nums: Torch.Tensor of shape [N - 1] specifying the number of + samples per segment, including each segment start and excluding + its end. Values must be non-negative; a value of 0 means that + no samples are drawn from that segment (other than the final + overall endpoint that is always appended once). + device: Torch device string ('cpu', 'cuda', 'cuda:0', ...). + + Returns: + Torch.Tensor of interpolated trajectories. + """ + trajectory = trajectory.to(device) + if not torch.is_floating_point(trajectory): + trajectory = trajectory.float() + + B, N, M = trajectory.shape + if N == 0: + return trajectory.new_empty((B, 0, M)) + + interp_nums_tensor = torch.as_tensor(interp_nums, device="cpu").reshape(-1) + if interp_nums_tensor.numel() != max(N - 1, 0): + raise ValueError("`interp_nums` must have shape (N - 1,).") + + if N == 1: + return trajectory[:, :1, :] + + interp_nums_list = interp_nums_tensor.to(torch.int64).tolist() + + # Always seed the output with the first waypoint so it is never dropped, + # even when leading segments have zero samples. + segments = [trajectory[:, :1, :]] + for i, count in enumerate(interp_nums_list): + if count < 0: + raise ValueError("`interp_nums` values must be non-negative.") + p0 = trajectory[:, i : i + 1, :] + p1 = trajectory[:, i + 1 : i + 2, :] + if count == 0: + # No interpolated samples for this segment, but ensure the endpoint + # waypoint is still present so zero-sample segments don't remove it. + segments.append(p1) + continue + # Generate linearly spaced interpolation parameters from 0 to 1 + # (inclusive), then drop the first value (t = 0) because p0 is + # already the last point in `segments`. This appends exactly + # `count` new points per segment and preserves all endpoints. + alpha = torch.linspace( + 0.0, + 1.0, + steps=count + 1, + device=device, + dtype=trajectory.dtype, + ).view(1, count + 1, 1) + seg = p0 + (p1 - p0) * alpha + segments.append(seg[:, 1:, :]) + return torch.cat(segments, dim=1) diff --git a/examples/sim/demo/grasp_cup_to_caffe.py b/examples/sim/demo/grasp_cup_to_caffe.py index 135da5f7..c2c69ab6 100644 --- a/examples/sim/demo/grasp_cup_to_caffe.py +++ b/examples/sim/demo/grasp_cup_to_caffe.py @@ -34,7 +34,7 @@ RigidBodyAttributesCfg, ArticulationCfg, ) -from embodichain.lab.sim.utility.action_utils import interpolate_with_distance_warp +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance from embodichain.lab.sim.shapes import MeshCfg from embodichain.data import get_data_path from embodichain.utils import logger @@ -374,7 +374,7 @@ def create_trajectory( ) all_trajectory = torch.cat([arm_trajectory, hand_trajectory], dim=-1) # trajetory with shape [n_envs, n_waypoint, dof] - interp_trajectory = interpolate_with_distance_warp( + interp_trajectory = interpolate_with_distance( trajectory=all_trajectory, interp_num=150, device=sim.device ) return interp_trajectory diff --git a/examples/sim/demo/press_softbody.py b/examples/sim/demo/press_softbody.py index 98c85705..12883116 100644 --- a/examples/sim/demo/press_softbody.py +++ b/examples/sim/demo/press_softbody.py @@ -28,7 +28,7 @@ from embodichain.lab.sim import SimulationManager, SimulationManagerCfg from embodichain.lab.sim.objects import Robot, SoftObject -from embodichain.lab.sim.utility.action_utils import interpolate_with_distance_warp +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance from embodichain.lab.sim.shapes import MeshCfg from embodichain.lab.sim.solvers import PytorchSolverCfg from embodichain.data import get_data_path @@ -181,7 +181,7 @@ def press_cow(sim: SimulationManager, robot: Robot): ) arm_trajectory = torch.concatenate([arm_start_qpos, approach_qpos]) - interp_trajectory = interpolate_with_distance_warp( + interp_trajectory = interpolate_with_distance( trajectory=arm_trajectory[None, :, :], interp_num=50, device=sim.device ) interp_trajectory = interp_trajectory[0] diff --git a/examples/sim/demo/scoop_ice.py b/examples/sim/demo/scoop_ice.py index d8fbaed8..00e05d77 100644 --- a/examples/sim/demo/scoop_ice.py +++ b/examples/sim/demo/scoop_ice.py @@ -39,7 +39,7 @@ LightCfg, ) from embodichain.lab.sim.material import VisualMaterialCfg -from embodichain.lab.sim.utility.action_utils import interpolate_with_distance_warp +from embodichain.lab.sim.utility.action_utils import interpolate_with_distance from embodichain.lab.sim.shapes import MeshCfg, CubeCfg from embodichain.lab.sim.solvers import PytorchSolverCfg from embodichain.data import get_data_path @@ -515,7 +515,7 @@ def scoop_ice(sim: SimulationManager, robot: Robot, scoop: RigidObject): ) all_trajectory = torch.hstack([arm_trajectory, hand_trajectory]) - interp_trajectory = interpolate_with_distance_warp( + interp_trajectory = interpolate_with_distance( trajectory=all_trajectory[None, :, :], interp_num=200, device=sim.device ) interp_trajectory = interp_trajectory[0]