diff --git a/rust/lance-index/src/vector/distributed/index_merger.rs b/rust/lance-index/src/vector/distributed/index_merger.rs index c5181b7f842..dd604adb138 100755 --- a/rust/lance-index/src/vector/distributed/index_merger.rs +++ b/rust/lance-index/src/vector/distributed/index_merger.rs @@ -6,10 +6,12 @@ use crate::vector::shared::partition_merger::{ write_unified_ivf_and_index_metadata, SupportedIvfIndexType, }; -use arrow::datatypes::Float32Type; +use arrow::{compute::concat_batches, datatypes::Float32Type}; use arrow_array::cast::AsArray; -use arrow_array::{Array, FixedSizeListArray, UInt64Array}; +use arrow_array::types::UInt8Type; +use arrow_array::{Array, FixedSizeListArray, RecordBatch, UInt64Array}; use futures::StreamExt as _; +use lance_arrow::{FixedSizeListArrayExt, RecordBatchExt}; use lance_core::utils::address::RowAddress; use lance_core::{Error, Result, ROW_ID_FIELD}; use snafu::location; @@ -19,7 +21,7 @@ use std::sync::Arc; use crate::pb; use crate::vector::flat::index::FlatMetadata; use crate::vector::ivf::storage::{IvfModel as IvfStorageModel, IVF_METADATA_KEY}; -use crate::vector::pq::storage::{ProductQuantizationMetadata, PQ_METADATA_KEY}; +use crate::vector::pq::storage::{transpose, ProductQuantizationMetadata, PQ_METADATA_KEY}; use crate::vector::quantizer::QuantizerMetadata; use crate::vector::sq::storage::{ScalarQuantizationMetadata, SQ_METADATA_KEY}; use crate::vector::storage::STORAGE_METADATA_KEY; @@ -272,6 +274,54 @@ pub async fn write_partition_rows( Ok(()) } +/// Transpose the PQ code column for a batch and write it to the unified writer. +/// +/// This helper assumes `batch` contains a contiguous range of rows for a single +/// IVF partition. +async fn write_partition_rows_pq_transposed( + w: &mut FileWriter, + mut batch: RecordBatch, +) -> Result<()> { + let num_rows = batch.num_rows(); + if num_rows == 0 { + return Ok(()); + } + + let pq_col = batch + .column_by_name(PQ_CODE_COLUMN) + .ok_or_else(|| Error::Index { + message: format!("PQ column {} missing in auxiliary shard", PQ_CODE_COLUMN), + location: location!(), + })?; + let pq_fsl = pq_col + .as_fixed_size_list_opt() + .ok_or_else(|| Error::Index { + message: format!( + "PQ column {} is not a FixedSizeList in auxiliary shard, got {}", + PQ_CODE_COLUMN, + pq_col.data_type(), + ), + location: location!(), + })?; + let num_bytes = pq_fsl.value_length() as usize; + let values = pq_fsl.values().as_primitive::(); + let transposed_codes = transpose(values, num_rows, num_bytes); + let transposed_fsl = Arc::new(FixedSizeListArray::try_new_from_values( + transposed_codes, + num_bytes as i32, + )?); + batch = batch.replace_column_by_name(PQ_CODE_COLUMN, transposed_fsl)?; + + // Write in reasonably sized chunks to avoid huge batches. + let batch_size: usize = 10_240; + for offset in (0..num_rows).step_by(batch_size) { + let len = std::cmp::min(batch_size, num_rows - offset); + let slice = batch.slice(offset, len); + w.write_batch(&slice).await?; + } + Ok(()) +} + /// Detect and return supported index type from reader and schema. /// /// This is a lightweight wrapper around SupportedIndexType::detect to keep @@ -817,7 +867,9 @@ pub async fn merge_partial_vector_auxiliary_files( pq_meta = Some(pm.clone()); } if v2w_opt.is_none() { - let w = init_writer_for_pq(object_store, &aux_out, dt, &pm).await?; + let mut pm_for_unified = pm.clone(); + pm_for_unified.transposed = true; + let w = init_writer_for_pq(object_store, &aux_out, dt, &pm_for_unified).await?; v2w_opt = Some(w); } } @@ -1023,7 +1075,9 @@ pub async fn merge_partial_vector_auxiliary_files( pq_meta = Some(pm.clone()); } if v2w_opt.is_none() { - let w = init_writer_for_pq(object_store, &aux_out, dt, &pm).await?; + let mut pm_for_unified = pm.clone(); + pm_for_unified.transposed = true; + let w = init_writer_for_pq(object_store, &aux_out, dt, &pm_for_unified).await?; v2w_opt = Some(w); } } @@ -1117,24 +1171,81 @@ pub async fn merge_partial_vector_auxiliary_files( message: "Missing IVF partition count".to_string(), location: location!(), })?; - for pid in 0..nlist { - for (path, lens, _) in shard_infos.iter() { - let part_len = lens[pid] as usize; - if part_len == 0 { - continue; + let idx_type_final = detected_index_type.ok_or_else(|| Error::Index { + message: "Unable to detect index type".to_string(), + location: location!(), + })?; + + match idx_type_final { + SupportedIvfIndexType::IvfPq | SupportedIvfIndexType::IvfHnswPq => { + // For PQ-backed indices, transpose PQ codes while merging partitions + // so that the unified file stores column-major PQ codes. + for pid in 0..nlist { + let total_len = accumulated_lengths[pid] as usize; + if total_len == 0 { + continue; + } + + let mut part_batches: Vec = Vec::new(); + for (path, lens, _) in shard_infos.iter() { + let part_len = lens[pid] as usize; + if part_len == 0 { + continue; + } + let offset: usize = lens.iter().take(pid).map(|x| *x as usize).sum(); + let fh = sched.open_file(path, &CachedFileSize::unknown()).await?; + let reader = V2Reader::try_open( + fh, + None, + Arc::default(), + &lance_core::cache::LanceCache::no_cache(), + V2ReaderOptions::default(), + ) + .await?; + let mut stream = reader.read_stream( + lance_io::ReadBatchParams::Range(offset..offset + part_len), + u32::MAX, + 4, + lance_encoding::decoder::FilterExpression::no_filter(), + )?; + while let Some(rb) = stream.next().await { + let rb = rb?; + part_batches.push(rb); + } + } + + if part_batches.is_empty() { + continue; + } + + let schema = part_batches[0].schema(); + let partition_batch = concat_batches(&schema, part_batches.iter())?; + if let Some(w) = v2w_opt.as_mut() { + write_partition_rows_pq_transposed(w, partition_batch).await?; + } } - let offset: usize = lens.iter().take(pid).map(|x| *x as usize).sum(); - let fh = sched.open_file(path, &CachedFileSize::unknown()).await?; - let reader = V2Reader::try_open( - fh, - None, - Arc::default(), - &lance_core::cache::LanceCache::no_cache(), - V2ReaderOptions::default(), - ) - .await?; - if let Some(w) = v2w_opt.as_mut() { - write_partition_rows(&reader, w, offset..offset + part_len).await?; + } + _ => { + for pid in 0..nlist { + for (path, lens, _) in shard_infos.iter() { + let part_len = lens[pid] as usize; + if part_len == 0 { + continue; + } + let offset: usize = lens.iter().take(pid).map(|x| *x as usize).sum(); + let fh = sched.open_file(path, &CachedFileSize::unknown()).await?; + let reader = V2Reader::try_open( + fh, + None, + Arc::default(), + &lance_core::cache::LanceCache::no_cache(), + V2ReaderOptions::default(), + ) + .await?; + if let Some(w) = v2w_opt.as_mut() { + write_partition_rows(&reader, w, offset..offset + part_len).await?; + } + } } } } @@ -1153,10 +1264,6 @@ pub async fn merge_partial_vector_auxiliary_files( message: "Distance type missing".to_string(), location: location!(), })?; - let idx_type_final = detected_index_type.ok_or_else(|| Error::Index { - message: "Unable to detect index type".to_string(), - location: location!(), - })?; write_unified_ivf_and_index_metadata(w, &ivf_model, dt2, idx_type_final).await?; w.finish().await?; } else { diff --git a/rust/lance/src/index/vector.rs b/rust/lance/src/index/vector.rs index 4e7316722b7..1dd015f789b 100644 --- a/rust/lance/src/index/vector.rs +++ b/rust/lance/src/index/vector.rs @@ -513,6 +513,9 @@ please provide PQBuildParams.codebook for distributed indexing" )? .with_ivf(ivf_model) .with_quantizer(global_pq) + // For distributed shards, keep PQ codes in their original layout + // and transpose only after all shards are merged. + .with_transpose(false) .with_fragment_filter(fragment_filter) .build() .await?; @@ -615,6 +618,9 @@ please provide PQBuildParams.codebook for distributed indexing" )? .with_ivf(ivf_model) .with_quantizer(global_pq) + // For distributed shards, keep PQ codes in their original layout + // and transpose only after all shards are merged. + .with_transpose(false) .with_fragment_filter(fragment_filter) .build() .await?; diff --git a/rust/lance/src/index/vector/builder.rs b/rust/lance/src/index/vector/builder.rs index 26fed852b86..59300550608 100644 --- a/rust/lance/src/index/vector/builder.rs +++ b/rust/lance/src/index/vector/builder.rs @@ -88,6 +88,26 @@ use super::{ utils::{self, get_vector_type}, }; +/// Stably sort a RecordBatch by the ROW_ID column in ascending order. +/// +/// If the batch has no ROW_ID column or has fewer than 2 rows, it is +/// returned unchanged. When sorting, the relative order of rows with the +/// same ROW_ID is preserved. +fn stable_sort_batch_by_row_id(batch: &RecordBatch) -> Result { + if let Some(row_id_col) = batch.column_by_name(ROW_ID) { + let row_ids = row_id_col.as_primitive::(); + if row_ids.len() > 1 { + let mut order: Vec = (0..row_ids.len()).collect(); + // Vec::sort_by is stable, so equal ROW_IDs keep their + // original relative order. + order.sort_by(|&i, &j| row_ids.value(i).cmp(&row_ids.value(j))); + let indices = UInt32Array::from_iter_values(order.into_iter().map(|i| i as u32)); + return Ok(batch.take(&indices)?); + } + } + Ok(batch.clone()) +} + // the number of partitions to evaluate for reassigning const REASSIGN_RANGE: usize = 64; @@ -128,6 +148,8 @@ pub struct IvfIndexBuilder { optimize_options: Option, // number of indices merged merged_num: usize, + // whether to transpose codes when building storage + transpose_codes: bool, } type BuildStream = @@ -169,6 +191,7 @@ impl IvfIndexBuilder fragment_filter: None, optimize_options: None, merged_num: 0, + transpose_codes: true, }) } @@ -235,6 +258,7 @@ impl IvfIndexBuilder fragment_filter: None, optimize_options: None, merged_num: 0, + transpose_codes: true, }) } @@ -334,6 +358,13 @@ impl IvfIndexBuilder self } + /// Control whether codes are transposed when building storage. + /// This mainly affects intermediate PQ/RQ storage when building distributed indices. + pub fn with_transpose(&mut self, transpose: bool) -> &mut Self { + self.transpose_codes = transpose; + self + } + #[instrument(name = "load_or_build_ivf", level = "debug", skip_all)] async fn load_or_build_ivf(&self) -> Result { match &self.ivf { @@ -935,6 +966,15 @@ impl IvfIndexBuilder } _ => {} } + + // Normalize each batch for this partition to be stably sorted by ROW_ID. + for batch in part_batches.iter_mut() { + if batch.num_rows() == 0 { + continue; + } + *batch = stable_sort_batch_by_row_id(batch)?; + } + batches.extend(part_batches); } @@ -958,6 +998,7 @@ impl IvfIndexBuilder .map(|s| s.parse::().unwrap_or(0.0)) .unwrap_or(0.0); let batch = batch.drop_column(PART_ID_COLUMN)?; + let batch = stable_sort_batch_by_row_id(&batch)?; batches.push(batch); } } @@ -981,6 +1022,8 @@ impl IvfIndexBuilder )); }; + let is_pq = Q::quantization_type() == QuantizationType::Product; + // prepare the final writers let storage_path = self.index_dir.child(INDEX_AUXILIARY_FILE_NAME); let index_path = self.index_dir.child(INDEX_FILE_NAME); @@ -1024,7 +1067,51 @@ impl IvfIndexBuilder storage_ivf.add_partition(0); } else { let batches = storage.to_batches()?.collect::>(); - let batch = arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; + let mut batch = + arrow::compute::concat_batches(&batches[0].schema(), batches.iter())?; + + if is_pq && batch.column_by_name(PQ_CODE_COLUMN).is_some() { + // The PQ storage keeps codes in a transposed layout (bytes grouped + // across all rows). Convert them back to per-row layout so that a + // stable ROW_ID sort moves PQ_CODE_COLUMN together with ROW_ID. + let codes_fsl = batch + .column_by_name(PQ_CODE_COLUMN) + .unwrap() + .as_fixed_size_list(); + let num_rows = batch.num_rows(); + let bytes_per_code = codes_fsl.value_length() as usize; + let codes = codes_fsl.values().as_primitive::(); + let original_codes = transpose(codes, bytes_per_code, num_rows); + let original_fsl = Arc::new(FixedSizeListArray::try_new_from_values( + original_codes, + bytes_per_code as i32, + )?); + batch = batch.replace_column_by_name(PQ_CODE_COLUMN, original_fsl)?; + } + + // Enforce a stable ROW_ID ordering for all auxiliary batches so that the + // PQ code column moves together with ROW_ID. + batch = stable_sort_batch_by_row_id(&batch)?; + + // For PQ storages, optionally convert codes back to transposed layout + // in the unified auxiliary file. This keeps final PQ storage column-major + // when `transpose_pq_codes` is enabled. + if is_pq && self.transpose_codes && batch.column_by_name(PQ_CODE_COLUMN).is_some() { + let codes_fsl = batch + .column_by_name(PQ_CODE_COLUMN) + .unwrap() + .as_fixed_size_list(); + let num_rows = batch.num_rows(); + let bytes_per_code = codes_fsl.value_length() as usize; + let codes = codes_fsl.values().as_primitive::(); + let transposed_codes = transpose(codes, num_rows, bytes_per_code); + let transposed_fsl = Arc::new(FixedSizeListArray::try_new_from_values( + transposed_codes, + bytes_per_code as i32, + )?); + batch = batch.replace_column_by_name(PQ_CODE_COLUMN, transposed_fsl)?; + } + storage_writer.write_batch(&batch).await?; storage_ivf.add_partition(batch.num_rows() as u32); } @@ -1066,12 +1153,18 @@ impl IvfIndexBuilder .add_global_buffer(storage_ivf_pb.encode_to_vec().into()) .await?; storage_writer.add_schema_metadata(IVF_METADATA_KEY, ivf_buffer_pos.to_string()); + let quant_type = Q::quantization_type(); + let transposed = match quant_type { + QuantizationType::Product => self.transpose_codes, + QuantizationType::Rabit => true, + _ => false, + }; // For now, each partition's metadata is just the quantizer, // it's all the same for now, so we just take the first one let mut metadata = quantizer.metadata(Some(QuantizationMetadata { codebook_position: Some(0), codebook: None, - transposed: true, + transposed, })); if let Some(extra_metadata) = metadata.extra_metadata()? { let idx = storage_writer.add_global_buffer(extra_metadata).await?; diff --git a/rust/lance/src/index/vector/ivf/io.rs b/rust/lance/src/index/vector/ivf/io.rs index c79d568a6c3..dc06c935521 100644 --- a/rust/lance/src/index/vector/ivf/io.rs +++ b/rust/lance/src/index/vector/ivf/io.rs @@ -201,20 +201,26 @@ pub(super) async fn write_pq_partitions( location: location!(), })?; if let Some(pq_code) = pq_index.code.as_ref() { - let original_pq_codes = transpose( - pq_code, - pq_index.pq.num_sub_vectors, - pq_code.len() / pq_index.pq.code_dim(), - ); + let row_ids = pq_index.row_ids.as_ref().unwrap(); + let num_vectors = row_ids.len(); + if num_vectors == 0 || pq_code.is_empty() { + continue; + } + if pq_code.len() % num_vectors != 0 { + continue; + } + let num_bytes_per_code = pq_code.len() / num_vectors; + let original_pq_codes = transpose(pq_code, num_bytes_per_code, num_vectors); let fsl = Arc::new( FixedSizeListArray::try_new_from_values( original_pq_codes, - pq_index.pq.code_dim() as i32, + num_bytes_per_code as i32, ) .unwrap(), ); + pq_array.push(fsl); - row_id_array.push(pq_index.row_ids.as_ref().unwrap().clone()); + row_id_array.push(row_ids.clone()); } } } diff --git a/rust/lance/src/index/vector/ivf/v2.rs b/rust/lance/src/index/vector/ivf/v2.rs index 0e85378ab97..2cabe0b2fb4 100644 --- a/rust/lance/src/index/vector/ivf/v2.rs +++ b/rust/lance/src/index/vector/ivf/v2.rs @@ -629,6 +629,7 @@ mod tests { use lance_index::vector::storage::VectorStore; use crate::dataset::{InsertBuilder, UpdateBuilder, WriteMode, WriteParams}; + use crate::index::vector::ivf::finalize_distributed_merge; use crate::index::vector::ivf::v2::IvfPq; use crate::index::DatasetIndexInternalExt; use crate::utils::test::copy_test_data_to_tmp; @@ -647,6 +648,7 @@ mod tests { use lance_file::reader::{FileReader, FileReaderOptions}; use lance_file::writer::FileWriter; use lance_index::vector::ivf::IvfBuildParams; + use lance_index::vector::kmeans::{train_kmeans, KMeansParams}; use lance_index::vector::pq::PQBuildParams; use lance_index::vector::quantizer::QuantizerMetadata; use lance_index::vector::sq::builder::SQBuildParams; @@ -670,6 +672,7 @@ mod tests { use rand::distr::uniform::SampleUniform; use rand::{rngs::StdRng, Rng, SeedableRng}; use rstest::rstest; + use uuid::Uuid; const NUM_ROWS: usize = 512; const DIM: usize = 32; @@ -1293,6 +1296,338 @@ mod tests { .collect() } + const TWO_FRAG_NUM_ROWS: usize = 2000; + const TWO_FRAG_DIM: usize = 128; + const TWO_FRAG_NUM_PARTITIONS: usize = 4; + const TWO_FRAG_NUM_SUBVECTORS: usize = 16; + const TWO_FRAG_NUM_BITS: usize = 8; + const TWO_FRAG_SAMPLE_RATE: usize = 7; + const TWO_FRAG_MAX_ITERS: u32 = 20; + + fn make_two_fragment_batches() -> (Arc, Vec) { + let ids = Arc::new(UInt64Array::from_iter_values(0..TWO_FRAG_NUM_ROWS as u64)); + + let values = generate_random_array_with_range(TWO_FRAG_NUM_ROWS * TWO_FRAG_DIM, 0.0..1.0); + let vectors = Arc::new( + FixedSizeListArray::try_new_from_values( + Float32Array::from(values), + TWO_FRAG_DIM as i32, + ) + .unwrap(), + ); + + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::UInt64, false), + Field::new("vector", vectors.data_type().clone(), false), + ])); + let batch = RecordBatch::try_new(schema.clone(), vec![ids, vectors]).unwrap(); + + (schema, vec![batch]) + } + + async fn write_dataset_from_batches( + test_uri: &str, + schema: Arc, + batches: Vec, + ) -> Dataset { + let batches = RecordBatchIterator::new(batches.into_iter().map(Ok), schema); + + let write_params = WriteParams { + max_rows_per_file: 500, + mode: WriteMode::Overwrite, + ..Default::default() + }; + + Dataset::write(batches, test_uri, Some(write_params)) + .await + .unwrap() + } + + async fn prepare_global_ivf_pq( + dataset: &Dataset, + vector_column: &str, + ) -> (IvfBuildParams, PQBuildParams) { + let batch = dataset + .scan() + .project(&[vector_column.to_string()]) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let vectors = batch + .column_by_name(vector_column) + .expect("vector column should exist") + .as_fixed_size_list(); + + let dim = vectors.value_length() as usize; + assert_eq!(dim, TWO_FRAG_DIM, "unexpected vector dimension"); + + let values = vectors.values().as_primitive::(); + + let kmeans_params = KMeansParams::new(None, TWO_FRAG_MAX_ITERS, 1, DistanceType::L2); + let kmeans = train_kmeans::( + values, + kmeans_params, + dim, + TWO_FRAG_NUM_PARTITIONS, + TWO_FRAG_SAMPLE_RATE, + ) + .unwrap(); + + let centroids_flat = kmeans.centroids.as_primitive::().clone(); + let centroids_fsl = + Arc::new(FixedSizeListArray::try_new_from_values(centroids_flat, dim as i32).unwrap()); + let mut ivf_params = + IvfBuildParams::try_with_centroids(TWO_FRAG_NUM_PARTITIONS, centroids_fsl).unwrap(); + ivf_params.max_iters = TWO_FRAG_MAX_ITERS as usize; + ivf_params.sample_rate = TWO_FRAG_SAMPLE_RATE; + + let mut pq_train_params = PQBuildParams::new(TWO_FRAG_NUM_SUBVECTORS, TWO_FRAG_NUM_BITS); + pq_train_params.max_iters = TWO_FRAG_MAX_ITERS as usize; + pq_train_params.sample_rate = TWO_FRAG_SAMPLE_RATE; + + let pq = pq_train_params.build(vectors, DistanceType::L2).unwrap(); + let codebook_flat = pq.codebook.values().as_primitive::().clone(); + let pq_codebook: ArrayRef = Arc::new(codebook_flat); + let mut pq_params = + PQBuildParams::with_codebook(TWO_FRAG_NUM_SUBVECTORS, TWO_FRAG_NUM_BITS, pq_codebook); + pq_params.max_iters = TWO_FRAG_MAX_ITERS as usize; + pq_params.sample_rate = TWO_FRAG_SAMPLE_RATE; + + (ivf_params, pq_params) + } + + async fn build_ivfpq_for_fragment_groups( + dataset: &mut Dataset, + fragment_groups: Vec>, // each group is a set of fragment ids + ivf_params: &IvfBuildParams, + pq_params: &PQBuildParams, + index_name: &str, + ) { + let shared_uuid = Uuid::new_v4(); + let params = VectorIndexParams::with_ivf_pq_params( + DistanceType::L2, + ivf_params.clone(), + pq_params.clone(), + ); + + for fragments in fragment_groups { + let mut builder = dataset.create_index_builder(&["vector"], IndexType::Vector, ¶ms); + builder = builder + .name(index_name.to_string()) + .fragments(fragments) + .index_uuid(shared_uuid.to_string()); + // Build partial index shards without committing to manifest. + builder.execute_uncommitted().await.unwrap(); + } + + let index_dir = dataset.indices_dir().child(shared_uuid.to_string()); + finalize_distributed_merge(dataset.object_store(), &index_dir, Some(IndexType::IvfPq)) + .await + .unwrap(); + + dataset + .commit_existing_index(index_name, "vector", shared_uuid) + .await + .unwrap(); + } + + fn assert_ivf_layout_equal(stats_a: &serde_json::Value, stats_b: &serde_json::Value) { + let idx_a = &stats_a["indices"][0]; + let idx_b = &stats_b["indices"][0]; + + // Centroids: same shape and values (within tolerance). + let centroids_a = idx_a["centroids"] + .as_array() + .expect("centroids should be an array"); + let centroids_b = idx_b["centroids"] + .as_array() + .expect("centroids should be an array"); + assert_eq!( + centroids_a.len(), + centroids_b.len(), + "num centroids mismatch", + ); + for (row_a, row_b) in centroids_a.iter().zip(centroids_b.iter()) { + let row_a = row_a + .as_array() + .unwrap_or_else(|| panic!("invalid centroid row: {:?}", row_a)); + let row_b = row_b + .as_array() + .unwrap_or_else(|| panic!("invalid centroid row: {:?}", row_b)); + assert_eq!(row_a.len(), row_b.len(), "centroid dim mismatch"); + for (va, vb) in row_a.iter().zip(row_b.iter()) { + let fa = va.as_f64().expect("centroid must be numeric") as f32; + let fb = vb.as_f64().expect("centroid must be numeric") as f32; + assert!( + (fa - fb).abs() <= 1e-4, + "centroid mismatch: {} vs {}", + fa, + fb + ); + } + } + + // Partitions sizes. + let parts_a = idx_a["partitions"] + .as_array() + .expect("partitions should be an array"); + let parts_b = idx_b["partitions"] + .as_array() + .expect("partitions should be an array"); + assert_eq!(parts_a.len(), parts_b.len(), "num partitions mismatch"); + let sizes_a: Vec = parts_a + .iter() + .map(|p| p["size"].as_u64().expect("partition size")) + .collect(); + let sizes_b: Vec = parts_b + .iter() + .map(|p| p["size"].as_u64().expect("partition size")) + .collect(); + assert_eq!(sizes_a, sizes_b, "partition sizes mismatch"); + } + + #[tokio::test] + async fn test_ivfpq_recall_performance_on_two_frags_single_vs_split() { + const INDEX_NAME: &str = "vector_idx"; + + let test_dir = TempStrDir::default(); + let base_uri = test_dir.as_str(); + + // Generate the data once, then write it twice to two independent dataset URIs. + let (schema, batches) = make_two_fragment_batches(); + + let ds_single_uri = format!("{}/single", base_uri); + let ds_split_uri = format!("{}/split", base_uri); + + let mut ds_single = + write_dataset_from_batches(&ds_single_uri, schema.clone(), batches.clone()).await; + let mut ds_split = write_dataset_from_batches(&ds_split_uri, schema, batches).await; + + // Ensure we have at least 2 fragments. + let fragments_single = ds_single.get_fragments(); + assert!( + fragments_single.len() >= 2, + "expected at least 2 fragments in ds_single, got {}", + fragments_single.len() + ); + let fragments_split = ds_split.get_fragments(); + assert!( + fragments_split.len() >= 2, + "expected at least 2 fragments in ds_split, got {}", + fragments_split.len() + ); + + // Pretrain global IVF centroids and PQ codebook. + let (ivf_params, pq_params) = prepare_global_ivf_pq(&ds_single, "vector").await; + + // Build single index using two fragments in one distributed build. + let group_single = vec![ + fragments_single[0].id() as u32, + fragments_single[1].id() as u32, + ]; + build_ivfpq_for_fragment_groups( + &mut ds_single, + vec![group_single], + &ivf_params, + &pq_params, + INDEX_NAME, + ) + .await; + + // Build split index: one fragment per distributed build, then merge. + let group0 = vec![fragments_split[0].id() as u32]; + let group1 = vec![fragments_split[1].id() as u32]; + build_ivfpq_for_fragment_groups( + &mut ds_split, + vec![group0, group1], + &ivf_params, + &pq_params, + INDEX_NAME, + ) + .await; + + // Compare IVF layout via index statistics. + let stats_single_json = ds_single.index_statistics(INDEX_NAME).await.unwrap(); + let stats_split_json = ds_split.index_statistics(INDEX_NAME).await.unwrap(); + let stats_single: serde_json::Value = serde_json::from_str(&stats_single_json).unwrap(); + let stats_split: serde_json::Value = serde_json::from_str(&stats_split_json).unwrap(); + assert_ivf_layout_equal(&stats_single, &stats_split); + + // Compare row id sets per partition. + let ctx_single = load_vector_index_context(&ds_single, "vector", INDEX_NAME).await; + let ctx_split = load_vector_index_context(&ds_split, "vector", INDEX_NAME).await; + + let ivf_single = ctx_single.ivf(); + let ivf_split = ctx_split.ivf(); + let total_partitions = ivf_single.total_partitions(); + assert_eq!(total_partitions, ivf_split.total_partitions()); + + for part_id in 0..total_partitions { + let row_ids_single = load_partition_row_ids(ivf_single, part_id).await; + let row_ids_split = load_partition_row_ids(ivf_split, part_id).await; + let set_single: HashSet = row_ids_single.into_iter().collect(); + let set_split: HashSet = row_ids_split.into_iter().collect(); + assert_eq!( + set_single, set_split, + "row id set mismatch for partition {}", + part_id + ); + } + + // Compare Top-K row ids on a deterministic set of queries. + const K: usize = 10; + const NUM_QUERIES: usize = 10; + + async fn collect_row_ids(ds: &Dataset, queries: &[Arc]) -> Vec> { + let mut ids_per_query = Vec::with_capacity(queries.len()); + for q in queries { + let result = ds + .scan() + .with_row_id() + .project(&["_rowid"] as &[&str]) + .unwrap() + .nearest("vector", q.as_ref(), K) + .unwrap() + .try_into_batch() + .await + .unwrap(); + + let row_ids = result[ROW_ID] + .as_primitive::() + .values() + .iter() + .copied() + .collect::>(); + ids_per_query.push(row_ids); + } + ids_per_query + } + + // Collect a deterministic query set from ds_single. + let query_batch = ds_single + .scan() + .project(&["vector"] as &[&str]) + .unwrap() + .limit(Some(NUM_QUERIES as i64), None) + .unwrap() + .try_into_batch() + .await + .unwrap(); + let vectors = query_batch["vector"].as_fixed_size_list(); + let queries: Vec> = (0..vectors.len()) + .map(|i| vectors.value(i) as Arc) + .collect(); + + let ids_single = collect_row_ids(&ds_single, &queries).await; + let ids_split = collect_row_ids(&ds_split, &queries).await; + + assert_eq!( + ids_single, ids_split, + "single vs split index returned different Top-K row ids", + ); + } + async fn test_index( params: VectorIndexParams, nlist: usize,