diff --git a/java/lance-jni/src/blocking_scanner.rs b/java/lance-jni/src/blocking_scanner.rs index 122824252cd..8c6bc402544 100644 --- a/java/lance-jni/src/blocking_scanner.rs +++ b/java/lance-jni/src/blocking_scanner.rs @@ -336,13 +336,12 @@ fn inner_create_scanner<'local>( scanner.refine(refine_factor); } - let distance_type_jstr: JString = env - .call_method(&java_obj, "getDistanceType", "()Ljava/lang/String;", &[])? - .l()? - .into(); - let distance_type_str: String = env.get_string(&distance_type_jstr)?.into(); - let distance_type = DistanceType::try_from(distance_type_str.as_str())?; - scanner.distance_metric(distance_type); + if let Some(distance_type_str) = + env.get_optional_string_from_method(&java_obj, "getDistanceTypeString")? + { + let distance_type = DistanceType::try_from(distance_type_str.as_str())?; + scanner.distance_metric(distance_type); + } let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?; scanner.use_index(use_index); diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 5cb55c200e1..03bcdf8f281 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -171,12 +171,13 @@ pub fn get_query(env: &mut JNIEnv, query_obj: JObject) -> Result> let refine_factor = env.get_optional_u32_from_method(&java_obj, "getRefineFactor")?; - let distance_type_jstr: JString = env - .call_method(&java_obj, "getDistanceType", "()Ljava/lang/String;", &[])? - .l()? - .into(); - let distance_type_str: String = env.get_string(&distance_type_jstr)?.into(); - let distance_type = DistanceType::try_from(distance_type_str.as_str())?; + let distance_type = if let Some(distance_type_str) = + env.get_optional_string_from_method(&java_obj, "getDistanceTypeString")? + { + Some(DistanceType::try_from(distance_type_str.as_str())?) + } else { + None + }; let use_index = env.get_boolean_from_method(&java_obj, "isUseIndex")?; diff --git a/java/src/main/java/org/lance/ipc/Query.java b/java/src/main/java/org/lance/ipc/Query.java index 6c51db1dde8..9bd2dc03b90 100644 --- a/java/src/main/java/org/lance/ipc/Query.java +++ b/java/src/main/java/org/lance/ipc/Query.java @@ -29,7 +29,7 @@ public class Query { private final Optional maximumNprobes; private final Optional ef; private final Optional refineFactor; - private final DistanceType distanceType; + private final Optional distanceType; private final boolean useIndex; private Query(Builder builder) { @@ -48,7 +48,7 @@ private Query(Builder builder) { this.maximumNprobes = builder.maximumNprobes; this.ef = builder.ef; this.refineFactor = builder.refineFactor; - this.distanceType = Preconditions.checkNotNull(builder.distanceType, "Metric type must be set"); + this.distanceType = builder.distanceType; this.useIndex = builder.useIndex; } @@ -80,8 +80,12 @@ public Optional getRefineFactor() { return refineFactor; } - public String getDistanceType() { - return distanceType.toString(); + public Optional getDistanceType() { + return distanceType; + } + + public Optional getDistanceTypeString() { + return distanceType.map(DistanceType::toString); } public boolean isUseIndex() { @@ -98,7 +102,7 @@ public String toString() { .add("maximumNprobes", maximumNprobes.orElse(null)) .add("ef", ef.orElse(null)) .add("refineFactor", refineFactor.orElse(null)) - .add("distanceType", distanceType) + .add("distanceType", distanceType.orElse(null)) .add("useIndex", useIndex) .toString(); } @@ -111,7 +115,7 @@ public static class Builder { private Optional maximumNprobes = Optional.empty(); private Optional ef = Optional.empty(); private Optional refineFactor = Optional.empty(); - private DistanceType distanceType = DistanceType.L2; + private Optional distanceType = Optional.empty(); private boolean useIndex = true; /** @@ -219,11 +223,14 @@ public Builder setRefineFactor(int refineFactor) { /** * Sets the distance metric type. * + *

If not set, the query will use the index's metric type (if an index is available), or the + * default metric for the data type (L2 for float vectors, Hamming for binary). + * * @param distanceType The DistanceType to use for the query. * @return The Builder instance for method chaining. */ public Builder setDistanceType(DistanceType distanceType) { - this.distanceType = distanceType; + this.distanceType = Optional.ofNullable(distanceType); return this; } diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index d66f91f0831..ea86905c55b 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -679,7 +679,7 @@ def test_ivf_flat_over_binary_vector(tmp_path): def test_ivf_flat_respects_index_metric_binary(tmp_path): - # Binary vectors indexed with Hamming should ignore a user-specified L2 metric. + # Searching with binary vectors should default to hamming distance table = pa.Table.from_pydict( { "vector": pa.array([[0], [128], [255]], type=pa.list_(pa.uint8(), 1)), @@ -697,67 +697,22 @@ def test_ivf_flat_respects_index_metric_binary(tmp_path): query = np.array([128], dtype=np.uint8) - # Search should succeed and use the index's Hamming metric despite the L2 hint. - indexed = ds.to_table( + # Search should succeed and use the index's Hamming metric. + indexed = ds.scanner( columns=["id"], nearest={ "column": "vector", "q": query, "k": 3, - "metric": "l2", }, ) + plan = indexed.explain_plan() + indexed = indexed.to_table() # Should succeed even though user asked for L2 (index metric is used). assert indexed["id"].to_pylist() == [1, 0, 2] - - -def test_ivf_flat_respects_index_metric_float(tmp_path): - # Float vectors indexed with L2 should ignore a user-specified Hamming metric. - vectors = np.array( - [ - [0.0, 0.0], - [1.0, 0.0], - [0.0, 2.0], - ], - dtype=np.float32, - ) - table = pa.Table.from_pydict( - { - "vector": pa.array(vectors.tolist(), type=pa.list_(pa.float32(), 2)), - "id": pa.array([0, 1, 2], type=pa.int32()), - } - ) - - ds = lance.write_dataset(table, tmp_path) - ds = ds.create_index( - "vector", - index_type="IVF_FLAT", - num_partitions=1, - metric="l2", - ) - - query = np.array([0.5, 0.0], dtype=np.float32) - - indexed = ds.to_table( - columns=["id"], - nearest={ - "column": "vector", - "q": query, - "k": 3, - "metric": "hamming", - }, - ) - - expected = ds.to_table( - columns=["id"], - nearest={"column": "vector", "q": query, "k": 3}, - ) - - assert indexed["id"].to_pylist() == expected["id"].to_pylist() - assert np.allclose( - indexed["_distance"].to_numpy(), expected["_distance"].to_numpy() - ) + assert "metric=Hamming" in plan + assert "metric=L2" not in plan def test_bruteforce_uses_user_metric(tmp_path): diff --git a/rust/lance-index/src/vector.rs b/rust/lance-index/src/vector.rs index 7871def65b6..c25bbe23cb4 100644 --- a/rust/lance-index/src/vector.rs +++ b/rust/lance-index/src/vector.rs @@ -104,8 +104,9 @@ pub struct Query { /// TODO: should we support fraction / float number here? pub refine_factor: Option, - /// Distance metric type - pub metric_type: DistanceType, + /// Distance metric type. If None, uses the index's metric (if available) + /// or the default for the data type. + pub metric_type: Option, /// Whether to use an ANN index if available pub use_index: bool, diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index fc5f636f5b1..c286fc48b16 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -1231,7 +1231,7 @@ impl Scanner { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: default_distance_type_for(&element_type), + metric_type: None, use_index: true, dist_q_c: 0.0, }); @@ -1361,7 +1361,7 @@ impl Scanner { /// Change the distance [MetricType], i.e, L2 or Cosine distance. pub fn distance_metric(&mut self, metric_type: MetricType) -> &mut Self { if let Some(q) = self.nearest.as_mut() { - q.metric_type = metric_type + q.metric_type = Some(metric_type) } self } @@ -3034,10 +3034,13 @@ impl Scanner { } else { Arc::new(vec![]) }; - if let Some(index) = indices.iter().find(|i| i.fields.contains(&column_id)) { - log::trace!("index found for vector search"); - // There is an index built for the column. - // We will use the index. + // Find an index for the column and check if metric is compatible + let matching_index = if let Some(index) = + indices.iter().find(|i| i.fields.contains(&column_id)) + { + // TODO: Once we do https://github.com/lance-format/lance/issues/5231, we + // should be able to get the metric type directly from the index metadata, + // at least for newer indexes. let idx = self .dataset .open_vector_index( @@ -3046,8 +3049,39 @@ impl Scanner { &NoOpMetricsCollector, ) .await?; - q.metric_type = idx.metric_type(); - validate_distance_type_for(q.metric_type, &element_type)?; + let index_metric = idx.metric_type(); + + // Check if user's requested metric is compatible with index + let use_this_index = match q.metric_type { + Some(user_metric) => { + if user_metric == index_metric { + true + } else { + log::warn!( + "Requested metric {:?} is incompatible with index metric {:?}, falling back to brute-force search", + user_metric, + index_metric + ); + false + } + } + None => true, // No preference, use index's metric + }; + + if use_this_index { + Some((index, idx, index_metric)) + } else { + None + } + } else { + None + }; + + if let Some((index, _idx, index_metric)) = matching_index { + log::trace!("index found for vector search"); + // Use the index's metric type + q.metric_type = Some(index_metric); + validate_distance_type_for(index_metric, &element_type)?; if matches!(q.refine_factor, Some(0)) { return Err(Error::invalid_input( @@ -3082,7 +3116,12 @@ impl Scanner { Ok(knn_node) } else { - validate_distance_type_for(q.metric_type, &element_type)?; + // Resolve metric type for flat search (use default if not specified) + let metric = q + .metric_type + .unwrap_or_else(|| default_distance_type_for(&element_type)); + q.metric_type = Some(metric); + validate_distance_type_for(metric, &element_type)?; // No index found. use flat search. let mut columns = vec![q.column.clone()]; if let Some(refine_expr) = filter_plan.refine_expr.as_ref() { @@ -3137,7 +3176,7 @@ impl Scanner { ) .await?; let mut q = q.clone(); - q.metric_type = idx.metric_type(); + q.metric_type = Some(idx.metric_type()); // If the vector column is not present, we need to take the vector column, so // that the distance value is comparable with the flat search ones. @@ -3589,11 +3628,19 @@ impl Scanner { /// Add a knn search node to the input plan fn flat_knn(&self, input: Arc, q: &Query) -> Result> { + // Resolve metric_type if not set (use default for the column's element type) + let metric_type = match q.metric_type { + Some(m) => m, + None => { + let (_, element_type) = get_vector_type(self.dataset.schema(), &q.column)?; + default_distance_type_for(&element_type) + } + }; let flat_dist = Arc::new(KNNVectorDistanceExec::try_new( input, &q.column, q.key.clone(), - q.metric_type, + metric_type, )?); let lower: Option<(Expr, Arc)> = q @@ -5122,10 +5169,8 @@ mod test { let mut scan = dataset.scan(); scan.nearest("bin", &query, 3).unwrap(); - assert_eq!( - scan.nearest.as_ref().unwrap().metric_type, - DistanceType::Hamming - ); + // metric_type is None initially; it will be resolved to Hamming during search + assert_eq!(scan.nearest.as_ref().unwrap().metric_type, None); let batch = scan.try_into_batch().await.unwrap(); let ids = batch @@ -5160,6 +5205,102 @@ mod test { ); } + /// Test that when query specifies a metric different from the index, + /// we fall back to flat search and return correct distances. + /// Regression test for https://github.com/lance-format/lance/issues/5608 + #[tokio::test] + async fn test_knn_metric_mismatch_falls_back_to_flat_search() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + // Create IVF_PQ index with L2 metric + test_ds.make_vector_index().await.unwrap(); + + let dataset = &test_ds.dataset; + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + + // Query with Dot metric (different from the L2 index) + let mut scan = dataset.scan(); + scan.nearest("vec", &key, 5).unwrap(); + scan.distance_metric(DistanceType::Dot); + + // Verify the explain plan does NOT show ANNSubIndex (should use flat search) + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + !plan.contains("ANNSubIndex"), + "Expected flat search, but got ANN index in plan:\n{}", + plan + ); + // Should show flat KNN with Dot metric (metric is displayed lowercase) + assert!( + plan.contains("KNNVectorDistance") && plan.to_lowercase().contains("dot"), + "Expected flat KNN with Dot metric in plan:\n{}", + plan + ); + + // Also verify the distances are different from L2 results + let dot_batch = dataset + .scan() + .nearest("vec", &key, 5) + .unwrap() + .distance_metric(DistanceType::Dot) + .try_into_batch() + .await + .unwrap(); + + let l2_batch = dataset + .scan() + .nearest("vec", &key, 5) + .unwrap() + .distance_metric(DistanceType::L2) + .try_into_batch() + .await + .unwrap(); + + let dot_distances: Vec = dot_batch + .column_by_name(DIST_COL) + .unwrap() + .as_primitive::() + .values() + .to_vec(); + let l2_distances: Vec = l2_batch + .column_by_name(DIST_COL) + .unwrap() + .as_primitive::() + .values() + .to_vec(); + + // Dot and L2 distances should be different (this verifies we're using the correct metric) + assert_ne!(dot_distances, l2_distances); + } + + /// Test that when query does not specify a metric, we use the index's metric. + /// Regression test for https://github.com/lance-format/lance/issues/5608 + #[tokio::test] + async fn test_knn_no_metric_uses_index_metric() { + let mut test_ds = TestVectorDataset::new(LanceFileVersion::Stable, true) + .await + .unwrap(); + // Create IVF_PQ index with L2 metric + test_ds.make_vector_index().await.unwrap(); + + let dataset = &test_ds.dataset; + let key: Float32Array = (32..64).map(|v| v as f32).collect(); + + // Query without specifying metric + let mut scan = dataset.scan(); + scan.nearest("vec", &key, 5).unwrap(); + // Don't call distance_metric() - should use index's L2 + + // Verify the explain plan shows ANNSubIndex with L2 metric + let plan = scan.explain_plan(false).await.unwrap(); + assert!( + plan.contains("ANNSubIndex") && plan.to_lowercase().contains("l2"), + "Expected ANN index with L2 metric in plan:\n{}", + plan + ); + } + #[rstest] #[tokio::test] async fn test_only_row_id( @@ -7251,7 +7392,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=42), expr=... - ANNSubIndex: name=..., k=42, deltas=1 + ANNSubIndex: name=..., k=42, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7271,7 +7412,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=40), expr=... - ANNSubIndex: name=..., k=40, deltas=1 + ANNSubIndex: name=..., k=40, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7315,7 +7456,7 @@ mod test { Take: columns=\"_distance, _rowid, (i)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... - ANNSubIndex: name=..., k=17, deltas=1 + ANNSubIndex: name=..., k=17, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7336,7 +7477,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... - ANNSubIndex: name=..., k=17, deltas=1 + ANNSubIndex: name=..., k=17, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" @@ -7345,7 +7486,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=17), expr=... - ANNSubIndex: name=..., k=17, deltas=1 + ANNSubIndex: name=..., k=17, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10) @@ -7381,7 +7522,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=6), expr=... - ANNSubIndex: name=..., k=6, deltas=1 + ANNSubIndex: name=..., k=6, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7413,7 +7554,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=15), expr=... - ANNSubIndex: name=..., k=15, deltas=1 + ANNSubIndex: name=..., k=15, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1"; assert_plan_equals( &dataset.dataset, @@ -7442,7 +7583,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" @@ -7464,7 +7605,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=2, range_before=None, range_after=None, \ row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" @@ -7495,7 +7636,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( @@ -7516,7 +7657,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 FilterExec: i@0 > 10 LanceScan: uri=..., projection=[i], row_id=true, row_addr=false, ordered=false, range=None" @@ -7525,7 +7666,7 @@ mod test { Take: columns=\"_distance, _rowid, (i), (s), (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=5), expr=... - ANNSubIndex: name=..., k=5, deltas=1 + ANNSubIndex: name=..., k=5, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 LanceRead: uri=..., projection=[], num_fragments=3, range_before=None, \ range_after=None, row_id=true, row_addr=false, full_filter=i > Int32(10), refine_filter=i > Int32(10)" @@ -7563,7 +7704,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=8), expr=... - ANNSubIndex: name=..., k=8, deltas=1 + ANNSubIndex: name=..., k=8, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; assert_plan_equals( @@ -7599,7 +7740,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=11), expr=... - ANNSubIndex: name=..., k=11, deltas=1 + ANNSubIndex: name=..., k=11, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1 ScalarIndexQuery: query=[i > 10]@i_idx"; dataset.make_scalar_index().await?; @@ -7939,7 +8080,7 @@ mod test { .project(&["_distance", "_rowid"]) }, "SortExec: TopK(fetch=32), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... - ANNSubIndex: name=idx, k=32, deltas=1 + ANNSubIndex: name=idx, k=32, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await @@ -7954,7 +8095,7 @@ mod test { .project(&["_distance", "_rowid"]) }, "SortExec: TopK(fetch=33), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... - ANNSubIndex: name=idx, k=33, deltas=1 + ANNSubIndex: name=idx, k=33, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await @@ -7982,7 +8123,7 @@ mod test { Take: columns=\"_distance, _rowid, (vec)\" CoalesceBatchesExec: target_batch_size=8192 SortExec: TopK(fetch=34), expr=[_distance@0 ASC NULLS LAST, _rowid@1 ASC NULLS LAST]... - ANNSubIndex: name=idx, k=34, deltas=1 + ANNSubIndex: name=idx, k=34, deltas=1, metric=L2 ANNIvfPartition: uuid=..., minimum_nprobes=1, maximum_nprobes=None, deltas=1", ) .await diff --git a/rust/lance/src/dataset/tests/dataset_scanner.rs b/rust/lance/src/dataset/tests/dataset_scanner.rs index 7f5caff3908..9fce5f6d2ca 100644 --- a/rust/lance/src/dataset/tests/dataset_scanner.rs +++ b/rust/lance/src/dataset/tests/dataset_scanner.rs @@ -45,7 +45,7 @@ async fn test_vector_filter_fts_search() { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: MetricType::L2, + metric_type: Some(MetricType::L2), use_index: true, dist_q_c: 0.0, }; diff --git a/rust/lance/src/index/vector/fixture_test.rs b/rust/lance/src/index/vector/fixture_test.rs index 6316e88d898..3445a3cd5d4 100644 --- a/rust/lance/src/index/vector/fixture_test.rs +++ b/rust/lance/src/index/vector/fixture_test.rs @@ -264,7 +264,7 @@ mod test { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: metric, + metric_type: Some(metric), use_index: true, dist_q_c: 0.0, }; diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 8a590ea8513..d42c06e1638 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -2177,7 +2177,7 @@ mod tests { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: MetricType::L2, + metric_type: Some(MetricType::L2), use_index: true, dist_q_c: 0.0, }; diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index 8c62541a519..111cb71e4cc 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -643,23 +643,30 @@ impl ANNIvfSubIndexExec { impl DisplayAs for ANNIvfSubIndexExec { fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result { + let metric_str = self + .query + .metric_type + .map(|m| format!("{:?}", m)) + .unwrap_or_else(|| "default".to_string()); match t { DisplayFormatType::Default | DisplayFormatType::Verbose => { write!( f, - "ANNSubIndex: name={}, k={}, deltas={}", + "ANNSubIndex: name={}, k={}, deltas={}, metric={}", self.indices[0].name, self.query.k * self.query.refine_factor.unwrap_or(1) as usize, - self.indices.len() + self.indices.len(), + metric_str ) } DisplayFormatType::TreeRender => { write!( f, - "ANNSubIndex\nname={}\nk={}\ndeltas={}", + "ANNSubIndex\nname={}\nk={}\ndeltas={}\nmetric={}", self.indices[0].name, self.query.k * self.query.refine_factor.unwrap_or(1) as usize, - self.indices.len() + self.indices.len(), + metric_str ) } } @@ -1375,7 +1382,7 @@ mod tests { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: DistanceType::L2, + metric_type: Some(DistanceType::L2), use_index: true, dist_q_c: 0.0, } @@ -1552,7 +1559,7 @@ mod tests { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: DistanceType::Cosine, + metric_type: Some(DistanceType::Cosine), use_index: true, dist_q_c: 0.0, };