diff --git a/datafusion/src/physical_plan/sort_preserving_merge.rs b/datafusion/src/physical_plan/sort_preserving_merge.rs index b4bcc2935e4f4..1bcdd63886b62 100644 --- a/datafusion/src/physical_plan/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sort_preserving_merge.rs @@ -17,6 +17,7 @@ //! Defines the sort preserving merge plan +use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; @@ -59,6 +60,8 @@ pub struct SortPreservingMergeExec { expr: Vec, /// The target size of yielded batches target_batch_size: usize, + /// Execution metrics + metrics: ExecutionPlanMetricsSet, } impl SortPreservingMergeExec { @@ -72,6 +75,7 @@ impl SortPreservingMergeExec { input, expr, target_batch_size, + metrics: ExecutionPlanMetricsSet::new(), } } @@ -134,6 +138,8 @@ impl ExecutionPlan for SortPreservingMergeExec { ))); } + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input_partitions = self.input.output_partitioning().partition_count(); match input_partitions { 0 => Err(DataFusionError::Internal( @@ -141,7 +147,7 @@ impl ExecutionPlan for SortPreservingMergeExec { .to_owned(), )), 1 => { - // bypass if there is only one partition to merge + // bypass if there is only one partition to merge (no metrics in this case either) self.input.execute(0).await } _ => { @@ -159,6 +165,7 @@ impl ExecutionPlan for SortPreservingMergeExec { self.schema(), &self.expr, self.target_batch_size, + baseline_metrics, ))) } } @@ -176,6 +183,10 @@ impl ExecutionPlan for SortPreservingMergeExec { } } } + + fn metrics(&self) -> Option { + Some(self.metrics.clone_inner()) + } } /// A `SortKeyCursor` is created from a `RecordBatch`, and a set of @@ -338,6 +349,8 @@ struct SortPreservingMergeStream { sort_options: Vec, /// The desired RecordBatch size to yield target_batch_size: usize, + /// used to record execution metrics + baseline_metrics: BaselineMetrics, /// If the stream has encountered an error aborted: bool, @@ -351,6 +364,7 @@ impl SortPreservingMergeStream { schema: SchemaRef, expressions: &[PhysicalSortExpr], target_batch_size: usize, + baseline_metrics: BaselineMetrics, ) -> Self { let cursors = (0..streams.len()) .into_iter() @@ -364,6 +378,7 @@ impl SortPreservingMergeStream { column_expressions: expressions.iter().map(|x| x.expr.clone()).collect(), sort_options: expressions.iter().map(|x| x.options).collect(), target_batch_size, + baseline_metrics, aborted: false, in_progress: vec![], next_batch_index: 0, @@ -390,7 +405,7 @@ impl SortPreservingMergeStream { return Poll::Ready(Ok(())); } - // Fetch a new record and create a cursor from it + // Fetch a new input record and create a cursor from it match futures::ready!(stream.poll_next_unpin(cx)) { None => return Poll::Ready(Ok(())), Some(Err(e)) => { @@ -539,6 +554,17 @@ impl Stream for SortPreservingMergeStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { + let poll = self.poll_next_inner(cx); + self.baseline_metrics.record_poll(poll) + } +} + +impl SortPreservingMergeStream { + #[inline] + fn poll_next_inner( + self: &mut Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll>> { if self.aborted { return Poll::Ready(None); } @@ -556,6 +582,11 @@ impl Stream for SortPreservingMergeStream { } loop { + // NB timer records time taken on drop, so there are no + // calls to `timer.done()` below. + let elapsed_compute = self.baseline_metrics.elapsed_compute().clone(); + let _timer = elapsed_compute.timer(); + let stream_idx = match self.next_stream_idx() { Ok(Some(idx)) => idx, Ok(None) if self.in_progress.is_empty() => return Poll::Ready(None), @@ -607,6 +638,7 @@ impl RecordBatchStream for SortPreservingMergeStream { #[cfg(test)] mod tests { + use crate::physical_plan::metrics::MetricValue; use std::iter::FromIterator; use crate::arrow::array::{Int32Array, StringArray, TimestampNanosecondArray}; @@ -1149,11 +1181,15 @@ mod tests { streams.push(receiver); } + let metrics = ExecutionPlanMetricsSet::new(); + let baseline_metrics = BaselineMetrics::new(&metrics, 0); + let merge_stream = SortPreservingMergeStream::new( streams, batches.schema(), sort.as_slice(), 1024, + baseline_metrics, ); let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); @@ -1172,4 +1208,59 @@ mod tests { assert_eq!(basic, partition); } + + #[tokio::test] + async fn test_merge_metrics() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); + let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![10, 20])); + let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("b"), Some("d")])); + let b2 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + + let schema = b1.schema(); + let sort = vec![PhysicalSortExpr { + expr: col("b", &schema).unwrap(), + options: Default::default(), + }]; + let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); + let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec), 1024)); + + let collected = collect(merge.clone()).await.unwrap(); + let expected = vec![ + "+----+---+", + "| a | b |", + "+----+---+", + "| 1 | a |", + "| 10 | b |", + "| 2 | c |", + "| 20 | d |", + "+----+---+", + ]; + assert_batches_eq!(expected, collected.as_slice()); + + // Now, validate metrics + let metrics = merge.metrics().unwrap(); + + assert_eq!(metrics.output_rows().unwrap(), 4); + assert!(metrics.elapsed_compute().unwrap() > 0); + + let mut saw_start = false; + let mut saw_end = false; + metrics.iter().for_each(|m| match m.value() { + MetricValue::StartTimestamp(ts) => { + saw_start = true; + assert!(ts.value().unwrap().timestamp_nanos() > 0); + } + MetricValue::EndTimestamp(ts) => { + saw_end = true; + assert!(ts.value().unwrap().timestamp_nanos() > 0); + } + _ => {} + }); + + assert!(saw_start); + assert!(saw_end); + } }