diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index e4fc803416743..c5dc1711c244a 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -86,7 +86,7 @@ mod tests { use crate::test::*; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateFilter::new()), plan, expected) } @@ -102,7 +102,7 @@ mod tests { // No aggregate / scan / limit let expected = "EmptyRelation"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -124,7 +124,7 @@ mod tests { \n EmptyRelation\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -139,7 +139,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -162,7 +162,7 @@ mod tests { \n TableScan: test\ \n Aggregate: groupBy=[[test.a]], aggr=[[SUM(test.b)]]\ \n TableScan: test"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -185,6 +185,6 @@ mod tests { // Filter is removed let expected = "Projection: test.a\ \n EmptyRelation"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/eliminate_outer_join.rs b/datafusion/optimizer/src/eliminate_outer_join.rs index 8c02950d84762..590f9855529fb 100644 --- a/datafusion/optimizer/src/eliminate_outer_join.rs +++ b/datafusion/optimizer/src/eliminate_outer_join.rs @@ -309,7 +309,7 @@ mod tests { Operator::{And, Or}, }; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(EliminateOuterJoin::new()), plan, expected) } @@ -333,7 +333,7 @@ mod tests { \n Left Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -356,7 +356,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -383,7 +383,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -410,7 +410,7 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -437,6 +437,6 @@ mod tests { \n Inner Join: t1.a = t2.a\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/filter_null_join_keys.rs b/datafusion/optimizer/src/filter_null_join_keys.rs index 8a6c995dee533..8f221eaccd351 100644 --- a/datafusion/optimizer/src/filter_null_join_keys.rs +++ b/datafusion/optimizer/src/filter_null_join_keys.rs @@ -115,7 +115,7 @@ mod tests { use datafusion_expr::logical_plan::table_scan; use datafusion_expr::{col, lit, logical_plan::JoinType, LogicalPlanBuilder}; - fn assert_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { assert_optimized_plan_eq(Arc::new(FilterNullJoinKeys {}), plan, expected) } @@ -127,7 +127,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -138,7 +138,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -175,7 +175,7 @@ mod tests { \n Filter: t1.optional_id IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -196,7 +196,7 @@ mod tests { \n Filter: t1.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t1\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -217,7 +217,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -240,7 +240,7 @@ mod tests { \n TableScan: t1\ \n Filter: t2.optional_id + UInt32(1) IS NOT NULL\ \n TableScan: t2"; - assert_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } fn build_plan( diff --git a/datafusion/optimizer/src/push_down_limit.rs b/datafusion/optimizer/src/push_down_limit.rs index ad5ceaea0b848..2619091832e1b 100644 --- a/datafusion/optimizer/src/push_down_limit.rs +++ b/datafusion/optimizer/src/push_down_limit.rs @@ -17,7 +17,8 @@ //! Optimizer rule to push down LIMIT in the query plan //! It will push down through projection, limits (taking the smaller limit) -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::{ logical_plan::{Join, JoinType, Limit, LogicalPlan, Sort, TableScan, Union}, @@ -78,11 +79,11 @@ impl OptimizerRule for PushDownLimit { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { let limit = match plan { LogicalPlan::Limit(limit) => limit, - _ => return Ok(Some(utils::optimize_children(self, plan, config)?)), + _ => return Ok(None), }; if let LogicalPlan::Limit(child_limit) = &*limit.input { @@ -112,12 +113,12 @@ impl OptimizerRule for PushDownLimit { fetch: new_fetch, input: Arc::new((*child_limit.input).clone()), }); - return self.try_optimize(&plan, config); + return self.try_optimize(&plan, _config); } let fetch = match limit.fetch { Some(fetch) => fetch, - None => return Ok(Some(utils::optimize_children(self, plan, config)?)), + None => return Ok(None), }; let skip = limit.skip; @@ -225,12 +226,16 @@ impl OptimizerRule for PushDownLimit { _ => plan.clone(), }; - Ok(Some(utils::optimize_children(self, &plan, config)?)) + Ok(Some(plan)) } fn name(&self) -> &str { "push_down_limit" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } fn fetch_minus_skip(fetch: usize, skip: usize) -> usize { @@ -247,25 +252,14 @@ mod test { use super::*; use crate::test::*; - use crate::OptimizerContext; use datafusion_expr::{ col, exists, logical_plan::{builder::LogicalPlanBuilder, JoinType, LogicalPlan}, max, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { - let optimized_plan = PushDownLimit::new() - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - - let formatted_plan = format!("{:?}", optimized_plan); - - assert_eq!(formatted_plan, expected); - assert_eq!(optimized_plan.schema(), plan.schema()); - - Ok(()) + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq(Arc::new(PushDownLimit::new()), plan, expected) } #[test] @@ -283,7 +277,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -301,7 +295,7 @@ mod test { let expected = "Limit: skip=0, fetch=10\ \n TableScan: test, fetch=10"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -318,7 +312,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -338,7 +332,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -355,7 +349,7 @@ mod test { \n Sort: test.a, fetch=10\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -372,7 +366,7 @@ mod test { \n Sort: test.a, fetch=15\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -391,7 +385,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -406,7 +400,7 @@ mod test { let expected = "Limit: skip=10, fetch=None\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -424,7 +418,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -441,7 +435,7 @@ mod test { \n Limit: skip=10, fetch=990\ \n TableScan: test, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -458,7 +452,7 @@ mod test { \n Limit: skip=10, fetch=1000\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -474,7 +468,7 @@ mod test { let expected = "Limit: skip=10, fetch=10\ \n TableScan: test, fetch=20"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -491,7 +485,7 @@ mod test { \n Aggregate: groupBy=[[test.a]], aggr=[[MAX(test.b)]]\ \n TableScan: test"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -511,7 +505,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test, fetch=1010"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -535,7 +529,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -559,7 +553,7 @@ mod test { \n TableScan: test\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -588,7 +582,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&outer_query, expected) + assert_optimized_plan_equal(&outer_query, expected) } #[test] @@ -617,7 +611,7 @@ mod test { \n Projection: test2.a\ \n TableScan: test2"; - assert_optimized_plan_eq(&outer_query, expected) + assert_optimized_plan_equal(&outer_query, expected) } #[test] @@ -643,7 +637,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_equal(&plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -662,7 +656,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_equal(&plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -681,7 +675,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_equal(&plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -699,7 +693,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_equal(&plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -717,7 +711,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_equal(&plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1.clone()) .join( @@ -735,7 +729,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected)?; + assert_optimized_plan_equal(&plan, expected)?; let plan = LogicalPlanBuilder::from(table_scan_1) .join( @@ -753,7 +747,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -778,7 +772,7 @@ mod test { \n TableScan: test, fetch=1000\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -803,7 +797,7 @@ mod test { \n TableScan: test, fetch=1010\ \n TableScan: test2"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -828,7 +822,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -853,7 +847,7 @@ mod test { \n Limit: skip=0, fetch=1010\ \n TableScan: test2, fetch=1010"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -873,7 +867,7 @@ mod test { \n Limit: skip=0, fetch=1000\ \n TableScan: test2, fetch=1000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -893,7 +887,7 @@ mod test { \n Limit: skip=0, fetch=2000\ \n TableScan: test2, fetch=2000"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -908,7 +902,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -923,7 +917,7 @@ mod test { let expected = "Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -940,6 +934,6 @@ mod test { \n Limit: skip=1000, fetch=0\ \n TableScan: test, fetch=0"; - assert_optimized_plan_eq(&plan, expected) + assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs index 0f9ba3d371e12..c14574eaf6f7c 100644 --- a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs +++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::Result; use datafusion_expr::expr::BinaryExpr; use datafusion_expr::logical_plan::Filter; use datafusion_expr::{Expr, LogicalPlan, Operator}; -use std::sync::Arc; /// Optimizer pass that rewrites predicates of the form /// @@ -127,7 +127,7 @@ impl OptimizerRule for RewriteDisjunctivePredicate { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => { @@ -136,18 +136,20 @@ impl OptimizerRule for RewriteDisjunctivePredicate { let rewritten_expr = normalize_predicate(rewritten_predicate); Ok(Some(LogicalPlan::Filter(Filter::try_new( rewritten_expr, - self.try_optimize(filter.input(), config)? - .map(Arc::new) - .unwrap_or_else(|| filter.input().clone()), + filter.input.clone(), )?))) } - _ => Ok(Some(utils::optimize_children(self, plan, config)?)), + _ => Ok(None), } } fn name(&self) -> &str { "rewrite_disjunctive_predicate" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } #[derive(Clone, PartialEq, Debug)] diff --git a/datafusion/optimizer/src/single_distinct_to_groupby.rs b/datafusion/optimizer/src/single_distinct_to_groupby.rs index bf4231c1f0cdc..c03d763b21e01 100644 --- a/datafusion/optimizer/src/single_distinct_to_groupby.rs +++ b/datafusion/optimizer/src/single_distinct_to_groupby.rs @@ -17,7 +17,8 @@ //! single distinct to group by optimizer rule -use crate::{utils, OptimizerConfig, OptimizerRule}; +use crate::optimizer::ApplyOrder; +use crate::{OptimizerConfig, OptimizerRule}; use datafusion_common::{DFSchema, Result}; use datafusion_expr::{ col, @@ -90,7 +91,7 @@ impl OptimizerRule for SingleDistinctToGroupBy { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Aggregate(Aggregate { @@ -157,13 +158,11 @@ impl OptimizerRule for SingleDistinctToGroupBy { inner_fields, input.schema().metadata().clone(), )?; - let grouped_aggr = LogicalPlan::Aggregate(Aggregate::try_new( + let inner_agg = LogicalPlan::Aggregate(Aggregate::try_new( input.clone(), inner_group_exprs, Vec::new(), )?); - let inner_agg = - utils::optimize_children(self, &grouped_aggr, config)?; let outer_aggr_schema = Arc::new(DFSchema::new_with_metadata( outer_group_exprs @@ -207,22 +206,26 @@ impl OptimizerRule for SingleDistinctToGroupBy { )?, ))) } else { - Ok(Some(utils::optimize_children(self, plan, config)?)) + Ok(None) } } - _ => Ok(Some(utils::optimize_children(self, plan, config)?)), + _ => Ok(None), } } + fn name(&self) -> &str { "single_distinct_aggregation_to_group_by" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } #[cfg(test)] mod tests { use super::*; use crate::test::*; - use crate::OptimizerContext; use datafusion_expr::expr; use datafusion_expr::expr::GroupingSet; use datafusion_expr::{ @@ -230,15 +233,13 @@ mod tests { AggregateFunction, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = SingleDistinctToGroupBy::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - - let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); - assert_eq!(formatted_plan, expected); + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq_display_indent( + Arc::new(SingleDistinctToGroupBy::new()), + plan, + expected, + ); + Ok(()) } #[test] @@ -254,8 +255,7 @@ mod tests { "Aggregate: groupBy=[[]], aggr=[[MAX(test.b)]] [MAX(test.b):UInt32;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -272,8 +272,7 @@ mod tests { \n Aggregate: groupBy=[[test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -294,8 +293,7 @@ mod tests { let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -313,8 +311,7 @@ mod tests { let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } // Currently this optimization is disabled for CUBE/ROLLUP/GROUPING SET @@ -333,8 +330,7 @@ mod tests { let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[COUNT(DISTINCT test.c)]] [a:UInt32, b:UInt32, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -350,8 +346,7 @@ mod tests { \n Aggregate: groupBy=[[Int32(2) * test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -368,8 +363,7 @@ mod tests { \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -387,8 +381,7 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(DISTINCT test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -415,8 +408,7 @@ mod tests { \n Aggregate: groupBy=[[test.a AS group_alias_0, test.b AS alias1]], aggr=[[]] [group_alias_0:UInt32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -434,19 +426,16 @@ mod tests { let expected = "Aggregate: groupBy=[[test.a]], aggr=[[COUNT(DISTINCT test.b), COUNT(test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] - fn group_by_with_expr() { + fn group_by_with_expr() -> Result<()> { let table_scan = test_table_scan().unwrap(); let plan = LogicalPlanBuilder::from(table_scan) - .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))]) - .unwrap() - .build() - .unwrap(); + .aggregate(vec![col("a") + lit(1)], vec![count_distinct(col("c"))])? + .build()?; // Should work let expected = "Projection: group_alias_0 AS test.a + Int32(1), COUNT(alias1) AS COUNT(DISTINCT test.c) [test.a + Int32(1):Int32, COUNT(DISTINCT test.c):Int64;N]\ @@ -454,6 +443,6 @@ mod tests { \n Aggregate: groupBy=[[test.a + Int32(1) AS group_alias_0, test.c AS alias1]], aggr=[[]] [group_alias_0:Int32, alias1:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); + assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/subquery_filter_to_join.rs b/datafusion/optimizer/src/subquery_filter_to_join.rs index da79869566d20..8d0c4e88d02e1 100644 --- a/datafusion/optimizer/src/subquery_filter_to_join.rs +++ b/datafusion/optimizer/src/subquery_filter_to_join.rs @@ -26,6 +26,7 @@ //! WHERE t1.f IN (SELECT f FROM t2) OR t2.f = 'x' //! ``` //! won't +use crate::optimizer::ApplyOrder; use crate::{utils, OptimizerConfig, OptimizerRule}; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{ @@ -52,14 +53,12 @@ impl OptimizerRule for SubqueryFilterToJoin { fn try_optimize( &self, plan: &LogicalPlan, - config: &dyn OptimizerConfig, + _config: &dyn OptimizerConfig, ) -> Result> { match plan { LogicalPlan::Filter(filter) => { // Apply optimizer rule to current input - let optimized_input = self - .try_optimize(filter.input(), config)? - .unwrap_or_else(|| filter.input().as_ref().clone()); + let input = filter.input().as_ref().clone(); // Splitting filter expression into components by AND let filters = utils::split_conjunction(filter.predicate()); @@ -82,14 +81,14 @@ impl OptimizerRule for SubqueryFilterToJoin { if !subqueries_in_regular.is_empty() { return Ok(Some(LogicalPlan::Filter(Filter::try_new( filter.predicate().clone(), - Arc::new(optimized_input), + Arc::new(input), )?))); }; // Add subquery joins to new_input // optimized_input value should retain for possible optimization rollback let opt_result = subquery_filters.iter().try_fold( - optimized_input.clone(), + input.clone(), |input, &e| match e { Expr::InSubquery { expr, @@ -98,7 +97,7 @@ impl OptimizerRule for SubqueryFilterToJoin { } => { let right_input = self.try_optimize( &subquery.subquery, - config + _config )?.unwrap_or_else(||subquery.subquery.as_ref().clone()); let right_schema = right_input.schema(); if right_schema.fields().len() != 1 { @@ -124,7 +123,7 @@ impl OptimizerRule for SubqueryFilterToJoin { }; let schema = build_join_schema( - optimized_input.schema(), + input.schema(), right_schema, &join_type, )?; @@ -154,7 +153,7 @@ impl OptimizerRule for SubqueryFilterToJoin { Err(_) => { return Ok(Some(LogicalPlan::Filter(Filter::try_new( filter.predicate().clone(), - Arc::new(optimized_input), + Arc::new(input), )?))) } }; @@ -166,16 +165,17 @@ impl OptimizerRule for SubqueryFilterToJoin { Ok(Some(utils::add_filter(new_input, ®ular_filters)?)) } } - _ => { - // Apply the optimization to all inputs of the plan - Ok(Some(utils::optimize_children(self, plan, config)?)) - } + _ => Ok(None), } } fn name(&self) -> &str { "subquery_filter_to_join" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::TopDown) + } } fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Result<()> { @@ -200,20 +200,18 @@ fn extract_subquery_filters(expression: &Expr, extracted: &mut Vec) -> Res mod tests { use super::*; use crate::test::*; - use crate::OptimizerContext; use datafusion_expr::{ and, binary_expr, col, in_subquery, lit, logical_plan::LogicalPlanBuilder, not_in_subquery, or, Operator, }; - fn assert_optimized_plan_eq(plan: &LogicalPlan, expected: &str) { - let rule = SubqueryFilterToJoin::new(); - let optimized_plan = rule - .try_optimize(plan, &OptimizerContext::new()) - .unwrap() - .expect("failed to optimize plan"); - let formatted_plan = format!("{}", optimized_plan.display_indent_schema()); - assert_eq!(formatted_plan, expected); + fn assert_optimized_plan_equal(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_optimized_plan_eq_display_indent( + Arc::new(SubqueryFilterToJoin::new()), + plan, + expected, + ); + Ok(()) } fn test_subquery_with_name(name: &str) -> Result> { @@ -240,8 +238,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } /// Test for single NOT IN subquery filter @@ -259,8 +256,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } /// Test for several IN subquery expressions @@ -284,8 +280,7 @@ mod tests { \n Projection: sq_2.c [c:UInt32]\ \n TableScan: sq_2 [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } /// Test for IN subquery with additional AND filter @@ -310,8 +305,7 @@ mod tests { \n Projection: sq.c [c:UInt32]\ \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } /// Test for IN subquery with additional OR filter @@ -337,8 +331,7 @@ mod tests { \n TableScan: sq [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } #[test] @@ -365,8 +358,7 @@ mod tests { \n TableScan: sq2 [a:UInt32, b:UInt32, c:UInt32]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } /// Test for nested IN subqueries @@ -393,8 +385,7 @@ mod tests { \n Projection: sq_nested.c [c:UInt32]\ \n TableScan: sq_nested [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } /// Test for filter input modification in case filter not supported @@ -425,7 +416,6 @@ mod tests { \n Projection: sq_inner.c [c:UInt32]\ \n TableScan: sq_inner [a:UInt32, b:UInt32, c:UInt32]"; - assert_optimized_plan_eq(&plan, expected); - Ok(()) + assert_optimized_plan_equal(&plan, expected) } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index a51c2ec29fe78..7532a9d1a3785 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -114,7 +114,7 @@ pub fn assert_optimized_plan_eq( plan, &OptimizerContext::new(), )? - .unwrap(); + .unwrap_or_else(|| plan.clone()); let formatted_plan = format!("{:?}", optimized_plan); assert_eq!(formatted_plan, expected); diff --git a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs index 9f6d1a5ac2ba9..9c4887699ea7a 100644 --- a/datafusion/optimizer/src/unwrap_cast_in_comparison.rs +++ b/datafusion/optimizer/src/unwrap_cast_in_comparison.rs @@ -18,6 +18,7 @@ //! Unwrap-cast binary comparison rule can be used to the binary/inlist comparison expr now, and other type //! of expr can be added if needed. //! This rule can reduce adding the `Expr::Cast` the expr instead of adding the `Expr::Cast` to literal expr. +use crate::optimizer::ApplyOrder; use crate::utils::rewrite_preserving_name; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::{ @@ -84,16 +85,9 @@ impl OptimizerRule for UnwrapCastInComparison { plan: &LogicalPlan, _config: &dyn OptimizerConfig, ) -> Result> { - let new_inputs = plan - .inputs() - .into_iter() - .map(|input| { - self.try_optimize(input, _config) - .map(|o| o.unwrap_or_else(|| input.clone())) - }) - .collect::>>()?; + let inputs: Vec = plan.inputs().into_iter().cloned().collect(); - let mut schema = new_inputs.iter().map(|input| input.schema()).fold( + let mut schema = inputs.iter().map(|input| input.schema()).fold( DFSchema::empty(), |mut lhs, rhs| { lhs.merge(rhs); @@ -116,13 +110,17 @@ impl OptimizerRule for UnwrapCastInComparison { Ok(Some(from_plan( plan, new_exprs.as_slice(), - new_inputs.as_slice(), + inputs.as_slice(), )?)) } fn name(&self) -> &str { "unwrap_cast_in_comparison" } + + fn apply_order(&self) -> Option { + Some(ApplyOrder::BottomUp) + } } struct UnwrapCastExprRewriter {