-
Notifications
You must be signed in to change notification settings - Fork 1.9k
Fix AggregateStatistics optimization so it doesn't change output type
#2674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0663979
b82c30e
85e54f2
5948c0e
eb14658
171c899
fde3cc4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| use std::sync::Arc; | ||
|
|
||
| use arrow::datatypes::Schema; | ||
| use datafusion_expr::utils::COUNT_STAR_EXPANSION; | ||
|
|
||
| use crate::execution::context::SessionConfig; | ||
| use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; | ||
|
|
@@ -37,6 +38,9 @@ use crate::error::Result; | |
| #[derive(Default)] | ||
| pub struct AggregateStatistics {} | ||
|
|
||
| /// The name of the column corresponding to [`COUNT_STAR_EXPANSION`] | ||
| const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))"; | ||
|
|
||
| impl AggregateStatistics { | ||
| #[allow(missing_docs)] | ||
| pub fn new() -> Self { | ||
|
|
@@ -148,10 +152,10 @@ fn take_optimizable_table_count( | |
| .as_any() | ||
| .downcast_ref::<expressions::Literal>() | ||
| { | ||
| if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { | ||
| if lit_expr.value() == &COUNT_STAR_EXPANSION { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was an implicit coupling between the SQL planner and this file, which I have now made explicit with a named constant |
||
| return Some(( | ||
| ScalarValue::UInt64(Some(num_rows as u64)), | ||
| "COUNT(UInt8(1))", | ||
| ScalarValue::Int64(Some(num_rows as i64)), | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The change from |
||
| COUNT_STAR_NAME, | ||
| )); | ||
| } | ||
| } | ||
|
|
@@ -183,7 +187,7 @@ fn take_optimizable_column_count( | |
| { | ||
| let expr = format!("COUNT({})", col_expr.name()); | ||
| return Some(( | ||
| ScalarValue::UInt64(Some((num_rows - val) as u64)), | ||
| ScalarValue::Int64(Some((num_rows - val) as i64)), | ||
| expr, | ||
| )); | ||
| } | ||
|
|
@@ -254,9 +258,10 @@ mod tests { | |
| use super::*; | ||
| use std::sync::Arc; | ||
|
|
||
| use arrow::array::{Int32Array, UInt64Array}; | ||
| use arrow::array::{Int32Array, Int64Array}; | ||
| use arrow::datatypes::{DataType, Field, Schema}; | ||
| use arrow::record_batch::RecordBatch; | ||
| use datafusion_physical_expr::PhysicalExpr; | ||
|
|
||
| use crate::error::Result; | ||
| use crate::logical_plan::Operator; | ||
|
|
@@ -291,65 +296,132 @@ mod tests { | |
| } | ||
|
|
||
| /// Checks that the count optimization was applied and we still get the right result | ||
| async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> Result<()> { | ||
| async fn assert_count_optim_success( | ||
| plan: AggregateExec, | ||
| agg: TestAggregate, | ||
| ) -> Result<()> { | ||
| let session_ctx = SessionContext::new(); | ||
| let task_ctx = session_ctx.task_ctx(); | ||
| let conf = session_ctx.copied_config(); | ||
| let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; | ||
|
|
||
| let (col, count) = match nulls { | ||
| false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3), | ||
| true => (Field::new("COUNT(a)", DataType::UInt64, false), 2), | ||
| }; | ||
| let plan = Arc::new(plan) as _; | ||
| let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; | ||
|
|
||
| // A ProjectionExec is a sign that the count optimization was applied | ||
| assert!(optimized.as_any().is::<ProjectionExec>()); | ||
| let result = common::collect(optimized.execute(0, task_ctx)?).await?; | ||
| assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); | ||
|
|
||
| // run both the optimized and nonoptimized plan | ||
| let optimized_result = | ||
| common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; | ||
| let nonoptimized_result = | ||
| common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; | ||
| assert_eq!(optimized_result.len(), nonoptimized_result.len()); | ||
|
|
||
| // and validate the results are the same and expected | ||
| assert_eq!(optimized_result.len(), 1); | ||
| check_batch(optimized_result.into_iter().next().unwrap(), &agg); | ||
| // check the non optimized one too to ensure types and names remain the same | ||
| assert_eq!(nonoptimized_result.len(), 1); | ||
| check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
||
| fn check_batch(batch: RecordBatch, agg: &TestAggregate) { | ||
| let schema = batch.schema(); | ||
| let fields = schema.fields(); | ||
| assert_eq!(fields.len(), 1); | ||
|
|
||
| let field = &fields[0]; | ||
| assert_eq!(field.name(), agg.column_name()); | ||
| assert_eq!(field.data_type(), &DataType::Int64); | ||
| // note that nullabiolity differs | ||
|
|
||
| assert_eq!( | ||
| result[0] | ||
| batch | ||
| .column(0) | ||
| .as_any() | ||
| .downcast_ref::<UInt64Array>() | ||
| .downcast_ref::<Int64Array>() | ||
| .unwrap() | ||
| .values(), | ||
| &[count] | ||
| &[agg.expected_count()] | ||
| ); | ||
| Ok(()) | ||
| } | ||
|
|
||
| fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc<dyn AggregateExpr> { | ||
| // Return appropriate expr depending if COUNT is for col or table | ||
| let expr = match schema { | ||
| None => expressions::lit(ScalarValue::UInt8(Some(1))), | ||
| Some(s) => expressions::col(col.unwrap(), s).unwrap(), | ||
| }; | ||
| Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64)) | ||
| /// Describe the type of aggregate being tested | ||
| enum TestAggregate { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This now parameterizes the difference between different tests into an explicit |
||
| /// Testing COUNT(*) type aggregates | ||
| CountStar, | ||
|
|
||
| /// Testing for COUNT(column) aggregate | ||
| ColumnA(Arc<Schema>), | ||
| } | ||
|
|
||
| impl TestAggregate { | ||
| fn new_count_star() -> Self { | ||
| Self::CountStar | ||
| } | ||
|
|
||
| fn new_count_column(schema: &Arc<Schema>) -> Self { | ||
| Self::ColumnA(schema.clone()) | ||
| } | ||
|
|
||
| /// Return appropriate expr depending if COUNT is for col or table (*) | ||
| fn count_expr(&self) -> Arc<dyn AggregateExpr> { | ||
| Arc::new(Count::new( | ||
| self.column(), | ||
| self.column_name(), | ||
| DataType::Int64, | ||
| )) | ||
| } | ||
|
|
||
| /// what argument would this aggregate need in the plan? | ||
| fn column(&self) -> Arc<dyn PhysicalExpr> { | ||
| match self { | ||
| Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), | ||
| Self::ColumnA(s) => expressions::col("a", s).unwrap(), | ||
| } | ||
| } | ||
|
|
||
| /// What name would this aggregate produce in a plan? | ||
| fn column_name(&self) -> &'static str { | ||
| match self { | ||
| Self::CountStar => COUNT_STAR_NAME, | ||
| Self::ColumnA(_) => "COUNT(a)", | ||
| } | ||
| } | ||
|
|
||
| /// What is the expected count? | ||
| fn expected_count(&self) -> i64 { | ||
| match self { | ||
| TestAggregate::CountStar => 3, | ||
| TestAggregate::ColumnA(_) => 2, | ||
| } | ||
| } | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn test_count_partial_direct_child() -> Result<()> { | ||
| // basic test case with the aggregation applied on a source with exact statistics | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
| let agg = TestAggregate::new_count_star(); | ||
|
|
||
| let partial_agg = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| vec![agg.count_expr()], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = AggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| vec![agg.count_expr()], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg, false).await?; | ||
| assert_count_optim_success(final_agg, agg).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -359,24 +431,25 @@ mod tests { | |
| // basic test case with the aggregation applied on a source with exact statistics | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
| let agg = TestAggregate::new_count_column(&schema); | ||
|
|
||
| let partial_agg = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| vec![agg.count_expr()], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = AggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| vec![agg.count_expr()], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg, true).await?; | ||
| assert_count_optim_success(final_agg, agg).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -385,11 +458,12 @@ mod tests { | |
| async fn test_count_partial_indirect_child() -> Result<()> { | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
| let agg = TestAggregate::new_count_star(); | ||
|
|
||
| let partial_agg = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| vec![agg.count_expr()], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
@@ -400,12 +474,12 @@ mod tests { | |
| let final_agg = AggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| vec![agg.count_expr()], | ||
| Arc::new(coalesce), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg, false).await?; | ||
| assert_count_optim_success(final_agg, agg).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -414,11 +488,12 @@ mod tests { | |
| async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
| let agg = TestAggregate::new_count_column(&schema); | ||
|
|
||
| let partial_agg = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| vec![agg.count_expr()], | ||
| source, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
@@ -429,12 +504,12 @@ mod tests { | |
| let final_agg = AggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| vec![agg.count_expr()], | ||
| Arc::new(coalesce), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| assert_count_optim_success(final_agg, true).await?; | ||
| assert_count_optim_success(final_agg, agg).await?; | ||
|
|
||
| Ok(()) | ||
| } | ||
|
|
@@ -443,6 +518,7 @@ mod tests { | |
| async fn test_count_inexact_stat() -> Result<()> { | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
| let agg = TestAggregate::new_count_star(); | ||
|
|
||
| // adding a filter makes the statistics inexact | ||
| let filter = Arc::new(FilterExec::try_new( | ||
|
|
@@ -458,15 +534,15 @@ mod tests { | |
| let partial_agg = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| vec![agg.count_expr()], | ||
| filter, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = AggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(None, None)], | ||
| vec![agg.count_expr()], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
@@ -485,6 +561,7 @@ mod tests { | |
| async fn test_count_with_nulls_inexact_stat() -> Result<()> { | ||
| let source = mock_data()?; | ||
| let schema = source.schema(); | ||
| let agg = TestAggregate::new_count_column(&schema); | ||
|
|
||
| // adding a filter makes the statistics inexact | ||
| let filter = Arc::new(FilterExec::try_new( | ||
|
|
@@ -500,15 +577,15 @@ mod tests { | |
| let partial_agg = AggregateExec::try_new( | ||
| AggregateMode::Partial, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| vec![agg.count_expr()], | ||
| filter, | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
| let final_agg = AggregateExec::try_new( | ||
| AggregateMode::Final, | ||
| vec![], | ||
| vec![count_expr(Some(&schema), Some("a"))], | ||
| vec![agg.count_expr()], | ||
| Arc::new(partial_agg), | ||
| Arc::clone(&schema), | ||
| )?; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This constant was hard coded in a few places and I think this symbolic name helps understand what it is doing