From 126c88c2d5334618deeebd62eb9133c9329b86f6 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 3 Apr 2023 16:56:52 +0100 Subject: [PATCH 1/2] Use SortPreservingMerge for in memory sort --- .../core/src/physical_plan/sorts/sort.rs | 336 ++---------------- 1 file changed, 39 insertions(+), 297 deletions(-) diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index c3fc06206ca15..0edda59ac8f45 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -25,7 +25,7 @@ use crate::execution::memory_pool::{ human_readable_size, MemoryConsumer, MemoryReservation, }; use crate::execution::runtime_env::RuntimeEnv; -use crate::physical_plan::common::{batch_byte_size, IPCWriter, SizedRecordBatchStream}; +use crate::physical_plan::common::{batch_byte_size, IPCWriter}; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet, @@ -35,28 +35,27 @@ use crate::physical_plan::sorts::SortedStream; use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::physical_plan::{ DisplayFormatType, Distribution, EmptyRecordBatchStream, ExecutionPlan, Partitioning, - RecordBatchStream, SendableRecordBatchStream, Statistics, + SendableRecordBatchStream, Statistics, }; use crate::prelude::SessionConfig; -use arrow::array::{make_array, Array, ArrayRef, MutableArrayData}; +use arrow::array::ArrayRef; pub use arrow::compute::SortOptions; -use arrow::compute::{concat, lexsort_to_indices, take, SortColumn, TakeOptions}; +use arrow::compute::{lexsort_to_indices, take, SortColumn, TakeOptions}; use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::ipc::reader::FileReader; use arrow::record_batch::RecordBatch; use datafusion_physical_expr::EquivalenceProperties; -use futures::{Stream, StreamExt, TryStreamExt}; +use futures::{StreamExt, TryStreamExt}; use log::{debug, error}; use std::any::Any; -use std::cmp::{min, Ordering}; +use std::cmp::Ordering; use std::fmt; use std::fmt::{Debug, Formatter}; use std::fs::File; use std::io::BufReader; use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::task::{Context, Poll}; use tempfile::NamedTempFile; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::task; @@ -72,7 +71,7 @@ use tokio::task; /// 3. when input is exhausted, merge all in memory batches and spills to get a total order. struct ExternalSorter { schema: SchemaRef, - in_mem_batches: Vec, + in_mem_batches: Vec, spills: Vec, /// Sort expressions expr: Vec, @@ -132,12 +131,13 @@ impl ExternalSorter { // NB timer records time taken on drop, so there are no // calls to `timer.done()` below. let _timer = tracking_metrics.elapsed_compute().timer(); - let partial = sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?; + let sorted_batch = + sort_batch(input, self.schema.clone(), &self.expr, self.fetch)?; // The resulting batch might be smaller (or larger, see #3747) than the input // batch due to either a propagated limit or the re-construction of arrays. So // for being reliable, we need to reflect the memory usage of the partial batch. - let new_size = batch_byte_size(&partial.sorted_batch); + let new_size = batch_byte_size(&sorted_batch); match new_size.cmp(&size) { Ordering::Greater => { // We don't have to call try_grow here, since we have already used the @@ -155,7 +155,7 @@ impl ExternalSorter { } Ordering::Equal => {} } - self.in_mem_batches.push(partial); + self.in_mem_batches.push(sorted_batch); } Ok(()) } @@ -276,279 +276,42 @@ impl Debug for ExternalSorter { /// consume the non-empty `sorted_batches` and do in_mem_sort fn in_mem_partial_sort( - buffered_batches: &mut Vec, + buffered_batches: &mut Vec, schema: SchemaRef, expressions: &[PhysicalSortExpr], batch_size: usize, tracking_metrics: MemTrackingMetrics, - fetch: Option, + _fetch: Option, ) -> Result { - assert_ne!(buffered_batches.len(), 0); - if buffered_batches.len() == 1 { - let result = buffered_batches.pop(); - Ok(Box::pin(SizedRecordBatchStream::new( - schema, - vec![Arc::new(result.unwrap().sorted_batch)], - tracking_metrics, - ))) - } else { - let (sorted_arrays, batches): (Vec>, Vec) = - buffered_batches - .drain(..) - .map(|b| { - let BatchWithSortArray { - sort_arrays, - sorted_batch: batch, - } = b; - (sort_arrays, batch) - }) - .unzip(); - - let sorted_iter = { - // NB timer records time taken on drop, so there are no - // calls to `timer.done()` below. - let _timer = tracking_metrics.elapsed_compute().timer(); - get_sorted_iter(&sorted_arrays, expressions, batch_size, fetch)? - }; - Ok(Box::pin(SortedSizedRecordBatchStream::new( - schema, - batches, - sorted_iter, - tracking_metrics, - ))) + if buffered_batches.len() < 2 { + let batches: Vec<_> = buffered_batches.drain(..).collect(); + return Ok(Box::pin(RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::iter(batches.into_iter().map(move |batch| { + tracking_metrics.output_rows().add(batch.num_rows()); + Ok(batch) + })), + ))); } -} - -#[derive(Debug, Copy, Clone)] -struct CompositeIndex { - batch_idx: u32, - row_idx: u32, -} -/// Get sorted iterator by sort concatenated `SortColumn`s -fn get_sorted_iter( - sort_arrays: &[Vec], - expr: &[PhysicalSortExpr], - batch_size: usize, - fetch: Option, -) -> Result { - let row_indices = sort_arrays - .iter() - .enumerate() - .flat_map(|(i, arrays)| { - (0..arrays[0].len()).map(move |r| CompositeIndex { - // since we original use UInt32Array to index the combined mono batch, - // component record batches won't overflow as well, - // use u32 here for space efficiency. - batch_idx: i as u32, - row_idx: r as u32, - }) - }) - .collect::>(); - - let sort_columns = expr - .iter() - .enumerate() - .map(|(i, expr)| { - let columns_i = sort_arrays - .iter() - .map(|cs| cs[i].as_ref()) - .collect::>(); - Ok(SortColumn { - values: concat(columns_i.as_slice())?, - options: Some(expr.options), - }) + let streams = buffered_batches + .drain(..) + .map(|batch| { + let s = RecordBatchStreamAdapter::new( + schema.clone(), + futures::stream::once(futures::future::ready(Ok(batch))), + ); + SortedStream::new(Box::pin(s), batch_size) }) - .collect::>>()?; - let indices = lexsort_to_indices(&sort_columns, fetch)?; - - // Calculate composite index based on sorted indices - let row_indices = indices - .values() - .iter() - .map(|i| row_indices[*i as usize]) .collect(); - Ok(SortedIterator::new(row_indices, batch_size)) -} - -struct SortedIterator { - /// Current logical position in the iterator - pos: usize, - /// Sorted composite index of where to find the rows in buffered batches - composite: Vec, - /// Maximum batch size to produce - batch_size: usize, -} - -impl SortedIterator { - fn new(composite: Vec, batch_size: usize) -> Self { - Self { - pos: 0, - composite, - batch_size, - } - } - - fn memory_size(&self) -> usize { - std::mem::size_of_val(self) + std::mem::size_of_val(&self.composite[..]) - } -} - -impl Iterator for SortedIterator { - type Item = Vec; - - /// Emit a max of `batch_size` positions each time - fn next(&mut self) -> Option { - let length = self.composite.len(); - if self.pos >= length { - return None; - } - - let current_size = min(self.batch_size, length - self.pos); - - // Combine adjacent indexes from the same batch to make a slice, - // for more efficient `extend` later. - let mut last_batch_idx = self.composite[self.pos].batch_idx; - let mut indices_in_batch = Vec::with_capacity(current_size); - - let mut slices = vec![]; - for ci in &self.composite[self.pos..self.pos + current_size] { - if ci.batch_idx != last_batch_idx { - group_indices(last_batch_idx, &mut indices_in_batch, &mut slices); - last_batch_idx = ci.batch_idx; - } - indices_in_batch.push(ci.row_idx); - } - - assert!( - !indices_in_batch.is_empty(), - "There should have at least one record in a sort output slice." - ); - group_indices(last_batch_idx, &mut indices_in_batch, &mut slices); - - self.pos += current_size; - Some(slices) - } -} - -/// Group continuous indices into a slice for better `extend` performance -fn group_indices( - batch_idx: u32, - positions: &mut Vec, - output: &mut Vec, -) { - positions.sort_unstable(); - let mut last_pos = 0; - let mut run_length = 0; - for pos in positions.iter() { - if run_length == 0 { - last_pos = *pos; - run_length = 1; - } else if *pos == last_pos + 1 { - run_length += 1; - last_pos = *pos; - } else { - output.push(CompositeSlice { - batch_idx, - start_row_idx: last_pos + 1 - run_length, - len: run_length as usize, - }); - last_pos = *pos; - run_length = 1; - } - } - assert!( - run_length > 0, - "There should have at least one record in a sort output slice." - ); - output.push(CompositeSlice { - batch_idx, - start_row_idx: last_pos + 1 - run_length, - len: run_length as usize, - }); - positions.clear() -} - -/// Stream of sorted record batches -struct SortedSizedRecordBatchStream { - schema: SchemaRef, - batches: Vec, - sorted_iter: SortedIterator, - num_cols: usize, - metrics: MemTrackingMetrics, -} - -impl SortedSizedRecordBatchStream { - /// new - pub fn new( - schema: SchemaRef, - batches: Vec, - sorted_iter: SortedIterator, - mut metrics: MemTrackingMetrics, - ) -> Self { - let size = batches.iter().map(batch_byte_size).sum::() - + sorted_iter.memory_size(); - metrics.init_mem_used(size); - let num_cols = batches[0].num_columns(); - SortedSizedRecordBatchStream { - schema, - batches, - sorted_iter, - num_cols, - metrics, - } - } -} - -impl Stream for SortedSizedRecordBatchStream { - type Item = Result; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - _: &mut Context<'_>, - ) -> Poll> { - match self.sorted_iter.next() { - None => Poll::Ready(None), - Some(slices) => { - let num_rows = slices.iter().map(|s| s.len).sum(); - let output = (0..self.num_cols) - .map(|i| { - let arrays = self - .batches - .iter() - .map(|b| b.column(i).data()) - .collect::>(); - let mut mutable = MutableArrayData::new(arrays, false, num_rows); - for x in slices.iter() { - mutable.extend( - x.batch_idx as usize, - x.start_row_idx as usize, - x.start_row_idx as usize + x.len, - ); - } - make_array(mutable.freeze()) - }) - .collect::>(); - let batch = - RecordBatch::try_new(self.schema.clone(), output).map_err(Into::into); - let poll = Poll::Ready(Some(batch)); - self.metrics.record_poll(poll) - } - } - } -} - -struct CompositeSlice { - batch_idx: u32, - start_row_idx: u32, - len: usize, -} - -impl RecordBatchStream for SortedSizedRecordBatchStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } + Ok(Box::pin(SortPreservingMergeStream::new_from_streams( + streams, + schema, + expressions, + tracking_metrics, + batch_size, + )?)) } async fn spill_partial_sorted_stream( @@ -808,17 +571,12 @@ impl ExecutionPlan for SortExec { } } -struct BatchWithSortArray { - sort_arrays: Vec, - sorted_batch: RecordBatch, -} - fn sort_batch( batch: RecordBatch, schema: SchemaRef, expr: &[PhysicalSortExpr], fetch: Option, -) -> Result { +) -> Result { let sort_columns = expr .iter() .map(|e| e.evaluate_to_sort_column(&batch)) @@ -846,23 +604,7 @@ fn sort_batch( .collect::, ArrowError>>()?, )?; - let sort_arrays = sort_columns - .into_iter() - .map(|sc| { - Ok(take( - sc.values.as_ref(), - &indices, - Some(TakeOptions { - check_bounds: false, - }), - )?) - }) - .collect::>>()?; - - Ok(BatchWithSortArray { - sort_arrays, - sorted_batch, - }) + Ok(sorted_batch) } async fn do_sort( From fc56bfffea1795c49a58759ecc8a0fc5a5476ebd Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Mon, 3 Apr 2023 18:52:31 +0100 Subject: [PATCH 2/2] Generic SortPreservingMerge --- .../core/src/physical_plan/sorts/sort.rs | 14 +- .../sorts/sort_preserving_merge.rs | 386 ++++++++++++------ 2 files changed, 266 insertions(+), 134 deletions(-) diff --git a/datafusion/core/src/physical_plan/sorts/sort.rs b/datafusion/core/src/physical_plan/sorts/sort.rs index 0edda59ac8f45..0fb15e8cf2c00 100644 --- a/datafusion/core/src/physical_plan/sorts/sort.rs +++ b/datafusion/core/src/physical_plan/sorts/sort.rs @@ -30,7 +30,7 @@ use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ BaselineMetrics, CompositeMetricsSet, MemTrackingMetrics, MetricsSet, }; -use crate::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeStream; +use crate::physical_plan::sorts::sort_preserving_merge::streaming_merge; use crate::physical_plan::sorts::SortedStream; use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::physical_plan::{ @@ -193,13 +193,13 @@ impl ExternalSorter { let tracking_metrics = self .metrics_set .new_final_tracking(self.partition_id, &self.runtime.memory_pool); - Ok(Box::pin(SortPreservingMergeStream::new_from_streams( + streaming_merge( streams, self.schema.clone(), &self.expr, tracking_metrics, self.session_config.batch_size(), - )?)) + ) } else if !self.in_mem_batches.is_empty() { let tracking_metrics = self .metrics_set @@ -305,13 +305,7 @@ fn in_mem_partial_sort( }) .collect(); - Ok(Box::pin(SortPreservingMergeStream::new_from_streams( - streams, - schema, - expressions, - tracking_metrics, - batch_size, - )?)) + streaming_merge(streams, schema, expressions, tracking_metrics, batch_size) } async fn spill_partial_sorted_stream( diff --git a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs index 7ef4d3bf8e868..a68b45a1e9022 100644 --- a/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/core/src/physical_plan/sorts/sort_preserving_merge.rs @@ -18,11 +18,14 @@ //! Defines the sort preserving merge plan use std::any::Any; +use std::cmp::Ordering; use std::collections::VecDeque; -use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; +use arrow::array::{ArrayRef, AsArray, PrimitiveArray}; +use arrow::datatypes::{ArrowNativeTypeOp, ArrowPrimitiveType, DataType, Int64Type}; +use arrow::error::ArrowError; use arrow::row::{RowConverter, SortField}; use arrow::{ array::{make_array as make_arrow_array, MutableArrayData}, @@ -30,7 +33,7 @@ use arrow::{ record_batch::RecordBatch, }; use futures::stream::{Fuse, FusedStream}; -use futures::{ready, Stream, StreamExt}; +use futures::{ready, StreamExt}; use log::debug; use tokio::sync::mpsc; @@ -40,11 +43,11 @@ use crate::physical_plan::metrics::{ ExecutionPlanMetricsSet, MemTrackingMetrics, MetricsSet, }; use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream}; -use crate::physical_plan::stream::RecordBatchReceiverStream; +use crate::physical_plan::stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter}; use crate::physical_plan::{ common::spawn_execution, expressions::PhysicalSortExpr, DisplayFormatType, - Distribution, ExecutionPlan, Partitioning, PhysicalExpr, RecordBatchStream, - SendableRecordBatchStream, Statistics, + Distribution, ExecutionPlan, Partitioning, PhysicalExpr, SendableRecordBatchStream, + Statistics, }; use datafusion_physical_expr::EquivalenceProperties; @@ -221,13 +224,13 @@ impl ExecutionPlan for SortPreservingMergeExec { debug!("Done setting up sender-receiver for SortPreservingMergeExec::execute"); - let result = Box::pin(SortPreservingMergeStream::new_from_streams( + let result = streaming_merge( receivers, schema, &self.expr, tracking_metrics, context.session_config().batch_size(), - )?); + )?; debug!("Got stream result from SortPreservingMergeStream::new_from_receivers"); @@ -286,8 +289,232 @@ impl MergingStreams { } } +/// Performs a streaming merge of the input +pub(crate) fn streaming_merge( + streams: Vec, + schema: SchemaRef, + expressions: &[PhysicalSortExpr], + mut tracking_metrics: MemTrackingMetrics, + batch_size: usize, +) -> Result { + let stream_count = streams.len(); + let batches = (0..stream_count).map(|_| VecDeque::new()).collect(); + tracking_metrics.init_mem_used(streams.iter().map(|s| s.mem_used).sum()); + let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect(); + + let mut stream = SortPreservingMergeStream { + schema: schema.clone(), + batches, + streams: MergingStreams::new(wrappers), + column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), + tracking_metrics, + aborted: false, + in_progress: Vec::with_capacity(batch_size), + next_batch_id: 0, + loser_tree: Vec::with_capacity(stream_count), + loser_tree_adjusted: false, + batch_size, + }; + + if expressions.len() == 1 { + match expressions[0].expr.data_type(&schema)? { + DataType::Int64 => { + let mut merge = PrimitiveMerge:: { + cursors: (0..stream_count).map(|_| None).collect(), + }; + + return Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::poll_fn(move |cx| { + let poll = stream.poll_next(cx, &mut merge); + stream.tracking_metrics.record_poll(poll) + }), + ))); + } + _ => {} + } + } + + let sort_fields = expressions + .iter() + .map(|expr| { + let data_type = expr.expr.data_type(&schema)?; + Ok(SortField::new_with_options(data_type, expr.options)) + }) + .collect::>>()?; + let row_converter = RowConverter::new(sort_fields)?; + + let mut merge = RowMerge { + cursors: (0..stream_count).map(|_| None).collect(), + row_converter, + }; + + Ok(Box::pin(RecordBatchStreamAdapter::new( + schema, + futures::stream::poll_fn(move |cx| { + let poll = stream.poll_next(cx, &mut merge); + stream.tracking_metrics.record_poll(poll) + }), + ))) +} + +trait Merge { + fn is_finished(&self, stream_idx: usize) -> bool; + + fn push( + &mut self, + stream_idx: usize, + batch_idx: usize, + cols: Vec, + ) -> Result<(), ArrowError>; + + fn challenger_win(&self, winner: usize, challenger: usize) -> bool; + + fn advance(&mut self, stream_idx: usize) -> Option; +} + +struct RowMerge { + /// Vector that holds all [`SortKeyCursor`]s + cursors: Vec>, + + /// row converter + row_converter: RowConverter, +} + +impl Merge for RowMerge { + fn is_finished(&self, stream_idx: usize) -> bool { + self.cursors[stream_idx] + .as_ref() + .map(|cursor| cursor.is_finished()) + .unwrap_or(true) + } + + fn push( + &mut self, + stream_idx: usize, + batch_idx: usize, + cols: Vec, + ) -> Result<(), ArrowError> { + let rows = self.row_converter.convert_columns(&cols)?; + self.cursors[stream_idx] = Some(SortKeyCursor::new(stream_idx, batch_idx, rows)); + Ok(()) + } + + fn challenger_win(&self, winner: usize, challenger: usize) -> bool { + match (&self.cursors[winner], &self.cursors[challenger]) { + (None, _) => true, + (_, None) => false, + (Some(winner), Some(challenger)) => challenger < winner, + } + } + + fn advance(&mut self, stream_idx: usize) -> Option { + self.cursors[stream_idx] + .as_mut() + .filter(|cursor| !cursor.is_finished()) + .map(|cursor| cursor.advance()) + } +} + +struct PrimitiveCursor { + array: PrimitiveArray, + row_idx: usize, + stream_idx: usize, +} + +impl PrimitiveCursor { + fn is_finished(&self) -> bool { + self.array.len() == self.row_idx + } + + fn advance(&mut self) -> usize { + assert!(!self.is_finished()); + let t = self.row_idx; + self.row_idx += 1; + t + } + + fn current(&self) -> T::Native { + // TODO: Handle nulls + self.array.values()[self.row_idx] + } +} + +impl PartialEq for PrimitiveCursor { + fn eq(&self, other: &Self) -> bool { + self.current() == other.current() + } +} + +impl Eq for PrimitiveCursor {} + +impl PartialOrd for PrimitiveCursor { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for PrimitiveCursor { + fn cmp(&self, other: &Self) -> Ordering { + // Order finished cursors greater (last) + match (self.is_finished(), other.is_finished()) { + (true, true) => Ordering::Equal, + (_, true) => Ordering::Less, + (true, _) => Ordering::Greater, + _ => self + .current() + .compare(other.current()) + .then_with(|| self.stream_idx.cmp(&other.stream_idx)), + } + } +} + +struct PrimitiveMerge { + cursors: Vec>>, +} + +impl Merge for PrimitiveMerge { + fn is_finished(&self, stream_idx: usize) -> bool { + self.cursors[stream_idx] + .as_ref() + .map(|cursor| cursor.is_finished()) + .unwrap_or(true) + } + + fn push( + &mut self, + stream_idx: usize, + batch_idx: usize, + cols: Vec, + ) -> Result<(), ArrowError> { + assert_eq!(cols.len(), 1); + let array = cols[0].as_primitive().clone(); + self.cursors[stream_idx] = Some(PrimitiveCursor { + array, + row_idx: 0, + stream_idx, + }); + Ok(()) + } + + fn challenger_win(&self, winner: usize, challenger: usize) -> bool { + match (&self.cursors[winner], &self.cursors[challenger]) { + (None, _) => true, + (_, None) => false, + (Some(winner), Some(challenger)) => challenger < winner, + } + } + + fn advance(&mut self, stream_idx: usize) -> Option { + self.cursors[stream_idx] + .as_mut() + .filter(|cursor| !cursor.is_finished()) + .map(|cursor| cursor.advance()) + } +} + #[derive(Debug)] -pub(crate) struct SortPreservingMergeStream { +struct SortPreservingMergeStream { /// The schema of the RecordBatches yielded by this stream schema: SchemaRef, @@ -315,9 +542,6 @@ pub(crate) struct SortPreservingMergeStream { /// An id to uniquely identify the input stream batch next_batch_id: usize, - /// Vector that holds all [`SortKeyCursor`]s - cursors: Vec>, - /// A loser tree that always produces the minimum cursor /// /// Node 0 stores the top winner, Nodes 1..num_streams store @@ -340,63 +564,19 @@ pub(crate) struct SortPreservingMergeStream { /// target batch size batch_size: usize, - - /// row converter - row_converter: RowConverter, } impl SortPreservingMergeStream { - pub(crate) fn new_from_streams( - streams: Vec, - schema: SchemaRef, - expressions: &[PhysicalSortExpr], - mut tracking_metrics: MemTrackingMetrics, - batch_size: usize, - ) -> Result { - let stream_count = streams.len(); - let batches = (0..stream_count).map(|_| VecDeque::new()).collect(); - tracking_metrics.init_mem_used(streams.iter().map(|s| s.mem_used).sum()); - let wrappers = streams.into_iter().map(|s| s.stream.fuse()).collect(); - - let sort_fields = expressions - .iter() - .map(|expr| { - let data_type = expr.expr.data_type(&schema)?; - Ok(SortField::new_with_options(data_type, expr.options)) - }) - .collect::>>()?; - let row_converter = RowConverter::new(sort_fields)?; - - Ok(Self { - schema, - batches, - streams: MergingStreams::new(wrappers), - column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), - tracking_metrics, - aborted: false, - in_progress: vec![], - next_batch_id: 0, - cursors: (0..stream_count).map(|_| None).collect(), - loser_tree: Vec::with_capacity(stream_count), - loser_tree_adjusted: false, - batch_size, - row_converter, - }) - } - /// If the stream at the given index is not exhausted, and the last cursor for the /// stream is finished, poll the stream for the next RecordBatch and create a new /// cursor for the stream from the returned result - fn maybe_poll_stream( + fn maybe_poll_stream( &mut self, cx: &mut Context<'_>, idx: usize, + merge: &mut M, ) -> Poll> { - if self.cursors[idx] - .as_ref() - .map(|cursor| !cursor.is_finished()) - .unwrap_or(false) - { + if !merge.is_finished(idx) { // Cursor is not finished - don't need a new RecordBatch yet return Poll::Ready(Ok(())); } @@ -423,18 +603,9 @@ impl SortPreservingMergeStream { }) .collect::>>()?; - let rows = match self.row_converter.convert_columns(&cols) { - Ok(rows) => rows, - Err(e) => { - return Poll::Ready(Err(DataFusionError::ArrowError(e))); - } - }; - - self.cursors[idx] = Some(SortKeyCursor::new( - idx, - self.next_batch_id, // assign this batch an ID - rows, - )); + if let Err(e) = merge.push(idx, self.next_batch_id, cols) { + return Poll::Ready(Err(DataFusionError::ArrowError(e))); + } self.next_batch_id += 1; self.batches[idx].push_back(batch) } else { @@ -445,7 +616,7 @@ impl SortPreservingMergeStream { } if empty_batch { - self.maybe_poll_stream(cx, idx) + self.maybe_poll_stream(cx, idx, merge) } else { Poll::Ready(Ok(())) } @@ -538,31 +709,18 @@ impl SortPreservingMergeStream { RecordBatch::try_new(self.schema.clone(), columns).map_err(Into::into) } -} - -impl Stream for SortPreservingMergeStream { - type Item = Result; - fn poll_next( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { - let poll = self.poll_next_inner(cx); - self.tracking_metrics.record_poll(poll) - } -} - -impl SortPreservingMergeStream { #[inline] - fn poll_next_inner( - self: &mut Pin<&mut Self>, + fn poll_next( + &mut self, cx: &mut Context<'_>, + merge: &mut M, ) -> Poll>> { if self.aborted { return Poll::Ready(None); } // try to initialize the loser tree - if let Err(e) = ready!(self.init_loser_tree(cx)) { + if let Err(e) = ready!(self.init_loser_tree(cx, merge)) { return Poll::Ready(Some(Err(e))); } @@ -573,17 +731,14 @@ impl SortPreservingMergeStream { loop { // Adjust the loser tree if necessary, returning control if needed - if let Err(e) = ready!(self.update_loser_tree(cx)) { + if let Err(e) = ready!(self.update_loser_tree(cx, merge)) { return Poll::Ready(Some(Err(e))); } - let min_cursor_idx = self.loser_tree[0]; - let next = self.cursors[min_cursor_idx] - .as_mut() - .filter(|cursor| !cursor.is_finished()) - .map(|cursor| (cursor.stream_idx(), cursor.advance())); + let stream_idx = self.loser_tree[0]; + let next = merge.advance(stream_idx); - if let Some((stream_idx, row_idx)) = next { + if let Some(row_idx) = next { self.loser_tree_adjusted = false; let batch_idx = self.batches[stream_idx].len() - 1; self.in_progress.push(RowIndex { @@ -610,9 +765,10 @@ impl SortPreservingMergeStream { /// * Poll::Ready(Ok()) on success /// * Poll::Ready(Err..) if any of the inputs errored #[inline] - fn init_loser_tree( - self: &mut Pin<&mut Self>, + fn init_loser_tree( + &mut self, cx: &mut Context<'_>, + merge: &mut M, ) -> Poll> { let num_streams = self.streams.num_streams(); @@ -623,7 +779,7 @@ impl SortPreservingMergeStream { // Ensure all non-exhausted streams have a cursor from which // rows can be pulled for i in 0..num_streams { - if let Err(e) = ready!(self.maybe_poll_stream(cx, i)) { + if let Err(e) = ready!(self.maybe_poll_stream(cx, i, merge)) { self.aborted = true; return Poll::Ready(Err(e)); } @@ -636,14 +792,7 @@ impl SortPreservingMergeStream { let mut cmp_node = (num_streams + i) / 2; while cmp_node != 0 && self.loser_tree[cmp_node] != usize::MAX { let challenger = self.loser_tree[cmp_node]; - let challenger_win = - match (&self.cursors[winner], &self.cursors[challenger]) { - (None, _) => true, - (_, None) => false, - (Some(winner), Some(challenger)) => challenger < winner, - }; - - if challenger_win { + if merge.challenger_win(winner, challenger) { self.loser_tree[cmp_node] = winner; winner = challenger; } @@ -663,9 +812,10 @@ impl SortPreservingMergeStream { /// * Poll::Ready(Ok()) on success /// * Poll::Ready(Err..) if any of the winning input erroed #[inline] - fn update_loser_tree( - self: &mut Pin<&mut Self>, + fn update_loser_tree( + &mut self, cx: &mut Context<'_>, + merge: &mut M, ) -> Poll> { if self.loser_tree_adjusted { return Poll::Ready(Ok(())); @@ -673,7 +823,7 @@ impl SortPreservingMergeStream { let num_streams = self.streams.num_streams(); let mut winner = self.loser_tree[0]; - if let Err(e) = ready!(self.maybe_poll_stream(cx, winner)) { + if let Err(e) = ready!(self.maybe_poll_stream(cx, winner, merge)) { self.aborted = true; return Poll::Ready(Err(e)); } @@ -682,13 +832,7 @@ impl SortPreservingMergeStream { let mut cmp_node = (num_streams + winner) / 2; while cmp_node != 0 { let challenger = self.loser_tree[cmp_node]; - let challenger_win = match (&self.cursors[winner], &self.cursors[challenger]) - { - (None, _) => true, - (_, None) => false, - (Some(winner), Some(challenger)) => challenger < winner, - }; - if challenger_win { + if merge.challenger_win(winner, challenger) { self.loser_tree[cmp_node] = winner; winner = challenger; } @@ -700,12 +844,6 @@ impl SortPreservingMergeStream { } } -impl RecordBatchStream for SortPreservingMergeStream { - fn schema(&self) -> SchemaRef { - self.schema.clone() - } -} - #[cfg(test)] mod tests { use std::iter::FromIterator; @@ -1297,7 +1435,7 @@ mod tests { let tracking_metrics = MemTrackingMetrics::new(&metrics, task_ctx.memory_pool(), 0); - let merge_stream = SortPreservingMergeStream::new_from_streams( + let merge_stream = streaming_merge( streams, batches.schema(), sort.as_slice(), @@ -1306,7 +1444,7 @@ mod tests { ) .unwrap(); - let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); + let mut merged = common::collect(merge_stream).await.unwrap(); assert_eq!(merged.len(), 1); let merged = merged.remove(0);