From b89d603e9e1fc5b1df879374faa724a3a32eead7 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Fri, 11 Nov 2022 09:59:39 +0100 Subject: [PATCH 1/4] refactor: remove needless async --- datafusion/core/src/execution/memory_manager.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/core/src/execution/memory_manager.rs b/datafusion/core/src/execution/memory_manager.rs index 48d4ca3c3d32a..e7148b06606c6 100644 --- a/datafusion/core/src/execution/memory_manager.rs +++ b/datafusion/core/src/execution/memory_manager.rs @@ -178,10 +178,8 @@ pub trait MemoryConsumer: Send + Sync { self.id(), ); - let can_grow_directly = self - .memory_manager() - .can_grow_directly(required, current) - .await; + let can_grow_directly = + self.memory_manager().can_grow_directly(required, current); if !can_grow_directly { debug!( "Failed to grow memory of {} directly from consumer {}, spilling first ...", @@ -334,7 +332,7 @@ impl MemoryManager { } /// Grow memory attempt from a consumer, return if we could grant that much to it - async fn can_grow_directly(&self, required: usize, current: usize) -> bool { + fn can_grow_directly(&self, required: usize, current: usize) -> bool { let num_rqt = self.requesters.lock().len(); let mut rqt_current_used = self.requesters_total.lock(); let mut rqt_max = self.max_mem_for_requesters(); From a0fc6cdd86aa630496a1fda4caad16e71aaca7b1 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Mon, 14 Nov 2022 11:36:31 +0100 Subject: [PATCH 2/4] feat: wire memory management into `GroupedHashAggregateStreamV2` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Most of it is refactoring to allow us to call the async memory subsystem while polling the stream. The actual memory accounting is rather easy (since it's only ever growing except when the stream is dropped). Helps with #3940. (not closing yet, also need to do V1) Performance Impact: ------------------- ```text ❯ cargo bench -p datafusion --bench aggregate_query_sql -- --baseline issue3940a-pre Finished bench [optimized] target(s) in 0.08s Running benches/aggregate_query_sql.rs (target/release/deps/aggregate_query_sql-e9e315ab7a06a262) aggregate_query_no_group_by 15 12 time: [654.77 µs 655.49 µs 656.29 µs] change: [-1.6711% -1.2910% -0.8435%] (p = 0.00 < 0.05) Change within noise threshold. Found 9 outliers among 100 measurements (9.00%) 1 (1.00%) low mild 5 (5.00%) high mild 3 (3.00%) high severe aggregate_query_no_group_by_min_max_f64 time: [579.93 µs 580.59 µs 581.27 µs] change: [-3.8985% -3.2219% -2.6198%] (p = 0.00 < 0.05) Performance has improved. Found 9 outliers among 100 measurements (9.00%) 1 (1.00%) low severe 3 (3.00%) low mild 1 (1.00%) high mild 4 (4.00%) high severe aggregate_query_no_group_by_count_distinct_wide time: [2.4610 ms 2.4801 ms 2.4990 ms] change: [-2.9300% -1.8414% -0.7493%] (p = 0.00 < 0.05) Change within noise threshold. Benchmarking aggregate_query_no_group_by_count_distinct_narrow: Warming up for 3.0000 s Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 8.4s, enable flat sampling, or reduce sample count to 50. aggregate_query_no_group_by_count_distinct_narrow time: [1.6578 ms 1.6661 ms 1.6743 ms] change: [-4.5391% -3.5033% -2.5050%] (p = 0.00 < 0.05) Performance has improved. Found 7 outliers among 100 measurements (7.00%) 1 (1.00%) low severe 2 (2.00%) low mild 2 (2.00%) high mild 2 (2.00%) high severe aggregate_query_group_by time: [2.1767 ms 2.2045 ms 2.2486 ms] change: [-4.1048% -2.5858% -0.3237%] (p = 0.00 < 0.05) Change within noise threshold. Found 1 outliers among 100 measurements (1.00%) 1 (1.00%) high severe Benchmarking aggregate_query_group_by_with_filter: Warming up for 3.0000 s Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 5.5s, enable flat sampling, or reduce sample count to 60. aggregate_query_group_by_with_filter time: [1.0916 ms 1.0927 ms 1.0941 ms] change: [-0.8524% -0.4230% -0.0724%] (p = 0.02 < 0.05) Change within noise threshold. Found 9 outliers among 100 measurements (9.00%) 2 (2.00%) low severe 1 (1.00%) low mild 4 (4.00%) high mild 2 (2.00%) high severe aggregate_query_group_by_u64 15 12 time: [2.2108 ms 2.2238 ms 2.2368 ms] change: [-4.2142% -3.2743% -2.3523%] (p = 0.00 < 0.05) Performance has improved. Benchmarking aggregate_query_group_by_with_filter_u64 15 12: Warming up for 3.0000 s Warning: Unable to complete 100 samples in 5.0s. You may wish to increase target time to 5.5s, enable flat sampling, or reduce sample count to 60. aggregate_query_group_by_with_filter_u64 15 12 time: [1.0922 ms 1.0931 ms 1.0940 ms] change: [-0.6872% -0.3192% +0.1193%] (p = 0.12 > 0.05) No change in performance detected. Found 7 outliers among 100 measurements (7.00%) 3 (3.00%) low mild 4 (4.00%) high severe aggregate_query_group_by_u64_multiple_keys time: [14.714 ms 15.023 ms 15.344 ms] change: [-5.8337% -2.7471% +0.2798%] (p = 0.09 > 0.05) No change in performance detected. aggregate_query_approx_percentile_cont_on_u64 time: [3.7776 ms 3.8049 ms 3.8329 ms] change: [-4.4977% -3.4230% -2.3282%] (p = 0.00 < 0.05) Performance has improved. Found 2 outliers among 100 measurements (2.00%) 2 (2.00%) high mild aggregate_query_approx_percentile_cont_on_f32 time: [3.1769 ms 3.1997 ms 3.2230 ms] change: [-4.4664% -3.2597% -2.0955%] (p = 0.00 < 0.05) Performance has improved. Found 1 outliers among 100 measurements (1.00%) 1 (1.00%) high mild ``` I think the mild improvements are either flux or due to the somewhat manual memory allocation pattern. --- .../core/src/physical_plan/aggregates/mod.rs | 66 +++- .../src/physical_plan/aggregates/row_hash.rs | 357 ++++++++++++++---- 2 files changed, 346 insertions(+), 77 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/mod.rs b/datafusion/core/src/physical_plan/aggregates/mod.rs index 43e75e3520109..6ce58592d83b5 100644 --- a/datafusion/core/src/physical_plan/aggregates/mod.rs +++ b/datafusion/core/src/physical_plan/aggregates/mod.rs @@ -348,7 +348,7 @@ impl ExecutionPlan for AggregateExec { context: Arc, ) -> Result { let batch_size = context.session_config().batch_size(); - let input = self.input.execute(partition, context)?; + let input = self.input.execute(partition, Arc::clone(&context))?; let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -369,6 +369,8 @@ impl ExecutionPlan for AggregateExec { input, baseline_metrics, batch_size, + context, + partition, )?)) } else { Ok(Box::pin(GroupedHashAggregateStream::new( @@ -689,7 +691,8 @@ fn evaluate_group_by( #[cfg(test)] mod tests { - use crate::execution::context::TaskContext; + use crate::execution::context::{SessionConfig, TaskContext}; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::from_slice::FromSlice; use crate::physical_plan::aggregates::{ AggregateExec, AggregateMode, PhysicalGroupBy, @@ -700,7 +703,7 @@ mod tests { use crate::{assert_batches_sorted_eq, physical_plan::common}; use arrow::array::{Float64Array, UInt32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; - use arrow::error::Result as ArrowResult; + use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_physical_expr::expressions::{lit, Count}; @@ -1081,6 +1084,63 @@ mod tests { check_grouping_sets(input).await } + #[tokio::test] + async fn test_oom() -> Result<()> { + let input: Arc = + Arc::new(TestYieldingExec { yield_first: true }); + let input_schema = input.schema(); + + let session_ctx = SessionContext::with_config_rt( + SessionConfig::default(), + Arc::new( + RuntimeEnv::new(RuntimeConfig::default().with_memory_limit(1, 1.0)) + .unwrap(), + ), + ); + let task_ctx = session_ctx.task_ctx(); + + let groups = PhysicalGroupBy { + expr: vec![(col("a", &input_schema)?, "a".to_string())], + null_expr: vec![], + groups: vec![vec![false]], + }; + + let aggregates: Vec> = vec![Arc::new(Avg::new( + col("b", &input_schema)?, + "AVG(b)".to_string(), + DataType::Float64, + ))]; + + let partial_aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + groups, + aggregates, + input, + input_schema.clone(), + )?); + + let err = common::collect(partial_aggregate.execute(0, task_ctx.clone())?) + .await + .unwrap_err(); + + // error root cause traversal is a bit complicated, see #4172. + if let DataFusionError::ArrowError(ArrowError::ExternalError(err)) = err { + if let Some(err) = err.downcast_ref::() { + assert!( + matches!(err, DataFusionError::ResourcesExhausted(_)), + "Wrong inner error type: {}", + err, + ); + } else { + panic!("Wrong arrow error type: {err}") + } + } else { + panic!("Wrong outer error type: {err}") + } + + Ok(()) + } + #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { let session_ctx = SessionContext::new(); diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index aefc6571b068a..93297bdeac37b 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -22,12 +22,14 @@ use std::task::{Context, Poll}; use std::vec; use ahash::RandomState; -use futures::{ - ready, - stream::{Stream, StreamExt}, -}; +use async_trait::async_trait; +use futures::stream::BoxStream; +use futures::stream::{Stream, StreamExt}; use crate::error::Result; +use crate::execution::context::TaskContext; +use crate::execution::memory_manager::ConsumerType; +use crate::execution::{MemoryConsumer, MemoryConsumerId, MemoryManager}; use crate::physical_plan::aggregates::{ evaluate_group_by, evaluate_many, group_schema, AccumulatorItemV2, AggregateMode, PhysicalGroupBy, @@ -45,7 +47,7 @@ use arrow::{ error::{ArrowError, Result as ArrowResult}, }; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_row::accessor::RowAccessor; use datafusion_row::layout::RowLayout; use datafusion_row::reader::{read_row, RowReader}; @@ -70,6 +72,16 @@ use hashbrown::raw::RawTable; /// [Compact]: datafusion_row::layout::RowType::Compact /// [WordAligned]: datafusion_row::layout::RowType::WordAligned pub(crate) struct GroupedHashAggregateStreamV2 { + stream: BoxStream<'static, ArrowResult>, + schema: SchemaRef, +} + +/// Actual implementation of [`GroupedHashAggregateStreamV2`]. +/// +/// This is wrapped into yet another struct because we need to interact with the async memory management subsystem +/// during poll. To have as little code "weirdness" as possible, we chose to just use [`BoxStream`] together with +/// [`futures::stream::unfold`]. The latter requires a state object, which is [`GroupedHashAggregateStreamV2Inner`]. +struct GroupedHashAggregateStreamV2Inner { schema: SchemaRef, input: SendableRecordBatchStream, mode: AggregateMode, @@ -102,6 +114,7 @@ fn aggr_state_schema(aggr_expr: &[Arc]) -> Result impl GroupedHashAggregateStreamV2 { /// Create a new GroupedRowHashAggregateStream + #[allow(clippy::too_many_arguments)] pub fn new( mode: AggregateMode, schema: SchemaRef, @@ -110,6 +123,8 @@ impl GroupedHashAggregateStreamV2 { input: SendableRecordBatchStream, baseline_metrics: BaselineMetrics, batch_size: usize, + context: Arc, + partition: usize, ) -> Result { let timer = baseline_metrics.elapsed_compute().timer(); @@ -125,10 +140,24 @@ impl GroupedHashAggregateStreamV2 { let aggr_schema = aggr_state_schema(&aggr_expr)?; let aggr_layout = Arc::new(RowLayout::new(&aggr_schema, RowType::WordAligned)); + + let aggr_state = AggregationState { + memory_consumer: AggregationStateMemoryConsumer { + id: MemoryConsumerId::new(partition), + memory_manager: Arc::clone(&context.runtime_env().memory_manager), + used: 0, + }, + map: RawTable::with_capacity(0), + group_states: Vec::with_capacity(0), + }; + context + .runtime_env() + .register_requester(aggr_state.memory_consumer.id()); + timer.done(); - Ok(Self { - schema, + let inner = GroupedHashAggregateStreamV2Inner { + schema: Arc::clone(&schema), mode, input, group_by, @@ -138,11 +167,78 @@ impl GroupedHashAggregateStreamV2 { aggr_layout, baseline_metrics, aggregate_expressions, - aggr_state: Default::default(), + aggr_state, random_state: Default::default(), batch_size, row_group_skip_position: 0, - }) + }; + + let stream = futures::stream::unfold(inner, |mut this| async move { + let elapsed_compute = this.baseline_metrics.elapsed_compute(); + + loop { + let result: ArrowResult> = + match this.input.next().await { + Some(Ok(batch)) => { + let timer = elapsed_compute.timer(); + let result = group_aggregate_batch( + &this.mode, + &this.random_state, + &this.group_by, + &mut this.accumulators, + &this.group_schema, + this.aggr_layout.clone(), + batch, + &mut this.aggr_state, + &this.aggregate_expressions, + ) + .await; + + timer.done(); + + match result { + Ok(_) => continue, + Err(e) => Err(ArrowError::ExternalError(Box::new(e))), + } + } + Some(Err(e)) => Err(e), + None => { + let timer = this.baseline_metrics.elapsed_compute().timer(); + let result = create_batch_from_map( + &this.mode, + &this.group_schema, + &this.aggr_schema, + this.batch_size, + this.row_group_skip_position, + &mut this.aggr_state, + &mut this.accumulators, + &this.schema, + ); + + timer.done(); + result + } + }; + + this.row_group_skip_position += this.batch_size; + match result { + Ok(Some(result)) => { + return Some(( + Ok(result.record_output(&this.baseline_metrics)), + this, + )); + } + Ok(None) => return None, + Err(error) => return Some((Err(error), this)), + } + } + }); + + // seems like some consumers call this stream even after it returned `None`, so let's fuse the stream. + let stream = stream.fuse(); + let stream = Box::pin(stream); + + Ok(Self { schema, stream }) } } @@ -154,63 +250,7 @@ impl Stream for GroupedHashAggregateStreamV2 { cx: &mut Context<'_>, ) -> Poll> { let this = &mut *self; - - let elapsed_compute = this.baseline_metrics.elapsed_compute(); - - loop { - let result: ArrowResult> = - match ready!(this.input.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - let timer = elapsed_compute.timer(); - let result = group_aggregate_batch( - &this.mode, - &this.random_state, - &this.group_by, - &mut this.accumulators, - &this.group_schema, - this.aggr_layout.clone(), - batch, - &mut this.aggr_state, - &this.aggregate_expressions, - ); - - timer.done(); - - match result { - Ok(_) => continue, - Err(e) => Err(ArrowError::ExternalError(Box::new(e))), - } - } - Some(Err(e)) => Err(e), - None => { - let timer = this.baseline_metrics.elapsed_compute().timer(); - let result = create_batch_from_map( - &this.mode, - &this.group_schema, - &this.aggr_schema, - this.batch_size, - this.row_group_skip_position, - &mut this.aggr_state, - &mut this.accumulators, - &this.schema, - ); - - timer.done(); - result - } - }; - - this.row_group_skip_position += this.batch_size; - match result { - Ok(Some(result)) => { - return Poll::Ready(Some(Ok( - result.record_output(&this.baseline_metrics) - ))) - } - Ok(None) => return Poll::Ready(None), - Err(error) => return Poll::Ready(Some(Err(error))), - } - } + this.stream.poll_next_unpin(cx) } } @@ -222,7 +262,7 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 { /// TODO: Make this a member function of [`GroupedHashAggregateStreamV2`] #[allow(clippy::too_many_arguments)] -fn group_aggregate_batch( +async fn group_aggregate_batch( mode: &AggregateMode, random_state: &RandomState, grouping_set: &PhysicalGroupBy, @@ -236,6 +276,13 @@ fn group_aggregate_batch( // evaluate the grouping expressions let grouping_by_values = evaluate_group_by(grouping_set, &batch)?; + let AggregationState { + map, + group_states, + memory_consumer, + } = aggr_state; + let mut memory_pool = ShortLivedMemoryPool::new(memory_consumer); + for group_values in grouping_by_values { let group_rows: Vec> = create_group_rows(group_values, group_schema); @@ -256,8 +303,6 @@ fn group_aggregate_batch( create_row_hashes(&group_rows, random_state, &mut batch_hashes)?; for (row, hash) in batch_hashes.into_iter().enumerate() { - let AggregationState { map, group_states } = aggr_state; - let entry = map.get_mut(hash, |(_hash, group_idx)| { // verify that a group that we are inserting with hash is // actually the same key value as the group in @@ -270,10 +315,25 @@ fn group_aggregate_batch( // Existing entry for this group value Some((_hash, group_idx)) => { let group_state = &mut group_states[*group_idx]; + // 1.3 if group_state.indices.is_empty() { groups_with_rows.push(*group_idx); }; + + // ensure we have enough indices allocated + if group_state.indices.capacity() == group_state.indices.len() { + // allocate more + + // growth factor: 2, but at least 2 elements + let bump_elements = (group_state.indices.capacity() * 2).max(2); + let bump_size = std::mem::size_of::() * bump_elements; + + memory_pool.alloc(bump_size).await?; + + group_state.indices.reserve(bump_elements); + } + group_state.indices.push(row as u32); // remember this row } // 1.2 Need to create new entry @@ -285,11 +345,61 @@ fn group_aggregate_batch( indices: vec![row as u32], // 1.3 }; let group_idx = group_states.len(); + + // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by + // `group_states` (see allocation check down below) + let mut bump_size_total = (std::mem::size_of::() + * group_state.group_by_values.capacity()) + + (std::mem::size_of::() + * group_state.aggregation_buffer.capacity()) + + (std::mem::size_of::() * group_state.indices.capacity()); + + // ensure that `group_states` has enough space + let reserve_groups_states = + if group_states.capacity() == group_states.len() { + // growth factor: 2, but at least 16 elements + let bump_elements = (group_states.capacity() * 2).max(16); + let bump_size = + bump_elements * std::mem::size_of::(); + bump_size_total += bump_size; + + Some(bump_elements) + } else { + None + }; + + // for hasher function, use precomputed hash value + let reserve_map = + if map.try_insert_no_grow(hash, (hash, group_idx)).is_err() { + // need to request more memory + + let bump_elements = (map.capacity() * 2).max(16); + let bump_size = + bump_elements * std::mem::size_of::<(u64, usize)>(); + bump_size_total += bump_size; + + Some(bump_elements) + } else { + None + }; + + // allocate once + memory_pool.alloc(bump_size_total).await?; + + if let Some(bump_elements) = reserve_groups_states { + group_states.reserve(bump_elements); + } group_states.push(group_state); + groups_with_rows.push(group_idx); - // for hasher function, use precomputed hash value - map.insert(hash, (hash, group_idx), |(hash, _group_idx)| *hash); + if let Some(bump_elements) = reserve_map { + map.reserve(bump_elements, |(hash, _group_index)| *hash); + + // still need to insert the element since first try failed + map.try_insert_no_grow(hash, (hash, group_idx)) + .expect("just grew the container"); + } } }; } @@ -299,7 +409,7 @@ fn group_aggregate_batch( let mut offsets = vec![0]; let mut offset_so_far = 0; for group_idx in groups_with_rows.iter() { - let indices = &aggr_state.group_states[*group_idx].indices; + let indices = &group_states[*group_idx].indices; batch_indices.append_slice(indices); offset_so_far += indices.len(); offsets.push(offset_so_far); @@ -334,7 +444,7 @@ fn group_aggregate_batch( .iter() .zip(offsets.windows(2)) .try_for_each(|(group_idx, offsets)| { - let group_state = &mut aggr_state.group_states[*group_idx]; + let group_state = &mut group_states[*group_idx]; // 2.2 accumulators .iter_mut() @@ -392,8 +502,9 @@ struct RowGroupState { } /// The state of all the groups -#[derive(Default)] struct AggregationState { + memory_consumer: AggregationStateMemoryConsumer, + /// Logically maps group values to an index in `group_states` /// /// Uses the raw API of hashbrown to avoid actually storing the @@ -418,6 +529,104 @@ impl std::fmt::Debug for AggregationState { } } +/// Accounting data structure for memory usage. +struct AggregationStateMemoryConsumer { + /// Consumer ID. + id: MemoryConsumerId, + + /// Linked memory manager. + memory_manager: Arc, + + /// Currently used size in bytes. + used: usize, +} + +#[async_trait] +impl MemoryConsumer for AggregationStateMemoryConsumer { + fn name(&self) -> String { + "AggregationState".to_owned() + } + + fn id(&self) -> &crate::execution::MemoryConsumerId { + &self.id + } + + fn memory_manager(&self) -> Arc { + Arc::clone(&self.memory_manager) + } + + fn type_(&self) -> &ConsumerType { + &ConsumerType::Tracking + } + + async fn spill(&self) -> Result { + Err(DataFusionError::ResourcesExhausted( + "Cannot spill AggregationState".to_owned(), + )) + } + + fn mem_used(&self) -> usize { + self.used + } +} + +impl Drop for AggregationStateMemoryConsumer { + fn drop(&mut self) { + self.memory_manager + .drop_consumer(self.id(), self.mem_used()); + } +} + +/// Memory pool that can be used in a function scope. +/// +/// This is helpful if there are many small memory allocations (so the overhead if tracking them in [`MemoryManager`] is +/// high due to lock contention) and pre-calculating the entire allocation for a whole [`RecordBatch`] is complicated or +/// expensive. +/// +/// The pool will try to allocate a whole block of memory and gives back overallocated memory on [drop](Self::drop). +struct ShortLivedMemoryPool<'a> { + pool: &'a mut AggregationStateMemoryConsumer, + block_size: usize, + remaining: usize, +} + +impl<'a> ShortLivedMemoryPool<'a> { + fn new(pool: &'a mut AggregationStateMemoryConsumer) -> Self { + Self { + pool, + block_size: 1024 * 1024, // 1MB + remaining: 0, + } + } + + async fn alloc(&mut self, mut bytes: usize) -> Result<()> { + // are there enough bytes left within the current block? + if bytes <= self.remaining { + self.remaining -= bytes; + return Ok(()); + } + + // we can already use the remaining bytes from the current block + bytes -= self.remaining; + + // need to allocate a new block + let alloc_size = bytes.max(self.block_size); + self.pool.try_grow(alloc_size).await?; + self.pool.used += alloc_size; + self.remaining = alloc_size - bytes; + + Ok(()) + } +} + +impl<'a> Drop for ShortLivedMemoryPool<'a> { + fn drop(&mut self) { + // give back over-allocated memory + self.pool.shrink(self.remaining); + self.pool.used -= self.remaining; + } +} + /// Create grouping rows fn create_group_rows(arrays: Vec, schema: &Schema) -> Vec> { let mut writer = RowWriter::new(schema, RowType::Compact); From a3ab17b6e604990c952a5dc36947d833a9ec0172 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 16 Nov 2022 12:22:30 +0100 Subject: [PATCH 3/4] refactor: simplify memory accounting --- .../src/physical_plan/aggregates/row_hash.rs | 180 ++++++++---------- 1 file changed, 82 insertions(+), 98 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index 93297bdeac37b..b813f54f824a0 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -53,7 +53,7 @@ use datafusion_row::layout::RowLayout; use datafusion_row::reader::{read_row, RowReader}; use datafusion_row::writer::{write_row, RowWriter}; use datafusion_row::{MutableRecordBatch, RowType}; -use hashbrown::raw::RawTable; +use hashbrown::raw::{Bucket, RawTable}; /// Grouping aggregate with row-format aggregation states inside. /// @@ -281,7 +281,7 @@ async fn group_aggregate_batch( group_states, memory_consumer, } = aggr_state; - let mut memory_pool = ShortLivedMemoryPool::new(memory_consumer); + let mut allocated = 0usize; for group_values in grouping_by_values { let group_rows: Vec> = create_group_rows(group_values, group_schema); @@ -321,20 +321,9 @@ async fn group_aggregate_batch( groups_with_rows.push(*group_idx); }; - // ensure we have enough indices allocated - if group_state.indices.capacity() == group_state.indices.len() { - // allocate more - - // growth factor: 2, but at least 2 elements - let bump_elements = (group_state.indices.capacity() * 2).max(2); - let bump_size = std::mem::size_of::() * bump_elements; - - memory_pool.alloc(bump_size).await?; - - group_state.indices.reserve(bump_elements); - } - - group_state.indices.push(row as u32); // remember this row + group_state + .indices + .push_accounted(row as u32, &mut allocated); // remember this row } // 1.2 Need to create new entry None => { @@ -347,63 +336,32 @@ async fn group_aggregate_batch( let group_idx = group_states.len(); // NOTE: do NOT include the `RowGroupState` struct size in here because this is captured by - // `group_states` (see allocation check down below) - let mut bump_size_total = (std::mem::size_of::() + // `group_states` (see allocation down below) + allocated += (std::mem::size_of::() * group_state.group_by_values.capacity()) + (std::mem::size_of::() * group_state.aggregation_buffer.capacity()) + (std::mem::size_of::() * group_state.indices.capacity()); - // ensure that `group_states` has enough space - let reserve_groups_states = - if group_states.capacity() == group_states.len() { - // growth factor: 2, but at least 16 elements - let bump_elements = (group_states.capacity() * 2).max(16); - let bump_size = - bump_elements * std::mem::size_of::(); - bump_size_total += bump_size; - - Some(bump_elements) - } else { - None - }; - // for hasher function, use precomputed hash value - let reserve_map = - if map.try_insert_no_grow(hash, (hash, group_idx)).is_err() { - // need to request more memory - - let bump_elements = (map.capacity() * 2).max(16); - let bump_size = - bump_elements * std::mem::size_of::<(u64, usize)>(); - bump_size_total += bump_size; - - Some(bump_elements) - } else { - None - }; + map.insert_accounted( + (hash, group_idx), + |(hash, _group_index)| *hash, + &mut allocated, + ); - // allocate once - memory_pool.alloc(bump_size_total).await?; - - if let Some(bump_elements) = reserve_groups_states { - group_states.reserve(bump_elements); - } - group_states.push(group_state); + group_states.push_accounted(group_state, &mut allocated); groups_with_rows.push(group_idx); - - if let Some(bump_elements) = reserve_map { - map.reserve(bump_elements, |(hash, _group_index)| *hash); - - // still need to insert the element since first try failed - map.try_insert_no_grow(hash, (hash, group_idx)) - .expect("just grew the container"); - } } }; } + // allocate memory + // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with + // overshooting a bit. Also this means we either store the whole record batch or not. + memory_consumer.alloc(allocated).await?; + // Collect all indices + offsets based on keys in this vec let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0); let mut offsets = vec![0]; @@ -570,6 +528,14 @@ impl MemoryConsumer for AggregationStateMemoryConsumer { } } +impl AggregationStateMemoryConsumer { + async fn alloc(&mut self, bytes: usize) -> Result<()> { + self.try_grow(bytes).await?; + self.used = self.used.checked_add(bytes).expect("overflow"); + Ok(()) + } +} + impl Drop for AggregationStateMemoryConsumer { fn drop(&mut self) { self.memory_manager @@ -577,53 +543,71 @@ impl Drop for AggregationStateMemoryConsumer { } } -/// Memory pool that can be used in a function scope. -/// -/// This is helpful if there are many small memory allocations (so the overhead if tracking them in [`MemoryManager`] is -/// high due to lock contention) and pre-calculating the entire allocation for a whole [`RecordBatch`] is complicated or -/// expensive. -/// -/// The pool will try to allocate a whole block of memory and gives back overallocated memory on [drop](Self::drop). -struct ShortLivedMemoryPool<'a> { - pool: &'a mut AggregationStateMemoryConsumer, - block_size: usize, - remaining: usize, +trait VecAllocExt { + type T; + + fn push_accounted(&mut self, x: Self::T, accounting: &mut usize); } -impl<'a> ShortLivedMemoryPool<'a> { - fn new(pool: &'a mut AggregationStateMemoryConsumer) -> Self { - Self { - pool, - block_size: 1024 * 1024, // 1MB - remaining: 0, - } - } +impl VecAllocExt for Vec { + type T = T; + + fn push_accounted(&mut self, x: Self::T, accounting: &mut usize) { + if self.capacity() == self.len() { + // allocate more - async fn alloc(&mut self, mut bytes: usize) -> Result<()> { - // are there enough bytes left within the current block? - if bytes <= self.remaining { - self.remaining -= bytes; - return Ok(()); + // growth factor: 2, but at least 2 elements + let bump_elements = (self.capacity() * 2).max(2); + let bump_size = std::mem::size_of::() * bump_elements; + self.reserve(bump_elements); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); } - // we can already use the remaining bytes from the current block - bytes -= self.remaining; + self.push(x); + } +} - // need to allocate a new block - let alloc_size = bytes.max(self.block_size); - self.pool.try_grow(alloc_size).await?; - self.pool.used += alloc_size; - self.remaining = alloc_size - bytes; +trait RawTableAllocExt { + type T; - Ok(()) - } + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ) -> Bucket; } -impl<'a> Drop for ShortLivedMemoryPool<'a> { - fn drop(&mut self) { - // give back over-allocated memory - self.pool.shrink(self.remaining); - self.pool.used -= self.remaining; +impl RawTableAllocExt for RawTable { + type T = T; + + fn insert_accounted( + &mut self, + x: Self::T, + hasher: impl Fn(&Self::T) -> u64, + accounting: &mut usize, + ) -> Bucket { + let hash = hasher(&x); + + match self.try_insert_no_grow(hash, x) { + Ok(bucket) => bucket, + Err(x) => { + // need to request more memory + + let bump_elements = (self.capacity() * 2).max(16); + let bump_size = bump_elements * std::mem::size_of::(); + *accounting = (*accounting).checked_add(bump_size).expect("overflow"); + + self.reserve(bump_elements, hasher); + + // still need to insert the element since first try failed + // Note: cannot use `.expect` here because `T` may not implement `Debug` + match self.try_insert_no_grow(hash, x) { + Ok(bucket) => bucket, + Err(_) => panic!("just grew the container"), + } + } + } } } From abebcedebf79e9bdbcd7b70a6c64d40993ed6ff0 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Thu, 17 Nov 2022 11:00:18 +0100 Subject: [PATCH 4/4] refactor: de-couple memory allocation --- .../src/physical_plan/aggregates/row_hash.rs | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/datafusion/core/src/physical_plan/aggregates/row_hash.rs b/datafusion/core/src/physical_plan/aggregates/row_hash.rs index b813f54f824a0..c6658b2a6ee54 100644 --- a/datafusion/core/src/physical_plan/aggregates/row_hash.rs +++ b/datafusion/core/src/physical_plan/aggregates/row_hash.rs @@ -191,13 +191,22 @@ impl GroupedHashAggregateStreamV2 { batch, &mut this.aggr_state, &this.aggregate_expressions, - ) - .await; + ); timer.done(); + // allocate memory + // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with + // overshooting a bit. Also this means we either store the whole record batch or not. + let result = match result { + Ok(allocated) => { + this.aggr_state.memory_consumer.alloc(allocated).await + } + Err(e) => Err(e), + }; + match result { - Ok(_) => continue, + Ok(()) => continue, Err(e) => Err(ArrowError::ExternalError(Box::new(e))), } } @@ -260,9 +269,13 @@ impl RecordBatchStream for GroupedHashAggregateStreamV2 { } } +/// Perform group-by aggregation for the given [`RecordBatch`]. +/// +/// If successfull, this returns the additional number of bytes that were allocated during this process. +/// /// TODO: Make this a member function of [`GroupedHashAggregateStreamV2`] #[allow(clippy::too_many_arguments)] -async fn group_aggregate_batch( +fn group_aggregate_batch( mode: &AggregateMode, random_state: &RandomState, grouping_set: &PhysicalGroupBy, @@ -272,14 +285,12 @@ async fn group_aggregate_batch( batch: RecordBatch, aggr_state: &mut AggregationState, aggregate_expressions: &[Vec>], -) -> Result<()> { +) -> Result { // evaluate the grouping expressions let grouping_by_values = evaluate_group_by(grouping_set, &batch)?; let AggregationState { - map, - group_states, - memory_consumer, + map, group_states, .. } = aggr_state; let mut allocated = 0usize; @@ -357,11 +368,6 @@ async fn group_aggregate_batch( }; } - // allocate memory - // This happens AFTER we actually used the memory, but simplifies the whole accounting and we are OK with - // overshooting a bit. Also this means we either store the whole record batch or not. - memory_consumer.alloc(allocated).await?; - // Collect all indices + offsets based on keys in this vec let mut batch_indices: UInt32Builder = UInt32Builder::with_capacity(0); let mut offsets = vec![0]; @@ -442,7 +448,7 @@ async fn group_aggregate_batch( })?; } - Ok(()) + Ok(allocated) } /// The state that is built for each output group.