Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@


## 🔔 News
- [Oct 26, 2025] Our NeurIPS camera-ready [paper](https://arxiv.org/abs/2506.05282v2) is available! 🎉
- We include additional experiments on generalizability and a new **anchor-free** model, which aligns more with practical assembly assumptions.
- We release **Version 1.1** to support the anchor-free model; see the [PR](https://github.com/GradientSpaces/Rectified-Point-Flow/pull/25) for more details.

- [Sept 18, 2025] Our paper has been accepted to **NeurIPS 2025 (Spotlight)**; see you in San Diego!

- [July 22, 2025] **Version 1.0**: We strongly recommend updating to this version, which includes:
Expand Down Expand Up @@ -292,6 +296,7 @@ Define parameters for Lightning's [Trainer](https://lightning.ai/docs/pytorch/la

**Dataloader workers killed**: Usually this is a signal of insufficient CPU memory or stack. You may try to reduce the `num_workers`.


> [!NOTE]
> Please don't hesitate to open an [issue](/issues) if you encounter any problems or bugs!

Expand Down
35 changes: 0 additions & 35 deletions config/RPF_base_main_10k.yaml

This file was deleted.

1 change: 1 addition & 0 deletions config/RPF_base_predict_overlap.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,4 @@ ckpt_path: null # when null, the checkpoint will be downloaded from Hug
# Model settings
model:
compute_overlap_points: true
build_overlap_head: true
1 change: 1 addition & 0 deletions config/RPF_base_pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ hydra:

model:
compute_overlap_points: true
build_overlap_head: true

data:
limit_val_samples: 1000
3 changes: 2 additions & 1 deletion config/data/ikea.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ num_points_to_sample: 5000
min_parts: 2
max_parts: 64
min_points_per_part: 20
multi_anchor: true
anchor_free: true
multi_anchor: false

data_root: ${data_root}
dataset_names: ["ikea"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ num_points_to_sample: 5000
min_parts: 2
max_parts: 64
min_points_per_part: 20
multi_anchor: true
anchor_free: true
multi_anchor: false

data_root: ${data_root}
dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ num_points_to_sample: 5000
min_parts: 2
max_parts: 64
min_points_per_part: 20
multi_anchor: true
anchor_free: true
multi_anchor: false

data_root: ${data_root}
dataset_names: ["ikea", "partnet", "everyday", "twobytwo", "modelnet", "tudl", "objaverse_v1"]
Expand Down
1 change: 1 addition & 0 deletions config/data/ikea_twobytwo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ num_points_to_sample: 5000
min_parts: 2
max_parts: 64
min_points_per_part: 20
anchor_free: true
multi_anchor: false

data_root: ${data_root}
Expand Down
1 change: 1 addition & 0 deletions config/model/rectified_point_flow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ timestep_sampling: "u_shaped"
inference_sampler: "euler"
inference_sampling_steps: 50
n_generations: 1
anchor_free: true
12 changes: 12 additions & 0 deletions rectified_point_flow/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(
up_axis: dict[str, str] = {},
min_parts: int = 2,
max_parts: int = 64,
anchor_free: bool = True,
num_points_to_sample: int = 5000,
min_points_per_part: int = 20,
min_dataset_size: int = 2000,
Expand All @@ -51,6 +52,9 @@ def __init__(
If not provided, the up axis is assumed to be 'y'. This only affects the visualization.
min_parts: Minimum number of parts in a point cloud.
max_parts: Maximum number of parts in a point cloud.
anchor_free: Whether to use anchor-free mode.
If True, the anchor part is centered and randomly rotated, like the non-anchor parts (default).
If False, the anchor part is not centered and thus its pose in the CoM frame of the GT point cloud is given (align with GARF).
num_points_to_sample: Number of points to sample from each point cloud.
min_points_per_part: Minimum number of points per part.
min_dataset_size: Minimum number of point clouds in a dataset.
Expand All @@ -65,6 +69,7 @@ def __init__(
self.up_axis = up_axis
self.min_parts = min_parts
self.max_parts = max_parts
self.anchor_free = anchor_free
self.num_points_to_sample = num_points_to_sample
self.min_points_per_part = min_points_per_part
self.batch_size = batch_size
Expand Down Expand Up @@ -120,6 +125,7 @@ def setup(self, stage: str):
num_points_to_sample=self.num_points_to_sample,
min_points_per_part=self.min_points_per_part,
min_dataset_size=self.min_dataset_size,
anchor_free=self.anchor_free,
random_scale_range=self.random_scale_range,
multi_anchor=self.multi_anchor,
)
Expand All @@ -136,6 +142,7 @@ def setup(self, stage: str):
dataset_name=dataset_name,
min_parts=self.min_parts,
max_parts=self.max_parts,
anchor_free=self.anchor_free,
num_points_to_sample=self.num_points_to_sample,
min_points_per_part=self.min_points_per_part,
limit_val_samples=self.limit_val_samples,
Expand All @@ -146,6 +153,7 @@ def setup(self, stage: str):
logger.info(make_line())
logger.info("Total Train Samples: " + str(self.train_dataset.cumulative_sizes[-1]))
logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1]))
logger.info("Anchor-free Mode: " + str(self.anchor_free))

elif stage == "validate":
self.val_dataset = ConcatDataset(
Expand All @@ -157,6 +165,7 @@ def setup(self, stage: str):
up_axis=self.up_axis.get(dataset_name, "y"),
min_parts=self.min_parts,
max_parts=self.max_parts,
anchor_free=self.anchor_free,
num_points_to_sample=self.num_points_to_sample,
min_points_per_part=self.min_points_per_part,
limit_val_samples=self.limit_val_samples,
Expand All @@ -166,6 +175,7 @@ def setup(self, stage: str):
)
logger.info(make_line())
logger.info("Total Val Samples: " + str(self.val_dataset.cumulative_sizes[-1]))
logger.info("Anchor-free Mode: " + str(self.anchor_free))

elif stage in ["test", "predict"]:
self.test_dataset = [
Expand All @@ -176,6 +186,7 @@ def setup(self, stage: str):
up_axis=self.up_axis.get(dataset_name, "y"),
min_parts=self.min_parts,
max_parts=self.max_parts,
anchor_free=self.anchor_free,
num_points_to_sample=self.num_points_to_sample,
min_points_per_part=self.min_points_per_part,
limit_val_samples=self.limit_val_samples,
Expand All @@ -184,6 +195,7 @@ def setup(self, stage: str):
]
logger.info(make_line())
logger.info("Total Test Samples: " + str(sum(len(dataset) for dataset in self.test_dataset)))
logger.info("Anchor-free Mode: " + str(self.anchor_free))

def train_dataloader(self):
"""Get training dataloader."""
Expand Down
93 changes: 59 additions & 34 deletions rectified_point_flow/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
up_axis: str = "y",
min_parts: int = 2,
max_parts: int = 64,
anchor_free: bool = True,
num_points_to_sample: int = 5000,
min_points_per_part: int = 20,
random_scale_range: tuple[float, float] | None = None,
Expand All @@ -58,6 +59,7 @@ def __init__(
self.up_axis = up_axis.lower()
self.min_parts = min_parts
self.max_parts = max_parts
self.anchor_free = anchor_free
self.num_points_to_sample = num_points_to_sample
self.min_points_per_part = min_points_per_part
self.random_scale_range = random_scale_range
Expand Down Expand Up @@ -279,6 +281,12 @@ def _transform(self, data: dict) -> dict:
pts_gt = np.concatenate(pcs_gt)
normals_gt = np.concatenate(pns_gt)

# Use the largest part as the anchor part
anchor = np.zeros(self.max_parts, bool)
anchor_idx = np.argmax(counts)
anchor[anchor_idx] = True

# Global centering
pts_gt, _ = center_pcd(pts_gt)

# Rotate point clouds to y-up
Expand All @@ -290,7 +298,7 @@ def _transform(self, data: dict) -> dict:
scale *= np.random.uniform(*self.random_scale_range)
pts_gt /= scale

# Initial rotation to remove the pose prior (e.g., y-up) during training
# Initial global rotation to remove the pose prior (e.g., y-up) during training
if self.split == "train":
pts_gt, normals_gt, init_rot = rotate_pcd(pts_gt, normals_gt)
else:
Expand All @@ -302,19 +310,38 @@ def _proc_part(i):
"""Process one part: center, rotate, and shuffle."""
st, ed = offsets[i], offsets[i+1]

# Center the point cloud
part, trans = center_pcd(pts_gt[st:ed])

# Random rotate the point cloud
part, norms, rot = rotate_pcd(part, normals_gt[st:ed])
# Center and rotate the part.
# In anchor-free mode (default):
# - Center all parts including the anchor part.
# - Additionally randomly rotate the non-anchor parts. Anchor part keeps its orientation from the initial global rotation.
#
# In anchor-fixed mode (align with GARF):
# - Only center and additionally randomly rotate the non-anchor parts.
# * Note: In anchor-fixed mode, the anchor part's pose in the CoM frame of the GT point cloud is given.

if self.anchor_free:
part, trans = center_pcd(pts_gt[st:ed])
if i != anchor_idx:
part, norms, rot = rotate_pcd(part, normals_gt[st:ed])
else:
rot = np.eye(3)
norms = normals_gt[st:ed]
else:
if i != anchor_idx:
part, trans = center_pcd(pts_gt[st:ed])
part, norms, rot = rotate_pcd(part, normals_gt[st:ed])
else:
part = pts_gt[st:ed]
trans = np.zeros(3)
rot = np.eye(3)
norms = normals_gt[st:ed]

# Random shuffle point order
_order = np.random.permutation(len(part))
pts[st: ed] = part[_order]
normals[st: ed] = norms[_order]
pts_gt[st:ed] = pts_gt[st:ed][_order]
normals_gt[st:ed] = normals_gt[st:ed][_order]

return rot, trans

results = list(self.pool.map(_proc_part, range(n_parts)))
Expand All @@ -325,35 +352,33 @@ def _proc_part(i):
rots = pad_data(np.stack(rots), self.max_parts)
trans = pad_data(np.stack(trans), self.max_parts)

# Use the largest part as the anchor part
anchor = np.zeros(self.max_parts, bool)
primary = np.argmax(counts)
anchor[primary] = True
rots[primary] = np.eye(3)
trans[primary] = np.zeros(3)

# Select extra parts if multi_anchor is enabled
if self.multi_anchor and n_parts > 2 and np.random.rand() > 1 / n_parts:
candidates = counts[:n_parts] > self.num_points_to_sample * 0.05
candidates[primary] = False
if candidates.any():
extra_n = np.random.randint(
1, min(candidates.sum() + 1, n_parts - 1)
)
extra_idx = np.random.choice(
np.where(candidates)[0], extra_n, replace=False
)
anchor[extra_idx] = True
rots[extra_idx] = np.eye(3)
trans[extra_idx] = np.zeros(3)

# Broadcast anchor part to points
anchor_indices = np.zeros(self.num_points_to_sample, bool)
# In anchor-fixed mode (align with GARF), the anchor's motion is fixed.
if not self.anchor_free:
assert np.allclose(rots[anchor_idx], np.eye(3)), f"rots[anchor_idx] is not the identity matrix: {rots[anchor_idx]}"
assert np.allclose(trans[anchor_idx], np.zeros(3)), f"trans[anchor_idx] is not the zero vector: {trans[anchor_idx]}"

# Select extra parts if multi_anchor is enabled
if self.multi_anchor and n_parts > 2 and np.random.rand() > 1 / n_parts:
candidates = counts[:n_parts] > self.num_points_to_sample * 0.05
candidates[anchor_idx] = False
if candidates.any():
extra_n = np.random.randint(
1, min(candidates.sum() + 1, n_parts - 1)
)
extra_idx = np.random.choice(
np.where(candidates)[0], extra_n, replace=False
)
anchor[extra_idx] = True
rots[extra_idx] = np.eye(3)
trans[extra_idx] = np.zeros(3)

# Broadcast anchor flag to a per-point boolean mask
anchor_mask = np.zeros(self.num_points_to_sample, bool)
for i in range(n_parts):
if anchor[i]:
st, ed = offsets[i], offsets[i + 1]
anchor_indices[st:ed] = True

anchor_mask[st:ed] = True
results = {}
for key in ["index", "name", "overlap_threshold"]:
results[key] = data[key]
Expand All @@ -369,7 +394,7 @@ def _proc_part(i):
results["points_per_part"] = pts_per_part.astype(np.int64)
results["scales"] = np.array(scale, dtype=np.float32)
results["anchor_parts"] = anchor.astype(bool)
results["anchor_indices"] = anchor_indices.astype(bool)
results["anchor_indices"] = anchor_mask.astype(bool)
results["init_rotation"] = init_rot.astype(np.float32)

return results
Expand Down
Loading