Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 42 additions & 18 deletions datafusion/src/physical_plan/hash_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ use pin_project_lite::pin_project;

use async_trait::async_trait;

use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::metrics::{
self, BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
use super::{expressions::Column, RecordBatchStream, SendableRecordBatchStream};

/// Hash aggregate modes
Expand Down Expand Up @@ -207,14 +209,15 @@ impl ExecutionPlan for HashAggregateExec {
let input = self.input.execute(partition).await?;
let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect();

let output_rows = MetricBuilder::new(&self.metrics).output_rows(partition);
let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);

if self.group_expr.is_empty() {
Ok(Box::pin(HashAggregateStream::new(
self.mode,
self.schema.clone(),
self.aggr_expr.clone(),
input,
baseline_metrics,
)))
} else {
Ok(Box::pin(GroupedHashAggregateStream::new(
Expand All @@ -223,7 +226,7 @@ impl ExecutionPlan for HashAggregateExec {
group_expr,
self.aggr_expr.clone(),
input,
output_rows,
baseline_metrics,
)))
}
}
Expand Down Expand Up @@ -315,7 +318,6 @@ pin_project! {
#[pin]
output: futures::channel::oneshot::Receiver<ArrowResult<RecordBatch>>,
finished: bool,
output_rows: metrics::Count,
}
}

Expand Down Expand Up @@ -487,7 +489,9 @@ async fn compute_grouped_hash_aggregate(
group_expr: Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
mut input: SendableRecordBatchStream,
elapsed_compute: metrics::Time,
) -> ArrowResult<RecordBatch> {
let timer = elapsed_compute.timer();
// The expressions to evaluate the batch, one vec of expressions per aggregation.
// Assume create_schema() always put group columns in front of aggr columns, we set
// col_idx_base to group expression count.
Expand All @@ -499,8 +503,10 @@ async fn compute_grouped_hash_aggregate(

// iterate over all input batches and update the accumulators
let mut accumulators = Accumulators::default();
timer.done();
while let Some(batch) = input.next().await {
let batch = batch?;
let timer = elapsed_compute.timer();
accumulators = group_aggregate_batch(
&mode,
&random_state,
Expand All @@ -511,9 +517,13 @@ async fn compute_grouped_hash_aggregate(
&aggregate_expressions,
)
.map_err(DataFusionError::into_arrow_external_error)?;
timer.done();
}

create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema)
let timer = elapsed_compute.timer();
let batch = create_batch_from_map(&mode, &accumulators, group_expr.len(), &schema);
timer.done();
batch
}

impl GroupedHashAggregateStream {
Expand All @@ -524,28 +534,30 @@ impl GroupedHashAggregateStream {
group_expr: Vec<Arc<dyn PhysicalExpr>>,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
output_rows: metrics::Count,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();

let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
tokio::spawn(async move {
let result = compute_grouped_hash_aggregate(
mode,
schema_clone,
group_expr,
aggr_expr,
input,
elapsed_compute,
)
.await;
.await
.record_output(&baseline_metrics);
tx.send(result)
});

Self {
schema,
output: rx,
finished: false,
output_rows,
}
}
}
Expand Down Expand Up @@ -604,8 +616,6 @@ impl Stream for GroupedHashAggregateStream {
return Poll::Ready(None);
}

let output_rows = self.output_rows.clone();

// is the output ready?
let this = self.project();
let output_poll = this.output.poll(cx);
Expand All @@ -620,10 +630,6 @@ impl Stream for GroupedHashAggregateStream {
Ok(result) => result,
};

if let Ok(batch) = &result {
output_rows.add(batch.num_rows())
}

Poll::Ready(Some(result))
}
Poll::Pending => Poll::Pending,
Expand Down Expand Up @@ -720,25 +726,33 @@ async fn compute_hash_aggregate(
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
mut input: SendableRecordBatchStream,
elapsed_compute: metrics::Time,
) -> ArrowResult<RecordBatch> {
let timer = elapsed_compute.timer();
let mut accumulators = create_accumulators(&aggr_expr)
.map_err(DataFusionError::into_arrow_external_error)?;
let expressions = aggregate_expressions(&aggr_expr, &mode, 0)
.map_err(DataFusionError::into_arrow_external_error)?;
let expressions = Arc::new(expressions);
timer.done();

// 1 for each batch, update / merge accumulators with the expressions' values
// future is ready when all batches are computed
while let Some(batch) = input.next().await {
let batch = batch?;
let timer = elapsed_compute.timer();
aggregate_batch(&mode, &batch, &mut accumulators, &expressions)
.map_err(DataFusionError::into_arrow_external_error)?;
timer.done();
}

// 2. convert values to a record batch
finalize_aggregation(&accumulators, &mode)
let timer = elapsed_compute.timer();
let batch = finalize_aggregation(&accumulators, &mode)
.map(|columns| RecordBatch::try_new(schema.clone(), columns))
.map_err(DataFusionError::into_arrow_external_error)?
.map_err(DataFusionError::into_arrow_external_error)?;
timer.done();
batch
}

impl HashAggregateStream {
Expand All @@ -748,13 +762,23 @@ impl HashAggregateStream {
schema: SchemaRef,
aggr_expr: Vec<Arc<dyn AggregateExpr>>,
input: SendableRecordBatchStream,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();

let schema_clone = schema.clone();
let elapsed_compute = baseline_metrics.elapsed_compute().clone();
tokio::spawn(async move {
let result =
compute_hash_aggregate(mode, schema_clone, aggr_expr, input).await;
let result = compute_hash_aggregate(
mode,
schema_clone,
aggr_expr,
input,
elapsed_compute,
)
.await
.record_output(&baseline_metrics);

tx.send(result)
});

Expand Down
29 changes: 9 additions & 20 deletions datafusion/src/physical_plan/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

//! Defines the SORT plan

use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use super::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, RecordOutput,
};
use super::{RecordBatchStream, SendableRecordBatchStream};
use crate::error::{DataFusionError, Result};
use crate::physical_plan::expressions::PhysicalSortExpr;
Expand Down Expand Up @@ -151,18 +153,13 @@ impl ExecutionPlan for SortExec {
}
}

let baseline_metrics = BaselineMetrics::new(&self.metrics, partition);
let input = self.input.execute(partition).await?;

let output_rows = MetricBuilder::new(&self.metrics).output_rows(partition);

let elapsed_compute =
MetricBuilder::new(&self.metrics).elapsed_compute(partition);

Ok(Box::pin(SortStream::new(
input,
self.expr.clone(),
output_rows,
elapsed_compute,
baseline_metrics,
)))
}

Expand Down Expand Up @@ -227,16 +224,14 @@ pin_project! {
output: futures::channel::oneshot::Receiver<ArrowResult<Option<RecordBatch>>>,
finished: bool,
schema: SchemaRef,
output_rows: metrics::Count
}
}

impl SortStream {
fn new(
input: SendableRecordBatchStream,
expr: Vec<PhysicalSortExpr>,
output_rows: metrics::Count,
sort_time: metrics::Time,
baseline_metrics: BaselineMetrics,
) -> Self {
let (tx, rx) = futures::channel::oneshot::channel();
let schema = input.schema();
Expand All @@ -246,13 +241,14 @@ impl SortStream {
.await
.map_err(DataFusionError::into_arrow_external_error)
.and_then(move |batches| {
let timer = sort_time.timer();
let timer = baseline_metrics.elapsed_compute().timer();
// combine all record batches into one for each column
let combined = common::combine_batches(&batches, schema.clone())?;
// sort combined record batch
let result = combined
.map(|batch| sort_batch(batch, schema, &expr))
.transpose()?;
.transpose()?
.record_output(&baseline_metrics);
timer.done();
Ok(result)
});
Expand All @@ -264,7 +260,6 @@ impl SortStream {
output: rx,
finished: false,
schema,
output_rows,
}
}
}
Expand All @@ -273,8 +268,6 @@ impl Stream for SortStream {
type Item = ArrowResult<RecordBatch>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let output_rows = self.output_rows.clone();

if self.finished {
return Poll::Ready(None);
}
Expand All @@ -293,10 +286,6 @@ impl Stream for SortStream {
Ok(result) => result.transpose(),
};

if let Some(Ok(batch)) = &result {
output_rows.add(batch.num_rows());
}

Poll::Ready(result)
}
Poll::Pending => Poll::Pending,
Expand Down
Loading