diff --git a/java/lance-jni/src/utils.rs b/java/lance-jni/src/utils.rs index 7387ca4564b..2b342c99ef7 100644 --- a/java/lance-jni/src/utils.rs +++ b/java/lance-jni/src/utils.rs @@ -430,6 +430,7 @@ pub fn get_vector_index_params( metric_type: distance_type, stages, version: IndexFileVersion::V3, + skip_transpose: false, }) }, )?; diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index b57d4b15efe..7b9840048e5 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -2796,6 +2796,7 @@ def create_index( index_uuid: Optional[str] = None, *, target_partition_size: Optional[int] = None, + skip_transpose: bool = False, **kwargs, ) -> LanceDataset: """Create index on column. @@ -3271,6 +3272,9 @@ def create_index( if shuffle_partition_concurrency is not None: kwargs["shuffle_partition_concurrency"] = shuffle_partition_concurrency + if skip_transpose: + kwargs["skip_transpose"] = True + # Add fragment_ids and index_uuid to kwargs if provided for # distributed indexing if fragment_ids is not None: diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index 245825b980d..c4dec90994e 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -827,6 +827,8 @@ def test_create_ivf_rq_index(): num_bits=1, ) assert ds.describe_indices()[0].field_names == ["vector"] + stats = ds.stats.index_stats("vector_idx") + assert stats["indices"][0]["sub_index"]["packed"] is True with pytest.raises( NotImplementedError, @@ -865,6 +867,19 @@ def test_create_ivf_rq_index(): assert res["_distance"].to_numpy().max() == 0.0 +def test_create_ivf_rq_skip_transpose(): + ds = lance.write_dataset(create_table(), "memory://") + ds = ds.create_index( + "vector", + index_type="IVF_RQ", + num_partitions=4, + num_bits=1, + skip_transpose=True, + ) + stats = ds.stats.index_stats("vector_idx") + assert stats["indices"][0]["sub_index"]["packed"] is False + + def test_create_ivf_rq_requires_dim_divisible_by_8(): vectors = np.zeros((1000, 30), dtype=np.float32).tolist() tbl = pa.Table.from_pydict( @@ -1058,6 +1073,22 @@ def test_pre_populated_ivf_centroids(dataset, tmp_path: Path): assert all([partition_keys == set(p.keys()) for p in partitions]) +def test_create_ivf_pq_skip_transpose(dataset, tmp_path: Path): + ds = lance.write_dataset( + dataset.to_table(), tmp_path / "indexed_skip_transpose.lance" + ) + ds = ds.create_index( + "vector", + index_type="IVF_PQ", + num_partitions=4, + num_sub_vectors=16, + skip_transpose=True, + ) + + stats = ds.stats.index_stats("vector_idx") + assert stats["indices"][0]["sub_index"]["transposed"] is False + + def test_optimize_index(dataset, tmp_path): dataset_uri = tmp_path / "dataset.lance" assert not dataset.has_index diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 11a7ab2f9c2..9ca47485e42 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -3205,6 +3205,7 @@ fn prepare_vector_index_params( let mut sq_params = SQBuildParams::default(); let mut rq_params = RQBuildParams::default(); let mut index_file_version = IndexFileVersion::V3; + let mut skip_transpose = false; if let Some(kwargs) = kwargs { // Parse metric type @@ -3334,6 +3335,10 @@ fn prepare_vector_index_params( index_file_version = IndexFileVersion::try_from(&version) .map_err(|e| PyValueError::new_err(format!("Invalid index_file_version: {e}")))?; } + + if let Some(value) = kwargs.get_item("skip_transpose")? { + skip_transpose = value.extract()?; + } } let mut params = match index_type { @@ -3378,6 +3383,7 @@ fn prepare_vector_index_params( ))), }?; params.version(index_file_version); + params.skip_transpose(skip_transpose); Ok(params) } diff --git a/rust/lance-index/src/vector/bq/storage.rs b/rust/lance-index/src/vector/bq/storage.rs index fdb47bfdf16..c47dd0211eb 100644 --- a/rust/lance-index/src/vector/bq/storage.rs +++ b/rust/lance-index/src/vector/bq/storage.rs @@ -848,6 +848,14 @@ fn get_rq_code( #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; + + use arrow_array::{ArrayRef, Float32Array, UInt64Array}; + use lance_core::ROW_ID; + use lance_linalg::distance::DistanceType; + + use crate::vector::bq::{RQRotationType, builder::RabitQuantizer}; + use crate::vector::quantizer::{Quantization, QuantizerStorage}; fn build_dist_table_not_optimized( sub_vec: &[T::Native], @@ -912,4 +920,111 @@ mod tests { ); } } + + fn make_test_codes(num_vectors: usize, code_dim: i32) -> FixedSizeListArray { + let quantizer = + RabitQuantizer::new_with_rotation::(1, code_dim, RQRotationType::Fast); + let values = Float32Array::from_iter_values( + (0..num_vectors * code_dim as usize).map(|idx| idx as f32 / code_dim as f32), + ); + let vectors = FixedSizeListArray::try_new_from_values(values, code_dim).unwrap(); + quantizer + .quantize(&vectors) + .unwrap() + .as_fixed_size_list() + .clone() + } + + fn make_test_metadata(code_dim: usize) -> RabitQuantizationMetadata { + RabitQuantizer::new_with_rotation::(1, code_dim as i32, RQRotationType::Fast) + .metadata(None) + } + + fn make_test_batch(codes: FixedSizeListArray) -> RecordBatch { + let num_rows = codes.len(); + RecordBatch::try_from_iter(vec![ + ( + ROW_ID, + Arc::new(UInt64Array::from_iter_values(0..num_rows as u64)) as ArrayRef, + ), + (RABIT_CODE_COLUMN, Arc::new(codes) as ArrayRef), + ( + ADD_FACTORS_COLUMN, + Arc::new(Float32Array::from_iter_values( + (0..num_rows).map(|v| v as f32), + )) as ArrayRef, + ), + ( + SCALE_FACTORS_COLUMN, + Arc::new(Float32Array::from_iter_values( + (0..num_rows).map(|v| v as f32 + 0.5), + )) as ArrayRef, + ), + ]) + .unwrap() + } + + fn assert_codes_eq(actual: &FixedSizeListArray, expected: &FixedSizeListArray) { + assert_eq!(actual.len(), expected.len()); + assert_eq!(actual.value_length(), expected.value_length()); + assert_eq!( + actual.values().as_primitive::().values(), + expected.values().as_primitive::().values() + ); + } + + #[test] + fn test_try_from_batch_canonicalizes_rq_codes_to_packed_layout() { + let original_codes = make_test_codes(50, 64); + let metadata = make_test_metadata(original_codes.value_length() as usize * 8); + assert!(!metadata.packed); + + let storage = RabitQuantizationStorage::try_from_batch( + make_test_batch(original_codes.clone()), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(); + + assert!(storage.metadata().packed); + let stored_batch = storage.to_batches().unwrap().next().unwrap(); + let stored_codes = stored_batch[RABIT_CODE_COLUMN].as_fixed_size_list(); + let expected_codes = pack_codes(&original_codes); + assert_codes_eq(stored_codes, &expected_codes); + } + + #[test] + fn test_remap_preserves_packed_rq_storage_layout() { + let original_codes = make_test_codes(50, 64); + let metadata = make_test_metadata(original_codes.value_length() as usize * 8); + let storage = RabitQuantizationStorage::try_from_batch( + make_test_batch(original_codes.clone()), + &metadata, + DistanceType::L2, + None, + ) + .unwrap(); + + let mut mapping = HashMap::new(); + mapping.insert(1, Some(101)); + mapping.insert(3, None); + mapping.insert(4, Some(104)); + + let remapped = storage.remap(&mapping).unwrap(); + assert!(remapped.metadata().packed); + + let remapped_batch = remapped.to_batches().unwrap().next().unwrap(); + let remapped_row_ids = remapped_batch[ROW_ID].as_primitive::().values(); + let expected_row_ids = UInt64Array::from_iter_values( + [0, 101, 2, 104] + .into_iter() + .chain(5..original_codes.len() as u64), + ); + assert_eq!(remapped_row_ids, expected_row_ids.values()); + + let remapped_codes = remapped_batch[RABIT_CODE_COLUMN].as_fixed_size_list(); + let repacked = pack_codes(&unpack_codes(remapped_codes)); + assert_codes_eq(remapped_codes, &repacked); + } } diff --git a/rust/lance/src/dataset/optimize.rs b/rust/lance/src/dataset/optimize.rs index 8bbc8d177cd..c3e1149a2d8 100644 --- a/rust/lance/src/dataset/optimize.rs +++ b/rust/lance/src/dataset/optimize.rs @@ -3862,6 +3862,7 @@ mod tests { }), ], version: crate::index::vector::IndexFileVersion::V3, + skip_transpose: false, }, false, ) diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 78aaa43fa30..cdaa9408281 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -8759,6 +8759,7 @@ mod test { }), ], version: crate::index::vector::IndexFileVersion::Legacy, + skip_transpose: false, }, false, ) diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 92336c15c41..4fc2f910527 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -106,6 +106,9 @@ pub struct VectorIndexParams { /// The version of the index file. pub version: IndexFileVersion, + + /// Skip transpose / packing for PQ and RQ storage. + pub skip_transpose: bool, } impl VectorIndexParams { @@ -114,6 +117,11 @@ impl VectorIndexParams { self } + pub fn skip_transpose(&mut self, skip_transpose: bool) -> &mut Self { + self.skip_transpose = skip_transpose; + self + } + pub fn ivf_flat(num_partitions: usize, metric_type: MetricType) -> Self { let ivf_params = IvfBuildParams::new(num_partitions); let stages = vec![StageParams::Ivf(ivf_params)]; @@ -121,6 +129,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -130,6 +139,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -163,6 +173,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -188,6 +199,7 @@ impl VectorIndexParams { stages, metric_type: distance_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -202,6 +214,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -215,6 +228,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -228,6 +242,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -241,6 +256,7 @@ impl VectorIndexParams { stages, metric_type: distance_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -261,6 +277,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -281,6 +298,7 @@ impl VectorIndexParams { stages, metric_type, version: IndexFileVersion::V3, + skip_transpose: false, } } @@ -514,8 +532,8 @@ pub(crate) async fn build_distributed_vector_index( )? .with_ivf(ivf_model) .with_quantizer(global_pq) - // For distributed shards, keep PQ codes in their original layout - // and transpose only after all shards are merged. + // For distributed shards, keep PQ codes in row-major layout. + // A single transpose is performed in the distributed merge stage. .with_transpose(false) .with_fragment_filter(fragment_filter) .with_progress(progress.clone()) @@ -610,8 +628,8 @@ pub(crate) async fn build_distributed_vector_index( )? .with_ivf(ivf_model) .with_quantizer(global_pq) - // For distributed shards, keep PQ codes in their original layout - // and transpose only after all shards are merged. + // For distributed shards, keep PQ codes in row-major layout. + // A single transpose is performed in the distributed merge stage. .with_transpose(false) .with_fragment_filter(fragment_filter) .with_progress(progress.clone()) @@ -787,7 +805,7 @@ pub(crate) async fn build_vector_index( .await?; } IndexFileVersion::V3 => { - IvfIndexBuilder::::new( + let mut builder = IvfIndexBuilder::::new( dataset.clone(), column.to_owned(), dataset.indices_dir().child(uuid), @@ -797,10 +815,13 @@ pub(crate) async fn build_vector_index( Some(pq_params.clone()), (), frag_reuse_index, - )? - .with_progress(progress.clone()) - .build() - .await?; + )?; + + builder + .with_transpose(!params.skip_transpose) + .with_progress(progress.clone()) + .build() + .await?; } } } @@ -835,7 +856,7 @@ pub(crate) async fn build_vector_index( ))); }; - IvfIndexBuilder::::new( + let mut builder = IvfIndexBuilder::::new( dataset.clone(), column.to_owned(), dataset.indices_dir().child(uuid), @@ -845,10 +866,13 @@ pub(crate) async fn build_vector_index( Some(rq_params.clone()), (), frag_reuse_index, - )? - .with_progress(progress.clone()) - .build() - .await?; + )?; + + builder + .with_transpose(!params.skip_transpose) + .with_progress(progress.clone()) + .build() + .await?; } IndexType::IvfHnswFlat => { let StageParams::Hnsw(hnsw_params) = &stages[1] else { @@ -1047,7 +1071,7 @@ pub(crate) async fn build_vector_index_incremental( }, // IVF_PQ (SubIndexType::Flat, QuantizationType::Product) => { - IvfIndexBuilder::::new_incremental( + let mut builder = IvfIndexBuilder::::new_incremental( dataset.clone(), column.to_owned(), index_dir, @@ -1056,12 +1080,14 @@ pub(crate) async fn build_vector_index_incremental( (), frag_reuse_index, OptimizeOptions::append(), - )? - .with_ivf(ivf_model) - .with_quantizer(quantizer.try_into()?) - .with_progress(progress.clone()) - .build() - .await?; + )?; + builder + .with_ivf(ivf_model) + .with_quantizer(quantizer.try_into()?) + .with_transpose(!params.skip_transpose) + .with_progress(progress.clone()) + .build() + .await?; } // IVF_SQ (SubIndexType::Flat, QuantizationType::Scalar) => { @@ -1083,7 +1109,7 @@ pub(crate) async fn build_vector_index_incremental( } // IVF_RQ (SubIndexType::Flat, QuantizationType::Rabit) => { - IvfIndexBuilder::::new_incremental( + let mut builder = IvfIndexBuilder::::new_incremental( dataset.clone(), column.to_owned(), index_dir, @@ -1092,12 +1118,14 @@ pub(crate) async fn build_vector_index_incremental( (), frag_reuse_index, OptimizeOptions::append(), - )? - .with_ivf(ivf_model) - .with_quantizer(quantizer.try_into()?) - .with_progress(progress.clone()) - .build() - .await?; + )?; + builder + .with_ivf(ivf_model) + .with_quantizer(quantizer.try_into()?) + .with_transpose(!params.skip_transpose) + .with_progress(progress.clone()) + .build() + .await?; } // IVF_HNSW variants (SubIndexType::Hnsw, quantization_type) => { diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index ab4aacd1d66..14b134092cc 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -33,7 +33,7 @@ use lance_index::frag_reuse::FragReuseIndex; use lance_index::metrics::NoOpMetricsCollector; use lance_index::optimize::OptimizeOptions; use lance_index::progress::{IndexBuildProgress, NoopIndexBuildProgress}; -use lance_index::vector::bq::storage::{RABIT_CODE_COLUMN, unpack_codes}; +use lance_index::vector::bq::storage::{RABIT_CODE_COLUMN, pack_codes, unpack_codes}; use lance_index::vector::kmeans::KMeansParams; use lance_index::vector::pq::storage::transpose; use lance_index::vector::quantizer::{ @@ -1016,6 +1016,7 @@ impl IvfIndexBuilder }; let is_pq = Q::quantization_type() == QuantizationType::Product; + let is_rq = Q::quantization_type() == QuantizationType::Rabit; // prepare the final writers let storage_path = self.index_dir.child(INDEX_AUXILIARY_FILE_NAME); @@ -1084,6 +1085,19 @@ impl IvfIndexBuilder batch = batch.replace_column_by_name(PQ_CODE_COLUMN, original_fsl)?; } + if is_rq && batch.column_by_name(RABIT_CODE_COLUMN).is_some() { + // RQ storage batches reaching merge_partitions always come + // from RabitQuantizationStorage, which canonicalizes codes + // into packed layout in try_from_batch/remap. Materialize + // row-major bytes so row-wise sort operates on per-row codes. + let codes_fsl = batch + .column_by_name(RABIT_CODE_COLUMN) + .unwrap() + .as_fixed_size_list(); + let unpacked = Arc::new(unpack_codes(codes_fsl)); + batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, unpacked)?; + } + // Enforce a stable ROW_ID ordering for all auxiliary batches so that the // PQ code column moves together with ROW_ID. batch = stable_sort_batch_by_row_id(&batch)?; @@ -1107,6 +1121,18 @@ impl IvfIndexBuilder batch = batch.replace_column_by_name(PQ_CODE_COLUMN, transposed_fsl)?; } + if is_rq + && self.transpose_codes + && batch.column_by_name(RABIT_CODE_COLUMN).is_some() + { + let codes_fsl = batch + .column_by_name(RABIT_CODE_COLUMN) + .unwrap() + .as_fixed_size_list(); + let packed = Arc::new(pack_codes(codes_fsl)); + batch = batch.replace_column_by_name(RABIT_CODE_COLUMN, packed)?; + } + storage_writer.write_batch(&batch).await?; storage_ivf.add_partition(batch.num_rows() as u32); } @@ -1150,8 +1176,7 @@ impl IvfIndexBuilder storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); let quant_type = Q::quantization_type(); let transposed = match quant_type { - QuantizationType::Product => self.transpose_codes, - QuantizationType::Rabit => true, + QuantizationType::Product | QuantizationType::Rabit => self.transpose_codes, _ => false, }; // For now, each partition's metadata is just the quantizer,