diff --git a/java/src/main/java/com/lancedb/lance/ipc/Query.java b/java/src/main/java/com/lancedb/lance/ipc/Query.java index 46af8692c5f..07b573dc494 100644 --- a/java/src/main/java/com/lancedb/lance/ipc/Query.java +++ b/java/src/main/java/com/lancedb/lance/ipc/Query.java @@ -107,7 +107,7 @@ public static class Builder { private String column; private float[] key; private int k = 10; - private int minimumNprobes = 20; + private int minimumNprobes = 1; private Optional maximumNprobes = Optional.empty(); private Optional ef = Optional.empty(); private Optional refineFactor = Optional.empty(); diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 7ef7f9e06ef..8956587d78a 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -717,7 +717,7 @@ def scanner( "column": , "q": , "k": 10, - "minimum_nprobes": 20, + "minimum_nprobes": 1, "maximum_nprobes": 50, "refine_factor": 1 } @@ -980,7 +980,7 @@ def to_table( "q": , "k": 10, "metric": "cosine", - "minimum_nprobes": 20, + "minimum_nprobes": 1, "maximum_nprobes": 50, "refine_factor": 1 } diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 6f1611eecb8..eabe1cfd4bc 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -1092,7 +1092,7 @@ def query_index(ds, ntimes, q=None): nearest={ "column": "vector", "q": q if q is not None else rng.standard_normal(ndim), - "minimum_nprobes": 1, + "nprobes": 20, }, ) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index eb67b60e12b..06cd5965899 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -103,7 +103,7 @@ pub mod commit; pub mod optimize; pub mod stats; -const DEFAULT_NPROBES: usize = 20; +const DEFAULT_NPROBES: usize = 1; const LANCE_COMMIT_MESSAGE_KEY: &str = "__lance_commit_message"; fn convert_reader(reader: &Bound) -> PyResult> { @@ -986,8 +986,9 @@ impl Dataset { if let Some(nprobes) = nearest.get_item("nprobes")? { if !nprobes.is_none() { - minimum_nprobes = nprobes.extract()?; - maximum_nprobes = Some(minimum_nprobes); + let extracted: usize = nprobes.extract()?; + minimum_nprobes = extracted; + maximum_nprobes = Some(extracted); } } @@ -1003,18 +1004,22 @@ impl Dataset { } } - if minimum_nprobes > maximum_nprobes.unwrap_or(usize::MAX) { - return Err(PyValueError::new_err( - "minimum_nprobes must be <= maximum_nprobes", - )); + if let Some(maximum_nprobes) = maximum_nprobes { + if minimum_nprobes > maximum_nprobes { + return Err(PyValueError::new_err( + "minimum_nprobes must be <= maximum_nprobes", + )); + } } if minimum_nprobes < 1 { return Err(PyValueError::new_err("minimum_nprobes must be >= 1")); } - if maximum_nprobes.unwrap_or(usize::MAX) < 1 { - return Err(PyValueError::new_err("maximum_nprobes must be >= 1")); + if let Some(maximum_nprobes) = maximum_nprobes { + if maximum_nprobes < 1 { + return Err(PyValueError::new_err("maximum_nprobes must be >= 1")); + } } let metric_type: Option = diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 9f472206c88..b986b1d6c20 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -86,7 +86,10 @@ pub struct Query { pub upper_bound: Option, /// The minimum number of probes to load and search. More partitions - /// will only be loaded if we have not found k results. + /// will only be loaded if we have not found k results, or the the algorithm + /// determines more partitions are needed to satisfy recall requirements. + /// + /// The planner will always search at least this many partitions. Defaults to 1. pub minimum_nprobes: usize, /// The maximum number of probes to load and search. If not set then diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 7aa8e379a0e..b1f2d075401 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1047,7 +1047,7 @@ impl Scanner { k, lower_bound: None, upper_bound: None, - minimum_nprobes: 20, + minimum_nprobes: 1, maximum_nprobes: None, ef: None, refine_factor: None, @@ -1110,6 +1110,8 @@ impl Scanner { /// If we have found k matching results after searching this many partitions then /// the search will stop. Increasing this number can increase recall but will increase /// latency on all queries. + /// + /// The default value is 1. pub fn minimum_nprobes(&mut self, n: usize) -> &mut Self { if let Some(q) = self.nearest.as_mut() { q.minimum_nprobes = n; @@ -6848,7 +6850,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=42), expr=... ANNSubIndex: name=..., k=42, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 42), @@ -6868,7 +6870,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=40), expr=... ANNSubIndex: name=..., k=40, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| Ok(scan.nearest("vec", &q, 10)?.refine(4)), @@ -6912,7 +6914,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| { @@ -6933,7 +6935,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" } else { @@ -6942,7 +6944,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... ANNSubIndex: name=..., k=17, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10) " @@ -6978,7 +6980,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=6), expr=... ANNSubIndex: name=..., k=6, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 6), @@ -7010,7 +7012,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=15), expr=... ANNSubIndex: name=..., k=15, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1"; + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, |scan| scan.nearest("vec", &q, 15)?.filter("i > 10"), @@ -7039,7 +7041,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" } else { @@ -7061,7 +7063,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" }; @@ -7092,7 +7094,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( &dataset.dataset, @@ -7113,7 +7115,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" } else { @@ -7122,7 +7124,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... ANNSubIndex: name=..., k=5, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=3, range_before=None, \ range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" }; @@ -7160,7 +7162,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=8), expr=... ANNSubIndex: name=..., k=8, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( &dataset.dataset, @@ -7196,7 +7198,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=11), expr=... ANNSubIndex: name=..., k=11, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1 + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; dataset.make_scalar_index().await?; assert_plan_equals( @@ -7536,7 +7538,7 @@ mod test { }, "SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... ANNSubIndex: name=idx, k=32, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1", + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await .unwrap(); @@ -7551,7 +7553,7 @@ mod test { }, "SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... ANNSubIndex: name=idx, k=33, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1", + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await .unwrap(); @@ -7579,7 +7581,7 @@ mod test { CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... ANNSubIndex: name=idx, k=34, deltas=1 - ANNIvfPartition: uuid=..., minimum_nprobes=20, maximum_nprobes=None, deltas=1", + ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await .unwrap(); diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 9b8d442f070..adea72118f9 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -721,8 +721,12 @@ impl ANNIvfSubIndexExec { state: Arc, ) -> impl Stream> { let stream = futures::stream::once(async move { - let max_nprobes = query.maximum_nprobes.unwrap_or(partitions.len()); - if max_nprobes == query.minimum_nprobes { + let max_nprobes = query + .maximum_nprobes + .unwrap_or(partitions.len()) + .min(partitions.len()); + let min_nprobes = query.minimum_nprobes.min(max_nprobes); + if max_nprobes <= min_nprobes { // We've already searched all partitions, no late search needed return futures::stream::empty().boxed(); } @@ -784,7 +788,7 @@ impl ANNIvfSubIndexExec { let state_clone = state.clone(); - futures::stream::iter(query.minimum_nprobes..max_nprobes) + futures::stream::iter(min_nprobes..max_nprobes) .map(move |idx| { let part_id = partitions.value(idx); let mut query = query.clone(); @@ -1025,8 +1029,9 @@ impl ExecutionPlan for ANNIvfSubIndexExec { let metrics = metrics.clone(); let pre_filter = pre_filter.clone(); let state = state.clone(); - let query = query.clone(); - + let mut query = query.clone(); + let pruned_nprobes = early_pruning(q_c_dists.values(), query.k); + adjust_probes(&mut query, pruned_nprobes); async move { let raw_index = ds .open_vector_index(&column, &index_uuid, &metrics.index_metrics) @@ -1100,6 +1105,30 @@ impl ExecutionPlan for ANNIvfSubIndexExec { } } +fn adjust_probes(query: &mut Query, pruned_nprobes: usize) { + query.minimum_nprobes = query.minimum_nprobes.max(pruned_nprobes); + if let Some(maximum) = query.maximum_nprobes { + if query.minimum_nprobes > maximum { + query.minimum_nprobes = maximum; + } + } +} + +fn early_pruning(dists: &[f32], k: usize) -> usize { + if dists.is_empty() { + return 0; + } + + const PRUNING_FACTORS: [f32; 3] = [0.6, 7.0, 81.0]; + let factor = match k { + ..=1 => PRUNING_FACTORS[0], + 2..=10 => PRUNING_FACTORS[1], + 11.. => PRUNING_FACTORS[2], + }; + let dist_threshold = dists[0] * factor; + dists.partition_point(|dist| *dist <= dist_threshold) +} + #[derive(Debug)] pub struct MultivectorScoringExec { // the inputs are sorted ANN search results @@ -1305,7 +1334,9 @@ mod tests { use arrow::compute::{concat_batches, sort_to_indices, take_record_batch}; use arrow::datatypes::Float32Type; - use arrow_array::{FixedSizeListArray, Int32Array, RecordBatchIterator, StringArray}; + use arrow_array::{ + ArrayRef, FixedSizeListArray, Float32Array, Int32Array, RecordBatchIterator, StringArray, + }; use arrow_schema::{Field as ArrowField, Schema as ArrowSchema}; use lance_core::utils::tempfile::TempStrDir; use lance_datafusion::exec::{ExecutionStatsCallback, ExecutionSummaryCounts}; @@ -1322,6 +1353,56 @@ mod tests { use crate::index::vector::VectorIndexParams; use crate::io::exec::testing::TestingExec; + fn base_query() -> Query { + Query { + column: "vec".to_string(), + key: Arc::new(Float32Array::from(vec![0.0f32])) as ArrayRef, + k: 10, + lower_bound: None, + upper_bound: None, + minimum_nprobes: 1, + maximum_nprobes: None, + ef: None, + refine_factor: None, + metric_type: DistanceType::L2, + use_index: true, + dist_q_c: 0.0, + } + } + + #[test] + fn test_adjust_probes_rules() { + let mut query = base_query(); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, 10); + assert_eq!(query.maximum_nprobes, None); + + let mut query = base_query(); + query.minimum_nprobes = 20; + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, 20); + assert_eq!(query.maximum_nprobes, None); + + let mut query = base_query(); + query.maximum_nprobes = Some(25); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, 10); + assert_eq!(query.maximum_nprobes, Some(25)); + + let mut query = base_query(); + query.maximum_nprobes = Some(5); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, 5); + assert_eq!(query.maximum_nprobes, Some(5)); + + let mut query = base_query(); + query.minimum_nprobes = 30; + query.maximum_nprobes = Some(50); + adjust_probes(&mut query, 10); + assert_eq!(query.minimum_nprobes, 30); + assert_eq!(query.maximum_nprobes, Some(50)); + } + #[tokio::test] async fn knn_flat_search() { let schema = Arc::new(ArrowSchema::new(vec![ @@ -1724,7 +1805,7 @@ mod tests { .scan() .nearest("vector", q.as_ref(), 50) .unwrap() - .minimum_nprobes(10) + .minimum_nprobes(max_nprobes) .maximum_nprobes(max_nprobes) .prefilter(true) .filter("label = 17")