Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3821,6 +3821,142 @@ mod tests {
Ok(())
}

/// Tests that when the memory pool is too small to accommodate the sort
/// reservation during spill, the error is properly propagated as
/// ResourcesExhausted rather than silently exceeding memory limits.
#[tokio::test]
async fn test_sort_reservation_fails_during_spill() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("g", DataType::Int64, false),
Field::new("a", DataType::Float64, false),
Field::new("b", DataType::Float64, false),
Field::new("c", DataType::Float64, false),
Field::new("d", DataType::Float64, false),
Field::new("e", DataType::Float64, false),
]));

let batches = vec![vec![
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![1])),
Arc::new(Float64Array::from(vec![10.0])),
Arc::new(Float64Array::from(vec![20.0])),
Arc::new(Float64Array::from(vec![30.0])),
Arc::new(Float64Array::from(vec![40.0])),
Arc::new(Float64Array::from(vec![50.0])),
],
)?,
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![2])),
Arc::new(Float64Array::from(vec![11.0])),
Arc::new(Float64Array::from(vec![21.0])),
Arc::new(Float64Array::from(vec![31.0])),
Arc::new(Float64Array::from(vec![41.0])),
Arc::new(Float64Array::from(vec![51.0])),
],
)?,
RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int64Array::from(vec![3])),
Arc::new(Float64Array::from(vec![12.0])),
Arc::new(Float64Array::from(vec![22.0])),
Arc::new(Float64Array::from(vec![32.0])),
Arc::new(Float64Array::from(vec![42.0])),
Arc::new(Float64Array::from(vec![52.0])),
],
)?,
]];

let scan = TestMemoryExec::try_new(&batches, Arc::clone(&schema), None)?;

let aggr = Arc::new(AggregateExec::try_new(
AggregateMode::Single,
PhysicalGroupBy::new(
vec![(col("g", schema.as_ref())?, "g".to_string())],
vec![],
vec![vec![false]],
false,
),
vec![
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("a", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(a)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("b", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(b)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("c", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(c)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("d", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(d)")
.build()?,
),
Arc::new(
AggregateExprBuilder::new(
avg_udaf(),
vec![col("e", schema.as_ref())?],
)
.schema(Arc::clone(&schema))
.alias("AVG(e)")
.build()?,
),
],
vec![None, None, None, None, None],
Arc::new(scan) as Arc<dyn ExecutionPlan>,
Arc::clone(&schema),
)?);

// Pool must be large enough for accumulation to start but too small for
// sort_memory after clearing.
let task_ctx = new_spill_ctx(1, 500);
let result = collect(aggr.execute(0, Arc::clone(&task_ctx))?).await;

match &result {
Ok(_) => panic!("Expected ResourcesExhausted error but query succeeded"),
Err(e) => {
let root = e.find_root();
assert!(
matches!(root, DataFusionError::ResourcesExhausted(_)),
"Expected ResourcesExhausted, got: {root}",
);
let msg = root.to_string();
assert!(
msg.contains("Failed to reserve memory for sort during spill"),
"Expected sort reservation error, got: {msg}",
);
}
}

Ok(())
}

/// Tests that PartialReduce mode:
/// 1. Accepts state as input (like Final)
/// 2. Produces state as output (like Partial)
Expand Down
72 changes: 60 additions & 12 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,16 @@ use crate::aggregates::{
create_schema, evaluate_group_by, evaluate_many, evaluate_optional,
};
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::spill::spill_manager::SpillManager;
use crate::spill::spill_manager::{GetSlicedSize, SpillManager};
use crate::{PhysicalExpr, aggregates, metrics};
use crate::{RecordBatchStream, SendableRecordBatchStream};

use arrow::array::*;
use arrow::datatypes::SchemaRef;
use datafusion_common::{
DataFusionError, Result, assert_eq_or_internal_err, assert_or_internal_err,
internal_err,
internal_err, resources_datafusion_err,
};
use datafusion_execution::TaskContext;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
Expand All @@ -51,7 +50,9 @@ use datafusion_physical_expr::expressions::Column;
use datafusion_physical_expr::{GroupsAccumulatorAdapter, PhysicalSortExpr};
use datafusion_physical_expr_common::sort_expr::LexOrdering;

use crate::sorts::IncrementalSortIterator;
use datafusion_common::instant::Instant;
use datafusion_common::utils::memory::get_record_batch_memory_size;
use futures::ready;
use futures::stream::{Stream, StreamExt};
use log::debug;
Expand Down Expand Up @@ -1048,10 +1049,27 @@ impl GroupedHashAggregateStream {

fn update_memory_reservation(&mut self) -> Result<()> {
let acc = self.accumulators.iter().map(|x| x.size()).sum::<usize>();
let new_size = acc
let groups_and_acc_size = acc
+ self.group_values.size()
+ self.group_ordering.size()
+ self.current_group_indices.allocated_size();

// Reserve extra headroom for sorting during potential spill.
// When OOM triggers, group_aggregate_batch has already processed the
// latest input batch, so the internal state may have grown well beyond
// the last successful reservation. The emit batch reflects this larger
// actual state, and the sort needs memory proportional to it.
// By reserving headroom equal to the data size, we trigger OOM earlier
// (before too much data accumulates), ensuring the freed reservation
// after clear_shrink is sufficient to cover the sort memory.
let sort_headroom =
if self.oom_mode == OutOfMemoryMode::Spill && !self.group_values.is_empty() {
acc + self.group_values.size()
} else {
0
};

let new_size = groups_and_acc_size + sort_headroom;
let reservation_result = self.reservation.try_resize(new_size);

if reservation_result.is_ok() {
Expand Down Expand Up @@ -1110,17 +1128,47 @@ impl GroupedHashAggregateStream {
let Some(emit) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
let sorted = sort_batch(&emit, &self.spill_state.spill_expr, None)?;

// Spill sorted state to disk
// Free accumulated state now that data has been emitted into `emit`.
// This must happen before reserving sort memory so the pool has room.
// Use 0 to minimize allocated capacity and maximize memory available for sorting.
self.clear_shrink(0);
self.update_memory_reservation()?;

let batch_size_ratio = self.batch_size as f32 / emit.num_rows() as f32;
let batch_memory = get_record_batch_memory_size(&emit);
// The maximum worst case for a sort is 2X the original underlying buffers(regardless of slicing)
// First we get the underlying buffers' size, then we get the sliced("actual") size of the batch,
// and multiply it by the ratio of batch_size to actual size to get the estimated memory needed for sorting the batch.
// If something goes wrong in get_sliced_size()(double counting or something),
// we fall back to the worst case.
let sort_memory = (batch_memory
+ (emit.get_sliced_size()? as f32 * batch_size_ratio) as usize)
.min(batch_memory * 2);

// If we can't grow even that, we have no choice but to return an error since we can't spill to disk without sorting the data first.
self.reservation.try_grow(sort_memory).map_err(|err| {
resources_datafusion_err!(
"Failed to reserve memory for sort during spill: {err}"
)
})?;

let sorted_iter = IncrementalSortIterator::new(
emit,
self.spill_state.spill_expr.clone(),
self.batch_size,
);
let spillfile = self
.spill_state
.spill_manager
.spill_record_batch_by_size_and_return_max_batch_memory(
&sorted,
.spill_record_batch_iter_and_return_max_batch_memory(
sorted_iter,
"HashAggSpill",
self.batch_size,
)?;

// Shrink the memory we allocated for sorting as the sorting is fully done at this point.
self.reservation.shrink(sort_memory);

match spillfile {
Some((spillfile, max_record_batch_memory)) => {
self.spill_state.spills.push(SortedSpillFile {
Expand All @@ -1138,14 +1186,14 @@ impl GroupedHashAggregateStream {
Ok(())
}

/// Clear memory and shirk capacities to the size of the batch.
/// Clear memory and shrink capacities to the given number of rows.
fn clear_shrink(&mut self, num_rows: usize) {
self.group_values.clear_shrink(num_rows);
self.current_group_indices.clear();
self.current_group_indices.shrink_to(num_rows);
}

/// Clear memory and shirk capacities to zero.
/// Clear memory and shrink capacities to zero.
fn clear_all(&mut self) {
self.clear_shrink(0);
}
Expand Down Expand Up @@ -1184,7 +1232,7 @@ impl GroupedHashAggregateStream {
// instead.
// Spilling to disk and reading back also ensures batch size is consistent
// rather than potentially having one significantly larger last batch.
self.spill()?; // TODO: use sort_batch_chunked instead?
self.spill()?;

// Mark that we're switching to stream merging mode.
self.spill_state.is_stream_merging = true;
Expand Down
2 changes: 2 additions & 0 deletions datafusion/physical-plan/src/sorts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ pub mod sort;
pub mod sort_preserving_merge;
mod stream;
pub mod streaming_merge;

pub(crate) use stream::IncrementalSortIterator;
47 changes: 9 additions & 38 deletions datafusion/physical-plan/src/sorts/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ use crate::metrics::{
BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, SpillMetrics,
};
use crate::projection::{ProjectionExec, make_with_child, update_ordering};
use crate::sorts::IncrementalSortIterator;
use crate::sorts::streaming_merge::{SortedSpillFile, StreamingMergeBuilder};
use crate::spill::get_record_batch_memory_size;
use crate::spill::in_progress_spill_file::InProgressSpillFile;
Expand Down Expand Up @@ -726,7 +727,6 @@ impl ExternalSorter {

// Sort the batch immediately and get all output batches
let sorted_batches = sort_batch_chunked(&batch, &expressions, batch_size)?;
drop(batch);

// Free the old reservation and grow it to match the actual sorted output size
reservation.free();
Expand Down Expand Up @@ -851,11 +851,13 @@ pub(crate) fn get_reserved_bytes_for_record_batch_size(
/// Estimate how much memory is needed to sort a `RecordBatch`.
/// This will just call `get_reserved_bytes_for_record_batch_size` with the
/// memory size of the record batch and its sliced size.
pub(super) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result<usize> {
Ok(get_reserved_bytes_for_record_batch_size(
get_record_batch_memory_size(batch),
batch.get_sliced_size()?,
))
pub(crate) fn get_reserved_bytes_for_record_batch(batch: &RecordBatch) -> Result<usize> {
batch.get_sliced_size().map(|sliced_size| {
get_reserved_bytes_for_record_batch_size(
get_record_batch_memory_size(batch),
sliced_size,
)
})
}
Comment thread
EmilyMatt marked this conversation as resolved.

impl Debug for ExternalSorter {
Expand Down Expand Up @@ -898,38 +900,7 @@ pub fn sort_batch_chunked(
expressions: &LexOrdering,
batch_size: usize,
) -> Result<Vec<RecordBatch>> {
let sort_columns = expressions
.iter()
.map(|expr| expr.evaluate_to_sort_column(batch))
.collect::<Result<Vec<_>>>()?;

let indices = lexsort_to_indices(&sort_columns, None)?;

// Split indices into chunks of batch_size
let num_rows = indices.len();
let num_chunks = num_rows.div_ceil(batch_size);

let result_batches = (0..num_chunks)
.map(|chunk_idx| {
let start = chunk_idx * batch_size;
let end = (start + batch_size).min(num_rows);
let chunk_len = end - start;

// Create a slice of indices for this chunk
let chunk_indices = indices.slice(start, chunk_len);

// Take the columns using this chunk of indices
let columns = take_arrays(batch.columns(), &chunk_indices, None)?;

let options = RecordBatchOptions::new().with_row_count(Some(chunk_len));
let chunk_batch =
RecordBatch::try_new_with_options(batch.schema(), columns, &options)?;

Ok(chunk_batch)
})
.collect::<Result<Vec<RecordBatch>>>()?;

Ok(result_batches)
IncrementalSortIterator::new(batch.clone(), expressions.clone(), batch_size).collect()
}

/// Sort execution plan.
Expand Down
Loading