diff --git a/rust/lance-index/src/vector/v3/shuffler.rs b/rust/lance-index/src/vector/v3/shuffler.rs index c1d74812b85..595c0a0c736 100644 --- a/rust/lance-index/src/vector/v3/shuffler.rs +++ b/rust/lance-index/src/vector/v3/shuffler.rs @@ -9,8 +9,7 @@ use std::sync::Arc; use arrow::{array::AsArray, compute::sort_to_indices}; use arrow_array::{RecordBatch, UInt32Array}; use arrow_schema::Schema; -use future::try_join_all; -use futures::prelude::*; +use futures::{future::try_join_all, prelude::*}; use lance_arrow::{RecordBatchExt, SchemaExt}; use lance_core::{ cache::LanceCache, @@ -69,7 +68,6 @@ pub struct IvfShuffler { num_partitions: usize, // options - buffer_size: usize, precomputed_shuffle_buffers: Option>, } @@ -79,16 +77,10 @@ impl IvfShuffler { object_store: Arc::new(ObjectStore::local()), output_dir, num_partitions, - buffer_size: 4096, precomputed_shuffle_buffers: None, } } - pub fn with_buffer_size(mut self, buffer_size: usize) -> Self { - self.buffer_size = buffer_size; - self - } - pub fn with_precomputed_shuffle_buffers( mut self, precomputed_shuffle_buffers: Option>, @@ -163,44 +155,20 @@ impl Shuffler for IvfShuffler { }) .buffered(get_num_compute_intensive_cpus()); - // part_id: | 0 | 1 | 3 | - // partition_buffers: |[batch,batch,..]|[batch,batch,..]|[batch,batch,..]| - let mut partition_buffers = vec![Vec::new(); num_partitions]; - - let mut counter = 0; let mut total_loss = 0.0; while let Some(shuffled) = parallel_sort_stream.next().await { let (shuffled, loss) = shuffled?; total_loss += loss; - for (part_id, batches) in shuffled.into_iter().enumerate() { - let part_batches = &mut partition_buffers[part_id]; - part_batches.extend(batches); - } - - counter += 1; - - // do flush - if counter % self.buffer_size == 0 { - let mut futs = vec![]; - for (part_id, writer) in writers.iter_mut().enumerate() { - let batches = &partition_buffers[part_id]; + let mut futs = Vec::new(); + for (part_id, (writer, batches)) in writers.iter_mut().zip(shuffled.iter()).enumerate() + { + if !batches.is_empty() { partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::(); futs.push(writer.write_batches(batches.iter())); } - try_join_all(futs).await?; - - partition_buffers.iter_mut().for_each(|b| b.clear()); - } - } - - // final flush - for (part_id, batches) in partition_buffers.into_iter().enumerate() { - let writer = &mut writers[part_id]; - partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::(); - for batch in batches.iter() { - writer.write_batch(batch).await?; } + try_join_all(futs).await?; } // finish all writers