From 39479878ee823239336ba0446cfa357e47a01143 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sat, 21 Feb 2026 14:23:13 +0800 Subject: [PATCH 1/7] feat(rq): add fast random rotation option --- python/python/lance/dataset.py | 8 + python/python/tests/test_vector_index.py | 20 +++ python/src/dataset.rs | 16 +- rust/lance-index/benches/rq.rs | 106 +++++++------ rust/lance-index/src/vector/bq.rs | 65 +++++++- rust/lance-index/src/vector/bq/builder.rs | 174 ++++++++++++++++----- rust/lance-index/src/vector/bq/rotation.rs | 130 +++++++++++++++ rust/lance-index/src/vector/bq/storage.rs | 137 +++++++++++++--- rust/lance/src/index/vector.rs | 19 ++- 9 files changed, 559 insertions(+), 116 deletions(-) create mode 100644 rust/lance-index/src/vector/bq/rotation.rs diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 97a9a0c6775..a027d982274 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2862,6 +2862,14 @@ def create_index( - index_file_version The version of the index file. Default is "V3". + Optional parameters for `IVF_RQ`: + + - num_bits + The number of bits for RQ (Rabit Quantization). Default is 1. + - rq_rotation_type / rabitq_rotation_type + Rotation implementation for RabitQ. One of ``"fast"`` (default) + or ``"matrix"``. + Optional parameters for `IVF_HNSW_*`: max_level Int, the maximum number of levels in the graph. diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 71850135ae3..cc27de999ba 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -803,6 +803,26 @@ def test_create_ivf_rq_index(): ) assert ds.list_indices()[0]["fields"] == ["vector"] + ds = ds.create_index( + "vector", + index_type="IVF_RQ", + num_partitions=4, + num_bits=1, + rq_rotation_type="matrix", + replace=True, + ) + assert ds.list_indices()[0]["fields"] == ["vector"] + + with pytest.raises(ValueError, match="Unknown RQ rotation type"): + ds.create_index( + "vector", + index_type="IVF_RQ", + num_partitions=4, + num_bits=1, + rq_rotation_type="not-a-rotation", + replace=True, + ) + with pytest.raises( NotImplementedError, match="Creating empty vector indices with train=False is not yet implemented", diff --git a/python/src/dataset.rs b/python/src/dataset.rs index b4b4d8af050..094f385f524 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -17,7 +17,7 @@ use async_trait::async_trait; use blob::LanceBlobFile; use chrono::{Duration, TimeDelta, Utc}; use futures::{StreamExt, TryFutureExt}; -use lance_index::vector::bq::RQBuildParams; +use lance_index::vector::bq::{RQBuildParams, RQRotationType}; use log::error; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -3331,6 +3331,20 @@ fn prepare_vector_index_params( rq_params.num_bits = num_bits; }; + let rq_rotation_type = if let Some(rotation_type) = kwargs.get_item("rq_rotation_type")? { + Some(rotation_type.extract::()?) + } else if let Some(rotation_type) = kwargs.get_item("rabitq_rotation_type")? { + Some(rotation_type.extract::()?) + } else { + None + }; + if let Some(rotation_type) = rq_rotation_type { + rq_params.rotation_type = + rotation_type + .parse::() + .map_err(|e| PyValueError::new_err(format!("{}", e)))?; + } + if let Some(n) = kwargs.get_item("num_sub_vectors")? { pq_params.num_sub_vectors = n.extract()? }; diff --git a/rust/lance-index/benches/rq.rs b/rust/lance-index/benches/rq.rs index afedc809c59..06853a3a142 100644 --- a/rust/lance-index/benches/rq.rs +++ b/rust/lance-index/benches/rq.rs @@ -16,6 +16,7 @@ use lance_datagen::{BatchGeneratorBuilder, RowCount}; use lance_index::vector::bq::builder::RabitQuantizer; use lance_index::vector::bq::storage::*; use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; +use lance_index::vector::bq::RQRotationType; use lance_index::vector::quantizer::{Quantization, QuantizerStorage}; use lance_index::vector::storage::{DistCalculator, VectorStore}; use lance_linalg::distance::DistanceType; @@ -23,9 +24,9 @@ use lance_linalg::distance::DistanceType; const DIM: usize = 128; const TOTAL: usize = 16 * 1000; -fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage { +fn mock_rq_storage(num_bits: u8, rotation_type: RQRotationType) -> RabitQuantizationStorage { // generate random rq codes - let rq = RabitQuantizer::new::(num_bits, DIM as i32); + let rq = RabitQuantizer::new_with_rotation::(num_bits, DIM as i32, rotation_type); let builder = BatchGeneratorBuilder::new() .col(ROW_ID, lance_datagen::array::step::()) .col( @@ -49,59 +50,70 @@ fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage { } fn construct_dist_table(c: &mut Criterion) { + let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix]; for num_bits in 1..=1 { - let rq = mock_rq_storage(num_bits); - let query = rand_type(&DataType::Float32) - .generate_default(RowCount::from(DIM as u64)) - .unwrap(); - c.bench_function( - format!( - "RQ{}: construct_dist_table: {},DIM={}", - num_bits, - DistanceType::L2, - DIM - ) - .as_str(), - |b| { - b.iter(|| { - black_box(rq.dist_calculator(query.clone(), 0.0)); - }) - }, - ); + for rotation_type in rotation_types { + let rq = mock_rq_storage(num_bits, rotation_type); + let query = rand_type(&DataType::Float32) + .generate_default(RowCount::from(DIM as u64)) + .unwrap(); + c.bench_function( + format!( + "RQ{}({:?}): construct_dist_table: {},DIM={}", + num_bits, + rotation_type, + DistanceType::L2, + DIM + ) + .as_str(), + |b| { + b.iter(|| { + black_box(rq.dist_calculator(query.clone(), 0.0)); + }) + }, + ); + } } } fn compute_distances(c: &mut Criterion) { + let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix]; for num_bits in 1..=1 { - let rq = mock_rq_storage(num_bits); - let query = rand_type(&DataType::Float32) - .generate_default(RowCount::from(DIM as u64)) - .unwrap(); - let dist_calc = rq.dist_calculator(query.clone(), 0.0); + for rotation_type in rotation_types { + let rq = mock_rq_storage(num_bits, rotation_type); + let query = rand_type(&DataType::Float32) + .generate_default(RowCount::from(DIM as u64)) + .unwrap(); + let dist_calc = rq.dist_calculator(query.clone(), 0.0); - c.bench_function( - format!("RQ{}: compute_distances: {},DIM={}", num_bits, TOTAL, DIM).as_str(), - |b| { - b.iter(|| { - black_box(dist_calc.distance_all(0)); - }) - }, - ); + c.bench_function( + format!( + "RQ{}({:?}): compute_distances: {},DIM={}", + num_bits, rotation_type, TOTAL, DIM + ) + .as_str(), + |b| { + b.iter(|| { + black_box(dist_calc.distance_all(0)); + }) + }, + ); - c.bench_function( - format!( - "RQ{}: compute_distances_single: {},DIM={}", - num_bits, TOTAL, DIM - ) - .as_str(), - |b| { - b.iter(|| { - for i in 0..TOTAL { - black_box(dist_calc.distance(i as u32)); - } - }) - }, - ); + c.bench_function( + format!( + "RQ{}({:?}): compute_distances_single: {},DIM={}", + num_bits, rotation_type, TOTAL, DIM + ) + .as_str(), + |b| { + b.iter(|| { + for i in 0..TOTAL { + black_box(dist_calc.distance(i as u32)); + } + }) + }, + ); + } } } diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index b36003fddf9..1210847c30e 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -4,17 +4,20 @@ //! Binary Quantization (BQ) use std::iter::once; +use std::str::FromStr; use std::sync::Arc; use arrow_array::types::Float32Type; use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array}; use lance_core::{Error, Result}; use num_traits::Float; +use serde::{Deserialize, Serialize}; use snafu::location; use crate::vector::quantizer::QuantizerBuildParams; pub mod builder; +pub mod rotation; pub mod storage; pub mod transform; @@ -80,14 +83,56 @@ fn binary_quantization(data: &[T]) -> impl Iterator + '_ { })) } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum RQRotationType { + Fast, + Matrix, +} + +impl Default for RQRotationType { + fn default() -> Self { + Self::Fast + } +} + +impl FromStr for RQRotationType { + type Err = Error; + + fn from_str(value: &str) -> std::result::Result { + match value.to_lowercase().as_str() { + "fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast), + "matrix" | "dense" => Ok(Self::Matrix), + _ => Err(Error::invalid_input( + format!( + "Unknown RQ rotation type: {}. Expected one of: fast, matrix", + value + ), + location!(), + )), + } + } +} + #[derive(Clone, Debug, PartialEq, Eq)] pub struct RQBuildParams { pub num_bits: u8, + pub rotation_type: RQRotationType, } impl RQBuildParams { pub fn new(num_bits: u8) -> Self { - Self { num_bits } + Self { + num_bits, + rotation_type: RQRotationType::default(), + } + } + + pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self { + Self { + num_bits, + rotation_type, + } } } @@ -99,7 +144,10 @@ impl QuantizerBuildParams for RQBuildParams { impl Default for RQBuildParams { fn default() -> Self { - Self { num_bits: 1 } + Self { + num_bits: 1, + rotation_type: RQRotationType::default(), + } } } @@ -126,4 +174,17 @@ mod tests { test_bq::(); test_bq::(); } + + #[test] + fn test_rotation_type_parse() { + assert_eq!( + "fast".parse::().unwrap(), + RQRotationType::Fast + ); + assert_eq!( + "matrix".parse::().unwrap(), + RQRotationType::Matrix + ); + assert!("invalid".parse::().is_err()); + } } diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 47a40c55801..11cde4ba3e9 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -20,7 +20,10 @@ use crate::vector::bq::storage::{ RabitQuantizationMetadata, RabitQuantizationStorage, RABIT_CODE_COLUMN, RABIT_METADATA_KEY, }; use crate::vector::bq::transform::{ADD_FACTORS_FIELD, SCALE_FACTORS_FIELD}; -use crate::vector::bq::RQBuildParams; +use crate::vector::bq::{ + rotation::{apply_fast_rotation, random_fast_rotation_signs}, + RQBuildParams, RQRotationType, +}; use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; /// Build parameters for RabitQuantizer. @@ -28,11 +31,15 @@ use crate::vector::quantizer::{Quantization, Quantizer, QuantizerBuildParams}; /// num_bits: the number of bits per dimension. pub struct RabitBuildParams { pub num_bits: u8, + pub rotation_type: RQRotationType, } impl Default for RabitBuildParams { fn default() -> Self { - Self { num_bits: 1 } + Self { + num_bits: 1, + rotation_type: RQRotationType::default(), + } } } @@ -50,25 +57,47 @@ pub struct RabitQuantizer { impl RabitQuantizer { pub fn new(num_bits: u8, dim: i32) -> Self { - // we don't need to calculate the inverse of P, - // just take the generated matrix as P^{-1} - let code_dim = dim * num_bits as i32; - let rotate_mat = random_orthogonal::(code_dim as usize); - let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset(); - - let rotate_mat = match T::FLOAT_TYPE { - FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => { - let rotate_mat = >::from_values(rotate_mat); - FixedSizeListArray::try_new_from_values(rotate_mat, code_dim).unwrap() - } - _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE), - }; + Self::new_with_rotation::(num_bits, dim, RQRotationType::default()) + } - let metadata = RabitQuantizationMetadata { - rotate_mat: Some(rotate_mat), - rotate_mat_position: 0, - num_bits, - packed: false, + pub fn new_with_rotation( + num_bits: u8, + dim: i32, + rotation_type: RQRotationType, + ) -> Self { + let code_dim = (dim * num_bits as i32) as usize; + let metadata = match rotation_type { + RQRotationType::Matrix => { + // we don't need to calculate the inverse of P, just take generated Q as P^{-1} + let rotate_mat = random_orthogonal::(code_dim); + let (rotate_mat, _) = rotate_mat.into_raw_vec_and_offset(); + let rotate_mat = match T::FLOAT_TYPE { + FloatType::Float16 | FloatType::Float32 | FloatType::Float64 => { + let rotate_mat = >::from_values(rotate_mat); + FixedSizeListArray::try_new_from_values(rotate_mat, code_dim as i32) + .unwrap() + } + _ => unimplemented!("RabitQ does not support data type: {:?}", T::FLOAT_TYPE), + }; + RabitQuantizationMetadata { + rotate_mat: Some(rotate_mat), + rotate_mat_position: None, + fast_rotation_signs: None, + rotation_type, + code_dim: code_dim as u32, + num_bits, + packed: false, + } + } + RQRotationType::Fast => RabitQuantizationMetadata { + rotate_mat: None, + rotate_mat_position: None, + fast_rotation_signs: Some(random_fast_rotation_signs(code_dim)), + rotation_type, + code_dim: code_dim as u32, + num_bits, + packed: false, + }, }; Self { metadata } } @@ -77,6 +106,10 @@ impl RabitQuantizer { self.metadata.num_bits } + pub fn rotation_type(&self) -> RQRotationType { + self.metadata.rotation_type + } + #[inline] fn rotate_mat_flat(&self) -> &[T::Native] { let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap(); @@ -94,6 +127,43 @@ impl RabitQuantizer { ndarray::ArrayView2::from_shape((code_dim, code_dim), self.rotate_mat_flat::()).unwrap() } + fn rotate_vectors( + &self, + vectors: ndarray::ArrayView2<'_, T::Native>, + ) -> ndarray::Array2 + where + T::Native: AsPrimitive, + { + let dim = vectors.nrows(); + let code_dim = self.code_dim(); + match self.rotation_type() { + RQRotationType::Matrix => { + let rotate_mat = self.rotate_mat::(); + let rotate_mat = rotate_mat.slice(s![.., 0..dim]); + rotate_mat.dot(&vectors).mapv(|v| v.as_()) + } + RQRotationType::Fast => { + let signs = self + .metadata + .fast_rotation_signs + .as_ref() + .expect("RabitQ fast rotation signs missing"); + let mut rotated = ndarray::Array2::::zeros((code_dim, vectors.ncols())); + let mut scratch = vec![0.0f32; code_dim]; + for (col_idx, vector) in vectors.axis_iter(Axis(1)).enumerate() { + let input = vector + .as_slice() + .expect("RabitQ input vectors should be contiguous"); + apply_fast_rotation(input, &mut scratch, signs); + for (row_idx, value) in scratch.iter().enumerate() { + rotated[[row_idx, col_idx]] = *value; + } + } + rotated + } + } + } + pub fn dim(&self) -> usize { self.code_dim() / self.metadata.num_bits as usize } @@ -131,12 +201,9 @@ impl RabitQuantizer { .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; let vec_mat = vec_mat.t(); - let rotate_mat = self.rotate_mat::(); - // slice to (code_dim, dim) - let rotate_mat = rotate_mat.slice(s![.., 0..dim]); - let rotated_vectors = rotate_mat.dot(&vec_mat); + let rotated_vectors = self.rotate_vectors::(vec_mat); let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt(); - let norm_dists = rotated_vectors.mapv(|v| v.as_().abs()).sum_axis(Axis(0)) / sqrt_dim; + let norm_dists = rotated_vectors.mapv(f32::abs).sum_axis(Axis(0)) / sqrt_dim; debug_assert_eq!(norm_dists.len(), residual_vectors.len()); Ok(norm_dists.to_vec()) } @@ -165,11 +232,9 @@ impl RabitQuantizer { ) .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; let vectors = vectors.t(); - let rotate_mat = self.rotate_mat::(); - let rotate_mat = rotate_mat.slice(s![.., 0..dim]); - let rotated_vectors = rotate_mat.dot(&vectors); + let rotated_vectors = self.rotate_vectors::(vectors); - let quantized_vectors = rotated_vectors.t().mapv(|v| v.as_().is_sign_positive()); + let quantized_vectors = rotated_vectors.t().mapv(|v| v.is_sign_positive()); let bv: BitVec = BitVec::from_iter(quantized_vectors); let codes = UInt8Array::from(bv.into_vec()); @@ -192,15 +257,21 @@ impl Quantization for RabitQuantizer { params: &Self::BuildParams, ) -> Result { let q = match data.as_fixed_size_list().value_type() { - DataType::Float16 => { - Self::new::(params.num_bits, data.as_fixed_size_list().value_length()) - } - DataType::Float32 => { - Self::new::(params.num_bits, data.as_fixed_size_list().value_length()) - } - DataType::Float64 => { - Self::new::(params.num_bits, data.as_fixed_size_list().value_length()) - } + DataType::Float16 => Self::new_with_rotation::( + params.num_bits, + data.as_fixed_size_list().value_length(), + params.rotation_type, + ), + DataType::Float32 => Self::new_with_rotation::( + params.num_bits, + data.as_fixed_size_list().value_length(), + params.rotation_type, + ), + DataType::Float64 => Self::new_with_rotation::( + params.num_bits, + data.as_fixed_size_list().value_length(), + params.rotation_type, + ), dt => { return Err(Error::invalid_input( format!("Unsupported data type: {:?}", dt), @@ -216,11 +287,15 @@ impl Quantization for RabitQuantizer { } fn code_dim(&self) -> usize { - self.metadata - .rotate_mat - .as_ref() - .map(|inv_p| inv_p.len()) - .unwrap_or(0) + if self.metadata.code_dim > 0 { + self.metadata.code_dim as usize + } else { + self.metadata + .rotate_mat + .as_ref() + .map(|rotate_mat| rotate_mat.len()) + .unwrap_or(0) + } } fn column(&self) -> &'static str { @@ -370,6 +445,7 @@ where mod tests { use super::*; use approx::assert_relative_eq; + use arrow::datatypes::Float32Type; use rstest::rstest; #[rstest] @@ -410,4 +486,16 @@ mod tests { assert_eq!(q.dim(), (m, m)); assert_eq!(r.dim(), (m, n)); } + + #[test] + fn test_rabit_quantizer_rotation_modes() { + let fast_q = RabitQuantizer::new_with_rotation::(1, 128, RQRotationType::Fast); + assert_eq!(fast_q.rotation_type(), RQRotationType::Fast); + assert_eq!(fast_q.dim(), 128); + + let matrix_q = + RabitQuantizer::new_with_rotation::(1, 128, RQRotationType::Matrix); + assert_eq!(matrix_q.rotation_type(), RQRotationType::Matrix); + assert_eq!(matrix_q.dim(), 128); + } } diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs new file mode 100644 index 00000000000..21161a43524 --- /dev/null +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -0,0 +1,130 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use num_traits::AsPrimitive; +use rand::RngCore; + +const FAST_ROTATION_ROUNDS: usize = 4; + +#[inline] +fn fwht_in_place(values: &mut [f32]) { + debug_assert!(values.len().is_power_of_two()); + let mut half = 1usize; + while half < values.len() { + let step = half * 2; + for i in (0..values.len()).step_by(step) { + for j in i..(i + half) { + let x = values[j]; + let y = values[j + half]; + values[j] = x + y; + values[j + half] = x - y; + } + } + half = step; + } +} + +#[inline] +fn flip_signs(values: &mut [f32], signs: &[u8]) { + debug_assert!(signs.len() * 8 >= values.len()); + for (idx, value) in values.iter_mut().enumerate() { + if (signs[idx / 8] >> (idx % 8)) & 1 == 1 { + *value = -*value; + } + } +} + +#[inline] +fn kacs_walk(values: &mut [f32]) { + let half = values.len() / 2; + for i in 0..half { + let x = values[i]; + let y = values[i + half]; + values[i] = x + y; + values[i + half] = x - y; + } +} + +#[inline] +fn rescale(values: &mut [f32], factor: f32) { + values.iter_mut().for_each(|v| *v *= factor); +} + +#[inline] +fn sign_bytes_per_round(dim: usize) -> usize { + dim.div_ceil(8) +} + +pub fn random_fast_rotation_signs(dim: usize) -> Vec { + let mut signs = vec![0u8; FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim)]; + rand::rng().fill_bytes(&mut signs); + signs +} + +pub fn apply_fast_rotation>(input: &[T], output: &mut [f32], signs: &[u8]) { + let dim = output.len(); + let bytes_per_round = sign_bytes_per_round(dim); + debug_assert_eq!(signs.len(), FAST_ROTATION_ROUNDS * bytes_per_round); + output.fill(0.0); + output + .iter_mut() + .zip(input.iter()) + .for_each(|(dst, src)| *dst = src.as_()); + + if dim == 0 { + return; + } + + let trunc_dim = 1usize << dim.ilog2(); + let scale = 1.0f32 / (trunc_dim as f32).sqrt(); + if trunc_dim == dim { + for round in 0..FAST_ROTATION_ROUNDS { + let offset = round * bytes_per_round; + flip_signs(output, &signs[offset..offset + bytes_per_round]); + fwht_in_place(output); + rescale(output, scale); + } + return; + } + + let start = dim - trunc_dim; + for round in 0..FAST_ROTATION_ROUNDS { + let offset = round * bytes_per_round; + flip_signs(output, &signs[offset..offset + bytes_per_round]); + + if round % 2 == 0 { + let head = &mut output[..trunc_dim]; + fwht_in_place(head); + rescale(head, scale); + } else { + let tail = &mut output[start..]; + fwht_in_place(tail); + rescale(tail, scale); + } + + kacs_walk(output); + } + + // Matches RaBitQ-Library FhtKacRotator behavior for non-power-of-two dimensions. + rescale(output, 0.25); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_fast_rotation_sign_bytes() { + assert_eq!(random_fast_rotation_signs(128).len(), 64); + assert_eq!(random_fast_rotation_signs(130).len(), 68); + } + + #[test] + fn test_fast_rotation_preserves_shape() { + let input = vec![1.0f32; 129]; + let mut output = vec![0.0f32; 129]; + let signs = random_fast_rotation_signs(129); + apply_fast_rotation(&input, &mut output, &signs); + assert_eq!(output.len(), 129); + } +} diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index cead98a72f6..f94f2135295 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -28,7 +28,9 @@ use snafu::location; use crate::frag_reuse::FragReuseIndex; use crate::pb; +use crate::vector::bq::rotation::apply_fast_rotation; use crate::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN}; +use crate::vector::bq::RQRotationType; use crate::vector::pq::storage::transpose; use crate::vector::quantizer::{QuantizerMetadata, QuantizerStorage}; use crate::vector::storage::{DistCalculator, VectorStore}; @@ -45,45 +47,80 @@ pub struct RabitQuantizationMetadata { // in the global buffer, which is a binary format (protobuf for now) for efficiency. #[serde(skip)] pub rotate_mat: Option, - pub rotate_mat_position: u32, + #[serde(default)] + pub rotate_mat_position: Option, + #[serde(default)] + pub fast_rotation_signs: Option>, + #[serde(default = "default_rotation_type_compat")] + pub rotation_type: RQRotationType, + #[serde(default)] + pub code_dim: u32, pub num_bits: u8, pub packed: bool, } +fn default_rotation_type_compat() -> RQRotationType { + // Older metadata does not have this field and always used dense matrices. + RQRotationType::Matrix +} + impl DeepSizeOf for RabitQuantizationMetadata { fn deep_size_of_children(&self, _context: &mut deepsize::Context) -> usize { self.rotate_mat .as_ref() .map(|inv_p| inv_p.get_array_memory_size()) .unwrap_or(0) + + self + .fast_rotation_signs + .as_ref() + .map(|signs| signs.len()) + .unwrap_or(0) } } #[async_trait] impl QuantizerMetadata for RabitQuantizationMetadata { fn buffer_index(&self) -> Option { - Some(self.rotate_mat_position) + match self.rotation_type { + RQRotationType::Matrix => self.rotate_mat_position, + RQRotationType::Fast => None, + } } fn set_buffer_index(&mut self, index: u32) { - self.rotate_mat_position = index; + self.rotate_mat_position = Some(index); } fn parse_buffer(&mut self, bytes: Bytes) -> Result<()> { + if self.rotation_type != RQRotationType::Matrix { + return Ok(()); + } debug_assert!(!bytes.is_empty()); let codebook_tensor: pb::Tensor = pb::Tensor::decode(bytes)?; self.rotate_mat = Some(FixedSizeListArray::try_from(&codebook_tensor)?); + if self.code_dim == 0 { + self.code_dim = self + .rotate_mat + .as_ref() + .map(|rotate_mat| rotate_mat.len() as u32) + .unwrap_or(0); + } Ok(()) } fn extra_metadata(&self) -> Result> { - if let Some(inv_p) = &self.rotate_mat { - let inv_p_tensor = pb::Tensor::try_from(inv_p)?; - let mut bytes = BytesMut::new(); - inv_p_tensor.encode(&mut bytes)?; - Ok(Some(bytes.freeze())) - } else { - Ok(None) + match self.rotation_type { + RQRotationType::Matrix => { + if let Some(inv_p) = &self.rotate_mat { + let inv_p_tensor = pb::Tensor::try_from(inv_p)?; + let mut bytes = BytesMut::new(); + inv_p_tensor.encode(&mut bytes)?; + Ok(Some(bytes.freeze())) + } else { + Ok(None) + } + } + RQRotationType::Fast => Ok(None), } } @@ -127,7 +164,7 @@ impl DeepSizeOf for RabitQuantizationStorage { } impl RabitQuantizationStorage { - fn rotate_query_vector( + fn rotate_query_vector_dense( rotate_mat: &FixedSizeListArray, qr: &dyn Array, ) -> Vec @@ -154,6 +191,25 @@ impl RabitQuantizationStorage { .map(|chunk| lance_linalg::distance::dot(&chunk[..d], qr)) .collect() } + + fn rotate_query_vector_fast( + code_dim: usize, + signs: &[u8], + qr: &dyn Array, + ) -> Vec + where + T::Native: AsPrimitive, + { + let qr = qr + .as_any() + .downcast_ref::() + .unwrap() + .as_slice(); + + let mut output = vec![0.0f32; code_dim]; + apply_fast_rotation(qr, &mut output, signs); + output + } } pub struct RabitDistCalculator<'a> { @@ -408,17 +464,56 @@ impl VectorStore for RabitQuantizationStorage { #[inline(never)] fn dist_calculator(&self, qr: Arc, dist_q_c: f32) -> Self::DistanceCalculator<'_> { let codes = self.codes.values().as_primitive::().values(); - let rotate_mat = self - .metadata - .rotate_mat - .as_ref() - .expect("RabitQ metadata not loaded"); + let code_dim = if self.metadata.code_dim > 0 { + self.metadata.code_dim as usize + } else { + self.metadata + .rotate_mat + .as_ref() + .map(|rotate_mat| rotate_mat.len()) + .unwrap_or_default() + }; - let rotated_qr = match rotate_mat.value_type() { - DataType::Float16 => Self::rotate_query_vector::(rotate_mat, &qr), - DataType::Float32 => Self::rotate_query_vector::(rotate_mat, &qr), - DataType::Float64 => Self::rotate_query_vector::(rotate_mat, &qr), - dt => unimplemented!("RabitQ does not support data type: {}", dt), + let rotated_qr = match self.metadata.rotation_type { + RQRotationType::Matrix => { + let rotate_mat = self + .metadata + .rotate_mat + .as_ref() + .expect("RabitQ dense rotation metadata not loaded"); + + match rotate_mat.value_type() { + DataType::Float16 => { + Self::rotate_query_vector_dense::(rotate_mat, &qr) + } + DataType::Float32 => { + Self::rotate_query_vector_dense::(rotate_mat, &qr) + } + DataType::Float64 => { + Self::rotate_query_vector_dense::(rotate_mat, &qr) + } + dt => unimplemented!("RabitQ does not support data type: {}", dt), + } + } + RQRotationType::Fast => { + let signs = self + .metadata + .fast_rotation_signs + .as_ref() + .expect("RabitQ fast rotation metadata not loaded"); + match qr.data_type() { + DataType::Float16 => { + Self::rotate_query_vector_fast::(code_dim, signs, &qr) + } + DataType::Float32 => { + Self::rotate_query_vector_fast::(code_dim, signs, &qr) + } + DataType::Float64 => { + Self::rotate_query_vector_fast::(code_dim, signs, &qr) + } + dt => unimplemented!("RabitQ does not support data type: {}", dt), + } + } }; let dist_table = build_dist_table_direct::(&rotated_qr); diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 1dd015f789b..7296aeafc7a 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -24,7 +24,7 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; use lance_index::vector::bq::builder::RabitQuantizer; -use lance_index::vector::bq::RQBuildParams; +use lance_index::vector::bq::{RQBuildParams, RQRotationType}; use lance_index::vector::flat::index::{FlatBinQuantizer, FlatIndex, FlatQuantizer}; use lance_index::vector::hnsw::HNSW; use lance_index::vector::ivf::builder::recommended_num_partitions; @@ -168,8 +168,22 @@ impl VectorIndexParams { } pub fn ivf_rq(num_partitions: usize, num_bits: u8, distance_type: DistanceType) -> Self { + Self::ivf_rq_with_rotation( + num_partitions, + num_bits, + distance_type, + RQRotationType::default(), + ) + } + + pub fn ivf_rq_with_rotation( + num_partitions: usize, + num_bits: u8, + distance_type: DistanceType, + rotation_type: RQRotationType, + ) -> Self { let ivf = IvfBuildParams::new(num_partitions); - let rq = RQBuildParams { num_bits }; + let rq = RQBuildParams::with_rotation_type(num_bits, rotation_type); let stages = vec![StageParams::Ivf(ivf), StageParams::RQ(rq)]; Self { stages, @@ -1639,6 +1653,7 @@ fn derive_sq_params(sq_quantizer: &ScalarQuantizer) -> SQBuildParams { fn derive_rabit_params(rabit_quantizer: &RabitQuantizer) -> RQBuildParams { RQBuildParams { num_bits: rabit_quantizer.num_bits(), + rotation_type: rabit_quantizer.rotation_type(), } } From 3df09d494ea5042462553a273ef3b836d25c8d3c Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Sat, 21 Feb 2026 15:02:18 +0800 Subject: [PATCH 2/7] perf(rq): optimize fast rotation path --- rust/lance-index/src/vector/bq/builder.rs | 183 ++++++++++++++------- rust/lance-index/src/vector/bq/rotation.rs | 94 +++++++++-- 2 files changed, 200 insertions(+), 77 deletions(-) diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index 11cde4ba3e9..db36723f425 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -11,9 +11,10 @@ use bitvec::prelude::{BitVec, Lsb0}; use deepsize::DeepSizeOf; use lance_arrow::{ArrowFloatType, FixedSizeListArrayExt, FloatArray, FloatType}; use lance_core::{Error, Result}; -use ndarray::{s, Axis}; +use ndarray::{s, Axis, ShapeBuilder}; use num_traits::{AsPrimitive, FromPrimitive}; use rand_distr::Distribution; +use rayon::prelude::*; use snafu::location; use crate::vector::bq::storage::{ @@ -55,6 +56,16 @@ pub struct RabitQuantizer { metadata: RabitQuantizationMetadata, } +#[inline] +fn pack_sign_bits(codes: &mut [u8], rotated: &[f32]) { + codes.fill(0); + for (bit_idx, value) in rotated.iter().enumerate() { + if value.is_sign_positive() { + codes[bit_idx / u8::BITS as usize] |= 1u8 << (bit_idx % u8::BITS as usize); + } + } +} + impl RabitQuantizer { pub fn new(num_bits: u8, dim: i32) -> Self { Self::new_with_rotation::(num_bits, dim, RQRotationType::default()) @@ -110,6 +121,15 @@ impl RabitQuantizer { self.metadata.rotation_type } + #[inline] + fn fast_rotation_signs(&self) -> &[u8] { + self.metadata + .fast_rotation_signs + .as_ref() + .expect("RabitQ fast rotation signs missing") + .as_slice() + } + #[inline] fn rotate_mat_flat(&self) -> &[T::Native] { let rotate_mat = self.metadata.rotate_mat.as_ref().unwrap(); @@ -143,23 +163,25 @@ impl RabitQuantizer { rotate_mat.dot(&vectors).mapv(|v| v.as_()) } RQRotationType::Fast => { - let signs = self - .metadata - .fast_rotation_signs - .as_ref() - .expect("RabitQ fast rotation signs missing"); - let mut rotated = ndarray::Array2::::zeros((code_dim, vectors.ncols())); - let mut scratch = vec![0.0f32; code_dim]; - for (col_idx, vector) in vectors.axis_iter(Axis(1)).enumerate() { - let input = vector - .as_slice() - .expect("RabitQ input vectors should be contiguous"); - apply_fast_rotation(input, &mut scratch, signs); - for (row_idx, value) in scratch.iter().enumerate() { - rotated[[row_idx, col_idx]] = *value; - } - } - rotated + let signs = self.fast_rotation_signs(); + let ncols = vectors.ncols(); + let mut rotated_data = vec![0.0f32; code_dim * ncols]; + rotated_data + .par_chunks_mut(code_dim) + .enumerate() + .for_each_init( + || vec![0.0f32; code_dim], + |scratch, (col_idx, dst)| { + let column = vectors.column(col_idx); + let input = column + .as_slice() + .expect("RabitQ input vectors should be contiguous"); + apply_fast_rotation(input, scratch, signs); + dst.copy_from_slice(scratch); + }, + ); + + ndarray::Array2::from_shape_vec((code_dim, ncols).f(), rotated_data).unwrap() } } } @@ -174,7 +196,7 @@ impl RabitQuantizer { residual_vectors: &FixedSizeListArray, ) -> Result> where - T::Native: AsPrimitive, + T::Native: AsPrimitive + Sync, { let dim = self.dim(); if residual_vectors.value_length() as usize != dim { @@ -188,24 +210,43 @@ impl RabitQuantizer { )); } - // convert the vector to a dxN matrix - let vec_mat = ndarray::ArrayView2::from_shape( - (residual_vectors.len(), dim), - residual_vectors - .values() - .as_any() - .downcast_ref::() - .unwrap() - .as_slice(), - ) - .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; - let vec_mat = vec_mat.t(); - - let rotated_vectors = self.rotate_vectors::(vec_mat); let sqrt_dim = (dim as f32 * self.metadata.num_bits as f32).sqrt(); - let norm_dists = rotated_vectors.mapv(f32::abs).sum_axis(Axis(0)) / sqrt_dim; - debug_assert_eq!(norm_dists.len(), residual_vectors.len()); - Ok(norm_dists.to_vec()) + let values = residual_vectors + .values() + .as_any() + .downcast_ref::() + .unwrap() + .as_slice(); + + match self.rotation_type() { + RQRotationType::Matrix => { + // convert the vector to a dxN matrix + let vec_mat = + ndarray::ArrayView2::from_shape((residual_vectors.len(), dim), values) + .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; + let vec_mat = vec_mat.t(); + let rotated_vectors = self.rotate_vectors::(vec_mat); + let norm_dists = rotated_vectors.mapv(f32::abs).sum_axis(Axis(0)) / sqrt_dim; + debug_assert_eq!(norm_dists.len(), residual_vectors.len()); + Ok(norm_dists.to_vec()) + } + RQRotationType::Fast => { + let code_dim = self.code_dim(); + let signs = self.fast_rotation_signs(); + let mut norm_dists = vec![0.0f32; residual_vectors.len()]; + norm_dists + .par_iter_mut() + .zip(values.par_chunks_exact(dim)) + .for_each_init( + || vec![0.0f32; code_dim], + |scratch, (dst, input)| { + apply_fast_rotation(input, scratch, signs); + *dst = scratch.iter().map(|v| v.abs()).sum::() / sqrt_dim; + }, + ); + Ok(norm_dists) + } + } } fn transform( @@ -213,36 +254,60 @@ impl RabitQuantizer { residual_vectors: &FixedSizeListArray, ) -> Result where - T::Native: AsPrimitive, + T::Native: AsPrimitive + Sync, { // we don't need to normalize the residual vectors, // because the sign of P^{-1} * v_r is the same as P^{-1} * v_r / ||v_r|| let n = residual_vectors.len(); let dim = self.dim(); debug_assert_eq!(residual_vectors.values().len(), n * dim); + let values = residual_vectors + .values() + .as_any() + .downcast_ref::() + .unwrap() + .as_slice(); + let code_dim = self.code_dim(); + let code_bytes = code_dim / u8::BITS as usize; - let vectors = ndarray::ArrayView2::from_shape( - (n, dim), - residual_vectors - .values() - .as_any() - .downcast_ref::() - .unwrap() - .as_slice(), - ) - .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; - let vectors = vectors.t(); - let rotated_vectors = self.rotate_vectors::(vectors); - - let quantized_vectors = rotated_vectors.t().mapv(|v| v.is_sign_positive()); - let bv: BitVec = BitVec::from_iter(quantized_vectors); - - let codes = UInt8Array::from(bv.into_vec()); - debug_assert_eq!(codes.len(), n * self.code_dim() / u8::BITS as usize); - Ok(Arc::new(FixedSizeListArray::try_new_from_values( - codes, - self.code_dim() as i32 / u8::BITS as i32, // num_bits -> num_bytes - )?)) + match self.rotation_type() { + RQRotationType::Matrix => { + let vectors = ndarray::ArrayView2::from_shape((n, dim), values) + .map_err(|e| Error::invalid_input(e.to_string(), location!()))?; + let vectors = vectors.t(); + let rotated_vectors = self.rotate_vectors::(vectors); + + let quantized_vectors = rotated_vectors.t().mapv(|v| v.is_sign_positive()); + let bv: BitVec = BitVec::from_iter(quantized_vectors); + + let codes = UInt8Array::from(bv.into_vec()); + debug_assert_eq!(codes.len(), n * code_bytes); + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + codes, + code_bytes as i32, // num_bits -> num_bytes + )?)) + } + RQRotationType::Fast => { + let signs = self.fast_rotation_signs(); + let mut encoded_codes = vec![0u8; n * code_bytes]; + encoded_codes + .par_chunks_mut(code_bytes) + .zip(values.par_chunks_exact(dim)) + .for_each_init( + || vec![0.0f32; code_dim], + |scratch, (code_dst, input)| { + apply_fast_rotation(input, scratch, signs); + pack_sign_bits(code_dst, scratch); + }, + ); + let codes = UInt8Array::from(encoded_codes); + debug_assert_eq!(codes.len(), n * code_bytes); + Ok(Arc::new(FixedSizeListArray::try_new_from_values( + codes, + code_bytes as i32, + )?)) + } + } } } diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs index 21161a43524..3726bf8b377 100644 --- a/rust/lance-index/src/vector/bq/rotation.rs +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -12,42 +12,97 @@ fn fwht_in_place(values: &mut [f32]) { let mut half = 1usize; while half < values.len() { let step = half * 2; - for i in (0..values.len()).step_by(step) { - for j in i..(i + half) { - let x = values[j]; - let y = values[j + half]; - values[j] = x + y; - values[j + half] = x - y; + for block in values.chunks_exact_mut(step) { + let (left, right) = block.split_at_mut(half); + for (x, y) in left.iter_mut().zip(right.iter_mut()) { + let lx = *x; + let ry = *y; + *x = lx + ry; + *y = lx - ry; } } half = step; } } +#[inline] +fn flip_signs_scalar(values: &mut [f32], signs: &[u8]) { + for (byte_idx, &mask) in signs.iter().enumerate() { + let start = byte_idx * 8; + if start >= values.len() { + break; + } + let end = (start + 8).min(values.len()); + for (bit_idx, value) in values[start..end].iter_mut().enumerate() { + let sign_mask = (((mask >> bit_idx) & 1) as u32) << 31; + *value = f32::from_bits(value.to_bits() ^ sign_mask); + } + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[target_feature(enable = "avx2")] +unsafe fn flip_signs_avx2(values: &mut [f32], signs: &[u8]) { + #[cfg(target_arch = "x86")] + use std::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use std::arch::x86_64::*; + + let full_chunks = values.len() / 8; + let bit_select = _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); + let sign_flip = _mm256_set1_epi32(0x80000000u32 as i32); + + for chunk_idx in 0..full_chunks { + let mask = signs[chunk_idx] as i32; + let mask_bits = _mm256_set1_epi32(mask); + let test = _mm256_and_si256(mask_bits, bit_select); + let cmp = _mm256_cmpeq_epi32(test, bit_select); + let xor_mask = _mm256_and_si256(cmp, sign_flip); + + let ptr = unsafe { values.as_mut_ptr().add(chunk_idx * 8) }; + let vec = unsafe { _mm256_loadu_ps(ptr) }; + let out = _mm256_xor_ps(vec, _mm256_castsi256_ps(xor_mask)); + unsafe { _mm256_storeu_ps(ptr, out) }; + } + + if full_chunks * 8 < values.len() { + flip_signs_scalar(&mut values[full_chunks * 8..], &signs[full_chunks..]); + } +} + #[inline] fn flip_signs(values: &mut [f32], signs: &[u8]) { debug_assert!(signs.len() * 8 >= values.len()); - for (idx, value) in values.iter_mut().enumerate() { - if (signs[idx / 8] >> (idx % 8)) & 1 == 1 { - *value = -*value; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + if std::arch::is_x86_feature_detected!("avx2") { + // SAFETY: guarded by runtime feature detection. + unsafe { + flip_signs_avx2(values, signs); + } + return; } } + flip_signs_scalar(values, signs); } #[inline] fn kacs_walk(values: &mut [f32]) { let half = values.len() / 2; - for i in 0..half { - let x = values[i]; - let y = values[i + half]; - values[i] = x + y; - values[i + half] = x - y; + let (left, right) = values.split_at_mut(half); + for (x, y) in left.iter_mut().zip(right.iter_mut()) { + let lx = *x; + let ry = *y; + *x = lx + ry; + *y = lx - ry; } } #[inline] fn rescale(values: &mut [f32], factor: f32) { - values.iter_mut().for_each(|v| *v *= factor); + for value in values.iter_mut() { + *value *= factor; + } } #[inline] @@ -65,11 +120,14 @@ pub fn apply_fast_rotation>(input: &[T], output: &mut [f32], let dim = output.len(); let bytes_per_round = sign_bytes_per_round(dim); debug_assert_eq!(signs.len(), FAST_ROTATION_ROUNDS * bytes_per_round); - output.fill(0.0); - output + let input_len = input.len().min(dim); + output[..input_len] .iter_mut() - .zip(input.iter()) + .zip(input[..input_len].iter()) .for_each(|(dst, src)| *dst = src.as_()); + if input_len < dim { + output[input_len..].fill(0.0); + } if dim == 0 { return; From c998f84eeb18d9cc283562584bd86ebdfeafe779 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 26 Feb 2026 18:36:31 +0800 Subject: [PATCH 3/7] Add matrix IVF_RQ optimize test --- rust/lance/src/index/vector/ivf/v2.rs | 96 ++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index b510c60e41b..08b2cfa4e80 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -625,7 +625,9 @@ mod tests { use arrow_schema::{DataType, Field, Schema, SchemaRef}; use itertools::Itertools; use lance_arrow::FixedSizeListArrayExt; - use lance_index::vector::bq::RQBuildParams; + use lance_index::vector::bq::{ + storage::RabitQuantizationMetadata, RQBuildParams, RQRotationType, + }; use lance_index::vector::storage::VectorStore; use crate::dataset::{InsertBuilder, UpdateBuilder, WriteMode, WriteParams}; @@ -735,6 +737,49 @@ mod tests { vectors } + async fn get_rq_metadata( + dataset: &Dataset, + scheduler: Arc, + index_uuid: &str, + ) -> RabitQuantizationMetadata { + let index_path = dataset + .indices_dir() + .child(index_uuid) + .child(INDEX_AUXILIARY_FILE_NAME); + let file_scheduler = scheduler + .open_file(&index_path, &CachedFileSize::unknown()) + .await + .unwrap(); + let reader = FileReader::try_open( + file_scheduler, + None, + Arc::::default(), + &LanceCache::no_cache(), + FileReaderOptions::default(), + ) + .await + .unwrap(); + let metadata = reader.schema().metadata.get(STORAGE_METADATA_KEY).unwrap(); + let metadata_entries: Vec = serde_json::from_str(metadata).unwrap(); + serde_json::from_str(&metadata_entries[0]).unwrap() + } + + async fn assert_rq_rotation_type(dataset: &Dataset, expected: RQRotationType) { + let obj_store = Arc::new(ObjectStore::local()); + let scheduler = ScanScheduler::new(obj_store, SchedulerConfig::default_for_testing()); + let indices = dataset.load_indices().await.unwrap(); + assert!(!indices.is_empty(), "Expected at least one vector index"); + for index in indices.iter() { + let rq_meta = + get_rq_metadata(dataset, scheduler.clone(), &index.uuid.to_string()).await; + assert_eq!( + rq_meta.rotation_type, expected, + "RQ rotation type mismatch for index {}", + index.uuid + ); + } + } + fn generate_batch( num_rows: usize, start_id: Option, @@ -2061,10 +2106,11 @@ mod tests { #[case] nlist: usize, #[case] distance_type: DistanceType, #[case] recall_requirement: f32, + #[values(RQRotationType::Fast, RQRotationType::Matrix)] rotation_type: RQRotationType, ) { let _ = env_logger::try_init(); let ivf_params = IvfBuildParams::new(nlist); - let rq_params = RQBuildParams::new(1); + let rq_params = RQBuildParams::with_rotation_type(1, rotation_type); let params = VectorIndexParams::with_ivf_rq_params(distance_type, ivf_params, rq_params); test_index(params.clone(), nlist, recall_requirement, None).await; if distance_type == DistanceType::Cosine { @@ -2073,6 +2119,52 @@ mod tests { test_remap(params.clone(), nlist, recall_requirement).await; } + #[rstest] + #[case::fast(RQRotationType::Fast)] + #[case::matrix(RQRotationType::Matrix)] + #[tokio::test] + async fn test_ivf_rq_rotation_type_after_optimize(#[case] rotation_type: RQRotationType) { + let test_dir = TempStrDir::default(); + let test_uri = test_dir.as_str(); + let (mut dataset, _) = generate_test_dataset::(test_uri, 0.0..1.0).await; + + let ivf_params = IvfBuildParams::new(4); + let rq_params = RQBuildParams::with_rotation_type(1, rotation_type); + let params = VectorIndexParams::with_ivf_rq_params(DistanceType::L2, ivf_params, rq_params); + dataset + .create_index(&["vector"], IndexType::Vector, None, ¶ms, true) + .await + .unwrap(); + + assert_rq_rotation_type(&dataset, rotation_type).await; + + append_dataset::(&mut dataset, 64, 0.0..1.0).await; + dataset + .optimize_indices(&OptimizeOptions::append()) + .await + .unwrap(); + + let indices_after_append = dataset.load_indices().await.unwrap(); + assert_eq!( + indices_after_append.len(), + 2, + "Expected append optimize to create one delta index" + ); + assert_rq_rotation_type(&dataset, rotation_type).await; + + dataset + .optimize_indices(&OptimizeOptions::merge(10)) + .await + .unwrap(); + let indices_after_merge = dataset.load_indices().await.unwrap(); + assert_eq!( + indices_after_merge.len(), + 1, + "Expected merge optimize to merge indices into one" + ); + assert_rq_rotation_type(&dataset, rotation_type).await; + } + #[rstest] #[case(4, DistanceType::L2, 0.9)] #[case(4, DistanceType::Cosine, 0.9)] From e51d3ac687c19145557add0453d8b5eb1747c4e0 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 26 Feb 2026 19:05:38 +0800 Subject: [PATCH 4/7] Remove rotation_type parameters --- python/python/lance/dataset.py | 3 --- python/python/tests/test_vector_index.py | 20 -------------------- python/src/dataset.rs | 16 +--------------- 3 files changed, 1 insertion(+), 38 deletions(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index a027d982274..4c8ce8ae40e 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2866,9 +2866,6 @@ def create_index( - num_bits The number of bits for RQ (Rabit Quantization). Default is 1. - - rq_rotation_type / rabitq_rotation_type - Rotation implementation for RabitQ. One of ``"fast"`` (default) - or ``"matrix"``. Optional parameters for `IVF_HNSW_*`: max_level diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index cc27de999ba..71850135ae3 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -803,26 +803,6 @@ def test_create_ivf_rq_index(): ) assert ds.list_indices()[0]["fields"] == ["vector"] - ds = ds.create_index( - "vector", - index_type="IVF_RQ", - num_partitions=4, - num_bits=1, - rq_rotation_type="matrix", - replace=True, - ) - assert ds.list_indices()[0]["fields"] == ["vector"] - - with pytest.raises(ValueError, match="Unknown RQ rotation type"): - ds.create_index( - "vector", - index_type="IVF_RQ", - num_partitions=4, - num_bits=1, - rq_rotation_type="not-a-rotation", - replace=True, - ) - with pytest.raises( NotImplementedError, match="Creating empty vector indices with train=False is not yet implemented", diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 094f385f524..b4b4d8af050 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -17,7 +17,7 @@ use async_trait::async_trait; use blob::LanceBlobFile; use chrono::{Duration, TimeDelta, Utc}; use futures::{StreamExt, TryFutureExt}; -use lance_index::vector::bq::{RQBuildParams, RQRotationType}; +use lance_index::vector::bq::RQBuildParams; use log::error; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -3331,20 +3331,6 @@ fn prepare_vector_index_params( rq_params.num_bits = num_bits; }; - let rq_rotation_type = if let Some(rotation_type) = kwargs.get_item("rq_rotation_type")? { - Some(rotation_type.extract::()?) - } else if let Some(rotation_type) = kwargs.get_item("rabitq_rotation_type")? { - Some(rotation_type.extract::()?) - } else { - None - }; - if let Some(rotation_type) = rq_rotation_type { - rq_params.rotation_type = - rotation_type - .parse::() - .map_err(|e| PyValueError::new_err(format!("{}", e)))?; - } - if let Some(n) = kwargs.get_item("num_sub_vectors")? { pq_params.num_sub_vectors = n.extract()? }; From 86ae6b77a2a4ae29d1bfd1ba1845918492c4e355 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 26 Feb 2026 19:18:34 +0800 Subject: [PATCH 5/7] Document rotation algorithms --- rust/lance-index/src/vector/bq/rotation.rs | 35 ++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs index 3726bf8b377..6ac231ad7f3 100644 --- a/rust/lance-index/src/vector/bq/rotation.rs +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -4,10 +4,30 @@ use num_traits::AsPrimitive; use rand::RngCore; +// Fast random rotation used by the RabitQ "fast" path. +// +// The transform is a composition of: +// 1) random diagonal sign flips (Rademacher variables), +// 2) FWHT-style mixing on a power-of-two window, +// 3) a Kac-style pairwise mixing step for non-power-of-two dimensions. +// +// Background: +// - Hadamard transform: https://en.wikipedia.org/wiki/Hadamard_transform +// - Fast Walsh-Hadamard transform (FWHT): +// https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform +// - Rademacher random signs: +// https://en.wikipedia.org/wiki/Rademacher_distribution +// - Kac-walk-based fast dimension reduction (uses fixed-angle pair rotations): +// https://arxiv.org/abs/2003.10069 +// - Givens / plane rotation: +// https://en.wikipedia.org/wiki/Givens_rotation const FAST_ROTATION_ROUNDS: usize = 4; #[inline] fn fwht_in_place(values: &mut [f32]) { + // In-place FWHT butterfly network. + // For each stage, pair entries (x, y) and map to (x + y, x - y). + // Complexity: O(n log n) operations, no extra heap allocation. debug_assert!(values.len().is_power_of_two()); let mut half = 1usize; while half < values.len() { @@ -27,6 +47,8 @@ fn fwht_in_place(values: &mut [f32]) { #[inline] fn flip_signs_scalar(values: &mut [f32], signs: &[u8]) { + // Apply a random diagonal matrix with +/-1 entries by toggling the f32 sign bit. + // One bit in `signs` controls one element in `values`. for (byte_idx, &mask) in signs.iter().enumerate() { let start = byte_idx * 8; if start >= values.len() { @@ -48,6 +70,8 @@ unsafe fn flip_signs_avx2(values: &mut [f32], signs: &[u8]) { #[cfg(target_arch = "x86_64")] use std::arch::x86_64::*; + // Vectorized variant of `flip_signs_scalar`: consume 8 f32 values per AVX2 lane. + // The sign mask is expanded from one byte to 8 lane-wise sign-bit masks. let full_chunks = values.len() / 8; let bit_select = _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); let sign_flip = _mm256_set1_epi32(0x80000000u32 as i32); @@ -88,6 +112,9 @@ fn flip_signs(values: &mut [f32], signs: &[u8]) { #[inline] fn kacs_walk(values: &mut [f32]) { + // A fixed-angle (pi/4) plane-rotation-like sweep over paired coordinates: + // (x, y) -> (x + y, x - y). Up to normalization, this is a 2x2 Hadamard block + // and corresponds to one Kac-style mixing step. let half = values.len() / 2; let (left, right) = values.split_at_mut(half); for (x, y) in left.iter_mut().zip(right.iter_mut()) { @@ -100,6 +127,7 @@ fn kacs_walk(values: &mut [f32]) { #[inline] fn rescale(values: &mut [f32], factor: f32) { + // Keep the transform numerically stable and approximately orthonormal. for value in values.iter_mut() { *value *= factor; } @@ -111,12 +139,18 @@ fn sign_bytes_per_round(dim: usize) -> usize { } pub fn random_fast_rotation_signs(dim: usize) -> Vec { + // Each round needs one random sign bit per dimension. let mut signs = vec![0u8; FAST_ROTATION_ROUNDS * sign_bytes_per_round(dim)]; rand::rng().fill_bytes(&mut signs); signs } pub fn apply_fast_rotation>(input: &[T], output: &mut [f32], signs: &[u8]) { + // Fast random rotation pipeline, aligned with RaBitQ-Library's FhtKacRotator: + // - power-of-two dims: repeat [random signs -> FWHT -> scale] for 4 rounds + // - non-power-of-two dims: alternate FWHT on head/tail + Kac mixing + // + // This keeps the fast path matrix-free: no dense orthogonal matrix materialization. let dim = output.len(); let bytes_per_round = sign_bytes_per_round(dim); debug_assert_eq!(signs.len(), FAST_ROTATION_ROUNDS * bytes_per_round); @@ -164,6 +198,7 @@ pub fn apply_fast_rotation>(input: &[T], output: &mut [f32], } // Matches RaBitQ-Library FhtKacRotator behavior for non-power-of-two dimensions. + // The extra factor compensates the alternating truncated FWHT + Kac steps above. rescale(output, 0.25); } From a2163cea96ec7cfaf29baa563458bb69587c7e4e Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 26 Feb 2026 19:28:32 +0800 Subject: [PATCH 6/7] fix(ci): resolve clippy and jni rq params failures --- java/lance-jni/src/utils.rs | 15 +++++++++++++++ rust/lance-index/src/vector/bq.rs | 9 ++------- rust/lance-index/src/vector/bq/rotation.rs | 4 ++-- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 56ff617c821..9a85bffdfea 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -13,6 +13,7 @@ use lance::dataset::{WriteMode, WriteParams}; use lance::index::vector::{IndexFileVersion, StageParams, VectorIndexParams}; use lance::io::ObjectStoreParams; use lance_encoding::version::LanceFileVersion; +use lance_index::vector::bq::RQBuildParams; use lance_index::vector::hnsw::builder::HnswBuildParams; use lance_index::vector::ivf::IvfBuildParams; use lance_index::vector::pq::PQBuildParams; @@ -399,6 +400,20 @@ pub fn get_vector_index_params( stages.push(StageParams::SQ(sq_params)); } + // Parse RQBuildParams + let rq_params = env.get_optional_from_method( + &vector_index_params_obj, + "getRqParams", + |env, rq_obj| { + let num_bits = env.call_method(&rq_obj, "getNumBits", "()B", &[])?.b()? as u8; + Ok(RQBuildParams::new(num_bits)) + }, + )?; + + if let Some(rq_params) = rq_params { + stages.push(StageParams::RQ(rq_params)); + } + Ok(VectorIndexParams { metric_type: distance_type, stages, diff --git a/rust/lance-index/src/vector/bq.rs b/rust/lance-index/src/vector/bq.rs index 1210847c30e..54748db3264 100644 --- a/rust/lance-index/src/vector/bq.rs +++ b/rust/lance-index/src/vector/bq.rs @@ -83,19 +83,14 @@ fn binary_quantization(data: &[T]) -> impl Iterator + '_ { })) } -#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] pub enum RQRotationType { + #[default] Fast, Matrix, } -impl Default for RQRotationType { - fn default() -> Self { - Self::Fast - } -} - impl FromStr for RQRotationType { type Err = Error; diff --git a/rust/lance-index/src/vector/bq/rotation.rs b/rust/lance-index/src/vector/bq/rotation.rs index 6ac231ad7f3..de4fbf549f1 100644 --- a/rust/lance-index/src/vector/bq/rotation.rs +++ b/rust/lance-index/src/vector/bq/rotation.rs @@ -76,8 +76,8 @@ unsafe fn flip_signs_avx2(values: &mut [f32], signs: &[u8]) { let bit_select = _mm256_setr_epi32(0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80); let sign_flip = _mm256_set1_epi32(0x80000000u32 as i32); - for chunk_idx in 0..full_chunks { - let mask = signs[chunk_idx] as i32; + for (chunk_idx, &mask) in signs.iter().take(full_chunks).enumerate() { + let mask = mask as i32; let mask_bits = _mm256_set1_epi32(mask); let test = _mm256_and_si256(mask_bits, bit_select); let cmp = _mm256_cmpeq_epi32(test, bit_select); From 6e45ae7ed608a7dfc5d4eb5710c86aa8e1d60979 Mon Sep 17 00:00:00 2001 From: BubbleCal Date: Thu, 26 Feb 2026 23:23:18 +0800 Subject: [PATCH 7/7] Enforce IVF RQ dim multiple --- python/python/tests/test_vector_index.py | 18 ++++++++++++++++ rust/lance-index/src/vector/bq/builder.rs | 25 +++++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index dcdf88ee84d..ee7f955c0f7 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -843,6 +843,24 @@ def test_create_ivf_rq_index(): assert res["_distance"].to_numpy().max() == 0.0 +def test_create_ivf_rq_requires_dim_divisible_by_8(): + vectors = np.zeros((1000, 30), dtype=np.float32).tolist() + tbl = pa.Table.from_pydict( + {"vector": pa.array(vectors, type=pa.list_(pa.float32(), 30))} + ) + ds = lance.write_dataset(tbl, "memory://", mode="overwrite") + + with pytest.raises( + ValueError, match="vector dimension must be divisible by 8 for IVF_RQ" + ): + ds.create_index( + "vector", + index_type="IVF_RQ", + num_partitions=4, + num_bits=1, + ) + + def test_create_ivf_hnsw_pq_index(dataset, tmp_path): assert not dataset.has_index ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance") diff --git a/rust/lance-index/src/vector/bq/builder.rs b/rust/lance-index/src/vector/bq/builder.rs index db36723f425..53c73e69cb3 100644 --- a/rust/lance-index/src/vector/bq/builder.rs +++ b/rust/lance-index/src/vector/bq/builder.rs @@ -321,6 +321,14 @@ impl Quantization for RabitQuantizer { _: lance_linalg::distance::DistanceType, params: &Self::BuildParams, ) -> Result { + let dim = data.as_fixed_size_list().value_length() as usize; + if !dim.is_multiple_of(u8::BITS as usize) { + return Err(Error::invalid_input( + "vector dimension must be divisible by 8 for IVF_RQ", + location!(), + )); + } + let q = match data.as_fixed_size_list().value_type() { DataType::Float16 => Self::new_with_rotation::( params.num_bits, @@ -511,6 +519,8 @@ mod tests { use super::*; use approx::assert_relative_eq; use arrow::datatypes::Float32Type; + use arrow_array::{FixedSizeListArray, Float32Array}; + use lance_linalg::distance::DistanceType; use rstest::rstest; #[rstest] @@ -563,4 +573,19 @@ mod tests { assert_eq!(matrix_q.rotation_type(), RQRotationType::Matrix); assert_eq!(matrix_q.dim(), 128); } + + #[test] + fn test_rabit_quantizer_requires_dim_divisible_by_8() { + let vectors = Float32Array::from(vec![0.0f32; 4 * 30]); + let fsl = FixedSizeListArray::try_new_from_values(vectors, 30).unwrap(); + let params = RQBuildParams::new(1); + + let err = RabitQuantizer::build(&fsl, DistanceType::L2, ¶ms).unwrap_err(); + assert!( + err.to_string() + .contains("vector dimension must be divisible by 8 for IVF_RQ"), + "{}", + err + ); + } }