diff --git a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs index 8fa5145938c4a..50179896f648c 100644 --- a/datafusion/core/src/physical_plan/joins/sort_merge_join.rs +++ b/datafusion/core/src/physical_plan/joins/sort_merge_join.rs @@ -24,6 +24,7 @@ use std::any::Any; use std::cmp::Ordering; use std::collections::VecDeque; use std::fmt::Formatter; +use std::mem; use std::ops::Range; use std::pin::Pin; use std::sync::Arc; @@ -39,6 +40,7 @@ use futures::{Stream, StreamExt}; use crate::error::DataFusionError; use crate::error::Result; use crate::execution::context::TaskContext; +use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation}; use crate::logical_expr::JoinType; use crate::physical_plan::expressions::Column; use crate::physical_plan::expressions::PhysicalSortExpr; @@ -305,6 +307,10 @@ impl ExecutionPlan for SortMergeJoinExec { // create output buffer let batch_size = context.session_config().batch_size(); + // create memory reservation + let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]")) + .register(context.memory_pool()); + // create join stream Ok(Box::pin(SMJStream::try_new( self.schema.clone(), @@ -317,6 +323,7 @@ impl ExecutionPlan for SortMergeJoinExec { self.join_type, batch_size, SortMergeJoinMetrics::new(partition, &self.metrics), + reservation, )?)) } @@ -362,6 +369,9 @@ struct SortMergeJoinMetrics { output_batches: metrics::Count, /// Number of rows produced by this operator output_rows: metrics::Count, + /// Peak memory used for buffered data. + /// Calculated as sum of peak memory values across partitions + peak_mem_used: metrics::Gauge, } impl SortMergeJoinMetrics { @@ -374,6 +384,7 @@ impl SortMergeJoinMetrics { let output_batches = MetricBuilder::new(metrics).counter("output_batches", partition); let output_rows = MetricBuilder::new(metrics).output_rows(partition); + let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition); Self { join_time, @@ -381,6 +392,7 @@ impl SortMergeJoinMetrics { input_rows, output_batches, output_rows, + peak_mem_used, } } } @@ -505,15 +517,34 @@ struct BufferedBatch { pub join_arrays: Vec, /// Buffered joined index (null joining buffered) pub null_joined: Vec, + /// Size estimation used for reserving / releasing memory + pub size_estimation: usize, } impl BufferedBatch { fn new(batch: RecordBatch, range: Range, on_column: &[Column]) -> Self { let join_arrays = join_arrays(&batch, on_column); + + // Estimation is calculated as + // inner batch size + // + join keys size + // + worst case null_joined (as vector capacity * element size) + // + Range size + // + size of this estimation + let size_estimation = batch.get_array_memory_size() + + join_arrays + .iter() + .map(|arr| arr.get_array_memory_size()) + .sum::() + + batch.num_rows().next_power_of_two() * mem::size_of::() + + mem::size_of::>() + + mem::size_of::(); + BufferedBatch { batch, range, join_arrays, null_joined: vec![], + size_estimation, } } } @@ -565,6 +596,8 @@ struct SMJStream { pub join_type: JoinType, /// Metrics pub join_metrics: SortMergeJoinMetrics, + /// Memory reservation + pub reservation: MemoryReservation, } impl RecordBatchStream for SMJStream { @@ -682,6 +715,7 @@ impl SMJStream { join_type: JoinType, batch_size: usize, join_metrics: SortMergeJoinMetrics, + reservation: MemoryReservation, ) -> Result { let streamed_schema = streamed.schema(); let buffered_schema = buffered.schema(); @@ -708,6 +742,7 @@ impl SMJStream { batch_size, join_type, join_metrics, + reservation, }) } @@ -763,7 +798,11 @@ impl SMJStream { let head_batch = self.buffered_data.head_batch(); if head_batch.range.end == head_batch.batch.num_rows() { self.freeze_dequeuing_buffered()?; - self.buffered_data.batches.pop_front(); + if let Some(buffered_batch) = + self.buffered_data.batches.pop_front() + { + self.reservation.shrink(buffered_batch.size_estimation); + } } else { break; } @@ -789,11 +828,14 @@ impl SMJStream { self.join_metrics.input_batches.add(1); self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { - self.buffered_data.batches.push_back(BufferedBatch::new( - batch, - 0..1, - &self.on_buffered, - )); + let buffered_batch = + BufferedBatch::new(batch, 0..1, &self.on_buffered); + self.reservation.try_grow(buffered_batch.size_estimation)?; + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + + self.buffered_data.batches.push_back(buffered_batch); self.buffered_state = BufferedState::PollingRest; } } @@ -827,15 +869,19 @@ impl SMJStream { } Poll::Ready(Some(batch)) => { self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); if batch.num_rows() > 0 { - self.join_metrics.input_rows.add(batch.num_rows()); - self.buffered_data.batches.push_back( - BufferedBatch::new( - batch, - 0..0, - &self.on_buffered, - ), + let buffered_batch = BufferedBatch::new( + batch, + 0..0, + &self.on_buffered, ); + self.reservation + .try_grow(buffered_batch.size_estimation)?; + self.join_metrics + .peak_mem_used + .set_max(self.reservation.size()); + self.buffered_data.batches.push_back(buffered_batch); } } } @@ -1315,7 +1361,9 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use crate::common::assert_contains; use crate::error::Result; + use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use crate::logical_expr::JoinType; use crate::physical_plan::expressions::Column; use crate::physical_plan::joins::utils::JoinOn; @@ -2212,4 +2260,135 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); Ok(()) } + + #[tokio::test] + async fn overallocation_single_batch() -> Result<()> { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![1, 2, 3, 4, 5, 6]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![1, 3, 4, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + for join_type in join_types { + let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + let session_ctx = SessionContext::with_config_rt(session_config, runtime); + let task_ctx = session_ctx.task_ctx(); + let join = join_with_options( + left.clone(), + right.clone(), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let err = common::collect(stream).await.unwrap_err(); + + assert_contains!( + err.to_string(), + "Resources exhausted: Failed to allocate additional" + ); + assert_contains!(err.to_string(), "SMJStream[0]"); + } + + Ok(()) + } + + #[tokio::test] + async fn overallocation_multi_batch() -> Result<()> { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1]), + ("b1", &vec![1, 1]), + ("c1", &vec![4, 5]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![2, 3]), + ("b1", &vec![1, 1]), + ("c1", &vec![6, 7]), + ); + let left_batch_3 = build_table_i32( + ("a1", &vec![4, 5]), + ("b1", &vec![1, 1]), + ("c1", &vec![8, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10]), + ("b2", &vec![1, 1]), + ("c2", &vec![50, 60]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![20, 30]), + ("b2", &vec![1, 1]), + ("c2", &vec![70, 80]), + ); + let right_batch_3 = + build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90])); + let left = + build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]); + let right = + build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]); + let on = vec![( + Column::new_with_schema("b1", &left.schema())?, + Column::new_with_schema("b2", &right.schema())?, + )]; + let sort_options = vec![SortOptions::default(); on.len()]; + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + ]; + + for join_type in join_types { + let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0); + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); + let session_config = SessionConfig::default().with_batch_size(50); + let session_ctx = SessionContext::with_config_rt(session_config, runtime); + let task_ctx = session_ctx.task_ctx(); + let join = join_with_options( + left.clone(), + right.clone(), + on.clone(), + join_type, + sort_options.clone(), + false, + )?; + + let stream = join.execute(0, task_ctx)?; + let err = common::collect(stream).await.unwrap_err(); + + assert_contains!( + err.to_string(), + "Resources exhausted: Failed to allocate additional" + ); + assert_contains!(err.to_string(), "SMJStream[0]"); + } + + Ok(()) + } } diff --git a/datafusion/core/src/physical_plan/metrics/value.rs b/datafusion/core/src/physical_plan/metrics/value.rs index 4df4e75675361..59b012f25a27d 100644 --- a/datafusion/core/src/physical_plan/metrics/value.rs +++ b/datafusion/core/src/physical_plan/metrics/value.rs @@ -129,6 +129,11 @@ impl Gauge { self.value.fetch_sub(n, Ordering::Relaxed); } + /// Set metric's value to maximum of `n` and current value + pub fn set_max(&self, n: usize) { + self.value.fetch_max(n, Ordering::Relaxed); + } + /// Set the metric's value to `n` and return the previous value pub fn set(&self, n: usize) -> usize { // relaxed ordering for operations on `value` poses no issues diff --git a/datafusion/core/tests/memory_limit.rs b/datafusion/core/tests/memory_limit.rs index 9de8aff7767b9..034505f52c966 100644 --- a/datafusion/core/tests/memory_limit.rs +++ b/datafusion/core/tests/memory_limit.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion::datasource::MemTable; +use datafusion::execution::context::SessionState; use datafusion::execution::disk_manager::DiskManagerConfig; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion_common::assert_contains; @@ -37,7 +38,9 @@ fn init() { async fn oom_sort() { run_limit_test( "select * from t order by host DESC", - "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", + vec![ + "Resources exhausted: Memory Exhausted while Sorting (DiskManager is disabled)", + ], 200_000, ) .await @@ -47,7 +50,10 @@ async fn oom_sort() { async fn group_by_none() { run_limit_test( "select median(image) from t", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "AggregateStream", + ], 20_000, ) .await @@ -57,7 +63,10 @@ async fn group_by_none() { async fn group_by_row_hash() { run_limit_test( "select count(*) from t GROUP BY response_bytes", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "GroupedHashAggregateStream", + ], 2_000, ) .await @@ -68,18 +77,41 @@ async fn group_by_hash() { run_limit_test( // group by dict column "select count(*) from t GROUP BY service, host, pod, container", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "GroupedHashAggregateStream", + ], 1_000, ) .await } #[tokio::test] -async fn join_by_key() { - run_limit_test( +async fn join_by_key_multiple_partitions() { + let config = SessionConfig::new().with_target_partitions(2); + run_limit_test_with_config( "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "HashJoinStream", + ], 1_000, + config, + ) + .await +} + +#[tokio::test] +async fn join_by_key_single_partition() { + let config = SessionConfig::new().with_target_partitions(1); + run_limit_test_with_config( + "select t1.* from t t1 JOIN t t2 ON t1.service = t2.service", + vec![ + "Resources exhausted: Failed to allocate additional", + "HashJoinExec", + ], + 1_000, + config, ) .await } @@ -88,7 +120,10 @@ async fn join_by_key() { async fn join_by_expression() { run_limit_test( "select t1.* from t t1 JOIN t t2 ON t1.service != t2.service", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "NestedLoopJoinExec", + ], 1_000, ) .await @@ -98,8 +133,30 @@ async fn join_by_expression() { async fn cross_join() { run_limit_test( "select t1.* from t t1 CROSS JOIN t t2", - "Resources exhausted: Failed to allocate additional", + vec![ + "Resources exhausted: Failed to allocate additional", + "CrossJoinExec", + ], + 1_000, + ) + .await +} + +#[tokio::test] +async fn merge_join() { + // Planner chooses MergeJoin only if number of partitions > 1 + let config = SessionConfig::new() + .with_target_partitions(2) + .set_bool("datafusion.optimizer.prefer_hash_join", false); + + run_limit_test_with_config( + "select t1.* from t t1 JOIN t t2 ON t1.pod = t2.pod AND t1.time = t2.time", + vec![ + "Resources exhausted: Failed to allocate additional", + "SMJStream", + ], 1_000, + config, ) .await } @@ -108,8 +165,26 @@ async fn cross_join() { const MEMORY_FRACTION: f64 = 0.95; /// runs the specified query against 1000 rows with a 50 -/// byte memory limit and no disk manager enabled. -async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) { +/// byte memory limit and no disk manager enabled +/// with default SessionConfig. +async fn run_limit_test( + query: &str, + expected_error_contains: Vec<&str>, + memory_limit: usize, +) { + let config = SessionConfig::new(); + run_limit_test_with_config(query, expected_error_contains, memory_limit, config).await +} + +/// runs the specified query against 1000 rows with a 50 +/// byte memory limit and no disk manager enabled +/// with specified SessionConfig instance +async fn run_limit_test_with_config( + query: &str, + expected_error_contains: Vec<&str>, + memory_limit: usize, + config: SessionConfig, +) { let batches: Vec<_> = AccessLogGenerator::new() .with_row_limit(1000) .with_max_batch_size(50) @@ -125,11 +200,12 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) let runtime = RuntimeEnv::new(rt_config).unwrap(); - let ctx = SessionContext::with_config_rt( - // do NOT re-partition (since RepartitionExec has also has a memory budget which we'll likely hit first) - SessionConfig::new().with_target_partitions(1), - Arc::new(runtime), - ); + // Disabling physical optimizer rules to avoid sorts / repartitions + // (since RepartitionExec / SortExec also has a memory budget which we'll likely hit first) + let state = SessionState::with_config_rt(config, Arc::new(runtime)) + .with_physical_optimizer_rules(vec![]); + + let ctx = SessionContext::with_state(state); ctx.register_table("t", Arc::new(table)) .expect("registering table"); @@ -140,7 +216,9 @@ async fn run_limit_test(query: &str, expected_error: &str, memory_limit: usize) panic!("Unexpected success when running, expected memory limit failure") } Err(e) => { - assert_contains!(e.to_string(), expected_error); + for error_substring in expected_error_contains { + assert_contains!(e.to_string(), error_substring); + } } } }