From 5bee284e2489b5533cdf7ba78e6c000a66db20b0 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 31 Jul 2024 13:30:22 -0400 Subject: [PATCH] Add `TestPlanBuilder` to simplify creating test `ExecutionPlans` --- .../src/physical_optimizer/enforce_sorting.rs | 91 +++++++++---------- .../core/src/physical_optimizer/test_utils.rs | 81 +++++++++++++++++ 2 files changed, 125 insertions(+), 47 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/enforce_sorting.rs b/datafusion/core/src/physical_optimizer/enforce_sorting.rs index cf9d33252ad9d..e222adf6b6f24 100644 --- a/datafusion/core/src/physical_optimizer/enforce_sorting.rs +++ b/datafusion/core/src/physical_optimizer/enforce_sorting.rs @@ -616,11 +616,11 @@ mod tests { use super::*; use crate::physical_optimizer::enforce_distribution::EnforceDistribution; use crate::physical_optimizer::test_utils::{ - aggregate_exec, bounded_window_exec, check_integrity, coalesce_batches_exec, - coalesce_partitions_exec, filter_exec, global_limit_exec, hash_join_exec, - limit_exec, local_limit_exec, memory_exec, parquet_exec, parquet_exec_sorted, - repartition_exec, sort_exec, sort_expr, sort_expr_options, sort_merge_join_exec, - sort_preserving_merge_exec, spr_repartition_exec, union_exec, + aggregate_exec, bounded_window_exec, check_integrity, coalesce_partitions_exec, + filter_exec, global_limit_exec, hash_join_exec, limit_exec, local_limit_exec, + memory_exec, parquet_exec, parquet_exec_sorted, repartition_exec, sort_exec, + sort_expr, sort_expr_options, sort_merge_join_exec, sort_preserving_merge_exec, + spr_repartition_exec, union_exec, TestPlanBuilder, }; use crate::physical_plan::{displayable, get_plan_string, Partitioning}; use crate::prelude::{SessionConfig, SessionContext}; @@ -754,9 +754,10 @@ mod tests { #[tokio::test] async fn test_remove_unnecessary_sort() -> Result<()> { let schema = create_test_schema()?; - let source = memory_exec(&schema); - let input = sort_exec(vec![sort_expr("non_nullable_col", &schema)], source); - let physical_plan = sort_exec(vec![sort_expr("nullable_col", &schema)], input); + let physical_plan = TestPlanBuilder::new_memory_exec(&schema) + .sort(vec!["non_nullable_col"]) + .sort(vec!["nullable_col"]) + .build(); let expected_input = [ "SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -775,43 +776,39 @@ mod tests { #[tokio::test] async fn test_remove_unnecessary_sort_window_multilayer() -> Result<()> { let schema = create_test_schema()?; - let source = memory_exec(&schema); let sort_exprs = vec![sort_expr_options( "non_nullable_col", - &source.schema(), + &schema, SortOptions { descending: true, nulls_first: true, }, )]; - let sort = sort_exec(sort_exprs.clone(), source); - // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let coalesce_batches = coalesce_batches_exec(sort); - let window_agg = - bounded_window_exec("non_nullable_col", sort_exprs, coalesce_batches); + let builder = TestPlanBuilder::new_memory_exec(&schema) + .sort_by_exprs(sort_exprs.clone()) + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before + .coalesce_batches() + .bounded_window("non_nullable_col", sort_exprs); let sort_exprs = vec![sort_expr_options( "non_nullable_col", - &window_agg.schema(), + builder.schema().as_ref(), SortOptions { descending: false, nulls_first: false, }, )]; - let sort = sort_exec(sort_exprs.clone(), window_agg); - - // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before - let filter = filter_exec( - Arc::new(NotExpr::new( + let physical_plan = builder + .sort_by_exprs(sort_exprs.clone()) + // Add dummy layer propagating Sort above, to test whether sort can be removed from multi layer before + .filter(Arc::new(NotExpr::new( col("non_nullable_col", schema.as_ref()).unwrap(), - )), - sort, - ); - - let physical_plan = bounded_window_exec("non_nullable_col", sort_exprs, filter); + ))) + .bounded_window("non_nullable_col", sort_exprs) + .build(); let expected_input = ["BoundedWindowAggExec: wdw=[count: Ok(Field { name: \"count\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(NULL), end_bound: CurrentRow, is_causal: false }], mode=[Sorted]", " FilterExec: NOT non_nullable_col@1", @@ -835,11 +832,9 @@ mod tests { #[tokio::test] async fn test_add_required_sort() -> Result<()> { let schema = create_test_schema()?; - let source = memory_exec(&schema); - - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - - let physical_plan = sort_preserving_merge_exec(sort_exprs, source); + let physical_plan = TestPlanBuilder::new_memory_exec(&schema) + .sort_preserving_merge(vec![sort_expr("nullable_col", &schema)]) + .build(); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", @@ -857,14 +852,14 @@ mod tests { #[tokio::test] async fn test_remove_unnecessary_sort1() -> Result<()> { let schema = create_test_schema()?; - let source = memory_exec(&schema); let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); + let physical_plan = TestPlanBuilder::new_memory_exec(&schema) + .sort_by_exprs(sort_exprs.clone()) + .sort_preserving_merge(sort_exprs.clone()) + .sort_by_exprs(sort_exprs.clone()) + .sort_preserving_merge(sort_exprs) + .build(); - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), spm); - let physical_plan = sort_preserving_merge_exec(sort_exprs, sort); let expected_input = [ "SortPreservingMergeExec: [nullable_col@0 ASC]", " SortExec: expr=[nullable_col@0 ASC], preserve_partitioning=[false]", @@ -884,21 +879,23 @@ mod tests { #[tokio::test] async fn test_remove_unnecessary_sort2() -> Result<()> { let schema = create_test_schema()?; - let source = memory_exec(&schema); - let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; - let sort = sort_exec(sort_exprs.clone(), source); - let spm = sort_preserving_merge_exec(sort_exprs, sort); - let sort_exprs = vec![ + let sort_exprs = vec![sort_expr("non_nullable_col", &schema)]; + let sort_exprs2 = vec![ sort_expr("nullable_col", &schema), sort_expr("non_nullable_col", &schema), ]; - let sort2 = sort_exec(sort_exprs.clone(), spm); - let spm2 = sort_preserving_merge_exec(sort_exprs, sort2); + let sort_exprs3 = vec![sort_expr("nullable_col", &schema)]; - let sort_exprs = vec![sort_expr("nullable_col", &schema)]; - let sort3 = sort_exec(sort_exprs, spm2); - let physical_plan = repartition_exec(repartition_exec(sort3)); + let physical_plan = TestPlanBuilder::new_memory_exec(&schema) + .sort_by_exprs(sort_exprs.clone()) + .sort_preserving_merge(sort_exprs) + .sort_by_exprs(sort_exprs2.clone()) + .sort_preserving_merge(sort_exprs2) + .sort_by_exprs(sort_exprs3) + .repartition() + .repartition() + .build(); let expected_input = [ "RepartitionExec: partitioning=RoundRobinBatch(10), input_partitions=10", diff --git a/datafusion/core/src/physical_optimizer/test_utils.rs b/datafusion/core/src/physical_optimizer/test_utils.rs index 5320938d2eb88..47634f7fda0bc 100644 --- a/datafusion/core/src/physical_optimizer/test_utils.rs +++ b/datafusion/core/src/physical_optimizer/test_utils.rs @@ -186,6 +186,87 @@ pub fn sort_merge_join_exec( ) } +/// Builder for creating ExecutionPlans for testing +pub struct TestPlanBuilder { + plan: Arc, +} + +impl TestPlanBuilder { + /// Start a new plan with a MemoryExec + pub fn new_memory_exec(schema: &SchemaRef) -> Self { + let plan = MemoryExec::try_new(&[vec![]], schema.clone(), None).unwrap(); + Self { + plan: Arc::new(plan), + } + } + + /// Adds a `SortExec to the current plan that sorts by the provided column names + /// with default options + pub fn sort<'a>(mut self, sort_exprs: impl IntoIterator) -> Self { + let sort_exprs = sort_exprs + .into_iter() + .map(|name| sort_expr(name, &self.plan.schema())) + .collect::>(); + self.plan = sort_exec(sort_exprs, self.plan); + self + } + + /// Adds a `SortExec to the current plan that sorts by the provided sort exprs + pub fn sort_by_exprs( + mut self, + sort_exprs: impl IntoIterator, + ) -> Self { + self.plan = sort_exec(sort_exprs, self.plan); + self + } + + /// Add a CoalesceBatches exec with a target size of 128 + pub fn coalesce_batches(mut self) -> Self { + self.plan = coalesce_batches_exec(self.plan); + self + } + + /// Add a [FilterExec] with the provided predicate + pub fn filter(mut self, predicate: Arc) -> Self { + self.plan = filter_exec(predicate, self.plan); + self + } + + /// Add a BoundWindowAggExec with a count aggregate with col_name as the argument + pub fn bounded_window( + mut self, + col_name: &str, + sort_exprs: impl IntoIterator, + ) -> Self { + self.plan = bounded_window_exec(col_name, sort_exprs, self.plan); + self + } + + /// Add a SortPreservingMergeExec with the provided sort exprs + pub fn sort_preserving_merge( + mut self, + sort_exprs: impl IntoIterator, + ) -> Self { + self.plan = sort_preserving_merge_exec(sort_exprs, self.plan); + self + } + + /// Add a [RepartitionExec] + pub fn repartition(mut self) -> Self { + self.plan = repartition_exec(self.plan); + self + } + + /// Return the current schema of the plan + pub fn schema(&self) -> SchemaRef { + self.plan.schema() + } + + pub fn build(self) -> Arc { + self.plan + } +} + /// make PhysicalSortExpr with default options pub fn sort_expr(name: &str, schema: &Schema) -> PhysicalSortExpr { sort_expr_options(name, schema, SortOptions::default())