From 641dd31781df7ee3620b4f08218399af419b9d88 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Thu, 26 Mar 2026 01:09:36 +0800 Subject: [PATCH 1/4] feat: support sampling selected fragments --- rust/lance/benches/take.rs | 2 +- rust/lance/src/dataset.rs | 75 +++++++++++-- rust/lance/src/dataset/scanner.rs | 2 +- rust/lance/src/dataset/take.rs | 7 +- rust/lance/src/dataset/tests/dataset_io.rs | 112 +++++++++++++++++++ rust/lance/src/dataset/write/merge_insert.rs | 2 +- rust/lance/src/index/vector/utils.rs | 2 +- 7 files changed, 187 insertions(+), 15 deletions(-) diff --git a/rust/lance/benches/take.rs b/rust/lance/benches/take.rs index 68d9c963ef9..ec078d0f636 100644 --- a/rust/lance/benches/take.rs +++ b/rust/lance/benches/take.rs @@ -376,7 +376,7 @@ fn bench_sample(c: &mut Criterion) { let schema = schema.clone(); let dataset = dataset.clone(); async move { - dataset.sample(sample_size, &schema).await.unwrap(); + dataset.sample(sample_size, &schema, None).await.unwrap(); } }) }, diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 02f4b28e047..404bf6849d1 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -54,7 +54,7 @@ use roaring::RoaringBitmap; use rowids::get_row_id_index; use serde::{Deserialize, Serialize}; use std::borrow::Cow; -use std::collections::{BTreeMap, HashMap, HashSet}; +use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use std::fmt::Debug; use std::ops::Range; use std::pin::Pin; @@ -1466,7 +1466,8 @@ impl Dataset { row_indices: &[u64], column: impl AsRef, ) -> Result> { - let row_addrs = row_offsets_to_row_addresses(self, row_indices).await?; + let fragments = self.get_fragments(); + let row_addrs = row_offsets_to_row_addresses(&fragments, row_indices).await?; blob::take_blobs_by_addresses(self, &row_addrs, column.as_ref()).await } @@ -1484,14 +1485,74 @@ impl Dataset { /// Randomly sample `n` rows from the dataset. /// + /// If `fragment_ids` is provided, sampling is limited to rows from those + /// fragments in the current dataset version. + /// /// The returned rows are in row-id order (not random order), which allows /// the underlying take operation to use an efficient sorted code path. - pub async fn sample(&self, n: usize, projection: &Schema) -> Result { + pub async fn sample( + &self, + n: usize, + projection: &Schema, + fragment_ids: Option<&[u32]>, + ) -> Result { use rand::seq::IteratorRandom; - let num_rows = self.count_rows(None).await?; - let mut ids = (0..num_rows as u64).choose_multiple(&mut rand::rng(), n); - ids.sort_unstable(); - self.take(&ids, projection.clone()).await + + match fragment_ids { + None => { + let num_rows = self.count_rows(None).await?; + let mut ids = (0..num_rows as u64).choose_multiple(&mut rand::rng(), n); + ids.sort_unstable(); + self.take(&ids, projection.clone()).await + } + Some(fragment_ids) => { + if fragment_ids.is_empty() { + return Err(Error::invalid_input( + "Dataset::sample does not accept an empty fragment_ids list".to_string(), + )); + } + + let selected_fragment_ids = fragment_ids.iter().copied().collect::>(); + let selected_fragments = self + .get_fragments() + .into_iter() + .filter(|fragment| selected_fragment_ids.contains(&(fragment.id() as u32))) + .collect::>(); + + if selected_fragments.len() != selected_fragment_ids.len() { + let present_fragment_ids = selected_fragments + .iter() + .map(|fragment| fragment.id() as u32) + .collect::>(); + let missing_fragment_ids = selected_fragment_ids + .into_iter() + .filter(|fragment_id| !present_fragment_ids.contains(fragment_id)) + .collect::>(); + return Err(Error::invalid_input(format!( + "Dataset::sample received fragment ids that are not part of the current dataset version: {missing_fragment_ids:?}", + ))); + } + + let num_rows = stream::iter(selected_fragments.iter().cloned()) + .map(|fragment| async move { fragment.count_rows(None).await }) + .buffer_unordered(16) + .try_fold(0_u64, |acc, rows| async move { Ok(acc + rows as u64) }) + .await?; + + let mut offsets = (0..num_rows).choose_multiple(&mut rand::rng(), n); + offsets.sort_unstable(); + + let row_addrs = row_offsets_to_row_addresses(&selected_fragments, &offsets).await?; + let dataset = Arc::new(self.clone()); + let projection = Arc::new( + ProjectionRequest::from(projection.clone()) + .into_projection_plan(dataset.clone())?, + ); + TakeBuilder::try_new_from_addresses(dataset, row_addrs, projection)? + .execute() + .await + } + } } /// Delete rows based on a predicate. diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 63bd7884879..8c7590c452f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -2718,7 +2718,7 @@ impl Scanner { TakeOperation::RowAddrs(addrs) => self.u64s_as_take_input(addrs), TakeOperation::RowOffsets(offsets) => { let mut addrs = - row_offsets_to_row_addresses(self.dataset.as_ref(), &offsets).await?; + row_offsets_to_row_addresses(&self.dataset.get_fragments(), &offsets).await?; addrs.retain(|addr| *addr != RowAddress::TOMBSTONE_ROW); self.u64s_as_take_input(addrs) } diff --git a/rust/lance/src/dataset/take.rs b/rust/lance/src/dataset/take.rs index 73625a171e0..68121410f01 100644 --- a/rust/lance/src/dataset/take.rs +++ b/rust/lance/src/dataset/take.rs @@ -44,11 +44,9 @@ use super::{Dataset, fragment::FileFragment, scanner::DatasetRecordBatchStream}; /// /// If any offsets are beyond the end of the dataset, they will be mapped to a tombstone row address. pub(super) async fn row_offsets_to_row_addresses( - dataset: &Dataset, + fragments: &[FileFragment], row_indices: &[u64], ) -> Result> { - let fragments = dataset.get_fragments(); - let mut perm = permutation::sort(row_indices); let sorted_offsets = perm.apply_slice(row_indices); @@ -115,7 +113,8 @@ pub async fn take( } // First, convert the dataset offsets into row addresses - let addrs = row_offsets_to_row_addresses(dataset, offsets).await?; + let fragments = dataset.get_fragments(); + let addrs = row_offsets_to_row_addresses(&fragments, offsets).await?; let builder = TakeBuilder::try_new_from_addresses( Arc::new(dataset.clone()), diff --git a/rust/lance/src/dataset/tests/dataset_io.rs b/rust/lance/src/dataset/tests/dataset_io.rs index e438e0801ea..2c094e7dc5c 100644 --- a/rust/lance/src/dataset/tests/dataset_io.rs +++ b/rust/lance/src/dataset/tests/dataset_io.rs @@ -1437,6 +1437,118 @@ async fn test_fast_count_rows( ); } +#[rstest] +#[tokio::test] +async fn test_sample_with_fragment_ids( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, +) { + let test_uri = TempStrDir::default(); + let data = gen_batch() + .col("i", array::step::()) + .into_reader_rows(RowCount::from(12), BatchCount::from(1)); + let mut dataset = Dataset::write( + data, + &test_uri, + Some(WriteParams { + max_rows_per_file: 4, + max_rows_per_group: 2, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + dataset.delete("i IN (1, 9)").await.unwrap(); + + let projection = dataset.schema().project(&["i"]).unwrap(); + let sampled = dataset + .sample(8, &projection, Some(&[0, 0, 2])) + .await + .unwrap(); + let sampled_values = sampled + .column_by_name("i") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .values() + .to_vec(); + + assert_eq!(sampled_values, vec![0, 2, 3, 8, 10, 11]); +} + +#[rstest] +#[tokio::test] +async fn test_sample_with_empty_fragment_ids_rejected( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, +) { + let test_uri = TempStrDir::default(); + let data = gen_batch() + .col("i", array::step::()) + .into_reader_rows(RowCount::from(8), BatchCount::from(1)); + let dataset = Dataset::write( + data, + &test_uri, + Some(WriteParams { + max_rows_per_file: 4, + max_rows_per_group: 2, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + let projection = dataset.schema().project(&["i"]).unwrap(); + let err = dataset.sample(1, &projection, Some(&[])).await.unwrap_err(); + + assert!(matches!(err, Error::InvalidInput { .. })); + assert!( + err.to_string() + .contains("does not accept an empty fragment_ids list") + ); +} + +#[rstest] +#[tokio::test] +async fn test_sample_with_unknown_fragment_ids_rejected( + #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] + data_storage_version: LanceFileVersion, +) { + let test_uri = TempStrDir::default(); + let data = gen_batch() + .col("i", array::step::()) + .into_reader_rows(RowCount::from(8), BatchCount::from(1)); + let dataset = Dataset::write( + data, + &test_uri, + Some(WriteParams { + max_rows_per_file: 4, + max_rows_per_group: 2, + data_storage_version: Some(data_storage_version), + ..Default::default() + }), + ) + .await + .unwrap(); + + let projection = dataset.schema().project(&["i"]).unwrap(); + let err = dataset + .sample(1, &projection, Some(&[0, 999])) + .await + .unwrap_err(); + + assert!(matches!(err, Error::InvalidInput { .. })); + assert!( + err.to_string() + .contains("not part of the current dataset version") + ); + assert!(err.to_string().contains("999")); +} + #[rstest] #[tokio::test] async fn test_bfloat16_roundtrip( diff --git a/rust/lance/src/dataset/write/merge_insert.rs b/rust/lance/src/dataset/write/merge_insert.rs index cfc1e8f0dca..bdbdfda3cda 100644 --- a/rust/lance/src/dataset/write/merge_insert.rs +++ b/rust/lance/src/dataset/write/merge_insert.rs @@ -3317,7 +3317,7 @@ mod tests { // Sample 2048 random indices and then paste on a column of 9999999's let some_indices = ds - .sample(2048, &(&just_index_col).try_into().unwrap()) + .sample(2048, &(&just_index_col).try_into().unwrap(), None) .await .unwrap(); let some_indices = some_indices.column(0).clone(); diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 83e010dc1a4..b20d659d6f3 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -96,7 +96,7 @@ async fn estimate_multivector_vectors_per_row( // Try a few random samples first (fast path). let sample_batch_size = std::cmp::min(64, num_rows); for _ in 0..8 { - let batch = dataset.sample(sample_batch_size, &projection).await?; + let batch = dataset.sample(sample_batch_size, &projection, None).await?; let array = get_column_from_batch(&batch, column)?; let list_array = array.as_list::(); for i in 0..list_array.len() { From c5aed6cab4feab76024e292e7c975584a6bb4b66 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Thu, 26 Mar 2026 02:12:32 +0800 Subject: [PATCH 2/4] feat: expose fragment-scoped vector training to python --- python/python/lance/indices/builder.py | 36 +++++++- .../python/lance/lance/indices/__init__.pyi | 2 + python/python/tests/test_indices.py | 52 ++++++++++++ python/src/indices.rs | 13 ++- rust/lance/src/index/create.rs | 5 +- rust/lance/src/index/vector.rs | 4 +- rust/lance/src/index/vector/builder.rs | 4 +- rust/lance/src/index/vector/ivf.rs | 19 ++++- rust/lance/src/index/vector/pq.rs | 15 +++- rust/lance/src/index/vector/utils.rs | 83 +++++++++++++++++-- 10 files changed, 212 insertions(+), 21 deletions(-) diff --git a/python/python/lance/indices/builder.py b/python/python/lance/indices/builder.py index ca033780a0e..201f70a4a35 100644 --- a/python/python/lance/indices/builder.py +++ b/python/python/lance/indices/builder.py @@ -65,6 +65,7 @@ def train_ivf( accelerator: Optional[Union[str, "torch.Device"]] = None, sample_rate: int = 256, max_iters: int = 50, + fragment_ids: Optional[list[int]] = None, ) -> IvfModel: """ Train IVF centroids for the given vector column. @@ -105,8 +106,10 @@ def train_ivf( some cases, k-means will not converge but will cycle between various possible minima. In these cases we must terminate or run forever. The max_iters parameter defines a cutoff at which we terminate training. + fragment_ids: list[int], optional + If provided, train using only the specified fragments from the dataset. """ - num_rows = self.dataset.count_rows() + num_rows = self._count_rows(fragment_ids) num_partitions = self._determine_num_partitions(num_partitions, num_rows) self._verify_ivf_sample_rate(sample_rate, num_partitions, num_rows) distance_type = self._normalize_distance_type(distance_type) @@ -123,9 +126,14 @@ def train_ivf( distance_type, sample_rate, max_iters, + fragment_ids, ) return IvfModel(ivf_centroids, distance_type) else: + if fragment_ids is not None: + raise NotImplementedError( + "fragment_ids is not supported with accelerator IVF training" + ) # Use accelerator to train ivf centroids from lance.vector import train_ivf_centroids_on_accelerator @@ -153,6 +161,7 @@ def train_pq( *, sample_rate: int = 256, max_iters: int = 50, + fragment_ids: Optional[list[int]] = None, ) -> PqModel: """ Train a PQ model for a given column. @@ -183,10 +192,12 @@ def train_pq( This parameter is used in the same way as in the IVF model. max_iters: int This parameter is used in the same way as in the IVF model. + fragment_ids: list[int], optional + If provided, train using only the specified fragments from the dataset. """ from lance.lance import indices - num_rows = self.dataset.count_rows() + num_rows = self._count_rows(fragment_ids) self.dataset.schema.field(self.column[0]).type.list_size num_subvectors = self._normalize_pq_params(num_subvectors, self.dimension) self._verify_pq_sample_rate(num_rows, sample_rate) @@ -200,6 +211,7 @@ def train_pq( sample_rate, max_iters, ivf_model.centroids, + fragment_ids, ) return PqModel(num_subvectors, pq_codebook) @@ -212,11 +224,17 @@ def prepare_global_ivf_pq( accelerator: Optional[Union[str, "torch.Device"]] = None, sample_rate: int = 256, max_iters: int = 50, + fragment_ids: Optional[list[int]] = None, ) -> dict: """ Perform global training for IVF+PQ using existing CPU training paths and return preprocessed artifacts for distributed builds. + Parameters + ---------- + fragment_ids: list[int], optional + If provided, train using only the specified fragments from the dataset. + Returns ------- dict @@ -238,6 +256,7 @@ def prepare_global_ivf_pq( accelerator=accelerator, # None by default (CPU path) sample_rate=sample_rate, max_iters=max_iters, + fragment_ids=fragment_ids, ) # Global PQ training using IVF residuals @@ -246,6 +265,7 @@ def prepare_global_ivf_pq( num_subvectors, sample_rate=sample_rate, max_iters=max_iters, + fragment_ids=fragment_ids, ) return {"ivf_centroids": ivf_model.centroids, "pq_codebook": pq_model.codebook} @@ -458,6 +478,18 @@ def _determine_num_partitions(self, num_partitions: Optional[int], num_rows: int return round(math.sqrt(num_rows)) return num_partitions + def _count_rows(self, fragment_ids: Optional[list[int]] = None) -> int: + if fragment_ids is None: + return self.dataset.count_rows() + + num_rows = 0 + for fragment_id in fragment_ids: + fragment = self.dataset.get_fragment(fragment_id) + if fragment is None: + raise ValueError(f"Fragment id does not exist: {fragment_id}") + num_rows += fragment.count_rows() + return num_rows + def _normalize_pq_params(self, num_subvectors: int, dimension: int): if num_subvectors is None: if dimension % 16 == 0: diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi index e1282b675a6..3369b61c619 100644 --- a/python/python/lance/lance/indices/__init__.pyi +++ b/python/python/lance/lance/indices/__init__.pyi @@ -46,6 +46,7 @@ def train_ivf_model( distance_type: str, sample_rate: int, max_iters: int, + fragment_ids: Optional[list[int]] = None, ) -> pa.Array: ... def train_pq_model( dataset, @@ -56,6 +57,7 @@ def train_pq_model( sample_rate: int, max_iters: int, ivf_model: pa.Array, + fragment_ids: Optional[list[int]] = None, ) -> pa.Array: ... def transform_vectors( dataset, diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py index e29f02705e2..c0d26053f96 100644 --- a/python/python/tests/test_indices.py +++ b/python/python/tests/test_indices.py @@ -159,6 +159,58 @@ def test_gen_pq(tmpdir, rand_dataset, rand_ivf): assert pq.codebook == reloaded.codebook +def test_ivf_centroids_fragment_ids(tmpdir): + rows_per_fragment = 32 + vectors = np.concatenate( + [ + np.zeros((rows_per_fragment, DIMENSION), dtype=np.float32), + np.full((rows_per_fragment, DIMENSION), 10.0, dtype=np.float32), + ], + axis=0, + ) + vectors.shape = -1 + table = pa.Table.from_arrays( + [pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)], names=["vectors"] + ) + ds = lance.write_dataset( + table, + pathlib.Path(tmpdir) / "fragment_ivf", + max_rows_per_file=rows_per_fragment, + ) + fragment_ids = [fragment.fragment_id for fragment in ds.get_fragments()] + + first_ivf = IndicesBuilder(ds, "vectors").train_ivf( + num_partitions=1, sample_rate=2, fragment_ids=[fragment_ids[0]] + ) + second_ivf = IndicesBuilder(ds, "vectors").train_ivf( + num_partitions=1, sample_rate=2, fragment_ids=[fragment_ids[1]] + ) + + first_centroid = first_ivf.centroids.values.to_numpy().reshape(-1, DIMENSION)[0] + second_centroid = second_ivf.centroids.values.to_numpy().reshape(-1, DIMENSION)[0] + + assert np.allclose(first_centroid, 0.0, atol=1e-4) + assert np.allclose(second_centroid, 10.0, atol=1e-4) + + +def test_pq_fragment_ids(rand_dataset): + fragment_id = rand_dataset.get_fragments()[0].fragment_id + ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf( + num_partitions=4, + sample_rate=16, + fragment_ids=[fragment_id], + ) + + pq = IndicesBuilder(rand_dataset, "vectors").train_pq( + ivf, + sample_rate=2, + fragment_ids=[fragment_id], + ) + + assert pq.dimension == DIMENSION + assert pq.num_subvectors == NUM_SUBVECTORS + + def test_pq_invalid_sub_vectors(tmpdir, rand_dataset, rand_ivf): with pytest.raises( ValueError, diff --git a/python/src/indices.rs b/python/src/indices.rs index f1d42918962..585f6950267 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -11,6 +11,7 @@ use arrow_data::ArrayData; use chrono::{DateTime, Utc}; use lance::dataset::Dataset as LanceDataset; use lance::index::vector::ivf::builder::write_vector_storage; +use lance::index::vector::pq::build_pq_model_in_fragments; use lance::io::ObjectStore; use lance_index::progress::NoopIndexBuildProgress; use lance_index::vector::ivf::shuffler::{IvfShuffler, shuffle_vectors}; @@ -205,6 +206,7 @@ async fn do_train_ivf_model( distance_type: &str, sample_rate: u32, max_iters: u32, + fragment_ids: Option>, ) -> PyResult { // We verify distance_type earlier so can unwrap here let distance_type = DistanceType::try_from(distance_type).unwrap(); @@ -220,6 +222,7 @@ async fn do_train_ivf_model( dimension, distance_type, ¶ms, + fragment_ids.as_deref(), Arc::new(NoopIndexBuildProgress), ) .await @@ -230,6 +233,7 @@ async fn do_train_ivf_model( #[pyfunction] #[allow(clippy::too_many_arguments)] +#[pyo3(signature=(dataset, column, dimension, num_partitions, distance_type, sample_rate, max_iters, fragment_ids=None))] fn train_ivf_model<'py>( py: Python<'py>, dataset: &Dataset, @@ -239,6 +243,7 @@ fn train_ivf_model<'py>( distance_type: &str, sample_rate: u32, max_iters: u32, + fragment_ids: Option>, ) -> PyResult> { let centroids = rt().block_on( Some(py), @@ -250,6 +255,7 @@ fn train_ivf_model<'py>( distance_type, sample_rate, max_iters, + fragment_ids, ), )??; centroids.to_pyarrow(py) @@ -265,6 +271,7 @@ async fn do_train_pq_model( sample_rate: u32, max_iters: u32, ivf_model: IvfModel, + fragment_ids: Option>, ) -> PyResult { // We verify distance_type earlier so can unwrap here let distance_type = DistanceType::try_from(distance_type).unwrap(); @@ -275,13 +282,14 @@ async fn do_train_pq_model( sample_rate: sample_rate as usize, ..Default::default() }; - let pq_model = lance::index::vector::pq::build_pq_model( + let pq_model = build_pq_model_in_fragments( dataset.ds.as_ref(), column, dimension, distance_type, ¶ms, Some(&ivf_model), + fragment_ids.as_deref(), ) .await .infer_error()?; @@ -290,6 +298,7 @@ async fn do_train_pq_model( #[pyfunction] #[allow(clippy::too_many_arguments)] +#[pyo3(signature=(dataset, column, dimension, num_subvectors, distance_type, sample_rate, max_iters, ivf_centroids, fragment_ids=None))] fn train_pq_model<'py>( py: Python<'py>, dataset: &Dataset, @@ -300,6 +309,7 @@ fn train_pq_model<'py>( sample_rate: u32, max_iters: u32, ivf_centroids: PyArrowType, + fragment_ids: Option>, ) -> PyResult> { let ivf_centroids = ivf_centroids.0; let ivf_centroids = FixedSizeListArray::from(ivf_centroids); @@ -320,6 +330,7 @@ fn train_pq_model<'py>( sample_rate, max_iters, ivf_model, + fragment_ids, ), )??; codebook.to_pyarrow(py) diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index a394f52258e..d3c9b3380ce 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -638,8 +638,9 @@ mod tests { use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; use arrow::datatypes::{Float32Type, Int32Type}; use arrow_array::cast::AsArray; - use arrow_array::{FixedSizeListArray, RecordBatchIterator}; - use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_array::{ + FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, StringArray, + }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use lance_arrow::FixedSizeListArrayExt; use lance_core::utils::tempfile::TempStrDir; diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index a19f472a9bf..a09105f8a22 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -571,7 +571,6 @@ pub(crate) async fn build_distributed_vector_index( stages ))); }; - IvfIndexBuilder::::new( filtered_dataset, column.to_owned(), @@ -666,7 +665,6 @@ pub(crate) async fn build_distributed_vector_index( stages ))); }; - IvfIndexBuilder::::new( filtered_dataset, column.to_owned(), @@ -2143,6 +2141,7 @@ mod tests { dim, MetricType::L2, &ivf_params, + None, noop_progress(), ) .await @@ -2195,6 +2194,7 @@ mod tests { dim, MetricType::L2, &ivf_params, + None, noop_progress(), ) .await diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 1f871046b71..017d50319f8 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -405,6 +405,7 @@ impl IvfIndexBuilder dim, self.distance_type, ivf_params, + None, self.progress.clone(), ) .await @@ -434,7 +435,8 @@ impl IvfIndexBuilder sample_size_hint ); let training_data = - utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint).await?; + utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint, None) + .await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 8e47b77cdfd..e747dd4fd25 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -1222,6 +1222,7 @@ pub async fn build_ivf_model( dim: usize, metric_type: MetricType, params: &IvfBuildParams, + fragment_ids: Option<&[u32]>, progress: std::sync::Arc, ) -> Result { let num_partitions = params.num_partitions.unwrap(); @@ -1244,7 +1245,8 @@ pub async fn build_ivf_model( "Loading training data for IVF. Sample size: {}", sample_size_hint ); - let training_data = maybe_sample_training_data(dataset, column, sample_size_hint).await?; + let training_data = + maybe_sample_training_data(dataset, column, sample_size_hint, fragment_ids).await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() @@ -1301,8 +1303,16 @@ async fn build_ivf_model_and_pq( get_vector_type(dataset.schema(), column)?; let dim = get_vector_dim(dataset.schema(), column)?; - let ivf_model = - build_ivf_model(dataset, column, dim, metric_type, ivf_params, progress).await?; + let ivf_model = build_ivf_model( + dataset, + column, + dim, + metric_type, + ivf_params, + None, + progress, + ) + .await?; let ivf_residual = if matches!(metric_type, MetricType::Cosine | MetricType::L2) { Some(&ivf_model) @@ -3286,6 +3296,7 @@ mod tests { DIM, MetricType::L2, &ivf_params, + None, lance_index::progress::noop_progress(), ) .await @@ -3321,6 +3332,7 @@ mod tests { DIM, MetricType::Cosine, &ivf_params, + None, lance_index::progress::noop_progress(), ) .await @@ -3880,6 +3892,7 @@ mod tests { DIM, MetricType::L2, &ivf_params, + None, progress, ) .await diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index d89999d921c..615f1b9c829 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -502,6 +502,18 @@ pub async fn build_pq_model( metric_type: MetricType, params: &PQBuildParams, ivf: Option<&IvfModel>, +) -> Result { + build_pq_model_in_fragments(dataset, column, dim, metric_type, params, ivf, None).await +} + +pub async fn build_pq_model_in_fragments( + dataset: &Dataset, + column: &str, + dim: usize, + metric_type: MetricType, + params: &PQBuildParams, + ivf: Option<&IvfModel>, + fragment_ids: Option<&[u32]>, ) -> Result { let num_codes = 2_usize.pow(params.num_bits as u32); @@ -542,7 +554,7 @@ pub async fn build_pq_model( ); let start = std::time::Instant::now(); let mut training_data = - maybe_sample_training_data(dataset, column, expected_sample_size).await?; + maybe_sample_training_data(dataset, column, expected_sample_size, fragment_ids).await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() @@ -712,6 +724,7 @@ mod tests { DIM, MetricType::Cosine, &ivf_params, + None, lance_index::progress::noop_progress(), ) .await diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index b20d659d6f3..f00a81b764d 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use arrow::array::ArrayData; use arrow::datatypes::DataType; use arrow_array::new_empty_array; -use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, cast::AsArray}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt32Array, cast::AsArray}; use arrow_buffer::{Buffer, MutableBuffer}; use futures::StreamExt; use lance_arrow::DataTypeExt; @@ -86,6 +86,7 @@ async fn estimate_multivector_vectors_per_row( dataset: &Dataset, column: &str, num_rows: usize, + fragments: Option<&[u32]>, ) -> Result { if num_rows == 0 { return Ok(1030); @@ -96,7 +97,9 @@ async fn estimate_multivector_vectors_per_row( // Try a few random samples first (fast path). let sample_batch_size = std::cmp::min(64, num_rows); for _ in 0..8 { - let batch = dataset.sample(sample_batch_size, &projection, None).await?; + let batch = dataset + .sample(sample_batch_size, &projection, fragments) + .await?; let array = get_column_from_batch(&batch, column)?; let list_array = array.as_list::(); for i in 0..list_array.len() { @@ -114,6 +117,9 @@ async fn estimate_multivector_vectors_per_row( // flakiness when values are extremely sparse. let mut scanner = dataset.scan(); scanner.project(&[column])?; + if let Some(fragments) = fragments { + scanner.with_fragments(resolve_scan_fragments(dataset, fragments)?); + } let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; scanner.filter_expr(column_expr.is_not_null()); scanner.limit(Some(std::cmp::min(num_rows, 1024) as i64), None)?; @@ -261,8 +267,15 @@ pub async fn maybe_sample_training_data( dataset: &Dataset, column: &str, sample_size_hint: usize, + fragment_ids: Option<&[u32]>, ) -> Result { - let num_rows = dataset.count_rows(None).await?; + let num_rows = if let Some(fragment_ids) = fragment_ids { + let mut scanner = dataset.scan(); + scanner.with_fragments(resolve_scan_fragments(dataset, fragment_ids)?); + scanner.count_rows().await? as usize + } else { + dataset.count_rows(None).await? + }; let vector_field = dataset.schema().field(column).ok_or(Error::index(format!( "Sample training data: column {} does not exist in schema", @@ -291,7 +304,8 @@ pub async fn maybe_sample_training_data( // Set a minimum sample size of 128 to avoid too small samples, // it's not a problem because 128 multivectors is just about 64 MiB let vectors_per_row = - estimate_multivector_vectors_per_row(dataset, column, num_rows).await?; + estimate_multivector_vectors_per_row(dataset, column, num_rows, fragment_ids) + .await?; sample_size_hint.div_ceil(vectors_per_row).max(128) } _ => sample_size_hint, @@ -306,11 +320,12 @@ pub async fn maybe_sample_training_data( num_rows, vector_field, is_nullable, + fragment_ids, ) .await } else { // too small to require sampling - let batch = scan_all_training_data(dataset, column, is_nullable).await?; + let batch = scan_all_training_data(dataset, column, is_nullable, fragment_ids).await?; vector_column_to_fsl(&batch, column) } } @@ -378,9 +393,13 @@ async fn scan_all_training_data( dataset: &Dataset, column: &str, is_nullable: bool, + fragment_ids: Option<&[u32]>, ) -> Result { let mut scanner = dataset.scan(); scanner.project(&[column])?; + if let Some(fragment_ids) = fragment_ids { + scanner.with_fragments(resolve_scan_fragments(dataset, fragment_ids)?); + } if is_nullable { let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; scanner.filter_expr(column_expr.is_not_null()); @@ -406,14 +425,38 @@ async fn sample_training_data( num_rows: usize, vector_field: &lance_core::datatypes::Field, is_nullable: bool, + fragment_ids: Option<&[u32]>, ) -> Result { + if fragment_ids.is_some() { + if !is_nullable { + let projection = dataset.schema().project(&[column])?; + let batch = dataset + .sample(sample_size_hint, &projection, fragment_ids) + .await?; + return vector_column_to_fsl(&batch, column); + } + + let batch = scan_all_training_data(dataset, column, is_nullable, fragment_ids).await?; + let training_data = vector_column_to_fsl(&batch, column)?; + if training_data.len() <= sample_size_hint { + return Ok(training_data); + } + let indices = UInt32Array::from_iter_values( + generate_random_indices(training_data.len(), sample_size_hint) + .into_iter() + .map(|index| index as u32), + ); + let sampled = arrow_select::take::take(&training_data, &indices, None)?; + return Ok(sampled.as_fixed_size_list().clone()); + } + let byte_width = vector_field .data_type() .byte_width_opt() .unwrap_or(4 * 1024); match vector_field.data_type() { - DataType::FixedSizeList(_, _) if !is_nullable => { + DataType::FixedSizeList(_, _) if !is_nullable && fragment_ids.is_none() => { sample_fsl_uniform( dataset, column, @@ -454,6 +497,28 @@ fn sample_training_data_scan( )) } +fn resolve_scan_fragments( + dataset: &Dataset, + fragment_ids: &[u32], +) -> Result> { + let mut ordered_ids = fragment_ids.to_vec(); + ordered_ids.sort_unstable(); + let fragments = dataset.get_frags_from_ordered_ids(&ordered_ids); + if let Some(missing_id) = fragments + .iter() + .zip(ordered_ids.iter()) + .find_map(|(fragment, fragment_id)| fragment.is_none().then_some(*fragment_id)) + { + return Err(Error::invalid_input(format!( + "Unknown fragment id {missing_id} in training fragment filter" + ))); + } + Ok(fragments + .into_iter() + .map(|fragment| fragment.unwrap().metadata().clone()) + .collect()) +} + /// Build a FixedSizeListArray from raw flat value bytes. fn fsl_values_to_array( field: &lance_core::datatypes::Field, @@ -821,7 +886,7 @@ mod tests { .await .unwrap(); - let training_data = maybe_sample_training_data(&dataset, "mv", 1000) + let training_data = maybe_sample_training_data(&dataset, "mv", 1000, None) .await .unwrap(); assert_eq!(training_data.len(), 1000); @@ -897,7 +962,7 @@ mod tests { .await .unwrap(); - let training_data = maybe_sample_training_data(&dataset, "vec", sample_size) + let training_data = maybe_sample_training_data(&dataset, "vec", sample_size, None) .await .unwrap(); @@ -983,7 +1048,7 @@ mod tests { .await .unwrap(); - let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows) + let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows, None) .await .unwrap(); assert_eq!(n, 1030); From 4cefb3e05ca36c520d2d18c51cae7e0a7d2f666e Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Thu, 26 Mar 2026 02:29:47 +0800 Subject: [PATCH 3/4] fix: update jni IVF trainer for fragment-scoped API --- java/lance-jni/src/vector_trainer.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/java/lance-jni/src/vector_trainer.rs b/java/lance-jni/src/vector_trainer.rs index 92b6afa084b..9ea164d3586 100755 --- a/java/lance-jni/src/vector_trainer.rs +++ b/java/lance-jni/src/vector_trainer.rs @@ -117,6 +117,7 @@ fn inner_train_ivf_centroids<'local>( dim, metric_type, &ivf_params, + None, Arc::new(NoopIndexBuildProgress), ))?; From 3c27866f814015d0b9a0b514a4cdcbd0dc3734c7 Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Fri, 27 Mar 2026 17:36:14 +0800 Subject: [PATCH 4/4] fix: allow private python IVF training helper arity --- python/src/indices.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/python/src/indices.rs b/python/src/indices.rs index 78342e309c4..895ceb19cf9 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -199,6 +199,7 @@ fn get_ivf_model(py: Python<'_>, dataset: &Dataset, index_name: &str) -> PyResul Py::new(py, PyIvfModel { inner: ivf_model }) } +#[allow(clippy::too_many_arguments)] async fn do_train_ivf_model( dataset: &Dataset, column: &str,