From 32d8b975377d122a1c5b2124d8ea4c0d585a277c Mon Sep 17 00:00:00 2001 From: jackwener Date: Mon, 12 Jun 2023 21:45:11 +0800 Subject: [PATCH] refactor: unify replace count(*) analyzer by removing it in sql crate fix: CountWildcardRule ignore Expr::Alias --- .../sqllogictests/test_files/functions.slt | 1 - .../src/analyzer/count_wildcard_rule.rs | 7 ++++ datafusion/sql/src/expr/function.rs | 37 ++----------------- datafusion/sql/tests/integration_test.rs | 30 +++++++-------- 4 files changed, 26 insertions(+), 49 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/functions.slt b/datafusion/core/tests/sqllogictests/test_files/functions.slt index 92597118c65a6..301d73befb02f 100644 --- a/datafusion/core/tests/sqllogictests/test_files/functions.slt +++ b/datafusion/core/tests/sqllogictests/test_files/functions.slt @@ -565,4 +565,3 @@ SELECT sqrt(column1),sqrt(column2),sqrt(column3),sqrt(column4),sqrt(column5),sqr statement ok drop table t - diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index 0e2689c14fa14..51354cb666612 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -132,6 +132,13 @@ impl TreeNodeRewriter for CountWildcardRewriter { fn mutate(&mut self, old_expr: Expr) -> Result { let new_expr = match old_expr.clone() { + Expr::Alias(expr, alias) if alias.contains(COUNT_STAR) => Expr::Alias( + expr, + alias.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ), + ), Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => { Expr::Column(Column { name: name.replace( diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0289e804110ce..afd0f61292014 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -19,7 +19,6 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, DataFusionError, Result}; use datafusion_expr::expr::{ScalarFunction, ScalarUDF}; use datafusion_expr::function::suggest_valid_function; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, @@ -96,8 +95,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { if let Ok(fun) = self.find_window_func(&name) { let expr = match fun { WindowFunction::AggregateFunction(aggregate_fun) => { - let (aggregate_fun, args) = self.aggregate_fn_to_expr( - aggregate_fun, + let args = self.function_args_to_expr( function.args, schema, planner_context, @@ -135,12 +133,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { planner_context, )?; let order_by = (!order_by.is_empty()).then_some(order_by); - let (fun, args) = self.aggregate_fn_to_expr( - fun, - function.args, - schema, - planner_context, - )?; + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; + return Ok(Expr::AggregateFunction(expr::AggregateFunction::new( fun, args, distinct, None, order_by, ))); @@ -228,28 +223,4 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context)) .collect::>>() } - - pub(super) fn aggregate_fn_to_expr( - &self, - fun: AggregateFunction, - args: Vec, - schema: &DFSchema, - planner_context: &mut PlannerContext, - ) -> Result<(AggregateFunction, Vec)> { - let args = match fun { - // Special case rewrite COUNT(*) to COUNT(constant) - AggregateFunction::Count => args - .into_iter() - .map(|a| match a { - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) - } - _ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context), - }) - .collect::>>()?, - _ => self.function_args_to_expr(args, schema, planner_context)?, - }; - - Ok((fun, args)) - } } diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 7161fa481cbe1..77706ea39967a 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -881,7 +881,7 @@ fn select_aggregate_with_having_referencing_column_not_in_select() { assert_eq!( "Plan(\"HAVING clause references non-aggregate values: \ Expression person.first_name could not be resolved from available columns: \ - COUNT(UInt8(1))\")", + COUNT(*)\")", format!("{err:?}") ); } @@ -1084,8 +1084,8 @@ fn select_aggregate_with_group_by_with_having_using_count_star_not_in_select() { GROUP BY first_name HAVING MAX(age) > 100 AND COUNT(*) < 50"; let expected = "Projection: person.first_name, MAX(person.age)\ - \n Filter: MAX(person.age) > Int64(100) AND COUNT(UInt8(1)) < Int64(50)\ - \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(UInt8(1))]]\ + \n Filter: MAX(person.age) > Int64(100) AND COUNT(*) < Int64(50)\ + \n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -1665,8 +1665,8 @@ fn select_group_by_columns_not_in_select() { #[test] fn select_group_by_count_star() { let sql = "SELECT state, COUNT(*) FROM person GROUP BY state"; - let expected = "Projection: person.state, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.state, COUNT(*)\ + \n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); @@ -2884,8 +2884,8 @@ fn scalar_subquery_reference_outer_field() { let expected = "Projection: j1.j1_string, j2.j2_string\ \n Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < ()\ \n Subquery:\ - \n Projection: COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n Projection: COUNT(*)\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(*)]]\ \n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\ \n CrossJoin:\ \n TableScan: j1\ @@ -2983,8 +2983,8 @@ fn cte_unbalanced_number_of_columns() { fn aggregate_with_rollup() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -2993,8 +2993,8 @@ fn aggregate_with_rollup() { fn aggregate_with_rollup_with_grouping() { let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \ FROM person GROUP BY id, ROLLUP (state, age)"; - let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3025,8 +3025,8 @@ fn rank_partition_grouping() { fn aggregate_with_cube() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)"; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); } @@ -3042,8 +3042,8 @@ fn round_decimal() { #[test] fn aggregate_with_grouping_sets() { let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))"; - let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\ - \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Projection: person.id, person.state, person.age, COUNT(*)\ + \n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(*)]]\ \n TableScan: person"; quick_test(sql, expected); }