Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 93 additions & 2 deletions datafusion/src/physical_plan/sort_preserving_merge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

//! Defines the sort preserving merge plan

use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
Expand Down Expand Up @@ -59,6 +60,8 @@ pub struct SortPreservingMergeExec {
expr: Vec<PhysicalSortExpr>,
/// The target size of yielded batches
target_batch_size: usize,
/// Execution metrics
metrics: ExecutionPlanMetricsSet,
}

impl SortPreservingMergeExec {
Expand All @@ -72,6 +75,7 @@ impl SortPreservingMergeExec {
input,
expr,
target_batch_size,
metrics: ExecutionPlanMetricsSet::new(),
}
}

Expand Down Expand Up @@ -134,14 +138,16 @@ impl ExecutionPlan for SortPreservingMergeExec {
)));
}

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);

let input_partitions = self.input.output_partitioning().partition_count();
match input_partitions {
0 => Err(DataFusionError::Internal(
"SortPreservingMergeExec requires at least one input partition"
.to_owned(),
)),
1 => {
// bypass if there is only one partition to merge
// bypass if there is only one partition to merge (no metrics in this case either)
self.input.execute(0).await
}
_ => {
Expand All @@ -159,6 +165,7 @@ impl ExecutionPlan for SortPreservingMergeExec {
self.schema(),
&self.expr,
self.target_batch_size,
baseline_metrics,
)))
}
}
Expand All @@ -176,6 +183,10 @@ impl ExecutionPlan for SortPreservingMergeExec {
}
}
}

fn metrics(&self) -> Option<MetricsSet> {
Some(self.metrics.clone_inner())
}
}

/// A `SortKeyCursor` is created from a `RecordBatch`, and a set of
Expand Down Expand Up @@ -338,6 +349,8 @@ struct SortPreservingMergeStream {
sort_options: Vec<SortOptions>,
/// The desired RecordBatch size to yield
target_batch_size: usize,
/// used to record execution metrics
baseline_metrics: BaselineMetrics,
/// If the stream has encountered an error
aborted: bool,

Expand All @@ -351,6 +364,7 @@ impl SortPreservingMergeStream {
schema: SchemaRef,
expressions: &[PhysicalSortExpr],
target_batch_size: usize,
baseline_metrics: BaselineMetrics,
) -> Self {
let cursors = (0..streams.len())
.into_iter()
Expand All @@ -364,6 +378,7 @@ impl SortPreservingMergeStream {
column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(),
sort_options: expressions.iter().map(|x| x.options).collect(),
target_batch_size,
baseline_metrics,
aborted: false,
in_progress: vec![],
next_batch_index: 0,
Expand All @@ -390,7 +405,7 @@ impl SortPreservingMergeStream {
return Poll::Ready(Ok(()));
}

// Fetch a new record and create a cursor from it
// Fetch a new input record and create a cursor from it
match futures::ready!(stream.poll_next_unpin(cx)) {
None => return Poll::Ready(Ok(())),
Some(Err(e)) => {
Expand Down Expand Up @@ -539,6 +554,17 @@ impl Stream for SortPreservingMergeStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
let poll = self.poll_next_inner(cx);
self.baseline_metrics.record_poll(poll)
}
}

impl SortPreservingMergeStream {
#[inline]
fn poll_next_inner(
self: &mut Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<ArrowResult<RecordBatch>>> {
if self.aborted {
return Poll::Ready(None);
}
Expand All @@ -556,6 +582,11 @@ impl Stream for SortPreservingMergeStream {
}

loop {
// NB timer records time taken on drop, so there are no
// calls to `timer.done()` below.
let elapsed_compute = self.baseline_metrics.elapsed_compute().clone();
let _timer = elapsed_compute.timer();

let stream_idx = match self.next_stream_idx() {
Ok(Some(idx)) => idx,
Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None),
Expand Down Expand Up @@ -607,6 +638,7 @@ impl RecordBatchStream for SortPreservingMergeStream {

#[cfg(test)]
mod tests {
use crate::physical_plan::metrics::MetricValue;
use std::iter::FromIterator;

use crate::arrow::array::{Int32Array, StringArray, TimestampNanosecondArray};
Expand Down Expand Up @@ -1149,11 +1181,15 @@ mod tests {
streams.push(receiver);
}

let metrics = ExecutionPlanMetricsSet::new();
let baseline_metrics = BaselineMetrics::new(&metrics, 0);

let merge_stream = SortPreservingMergeStream::new(
streams,
batches.schema(),
sort.as_slice(),
1024,
baseline_metrics,
);

let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap();
Expand All @@ -1172,4 +1208,59 @@ mod tests {

assert_eq!(basic, partition);
}

#[tokio::test]
async fn test_merge_metrics() {
let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")]));
let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();

let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20]));
let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")]));
let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap();

let schema = b1.schema();
let sort = vec![PhysicalSortExpr {
expr: col("b", &schema).unwrap(),
options: Default::default(),
}];
let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap();
let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024));

let collected = collect(merge.clone()).await.unwrap();
let expected = vec![
"+----+---+",
"| a | b |",
"+----+---+",
"| 1 | a |",
"| 10 | b |",
"| 2 | c |",
"| 20 | d |",
"+----+---+",
];
assert_batches_eq!(expected, collected.as_slice());

// Now, validate metrics
let metrics = merge.metrics().unwrap();

assert_eq!(metrics.output_rows().unwrap(), 4);
assert!(metrics.elapsed_compute().unwrap() > 0);

let mut saw_start = false;
let mut saw_end = false;
metrics.iter().for_each(|m| match m.value() {
MetricValue::StartTimestamp(ts) => {
saw_start = true;
assert!(ts.value().unwrap().timestamp_nanos() > 0);
}
MetricValue::EndTimestamp(ts) => {
saw_end = true;
assert!(ts.value().unwrap().timestamp_nanos() > 0);
}
_ => {}
});

assert!(saw_start);
assert!(saw_end);
}
}