diff --git a/python/python/tests/test_vector.py b/python/python/tests/test_vector.py index ffa1428a7ea..c02c8312f88 100644 --- a/python/python/tests/test_vector.py +++ b/python/python/tests/test_vector.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright The Lance Authors +import lance import numpy as np import pyarrow as pa import pytest @@ -92,3 +93,57 @@ def _to_vec(lst): return pa.FixedSizeListArray.from_arrays( pa.array(np.array(lst).ravel(), type=pa.float32()), list_size=8 ) + + +def _binary_vectors_table(): + vectors = pa.FixedSizeListArray.from_arrays( + pa.array( + [ + 0x0F, + 0, + 0, + 0, + 0x03, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + type=pa.uint8(), + ), + list_size=4, + ) + ids = pa.array([0, 1, 2], type=pa.int32()) + return pa.Table.from_arrays([ids, vectors], names=["id", "vector"]) + + +def test_binary_vectors_default_hamming(tmp_path): + dataset = lance.write_dataset(_binary_vectors_table(), tmp_path / "bin") + scanner = dataset.scanner( + nearest={"column": "vector", "q": [0x0F, 0, 0, 0], "k": 3} + ) + + plan = scanner.analyze_plan() + assert "metric=hamming" in plan + + tbl = scanner.to_table() + assert tbl["id"].to_pylist() == [0, 1, 2] + assert tbl["_distance"].to_pylist() == [0.0, 2.0, 4.0] + + +def test_binary_vectors_invalid_metric(tmp_path): + dataset = lance.write_dataset(_binary_vectors_table(), tmp_path / "bin") + with pytest.raises( + ValueError, match="Distance type l2 does not support .*UInt8 vectors" + ): + dataset.scanner( + nearest={ + "column": "vector", + "q": [0x0F, 0, 0, 0], + "k": 1, + "metric": "l2", + } + ).to_table() diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 96af9996887..8743c9e1c06 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -76,7 +76,9 @@ use tracing::{info_span, instrument, Span}; use super::Dataset; use crate::dataset::row_offsets_to_row_addresses; use crate::dataset::utils::SchemaAdapter; -use crate::index::vector::utils::{get_vector_dim, get_vector_type}; +use crate::index::vector::utils::{ + default_distance_type_for, get_vector_dim, get_vector_type, validate_distance_type_for, +}; use crate::index::DatasetIndexInternalExt; use crate::io::exec::filtered_read::{FilteredReadExec, FilteredReadOptions}; use crate::io::exec::fts::{BoostQueryExec, FlatMatchQueryExec, MatchQueryExec, PhraseQueryExec}; @@ -1048,11 +1050,11 @@ impl Scanner { } }; - let key = match element_type { - dt if dt == *q.data_type() => q, + let key = match &element_type { + dt if dt == q.data_type() => q, dt if dt.is_floating() => coerce_float_vector( q.as_any().downcast_ref::().unwrap(), - FloatType::try_from(&dt)?, + FloatType::try_from(dt)?, )?, _ => { return Err(Error::invalid_input( @@ -1077,7 +1079,7 @@ impl Scanner { maximum_nprobes: None, ef: None, refine_factor: None, - metric_type: MetricType::L2, + metric_type: default_distance_type_for(&element_type), use_index: true, dist_q_c: 0.0, }); @@ -2847,7 +2849,8 @@ impl Scanner { }; // Sanity check - let (vector_type, _) = get_vector_type(self.dataset.schema(), &q.column)?; + let (vector_type, element_type) = get_vector_type(self.dataset.schema(), &q.column)?; + validate_distance_type_for(q.metric_type, &element_type)?; let column_id = self.dataset.schema().field_id(q.column.as_str())?; let use_index = self.nearest.as_ref().map(|q| q.use_index).unwrap_or(false); @@ -3901,7 +3904,7 @@ mod test { use arrow_array::types::{Float32Type, UInt64Type}; use arrow_array::{ ArrayRef, FixedSizeListArray, Float16Array, Int32Array, LargeStringArray, PrimitiveArray, - RecordBatchIterator, StringArray, StructArray, + RecordBatchIterator, StringArray, StructArray, UInt8Array, }; use arrow_ord::sort::sort_to_indices; @@ -3909,7 +3912,7 @@ mod test { use arrow_select::take; use datafusion::logical_expr::{col, lit}; use half::f16; - use lance_arrow::SchemaExt; + use lance_arrow::{FixedSizeListArrayExt, SchemaExt}; use lance_core::utils::tempfile::TempStrDir; use lance_core::{ROW_CREATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION}; use lance_datagen::{ @@ -3941,6 +3944,47 @@ mod test { assert_plan_node_equals, DatagenExt, FragmentCount, FragmentRowCount, ThrottledStoreWrapper, }; + async fn make_binary_vector_dataset() -> Result<(TempStrDir, Dataset)> { + let tmp_dir = TempStrDir::default(); + let dim = 4; + let schema = Arc::new(ArrowSchema::new(vec![ + ArrowField::new("id", DataType::Int32, false), + ArrowField::new( + "bin", + DataType::FixedSizeList( + Arc::new(ArrowField::new("item", DataType::UInt8, true)), + dim, + ), + false, + ), + ])); + + let vectors = FixedSizeListArray::try_new_from_values( + UInt8Array::from(vec![ + 0b0000_1111u8, + 0, + 0, + 0, // + 0b0000_0011u8, + 0, + 0, + 0, // + 0u8, + 0, + 0, + 0, + ]), + dim, + )?; + let ids = Int32Array::from(vec![0, 1, 2]); + + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?; + let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); + Dataset::write(reader, &tmp_dir, None).await?; + let dataset = Dataset::open(&tmp_dir).await?; + Ok((tmp_dir, dataset)) + } + #[tokio::test] async fn test_batch_size() { let schema = Arc::new(ArrowSchema::new(vec![ @@ -4767,6 +4811,52 @@ mod test { assert_eq!(expected_i, actual_i); } + #[tokio::test] + async fn test_binary_vectors_default_to_hamming() { + let (_tmp_dir, dataset) = make_binary_vector_dataset().await.unwrap(); + let query = UInt8Array::from(vec![0b0000_1111u8, 0, 0, 0]); + + let mut scan = dataset.scan(); + scan.nearest("bin", &query, 3).unwrap(); + + assert_eq!( + scan.nearest.as_ref().unwrap().metric_type, + DistanceType::Hamming + ); + + let batch = scan.try_into_batch().await.unwrap(); + let ids = batch + .column_by_name("id") + .unwrap() + .as_primitive::() + .values(); + assert_eq!(ids, &[0, 1, 2]); + let distances = batch + .column_by_name(DIST_COL) + .unwrap() + .as_primitive::() + .values(); + assert_eq!(distances, &[0.0, 2.0, 4.0]); + } + + #[tokio::test] + async fn test_binary_vectors_invalid_distance_error() { + let (_tmp_dir, dataset) = make_binary_vector_dataset().await.unwrap(); + let query = UInt8Array::from(vec![0b0000_1111u8, 0, 0, 0]); + + let mut scan = dataset.scan(); + scan.nearest("bin", &query, 1).unwrap(); + scan.distance_metric(DistanceType::L2); + + let err = scan.try_into_batch().await.unwrap_err(); + assert!(matches!(err, Error::InvalidInput { .. })); + let message = err.to_string(); + assert!( + message.contains("l2") && message.contains("UInt8"), + "unexpected message: {message}" + ); + } + #[rstest] #[tokio::test] async fn test_only_row_id( diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 4c30a399ce5..8b1a000fb1b 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -7,6 +7,7 @@ use arrow_array::{cast::AsArray, ArrayRef, FixedSizeListArray, RecordBatch}; use futures::StreamExt; use lance_arrow::{interleave_batches, DataTypeExt}; use lance_core::datatypes::Schema; +use lance_linalg::distance::DistanceType; use log::info; use rand::rngs::SmallRng; use rand::seq::{IteratorRandom, SliceRandom}; @@ -122,6 +123,46 @@ pub fn get_vector_type( )) } +/// Returns the default distance type for the given vector element type. +pub fn default_distance_type_for(element_type: &arrow_schema::DataType) -> DistanceType { + match element_type { + arrow_schema::DataType::UInt8 => DistanceType::Hamming, + _ => DistanceType::L2, + } +} + +/// Validate that the distance type is supported by the vector element type. +pub fn validate_distance_type_for( + distance_type: DistanceType, + element_type: &arrow_schema::DataType, +) -> Result<()> { + let supported = match element_type { + arrow_schema::DataType::UInt8 => matches!(distance_type, DistanceType::Hamming), + arrow_schema::DataType::Int8 + | arrow_schema::DataType::Float16 + | arrow_schema::DataType::Float32 + | arrow_schema::DataType::Float64 => { + matches!( + distance_type, + DistanceType::L2 | DistanceType::Cosine | DistanceType::Dot + ) + } + _ => false, + }; + + if supported { + Ok(()) + } else { + Err(Error::invalid_input( + format!( + "Distance type {} does not support {} vectors", + distance_type, element_type + ), + location!(), + )) + } +} + /// If the data type is a fixed size list or list of fixed size list return the inner element type /// and verify it is a type we can create a vector index on. /// diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index adea72118f9..010f0f48515 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -55,7 +55,7 @@ use tokio::sync::Notify; use crate::dataset::Dataset; use crate::index::prefilter::{DatasetPreFilter, FilterLoader}; -use crate::index::vector::utils::get_vector_type; +use crate::index::vector::utils::{get_vector_type, validate_distance_type_for}; use crate::index::DatasetIndexInternalExt; use crate::{Error, Result}; use lance_arrow::*; @@ -146,7 +146,8 @@ impl KNNVectorDistanceExec { distance_type: DistanceType, ) -> Result { let mut output_schema = input.schema().as_ref().clone(); - get_vector_type(&(&output_schema).try_into()?, column)?; + let (_, element_type) = get_vector_type(&(&output_schema).try_into()?, column)?; + validate_distance_type_for(distance_type, &element_type)?; // FlatExec appends a distance column to the input schema. The input // may already have a distance column (possibly in the wrong position), so