Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions java/lance-jni/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,7 @@ pub fn get_vector_index_params(
metric_type: distance_type,
stages,
version: IndexFileVersion::V3,
skip_transpose: false,
})
},
)?;
Expand Down
4 changes: 4 additions & 0 deletions python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2796,6 +2796,7 @@ def create_index(
index_uuid: Optional[str] = None,
*,
target_partition_size: Optional[int] = None,
skip_transpose: bool = False,
**kwargs,
) -> LanceDataset:
"""Create index on column.
Expand Down Expand Up @@ -3271,6 +3272,9 @@ def create_index(
if shuffle_partition_concurrency is not None:
kwargs["shuffle_partition_concurrency"] = shuffle_partition_concurrency

if skip_transpose:
kwargs["skip_transpose"] = True

# Add fragment_ids and index_uuid to kwargs if provided for
# distributed indexing
if fragment_ids is not None:
Expand Down
31 changes: 31 additions & 0 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,6 +827,8 @@ def test_create_ivf_rq_index():
num_bits=1,
)
assert ds.describe_indices()[0].field_names == ["vector"]
stats = ds.stats.index_stats("vector_idx")
assert stats["indices"][0]["sub_index"]["packed"] is True

with pytest.raises(
NotImplementedError,
Expand Down Expand Up @@ -865,6 +867,19 @@ def test_create_ivf_rq_index():
assert res["_distance"].to_numpy().max() == 0.0


def test_create_ivf_rq_skip_transpose():
ds = lance.write_dataset(create_table(), "memory://")
ds = ds.create_index(
"vector",
index_type="IVF_RQ",
num_partitions=4,
num_bits=1,
skip_transpose=True,
)
stats = ds.stats.index_stats("vector_idx")
assert stats["indices"][0]["sub_index"]["packed"] is False


def test_create_ivf_rq_requires_dim_divisible_by_8():
vectors = np.zeros((1000, 30), dtype=np.float32).tolist()
tbl = pa.Table.from_pydict(
Expand Down Expand Up @@ -1058,6 +1073,22 @@ def test_pre_populated_ivf_centroids(dataset, tmp_path: Path):
assert all([partition_keys == set(p.keys()) for p in partitions])


def test_create_ivf_pq_skip_transpose(dataset, tmp_path: Path):
ds = lance.write_dataset(
dataset.to_table(), tmp_path / "indexed_skip_transpose.lance"
)
ds = ds.create_index(
"vector",
index_type="IVF_PQ",
num_partitions=4,
num_sub_vectors=16,
skip_transpose=True,
)

stats = ds.stats.index_stats("vector_idx")
assert stats["indices"][0]["sub_index"]["transposed"] is False


def test_optimize_index(dataset, tmp_path):
dataset_uri = tmp_path / "dataset.lance"
assert not dataset.has_index
Expand Down
6 changes: 6 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3205,6 +3205,7 @@ fn prepare_vector_index_params(
let mut sq_params = SQBuildParams::default();
let mut rq_params = RQBuildParams::default();
let mut index_file_version = IndexFileVersion::V3;
let mut skip_transpose = false;

if let Some(kwargs) = kwargs {
// Parse metric type
Expand Down Expand Up @@ -3334,6 +3335,10 @@ fn prepare_vector_index_params(
index_file_version = IndexFileVersion::try_from(&version)
.map_err(|e| PyValueError::new_err(format!("Invalid index_file_version: {e}")))?;
}

if let Some(value) = kwargs.get_item("skip_transpose")? {
skip_transpose = value.extract()?;
}
}

let mut params = match index_type {
Expand Down Expand Up @@ -3378,6 +3383,7 @@ fn prepare_vector_index_params(
))),
}?;
params.version(index_file_version);
params.skip_transpose(skip_transpose);
Ok(params)
}

Expand Down
115 changes: 115 additions & 0 deletions rust/lance-index/src/vector/bq/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,14 @@ fn get_rq_code(
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;

use arrow_array::{ArrayRef, Float32Array, UInt64Array};
use lance_core::ROW_ID;
use lance_linalg::distance::DistanceType;

use crate::vector::bq::{RQRotationType, builder::RabitQuantizer};
use crate::vector::quantizer::{Quantization, QuantizerStorage};

fn build_dist_table_not_optimized<T: ArrowFloatType>(
sub_vec: &[T::Native],
Expand Down Expand Up @@ -912,4 +920,111 @@ mod tests {
);
}
}

fn make_test_codes(num_vectors: usize, code_dim: i32) -> FixedSizeListArray {
let quantizer =
RabitQuantizer::new_with_rotation::<Float32Type>(1, code_dim, RQRotationType::Fast);
let values = Float32Array::from_iter_values(
(0..num_vectors * code_dim as usize).map(|idx| idx as f32 / code_dim as f32),
);
let vectors = FixedSizeListArray::try_new_from_values(values, code_dim).unwrap();
quantizer
.quantize(&vectors)
.unwrap()
.as_fixed_size_list()
.clone()
}

fn make_test_metadata(code_dim: usize) -> RabitQuantizationMetadata {
RabitQuantizer::new_with_rotation::<Float32Type>(1, code_dim as i32, RQRotationType::Fast)
.metadata(None)
}

fn make_test_batch(codes: FixedSizeListArray) -> RecordBatch {
let num_rows = codes.len();
RecordBatch::try_from_iter(vec![
(
ROW_ID,
Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef,
),
(RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef),
(
ADD_FACTORS_COLUMN,
Arc::new(Float32Array::from_iter_values(
(0..num_rows).map(|v| v as f32),
)) as ArrayRef,
),
(
SCALE_FACTORS_COLUMN,
Arc::new(Float32Array::from_iter_values(
(0..num_rows).map(|v| v as f32 + 0.5),
)) as ArrayRef,
),
])
.unwrap()
}

fn assert_codes_eq(actual: &FixedSizeListArray, expected: &FixedSizeListArray) {
assert_eq!(actual.len(), expected.len());
assert_eq!(actual.value_length(), expected.value_length());
assert_eq!(
actual.values().as_primitive::<UInt8Type>().values(),
expected.values().as_primitive::<UInt8Type>().values()
);
}

#[test]
fn test_try_from_batch_canonicalizes_rq_codes_to_packed_layout() {
let original_codes = make_test_codes(50, 64);
let metadata = make_test_metadata(original_codes.value_length() as usize * 8);
assert!(!metadata.packed);

let storage = RabitQuantizationStorage::try_from_batch(
make_test_batch(original_codes.clone()),
&metadata,
DistanceType::L2,
None,
)
.unwrap();

assert!(storage.metadata().packed);
let stored_batch = storage.to_batches().unwrap().next().unwrap();
let stored_codes = stored_batch[RABIT_CODE_COLUMN].as_fixed_size_list();
let expected_codes = pack_codes(&original_codes);
assert_codes_eq(stored_codes, &expected_codes);
}

#[test]
fn test_remap_preserves_packed_rq_storage_layout() {
let original_codes = make_test_codes(50, 64);
let metadata = make_test_metadata(original_codes.value_length() as usize * 8);
let storage = RabitQuantizationStorage::try_from_batch(
make_test_batch(original_codes.clone()),
&metadata,
DistanceType::L2,
None,
)
.unwrap();

let mut mapping = HashMap::new();
mapping.insert(1, Some(101));
mapping.insert(3, None);
mapping.insert(4, Some(104));

let remapped = storage.remap(&mapping).unwrap();
assert!(remapped.metadata().packed);

let remapped_batch = remapped.to_batches().unwrap().next().unwrap();
let remapped_row_ids = remapped_batch[ROW_ID].as_primitive::<UInt64Type>().values();
let expected_row_ids = UInt64Array::from_iter_values(
[0, 101, 2, 104]
.into_iter()
.chain(5..original_codes.len() as u64),
);
assert_eq!(remapped_row_ids, expected_row_ids.values());

let remapped_codes = remapped_batch[RABIT_CODE_COLUMN].as_fixed_size_list();
let repacked = pack_codes(&unpack_codes(remapped_codes));
assert_codes_eq(remapped_codes, &repacked);
}
}
1 change: 1 addition & 0 deletions rust/lance/src/dataset/optimize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3862,6 +3862,7 @@ mod tests {
}),
],
version: crate::index::vector::IndexFileVersion::V3,
skip_transpose: false,
},
false,
)
Expand Down
1 change: 1 addition & 0 deletions rust/lance/src/dataset/scanner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8759,6 +8759,7 @@ mod test {
}),
],
version: crate::index::vector::IndexFileVersion::Legacy,
skip_transpose: false,
},
false,
)
Expand Down
Loading