From 473b6a7bac6e660463b97626c6a7288659cbd664 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Sat, 28 May 2022 20:19:44 +0800 Subject: [PATCH] change result type of count/count_distinct to int64 --- .../optimizer/single_distinct_to_groupby.rs | 20 ++++---- .../core/src/physical_plan/windows/mod.rs | 2 +- datafusion/core/tests/path_partition.rs | 22 ++++----- .../core/tests/provider_filter_pushdown.rs | 8 ++-- datafusion/expr/src/aggregate_function.rs | 3 +- datafusion/expr/src/window_function.rs | 4 +- .../physical-expr/src/aggregate/build_in.rs | 10 ++-- .../physical-expr/src/aggregate/count.rs | 46 +++++++++---------- .../src/aggregate/count_distinct.rs | 30 ++++++------ 9 files changed, 71 insertions(+), 74 deletions(-) diff --git a/datafusion/core/src/optimizer/single_distinct_to_groupby.rs b/datafusion/core/src/optimizer/single_distinct_to_groupby.rs index 65ff565562341..1748f9af624f9 100644 --- a/datafusion/core/src/optimizer/single_distinct_to_groupby.rs +++ b/datafusion/core/src/optimizer/single_distinct_to_groupby.rs @@ -238,8 +238,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):UInt64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ + let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT test.b) [COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):Int64;N]\ \n Aggregate: groupBy=[[#test.b AS alias1]], aggr=[[]] [alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; @@ -255,8 +255,8 @@ mod tests { .aggregate(Vec::::new(), vec![count_distinct(lit(2) * col("b"))])? .build()?; - let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):UInt64;N]\ - \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):UInt64;N]\ + let expected = "Projection: #COUNT(alias1) AS COUNT(DISTINCT Int32(2) * test.b) [COUNT(DISTINCT Int32(2) * test.b):Int64;N]\ + \n Aggregate: groupBy=[[]], aggr=[[COUNT(#alias1)]] [COUNT(alias1):Int64;N]\ \n Aggregate: groupBy=[[Int32(2) * #test.b AS alias1]], aggr=[[]] [alias1:Int32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; @@ -273,8 +273,8 @@ mod tests { .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N]\ + let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N]\ \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; @@ -294,7 +294,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(DISTINCT test.c):UInt64;N]\ + let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(DISTINCT #test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(DISTINCT test.c):Int64;N]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); @@ -319,8 +319,8 @@ mod tests { )? .build()?; // Should work - let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, MAX(DISTINCT test.b):UInt32;N]\ - \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):UInt64;N, MAX(alias1):UInt32;N]\ + let expected = "Projection: #test.a AS a, #COUNT(alias1) AS COUNT(DISTINCT test.b), #MAX(alias1) AS MAX(DISTINCT test.b) [a:UInt32, COUNT(DISTINCT test.b):Int64;N, MAX(DISTINCT test.b):UInt32;N]\ + \n Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(#alias1), MAX(#alias1)]] [a:UInt32, COUNT(alias1):Int64;N, MAX(alias1):UInt32;N]\ \n Aggregate: groupBy=[[#test.a, #test.b AS alias1]], aggr=[[]] [a:UInt32, alias1:UInt32]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; @@ -340,7 +340,7 @@ mod tests { .build()?; // Do nothing - let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):UInt64;N, COUNT(test.c):UInt64;N]\ + let expected = "Aggregate: groupBy=[[#test.a]], aggr=[[COUNT(DISTINCT #test.b), COUNT(#test.c)]] [a:UInt32, COUNT(DISTINCT test.b):Int64;N, COUNT(test.c):Int64;N]\ \n TableScan: test projection=None [a:UInt32, b:UInt32, c:UInt32]"; assert_optimized_plan_eq(&plan, expected); diff --git a/datafusion/core/src/physical_plan/windows/mod.rs b/datafusion/core/src/physical_plan/windows/mod.rs index 496b4a6ef7f0b..26cb14fe33a96 100644 --- a/datafusion/core/src/physical_plan/windows/mod.rs +++ b/datafusion/core/src/physical_plan/windows/mod.rs @@ -219,7 +219,7 @@ mod tests { // c3 is small int - let count: &UInt64Array = as_primitive_array(&columns[0]); + let count: &Int64Array = as_primitive_array(&columns[0]); assert_eq!(count.value(0), 100); assert_eq!(count.value(99), 100); diff --git a/datafusion/core/tests/path_partition.rs b/datafusion/core/tests/path_partition.rs index 4297b12b60473..873554747f1e9 100644 --- a/datafusion/core/tests/path_partition.rs +++ b/datafusion/core/tests/path_partition.rs @@ -106,8 +106,8 @@ async fn parquet_distinct_partition_col() -> Result<()> { .await?; let mut max_limit = match ScalarValue::try_from_array(results[0].column(0), 0)? { - ScalarValue::UInt64(Some(count)) => count, - s => panic!("Expected count as Int64 found {}", s), + ScalarValue::Int64(Some(count)) => count, + s => panic!("Expected count as Int64 found {}", s.get_datatype()), }; max_limit += 1; @@ -117,40 +117,40 @@ async fn parquet_distinct_partition_col() -> Result<()> { let last_row_idx = last_batch.num_rows() - 1; let mut min_limit = match ScalarValue::try_from_array(last_batch.column(0), last_row_idx)? { - ScalarValue::UInt64(Some(count)) => count, - s => panic!("Expected count as Int64 found {}", s), + ScalarValue::Int64(Some(count)) => count, + s => panic!("Expected count as Int64 found {}", s.get_datatype()), }; min_limit -= 1; let sql_cross_partition_boundary = format!("SELECT month FROM t limit {}", max_limit); - let resulting_limit: u64 = ctx + let resulting_limit: i64 = ctx .sql(sql_cross_partition_boundary.as_str()) .await? .collect() .await? .into_iter() - .map(|r| r.num_rows() as u64) + .map(|r| r.num_rows() as i64) .sum(); assert_eq!(max_limit, resulting_limit); let sql_within_partition_boundary = format!("SELECT month from t limit {}", min_limit); - let resulting_limit: u64 = ctx + let resulting_limit: i64 = ctx .sql(sql_within_partition_boundary.as_str()) .await? .collect() .await? .into_iter() - .map(|r| r.num_rows() as u64) + .map(|r| r.num_rows() as i64) .sum(); assert_eq!(min_limit, resulting_limit); let month = match ScalarValue::try_from_array(results[0].column(1), 0)? { ScalarValue::Utf8(Some(month)) => month, - s => panic!("Expected count as Int64 found {}", s), + s => panic!("Expected count as Int64 found {}", s.get_datatype()), }; let sql_on_partition_boundary = format!( @@ -158,13 +158,13 @@ async fn parquet_distinct_partition_col() -> Result<()> { month, max_limit - 1 ); - let resulting_limit: u64 = ctx + let resulting_limit: i64 = ctx .sql(sql_on_partition_boundary.as_str()) .await? .collect() .await? .into_iter() - .map(|r| r.num_rows() as u64) + .map(|r| r.num_rows() as i64) .sum(); let partition_row_count = max_limit - 1; assert_eq!(partition_row_count, resulting_limit); diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index c8fe483ea9f46..79c71afb341a7 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use arrow::array::{as_primitive_array, Int32Builder, UInt64Array}; +use arrow::array::{as_primitive_array, Int32Builder, Int64Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; use async_trait::async_trait; @@ -170,7 +170,7 @@ impl TableProvider for CustomProvider { } } -async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<()> { +async fn assert_provider_row_count(value: i64, expected_count: i64) -> Result<()> { let provider = CustomProvider { zero_batch: create_batch(0, 10)?, one_batch: create_batch(1, 5)?, @@ -183,7 +183,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() .aggregate(vec![], vec![count(col("flag"))])?; let results = df.collect().await?; - let result_col: &UInt64Array = as_primitive_array(results[0].column(0)); + let result_col: &Int64Array = as_primitive_array(results[0].column(0)); assert_eq!(result_col.value(0), expected_count); ctx.register_table("data", Arc::new(provider))?; @@ -193,7 +193,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() .collect() .await?; - let sql_result_col: &UInt64Array = as_primitive_array(sql_results[0].column(0)); + let sql_result_col: &Int64Array = as_primitive_array(sql_results[0].column(0)); assert_eq!(sql_result_col.value(0), expected_count); Ok(()) diff --git a/datafusion/expr/src/aggregate_function.rs b/datafusion/expr/src/aggregate_function.rs index eacb3f74a8644..fb9f89691624b 100644 --- a/datafusion/expr/src/aggregate_function.rs +++ b/datafusion/expr/src/aggregate_function.rs @@ -146,9 +146,8 @@ pub fn return_type( let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?; match fun { - // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64. AggregateFunction::Count | AggregateFunction::ApproxDistinct => { - Ok(DataType::UInt64) + Ok(DataType::Int64) } AggregateFunction::Max | AggregateFunction::Min => { // For min and max agg function, the returned type is same as input type. diff --git a/datafusion/expr/src/window_function.rs b/datafusion/expr/src/window_function.rs index e3daa0922f5ba..414f4bf6f97a5 100644 --- a/datafusion/expr/src/window_function.rs +++ b/datafusion/expr/src/window_function.rs @@ -227,10 +227,10 @@ mod tests { fn test_count_return_type() -> Result<()> { let fun = WindowFunction::from_str("count")?; let observed = return_type(&fun, &[DataType::Utf8])?; - assert_eq!(DataType::UInt64, observed); + assert_eq!(DataType::Int64, observed); let observed = return_type(&fun, &[DataType::UInt64])?; - assert_eq!(DataType::UInt64, observed); + assert_eq!(DataType::Int64, observed); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index 6d5dfc75633cd..23d2a84d132ea 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -304,7 +304,7 @@ mod tests { assert!(result_agg_phy_exprs.as_any().is::()); assert_eq!("c1", result_agg_phy_exprs.name()); assert_eq!( - Field::new("c1", DataType::UInt64, true), + Field::new("c1", DataType::Int64, true), result_agg_phy_exprs.field().unwrap() ); } @@ -347,7 +347,7 @@ mod tests { assert!(result_distinct.as_any().is::()); assert_eq!("c1", result_distinct.name()); assert_eq!( - Field::new("c1", DataType::UInt64, true), + Field::new("c1", DataType::Int64, true), result_distinct.field().unwrap() ); } @@ -954,14 +954,14 @@ mod tests { #[test] fn test_count_return_type() -> Result<()> { let observed = return_type(&AggregateFunction::Count, &[DataType::Utf8])?; - assert_eq!(DataType::UInt64, observed); + assert_eq!(DataType::Int64, observed); let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; - assert_eq!(DataType::UInt64, observed); + assert_eq!(DataType::Int64, observed); let observed = return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; - assert_eq!(DataType::UInt64, observed); + assert_eq!(DataType::Int64, observed); Ok(()) } diff --git a/datafusion/physical-expr/src/aggregate/count.rs b/datafusion/physical-expr/src/aggregate/count.rs index 54bec05d72f0a..2b02d03b51f6c 100644 --- a/datafusion/physical-expr/src/aggregate/count.rs +++ b/datafusion/physical-expr/src/aggregate/count.rs @@ -23,12 +23,10 @@ use std::sync::Arc; use crate::aggregate::row_accumulator::RowAccumulator; use crate::{AggregateExpr, PhysicalExpr}; +use arrow::array::Int64Array; use arrow::compute; use arrow::datatypes::DataType; -use arrow::{ - array::{ArrayRef, UInt64Array}, - datatypes::Field, -}; +use arrow::{array::ArrayRef, datatypes::Field}; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::Accumulator; @@ -110,7 +108,7 @@ impl AggregateExpr for Count { #[derive(Debug)] struct CountAccumulator { - count: u64, + count: i64, } impl CountAccumulator { @@ -123,12 +121,12 @@ impl CountAccumulator { impl Accumulator for CountAccumulator { fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { let array = &values[0]; - self.count += (array.len() - array.data().null_count()) as u64; + self.count += (array.len() - array.data().null_count()) as i64; Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - let counts = states[0].as_any().downcast_ref::().unwrap(); + let counts = states[0].as_any().downcast_ref::().unwrap(); let delta = &compute::sum(counts); if let Some(d) = delta { self.count += *d; @@ -137,11 +135,11 @@ impl Accumulator for CountAccumulator { } fn state(&self) -> Result> { - Ok(vec![ScalarValue::UInt64(Some(self.count))]) + Ok(vec![ScalarValue::Int64(Some(self.count))]) } fn evaluate(&self) -> Result { - Ok(ScalarValue::UInt64(Some(self.count))) + Ok(ScalarValue::Int64(Some(self.count))) } } @@ -173,16 +171,16 @@ impl RowAccumulator for CountRowAccumulator { states: &[ArrayRef], accessor: &mut RowAccessor, ) -> Result<()> { - let counts = states[0].as_any().downcast_ref::().unwrap(); + let counts = states[0].as_any().downcast_ref::().unwrap(); let delta = &compute::sum(counts); if let Some(d) = delta { - accessor.add_u64(self.state_index, *d); + accessor.add_i64(self.state_index, *d); } Ok(()) } fn evaluate(&self, accessor: &RowAccessor) -> Result { - Ok(accessor.get_as_scalar(&DataType::UInt64, self.state_index)) + Ok(accessor.get_as_scalar(&DataType::Int64, self.state_index)) } #[inline(always)] @@ -208,8 +206,8 @@ mod tests { a, DataType::Int32, Count, - ScalarValue::from(5u64), - DataType::UInt64 + ScalarValue::from(5i64), + DataType::Int64 ) } @@ -227,8 +225,8 @@ mod tests { a, DataType::Int32, Count, - ScalarValue::from(3u64), - DataType::UInt64 + ScalarValue::from(3i64), + DataType::Int64 ) } @@ -241,8 +239,8 @@ mod tests { a, DataType::Boolean, Count, - ScalarValue::from(0u64), - DataType::UInt64 + ScalarValue::from(0i64), + DataType::Int64 ) } @@ -254,8 +252,8 @@ mod tests { a, DataType::Boolean, Count, - ScalarValue::from(0u64), - DataType::UInt64 + ScalarValue::from(0i64), + DataType::Int64 ) } @@ -267,8 +265,8 @@ mod tests { a, DataType::Utf8, Count, - ScalarValue::from(5u64), - DataType::UInt64 + ScalarValue::from(5i64), + DataType::Int64 ) } @@ -280,8 +278,8 @@ mod tests { a, DataType::LargeUtf8, Count, - ScalarValue::from(5u64), - DataType::UInt64 + ScalarValue::from(5i64), + DataType::Int64 ) } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index f1e3afe6b041b..e7a8b8c5663f3 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -218,7 +218,7 @@ impl Accumulator for DistinctCountAccumulator { fn evaluate(&self) -> Result { match &self.count_data_type { - DataType::UInt64 => Ok(ScalarValue::UInt64(Some(self.values.len() as u64))), + DataType::Int64 => Ok(ScalarValue::Int64(Some(self.values.len() as i64))), t => Err(DataFusionError::Internal(format!( "Invalid data type {:?} for count distinct aggregation", t @@ -317,7 +317,7 @@ mod tests { assert_eq!(states.len(), 1); assert_eq!(state_vec, vec![Some(1), Some(2), Some(3)]); - assert_eq!(result, ScalarValue::UInt64(Some(3))); + assert_eq!(result, ScalarValue::Int64(Some(3))); Ok(()) }}; @@ -344,7 +344,7 @@ mod tests { .collect::>(), vec![], String::from("__col_name__"), - DataType::UInt64, + DataType::Int64, ); let mut accum = agg.create_accumulator()?; @@ -361,7 +361,7 @@ mod tests { data_types.to_vec(), vec![], String::from("__col_name__"), - DataType::UInt64, + DataType::Int64, ); let mut accum = agg.create_accumulator()?; @@ -393,7 +393,7 @@ mod tests { .collect::>(), vec![], String::from("__col_name__"), - DataType::UInt64, + DataType::Int64, ); let mut accum = agg.create_accumulator()?; @@ -466,7 +466,7 @@ mod tests { ] ); assert!(state_vec[nan_idx].unwrap_or_default().is_nan()); - assert_eq!(result, ScalarValue::UInt64(Some(8))); + assert_eq!(result, ScalarValue::Int64(Some(8))); Ok(()) }}; @@ -524,17 +524,17 @@ mod tests { #[test] fn count_distinct_update_batch_boolean() -> Result<()> { - let get_count = |data: BooleanArray| -> Result<(Vec>, u64)> { + let get_count = |data: BooleanArray| -> Result<(Vec>, i64)> { let arrays = vec![Arc::new(data) as ArrayRef]; let (states, result) = run_update_batch(&arrays)?; let mut state_vec = state_to_vec!(&states[0], Boolean, bool).unwrap(); state_vec.sort(); let count = match result { - ScalarValue::UInt64(c) => c.ok_or_else(|| { + ScalarValue::Int64(c) => c.ok_or_else(|| { DataFusionError::Internal("Found None count".to_string()) }), scalar => Err(DataFusionError::Internal(format!( - "Found non Uint64 scalar value from count: {}", + "Found non int64 scalar value from count: {}", scalar ))), }?; @@ -587,7 +587,7 @@ mod tests { assert_eq!(states.len(), 1); assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); - assert_eq!(result, ScalarValue::UInt64(Some(0))); + assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) } @@ -600,7 +600,7 @@ mod tests { assert_eq!(states.len(), 1); assert_eq!(state_to_vec!(&states[0], Int32, i32), Some(vec![])); - assert_eq!(result, ScalarValue::UInt64(Some(0))); + assert_eq!(result, ScalarValue::Int64(Some(0))); Ok(()) } @@ -623,7 +623,7 @@ mod tests { vec![(Some(1_i8), Some(3_i16)), (Some(2_i8), Some(4_i16))] ); - assert_eq!(result, ScalarValue::UInt64(Some(2))); + assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } @@ -658,7 +658,7 @@ mod tests { (Some(5_i32), Some(1_u64)), ] ); - assert_eq!(result, ScalarValue::UInt64(Some(5))); + assert_eq!(result, ScalarValue::Int64(Some(5))); Ok(()) } @@ -690,7 +690,7 @@ mod tests { vec![(Some(-2_i32), Some(5_u64)), (Some(-1_i32), Some(5_u64))] ); - assert_eq!(result, ScalarValue::UInt64(Some(2))); + assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } @@ -730,7 +730,7 @@ mod tests { ] ); - assert_eq!(result, ScalarValue::UInt64(Some(5))); + assert_eq!(result, ScalarValue::Int64(Some(5))); Ok(()) }