diff --git a/README.md b/README.md index 85b1e39..08b8f7d 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,10 @@ ## 🔔 News +- [Oct 26, 2025] Our NeurIPS camera-ready [paper](https://arxiv.org/abs/2506.05282v2) is available! 🎉 + - We include additional experiments on generalizability and a new **anchor-free** model, which aligns more with practical assembly assumptions. + - We release **Version 1.1** to support the anchor-free model; see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details. + - [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego! - [July 22, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes: @@ -292,6 +296,7 @@ Define parameters for Lightning's [Trainer](https://lightning.ai/docs/pytorch/la **Dataloader workers killed**: Usually this is a signal of insufficient CPU memory or stack. You may try to reduce the `num_workers`. + > [!NOTE] > Please don't hesitate to open an [issue](/issues) if you encounter any problems or bugs! diff --git a/config/RPF_base_main_10k.yaml b/config/RPF_base_main_10k.yaml deleted file mode 100644 index de70d45..0000000 --- a/config/RPF_base_main_10k.yaml +++ /dev/null @@ -1,35 +0,0 @@ -# Training Rectified Point Flow - -defaults: - - model: rectified_point_flow - - data: ikea_partnet_everyday_twobytwo_modelnet_tudl - - trainer: main - - loggers: wandb - - _self_ - -# Random seed for reproducibility -seed: 42 - -# Data root -data_root: "../dataset" -data: - num_points_to_sample: 10000 - -# Experiment name and log directory -experiment_name: RPF_base -log_dir: ./output/${experiment_name} -ckpt_path: ${log_dir}/last.ckpt -hydra: - run: - dir: ${log_dir} - -# Model settings -model: - encoder_ckpt: null - flow_model_ckpt: null - - flow_model: - # For 10k points, we replace QK norm by softcapping for speeding up. - attn_dtype: "bfloat16" - softcap: 50.0 - qk_norm: False diff --git a/config/RPF_base_predict_overlap.yaml b/config/RPF_base_predict_overlap.yaml index fe856e3..990a96a 100644 --- a/config/RPF_base_predict_overlap.yaml +++ b/config/RPF_base_predict_overlap.yaml @@ -23,3 +23,4 @@ ckpt_path: null # when null, the checkpoint will be downloaded from Hug # Model settings model: compute_overlap_points: true + build_overlap_head: true \ No newline at end of file diff --git a/config/RPF_base_pretrain.yaml b/config/RPF_base_pretrain.yaml index 433a020..cfafcc6 100644 --- a/config/RPF_base_pretrain.yaml +++ b/config/RPF_base_pretrain.yaml @@ -23,6 +23,7 @@ hydra: model: compute_overlap_points: true + build_overlap_head: true data: limit_val_samples: 1000 diff --git a/config/data/ikea.yaml b/config/data/ikea.yaml index 4bce540..f430f00 100644 --- a/config/data/ikea.yaml +++ b/config/data/ikea.yaml @@ -3,7 +3,8 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 -multi_anchor: true +anchor_free: true +multi_anchor: false data_root: ${data_root} dataset_names: ["ikea"] diff --git a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml index 01f52eb..25eb181 100644 --- a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml +++ b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl.yaml @@ -3,7 +3,8 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 -multi_anchor: true +anchor_free: true +multi_anchor: false data_root: ${data_root} dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl"] diff --git a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml index 45e2ba0..4c8fdd3 100644 --- a/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml +++ b/config/data/ikea_partnet_everyday_twobytwo_modelnet_tudl_objverse.yaml @@ -3,7 +3,8 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 -multi_anchor: true +anchor_free: true +multi_anchor: false data_root: ${data_root} dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl", "objaverse_v1"] diff --git a/config/data/ikea_twobytwo.yaml b/config/data/ikea_twobytwo.yaml index 1af24b8..abcf458 100644 --- a/config/data/ikea_twobytwo.yaml +++ b/config/data/ikea_twobytwo.yaml @@ -3,6 +3,7 @@ num_points_to_sample: 5000 min_parts: 2 max_parts: 64 min_points_per_part: 20 +anchor_free: true multi_anchor: false data_root: ${data_root} diff --git a/config/model/rectified_point_flow.yaml b/config/model/rectified_point_flow.yaml index 209ec87..2bceeed 100644 --- a/config/model/rectified_point_flow.yaml +++ b/config/model/rectified_point_flow.yaml @@ -22,3 +22,4 @@ timestep_sampling: "u_shaped" inference_sampler: "euler" inference_sampling_steps: 50 n_generations: 1 +anchor_free: true \ No newline at end of file diff --git a/rectified_point_flow/data/datamodule.py b/rectified_point_flow/data/datamodule.py index 32609df..570603a 100644 --- a/rectified_point_flow/data/datamodule.py +++ b/rectified_point_flow/data/datamodule.py @@ -33,6 +33,7 @@ def __init__( up_axis: dict[str, str] = {}, min_parts: int = 2, max_parts: int = 64, + anchor_free: bool = True, num_points_to_sample: int = 5000, min_points_per_part: int = 20, min_dataset_size: int = 2000, @@ -51,6 +52,9 @@ def __init__( If not provided, the up axis is assumed to be 'y'. This only affects the visualization. min_parts: Minimum number of parts in a point cloud. max_parts: Maximum number of parts in a point cloud. + anchor_free: Whether to use anchor-free mode. + If True, the anchor part is centered and randomly rotated, like the non-anchor parts (default). + If False, the anchor part is not centered and thus its pose in the CoM frame of the GT point cloud is given (align with GARF). num_points_to_sample: Number of points to sample from each point cloud. min_points_per_part: Minimum number of points per part. min_dataset_size: Minimum number of point clouds in a dataset. @@ -65,6 +69,7 @@ def __init__( self.up_axis = up_axis self.min_parts = min_parts self.max_parts = max_parts + self.anchor_free = anchor_free self.num_points_to_sample = num_points_to_sample self.min_points_per_part = min_points_per_part self.batch_size = batch_size @@ -120,6 +125,7 @@ def setup(self, stage: str): num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, min_dataset_size=self.min_dataset_size, + anchor_free=self.anchor_free, random_scale_range=self.random_scale_range, multi_anchor=self.multi_anchor, ) @@ -136,6 +142,7 @@ def setup(self, stage: str): dataset_name=dataset_name, min_parts=self.min_parts, max_parts=self.max_parts, + anchor_free=self.anchor_free, num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, limit_val_samples=self.limit_val_samples, @@ -146,6 +153,7 @@ def setup(self, stage: str): logger.info(make_line()) logger.info("Total Train Samples: " + str(self.train_dataset.cumulative_sizes[-1])) logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1])) + logger.info("Anchor-free Mode: " + str(self.anchor_free)) elif stage == "validate": self.val_dataset = ConcatDataset( @@ -157,6 +165,7 @@ def setup(self, stage: str): up_axis=self.up_axis.get(dataset_name, "y"), min_parts=self.min_parts, max_parts=self.max_parts, + anchor_free=self.anchor_free, num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, limit_val_samples=self.limit_val_samples, @@ -166,6 +175,7 @@ def setup(self, stage: str): ) logger.info(make_line()) logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1])) + logger.info("Anchor-free Mode: " + str(self.anchor_free)) elif stage in ["test", "predict"]: self.test_dataset = [ @@ -176,6 +186,7 @@ def setup(self, stage: str): up_axis=self.up_axis.get(dataset_name, "y"), min_parts=self.min_parts, max_parts=self.max_parts, + anchor_free=self.anchor_free, num_points_to_sample=self.num_points_to_sample, min_points_per_part=self.min_points_per_part, limit_val_samples=self.limit_val_samples, @@ -184,6 +195,7 @@ def setup(self, stage: str): ] logger.info(make_line()) logger.info("Total Test Samples: " + str(sum(len(dataset) for dataset in self.test_dataset))) + logger.info("Anchor-free Mode: " + str(self.anchor_free)) def train_dataloader(self): """Get training dataloader.""" diff --git a/rectified_point_flow/data/dataset.py b/rectified_point_flow/data/dataset.py index ae1384d..ae15fbf 100644 --- a/rectified_point_flow/data/dataset.py +++ b/rectified_point_flow/data/dataset.py @@ -43,6 +43,7 @@ def __init__( up_axis: str = "y", min_parts: int = 2, max_parts: int = 64, + anchor_free: bool = True, num_points_to_sample: int = 5000, min_points_per_part: int = 20, random_scale_range: tuple[float, float] | None = None, @@ -58,6 +59,7 @@ def __init__( self.up_axis = up_axis.lower() self.min_parts = min_parts self.max_parts = max_parts + self.anchor_free = anchor_free self.num_points_to_sample = num_points_to_sample self.min_points_per_part = min_points_per_part self.random_scale_range = random_scale_range @@ -279,6 +281,12 @@ def _transform(self, data: dict) -> dict: pts_gt = np.concatenate(pcs_gt) normals_gt = np.concatenate(pns_gt) + # Use the largest part as the anchor part + anchor = np.zeros(self.max_parts, bool) + anchor_idx = np.argmax(counts) + anchor[anchor_idx] = True + + # Global centering pts_gt, _ = center_pcd(pts_gt) # Rotate point clouds to y-up @@ -290,7 +298,7 @@ def _transform(self, data: dict) -> dict: scale *= np.random.uniform(*self.random_scale_range) pts_gt /= scale - # Initial rotation to remove the pose prior (e.g., y-up) during training + # Initial global rotation to remove the pose prior (e.g., y-up) during training if self.split == "train": pts_gt, normals_gt, init_rot = rotate_pcd(pts_gt, normals_gt) else: @@ -302,11 +310,31 @@ def _proc_part(i): """Process one part: center, rotate, and shuffle.""" st, ed = offsets[i], offsets[i+1] - # Center the point cloud - part, trans = center_pcd(pts_gt[st:ed]) - - # Random rotate the point cloud - part, norms, rot = rotate_pcd(part, normals_gt[st:ed]) + # Center and rotate the part. + # In anchor-free mode (default): + # - Center all parts including the anchor part. + # - Additionally randomly rotate the non-anchor parts. Anchor part keeps its orientation from the initial global rotation. + # + # In anchor-fixed mode (align with GARF): + # - Only center and additionally randomly rotate the non-anchor parts. + # * Note: In anchor-fixed mode, the anchor part's pose in the CoM frame of the GT point cloud is given. + + if self.anchor_free: + part, trans = center_pcd(pts_gt[st:ed]) + if i != anchor_idx: + part, norms, rot = rotate_pcd(part, normals_gt[st:ed]) + else: + rot = np.eye(3) + norms = normals_gt[st:ed] + else: + if i != anchor_idx: + part, trans = center_pcd(pts_gt[st:ed]) + part, norms, rot = rotate_pcd(part, normals_gt[st:ed]) + else: + part = pts_gt[st:ed] + trans = np.zeros(3) + rot = np.eye(3) + norms = normals_gt[st:ed] # Random shuffle point order _order = np.random.permutation(len(part)) @@ -314,7 +342,6 @@ def _proc_part(i): normals[st: ed] = norms[_order] pts_gt[st:ed] = pts_gt[st:ed][_order] normals_gt[st:ed] = normals_gt[st:ed][_order] - return rot, trans results = list(self.pool.map(_proc_part, range(n_parts))) @@ -325,35 +352,33 @@ def _proc_part(i): rots = pad_data(np.stack(rots), self.max_parts) trans = pad_data(np.stack(trans), self.max_parts) - # Use the largest part as the anchor part - anchor = np.zeros(self.max_parts, bool) - primary = np.argmax(counts) - anchor[primary] = True - rots[primary] = np.eye(3) - trans[primary] = np.zeros(3) - - # Select extra parts if multi_anchor is enabled - if self.multi_anchor and n_parts > 2 and np.random.rand() > 1 / n_parts: - candidates = counts[:n_parts] > self.num_points_to_sample * 0.05 - candidates[primary] = False - if candidates.any(): - extra_n = np.random.randint( - 1, min(candidates.sum() + 1, n_parts - 1) - ) - extra_idx = np.random.choice( - np.where(candidates)[0], extra_n, replace=False - ) - anchor[extra_idx] = True - rots[extra_idx] = np.eye(3) - trans[extra_idx] = np.zeros(3) - - # Broadcast anchor part to points - anchor_indices = np.zeros(self.num_points_to_sample, bool) + # In anchor-fixed mode (align with GARF), the anchor's motion is fixed. + if not self.anchor_free: + assert np.allclose(rots[anchor_idx], np.eye(3)), f"rots[anchor_idx] is not the identity matrix: {rots[anchor_idx]}" + assert np.allclose(trans[anchor_idx], np.zeros(3)), f"trans[anchor_idx] is not the zero vector: {trans[anchor_idx]}" + + # Select extra parts if multi_anchor is enabled + if self.multi_anchor and n_parts > 2 and np.random.rand() > 1 / n_parts: + candidates = counts[:n_parts] > self.num_points_to_sample * 0.05 + candidates[anchor_idx] = False + if candidates.any(): + extra_n = np.random.randint( + 1, min(candidates.sum() + 1, n_parts - 1) + ) + extra_idx = np.random.choice( + np.where(candidates)[0], extra_n, replace=False + ) + anchor[extra_idx] = True + rots[extra_idx] = np.eye(3) + trans[extra_idx] = np.zeros(3) + + # Broadcast anchor flag to a per-point boolean mask + anchor_mask = np.zeros(self.num_points_to_sample, bool) for i in range(n_parts): if anchor[i]: st, ed = offsets[i], offsets[i + 1] - anchor_indices[st:ed] = True - + anchor_mask[st:ed] = True + results = {} for key in ["index", "name", "overlap_threshold"]: results[key] = data[key] @@ -369,7 +394,7 @@ def _proc_part(i): results["points_per_part"] = pts_per_part.astype(np.int64) results["scales"] = np.array(scale, dtype=np.float32) results["anchor_parts"] = anchor.astype(bool) - results["anchor_indices"] = anchor_indices.astype(bool) + results["anchor_indices"] = anchor_mask.astype(bool) results["init_rotation"] = init_rot.astype(np.float32) return results diff --git a/rectified_point_flow/encoder/point_cloud_encoder.py b/rectified_point_flow/encoder/point_cloud_encoder.py index 36607bc..3b13c9e 100644 --- a/rectified_point_flow/encoder/point_cloud_encoder.py +++ b/rectified_point_flow/encoder/point_cloud_encoder.py @@ -24,6 +24,7 @@ def __init__( grid_size: float = 0.02, overlap_head_intermediate_dim: int = 16, compute_overlap_points: bool = False, + build_overlap_head: bool = True, ): super().__init__() self.pc_feat_dim = pc_feat_dim @@ -33,9 +34,12 @@ def __init__( self.grid_size = grid_size self.overlap_head_intermediate_dim = overlap_head_intermediate_dim self.compute_overlap_points = compute_overlap_points - self._build_model() + self.build_overlap_head = build_overlap_head + if build_overlap_head: + self._build_overlap_head() + self._init_weights() - def _build_model(self): + def _build_overlap_head(self): """Build the overlap-aware pretraining model components.""" self.norm = nn.LayerNorm(self.pc_feat_dim) self.overlap_head = nn.Sequential( @@ -110,30 +114,27 @@ def _extract_point_features(self, batch: Dict[str, torch.Tensor]) -> tuple: "grid_size": torch.tensor(self.grid_size).to(part_coords.device), }) point["normal"] = part_normals - features = self.norm(point["feat"]) - + features = point["feat"] return features, point, super_point, n_valid_partsarts def forward(self, batch: Dict[str, torch.Tensor], batch_idx: Optional[int] = None) -> Dict[str, torch.Tensor]: """Forward pass of the model.""" # Extract features point_features, point_data, super_point_data, _ = self._extract_point_features(batch) - - # Overlap prediction - overlap_logits = self.overlap_head(point_features) - overlap_prob = torch.sigmoid(overlap_logits) - output = { - "overlap_logits": overlap_logits, - "overlap_prob": overlap_prob, - "point": point_data, - "super_point": super_point_data, - } + output = {"point": point_data, "super_point": super_point_data} - # Compute overlap points GT for pretraining stage + # Overlap prediction (optional) + if self.build_overlap_head: + point_features = self.norm(point_features) + overlap_logits = self.overlap_head(point_features) + overlap_prob = torch.sigmoid(overlap_logits) + output["overlap_logits"] = overlap_logits + output["overlap_prob"] = overlap_prob + + # Compute overlap points GT for pretraining stage (optional) if self.compute_overlap_points: with torch.no_grad(): output["overlap_mask"] = self._compute_overlap_points(batch, point_data) - return output def loss(self, predictions: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor]) -> tuple: diff --git a/rectified_point_flow/eval/evaluator.py b/rectified_point_flow/eval/evaluator.py index 0ffe699..04c3eeb 100644 --- a/rectified_point_flow/eval/evaluator.py +++ b/rectified_point_flow/eval/evaluator.py @@ -5,7 +5,7 @@ import torch import lightning as L -from .metrics import compute_object_cd, compute_part_acc, compute_transform_errors +from .metrics import compute_object_cd, compute_part_acc, compute_transform_errors, align_anchor class Evaluator: @@ -34,6 +34,10 @@ def _compute_metrics( pts_gt_rescaled = pts_gt * scales.view(B, 1, 1) pts_pred_rescaled = pointclouds_pred * scales.view(B, 1, 1) + # Align the predicted anchor parts to the ground truth anchor parts using ICP (only used in anchor-free mode) + if self.model.anchor_free: + pts_pred_rescaled = align_anchor(pts_gt_rescaled, pts_pred_rescaled, points_per_part, anchor_parts) + object_cd = compute_object_cd(pts_gt_rescaled, pts_pred_rescaled) part_acc, matched_parts = compute_part_acc(pts_gt_rescaled, pts_pred_rescaled, points_per_part) metrics = { diff --git a/rectified_point_flow/eval/metrics.py b/rectified_point_flow/eval/metrics.py index e30671a..1efe115 100644 --- a/rectified_point_flow/eval/metrics.py +++ b/rectified_point_flow/eval/metrics.py @@ -9,6 +9,49 @@ from ..utils.point_clouds import split_parts +def align_anchor( + pointclouds_gt: torch.Tensor, + pointclouds_pred: torch.Tensor, + points_per_part: torch.Tensor, + anchor_parts: torch.Tensor, +) -> torch.Tensor: + """Align the predicted anchor parts to the ground truth anchor parts using ICP. + + Args: + pointclouds_gt (B, N, 3): Ground truth point clouds. + pointclouds_pred (B, N, 3): Sampled point clouds. + points_per_part (B, P): Number of points in each part. + anchor_parts (B, P): Whether the part is an anchor part; we use the first part with the flag of True as the anchor part. + + Returns: + pointclouds_pred_aligned (B, N, 3): Aligned sampled point clouds. + """ + B, P = anchor_parts.shape + device = pointclouds_pred.device + pointclouds_pred_aligned = pointclouds_pred.clone() + + with torch.amp.autocast(device_type=device.type, dtype=torch.float32): + for b in range(B): + pts_count = 0 + for p in range(P): + if points_per_part[b, p] == 0: + continue + if anchor_parts[b, p]: + ed = pts_count + points_per_part[b, p] + anchor_align_icp = iterative_closest_point(pointclouds_pred[b, pts_count:ed].unsqueeze(0), pointclouds_gt[b, pts_count:ed].unsqueeze(0)).RTs + break + + pts_count = 0 + for p in range(P): + if points_per_part[b, p] == 0: + continue + ed = pts_count + points_per_part[b, p] + pointclouds_pred_aligned[b, pts_count:ed] = pointclouds_pred[b, pts_count:ed] @ anchor_align_icp.R[0].T + anchor_align_icp.T[0] + pts_count = ed + + return pointclouds_pred_aligned + + def compute_object_cd( pointclouds_gt: torch.Tensor, pointclouds_pred: torch.Tensor, diff --git a/rectified_point_flow/flow_model/norm.py b/rectified_point_flow/flow_model/norm.py index 88dc865..03034a1 100644 --- a/rectified_point_flow/flow_model/norm.py +++ b/rectified_point_flow/flow_model/norm.py @@ -28,7 +28,9 @@ def __init__(self, dim: int, heads: int = 1): def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply multi-head RMS normalization.""" - return F.normalize(x, dim=-1) * self.gamma * self.scale + orig = x.dtype + x = F.normalize(x.float(), dim=-1, eps=1e-6) + return (x * self.gamma * self.scale).to(orig) class AdaptiveLayerNorm(nn.Module): @@ -68,4 +70,4 @@ def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: emb = self.timestep_embedder(self.timestep_proj(timestep)) # (B, dim) emb = self.linear(self.activation(emb)) # (B, dim * 2) scale, shift = emb.unsqueeze(1).chunk(2, dim=-1) # (B, 1, dim) for both - return self.norm(x) * (1 + scale) + shift \ No newline at end of file + return self.norm(x) * (1 + scale) + shift diff --git a/rectified_point_flow/modeling.py b/rectified_point_flow/modeling.py index f8675da..75111b2 100644 --- a/rectified_point_flow/modeling.py +++ b/rectified_point_flow/modeling.py @@ -27,8 +27,10 @@ def __init__( lr_scheduler: "partial[torch.optim.lr_scheduler._LRScheduler]" = None, encoder_ckpt: str = None, flow_model_ckpt: str = None, + frozen_encoder: bool = False, + anchor_free: bool = True, loss_type: str = "mse", - timestep_sampling: str = "u-shaped", + timestep_sampling: str = "u_shaped", inference_sampling_steps: int = 20, inference_sampler: str = "euler", n_generations: int = 1, @@ -40,6 +42,8 @@ def __init__( self.flow_model = flow_model self.optimizer = optimizer self.lr_scheduler = lr_scheduler + self.frozen_encoder = frozen_encoder + self.anchor_free = anchor_free self.loss_type = loss_type self.timestep_sampling = timestep_sampling self.inference_sampling_steps = inference_sampling_steps @@ -69,12 +73,19 @@ def __init__( self.meter = MetricsMeter(self) self._freeze_encoder() - def _freeze_encoder(self): - self.feature_extractor.eval() - for module in self.feature_extractor.modules(): - module.eval() - for param in self.feature_extractor.parameters(): - param.requires_grad = False + def _freeze_encoder(self, eval_mode: bool = False): + if self.frozen_encoder or eval_mode: + self.feature_extractor.eval() + for module in self.feature_extractor.modules(): + module.eval() + for param in self.feature_extractor.parameters(): + param.requires_grad = False + else: + self.feature_extractor.train() + for module in self.feature_extractor.modules(): + module.train() + for param in self.feature_extractor.parameters(): + param.requires_grad = True def on_train_epoch_start(self): super().on_train_epoch_start() @@ -82,11 +93,11 @@ def on_train_epoch_start(self): def on_validation_epoch_start(self): super().on_validation_epoch_start() - self._freeze_encoder() + self._freeze_encoder(eval_mode=True) def on_test_epoch_start(self): super().on_test_epoch_start() - self._freeze_encoder() + self._freeze_encoder(eval_mode=True) def _sample_timesteps( self, @@ -120,7 +131,7 @@ def _sample_timesteps( def _encode(self, data_dict: dict): """Extract features from input data using FP16.""" - with torch.inference_mode(): + with torch.inference_mode(self.frozen_encoder): with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=True): out_dict = self.feature_extractor(data_dict) points = out_dict["point"] @@ -161,9 +172,10 @@ def forward(self, data_dict: dict): x_1 = torch.randn_like(x_0) # (B, N, 3) x_t, v_t = self._compute_flow_target(x_0, x_1, timesteps) # (B, N, 3) each - # Apply anchor part constraints - x_t[anchor_indices] = x_0[anchor_indices] - v_t[anchor_indices] = 0.0 + # Apply anchor part constraints (only used in anchor-fixed mode) + if not self.anchor_free: + x_t[anchor_indices] = x_0[anchor_indices] + v_t[anchor_indices] = 0.0 # Predict velocity field v_pred = self.flow_model( @@ -318,7 +330,7 @@ def _flow_model_fn(x: torch.Tensor, t: float) -> torch.Tensor: flow_model_fn=_flow_model_fn, x_1=x_1, x_0=x_0, - anchor_indices=anchor_indices, + anchor_indices=anchor_indices if not self.anchor_free else None, # None => skip anchor constraints num_steps=self.inference_sampling_steps, return_trajectory=return_tarjectory, ) diff --git a/rectified_point_flow/sampler.py b/rectified_point_flow/sampler.py index cffb3ae..7caa1c1 100644 --- a/rectified_point_flow/sampler.py +++ b/rectified_point_flow/sampler.py @@ -4,6 +4,16 @@ from typing import Callable from functools import partial + +# Anchor helper + +def _reset_anchor(x_t: torch.Tensor, x_0: torch.Tensor, anchor_indices: torch.Tensor | None = None) -> torch.Tensor: + """Reset anchor parts to the ground truth anchor parts if anchor_indices is not None.""" + if anchor_indices is not None: + x_t[anchor_indices] = x_0[anchor_indices] + return x_t + + # Base sampler def flow_sampler( @@ -11,7 +21,7 @@ def flow_sampler( flow_model_fn: Callable, x_1: torch.Tensor, x_0: torch.Tensor, - anchor_indices: torch.Tensor, + anchor_indices: torch.Tensor | None = None, num_steps: int = 20, return_trajectory: bool = False, ) -> torch.Tensor: @@ -22,7 +32,7 @@ def flow_sampler( flow_model_fn: Partial flow model function that takes (x, timesteps) and returns velocity. x_1: Initial noise (B, N, 3). x_0: Ground truth anchor points (B, N, 3). - anchor_indices: Anchor point indices (B, N). + anchor_indices: Anchor point indices (B, N). If None, no anchor part constraints are applied (used in anchor-free mode). num_steps: Number of integration steps, default 20. return_trajectory: Whether to return full trajectory, default False. @@ -31,7 +41,7 @@ def flow_sampler( """ dt = 1.0 / num_steps x_t = x_1.clone() - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) if return_trajectory: trajectory = torch.empty((num_steps, *x_1.shape), device=x_1.device) @@ -60,7 +70,7 @@ def euler_step( """Euler integration step.""" v = flow_model_fn(x_t, t) x_t = x_t - dt * v - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) return x_t def rk2_step( @@ -77,13 +87,13 @@ def rk2_step( # K2 x_mid = x_t - 0.5 * dt * v1 - x_mid[anchor_indices] = x_0[anchor_indices] + x_mid = _reset_anchor(x_mid, x_0, anchor_indices) t_next = max(0, t - 0.5 * dt) v2 = flow_model_fn(x_mid, t_next) # RK2 update x_t = x_t - dt * (v1 + v2) / 2 - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) return x_t def rk4_step( @@ -100,24 +110,24 @@ def rk4_step( # K2 x_temp = x_t - dt * v1 / 2 - x_temp[anchor_indices] = x_0[anchor_indices] + x_temp = _reset_anchor(x_temp, x_0, anchor_indices) t_half = max(0, t - dt / 2) v2 = flow_model_fn(x_temp, t_half) # K3 x_temp = x_t - dt * v2 / 2 - x_temp[anchor_indices] = x_0[anchor_indices] + x_temp = _reset_anchor(x_temp, x_0, anchor_indices) v3 = flow_model_fn(x_temp, t_half) # K4 x_temp = x_t - dt * v3 - x_temp[anchor_indices] = x_0[anchor_indices] + x_temp = _reset_anchor(x_temp, x_0, anchor_indices) t_next = max(0, t - dt) v4 = flow_model_fn(x_temp, t_next) # RK4 update x_t = x_t - dt * (v1 + 2 * v2 + 2 * v3 + v4) / 6 - x_t[anchor_indices] = x_0[anchor_indices] + x_t = _reset_anchor(x_t, x_0, anchor_indices) return x_t @@ -130,7 +140,7 @@ def get_sampler(sampler_name: str): sampler_name: Name of the sampler ('euler', 'rk2', 'rk4') Returns: - Sampler function + Sampler function with input arguments (step_fn, flow_model_fn, x_1, x_0, anchor_indices, num_steps, return_trajectory=False) """ step_fns = { 'euler': euler_step, diff --git a/sample.py b/sample.py index 5f0e05d..0c0fdc0 100644 --- a/sample.py +++ b/sample.py @@ -22,7 +22,7 @@ torch.backends.cudnn.allow_tf32 = True torch.backends.cudnn.benchmark = True -DEFAULT_CKPT_PATH_HF = "RPF_base_full_ep2000.ckpt" +DEFAULT_CKPT_PATH_HF = "RPF_base_full_anchorfree_ep2000.ckpt" def setup(cfg: DictConfig):