From 5e8cd703b22ffdb31624060127d5f9e8f389ef3a Mon Sep 17 00:00:00 2001 From: Will Jones Date: Wed, 31 Dec 2025 13:12:43 -0800 Subject: [PATCH 1/5] fix: check metric compatibility before using vector index MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously, vector search would use an ANN index regardless of whether the index's metric type matched the query's requested metric. This produced incorrect distances when, for example, an index built with metric="dot" was used for a query with metric="l2". Now the scanner checks if the index's metric matches the user's requested metric. If they don't match, it silently falls back to flat search. If the user doesn't specify a metric, the index's metric is used. Changes: - Query.metric_type is now Option (None = use index default) - Scanner checks metric compatibility before using an index - Explain plan now shows the metric being used - Java bindings updated to make distanceType optional Fixes lancedb/lance#5608 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- java/lance-jni/src/blocking_scanner.rs | 13 +- java/lance-jni/src/utils.rs | 13 +- java/src/main/java/org/lance/ipc/Query.java | 21 +- rust/lance-index/src/vector.rs | 5 +- rust/lance/src/dataset/scanner.rs | 204 ++++++++++++++---- .../src/dataset/tests/dataset_scanner.rs | 2 +- rust/lance/src/index/vector/fixture_test.rs | 2 +- rust/lance/src/index/vector/ivf.rs | 2 +- rust/lance/src/io/exec/knn.rs | 19 +- 9 files changed, 211 insertions(+), 70 deletions(-) diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 122824252cd..8c6bc402544 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -336,13 +336,12 @@ fn inner_create_scanner<'local>( scanner.refine(refine_factor); } - let distance_type_jstr: JString = env - .call_method(&java_obj, "getDistanceType", "()Ljava/lang/String;", &[])? - .l()? - .into(); - let distance_type_str: String = env.get_string(&distance_type_jstr)?.into(); - let distance_type = DistanceType::try_from(distance_type_str.as_str())?; - scanner.distance_metric(distance_type); + if let Some(distance_type_str) = + env.get_optional_string_from_method(&java_obj, "getDistanceTypeString")? + { + let distance_type = DistanceType::try_from(distance_type_str.as_str())?; + scanner.distance_metric(distance_type); + } let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?; scanner.use_index(use_index); diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 5cb55c200e1..03bcdf8f281 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -171,12 +171,13 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> let refine_factor = env.get_optional_u32_from_method(&java_obj, "getRefineFactor")?; - let distance_type_jstr: JString = env - .call_method(&java_obj, "getDistanceType", "()Ljava/lang/String;", &[])? - .l()? - .into(); - let distance_type_str: String = env.get_string(&distance_type_jstr)?.into(); - let distance_type = DistanceType::try_from(distance_type_str.as_str())?; + let distance_type = if let Some(distance_type_str) = + env.get_optional_string_from_method(&java_obj, "getDistanceTypeString")? + { + Some(DistanceType::try_from(distance_type_str.as_str())?) + } else { + None + }; let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?; diff --git a/java/src/main/java/org/lance/ipc/Query.java b/java/src/main/java/org/lance/ipc/Query.java index 6c51db1dde8..9bd2dc03b90 100644 --- a/java/src/main/java/org/lance/ipc/Query.java +++ b/java/src/main/java/org/lance/ipc/Query.java @@ -29,7 +29,7 @@ public class Query { private final Optional maximumNprobes; private final Optional ef; private final Optional refineFactor; - private final DistanceType distanceType; + private final Optional distanceType; private final boolean useIndex; private Query(Builder builder) { @@ -48,7 +48,7 @@ private Query(Builder builder) { this.maximumNprobes = builder.maximumNprobes; this.ef = builder.ef; this.refineFactor = builder.refineFactor; - this.distanceType = Preconditions.checkNotNull(builder.distanceType, "Metric type must be set"); + this.distanceType = builder.distanceType; this.useIndex = builder.useIndex; } @@ -80,8 +80,12 @@ public Optional getRefineFactor() { return refineFactor; } - public String getDistanceType() { - return distanceType.toString(); + public Optional getDistanceType() { + return distanceType; + } + + public Optional getDistanceTypeString() { + return distanceType.map(DistanceType::toString); } public boolean isUseIndex() { @@ -98,7 +102,7 @@ public String toString() { .add("maximumNprobes", maximumNprobes.orElse(null)) .add("ef", ef.orElse(null)) .add("refineFactor", refineFactor.orElse(null)) - .add("distanceType", distanceType) + .add("distanceType", distanceType.orElse(null)) .add("useIndex", useIndex) .toString(); } @@ -111,7 +115,7 @@ public static class Builder { private Optional maximumNprobes = Optional.empty(); private Optional ef = Optional.empty(); private Optional refineFactor = Optional.empty(); - private DistanceType distanceType = DistanceType.L2; + private Optional distanceType = Optional.empty(); private boolean useIndex = true; /** @@ -219,11 +223,14 @@ public Builder setRefineFactor(int refineFactor) { /** * Sets the distance metric type. * + *

If not set, the query will use the index's metric type (if an index is available), or the + * default metric for the data type (L2 for float vectors, Hamming for binary). + * * @param distanceType The DistanceType to use for the query. * @return The Builder instance for method chaining. */ public Builder setDistanceType(DistanceType distanceType) { - this.distanceType = distanceType; + this.distanceType = Optional.ofNullable(distanceType); return this; } diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 7871def65b6..c25bbe23cb4 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -104,8 +104,9 @@ pub struct Query { /// TODO: should we support fraction / float number here? pub refine_factor: Option, - /// Distance metric type - pub metric_type: DistanceType, + /// Distance metric type. If None, uses the index's metric (if available) + /// or the default for the data type. + pub metric_type: Option, /// Whether to use an ANN index if available pub use_index: bool, diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index fc5f636f5b1..918e26ed39f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1231,7 +1231,7 @@ impl Scanner { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: default_distance_type_for(&element_type), + metric_type: None, use_index: true, dist_q_c: 0.0, }); @@ -1361,7 +1361,7 @@ impl Scanner { /// Change the distance [MetricType], i.e, L2 or Cosine distance. pub fn distance_metric(&mut self, metric_type: MetricType) -> &mut Self { if let Some(q) = self.nearest.as_mut() { - q.metric_type = metric_type + q.metric_type = Some(metric_type) } self } @@ -3034,20 +3034,39 @@ impl Scanner { } else { Arc::new(vec![]) }; - if let Some(index) = indices.iter().find(|i| i.fields.contains(&column_id)) { + // Find an index for the column and check if metric is compatible + let matching_index = + if let Some(index) = indices.iter().find(|i| i.fields.contains(&column_id)) { + let idx = self + .dataset + .open_vector_index( + q.column.as_str(), + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) + .await?; + let index_metric = idx.metric_type(); + + // Check if user's requested metric is compatible with index + let use_this_index = match q.metric_type { + Some(user_metric) => user_metric == index_metric, + None => true, // No preference, use index's metric + }; + + if use_this_index { + Some((index, idx, index_metric)) + } else { + None + } + } else { + None + }; + + if let Some((index, _idx, index_metric)) = matching_index { log::trace!("index found for vector search"); - // There is an index built for the column. - // We will use the index. - let idx = self - .dataset - .open_vector_index( - q.column.as_str(), - &index.uuid.to_string(), - &NoOpMetricsCollector, - ) - .await?; - q.metric_type = idx.metric_type(); - validate_distance_type_for(q.metric_type, &element_type)?; + // Use the index's metric type + q.metric_type = Some(index_metric); + validate_distance_type_for(index_metric, &element_type)?; if matches!(q.refine_factor, Some(0)) { return Err(Error::invalid_input( @@ -3082,7 +3101,12 @@ impl Scanner { Ok(knn_node) } else { - validate_distance_type_for(q.metric_type, &element_type)?; + // Resolve metric type for flat search (use default if not specified) + let metric = q + .metric_type + .unwrap_or_else(|| default_distance_type_for(&element_type)); + q.metric_type = Some(metric); + validate_distance_type_for(metric, &element_type)?; // No index found. use flat search. let mut columns = vec![q.column.clone()]; if let Some(refine_expr) = filter_plan.refine_expr.as_ref() { @@ -3137,7 +3161,7 @@ impl Scanner { ) .await?; let mut q = q.clone(); - q.metric_type = idx.metric_type(); + q.metric_type = Some(idx.metric_type()); // If the vector column is not present, we need to take the vector column, so // that the distance value is comparable with the flat search ones. @@ -3589,11 +3613,19 @@ impl Scanner { /// Add a knn search node to the input plan fn flat_knn(&self, input: Arc, q: &Query) -> Result> { + // Resolve metric_type if not set (use default for the column's element type) + let metric_type = match q.metric_type { + Some(m) => m, + None => { + let (_, element_type) = get_vector_type(self.dataset.schema(), &q.column)?; + default_distance_type_for(&element_type) + } + }; let flat_dist = Arc::new(KNNVectorDistanceExec::try_new( input, &q.column, q.key.clone(), - q.metric_type, + metric_type, )?); let lower: Option<(Expr, Arc)> = q @@ -5122,10 +5154,8 @@ mod test { let mut scan = dataset.scan(); scan.nearest("bin", &query, 3).unwrap(); - assert_eq!( - scan.nearest.as_ref().unwrap().metric_type, - DistanceType::Hamming - ); + // metric_type is None initially; it will be resolved to Hamming during search + assert_eq!(scan.nearest.as_ref().unwrap().metric_type, None); let batch = scan.try_into_batch().await.unwrap(); let ids = batch @@ -5160,6 +5190,102 @@ mod test { ); } + /// Test that when query specifies a metric different from the index, + /// we fall back to flat search and return correct distances. + /// Regression test for https://github.com/lance-format/lance/issues/5608 + #[tokio::test] + async fn test_knn_metric_mismatch_falls_back_to_flat_search() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + // Create IVF_PQ index with L2 metric + test_ds.make_vector_index().await.unwrap(); + + let dataset = &test_ds.dataset; + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + + // Query with Dot metric (different from the L2 index) + let mut scan = dataset.scan(); + scan.nearest("vec", &key, 5).unwrap(); + scan.distance_metric(DistanceType::Dot); + + // Verify the explain plan does NOT show ANNSubIndex (should use flat search) + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + !plan.contains("ANNSubIndex"), + "Expected flat search, but got ANN index in plan:\n{}", + plan + ); + // Should show flat KNN with Dot metric (metric is displayed lowercase) + assert!( + plan.contains("KNNVectorDistance") && plan.to_lowercase().contains("dot"), + "Expected flat KNN with Dot metric in plan:\n{}", + plan + ); + + // Also verify the distances are different from L2 results + let dot_batch = dataset + .scan() + .nearest("vec", &key, 5) + .unwrap() + .distance_metric(DistanceType::Dot) + .try_into_batch() + .await + .unwrap(); + + let l2_batch = dataset + .scan() + .nearest("vec", &key, 5) + .unwrap() + .distance_metric(DistanceType::L2) + .try_into_batch() + .await + .unwrap(); + + let dot_distances: Vec = dot_batch + .column_by_name(DIST_COL) + .unwrap() + .as_primitive::() + .values() + .to_vec(); + let l2_distances: Vec = l2_batch + .column_by_name(DIST_COL) + .unwrap() + .as_primitive::() + .values() + .to_vec(); + + // Dot and L2 distances should be different (this verifies we're using the correct metric) + assert_ne!(dot_distances, l2_distances); + } + + /// Test that when query does not specify a metric, we use the index's metric. + /// Regression test for https://github.com/lance-format/lance/issues/5608 + #[tokio::test] + async fn test_knn_no_metric_uses_index_metric() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + // Create IVF_PQ index with L2 metric + test_ds.make_vector_index().await.unwrap(); + + let dataset = &test_ds.dataset; + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + + // Query without specifying metric + let mut scan = dataset.scan(); + scan.nearest("vec", &key, 5).unwrap(); + // Don't call distance_metric() - should use index's L2 + + // Verify the explain plan shows ANNSubIndex with L2 metric + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + plan.contains("ANNSubIndex") && plan.to_lowercase().contains("l2"), + "Expected ANN index with L2 metric in plan:\n{}", + plan + ); + } + #[rstest] #[tokio::test] async fn test_only_row_id( @@ -7251,7 +7377,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=42), expr=... - ANNSubIndex: name=..., k=42, deltas=1 + ANNSubIndex: name=..., k=42, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7271,7 +7397,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=40), expr=... - ANNSubIndex: name=..., k=40, deltas=1 + ANNSubIndex: name=..., k=40, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7315,7 +7441,7 @@ mod test { Take: columns=\"_distance, _rowid, (i)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... - ANNSubIndex: name=..., k=17, deltas=1 + ANNSubIndex: name=..., k=17, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7336,7 +7462,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... - ANNSubIndex: name=..., k=17, deltas=1 + ANNSubIndex: name=..., k=17, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" @@ -7345,7 +7471,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... - ANNSubIndex: name=..., k=17, deltas=1 + ANNSubIndex: name=..., k=17, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10) @@ -7381,7 +7507,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=6), expr=... - ANNSubIndex: name=..., k=6, deltas=1 + ANNSubIndex: name=..., k=6, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7413,7 +7539,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=15), expr=... - ANNSubIndex: name=..., k=15, deltas=1 + ANNSubIndex: name=..., k=15, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7442,7 +7568,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" @@ -7464,7 +7590,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" @@ -7495,7 +7621,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( @@ -7516,7 +7642,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" @@ -7525,7 +7651,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=3, range_before=None, \ range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" @@ -7563,7 +7689,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=8), expr=... - ANNSubIndex: name=..., k=8, deltas=1 + ANNSubIndex: name=..., k=8, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( @@ -7599,7 +7725,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=11), expr=... - ANNSubIndex: name=..., k=11, deltas=1 + ANNSubIndex: name=..., k=11, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; dataset.make_scalar_index().await?; @@ -7939,7 +8065,7 @@ mod test { .project(&["_distance", "_rowid"]) }, "SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... - ANNSubIndex: name=idx, k=32, deltas=1 + ANNSubIndex: name=idx, k=32, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await @@ -7954,7 +8080,7 @@ mod test { .project(&["_distance", "_rowid"]) }, "SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... - ANNSubIndex: name=idx, k=33, deltas=1 + ANNSubIndex: name=idx, k=33, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await @@ -7982,7 +8108,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... - ANNSubIndex: name=idx, k=34, deltas=1 + ANNSubIndex: name=idx, k=34, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await diff --git a/rust/lance/src/dataset/tests/dataset_scanner.rs b/rust/lance/src/dataset/tests/dataset_scanner.rs index 7f5caff3908..9fce5f6d2ca 100644 --- a/rust/lance/src/dataset/tests/dataset_scanner.rs +++ b/rust/lance/src/dataset/tests/dataset_scanner.rs @@ -45,7 +45,7 @@ async fn test_vector_filter_fts_search() { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: MetricType::L2, + metric_type: Some(MetricType::L2), use_index: true, dist_q_c: 0.0, }; diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 6316e88d898..3445a3cd5d4 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -264,7 +264,7 @@ mod test { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: metric, + metric_type: Some(metric), use_index: true, dist_q_c: 0.0, }; diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 8a590ea8513..d42c06e1638 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2177,7 +2177,7 @@ mod tests { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: MetricType::L2, + metric_type: Some(MetricType::L2), use_index: true, dist_q_c: 0.0, }; diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 8c62541a519..111cb71e4cc 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -643,23 +643,30 @@ impl ANNIvfSubIndexExec { impl DisplayAs for ANNIvfSubIndexExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let metric_str = self + .query + .metric_type + .map(|m| format!("{:?}", m)) + .unwrap_or_else(|| "default".to_string()); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( f, - "ANNSubIndex: name={}, k={}, deltas={}", + "ANNSubIndex: name={}, k={}, deltas={}, metric={}", self.indices[0].name, self.query.k * self.query.refine_factor.unwrap_or(1) as usize, - self.indices.len() + self.indices.len(), + metric_str ) } DisplayFormatType::TreeRender => { write!( f, - "ANNSubIndex\nname={}\nk={}\ndeltas={}", + "ANNSubIndex\nname={}\nk={}\ndeltas={}\nmetric={}", self.indices[0].name, self.query.k * self.query.refine_factor.unwrap_or(1) as usize, - self.indices.len() + self.indices.len(), + metric_str ) } } @@ -1375,7 +1382,7 @@ mod tests { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: DistanceType::L2, + metric_type: Some(DistanceType::L2), use_index: true, dist_q_c: 0.0, } @@ -1552,7 +1559,7 @@ mod tests { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: DistanceType::Cosine, + metric_type: Some(DistanceType::Cosine), use_index: true, dist_q_c: 0.0, }; From 0a8d618b5ec6ce6ac978aadced88a0061487f9ec Mon Sep 17 00:00:00 2001 From: Will Jones Date: Thu, 1 Jan 2026 10:15:11 -0800 Subject: [PATCH 2/5] fix: use index when user specifies incompatible metric type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a user specifies a metric that is incompatible with the data type (e.g., L2 on binary vectors), use the index with its own metric rather than falling back to flat search which would fail. The logic now is: - If metrics match: use the index - If user metric is incompatible with data type: use the index - If user metric is compatible but different from index: flat search 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 --- rust/lance/src/dataset/scanner.rs | 11 +++- .../src/dataset/tests/dataset_scanner.rs | 50 +++++++++++++++++++ rust/lance/src/index/vector/utils.rs | 18 ++++--- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 918e26ed39f..e56a20a7a21 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -79,7 +79,8 @@ use super::Dataset; use crate::dataset::row_offsets_to_row_addresses; use crate::dataset::utils::SchemaAdapter; use crate::index::vector::utils::{ - default_distance_type_for, get_vector_dim, get_vector_type, validate_distance_type_for, + default_distance_type_for, get_vector_dim, get_vector_type, is_distance_type_supported_for, + validate_distance_type_for, }; use crate::index::DatasetIndexInternalExt; use crate::io::exec::filtered_read::{FilteredReadExec, FilteredReadOptions}; @@ -3049,7 +3050,13 @@ impl Scanner { // Check if user's requested metric is compatible with index let use_this_index = match q.metric_type { - Some(user_metric) => user_metric == index_metric, + Some(user_metric) => { + // Use index if metrics match, OR if user's metric is invalid for + // this data type (e.g., L2 on binary vectors). In the latter case, + // there's no valid flat search option, so use the index. + user_metric == index_metric + || !is_distance_type_supported_for(user_metric, &element_type) + } None => true, // No preference, use index's metric }; diff --git a/rust/lance/src/dataset/tests/dataset_scanner.rs b/rust/lance/src/dataset/tests/dataset_scanner.rs index 9fce5f6d2ca..d705cf8e794 100644 --- a/rust/lance/src/dataset/tests/dataset_scanner.rs +++ b/rust/lance/src/dataset/tests/dataset_scanner.rs @@ -465,3 +465,53 @@ async fn check_results( .unwrap(); assert_eq!(ids.values(), expected_ids); } + +/// Test that when a user specifies a metric incompatible with the data type +/// (e.g., L2 on binary vectors), the index is still used with its own metric. +#[tokio::test] +async fn test_knn_incompatible_metric_uses_index() { + use arrow_array::UInt8Array; + + // Create binary vectors (UInt8) - only Hamming is valid for these + let vectors = UInt8Array::from(vec![0u8, 128, 255]); + let vectors = FixedSizeListArray::try_new_from_values(vectors, 1).unwrap(); + let ids = Int32Array::from(vec![0, 1, 2]); + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new( + "vector", + DataType::FixedSizeList(Arc::new(ArrowField::new("item", DataType::UInt8, true)), 1), + false, + ), + ArrowField::new("id", DataType::Int32, false), + ])); + let batch = + RecordBatch::try_new(schema.clone(), vec![Arc::new(vectors), Arc::new(ids)]).unwrap(); + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); + + let mut ds = Dataset::write(reader, "memory://test_incompatible_metric", None) + .await + .unwrap(); + + // Create Hamming index (only valid metric for binary vectors) + let params = VectorIndexParams::ivf_flat(1, MetricType::Hamming); + ds.create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + // Query with L2 metric (invalid for binary vectors) - should still work by using index + let query_vec = UInt8Array::from(vec![128u8]); + let query_vec = FixedSizeListArray::try_new_from_values(query_vec, 1).unwrap(); + + let results = ds + .scan() + .nearest("vector", query_vec.values().as_ref(), 3) + .unwrap() + .distance_metric(MetricType::L2) // Invalid for binary, should be ignored + .try_into_batch() + .await + .unwrap(); + + // Should succeed and return results (using Hamming from the index) + assert_eq!(results.num_rows(), 3); + assert!(results.column_by_name(DIST_COL).is_some()); +} diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 3358f9093d5..607039bd41c 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -183,12 +183,12 @@ pub fn default_distance_type_for(element_type: &arrow_schema::DataType) -> Dista } } -/// Validate that the distance type is supported by the vector element type. -pub fn validate_distance_type_for( +/// Check if a distance type is supported for the given vector element type. +pub fn is_distance_type_supported_for( distance_type: DistanceType, element_type: &arrow_schema::DataType, -) -> Result<()> { - let supported = match element_type { +) -> bool { + match element_type { arrow_schema::DataType::UInt8 => matches!(distance_type, DistanceType::Hamming), arrow_schema::DataType::Int8 | arrow_schema::DataType::Float16 @@ -200,9 +200,15 @@ pub fn validate_distance_type_for( ) } _ => false, - }; + } +} - if supported { +/// Validate that the distance type is supported by the vector element type. +pub fn validate_distance_type_for( + distance_type: DistanceType, + element_type: &arrow_schema::DataType, +) -> Result<()> { + if is_distance_type_supported_for(distance_type, element_type) { Ok(()) } else { Err(Error::invalid_input( From f6aff694f9e4adad3b3fc8af1be6ea5314aeb949 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 2 Jan 2026 14:29:32 -0800 Subject: [PATCH 3/5] Revert "fix: use index when user specifies incompatible metric type" This reverts commit 0a8d618b5ec6ce6ac978aadced88a0061487f9ec. --- rust/lance/src/dataset/scanner.rs | 11 +--- .../src/dataset/tests/dataset_scanner.rs | 50 ------------------- rust/lance/src/index/vector/utils.rs | 18 +++---- 3 files changed, 8 insertions(+), 71 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index e56a20a7a21..918e26ed39f 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -79,8 +79,7 @@ use super::Dataset; use crate::dataset::row_offsets_to_row_addresses; use crate::dataset::utils::SchemaAdapter; use crate::index::vector::utils::{ - default_distance_type_for, get_vector_dim, get_vector_type, is_distance_type_supported_for, - validate_distance_type_for, + default_distance_type_for, get_vector_dim, get_vector_type, validate_distance_type_for, }; use crate::index::DatasetIndexInternalExt; use crate::io::exec::filtered_read::{FilteredReadExec, FilteredReadOptions}; @@ -3050,13 +3049,7 @@ impl Scanner { // Check if user's requested metric is compatible with index let use_this_index = match q.metric_type { - Some(user_metric) => { - // Use index if metrics match, OR if user's metric is invalid for - // this data type (e.g., L2 on binary vectors). In the latter case, - // there's no valid flat search option, so use the index. - user_metric == index_metric - || !is_distance_type_supported_for(user_metric, &element_type) - } + Some(user_metric) => user_metric == index_metric, None => true, // No preference, use index's metric }; diff --git a/rust/lance/src/dataset/tests/dataset_scanner.rs b/rust/lance/src/dataset/tests/dataset_scanner.rs index d705cf8e794..9fce5f6d2ca 100644 --- a/rust/lance/src/dataset/tests/dataset_scanner.rs +++ b/rust/lance/src/dataset/tests/dataset_scanner.rs @@ -465,53 +465,3 @@ async fn check_results( .unwrap(); assert_eq!(ids.values(), expected_ids); } - -/// Test that when a user specifies a metric incompatible with the data type -/// (e.g., L2 on binary vectors), the index is still used with its own metric. -#[tokio::test] -async fn test_knn_incompatible_metric_uses_index() { - use arrow_array::UInt8Array; - - // Create binary vectors (UInt8) - only Hamming is valid for these - let vectors = UInt8Array::from(vec![0u8, 128, 255]); - let vectors = FixedSizeListArray::try_new_from_values(vectors, 1).unwrap(); - let ids = Int32Array::from(vec![0, 1, 2]); - let schema = Arc::new(ArrowSchema::new(vec![ - ArrowField::new( - "vector", - DataType::FixedSizeList(Arc::new(ArrowField::new("item", DataType::UInt8, true)), 1), - false, - ), - ArrowField::new("id", DataType::Int32, false), - ])); - let batch = - RecordBatch::try_new(schema.clone(), vec![Arc::new(vectors), Arc::new(ids)]).unwrap(); - let reader = RecordBatchIterator::new(vec![Ok(batch)], schema); - - let mut ds = Dataset::write(reader, "memory://test_incompatible_metric", None) - .await - .unwrap(); - - // Create Hamming index (only valid metric for binary vectors) - let params = VectorIndexParams::ivf_flat(1, MetricType::Hamming); - ds.create_index(&["vector"], IndexType::Vector, None, ¶ms, true) - .await - .unwrap(); - - // Query with L2 metric (invalid for binary vectors) - should still work by using index - let query_vec = UInt8Array::from(vec![128u8]); - let query_vec = FixedSizeListArray::try_new_from_values(query_vec, 1).unwrap(); - - let results = ds - .scan() - .nearest("vector", query_vec.values().as_ref(), 3) - .unwrap() - .distance_metric(MetricType::L2) // Invalid for binary, should be ignored - .try_into_batch() - .await - .unwrap(); - - // Should succeed and return results (using Hamming from the index) - assert_eq!(results.num_rows(), 3); - assert!(results.column_by_name(DIST_COL).is_some()); -} diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 607039bd41c..3358f9093d5 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -183,12 +183,12 @@ pub fn default_distance_type_for(element_type: &arrow_schema::DataType) -> Dista } } -/// Check if a distance type is supported for the given vector element type. -pub fn is_distance_type_supported_for( +/// Validate that the distance type is supported by the vector element type. +pub fn validate_distance_type_for( distance_type: DistanceType, element_type: &arrow_schema::DataType, -) -> bool { - match element_type { +) -> Result<()> { + let supported = match element_type { arrow_schema::DataType::UInt8 => matches!(distance_type, DistanceType::Hamming), arrow_schema::DataType::Int8 | arrow_schema::DataType::Float16 @@ -200,15 +200,9 @@ pub fn is_distance_type_supported_for( ) } _ => false, - } -} + }; -/// Validate that the distance type is supported by the vector element type. -pub fn validate_distance_type_for( - distance_type: DistanceType, - element_type: &arrow_schema::DataType, -) -> Result<()> { - if is_distance_type_supported_for(distance_type, element_type) { + if supported { Ok(()) } else { Err(Error::invalid_input( From a7130314af66e4512d645f1c15425936108f1b86 Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 2 Jan 2026 14:34:41 -0800 Subject: [PATCH 4/5] fix test --- python/python/tests/test_vector_index.py | 59 +++--------------------- 1 file changed, 7 insertions(+), 52 deletions(-) diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index d66f91f0831..ea86905c55b 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -679,7 +679,7 @@ def test_ivf_flat_over_binary_vector(tmp_path): def test_ivf_flat_respects_index_metric_binary(tmp_path): - # Binary vectors indexed with Hamming should ignore a user-specified L2 metric. + # Searching with binary vectors should default to hamming distance table = pa.Table.from_pydict( { "vector": pa.array([[0], [128], [255]], type=pa.list_(pa.uint8(), 1)), @@ -697,67 +697,22 @@ def test_ivf_flat_respects_index_metric_binary(tmp_path): query = np.array([128], dtype=np.uint8) - # Search should succeed and use the index's Hamming metric despite the L2 hint. - indexed = ds.to_table( + # Search should succeed and use the index's Hamming metric. + indexed = ds.scanner( columns=["id"], nearest={ "column": "vector", "q": query, "k": 3, - "metric": "l2", }, ) + plan = indexed.explain_plan() + indexed = indexed.to_table() # Should succeed even though user asked for L2 (index metric is used). assert indexed["id"].to_pylist() == [1, 0, 2] - - -def test_ivf_flat_respects_index_metric_float(tmp_path): - # Float vectors indexed with L2 should ignore a user-specified Hamming metric. - vectors = np.array( - [ - [0.0, 0.0], - [1.0, 0.0], - [0.0, 2.0], - ], - dtype=np.float32, - ) - table = pa.Table.from_pydict( - { - "vector": pa.array(vectors.tolist(), type=pa.list_(pa.float32(), 2)), - "id": pa.array([0, 1, 2], type=pa.int32()), - } - ) - - ds = lance.write_dataset(table, tmp_path) - ds = ds.create_index( - "vector", - index_type="IVF_FLAT", - num_partitions=1, - metric="l2", - ) - - query = np.array([0.5, 0.0], dtype=np.float32) - - indexed = ds.to_table( - columns=["id"], - nearest={ - "column": "vector", - "q": query, - "k": 3, - "metric": "hamming", - }, - ) - - expected = ds.to_table( - columns=["id"], - nearest={"column": "vector", "q": query, "k": 3}, - ) - - assert indexed["id"].to_pylist() == expected["id"].to_pylist() - assert np.allclose( - indexed["_distance"].to_numpy(), expected["_distance"].to_numpy() - ) + assert "metric=Hamming" in plan + assert "metric=L2" not in plan def test_bruteforce_uses_user_metric(tmp_path): From 82a5b8e6d789272ca804d1d5859135f17a89ba2b Mon Sep 17 00:00:00 2001 From: Will Jones Date: Fri, 2 Jan 2026 14:53:46 -0800 Subject: [PATCH 5/5] add log if there is a mismatch --- rust/lance/src/dataset/scanner.rs | 59 +++++++++++++++++++------------ 1 file changed, 37 insertions(+), 22 deletions(-) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 918e26ed39f..c286fc48b16 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -3035,32 +3035,47 @@ impl Scanner { Arc::new(vec![]) }; // Find an index for the column and check if metric is compatible - let matching_index = - if let Some(index) = indices.iter().find(|i| i.fields.contains(&column_id)) { - let idx = self - .dataset - .open_vector_index( - q.column.as_str(), - &index.uuid.to_string(), - &NoOpMetricsCollector, - ) - .await?; - let index_metric = idx.metric_type(); - - // Check if user's requested metric is compatible with index - let use_this_index = match q.metric_type { - Some(user_metric) => user_metric == index_metric, - None => true, // No preference, use index's metric - }; + let matching_index = if let Some(index) = + indices.iter().find(|i| i.fields.contains(&column_id)) + { + // TODO: Once we do https://github.com/lance-format/lance/issues/5231, we + // should be able to get the metric type directly from the index metadata, + // at least for newer indexes. + let idx = self + .dataset + .open_vector_index( + q.column.as_str(), + &index.uuid.to_string(), + &NoOpMetricsCollector, + ) + .await?; + let index_metric = idx.metric_type(); - if use_this_index { - Some((index, idx, index_metric)) - } else { - None + // Check if user's requested metric is compatible with index + let use_this_index = match q.metric_type { + Some(user_metric) => { + if user_metric == index_metric { + true + } else { + log::warn!( + "Requested metric {:?} is incompatible with index metric {:?}, falling back to brute-force search", + user_metric, + index_metric + ); + false + } } + None => true, // No preference, use index's metric + }; + + if use_this_index { + Some((index, idx, index_metric)) } else { None - }; + } + } else { + None + }; if let Some((index, _idx, index_metric)) = matching_index { log::trace!("index found for vector search");