Skip to content
5 changes: 3 additions & 2 deletions datafusion/physical-plan/src/sorts/cursor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use std::cmp::Ordering;
use std::sync::Arc;

use arrow::array::{
types::ByteArrayType, Array, ArrowPrimitiveType, GenericByteArray,
Expand Down Expand Up @@ -151,7 +152,7 @@ impl<T: CursorValues> Ord for Cursor<T> {
/// Used for sorting when there are multiple columns in the sort key
#[derive(Debug)]
pub struct RowValues {
rows: Rows,
rows: Arc<Rows>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense, thank you @Dandandan !


/// Tracks for the memory used by in the `Rows` of this
/// cursor. Freed on drop
Expand All @@ -164,7 +165,7 @@ impl RowValues {
///
/// Panics if the reservation is not for exactly `rows.size()`
/// bytes or if `rows` is empty.
pub fn new(rows: Rows, reservation: MemoryReservation) -> Self {
pub fn new(rows: Arc<Rows>, reservation: MemoryReservation) -> Self {
assert_eq!(
rows.size(),
reservation.size(),
Expand Down
69 changes: 63 additions & 6 deletions datafusion/physical-plan/src/sorts/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use crate::{PhysicalExpr, PhysicalSortExpr};
use arrow::array::Array;
use arrow::datatypes::Schema;
use arrow::record_batch::RecordBatch;
use arrow::row::{RowConverter, SortField};
use datafusion_common::Result;
use arrow::row::{RowConverter, Rows, SortField};
use datafusion_common::{internal_datafusion_err, Result};
use datafusion_execution::memory_pool::MemoryReservation;
use datafusion_physical_expr_common::sort_expr::LexOrdering;
use futures::stream::{Fuse, StreamExt};
Expand Down Expand Up @@ -76,8 +76,40 @@ impl FusedStreams {
}
}

/// A pair of `Arc<Rows>` that can be reused
#[derive(Debug)]
struct ReusableRows {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

// inner[stream_idx] holds a two Arcs:
// at start of a new poll
// .0 is the rows from the previous poll (at start),
// .1 is the one that is being written to
// at end of a poll, .0 will be swapped with .1,
inner: Vec<[Option<Arc<Rows>>; 2]>,
}

impl ReusableRows {
// return a Rows for writing,
// does not clone if the existing rows can be reused
fn take_next(&mut self, stream_idx: usize) -> Result<Rows> {
Arc::try_unwrap(self.inner[stream_idx][1].take().unwrap()).map_err(|_| {
internal_datafusion_err!(
"Rows from RowCursorStream is still in use by consumer"
)
})
}
// save the Rows
fn save(&mut self, stream_idx: usize, rows: Arc<Rows>) {
self.inner[stream_idx][1] = Some(Arc::clone(&rows));
// swap the curent with the previous one, so that the next poll can reuse the Rows from the previous poll
let [a, b] = &mut self.inner[stream_idx];
std::mem::swap(a, b);
}
}

/// A [`PartitionedStream`] that wraps a set of [`SendableRecordBatchStream`]
/// and computes [`RowValues`] based on the provided [`PhysicalSortExpr`]
/// Note: the stream returns an error if the consumer buffers more than one RowValues (i.e. holds on to two RowValues
/// from the same partition at the same time).
#[derive(Debug)]
pub struct RowCursorStream {
/// Converter to convert output of physical expressions
Expand All @@ -88,6 +120,9 @@ pub struct RowCursorStream {
streams: FusedStreams,
/// Tracks the memory used by `converter`
reservation: MemoryReservation,
/// Allocated rows for each partition, we keep two to allow for buffering one
/// in the consumer of the stream
rows: ReusableRows,
}

impl RowCursorStream {
Expand All @@ -105,26 +140,48 @@ impl RowCursorStream {
})
.collect::<Result<Vec<_>>>()?;

let streams = streams.into_iter().map(|s| s.fuse()).collect();
let streams: Vec<_> = streams.into_iter().map(|s| s.fuse()).collect();
let converter = RowConverter::new(sort_fields)?;
let mut rows = Vec::with_capacity(streams.len());
for _ in &streams {
// Initialize each stream with an empty Rows
rows.push([
Some(Arc::new(converter.empty_rows(0, 0))),
Some(Arc::new(converter.empty_rows(0, 0))),
]);
}
Ok(Self {
converter,
reservation,
column_expressions: expressions.iter().map(|x| Arc::clone(&x.expr)).collect(),
streams: FusedStreams(streams),
rows: ReusableRows { inner: rows },
})
}

fn convert_batch(&mut self, batch: &RecordBatch) -> Result<RowValues> {
fn convert_batch(
&mut self,
batch: &RecordBatch,
stream_idx: usize,
) -> Result<RowValues> {
let cols = self
.column_expressions
.iter()
.map(|expr| expr.evaluate(batch)?.into_array(batch.num_rows()))
.collect::<Result<Vec<_>>>()?;

let rows = self.converter.convert_columns(&cols)?;
// At this point, ownership should of this Rows should be unique
let mut rows = self.rows.take_next(stream_idx)?;

rows.clear();

self.converter.append(&mut rows, &cols)?;
self.reservation.try_resize(self.converter.size())?;

let rows = Arc::new(rows);

self.rows.save(stream_idx, Arc::clone(&rows));

// track the memory in the newly created Rows.
let mut rows_reservation = self.reservation.new_empty();
rows_reservation.try_grow(rows.size())?;
Expand All @@ -146,7 +203,7 @@ impl PartitionedStream for RowCursorStream {
) -> Poll<Option<Self::Output>> {
Poll::Ready(ready!(self.streams.poll_next(cx, stream_idx)).map(|r| {
r.and_then(|batch| {
let cursor = self.convert_batch(&batch)?;
let cursor = self.convert_batch(&batch, stream_idx)?;
Ok((cursor, batch))
})
}))
Expand Down