diff --git a/Cargo.lock b/Cargo.lock index 2e68786d957..9f26e238541 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2302,7 +2302,7 @@ dependencies = [ [[package]] name = "fsst" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow-array", "lance-datagen", @@ -3002,7 +3002,7 @@ dependencies = [ [[package]] name = "lance" -version = "0.20.1" +version = "0.21.0" dependencies = [ "all_asserts", "approx", @@ -3082,7 +3082,7 @@ dependencies = [ [[package]] name = "lance-arrow" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow-array", "arrow-buffer", @@ -3099,7 +3099,7 @@ dependencies = [ [[package]] name = "lance-core" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow-array", "arrow-buffer", @@ -3138,7 +3138,7 @@ dependencies = [ [[package]] name = "lance-datafusion" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow", "arrow-array", @@ -3166,7 +3166,7 @@ dependencies = [ [[package]] name = "lance-datagen" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow", "arrow-array", @@ -3183,7 +3183,7 @@ dependencies = [ [[package]] name = "lance-encoding" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrayref", "arrow", @@ -3229,7 +3229,7 @@ dependencies = [ [[package]] name = "lance-encoding-datafusion" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow-array", "arrow-buffer", @@ -3261,7 +3261,7 @@ dependencies = [ [[package]] name = "lance-file" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow-arith", "arrow-array", @@ -3303,7 +3303,7 @@ dependencies = [ [[package]] name = "lance-index" -version = "0.20.1" +version = "0.21.0" dependencies = [ "approx", "arrow", @@ -3362,7 +3362,7 @@ dependencies = [ [[package]] name = "lance-io" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow", "arrow-arith", @@ -3407,7 +3407,7 @@ dependencies = [ [[package]] name = "lance-jni" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow", "arrow-schema", @@ -3428,7 +3428,7 @@ dependencies = [ [[package]] name = "lance-linalg" -version = "0.20.1" +version = "0.21.0" dependencies = [ "approx", "arrow-arith", @@ -3457,7 +3457,7 @@ dependencies = [ [[package]] name = "lance-table" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow", "arrow-array", @@ -3501,7 +3501,7 @@ dependencies = [ [[package]] name = "lance-test-macros" -version = "0.20.1" +version = "0.21.0" dependencies = [ "proc-macro2", "quote", @@ -3510,7 +3510,7 @@ dependencies = [ [[package]] name = "lance-testing" -version = "0.20.1" +version = "0.21.0" dependencies = [ "arrow-array", "arrow-schema", diff --git a/Cargo.toml b/Cargo.toml index 94405a5c925..84c183579c2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,7 +21,7 @@ exclude = ["python"] resolver = "2" [workspace.package] -version = "0.20.1" +version = "0.21.0" edition = "2021" authors = ["Lance Devs "] license = "Apache-2.0" @@ -44,21 +44,21 @@ categories = [ rust-version = "1.78" [workspace.dependencies] -lance = { version = "=0.20.1", path = "./rust/lance" } -lance-arrow = { version = "=0.20.1", path = "./rust/lance-arrow" } -lance-core = { version = "=0.20.1", path = "./rust/lance-core" } -lance-datafusion = { version = "=0.20.1", path = "./rust/lance-datafusion" } -lance-datagen = { version = "=0.20.1", path = "./rust/lance-datagen" } -lance-encoding = { version = "=0.20.1", path = "./rust/lance-encoding" } -lance-encoding-datafusion = { version = "=0.20.1", path = "./rust/lance-encoding-datafusion" } -lance-file = { version = "=0.20.1", path = "./rust/lance-file" } -lance-index = { version = "=0.20.1", path = "./rust/lance-index" } -lance-io = { version = "=0.20.1", path = "./rust/lance-io" } -lance-jni = { version = "=0.20.1", path = "./java/core/lance-jni" } -lance-linalg = { version = "=0.20.1", path = "./rust/lance-linalg" } -lance-table = { version = "=0.20.1", path = "./rust/lance-table" } -lance-test-macros = { version = "=0.20.1", path = "./rust/lance-test-macros" } -lance-testing = { version = "=0.20.1", path = "./rust/lance-testing" } +lance = { version = "=0.21.0", path = "./rust/lance" } +lance-arrow = { version = "=0.21.0", path = "./rust/lance-arrow" } +lance-core = { version = "=0.21.0", path = "./rust/lance-core" } +lance-datafusion = { version = "=0.21.0", path = "./rust/lance-datafusion" } +lance-datagen = { version = "=0.21.0", path = "./rust/lance-datagen" } +lance-encoding = { version = "=0.21.0", path = "./rust/lance-encoding" } +lance-encoding-datafusion = { version = "=0.21.0", path = "./rust/lance-encoding-datafusion" } +lance-file = { version = "=0.21.0", path = "./rust/lance-file" } +lance-index = { version = "=0.21.0", path = "./rust/lance-index" } +lance-io = { version = "=0.21.0", path = "./rust/lance-io" } +lance-jni = { version = "=0.21.0", path = "./java/core/lance-jni" } +lance-linalg = { version = "=0.21.0", path = "./rust/lance-linalg" } +lance-table = { version = "=0.21.0", path = "./rust/lance-table" } +lance-test-macros = { version = "=0.21.0", path = "./rust/lance-test-macros" } +lance-testing = { version = "=0.21.0", path = "./rust/lance-testing" } approx = "0.5.1" # Note that this one does not include pyarrow arrow = { version = "53.2", optional = false, features = ["prettyprint"] } @@ -111,7 +111,7 @@ datafusion-physical-expr = { version = "42.0", features = [ ] } deepsize = "0.2.0" either = "1.0" -fsst = { version = "=0.20.1", path = "./rust/lance-encoding/src/compression_algo/fsst" } +fsst = { version = "=0.21.0", path = "./rust/lance-encoding/src/compression_algo/fsst" } futures = "0.3" http = "1.1.0" hyperloglogplus = { version = "0.4.1", features = ["const-loop"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index a56a87cba14..e9e9f867c4d 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "pylance" -version = "0.20.1" +version = "0.21.0" edition = "2021" authors = ["Lance Devs "] rust-version = "1.65" diff --git a/python/src/utils.rs b/python/src/utils.rs index 9b8420e781b..9f53c90772b 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -15,6 +15,7 @@ use std::sync::Arc; use arrow::compute::concat; +use arrow::datatypes::Float32Type; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; use arrow_array::{cast::AsArray, Array, FixedSizeListArray, Float32Array, UInt32Array}; use arrow_data::ArrayData; @@ -26,7 +27,7 @@ use lance_file::writer::FileWriter; use lance_index::scalar::IndexWriter; use lance_index::vector::hnsw::{builder::HnswBuildParams, HNSW}; use lance_index::vector::v3::subindex::IvfSubIndex; -use lance_linalg::kmeans::compute_partitions; +use lance_linalg::kmeans::{compute_partitions, KMeansAlgoFloat}; use lance_linalg::{ distance::DistanceType, kmeans::{KMeans as LanceKMeans, KMeansParams}, @@ -132,14 +133,15 @@ impl KMeans { if !matches!(fixed_size_arr.value_type(), DataType::Float32) { return Err(PyValueError::new_err("Must be a FixedSizeList of Float32")); }; - let values: Arc = fixed_size_arr.values().as_primitive().clone().into(); - let centroids: &Float32Array = kmeans.centroids.as_primitive(); - let cluster_ids = UInt32Array::from(compute_partitions( - centroids.values(), - values.values(), - kmeans.dimension, - kmeans.distance_type, - )); + let values = fixed_size_arr.values().as_primitive(); + let centroids = kmeans.centroids.as_primitive(); + let cluster_ids = + UInt32Array::from(compute_partitions::< + Float32Type, + KMeansAlgoFloat, + >( + centroids, values, kmeans.dimension, kmeans.distance_type + )); cluster_ids.into_data().to_pyarrow(py) } diff --git a/rust/lance-index/benches/hnsw.rs b/rust/lance-index/benches/hnsw.rs index b51d75d7469..e250dfffd83 100644 --- a/rust/lance-index/benches/hnsw.rs +++ b/rust/lance-index/benches/hnsw.rs @@ -15,7 +15,7 @@ use lance_index::vector::v3::subindex::IvfSubIndex; use pprof::criterion::{Output, PProfProfiler}; use lance_index::vector::{ - flat::storage::FlatStorage, + flat::storage::FlatFloatStorage, hnsw::builder::{HnswBuildParams, HNSW}, }; use lance_linalg::distance::DistanceType; @@ -31,7 +31,7 @@ fn bench_hnsw(c: &mut Criterion) { let data = generate_random_array_with_seed::(TOTAL * DIMENSION, SEED); let fsl = FixedSizeListArray::try_new_from_values(data, DIMENSION as i32).unwrap(); - let vectors = Arc::new(FlatStorage::new(fsl.clone(), DistanceType::L2)); + let vectors = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let query = fsl.value(0); c.bench_function( diff --git a/rust/lance-index/src/vector/flat/index.rs b/rust/lance-index/src/vector/flat/index.rs index f50e995e4cb..bc26fd5620f 100644 --- a/rust/lance-index/src/vector/flat/index.rs +++ b/rust/lance-index/src/vector/flat/index.rs @@ -28,7 +28,7 @@ use crate::{ }, }; -use super::storage::{FlatStorage, FLAT_COLUMN}; +use super::storage::{FlatBinStorage, FlatFloatStorage, FLAT_COLUMN}; /// A Flat index is any index that stores no metadata, and /// during query, it simply scans over the storage and returns the top k results @@ -166,7 +166,7 @@ impl FlatQuantizer { impl Quantization for FlatQuantizer { type BuildParams = (); type Metadata = FlatMetadata; - type Storage = FlatStorage; + type Storage = FlatFloatStorage; fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result { let dim = data.as_fixed_size_list().value_length(); @@ -228,3 +228,81 @@ impl TryFrom for FlatQuantizer { } } } + +#[derive(Debug, Clone, DeepSizeOf)] +pub struct FlatBinQuantizer { + dim: usize, + distance_type: DistanceType, +} + +impl FlatBinQuantizer { + pub fn new(dim: usize, distance_type: DistanceType) -> Self { + Self { dim, distance_type } + } +} + +impl Quantization for FlatBinQuantizer { + type BuildParams = (); + type Metadata = FlatMetadata; + type Storage = FlatBinStorage; + + fn build(data: &dyn Array, distance_type: DistanceType, _: &Self::BuildParams) -> Result { + let dim = data.as_fixed_size_list().value_length(); + Ok(Self::new(dim as usize, distance_type)) + } + + fn code_dim(&self) -> usize { + self.dim + } + + fn column(&self) -> &'static str { + FLAT_COLUMN + } + + fn from_metadata(metadata: &Self::Metadata, distance_type: DistanceType) -> Result { + Ok(Quantizer::FlatBin(Self { + dim: metadata.dim, + distance_type, + })) + } + + fn metadata( + &self, + _: Option, + ) -> Result { + let metadata = FlatMetadata { dim: self.dim }; + Ok(serde_json::to_value(metadata)?) + } + + fn metadata_key() -> &'static str { + "flat" + } + + fn quantization_type() -> QuantizationType { + QuantizationType::Flat + } + + fn quantize(&self, vectors: &dyn Array) -> Result { + Ok(vectors.slice(0, vectors.len())) + } +} + +impl From for Quantizer { + fn from(value: FlatBinQuantizer) -> Self { + Self::FlatBin(value) + } +} + +impl TryFrom for FlatBinQuantizer { + type Error = Error; + + fn try_from(value: Quantizer) -> Result { + match value { + Quantizer::FlatBin(quantizer) => Ok(quantizer), + _ => Err(Error::invalid_input( + "quantizer is not FlatBinQuantizer", + location!(), + )), + } + } +} diff --git a/rust/lance-index/src/vector/flat/storage.rs b/rust/lance-index/src/vector/flat/storage.rs index b3bb11d02a0..9fece3b3f8d 100644 --- a/rust/lance-index/src/vector/flat/storage.rs +++ b/rust/lance-index/src/vector/flat/storage.rs @@ -10,6 +10,8 @@ use crate::vector::storage::{DistCalculator, VectorStore}; use crate::vector::utils::do_prefetch; use arrow::array::AsArray; use arrow::compute::concat_batches; +use arrow::datatypes::UInt8Type; +use arrow_array::ArrowPrimitiveType; use arrow_array::{ types::{Float32Type, UInt64Type}, Array, ArrayRef, FixedSizeListArray, RecordBatch, UInt64Array, @@ -18,6 +20,7 @@ use arrow_schema::{DataType, SchemaRef}; use deepsize::DeepSizeOf; use lance_core::{Error, Result, ROW_ID}; use lance_file::reader::FileReader; +use lance_linalg::distance::hamming::hamming; use lance_linalg::distance::DistanceType; use snafu::{location, Location}; @@ -27,7 +30,7 @@ pub const FLAT_COLUMN: &str = "flat"; /// All data are stored in memory #[derive(Debug, Clone)] -pub struct FlatStorage { +pub struct FlatFloatStorage { batch: RecordBatch, distance_type: DistanceType, @@ -36,14 +39,14 @@ pub struct FlatStorage { vectors: Arc, } -impl DeepSizeOf for FlatStorage { +impl DeepSizeOf for FlatFloatStorage { fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { self.batch.get_array_memory_size() } } #[async_trait::async_trait] -impl QuantizerStorage for FlatStorage { +impl QuantizerStorage for FlatFloatStorage { type Metadata = FlatMetadata; async fn load_partition( _: &FileReader, @@ -55,7 +58,7 @@ impl QuantizerStorage for FlatStorage { } } -impl FlatStorage { +impl FlatFloatStorage { // deprecated, use `try_from_batch` instead pub fn new(vectors: FixedSizeListArray, distance_type: DistanceType) -> Self { let row_ids = Arc::new(UInt64Array::from_iter_values(0..vectors.len() as u64)); @@ -80,8 +83,8 @@ impl FlatStorage { } } -impl VectorStore for FlatStorage { - type DistanceCalculator<'a> = FlatDistanceCal<'a>; +impl VectorStore for FlatFloatStorage { + type DistanceCalculator<'a> = FlatDistanceCal<'a, Float32Type>; fn try_from_batch(batch: RecordBatch, distance_type: DistanceType) -> Result { let row_ids = Arc::new( @@ -149,11 +152,11 @@ impl VectorStore for FlatStorage { } fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> { - FlatDistanceCal::new(self.vectors.as_ref(), query, self.distance_type) + Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type) } fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { - FlatDistanceCal::new( + Self::DistanceCalculator::new( self.vectors.as_ref(), self.vectors.value(id as usize), self.distance_type, @@ -176,14 +179,147 @@ impl VectorStore for FlatStorage { } } -pub struct FlatDistanceCal<'a> { - vectors: &'a [f32], - query: Vec, +/// All data are stored in memory +#[derive(Debug, Clone)] +pub struct FlatBinStorage { + batch: RecordBatch, + distance_type: DistanceType, + + // helper fields + pub(super) row_ids: Arc, + vectors: Arc, +} + +impl DeepSizeOf for FlatBinStorage { + fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize { + self.batch.get_array_memory_size() + } +} + +#[async_trait::async_trait] +impl QuantizerStorage for FlatBinStorage { + type Metadata = FlatMetadata; + async fn load_partition( + _: &FileReader, + _: std::ops::Range, + _: DistanceType, + _: &Self::Metadata, + ) -> Result { + unimplemented!("Flat will be used in new index builder which doesn't require this") + } +} + +impl FlatBinStorage { + pub fn vector(&self, id: u32) -> ArrayRef { + self.vectors.value(id as usize) + } +} + +impl VectorStore for FlatBinStorage { + type DistanceCalculator<'a> = FlatDistanceCal<'a, UInt8Type>; + + fn try_from_batch(batch: RecordBatch, distance_type: DistanceType) -> Result { + let row_ids = Arc::new( + batch + .column_by_name(ROW_ID) + .ok_or(Error::Schema { + message: format!("column {} not found", ROW_ID), + location: location!(), + })? + .as_primitive::() + .clone(), + ); + let vectors = Arc::new( + batch + .column_by_name(FLAT_COLUMN) + .ok_or(Error::Schema { + message: "column flat not found".to_string(), + location: location!(), + })? + .as_fixed_size_list() + .clone(), + ); + Ok(Self { + batch, + distance_type, + row_ids, + vectors, + }) + } + + fn to_batches(&self) -> Result> { + Ok([self.batch.clone()].into_iter()) + } + + fn append_batch(&self, batch: RecordBatch, _vector_column: &str) -> Result { + // TODO: use chunked storage + let new_batch = concat_batches(&batch.schema(), vec![&self.batch, &batch].into_iter())?; + let mut storage = self.clone(); + storage.batch = new_batch; + Ok(storage) + } + + fn schema(&self) -> &SchemaRef { + self.batch.schema_ref() + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn len(&self) -> usize { + self.vectors.len() + } + + fn distance_type(&self) -> DistanceType { + self.distance_type + } + + fn row_id(&self, id: u32) -> u64 { + self.row_ids.values()[id as usize] + } + + fn row_ids(&self) -> impl Iterator { + self.row_ids.values().iter() + } + + fn dist_calculator(&self, query: ArrayRef) -> Self::DistanceCalculator<'_> { + Self::DistanceCalculator::new(self.vectors.as_ref(), query, self.distance_type) + } + + fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { + Self::DistanceCalculator::new( + self.vectors.as_ref(), + self.vectors.value(id as usize), + self.distance_type, + ) + } + + /// Distance between two vectors. + fn distance_between(&self, a: u32, b: u32) -> f32 { + match self.vectors.value_type() { + DataType::Float32 => { + let vector1 = self.vectors.value(a as usize); + let vector2 = self.vectors.value(b as usize); + self.distance_type.func()( + vector1.as_primitive::().values(), + vector2.as_primitive::().values(), + ) + } + _ => unimplemented!(), + } + } +} + +pub struct FlatDistanceCal<'a, T: ArrowPrimitiveType> { + vectors: &'a [T::Native], + query: Vec, dimension: usize, - distance_fn: fn(&[f32], &[f32]) -> f32, + #[allow(clippy::type_complexity)] + distance_fn: fn(&[T::Native], &[T::Native]) -> f32, } -impl<'a> FlatDistanceCal<'a> { +impl<'a> FlatDistanceCal<'a, Float32Type> { fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, distance_type: DistanceType) -> Self { // Gained significant performance improvement by using strong typed primitive slice. // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct. @@ -196,14 +332,31 @@ impl<'a> FlatDistanceCal<'a> { distance_fn: distance_type.func(), } } +} + +impl<'a> FlatDistanceCal<'a, UInt8Type> { + fn new(vectors: &'a FixedSizeListArray, query: ArrayRef, _distance_type: DistanceType) -> Self { + // Gained significant performance improvement by using strong typed primitive slice. + // TODO: to support other data types other than `f32`, make FlatDistanceCal a generic struct. + let flat_array = vectors.values().as_primitive::(); + let dimension = vectors.value_length() as usize; + Self { + vectors: flat_array.values(), + query: query.as_primitive::().values().to_vec(), + dimension, + distance_fn: hamming, + } + } +} +impl FlatDistanceCal<'_, T> { #[inline] - fn get_vector(&self, id: u32) -> &[f32] { + fn get_vector(&self, id: u32) -> &[T::Native] { &self.vectors[self.dimension * id as usize..self.dimension * (id + 1) as usize] } } -impl DistCalculator for FlatDistanceCal<'_> { +impl DistCalculator for FlatDistanceCal<'_, T> { #[inline] fn distance(&self, id: u32) -> f32 { let vector = self.get_vector(id); diff --git a/rust/lance-index/src/vector/hnsw/builder.rs b/rust/lance-index/src/vector/hnsw/builder.rs index a30bbf993c7..abdebed2d36 100644 --- a/rust/lance-index/src/vector/hnsw/builder.rs +++ b/rust/lance-index/src/vector/hnsw/builder.rs @@ -31,7 +31,7 @@ use serde::{Deserialize, Serialize}; use super::super::graph::beam_search; use super::{select_neighbors_heuristic, HnswMetadata, HNSW_TYPE, VECTOR_ID_COL, VECTOR_ID_FIELD}; use crate::prefilter::PreFilter; -use crate::vector::flat::storage::FlatStorage; +use crate::vector::flat::storage::FlatFloatStorage; use crate::vector::graph::builder::GraphBuilderNode; use crate::vector::graph::{greedy_search, Visited}; use crate::vector::graph::{ @@ -100,7 +100,7 @@ impl HnswBuildParams { /// - `data`: A FixedSizeList to build the HNSW. /// - `distance_type`: The distance type to use. pub async fn build(self, data: ArrayRef, distance_type: DistanceType) -> Result { - let vec_store = Arc::new(FlatStorage::new( + let vec_store = Arc::new(FlatFloatStorage::new( data.as_fixed_size_list().clone(), distance_type, )); @@ -819,7 +819,7 @@ mod tests { use crate::scalar::IndexWriter; use crate::vector::v3::subindex::IvfSubIndex; use crate::vector::{ - flat::storage::FlatStorage, + flat::storage::FlatFloatStorage, graph::{DISTS_FIELD, NEIGHBORS_FIELD}, hnsw::{builder::HnswBuildParams, HNSW, VECTOR_ID_FIELD}, }; @@ -831,7 +831,7 @@ mod tests { const NUM_EDGES: usize = 20; let data = generate_random_array(TOTAL * DIM); let fsl = FixedSizeListArray::try_new_from_values(data, DIM as i32).unwrap(); - let store = Arc::new(FlatStorage::new(fsl.clone(), DistanceType::L2)); + let store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let builder = HNSW::index_vectors( store.as_ref(), HnswBuildParams::default() diff --git a/rust/lance-index/src/vector/ivf.rs b/rust/lance-index/src/vector/ivf.rs index 55bfc641732..ab3a685718b 100644 --- a/rust/lance-index/src/vector/ivf.rs +++ b/rust/lance-index/src/vector/ivf.rs @@ -54,7 +54,7 @@ pub fn new_ivf_transformer_with_quantizer( range: Option>, ) -> Result { match quantizer { - Quantizer::Flat(_) => Ok(IvfTransformer::new_flat( + Quantizer::Flat(_) | Quantizer::FlatBin(_) => Ok(IvfTransformer::new_flat( centroids, metric_type, vector_column, diff --git a/rust/lance-index/src/vector/quantizer.rs b/rust/lance-index/src/vector/quantizer.rs index 1290a0f07b2..110e438df0a 100644 --- a/rust/lance-index/src/vector/quantizer.rs +++ b/rust/lance-index/src/vector/quantizer.rs @@ -19,7 +19,7 @@ use snafu::{location, Location}; use crate::{IndexMetadata, INDEX_METADATA_SCHEMA_KEY}; -use super::flat::index::FlatQuantizer; +use super::flat::index::{FlatBinQuantizer, FlatQuantizer}; use super::pq::ProductQuantizer; use super::{ivf::storage::IvfModel, sq::ScalarQuantizer, storage::VectorStore}; @@ -98,6 +98,7 @@ impl QuantizerBuildParams for () { #[derive(Debug, Clone, DeepSizeOf)] pub enum Quantizer { Flat(FlatQuantizer), + FlatBin(FlatBinQuantizer), Product(ProductQuantizer), Scalar(ScalarQuantizer), } @@ -106,6 +107,7 @@ impl Quantizer { pub fn code_dim(&self) -> usize { match self { Self::Flat(fq) => fq.code_dim(), + Self::FlatBin(fq) => fq.code_dim(), Self::Product(pq) => pq.code_dim(), Self::Scalar(sq) => sq.code_dim(), } @@ -114,6 +116,7 @@ impl Quantizer { pub fn column(&self) -> &'static str { match self { Self::Flat(fq) => fq.column(), + Self::FlatBin(fq) => fq.column(), Self::Product(pq) => pq.column(), Self::Scalar(sq) => sq.column(), } @@ -122,6 +125,7 @@ impl Quantizer { pub fn metadata_key(&self) -> &'static str { match self { Self::Flat(_) => FlatQuantizer::metadata_key(), + Self::FlatBin(_) => FlatBinQuantizer::metadata_key(), Self::Product(_) => ProductQuantizer::metadata_key(), Self::Scalar(_) => ScalarQuantizer::metadata_key(), } @@ -130,6 +134,7 @@ impl Quantizer { pub fn quantization_type(&self) -> QuantizationType { match self { Self::Flat(_) => QuantizationType::Flat, + Self::FlatBin(_) => QuantizationType::Flat, Self::Product(_) => QuantizationType::Product, Self::Scalar(_) => QuantizationType::Scalar, } @@ -138,6 +143,7 @@ impl Quantizer { pub fn metadata(&self, args: Option) -> Result { match self { Self::Flat(fq) => fq.metadata(args), + Self::FlatBin(fq) => fq.metadata(args), Self::Product(pq) => pq.metadata(args), Self::Scalar(sq) => sq.metadata(args), } diff --git a/rust/lance-index/src/vector/residual.rs b/rust/lance-index/src/vector/residual.rs index b094e43d114..90730529b41 100644 --- a/rust/lance-index/src/vector/residual.rs +++ b/rust/lance-index/src/vector/residual.rs @@ -1,19 +1,21 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors +use std::ops::{AddAssign, DivAssign}; use std::sync::Arc; +use arrow_array::ArrowNumericType; use arrow_array::{ cast::AsArray, - types::{ArrowPrimitiveType, Float16Type, Float32Type, Float64Type, UInt32Type}, + types::{Float16Type, Float32Type, Float64Type, UInt32Type}, Array, FixedSizeListArray, PrimitiveArray, RecordBatch, UInt32Array, }; use arrow_schema::DataType; use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::{Error, Result}; use lance_linalg::distance::{DistanceType, Dot, L2}; -use lance_linalg::kmeans::compute_partitions; -use num_traits::Float; +use lance_linalg::kmeans::{compute_partitions, KMeansAlgoFloat}; +use num_traits::{Float, FromPrimitive, Num}; use snafu::{location, Location}; use tracing::instrument; @@ -53,29 +55,31 @@ impl ResidualTransform { } } -fn do_compute_residual( +fn do_compute_residual( centroids: &FixedSizeListArray, vectors: &FixedSizeListArray, distance_type: Option, partitions: Option<&UInt32Array>, ) -> Result where - T::Native: Float + L2 + Dot, + T::Native: Num + Float + L2 + Dot + DivAssign + AddAssign + FromPrimitive, { let dimension = centroids.value_length() as usize; - let centroids_slice = centroids.values().as_primitive::().values(); - let vectors_slice = vectors.values().as_primitive::().values(); + let centroids = centroids.values().as_primitive::(); + let vectors = vectors.values().as_primitive::(); let part_ids = partitions.cloned().unwrap_or_else(|| { - compute_partitions( - centroids_slice, - vectors_slice, + compute_partitions::>( + centroids, + vectors, dimension, distance_type.expect("provide either partitions or distance type"), ) .into() }); + let vectors_slice = vectors.values(); + let centroids_slice = centroids.values(); let residuals = vectors_slice .chunks_exact(dimension) .enumerate() diff --git a/rust/lance-linalg/benches/compute_partition.rs b/rust/lance-linalg/benches/compute_partition.rs index 7b155a9aa5b..5cdda57158a 100644 --- a/rust/lance-linalg/benches/compute_partition.rs +++ b/rust/lance-linalg/benches/compute_partition.rs @@ -5,6 +5,7 @@ use std::sync::Arc; use arrow_array::types::Float32Type; use criterion::{criterion_group, criterion_main, Criterion}; +use lance_linalg::kmeans::KMeansAlgoFloat; use lance_linalg::{distance::MetricType, kmeans::compute_partitions}; use lance_testing::datagen::generate_random_array_with_seed; #[cfg(target_os = "linux")] @@ -24,9 +25,9 @@ fn bench_compute_partitions(c: &mut Criterion) { c.bench_function("compute_centroids(L2)", |b| { b.iter(|| { - compute_partitions( - centroids.values(), - input.values(), + compute_partitions::>( + centroids.as_ref(), + &input, DIMENSION, MetricType::L2, ) @@ -35,9 +36,9 @@ fn bench_compute_partitions(c: &mut Criterion) { c.bench_function("compute_centroids(Cosine)", |b| { b.iter(|| { - compute_partitions( - centroids.values(), - input.values(), + compute_partitions::>( + centroids.as_ref(), + &input, DIMENSION, MetricType::Cosine, ) diff --git a/rust/lance-linalg/src/distance/hamming.rs b/rust/lance-linalg/src/distance/hamming.rs index 0b94f867bc0..80e03088318 100644 --- a/rust/lance-linalg/src/distance/hamming.rs +++ b/rust/lance-linalg/src/distance/hamming.rs @@ -3,6 +3,14 @@ //! Hamming distance. +use std::sync::Arc; + +use crate::{Error, Result}; +use arrow_array::cast::AsArray; +use arrow_array::types::UInt8Type; +use arrow_array::{Array, Float32Array}; +use arrow_schema::DataType; + pub trait Hamming { /// Hamming distance between two vectors. fn hamming(x: &[u8], y: &[u8]) -> f32; @@ -44,6 +52,37 @@ pub fn hamming_scalar(x: &[u8], y: &[u8]) -> f32 { .sum::() as f32 } +pub fn hamming_distance_batch<'a>( + from: &'a [u8], + to: &'a [u8], + dimension: usize, +) -> Box + 'a> { + debug_assert_eq!(from.len(), dimension); + debug_assert_eq!(to.len() % dimension, 0); + Box::new(to.chunks_exact(dimension).map(|v| hamming(from, v))) +} + +pub fn hamming_distance_arrow_batch(from: &dyn Array, to: &dyn Array) -> Result> { + let dists = match *from.data_type() { + DataType::UInt8 => hamming_distance_batch( + from.as_primitive::().values(), + to.as_primitive::().values(), + from.len(), + ), + _ => { + return Err(Error::InvalidArgumentError(format!( + "Unsupported data type: {:?}", + from.data_type() + ))) + } + }; + + Ok(Arc::new(Float32Array::new( + dists.collect(), + to.nulls().cloned(), + ))) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/lance-linalg/src/kmeans.rs b/rust/lance-linalg/src/kmeans.rs index 57c8f16839a..a318a92b6cc 100644 --- a/rust/lance-linalg/src/kmeans.rs +++ b/rust/lance-linalg/src/kmeans.rs @@ -28,7 +28,7 @@ use num_traits::{AsPrimitive, Float, FromPrimitive, Num, Zero}; use rand::prelude::*; use rayon::prelude::*; -use crate::distance::hamming::hamming; +use crate::distance::hamming::{hamming, hamming_distance_batch}; use crate::distance::{dot_distance_batch, DistanceType}; use crate::kernels::{argmax, argmin_value_float}; use crate::{ @@ -170,7 +170,7 @@ fn hist_stddev(k: usize, membership: &[Option]) -> f32 { .sqrt() } -trait KMeansAlgo { +pub trait KMeansAlgo { /// Recompute the membership of each vector. /// /// Parameters: @@ -194,7 +194,7 @@ trait KMeansAlgo { ) -> KMeans; } -struct KMeansAlgoFloat +pub struct KMeansAlgoFloat where T::Native: Float + Num, { @@ -596,6 +596,12 @@ pub fn kmeans_find_partitions_arrow_array( nprobes, distance_type, )?), + (DataType::UInt8, DataType::UInt8) => kmeans_find_partitions_binary( + centroids.values().as_primitive::().values(), + query.as_primitive::().values(), + nprobes, + distance_type, + ), _ => Err(ArrowError::InvalidArgumentError(format!( "Centroids and vectors have different types: {} != {}", centroids.value_type(), @@ -637,6 +643,27 @@ pub fn kmeans_find_partitions( sort_to_indices(&dists_arr, None, Some(nprobes)) } +pub fn kmeans_find_partitions_binary( + centroids: &[u8], + query: &[u8], + nprobes: usize, + distance_type: DistanceType, +) -> Result { + let dists: Vec = match distance_type { + DistanceType::Hamming => hamming_distance_batch(query, centroids, query.len()).collect(), + _ => { + panic!( + "KMeans::find_partitions: {} is not supported", + distance_type + ); + } + }; + + // TODO: use heap to just keep nprobes smallest values. + let dists_arr = Float32Array::from(dists); + sort_to_indices(&dists_arr, None, Some(nprobes)) +} + /// Compute partitions from Arrow FixedSizeListArray. pub fn compute_partitions_arrow_array( centroids: &FixedSizeListArray, @@ -649,21 +676,36 @@ pub fn compute_partitions_arrow_array( )); } match (centroids.value_type(), vectors.value_type()) { - (DataType::Float16, DataType::Float16) => Ok(compute_partitions( - centroids.values().as_primitive::().values(), - vectors.values().as_primitive::().values(), + (DataType::Float16, DataType::Float16) => Ok(compute_partitions::< + Float16Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.values().as_primitive(), centroids.value_length(), distance_type, )), - (DataType::Float32, DataType::Float32) => Ok(compute_partitions( - centroids.values().as_primitive::().values(), - vectors.values().as_primitive::().values(), + (DataType::Float32, DataType::Float32) => Ok(compute_partitions::< + Float32Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.values().as_primitive(), centroids.value_length(), distance_type, )), - (DataType::Float64, DataType::Float64) => Ok(compute_partitions( - centroids.values().as_primitive::().values(), - vectors.values().as_primitive::().values(), + (DataType::Float64, DataType::Float64) => Ok(compute_partitions::< + Float64Type, + KMeansAlgoFloat, + >( + centroids.values().as_primitive(), + vectors.values().as_primitive(), + centroids.value_length(), + distance_type, + )), + (DataType::UInt8, DataType::UInt8) => Ok(compute_partitions::( + centroids.values().as_primitive(), + vectors.values().as_primitive(), centroids.value_length(), distance_type, )), @@ -676,17 +718,23 @@ pub fn compute_partitions_arrow_array( /// Compute partition ID of each vector in the KMeans. /// /// If returns `None`, means the vector is not valid, i.e., all `NaN`. -pub fn compute_partitions( - centroids: &[T], - vectors: &[T], +pub fn compute_partitions>( + centroids: &PrimitiveArray, + vectors: &PrimitiveArray, dimension: impl AsPrimitive, distance_type: DistanceType, -) -> Vec> { +) -> Vec> +where + T::Native: Num, +{ let dimension = dimension.as_(); - vectors - .par_chunks(dimension) - .map(|vec| compute_partition(centroids, vec, distance_type)) - .collect::>() + let (membership, _) = K::compute_membership_and_loss( + centroids.values(), + vectors.values(), + dimension, + distance_type, + ); + membership } #[inline] @@ -752,7 +800,12 @@ mod tests { ) }) .collect::>(); - let actual = compute_partitions(centroids.values(), data.values(), DIM, DistanceType::L2); + let actual = compute_partitions::>( + ¢roids, + &data, + DIM, + DistanceType::L2, + ); assert_eq!(expected, actual); } @@ -782,11 +835,16 @@ mod tests { let centroids = generate_random_array(DIM * NUM_CENTROIDS); let values = Float32Array::from_iter_values(repeat(f32::NAN).take(DIM * K)); - compute_partitions::(centroids.values(), values.values(), DIM, DistanceType::L2) - .iter() - .for_each(|cd| { - assert!(cd.is_none()); - }); + compute_partitions::>( + ¢roids, + &values, + DIM, + DistanceType::L2, + ) + .iter() + .for_each(|cd| { + assert!(cd.is_none()); + }); } #[tokio::test] diff --git a/rust/lance/examples/hnsw.rs b/rust/lance/examples/hnsw.rs index 414038167fa..9c8b9d558ae 100644 --- a/rust/lance/examples/hnsw.rs +++ b/rust/lance/examples/hnsw.rs @@ -16,7 +16,7 @@ use futures::StreamExt; use lance::Dataset; use lance_index::vector::v3::subindex::IvfSubIndex; use lance_index::vector::{ - flat::storage::FlatStorage, + flat::storage::FlatFloatStorage, hnsw::{builder::HnswBuildParams, HNSW}, }; use lance_linalg::distance::DistanceType; @@ -79,7 +79,7 @@ async fn main() { let fsl = concat(&arrs).unwrap().as_fixed_size_list().clone(); println!("Loaded {:?} batches", fsl.len()); - let vector_store = Arc::new(FlatStorage::new(fsl.clone(), DistanceType::L2)); + let vector_store = Arc::new(FlatFloatStorage::new(fsl.clone(), DistanceType::L2)); let q = fsl.value(0); let k = 10; diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index 034168c26c0..a1e8b82ea2d 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -1673,7 +1673,7 @@ mod tests { let mut scanner = dataset.scan(); scanner - .nearest("vec", &vec![0.0; 128].into(), 10) + .nearest("vec", &vec![0.0f32; 128].into(), 10) .unwrap() .project(&["i"]) .unwrap(); diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 9c7de8ba6f2..b813c633f03 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -7,7 +7,9 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow_array::{Array, Float32Array, Int64Array, RecordBatch}; +use arrow_array::{ + Array, ArrowPrimitiveType, Float32Array, Int64Array, PrimitiveArray, RecordBatch, +}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema, SchemaRef, SortOptions}; use arrow_select::concat::concat_batches; use async_recursion::async_recursion; @@ -631,7 +633,12 @@ impl Scanner { } /// Find k-nearest neighbor within the vector column. - pub fn nearest(&mut self, column: &str, q: &Float32Array, k: usize) -> Result<&mut Self> { + pub fn nearest( + &mut self, + column: &str, + q: &PrimitiveArray, + k: usize, + ) -> Result<&mut Self> { if !self.prefilter { // We can allow fragment scan if the input to nearest is a prefilter. // The fragment scan will be performed by the prefilter. @@ -661,14 +668,20 @@ impl Scanner { ))?; let key = match field.data_type() { DataType::FixedSizeList(dt, _) => { - if dt.data_type().is_floating() { - coerce_float_vector(q, FloatType::try_from(dt.data_type())?)? + if dt.data_type() == q.data_type() { + Box::new(q.clone()) + } else if dt.data_type().is_floating() { + coerce_float_vector( + q.as_any().downcast_ref::().unwrap(), + FloatType::try_from(dt.data_type())?, + )? } else { return Err(Error::invalid_input( format!( - "Column {} is not a vector column (type: {})", + "Column {} has element type {} and the query vector is {}", column, - field.data_type() + dt.data_type(), + q.data_type(), ), location!(), )); @@ -1574,7 +1587,9 @@ impl Scanner { let schema = self.dataset.schema(); if let Some(field) = schema.field(&q.column) { match field.data_type() { - DataType::FixedSizeList(subfield, _) if subfield.data_type().is_floating() => {} + DataType::FixedSizeList(subfield, _) + if subfield.data_type().is_floating() + || *subfield.data_type() == DataType::UInt8 => {} _ => { return Err(Error::invalid_input( format!( diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index ace9906d5cc..8dc5d966228 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -21,7 +21,7 @@ use lance_index::scalar::expression::{ }; use lance_index::scalar::lance_format::LanceIndexStore; use lance_index::scalar::{InvertedIndexParams, ScalarIndex, ScalarIndexType}; -use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::pq::ProductQuantizer; use lance_index::vector::sq::ScalarQuantizer; @@ -270,8 +270,15 @@ impl DatasetIndexExt for Dataset { location: location!(), })?; - build_vector_index(self, column, &index_name, &index_id.to_string(), vec_params) - .await?; + // this is a large future so move it to heap + Box::pin(build_vector_index( + self, + column, + &index_name, + &index_id.to_string(), + vec_params, + )) + .await?; vector_index_details() } // Can't use if let Some(...) here because it's not stable yet. @@ -757,6 +764,16 @@ impl DatasetIndexInternalExt for Dataset { .await?; Ok(Arc::new(ivf) as Arc) } + DataType::UInt8 => { + let ivf = IVFIndex::::try_new( + self.object_store.clone(), + self.indices_dir(), + uuid.to_owned(), + Arc::downgrade(&self.session), + ) + .await?; + Ok(Arc::new(ivf) as Arc) + } _ => Err(Error::Index { message: format!( "the field type {} is not supported for FLAT index", diff --git a/rust/lance/src/index/append.rs b/rust/lance/src/index/append.rs index 3c6a377dd5a..eb64030e103 100644 --- a/rust/lance/src/index/append.rs +++ b/rust/lance/src/index/append.rs @@ -152,6 +152,7 @@ pub async fn merge_indices<'a>( mod tests { use super::*; + use arrow::datatypes::Float32Type; use arrow_array::cast::AsArray; use arrow_array::types::UInt32Type; use arrow_array::{FixedSizeListArray, RecordBatch, RecordBatchIterator, UInt32Array}; @@ -225,7 +226,9 @@ mod tests { let q = array.value(5); let mut scanner = dataset.scan(); - scanner.nearest("vector", q.as_primitive(), 10).unwrap(); + scanner + .nearest("vector", q.as_primitive::(), 10) + .unwrap(); let results = scanner .try_into_stream() .await @@ -257,7 +260,9 @@ mod tests { assert_eq!(index_dirs.len(), 2); let mut scanner = dataset.scan(); - scanner.nearest("vector", q.as_primitive(), 10).unwrap(); + scanner + .nearest("vector", q.as_primitive::(), 10) + .unwrap(); let results = scanner .try_into_stream() .await @@ -385,7 +390,7 @@ mod tests { .scan() .project(&["id"]) .unwrap() - .nearest("vector", array.value(0).as_primitive(), 2) + .nearest("vector", array.value(0).as_primitive::(), 2) .unwrap() .refine(1) .try_into_batch() diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index bd05fcc6436..122889807e6 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -15,9 +15,10 @@ mod utils; #[cfg(test)] mod fixture_test; +use arrow_schema::DataType; use builder::IvfIndexBuilder; use lance_file::reader::FileReader; -use lance_index::vector::flat::index::{FlatIndex, FlatQuantizer}; +use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::pq::ProductQuantizer; @@ -252,18 +253,61 @@ pub(crate) async fn build_vector_index( let temp_dir_path = Path::from_filesystem_path(temp_dir.path())?; let shuffler = IvfShuffler::new(temp_dir_path, ivf_params.num_partitions); if is_ivf_flat(stages) { - IvfIndexBuilder::::new( - dataset.clone(), - column.to_owned(), - dataset.indices_dir().child(uuid), - params.metric_type, - Box::new(shuffler), - Some(ivf_params.clone()), - Some(()), - (), - )? - .build() - .await?; + let data_type = dataset + .schema() + .field(column) + .ok_or(Error::Schema { + message: format!("Column {} not found in schema", column), + location: location!(), + })? + .data_type(); + match data_type { + DataType::FixedSizeList(f, _) => match f.data_type() { + DataType::Float16 | DataType::Float32 | DataType::Float64 => { + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), + params.metric_type, + Box::new(shuffler), + Some(ivf_params.clone()), + Some(()), + (), + )? + .build() + .await?; + } + DataType::UInt8 => { + IvfIndexBuilder::::new( + dataset.clone(), + column.to_owned(), + dataset.indices_dir().child(uuid), + params.metric_type, + Box::new(shuffler), + Some(ivf_params.clone()), + Some(()), + (), + )? + .build() + .await?; + } + _ => { + return Err(Error::Index { + message: format!( + "Build Vector Index: invalid data type: {:?}", + f.data_type() + ), + location: location!(), + }); + } + }, + _ => { + return Err(Error::Index { + message: format!("Build Vector Index: invalid data type: {:?}", data_type), + location: location!(), + }); + } + } } else if is_ivf_pq(stages) { let len = stages.len(); let StageParams::PQ(pq_params) = &stages[len - 1] else { diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index c4c22265c4a..c79fcf45b45 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -14,7 +14,7 @@ use lance_core::{Error, Result, ROW_ID_FIELD}; use lance_encoding::decoder::{DecoderPlugins, FilterExpression}; use lance_file::v2::reader::FileReaderOptions; use lance_file::v2::{reader::FileReader, writer::FileWriter}; -use lance_index::vector::flat::storage::FlatStorage; +use lance_index::vector::flat::storage::FlatFloatStorage; use lance_index::vector::ivf::storage::IvfModel; use lance_index::vector::quantizer::{ QuantizationMetadata, QuantizationType, QuantizerBuildParams, @@ -434,7 +434,7 @@ impl IvfIndexBuilde // build the sub index, with in-memory storage let index_len = { let vectors = batch[&self.column].as_fixed_size_list(); - let flat_storage = FlatStorage::new(vectors.clone(), self.distance_type); + let flat_storage = FlatFloatStorage::new(vectors.clone(), self.distance_type); let sub_index = S::index_vectors(&flat_storage, self.sub_index_params.clone())?; let path = self.temp_dir.child(format!("index_part{}", part_id)); let writer = object_store.create(&path).await?; diff --git a/rust/lance/src/index/vector/ivf.rs b/rust/lance/src/index/vector/ivf.rs index 25dfee8b364..8b7fd6b62ac 100644 --- a/rust/lance/src/index/vector/ivf.rs +++ b/rust/lance/src/index/vector/ivf.rs @@ -9,6 +9,7 @@ use std::{ sync::{Arc, Weak}, }; +use arrow::datatypes::UInt8Type; use arrow_arith::numeric::sub; use arrow_array::{ cast::{as_struct_array, AsArray}, @@ -638,6 +639,7 @@ async fn optimize_ivf_hnsw_indices( // Write the metadata of quantizer let quantization_metadata = match &quantizer { Quantizer::Flat(_) => None, + Quantizer::FlatBin(_) => None, Quantizer::Product(pq) => { let codebook_tensor = pb::Tensor::try_from(&pq.codebook)?; let codebook_pos = aux_writer.tell().await?; @@ -1604,6 +1606,7 @@ async fn write_ivf_hnsw_file( // For PQ, we need to store the codebook let quantization_metadata = match &quantizer { Quantizer::Flat(_) => None, + Quantizer::FlatBin(_) => None, Quantizer::Product(pq) => { let codebook_tensor = pb::Tensor::try_from(&pq.codebook)?; let codebook_pos = aux_writer.tell().await?; @@ -1731,6 +1734,15 @@ async fn train_ivf_model( ) .await } + (DataType::UInt8, DistanceType::Hamming) => { + do_train_ivf_model::( + values.as_primitive::().values(), + dim, + distance_type, + params, + ) + .await + } _ => Err(Error::Index { message: "Unsupported data type".to_string(), location: location!(), @@ -2750,7 +2762,7 @@ mod tests { true, )])); - let arr = generate_random_array_with_range(1000 * DIM, 1000.0..1001.0); + let arr = generate_random_array_with_range::(1000 * DIM, 1000.0..1001.0); let fsl = FixedSizeListArray::try_new_from_values(arr.clone(), DIM as i32).unwrap(); let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(fsl)]).unwrap(); let batches = RecordBatchIterator::new(vec![batch].into_iter().map(Ok), schema.clone()); diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index 8290f88ab26..3fe89b74a82 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -320,6 +320,7 @@ pub(super) async fn write_hnsw_quantization_index_partitions( let code_column = match &quantizer { Quantizer::Flat(_) => None, + Quantizer::FlatBin(_) => None, Quantizer::Product(pq) => Some(pq.column()), Quantizer::Scalar(_) => None, }; diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index f518d41bfc6..a20282842cf 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -518,9 +518,11 @@ mod tests { use std::collections::HashSet; use std::{collections::HashMap, ops::Range, sync::Arc}; - use arrow::datatypes::UInt64Type; + use arrow::datatypes::{UInt64Type, UInt8Type}; use arrow::{array::AsArray, datatypes::Float32Type}; - use arrow_array::{Array, FixedSizeListArray, RecordBatch, RecordBatchIterator}; + use arrow_array::{ + Array, ArrowPrimitiveType, FixedSizeListArray, RecordBatch, RecordBatchIterator, + }; use arrow_schema::{DataType, Field, Schema}; use lance_arrow::FixedSizeListArrayExt; @@ -531,9 +533,10 @@ mod tests { use lance_index::vector::sq::builder::SQBuildParams; use lance_index::vector::DIST_COL; use lance_index::{DatasetIndexExt, IndexType}; + use lance_linalg::distance::hamming::hamming; use lance_linalg::distance::DistanceType; - use lance_linalg::kernels::normalize_arrow; use lance_testing::datagen::generate_random_array_with_range; + use rand::distributions::uniform::SampleUniform; use rstest::rstest; use tempfile::tempdir; @@ -541,28 +544,32 @@ mod tests { const DIM: usize = 32; - async fn generate_test_dataset( + async fn generate_test_dataset( test_uri: &str, - range: Range, - ) -> (Dataset, Arc) { - let vectors = generate_random_array_with_range::(1000 * DIM, range); - let vectors = normalize_arrow(&vectors).unwrap(); + range: Range, + ) -> (Dataset, Arc) + where + T::Native: SampleUniform, + { + let vectors = generate_random_array_with_range::(1000 * DIM, range); let metadata: HashMap = vec![("test".to_string(), "ivf_pq".to_string())] .into_iter() .collect(); - + let data_type = vectors.data_type().clone(); let schema: Arc<_> = Schema::new(vec![Field::new( "vector", DataType::FixedSizeList( - Arc::new(Field::new("item", DataType::Float32, true)), + Arc::new(Field::new("item", data_type.clone(), true)), DIM as i32, ), true, )]) .with_metadata(metadata) .into(); - let fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); - let fsl = lance_linalg::kernels::normalize_fsl(&fsl).unwrap(); + let mut fsl = FixedSizeListArray::try_new_from_values(vectors, DIM as i32).unwrap(); + if data_type != DataType::UInt8 { + fsl = lance_linalg::kernels::normalize_fsl(&fsl).unwrap(); + } let array = Arc::new(fsl); let batch = RecordBatch::try_new(schema.clone(), vec![array.clone()]).unwrap(); @@ -574,16 +581,22 @@ mod tests { #[allow(dead_code)] fn ground_truth( vectors: &FixedSizeListArray, - query: &[f32], + query: &dyn Array, k: usize, distance_type: DistanceType, ) -> Vec<(f32, u64)> { let mut dists = vec![]; for i in 0..vectors.len() { - let dist = distance_type.func()( - query, - vectors.value(i).as_primitive::().values(), - ); + let dist = match distance_type { + DistanceType::Hamming => hamming( + query.as_primitive::().values(), + vectors.value(i).as_primitive::().values(), + ), + _ => distance_type.func()( + query.as_primitive::().values(), + vectors.value(i).as_primitive::().values(), + ), + }; dists.push((dist, i as u64)); } dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap()); @@ -592,12 +605,31 @@ mod tests { } async fn test_index(params: VectorIndexParams, nlist: usize, recall_requirement: f32) { + match params.metric_type { + DistanceType::Hamming => { + test_index_impl::(params, nlist, recall_requirement, 0..2).await; + } + _ => { + test_index_impl::(params, nlist, recall_requirement, 0.0..1.0).await; + } + } + } + + async fn test_index_impl( + params: VectorIndexParams, + nlist: usize, + recall_requirement: f32, + range: Range, + ) where + T::Native: SampleUniform, + { let test_dir = tempdir().unwrap(); let test_uri = test_dir.path().to_str().unwrap(); - let (mut dataset, vectors) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, vectors) = generate_test_dataset::(test_uri, range).await; + let vector_column = "vector"; dataset - .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .create_index(&[vector_column], IndexType::Vector, None, ¶ms, true) .await .unwrap(); @@ -605,7 +637,7 @@ mod tests { let k = 100; let result = dataset .scan() - .nearest("vector", query.as_primitive::(), k) + .nearest(vector_column, query.as_primitive::(), k) .unwrap() .nprobs(nlist) .with_row_id() @@ -627,12 +659,7 @@ mod tests { .collect::>(); let row_ids = results.iter().map(|(_, id)| *id).collect::>(); - let gt = ground_truth( - &vectors, - query.as_primitive::().values(), - k, - params.metric_type, - ); + let gt = ground_truth(&vectors, query.as_ref(), k, params.metric_type); let gt_set = gt.iter().map(|r| r.1).collect::>(); let recall = row_ids.intersection(>_set).count() as f32 / k as f32; @@ -649,6 +676,7 @@ mod tests { #[case(4, DistanceType::L2, 1.0)] #[case(4, DistanceType::Cosine, 1.0)] #[case(4, DistanceType::Dot, 1.0)] + #[case(4, DistanceType::Hamming, 0.9)] #[tokio::test] async fn test_build_ivf_flat( #[case] nlist: usize, @@ -783,7 +811,7 @@ mod tests { let test_uri = test_dir.path().to_str().unwrap(); let nlist = 4; - let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, _) = generate_test_dataset::(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::new(nlist); let sq_params = SQBuildParams::default(); @@ -826,7 +854,7 @@ mod tests { let test_uri = test_dir.path().to_str().unwrap(); let nlist = 1000; - let (mut dataset, _) = generate_test_dataset(test_uri, 0.0..1.0).await; + let (mut dataset, _) = generate_test_dataset::(test_uri, 0.0..1.0).await; let ivf_params = IvfBuildParams::new(nlist); let sq_params = SQBuildParams::default(); diff --git a/rust/lance/src/index/vector/utils.rs b/rust/lance/src/index/vector/utils.rs index 661877ed539..b3c5f5b44c6 100644 --- a/rust/lance/src/index/vector/utils.rs +++ b/rust/lance/src/index/vector/utils.rs @@ -4,9 +4,6 @@ use std::sync::Arc; use arrow_array::{cast::AsArray, FixedSizeListArray}; -use arrow_schema::Schema as ArrowSchema; -use arrow_select::concat::concat_batches; -use futures::stream::TryStreamExt; use snafu::{location, Location}; use tokio::sync::Mutex; @@ -43,18 +40,13 @@ pub async fn maybe_sample_training_data( sample_size_hint: usize, ) -> Result { let num_rows = dataset.count_rows(None).await?; - let projection = dataset.schema().project(&[column])?; let batch = if num_rows > sample_size_hint { + let projection = dataset.schema().project(&[column])?; dataset.sample(sample_size_hint, &projection).await? } else { let mut scanner = dataset.scan(); scanner.project(&[column])?; - let batches = scanner - .try_into_stream() - .await? - .try_collect::>() - .await?; - concat_batches(&Arc::new(ArrowSchema::from(&projection)), &batches)? + scanner.try_into_batch().await? }; let array = batch.column_by_name(column).ok_or(Error::Index { diff --git a/rust/lance/src/io/exec/knn.rs b/rust/lance/src/io/exec/knn.rs index a09aa2f1331..96a017c706b 100644 --- a/rust/lance/src/io/exec/knn.rs +++ b/rust/lance/src/io/exec/knn.rs @@ -737,7 +737,7 @@ mod tests { let dataset = Dataset::open(test_uri).await.unwrap(); let stream = dataset .scan() - .nearest("vector", q.as_primitive(), 10) + .nearest("vector", q.as_primitive::(), 10) .unwrap() .try_into_stream() .await