diff --git a/java/lance-jni/src/vector_trainer.rs b/java/lance-jni/src/vector_trainer.rs index 92b6afa084b..9ea164d3586 100755 --- a/java/lance-jni/src/vector_trainer.rs +++ b/java/lance-jni/src/vector_trainer.rs @@ -117,6 +117,7 @@ fn inner_train_ivf_centroids<'local>( dim, metric_type, &ivf_params, + None, Arc::new(NoopIndexBuildProgress), ))?; diff --git a/python/python/lance/indices/builder.py b/python/python/lance/indices/builder.py index 7105c9234bb..c31ea0a7a0c 100644 --- a/python/python/lance/indices/builder.py +++ b/python/python/lance/indices/builder.py @@ -66,6 +66,7 @@ def train_ivf( accelerator: Optional[Union[str, "torch.Device"]] = None, sample_rate: int = 256, max_iters: int = 50, + fragment_ids: Optional[list[int]] = None, ) -> IvfModel: """ Train IVF centroids for the given vector column. @@ -106,8 +107,10 @@ def train_ivf( some cases, k-means will not converge but will cycle between various possible minima. In these cases we must terminate or run forever. The max_iters parameter defines a cutoff at which we terminate training. + fragment_ids: list[int], optional + If provided, train using only the specified fragments from the dataset. """ - num_rows = self.dataset.count_rows() + num_rows = self._count_rows(fragment_ids) num_partitions = self._determine_num_partitions(num_partitions, num_rows) self._verify_ivf_sample_rate(sample_rate, num_partitions, num_rows) distance_type = self._normalize_distance_type(distance_type) @@ -124,9 +127,14 @@ def train_ivf( distance_type, sample_rate, max_iters, + fragment_ids, ) return IvfModel(ivf_centroids, distance_type) else: + if fragment_ids is not None: + raise NotImplementedError( + "fragment_ids is not supported with accelerator IVF training" + ) # Use accelerator to train ivf centroids from lance.vector import train_ivf_centroids_on_accelerator @@ -154,6 +162,7 @@ def train_pq( *, sample_rate: int = 256, max_iters: int = 50, + fragment_ids: Optional[list[int]] = None, ) -> PqModel: """ Train a PQ model for a given column. @@ -184,10 +193,12 @@ def train_pq( This parameter is used in the same way as in the IVF model. max_iters: int This parameter is used in the same way as in the IVF model. + fragment_ids: list[int], optional + If provided, train using only the specified fragments from the dataset. """ from lance.lance import indices - num_rows = self.dataset.count_rows() + num_rows = self._count_rows(fragment_ids) self.dataset.schema.field(self.column[0]).type.list_size num_subvectors = self._normalize_pq_params(num_subvectors, self.dimension) self._verify_pq_sample_rate(num_rows, sample_rate) @@ -201,6 +212,7 @@ def train_pq( sample_rate, max_iters, ivf_model.centroids, + fragment_ids, ) return PqModel(num_subvectors, pq_codebook) @@ -213,11 +225,17 @@ def prepare_global_ivf_pq( accelerator: Optional[Union[str, "torch.Device"]] = None, sample_rate: int = 256, max_iters: int = 50, + fragment_ids: Optional[list[int]] = None, ) -> dict: """ Perform global training for IVF+PQ using existing CPU training paths and return preprocessed artifacts for distributed builds. + Parameters + ---------- + fragment_ids: list[int], optional + If provided, train using only the specified fragments from the dataset. + Returns ------- dict @@ -239,6 +257,7 @@ def prepare_global_ivf_pq( accelerator=accelerator, # None by default (CPU path) sample_rate=sample_rate, max_iters=max_iters, + fragment_ids=fragment_ids, ) # Global PQ training using IVF residuals @@ -247,6 +266,7 @@ def prepare_global_ivf_pq( num_subvectors, sample_rate=sample_rate, max_iters=max_iters, + fragment_ids=fragment_ids, ) return {"ivf_centroids": ivf_model.centroids, "pq_codebook": pq_model.codebook} @@ -459,6 +479,18 @@ def _determine_num_partitions(self, num_partitions: Optional[int], num_rows: int return round(math.sqrt(num_rows)) return num_partitions + def _count_rows(self, fragment_ids: Optional[list[int]] = None) -> int: + if fragment_ids is None: + return self.dataset.count_rows() + + num_rows = 0 + for fragment_id in fragment_ids: + fragment = self.dataset.get_fragment(fragment_id) + if fragment is None: + raise ValueError(f"Fragment id does not exist: {fragment_id}") + num_rows += fragment.count_rows() + return num_rows + def _normalize_pq_params(self, num_subvectors: int, dimension: int): if num_subvectors is None: if dimension % 16 == 0: diff --git a/python/python/lance/lance/indices/__init__.pyi b/python/python/lance/lance/indices/__init__.pyi index e1282b675a6..3369b61c619 100644 --- a/python/python/lance/lance/indices/__init__.pyi +++ b/python/python/lance/lance/indices/__init__.pyi @@ -46,6 +46,7 @@ def train_ivf_model( distance_type: str, sample_rate: int, max_iters: int, + fragment_ids: Optional[list[int]] = None, ) -> pa.Array: ... def train_pq_model( dataset, @@ -56,6 +57,7 @@ def train_pq_model( sample_rate: int, max_iters: int, ivf_model: pa.Array, + fragment_ids: Optional[list[int]] = None, ) -> pa.Array: ... def transform_vectors( dataset, diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py index f0444e4abcb..88cae659561 100644 --- a/python/python/tests/test_indices.py +++ b/python/python/tests/test_indices.py @@ -184,6 +184,58 @@ def test_gen_pq(tmpdir, rand_dataset, rand_ivf): assert pq.codebook == reloaded.codebook +def test_ivf_centroids_fragment_ids(tmpdir): + rows_per_fragment = 32 + vectors = np.concatenate( + [ + np.zeros((rows_per_fragment, DIMENSION), dtype=np.float32), + np.full((rows_per_fragment, DIMENSION), 10.0, dtype=np.float32), + ], + axis=0, + ) + vectors.shape = -1 + table = pa.Table.from_arrays( + [pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)], names=["vectors"] + ) + ds = lance.write_dataset( + table, + pathlib.Path(tmpdir) / "fragment_ivf", + max_rows_per_file=rows_per_fragment, + ) + fragment_ids = [fragment.fragment_id for fragment in ds.get_fragments()] + + first_ivf = IndicesBuilder(ds, "vectors").train_ivf( + num_partitions=1, sample_rate=2, fragment_ids=[fragment_ids[0]] + ) + second_ivf = IndicesBuilder(ds, "vectors").train_ivf( + num_partitions=1, sample_rate=2, fragment_ids=[fragment_ids[1]] + ) + + first_centroid = first_ivf.centroids.values.to_numpy().reshape(-1, DIMENSION)[0] + second_centroid = second_ivf.centroids.values.to_numpy().reshape(-1, DIMENSION)[0] + + assert np.allclose(first_centroid, 0.0, atol=1e-4) + assert np.allclose(second_centroid, 10.0, atol=1e-4) + + +def test_pq_fragment_ids(rand_dataset): + fragment_id = rand_dataset.get_fragments()[0].fragment_id + ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf( + num_partitions=4, + sample_rate=16, + fragment_ids=[fragment_id], + ) + + pq = IndicesBuilder(rand_dataset, "vectors").train_pq( + ivf, + sample_rate=2, + fragment_ids=[fragment_id], + ) + + assert pq.dimension == DIMENSION + assert pq.num_subvectors == NUM_SUBVECTORS + + def test_pq_invalid_sub_vectors(tmpdir, rand_dataset, rand_ivf): with pytest.raises( ValueError, diff --git a/python/src/indices.rs b/python/src/indices.rs index 9651c6cc00e..895ceb19cf9 100644 --- a/python/src/indices.rs +++ b/python/src/indices.rs @@ -11,6 +11,7 @@ use arrow_data::ArrayData; use chrono::{DateTime, Utc}; use lance::dataset::Dataset as LanceDataset; use lance::index::vector::ivf::builder::write_vector_storage; +use lance::index::vector::pq::build_pq_model_in_fragments; use lance::index::{DatasetIndexExt, IndexSegment, IndexSegmentPlan}; use lance::io::ObjectStore; use lance_index::progress::NoopIndexBuildProgress; @@ -198,6 +199,7 @@ fn get_ivf_model(py: Python<'_>, dataset: &Dataset, index_name: &str) -> PyResul Py::new(py, PyIvfModel { inner: ivf_model }) } +#[allow(clippy::too_many_arguments)] async fn do_train_ivf_model( dataset: &Dataset, column: &str, @@ -206,6 +208,7 @@ async fn do_train_ivf_model( distance_type: &str, sample_rate: u32, max_iters: u32, + fragment_ids: Option>, ) -> PyResult { // We verify distance_type earlier so can unwrap here let distance_type = DistanceType::try_from(distance_type).unwrap(); @@ -221,6 +224,7 @@ async fn do_train_ivf_model( dimension, distance_type, ¶ms, + fragment_ids.as_deref(), Arc::new(NoopIndexBuildProgress), ) .await @@ -231,6 +235,7 @@ async fn do_train_ivf_model( #[pyfunction] #[allow(clippy::too_many_arguments)] +#[pyo3(signature=(dataset, column, dimension, num_partitions, distance_type, sample_rate, max_iters, fragment_ids=None))] fn train_ivf_model<'py>( py: Python<'py>, dataset: &Dataset, @@ -240,6 +245,7 @@ fn train_ivf_model<'py>( distance_type: &str, sample_rate: u32, max_iters: u32, + fragment_ids: Option>, ) -> PyResult> { let centroids = rt().block_on( Some(py), @@ -251,6 +257,7 @@ fn train_ivf_model<'py>( distance_type, sample_rate, max_iters, + fragment_ids, ), )??; centroids.to_pyarrow(py) @@ -266,6 +273,7 @@ async fn do_train_pq_model( sample_rate: u32, max_iters: u32, ivf_model: IvfModel, + fragment_ids: Option>, ) -> PyResult { // We verify distance_type earlier so can unwrap here let distance_type = DistanceType::try_from(distance_type).unwrap(); @@ -276,13 +284,14 @@ async fn do_train_pq_model( sample_rate: sample_rate as usize, ..Default::default() }; - let pq_model = lance::index::vector::pq::build_pq_model( + let pq_model = build_pq_model_in_fragments( dataset.ds.as_ref(), column, dimension, distance_type, ¶ms, Some(&ivf_model), + fragment_ids.as_deref(), ) .await .infer_error()?; @@ -291,6 +300,7 @@ async fn do_train_pq_model( #[pyfunction] #[allow(clippy::too_many_arguments)] +#[pyo3(signature=(dataset, column, dimension, num_subvectors, distance_type, sample_rate, max_iters, ivf_centroids, fragment_ids=None))] fn train_pq_model<'py>( py: Python<'py>, dataset: &Dataset, @@ -301,6 +311,7 @@ fn train_pq_model<'py>( sample_rate: u32, max_iters: u32, ivf_centroids: PyArrowType, + fragment_ids: Option>, ) -> PyResult> { let ivf_centroids = ivf_centroids.0; let ivf_centroids = FixedSizeListArray::from(ivf_centroids); @@ -321,6 +332,7 @@ fn train_pq_model<'py>( sample_rate, max_iters, ivf_model, + fragment_ids, ), )??; codebook.to_pyarrow(py) diff --git a/rust/lance/src/index/create.rs b/rust/lance/src/index/create.rs index 0bf9fdd283c..7403e2b24c9 100644 --- a/rust/lance/src/index/create.rs +++ b/rust/lance/src/index/create.rs @@ -640,8 +640,9 @@ mod tests { use crate::utils::test::{DatagenExt, FragmentCount, FragmentRowCount}; use arrow::datatypes::{Float32Type, Int32Type}; use arrow_array::cast::AsArray; - use arrow_array::{FixedSizeListArray, RecordBatchIterator}; - use arrow_array::{Int32Array, RecordBatch, StringArray}; + use arrow_array::{ + FixedSizeListArray, Int32Array, RecordBatch, RecordBatchIterator, StringArray, + }; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; use lance_arrow::FixedSizeListArrayExt; use lance_core::utils::tempfile::TempStrDir; diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 4b5b113cae9..a60d4941a07 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -569,7 +569,6 @@ pub(crate) async fn build_distributed_vector_index( stages ))); }; - IvfIndexBuilder::::new( filtered_dataset, column.to_owned(), @@ -664,7 +663,6 @@ pub(crate) async fn build_distributed_vector_index( stages ))); }; - IvfIndexBuilder::::new( filtered_dataset, column.to_owned(), @@ -2141,6 +2139,7 @@ mod tests { dim, MetricType::L2, &ivf_params, + None, noop_progress(), ) .await @@ -2193,6 +2192,7 @@ mod tests { dim, MetricType::L2, &ivf_params, + None, noop_progress(), ) .await diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 1f871046b71..017d50319f8 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -405,6 +405,7 @@ impl IvfIndexBuilder dim, self.distance_type, ivf_params, + None, self.progress.clone(), ) .await @@ -434,7 +435,8 @@ impl IvfIndexBuilder sample_size_hint ); let training_data = - utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint).await?; + utils::maybe_sample_training_data(dataset, &self.column, sample_size_hint, None) + .await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 90864154094..d58cbae2fcd 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -1223,6 +1223,7 @@ pub async fn build_ivf_model( dim: usize, metric_type: MetricType, params: &IvfBuildParams, + fragment_ids: Option<&[u32]>, progress: std::sync::Arc, ) -> Result { let num_partitions = params.num_partitions.unwrap(); @@ -1245,7 +1246,8 @@ pub async fn build_ivf_model( "Loading training data for IVF. Sample size: {}", sample_size_hint ); - let training_data = maybe_sample_training_data(dataset, column, sample_size_hint).await?; + let training_data = + maybe_sample_training_data(dataset, column, sample_size_hint, fragment_ids).await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() @@ -1302,8 +1304,16 @@ async fn build_ivf_model_and_pq( get_vector_type(dataset.schema(), column)?; let dim = get_vector_dim(dataset.schema(), column)?; - let ivf_model = - build_ivf_model(dataset, column, dim, metric_type, ivf_params, progress).await?; + let ivf_model = build_ivf_model( + dataset, + column, + dim, + metric_type, + ivf_params, + None, + progress, + ) + .await?; let ivf_residual = if matches!(metric_type, MetricType::Cosine | MetricType::L2) { Some(&ivf_model) @@ -3287,6 +3297,7 @@ mod tests { DIM, MetricType::L2, &ivf_params, + None, lance_index::progress::noop_progress(), ) .await @@ -3322,6 +3333,7 @@ mod tests { DIM, MetricType::Cosine, &ivf_params, + None, lance_index::progress::noop_progress(), ) .await @@ -3881,6 +3893,7 @@ mod tests { DIM, MetricType::L2, &ivf_params, + None, progress, ) .await diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index d89999d921c..615f1b9c829 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -502,6 +502,18 @@ pub async fn build_pq_model( metric_type: MetricType, params: &PQBuildParams, ivf: Option<&IvfModel>, +) -> Result { + build_pq_model_in_fragments(dataset, column, dim, metric_type, params, ivf, None).await +} + +pub async fn build_pq_model_in_fragments( + dataset: &Dataset, + column: &str, + dim: usize, + metric_type: MetricType, + params: &PQBuildParams, + ivf: Option<&IvfModel>, + fragment_ids: Option<&[u32]>, ) -> Result { let num_codes = 2_usize.pow(params.num_bits as u32); @@ -542,7 +554,7 @@ pub async fn build_pq_model( ); let start = std::time::Instant::now(); let mut training_data = - maybe_sample_training_data(dataset, column, expected_sample_size).await?; + maybe_sample_training_data(dataset, column, expected_sample_size, fragment_ids).await?; info!( "Finished loading training data in {:02} seconds", start.elapsed().as_secs_f32() @@ -712,6 +724,7 @@ mod tests { DIM, MetricType::Cosine, &ivf_params, + None, lance_index::progress::noop_progress(), ) .await diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index b20d659d6f3..f00a81b764d 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -6,7 +6,7 @@ use std::sync::Arc; use arrow::array::ArrayData; use arrow::datatypes::DataType; use arrow_array::new_empty_array; -use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, cast::AsArray}; +use arrow_array::{Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt32Array, cast::AsArray}; use arrow_buffer::{Buffer, MutableBuffer}; use futures::StreamExt; use lance_arrow::DataTypeExt; @@ -86,6 +86,7 @@ async fn estimate_multivector_vectors_per_row( dataset: &Dataset, column: &str, num_rows: usize, + fragments: Option<&[u32]>, ) -> Result { if num_rows == 0 { return Ok(1030); @@ -96,7 +97,9 @@ async fn estimate_multivector_vectors_per_row( // Try a few random samples first (fast path). let sample_batch_size = std::cmp::min(64, num_rows); for _ in 0..8 { - let batch = dataset.sample(sample_batch_size, &projection, None).await?; + let batch = dataset + .sample(sample_batch_size, &projection, fragments) + .await?; let array = get_column_from_batch(&batch, column)?; let list_array = array.as_list::(); for i in 0..list_array.len() { @@ -114,6 +117,9 @@ async fn estimate_multivector_vectors_per_row( // flakiness when values are extremely sparse. let mut scanner = dataset.scan(); scanner.project(&[column])?; + if let Some(fragments) = fragments { + scanner.with_fragments(resolve_scan_fragments(dataset, fragments)?); + } let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; scanner.filter_expr(column_expr.is_not_null()); scanner.limit(Some(std::cmp::min(num_rows, 1024) as i64), None)?; @@ -261,8 +267,15 @@ pub async fn maybe_sample_training_data( dataset: &Dataset, column: &str, sample_size_hint: usize, + fragment_ids: Option<&[u32]>, ) -> Result { - let num_rows = dataset.count_rows(None).await?; + let num_rows = if let Some(fragment_ids) = fragment_ids { + let mut scanner = dataset.scan(); + scanner.with_fragments(resolve_scan_fragments(dataset, fragment_ids)?); + scanner.count_rows().await? as usize + } else { + dataset.count_rows(None).await? + }; let vector_field = dataset.schema().field(column).ok_or(Error::index(format!( "Sample training data: column {} does not exist in schema", @@ -291,7 +304,8 @@ pub async fn maybe_sample_training_data( // Set a minimum sample size of 128 to avoid too small samples, // it's not a problem because 128 multivectors is just about 64 MiB let vectors_per_row = - estimate_multivector_vectors_per_row(dataset, column, num_rows).await?; + estimate_multivector_vectors_per_row(dataset, column, num_rows, fragment_ids) + .await?; sample_size_hint.div_ceil(vectors_per_row).max(128) } _ => sample_size_hint, @@ -306,11 +320,12 @@ pub async fn maybe_sample_training_data( num_rows, vector_field, is_nullable, + fragment_ids, ) .await } else { // too small to require sampling - let batch = scan_all_training_data(dataset, column, is_nullable).await?; + let batch = scan_all_training_data(dataset, column, is_nullable, fragment_ids).await?; vector_column_to_fsl(&batch, column) } } @@ -378,9 +393,13 @@ async fn scan_all_training_data( dataset: &Dataset, column: &str, is_nullable: bool, + fragment_ids: Option<&[u32]>, ) -> Result { let mut scanner = dataset.scan(); scanner.project(&[column])?; + if let Some(fragment_ids) = fragment_ids { + scanner.with_fragments(resolve_scan_fragments(dataset, fragment_ids)?); + } if is_nullable { let column_expr = lance_datafusion::logical_expr::field_path_to_expr(column)?; scanner.filter_expr(column_expr.is_not_null()); @@ -406,14 +425,38 @@ async fn sample_training_data( num_rows: usize, vector_field: &lance_core::datatypes::Field, is_nullable: bool, + fragment_ids: Option<&[u32]>, ) -> Result { + if fragment_ids.is_some() { + if !is_nullable { + let projection = dataset.schema().project(&[column])?; + let batch = dataset + .sample(sample_size_hint, &projection, fragment_ids) + .await?; + return vector_column_to_fsl(&batch, column); + } + + let batch = scan_all_training_data(dataset, column, is_nullable, fragment_ids).await?; + let training_data = vector_column_to_fsl(&batch, column)?; + if training_data.len() <= sample_size_hint { + return Ok(training_data); + } + let indices = UInt32Array::from_iter_values( + generate_random_indices(training_data.len(), sample_size_hint) + .into_iter() + .map(|index| index as u32), + ); + let sampled = arrow_select::take::take(&training_data, &indices, None)?; + return Ok(sampled.as_fixed_size_list().clone()); + } + let byte_width = vector_field .data_type() .byte_width_opt() .unwrap_or(4 * 1024); match vector_field.data_type() { - DataType::FixedSizeList(_, _) if !is_nullable => { + DataType::FixedSizeList(_, _) if !is_nullable && fragment_ids.is_none() => { sample_fsl_uniform( dataset, column, @@ -454,6 +497,28 @@ fn sample_training_data_scan( )) } +fn resolve_scan_fragments( + dataset: &Dataset, + fragment_ids: &[u32], +) -> Result> { + let mut ordered_ids = fragment_ids.to_vec(); + ordered_ids.sort_unstable(); + let fragments = dataset.get_frags_from_ordered_ids(&ordered_ids); + if let Some(missing_id) = fragments + .iter() + .zip(ordered_ids.iter()) + .find_map(|(fragment, fragment_id)| fragment.is_none().then_some(*fragment_id)) + { + return Err(Error::invalid_input(format!( + "Unknown fragment id {missing_id} in training fragment filter" + ))); + } + Ok(fragments + .into_iter() + .map(|fragment| fragment.unwrap().metadata().clone()) + .collect()) +} + /// Build a FixedSizeListArray from raw flat value bytes. fn fsl_values_to_array( field: &lance_core::datatypes::Field, @@ -821,7 +886,7 @@ mod tests { .await .unwrap(); - let training_data = maybe_sample_training_data(&dataset, "mv", 1000) + let training_data = maybe_sample_training_data(&dataset, "mv", 1000, None) .await .unwrap(); assert_eq!(training_data.len(), 1000); @@ -897,7 +962,7 @@ mod tests { .await .unwrap(); - let training_data = maybe_sample_training_data(&dataset, "vec", sample_size) + let training_data = maybe_sample_training_data(&dataset, "vec", sample_size, None) .await .unwrap(); @@ -983,7 +1048,7 @@ mod tests { .await .unwrap(); - let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows) + let n = estimate_multivector_vectors_per_row(&dataset, "mv", nrows, None) .await .unwrap(); assert_eq!(n, 1030);