From a44159c264d71b7daa5c8d5c4089594e88c3272c Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 18 Sep 2025 20:47:09 +0800 Subject: [PATCH 01/12] return dists Signed-off-by: BubbleCal --- java/lance-jni/Cargo.lock | 1 + rust/lance-index/src/vector.rs | 4 ++-- rust/lance-index/src/vector/hnsw/index.rs | 4 ++-- rust/lance-index/src/vector/ivf.rs | 8 ++++++-- rust/lance-index/src/vector/ivf/storage.rs | 2 +- rust/lance-index/src/vector/kmeans.rs | 18 +++++++++++++----- rust/lance/src/index/vector/fixture_test.rs | 2 +- rust/lance/src/index/vector/ivf.rs | 3 ++- rust/lance/src/index/vector/ivf/v2.rs | 4 ++-- rust/lance/src/index/vector/pq.rs | 2 +- rust/lance/src/io/exec/knn.rs | 2 +- 11 files changed, 32 insertions(+), 18 deletions(-) diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index dfbb529019d..2501eae8981 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -3381,6 +3381,7 @@ dependencies = [ "lance-io", "lance-linalg", "lance-table", + "libm", "log", "num-traits", "object_store", diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index ee2f9b9f2f5..ddb38c37f5b 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -8,7 +8,7 @@ use std::any::Any; use std::fmt::Debug; use std::{collections::HashMap, sync::Arc}; -use arrow_array::{ArrayRef, RecordBatch, UInt32Array}; +use arrow_array::{ArrayRef, Float32Array, RecordBatch, UInt32Array}; use arrow_schema::Field; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; @@ -166,7 +166,7 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index { /// partitions to the query vector). /// /// The results should be in sorted order from closest to farthest. - fn find_partitions(&self, query: &Query) -> Result; + fn find_partitions(&self, query: &Query) -> Result<(UInt32Array, Float32Array)>; /// Get the total number of partitions in the index. fn total_partitions(&self) -> usize; diff --git a/rust/lance-index/src/vector/hnsw/index.rs b/rust/lance-index/src/vector/hnsw/index.rs index 8d1364e17a1..e17471b0382 100644 --- a/rust/lance-index/src/vector/hnsw/index.rs +++ b/rust/lance-index/src/vector/hnsw/index.rs @@ -8,7 +8,7 @@ use std::{ sync::Arc, }; -use arrow_array::{RecordBatch, UInt32Array}; +use arrow_array::{Float32Array, RecordBatch, UInt32Array}; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -186,7 +186,7 @@ impl VectorIndex for HNSWIndex { ) } - fn find_partitions(&self, _: &Query) -> Result { + fn find_partitions(&self, _: &Query) -> Result<(UInt32Array, Float32Array)> { unimplemented!("only for IVF") } diff --git a/rust/lance-index/src/vector/ivf.rs b/rust/lance-index/src/vector/ivf.rs index 15b9b086a0c..0b171849edf 100644 --- a/rust/lance-index/src/vector/ivf.rs +++ b/rust/lance-index/src/vector/ivf.rs @@ -6,7 +6,7 @@ use std::ops::Range; use std::sync::Arc; -use arrow_array::{Array, FixedSizeListArray, RecordBatch, UInt32Array}; +use arrow_array::{Array, FixedSizeListArray, Float32Array, RecordBatch, UInt32Array}; pub use builder::IvfBuildParams; use lance_core::Result; @@ -255,7 +255,11 @@ impl IvfTransformer { ) } - pub fn find_partitions(&self, query: &dyn Array, nprobes: usize) -> Result { + pub fn find_partitions( + &self, + query: &dyn Array, + nprobes: usize, + ) -> Result<(UInt32Array, Float32Array)> { Ok(kmeans_find_partitions_arrow_array( &self.centroids, query, diff --git a/rust/lance-index/src/vector/ivf/storage.rs b/rust/lance-index/src/vector/ivf/storage.rs index a9571dca366..3250967285b 100644 --- a/rust/lance-index/src/vector/ivf/storage.rs +++ b/rust/lance-index/src/vector/ivf/storage.rs @@ -107,7 +107,7 @@ impl IvfModel { query: &dyn Array, nprobes: usize, distance_type: DistanceType, - ) -> Result { + ) -> Result<(UInt32Array, Float32Array)> { let internal = crate::vector::ivf::new_ivf_transformer( self.centroids.clone().unwrap(), distance_type, diff --git a/rust/lance-index/src/vector/kmeans.rs b/rust/lance-index/src/vector/kmeans.rs index 40fbe691859..7a9083d10cf 100644 --- a/rust/lance-index/src/vector/kmeans.rs +++ b/rust/lance-index/src/vector/kmeans.rs @@ -1015,7 +1015,7 @@ pub fn kmeans_find_partitions_arrow_array( query: &dyn Array, nprobes: usize, distance_type: DistanceType, -) -> arrow::error::Result { +) -> arrow::error::Result<(UInt32Array, Float32Array)> { if centroids.value_length() as usize != query.len() { return Err(ArrowError::InvalidArgumentError(format!( "Centroids and vectors have different dimensions: {} != {}", @@ -1073,7 +1073,7 @@ pub fn kmeans_find_partitions( query: &[T], nprobes: usize, distance_type: DistanceType, -) -> arrow::error::Result { +) -> arrow::error::Result<(UInt32Array, Float32Array)> { let dists: Vec = match distance_type { DistanceType::L2 => l2_distance_batch(query, centroids, query.len()).collect(), DistanceType::Dot => dot_distance_batch(query, centroids, query.len()).collect(), @@ -1087,7 +1087,11 @@ pub fn kmeans_find_partitions( // TODO: use heap to just keep nprobes smallest values. let dists_arr = Float32Array::from(dists); - sort_to_indices(&dists_arr, None, Some(nprobes)) + let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?; + let dists = arrow::compute::take(&dists_arr, &indices, None)? + .as_primitive::() + .clone(); + Ok((indices, dists)) } pub fn kmeans_find_partitions_binary( @@ -1095,7 +1099,7 @@ pub fn kmeans_find_partitions_binary( query: &[u8], nprobes: usize, distance_type: DistanceType, -) -> arrow::error::Result { +) -> arrow::error::Result<(UInt32Array, Float32Array)> { let dists: Vec = match distance_type { DistanceType::Hamming => hamming_distance_batch(query, centroids, query.len()).collect(), _ => { @@ -1108,7 +1112,11 @@ pub fn kmeans_find_partitions_binary( // TODO: use heap to just keep nprobes smallest values. let dists_arr = Float32Array::from(dists); - sort_to_indices(&dists_arr, None, Some(nprobes)) + let indices = sort_to_indices(&dists_arr, None, Some(nprobes))?; + let dists = arrow::compute::take(&dists_arr, &indices, None)? + .as_primitive::() + .clone(); + Ok((indices, dists)) } /// Compute partitions from Arrow FixedSizeListArray. diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 8714dc15784..cd0c0e3a944 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -109,7 +109,7 @@ mod test { Ok(self.ret_val.clone()) } - fn find_partitions(&self, _: &Query) -> Result { + fn find_partitions(&self, _: &Query) -> Result<(UInt32Array, Float32Array)> { unimplemented!("only for IVF") } diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 151d249c084..13f0e2386a8 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -19,6 +19,7 @@ use crate::{ }; use arrow::datatypes::UInt8Type; use arrow_arith::numeric::sub; +use arrow_array::Float32Array; use arrow_array::{ cast::AsArray, types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type}, @@ -923,7 +924,7 @@ impl VectorIndex for IVFIndex { /// Internal API with no stability guarantees. /// /// Assumes the query vector is normalized if the metric type is cosine. - fn find_partitions(&self, query: &Query) -> Result { + fn find_partitions(&self, query: &Query) -> Result<(UInt32Array, Float32Array)> { let mt = if self.metric_type == MetricType::Cosine { MetricType::L2 } else { diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 0cfb32c63e9..7dd27fd8585 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -16,7 +16,7 @@ use crate::index::{ }; use arrow::compute::concat_batches; use arrow_arith::numeric::sub; -use arrow_array::{RecordBatch, UInt32Array}; +use arrow_array::{Float32Array, RecordBatch, UInt32Array}; use async_trait::async_trait; use datafusion::execution::SendableRecordBatchStream; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; @@ -447,7 +447,7 @@ impl VectorIndex for IVFInd unimplemented!("IVFIndex not currently used as sub-index and top-level indices do partition-aware search") } - fn find_partitions(&self, query: &Query) -> Result { + fn find_partitions(&self, query: &Query) -> Result<(UInt32Array, Float32Array)> { let dt = if self.distance_type == DistanceType::Cosine { DistanceType::L2 } else { diff --git a/rust/lance/src/index/vector/pq.rs b/rust/lance/src/index/vector/pq.rs index 4b82dc9e158..5c685b279e7 100644 --- a/rust/lance/src/index/vector/pq.rs +++ b/rust/lance/src/index/vector/pq.rs @@ -290,7 +290,7 @@ impl VectorIndex for PQIndex { .await } - fn find_partitions(&self, _: &Query) -> Result { + fn find_partitions(&self, _: &Query) -> Result<(UInt32Array, Float32Array)> { unimplemented!("only for IVF") } diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 4297a440a78..83b17e7be1a 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -491,7 +491,7 @@ impl ExecutionPlan for ANNIvfPartitionExec { metrics.partitions_ranked.add(index.total_partitions()); - let partitions = index.find_partitions(&query).map_err(|e| { + let (partitions, _dists) = index.find_partitions(&query).map_err(|e| { DataFusionError::Execution(format!("Failed to find partitions: {}", e)) })?; From c7e73358a818b4c80edc3532731c1caca90bdc5d Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 19 Sep 2025 17:17:43 +0800 Subject: [PATCH 02/12] perf: dynamic pruning for vector search Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 56 ++++++++++++++++++++++++++++------- 1 file changed, 45 insertions(+), 11 deletions(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 83b17e7be1a..6aed0642d92 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -7,6 +7,7 @@ use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::{Arc, LazyLock, Mutex}; use std::time::Instant; +use arrow::array::Float32Builder; use arrow::datatypes::{Float32Type, UInt32Type, UInt64Type}; use arrow_array::{ builder::{ListBuilder, UInt32Builder}, @@ -296,6 +297,11 @@ pub static KNN_PARTITION_SCHEMA: LazyLock = LazyLock::new(|| { false, ), Field::new(INDEX_UUID_COLUMN, DataType::Utf8, false), + Field::new( + DIST_COL, + DataType::List(Field::new_list_field(DataType::Float32, false).into()), + false, + ), ])) }); @@ -491,18 +497,26 @@ impl ExecutionPlan for ANNIvfPartitionExec { metrics.partitions_ranked.add(index.total_partitions()); - let (partitions, _dists) = index.find_partitions(&query).map_err(|e| { + let (partitions, dists) = index.find_partitions(&query).map_err(|e| { DataFusionError::Execution(format!("Failed to find partitions: {}", e)) })?; - let mut list_builder = ListBuilder::new(UInt32Builder::new()) + let mut part_id_builder = ListBuilder::new(UInt32Builder::new()) .with_field(Field::new("item", DataType::UInt32, false)); - list_builder.append_value(partitions.iter()); - let partition_col = list_builder.finish(); + part_id_builder.append_value(partitions.iter()); + let mut dist_builder = ListBuilder::new(Float32Builder::new()) + .with_field(Field::new("item", DataType::Float32, false)); + dist_builder.append_value(dists.iter()); + let part_id_col = part_id_builder.finish(); + let dist_col = dist_builder.finish(); let uuid_col = StringArray::from(vec![uuid.as_str()]); let batch = RecordBatch::try_new( KNN_PARTITION_SCHEMA.clone(), - vec![Arc::new(partition_col), Arc::new(uuid_col)], + vec![ + Arc::new(part_id_col), + Arc::new(uuid_col), + Arc::new(dist_col), + ], )?; metrics.baseline_metrics.record_output(batch.num_rows()); Ok::<_, DataFusionError>(batch) @@ -940,19 +954,25 @@ impl ExecutionPlan for ANNIvfSubIndexExec { .column_by_name(PART_ID_COLUMN) .expect("ANNSubIndexExec: input missing part_id column"); let part_id_arr = part_id_col.as_list::().clone(); + let dist_col = batch + .column_by_name(DIST_COL) + .expect("ANNSubIndexExec: input missing dist column"); + let dist_arr = dist_col.as_list::().clone(); let index_uuid_col = batch .column_by_name(INDEX_UUID_COLUMN) .expect("ANNSubIndexExec: input missing index_uuid column"); let index_uuid = index_uuid_col.as_string::().clone(); - let plan: Vec> = part_id_arr + let plan: Vec> = part_id_arr .iter() .zip(index_uuid.iter()) - .map(|(part_id, uuid)| { + .zip(dist_arr.iter()) + .map(|((part_id, uuid), dists)| { let partitions = Arc::new(part_id.unwrap().as_primitive::().clone()); let uuid = uuid.unwrap().to_string(); - Ok((partitions, uuid)) + let dists = dists.unwrap().as_primitive::().clone(); + Ok((partitions, uuid, Arc::new(dists))) }) .collect_vec(); async move { DataFusionResult::Ok(stream::iter(plan)) } @@ -981,14 +1001,17 @@ impl ExecutionPlan for ANNIvfSubIndexExec { Ok(Box::pin(RecordBatchStreamAdapter::new( schema, per_index_stream - .and_then(move |(part_ids, index_uuid)| { + .and_then(move |(part_ids, index_uuid, dists)| { let ds = ds.clone(); let column = column.clone(); let metrics = metrics.clone(); let pre_filter = pre_filter.clone(); let state = state.clone(); - let query = query.clone(); - + let mut query = query.clone(); + query.minimum_nprobes = std::cmp::min( + query.minimum_nprobes, + early_pruning(dists.values(), query.k), + ); async move { let raw_index = ds .open_vector_index(&column, &index_uuid, &metrics.index_metrics) @@ -1056,6 +1079,17 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } } +fn early_pruning(dists: &[f32], k: usize) -> usize { + const PRUNING_FACTORS: [f32; 3] = [0.6, 7.0, 81.0]; + let factor = match k { + ..=1 => PRUNING_FACTORS[0], + 2..=10 => PRUNING_FACTORS[1], + 11.. => PRUNING_FACTORS[2], + }; + let dist_threshold = dists[0] * factor; + dists.partition_point(|dist| *dist <= dist_threshold) +} + #[derive(Debug)] pub struct MultivectorScoringExec { // the inputs are sorted ANN search results From 7326200815cda10096137c452ebb7ec0c4468635 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 19 Sep 2025 18:49:58 +0800 Subject: [PATCH 03/12] fmt Signed-off-by: BubbleCal --- rust/lance/src/index/vector/fixture_test.rs | 2 +- rust/lance/src/index/vector/ivf.rs | 2 +- rust/lance/src/session/index_extension.rs | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index cd0c0e3a944..ed43afdd4f6 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -263,7 +263,7 @@ mod test { use_index: true, }; let idx = make_idx.clone()(expected_query_at_subindex, metric).await; - let partition_ids = idx.find_partitions(&q).unwrap(); + let (partition_ids, _) = idx.find_partitions(&q).unwrap(); assert_eq!(partition_ids.len(), 4); let nearest_partition_id = partition_ids.value(0); idx.search_in_partition( diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 13f0e2386a8..4a0f287387b 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2071,7 +2071,7 @@ mod tests { metric_type: MetricType::L2, use_index: true, }; - let partitions = index.find_partitions(&query).unwrap(); + let (partitions, _) = index.find_partitions(&query).unwrap(); let nearest_partition_id = partitions.value(0) as usize; let search_result = index .search_in_partition( diff --git a/rust/lance/src/session/index_extension.rs b/rust/lance/src/session/index_extension.rs index 9502131318e..ce684f7b4cb 100644 --- a/rust/lance/src/session/index_extension.rs +++ b/rust/lance/src/session/index_extension.rs @@ -65,7 +65,7 @@ mod test { sync::{atomic::AtomicBool, Arc}, }; - use arrow_array::{RecordBatch, UInt32Array}; + use arrow_array::{Float32Array, RecordBatch, UInt32Array}; use arrow_schema::Schema; use datafusion::execution::SendableRecordBatchStream; use deepsize::DeepSizeOf; @@ -140,7 +140,7 @@ mod test { unimplemented!() } - fn find_partitions(&self, _: &Query) -> Result { + fn find_partitions(&self, _: &Query) -> Result<(UInt32Array, Float32Array)> { unimplemented!() } From db5f098effb66b71eac9b37f867061341ffe5ea4 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 19 Sep 2025 19:28:38 +0800 Subject: [PATCH 04/12] fix ut Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 6aed0642d92..0b0579a980c 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -1304,7 +1304,7 @@ mod tests { use rstest::rstest; use tempfile::{tempdir, TempDir}; - use crate::dataset::{WriteMode, WriteParams}; + use crate::dataset::{ProjectionRequest, WriteMode, WriteParams}; use crate::index::vector::VectorIndexParams; use crate::io::exec::testing::TestingExec; @@ -1665,7 +1665,17 @@ mod tests { async fn test_no_prefilter_results(#[values(1, 20)] num_deltas: usize) { let fixture = NprobesTestFixture::new(100, num_deltas).await; - let q = fixture.get_centroid(0); + let q = fixture + .dataset + .take( + &[0], + ProjectionRequest::from_schema(fixture.dataset.schema().clone()), + ) + .await + .unwrap() + .column_by_name("vector") + .unwrap() + .clone(); let stats_holder = StatsHolder::default(); let results = fixture From c9c3d229c8a3c588383dd99868d491f2a97c3eb3 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 5 Nov 2025 19:38:36 +0800 Subject: [PATCH 05/12] fix Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index d22ddfe5087..097f619f696 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -1033,7 +1033,7 @@ impl ExecutionPlan for ANNIvfSubIndexExec { let mut query = query.clone(); query.minimum_nprobes = std::cmp::min( query.minimum_nprobes, - early_pruning(dists.values(), query.k), + early_pruning(q_c_dists.values(), query.k), ); async move { let raw_index = ds From 4aa97e9b8d2b1089cd7cb322c6c728e25b8bbaa7 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Wed, 5 Nov 2025 20:36:56 +0800 Subject: [PATCH 06/12] fix Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 097f619f696..c065c809c81 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -306,11 +306,6 @@ pub static KNN_PARTITION_SCHEMA: LazyLock = LazyLock::new(|| { false, ), Field::new(INDEX_UUID_COLUMN, DataType::Utf8, false), - Field::new( - DIST_COL, - DataType::List(Field::new_list_field(DataType::Float32, false).into()), - false, - ), ])) }); From b25aa20c5778b635b17d1ff699f8cf9f8874e991 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 6 Nov 2025 15:25:41 +0800 Subject: [PATCH 07/12] fmt Signed-off-by: BubbleCal --- java/lance-jni/src/blocking_scanner.rs | 2 +- java/lance-jni/src/utils.rs | 2 +- .../java/com/lancedb/lance/ipc/Query.java | 41 ++++-- python/src/dataset.rs | 37 +++-- rust/examples/src/ivf_hnsw.rs | 2 +- rust/lance-index/src/vector.rs | 4 +- rust/lance/benches/vector_index.rs | 6 +- rust/lance/src/dataset/scanner.rs | 51 +++++-- rust/lance/src/index/vector/fixture_test.rs | 2 +- rust/lance/src/index/vector/ivf.rs | 10 +- rust/lance/src/index/vector/ivf/v2.rs | 14 +- rust/lance/src/io/exec/knn.rs | 126 +++++++++++++----- 12 files changed, 217 insertions(+), 80 deletions(-) diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index c97cbbc170a..3be1fd0d75e 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -179,7 +179,7 @@ fn inner_create_scanner<'local>( let k = env.get_int_as_usize_from_method(&java_obj, "getK")?; let _ = scanner.nearest(&column, &key, k); - let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?; + let minimum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMinimumNprobes")?; scanner.minimum_nprobes(minimum_nprobes); let maximum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMaximumNprobes")?; diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 495ab229f4b..6ec0c560f97 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -133,7 +133,7 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> let key = Arc::new(Float32Array::from(key_array)); let k = env.get_int_as_usize_from_method(&java_obj, "getK")?; - let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?; + let minimum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMinimumNprobes")?; let maximum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMaximumNprobes")?; let ef = env.get_optional_usize_from_method(&java_obj, "getEf")?; diff --git a/java/src/main/java/com/lancedb/lance/ipc/Query.java b/java/src/main/java/com/lancedb/lance/ipc/Query.java index 46af8692c5f..d30f6cee5e5 100644 --- a/java/src/main/java/com/lancedb/lance/ipc/Query.java +++ b/java/src/main/java/com/lancedb/lance/ipc/Query.java @@ -25,7 +25,7 @@ public class Query { private final String column; private final float[] key; private final int k; - private final int minimumNprobes; + private final Optional minimumNprobes; private final Optional maximumNprobes; private final Optional ef; private final Optional refineFactor; @@ -38,10 +38,13 @@ private Query(Builder builder) { this.key = Preconditions.checkNotNull(builder.key, "Key must be set"); Preconditions.checkArgument(builder.k > 0, "K must be greater than 0"); Preconditions.checkArgument( - builder.minimumNprobes > 0, "Minimum Nprobes must be greater than 0"); + builder.minimumNprobes.map(n -> n > 0).orElse(true), + "Minimum Nprobes must be greater than 0"); Preconditions.checkArgument( !builder.maximumNprobes.isPresent() - || builder.maximumNprobes.get() >= builder.minimumNprobes, + || builder.minimumNprobes + .map(min -> builder.maximumNprobes.get() >= min) + .orElse(true), "Maximum Nprobes must be greater than minimum Nprobes"); this.k = builder.k; this.minimumNprobes = builder.minimumNprobes; @@ -64,7 +67,7 @@ public int getK() { return k; } - public int getMinimumNprobes() { + public Optional getMinimumNprobes() { return minimumNprobes; } @@ -94,7 +97,7 @@ public String toString() { .add("column", column) .add("key", key) .add("k", k) - .add("minimumNprobes", minimumNprobes) + .add("minimumNprobes", minimumNprobes.orElse(null)) .add("maximumNprobes", maximumNprobes.orElse(null)) .add("ef", ef.orElse(null)) .add("refineFactor", refineFactor.orElse(null)) @@ -107,7 +110,7 @@ public static class Builder { private String column; private float[] key; private int k = 10; - private int minimumNprobes = 20; + private Optional minimumNprobes = Optional.of(20); private Optional maximumNprobes = Optional.empty(); private Optional ef = Optional.empty(); private Optional refineFactor = Optional.empty(); @@ -157,7 +160,7 @@ public Builder setK(int k) { * @return The Builder instance for method chaining. */ public Builder setNprobes(int nprobes) { - this.minimumNprobes = nprobes; + this.minimumNprobes = Optional.of(nprobes); this.maximumNprobes = Optional.of(nprobes); return this; } @@ -172,7 +175,29 @@ public Builder setNprobes(int nprobes) { * @return The Builder instance for method chaining. */ public Builder setMinimumNprobes(int minimumNprobes) { - this.minimumNprobes = minimumNprobes; + this.minimumNprobes = Optional.of(minimumNprobes); + return this; + } + + /** + * Sets the minimum number of partitions to search. + * + * @param minimumNprobes The optional number of partitions to search. + * @return The Builder instance for method chaining. + */ + public Builder setMinimumNprobes(Optional minimumNprobes) { + this.minimumNprobes = + Preconditions.checkNotNull(minimumNprobes, "minimumNprobes must not be null"); + return this; + } + + /** + * Clears any previously configured minimum number of partitions to search. + * + * @return The Builder instance for method chaining. + */ + public Builder clearMinimumNprobes() { + this.minimumNprobes = Optional.empty(); return this; } diff --git a/python/src/dataset.rs b/python/src/dataset.rs index eb67b60e12b..4d7222bd322 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -981,19 +981,22 @@ impl Dataset { 10 }; - let mut minimum_nprobes = DEFAULT_NPROBES; + let mut minimum_nprobes = Some(DEFAULT_NPROBES); let mut maximum_nprobes = None; if let Some(nprobes) = nearest.get_item("nprobes")? { if !nprobes.is_none() { - minimum_nprobes = nprobes.extract()?; - maximum_nprobes = Some(minimum_nprobes); + let extracted: usize = nprobes.extract()?; + minimum_nprobes = Some(extracted); + maximum_nprobes = Some(extracted); } } if let Some(min_nprobes) = nearest.get_item("minimum_nprobes")? { - if !min_nprobes.is_none() { - minimum_nprobes = min_nprobes.extract()?; + if min_nprobes.is_none() { + minimum_nprobes = None; + } else { + minimum_nprobes = Some(min_nprobes.extract()?); } } @@ -1003,18 +1006,26 @@ impl Dataset { } } - if minimum_nprobes > maximum_nprobes.unwrap_or(usize::MAX) { - return Err(PyValueError::new_err( - "minimum_nprobes must be <= maximum_nprobes", - )); + if let (Some(minimum_nprobes), Some(maximum_nprobes)) = + (minimum_nprobes, maximum_nprobes) + { + if minimum_nprobes > maximum_nprobes { + return Err(PyValueError::new_err( + "minimum_nprobes must be <= maximum_nprobes", + )); + } } - if minimum_nprobes < 1 { - return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); + if let Some(minimum_nprobes) = minimum_nprobes { + if minimum_nprobes < 1 { + return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); + } } - if maximum_nprobes.unwrap_or(usize::MAX) < 1 { - return Err(PyValueError::new_err("maximum_nprobes must be >= 1")); + if let Some(maximum_nprobes) = maximum_nprobes { + if maximum_nprobes < 1 { + return Err(PyValueError::new_err("maximum_nprobes must be >= 1")); + } } let metric_type: Option = diff --git a/rust/examples/src/ivf_hnsw.rs b/rust/examples/src/ivf_hnsw.rs index 34bd4cbca7f..9e9aa4910b8 100644 --- a/rust/examples/src/ivf_hnsw.rs +++ b/rust/examples/src/ivf_hnsw.rs @@ -117,7 +117,7 @@ async fn main() { .with_row_id() .nearest(column, &q, args.k) .unwrap() - .minimum_nprobes(args.nprobe); + .minimum_nprobes(Some(args.nprobe)); println!("{:?}", plan.explain_plan(true).await.unwrap()); let now = std::time::Instant::now(); diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 9f472206c88..200e1b13dc1 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -87,7 +87,9 @@ pub struct Query { /// The minimum number of probes to load and search. More partitions /// will only be loaded if we have not found k results. - pub minimum_nprobes: usize, + /// + /// If None, the planner will decide how many partitions to search first. + pub minimum_nprobes: Option, /// The maximum number of probes to load and search. If not set then /// ALL partitions will be searched, if needed, to satisfy k results. diff --git a/rust/lance/benches/vector_index.rs b/rust/lance/benches/vector_index.rs index e20febfd2fb..a39c034c6c0 100644 --- a/rust/lance/benches/vector_index.rs +++ b/rust/lance/benches/vector_index.rs @@ -56,7 +56,7 @@ fn bench_ivf_pq_index(c: &mut Criterion) { .scan() .nearest("vector", q, 10) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .try_into_stream() .await .unwrap() @@ -76,7 +76,7 @@ fn bench_ivf_pq_index(c: &mut Criterion) { .scan() .nearest("vector", q, 10) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .refine(2) .try_into_stream() .await @@ -110,7 +110,7 @@ fn bench_ivf_pq_index(c: &mut Criterion) { .scan() .nearest("vector", q, 10) .unwrap() - .minimum_nprobes(32) + .minimum_nprobes(Some(32)) .try_into_stream() .await .unwrap() diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 7aa8e379a0e..df6b39d680d 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1047,7 +1047,7 @@ impl Scanner { k, lower_bound: None, upper_bound: None, - minimum_nprobes: 20, + minimum_nprobes: Some(20), maximum_nprobes: None, ef: None, refine_factor: None, @@ -1082,7 +1082,7 @@ impl Scanner { /// [Self::maximum_nprobes] to the same value. pub fn nprobes(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { - q.minimum_nprobes = n; + q.minimum_nprobes = Some(n); q.maximum_nprobes = Some(n); } else { log::warn!("nprobes is not set because nearest has not been called yet"); @@ -1097,7 +1097,7 @@ impl Scanner { #[deprecated(note = "Use nprobes instead")] pub fn nprobs(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { - q.minimum_nprobes = n; + q.minimum_nprobes = Some(n); q.maximum_nprobes = Some(n); } else { log::warn!("nprobes is not set because nearest has not been called yet"); @@ -1110,7 +1110,10 @@ impl Scanner { /// If we have found k matching results after searching this many partitions then /// the search will stop. Increasing this number can increase recall but will increase /// latency on all queries. - pub fn minimum_nprobes(&mut self, n: usize) -> &mut Self { + /// + /// Passing [`None`] clears any previously configured minimum which allows the planner to + /// determine an appropriate value dynamically. + pub fn minimum_nprobes(&mut self, n: Option) -> &mut Self { if let Some(q) = self.nearest.as_mut() { q.minimum_nprobes = n; } else { @@ -3978,6 +3981,38 @@ mod test { } } + #[tokio::test] + async fn test_minimum_nprobes_can_be_cleared() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) + .await + .unwrap(); + test_ds.make_vector_index().await.unwrap(); + + let vector_field = test_ds + .schema + .field_with_name("vec") + .expect("vector field must exist"); + let dimension = match vector_field.data_type() { + DataType::FixedSizeList(_, size) => *size as usize, + _ => panic!("expected fixed size list for vector field"), + }; + + let query = std::sync::Arc::new(arrow_array::Float32Array::from(vec![0.0f32; dimension])); + + let mut scanner = test_ds.dataset.scan(); + scanner.nearest("vec", query.as_ref(), 5).unwrap(); + scanner.minimum_nprobes(Some(3)); + scanner.minimum_nprobes(None); + + assert!(scanner + .nearest_mut() + .expect("nearest query should be configured") + .minimum_nprobes + .is_none()); + + scanner.try_into_stream().await.unwrap(); + } + #[tokio::test] async fn test_strict_batch_size() { let dataset = lance_datagen::gen_batch() @@ -5167,7 +5202,7 @@ mod test { let mut scan = dataset.scan(); scan.filter("filterable > 5").unwrap(); scan.nearest("vector", query_key.as_ref(), 1).unwrap(); - scan.minimum_nprobes(100); + scan.minimum_nprobes(Some(100)); scan.with_row_id(); let batches = scan @@ -5471,7 +5506,7 @@ mod test { let key: Float32Array = (0..32).map(|_v| 1.0_f32).collect(); scan.nearest("vec", &key, 5).unwrap(); scan.refine(100); - scan.minimum_nprobes(100); + scan.minimum_nprobes(Some(100)); assert_eq!( dataset.index_cache_entry_count().await, @@ -5507,7 +5542,7 @@ mod test { let mut scan = dataset.scan(); scan.nearest("vec", &key, 5).unwrap(); scan.refine(100); - scan.minimum_nprobes(100); + scan.minimum_nprobes(Some(100)); let results = scan .try_into_stream() @@ -5569,7 +5604,7 @@ mod test { let mut scan = dataset.scan(); scan.nearest("vec", &key, 5).unwrap(); scan.refine(100); - scan.minimum_nprobes(100); + scan.minimum_nprobes(Some(100)); let results = scan .try_into_stream() diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 0ec68319121..519c2510b24 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -255,7 +255,7 @@ mod test { k: 1, lower_bound: None, upper_bound: None, - minimum_nprobes: 1, + minimum_nprobes: Some(1), maximum_nprobes: None, ef: None, refine_factor: None, diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index f9652eb5048..8ae86db8529 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2153,7 +2153,7 @@ mod tests { k: 5, lower_bound: None, upper_bound: None, - minimum_nprobes: 1, + minimum_nprobes: Some(1), maximum_nprobes: None, ef: None, refine_factor: None, @@ -2658,7 +2658,7 @@ mod tests { .nearest("vec", &query, 2_000) .unwrap() .ef(100_000) - .minimum_nprobes(2) + .minimum_nprobes(Some(2)) .try_into_batch() .await .unwrap(); @@ -2703,7 +2703,7 @@ mod tests { .scan() .nearest("vec", &query, 2_000) .unwrap() - .minimum_nprobes(2) + .minimum_nprobes(Some(2)) .try_into_batch() .await .unwrap(); @@ -3125,7 +3125,7 @@ mod tests { .with_row_id() .nearest("vector", query, k) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .try_into_stream() .await .unwrap() @@ -3207,7 +3207,7 @@ mod tests { .with_row_id() .nearest("vector", query, k) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .try_into_stream() .await .unwrap() diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index f6ea14b107e..480720f6bbb 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -1135,7 +1135,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), 100) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .with_row_id() .try_into_batch() .await @@ -1166,7 +1166,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), 100) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .with_row_id() .try_into_batch() .await @@ -1539,7 +1539,7 @@ mod tests { .scan() .nearest("vector", &query, k) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .with_row_id() .try_into_batch() .await @@ -1768,7 +1768,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), k) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .ef(100) .with_row_id() .try_into_batch() @@ -1785,7 +1785,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), part_idx) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .ef(100) .with_row_id() .distance_range(None, Some(part_dist)) @@ -1796,7 +1796,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), k - part_idx) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .ef(100) .with_row_id() .distance_range(Some(part_dist), None) @@ -1831,7 +1831,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), k) .unwrap() - .minimum_nprobes(nlist) + .minimum_nprobes(Some(nlist)) .ef(100) .with_row_id() .distance_range(dists.first().copied(), dists.last().copied()) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index c065c809c81..e9a623c859e 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -403,21 +403,31 @@ impl DisplayAs for ANNIvfPartitionExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { + let min_display = self + .query + .minimum_nprobes + .map(|n| n.to_string()) + .unwrap_or_else(|| "None".to_string()); write!( f, "ANNIvfPartition: uuid={}, minimum_nprobes={}, maximum_nprobes={:?}, deltas={}", self.index_uuids[0], - self.query.minimum_nprobes, + min_display, self.query.maximum_nprobes, self.index_uuids.len() ) } DisplayFormatType::TreeRender => { + let min_display = self + .query + .minimum_nprobes + .map(|n| n.to_string()) + .unwrap_or_else(|| "None".to_string()); write!( f, "ANNIvfPartition\nuuid={}\nminimum_nprobes={}\nmaximum_nprobes={:?}\ndeltas={}", self.index_uuids[0], - self.query.minimum_nprobes, + min_display, self.query.maximum_nprobes, self.index_uuids.len() ) @@ -441,7 +451,7 @@ impl ExecutionPlan for ANNIvfPartitionExec { fn statistics(&self) -> DataFusionResult { Ok(Statistics { - num_rows: Precision::Exact(self.query.minimum_nprobes), + num_rows: Precision::Exact(self.query.minimum_nprobes.unwrap_or(0)), ..Statistics::new_unknown(self.schema().as_ref()) }) } @@ -721,8 +731,12 @@ impl ANNIvfSubIndexExec { state: Arc, ) -> impl Stream> { let stream = futures::stream::once(async move { - let max_nprobes = query.maximum_nprobes.unwrap_or(partitions.len()); - if max_nprobes == query.minimum_nprobes { + let max_nprobes = query + .maximum_nprobes + .unwrap_or(partitions.len()) + .min(partitions.len()); + let min_nprobes = query.minimum_nprobes.unwrap_or(0).min(max_nprobes); + if max_nprobes <= min_nprobes { // We've already searched all partitions, no late search needed return futures::stream::empty().boxed(); } @@ -784,7 +798,7 @@ impl ANNIvfSubIndexExec { let state_clone = state.clone(); - futures::stream::iter(query.minimum_nprobes..max_nprobes) + futures::stream::iter(min_nprobes..max_nprobes) .map(move |idx| { let part_id = partitions.value(idx); let mut query = query.clone(); @@ -839,7 +853,7 @@ impl ANNIvfSubIndexExec { metrics: Arc, state: Arc, ) -> impl Stream> { - let minimum_nprobes = query.minimum_nprobes.min(partitions.len()); + let minimum_nprobes = query.minimum_nprobes.unwrap_or(0).min(partitions.len()); metrics.partitions_searched.add(minimum_nprobes); futures::stream::iter(0..minimum_nprobes) @@ -1026,10 +1040,8 @@ impl ExecutionPlan for ANNIvfSubIndexExec { let pre_filter = pre_filter.clone(); let state = state.clone(); let mut query = query.clone(); - query.minimum_nprobes = std::cmp::min( - query.minimum_nprobes, - early_pruning(q_c_dists.values(), query.k), - ); + let pruned_nprobes = early_pruning(q_c_dists.values(), query.k); + adjust_probes(&mut query, pruned_nprobes); async move { let raw_index = ds .open_vector_index(&column, &index_uuid, &metrics.index_metrics) @@ -1103,6 +1115,22 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } } +fn adjust_probes(query: &mut Query, pruned_nprobes: usize) { + let minimum = query + .minimum_nprobes + .map(|current| current.max(pruned_nprobes)) + .unwrap_or(pruned_nprobes); + let mut maximum = query + .maximum_nprobes + .map(|current| current.min(pruned_nprobes)) + .unwrap_or(pruned_nprobes); + if minimum > maximum { + maximum = minimum; + } + query.minimum_nprobes = Some(minimum); + query.maximum_nprobes = Some(maximum); +} + fn early_pruning(dists: &[f32], k: usize) -> usize { const PRUNING_FACTORS: [f32; 3] = [0.6, 7.0, 81.0]; let factor = match k { @@ -1319,7 +1347,9 @@ mod tests { use arrow::compute::{concat_batches, sort_to_indices, take_record_batch}; use arrow::datatypes::Float32Type; - use arrow_array::{FixedSizeListArray, Int32Array, RecordBatchIterator, StringArray}; + use arrow_array::{ + ArrayRef, FixedSizeListArray, Float32Array, Int32Array, RecordBatchIterator, StringArray, + }; use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; use lance_core::utils::tempfile::TempStrDir; use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts}; @@ -1332,10 +1362,54 @@ mod tests { use lance_testing::datagen::generate_random_array; use rstest::rstest; - use crate::dataset::{ProjectionRequest, WriteMode, WriteParams}; + use crate::dataset::{WriteMode, WriteParams}; use crate::index::vector::VectorIndexParams; use crate::io::exec::testing::TestingExec; + fn base_query() -> Query { + Query { + column: "vec".to_string(), + key: Arc::new(Float32Array::from(vec![0.0f32])) as ArrayRef, + k: 10, + lower_bound: None, + upper_bound: None, + minimum_nprobes: None, + maximum_nprobes: None, + ef: None, + refine_factor: None, + metric_type: DistanceType::L2, + use_index: true, + dist_q_c: 0.0, + } + } + + #[test] + fn test_adjust_probes_rules() { + let mut query = base_query(); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, Some(10)); + assert_eq!(query.maximum_nprobes, Some(10)); + + let mut query = base_query(); + query.minimum_nprobes = Some(20); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, Some(20)); + assert_eq!(query.maximum_nprobes, Some(20)); + + let mut query = base_query(); + query.maximum_nprobes = Some(25); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, Some(10)); + assert_eq!(query.maximum_nprobes, Some(10)); + + let mut query = base_query(); + query.minimum_nprobes = Some(30); + query.maximum_nprobes = Some(50); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, Some(30)); + assert_eq!(query.maximum_nprobes, Some(30)); + } + #[tokio::test] async fn knn_flat_search() { let schema = Arc::new(ArrowSchema::new(vec![ @@ -1470,7 +1544,7 @@ mod tests { k: 10, lower_bound: None, upper_bound: None, - minimum_nprobes: 1, + minimum_nprobes: Some(1), maximum_nprobes: None, ef: None, refine_factor: None, @@ -1665,7 +1739,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .prefilter(true) .scan_stats_callback(stats_holder.get_setter()) .filter("label = 17") @@ -1694,17 +1768,7 @@ mod tests { async fn test_no_prefilter_results(#[values(1, 20)] num_deltas: usize) { let fixture = NprobesTestFixture::new(100, num_deltas).await; - let q = fixture - .dataset - .take( - &[0], - ProjectionRequest::from_schema(fixture.dataset.schema().clone()), - ) - .await - .unwrap() - .column_by_name("vector") - .unwrap() - .clone(); + let q = fixture.get_centroid(0); let stats_holder = StatsHolder::default(); let results = fixture @@ -1712,7 +1776,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .prefilter(true) .scan_stats_callback(stats_holder.get_setter()) .filter("label = 17 AND label = 18") @@ -1748,7 +1812,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .maximum_nprobes(max_nprobes) .prefilter(true) .filter("label = 17") @@ -1787,7 +1851,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .prefilter(true) .filter("userid < 20") .unwrap() @@ -1826,7 +1890,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .prefilter(true) .refine(1) .filter("userid < 20") @@ -1863,7 +1927,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 40000) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(Some(10)) .scan_stats_callback(stats_holder.get_setter()) .project(&Vec::::new()) .unwrap() From 6d832ef95dccc730a0e1e9b7a487666b2823d6a9 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 6 Nov 2025 16:49:13 +0800 Subject: [PATCH 08/12] fix Signed-off-by: BubbleCal --- java/lance-jni/src/blocking_scanner.rs | 2 +- java/lance-jni/src/utils.rs | 2 +- .../java/com/lancedb/lance/ipc/Query.java | 41 ++------- python/python/lance/dataset.py | 4 +- python/src/dataset.rs | 22 ++--- rust/examples/src/ivf_hnsw.rs | 2 +- rust/lance-index/src/vector.rs | 4 +- rust/lance/benches/vector_index.rs | 6 +- rust/lance/src/dataset/scanner.rs | 85 ++++++------------- rust/lance/src/index/vector/fixture_test.rs | 2 +- rust/lance/src/index/vector/ivf.rs | 10 +-- rust/lance/src/index/vector/ivf/v2.rs | 14 +-- rust/lance/src/io/exec/knn.rs | 81 ++++++++---------- 13 files changed, 100 insertions(+), 175 deletions(-) diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 3be1fd0d75e..c97cbbc170a 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -179,7 +179,7 @@ fn inner_create_scanner<'local>( let k = env.get_int_as_usize_from_method(&java_obj, "getK")?; let _ = scanner.nearest(&column, &key, k); - let minimum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMinimumNprobes")?; + let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?; scanner.minimum_nprobes(minimum_nprobes); let maximum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMaximumNprobes")?; diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 6ec0c560f97..495ab229f4b 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -133,7 +133,7 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> let key = Arc::new(Float32Array::from(key_array)); let k = env.get_int_as_usize_from_method(&java_obj, "getK")?; - let minimum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMinimumNprobes")?; + let minimum_nprobes = env.get_int_as_usize_from_method(&java_obj, "getMinimumNprobes")?; let maximum_nprobes = env.get_optional_usize_from_method(&java_obj, "getMaximumNprobes")?; let ef = env.get_optional_usize_from_method(&java_obj, "getEf")?; diff --git a/java/src/main/java/com/lancedb/lance/ipc/Query.java b/java/src/main/java/com/lancedb/lance/ipc/Query.java index d30f6cee5e5..07b573dc494 100644 --- a/java/src/main/java/com/lancedb/lance/ipc/Query.java +++ b/java/src/main/java/com/lancedb/lance/ipc/Query.java @@ -25,7 +25,7 @@ public class Query { private final String column; private final float[] key; private final int k; - private final Optional minimumNprobes; + private final int minimumNprobes; private final Optional maximumNprobes; private final Optional ef; private final Optional refineFactor; @@ -38,13 +38,10 @@ private Query(Builder builder) { this.key = Preconditions.checkNotNull(builder.key, "Key must be set"); Preconditions.checkArgument(builder.k > 0, "K must be greater than 0"); Preconditions.checkArgument( - builder.minimumNprobes.map(n -> n > 0).orElse(true), - "Minimum Nprobes must be greater than 0"); + builder.minimumNprobes > 0, "Minimum Nprobes must be greater than 0"); Preconditions.checkArgument( !builder.maximumNprobes.isPresent() - || builder.minimumNprobes - .map(min -> builder.maximumNprobes.get() >= min) - .orElse(true), + || builder.maximumNprobes.get() >= builder.minimumNprobes, "Maximum Nprobes must be greater than minimum Nprobes"); this.k = builder.k; this.minimumNprobes = builder.minimumNprobes; @@ -67,7 +64,7 @@ public int getK() { return k; } - public Optional getMinimumNprobes() { + public int getMinimumNprobes() { return minimumNprobes; } @@ -97,7 +94,7 @@ public String toString() { .add("column", column) .add("key", key) .add("k", k) - .add("minimumNprobes", minimumNprobes.orElse(null)) + .add("minimumNprobes", minimumNprobes) .add("maximumNprobes", maximumNprobes.orElse(null)) .add("ef", ef.orElse(null)) .add("refineFactor", refineFactor.orElse(null)) @@ -110,7 +107,7 @@ public static class Builder { private String column; private float[] key; private int k = 10; - private Optional minimumNprobes = Optional.of(20); + private int minimumNprobes = 1; private Optional maximumNprobes = Optional.empty(); private Optional ef = Optional.empty(); private Optional refineFactor = Optional.empty(); @@ -160,7 +157,7 @@ public Builder setK(int k) { * @return The Builder instance for method chaining. */ public Builder setNprobes(int nprobes) { - this.minimumNprobes = Optional.of(nprobes); + this.minimumNprobes = nprobes; this.maximumNprobes = Optional.of(nprobes); return this; } @@ -175,29 +172,7 @@ public Builder setNprobes(int nprobes) { * @return The Builder instance for method chaining. */ public Builder setMinimumNprobes(int minimumNprobes) { - this.minimumNprobes = Optional.of(minimumNprobes); - return this; - } - - /** - * Sets the minimum number of partitions to search. - * - * @param minimumNprobes The optional number of partitions to search. - * @return The Builder instance for method chaining. - */ - public Builder setMinimumNprobes(Optional minimumNprobes) { - this.minimumNprobes = - Preconditions.checkNotNull(minimumNprobes, "minimumNprobes must not be null"); - return this; - } - - /** - * Clears any previously configured minimum number of partitions to search. - * - * @return The Builder instance for method chaining. - */ - public Builder clearMinimumNprobes() { - this.minimumNprobes = Optional.empty(); + this.minimumNprobes = minimumNprobes; return this; } diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7ef7f9e06ef..8956587d78a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -717,7 +717,7 @@ def scanner( "column": , "q": , "k": 10, - "minimum_nprobes": 20, + "minimum_nprobes": 1, "maximum_nprobes": 50, "refine_factor": 1 } @@ -980,7 +980,7 @@ def to_table( "q": , "k": 10, "metric": "cosine", - "minimum_nprobes": 20, + "minimum_nprobes": 1, "maximum_nprobes": 50, "refine_factor": 1 } diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 4d7222bd322..06cd5965899 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -103,7 +103,7 @@ pub mod commit; pub mod optimize; pub mod stats; -const DEFAULT_NPROBES: usize = 20; +const DEFAULT_NPROBES: usize = 1; const LANCE_COMMIT_MESSAGE_KEY: &str = "__lance_commit_message"; fn convert_reader(reader: &Bound) -> PyResult> { @@ -981,22 +981,20 @@ impl Dataset { 10 }; - let mut minimum_nprobes = Some(DEFAULT_NPROBES); + let mut minimum_nprobes = DEFAULT_NPROBES; let mut maximum_nprobes = None; if let Some(nprobes) = nearest.get_item("nprobes")? { if !nprobes.is_none() { let extracted: usize = nprobes.extract()?; - minimum_nprobes = Some(extracted); + minimum_nprobes = extracted; maximum_nprobes = Some(extracted); } } if let Some(min_nprobes) = nearest.get_item("minimum_nprobes")? { - if min_nprobes.is_none() { - minimum_nprobes = None; - } else { - minimum_nprobes = Some(min_nprobes.extract()?); + if !min_nprobes.is_none() { + minimum_nprobes = min_nprobes.extract()?; } } @@ -1006,9 +1004,7 @@ impl Dataset { } } - if let (Some(minimum_nprobes), Some(maximum_nprobes)) = - (minimum_nprobes, maximum_nprobes) - { + if let Some(maximum_nprobes) = maximum_nprobes { if minimum_nprobes > maximum_nprobes { return Err(PyValueError::new_err( "minimum_nprobes must be <= maximum_nprobes", @@ -1016,10 +1012,8 @@ impl Dataset { } } - if let Some(minimum_nprobes) = minimum_nprobes { - if minimum_nprobes < 1 { - return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); - } + if minimum_nprobes < 1 { + return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); } if let Some(maximum_nprobes) = maximum_nprobes { diff --git a/rust/examples/src/ivf_hnsw.rs b/rust/examples/src/ivf_hnsw.rs index 9e9aa4910b8..34bd4cbca7f 100644 --- a/rust/examples/src/ivf_hnsw.rs +++ b/rust/examples/src/ivf_hnsw.rs @@ -117,7 +117,7 @@ async fn main() { .with_row_id() .nearest(column, &q, args.k) .unwrap() - .minimum_nprobes(Some(args.nprobe)); + .minimum_nprobes(args.nprobe); println!("{:?}", plan.explain_plan(true).await.unwrap()); let now = std::time::Instant::now(); diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 200e1b13dc1..4377675202e 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -88,8 +88,8 @@ pub struct Query { /// The minimum number of probes to load and search. More partitions /// will only be loaded if we have not found k results. /// - /// If None, the planner will decide how many partitions to search first. - pub minimum_nprobes: Option, + /// The planner will always search at least this many partitions. Defaults to 1. + pub minimum_nprobes: usize, /// The maximum number of probes to load and search. If not set then /// ALL partitions will be searched, if needed, to satisfy k results. diff --git a/rust/lance/benches/vector_index.rs b/rust/lance/benches/vector_index.rs index a39c034c6c0..e20febfd2fb 100644 --- a/rust/lance/benches/vector_index.rs +++ b/rust/lance/benches/vector_index.rs @@ -56,7 +56,7 @@ fn bench_ivf_pq_index(c: &mut Criterion) { .scan() .nearest("vector", q, 10) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .try_into_stream() .await .unwrap() @@ -76,7 +76,7 @@ fn bench_ivf_pq_index(c: &mut Criterion) { .scan() .nearest("vector", q, 10) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .refine(2) .try_into_stream() .await @@ -110,7 +110,7 @@ fn bench_ivf_pq_index(c: &mut Criterion) { .scan() .nearest("vector", q, 10) .unwrap() - .minimum_nprobes(Some(32)) + .minimum_nprobes(32) .try_into_stream() .await .unwrap() diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index df6b39d680d..b1f2d075401 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1047,7 +1047,7 @@ impl Scanner { k, lower_bound: None, upper_bound: None, - minimum_nprobes: Some(20), + minimum_nprobes: 1, maximum_nprobes: None, ef: None, refine_factor: None, @@ -1082,7 +1082,7 @@ impl Scanner { /// [Self::maximum_nprobes] to the same value. pub fn nprobes(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { - q.minimum_nprobes = Some(n); + q.minimum_nprobes = n; q.maximum_nprobes = Some(n); } else { log::warn!("nprobes is not set because nearest has not been called yet"); @@ -1097,7 +1097,7 @@ impl Scanner { #[deprecated(note = "Use nprobes instead")] pub fn nprobs(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { - q.minimum_nprobes = Some(n); + q.minimum_nprobes = n; q.maximum_nprobes = Some(n); } else { log::warn!("nprobes is not set because nearest has not been called yet"); @@ -1111,9 +1111,8 @@ impl Scanner { /// the search will stop. Increasing this number can increase recall but will increase /// latency on all queries. /// - /// Passing [`None`] clears any previously configured minimum which allows the planner to - /// determine an appropriate value dynamically. - pub fn minimum_nprobes(&mut self, n: Option) -> &mut Self { + /// The default value is 1. + pub fn minimum_nprobes(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { q.minimum_nprobes = n; } else { @@ -3981,38 +3980,6 @@ mod test { } } - #[tokio::test] - async fn test_minimum_nprobes_can_be_cleared() { - let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, false) - .await - .unwrap(); - test_ds.make_vector_index().await.unwrap(); - - let vector_field = test_ds - .schema - .field_with_name("vec") - .expect("vector field must exist"); - let dimension = match vector_field.data_type() { - DataType::FixedSizeList(_, size) => *size as usize, - _ => panic!("expected fixed size list for vector field"), - }; - - let query = std::sync::Arc::new(arrow_array::Float32Array::from(vec![0.0f32; dimension])); - - let mut scanner = test_ds.dataset.scan(); - scanner.nearest("vec", query.as_ref(), 5).unwrap(); - scanner.minimum_nprobes(Some(3)); - scanner.minimum_nprobes(None); - - assert!(scanner - .nearest_mut() - .expect("nearest query should be configured") - .minimum_nprobes - .is_none()); - - scanner.try_into_stream().await.unwrap(); - } - #[tokio::test] async fn test_strict_batch_size() { let dataset = lance_datagen::gen_batch() @@ -5202,7 +5169,7 @@ mod test { let mut scan = dataset.scan(); scan.filter("filterable > 5").unwrap(); scan.nearest("vector", query_key.as_ref(), 1).unwrap(); - scan.minimum_nprobes(Some(100)); + scan.minimum_nprobes(100); scan.with_row_id(); let batches = scan @@ -5506,7 +5473,7 @@ mod test { let key: Float32Array = (0..32).map(|_v| 1.0_f32).collect(); scan.nearest("vec", &key, 5).unwrap(); scan.refine(100); - scan.minimum_nprobes(Some(100)); + scan.minimum_nprobes(100); assert_eq!( dataset.index_cache_entry_count().await, @@ -5542,7 +5509,7 @@ mod test { let mut scan = dataset.scan(); scan.nearest("vec", &key, 5).unwrap(); scan.refine(100); - scan.minimum_nprobes(Some(100)); + scan.minimum_nprobes(100); let results = scan .try_into_stream() @@ -5604,7 +5571,7 @@ mod test { let mut scan = dataset.scan(); scan.nearest("vec", &key, 5).unwrap(); scan.refine(100); - scan.minimum_nprobes(Some(100)); + scan.minimum_nprobes(100); let results = scan .try_into_stream() @@ -6883,7 +6850,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=42), expr=... ANNSubIndex: name=..., k=42, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 42), @@ -6903,7 +6870,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=40), expr=... ANNSubIndex: name=..., k=40, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| Ok(scan.nearest("vec", &q, 10)?.refine(4)), @@ -6947,7 +6914,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| { @@ -6968,7 +6935,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" } else { @@ -6977,7 +6944,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10) " @@ -7013,7 +6980,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=6), expr=... ANNSubIndex: name=..., k=6, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 6), @@ -7045,7 +7012,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=15), expr=... ANNSubIndex: name=..., k=15, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 15)?.filter("i > 10"), @@ -7074,7 +7041,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" } else { @@ -7096,7 +7063,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" }; @@ -7127,7 +7094,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( &dataset.dataset, @@ -7148,7 +7115,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" } else { @@ -7157,7 +7124,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=3, range_before=None, \ range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" }; @@ -7195,7 +7162,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=8), expr=... ANNSubIndex: name=..., k=8, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( &dataset.dataset, @@ -7231,7 +7198,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=11), expr=... ANNSubIndex: name=..., k=11, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; dataset.make_scalar_index().await?; assert_plan_equals( @@ -7571,7 +7538,7 @@ mod test { }, "SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... ANNSubIndex: name=idx, k=32, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1", + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await .unwrap(); @@ -7586,7 +7553,7 @@ mod test { }, "SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... ANNSubIndex: name=idx, k=33, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1", + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await .unwrap(); @@ -7614,7 +7581,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... ANNSubIndex: name=idx, k=34, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1", + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await .unwrap(); diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 519c2510b24..0ec68319121 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -255,7 +255,7 @@ mod test { k: 1, lower_bound: None, upper_bound: None, - minimum_nprobes: Some(1), + minimum_nprobes: 1, maximum_nprobes: None, ef: None, refine_factor: None, diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 8ae86db8529..f9652eb5048 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2153,7 +2153,7 @@ mod tests { k: 5, lower_bound: None, upper_bound: None, - minimum_nprobes: Some(1), + minimum_nprobes: 1, maximum_nprobes: None, ef: None, refine_factor: None, @@ -2658,7 +2658,7 @@ mod tests { .nearest("vec", &query, 2_000) .unwrap() .ef(100_000) - .minimum_nprobes(Some(2)) + .minimum_nprobes(2) .try_into_batch() .await .unwrap(); @@ -2703,7 +2703,7 @@ mod tests { .scan() .nearest("vec", &query, 2_000) .unwrap() - .minimum_nprobes(Some(2)) + .minimum_nprobes(2) .try_into_batch() .await .unwrap(); @@ -3125,7 +3125,7 @@ mod tests { .with_row_id() .nearest("vector", query, k) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .try_into_stream() .await .unwrap() @@ -3207,7 +3207,7 @@ mod tests { .with_row_id() .nearest("vector", query, k) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .try_into_stream() .await .unwrap() diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 480720f6bbb..f6ea14b107e 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -1135,7 +1135,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), 100) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .with_row_id() .try_into_batch() .await @@ -1166,7 +1166,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), 100) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .with_row_id() .try_into_batch() .await @@ -1539,7 +1539,7 @@ mod tests { .scan() .nearest("vector", &query, k) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .with_row_id() .try_into_batch() .await @@ -1768,7 +1768,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), k) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .ef(100) .with_row_id() .try_into_batch() @@ -1785,7 +1785,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), part_idx) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .ef(100) .with_row_id() .distance_range(None, Some(part_dist)) @@ -1796,7 +1796,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), k - part_idx) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .ef(100) .with_row_id() .distance_range(Some(part_dist), None) @@ -1831,7 +1831,7 @@ mod tests { .scan() .nearest(vector_column, query.as_primitive::(), k) .unwrap() - .minimum_nprobes(Some(nlist)) + .minimum_nprobes(nlist) .ef(100) .with_row_id() .distance_range(dists.first().copied(), dists.last().copied()) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index e9a623c859e..6d10ceaf446 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -403,31 +403,21 @@ impl DisplayAs for ANNIvfPartitionExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { - let min_display = self - .query - .minimum_nprobes - .map(|n| n.to_string()) - .unwrap_or_else(|| "None".to_string()); write!( f, "ANNIvfPartition: uuid={}, minimum_nprobes={}, maximum_nprobes={:?}, deltas={}", self.index_uuids[0], - min_display, + self.query.minimum_nprobes, self.query.maximum_nprobes, self.index_uuids.len() ) } DisplayFormatType::TreeRender => { - let min_display = self - .query - .minimum_nprobes - .map(|n| n.to_string()) - .unwrap_or_else(|| "None".to_string()); write!( f, "ANNIvfPartition\nuuid={}\nminimum_nprobes={}\nmaximum_nprobes={:?}\ndeltas={}", self.index_uuids[0], - min_display, + self.query.minimum_nprobes, self.query.maximum_nprobes, self.index_uuids.len() ) @@ -451,7 +441,7 @@ impl ExecutionPlan for ANNIvfPartitionExec { fn statistics(&self) -> DataFusionResult { Ok(Statistics { - num_rows: Precision::Exact(self.query.minimum_nprobes.unwrap_or(0)), + num_rows: Precision::Exact(self.query.minimum_nprobes), ..Statistics::new_unknown(self.schema().as_ref()) }) } @@ -735,7 +725,7 @@ impl ANNIvfSubIndexExec { .maximum_nprobes .unwrap_or(partitions.len()) .min(partitions.len()); - let min_nprobes = query.minimum_nprobes.unwrap_or(0).min(max_nprobes); + let min_nprobes = query.minimum_nprobes.min(max_nprobes); if max_nprobes <= min_nprobes { // We've already searched all partitions, no late search needed return futures::stream::empty().boxed(); @@ -853,7 +843,7 @@ impl ANNIvfSubIndexExec { metrics: Arc, state: Arc, ) -> impl Stream> { - let minimum_nprobes = query.minimum_nprobes.unwrap_or(0).min(partitions.len()); + let minimum_nprobes = query.minimum_nprobes.min(partitions.len()); metrics.partitions_searched.add(minimum_nprobes); futures::stream::iter(0..minimum_nprobes) @@ -1116,19 +1106,12 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } fn adjust_probes(query: &mut Query, pruned_nprobes: usize) { - let minimum = query - .minimum_nprobes - .map(|current| current.max(pruned_nprobes)) - .unwrap_or(pruned_nprobes); - let mut maximum = query - .maximum_nprobes - .map(|current| current.min(pruned_nprobes)) - .unwrap_or(pruned_nprobes); - if minimum > maximum { - maximum = minimum; - } - query.minimum_nprobes = Some(minimum); - query.maximum_nprobes = Some(maximum); + query.minimum_nprobes = query.minimum_nprobes.max(pruned_nprobes); + if let Some(maximum) = query.maximum_nprobes { + if query.minimum_nprobes > maximum { + query.minimum_nprobes = maximum; + } + } } fn early_pruning(dists: &[f32], k: usize) -> usize { @@ -1373,7 +1356,7 @@ mod tests { k: 10, lower_bound: None, upper_bound: None, - minimum_nprobes: None, + minimum_nprobes: 1, maximum_nprobes: None, ef: None, refine_factor: None, @@ -1387,27 +1370,33 @@ mod tests { fn test_adjust_probes_rules() { let mut query = base_query(); adjust_probes(&mut query, 10); - assert_eq!(query.minimum_nprobes, Some(10)); - assert_eq!(query.maximum_nprobes, Some(10)); + assert_eq!(query.minimum_nprobes, 10); + assert_eq!(query.maximum_nprobes, None); let mut query = base_query(); - query.minimum_nprobes = Some(20); + query.minimum_nprobes = 20; adjust_probes(&mut query, 10); - assert_eq!(query.minimum_nprobes, Some(20)); - assert_eq!(query.maximum_nprobes, Some(20)); + assert_eq!(query.minimum_nprobes, 20); + assert_eq!(query.maximum_nprobes, None); let mut query = base_query(); query.maximum_nprobes = Some(25); adjust_probes(&mut query, 10); - assert_eq!(query.minimum_nprobes, Some(10)); - assert_eq!(query.maximum_nprobes, Some(10)); + assert_eq!(query.minimum_nprobes, 10); + assert_eq!(query.maximum_nprobes, Some(25)); + + let mut query = base_query(); + query.maximum_nprobes = Some(5); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, 5); + assert_eq!(query.maximum_nprobes, Some(5)); let mut query = base_query(); - query.minimum_nprobes = Some(30); + query.minimum_nprobes = 30; query.maximum_nprobes = Some(50); adjust_probes(&mut query, 10); - assert_eq!(query.minimum_nprobes, Some(30)); - assert_eq!(query.maximum_nprobes, Some(30)); + assert_eq!(query.minimum_nprobes, 30); + assert_eq!(query.maximum_nprobes, Some(50)); } #[tokio::test] @@ -1544,7 +1533,7 @@ mod tests { k: 10, lower_bound: None, upper_bound: None, - minimum_nprobes: Some(1), + minimum_nprobes: 1, maximum_nprobes: None, ef: None, refine_factor: None, @@ -1739,7 +1728,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .prefilter(true) .scan_stats_callback(stats_holder.get_setter()) .filter("label = 17") @@ -1776,7 +1765,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .prefilter(true) .scan_stats_callback(stats_holder.get_setter()) .filter("label = 17 AND label = 18") @@ -1812,7 +1801,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(max_nprobes) .maximum_nprobes(max_nprobes) .prefilter(true) .filter("label = 17") @@ -1851,7 +1840,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .prefilter(true) .filter("userid < 20") .unwrap() @@ -1890,7 +1879,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .prefilter(true) .refine(1) .filter("userid < 20") @@ -1927,7 +1916,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 40000) .unwrap() - .minimum_nprobes(Some(10)) + .minimum_nprobes(10) .scan_stats_callback(stats_holder.get_setter()) .project(&Vec::::new()) .unwrap() From 8ee0ba7305d220383b53e0dd9ba3514e8db8e181 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 6 Nov 2025 17:12:43 +0800 Subject: [PATCH 09/12] fix ut Signed-off-by: BubbleCal --- python/python/tests/test_vector_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 6f1611eecb8..c2de6c8e03d 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1092,7 +1092,7 @@ def query_index(ds, ntimes, q=None): nearest={ "column": "vector", "q": q if q is not None else rng.standard_normal(ndim), - "minimum_nprobes": 1, + "minimum_nprobes": 20, }, ) From 2d55da386c397f11c40bd1a50a1dcbf287a481bf Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 6 Nov 2025 17:56:59 +0800 Subject: [PATCH 10/12] fix ut Signed-off-by: BubbleCal --- python/python/tests/test_vector_index.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index c2de6c8e03d..eabe1cfd4bc 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1092,7 +1092,7 @@ def query_index(ds, ntimes, q=None): nearest={ "column": "vector", "q": q if q is not None else rng.standard_normal(ndim), - "minimum_nprobes": 20, + "nprobes": 20, }, ) From 2398244a3af8264ec6fea4b147442930bca91c3b Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Fri, 7 Nov 2025 20:15:34 +0800 Subject: [PATCH 11/12] fix Signed-off-by: BubbleCal --- rust/lance/src/io/exec/knn.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 6d10ceaf446..adea72118f9 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -1115,6 +1115,10 @@ fn adjust_probes(query: &mut Query, pruned_nprobes: usize) { } fn early_pruning(dists: &[f32], k: usize) -> usize { + if dists.is_empty() { + return 0; + } + const PRUNING_FACTORS: [f32; 3] = [0.6, 7.0, 81.0]; let factor = match k { ..=1 => PRUNING_FACTORS[0], From f492fd341d22ecffac072d2aceed1caf4e83346b Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Mon, 10 Nov 2025 13:54:32 +0800 Subject: [PATCH 12/12] more comments Signed-off-by: BubbleCal --- rust/lance-index/src/vector.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 4377675202e..b986b1d6c20 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -86,7 +86,8 @@ pub struct Query { pub upper_bound: Option, /// The minimum number of probes to load and search. More partitions - /// will only be loaded if we have not found k results. + /// will only be loaded if we have not found k results, or the the algorithm + /// determines more partitions are needed to satisfy recall requirements. /// /// The planner will always search at least this many partitions. Defaults to 1. pub minimum_nprobes: usize,