Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 192 additions & 13 deletions datafusion/core/src/physical_plan/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::any::Any;
use std::cmp::Ordering;
use std::collections::VecDeque;
use std::fmt::Formatter;
use std::mem;
use std::ops::Range;
use std::pin::Pin;
use std::sync::Arc;
Expand All @@ -39,6 +40,7 @@ use futures::{Stream, StreamExt};
use crate::error::DataFusionError;
use crate::error::Result;
use crate::execution::context::TaskContext;
use crate::execution::memory_pool::{MemoryConsumer, MemoryReservation};
use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::expressions::PhysicalSortExpr;
Expand Down Expand Up @@ -305,6 +307,10 @@ impl ExecutionPlan for SortMergeJoinExec {
// create output buffer
let batch_size = context.session_config().batch_size();

// create memory reservation
let reservation = MemoryConsumer::new(format!("SMJStream[{partition}]"))
.register(context.memory_pool());

// create join stream
Ok(Box::pin(SMJStream::try_new(
self.schema.clone(),
Expand All @@ -317,6 +323,7 @@ impl ExecutionPlan for SortMergeJoinExec {
self.join_type,
batch_size,
SortMergeJoinMetrics::new(partition, &self.metrics),
reservation,
)?))
}

Expand Down Expand Up @@ -362,6 +369,9 @@ struct SortMergeJoinMetrics {
output_batches: metrics::Count,
/// Number of rows produced by this operator
output_rows: metrics::Count,
/// Peak memory used for buffered data.
/// Calculated as sum of peak memory values across partitions
peak_mem_used: metrics::Gauge,
}

impl SortMergeJoinMetrics {
Expand All @@ -374,13 +384,15 @@ impl SortMergeJoinMetrics {
let output_batches =
MetricBuilder::new(metrics).counter("output_batches", partition);
let output_rows = MetricBuilder::new(metrics).output_rows(partition);
let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition);

Self {
join_time,
input_batches,
input_rows,
output_batches,
output_rows,
peak_mem_used,
}
}
}
Expand Down Expand Up @@ -505,15 +517,34 @@ struct BufferedBatch {
pub join_arrays: Vec<ArrayRef>,
/// Buffered joined index (null joining buffered)
pub null_joined: Vec<usize>,
/// Size estimation used for reserving / releasing memory
pub size_estimation: usize,
}
impl BufferedBatch {
fn new(batch: RecordBatch, range: Range<usize>, on_column: &[Column]) -> Self {
let join_arrays = join_arrays(&batch, on_column);

// Estimation is calculated as
// inner batch size
// + join keys size
// + worst case null_joined (as vector capacity * element size)
// + Range size
// + size of this estimation
let size_estimation = batch.get_array_memory_size()
+ join_arrays
.iter()
.map(|arr| arr.get_array_memory_size())
.sum::<usize>()
+ batch.num_rows().next_power_of_two() * mem::size_of::<usize>()
+ mem::size_of::<Range<usize>>()
+ mem::size_of::<usize>();

BufferedBatch {
batch,
range,
join_arrays,
null_joined: vec![],
size_estimation,
}
}
}
Expand Down Expand Up @@ -565,6 +596,8 @@ struct SMJStream {
pub join_type: JoinType,
/// Metrics
pub join_metrics: SortMergeJoinMetrics,
/// Memory reservation
pub reservation: MemoryReservation,
}

impl RecordBatchStream for SMJStream {
Expand Down Expand Up @@ -682,6 +715,7 @@ impl SMJStream {
join_type: JoinType,
batch_size: usize,
join_metrics: SortMergeJoinMetrics,
reservation: MemoryReservation,
) -> Result<Self> {
let streamed_schema = streamed.schema();
let buffered_schema = buffered.schema();
Expand All @@ -708,6 +742,7 @@ impl SMJStream {
batch_size,
join_type,
join_metrics,
reservation,
})
}

Expand Down Expand Up @@ -763,7 +798,11 @@ impl SMJStream {
let head_batch = self.buffered_data.head_batch();
if head_batch.range.end == head_batch.batch.num_rows() {
self.freeze_dequeuing_buffered()?;
self.buffered_data.batches.pop_front();
if let Some(buffered_batch) =
self.buffered_data.batches.pop_front()
{
self.reservation.shrink(buffered_batch.size_estimation);
}
} else {
break;
}
Expand All @@ -789,11 +828,14 @@ impl SMJStream {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if batch.num_rows() > 0 {
self.buffered_data.batches.push_back(BufferedBatch::new(
batch,
0..1,
&self.on_buffered,
));
let buffered_batch =
BufferedBatch::new(batch, 0..1, &self.on_buffered);
self.reservation.try_grow(buffered_batch.size_estimation)?;
self.join_metrics
.peak_mem_used
.set_max(self.reservation.size());

self.buffered_data.batches.push_back(buffered_batch);
self.buffered_state = BufferedState::PollingRest;
}
}
Expand Down Expand Up @@ -827,15 +869,19 @@ impl SMJStream {
}
Poll::Ready(Some(batch)) => {
self.join_metrics.input_batches.add(1);
self.join_metrics.input_rows.add(batch.num_rows());
if batch.num_rows() > 0 {
self.join_metrics.input_rows.add(batch.num_rows());
self.buffered_data.batches.push_back(
BufferedBatch::new(
batch,
0..0,
&self.on_buffered,
),
let buffered_batch = BufferedBatch::new(
batch,
0..0,
&self.on_buffered,
);
self.reservation
.try_grow(buffered_batch.size_estimation)?;
self.join_metrics
.peak_mem_used
.set_max(self.reservation.size());
self.buffered_data.batches.push_back(buffered_batch);
}
}
}
Expand Down Expand Up @@ -1315,7 +1361,9 @@ mod tests {
use arrow::datatypes::{DataType, Field, Schema};
use arrow::record_batch::RecordBatch;

use crate::common::assert_contains;
use crate::error::Result;
use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use crate::logical_expr::JoinType;
use crate::physical_plan::expressions::Column;
use crate::physical_plan::joins::utils::JoinOn;
Expand Down Expand Up @@ -2212,4 +2260,135 @@ mod tests {
assert_batches_sorted_eq!(expected, &batches);
Ok(())
}

#[tokio::test]
async fn overallocation_single_batch() -> Result<()> {
let left = build_table(
("a1", &vec![0, 1, 2, 3, 4, 5]),
("b1", &vec![1, 2, 3, 4, 5, 6]),
("c1", &vec![4, 5, 6, 7, 8, 9]),
);
let right = build_table(
("a2", &vec![0, 10, 20, 30, 40]),
("b2", &vec![1, 3, 4, 6, 8]),
("c2", &vec![50, 60, 70, 80, 90]),
);
let on = vec![(
Column::new_with_schema("b1", &left.schema())?,
Column::new_with_schema("b2", &right.schema())?,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
];

for join_type in join_types {
let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0);
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_config = SessionConfig::default().with_batch_size(50);
let session_ctx = SessionContext::with_config_rt(session_config, runtime);
let task_ctx = session_ctx.task_ctx();
let join = join_with_options(
left.clone(),
right.clone(),
on.clone(),
join_type,
sort_options.clone(),
false,
)?;

let stream = join.execute(0, task_ctx)?;
let err = common::collect(stream).await.unwrap_err();

assert_contains!(
err.to_string(),
"Resources exhausted: Failed to allocate additional"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the unit test

);
assert_contains!(err.to_string(), "SMJStream[0]");
}

Ok(())
}

#[tokio::test]
async fn overallocation_multi_batch() -> Result<()> {
let left_batch_1 = build_table_i32(
("a1", &vec![0, 1]),
("b1", &vec![1, 1]),
("c1", &vec![4, 5]),
);
let left_batch_2 = build_table_i32(
("a1", &vec![2, 3]),
("b1", &vec![1, 1]),
("c1", &vec![6, 7]),
);
let left_batch_3 = build_table_i32(
("a1", &vec![4, 5]),
("b1", &vec![1, 1]),
("c1", &vec![8, 9]),
);
let right_batch_1 = build_table_i32(
("a2", &vec![0, 10]),
("b2", &vec![1, 1]),
("c2", &vec![50, 60]),
);
let right_batch_2 = build_table_i32(
("a2", &vec![20, 30]),
("b2", &vec![1, 1]),
("c2", &vec![70, 80]),
);
let right_batch_3 =
build_table_i32(("a2", &vec![40]), ("b2", &vec![1]), ("c2", &vec![90]));
let left =
build_table_from_batches(vec![left_batch_1, left_batch_2, left_batch_3]);
let right =
build_table_from_batches(vec![right_batch_1, right_batch_2, right_batch_3]);
let on = vec![(
Column::new_with_schema("b1", &left.schema())?,
Column::new_with_schema("b2", &right.schema())?,
)];
let sort_options = vec![SortOptions::default(); on.len()];

let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
JoinType::LeftSemi,
JoinType::LeftAnti,
];

for join_type in join_types {
let runtime_config = RuntimeConfig::new().with_memory_limit(100, 1.0);
let runtime = Arc::new(RuntimeEnv::new(runtime_config)?);
let session_config = SessionConfig::default().with_batch_size(50);
let session_ctx = SessionContext::with_config_rt(session_config, runtime);
let task_ctx = session_ctx.task_ctx();
let join = join_with_options(
left.clone(),
right.clone(),
on.clone(),
join_type,
sort_options.clone(),
false,
)?;

let stream = join.execute(0, task_ctx)?;
let err = common::collect(stream).await.unwrap_err();

assert_contains!(
err.to_string(),
"Resources exhausted: Failed to allocate additional"
);
assert_contains!(err.to_string(), "SMJStream[0]");
}

Ok(())
}
}
5 changes: 5 additions & 0 deletions datafusion/core/src/physical_plan/metrics/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ impl Gauge {
self.value.fetch_sub(n, Ordering::Relaxed);
}

/// Set metric's value to maximum of `n` and current value
pub fn set_max(&self, n: usize) {
self.value.fetch_max(n, Ordering::Relaxed);
}

/// Set the metric's value to `n` and return the previous value
pub fn set(&self, n: usize) -> usize {
// relaxed ordering for operations on `value` poses no issues
Expand Down
Loading