diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py index ee7f955c0f7..c1b8db9ad86 100644 --- a/python/python/tests/test_vector_index.py +++ b/python/python/tests/test_vector_index.py @@ -184,6 +184,28 @@ def test_ann(indexed_dataset): run(indexed_dataset) +def test_distributed_ivf_pq_partition_window_env_override(tmp_path, monkeypatch): + # Keep this before other distributed vector merge tests so the process-level + # lazy window size initialization reads this override. + monkeypatch.setenv("LANCE_IVF_PQ_MERGE_PARTITION_WINDOW_SIZE", "4") + monkeypatch.setenv("LANCE_IVF_PQ_MERGE_PARTITION_PREFETCH_WINDOW_COUNT", "2") + + data = create_table(nvec=3000, ndim=128) + q = np.random.randn(128).astype(np.float32) + assert_distributed_vector_consistency( + data, + "vector", + index_type="IVF_PQ", + index_params={"num_partitions": 10, "num_sub_vectors": 16}, + queries=[q], + topk=10, + world=2, + tmp_path=tmp_path, + similarity_metric="recall", + similarity_threshold=0.80, + ) + + @pytest.mark.parametrize( "fixture_name,index_type,index_params,similarity_threshold", [ diff --git a/rust/lance-index/src/vector/distributed/index_merger.rs b/rust/lance-index/src/vector/distributed/index_merger.rs index dd604adb138..d0a9711bd18 100755 --- a/rust/lance-index/src/vector/distributed/index_merger.rs +++ b/rust/lance-index/src/vector/distributed/index_merger.rs @@ -37,6 +37,27 @@ use lance_io::scheduler::{ScanScheduler, SchedulerConfig}; use lance_io::utils::CachedFileSize; use lance_linalg::distance::DistanceType; use prost::Message; +use std::future::Future; +use std::pin::Pin; +use std::sync::LazyLock; + +const DEFAULT_PARTITION_WINDOW_SIZE: usize = 512; +const PARTITION_WINDOW_SIZE_ENV: &str = "LANCE_IVF_PQ_MERGE_PARTITION_WINDOW_SIZE"; +const DEFAULT_PARTITION_PREFETCH_WINDOW_COUNT: usize = 2; +const PARTITION_PREFETCH_WINDOW_COUNT_ENV: &str = + "LANCE_IVF_PQ_MERGE_PARTITION_PREFETCH_WINDOW_COUNT"; +static PARTITION_WINDOW_SIZE: LazyLock = LazyLock::new(|| { + std::env::var(PARTITION_WINDOW_SIZE_ENV) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_PARTITION_WINDOW_SIZE) +}); +static PARTITION_PREFETCH_WINDOW_COUNT: LazyLock = LazyLock::new(|| { + std::env::var(PARTITION_PREFETCH_WINDOW_COUNT_ENV) + .ok() + .and_then(|v| v.parse::().ok()) + .unwrap_or(DEFAULT_PARTITION_PREFETCH_WINDOW_COUNT) +}); /// Strict bitwise equality check for FixedSizeListArray values. /// Returns true only if length, value_length and all underlying primitive values are equal. @@ -457,6 +478,253 @@ async fn compute_shard_content_key( Ok((min_fragment_id, min_row_id, parent_name)) } +#[derive(Debug)] +struct ShardInfo { + reader: Arc, + lengths: Vec, + partition_offsets: Vec, + total_rows: usize, + sort_key: (u32, u64, String), +} + +#[derive(Debug)] +struct ShardWindowReadJob { + reader: Arc, + window_lengths: Vec, + window_total_rows: usize, + start_offset: usize, + end_offset: usize, +} + +#[derive(Debug)] +struct PartitionWindowBatches { + window_start: usize, + per_partition_batches: Vec>, +} + +type PartitionWindowFuture = Pin> + Send>>; + +struct ShardMergeReader { + shard_infos: Arc>, + nlist: usize, + partition_window_size: usize, + prefetch_window_count: usize, + next_window_start: usize, + in_flight_windows: futures::stream::FuturesOrdered, + current_window: Option, + current_partition_offset: usize, +} + +impl ShardMergeReader { + fn new( + shard_infos: Vec, + nlist: usize, + partition_window_size: usize, + prefetch_window_count: usize, + ) -> Self { + let mut this = Self { + shard_infos: Arc::new(shard_infos), + nlist, + partition_window_size: partition_window_size.max(1), + prefetch_window_count: prefetch_window_count.max(1), + next_window_start: 0, + in_flight_windows: futures::stream::FuturesOrdered::new(), + current_window: None, + current_partition_offset: 0, + }; + this.fill_prefetch(); + this + } + + fn fill_prefetch(&mut self) { + while self.in_flight_windows.len() < self.prefetch_window_count + && self.next_window_start < self.nlist + { + let window_start = self.next_window_start; + let window_end = std::cmp::min(window_start + self.partition_window_size, self.nlist); + self.next_window_start = window_end; + + let shard_infos = Arc::clone(&self.shard_infos); + let nlist = self.nlist; + let fut: PartitionWindowFuture = Box::pin(async move { + read_partition_window(shard_infos, nlist, window_start, window_end).await + }); + self.in_flight_windows.push_back(fut); + } + } + + async fn next_partition(&mut self) -> Result)>> { + loop { + if let Some(window) = self.current_window.as_mut() { + if self.current_partition_offset < window.per_partition_batches.len() { + let partition_id = window.window_start + self.current_partition_offset; + let batches = std::mem::take( + &mut window.per_partition_batches[self.current_partition_offset], + ); + self.current_partition_offset += 1; + if self.current_partition_offset == window.per_partition_batches.len() { + self.current_window = None; + self.current_partition_offset = 0; + } + self.fill_prefetch(); + return Ok(Some((partition_id, batches))); + } + self.current_window = None; + self.current_partition_offset = 0; + continue; + } + + self.fill_prefetch(); + match self.in_flight_windows.next().await { + Some(window) => { + self.current_window = Some(window?); + self.current_partition_offset = 0; + } + None => return Ok(None), + } + } + } +} + +async fn read_partition_window( + shard_infos: Arc>, + nlist: usize, + window_start: usize, + window_end: usize, +) -> Result { + let window_len = window_end - window_start; + + let shard_jobs: Vec = shard_infos + .iter() + .map(|shard| { + let window_lengths = shard.lengths[window_start..window_end].to_vec(); + let window_total_rows = window_lengths.iter().map(|len| *len as usize).sum(); + let start_offset = shard.partition_offsets[window_start]; + let end_offset = if window_end < nlist { + shard.partition_offsets[window_end] + } else { + shard.total_rows + }; + + ShardWindowReadJob { + reader: Arc::clone(&shard.reader), + window_lengths, + window_total_rows, + start_offset, + end_offset, + } + }) + .collect(); + + let shard_parallelism = shard_jobs.len().max(1); + let mut shard_results_stream = futures::stream::iter(shard_jobs.into_iter().enumerate().map( + |(shard_idx, shard_job)| async move { + let per_partition_batches = + read_shard_window_partitions(shard_job, window_start, window_end, window_len) + .await?; + Ok::<(usize, Vec>), Error>((shard_idx, per_partition_batches)) + }, + )) + .buffer_unordered(shard_parallelism); + + let mut shard_results: Vec<(usize, Vec>)> = + Vec::with_capacity(shard_parallelism); + while let Some(shard_result) = shard_results_stream.next().await { + shard_results.push(shard_result?); + } + shard_results.sort_by_key(|(shard_idx, _)| *shard_idx); + + let mut per_partition_batches: Vec> = vec![Vec::new(); window_len]; + for (_, mut shard_partition_batches) in shard_results { + for rel_partition in 0..window_len { + per_partition_batches[rel_partition] + .append(&mut shard_partition_batches[rel_partition]); + } + } + + Ok(PartitionWindowBatches { + window_start, + per_partition_batches, + }) +} + +async fn read_shard_window_partitions( + shard_job: ShardWindowReadJob, + window_start: usize, + window_end: usize, + window_len: usize, +) -> Result>> { + let mut per_partition_batches: Vec> = vec![Vec::new(); window_len]; + if shard_job.window_total_rows == 0 { + return Ok(per_partition_batches); + } + + let mut stream = shard_job.reader.read_stream( + lance_io::ReadBatchParams::Range(shard_job.start_offset..shard_job.end_offset), + u32::MAX, + 4, + lance_encoding::decoder::FilterExpression::no_filter(), + )?; + + let mut rel_partition = 0usize; + while rel_partition < window_len && shard_job.window_lengths[rel_partition] == 0 { + rel_partition += 1; + } + let mut remaining = if rel_partition < window_len { + shard_job.window_lengths[rel_partition] as usize + } else { + 0 + }; + + while let Some(rb) = stream.next().await { + let rb = rb?; + let mut consumed = 0usize; + + while consumed < rb.num_rows() { + while rel_partition < window_len && remaining == 0 { + rel_partition += 1; + if rel_partition < window_len { + remaining = shard_job.window_lengths[rel_partition] as usize; + } + } + + if rel_partition >= window_len { + return Err(Error::Index { + message: format!( + "Shard has more rows than declared lengths in partition window [{}, {})", + window_start, window_end + ), + location: location!(), + }); + } + + let to_take = std::cmp::min(remaining, rb.num_rows() - consumed); + per_partition_batches[rel_partition].push(rb.slice(consumed, to_take)); + consumed += to_take; + remaining -= to_take; + } + } + + while rel_partition < window_len && remaining == 0 { + rel_partition += 1; + if rel_partition < window_len { + remaining = shard_job.window_lengths[rel_partition] as usize; + } + } + + if rel_partition != window_len { + return Err(Error::Index { + message: format!( + "Shard has fewer rows than declared lengths in partition window [{}, {})", + window_start, window_end + ), + location: location!(), + }); + } + + Ok(per_partition_batches) +} + /// Merge all partial_* vector index auxiliary files under `index_dir/{uuid}/partial_*/auxiliary.idx` /// into `index_dir/{uuid}/auxiliary.idx`. /// @@ -545,9 +813,9 @@ pub async fn merge_partial_vector_auxiliary_files( let mut accumulated_lengths: Vec = Vec::new(); let mut first_centroids: Option = None; - // Track per-shard IVF lengths to reorder writing to partitions later - #[allow(clippy::type_complexity)] - let mut shard_infos: Vec<(object_store::path::Path, Vec, (u32, u64, String))> = Vec::new(); + // Track per-shard readers, IVF lengths, and precomputed partition offsets. + // This avoids reopening each shard file for every partition during merge. + let mut shard_infos: Vec = Vec::new(); // Iterate over each shard auxiliary file and merge its metadata and collect lengths for (aux, key) in &shard_keys { @@ -1146,18 +1414,32 @@ pub async fn merge_partial_vector_auxiliary_files( } } - // Collect per-shard lengths to write grouped by partition later - shard_infos.push((aux.clone(), lengths.clone(), key.clone())); - // Accumulate overall lengths per partition for unified IVF model + let mut partition_offsets = Vec::with_capacity(nlist); + let mut running_offset = 0usize; + for len in &lengths { + partition_offsets.push(running_offset); + running_offset = running_offset.saturating_add(*len as usize); + } + + // Accumulate overall lengths per partition for unified IVF model. for pid in 0..nlist { let part_len = lengths[pid]; accumulated_lengths[pid] = accumulated_lengths[pid].saturating_add(part_len); } + + // Keep one opened reader per shard and reuse it during partition merge. + shard_infos.push(ShardInfo { + reader: Arc::new(reader), + lengths, + partition_offsets, + total_rows: running_offset, + sort_key: key.clone(), + }); } // Re-sort shard_infos using content-derived keys to decouple per-partition // write ordering from discovery order. - shard_infos.sort_by(|a, b| a.2.cmp(&b.2)); + shard_infos.sort_by(|a, b| a.sort_key.cmp(&b.sort_key)); // Write rows grouped by partition across all shards to ensure contiguous ranges per partition @@ -1180,46 +1462,28 @@ pub async fn merge_partial_vector_auxiliary_files( SupportedIvfIndexType::IvfPq | SupportedIvfIndexType::IvfHnswPq => { // For PQ-backed indices, transpose PQ codes while merging partitions // so that the unified file stores column-major PQ codes. - for pid in 0..nlist { - let total_len = accumulated_lengths[pid] as usize; - if total_len == 0 { - continue; - } - - let mut part_batches: Vec = Vec::new(); - for (path, lens, _) in shard_infos.iter() { - let part_len = lens[pid] as usize; - if part_len == 0 { - continue; - } - let offset: usize = lens.iter().take(pid).map(|x| *x as usize).sum(); - let fh = sched.open_file(path, &CachedFileSize::unknown()).await?; - let reader = V2Reader::try_open( - fh, - None, - Arc::default(), - &lance_core::cache::LanceCache::no_cache(), - V2ReaderOptions::default(), - ) - .await?; - let mut stream = reader.read_stream( - lance_io::ReadBatchParams::Range(offset..offset + part_len), - u32::MAX, - 4, - lance_encoding::decoder::FilterExpression::no_filter(), - )?; - while let Some(rb) = stream.next().await { - let rb = rb?; - part_batches.push(rb); - } - } + let partition_window_size = *PARTITION_WINDOW_SIZE; + let prefetch_window_count = *PARTITION_PREFETCH_WINDOW_COUNT; + let mut shard_merge_reader = ShardMergeReader::new( + shard_infos, + nlist, + partition_window_size, + prefetch_window_count, + ); - if part_batches.is_empty() { + while let Some((pid, batches)) = shard_merge_reader.next_partition().await? { + if accumulated_lengths[pid] == 0 { continue; } + if batches.is_empty() { + return Err(Error::Index { + message: format!("No merged batches found for non-empty partition {}", pid), + location: location!(), + }); + } - let schema = part_batches[0].schema(); - let partition_batch = concat_batches(&schema, part_batches.iter())?; + let schema = batches[0].schema(); + let partition_batch = concat_batches(&schema, batches.iter())?; if let Some(w) = v2w_opt.as_mut() { write_partition_rows_pq_transposed(w, partition_batch).await?; } @@ -1227,23 +1491,15 @@ pub async fn merge_partial_vector_auxiliary_files( } _ => { for pid in 0..nlist { - for (path, lens, _) in shard_infos.iter() { - let part_len = lens[pid] as usize; + for shard in shard_infos.iter() { + let part_len = shard.lengths[pid] as usize; if part_len == 0 { continue; } - let offset: usize = lens.iter().take(pid).map(|x| *x as usize).sum(); - let fh = sched.open_file(path, &CachedFileSize::unknown()).await?; - let reader = V2Reader::try_open( - fh, - None, - Arc::default(), - &lance_core::cache::LanceCache::no_cache(), - V2ReaderOptions::default(), - ) - .await?; + let offset = shard.partition_offsets[pid]; if let Some(w) = v2w_opt.as_mut() { - write_partition_rows(&reader, w, offset..offset + part_len).await?; + write_partition_rows(shard.reader.as_ref(), w, offset..offset + part_len) + .await?; } } }