diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 9da5bf2b86ad2..23c6623ab1b5a 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -51,16 +51,18 @@ async fn count_wildcard() -> Result<()> { let sql_results = ctx .sql("select count(*) from alltypes_tiny_pages") .await? + .select(vec![count(Expr::Wildcard)])? .explain(false, false)? .collect() .await?; + // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("alltypes_tiny_pages") .await? .aggregate(vec![], vec![count(Expr::Wildcard)])? - .explain(false, false) - .unwrap() + .select(vec![count(Expr::Wildcard)])? + .explain(false, false)? .collect() .await?; diff --git a/datafusion/optimizer/src/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs similarity index 61% rename from datafusion/optimizer/src/count_wildcard_rule.rs rename to datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 416bd0337a4da..4b4c603bcfe46 100644 --- a/datafusion/optimizer/src/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -15,15 +15,17 @@ // specific language governing permissions and limitations // under the License. -use crate::analyzer::AnalyzerRule; use datafusion_common::config::ConfigOptions; use datafusion_common::Result; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; -use std::ops::Deref; -use std::sync::Arc; +use crate::analyzer::AnalyzerRule; +use crate::rewrite::TreeNodeRewritable; + +/// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. +/// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. pub struct CountWildcardRule {} impl Default for CountWildcardRule { @@ -39,35 +41,7 @@ impl CountWildcardRule { } impl AnalyzerRule for CountWildcardRule { fn analyze(&self, plan: &LogicalPlan, _: &ConfigOptions) -> Result { - let new_plan = match plan { - LogicalPlan::Window(window) => { - let inputs = plan.inputs(); - let window_expr = window.clone().window_expr; - let window_expr = handle_wildcard(window_expr).unwrap(); - LogicalPlan::Window(Window { - input: Arc::new(inputs.get(0).unwrap().deref().clone()), - window_expr, - schema: plan.schema().clone(), - }) - } - - LogicalPlan::Aggregate(aggregate) => { - let inputs = plan.inputs(); - let aggr_expr = aggregate.clone().aggr_expr; - let aggr_expr = handle_wildcard(aggr_expr).unwrap(); - LogicalPlan::Aggregate( - Aggregate::try_new_with_schema( - Arc::new(inputs.get(0).unwrap().deref().clone()), - aggregate.clone().group_expr, - aggr_expr, - plan.schema().clone(), - ) - .unwrap(), - ) - } - _ => plan.clone(), - }; - Ok(new_plan) + plan.clone().transform_down(&analyze_internal) } fn name(&self) -> &str { @@ -75,9 +49,34 @@ impl AnalyzerRule for CountWildcardRule { } } -//handle Count(Expr:Wildcard) with DataFrame API -pub fn handle_wildcard(exprs: Vec) -> Result> { - let exprs: Vec = exprs +fn analyze_internal(plan: LogicalPlan) -> Result> { + match plan { + LogicalPlan::Window(window) => { + let window_expr = handle_wildcard(&window.window_expr); + Ok(Some(LogicalPlan::Window(Window { + input: window.input.clone(), + window_expr, + schema: window.schema, + }))) + } + LogicalPlan::Aggregate(agg) => { + let aggr_expr = handle_wildcard(&agg.aggr_expr); + Ok(Some(LogicalPlan::Aggregate( + Aggregate::try_new_with_schema( + agg.input.clone(), + agg.group_expr.clone(), + aggr_expr, + agg.schema, + )?, + ))) + } + _ => Ok(None), + } +} + +// handle Count(Expr:Wildcard) with DataFrame API +pub fn handle_wildcard(exprs: &[Expr]) -> Vec { + exprs .iter() .map(|expr| match expr { Expr::AggregateFunction(AggregateFunction { @@ -96,6 +95,5 @@ pub fn handle_wildcard(exprs: Vec) -> Result> { }, _ => expr.clone(), }) - .collect(); - Ok(exprs) + .collect() } diff --git a/datafusion/optimizer/src/analyzer.rs b/datafusion/optimizer/src/analyzer/mod.rs similarity index 98% rename from datafusion/optimizer/src/analyzer.rs rename to datafusion/optimizer/src/analyzer/mod.rs index e999eb2419d07..0982198bb841d 100644 --- a/datafusion/optimizer/src/analyzer.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. -use crate::count_wildcard_rule::CountWildcardRule; +mod count_wildcard_rule; + +use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::rewrite::TreeNodeRewritable; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; diff --git a/datafusion/optimizer/src/eliminate_filter.rs b/datafusion/optimizer/src/eliminate_filter.rs index c5dc1711c244a..c97906a81adf1 100644 --- a/datafusion/optimizer/src/eliminate_filter.rs +++ b/datafusion/optimizer/src/eliminate_filter.rs @@ -27,7 +27,7 @@ use datafusion_expr::{ use crate::{OptimizerConfig, OptimizerRule}; -/// Optimization rule that elimanate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] +/// Optimization rule that eliminate the scalar value (true/false) filter with an [LogicalPlan::EmptyRelation] #[derive(Default)] pub struct EliminateFilter; diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index 3fa1995271dc8..7f930ae3a8d0b 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -45,7 +45,6 @@ pub mod type_coercion; pub mod unwrap_cast_in_comparison; pub mod utils; -pub mod count_wildcard_rule; #[cfg(test)] pub mod test;