diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 03c874052e9..9a85bffdfea 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -406,7 +406,7 @@ pub fn get_vector_index_params( "getRqParams", |env, rq_obj| { let num_bits = env.call_method(&rq_obj, "getNumBits", "()B", &[])?.b()? as u8; - Ok(RQBuildParams { num_bits }) + Ok(RQBuildParams::new(num_bits)) }, )?; diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 83b376e1704..2f4d78fb1bc 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2905,6 +2905,11 @@ def create_index( - index_file_version The version of the index file. Default is "V3". + Optional parameters for `IVF_RQ`: + + - num_bits + The number of bits for RQ (Rabit Quantization). Default is 1. + Optional parameters for `IVF_HNSW_*`: max_level Int, the maximum number of levels in the graph. diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index dcdf88ee84d..ee7f955c0f7 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -843,6 +843,24 @@ def test_create_ivf_rq_index(): assert res["_distance"].to_numpy().max() == 0.0 +def test_create_ivf_rq_requires_dim_divisible_by_8(): + vectors = np.zeros((1000, 30), dtype=np.float32).tolist() + tbl = pa.Table.from_pydict( + {"vector": pa.array(vectors, type=pa.list_(pa.float32(), 30))} + ) + ds = lance.write_dataset(tbl, "memory://", mode="overwrite") + + with pytest.raises( + ValueError, match="vector dimension must be divisible by 8 for IVF_RQ" + ): + ds.create_index( + "vector", + index_type="IVF_RQ", + num_partitions=4, + num_bits=1, + ) + + def test_create_ivf_hnsw_pq_index(dataset, tmp_path): assert not dataset.has_index ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") diff --git a/rust/lance-index/benches/rq.rs b/rust/lance-index/benches/rq.rs index afedc809c59..06853a3a142 100644 --- a/rust/lance-index/benches/rq.rs +++ b/rust/lance-index/benches/rq.rs @@ -16,6 +16,7 @@ use lance_datagen::{BatchGeneratorBuilder, RowCount}; use lance_index::vector::bq::builder::RabitQuantizer; use lance_index::vector::bq::storage::*; use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; +use lance_index::vector::bq::RQRotationType; use lance_index::vector::quantizer::{Quantization, QuantizerStorage}; use lance_index::vector::storage::{DistCalculator, VectorStore}; use lance_linalg::distance::DistanceType; @@ -23,9 +24,9 @@ use lance_linalg::distance::DistanceType; const DIM: usize = 128; const TOTAL: usize = 16 * 1000; -fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage { +fn mock_rq_storage(num_bits: u8, rotation_type: RQRotationType) -> RabitQuantizationStorage { // generate random rq codes - let rq = RabitQuantizer::new::(num_bits, DIM as i32); + let rq = RabitQuantizer::new_with_rotation::(num_bits, DIM as i32, rotation_type); let builder = BatchGeneratorBuilder::new() .col(ROW_ID, lance_datagen::array::step::()) .col( @@ -49,59 +50,70 @@ fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage { } fn construct_dist_table(c: &mut Criterion) { + let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix]; for num_bits in 1..=1 { - let rq = mock_rq_storage(num_bits); - let query = rand_type(&DataType::Float32) - .generate_default(RowCount::from(DIM as u64)) - .unwrap(); - c.bench_function( - format!( - "RQ{}: construct_dist_table: {},DIM={}", - num_bits, - DistanceType::L2, - DIM - ) - .as_str(), - |b| { - b.iter(|| { - black_box(rq.dist_calculator(query.clone(), 0.0)); - }) - }, - ); + for rotation_type in rotation_types { + let rq = mock_rq_storage(num_bits, rotation_type); + let query = rand_type(&DataType::Float32) + .generate_default(RowCount::from(DIM as u64)) + .unwrap(); + c.bench_function( + format!( + "RQ{}({:?}): construct_dist_table: {},DIM={}", + num_bits, + rotation_type, + DistanceType::L2, + DIM + ) + .as_str(), + |b| { + b.iter(|| { + black_box(rq.dist_calculator(query.clone(), 0.0)); + }) + }, + ); + } } } fn compute_distances(c: &mut Criterion) { + let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix]; for num_bits in 1..=1 { - let rq = mock_rq_storage(num_bits); - let query = rand_type(&DataType::Float32) - .generate_default(RowCount::from(DIM as u64)) - .unwrap(); - let dist_calc = rq.dist_calculator(query.clone(), 0.0); + for rotation_type in rotation_types { + let rq = mock_rq_storage(num_bits, rotation_type); + let query = rand_type(&DataType::Float32) + .generate_default(RowCount::from(DIM as u64)) + .unwrap(); + let dist_calc = rq.dist_calculator(query.clone(), 0.0); - c.bench_function( - format!("RQ{}: compute_distances: {},DIM={}", num_bits, TOTAL, DIM).as_str(), - |b| { - b.iter(|| { - black_box(dist_calc.distance_all(0)); - }) - }, - ); + c.bench_function( + format!( + "RQ{}({:?}): compute_distances: {},DIM={}", + num_bits, rotation_type, TOTAL, DIM + ) + .as_str(), + |b| { + b.iter(|| { + black_box(dist_calc.distance_all(0)); + }) + }, + ); - c.bench_function( - format!( - "RQ{}: compute_distances_single: {},DIM={}", - num_bits, TOTAL, DIM - ) - .as_str(), - |b| { - b.iter(|| { - for i in 0..TOTAL { - black_box(dist_calc.distance(i as u32)); - } - }) - }, - ); + c.bench_function( + format!( + "RQ{}({:?}): compute_distances_single: {},DIM={}", + num_bits, rotation_type, TOTAL, DIM + ) + .as_str(), + |b| { + b.iter(|| { + for i in 0..TOTAL { + black_box(dist_calc.distance(i as u32)); + } + }) + }, + ); + } } } diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index b36003fddf9..54748db3264 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -4,17 +4,20 @@ //! Binary Quantization (BQ) use std::iter::once; +use std::str::FromStr; use std::sync::Arc; use arrow_array::types::Float32Type; use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array}; use lance_core::{Error, Result}; use num_traits::Float; +use serde::{Deserialize, Serialize}; use snafu::location; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; +pub mod rotation; pub mod storage; pub mod transform; @@ -80,14 +83,51 @@ fn binary_quantization(data: &[T]) -> impl Iterator + '_ { })) } +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RQRotationType { + #[default] + Fast, + Matrix, +} + +impl FromStr for RQRotationType { + type Err = Error; + + fn from_str(value: &str) -> std::result::Result { + match value.to_lowercase().as_str() { + "fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast), + "matrix" | "dense" => Ok(Self::Matrix), + _ => Err(Error::invalid_input( + format!( + "Unknown RQ rotation type: {}. Expected one of: fast, matrix", + value + ), + location!(), + )), + } + } +} + #[derive(Clone, Debug, PartialEq, Eq)] pub struct RQBuildParams { pub num_bits: u8, + pub rotation_type: RQRotationType, } impl RQBuildParams { pub fn new(num_bits: u8) -> Self { - Self { num_bits } + Self { + num_bits, + rotation_type: RQRotationType::default(), + } + } + + pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self { + Self { + num_bits, + rotation_type, + } } } @@ -99,7 +139,10 @@ impl QuantizerBuildParams for RQBuildParams { impl Default for RQBuildParams { fn default() -> Self { - Self { num_bits: 1 } + Self { + num_bits: 1, + rotation_type: RQRotationType::default(), + } } } @@ -126,4 +169,17 @@ mod tests { test_bq::(); test_bq::(); } + + #[test] + fn test_rotation_type_parse() { + assert_eq!( + "fast".parse::().unwrap(), + RQRotationType::Fast + ); + assert_eq!( + "matrix".parse::().unwrap(), + RQRotationType::Matrix + ); + assert!("invalid".parse::().is_err()); + } } diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 47a40c55801..53c73e69cb3 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -11,16 +11,20 @@ use bitvec::prelude::{BitVec, Lsb0}; use deepsize::DeepSizeOf; use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType}; use lance_core::{Error, Result}; -use ndarray::{s, Axis}; +use ndarray::{s, Axis, ShapeBuilder}; use num_traits::{AsPrimitive, FromPrimitive}; use rand_distr::Distribution; +use rayon::prelude::*; use snafu::location; use crate::vector::bq::storage::{ RabitQuantizationMetadata, RabitQuantizationStorage, RABIT_CODE_COLUMN, RABIT_METADATA_KEY, }; use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD}; -use crate::vector::bq::RQBuildParams; +use crate::vector::bq::{ + rotation::{apply_fast_rotation, random_fast_rotation_signs}, + RQBuildParams, RQRotationType, +}; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; /// Build parameters for RabitQuantizer. @@ -28,11 +32,15 @@ use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; /// num_bits: the number of bits per dimension. pub struct RabitBuildParams { pub num_bits: u8, + pub rotation_type: RQRotationType, } impl Default for RabitBuildParams { fn default() -> Self { - Self { num_bits: 1 } + Self { + num_bits: 1, + rotation_type: RQRotationType::default(), + } } } @@ -48,27 +56,59 @@ pub struct RabitQuantizer { metadata: RabitQuantizationMetadata, } +#[inline] +fn pack_sign_bits(codes: &mut [u8], rotated: &[f32]) { + codes.fill(0); + for (bit_idx, value) in rotated.iter().enumerate() { + if value.is_sign_positive() { + codes[bit_idx / u8::BITS as usize] |= 1u8 << (bit_idx % u8::BITS as usize); + } + } +} + impl RabitQuantizer { pub fn new(num_bits: u8, dim: i32) -> Self { - // we don't need to calculate the inverse of P, - // just take the generated matrix as P^{-1} - let code_dim = dim * num_bits as i32; - let rotate_mat = random_orthogonal::(code_dim as usize); - let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset(); - - let rotate_mat = match T::FLOAT_TYPE { - FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => { - let rotate_mat = >::from_values(rotate_mat); - FixedSizeListArray::try_new_from_values(rotate_mat, code_dim).unwrap() - } - _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE), - }; + Self::new_with_rotation::(num_bits, dim, RQRotationType::default()) + } - let metadata = RabitQuantizationMetadata { - rotate_mat: Some(rotate_mat), - rotate_mat_position: 0, - num_bits, - packed: false, + pub fn new_with_rotation( + num_bits: u8, + dim: i32, + rotation_type: RQRotationType, + ) -> Self { + let code_dim = (dim * num_bits as i32) as usize; + let metadata = match rotation_type { + RQRotationType::Matrix => { + // we don't need to calculate the inverse of P, just take generated Q as P^{-1} + let rotate_mat = random_orthogonal::(code_dim); + let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset(); + let rotate_mat = match T::FLOAT_TYPE { + FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => { + let rotate_mat = >::from_values(rotate_mat); + FixedSizeListArray::try_new_from_values(rotate_mat, code_dim as i32) + .unwrap() + } + _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE), + }; + RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type, + code_dim: code_dim as u32, + num_bits, + packed: false, + } + } + RQRotationType::Fast => RabitQuantizationMetadata { + rotate_mat: None, + rotate_mat_position: None, + fast_rotation_signs: Some(random_fast_rotation_signs(code_dim)), + rotation_type, + code_dim: code_dim as u32, + num_bits, + packed: false, + }, }; Self { metadata } } @@ -77,6 +117,19 @@ impl RabitQuantizer { self.metadata.num_bits } + pub fn rotation_type(&self) -> RQRotationType { + self.metadata.rotation_type + } + + #[inline] + fn fast_rotation_signs(&self) -> &[u8] { + self.metadata + .fast_rotation_signs + .as_ref() + .expect("RabitQ fast rotation signs missing") + .as_slice() + } + #[inline] fn rotate_mat_flat(&self) -> &[T::Native] { let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap(); @@ -94,6 +147,45 @@ impl RabitQuantizer { ndarray::ArrayView2::from_shape((code_dim, code_dim), self.rotate_mat_flat::()).unwrap() } + fn rotate_vectors( + &self, + vectors: ndarray::ArrayView2<'_, T::Native>, + ) -> ndarray::Array2 + where + T::Native: AsPrimitive, + { + let dim = vectors.nrows(); + let code_dim = self.code_dim(); + match self.rotation_type() { + RQRotationType::Matrix => { + let rotate_mat = self.rotate_mat::(); + let rotate_mat = rotate_mat.slice(s![.., 0..dim]); + rotate_mat.dot(&vectors).mapv(|v| v.as_()) + } + RQRotationType::Fast => { + let signs = self.fast_rotation_signs(); + let ncols = vectors.ncols(); + let mut rotated_data = vec![0.0f32; code_dim * ncols]; + rotated_data + .par_chunks_mut(code_dim) + .enumerate() + .for_each_init( + || vec![0.0f32; code_dim], + |scratch, (col_idx, dst)| { + let column = vectors.column(col_idx); + let input = column + .as_slice() + .expect("RabitQ input vectors should be contiguous"); + apply_fast_rotation(input, scratch, signs); + dst.copy_from_slice(scratch); + }, + ); + + ndarray::Array2::from_shape_vec((code_dim, ncols).f(), rotated_data).unwrap() + } + } + } + pub fn dim(&self) -> usize { self.code_dim() / self.metadata.num_bits as usize } @@ -104,7 +196,7 @@ impl RabitQuantizer { residual_vectors: &FixedSizeListArray, ) -> Result> where - T::Native: AsPrimitive, + T::Native: AsPrimitive + Sync, { let dim = self.dim(); if residual_vectors.value_length() as usize != dim { @@ -118,27 +210,43 @@ impl RabitQuantizer { )); } - // convert the vector to a dxN matrix - let vec_mat = ndarray::ArrayView2::from_shape( - (residual_vectors.len(), dim), - residual_vectors - .values() - .as_any() - .downcast_ref::() - .unwrap() - .as_slice(), - ) - .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; - let vec_mat = vec_mat.t(); - - let rotate_mat = self.rotate_mat::(); - // slice to (code_dim, dim) - let rotate_mat = rotate_mat.slice(s![.., 0..dim]); - let rotated_vectors = rotate_mat.dot(&vec_mat); let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt(); - let norm_dists = rotated_vectors.mapv(|v| v.as_().abs()).sum_axis(Axis(0)) / sqrt_dim; - debug_assert_eq!(norm_dists.len(), residual_vectors.len()); - Ok(norm_dists.to_vec()) + let values = residual_vectors + .values() + .as_any() + .downcast_ref::() + .unwrap() + .as_slice(); + + match self.rotation_type() { + RQRotationType::Matrix => { + // convert the vector to a dxN matrix + let vec_mat = + ndarray::ArrayView2::from_shape((residual_vectors.len(), dim), values) + .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; + let vec_mat = vec_mat.t(); + let rotated_vectors = self.rotate_vectors::(vec_mat); + let norm_dists = rotated_vectors.mapv(f32::abs).sum_axis(Axis(0)) / sqrt_dim; + debug_assert_eq!(norm_dists.len(), residual_vectors.len()); + Ok(norm_dists.to_vec()) + } + RQRotationType::Fast => { + let code_dim = self.code_dim(); + let signs = self.fast_rotation_signs(); + let mut norm_dists = vec![0.0f32; residual_vectors.len()]; + norm_dists + .par_iter_mut() + .zip(values.par_chunks_exact(dim)) + .for_each_init( + || vec![0.0f32; code_dim], + |scratch, (dst, input)| { + apply_fast_rotation(input, scratch, signs); + *dst = scratch.iter().map(|v| v.abs()).sum::() / sqrt_dim; + }, + ); + Ok(norm_dists) + } + } } fn transform( @@ -146,38 +254,60 @@ impl RabitQuantizer { residual_vectors: &FixedSizeListArray, ) -> Result where - T::Native: AsPrimitive, + T::Native: AsPrimitive + Sync, { // we don't need to normalize the residual vectors, // because the sign of P^{-1} * v_r is the same as P^{-1} * v_r / ||v_r|| let n = residual_vectors.len(); let dim = self.dim(); debug_assert_eq!(residual_vectors.values().len(), n * dim); - - let vectors = ndarray::ArrayView2::from_shape( - (n, dim), - residual_vectors - .values() - .as_any() - .downcast_ref::() - .unwrap() - .as_slice(), - ) - .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; - let vectors = vectors.t(); - let rotate_mat = self.rotate_mat::(); - let rotate_mat = rotate_mat.slice(s![.., 0..dim]); - let rotated_vectors = rotate_mat.dot(&vectors); - - let quantized_vectors = rotated_vectors.t().mapv(|v| v.as_().is_sign_positive()); - let bv: BitVec = BitVec::from_iter(quantized_vectors); - - let codes = UInt8Array::from(bv.into_vec()); - debug_assert_eq!(codes.len(), n * self.code_dim() / u8::BITS as usize); - Ok(Arc::new(FixedSizeListArray::try_new_from_values( - codes, - self.code_dim() as i32 / u8::BITS as i32, // num_bits -> num_bytes - )?)) + let values = residual_vectors + .values() + .as_any() + .downcast_ref::() + .unwrap() + .as_slice(); + let code_dim = self.code_dim(); + let code_bytes = code_dim / u8::BITS as usize; + + match self.rotation_type() { + RQRotationType::Matrix => { + let vectors = ndarray::ArrayView2::from_shape((n, dim), values) + .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; + let vectors = vectors.t(); + let rotated_vectors = self.rotate_vectors::(vectors); + + let quantized_vectors = rotated_vectors.t().mapv(|v| v.is_sign_positive()); + let bv: BitVec = BitVec::from_iter(quantized_vectors); + + let codes = UInt8Array::from(bv.into_vec()); + debug_assert_eq!(codes.len(), n * code_bytes); + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + codes, + code_bytes as i32, // num_bits -> num_bytes + )?)) + } + RQRotationType::Fast => { + let signs = self.fast_rotation_signs(); + let mut encoded_codes = vec![0u8; n * code_bytes]; + encoded_codes + .par_chunks_mut(code_bytes) + .zip(values.par_chunks_exact(dim)) + .for_each_init( + || vec![0.0f32; code_dim], + |scratch, (code_dst, input)| { + apply_fast_rotation(input, scratch, signs); + pack_sign_bits(code_dst, scratch); + }, + ); + let codes = UInt8Array::from(encoded_codes); + debug_assert_eq!(codes.len(), n * code_bytes); + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + codes, + code_bytes as i32, + )?)) + } + } } } @@ -191,16 +321,30 @@ impl Quantization for RabitQuantizer { _: lance_linalg::distance::DistanceType, params: &Self::BuildParams, ) -> Result { + let dim = data.as_fixed_size_list().value_length() as usize; + if !dim.is_multiple_of(u8::BITS as usize) { + return Err(Error::invalid_input( + "vector dimension must be divisible by 8 for IVF_RQ", + location!(), + )); + } + let q = match data.as_fixed_size_list().value_type() { - DataType::Float16 => { - Self::new::(params.num_bits, data.as_fixed_size_list().value_length()) - } - DataType::Float32 => { - Self::new::(params.num_bits, data.as_fixed_size_list().value_length()) - } - DataType::Float64 => { - Self::new::(params.num_bits, data.as_fixed_size_list().value_length()) - } + DataType::Float16 => Self::new_with_rotation::( + params.num_bits, + data.as_fixed_size_list().value_length(), + params.rotation_type, + ), + DataType::Float32 => Self::new_with_rotation::( + params.num_bits, + data.as_fixed_size_list().value_length(), + params.rotation_type, + ), + DataType::Float64 => Self::new_with_rotation::( + params.num_bits, + data.as_fixed_size_list().value_length(), + params.rotation_type, + ), dt => { return Err(Error::invalid_input( format!("Unsupported data type: {:?}", dt), @@ -216,11 +360,15 @@ impl Quantization for RabitQuantizer { } fn code_dim(&self) -> usize { - self.metadata - .rotate_mat - .as_ref() - .map(|inv_p| inv_p.len()) - .unwrap_or(0) + if self.metadata.code_dim > 0 { + self.metadata.code_dim as usize + } else { + self.metadata + .rotate_mat + .as_ref() + .map(|rotate_mat| rotate_mat.len()) + .unwrap_or(0) + } } fn column(&self) -> &'static str { @@ -370,6 +518,9 @@ where mod tests { use super::*; use approx::assert_relative_eq; + use arrow::datatypes::Float32Type; + use arrow_array::{FixedSizeListArray, Float32Array}; + use lance_linalg::distance::DistanceType; use rstest::rstest; #[rstest] @@ -410,4 +561,31 @@ mod tests { assert_eq!(q.dim(), (m, m)); assert_eq!(r.dim(), (m, n)); } + + #[test] + fn test_rabit_quantizer_rotation_modes() { + let fast_q = RabitQuantizer::new_with_rotation::(1, 128, RQRotationType::Fast); + assert_eq!(fast_q.rotation_type(), RQRotationType::Fast); + assert_eq!(fast_q.dim(), 128); + + let matrix_q = + RabitQuantizer::new_with_rotation::(1, 128, RQRotationType::Matrix); + assert_eq!(matrix_q.rotation_type(), RQRotationType::Matrix); + assert_eq!(matrix_q.dim(), 128); + } + + #[test] + fn test_rabit_quantizer_requires_dim_divisible_by_8() { + let vectors = Float32Array::from(vec![0.0f32; 4 * 30]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, 30).unwrap(); + let params = RQBuildParams::new(1); + + let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string() + .contains("vector dimension must be divisible by 8 for IVF_RQ"), + "{}", + err + ); + } } diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs new file mode 100644 index 00000000000..de4fbf549f1 --- /dev/null +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -0,0 +1,223 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use num_traits::AsPrimitive; +use rand::RngCore; + +// Fast random rotation used by the RabitQ "fast" path. +// +// The transform is a composition of: +// 1) random diagonal sign flips (Rademacher variables), +// 2) FWHT-style mixing on a power-of-two window, +// 3) a Kac-style pairwise mixing step for non-power-of-two dimensions. +// +// Background: +// - Hadamard transform: https://en.wikipedia.org/wiki/Hadamard_transform +// - Fast Walsh-Hadamard transform (FWHT): +// https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform +// - Rademacher random signs: +// https://en.wikipedia.org/wiki/Rademacher_distribution +// - Kac-walk-based fast dimension reduction (uses fixed-angle pair rotations): +// https://arxiv.org/abs/2003.10069 +// - Givens / plane rotation: +// https://en.wikipedia.org/wiki/Givens_rotation +const FAST_ROTATION_ROUNDS: usize = 4; + +#[inline] +fn fwht_in_place(values: &mut [f32]) { + // In-place FWHT butterfly network. + // For each stage, pair entries (x, y) and map to (x + y, x - y). + // Complexity: O(n log n) operations, no extra heap allocation. + debug_assert!(values.len().is_power_of_two()); + let mut half = 1usize; + while half < values.len() { + let step = half * 2; + for block in values.chunks_exact_mut(step) { + let (left, right) = block.split_at_mut(half); + for (x, y) in left.iter_mut().zip(right.iter_mut()) { + let lx = *x; + let ry = *y; + *x = lx + ry; + *y = lx - ry; + } + } + half = step; + } +} + +#[inline] +fn flip_signs_scalar(values: &mut [f32], signs: &[u8]) { + // Apply a random diagonal matrix with +/-1 entries by toggling the f32 sign bit. + // One bit in `signs` controls one element in `values`. + for (byte_idx, &mask) in signs.iter().enumerate() { + let start = byte_idx * 8; + if start >= values.len() { + break; + } + let end = (start + 8).min(values.len()); + for (bit_idx, value) in values[start..end].iter_mut().enumerate() { + let sign_mask = (((mask >> bit_idx) & 1) as u32) << 31; + *value = f32::from_bits(value.to_bits() ^ sign_mask); + } + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn flip_signs_avx2(values: &mut [f32], signs: &[u8]) { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use std::arch::x86_64::*; + + // Vectorized variant of `flip_signs_scalar`: consume 8 f32 values per AVX2 lane. + // The sign mask is expanded from one byte to 8 lane-wise sign-bit masks. + let full_chunks = values.len() / 8; + let bit_select = _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); + let sign_flip = _mm256_set1_epi32(0x80000000u32 as i32); + + for (chunk_idx, &mask) in signs.iter().take(full_chunks).enumerate() { + let mask = mask as i32; + let mask_bits = _mm256_set1_epi32(mask); + let test = _mm256_and_si256(mask_bits, bit_select); + let cmp = _mm256_cmpeq_epi32(test, bit_select); + let xor_mask = _mm256_and_si256(cmp, sign_flip); + + let ptr = unsafe { values.as_mut_ptr().add(chunk_idx * 8) }; + let vec = unsafe { _mm256_loadu_ps(ptr) }; + let out = _mm256_xor_ps(vec, _mm256_castsi256_ps(xor_mask)); + unsafe { _mm256_storeu_ps(ptr, out) }; + } + + if full_chunks * 8 < values.len() { + flip_signs_scalar(&mut values[full_chunks * 8..], &signs[full_chunks..]); + } +} + +#[inline] +fn flip_signs(values: &mut [f32], signs: &[u8]) { + debug_assert!(signs.len() * 8 >= values.len()); + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if std::arch::is_x86_feature_detected!("avx2") { + // SAFETY: guarded by runtime feature detection. + unsafe { + flip_signs_avx2(values, signs); + } + return; + } + } + flip_signs_scalar(values, signs); +} + +#[inline] +fn kacs_walk(values: &mut [f32]) { + // A fixed-angle (pi/4) plane-rotation-like sweep over paired coordinates: + // (x, y) -> (x + y, x - y). Up to normalization, this is a 2x2 Hadamard block + // and corresponds to one Kac-style mixing step. + let half = values.len() / 2; + let (left, right) = values.split_at_mut(half); + for (x, y) in left.iter_mut().zip(right.iter_mut()) { + let lx = *x; + let ry = *y; + *x = lx + ry; + *y = lx - ry; + } +} + +#[inline] +fn rescale(values: &mut [f32], factor: f32) { + // Keep the transform numerically stable and approximately orthonormal. + for value in values.iter_mut() { + *value *= factor; + } +} + +#[inline] +fn sign_bytes_per_round(dim: usize) -> usize { + dim.div_ceil(8) +} + +pub fn random_fast_rotation_signs(dim: usize) -> Vec { + // Each round needs one random sign bit per dimension. + let mut signs = vec![0u8; FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim)]; + rand::rng().fill_bytes(&mut signs); + signs +} + +pub fn apply_fast_rotation>(input: &[T], output: &mut [f32], signs: &[u8]) { + // Fast random rotation pipeline, aligned with RaBitQ-Library's FhtKacRotator: + // - power-of-two dims: repeat [random signs -> FWHT -> scale] for 4 rounds + // - non-power-of-two dims: alternate FWHT on head/tail + Kac mixing + // + // This keeps the fast path matrix-free: no dense orthogonal matrix materialization. + let dim = output.len(); + let bytes_per_round = sign_bytes_per_round(dim); + debug_assert_eq!(signs.len(), FAST_ROTATION_ROUNDS * bytes_per_round); + let input_len = input.len().min(dim); + output[..input_len] + .iter_mut() + .zip(input[..input_len].iter()) + .for_each(|(dst, src)| *dst = src.as_()); + if input_len < dim { + output[input_len..].fill(0.0); + } + + if dim == 0 { + return; + } + + let trunc_dim = 1usize << dim.ilog2(); + let scale = 1.0f32 / (trunc_dim as f32).sqrt(); + if trunc_dim == dim { + for round in 0..FAST_ROTATION_ROUNDS { + let offset = round * bytes_per_round; + flip_signs(output, &signs[offset..offset + bytes_per_round]); + fwht_in_place(output); + rescale(output, scale); + } + return; + } + + let start = dim - trunc_dim; + for round in 0..FAST_ROTATION_ROUNDS { + let offset = round * bytes_per_round; + flip_signs(output, &signs[offset..offset + bytes_per_round]); + + if round % 2 == 0 { + let head = &mut output[..trunc_dim]; + fwht_in_place(head); + rescale(head, scale); + } else { + let tail = &mut output[start..]; + fwht_in_place(tail); + rescale(tail, scale); + } + + kacs_walk(output); + } + + // Matches RaBitQ-Library FhtKacRotator behavior for non-power-of-two dimensions. + // The extra factor compensates the alternating truncated FWHT + Kac steps above. + rescale(output, 0.25); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fast_rotation_sign_bytes() { + assert_eq!(random_fast_rotation_signs(128).len(), 64); + assert_eq!(random_fast_rotation_signs(130).len(), 68); + } + + #[test] + fn test_fast_rotation_preserves_shape() { + let input = vec![1.0f32; 129]; + let mut output = vec![0.0f32; 129]; + let signs = random_fast_rotation_signs(129); + apply_fast_rotation(&input, &mut output, &signs); + assert_eq!(output.len(), 129); + } +} diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index cead98a72f6..f94f2135295 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -28,7 +28,9 @@ use snafu::location; use crate::frag_reuse::FragReuseIndex; use crate::pb; +use crate::vector::bq::rotation::apply_fast_rotation; use crate::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; +use crate::vector::bq::RQRotationType; use crate::vector::pq::storage::transpose; use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage}; use crate::vector::storage::{DistCalculator, VectorStore}; @@ -45,45 +47,80 @@ pub struct RabitQuantizationMetadata { // in the global buffer, which is a binary format (protobuf for now) for efficiency. #[serde(skip)] pub rotate_mat: Option, - pub rotate_mat_position: u32, + #[serde(default)] + pub rotate_mat_position: Option, + #[serde(default)] + pub fast_rotation_signs: Option>, + #[serde(default = "default_rotation_type_compat")] + pub rotation_type: RQRotationType, + #[serde(default)] + pub code_dim: u32, pub num_bits: u8, pub packed: bool, } +fn default_rotation_type_compat() -> RQRotationType { + // Older metadata does not have this field and always used dense matrices. + RQRotationType::Matrix +} + impl DeepSizeOf for RabitQuantizationMetadata { fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { self.rotate_mat .as_ref() .map(|inv_p| inv_p.get_array_memory_size()) .unwrap_or(0) + + self + .fast_rotation_signs + .as_ref() + .map(|signs| signs.len()) + .unwrap_or(0) } } #[async_trait] impl QuantizerMetadata for RabitQuantizationMetadata { fn buffer_index(&self) -> Option { - Some(self.rotate_mat_position) + match self.rotation_type { + RQRotationType::Matrix => self.rotate_mat_position, + RQRotationType::Fast => None, + } } fn set_buffer_index(&mut self, index: u32) { - self.rotate_mat_position = index; + self.rotate_mat_position = Some(index); } fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> { + if self.rotation_type != RQRotationType::Matrix { + return Ok(()); + } debug_assert!(!bytes.is_empty()); let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?; self.rotate_mat = Some(FixedSizeListArray::try_from(&codebook_tensor)?); + if self.code_dim == 0 { + self.code_dim = self + .rotate_mat + .as_ref() + .map(|rotate_mat| rotate_mat.len() as u32) + .unwrap_or(0); + } Ok(()) } fn extra_metadata(&self) -> Result> { - if let Some(inv_p) = &self.rotate_mat { - let inv_p_tensor = pb::Tensor::try_from(inv_p)?; - let mut bytes = BytesMut::new(); - inv_p_tensor.encode(&mut bytes)?; - Ok(Some(bytes.freeze())) - } else { - Ok(None) + match self.rotation_type { + RQRotationType::Matrix => { + if let Some(inv_p) = &self.rotate_mat { + let inv_p_tensor = pb::Tensor::try_from(inv_p)?; + let mut bytes = BytesMut::new(); + inv_p_tensor.encode(&mut bytes)?; + Ok(Some(bytes.freeze())) + } else { + Ok(None) + } + } + RQRotationType::Fast => Ok(None), } } @@ -127,7 +164,7 @@ impl DeepSizeOf for RabitQuantizationStorage { } impl RabitQuantizationStorage { - fn rotate_query_vector( + fn rotate_query_vector_dense( rotate_mat: &FixedSizeListArray, qr: &dyn Array, ) -> Vec @@ -154,6 +191,25 @@ impl RabitQuantizationStorage { .map(|chunk| lance_linalg::distance::dot(&chunk[..d], qr)) .collect() } + + fn rotate_query_vector_fast( + code_dim: usize, + signs: &[u8], + qr: &dyn Array, + ) -> Vec + where + T::Native: AsPrimitive, + { + let qr = qr + .as_any() + .downcast_ref::() + .unwrap() + .as_slice(); + + let mut output = vec![0.0f32; code_dim]; + apply_fast_rotation(qr, &mut output, signs); + output + } } pub struct RabitDistCalculator<'a> { @@ -408,17 +464,56 @@ impl VectorStore for RabitQuantizationStorage { #[inline(never)] fn dist_calculator(&self, qr: Arc, dist_q_c: f32) -> Self::DistanceCalculator<'_> { let codes = self.codes.values().as_primitive::().values(); - let rotate_mat = self - .metadata - .rotate_mat - .as_ref() - .expect("RabitQ metadata not loaded"); + let code_dim = if self.metadata.code_dim > 0 { + self.metadata.code_dim as usize + } else { + self.metadata + .rotate_mat + .as_ref() + .map(|rotate_mat| rotate_mat.len()) + .unwrap_or_default() + }; - let rotated_qr = match rotate_mat.value_type() { - DataType::Float16 => Self::rotate_query_vector::(rotate_mat, &qr), - DataType::Float32 => Self::rotate_query_vector::(rotate_mat, &qr), - DataType::Float64 => Self::rotate_query_vector::(rotate_mat, &qr), - dt => unimplemented!("RabitQ does not support data type: {}", dt), + let rotated_qr = match self.metadata.rotation_type { + RQRotationType::Matrix => { + let rotate_mat = self + .metadata + .rotate_mat + .as_ref() + .expect("RabitQ dense rotation metadata not loaded"); + + match rotate_mat.value_type() { + DataType::Float16 => { + Self::rotate_query_vector_dense::(rotate_mat, &qr) + } + DataType::Float32 => { + Self::rotate_query_vector_dense::(rotate_mat, &qr) + } + DataType::Float64 => { + Self::rotate_query_vector_dense::(rotate_mat, &qr) + } + dt => unimplemented!("RabitQ does not support data type: {}", dt), + } + } + RQRotationType::Fast => { + let signs = self + .metadata + .fast_rotation_signs + .as_ref() + .expect("RabitQ fast rotation metadata not loaded"); + match qr.data_type() { + DataType::Float16 => { + Self::rotate_query_vector_fast::(code_dim, signs, &qr) + } + DataType::Float32 => { + Self::rotate_query_vector_fast::(code_dim, signs, &qr) + } + DataType::Float64 => { + Self::rotate_query_vector_fast::(code_dim, signs, &qr) + } + dt => unimplemented!("RabitQ does not support data type: {}", dt), + } + } }; let dist_table = build_dist_table_direct::(&rotated_qr); diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 22a79117924..4f3de52659d 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -25,7 +25,7 @@ use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; use lance_index::progress::{noop_progress, IndexBuildProgress}; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::RQBuildParams; +use lance_index::vector::bq::{RQBuildParams, RQRotationType}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::builder::recommended_num_partitions; @@ -169,8 +169,22 @@ impl VectorIndexParams { } pub fn ivf_rq(num_partitions: usize, num_bits: u8, distance_type: DistanceType) -> Self { + Self::ivf_rq_with_rotation( + num_partitions, + num_bits, + distance_type, + RQRotationType::default(), + ) + } + + pub fn ivf_rq_with_rotation( + num_partitions: usize, + num_bits: u8, + distance_type: DistanceType, + rotation_type: RQRotationType, + ) -> Self { let ivf = IvfBuildParams::new(num_partitions); - let rq = RQBuildParams { num_bits }; + let rq = RQBuildParams::with_rotation_type(num_bits, rotation_type); let stages = vec![StageParams::Ivf(ivf), StageParams::RQ(rq)]; Self { stages, @@ -1671,6 +1685,7 @@ fn derive_sq_params(sq_quantizer: &ScalarQuantizer) -> SQBuildParams { fn derive_rabit_params(rabit_quantizer: &RabitQuantizer) -> RQBuildParams { RQBuildParams { num_bits: rabit_quantizer.num_bits(), + rotation_type: rabit_quantizer.rotation_type(), } } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index b510c60e41b..08b2cfa4e80 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -625,7 +625,9 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SchemaRef}; use itertools::Itertools; use lance_arrow::FixedSizeListArrayExt; - use lance_index::vector::bq::RQBuildParams; + use lance_index::vector::bq::{ + storage::RabitQuantizationMetadata, RQBuildParams, RQRotationType, + }; use lance_index::vector::storage::VectorStore; use crate::dataset::{InsertBuilder, UpdateBuilder, WriteMode, WriteParams}; @@ -735,6 +737,49 @@ mod tests { vectors } + async fn get_rq_metadata( + dataset: &Dataset, + scheduler: Arc, + index_uuid: &str, + ) -> RabitQuantizationMetadata { + let index_path = dataset + .indices_dir() + .child(index_uuid) + .child(INDEX_AUXILIARY_FILE_NAME); + let file_scheduler = scheduler + .open_file(&index_path, &CachedFileSize::unknown()) + .await + .unwrap(); + let reader = FileReader::try_open( + file_scheduler, + None, + Arc::::default(), + &LanceCache::no_cache(), + FileReaderOptions::default(), + ) + .await + .unwrap(); + let metadata = reader.schema().metadata.get(STORAGE_METADATA_KEY).unwrap(); + let metadata_entries: Vec = serde_json::from_str(metadata).unwrap(); + serde_json::from_str(&metadata_entries[0]).unwrap() + } + + async fn assert_rq_rotation_type(dataset: &Dataset, expected: RQRotationType) { + let obj_store = Arc::new(ObjectStore::local()); + let scheduler = ScanScheduler::new(obj_store, SchedulerConfig::default_for_testing()); + let indices = dataset.load_indices().await.unwrap(); + assert!(!indices.is_empty(), "Expected at least one vector index"); + for index in indices.iter() { + let rq_meta = + get_rq_metadata(dataset, scheduler.clone(), &index.uuid.to_string()).await; + assert_eq!( + rq_meta.rotation_type, expected, + "RQ rotation type mismatch for index {}", + index.uuid + ); + } + } + fn generate_batch( num_rows: usize, start_id: Option, @@ -2061,10 +2106,11 @@ mod tests { #[case] nlist: usize, #[case] distance_type: DistanceType, #[case] recall_requirement: f32, + #[values(RQRotationType::Fast, RQRotationType::Matrix)] rotation_type: RQRotationType, ) { let _ = env_logger::try_init(); let ivf_params = IvfBuildParams::new(nlist); - let rq_params = RQBuildParams::new(1); + let rq_params = RQBuildParams::with_rotation_type(1, rotation_type); let params = VectorIndexParams::with_ivf_rq_params(distance_type, ivf_params, rq_params); test_index(params.clone(), nlist, recall_requirement, None).await; if distance_type == DistanceType::Cosine { @@ -2073,6 +2119,52 @@ mod tests { test_remap(params.clone(), nlist, recall_requirement).await; } + #[rstest] + #[case::fast(RQRotationType::Fast)] + #[case::matrix(RQRotationType::Matrix)] + #[tokio::test] + async fn test_ivf_rq_rotation_type_after_optimize(#[case] rotation_type: RQRotationType) { + let test_dir = TempStrDir::default(); + let test_uri = test_dir.as_str(); + let (mut dataset, _) = generate_test_dataset::(test_uri, 0.0..1.0).await; + + let ivf_params = IvfBuildParams::new(4); + let rq_params = RQBuildParams::with_rotation_type(1, rotation_type); + let params = VectorIndexParams::with_ivf_rq_params(DistanceType::L2, ivf_params, rq_params); + dataset + .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + assert_rq_rotation_type(&dataset, rotation_type).await; + + append_dataset::(&mut dataset, 64, 0.0..1.0).await; + dataset + .optimize_indices(&OptimizeOptions::append()) + .await + .unwrap(); + + let indices_after_append = dataset.load_indices().await.unwrap(); + assert_eq!( + indices_after_append.len(), + 2, + "Expected append optimize to create one delta index" + ); + assert_rq_rotation_type(&dataset, rotation_type).await; + + dataset + .optimize_indices(&OptimizeOptions::merge(10)) + .await + .unwrap(); + let indices_after_merge = dataset.load_indices().await.unwrap(); + assert_eq!( + indices_after_merge.len(), + 1, + "Expected merge optimize to merge indices into one" + ); + assert_rq_rotation_type(&dataset, rotation_type).await; + } + #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)]