diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4cf96d2350ebb..bcf4fec071d47 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; use crate::execution::context::SessionConfig; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; @@ -37,6 +38,9 @@ use crate::error::Result; #[derive(Default)] pub struct AggregateStatistics {} +/// The name of the column corresponding to [`COUNT_STAR_EXPANSION`] +const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))"; + impl AggregateStatistics { #[allow(missing_docs)] pub fn new() -> Self { @@ -148,10 +152,10 @@ fn take_optimizable_table_count( .as_any() .downcast_ref::() { - if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { + if lit_expr.value() == &COUNT_STAR_EXPANSION { return Some(( - ScalarValue::UInt64(Some(num_rows as u64)), - "COUNT(UInt8(1))", + ScalarValue::Int64(Some(num_rows as i64)), + COUNT_STAR_NAME, )); } } @@ -183,7 +187,7 @@ fn take_optimizable_column_count( { let expr = format!("COUNT({})", col_expr.name()); return Some(( - ScalarValue::UInt64(Some((num_rows - val) as u64)), + ScalarValue::Int64(Some((num_rows - val) as i64)), expr, )); } @@ -254,9 +258,10 @@ mod tests { use super::*; use std::sync::Arc; - use arrow::array::{Int32Array, UInt64Array}; + use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use datafusion_physical_expr::PhysicalExpr; use crate::error::Result; use crate::logical_plan::Operator; @@ -291,40 +296,106 @@ mod tests { } /// Checks that the count optimization was applied and we still get the right result - async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> Result<()> { + async fn assert_count_optim_success( + plan: AggregateExec, + agg: TestAggregate, + ) -> Result<()> { let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); let conf = session_ctx.copied_config(); - let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; - - let (col, count) = match nulls { - false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3), - true => (Field::new("COUNT(a)", DataType::UInt64, false), 2), - }; + let plan = Arc::new(plan) as _; + let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); - let result = common::collect(optimized.execute(0, task_ctx)?).await?; - assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); + + // run both the optimized and nonoptimized plan + let optimized_result = + common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; + let nonoptimized_result = + common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; + assert_eq!(optimized_result.len(), nonoptimized_result.len()); + + // and validate the results are the same and expected + assert_eq!(optimized_result.len(), 1); + check_batch(optimized_result.into_iter().next().unwrap(), &agg); + // check the non optimized one too to ensure types and names remain the same + assert_eq!(nonoptimized_result.len(), 1); + check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); + + Ok(()) + } + + fn check_batch(batch: RecordBatch, agg: &TestAggregate) { + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 1); + + let field = &fields[0]; + assert_eq!(field.name(), agg.column_name()); + assert_eq!(field.data_type(), &DataType::Int64); + // note that nullabiolity differs + assert_eq!( - result[0] + batch .column(0) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap() .values(), - &[count] + &[agg.expected_count()] ); - Ok(()) } - fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc { - // Return appropriate expr depending if COUNT is for col or table - let expr = match schema { - None => expressions::lit(ScalarValue::UInt8(Some(1))), - Some(s) => expressions::col(col.unwrap(), s).unwrap(), - }; - Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64)) + /// Describe the type of aggregate being tested + enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), + } + + impl TestAggregate { + fn new_count_star() -> Self { + Self::CountStar + } + + fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(schema.clone()) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + fn count_expr(&self) -> Arc { + Arc::new(Count::new( + self.column(), + self.column_name(), + DataType::Int64, + )) + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), + } + } + + /// What name would this aggregate produce in a plan? + fn column_name(&self) -> &'static str { + match self { + Self::CountStar => COUNT_STAR_NAME, + Self::ColumnA(_) => "COUNT(a)", + } + } + + /// What is the expected count? + fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, + } + } } #[tokio::test] @@ -332,11 +403,12 @@ mod tests { // basic test case with the aggregation applied on a source with exact statistics let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_star(); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -344,12 +416,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, false).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -359,11 +431,12 @@ mod tests { // basic test case with the aggregation applied on a source with exact statistics let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -371,12 +444,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, true).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -385,11 +458,12 @@ mod tests { async fn test_count_partial_indirect_child() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_star(); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -400,12 +474,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], Arc::new(coalesce), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, false).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -414,11 +488,12 @@ mod tests { async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -429,12 +504,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], Arc::new(coalesce), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, true).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -443,6 +518,7 @@ mod tests { async fn test_count_inexact_stat() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_star(); // adding a filter makes the statistics inexact let filter = Arc::new(FilterExec::try_new( @@ -458,7 +534,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], filter, Arc::clone(&schema), )?; @@ -466,7 +542,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -485,6 +561,7 @@ mod tests { async fn test_count_with_nulls_inexact_stat() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); // adding a filter makes the statistics inexact let filter = Arc::new(FilterExec::try_new( @@ -500,7 +577,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], filter, Arc::clone(&schema), )?; @@ -508,7 +585,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index 1e4ac6e5134e6..cccac05234045 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Int32Array, PrimitiveArray, UInt64Array}; +use arrow::array::{Int32Array, Int64Array, PrimitiveArray}; use arrow::compute::kernels::aggregate; use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; @@ -284,12 +284,12 @@ async fn optimizers_catch_all_statistics() { let expected = RecordBatch::try_new( Arc::new(Schema::new(vec![ - Field::new("COUNT(UInt8(1))", DataType::UInt64, false), + Field::new("COUNT(UInt8(1))", DataType::Int64, false), Field::new("MIN(test.c1)", DataType::Int32, false), Field::new("MAX(test.c1)", DataType::Int32, false), ])), vec![ - Arc::new(UInt64Array::from_slice(&[4])), + Arc::new(Int64Array::from_slice(&[4])), Arc::new(Int32Array::from_slice(&[1])), Arc::new(Int32Array::from_slice(&[100])), ], diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index ac22d094e5b73..3986eb3e64e3c 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -26,11 +26,15 @@ use crate::logical_plan::{ }; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use std::collections::HashSet; use std::sync::Arc; +/// The value to which `COUNT(*)` is expanded to in +/// `COUNT()` expressions +pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1)); + /// Recursively walk a list of expression trees, collecting the unique set of columns /// referenced in the expression pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result<()> { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 51f16033624e5..1e5daa472baba 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -30,7 +30,7 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, - find_aggregate_exprs, find_column_exprs, find_window_exprs, + find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION, }; use datafusion_expr::{ and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF, @@ -2122,14 +2122,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result<(AggregateFunction, Vec)> { let args = match fun { + // Special case rewrite COUNT(*) to COUNT(constant) AggregateFunction::Count => function .args .into_iter() .map(|a| match a { FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Value( Value::Number(_, _), - ))) => Ok(lit(1_u8)), - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(lit(1_u8)), + ))) => Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { + Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) + } _ => self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()), }) .collect::>>()?,