From d665450d27181e144fc6aaf3eb14e560c008b09c Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 25 Nov 2022 11:58:13 +0100 Subject: [PATCH 1/3] feat: `ResourceExhausted` for memory limit in `GroupedHashAggregateStream` Closes #3940. --- .../core/src/physical_plan/aggregates/hash.rs | 215 ++++++++++++------ .../core/src/physical_plan/aggregates/mod.rs | 176 +++++++++----- 2 files changed, 268 insertions(+), 123 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs index 9487df61afa06..e0edf5e259a62 100644 --- a/datafusion/core/src/physical_plan/aggregates/hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/hash.rs @@ -22,12 +22,16 @@ use std::task::{Context, Poll}; use std::vec; use ahash::RandomState; -use futures::{ - ready, - stream::{Stream, StreamExt}, -}; +use datafusion_expr::Accumulator; +use futures::stream::BoxStream; +use futures::stream::{Stream, StreamExt}; use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::execution::memory_manager::proxy::{ + MemoryConsumerProxy, RawTableAllocExt, VecAllocExt, +}; +use crate::execution::MemoryConsumerId; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, AccumulatorItem, AggregateMode, PhysicalGroupBy, }; @@ -74,6 +78,16 @@ Example: average * Finally, `get_value` returns an array with one entry computed from the state */ pub(crate) struct GroupedHashAggregateStream { + stream: BoxStream<'static, ArrowResult>, + schema: SchemaRef, +} + +/// Actual implementation of [`GroupedHashAggregateStream`]. +/// +/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem +/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with +/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2`]. +struct GroupedHashAggregateStreamInner { schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, @@ -90,6 +104,7 @@ pub(crate) struct GroupedHashAggregateStream { impl GroupedHashAggregateStream { /// Create a new GroupedHashAggregateStream + #[allow(clippy::too_many_arguments)] pub fn new( mode: AggregateMode, schema: SchemaRef, @@ -97,6 +112,8 @@ impl GroupedHashAggregateStream { aggr_expr: Vec>, input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, + context: Arc, + partition: usize, ) -> Result { let timer = baseline_metrics.elapsed_compute().timer(); @@ -108,18 +125,92 @@ impl GroupedHashAggregateStream { timer.done(); - Ok(Self { - schema, + let inner = GroupedHashAggregateStreamInner { + schema: Arc::clone(&schema), mode, input, aggr_expr, group_by, baseline_metrics, aggregate_expressions, - accumulators: Default::default(), + accumulators: Accumulators { + memory_consumer: MemoryConsumerProxy::new( + "Accumulators", + MemoryConsumerId::new(partition), + Arc::clone(&context.runtime_env().memory_manager), + ), + map: RawTable::with_capacity(0), + group_states: Vec::with_capacity(0), + }, random_state: Default::default(), finished: false, - }) + }; + + let stream = futures::stream::unfold(inner, |mut this| async move { + if this.finished { + return None; + } + + let elapsed_compute = this.baseline_metrics.elapsed_compute(); + + loop { + let result = match this.input.next().await { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = group_aggregate_batch( + &this.mode, + &this.random_state, + &this.group_by, + &this.aggr_expr, + batch, + &mut this.accumulators, + &this.aggregate_expressions, + ); + + timer.done(); + + // allocate memory + // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with + // overshooting a bit. Also this means we either store the whole record batch or not. + let result = match result { + Ok(allocated) => { + this.accumulators.memory_consumer.alloc(allocated).await + } + Err(e) => Err(e), + }; + + match result { + Ok(()) => continue, + Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + } + } + Some(Err(e)) => Err(e), + None => { + this.finished = true; + let timer = this.baseline_metrics.elapsed_compute().timer(); + let result = create_batch_from_map( + &this.mode, + &this.accumulators, + this.group_by.expr.len(), + &this.schema, + ) + .record_output(&this.baseline_metrics); + + timer.done(); + result + } + }; + + this.finished = true; + return Some((result, this)); + } + }); + + // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. + let stream = stream.fuse(); + let stream = Box::pin(stream); + + Ok(Self { schema, stream }) } } @@ -131,53 +222,7 @@ impl Stream for GroupedHashAggregateStream { cx: &mut Context<'_>, ) -> Poll> { let this = &mut *self; - if this.finished { - return Poll::Ready(None); - } - - let elapsed_compute = this.baseline_metrics.elapsed_compute(); - - loop { - let result = match ready!(this.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = group_aggregate_batch( - &this.mode, - &this.random_state, - &this.group_by, - &this.aggr_expr, - batch, - &mut this.accumulators, - &this.aggregate_expressions, - ); - - timer.done(); - - match result { - Ok(_) => continue, - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), - } - } - Some(Err(e)) => Err(e), - None => { - this.finished = true; - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = create_batch_from_map( - &this.mode, - &this.accumulators, - this.group_by.expr.len(), - &this.schema, - ) - .record_output(&this.baseline_metrics); - - timer.done(); - result - } - }; - - this.finished = true; - return Poll::Ready(Some(result)); - } + this.stream.poll_next_unpin(cx) } } @@ -187,6 +232,10 @@ impl RecordBatchStream for GroupedHashAggregateStream { } } +/// Perform group-by aggregation for the given [`RecordBatch`]. +/// +/// If successfull, this returns the additional number of bytes that were allocated during this process. +/// /// TODO: Make this a member function of [`GroupedHashAggregateStream`] fn group_aggregate_batch( mode: &AggregateMode, @@ -196,7 +245,7 @@ fn group_aggregate_batch( batch: RecordBatch, accumulators: &mut Accumulators, aggregate_expressions: &[Vec>], -) -> Result<()> { +) -> Result { // evaluate the grouping expressions let group_by_values = evaluate_group_by(group_by, &batch)?; @@ -205,6 +254,9 @@ fn group_aggregate_batch( // of them anyways, it is more performant to do it while they are together. let aggr_input_values = evaluate_many(aggregate_expressions, &batch)?; + // track memory allocations + let mut allocated = 0usize; + for grouping_set_values in group_by_values { // 1.1 construct the key from the group values // 1.2 construct the mapping key if it does not exist @@ -218,7 +270,9 @@ fn group_aggregate_batch( create_hashes(&grouping_set_values, random_state, &mut batch_hashes)?; for (row, hash) in batch_hashes.into_iter().enumerate() { - let Accumulators { map, group_states } = accumulators; + let Accumulators { + map, group_states, .. + } = accumulators; let entry = map.get_mut(hash, |(_hash, group_idx)| { // verify that a group that we are inserting with hash is @@ -239,7 +293,9 @@ fn group_aggregate_batch( if group_state.indices.is_empty() { groups_with_rows.push(*group_idx); }; - group_state.indices.push(row as u32); // remember this row + group_state + .indices + .push_accounted(row as u32, &mut allocated); // remember this row } // 1.2 Need to create new entry None => { @@ -257,12 +313,32 @@ fn group_aggregate_batch( accumulator_set, indices: vec![row as u32], // 1.3 }; + // NOTE: do NOT include the `GroupState` struct size in here because this is captured by + // `group_states` (see allocation down below) + allocated += group_state + .group_by_values + .iter() + .map(|sv| sv.size()) + .sum::() + + (std::mem::size_of::>() + * group_state.accumulator_set.capacity()) + + group_state + .accumulator_set + .iter() + .map(|accu| accu.size()) + .sum::() + + (std::mem::size_of::() * group_state.indices.capacity()); + let group_idx = group_states.len(); - group_states.push(group_state); + group_states.push_accounted(group_state, &mut allocated); groups_with_rows.push(group_idx); // for hasher function, use precomputed hash value - map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); + map.insert_accounted( + (hash, group_idx), + |(hash, _group_idx)| *hash, + &mut allocated, + ); } }; } @@ -326,10 +402,20 @@ fn group_aggregate_batch( ) }) .try_for_each(|(accumulator, values)| match mode { - AggregateMode::Partial => accumulator.update_batch(&values), + AggregateMode::Partial => { + let size_pre = accumulator.size(); + let res = accumulator.update_batch(&values); + let size_post = accumulator.size(); + allocated += size_post.saturating_sub(size_pre); + res + } AggregateMode::FinalPartitioned | AggregateMode::Final => { // note: the aggregation here is over states, not values, thus the merge - accumulator.merge_batch(&values) + let size_pre = accumulator.size(); + let res = accumulator.merge_batch(&values); + let size_post = accumulator.size(); + allocated += size_post.saturating_sub(size_pre); + res } }) // 2.5 @@ -340,7 +426,7 @@ fn group_aggregate_batch( })?; } - Ok(()) + Ok(allocated) } /// The state that is built for each output group. @@ -358,8 +444,9 @@ struct GroupState { } /// The state of all the groups -#[derive(Default)] struct Accumulators { + memory_consumer: MemoryConsumerProxy, + /// Logically maps group values to an index in `group_states` /// /// Uses the raw API of hashbrown to avoid actually storing the diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 6ce58592d83b5..312a3263aa814 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -150,6 +150,22 @@ impl PhysicalGroupBy { } } +enum StreamType { + AggregateStream(AggregateStream), + GroupedHashAggregateStreamV2(GroupedHashAggregateStreamV2), + GroupedHashAggregateStream(GroupedHashAggregateStream), +} + +impl From for SendableRecordBatchStream { + fn from(stream: StreamType) -> Self { + match stream { + StreamType::AggregateStream(stream) => Box::pin(stream), + StreamType::GroupedHashAggregateStreamV2(stream) => Box::pin(stream), + StreamType::GroupedHashAggregateStream(stream) => Box::pin(stream), + } + } +} + /// Hash aggregate execution plan #[derive(Debug)] pub struct AggregateExec { @@ -261,6 +277,54 @@ impl AggregateExec { row_supported(&group_schema, RowType::Compact) && accumulator_v2_supported(&self.aggr_expr) } + + fn execute_typed( + &self, + partition: usize, + context: Arc, + ) -> Result { + let batch_size = context.session_config().batch_size(); + let input = self.input.execute(partition, Arc::clone(&context))?; + + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + + if self.group_by.expr.is_empty() { + Ok(StreamType::AggregateStream(AggregateStream::new( + self.mode, + self.schema.clone(), + self.aggr_expr.clone(), + input, + baseline_metrics, + )?)) + } else if self.row_aggregate_supported() { + Ok(StreamType::GroupedHashAggregateStreamV2( + GroupedHashAggregateStreamV2::new( + self.mode, + self.schema.clone(), + self.group_by.clone(), + self.aggr_expr.clone(), + input, + baseline_metrics, + batch_size, + context, + partition, + )?, + )) + } else { + Ok(StreamType::GroupedHashAggregateStream( + GroupedHashAggregateStream::new( + self.mode, + self.schema.clone(), + self.group_by.clone(), + self.aggr_expr.clone(), + input, + baseline_metrics, + context, + partition, + )?, + )) + } + } } impl ExecutionPlan for AggregateExec { @@ -347,41 +411,8 @@ impl ExecutionPlan for AggregateExec { partition: usize, context: Arc, ) -> Result { - let batch_size = context.session_config().batch_size(); - let input = self.input.execute(partition, Arc::clone(&context))?; - - let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - - if self.group_by.expr.is_empty() { - Ok(Box::pin(AggregateStream::new( - self.mode, - self.schema.clone(), - self.aggr_expr.clone(), - input, - baseline_metrics, - )?)) - } else if self.row_aggregate_supported() { - Ok(Box::pin(GroupedHashAggregateStreamV2::new( - self.mode, - self.schema.clone(), - self.group_by.clone(), - self.aggr_expr.clone(), - input, - baseline_metrics, - batch_size, - context, - partition, - )?)) - } else { - Ok(Box::pin(GroupedHashAggregateStream::new( - self.mode, - self.schema.clone(), - self.group_by.clone(), - self.aggr_expr.clone(), - input, - baseline_metrics, - )?)) - } + self.execute_typed(partition, context) + .map(|stream| stream.into()) } fn metrics(&self) -> Option { @@ -706,13 +737,14 @@ mod tests { use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; - use datafusion_physical_expr::expressions::{lit, Count}; + use datafusion_physical_expr::expressions::{lit, ApproxDistinct, Count}; use datafusion_physical_expr::{AggregateExpr, PhysicalExpr, PhysicalSortExpr}; use futures::{FutureExt, Stream}; use std::any::Any; use std::sync::Arc; use std::task::{Context, Poll}; + use super::StreamType; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::{ ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, @@ -1105,37 +1137,63 @@ mod tests { groups: vec![vec![false]], }; - let aggregates: Vec> = vec![Arc::new(Avg::new( + // use slow-path in `hash.rs` + let aggregates_v1: Vec> = + vec![Arc::new(ApproxDistinct::new( + col("a", &input_schema)?, + "APPROX_DISTINCT(a)".to_string(), + DataType::UInt32, + ))]; + + // use fast-path in `row_hash.rs`. + let aggregates_v2: Vec> = vec![Arc::new(Avg::new( col("b", &input_schema)?, "AVG(b)".to_string(), DataType::Float64, ))]; - let partial_aggregate = Arc::new(AggregateExec::try_new( - AggregateMode::Partial, - groups, - aggregates, - input, - input_schema.clone(), - )?); + for (version, aggregates) in [(1, aggregates_v1), (2, aggregates_v2)] { + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups.clone(), + aggregates, + input.clone(), + input_schema.clone(), + )?); + + let stream = partial_aggregate.execute_typed(0, task_ctx.clone())?; + + // ensure that we really got the version we wanted + match version { + 1 => { + assert!(matches!(stream, StreamType::GroupedHashAggregateStream(_))); + } + 2 => { + assert!(matches!( + stream, + StreamType::GroupedHashAggregateStreamV2(_) + )); + } + _ => panic!("Unknown version: {version}"), + } - let err = common::collect(partial_aggregate.execute(0, task_ctx.clone())?) - .await - .unwrap_err(); - - // error root cause traversal is a bit complicated, see #4172. - if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err { - if let Some(err) = err.downcast_ref::() { - assert!( - matches!(err, DataFusionError::ResourcesExhausted(_)), - "Wrong inner error type: {}", - err, - ); + let stream: SendableRecordBatchStream = stream.into(); + let err = common::collect(stream).await.unwrap_err(); + + // error root cause traversal is a bit complicated, see #4172. + if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err { + if let Some(err) = err.downcast_ref::() { + assert!( + matches!(err, DataFusionError::ResourcesExhausted(_)), + "Wrong inner error type: {}", + err, + ); + } else { + panic!("Wrong arrow error type: {err}") + } } else { - panic!("Wrong arrow error type: {err}") + panic!("Wrong outer error type: {err}") } - } else { - panic!("Wrong outer error type: {err}") } Ok(()) From defec7536203ee0d41af906c42b1d314c5095305 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 25 Nov 2022 12:54:50 +0100 Subject: [PATCH 2/3] fix: `ScalarValue` size calculations --- datafusion/common/src/scalar.rs | 38 ++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 9a1119469a59b..fd6a248ee40d1 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -2295,7 +2295,7 @@ impl ScalarValue { /// Estimate size if bytes including `Self`. For values with internal containers such as `String` /// includes the allocated size (`capacity`) rather than the current length (`len`) pub fn size(&self) -> usize { - std::mem::size_of_val(&self) + std::mem::size_of_val(self) + match self { ScalarValue::Null | ScalarValue::Boolean(_) @@ -2362,7 +2362,8 @@ impl ScalarValue { /// /// Includes the size of the [`Vec`] container itself. pub fn size_of_vec(vec: &Vec) -> usize { - (std::mem::size_of::() * vec.capacity()) + std::mem::size_of_val(vec) + + (std::mem::size_of::() * vec.capacity()) + vec .iter() .map(|sv| sv.size() - std::mem::size_of_val(sv)) @@ -2373,7 +2374,8 @@ impl ScalarValue { /// /// Includes the size of the [`HashSet`] container itself. pub fn size_of_hashset(set: &HashSet) -> usize { - (std::mem::size_of::() * set.capacity()) + std::mem::size_of_val(set) + + (std::mem::size_of::() * set.capacity()) + set .iter() .map(|sv| sv.size() - std::mem::size_of_val(sv)) @@ -3279,6 +3281,36 @@ mod tests { assert_eq!(std::mem::size_of::(), 48); } + #[test] + fn memory_size() { + let sv = ScalarValue::Binary(Some(Vec::with_capacity(10))); + assert_eq!(sv.size(), std::mem::size_of::() + 10,); + let sv_size = sv.size(); + + let mut v = Vec::with_capacity(10); + // do NOT clone `sv` here because this may shrink the vector capacity + v.push(sv); + assert_eq!(v.capacity(), 10); + assert_eq!( + ScalarValue::size_of_vec(&v), + std::mem::size_of::>() + + (9 * std::mem::size_of::()) + + sv_size, + ); + + let mut s = HashSet::with_capacity(0); + // do NOT clone `sv` here because this may shrink the vector capacity + s.insert(v.pop().unwrap()); + // hashsets may easily grow during insert, so capacity is dynamic + let s_capacity = s.capacity(); + assert_eq!( + ScalarValue::size_of_hashset(&s), + std::mem::size_of::>() + + ((s_capacity - 1) * std::mem::size_of::()) + + sv_size, + ); + } + #[test] fn scalar_eq_array() { // Validate that eq_array has the same semantics as ScalarValue::eq From 386467347f79cca64c53c8a7557341cfbed1e9f2 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Mon, 28 Nov 2022 11:16:37 +0100 Subject: [PATCH 3/3] refactor: de-dup code --- .../core/src/physical_plan/aggregates/hash.rs | 28 ++++++++----------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/hash.rs b/datafusion/core/src/physical_plan/aggregates/hash.rs index e0edf5e259a62..d3d5a337e02fd 100644 --- a/datafusion/core/src/physical_plan/aggregates/hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/hash.rs @@ -401,22 +401,18 @@ fn group_aggregate_batch( .collect::>(), ) }) - .try_for_each(|(accumulator, values)| match mode { - AggregateMode::Partial => { - let size_pre = accumulator.size(); - let res = accumulator.update_batch(&values); - let size_post = accumulator.size(); - allocated += size_post.saturating_sub(size_pre); - res - } - AggregateMode::FinalPartitioned | AggregateMode::Final => { - // note: the aggregation here is over states, not values, thus the merge - let size_pre = accumulator.size(); - let res = accumulator.merge_batch(&values); - let size_post = accumulator.size(); - allocated += size_post.saturating_sub(size_pre); - res - } + .try_for_each(|(accumulator, values)| { + let size_pre = accumulator.size(); + let res = match mode { + AggregateMode::Partial => accumulator.update_batch(&values), + AggregateMode::FinalPartitioned | AggregateMode::Final => { + // note: the aggregation here is over states, not values, thus the merge + accumulator.merge_batch(&values) + } + }; + let size_post = accumulator.size(); + allocated += size_post.saturating_sub(size_pre); + res }) // 2.5 .and({