From c74374a66c7725a6742409164963f222457b5e1b Mon Sep 17 00:00:00 2001 From: jackwener Date: Tue, 21 Mar 2023 22:10:24 +0800 Subject: [PATCH 1/3] refactor: move analyzer to new dir and polish CountWildcardRule. --- .../src/{ => analyzer}/count_wildcard_rule.rs | 43 ++++++++----------- .../src/{analyzer.rs => analyzer/mod.rs} | 4 +- datafusion/optimizer/src/lib.rs | 1 - 3 files changed, 21 insertions(+), 27 deletions(-) rename datafusion/optimizer/src/{ => analyzer}/count_wildcard_rule.rs (70%) rename datafusion/optimizer/src/{analyzer.rs => analyzer/mod.rs} (98%) diff --git a/datafusion/optimizer/src/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs similarity index 70% rename from datafusion/optimizer/src/count_wildcard_rule.rs rename to datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 416bd0337a4da..3654d431aef7d 100644 --- a/datafusion/optimizer/src/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -15,15 +15,16 @@ // specific language governing permissions and limitations // under the License. -use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; -use std::ops::Deref; -use std::sync::Arc; +use crate::analyzer::AnalyzerRule; + +/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. +/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. pub struct CountWildcardRule {} impl Default for CountWildcardRule { @@ -41,29 +42,22 @@ impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result { let new_plan = match plan { LogicalPlan::Window(window) => { - let inputs = plan.inputs(); - let window_expr = window.clone().window_expr; - let window_expr = handle_wildcard(window_expr).unwrap(); + let window_expr = handle_wildcard(&window.window_expr); LogicalPlan::Window(Window { - input: Arc::new(inputs.get(0).unwrap().deref().clone()), + input: window.input.clone(), window_expr, schema: plan.schema().clone(), }) } - LogicalPlan::Aggregate(aggregate) => { - let inputs = plan.inputs(); - let aggr_expr = aggregate.clone().aggr_expr; - let aggr_expr = handle_wildcard(aggr_expr).unwrap(); - LogicalPlan::Aggregate( - Aggregate::try_new_with_schema( - Arc::new(inputs.get(0).unwrap().deref().clone()), - aggregate.clone().group_expr, - aggr_expr, - plan.schema().clone(), - ) - .unwrap(), - ) + LogicalPlan::Aggregate(agg) => { + let aggr_expr = handle_wildcard(&agg.aggr_expr); + LogicalPlan::Aggregate(Aggregate::try_new_with_schema( + agg.input.clone(), + agg.group_expr.clone(), + aggr_expr, + plan.schema().clone(), + )?) } _ => plan.clone(), }; @@ -75,9 +69,9 @@ impl AnalyzerRule for CountWildcardRule { } } -//handle Count(Expr:Wildcard) with DataFrame API -pub fn handle_wildcard(exprs: Vec) -> Result> { - let exprs: Vec = exprs +// handle Count(Expr:Wildcard) with DataFrame API +pub fn handle_wildcard(exprs: &[Expr]) -> Vec { + exprs .iter() .map(|expr| match expr { Expr::AggregateFunction(AggregateFunction { @@ -96,6 +90,5 @@ pub fn handle_wildcard(exprs: Vec) -> Result> { }, _ => expr.clone(), }) - .collect(); - Ok(exprs) + .collect() } diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer/mod.rs similarity index 98% rename from datafusion/optimizer/src/analyzer.rs rename to datafusion/optimizer/src/analyzer/mod.rs index e999eb2419d07..0982198bb841d 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::count_wildcard_rule::CountWildcardRule; +mod count_wildcard_rule; + +use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 3fa1995271dc8..7f930ae3a8d0b 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -45,7 +45,6 @@ pub mod type_coercion; pub mod unwrap_cast_in_comparison; pub mod utils; -pub mod count_wildcard_rule; #[cfg(test)] pub mod test; From 0dcf69f75048133d03c4dd7756587d37d0061338 Mon Sep 17 00:00:00 2001 From: jackwener Date: Tue, 21 Mar 2023 22:56:08 +0800 Subject: [PATCH 2/3] fix typo --- datafusion/optimizer/src/eliminate_filter.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c5dc1711c244a..c97906a81adf1 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -27,7 +27,7 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that elimanate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] pub struct EliminateFilter; From 554b994b54d642d6a8164e77428e27670e385415 Mon Sep 17 00:00:00 2001 From: jackwener Date: Tue, 21 Mar 2023 23:53:46 +0800 Subject: [PATCH 3/3] correct rule. --- datafusion/core/tests/dataframe.rs | 6 ++- .../src/analyzer/count_wildcard_rule.rs | 49 ++++++++++--------- 2 files changed, 31 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 9da5bf2b86ad2..23c6623ab1b5a 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -51,16 +51,18 @@ async fn count_wildcard() -> Result<()> { let sql_results = ctx .sql("select count(*) from alltypes_tiny_pages") .await? + .select(vec![count(Expr::Wildcard)])? .explain(false, false)? .collect() .await?; + // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("alltypes_tiny_pages") .await? .aggregate(vec![], vec![count(Expr::Wildcard)])? - .explain(false, false) - .unwrap() + .select(vec![count(Expr::Wildcard)])? + .explain(false, false)? .collect() .await?; diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 3654d431aef7d..4b4c603bcfe46 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -22,6 +22,7 @@ use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; use crate::analyzer::AnalyzerRule; +use crate::rewrite::TreeNodeRewritable; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. @@ -40,28 +41,7 @@ impl CountWildcardRule { } impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result { - let new_plan = match plan { - LogicalPlan::Window(window) => { - let window_expr = handle_wildcard(&window.window_expr); - LogicalPlan::Window(Window { - input: window.input.clone(), - window_expr, - schema: plan.schema().clone(), - }) - } - - LogicalPlan::Aggregate(agg) => { - let aggr_expr = handle_wildcard(&agg.aggr_expr); - LogicalPlan::Aggregate(Aggregate::try_new_with_schema( - agg.input.clone(), - agg.group_expr.clone(), - aggr_expr, - plan.schema().clone(), - )?) - } - _ => plan.clone(), - }; - Ok(new_plan) + plan.clone().transform_down(&analyze_internal) } fn name(&self) -> &str { @@ -69,6 +49,31 @@ impl AnalyzerRule for CountWildcardRule { } } +fn analyze_internal(plan: LogicalPlan) -> Result> { + match plan { + LogicalPlan::Window(window) => { + let window_expr = handle_wildcard(&window.window_expr); + Ok(Some(LogicalPlan::Window(Window { + input: window.input.clone(), + window_expr, + schema: window.schema, + }))) + } + LogicalPlan::Aggregate(agg) => { + let aggr_expr = handle_wildcard(&agg.aggr_expr); + Ok(Some(LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + agg.input.clone(), + agg.group_expr.clone(), + aggr_expr, + agg.schema, + )?, + ))) + } + _ => Ok(None), + } +} + // handle Count(Expr:Wildcard) with DataFrame API pub fn handle_wildcard(exprs: &[Expr]) -> Vec { exprs