Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion java/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ pub fn get_vector_index_params(
"getRqParams",
|env, rq_obj| {
let num_bits = env.call_method(&rq_obj, "getNumBits", "()B", &[])?.b()? as u8;
Ok(RQBuildParams { num_bits })
Ok(RQBuildParams::new(num_bits))
},
)?;

Expand Down
5 changes: 5 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,6 +2905,11 @@ def create_index(
- index_file_version
The version of the index file. Default is "V3".

Optional parameters for `IVF_RQ`:

- num_bits
The number of bits for RQ (Rabit Quantization). Default is 1.

Optional parameters for `IVF_HNSW_*`:
max_level
Int, the maximum number of levels in the graph.
Expand Down
18 changes: 18 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,24 @@ def test_create_ivf_rq_index():
assert res["_distance"].to_numpy().max() == 0.0


def test_create_ivf_rq_requires_dim_divisible_by_8():
vectors = np.zeros((1000, 30), dtype=np.float32).tolist()
tbl = pa.Table.from_pydict(
{"vector": pa.array(vectors, type=pa.list_(pa.float32(), 30))}
)
ds = lance.write_dataset(tbl, "memory://", mode="overwrite")

with pytest.raises(
ValueError, match="vector dimension must be divisible by 8 for IVF_RQ"
):
ds.create_index(
"vector",
index_type="IVF_RQ",
num_partitions=4,
num_bits=1,
)


def test_create_ivf_hnsw_pq_index(dataset, tmp_path):
assert not dataset.has_index
ann_ds = lance.write_dataset(dataset.to_table(), tmp_path / "indexed.lance")
Expand Down
106 changes: 59 additions & 47 deletions rust/lance-index/benches/rq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@ use lance_datagen::{BatchGeneratorBuilder, RowCount};
use lance_index::vector::bq::builder::RabitQuantizer;
use lance_index::vector::bq::storage::*;
use lance_index::vector::bq::transform::{ADD_FACTORS_COLUMN, SCALE_FACTORS_COLUMN};
use lance_index::vector::bq::RQRotationType;
use lance_index::vector::quantizer::{Quantization, QuantizerStorage};
use lance_index::vector::storage::{DistCalculator, VectorStore};
use lance_linalg::distance::DistanceType;

const DIM: usize = 128;
const TOTAL: usize = 16 * 1000;

fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage {
fn mock_rq_storage(num_bits: u8, rotation_type: RQRotationType) -> RabitQuantizationStorage {
// generate random rq codes
let rq = RabitQuantizer::new::<Float32Type>(num_bits, DIM as i32);
let rq = RabitQuantizer::new_with_rotation::<Float32Type>(num_bits, DIM as i32, rotation_type);
let builder = BatchGeneratorBuilder::new()
.col(ROW_ID, lance_datagen::array::step::<UInt64Type>())
.col(
Expand All @@ -49,59 +50,70 @@ fn mock_rq_storage(num_bits: u8) -> RabitQuantizationStorage {
}

fn construct_dist_table(c: &mut Criterion) {
let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix];
for num_bits in 1..=1 {
let rq = mock_rq_storage(num_bits);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
c.bench_function(
format!(
"RQ{}: construct_dist_table: {},DIM={}",
num_bits,
DistanceType::L2,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(rq.dist_calculator(query.clone(), 0.0));
})
},
);
for rotation_type in rotation_types {
let rq = mock_rq_storage(num_bits, rotation_type);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
c.bench_function(
format!(
"RQ{}({:?}): construct_dist_table: {},DIM={}",
num_bits,
rotation_type,
DistanceType::L2,
DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(rq.dist_calculator(query.clone(), 0.0));
})
},
);
}
}
}

fn compute_distances(c: &mut Criterion) {
let rotation_types = [RQRotationType::Fast, RQRotationType::Matrix];
for num_bits in 1..=1 {
let rq = mock_rq_storage(num_bits);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
let dist_calc = rq.dist_calculator(query.clone(), 0.0);
for rotation_type in rotation_types {
let rq = mock_rq_storage(num_bits, rotation_type);
let query = rand_type(&DataType::Float32)
.generate_default(RowCount::from(DIM as u64))
.unwrap();
let dist_calc = rq.dist_calculator(query.clone(), 0.0);

c.bench_function(
format!("RQ{}: compute_distances: {},DIM={}", num_bits, TOTAL, DIM).as_str(),
|b| {
b.iter(|| {
black_box(dist_calc.distance_all(0));
})
},
);
c.bench_function(
format!(
"RQ{}({:?}): compute_distances: {},DIM={}",
num_bits, rotation_type, TOTAL, DIM
)
.as_str(),
|b| {
b.iter(|| {
black_box(dist_calc.distance_all(0));
})
},
);

c.bench_function(
format!(
"RQ{}: compute_distances_single: {},DIM={}",
num_bits, TOTAL, DIM
)
.as_str(),
|b| {
b.iter(|| {
for i in 0..TOTAL {
black_box(dist_calc.distance(i as u32));
}
})
},
);
c.bench_function(
format!(
"RQ{}({:?}): compute_distances_single: {},DIM={}",
num_bits, rotation_type, TOTAL, DIM
)
.as_str(),
|b| {
b.iter(|| {
for i in 0..TOTAL {
black_box(dist_calc.distance(i as u32));
}
})
},
);
}
}
}

Expand Down
60 changes: 58 additions & 2 deletions rust/lance-index/src/vector/bq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,20 @@
//! Binary Quantization (BQ)

use std::iter::once;
use std::str::FromStr;
use std::sync::Arc;

use arrow_array::types::Float32Type;
use arrow_array::{cast::AsArray, Array, ArrayRef, UInt8Array};
use lance_core::{Error, Result};
use num_traits::Float;
use serde::{Deserialize, Serialize};
use snafu::location;

use crate::vector::quantizer::QuantizerBuildParams;

pub mod builder;
pub mod rotation;
pub mod storage;
pub mod transform;

Expand Down Expand Up @@ -80,14 +83,51 @@ fn binary_quantization<T: Float>(data: &[T]) -> impl Iterator<Item = u8> + '_ {
}))
}

#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum RQRotationType {
#[default]
Fast,
Matrix,
}

impl FromStr for RQRotationType {
type Err = Error;

fn from_str(value: &str) -> std::result::Result<Self, Self::Err> {
match value.to_lowercase().as_str() {
"fast" | "fht_kac" | "fht-kac" => Ok(Self::Fast),
"matrix" | "dense" => Ok(Self::Matrix),
_ => Err(Error::invalid_input(
format!(
"Unknown RQ rotation type: {}. Expected one of: fast, matrix",
value
),
location!(),
)),
}
}
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct RQBuildParams {
pub num_bits: u8,
pub rotation_type: RQRotationType,
}

impl RQBuildParams {
pub fn new(num_bits: u8) -> Self {
Self { num_bits }
Self {
num_bits,
rotation_type: RQRotationType::default(),
}
}

pub fn with_rotation_type(num_bits: u8, rotation_type: RQRotationType) -> Self {
Self {
num_bits,
rotation_type,
}
}
}

Expand All @@ -99,7 +139,10 @@ impl QuantizerBuildParams for RQBuildParams {

impl Default for RQBuildParams {
fn default() -> Self {
Self { num_bits: 1 }
Self {
num_bits: 1,
rotation_type: RQRotationType::default(),
}
}
}

Expand All @@ -126,4 +169,17 @@ mod tests {
test_bq::<f32>();
test_bq::<f64>();
}

#[test]
fn test_rotation_type_parse() {
assert_eq!(
"fast".parse::<RQRotationType>().unwrap(),
RQRotationType::Fast
);
assert_eq!(
"matrix".parse::<RQRotationType>().unwrap(),
RQRotationType::Matrix
);
assert!("invalid".parse::<RQRotationType>().is_err());
}
}
Loading
Loading