diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index a0a7137653960..12fb0ff969c07 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -862,6 +862,7 @@ dependencies = [ "arrow-schema", "datafusion-common", "datafusion-expr", + "datafusion-optimizer", "log", "sqlparser", ] diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 82c5a35b7f7e1..829f02b976f17 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -24,6 +24,8 @@ use arrow::{ }, record_batch::RecordBatch, }; +use arrow_array::TimestampNanosecondArray; +use arrow_schema::TimeUnit; use datafusion::from_slice::FromSlice; use std::sync::Arc; @@ -153,7 +155,27 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { #[tokio::test] async fn test_count_wildcard_on_window() -> Result<()> { - let ctx = create_join_context()?; + let ctx = SessionContext::new(); + let t1 = Arc::new(Schema::new(vec![ + Field::new("a", DataType::UInt32, false), + Field::new("ts", DataType::Timestamp(TimeUnit::Nanosecond, None), false), + ])); + + // define data. + let batch1 = RecordBatch::try_new( + t1, + vec![ + Arc::new(UInt32Array::from_slice([1, 10, 11, 100])), + Arc::new(TimestampNanosecondArray::from_slice([ + 1664264591000000000, + 1664264592000000000, + 1664264592000000000, + 1664264593000000000, + ])), + ], + )?; + + ctx.register_batch("t1", batch1)?; let sql_results = ctx .sql("select COUNT(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") @@ -185,6 +207,43 @@ async fn test_count_wildcard_on_window() -> Result<()> { pretty_format_batches(&sql_results)?.to_string() ); + #[allow(unused_doc_comments)] + ///special timestamp scenarios, DataFrame cannot achieve the same logical plan as SQL, it needs to be tested separately. + ///```sql + /// COUNT(*) OVER (ORDER BY ts DESC RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) + /// ``` + /// logic_plan + /// ```text + /// COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 18446744073709551616 PRECEDING AND 36893488147419103232 FOLLOWING + /// AS + /// COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING + /// ``` + let sql_results = ctx + .sql("select \ + COUNT(*) OVER (ORDER BY ts DESC RANGE BETWEEN INTERVAL '1' DAY PRECEDING AND INTERVAL '2 DAY' FOLLOWING) as cnt2 \ + from t1") + .await? + .explain(false, false)? + .collect() + .await?; + + #[rustfmt::skip] + let expected = vec![ + "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Projection: COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING AS cnt2 |", + "| | WindowAggr: windowExpr=[[COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 18446744073709551616 PRECEDING AND 36893488147419103232 FOLLOWING AS COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING]] |", + "| | TableScan: t1 projection=[a, ts] |", + "| physical_plan | ProjectionExec: expr=[COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING@2 as cnt2] |", + "| | BoundedWindowAggExec: wdw=[COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING: Ok(Field { name: \"COUNT(UInt8(1)) ORDER BY [t1.ts DESC NULLS FIRST] RANGE BETWEEN 1 DAY PRECEDING AND 2 DAY FOLLOWING\", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(IntervalMonthDayNano(\"18446744073709551616\")), end_bound: Following(IntervalMonthDayNano(\"36893488147419103232\")) }] |", + "| | SortExec: expr=[ts@1 DESC] |", + "| | MemoryExec: partitions=1, partition_sizes=[1] |", + "| | |", + "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + ]; + assert_batches_eq!(expected, &sql_results); + Ok(()) } diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 192e2f65f8acf..0d183d92bc1b3 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -1131,15 +1131,6 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result Vec { let df = ctx.sql(sql).await.unwrap(); - - // We are not really interested in the direct output of optimized_logical_plan - // since the physical plan construction already optimizes the given logical plan - // and we want to avoid double-optimization as a consequence. So we just construct - // it here to make sure that it doesn't fail at this step and get the optimized - // schema (to assert later that the logical and optimized schemas are the same). - let optimized = df.clone().into_optimized_plan().unwrap(); - assert_eq!(df.logical_plan().schema(), optimized.schema()); - df.collect().await.unwrap() } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index ed48da7fde98c..d439abbfbd678 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -20,7 +20,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; use datafusion_common::{Column, DFField, DFSchema, DFSchemaRef, Result}; use datafusion_expr::expr::AggregateFunction; use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_expr::Expr::{Exists, InSubquery, ScalarSubquery}; +use datafusion_expr::Expr::{Alias, Exists, InSubquery, ScalarSubquery}; use datafusion_expr::{ aggregate_function, count, expr, lit, window_function, Aggregate, Expr, Filter, LogicalPlan, Projection, Sort, Subquery, Window, @@ -230,6 +230,7 @@ impl TreeNodeRewriter for CountWildcardRewriter { negated, } } + Alias(expr, name) => Alias(expr, replace_count_star(name)), _ => old_expr, }; Ok(new_expr) @@ -240,29 +241,35 @@ fn rewrite_schema(schema: &DFSchema) -> DFSchemaRef { .fields() .iter() .map(|field| { - let mut name = field.field().name().clone(); - if name.contains(COUNT_STAR) { - name = name.replace( - COUNT_STAR, - count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), - ); - } DFField::new( field.qualifier().cloned(), - &name, + &replace_count_star(field.field().name().clone()), field.data_type().clone(), field.is_nullable(), ) }) .collect::>(); + DFSchemaRef::new( DFSchema::new_with_metadata(new_fields, schema.metadata().clone()).unwrap(), ) } +fn replace_count_star(name: String) -> String { + if name.contains(COUNT_STAR) { + name.replace( + COUNT_STAR, + count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(), + ) + } else { + name + } +} + #[cfg(test)] mod tests { use super::*; + use crate::analyzer::Analyzer; use crate::test::*; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::expr::Sort; @@ -376,6 +383,7 @@ mod tests { \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]"; assert_plan_eq(&plan, expected) } + #[test] fn test_count_wildcard_on_window() -> Result<()> { let table_scan = test_table_scan()?; diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index bb88226eaaa04..3e49c3316e9f9 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -41,6 +41,7 @@ arrow = { workspace = true } arrow-schema = { workspace = true } datafusion-common = { path = "../common", version = "22.0.0" } datafusion-expr = { path = "../expr", version = "22.0.0" } +datafusion-optimizer = { path = "../optimizer", version = "22.0.0" } log = "^0.4" sqlparser = "0.33" diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index bf076a2f3068a..00795ea899305 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -17,7 +17,6 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; use datafusion_common::{DFSchema, DataFusionError, Result}; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::window_frame::regularize; use datafusion_expr::{ expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame, @@ -215,12 +214,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { // Special case rewrite COUNT(*) to COUNT(constant) AggregateFunction::Count => args .into_iter() - .map(|a| match a { - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { - Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) - } - _ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context), - }) + .map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context)) .collect::>>()?, _ => self.function_args_to_expr(args, schema, planner_context)?, }; diff --git a/datafusion/sql/tests/integration_test.rs b/datafusion/sql/tests/integration_test.rs index 64ca85b72d989..017e68aa29164 100644 --- a/datafusion/sql/tests/integration_test.rs +++ b/datafusion/sql/tests/integration_test.rs @@ -35,6 +35,8 @@ use datafusion_sql::{ planner::{ContextProvider, ParserOptions, SqlToRel}, }; +use datafusion_optimizer::analyzer::Analyzer; +use datafusion_optimizer::{OptimizerConfig, OptimizerContext}; use rstest::rstest; #[cfg(test)] @@ -865,7 +867,7 @@ fn select_aggregate_with_having_referencing_column_not_in_select() { assert_eq!( "Plan(\"HAVING clause references non-aggregate values: \ Expression person.first_name could not be resolved from available columns: \ - COUNT(UInt8(1))\")", + COUNT(*)\")", format!("{err:?}") ); } @@ -2530,7 +2532,14 @@ fn logical_plan_with_dialect_and_options( let planner = SqlToRel::new_with_options(&context, options); let result = DFParser::parse_sql_with_dialect(sql, dialect); let mut ast = result?; - planner.statement_to_plan(ast.pop_front().unwrap()) + match planner.statement_to_plan(ast.pop_front().unwrap()) { + Ok(plan) => Analyzer::new().execute_and_check( + &plan, + OptimizerContext::new().options(), + |_, _| {}, + ), + Err(err) => Err(err), + } } /// Create logical plan, write with formatter, compare to expected output