Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 44 additions & 47 deletions datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,11 +616,11 @@ mod tests {
use super::*;
use crate::physical_optimizer::enforce_distribution::EnforceDistribution;
use crate::physical_optimizer::test_utils::{
aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec,
coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec,
limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted,
repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec,
sort_preserving_merge_exec, spr_repartition_exec, union_exec,
aggregate_exec, bounded_window_exec, check_integrity, coalesce_partitions_exec,
filter_exec, global_limit_exec, hash_join_exec, limit_exec, local_limit_exec,
memory_exec, parquet_exec, parquet_exec_sorted, repartition_exec, sort_exec,
sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec,
spr_repartition_exec, union_exec, TestPlanBuilder,
};
use crate::physical_plan::{displayable, get_plan_string, Partitioning};
use crate::prelude::{SessionConfig, SessionContext};
Expand Down Expand Up @@ -754,9 +754,10 @@ mod tests {
#[tokio::test]
async fn test_remove_unnecessary_sort() -> Result<()> {
let schema = create_test_schema()?;
let source = memory_exec(&schema);
let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source);
let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input);
let physical_plan = TestPlanBuilder::new_memory_exec(&schema)
.sort(vec!["non_nullable_col"])
.sort(vec!["nullable_col"])
.build();

let expected_input = [
"SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]",
Expand All @@ -775,43 +776,39 @@ mod tests {
#[tokio::test]
async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> {
let schema = create_test_schema()?;
let source = memory_exec(&schema);

let sort_exprs = vec![sort_expr_options(
"non_nullable_col",
&source.schema(),
&schema,
SortOptions {
descending: true,
nulls_first: true,
},
)];
let sort = sort_exec(sort_exprs.clone(), source);
// Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before
let coalesce_batches = coalesce_batches_exec(sort);

let window_agg =
bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches);
let builder = TestPlanBuilder::new_memory_exec(&schema)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This shows a good example of the difference -- the plan gets built in builder chain rather than several individual function calls

.sort_by_exprs(sort_exprs.clone())
// Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before
.coalesce_batches()
.bounded_window("non_nullable_col", sort_exprs);

let sort_exprs = vec![sort_expr_options(
"non_nullable_col",
&window_agg.schema(),
builder.schema().as_ref(),
SortOptions {
descending: false,
nulls_first: false,
},
)];

let sort = sort_exec(sort_exprs.clone(), window_agg);

// Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before
let filter = filter_exec(
Arc::new(NotExpr::new(
let physical_plan = builder
.sort_by_exprs(sort_exprs.clone())
// Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before
.filter(Arc::new(NotExpr::new(
col("non_nullable_col", schema.as_ref()).unwrap(),
)),
sort,
);

let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter);
)))
.bounded_window("non_nullable_col", sort_exprs)
.build();

let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]",
" FilterExec: NOT non_nullable_col@1",
Expand All @@ -835,11 +832,9 @@ mod tests {
#[tokio::test]
async fn test_add_required_sort() -> Result<()> {
let schema = create_test_schema()?;
let source = memory_exec(&schema);

let sort_exprs = vec![sort_expr("nullable_col", &schema)];

let physical_plan = sort_preserving_merge_exec(sort_exprs, source);
let physical_plan = TestPlanBuilder::new_memory_exec(&schema)
.sort_preserving_merge(vec![sort_expr("nullable_col", &schema)])
.build();

let expected_input = [
"SortPreservingMergeExec: [nullable_col@0 ASC]",
Expand All @@ -857,14 +852,14 @@ mod tests {
#[tokio::test]
async fn test_remove_unnecessary_sort1() -> Result<()> {
let schema = create_test_schema()?;
let source = memory_exec(&schema);
let sort_exprs = vec![sort_expr("nullable_col", &schema)];
let sort = sort_exec(sort_exprs.clone(), source);
let spm = sort_preserving_merge_exec(sort_exprs, sort);
let physical_plan = TestPlanBuilder::new_memory_exec(&schema)
.sort_by_exprs(sort_exprs.clone())
.sort_preserving_merge(sort_exprs.clone())
.sort_by_exprs(sort_exprs.clone())
.sort_preserving_merge(sort_exprs)
.build();

let sort_exprs = vec![sort_expr("nullable_col", &schema)];
let sort = sort_exec(sort_exprs.clone(), spm);
let physical_plan = sort_preserving_merge_exec(sort_exprs, sort);
let expected_input = [
"SortPreservingMergeExec: [nullable_col@0 ASC]",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests already have an exepcted output that nicely confirms the created plan still has the same structure as the previous tests

" SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]",
Expand All @@ -884,21 +879,23 @@ mod tests {
#[tokio::test]
async fn test_remove_unnecessary_sort2() -> Result<()> {
let schema = create_test_schema()?;
let source = memory_exec(&schema);
let sort_exprs = vec![sort_expr("non_nullable_col", &schema)];
let sort = sort_exec(sort_exprs.clone(), source);
let spm = sort_preserving_merge_exec(sort_exprs, sort);

let sort_exprs = vec![
let sort_exprs = vec![sort_expr("non_nullable_col", &schema)];
let sort_exprs2 = vec![
sort_expr("nullable_col", &schema),
sort_expr("non_nullable_col", &schema),
];
let sort2 = sort_exec(sort_exprs.clone(), spm);
let spm2 = sort_preserving_merge_exec(sort_exprs, sort2);
let sort_exprs3 = vec![sort_expr("nullable_col", &schema)];

let sort_exprs = vec![sort_expr("nullable_col", &schema)];
let sort3 = sort_exec(sort_exprs, spm2);
let physical_plan = repartition_exec(repartition_exec(sort3));
let physical_plan = TestPlanBuilder::new_memory_exec(&schema)
.sort_by_exprs(sort_exprs.clone())
.sort_preserving_merge(sort_exprs)
.sort_by_exprs(sort_exprs2.clone())
.sort_preserving_merge(sort_exprs2)
.sort_by_exprs(sort_exprs3)
.repartition()
.repartition()
.build();

let expected_input = [
"RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10",
Expand Down
81 changes: 81 additions & 0 deletions datafusion/core/src/physical_optimizer/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,87 @@ pub fn sort_merge_join_exec(
)
}

/// Builder for creating ExecutionPlans for testing
pub struct TestPlanBuilder {
plan: Arc<dyn ExecutionPlan>,
}

impl TestPlanBuilder {
/// Start a new plan with a MemoryExec
pub fn new_memory_exec(schema: &SchemaRef) -> Self {
let plan = MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap();
Self {
plan: Arc::new(plan),
}
}

/// Adds a `SortExec to the current plan that sorts by the provided column names
/// with default options
pub fn sort<'a>(mut self, sort_exprs: impl IntoIterator<Item = &'a str>) -> Self {
let sort_exprs = sort_exprs
.into_iter()
.map(|name| sort_expr(name, &self.plan.schema()))
.collect::<Vec<_>>();
self.plan = sort_exec(sort_exprs, self.plan);
self
}

/// Adds a `SortExec to the current plan that sorts by the provided sort exprs
pub fn sort_by_exprs(
mut self,
sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
) -> Self {
self.plan = sort_exec(sort_exprs, self.plan);
self
}

/// Add a CoalesceBatches exec with a target size of 128
pub fn coalesce_batches(mut self) -> Self {
self.plan = coalesce_batches_exec(self.plan);
self
}

/// Add a [FilterExec] with the provided predicate
pub fn filter(mut self, predicate: Arc<dyn PhysicalExpr>) -> Self {
self.plan = filter_exec(predicate, self.plan);
self
}

/// Add a BoundWindowAggExec with a count aggregate with col_name as the argument
pub fn bounded_window(
mut self,
col_name: &str,
sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
) -> Self {
self.plan = bounded_window_exec(col_name, sort_exprs, self.plan);
self
}

/// Add a SortPreservingMergeExec with the provided sort exprs
pub fn sort_preserving_merge(
mut self,
sort_exprs: impl IntoIterator<Item = PhysicalSortExpr>,
) -> Self {
self.plan = sort_preserving_merge_exec(sort_exprs, self.plan);
self
}

/// Add a [RepartitionExec]
pub fn repartition(mut self) -> Self {
self.plan = repartition_exec(self.plan);
self
}

/// Return the current schema of the plan
pub fn schema(&self) -> SchemaRef {
self.plan.schema()
}

pub fn build(self) -> Arc<dyn ExecutionPlan> {
self.plan
}
}

/// make PhysicalSortExpr with default options
pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr {
sort_expr_options(name, schema, SortOptions::default())
Expand Down