diff --git a/rust/lance-index/benches/hnsw.rs b/rust/lance-index/benches/hnsw.rs index 967b2e67b67..5339074eb37 100644 --- a/rust/lance-index/benches/hnsw.rs +++ b/rust/lance-index/benches/hnsw.rs @@ -7,16 +7,21 @@ use std::{collections::HashSet, sync::Arc, time::Duration}; -use arrow_array::{types::Float32Type, FixedSizeListArray}; +use arrow_array::{types::Float32Type, FixedSizeListArray, RecordBatch, UInt64Array}; +use arrow_schema::{DataType, Field, Schema}; use criterion::{criterion_group, criterion_main, Criterion}; use lance_arrow::FixedSizeListArrayExt; use lance_index::vector::v3::subindex::IvfSubIndex; #[cfg(target_os = "linux")] use pprof::criterion::{Output, PProfProfiler}; +use lance_core::ROW_ID_FIELD; use lance_index::vector::{ flat::storage::FlatFloatStorage, hnsw::builder::{HnswBuildParams, HnswQueryParams, HNSW}, + quantizer::Quantization, + sq::{builder::SQBuildParams, ScalarQuantizer}, + storage::StorageBuilder, }; use lance_linalg::distance::DistanceType; use lance_testing::datagen::generate_random_array_with_seed; @@ -85,6 +90,96 @@ fn bench_hnsw(c: &mut Criterion) { }); } +fn bench_hnsw_sq(c: &mut Criterion) { + const DIMENSION: usize = 128; + const TOTAL: usize = 100_000; + const SEED: [u8; 32] = [42; 32]; + const K: usize = 100; + + let rt = tokio::runtime::Runtime::new().unwrap(); + + let data = generate_random_array_with_seed::(TOTAL * DIMENSION, SEED); + let fsl = FixedSizeListArray::try_new_from_values(data, DIMENSION as i32).unwrap(); + let quantizer = + ::build(&fsl, DistanceType::L2, &SQBuildParams::default()) + .unwrap(); + + let schema = Arc::new(Schema::new(vec![ + Field::new( + "vector", + DataType::FixedSizeList( + Field::new_list_field(DataType::Float32, true).into(), + DIMENSION as i32, + ), + true, + ), + ROW_ID_FIELD.clone(), + ])); + let row_ids = UInt64Array::from_iter_values((0..TOTAL).map(|v| v as u64)); + let batch = + RecordBatch::try_new(schema, vec![Arc::new(fsl.clone()), Arc::new(row_ids)]).unwrap(); + let sq_storage = StorageBuilder::new("vector".to_owned(), DistanceType::L2, quantizer, None) + .unwrap() + .build(vec![batch]) + .unwrap(); + let vectors = Arc::new(sq_storage); + + let query = fsl.value(0); + c.bench_function( + format!("create_hnsw_sq({TOTAL}x{DIMENSION})").as_str(), + |b| { + b.to_async(&rt).iter(|| async { + let hnsw = + HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default()).unwrap(); + let uids: HashSet = hnsw + .search_basic( + query.clone(), + K, + &HnswQueryParams { + ef: 300, + lower_bound: None, + upper_bound: None, + dist_q_c: 0.0, + }, + None, + vectors.as_ref(), + ) + .unwrap() + .iter() + .map(|node| node.id) + .collect(); + + assert_eq!(uids.len(), K); + }) + }, + ); + + let hnsw = HNSW::index_vectors(vectors.as_ref(), HnswBuildParams::default()).unwrap(); + c.bench_function(format!("search_hnsw_sq{TOTAL}x{DIMENSION}").as_str(), |b| { + b.to_async(&rt).iter(|| async { + let uids: HashSet = hnsw + .search_basic( + query.clone(), + K, + &HnswQueryParams { + ef: 300, + lower_bound: None, + upper_bound: None, + dist_q_c: 0.0, + }, + None, + vectors.as_ref(), + ) + .unwrap() + .iter() + .map(|node| node.id) + .collect(); + + assert_eq!(uids.len(), K); + }) + }); +} + #[cfg(target_os = "linux")] criterion_group!( name=benches; @@ -92,7 +187,7 @@ criterion_group!( .measurement_time(Duration::from_secs(10)) .sample_size(10) .with_profiler(PProfProfiler::new(100, Output::Flamegraph(None))); - targets = bench_hnsw); + targets = bench_hnsw, bench_hnsw_sq); // Non-linux version does not support pprof. #[cfg(not(target_os = "linux"))] @@ -101,6 +196,6 @@ criterion_group!( config = Criterion::default() .measurement_time(Duration::from_secs(10)) .sample_size(10); - targets = bench_hnsw); + targets = bench_hnsw, bench_hnsw_sq); criterion_main!(benches); diff --git a/rust/lance-index/src/vector/sq.rs b/rust/lance-index/src/vector/sq.rs index 6ac382bb347..520ed3fc212 100644 --- a/rust/lance-index/src/vector/sq.rs +++ b/rust/lance-index/src/vector/sq.rs @@ -276,15 +276,6 @@ pub(crate) fn scale_to_u8(values: &[T::Native], bounds: &Rang .collect_vec() } -pub(crate) fn inverse_scalar_dist( - values: impl Iterator, - bounds: &Range, -) -> Vec { - let range = (bounds.end - bounds.start) as f32; - values - .map(|v| v * range.powi(2) / 255.0.powi(2)) - .collect_vec() -} #[cfg(test)] mod tests { use arrow::datatypes::{Float16Type, Float32Type, Float64Type}; diff --git a/rust/lance-index/src/vector/sq/storage.rs b/rust/lance-index/src/vector/sq/storage.rs index c3ef4c96345..13c916aa657 100644 --- a/rust/lance-index/src/vector/sq/storage.rs +++ b/rust/lance-index/src/vector/sq/storage.rs @@ -23,7 +23,7 @@ use serde::{Deserialize, Serialize}; use snafu::location; use std::sync::Arc; -use super::{inverse_scalar_dist, scale_to_u8, ScalarQuantizer}; +use super::{scale_to_u8, ScalarQuantizer}; use crate::frag_reuse::FragReuseIndex; use crate::{ vector::{ @@ -387,17 +387,24 @@ impl VectorStore for ScalarQuantizationStorage { fn dist_calculator_from_id(&self, id: u32) -> Self::DistanceCalculator<'_> { let (offset, chunk) = self.chunk(id); let query_sq_code = chunk.sq_code_slice(id - offset).to_vec(); + let bounds = self.quantizer.bounds(); SQDistCalculator { query_sq_code, - bounds: self.quantizer.bounds(), + scale: sq_distance_scale(&bounds), storage: self, } } } +#[inline] +fn sq_distance_scale(bounds: &Range) -> f32 { + let range = (bounds.end - bounds.start) as f32; + (range * range) / (255.0_f32 * 255.0_f32) +} + pub struct SQDistCalculator<'a> { query_sq_code: Vec, - bounds: Range, + scale: f32, storage: &'a ScalarQuantizationStorage, } @@ -423,7 +430,7 @@ impl<'a> SQDistCalculator<'a> { }; Self { query_sq_code, - bounds, + scale: sq_distance_scale(&bounds), storage, } } @@ -440,29 +447,35 @@ impl DistCalculator for SQDistCalculator<'_> { DistanceType::Dot => dot_distance(sq_code, &self.query_sq_code), _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), }; - inverse_scalar_dist(std::iter::once(dist), &self.bounds)[0] + dist * self.scale } fn distance_all(&self, _k_hint: usize) -> Vec { match self.storage.distance_type { - DistanceType::L2 | DistanceType::Cosine => inverse_scalar_dist( - self.storage.chunks.iter().flat_map(|c| { + DistanceType::L2 | DistanceType::Cosine => self + .storage + .chunks + .iter() + .flat_map(|c| { c.sq_codes .values() .chunks_exact(c.dim()) .map(|sq_codes| l2_distance_uint_scalar(sq_codes, &self.query_sq_code)) - }), - &self.bounds, - ), - DistanceType::Dot => inverse_scalar_dist( - self.storage.chunks.iter().flat_map(|c| { + }) + .map(|dist| dist * self.scale) + .collect(), + DistanceType::Dot => self + .storage + .chunks + .iter() + .flat_map(|c| { c.sq_codes .values() .chunks_exact(c.dim()) .map(|sq_codes| dot_distance(sq_codes, &self.query_sq_code)) - }), - &self.bounds, - ), + }) + .map(|dist| dist * self.scale) + .collect(), _ => panic!("We should not reach here: sq distance can only be L2 or Dot"), } }