Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 6 additions & 38 deletions rust/lance-index/src/vector/v3/shuffler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ use std::sync::Arc;
use arrow::{array::AsArray, compute::sort_to_indices};
use arrow_array::{RecordBatch, UInt32Array};
use arrow_schema::Schema;
use future::try_join_all;
use futures::prelude::*;
use futures::{future::try_join_all, prelude::*};
use lance_arrow::{RecordBatchExt, SchemaExt};
use lance_core::{
cache::LanceCache,
Expand Down Expand Up @@ -69,7 +68,6 @@ pub struct IvfShuffler {
num_partitions: usize,

// options
buffer_size: usize,
precomputed_shuffle_buffers: Option<Vec<String>>,
}

Expand All @@ -79,16 +77,10 @@ impl IvfShuffler {
object_store: Arc::new(ObjectStore::local()),
output_dir,
num_partitions,
buffer_size: 4096,
precomputed_shuffle_buffers: None,
}
}

pub fn with_buffer_size(mut self, buffer_size: usize) -> Self {
self.buffer_size = buffer_size;
self
}

pub fn with_precomputed_shuffle_buffers(
mut self,
precomputed_shuffle_buffers: Option<Vec<String>>,
Expand Down Expand Up @@ -163,44 +155,20 @@ impl Shuffler for IvfShuffler {
})
.buffered(get_num_compute_intensive_cpus());

// part_id: | 0 | 1 | 3 |
// partition_buffers: |[batch,batch,..]|[batch,batch,..]|[batch,batch,..]|
let mut partition_buffers = vec![Vec::new(); num_partitions];

let mut counter = 0;
let mut total_loss = 0.0;
while let Some(shuffled) = parallel_sort_stream.next().await {
let (shuffled, loss) = shuffled?;
total_loss += loss;

for (part_id, batches) in shuffled.into_iter().enumerate() {
let part_batches = &mut partition_buffers[part_id];
part_batches.extend(batches);
}

counter += 1;

// do flush
if counter % self.buffer_size == 0 {
let mut futs = vec![];
for (part_id, writer) in writers.iter_mut().enumerate() {
let batches = &partition_buffers[part_id];
let mut futs = Vec::new();
for (part_id, (writer, batches)) in writers.iter_mut().zip(shuffled.iter()).enumerate()
{
if !batches.is_empty() {
partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
futs.push(writer.write_batches(batches.iter()));
}
try_join_all(futs).await?;

partition_buffers.iter_mut().for_each(|b| b.clear());
}
}

// final flush
for (part_id, batches) in partition_buffers.into_iter().enumerate() {
let writer = &mut writers[part_id];
partition_sizes[part_id] += batches.iter().map(|b| b.num_rows()).sum::<usize>();
for batch in batches.iter() {
writer.write_batch(batch).await?;
}
try_join_all(futs).await?;
}

// finish all writers
Expand Down
Loading