diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 5a2ec9ee2d88f..e21cc311ccf07 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -50,7 +50,9 @@ use pin_project_lite::pin_project; use async_trait::async_trait; -use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use super::metrics::{ + self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, +}; use super::{expressions::Column, RecordBatchStream, SendableRecordBatchStream}; /// Hash aggregate modes @@ -207,7 +209,7 @@ impl ExecutionPlan for HashAggregateExec { let input = self.input.execute(partition).await?; let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect(); - let output_rows = MetricBuilder::new(&self.metrics).output_rows(partition); + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); if self.group_expr.is_empty() { Ok(Box::pin(HashAggregateStream::new( @@ -215,6 +217,7 @@ impl ExecutionPlan for HashAggregateExec { self.schema.clone(), self.aggr_expr.clone(), input, + baseline_metrics, ))) } else { Ok(Box::pin(GroupedHashAggregateStream::new( @@ -223,7 +226,7 @@ impl ExecutionPlan for HashAggregateExec { group_expr, self.aggr_expr.clone(), input, - output_rows, + baseline_metrics, ))) } } @@ -315,7 +318,6 @@ pin_project! { #[pin] output: futures::channel::oneshot::Receiver>, finished: bool, - output_rows: metrics::Count, } } @@ -487,7 +489,9 @@ async fn compute_grouped_hash_aggregate( group_expr: Vec>, aggr_expr: Vec>, mut input: SendableRecordBatchStream, + elapsed_compute: metrics::Time, ) -> ArrowResult { + let timer = elapsed_compute.timer(); // The expressions to evaluate the batch, one vec of expressions per aggregation. // Assume create_schema() always put group columns in front of aggr columns, we set // col_idx_base to group expression count. @@ -499,8 +503,10 @@ async fn compute_grouped_hash_aggregate( // iterate over all input batches and update the accumulators let mut accumulators = Accumulators::default(); + timer.done(); while let Some(batch) = input.next().await { let batch = batch?; + let timer = elapsed_compute.timer(); accumulators = group_aggregate_batch( &mode, &random_state, @@ -511,9 +517,13 @@ async fn compute_grouped_hash_aggregate( &aggregate_expressions, ) .map_err(DataFusionError::into_arrow_external_error)?; + timer.done(); } - create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema) + let timer = elapsed_compute.timer(); + let batch = create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema); + timer.done(); + batch } impl GroupedHashAggregateStream { @@ -524,11 +534,12 @@ impl GroupedHashAggregateStream { group_expr: Vec>, aggr_expr: Vec>, input: SendableRecordBatchStream, - output_rows: metrics::Count, + baseline_metrics: BaselineMetrics, ) -> Self { let (tx, rx) = futures::channel::oneshot::channel(); let schema_clone = schema.clone(); + let elapsed_compute = baseline_metrics.elapsed_compute().clone(); tokio::spawn(async move { let result = compute_grouped_hash_aggregate( mode, @@ -536,8 +547,10 @@ impl GroupedHashAggregateStream { group_expr, aggr_expr, input, + elapsed_compute, ) - .await; + .await + .record_output(&baseline_metrics); tx.send(result) }); @@ -545,7 +558,6 @@ impl GroupedHashAggregateStream { schema, output: rx, finished: false, - output_rows, } } } @@ -604,8 +616,6 @@ impl Stream for GroupedHashAggregateStream { return Poll::Ready(None); } - let output_rows = self.output_rows.clone(); - // is the output ready? let this = self.project(); let output_poll = this.output.poll(cx); @@ -620,10 +630,6 @@ impl Stream for GroupedHashAggregateStream { Ok(result) => result, }; - if let Ok(batch) = &result { - output_rows.add(batch.num_rows()) - } - Poll::Ready(Some(result)) } Poll::Pending => Poll::Pending, @@ -720,25 +726,33 @@ async fn compute_hash_aggregate( schema: SchemaRef, aggr_expr: Vec>, mut input: SendableRecordBatchStream, + elapsed_compute: metrics::Time, ) -> ArrowResult { + let timer = elapsed_compute.timer(); let mut accumulators = create_accumulators(&aggr_expr) .map_err(DataFusionError::into_arrow_external_error)?; let expressions = aggregate_expressions(&aggr_expr, &mode, 0) .map_err(DataFusionError::into_arrow_external_error)?; let expressions = Arc::new(expressions); + timer.done(); // 1 for each batch, update / merge accumulators with the expressions' values // future is ready when all batches are computed while let Some(batch) = input.next().await { let batch = batch?; + let timer = elapsed_compute.timer(); aggregate_batch(&mode, &batch, &mut accumulators, &expressions) .map_err(DataFusionError::into_arrow_external_error)?; + timer.done(); } // 2. convert values to a record batch - finalize_aggregation(&accumulators, &mode) + let timer = elapsed_compute.timer(); + let batch = finalize_aggregation(&accumulators, &mode) .map(|columns| RecordBatch::try_new(schema.clone(), columns)) - .map_err(DataFusionError::into_arrow_external_error)? + .map_err(DataFusionError::into_arrow_external_error)?; + timer.done(); + batch } impl HashAggregateStream { @@ -748,13 +762,23 @@ impl HashAggregateStream { schema: SchemaRef, aggr_expr: Vec>, input: SendableRecordBatchStream, + baseline_metrics: BaselineMetrics, ) -> Self { let (tx, rx) = futures::channel::oneshot::channel(); let schema_clone = schema.clone(); + let elapsed_compute = baseline_metrics.elapsed_compute().clone(); tokio::spawn(async move { - let result = - compute_hash_aggregate(mode, schema_clone, aggr_expr, input).await; + let result = compute_hash_aggregate( + mode, + schema_clone, + aggr_expr, + input, + elapsed_compute, + ) + .await + .record_output(&baseline_metrics); + tx.send(result) }); diff --git a/datafusion/src/physical_plan/sort.rs b/datafusion/src/physical_plan/sort.rs index df77a16c29476..5a47931f96e85 100644 --- a/datafusion/src/physical_plan/sort.rs +++ b/datafusion/src/physical_plan/sort.rs @@ -17,7 +17,9 @@ //! Defines the SORT plan -use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; +use super::metrics::{ + BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput, +}; use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; use crate::physical_plan::expressions::PhysicalSortExpr; @@ -151,18 +153,13 @@ impl ExecutionPlan for SortExec { } } + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); let input = self.input.execute(partition).await?; - let output_rows = MetricBuilder::new(&self.metrics).output_rows(partition); - - let elapsed_compute = - MetricBuilder::new(&self.metrics).elapsed_compute(partition); - Ok(Box::pin(SortStream::new( input, self.expr.clone(), - output_rows, - elapsed_compute, + baseline_metrics, ))) } @@ -227,7 +224,6 @@ pin_project! { output: futures::channel::oneshot::Receiver>>, finished: bool, schema: SchemaRef, - output_rows: metrics::Count } } @@ -235,8 +231,7 @@ impl SortStream { fn new( input: SendableRecordBatchStream, expr: Vec, - output_rows: metrics::Count, - sort_time: metrics::Time, + baseline_metrics: BaselineMetrics, ) -> Self { let (tx, rx) = futures::channel::oneshot::channel(); let schema = input.schema(); @@ -246,13 +241,14 @@ impl SortStream { .await .map_err(DataFusionError::into_arrow_external_error) .and_then(move |batches| { - let timer = sort_time.timer(); + let timer = baseline_metrics.elapsed_compute().timer(); // combine all record batches into one for each column let combined = common::combine_batches(&batches, schema.clone())?; // sort combined record batch let result = combined .map(|batch| sort_batch(batch, schema, &expr)) - .transpose()?; + .transpose()? + .record_output(&baseline_metrics); timer.done(); Ok(result) }); @@ -264,7 +260,6 @@ impl SortStream { output: rx, finished: false, schema, - output_rows, } } } @@ -273,8 +268,6 @@ impl Stream for SortStream { type Item = ArrowResult; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let output_rows = self.output_rows.clone(); - if self.finished { return Poll::Ready(None); } @@ -293,10 +286,6 @@ impl Stream for SortStream { Ok(result) => result.transpose(), }; - if let Some(Ok(batch)) = &result { - output_rows.add(batch.num_rows()); - } - Poll::Ready(result) } Poll::Pending => Poll::Pending, diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 8aae3d9507190..807edb6258637 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -41,6 +41,9 @@ use arrow::{ use datafusion::assert_batches_eq; use datafusion::assert_batches_sorted_eq; use datafusion::logical_plan::LogicalPlan; +use datafusion::physical_plan::metrics::MetricValue; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::physical_plan::ExecutionPlanVisitor; use datafusion::prelude::*; use datafusion::{ datasource::{csv::CsvReadOptions, MemTable}, @@ -2194,8 +2197,6 @@ async fn csv_explain_analyze() { let formatted = arrow::util::pretty::pretty_format_batches(&actual).unwrap(); let formatted = normalize_for_explain(&formatted); - println!("ANALYZE EXPLAIN:\n{}", formatted); - // Only test basic plumbing and try to avoid having to change too // many things let needle = @@ -2221,6 +2222,112 @@ async fn csv_explain_analyze_verbose() { assert_contains!(formatted, verbose_needle); } +/// A macro to assert that some particular line contains two substrings +/// +/// Usage: `assert_metrics!(actual, operator_name, metrics)` +/// +macro_rules! assert_metrics { + ($ACTUAL: expr, $OPERATOR_NAME: expr, $METRICS: expr) => { + let found = $ACTUAL + .lines() + .any(|line| line.contains($OPERATOR_NAME) && line.contains($METRICS)); + assert!( + found, + "Can not find a line with both '{}' and '{}' in\n\n{}", + $OPERATOR_NAME, $METRICS, $ACTUAL + ); + }; +} + +#[tokio::test] +async fn explain_analyze_baseline_metrics() { + // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE + // and then validate the presence of baseline metrics for supported operators + let config = ExecutionConfig::new().with_target_partitions(3); + let mut ctx = ExecutionContext::with_config(config); + register_aggregate_csv_by_sql(&mut ctx).await; + // a query with as many operators as we have metrics for + let sql = "EXPLAIN ANALYZE select count(*) from (SELECT count(*), c1 FROM aggregate_test_100 group by c1 ORDER BY c1)"; + let plan = ctx.create_logical_plan(sql).unwrap(); + let plan = ctx.optimize(&plan).unwrap(); + let physical_plan = ctx.create_physical_plan(&plan).unwrap(); + let results = collect(physical_plan.clone()).await.unwrap(); + let formatted = arrow::util::pretty::pretty_format_batches(&results).unwrap(); + let formatted = normalize_for_explain(&formatted); + + assert_metrics!( + &formatted, + "CoalescePartitionsExec", + "metrics=[output_rows=5, elapsed_compute=NOT RECORDED" + ); + assert_metrics!( + &formatted, + "HashAggregateExec: mode=Partial, gby=[]", + "metrics=[output_rows=3, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "HashAggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1]", + "metrics=[output_rows=5, elapsed_compute=" + ); + assert_metrics!( + &formatted, + "SortExec: [c1@0 ASC]", + "metrics=[output_rows=5, elapsed_compute=" + ); + + fn expected_to_have_metrics(plan: &dyn ExecutionPlan) -> bool { + use datafusion::physical_plan::{ + hash_aggregate::HashAggregateExec, sort::SortExec, + }; + + plan.as_any().downcast_ref::().is_some() + || plan.as_any().downcast_ref::().is_some() + } + + // Validate that the recorded elapsed compute time was more than + // zero for all operators as well as the start/end timestamp are set + struct TimeValidator {} + impl ExecutionPlanVisitor for TimeValidator { + type Error = std::convert::Infallible; + + fn pre_visit( + &mut self, + plan: &dyn ExecutionPlan, + ) -> std::result::Result { + if !expected_to_have_metrics(plan) { + return Ok(true); + } + let metrics = plan.metrics().unwrap().aggregate_by_partition(); + + assert!(metrics.output_rows().unwrap() > 0); + assert!(metrics.elapsed_compute().unwrap() > 0); + + let mut saw_start = false; + let mut saw_end = false; + metrics.iter().for_each(|m| match m.value() { + MetricValue::StartTimestamp(ts) => { + saw_start = true; + assert!(ts.value().unwrap().timestamp_nanos() > 0); + } + MetricValue::EndTimestamp(ts) => { + saw_end = true; + assert!(ts.value().unwrap().timestamp_nanos() > 0); + } + _ => {} + }); + + assert!(saw_start); + assert!(saw_end); + + Ok(true) + } + } + + datafusion::physical_plan::accept(physical_plan.as_ref(), &mut TimeValidator {}) + .unwrap(); +} + #[tokio::test] async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation