From 86a014fe827bb8d53e73d40ecf4296d8efa999a4 Mon Sep 17 00:00:00 2001 From: jiangzhx Date: Fri, 14 Apr 2023 17:53:51 +0800 Subject: [PATCH] update count_wildcard_rule for more scenario --- datafusion/core/tests/dataframe.rs | 235 +++++++++-- .../src/analyzer/count_wildcard_rule.rs | 372 +++++++++++++++++- datafusion/optimizer/src/test/mod.rs | 12 + 3 files changed, 576 insertions(+), 43 deletions(-) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 4b2daa100b791..82c5a35b7f7e1 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -32,24 +32,169 @@ use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; +use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; +use datafusion_common::ScalarValue; use datafusion_expr::expr::{GroupingSet, Sort}; -use datafusion_expr::{avg, col, count, lit, max, sum, Expr, ExprSchemable}; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; +use datafusion_expr::Expr::Wildcard; +use datafusion_expr::{ + avg, col, count, exists, expr, in_subquery, lit, max, scalar_subquery, sum, + AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunction, +}; #[tokio::test] -async fn count_wildcard() -> Result<()> { - let ctx = SessionContext::new(); - let testdata = datafusion::test_util::parquet_test_data(); +async fn test_count_wildcard_on_sort() -> Result<()> { + let ctx = create_join_context()?; - ctx.register_parquet( - "alltypes_tiny_pages", - &format!("{testdata}/alltypes_tiny_pages.parquet"), - ParquetReadOptions::default(), - ) - .await?; + let sql_results = ctx + .sql("select b,count(*) from t1 group by b order by count(*)") + .await? + .explain(false, false)? + .collect() + .await?; + + let df_results = ctx + .table("t1") + .await? + .aggregate(vec![col("b")], vec![count(Wildcard)])? + .sort(vec![count(Wildcard).sort(true, false)])? + .explain(false, false)? + .collect() + .await?; + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); + Ok(()) +} + +#[tokio::test] +async fn test_count_wildcard_on_where_in() -> Result<()> { + let ctx = create_join_context()?; + let sql_results = ctx + .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(*) FROM t2)") + .await? + .explain(false, false)? + .collect() + .await?; + // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 + // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 + // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here + let ctx = create_join_context()?; + let df_results = ctx + .table("t1") + .await? + .filter(in_subquery( + col("a"), + Arc::new( + ctx.table("t2") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .select(vec![count(Expr::Wildcard)])? + .into_unoptimized_plan(), + // Usually, into_optimized_plan() should be used here, but due to + // https://github.com/apache/arrow-datafusion/issues/5771, + // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here. + ), + ))? + .select(vec![col("a"), col("b")])? + .explain(false, false)? + .collect() + .await?; + + // make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); + + Ok(()) +} + +#[tokio::test] +async fn test_count_wildcard_on_where_exist() -> Result<()> { + let ctx = create_join_context()?; let sql_results = ctx - .sql("select count(*) from alltypes_tiny_pages") + .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)") + .await? + .explain(false, false)? + .collect() + .await?; + let df_results = ctx + .table("t1") + .await? + .filter(exists(Arc::new( + ctx.table("t2") + .await? + .aggregate(vec![], vec![count(Expr::Wildcard)])? + .select(vec![count(Expr::Wildcard)])? + .into_unoptimized_plan(), + // Usually, into_optimized_plan() should be used here, but due to + // https://github.com/apache/arrow-datafusion/issues/5771, + // subqueries in SQL cannot be optimized, resulting in differences in logical_plan. Therefore, into_unoptimized_plan() is temporarily used here. + )))? + .select(vec![col("a"), col("b")])? + .explain(false, false)? + .collect() + .await?; + + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); + + Ok(()) +} + +#[tokio::test] +async fn test_count_wildcard_on_window() -> Result<()> { + let ctx = create_join_context()?; + + let sql_results = ctx + .sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") + .await? + .explain(false, false)? + .collect() + .await?; + let df_results = ctx + .table("t1") + .await? + .select(vec![Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateFunction(AggregateFunction::Count), + vec![Expr::Wildcard], + vec![], + vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + }, + ))])? + .explain(false, false)? + .collect() + .await?; + + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&df_results)?.to_string(), + pretty_format_batches(&sql_results)?.to_string() + ); + + Ok(()) +} + +#[tokio::test] +async fn test_count_wildcard_on_aggregate() -> Result<()> { + let ctx = create_join_context()?; + register_alltypes_tiny_pages_parquet(&ctx).await?; + + let sql_results = ctx + .sql("select count(*) from t1") .await? .select(vec![count(Expr::Wildcard)])? .explain(false, false)? @@ -58,7 +203,7 @@ async fn count_wildcard() -> Result<()> { // add `.select(vec![count(Expr::Wildcard)])?` to make sure we can analyze all node instead of just top node. let df_results = ctx - .table("alltypes_tiny_pages") + .table("t1") .await? .aggregate(vec![], vec![count(Expr::Wildcard)])? .select(vec![count(Expr::Wildcard)])? @@ -72,24 +217,51 @@ async fn count_wildcard() -> Result<()> { pretty_format_batches(&df_results)?.to_string() ); - let results = ctx - .table("alltypes_tiny_pages") + Ok(()) +} +#[tokio::test] +async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { + let ctx = create_join_context()?; + + let sql_results = ctx + .sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;") .await? - .aggregate(vec![], vec![count(Expr::Wildcard)])? + .explain(false, false)? .collect() .await?; - let expected = vec![ - "+-----------------+", - "| COUNT(UInt8(1)) |", - "+-----------------+", - "| 7300 |", - "+-----------------+", - ]; - assert_batches_sorted_eq!(expected, &results); + // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 + // https://github.com/apache/arrow-datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 + // for compare difference betwwen sql and df logical plan, we need to create a new SessionContext here + let ctx = create_join_context()?; + let df_results = ctx + .table("t1") + .await? + .filter( + scalar_subquery(Arc::new( + ctx.table("t2") + .await? + .filter(col("t1.a").eq(col("t2.a")))? + .aggregate(vec![], vec![count(lit(COUNT_STAR_EXPANSION))])? + .select(vec![count(lit(COUNT_STAR_EXPANSION))])? + .into_unoptimized_plan(), + )) + .gt(lit(ScalarValue::UInt8(Some(0)))), + )? + .select(vec![col("t1.a"), col("t1.b")])? + .explain(false, false)? + .collect() + .await?; + + //make sure sql plan same with df plan + assert_eq!( + pretty_format_batches(&sql_results)?.to_string(), + pretty_format_batches(&df_results)?.to_string() + ); Ok(()) } + #[tokio::test] async fn describe() -> Result<()> { let ctx = SessionContext::new(); @@ -229,7 +401,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { let results = df.collect().await.unwrap(); #[rustfmt::skip] - let expected = vec![ + let expected = vec![ "+-----+", "| a |", "+-----+", @@ -275,7 +447,7 @@ async fn sort_on_distinct_columns() -> Result<()> { let results = df.collect().await.unwrap(); #[rustfmt::skip] - let expected = vec![ + let expected = vec![ "+-----+", "| a |", "+-----+", @@ -417,7 +589,7 @@ async fn filter_with_alias_overwrite() -> Result<()> { let results = df.collect().await.unwrap(); #[rustfmt::skip] - let expected = vec![ + let expected = vec![ "+------+", "| a |", "+------+", @@ -1047,3 +1219,14 @@ async fn table_with_nested_types(n: usize) -> Result { ctx.register_batch("shapes", batch)?; ctx.table("shapes").await } + +pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Result<()> { + let testdata = parquet_test_data(); + ctx.register_parquet( + "alltypes_tiny_pages", + &format!("{testdata}/alltypes_tiny_pages.parquet"), + ParquetReadOptions::default(), + ) + .await?; + Ok(()) +} diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index ecd00d7ac15c8..ed48da7fde98c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -16,14 +16,22 @@ // under the License. use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::Result; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::{aggregate_function, lit, Aggregate, Expr, LogicalPlan, Window}; +use datafusion_expr::Expr::{Exists, InSubquery, ScalarSubquery}; +use datafusion_expr::{ + aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter, + LogicalPlan, Projection, Sort, Subquery, Window, +}; +use std::string::ToString; +use std::sync::Arc; use crate::analyzer::AnalyzerRule; +pub const COUNT_STAR: &str = "COUNT(*)"; + /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. #[derive(Default)] @@ -46,35 +54,116 @@ impl AnalyzerRule for CountWildcardRule { } fn analyze_internal(plan: LogicalPlan) -> Result> { + let mut rewriter = CountWildcardRewriter {}; match plan { LogicalPlan::Window(window) => { - let window_expr = handle_wildcard(&window.window_expr); + let window_expr = window + .window_expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect::>(); + Ok(Transformed::Yes(LogicalPlan::Window(Window { input: window.input.clone(), window_expr, - schema: window.schema, + schema: rewrite_schema(&window.schema), }))) } LogicalPlan::Aggregate(agg) => { - let aggr_expr = handle_wildcard(&agg.aggr_expr); + let aggr_expr = agg + .aggr_expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); + Ok(Transformed::Yes(LogicalPlan::Aggregate( Aggregate::try_new_with_schema( agg.input.clone(), agg.group_expr.clone(), aggr_expr, - agg.schema, + rewrite_schema(&agg.schema), )?, ))) } + LogicalPlan::Sort(Sort { expr, input, fetch }) => { + let sort_expr = expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); + Ok(Transformed::Yes(LogicalPlan::Sort(Sort { + expr: sort_expr, + input, + fetch, + }))) + } + LogicalPlan::Projection(projection) => { + let projection_expr = projection + .expr + .iter() + .map(|expr| expr.clone().rewrite(&mut rewriter).unwrap()) + .collect(); + Ok(Transformed::Yes(LogicalPlan::Projection( + Projection::try_new_with_schema( + projection_expr, + projection.input, + // rewrite_schema(projection.schema.clone()), + rewrite_schema(&projection.schema), + )?, + ))) + } + LogicalPlan::Filter(Filter { + predicate, input, .. + }) => { + let predicate = predicate.rewrite(&mut rewriter).unwrap(); + Ok(Transformed::Yes(LogicalPlan::Filter( + Filter::try_new(predicate, input).unwrap(), + ))) + } + _ => Ok(Transformed::No(plan)), } } -// handle Count(Expr:Wildcard) with DataFrame API -pub fn handle_wildcard(exprs: &[Expr]) -> Vec { - exprs - .iter() - .map(|expr| match expr { +struct CountWildcardRewriter {} + +impl TreeNodeRewriter for CountWildcardRewriter { + type N = Expr; + + fn mutate(&mut self, old_expr: Expr) -> Result { + let new_expr = match old_expr.clone() { + Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => { + Expr::Column(Column { + name: name.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ), + relation: relation.clone(), + }) + } + Expr::WindowFunction(expr::WindowFunction { + fun: + window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args, + partition_by, + order_by, + window_frame, + }) if args.len() == 1 => match args[0] { + Expr::Wildcard => { + Expr::WindowFunction(datafusion_expr::expr::WindowFunction { + fun: window_function::WindowFunction::AggregateFunction( + aggregate_function::AggregateFunction::Count, + ), + args: vec![lit(COUNT_STAR_EXPANSION)], + partition_by, + order_by, + window_frame, + }) + } + + _ => old_expr, + }, Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args, @@ -84,12 +173,261 @@ pub fn handle_wildcard(exprs: &[Expr]) -> Vec { Expr::Wildcard => Expr::AggregateFunction(AggregateFunction { fun: aggregate_function::AggregateFunction::Count, args: vec![lit(COUNT_STAR_EXPANSION)], - distinct: *distinct, - filter: filter.clone(), + distinct, + filter, }), - _ => expr.clone(), + _ => old_expr, }, - _ => expr.clone(), + + ScalarSubquery(Subquery { + subquery, + outer_ref_columns, + }) => { + let new_plan = subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + ScalarSubquery(Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns, + }) + } + InSubquery { + expr, + subquery, + negated, + } => { + let new_plan = subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + + InSubquery { + expr, + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + } + } + Exists { subquery, negated } => { + let new_plan = subquery + .subquery + .as_ref() + .clone() + .transform_down(&analyze_internal) + .unwrap(); + + Exists { + subquery: Subquery { + subquery: Arc::new(new_plan), + outer_ref_columns: subquery.outer_ref_columns, + }, + negated, + } + } + _ => old_expr, + }; + Ok(new_expr) + } +} +fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef { + let new_fields = schema + .fields() + .iter() + .map(|field| { + let mut name = field.field().name().clone(); + if name.contains(COUNT_STAR) { + name = name.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ); + } + DFField::new( + field.qualifier().cloned(), + &name, + field.data_type().clone(), + field.is_nullable(), + ) }) - .collect() + .collect::>(); + DFSchemaRef::new( + DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(), + ) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test::*; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::expr::Sort; + use datafusion_expr::{ + col, count, exists, expr, in_subquery, lit, logical_plan::LogicalPlanBuilder, + max, scalar_subquery, AggregateFunction, Expr, WindowFrame, WindowFrameBound, + WindowFrameUnits, WindowFunction, + }; + + fn assert_plan_eq(plan: &LogicalPlan, expected: &str) -> Result<()> { + assert_analyzed_plan_eq_display_indent( + Arc::new(CountWildcardRule::new()), + plan, + expected, + ) + } + + #[test] + fn test_count_wildcard_on_sort() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(vec![col("b")], vec![count(Expr::Wildcard)])? + .project(vec![count(Expr::Wildcard)])? + .sort(vec![count(Expr::Wildcard).sort(true, false)])? + .build()?; + let expected = "Sort: COUNT(UInt8(1)) ASC NULLS LAST [COUNT(UInt8(1)):Int64;N]\ + \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[COUNT(UInt8(1))]] [b:UInt32, COUNT(UInt8(1)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } + + #[test] + fn test_count_wildcard_on_where_in() -> Result<()> { + let table_scan_t1 = test_table_scan_with_name("t1")?; + let table_scan_t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(table_scan_t1) + .filter(in_subquery( + col("a"), + Arc::new( + LogicalPlanBuilder::from(table_scan_t2) + .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? + .project(vec![count(Expr::Wildcard)])? + .build()?, + ), + ))? + .build()?; + + let expected = "Filter: t1.a IN () [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [COUNT(UInt8(1)):Int64;N]\ + \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } + + #[test] + fn test_count_wildcard_on_where_exists() -> Result<()> { + let table_scan_t1 = test_table_scan_with_name("t1")?; + let table_scan_t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(table_scan_t1) + .filter(exists(Arc::new( + LogicalPlanBuilder::from(table_scan_t2) + .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? + .project(vec![count(Expr::Wildcard)])? + .build()?, + )))? + .build()?; + + let expected = "Filter: EXISTS () [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [COUNT(UInt8(1)):Int64;N]\ + \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } + + #[test] + fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { + let table_scan_t1 = test_table_scan_with_name("t1")?; + let table_scan_t2 = test_table_scan_with_name("t2")?; + + let plan = LogicalPlanBuilder::from(table_scan_t1) + .filter( + scalar_subquery(Arc::new( + LogicalPlanBuilder::from(table_scan_t2) + .filter(col("t1.a").eq(col("t2.a")))? + .aggregate( + Vec::::new(), + vec![count(lit(COUNT_STAR_EXPANSION))], + )? + .project(vec![count(lit(COUNT_STAR_EXPANSION))])? + .build()?, + )) + .gt(lit(ScalarValue::UInt8(Some(0)))), + )? + .project(vec![col("t1.a"), col("t1.b")])? + .build()?; + + let expected = "Projection: t1.a, t1.b [a:UInt32, b:UInt32]\ + \n Filter: () > UInt8(0) [a:UInt32, b:UInt32, c:UInt32]\ + \n Subquery: [COUNT(UInt8(1)):Int64;N]\ + \n Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ + \n Filter: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\ + \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } + #[test] + fn test_count_wildcard_on_window() -> Result<()> { + let table_scan = test_table_scan()?; + + let plan = LogicalPlanBuilder::from(table_scan) + .window(vec![Expr::WindowFunction(expr::WindowFunction::new( + WindowFunction::AggregateFunction(AggregateFunction::Count), + vec![Expr::Wildcard], + vec![], + vec![Expr::Sort(Sort::new(Box::new(col("a")), false, true))], + WindowFrame { + units: WindowFrameUnits::Range, + start_bound: WindowFrameBound::Preceding(ScalarValue::UInt32(Some( + 6, + ))), + end_bound: WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + }, + ))])? + .project(vec![count(Expr::Wildcard)])? + .build()?; + + let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] [a:UInt32, b:UInt32, c:UInt32, COUNT(UInt8(1)) ORDER BY [test.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING:Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } + + #[test] + fn test_count_wildcard_on_aggregate() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![count(Expr::Wildcard)])? + .project(vec![count(Expr::Wildcard)])? + .build()?; + + let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]] [COUNT(UInt8(1)):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } + + #[test] + fn test_count_wildcard_on_nesting() -> Result<()> { + let table_scan = test_table_scan()?; + let plan = LogicalPlanBuilder::from(table_scan) + .aggregate(Vec::::new(), vec![max(count(Expr::Wildcard))])? + .project(vec![count(Expr::Wildcard)])? + .build()?; + + let expected = "Projection: COUNT(UInt8(1)) [COUNT(UInt8(1)):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[MAX(COUNT(UInt8(1)))]] [MAX(COUNT(UInt8(1))):Int64;N]\ + \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; + assert_plan_eq(&plan, expected) + } } diff --git a/datafusion/optimizer/src/test/mod.rs b/datafusion/optimizer/src/test/mod.rs index 439f44151ed77..67d342b4cb4b0 100644 --- a/datafusion/optimizer/src/test/mod.rs +++ b/datafusion/optimizer/src/test/mod.rs @@ -121,7 +121,19 @@ pub fn assert_analyzed_plan_eq( Ok(()) } +pub fn assert_analyzed_plan_eq_display_indent( + rule: Arc, + plan: &LogicalPlan, + expected: &str, +) -> Result<()> { + let options = ConfigOptions::default(); + let analyzed_plan = + Analyzer::with_rules(vec![rule]).execute_and_check(plan, &options)?; + let formatted_plan = format!("{}", analyzed_plan.display_indent_schema()); + assert_eq!(formatted_plan, expected); + Ok(()) +} pub fn assert_optimized_plan_eq( rule: Arc, plan: &LogicalPlan,