Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions python/python/tests/test_vector.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright The Lance Authors

import lance
import numpy as np
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -92,3 +93,57 @@ def _to_vec(lst):
return pa.FixedSizeListArray.from_arrays(
pa.array(np.array(lst).ravel(), type=pa.float32()), list_size=8
)


def _binary_vectors_table():
vectors = pa.FixedSizeListArray.from_arrays(
pa.array(
[
0x0F,
0,
0,
0,
0x03,
0,
0,
0,
0,
0,
0,
0,
],
type=pa.uint8(),
),
list_size=4,
)
ids = pa.array([0, 1, 2], type=pa.int32())
return pa.Table.from_arrays([ids, vectors], names=["id", "vector"])


def test_binary_vectors_default_hamming(tmp_path):
dataset = lance.write_dataset(_binary_vectors_table(), tmp_path / "bin")
scanner = dataset.scanner(
nearest={"column": "vector", "q": [0x0F, 0, 0, 0], "k": 3}
)

plan = scanner.analyze_plan()
assert "metric=hamming" in plan

tbl = scanner.to_table()
assert tbl["id"].to_pylist() == [0, 1, 2]
assert tbl["_distance"].to_pylist() == [0.0, 2.0, 4.0]


def test_binary_vectors_invalid_metric(tmp_path):
dataset = lance.write_dataset(_binary_vectors_table(), tmp_path / "bin")
with pytest.raises(
ValueError, match="Distance type l2 does not support .*UInt8 vectors"
):
Comment thread
BubbleCal marked this conversation as resolved.
dataset.scanner(
nearest={
"column": "vector",
"q": [0x0F, 0, 0, 0],
"k": 1,
"metric": "l2",
}
).to_table()
106 changes: 98 additions & 8 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ use tracing::{info_span, instrument, Span};
use super::Dataset;
use crate::dataset::row_offsets_to_row_addresses;
use crate::dataset::utils::SchemaAdapter;
use crate::index::vector::utils::{get_vector_dim, get_vector_type};
use crate::index::vector::utils::{
default_distance_type_for, get_vector_dim, get_vector_type, validate_distance_type_for,
};
use crate::index::DatasetIndexInternalExt;
use crate::io::exec::filtered_read::{FilteredReadExec, FilteredReadOptions};
use crate::io::exec::fts::{BoostQueryExec, FlatMatchQueryExec, MatchQueryExec, PhraseQueryExec};
Expand Down Expand Up @@ -1048,11 +1050,11 @@ impl Scanner {
}
};

let key = match element_type {
dt if dt == *q.data_type() => q,
let key = match &element_type {
dt if dt == q.data_type() => q,
dt if dt.is_floating() => coerce_float_vector(
q.as_any().downcast_ref::<Float32Array>().unwrap(),
FloatType::try_from(&dt)?,
FloatType::try_from(dt)?,
)?,
_ => {
return Err(Error::invalid_input(
Expand All @@ -1077,7 +1079,7 @@ impl Scanner {
maximum_nprobes: None,
ef: None,
refine_factor: None,
metric_type: MetricType::L2,
metric_type: default_distance_type_for(&element_type),
use_index: true,
dist_q_c: 0.0,
});
Expand Down Expand Up @@ -2847,7 +2849,8 @@ impl Scanner {
};

// Sanity check
let (vector_type, _) = get_vector_type(self.dataset.schema(), &q.column)?;
let (vector_type, element_type) = get_vector_type(self.dataset.schema(), &q.column)?;
validate_distance_type_for(q.metric_type, &element_type)?;

let column_id = self.dataset.schema().field_id(q.column.as_str())?;
let use_index = self.nearest.as_ref().map(|q| q.use_index).unwrap_or(false);
Expand Down Expand Up @@ -3901,15 +3904,15 @@ mod test {
use arrow_array::types::{Float32Type, UInt64Type};
use arrow_array::{
ArrayRef, FixedSizeListArray, Float16Array, Int32Array, LargeStringArray, PrimitiveArray,
RecordBatchIterator, StringArray, StructArray,
RecordBatchIterator, StringArray, StructArray, UInt8Array,
};

use arrow_ord::sort::sort_to_indices;
use arrow_schema::Fields;
use arrow_select::take;
use datafusion::logical_expr::{col, lit};
use half::f16;
use lance_arrow::SchemaExt;
use lance_arrow::{FixedSizeListArrayExt, SchemaExt};
use lance_core::utils::tempfile::TempStrDir;
use lance_core::{ROW_CREATED_AT_VERSION, ROW_LAST_UPDATED_AT_VERSION};
use lance_datagen::{
Expand Down Expand Up @@ -3941,6 +3944,47 @@ mod test {
assert_plan_node_equals, DatagenExt, FragmentCount, FragmentRowCount, ThrottledStoreWrapper,
};

async fn make_binary_vector_dataset() -> Result<(TempStrDir, Dataset)> {
let tmp_dir = TempStrDir::default();
let dim = 4;
let schema = Arc::new(ArrowSchema::new(vec![
ArrowField::new("id", DataType::Int32, false),
ArrowField::new(
"bin",
DataType::FixedSizeList(
Arc::new(ArrowField::new("item", DataType::UInt8, true)),
dim,
),
false,
),
]));

let vectors = FixedSizeListArray::try_new_from_values(
UInt8Array::from(vec![
0b0000_1111u8,
0,
0,
0, //
0b0000_0011u8,
0,
0,
0, //
0u8,
0,
0,
0,
]),
dim,
)?;
let ids = Int32Array::from(vec![0, 1, 2]);

let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(ids), Arc::new(vectors)])?;
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
Dataset::write(reader, &tmp_dir, None).await?;
let dataset = Dataset::open(&tmp_dir).await?;
Ok((tmp_dir, dataset))
}

#[tokio::test]
async fn test_batch_size() {
let schema = Arc::new(ArrowSchema::new(vec![
Expand Down Expand Up @@ -4767,6 +4811,52 @@ mod test {
assert_eq!(expected_i, actual_i);
}

#[tokio::test]
async fn test_binary_vectors_default_to_hamming() {
let (_tmp_dir, dataset) = make_binary_vector_dataset().await.unwrap();
let query = UInt8Array::from(vec![0b0000_1111u8, 0, 0, 0]);

let mut scan = dataset.scan();
scan.nearest("bin", &query, 3).unwrap();

assert_eq!(
scan.nearest.as_ref().unwrap().metric_type,
DistanceType::Hamming
);

let batch = scan.try_into_batch().await.unwrap();
let ids = batch
.column_by_name("id")
.unwrap()
.as_primitive::<Int32Type>()
.values();
assert_eq!(ids, &[0, 1, 2]);
let distances = batch
.column_by_name(DIST_COL)
.unwrap()
.as_primitive::<Float32Type>()
.values();
assert_eq!(distances, &[0.0, 2.0, 4.0]);
}

#[tokio::test]
async fn test_binary_vectors_invalid_distance_error() {
let (_tmp_dir, dataset) = make_binary_vector_dataset().await.unwrap();
let query = UInt8Array::from(vec![0b0000_1111u8, 0, 0, 0]);

let mut scan = dataset.scan();
scan.nearest("bin", &query, 1).unwrap();
scan.distance_metric(DistanceType::L2);

let err = scan.try_into_batch().await.unwrap_err();
assert!(matches!(err, Error::InvalidInput { .. }));
let message = err.to_string();
assert!(
message.contains("l2") && message.contains("UInt8"),
"unexpected message: {message}"
);
}

#[rstest]
#[tokio::test]
async fn test_only_row_id(
Expand Down
41 changes: 41 additions & 0 deletions rust/lance/src/index/vector/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use arrow_array::{cast::AsArray, ArrayRef, FixedSizeListArray, RecordBatch};
use futures::StreamExt;
use lance_arrow::{interleave_batches, DataTypeExt};
use lance_core::datatypes::Schema;
use lance_linalg::distance::DistanceType;
use log::info;
use rand::rngs::SmallRng;
use rand::seq::{IteratorRandom, SliceRandom};
Expand Down Expand Up @@ -122,6 +123,46 @@ pub fn get_vector_type(
))
}

/// Returns the default distance type for the given vector element type.
pub fn default_distance_type_for(element_type: &arrow_schema::DataType) -> DistanceType {
match element_type {
arrow_schema::DataType::UInt8 => DistanceType::Hamming,
_ => DistanceType::L2,
}
}

/// Validate that the distance type is supported by the vector element type.
pub fn validate_distance_type_for(
distance_type: DistanceType,
element_type: &arrow_schema::DataType,
) -> Result<()> {
let supported = match element_type {
arrow_schema::DataType::UInt8 => matches!(distance_type, DistanceType::Hamming),
arrow_schema::DataType::Int8
| arrow_schema::DataType::Float16
| arrow_schema::DataType::Float32
| arrow_schema::DataType::Float64 => {
matches!(
distance_type,
DistanceType::L2 | DistanceType::Cosine | DistanceType::Dot
)
}
_ => false,
};

if supported {
Ok(())
} else {
Err(Error::invalid_input(
format!(
"Distance type {} does not support {} vectors",
distance_type, element_type
),
location!(),
))
}
}

/// If the data type is a fixed size list or list of fixed size list return the inner element type
/// and verify it is a type we can create a vector index on.
///
Expand Down
5 changes: 3 additions & 2 deletions rust/lance/src/io/exec/knn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ use tokio::sync::Notify;

use crate::dataset::Dataset;
use crate::index::prefilter::{DatasetPreFilter, FilterLoader};
use crate::index::vector::utils::get_vector_type;
use crate::index::vector::utils::{get_vector_type, validate_distance_type_for};
use crate::index::DatasetIndexInternalExt;
use crate::{Error, Result};
use lance_arrow::*;
Expand Down Expand Up @@ -146,7 +146,8 @@ impl KNNVectorDistanceExec {
distance_type: DistanceType,
) -> Result<Self> {
let mut output_schema = input.schema().as_ref().clone();
get_vector_type(&(&output_schema).try_into()?, column)?;
let (_, element_type) = get_vector_type(&(&output_schema).try_into()?, column)?;
validate_distance_type_for(distance_type, &element_type)?;

// FlatExec appends a distance column to the input schema. The input
// may already have a distance column (possibly in the wrong position), so
Expand Down
Loading