From 0663979b82909cdffddfd0b3a027bd2038999b8b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 1 Jun 2022 08:56:55 -0400 Subject: [PATCH 1/6] Fix `AggregateStatistics` optimization so it doens't change output type --- .../aggregate_statistics.rs | 68 +++++++++++++++---- 1 file changed, 55 insertions(+), 13 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 4cf96d2350ebb..140a270ffd241 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -150,7 +150,7 @@ fn take_optimizable_table_count( { if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { return Some(( - ScalarValue::UInt64(Some(num_rows as u64)), + ScalarValue::Int64(Some(num_rows as i64)), "COUNT(UInt8(1))", )); } @@ -183,7 +183,7 @@ fn take_optimizable_column_count( { let expr = format!("COUNT({})", col_expr.name()); return Some(( - ScalarValue::UInt64(Some((num_rows - val) as u64)), + ScalarValue::Int64(Some((num_rows - val) as i64)), expr, )); } @@ -254,7 +254,7 @@ mod tests { use super::*; use std::sync::Arc; - use arrow::array::{Int32Array, UInt64Array}; + use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; @@ -293,38 +293,80 @@ mod tests { /// Checks that the count optimization was applied and we still get the right result async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> Result<()> { let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); let conf = session_ctx.copied_config(); - let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; + let plan = Arc::new(plan) as _; + let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; let (col, count) = match nulls { - false => (Field::new("COUNT(UInt8(1))", DataType::UInt64, false), 3), - true => (Field::new("COUNT(a)", DataType::UInt64, false), 2), + false => (Field::new("COUNT(UInt8(1))", DataType::Int64, false), 3), + true => (Field::new("COUNT(a)", DataType::Int64, false), 2), }; // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); + let task_ctx = session_ctx.task_ctx(); let result = common::collect(optimized.execute(0, task_ctx)?).await?; assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); assert_eq!( result[0] .column(0) .as_any() - .downcast_ref::() + .downcast_ref::() .unwrap() .values(), &[count] ); + + // Validate that the optimized plan returns the exact same + // answer (both schema and data) as the original plan + let task_ctx = session_ctx.task_ctx(); + let plan_result = common::collect(plan.execute(0, task_ctx)?).await?; + assert_eq!(normalize(result), normalize(plan_result)); Ok(()) } + /// Normalize record batches for comparison: + /// 1. Sets nullable to `true` + fn normalize(batches: Vec) -> Vec { + let schema = normalize_schema(&batches[0].schema()); + batches + .into_iter() + .map(|batch| { + RecordBatch::try_new(schema.clone(), batch.columns().to_vec()) + .expect("Error creating record batch") + }) + .collect() + } + fn normalize_schema(schema: &Schema) -> Arc { + let nullable = true; + let normalized_fields = schema + .fields() + .iter() + .map(|f| { + Field::new(f.name(), f.data_type().clone(), nullable) + .with_metadata(f.metadata().cloned()) + }) + .collect(); + Arc::new(Schema::new_with_metadata( + normalized_fields, + schema.metadata().clone(), + )) + } + fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc { - // Return appropriate expr depending if COUNT is for col or table - let expr = match schema { - None => expressions::lit(ScalarValue::UInt8(Some(1))), - Some(s) => expressions::col(col.unwrap(), s).unwrap(), + // Return appropriate expr depending if COUNT is for col or table (*) + let (expr, name) = match schema { + None => ( + expressions::lit(ScalarValue::UInt8(Some(1))), + "COUNT(UInt8(1))".to_string(), + ), + Some(s) => ( + expressions::col(col.unwrap(), s).unwrap(), + format!("COUNT({})", col.unwrap()), + ), }; - Arc::new(Count::new(expr, "my_count_alias", DataType::UInt64)) + + Arc::new(Count::new(expr, name, DataType::Int64)) } #[tokio::test] From b82c30eb1af254545edf7e6b2f8a7daadf191e66 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 1 Jun 2022 09:37:44 -0400 Subject: [PATCH 2/6] fix test --- datafusion/core/tests/custom_sources.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index 1e4ac6e5134e6..cccac05234045 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{Int32Array, PrimitiveArray, UInt64Array}; +use arrow::array::{Int32Array, Int64Array, PrimitiveArray}; use arrow::compute::kernels::aggregate; use arrow::datatypes::{DataType, Field, Int32Type, Schema, SchemaRef}; use arrow::error::Result as ArrowResult; @@ -284,12 +284,12 @@ async fn optimizers_catch_all_statistics() { let expected = RecordBatch::try_new( Arc::new(Schema::new(vec![ - Field::new("COUNT(UInt8(1))", DataType::UInt64, false), + Field::new("COUNT(UInt8(1))", DataType::Int64, false), Field::new("MIN(test.c1)", DataType::Int32, false), Field::new("MAX(test.c1)", DataType::Int32, false), ])), vec![ - Arc::new(UInt64Array::from_slice(&[4])), + Arc::new(Int64Array::from_slice(&[4])), Arc::new(Int32Array::from_slice(&[1])), Arc::new(Int32Array::from_slice(&[100])), ], From 85e54f21ae35b9808c3274c6af91c05226e5fa8b Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 1 Jun 2022 09:52:36 -0400 Subject: [PATCH 3/6] Give some constants symbolic names to improve readability --- .../src/physical_optimizer/aggregate_statistics.rs | 14 +++++++++----- datafusion/expr/src/utils.rs | 6 +++++- datafusion/sql/src/planner.rs | 9 ++++++--- 3 files changed, 20 insertions(+), 9 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 140a270ffd241..c83a156994bff 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; +use datafusion_expr::utils::COUNT_STAR_EXPANSION; use crate::execution::context::SessionConfig; use crate::physical_plan::aggregates::{AggregateExec, AggregateMode}; @@ -37,6 +38,9 @@ use crate::error::Result; #[derive(Default)] pub struct AggregateStatistics {} +/// The name of the column corresponding to [`COUNT_STAR_EXPANSION`] +const COUNT_STAR_NAME: &str = "COUNT(UInt8(1))"; + impl AggregateStatistics { #[allow(missing_docs)] pub fn new() -> Self { @@ -148,10 +152,10 @@ fn take_optimizable_table_count( .as_any() .downcast_ref::() { - if lit_expr.value() == &ScalarValue::UInt8(Some(1)) { + if lit_expr.value() == &COUNT_STAR_EXPANSION { return Some(( ScalarValue::Int64(Some(num_rows as i64)), - "COUNT(UInt8(1))", + COUNT_STAR_NAME, )); } } @@ -298,7 +302,7 @@ mod tests { let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; let (col, count) = match nulls { - false => (Field::new("COUNT(UInt8(1))", DataType::Int64, false), 3), + false => (Field::new(COUNT_STAR_NAME, DataType::Int64, false), 3), true => (Field::new("COUNT(a)", DataType::Int64, false), 2), }; @@ -357,8 +361,8 @@ mod tests { // Return appropriate expr depending if COUNT is for col or table (*) let (expr, name) = match schema { None => ( - expressions::lit(ScalarValue::UInt8(Some(1))), - "COUNT(UInt8(1))".to_string(), + expressions::lit(COUNT_STAR_EXPANSION), + COUNT_STAR_NAME.to_string(), ), Some(s) => ( expressions::col(col.unwrap(), s).unwrap(), diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 483b12b49a5e8..35e538b3cbe81 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -26,11 +26,15 @@ use crate::logical_plan::{ }; use crate::{Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; use datafusion_common::{ - Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, + Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, }; use std::collections::HashSet; use std::sync::Arc; +/// The value to which `COUNT(*)` is expanded to in +/// `COUNT()` expressions +pub const COUNT_STAR_EXPANSION: ScalarValue = ScalarValue::UInt8(Some(1)); + /// Recursively walk a list of expression trees, collecting the unique set of columns /// referenced in the expression pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet) -> Result<()> { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index b21550567f53d..948b03f08e600 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -30,7 +30,7 @@ use datafusion_expr::logical_plan::{ }; use datafusion_expr::utils::{ expand_qualified_wildcard, expand_wildcard, expr_as_column_expr, expr_to_columns, - find_aggregate_exprs, find_column_exprs, find_window_exprs, + find_aggregate_exprs, find_column_exprs, find_window_exprs, COUNT_STAR_EXPANSION, }; use datafusion_expr::{ and, col, lit, AggregateFunction, AggregateUDF, Expr, Operator, ScalarUDF, @@ -2197,14 +2197,17 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, ) -> Result<(AggregateFunction, Vec)> { let args = match fun { + // Special case rewrite COUNT(*) to COUNT(constant) AggregateFunction::Count => function .args .into_iter() .map(|a| match a { FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Value( Value::Number(_, _), - ))) => Ok(lit(1_u8)), - FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => Ok(lit(1_u8)), + ))) => Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())), + FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => { + Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone())) + } _ => self.sql_fn_arg_to_logical_expr(a, schema, &mut HashMap::new()), }) .collect::>>()?, From eb14658de72776cbd241e0e44dd65eb92215ef4f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 2 Jun 2022 14:40:08 -0400 Subject: [PATCH 4/6] Consolidate expected differences in COUNT(*) and COUNT(a) in tests --- .../aggregate_statistics.rs | 122 ++++++++++++------ 1 file changed, 85 insertions(+), 37 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index c83a156994bff..5c09f3ca15630 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -261,6 +261,7 @@ mod tests { use arrow::array::{Int32Array, Int64Array}; use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; + use datafusion_physical_expr::PhysicalExpr; use crate::error::Result; use crate::logical_plan::Operator; @@ -295,22 +296,22 @@ mod tests { } /// Checks that the count optimization was applied and we still get the right result - async fn assert_count_optim_success(plan: AggregateExec, nulls: bool) -> Result<()> { + async fn assert_count_optim_success( + plan: AggregateExec, + agg: TestAggregate, + ) -> Result<()> { let session_ctx = SessionContext::new(); let conf = session_ctx.copied_config(); let plan = Arc::new(plan) as _; let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; - let (col, count) = match nulls { - false => (Field::new(COUNT_STAR_NAME, DataType::Int64, false), 3), - true => (Field::new("COUNT(a)", DataType::Int64, false), 2), - }; + let expected_schema = Arc::new(Schema::new(vec![agg.field()])); // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); let task_ctx = session_ctx.task_ctx(); let result = common::collect(optimized.execute(0, task_ctx)?).await?; - assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); + assert_eq!(result[0].schema(), expected_schema); assert_eq!( result[0] .column(0) @@ -318,7 +319,7 @@ mod tests { .downcast_ref::() .unwrap() .values(), - &[count] + &[agg.expected_count()] ); // Validate that the optimized plan returns the exact same @@ -357,20 +358,61 @@ mod tests { )) } - fn count_expr(schema: Option<&Schema>, col: Option<&str>) -> Arc { - // Return appropriate expr depending if COUNT is for col or table (*) - let (expr, name) = match schema { - None => ( - expressions::lit(COUNT_STAR_EXPANSION), - COUNT_STAR_NAME.to_string(), - ), - Some(s) => ( - expressions::col(col.unwrap(), s).unwrap(), - format!("COUNT({})", col.unwrap()), - ), - }; - - Arc::new(Count::new(expr, name, DataType::Int64)) + /// Describe the type of aggregate being tested + enum TestAggregate { + /// Testing COUNT(*) type aggregates + CountStar, + + /// Testing for COUNT(column) aggregate + ColumnA(Arc), + } + + impl TestAggregate { + fn new_count_star() -> Self { + Self::CountStar + } + + fn new_count_column(schema: &Arc) -> Self { + Self::ColumnA(schema.clone()) + } + + /// Return appropriate expr depending if COUNT is for col or table (*) + fn count_expr(&self) -> Arc { + Arc::new(Count::new( + self.column(), + self.column_name(), + DataType::Int64, + )) + } + + /// what argument would this aggregate need in the plan? + fn column(&self) -> Arc { + match self { + Self::CountStar => expressions::lit(COUNT_STAR_EXPANSION), + Self::ColumnA(s) => expressions::col("a", s).unwrap(), + } + } + + /// What name would this aggregate produce in a plan? + fn column_name(&self) -> &'static str { + match self { + Self::CountStar => COUNT_STAR_NAME, + Self::ColumnA(_) => "COUNT(a)", + } + } + + /// What is the output Field this aggregate would produce? + fn field(&self) -> Field { + Field::new(self.column_name(), DataType::Int64, false) + } + + /// What is the expected count? + fn expected_count(&self) -> i64 { + match self { + TestAggregate::CountStar => 3, + TestAggregate::ColumnA(_) => 2, + } + } } #[tokio::test] @@ -378,11 +420,12 @@ mod tests { // basic test case with the aggregation applied on a source with exact statistics let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_star(); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -390,12 +433,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, false).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -405,11 +448,12 @@ mod tests { // basic test case with the aggregation applied on a source with exact statistics let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -417,12 +461,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, true).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -431,11 +475,12 @@ mod tests { async fn test_count_partial_indirect_child() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_star(); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -446,12 +491,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], Arc::new(coalesce), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, false).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -460,11 +505,12 @@ mod tests { async fn test_count_partial_with_nulls_indirect_child() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], source, Arc::clone(&schema), )?; @@ -475,12 +521,12 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], Arc::new(coalesce), Arc::clone(&schema), )?; - assert_count_optim_success(final_agg, true).await?; + assert_count_optim_success(final_agg, agg).await?; Ok(()) } @@ -489,6 +535,7 @@ mod tests { async fn test_count_inexact_stat() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_star(); // adding a filter makes the statistics inexact let filter = Arc::new(FilterExec::try_new( @@ -504,7 +551,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], filter, Arc::clone(&schema), )?; @@ -512,7 +559,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(None, None)], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; @@ -531,6 +578,7 @@ mod tests { async fn test_count_with_nulls_inexact_stat() -> Result<()> { let source = mock_data()?; let schema = source.schema(); + let agg = TestAggregate::new_count_column(&schema); // adding a filter makes the statistics inexact let filter = Arc::new(FilterExec::try_new( @@ -546,7 +594,7 @@ mod tests { let partial_agg = AggregateExec::try_new( AggregateMode::Partial, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], filter, Arc::clone(&schema), )?; @@ -554,7 +602,7 @@ mod tests { let final_agg = AggregateExec::try_new( AggregateMode::Final, vec![], - vec![count_expr(Some(&schema), Some("a"))], + vec![agg.count_expr()], Arc::new(partial_agg), Arc::clone(&schema), )?; From 171c89901ecdadca6c2eccb2973bc7ad0990c92f Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 2 Jun 2022 15:07:55 -0400 Subject: [PATCH 5/6] Simplify how the verification is done --- .../aggregate_statistics.rs | 72 +++++++------------ 1 file changed, 27 insertions(+), 45 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index 5c09f3ca15630..f34d67849a5ca 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -305,15 +305,37 @@ mod tests { let plan = Arc::new(plan) as _; let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; - let expected_schema = Arc::new(Schema::new(vec![agg.field()])); // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); - let task_ctx = session_ctx.task_ctx(); - let result = common::collect(optimized.execute(0, task_ctx)?).await?; - assert_eq!(result[0].schema(), expected_schema); + + // run both the optimized and nonoptimized plan + let optimized_result = common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; + let nonoptimized_result = common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; + assert_eq!(optimized_result.len(), nonoptimized_result.len()); + + // and validate the results are the same and expected + assert_eq!(optimized_result.len(), 1); + check_batch(optimized_result.into_iter().next().unwrap(), &agg); + // check the non optimized one too to ensure types and names remain the same + assert_eq!(nonoptimized_result.len(), 1); + check_batch(nonoptimized_result.into_iter().next().unwrap(), &agg); + + Ok(()) + } + + fn check_batch(batch: RecordBatch, agg: &TestAggregate){ + let schema = batch.schema(); + let fields = schema.fields(); + assert_eq!(fields.len(), 1); + + let field = &fields[0]; + assert_eq!(field.name(), agg.column_name()); + assert_eq!(field.data_type(), &DataType::Int64); + // note that nullabiolity differs + assert_eq!( - result[0] + batch .column(0) .as_any() .downcast_ref::() @@ -321,41 +343,6 @@ mod tests { .values(), &[agg.expected_count()] ); - - // Validate that the optimized plan returns the exact same - // answer (both schema and data) as the original plan - let task_ctx = session_ctx.task_ctx(); - let plan_result = common::collect(plan.execute(0, task_ctx)?).await?; - assert_eq!(normalize(result), normalize(plan_result)); - Ok(()) - } - - /// Normalize record batches for comparison: - /// 1. Sets nullable to `true` - fn normalize(batches: Vec) -> Vec { - let schema = normalize_schema(&batches[0].schema()); - batches - .into_iter() - .map(|batch| { - RecordBatch::try_new(schema.clone(), batch.columns().to_vec()) - .expect("Error creating record batch") - }) - .collect() - } - fn normalize_schema(schema: &Schema) -> Arc { - let nullable = true; - let normalized_fields = schema - .fields() - .iter() - .map(|f| { - Field::new(f.name(), f.data_type().clone(), nullable) - .with_metadata(f.metadata().cloned()) - }) - .collect(); - Arc::new(Schema::new_with_metadata( - normalized_fields, - schema.metadata().clone(), - )) } /// Describe the type of aggregate being tested @@ -401,11 +388,6 @@ mod tests { } } - /// What is the output Field this aggregate would produce? - fn field(&self) -> Field { - Field::new(self.column_name(), DataType::Int64, false) - } - /// What is the expected count? fn expected_count(&self) -> i64 { match self { From fde3cc4a2c45252d78c3caecd31e377ab35ab3fb Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 2 Jun 2022 15:08:15 -0400 Subject: [PATCH 6/6] fmt --- .../core/src/physical_optimizer/aggregate_statistics.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs index f34d67849a5ca..bcf4fec071d47 100644 --- a/datafusion/core/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/core/src/physical_optimizer/aggregate_statistics.rs @@ -305,13 +305,14 @@ mod tests { let plan = Arc::new(plan) as _; let optimized = AggregateStatistics::new().optimize(Arc::clone(&plan), &conf)?; - // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); // run both the optimized and nonoptimized plan - let optimized_result = common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; - let nonoptimized_result = common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; + let optimized_result = + common::collect(optimized.execute(0, session_ctx.task_ctx())?).await?; + let nonoptimized_result = + common::collect(plan.execute(0, session_ctx.task_ctx())?).await?; assert_eq!(optimized_result.len(), nonoptimized_result.len()); // and validate the results are the same and expected @@ -324,7 +325,7 @@ mod tests { Ok(()) } - fn check_batch(batch: RecordBatch, agg: &TestAggregate){ + fn check_batch(batch: RecordBatch, agg: &TestAggregate) { let schema = batch.schema(); let fields = schema.fields(); assert_eq!(fields.len(), 1);