From 02dc3d4f6b6fcf081d93bb51d6ec4335c53bab85 Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 15:06:00 -0700 Subject: [PATCH 1/7] update code for anchor-free RPF --- README.md | 4 + rectified_point_flow/data/datamodule.py | 12 +++ rectified_point_flow/data/dataset.py | 93 ++++++++++++------- .../encoder/point_cloud_encoder.py | 33 +++---- rectified_point_flow/eval/evaluator.py | 6 +- rectified_point_flow/eval/metrics.py | 43 +++++++++ rectified_point_flow/flow_model/norm.py | 6 +- rectified_point_flow/modeling.py | 40 +++++--- rectified_point_flow/sampler.py | 32 ++++--- 9 files changed, 191 insertions(+), 78 deletions(-) diff --git a/README.md b/README.md index 85b1e39..fbc2afd 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,10 @@ ## πŸ”” News +- [Oct 26, 2025] Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰πŸŽ‰πŸŽ‰ + - We include an additional **anchor-free** version for RPF, which aligns more with practical assembly tasks. We find that the anchor-free RPF can achieve similar performance to the anchor-fixed RPF, and shows better generalization ability. + - The code and checkpoints are updated to support the anchor-free mode. For more details, please see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25). + - [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego! - [July 22, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes: diff --git a/rectified_point_flow/data/datamodule.py b/rectified_point_flow/data/datamodule.py index 32609df..570603a 100644 --- a/rectified_point_flow/data/datamodule.py +++ b/rectified_point_flow/data/datamodule.py @@ -33,6 +33,7 @@ def __init__( up_axis: dict[str, str] = {}, min_parts: int = 2, max_parts: int = 64, + anchor_free: bool = True, num_points_to_sample: int = 5000, min_points_per_part: int = 20, min_dataset_size: int = 2000, @@ -51,6 +52,9 @@ def __init__( If not provided, the up axis is assumed to be 'y'. This only affects the visualization. min_parts: Minimum number of parts in a point cloud. max_parts: Maximum number of parts in a point cloud. + anchor_free: Whether to use anchor-free mode. + If True, the anchor part is centered and randomly rotated, like the non-anchor parts (default). + If False, the anchor part is not centered and thus its pose in the CoM frame of the GT point cloud is given (align with GARF). num_points_to_sample: Number of points to sample from each point cloud. min_points_per_part: Minimum number of points per part. min_dataset_size: Minimum number of point clouds in a dataset. @@ -65,6 +69,7 @@ def __init__( self.up_axis = up_axis self.min_parts = min_parts self.max_parts = max_parts + self.anchor_free = anchor_free self.num_points_to_sample = num_points_to_sample self.min_points_per_part = min_points_per_part self.batch_size = batch_size @@ -120,6 +125,7 @@ def setup(self, stage: str): num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, min_dataset_size=self.min_dataset_size, + anchor_free=self.anchor_free, random_scale_range=self.random_scale_range, multi_anchor=self.multi_anchor, ) @@ -136,6 +142,7 @@ def setup(self, stage: str): dataset_name=dataset_name, min_parts=self.min_parts, max_parts=self.max_parts, + anchor_free=self.anchor_free, num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, limit_val_samples=self.limit_val_samples, @@ -146,6 +153,7 @@ def setup(self, stage: str): logger.info(make_line()) logger.info("Total Train Samples: " + str(self.train_dataset.cumulative_sizes[-1])) logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1])) + logger.info("Anchor-free Mode: " + str(self.anchor_free)) elif stage == "validate": self.val_dataset = ConcatDataset( @@ -157,6 +165,7 @@ def setup(self, stage: str): up_axis=self.up_axis.get(dataset_name, "y"), min_parts=self.min_parts, max_parts=self.max_parts, + anchor_free=self.anchor_free, num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, limit_val_samples=self.limit_val_samples, @@ -166,6 +175,7 @@ def setup(self, stage: str): ) logger.info(make_line()) logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1])) + logger.info("Anchor-free Mode: " + str(self.anchor_free)) elif stage in ["test", "predict"]: self.test_dataset = [ @@ -176,6 +186,7 @@ def setup(self, stage: str): up_axis=self.up_axis.get(dataset_name, "y"), min_parts=self.min_parts, max_parts=self.max_parts, + anchor_free=self.anchor_free, num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, limit_val_samples=self.limit_val_samples, @@ -184,6 +195,7 @@ def setup(self, stage: str): ] logger.info(make_line()) logger.info("Total Test Samples: " + str(sum(len(dataset) for dataset in self.test_dataset))) + logger.info("Anchor-free Mode: " + str(self.anchor_free)) def train_dataloader(self): """Get training dataloader.""" diff --git a/rectified_point_flow/data/dataset.py b/rectified_point_flow/data/dataset.py index ae1384d..ae15fbf 100644 --- a/rectified_point_flow/data/dataset.py +++ b/rectified_point_flow/data/dataset.py @@ -43,6 +43,7 @@ def __init__( up_axis: str = "y", min_parts: int = 2, max_parts: int = 64, + anchor_free: bool = True, num_points_to_sample: int = 5000, min_points_per_part: int = 20, random_scale_range: tuple[float, float] | None = None, @@ -58,6 +59,7 @@ def __init__( self.up_axis = up_axis.lower() self.min_parts = min_parts self.max_parts = max_parts + self.anchor_free = anchor_free self.num_points_to_sample = num_points_to_sample self.min_points_per_part = min_points_per_part self.random_scale_range = random_scale_range @@ -279,6 +281,12 @@ def _transform(self, data: dict) -> dict: pts_gt = np.concatenate(pcs_gt) normals_gt = np.concatenate(pns_gt) + # Use the largest part as the anchor part + anchor = np.zeros(self.max_parts, bool) + anchor_idx = np.argmax(counts) + anchor[anchor_idx] = True + + # Global centering pts_gt, _ = center_pcd(pts_gt) # Rotate point clouds to y-up @@ -290,7 +298,7 @@ def _transform(self, data: dict) -> dict: scale *= np.random.uniform(*self.random_scale_range) pts_gt /= scale - # Initial rotation to remove the pose prior (e.g., y-up) during training + # Initial global rotation to remove the pose prior (e.g., y-up) during training if self.split == "train": pts_gt, normals_gt, init_rot = rotate_pcd(pts_gt, normals_gt) else: @@ -302,11 +310,31 @@ def _proc_part(i): """Process one part: center, rotate, and shuffle.""" st, ed = offsets[i], offsets[i+1] - # Center the point cloud - part, trans = center_pcd(pts_gt[st:ed]) - - # Random rotate the point cloud - part, norms, rot = rotate_pcd(part, normals_gt[st:ed]) + # Center and rotate the part. + # In anchor-free mode (default): + # - Center all parts including the anchor part. + # - Additionally randomly rotate the non-anchor parts. Anchor part keeps its orientation from the initial global rotation. + # + # In anchor-fixed mode (align with GARF): + # - Only center and additionally randomly rotate the non-anchor parts. + # * Note: In anchor-fixed mode, the anchor part's pose in the CoM frame of the GT point cloud is given. + + if self.anchor_free: + part, trans = center_pcd(pts_gt[st:ed]) + if i != anchor_idx: + part, norms, rot = rotate_pcd(part, normals_gt[st:ed]) + else: + rot = np.eye(3) + norms = normals_gt[st:ed] + else: + if i != anchor_idx: + part, trans = center_pcd(pts_gt[st:ed]) + part, norms, rot = rotate_pcd(part, normals_gt[st:ed]) + else: + part = pts_gt[st:ed] + trans = np.zeros(3) + rot = np.eye(3) + norms = normals_gt[st:ed] # Random shuffle point order _order = np.random.permutation(len(part)) @@ -314,7 +342,6 @@ def _proc_part(i): normals[st: ed] = norms[_order] pts_gt[st:ed] = pts_gt[st:ed][_order] normals_gt[st:ed] = normals_gt[st:ed][_order] - return rot, trans results = list(self.pool.map(_proc_part, range(n_parts))) @@ -325,35 +352,33 @@ def _proc_part(i): rots = pad_data(np.stack(rots), self.max_parts) trans = pad_data(np.stack(trans), self.max_parts) - # Use the largest part as the anchor part - anchor = np.zeros(self.max_parts, bool) - primary = np.argmax(counts) - anchor[primary] = True - rots[primary] = np.eye(3) - trans[primary] = np.zeros(3) - - # Select extra parts if multi_anchor is enabled - if self.multi_anchor and n_parts > 2 and np.random.rand() > 1 / n_parts: - candidates = counts[:n_parts] > self.num_points_to_sample * 0.05 - candidates[primary] = False - if candidates.any(): - extra_n = np.random.randint( - 1, min(candidates.sum() + 1, n_parts - 1) - ) - extra_idx = np.random.choice( - np.where(candidates)[0], extra_n, replace=False - ) - anchor[extra_idx] = True - rots[extra_idx] = np.eye(3) - trans[extra_idx] = np.zeros(3) - - # Broadcast anchor part to points - anchor_indices = np.zeros(self.num_points_to_sample, bool) + # In anchor-fixed mode (align with GARF), the anchor's motion is fixed. + if not self.anchor_free: + assert np.allclose(rots[anchor_idx], np.eye(3)), f"rots[anchor_idx] is not the identity matrix: {rots[anchor_idx]}" + assert np.allclose(trans[anchor_idx], np.zeros(3)), f"trans[anchor_idx] is not the zero vector: {trans[anchor_idx]}" + + # Select extra parts if multi_anchor is enabled + if self.multi_anchor and n_parts > 2 and np.random.rand() > 1 / n_parts: + candidates = counts[:n_parts] > self.num_points_to_sample * 0.05 + candidates[anchor_idx] = False + if candidates.any(): + extra_n = np.random.randint( + 1, min(candidates.sum() + 1, n_parts - 1) + ) + extra_idx = np.random.choice( + np.where(candidates)[0], extra_n, replace=False + ) + anchor[extra_idx] = True + rots[extra_idx] = np.eye(3) + trans[extra_idx] = np.zeros(3) + + # Broadcast anchor flag to a per-point boolean mask + anchor_mask = np.zeros(self.num_points_to_sample, bool) for i in range(n_parts): if anchor[i]: st, ed = offsets[i], offsets[i + 1] - anchor_indices[st:ed] = True - + anchor_mask[st:ed] = True + results = {} for key in ["index", "name", "overlap_threshold"]: results[key] = data[key] @@ -369,7 +394,7 @@ def _proc_part(i): results["points_per_part"] = pts_per_part.astype(np.int64) results["scales"] = np.array(scale, dtype=np.float32) results["anchor_parts"] = anchor.astype(bool) - results["anchor_indices"] = anchor_indices.astype(bool) + results["anchor_indices"] = anchor_mask.astype(bool) results["init_rotation"] = init_rot.astype(np.float32) return results diff --git a/rectified_point_flow/encoder/point_cloud_encoder.py b/rectified_point_flow/encoder/point_cloud_encoder.py index 36607bc..3b13c9e 100644 --- a/rectified_point_flow/encoder/point_cloud_encoder.py +++ b/rectified_point_flow/encoder/point_cloud_encoder.py @@ -24,6 +24,7 @@ def __init__( grid_size: float = 0.02, overlap_head_intermediate_dim: int = 16, compute_overlap_points: bool = False, + build_overlap_head: bool = True, ): super().__init__() self.pc_feat_dim = pc_feat_dim @@ -33,9 +34,12 @@ def __init__( self.grid_size = grid_size self.overlap_head_intermediate_dim = overlap_head_intermediate_dim self.compute_overlap_points = compute_overlap_points - self._build_model() + self.build_overlap_head = build_overlap_head + if build_overlap_head: + self._build_overlap_head() + self._init_weights() - def _build_model(self): + def _build_overlap_head(self): """Build the overlap-aware pretraining model components.""" self.norm = nn.LayerNorm(self.pc_feat_dim) self.overlap_head = nn.Sequential( @@ -110,30 +114,27 @@ def _extract_point_features(self, batch: Dict[str, torch.Tensor]) -> tuple: "grid_size": torch.tensor(self.grid_size).to(part_coords.device), }) point["normal"] = part_normals - features = self.norm(point["feat"]) - + features = point["feat"] return features, point, super_point, n_valid_partsarts def forward(self, batch: Dict[str, torch.Tensor], batch_idx: Optional[int] = None) -> Dict[str, torch.Tensor]: """Forward pass of the model.""" # Extract features point_features, point_data, super_point_data, _ = self._extract_point_features(batch) - - # Overlap prediction - overlap_logits = self.overlap_head(point_features) - overlap_prob = torch.sigmoid(overlap_logits) - output = { - "overlap_logits": overlap_logits, - "overlap_prob": overlap_prob, - "point": point_data, - "super_point": super_point_data, - } + output = {"point": point_data, "super_point": super_point_data} - # Compute overlap points GT for pretraining stage + # Overlap prediction (optional) + if self.build_overlap_head: + point_features = self.norm(point_features) + overlap_logits = self.overlap_head(point_features) + overlap_prob = torch.sigmoid(overlap_logits) + output["overlap_logits"] = overlap_logits + output["overlap_prob"] = overlap_prob + + # Compute overlap points GT for pretraining stage (optional) if self.compute_overlap_points: with torch.no_grad(): output["overlap_mask"] = self._compute_overlap_points(batch, point_data) - return output def loss(self, predictions: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> tuple: diff --git a/rectified_point_flow/eval/evaluator.py b/rectified_point_flow/eval/evaluator.py index 0ffe699..04c3eeb 100644 --- a/rectified_point_flow/eval/evaluator.py +++ b/rectified_point_flow/eval/evaluator.py @@ -5,7 +5,7 @@ import torch import lightning as L -from .metrics import compute_object_cd, compute_part_acc, compute_transform_errors +from .metrics import compute_object_cd, compute_part_acc, compute_transform_errors, align_anchor class Evaluator: @@ -34,6 +34,10 @@ def _compute_metrics( pts_gt_rescaled = pts_gt * scales.view(B, 1, 1) pts_pred_rescaled = pointclouds_pred * scales.view(B, 1, 1) + # Align the predicted anchor parts to the ground truth anchor parts using ICP (only used in anchor-free mode) + if self.model.anchor_free: + pts_pred_rescaled = align_anchor(pts_gt_rescaled, pts_pred_rescaled, points_per_part, anchor_parts) + object_cd = compute_object_cd(pts_gt_rescaled, pts_pred_rescaled) part_acc, matched_parts = compute_part_acc(pts_gt_rescaled, pts_pred_rescaled, points_per_part) metrics = { diff --git a/rectified_point_flow/eval/metrics.py b/rectified_point_flow/eval/metrics.py index e30671a..1efe115 100644 --- a/rectified_point_flow/eval/metrics.py +++ b/rectified_point_flow/eval/metrics.py @@ -9,6 +9,49 @@ from ..utils.point_clouds import split_parts +def align_anchor( + pointclouds_gt: torch.Tensor, + pointclouds_pred: torch.Tensor, + points_per_part: torch.Tensor, + anchor_parts: torch.Tensor, +) -> torch.Tensor: + """Align the predicted anchor parts to the ground truth anchor parts using ICP. + + Args: + pointclouds_gt (B, N, 3): Ground truth point clouds. + pointclouds_pred (B, N, 3): Sampled point clouds. + points_per_part (B, P): Number of points in each part. + anchor_parts (B, P): Whether the part is an anchor part; we use the first part with the flag of True as the anchor part. + + Returns: + pointclouds_pred_aligned (B, N, 3): Aligned sampled point clouds. + """ + B, P = anchor_parts.shape + device = pointclouds_pred.device + pointclouds_pred_aligned = pointclouds_pred.clone() + + with torch.amp.autocast(device_type=device.type, dtype=torch.float32): + for b in range(B): + pts_count = 0 + for p in range(P): + if points_per_part[b, p] == 0: + continue + if anchor_parts[b, p]: + ed = pts_count + points_per_part[b, p] + anchor_align_icp = iterative_closest_point(pointclouds_pred[b, pts_count:ed].unsqueeze(0), pointclouds_gt[b, pts_count:ed].unsqueeze(0)).RTs + break + + pts_count = 0 + for p in range(P): + if points_per_part[b, p] == 0: + continue + ed = pts_count + points_per_part[b, p] + pointclouds_pred_aligned[b, pts_count:ed] = pointclouds_pred[b, pts_count:ed] @ anchor_align_icp.R[0].T + anchor_align_icp.T[0] + pts_count = ed + + return pointclouds_pred_aligned + + def compute_object_cd( pointclouds_gt: torch.Tensor, pointclouds_pred: torch.Tensor, diff --git a/rectified_point_flow/flow_model/norm.py b/rectified_point_flow/flow_model/norm.py index 88dc865..03034a1 100644 --- a/rectified_point_flow/flow_model/norm.py +++ b/rectified_point_flow/flow_model/norm.py @@ -28,7 +28,9 @@ def __init__(self, dim: int, heads: int = 1): def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply multi-head RMS normalization.""" - return F.normalize(x, dim=-1) * self.gamma * self.scale + orig = x.dtype + x = F.normalize(x.float(), dim=-1, eps=1e-6) + return (x * self.gamma * self.scale).to(orig) class AdaptiveLayerNorm(nn.Module): @@ -68,4 +70,4 @@ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: emb = self.timestep_embedder(self.timestep_proj(timestep)) # (B, dim) emb = self.linear(self.activation(emb)) # (B, dim * 2) scale, shift = emb.unsqueeze(1).chunk(2, dim=-1) # (B, 1, dim) for both - return self.norm(x) * (1 + scale) + shift \ No newline at end of file + return self.norm(x) * (1 + scale) + shift diff --git a/rectified_point_flow/modeling.py b/rectified_point_flow/modeling.py index f8675da..75111b2 100644 --- a/rectified_point_flow/modeling.py +++ b/rectified_point_flow/modeling.py @@ -27,8 +27,10 @@ def __init__( lr_scheduler: "partial[torch.optim.lr_scheduler._LRScheduler]" = None, encoder_ckpt: str = None, flow_model_ckpt: str = None, + frozen_encoder: bool = False, + anchor_free: bool = True, loss_type: str = "mse", - timestep_sampling: str = "u-shaped", + timestep_sampling: str = "u_shaped", inference_sampling_steps: int = 20, inference_sampler: str = "euler", n_generations: int = 1, @@ -40,6 +42,8 @@ def __init__( self.flow_model = flow_model self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.frozen_encoder = frozen_encoder + self.anchor_free = anchor_free self.loss_type = loss_type self.timestep_sampling = timestep_sampling self.inference_sampling_steps = inference_sampling_steps @@ -69,12 +73,19 @@ def __init__( self.meter = MetricsMeter(self) self._freeze_encoder() - def _freeze_encoder(self): - self.feature_extractor.eval() - for module in self.feature_extractor.modules(): - module.eval() - for param in self.feature_extractor.parameters(): - param.requires_grad = False + def _freeze_encoder(self, eval_mode: bool = False): + if self.frozen_encoder or eval_mode: + self.feature_extractor.eval() + for module in self.feature_extractor.modules(): + module.eval() + for param in self.feature_extractor.parameters(): + param.requires_grad = False + else: + self.feature_extractor.train() + for module in self.feature_extractor.modules(): + module.train() + for param in self.feature_extractor.parameters(): + param.requires_grad = True def on_train_epoch_start(self): super().on_train_epoch_start() @@ -82,11 +93,11 @@ def on_train_epoch_start(self): def on_validation_epoch_start(self): super().on_validation_epoch_start() - self._freeze_encoder() + self._freeze_encoder(eval_mode=True) def on_test_epoch_start(self): super().on_test_epoch_start() - self._freeze_encoder() + self._freeze_encoder(eval_mode=True) def _sample_timesteps( self, @@ -120,7 +131,7 @@ def _sample_timesteps( def _encode(self, data_dict: dict): """Extract features from input data using FP16.""" - with torch.inference_mode(): + with torch.inference_mode(self.frozen_encoder): with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=True): out_dict = self.feature_extractor(data_dict) points = out_dict["point"] @@ -161,9 +172,10 @@ def forward(self, data_dict: dict): x_1 = torch.randn_like(x_0) # (B, N, 3) x_t, v_t = self._compute_flow_target(x_0, x_1, timesteps) # (B, N, 3) each - # Apply anchor part constraints - x_t[anchor_indices] = x_0[anchor_indices] - v_t[anchor_indices] = 0.0 + # Apply anchor part constraints (only used in anchor-fixed mode) + if not self.anchor_free: + x_t[anchor_indices] = x_0[anchor_indices] + v_t[anchor_indices] = 0.0 # Predict velocity field v_pred = self.flow_model( @@ -318,7 +330,7 @@ def _flow_model_fn(x: torch.Tensor, t: float) -> torch.Tensor: flow_model_fn=_flow_model_fn, x_1=x_1, x_0=x_0, - anchor_indices=anchor_indices, + anchor_indices=anchor_indices if not self.anchor_free else None, # None => skip anchor constraints num_steps=self.inference_sampling_steps, return_trajectory=return_tarjectory, ) diff --git a/rectified_point_flow/sampler.py b/rectified_point_flow/sampler.py index cffb3ae..7caa1c1 100644 --- a/rectified_point_flow/sampler.py +++ b/rectified_point_flow/sampler.py @@ -4,6 +4,16 @@ from typing import Callable from functools import partial + +# Anchor helper + +def _reset_anchor(x_t: torch.Tensor, x_0: torch.Tensor, anchor_indices: torch.Tensor | None = None) -> torch.Tensor: + """Reset anchor parts to the ground truth anchor parts if anchor_indices is not None.""" + if anchor_indices is not None: + x_t[anchor_indices] = x_0[anchor_indices] + return x_t + + # Base sampler def flow_sampler( @@ -11,7 +21,7 @@ def flow_sampler( flow_model_fn: Callable, x_1: torch.Tensor, x_0: torch.Tensor, - anchor_indices: torch.Tensor, + anchor_indices: torch.Tensor | None = None, num_steps: int = 20, return_trajectory: bool = False, ) -> torch.Tensor: @@ -22,7 +32,7 @@ def flow_sampler( flow_model_fn: Partial flow model function that takes (x, timesteps) and returns velocity. x_1: Initial noise (B, N, 3). x_0: Ground truth anchor points (B, N, 3). - anchor_indices: Anchor point indices (B, N). + anchor_indices: Anchor point indices (B, N). If None, no anchor part constraints are applied (used in anchor-free mode). num_steps: Number of integration steps, default 20. return_trajectory: Whether to return full trajectory, default False. @@ -31,7 +41,7 @@ def flow_sampler( """ dt = 1.0 / num_steps x_t = x_1.clone() - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) if return_trajectory: trajectory = torch.empty((num_steps, *x_1.shape), device=x_1.device) @@ -60,7 +70,7 @@ def euler_step( """Euler integration step.""" v = flow_model_fn(x_t, t) x_t = x_t - dt * v - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) return x_t def rk2_step( @@ -77,13 +87,13 @@ def rk2_step( # K2 x_mid = x_t - 0.5 * dt * v1 - x_mid[anchor_indices] = x_0[anchor_indices] + x_mid = _reset_anchor(x_mid, x_0, anchor_indices) t_next = max(0, t - 0.5 * dt) v2 = flow_model_fn(x_mid, t_next) # RK2 update x_t = x_t - dt * (v1 + v2) / 2 - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) return x_t def rk4_step( @@ -100,24 +110,24 @@ def rk4_step( # K2 x_temp = x_t - dt * v1 / 2 - x_temp[anchor_indices] = x_0[anchor_indices] + x_temp = _reset_anchor(x_temp, x_0, anchor_indices) t_half = max(0, t - dt / 2) v2 = flow_model_fn(x_temp, t_half) # K3 x_temp = x_t - dt * v2 / 2 - x_temp[anchor_indices] = x_0[anchor_indices] + x_temp = _reset_anchor(x_temp, x_0, anchor_indices) v3 = flow_model_fn(x_temp, t_half) # K4 x_temp = x_t - dt * v3 - x_temp[anchor_indices] = x_0[anchor_indices] + x_temp = _reset_anchor(x_temp, x_0, anchor_indices) t_next = max(0, t - dt) v4 = flow_model_fn(x_temp, t_next) # RK4 update x_t = x_t - dt * (v1 + 2 * v2 + 2 * v3 + v4) / 6 - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) return x_t @@ -130,7 +140,7 @@ def get_sampler(sampler_name: str): sampler_name: Name of the sampler ('euler', 'rk2', 'rk4') Returns: - Sampler function + Sampler function with input arguments (step_fn, flow_model_fn, x_1, x_0, anchor_indices, num_steps, return_trajectory=False) """ step_fns = { 'euler': euler_step, From 82c236836d706ea01be51d3e8698c743c37ce5c7 Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 16:12:04 -0700 Subject: [PATCH 2/7] udpate readme --- README.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index fbc2afd..36ff869 100644 --- a/README.md +++ b/README.md @@ -17,9 +17,9 @@ ## πŸ”” News -- [Oct 26, 2025] Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰πŸŽ‰πŸŽ‰ - - We include an additional **anchor-free** version for RPF, which aligns more with practical assembly tasks. We find that the anchor-free RPF can achieve similar performance to the anchor-fixed RPF, and shows better generalization ability. - - The code and checkpoints are updated to support the anchor-free mode. For more details, please see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25). +- [Oct 26, 2025] Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰ + - We include an **anchor-free** version, which aligns more with practical assembly assumptions. We find that the anchor-free RPF achieves similar performance to the anchor-fixed RPF, and shows better generalization ability. + - The codes and checkpoints are updated for anchor-free mode. For more details, please check the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25). - [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego! @@ -296,6 +296,7 @@ Define parameters for Lightning's [Trainer](https://lightning.ai/docs/pytorch/la **Dataloader workers killed**: Usually this is a signal of insufficient CPU memory or stack. You may try to reduce the `num_workers`. + > [!NOTE] > Please don't hesitate to open an [issue](/issues) if you encounter any problems or bugs! From 016f8305ee54eed03bdabec026603ee59142a7d6 Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 16:14:25 -0700 Subject: [PATCH 3/7] udpate readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 36ff869..ec5e86a 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,8 @@ ## πŸ”” News - [Oct 26, 2025] Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰ - - We include an **anchor-free** version, which aligns more with practical assembly assumptions. We find that the anchor-free RPF achieves similar performance to the anchor-fixed RPF, and shows better generalization ability. - - The codes and checkpoints are updated for anchor-free mode. For more details, please check the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25). + - We include an **anchor-free** version, which aligns more with practical assembly assumptions. We find that the anchor-free RPF achieves similar performance to the anchor-fixed version, and shows better generalization ability. + - The codes and checkpoints are updated for anchor-free mode. See the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. - [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego! From d48a1aefe1f48ffdfb23da1e08f4e048a5e8fa99 Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 16:55:08 -0700 Subject: [PATCH 4/7] udpate ckpt and config --- README.md | 39 +++++++++++++++++++ config/RPF_base_main_10k.yaml | 35 ----------------- config/RPF_base_predict_overlap.yaml | 1 + config/RPF_base_pretrain.yaml | 1 + config/data/ikea.yaml | 3 +- ...rtnet_everyday_twobytwo_modelnet_tudl.yaml | 3 +- ...ryday_twobytwo_modelnet_tudl_objverse.yaml | 3 +- config/data/ikea_twobytwo.yaml | 1 + config/model/rectified_point_flow.yaml | 1 + sample.py | 2 +- 10 files changed, 50 insertions(+), 39 deletions(-) delete mode 100644 config/RPF_base_main_10k.yaml diff --git a/README.md b/README.md index ec5e86a..68044c5 100644 --- a/README.md +++ b/README.md @@ -328,3 +328,42 @@ Some codes in this repo are borrowed from open-source projects, including [DiT]( ## License This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. + + +This PR adds an anchor-free mode for the RPF training and inference. + +In anchor-free mode, the model is not given the anchor part’s pose in the assembled object CoM frame during training or inference, a more practical assumption for real-world assembly tasks. + +Comparison between Anchor-fixed and Anchor-free: + +Anchor Fixed Anchor Free +Training Anchor: Global Rotation. +Non-anchor: Global Rotation + Part Centring + Part Rotation. Anchor: Global Rotation + Part Centring. +Non-anchor: Global Rotation + Part Centring + Part Rotation. +Inference Sampling the flow with anchor's point cloud reset to GT every step. Sampling the flow without anchor resetting. + +### Evaluation in the Anchor-free Mode +Evaluation in the Anchor-free Mode. To keep evaluation comparable between anchor-fixed and anchor-free models, +we perform the alignment steps at evaluation time for anchor-free predictions: + +Align the predicted anchor part to the GT anchor part using ICP. +Apply the same rigid transformation to all predicted non-anchor parts. +Evaluate the aligned whole point cloud against ground truth. +Configuration. The anchor-free mode is enabled by model.anchor_free=true and data.anchor_free=true in the config. They are already enabled by default. + +Code Changes: + +Added an anchor_free parameter (default True) to the rectified_point_flow/data/{dataset, datamodule}.py. +Modified the dataset transformation logic to handle anchor parts differently depending on the mode. +Added an align_anchor function to rectified_point_flow/eval/metrics.py and used it in the evaluator to align predicted anchor parts with ground truth via ICP in anchor-free mode. +More details and results of the anchor-free model can be found in our paper (Appendix. B). + + +Q: Why does an anchor-fixed mode leak information of the GT? +A: During preprocessing, the full object is globally centered to its CoM (center of mass) frame and then normalized to unit scale. If the anchor part is not independently re-centered (i.e., Part Centering ), its coordinates implicitly encode the assembled object’s CoM. + +Q: Why don't we apply Part Rotation to the anchor part in anchor-free mode? +A: Anchor-free training already randomly rotates all non-anchor parts, so the anchor does not provide a stable orientation signal. Applying an extra Part Rotation to the anchor is essentially the same as we add this rotation to the Global Rotation, so we omit this redundant step. + +Q: If the anchor pose is removed, why do we still return the `anchor_indices` in anchor-free mode? +A: This is necessary for evaluation alignment: at test time we align the predicted anchor to the GT anchor using ICP and apply the same rigid transform to all predicted non-anchor parts before computing metrics. We don't reset the model's predicted anchor part to the GT in anchor-free mode. \ No newline at end of file diff --git a/config/RPF_base_main_10k.yaml b/config/RPF_base_main_10k.yaml deleted file mode 100644 index de70d45..0000000 --- a/config/RPF_base_main_10k.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Training Rectified Point Flow - -defaults: - - model: rectified_point_flow - - data: ikea_partnet_everyday_twobytwo_modelnet_tudl - - trainer: main - - loggers: wandb - - _self_ - -# Random seed for reproducibility -seed: 42 - -# Data root -data_root: "../dataset" -data: - num_points_to_sample: 10000 - -# Experiment name and log directory -experiment_name: RPF_base -log_dir: ./output/${experiment_name} -ckpt_path: ${log_dir}/last.ckpt -hydra: - run: - dir: ${log_dir} - -# Model settings -model: - encoder_ckpt: null - flow_model_ckpt: null - - flow_model: - # For 10k points, we replace QK norm by softcapping for speeding up. - attn_dtype: "bfloat16" - softcap: 50.0 - qk_norm: False diff --git a/config/RPF_base_predict_overlap.yaml b/config/RPF_base_predict_overlap.yaml index fe856e3..990a96a 100644 --- a/config/RPF_base_predict_overlap.yaml +++ b/config/RPF_base_predict_overlap.yaml @@ -23,3 +23,4 @@ ckpt_path: null # when null, the checkpoint will be downloaded from Hug # Model settings model: compute_overlap_points: true + build_overlap_head: true \ No newline at end of file diff --git a/config/RPF_base_pretrain.yaml b/config/RPF_base_pretrain.yaml index 433a020..cfafcc6 100644 --- a/config/RPF_base_pretrain.yaml +++ b/config/RPF_base_pretrain.yaml @@ -23,6 +23,7 @@ hydra: model: compute_overlap_points: true + build_overlap_head: true data: limit_val_samples: 1000 diff --git a/config/data/ikea.yaml b/config/data/ikea.yaml index 4bce540..f430f00 100644 --- a/config/data/ikea.yaml +++ b/config/data/ikea.yaml @@ -3,7 +3,8 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 -multi_anchor: true +anchor_free: true +multi_anchor: false data_root: ${data_root} dataset_names: ["ikea"] diff --git a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml index 01f52eb..25eb181 100644 --- a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml +++ b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml @@ -3,7 +3,8 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 -multi_anchor: true +anchor_free: true +multi_anchor: false data_root: ${data_root} dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl"] diff --git a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml index 45e2ba0..4c8fdd3 100644 --- a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml +++ b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml @@ -3,7 +3,8 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 -multi_anchor: true +anchor_free: true +multi_anchor: false data_root: ${data_root} dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl", "objaverse_v1"] diff --git a/config/data/ikea_twobytwo.yaml b/config/data/ikea_twobytwo.yaml index 1af24b8..abcf458 100644 --- a/config/data/ikea_twobytwo.yaml +++ b/config/data/ikea_twobytwo.yaml @@ -3,6 +3,7 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 +anchor_free: true multi_anchor: false data_root: ${data_root} diff --git a/config/model/rectified_point_flow.yaml b/config/model/rectified_point_flow.yaml index 209ec87..2bceeed 100644 --- a/config/model/rectified_point_flow.yaml +++ b/config/model/rectified_point_flow.yaml @@ -22,3 +22,4 @@ timestep_sampling: "u_shaped" inference_sampler: "euler" inference_sampling_steps: 50 n_generations: 1 +anchor_free: true \ No newline at end of file diff --git a/sample.py b/sample.py index 5f0e05d..0c0fdc0 100644 --- a/sample.py +++ b/sample.py @@ -22,7 +22,7 @@ torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True -DEFAULT_CKPT_PATH_HF = "RPF_base_full_ep2000.ckpt" +DEFAULT_CKPT_PATH_HF = "RPF_base_full_anchorfree_ep2000.ckpt" def setup(cfg: DictConfig): From 25cd01c513604851862e5215a6ef97b9d67b19f4 Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 17:00:35 -0700 Subject: [PATCH 5/7] update readme --- README.md | 57 +++---------------------------------------------------- 1 file changed, 3 insertions(+), 54 deletions(-) diff --git a/README.md b/README.md index 68044c5..d8bfe0b 100644 --- a/README.md +++ b/README.md @@ -17,21 +17,9 @@ ## πŸ”” News -- [Oct 26, 2025] Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰ - - We include an **anchor-free** version, which aligns more with practical assembly assumptions. We find that the anchor-free RPF achieves similar performance to the anchor-fixed version, and shows better generalization ability. - - The codes and checkpoints are updated for anchor-free mode. See the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. - -- [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego! - -- [July 22, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes: - - Improved model speed (9-12% faster) and training stability. - - Fixed bugs in configs, RK2 sampler, and validation. - - Simplified point cloud packing and shaping. - - Checkpoints are compatible with the previous version. - -- [July 9, 2025] **Version 0.1**: Release training codes. - -- [July 1, 2025] Initial release of the model checkpoints and inference codes. +- [Oct 26, 2025] + - Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰ It includes an additional **anchor-free** version of RPF, which aligns more with practical assembly assumptions. We also provide additional experiments on generalizability to the paper. + - We release **Version 1.1** to support the anchor-free version. See the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. ## Overview @@ -328,42 +316,3 @@ Some codes in this repo are borrowed from open-source projects, including [DiT]( ## License This project is licensed under the Apache License 2.0 - see the [LICENSE](LICENSE) file for details. - - -This PR adds an anchor-free mode for the RPF training and inference. - -In anchor-free mode, the model is not given the anchor part’s pose in the assembled object CoM frame during training or inference, a more practical assumption for real-world assembly tasks. - -Comparison between Anchor-fixed and Anchor-free: - -Anchor Fixed Anchor Free -Training Anchor: Global Rotation. -Non-anchor: Global Rotation + Part Centring + Part Rotation. Anchor: Global Rotation + Part Centring. -Non-anchor: Global Rotation + Part Centring + Part Rotation. -Inference Sampling the flow with anchor's point cloud reset to GT every step. Sampling the flow without anchor resetting. - -### Evaluation in the Anchor-free Mode -Evaluation in the Anchor-free Mode. To keep evaluation comparable between anchor-fixed and anchor-free models, -we perform the alignment steps at evaluation time for anchor-free predictions: - -Align the predicted anchor part to the GT anchor part using ICP. -Apply the same rigid transformation to all predicted non-anchor parts. -Evaluate the aligned whole point cloud against ground truth. -Configuration. The anchor-free mode is enabled by model.anchor_free=true and data.anchor_free=true in the config. They are already enabled by default. - -Code Changes: - -Added an anchor_free parameter (default True) to the rectified_point_flow/data/{dataset, datamodule}.py. -Modified the dataset transformation logic to handle anchor parts differently depending on the mode. -Added an align_anchor function to rectified_point_flow/eval/metrics.py and used it in the evaluator to align predicted anchor parts with ground truth via ICP in anchor-free mode. -More details and results of the anchor-free model can be found in our paper (Appendix. B). - - -Q: Why does an anchor-fixed mode leak information of the GT? -A: During preprocessing, the full object is globally centered to its CoM (center of mass) frame and then normalized to unit scale. If the anchor part is not independently re-centered (i.e., Part Centering ), its coordinates implicitly encode the assembled object’s CoM. - -Q: Why don't we apply Part Rotation to the anchor part in anchor-free mode? -A: Anchor-free training already randomly rotates all non-anchor parts, so the anchor does not provide a stable orientation signal. Applying an extra Part Rotation to the anchor is essentially the same as we add this rotation to the Global Rotation, so we omit this redundant step. - -Q: If the anchor pose is removed, why do we still return the `anchor_indices` in anchor-free mode? -A: This is necessary for evaluation alignment: at test time we align the predicted anchor to the GT anchor using ICP and apply the same rigid transform to all predicted non-anchor parts before computing metrics. We don't reset the model's predicted anchor part to the GT in anchor-free mode. \ No newline at end of file From 2fd4193a3c68de4dc5234dfd3508b4c9af3480cb Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 17:04:49 -0700 Subject: [PATCH 6/7] update readme --- README.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index d8bfe0b..e21073b 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,20 @@ ## πŸ”” News - [Oct 26, 2025] - - Our NeurIPS camera-ready paper is available on [arXiv](https://arxiv.org/abs/2506.05282v2)! πŸŽ‰ It includes an additional **anchor-free** version of RPF, which aligns more with practical assembly assumptions. We also provide additional experiments on generalizability to the paper. - - We release **Version 1.1** to support the anchor-free version. See the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. + - Our NeurIPS camera-ready [paper](https://arxiv.org/abs/2506.05282v2) is available! πŸŽ‰ We include additional experiments on generalizability and a new **anchor-free** model, which aligns more with practical assembly assumptions. + - We release **Version 1.1** to support the anchor-free model; see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. + +- [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego! + +- [July 22, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes: + - Improved model speed (9-12% faster) and training stability. + - Fixed bugs in configs, RK2 sampler, and validation. + - Simplified point cloud packing and shaping. + - Checkpoints are compatible with the previous version. + +- [July 9, 2025] **Version 0.1**: Release training codes. + +- [July 1, 2025] Initial release of the model checkpoints and inference codes. ## Overview From 8fe3d15692c04d7de3fce3e85cb1ff060ff37a34 Mon Sep 17 00:00:00 2001 From: Tao Sun Date: Sun, 26 Oct 2025 17:05:23 -0700 Subject: [PATCH 7/7] update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index e21073b..08b8f7d 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ ## πŸ”” News -- [Oct 26, 2025] - - Our NeurIPS camera-ready [paper](https://arxiv.org/abs/2506.05282v2) is available! πŸŽ‰ We include additional experiments on generalizability and a new **anchor-free** model, which aligns more with practical assembly assumptions. +- [Oct 26, 2025] Our NeurIPS camera-ready [paper](https://arxiv.org/abs/2506.05282v2) is available! πŸŽ‰ + - We include additional experiments on generalizability and a new **anchor-free** model, which aligns more with practical assembly assumptions. - We release **Version 1.1** to support the anchor-free model; see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. - [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego!