From e7a961e15c9750cb7cce5d1c0c77b5d1b369cfd9 Mon Sep 17 00:00:00 2001 From: jackwener Date: Mon, 19 Dec 2022 13:13:35 +0800 Subject: [PATCH] optimizer: remove recursion in optimizer rules --- benchmarks/expected-plans/q20.txt | 8 +- datafusion/core/tests/sql/subqueries.rs | 20 +-- datafusion/expr/src/logical_plan/plan.rs | 8 +- .../optimizer/src/decorrelate_where_exists.rs | 80 ++++----- .../optimizer/src/decorrelate_where_in.rs | 118 ++++++++----- datafusion/optimizer/src/eliminate_filter.rs | 167 ++++++++---------- datafusion/optimizer/src/eliminate_limit.rs | 39 ++-- .../optimizer/src/eliminate_outer_join.rs | 38 ++-- .../optimizer/src/filter_null_join_keys.rs | 61 ++----- datafusion/optimizer/src/inline_table_scan.rs | 51 +++--- .../optimizer/src/propagate_empty_relation.rs | 158 ++++++++--------- .../optimizer/src/scalar_subquery_to_join.rs | 125 ++++++++----- datafusion/optimizer/src/test/mod.rs | 62 +++++-- 13 files changed, 479 insertions(+), 456 deletions(-) diff --git a/benchmarks/expected-plans/q20.txt b/benchmarks/expected-plans/q20.txt index 1266622ea6c4d..b2676f61f8eb9 100644 --- a/benchmarks/expected-plans/q20.txt +++ b/benchmarks/expected-plans/q20.txt @@ -1,17 +1,17 @@ Sort: supplier.s_name ASC NULLS LAST Projection: supplier.s_name, supplier.s_address - LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey + LeftSemi Join: supplier.s_suppkey = __sq_1.ps_suppkey Inner Join: supplier.s_nationkey = nation.n_nationkey TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey] Filter: nation.n_name = Utf8("CANADA") TableScan: nation projection=[n_nationkey, n_name] - SubqueryAlias: __sq_2 + SubqueryAlias: __sq_1 Projection: partsupp.ps_suppkey AS ps_suppkey Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey - LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey + LeftSemi Join: partsupp.ps_partkey = __sq_2.p_partkey TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty] - SubqueryAlias: __sq_1 + SubqueryAlias: __sq_2 Projection: part.p_partkey AS p_partkey Filter: part.p_name LIKE Utf8("forest%") TableScan: part projection=[p_partkey, p_name] diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs index e6c98edf59861..d221ddfe28015 100644 --- a/datafusion/core/tests/sql/subqueries.rs +++ b/datafusion/core/tests/sql/subqueries.rs @@ -52,16 +52,16 @@ where c_acctbal < ( let actual = format!("{}", plan.display_indent()); let expected = "Sort: customer.c_custkey ASC NULLS LAST\ \n Projection: customer.c_custkey\ - \n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __sq_2.__value\ - \n Inner Join: customer.c_custkey = __sq_2.o_custkey\ + \n Filter: CAST(customer.c_acctbal AS Decimal128(25, 2)) < __sq_1.__value\ + \n Inner Join: customer.c_custkey = __sq_1.o_custkey\ \n TableScan: customer projection=[c_custkey, c_acctbal]\ - \n SubqueryAlias: __sq_2\ + \n SubqueryAlias: __sq_1\ \n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]]\ - \n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __sq_1.__value\ - \n Inner Join: orders.o_orderkey = __sq_1.l_orderkey\ + \n Filter: CAST(orders.o_totalprice AS Decimal128(25, 2)) < __sq_2.__value\ + \n Inner Join: orders.o_orderkey = __sq_2.l_orderkey\ \n TableScan: orders projection=[o_orderkey, o_custkey, o_totalprice]\ - \n SubqueryAlias: __sq_1\ + \n SubqueryAlias: __sq_2\ \n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS price AS __value\ \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]]\ \n TableScan: lineitem projection=[l_orderkey, l_extendedprice]"; @@ -324,18 +324,18 @@ order by s_name; let actual = format!("{}", plan.display_indent()); let expected = "Sort: supplier.s_name ASC NULLS LAST\ \n Projection: supplier.s_name, supplier.s_address\ - \n LeftSemi Join: supplier.s_suppkey = __sq_2.ps_suppkey\ + \n LeftSemi Join: supplier.s_suppkey = __sq_1.ps_suppkey\ \n Inner Join: supplier.s_nationkey = nation.n_nationkey\ \n TableScan: supplier projection=[s_suppkey, s_name, s_address, s_nationkey]\ \n Filter: nation.n_name = Utf8(\"CANADA\")\ \n TableScan: nation projection=[n_nationkey, n_name], partial_filters=[nation.n_name = Utf8(\"CANADA\")]\ - \n SubqueryAlias: __sq_2\ + \n SubqueryAlias: __sq_1\ \n Projection: partsupp.ps_suppkey AS ps_suppkey\ \n Filter: CAST(partsupp.ps_availqty AS Float64) > __sq_3.__value\ \n Inner Join: partsupp.ps_partkey = __sq_3.l_partkey, partsupp.ps_suppkey = __sq_3.l_suppkey\ - \n LeftSemi Join: partsupp.ps_partkey = __sq_1.p_partkey\ + \n LeftSemi Join: partsupp.ps_partkey = __sq_2.p_partkey\ \n TableScan: partsupp projection=[ps_partkey, ps_suppkey, ps_availqty]\ - \n SubqueryAlias: __sq_1\ + \n SubqueryAlias: __sq_2\ \n Projection: part.p_partkey AS p_partkey\ \n Filter: part.p_name LIKE Utf8(\"forest%\")\ \n TableScan: part projection=[p_partkey, p_name], partial_filters=[part.p_name LIKE Utf8(\"forest%\")]\ diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 14dfe71437dfc..9d7fdf8f0a0c0 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -1329,12 +1329,16 @@ impl SubqueryAlias { /// If the value of `` is true, the input row is passed to /// the output. If the value of `` is false, the row is /// discarded. +/// +/// Filter should not be created directly but instead use `try_new()` +/// and that these fields are only pub to support pattern matching #[derive(Clone)] +#[non_exhaustive] pub struct Filter { /// The predicate expression, which must have Boolean type. - predicate: Expr, + pub predicate: Expr, /// The incoming logical plan - input: Arc, + pub input: Arc, } impl Filter { diff --git a/datafusion/optimizer/src/decorrelate_where_exists.rs b/datafusion/optimizer/src/decorrelate_where_exists.rs index f1addf651b210..50bbf6bb51aa0 100644 --- a/datafusion/optimizer/src/decorrelate_where_exists.rs +++ b/datafusion/optimizer/src/decorrelate_where_exists.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::optimizer::ApplyOrder; use crate::utils::{ conjunction, exprs_to_join_cols, find_join_exprs, split_conjunction, verify_not_disjunction, }; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{context, Result}; use datafusion_expr::{ logical_plan::{Filter, JoinType, Subquery}, @@ -81,27 +82,15 @@ impl OptimizerRule for DecorrelateWhereExists { ) -> Result> { match plan { LogicalPlan::Filter(filter) => { - let predicate = filter.predicate(); - let filter_input = filter.input().as_ref(); - - // Apply optimizer rule to current input - let optimized_input = self - .try_optimize(filter_input, config)? - .unwrap_or_else(|| filter_input.clone()); - let (subqueries, other_exprs) = - self.extract_subquery_exprs(predicate, config)?; - let optimized_plan = LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(optimized_input), - )?); + self.extract_subquery_exprs(filter.predicate(), config)?; if subqueries.is_empty() { // regular filter, no subquery exists clause here - return Ok(Some(optimized_plan)); + return Ok(None); } // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = filter_input.clone(); + let mut cur_input = filter.input().as_ref().clone(); for subquery in subqueries { if let Some(x) = optimize_exists(&subquery, &cur_input, &other_exprs)? { @@ -112,16 +101,17 @@ impl OptimizerRule for DecorrelateWhereExists { } Ok(Some(cur_input)) } - _ => { - // Apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) - } + _ => Ok(None), } } fn name(&self) -> &str { "decorrelate_where_exists" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } /// Takes a query like: @@ -226,6 +216,15 @@ mod tests { }; use std::ops::Add; + fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereExists::new()), + plan, + expected, + ); + Ok(()) + } + /// Test for multiple exists subqueries in the same filter expression #[test] fn multiple_subqueries() -> Result<()> { @@ -248,8 +247,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test recursive correlated subqueries @@ -284,8 +282,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery filter with additional subquery filters @@ -313,8 +310,7 @@ mod tests { Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery with no columns in schema @@ -332,8 +328,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } /// Test for exists subquery with both columns in schema @@ -351,8 +346,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } /// Test for correlated exists subquery not equal @@ -370,8 +364,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } /// Test for correlated exists subquery less than @@ -391,7 +384,7 @@ mod tests { let expected = r#"can't optimize < column comparison"#; - assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); Ok(()) } @@ -416,7 +409,7 @@ mod tests { let expected = r#"Optimizing disjunctions not supported!"#; - assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); Ok(()) } @@ -434,8 +427,7 @@ mod tests { .project(vec![col("customer.c_custkey")])? .build()?; - assert_optimization_skipped(&DecorrelateWhereExists::new(), &plan); - Ok(()) + assert_optimization_skipped(Arc::new(DecorrelateWhereExists::new()), &plan) } /// Test for correlated exists expressions @@ -459,8 +451,7 @@ mod tests { TableScan: customer [c_custkey:Int64, c_name:Utf8] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery filter with additional filters @@ -483,8 +474,7 @@ mod tests { TableScan: customer [c_custkey:Int64, c_name:Utf8] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated exists subquery filter with disjustions @@ -511,8 +501,7 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for correlated EXISTS subquery filter @@ -535,8 +524,7 @@ mod tests { TableScan: test [a:UInt32, b:UInt32, c:UInt32] TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"#; - assert_optimized_plan_eq(&DecorrelateWhereExists::new(), &plan, expected); - Ok(()) + assert_plan_eq(&plan, expected) } /// Test for single exists subquery filter @@ -550,7 +538,7 @@ mod tests { let expected = "cannot optimize non-correlated subquery"; - assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); Ok(()) } @@ -565,7 +553,7 @@ mod tests { let expected = "cannot optimize non-correlated subquery"; - assert_optimizer_err(&DecorrelateWhereExists::new(), &plan, expected); + assert_optimizer_err(Arc::new(DecorrelateWhereExists::new()), &plan, expected); Ok(()) } } diff --git a/datafusion/optimizer/src/decorrelate_where_in.rs b/datafusion/optimizer/src/decorrelate_where_in.rs index d2555ea5c31a2..91dd9c550130c 100644 --- a/datafusion/optimizer/src/decorrelate_where_in.rs +++ b/datafusion/optimizer/src/decorrelate_where_in.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. +use crate::optimizer::ApplyOrder; use crate::utils::{ alias_cols, conjunction, exprs_to_join_cols, find_join_exprs, merge_cols, only_or_err, split_conjunction, swap_table, verify_not_disjunction, }; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{context, Result}; -use datafusion_expr::logical_plan::{Filter, JoinType, Projection, Subquery}; +use datafusion_expr::logical_plan::{JoinType, Projection, Subquery}; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; use log::debug; use std::sync::Arc; @@ -85,43 +86,32 @@ impl OptimizerRule for DecorrelateWhereIn { ) -> Result> { match plan { LogicalPlan::Filter(filter) => { - let predicate = filter.predicate(); - let filter_input = filter.input().as_ref(); - - // Apply optimizer rule to current input - let optimized_input = self - .try_optimize(filter_input, config)? - .unwrap_or_else(|| filter_input.clone()); - let (subqueries, other_exprs) = - self.extract_subquery_exprs(predicate, config)?; - let optimized_plan = LogicalPlan::Filter(Filter::try_new( - predicate.clone(), - Arc::new(optimized_input), - )?); + self.extract_subquery_exprs(filter.predicate(), config)?; if subqueries.is_empty() { // regular filter, no subquery exists clause here - return Ok(Some(optimized_plan)); + return Ok(None); } // iterate through all exists clauses in predicate, turning each into a join - let mut cur_input = filter_input.clone(); + let mut cur_input = filter.input().as_ref().clone(); for subquery in subqueries { cur_input = optimize_where_in(&subquery, &cur_input, &other_exprs, config)?; } Ok(Some(cur_input)) } - _ => { - // Apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) - } + _ => Ok(None), } } fn name(&self) -> &str { "decorrelate_where_in" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } fn optimize_where_in( @@ -268,7 +258,11 @@ mod tests { \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n SubqueryAlias: __sq_2 [o_custkey:Int64]\n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -299,17 +293,21 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n LeftSemi Join: customer.c_custkey = __sq_2.o_custkey [c_custkey:Int64, c_name:Utf8]\ + \n LeftSemi Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, c_name:Utf8]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __sq_2 [o_custkey:Int64]\ + \n SubqueryAlias: __sq_1 [o_custkey:Int64]\ \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ - \n LeftSemi Join: orders.o_orderkey = __sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ + \n LeftSemi Join: orders.o_orderkey = __sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __sq_1 [l_orderkey:Int64]\ + \n SubqueryAlias: __sq_2 [l_orderkey:Int64]\ \n Projection: lineitem.l_orderkey AS l_orderkey [l_orderkey:Int64]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -340,7 +338,11 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -368,7 +370,11 @@ mod tests { \n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -395,7 +401,11 @@ mod tests { \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -421,7 +431,11 @@ mod tests { \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -442,7 +456,7 @@ mod tests { // can't optimize on arbitrary expressions (yet) assert_optimizer_err( - &DecorrelateWhereIn::new(), + Arc::new(DecorrelateWhereIn::new()), &plan, "column correlation not found", ); @@ -469,7 +483,7 @@ mod tests { .build()?; assert_optimizer_err( - &DecorrelateWhereIn::new(), + Arc::new(DecorrelateWhereIn::new()), &plan, "Optimizing disjunctions not supported!", ); @@ -492,7 +506,7 @@ mod tests { // Maybe okay if the table only has a single column? assert_optimizer_err( - &DecorrelateWhereIn::new(), + Arc::new(DecorrelateWhereIn::new()), &plan, "a projection is required", ); @@ -516,7 +530,7 @@ mod tests { // TODO: support join on expression assert_optimizer_err( - &DecorrelateWhereIn::new(), + Arc::new(DecorrelateWhereIn::new()), &plan, "column comparison required", ); @@ -540,7 +554,7 @@ mod tests { // TODO: support join on expressions? assert_optimizer_err( - &DecorrelateWhereIn::new(), + Arc::new(DecorrelateWhereIn::new()), &plan, "single column projection required", ); @@ -566,7 +580,7 @@ mod tests { .build()?; assert_optimizer_err( - &DecorrelateWhereIn::new(), + Arc::new(DecorrelateWhereIn::new()), &plan, "single expression projection required", ); @@ -599,7 +613,11 @@ mod tests { \n Projection: orders.o_custkey AS o_custkey [o_custkey:Int64]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -630,7 +648,11 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -656,7 +678,11 @@ mod tests { \n Projection: sq.c AS c, sq.a AS a [c:UInt32, a:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -676,7 +702,11 @@ mod tests { \n Projection: sq.c AS c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } @@ -696,7 +726,11 @@ mod tests { \n Projection: sq.c AS c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&DecorrelateWhereIn::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(DecorrelateWhereIn::new()), + &plan, + expected, + ); Ok(()) } } diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index 7636a6a9fcc7b..e4fc803416743 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -18,13 +18,14 @@ //! Optimizer rule to replace `where false` on a plan with an empty relation. //! This saves time in planning and executing the query. //! Note that this rule should be applied after simplify expressions optimizer rule. +use crate::optimizer::ApplyOrder; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ logical_plan::{EmptyRelation, LogicalPlan}, - Expr, + Expr, Filter, }; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; /// Optimization rule that elimanate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] @@ -41,139 +42,119 @@ impl OptimizerRule for EliminateFilter { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { - let predicate_and_input = match plan { - LogicalPlan::Filter(filter) => match filter.predicate() { - Expr::Literal(ScalarValue::Boolean(Some(v))) => { - Some((*v, filter.input())) + match plan { + LogicalPlan::Filter(Filter { + predicate: Expr::Literal(ScalarValue::Boolean(Some(v))), + input, + .. + }) => { + match *v { + // input also can be filter, apply again + true => Ok(Some( + self.try_optimize(input, _config)? + .unwrap_or_else(|| input.as_ref().clone()), + )), + false => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: input.schema().clone(), + }))), } - _ => None, - }, - _ => None, - }; - - match predicate_and_input { - Some((true, input)) => self.try_optimize(input, config), - Some((false, input)) => Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: input.schema().clone(), - }))), - None => { - // Apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) } + _ => Ok(None), } } fn name(&self) -> &str { "eliminate_filter" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } #[cfg(test)] mod tests { - use datafusion_expr::{col, lit, logical_plan::builder::LogicalPlanBuilder, sum}; + use crate::eliminate_filter::EliminateFilter; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::{ + col, lit, logical_plan::builder::LogicalPlanBuilder, sum, Expr, LogicalPlan, + }; + use std::sync::Arc; - use crate::optimizer::OptimizerContext; use crate::test::*; - use super::*; - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = EliminateFilter::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); + fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } #[test] - fn filter_false() { + fn filter_false() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b"))]) - .unwrap() - .filter(filter_expr) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .build()?; // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected); + assert_eq(&plan, expected) } #[test] - fn filter_false_nested() { + fn filter_false_nested() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(false))); - let table_scan = test_table_scan().unwrap(); + let table_scan = test_table_scan()?; let plan1 = LogicalPlanBuilder::from(table_scan.clone()) - .aggregate(vec![col("a")], vec![sum(col("b"))]) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b"))]) - .unwrap() - .filter(filter_expr) - .unwrap() - .union(plan1) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .union(plan1)? + .build()?; // Left side is removed let expected = "Union\ \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_eq(&plan, expected) } #[test] - fn filter_true() { + fn filter_true() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true))); - let table_scan = test_table_scan().unwrap(); + let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b"))]) - .unwrap() - .filter(filter_expr) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .build()?; let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_eq(&plan, expected) } #[test] - fn filter_true_nested() { + fn filter_true_nested() -> Result<()> { let filter_expr = Expr::Literal(ScalarValue::Boolean(Some(true))); - let table_scan = test_table_scan().unwrap(); + let table_scan = test_table_scan()?; let plan1 = LogicalPlanBuilder::from(table_scan.clone()) - .aggregate(vec![col("a")], vec![sum(col("b"))]) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .build()?; let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a")], vec![sum(col("b"))]) - .unwrap() - .filter(filter_expr) - .unwrap() - .union(plan1) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a")], vec![sum(col("b"))])? + .filter(filter_expr)? + .union(plan1)? + .build()?; // Filter is removed let expected = "Union\ @@ -181,35 +162,29 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected); + assert_eq(&plan, expected) } #[test] - fn filter_from_subquery() { + fn filter_from_subquery() -> Result<()> { // SELECT a FROM (SELECT a FROM test WHERE FALSE) WHERE TRUE let false_filter = lit(false); - let table_scan = test_table_scan().unwrap(); + let table_scan = test_table_scan()?; let plan1 = LogicalPlanBuilder::from(table_scan) - .project(vec![col("a")]) - .unwrap() - .filter(false_filter) - .unwrap() - .build() - .unwrap(); + .project(vec![col("a")])? + .filter(false_filter)? + .build()?; let true_filter = lit(true); let plan = LogicalPlanBuilder::from(plan1) - .project(vec![col("a")]) - .unwrap() - .filter(true_filter) - .unwrap() - .build() - .unwrap(); + .project(vec![col("a")])? + .filter(true_filter)? + .build()?; // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_optimized_plan_eq(&plan, expected); + assert_eq(&plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_limit.rs b/datafusion/optimizer/src/eliminate_limit.rs index 9e3cbf6fab038..caea145dd2f28 100644 --- a/datafusion/optimizer/src/eliminate_limit.rs +++ b/datafusion/optimizer/src/eliminate_limit.rs @@ -43,28 +43,25 @@ impl OptimizerRule for EliminateLimit { plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let limit = match plan { - LogicalPlan::Limit(limit) => limit, - _ => return Ok(None), - }; - - match limit.fetch { - Some(fetch) => { - if fetch == 0 { - return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: false, - schema: limit.input.schema().clone(), - }))); + if let LogicalPlan::Limit(limit) = plan { + match limit.fetch { + Some(fetch) => { + if fetch == 0 { + return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: limit.input.schema().clone(), + }))); + } } - } - None => { - if limit.skip == 0 { - let input = limit.input.as_ref(); - // input also can be Limit, so we should apply again. - return Ok(Some( - self.try_optimize(input, _config)? - .unwrap_or_else(|| input.clone()), - )); + None => { + if limit.skip == 0 { + let input = limit.input.as_ref(); + // input also can be Limit, so we should apply again. + return Ok(Some( + self.try_optimize(input, _config)? + .unwrap_or_else(|| input.clone()), + )); + } } } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index cc535117def36..8c02950d84762 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -16,7 +16,7 @@ // under the License. //! Optimizer rule to eliminate left/right/full join to inner join if possible. -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{Column, DFSchema, Result}; use datafusion_expr::{ logical_plan::{Join, JoinType, LogicalPlan}, @@ -24,6 +24,7 @@ use datafusion_expr::{ }; use datafusion_expr::{Expr, Operator}; +use crate::optimizer::ApplyOrder; use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; use std::sync::Arc; @@ -64,7 +65,7 @@ impl OptimizerRule for EliminateOuterJoin { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => match filter.input().as_ref() { @@ -109,17 +110,21 @@ impl OptimizerRule for EliminateOuterJoin { null_equals_null: join.null_equals_null, }); let new_plan = from_plan(plan, &plan.expressions(), &[new_join])?; - Ok(Some(utils::optimize_children(self, &new_plan, config)?)) + Ok(Some(new_plan)) } - _ => Ok(Some(utils::optimize_children(self, plan, config)?)), + _ => Ok(None), }, - _ => Ok(Some(utils::optimize_children(self, plan, config)?)), + _ => Ok(None), } } fn name(&self) -> &str { "eliminate_outer_join" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } pub fn eliminate_outer( @@ -295,7 +300,6 @@ fn extract_non_nullable_columns( #[cfg(test)] mod tests { use super::*; - use crate::optimizer::OptimizerContext; use crate::test::*; use arrow::datatypes::DataType; use datafusion_expr::{ @@ -305,16 +309,8 @@ mod tests { Operator::{And, Or}, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { - let rule = EliminateOuterJoin::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); - Ok(()) + fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } #[test] @@ -337,7 +333,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected) + assert_eq(&plan, expected) } #[test] @@ -360,7 +356,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected) + assert_eq(&plan, expected) } #[test] @@ -387,7 +383,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected) + assert_eq(&plan, expected) } #[test] @@ -414,7 +410,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected) + assert_eq(&plan, expected) } #[test] @@ -441,6 +437,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected) + assert_eq(&plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index eea98ad1f484a..8a6c995dee533 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -20,7 +20,8 @@ //! and then insert an `IsNotNull` filter on the nullable side since null values //! can never match. -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::{ and, logical_plan::Filter, logical_plan::JoinType, Expr, ExprSchemable, LogicalPlan, @@ -42,20 +43,11 @@ impl OptimizerRule for FilterNullJoinKeys { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Join(join) if join.join_type == JoinType::Inner => { - // recurse down first and optimize inputs let mut join = join.clone(); - join.left = Arc::new( - self.try_optimize(&join.left, config)? - .unwrap_or_else(|| join.left.as_ref().clone()), - ); - join.right = Arc::new( - self.try_optimize(&join.right, config)? - .unwrap_or_else(|| join.right.as_ref().clone()), - ); let left_schema = join.left.schema(); let right_schema = join.right.schema(); @@ -89,16 +81,17 @@ impl OptimizerRule for FilterNullJoinKeys { } Ok(Some(LogicalPlan::Join(join))) } - _ => { - // Apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) - } + _ => Ok(None), } } fn name(&self) -> &str { Self::NAME } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } } fn create_not_null_predicate(filters: Vec) -> Expr { @@ -115,27 +108,15 @@ fn create_not_null_predicate(filters: Vec) -> Expr { #[cfg(test)] mod tests { + use super::*; + use crate::test::assert_optimized_plan_eq; use arrow::datatypes::{DataType, Field, Schema}; - use datafusion_common::{Column, Result}; use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - use crate::optimizer::OptimizerContext; - - use super::*; - - fn optimize_plan(plan: &LogicalPlan) -> LogicalPlan { - let rule = FilterNullJoinKeys::default(); - rule.try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan") - } - - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let optimized_plan = optimize_plan(plan); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); + fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } #[test] @@ -146,8 +127,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_eq(&plan, expected) } #[test] @@ -158,8 +138,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_eq(&plan, expected) } #[test] @@ -196,8 +175,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_eq(&plan, expected) } #[test] @@ -218,8 +196,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_eq(&plan, expected) } #[test] @@ -240,8 +217,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_eq(&plan, expected) } #[test] @@ -264,8 +240,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_eq(&plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/inline_table_scan.rs b/datafusion/optimizer/src/inline_table_scan.rs index fe24e675de0c0..1783cf0a2a7e9 100644 --- a/datafusion/optimizer/src/inline_table_scan.rs +++ b/datafusion/optimizer/src/inline_table_scan.rs @@ -18,7 +18,8 @@ //! Optimizer rule to replace TableScan references //! such as DataFrames and Views and inlines the LogicalPlan //! to support further optimization -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::{logical_plan::LogicalPlan, Expr, LogicalPlanBuilder, TableScan}; @@ -38,7 +39,7 @@ impl OptimizerRule for InlineTableScan { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { // Match only on scans without filter / projection / fetch @@ -51,29 +52,25 @@ impl OptimizerRule for InlineTableScan { .. }) if filters.is_empty() => { if let Some(sub_plan) = source.get_logical_plan() { - // Recursively apply optimization - let plan = utils::optimize_children(self, sub_plan, config)?; - let plan = LogicalPlanBuilder::from(plan) + let plan = LogicalPlanBuilder::from(sub_plan.clone()) .project(vec![Expr::Wildcard])? .alias(table_name)?; Ok(Some(plan.build()?)) } else { - // No plan available, return with table scan as is - Ok(Some(plan.clone())) + Ok(None) } } - - // Rest: Recurse - _ => { - // apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) - } + _ => Ok(None), } } fn name(&self) -> &str { "inline_table_scan" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } } #[cfg(test)] @@ -83,8 +80,8 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::{col, lit, LogicalPlan, LogicalPlanBuilder, TableSource}; - use crate::optimizer::OptimizerContext; - use crate::{inline_table_scan::InlineTableScan, OptimizerRule}; + use crate::inline_table_scan::InlineTableScan; + use crate::test::assert_optimized_plan_eq; pub struct RawTableSource {} @@ -144,26 +141,18 @@ mod tests { } #[test] - fn inline_table_scan() { - let rule = InlineTableScan::new(); - - let source = Arc::new(CustomSource::new()); - - let scan = LogicalPlanBuilder::scan("x".to_string(), source, None).unwrap(); - - let plan = scan.filter(col("x.a").eq(lit(1))).unwrap().build().unwrap(); - - let optimized_plan = rule - .try_optimize(&plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimized_plan); + fn inline_table_scan() -> datafusion_common::Result<()> { + let scan = LogicalPlanBuilder::scan( + "x".to_string(), + Arc::new(CustomSource::new()), + None, + )?; + let plan = scan.filter(col("x.a").eq(lit(1)))?.build()?; let expected = "Filter: x.a = Int32(1)\ \n SubqueryAlias: x\ \n Projection: y.a\ \n TableScan: y"; - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); + assert_optimized_plan_eq(Arc::new(InlineTableScan::new()), &plan, expected) } } diff --git a/datafusion/optimizer/src/propagate_empty_relation.rs b/datafusion/optimizer/src/propagate_empty_relation.rs index 7ef769e2122e8..e3a86381fbc14 100644 --- a/datafusion/optimizer/src/propagate_empty_relation.rs +++ b/datafusion/optimizer/src/propagate_empty_relation.rs @@ -20,7 +20,8 @@ use datafusion_expr::logical_plan::LogicalPlan; use datafusion_expr::{EmptyRelation, JoinType, Projection, Union}; use std::sync::Arc; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; /// Optimization rule that bottom-up to eliminate plan by propagating empty_relation. #[derive(Default)] @@ -37,32 +38,28 @@ impl OptimizerRule for PropagateEmptyRelation { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { - // optimize child plans first - let optimized_children_plan = utils::optimize_children(self, plan, config)?; - match &optimized_children_plan { - LogicalPlan::EmptyRelation(_) => Ok(Some(optimized_children_plan)), + match plan { + LogicalPlan::EmptyRelation(_) => {} LogicalPlan::Projection(_) | LogicalPlan::Filter(_) | LogicalPlan::Window(_) | LogicalPlan::Sort(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Repartition(_) - | LogicalPlan::Limit(_) => match empty_child(&optimized_children_plan)? { - Some(empty) => Ok(Some(empty)), - None => Ok(Some(optimized_children_plan)), - }, + | LogicalPlan::Limit(_) => { + if let Some(empty) = empty_child(plan)? { + return Ok(Some(empty)); + } + } LogicalPlan::CrossJoin(_) => { - let (left_empty, right_empty) = - binary_plan_children_is_empty(&optimized_children_plan)?; + let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?; if left_empty || right_empty { - Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: optimized_children_plan.schema().clone(), - }))) - } else { - Ok(Some(optimized_children_plan)) + schema: plan.schema().clone(), + }))); } } LogicalPlan::Join(join) => { @@ -79,18 +76,13 @@ impl OptimizerRule for PropagateEmptyRelation { // For RightOut/Full Join, if the left side is empty, the Join can be eliminated with a Projection with right side // columns + left side columns replaced with null values. if join.join_type == JoinType::Inner { - let (left_empty, right_empty) = - binary_plan_children_is_empty(&optimized_children_plan)?; + let (left_empty, right_empty) = binary_plan_children_is_empty(plan)?; if left_empty || right_empty { - Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: optimized_children_plan.schema().clone(), - }))) - } else { - Ok(Some(optimized_children_plan)) + schema: plan.schema().clone(), + }))); } - } else { - Ok(Some(optimized_children_plan)) } } LogicalPlan::Union(union) => { @@ -105,46 +97,50 @@ impl OptimizerRule for PropagateEmptyRelation { .collect::>(); if new_inputs.len() == union.inputs.len() { - Ok(Some(optimized_children_plan)) + return Ok(None); } else if new_inputs.is_empty() { - Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { + return Ok(Some(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, - schema: optimized_children_plan.schema().clone(), - }))) + schema: plan.schema().clone(), + }))); } else if new_inputs.len() == 1 { let child = (**(union.inputs.get(0).unwrap())).clone(); - if child.schema().eq(optimized_children_plan.schema()) { - Ok(Some(child)) + if child.schema().eq(plan.schema()) { + return Ok(Some(child)); } else { - Ok(Some(LogicalPlan::Projection(Projection::new_from_schema( - Arc::new(child), - optimized_children_plan.schema().clone(), - )))) + return Ok(Some(LogicalPlan::Projection( + Projection::new_from_schema( + Arc::new(child), + plan.schema().clone(), + ), + ))); } } else { - Ok(Some(LogicalPlan::Union(Union { + return Ok(Some(LogicalPlan::Union(Union { inputs: new_inputs, schema: union.schema.clone(), - }))) + }))); } } LogicalPlan::Aggregate(agg) => { if !agg.group_expr.is_empty() { - match empty_child(&optimized_children_plan)? { - Some(empty) => Ok(Some(empty)), - None => Ok(Some(optimized_children_plan)), + if let Some(empty) = empty_child(plan)? { + return Ok(Some(empty)); } - } else { - Ok(Some(optimized_children_plan)) } } - _ => Ok(Some(optimized_children_plan)), + _ => {} } + Ok(None) } fn name(&self) -> &str { "propagate_empty_relation" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } } fn binary_plan_children_is_empty(plan: &LogicalPlan) -> Result<(bool, bool)> { @@ -202,7 +198,10 @@ fn empty_child(plan: &LogicalPlan) -> Result> { #[cfg(test)] mod tests { use crate::eliminate_filter::EliminateFilter; - use crate::test::{test_table_scan, test_table_scan_with_name}; + use crate::optimizer::Optimizer; + use crate::test::{ + assert_optimized_plan_eq, test_table_scan, test_table_scan_with_name, + }; use crate::OptimizerContext; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{Column, ScalarValue}; @@ -214,29 +213,29 @@ mod tests { use super::*; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = PropagateEmptyRelation::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimized_plan); - assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimized_plan.schema()); + fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(PropagateEmptyRelation::new()), plan, expected) } - fn assert_together_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let optimize_one = EliminateFilter::new() - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let optimize_two = PropagateEmptyRelation::new() - .try_optimize(&optimize_one, &OptimizerContext::new()) - .unwrap() + fn assert_together_optimized_plan_eq( + plan: &LogicalPlan, + expected: &str, + ) -> Result<()> { + fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + let optimizer = Optimizer::with_rules(vec![ + Arc::new(EliminateFilter::new()), + Arc::new(PropagateEmptyRelation::new()), + ]); + let config = &mut OptimizerContext::new() + .with_max_passes(1) + .with_skip_failing_rules(false); + let optimized_plan = optimizer + .optimize(plan, config, observe) .expect("failed to optimize plan"); - let formatted_plan = format!("{:?}", optimize_two); + let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); - assert_eq!(plan.schema(), optimize_two.schema()); + assert_eq!(plan.schema(), optimized_plan.schema()); + Ok(()) } #[test] @@ -248,9 +247,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_eq(&plan, expected) } #[test] @@ -273,9 +270,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } #[test] @@ -289,9 +284,7 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } #[test] @@ -316,9 +309,7 @@ mod tests { let expected = "Union\ \n TableScan: test1\ \n TableScan: test4"; - assert_together_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } #[test] @@ -343,9 +334,7 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } #[test] @@ -372,9 +361,7 @@ mod tests { let expected = "Union\ \n TableScan: test2\ \n TableScan: test3"; - assert_together_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } #[test] @@ -388,8 +375,7 @@ mod tests { let expected = "Projection: a, b, c\ \n TableScan: test"; - assert_together_optimized_plan_eq(&plan, expected); - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } #[test] @@ -404,8 +390,6 @@ mod tests { .build()?; let expected = "EmptyRelation"; - assert_together_optimized_plan_eq(&plan, expected); - - Ok(()) + assert_together_optimized_plan_eq(&plan, expected) } } diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 0a61105411792..8e7610bcc3c7f 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::optimizer::ApplyOrder; use crate::utils::{ conjunction, exprs_to_join_cols, find_join_exprs, only_or_err, split_conjunction, verify_not_disjunction, }; -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{context, plan_err, Column, Result}; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::logical_plan::{Filter, JoinType, Limit, Subquery}; @@ -97,20 +98,12 @@ impl OptimizerRule for ScalarSubqueryToJoin { ) -> Result> { match plan { LogicalPlan::Filter(filter) => { - // Apply optimizer rule to current input - let optimized_input = self - .try_optimize(filter.input(), config)? - .unwrap_or_else(|| filter.input().as_ref().clone()); - let (subqueries, other_exprs) = self.extract_subquery_exprs(filter.predicate(), config)?; if subqueries.is_empty() { // regular filter, no subquery exists clause here - return Ok(Some(LogicalPlan::Filter(Filter::try_new( - filter.predicate().clone(), - Arc::new(optimized_input), - )?))); + return Ok(None); } // iterate through all subqueries in predicate, turning each into a join @@ -122,24 +115,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { cur_input = optimized_subquery; } else { // if we can't handle all of the subqueries then bail for now - return Ok(Some(LogicalPlan::Filter(Filter::try_new( - filter.predicate().clone(), - Arc::new(optimized_input), - )?))); + return Ok(None); } } Ok(Some(cur_input)) } - _ => { - // Apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) - } + _ => Ok(None), } } fn name(&self) -> &str { "scalar_subquery_to_join" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } /// Takes a query like: @@ -408,7 +399,11 @@ mod tests { \n Projection: orders.o_custkey, MAX(orders.o_custkey) AS __value [o_custkey:Int64, __value:Int64;N]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -444,20 +439,24 @@ mod tests { .build()?; let expected = "Projection: customer.c_custkey [c_custkey:Int64]\ - \n Filter: customer.c_acctbal < __sq_2.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\ - \n Inner Join: customer.c_custkey = __sq_2.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\ + \n Filter: customer.c_acctbal < __sq_1.__value [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\ + \n Inner Join: customer.c_custkey = __sq_1.o_custkey [c_custkey:Int64, c_name:Utf8, o_custkey:Int64, __value:Float64;N]\ \n TableScan: customer [c_custkey:Int64, c_name:Utf8]\ - \n SubqueryAlias: __sq_2 [o_custkey:Int64, __value:Float64;N]\ + \n SubqueryAlias: __sq_1 [o_custkey:Int64, __value:Float64;N]\ \n Projection: orders.o_custkey, SUM(orders.o_totalprice) AS __value [o_custkey:Int64, __value:Float64;N]\ \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[SUM(orders.o_totalprice)]] [o_custkey:Int64, SUM(orders.o_totalprice):Float64;N]\ - \n Filter: orders.o_totalprice < __sq_1.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\ - \n Inner Join: orders.o_orderkey = __sq_1.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\ + \n Filter: orders.o_totalprice < __sq_2.__value [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\ + \n Inner Join: orders.o_orderkey = __sq_2.l_orderkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N, l_orderkey:Int64, __value:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ - \n SubqueryAlias: __sq_1 [l_orderkey:Int64, __value:Float64;N]\ + \n SubqueryAlias: __sq_2 [l_orderkey:Int64, __value:Float64;N]\ \n Projection: lineitem.l_orderkey, SUM(lineitem.l_extendedprice) AS __value [l_orderkey:Int64, __value:Float64;N]\ \n Aggregate: groupBy=[[lineitem.l_orderkey]], aggr=[[SUM(lineitem.l_extendedprice)]] [l_orderkey:Int64, SUM(lineitem.l_extendedprice):Float64;N]\ \n TableScan: lineitem [l_orderkey:Int64, l_partkey:Int64, l_suppkey:Int64, l_linenumber:Int32, l_quantity:Float64, l_extendedprice:Float64]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -490,7 +489,11 @@ mod tests { \n Filter: orders.o_orderkey = Int32(1) [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -520,7 +523,11 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ \n Filter: customer.c_custkey = customer.c_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -550,7 +557,11 @@ mod tests { \n Filter: orders.o_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -572,7 +583,7 @@ mod tests { let expected = r#"only joins on column equality are presently supported"#; - assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected); Ok(()) } @@ -593,7 +604,7 @@ mod tests { .build()?; let expected = r#"can't optimize < column comparison"#; - assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected); Ok(()) } @@ -618,7 +629,7 @@ mod tests { .build()?; let expected = r#"Optimizing disjunctions not supported!"#; - assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected); Ok(()) } @@ -644,7 +655,11 @@ mod tests { TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -670,7 +685,11 @@ mod tests { let expected = r#""#; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -694,7 +713,7 @@ mod tests { .build()?; let expected = r#"exactly one expression should be projected"#; - assert_optimizer_err(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimizer_err(Arc::new(ScalarSubqueryToJoin::new()), &plan, expected); Ok(()) } @@ -728,7 +747,11 @@ mod tests { \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -760,7 +783,11 @@ mod tests { \n Aggregate: groupBy=[[orders.o_custkey]], aggr=[[MAX(orders.o_custkey)]] [o_custkey:Int64, MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -793,7 +820,11 @@ mod tests { Filter: customer.c_custkey = orders.o_custkey [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N] TableScan: customer [c_custkey:Int64, c_name:Utf8]"#; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -822,7 +853,11 @@ mod tests { \n Aggregate: groupBy=[[sq.a]], aggr=[[MIN(sq.c)]] [a:UInt32, MIN(sq.c):UInt32;N]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -850,7 +885,11 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } @@ -877,7 +916,11 @@ mod tests { \n Aggregate: groupBy=[[]], aggr=[[MAX(orders.o_custkey)]] [MAX(orders.o_custkey):Int64;N]\ \n TableScan: orders [o_orderkey:Int64, o_custkey:Int64, o_orderstatus:Utf8, o_totalprice:Float64;N]"; - assert_optimized_plan_eq(&ScalarSubqueryToJoin::new(), &plan, expected); + assert_optimized_plan_eq_display_indent( + Arc::new(ScalarSubqueryToJoin::new()), + &plan, + expected, + ); Ok(()) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 462b94dd0d050..a51c2ec29fe78 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use crate::optimizer::Optimizer; use crate::{OptimizerContext, OptimizerRule}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::Result; @@ -102,24 +103,53 @@ pub fn get_tpch_table_schema(table: &str) -> Schema { } pub fn assert_optimized_plan_eq( - rule: &dyn OptimizerRule, + rule: Arc, + plan: &LogicalPlan, + expected: &str, +) -> Result<()> { + let optimizer = Optimizer::with_rules(vec![rule]); + let optimized_plan = optimizer + .optimize_recursively( + optimizer.rules.get(0).unwrap(), + plan, + &OptimizerContext::new(), + )? + .unwrap(); + let formatted_plan = format!("{:?}", optimized_plan); + assert_eq!(formatted_plan, expected); + + Ok(()) +} + +pub fn assert_optimized_plan_eq_display_indent( + rule: Arc, plan: &LogicalPlan, expected: &str, ) { - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); + let optimizer = Optimizer::with_rules(vec![rule]); + let optimized_plan = optimizer + .optimize_recursively( + optimizer.rules.get(0).unwrap(), + plan, + &OptimizerContext::new(), + ) + .expect("failed to optimize plan") + .unwrap_or_else(|| plan.clone()); let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); assert_eq!(formatted_plan, expected); } pub fn assert_optimizer_err( - rule: &dyn OptimizerRule, + rule: Arc, plan: &LogicalPlan, expected: &str, ) { - let res = rule.try_optimize(plan, &OptimizerContext::new()); + let optimizer = Optimizer::with_rules(vec![rule]); + let res = optimizer.optimize_recursively( + optimizer.rules.get(0).unwrap(), + plan, + &OptimizerContext::new(), + ); match res { Ok(plan) => assert_eq!(format!("{}", plan.unwrap().display_indent()), "An error"), Err(ref e) => { @@ -131,13 +161,21 @@ pub fn assert_optimizer_err( } } -pub fn assert_optimization_skipped(rule: &dyn OptimizerRule, plan: &LogicalPlan) { - let new_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .unwrap(); +pub fn assert_optimization_skipped( + rule: Arc, + plan: &LogicalPlan, +) -> Result<()> { + let optimizer = Optimizer::with_rules(vec![rule]); + let new_plan = optimizer + .optimize_recursively( + optimizer.rules.get(0).unwrap(), + plan, + &OptimizerContext::new(), + )? + .unwrap_or_else(|| plan.clone()); assert_eq!( format!("{}", plan.display_indent()), format!("{}", new_plan.display_indent()) ); + Ok(()) }