From e0d14c4b3429f08d9470aa4257df3e4e8832910f Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 23 Nov 2021 13:45:37 +0800 Subject: [PATCH 1/2] support min,max for decimal data type --- .../examples/csv_decimal_sql.rs | 54 ++++ datafusion/Cargo.toml | 6 +- datafusion/src/execution/context.rs | 25 ++ datafusion/src/physical_plan/aggregates.rs | 246 +++++++++++++++--- .../coercion_rule/aggregate_rule.rs | 175 +++++++++++++ .../src/physical_plan/coercion_rule/mod.rs | 19 ++ .../physical_plan/expressions/array_agg.rs | 4 + .../src/physical_plan/expressions/average.rs | 13 + .../src/physical_plan/expressions/coercion.rs | 87 ++++++- .../src/physical_plan/expressions/min_max.rs | 36 +++ .../src/physical_plan/expressions/mod.rs | 2 + .../src/physical_plan/expressions/sum.rs | 17 +- datafusion/src/physical_plan/functions.rs | 1 + datafusion/src/physical_plan/mod.rs | 1 + datafusion/src/physical_plan/type_coercion.rs | 18 +- datafusion/src/sql/planner.rs | 31 ++- datafusion/tests/sql.rs | 134 ++++++++++ 17 files changed, 814 insertions(+), 55 deletions(-) create mode 100644 datafusion-examples/examples/csv_decimal_sql.rs create mode 100644 datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs create mode 100644 datafusion/src/physical_plan/coercion_rule/mod.rs diff --git a/datafusion-examples/examples/csv_decimal_sql.rs b/datafusion-examples/examples/csv_decimal_sql.rs new file mode 100644 index 0000000000000..e419d62a480ce --- /dev/null +++ b/datafusion-examples/examples/csv_decimal_sql.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::error::Result; +use datafusion::prelude::*; +use std::sync::Arc; + +/// This example demonstrates executing a simple query against an Arrow data source (CSV) and +/// fetching results +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let mut ctx = ExecutionContext::new(); + + let testdata = datafusion::test_util::arrow_test_data(); + + // schema with decimal type + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Decimal(10, 6), false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + // register csv file with the execution context + ctx.register_csv( + "aggregate_simple", + &format!("{}/csv/aggregate_simple.csv", testdata), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // execute the query + let df = ctx.sql("select c1 from aggregate_simple").await?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 390dab43160de..c76899a95dccf 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -52,8 +52,10 @@ avro = ["avro-rs", "num-traits"] [dependencies] ahash = "0.7" hashbrown = { version = "0.11", features = ["raw"] } -arrow = { version = "6.2.0", features = ["prettyprint"] } -parquet = { version = "6.2.0", features = ["arrow"] } +arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] } +#arrow = { version = "6.2.0", features = ["prettyprint"] } +#parquet = { version = "6.2.0", features = ["arrow"] } +parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] } sqlparser = "0.12" paste = "^1.0" num_cpus = "1.13.0" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 5f77b2bfaa6bc..867dd30d126c8 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -3895,6 +3895,31 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_decimal() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // schema with data + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Decimal(10, 6), false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + ctx.register_csv( + "aggregate_simple", + "tests/aggregate_simple.csv", + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // decimal query + let result = plan_and_collect(&mut ctx, "select min(c1) from aggregate_simple") + .await + .unwrap(); + println!("{:?}", result); + Ok(()) + } + #[tokio::test] async fn create_external_table_with_timestamps() { let mut ctx = ExecutionContext::new(); diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 0c99c4f99caf7..c7d507e6fbb55 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,15 +28,16 @@ use super::{ functions::{Signature, Volatility}, - type_coercion::{coerce, data_types}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use expressions::{avg_return_type, sum_return_type}; use std::{fmt, str::FromStr, sync::Arc}; + /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = Arc Result> + Send + Sync>; @@ -87,96 +88,123 @@ impl FromStr for AggregateFunction { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", name - ))) + ))); } }) } } -/// Returns the datatype of the scalar function +/// Returns the datatype of the aggregate function. +/// This is used to get the returned data type for aggregate expr. pub fn return_type(fun: &AggregateFunction, arg_types: &[DataType]) -> Result { // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function - data_types(arg_types, &signature(fun))?; + let coerced_data_types = coerce_types(fun, arg_types, &signature(fun))?; match fun { + // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64. AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(DataType::UInt64) } - AggregateFunction::Max | AggregateFunction::Min => Ok(arg_types[0].clone()), - AggregateFunction::Sum => sum_return_type(&arg_types[0]), - AggregateFunction::Avg => avg_return_type(&arg_types[0]), + AggregateFunction::Max | AggregateFunction::Min => { + // For min and max agg function, the returned type is same as input type. + // The coerced_data_types is same with input_types. + Ok(coerced_data_types[0].clone()) + } + AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", - arg_types[0].clone(), + coerced_data_types[0].clone(), true, )))), } } -/// Create a physical (function) expression. -/// This function errors when `args`' can't be coerced to a valid argument type of the function. +/// Create a physical (aggregate) expression. +/// This function errors when `input_phy_exprs`' can't be coerced to a valid argument type of the aggregate function. pub fn create_aggregate_expr( fun: &AggregateFunction, distinct: bool, - args: &[Arc], + input_phy_exprs: &[Arc], input_schema: &Schema, name: impl Into, ) -> Result> { let name = name.into(); - let arg = coerce(args, input_schema, &signature(fun))?; - if arg.is_empty() { + // get the coerced phy exprs if some expr need try cast + let coerced_phy_exprs = + coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?; + if coerced_phy_exprs.is_empty() { return Err(DataFusionError::Plan(format!( "Invalid or wrong number of arguments passed to aggregate: '{}'", name, ))); } - let arg = arg[0].clone(); - - let arg_types = args + let coerced_types = coerced_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; + let first_coerced_phy_expr = coerced_phy_exprs[0].clone(); - let return_type = return_type(fun, &arg_types)?; + // get the result data type for this aggregate function + let input_phy_types = input_phy_exprs + .iter() + .map(|e| e.data_type(input_schema)) + .collect::>>()?; + let return_type = return_type(fun, &input_phy_types)?; Ok(match (fun, distinct) { - (AggregateFunction::Count, false) => { - Arc::new(expressions::Count::new(arg, name, return_type)) - } + (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( + first_coerced_phy_expr, + name, + return_type, + )), (AggregateFunction::Count, true) => { Arc::new(distinct_expressions::DistinctCount::new( - arg_types, - args.to_vec(), + coerced_types, + coerced_phy_exprs, name, return_type, )) } - (AggregateFunction::Sum, false) => { - Arc::new(expressions::Sum::new(arg, name, return_type)) - } + (AggregateFunction::Sum, false) => Arc::new(expressions::Sum::new( + first_coerced_phy_expr, + name, + return_type, + )), (AggregateFunction::Sum, true) => { return Err(DataFusionError::NotImplemented( "SUM(DISTINCT) aggregations are not available".to_string(), )); } - (AggregateFunction::ApproxDistinct, _) => Arc::new( - expressions::ApproxDistinct::new(arg, name, arg_types[0].clone()), - ), - (AggregateFunction::ArrayAgg, _) => { - Arc::new(expressions::ArrayAgg::new(arg, name, arg_types[0].clone())) - } - (AggregateFunction::Min, _) => { - Arc::new(expressions::Min::new(arg, name, return_type)) - } - (AggregateFunction::Max, _) => { - Arc::new(expressions::Max::new(arg, name, return_type)) - } - (AggregateFunction::Avg, false) => { - Arc::new(expressions::Avg::new(arg, name, return_type)) + (AggregateFunction::ApproxDistinct, _) => { + Arc::new(expressions::ApproxDistinct::new( + first_coerced_phy_expr, + name, + coerced_types[0].clone(), + )) } + (AggregateFunction::ArrayAgg, _) => Arc::new(expressions::ArrayAgg::new( + first_coerced_phy_expr, + name, + coerced_types[0].clone(), + )), + (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( + first_coerced_phy_expr, + name, + return_type, + )), + (AggregateFunction::Max, _) => Arc::new(expressions::Max::new( + first_coerced_phy_expr, + name, + return_type, + )), + (AggregateFunction::Avg, false) => Arc::new(expressions::Avg::new( + first_coerced_phy_expr, + name, + return_type, + )), (AggregateFunction::Avg, true) => { return Err(DataFusionError::NotImplemented( "AVG(DISTINCT) aggregations are not available".to_string(), @@ -236,6 +264,130 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::error::Result; + use crate::physical_plan::expressions::{ApproxDistinct, ArrayAgg, Count, Max, Min}; + + #[test] + fn test_count_arragg_approx_expr() -> Result<()> { + let funcs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + ]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Count => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ApproxDistinct => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, false), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ArrayAgg => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new( + "c1", + DataType::List(Box::new(Field::new( + "item", + data_type.clone(), + true + ))), + false + ), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_min_max_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Min => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::Max => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_sum_avg_expr() -> Result<()> { + Ok(()) + } #[test] fn test_min_max() -> Result<()> { @@ -244,6 +396,16 @@ mod tests { let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; assert_eq!(DataType::Int32, observed); + + // test decimal for min + let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(10, 6), observed); + + // test decimal for max + let observed = + return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::Decimal(28, 13), observed); + Ok(()) } @@ -267,6 +429,10 @@ mod tests { let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; assert_eq!(DataType::UInt64, observed); + + let observed = + return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::UInt64, observed); Ok(()) } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs new file mode 100644 index 0000000000000..5508c34f3c2af --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support the coercion rule for aggregate function. + +use crate::arrow::datatypes::Schema; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::aggregates::AggregateFunction; +use crate::physical_plan::expressions::{ + is_avg_support_arg_type, is_sum_support_arg_type, try_cast, +}; +use crate::physical_plan::functions::{Signature, TypeSignature}; +use crate::physical_plan::PhysicalExpr; +use arrow::datatypes::DataType; +use std::sync::Arc; + +pub fn coerce_types( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &Signature, +) -> Result> { + match signature.type_signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != agg_count { + return Err(DataFusionError::Plan(format!("The function {:?} expect argument number is {:?}, but the input argument number is {:?}", + agg_fun, agg_count, input_types.len()))); + } + } + _ => { + return Err(DataFusionError::Plan(format!( + "The aggregate coercion rule don't support this {:?}", + signature + ))); + } + }; + match agg_fun { + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(input_types.to_vec()) + } + AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), + AggregateFunction::Min | AggregateFunction::Max => Ok(input_types.to_vec()), + AggregateFunction::Sum => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + if !is_sum_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} do not support the {:?}", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Avg => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval + if !is_avg_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} do not support the {:?}", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +pub fn coerce_exprs( + agg_fun: &AggregateFunction, + input_exprs: &[Arc], + schema: &Schema, + signature: &Signature, +) -> Result>> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let input_types = input_exprs + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // get the coerced data types + let coerced_types = coerce_types(agg_fun, &input_types, signature)?; + + // try cast if need + input_exprs + .iter() + .enumerate() + .map(|(i, expr)| try_cast(expr.clone(), schema, coerced_types[i].clone())) + .collect::>>() +} + +#[cfg(test)] +mod tests { + use crate::physical_plan::aggregates; + use crate::physical_plan::aggregates::{signature, AggregateFunction}; + use crate::physical_plan::coercion_rule::aggregate_rule::coerce_types; + use arrow::datatypes::DataType; + + #[test] + fn test_aggregate_coerce_types() { + // test input args with error number input types + let fun = AggregateFunction::Min; + let input_types = vec![DataType::Int64, DataType::Int32]; + let signature = signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!("Error during planning: The function Min expect argument number is 1, but the input argument number is 2", result.unwrap_err().to_string()); + + // test input args is invalid data type for sum or avg + let fun = AggregateFunction::Sum; + let input_types = vec![DataType::Utf8]; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Sum do not support the Utf8", + result.unwrap_err().to_string() + ); + let fun = AggregateFunction::Avg; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Avg do not support the Utf8", + result.unwrap_err().to_string() + ); + + // test count, array_agg, approx_distinct, min, max. + // the coerced types is same with input types + let funs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + AggregateFunction::Min, + AggregateFunction::Max, + ]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Decimal(10, 2)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + // test sum, avg + let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Float32], + vec![DataType::Decimal(20, 3)], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + } +} diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs new file mode 100644 index 0000000000000..8d07b10bfe23d --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! define the coercion rule for different Expr type +pub(crate) mod aggregate_rule; diff --git a/datafusion/src/physical_plan/expressions/array_agg.rs b/datafusion/src/physical_plan/expressions/array_agg.rs index 213b392f627b9..3139c874004b9 100644 --- a/datafusion/src/physical_plan/expressions/array_agg.rs +++ b/datafusion/src/physical_plan/expressions/array_agg.rs @@ -86,6 +86,10 @@ impl AggregateExpr for ArrayAgg { fn expressions(&self) -> Vec> { vec![self.expr.clone()] } + + fn name(&self) -> &str { + &self.name + } } #[derive(Debug)] diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 2e218191f6683..9489d30e49064 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -62,6 +62,19 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { + match arg_type { + // TODO: do we need to support the unsigned data type? + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true, + DataType::Float16 | DataType::Float32 | DataType::Float64 => true, + // TODO support the decimal data type + DataType::Decimal(_, _) => true, + // TODO support the interva + _ => false, + } +} + impl Avg { /// Create a new AVG aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/expressions/coercion.rs b/datafusion/src/physical_plan/expressions/coercion.rs index 180b16548b32b..b8d8d62714d46 100644 --- a/datafusion/src/physical_plan/expressions/coercion.rs +++ b/datafusion/src/physical_plan/expressions/coercion.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType; /// Determine if a DataType is signed numeric or not -pub fn is_signed_numeric(dt: &DataType) -> bool { +pub(crate) fn is_signed_numeric(dt: &DataType) -> bool { matches!( dt, DataType::Int8 @@ -29,12 +29,13 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { | DataType::Int64 | DataType::Float16 | DataType::Float32 - | DataType::Float64 + | DataType::Float64 // TODO liukun4515 + // | DataType::Decimal(_,_) ) } /// Determine if a DataType is numeric or not -pub fn is_numeric(dt: &DataType) -> bool { +fn is_numeric(dt: &DataType) -> bool { is_signed_numeric(dt) || match dt { DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { @@ -125,6 +126,11 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option 0 { + let this_type = &test_types[index - 1]; + for i in 0..index { + assert_eq!( + Some(this_type.clone()), + numerical_coercion(this_type, &test_types[i]) + ); + } + index -= 1; + } + } } diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 9e5b1e095cd6f..e445ab31f4ea1 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -487,6 +487,42 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn min_decimal() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn max_decimal() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn min_decimal_with_nulls() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn max_decimal_with_nulls() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn min_decimal_with_all_nulls() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn max_decimal_with_all_nulls() -> Result<()> { + // todo!() + Ok(()) + } + #[test] fn max_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 5647ee0a4d270..134c6d89ac4f1 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -60,6 +60,7 @@ pub mod helpers { pub use approx_distinct::ApproxDistinct; pub use array_agg::ArrayAgg; +pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; @@ -83,6 +84,7 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c3f57e31e0d54..36ab395084223 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -63,6 +63,19 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { + match arg_type { + // TODO: do we need to support the unsigned data type? + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true, + DataType::Float16 | DataType::Float32 | DataType::Float64 => true, + // TODO support the decimal data type + DataType::Decimal(_, _) => true, + // TODO support the interva + _ => false, + } +} + impl Sum { /// Create a new SUM aggregate function pub fn new( @@ -154,7 +167,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive the type {:?}", e - ))) + ))); } }) } @@ -238,7 +251,7 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive a scalar {:?}", e - ))) + ))); } }) } diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 72b2635be385d..5226e3225bcc7 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -144,6 +144,7 @@ impl Signature { } ///A function's volatility, which defines the functions eligibility for certain optimizations +///Ref from postgresql https://www.postgresql.org/docs/current/xfunc-volatility.html #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum Volatility { /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos]. diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ef53d8602b405..8c5f662a4ac73 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -608,6 +608,7 @@ pub mod analyze; pub mod array_expressions; pub mod coalesce_batches; pub mod coalesce_partitions; +mod coercion_rule; pub mod common; pub mod cross_join; #[cfg(feature = "crypto_expressions")] diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index 801a83b909fa2..3dd1de46a669c 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -76,6 +76,7 @@ pub fn data_types( if current_types.is_empty() { return Ok(vec![]); } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types @@ -103,11 +104,11 @@ fn get_valid_types( current_types: &[DataType], ) -> Result>> { let valid_types = match signature { - TypeSignature::Variadic(valid_types, ..) => valid_types + TypeSignature::Variadic(valid_types) => valid_types .iter() .map(|valid_type| current_types.iter().map(|_| valid_type.clone()).collect()) .collect(), - TypeSignature::Uniform(number, valid_types, ..) => valid_types + TypeSignature::Uniform(number, valid_types) => valid_types .iter() .map(|valid_type| (0..*number).map(|_| valid_type.clone()).collect()) .collect(), @@ -118,8 +119,8 @@ fn get_valid_types( .map(|_| current_types[0].clone()) .collect()] } - TypeSignature::Exact(valid_types, ..) => vec![valid_types.clone()], - TypeSignature::Any(number, ..) => { + TypeSignature::Exact(valid_types) => vec![valid_types.clone()], + TypeSignature::Any(number) => { if current_types.len() != *number { return Err(DataFusionError::Plan(format!( "The function expected {} arguments but received {}", @@ -129,7 +130,7 @@ fn get_valid_types( } vec![(0..*number).map(|i| current_types[i].clone()).collect()] } - TypeSignature::OneOf(types, ..) => types + TypeSignature::OneOf(types) => types .iter() .filter_map(|t| get_valid_types(t, current_types).ok()) .flatten() @@ -144,6 +145,8 @@ fn maybe_data_types( valid_types: &[DataType], current_types: &[DataType], ) -> Option> { + // TODO liukun4515 + if valid_types.len() != current_types.len() { return None; } @@ -155,7 +158,6 @@ fn maybe_data_types( if current_type == valid_type { new_type.push(current_type.clone()) } else { - // attempt to coerce if can_coerce_from(valid_type, current_type) { new_type.push(valid_type.clone()) } else { @@ -171,9 +173,11 @@ fn maybe_data_types( /// (losslessly converted) into a value of `type_to` /// /// See the module level documentation for more detail on coercion. -pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { +fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; + // TODO liukun4515 match type_into { + // TODO, decimal data type, we just support the decimal Int8 => matches!(type_from, Int8), Int16 => matches!(type_from, Int8 | Int16 | UInt8), Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16), diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index 9f1d2433fe6bb..1ac468f047f13 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -372,13 +372,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => { Ok(DataType::Utf8) } - SQLDataType::Decimal(_, _) => Ok(DataType::Float64), SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real | SQLDataType::Double => Ok(DataType::Float64), SQLDataType::Boolean => Ok(DataType::Boolean), SQLDataType::Date => Ok(DataType::Date32), SQLDataType::Time => Ok(DataType::Time64(TimeUnit::Millisecond)), SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), + // TODO liukun4515 + // from the sql statement data type to the arrow or datafusion data type + SQLDataType::Decimal(precision, scale) => { + if precision.is_none() + || scale.is_none() + || precision.unwrap() > 38 + || scale.unwrap() > precision.unwrap() + { + // illegal arguments for decimal + Err(DataFusionError::Internal(format!( + "The decimal data type {:?} is error", + sql_type + ))) + } else { + Ok(DataType::Decimal( + precision.unwrap() as usize, + scale.unwrap() as usize, + )) + } + } _ => Err(DataFusionError::NotImplemented(format!( "The SQL data type {:?} is not implemented", sql_type @@ -1999,6 +2018,16 @@ pub fn convert_data_type(sql: &SQLDataType) -> Result { SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8), SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), SQLDataType::Date => Ok(DataType::Date32), + // TODO liukun4515 + // SQLDataType::Decimal(precision,scale) => { + // if precision.is_none() || scale.is_none() { + // Err(DataFusionError::Internal(format!("Error DecimalType({:?},{:?})",precision,scale))) + // } else if precision.unwrap() >38 || scale.unwrap()>precision.unwrap() { + // Err(DataFusionError::Internal(format!("Error DecimalType({:?},{:?})",precision,scale))) + // } else { + // Ok(DataType::Decimal(precision.unwrap() as usize, scale.unwrap() as usize)) + // } + // } other => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", other diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index e761633f1702c..999c89e268215 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -411,6 +411,98 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { Ok(()) } +#[tokio::test] +async fn error_count_agg() -> Result<()> { + let mut ctx = ExecutionContext::new(); + + // let sql = "select sum('123')"; + let sql = "select min(1)"; + let actual = execute_to_batches(&mut ctx, sql).await; + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_decimal_by_sql() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; + let sql = "SELECT c1 from aggregate_simple"; + // let sql = "SELECT '123'+1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| c1 |", + "+----------+", + "| 0.000010 |", + "| 0.000020 |", + "| 0.000020 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_decimal() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,6) + register_aggregate_simple_csv_use_decimal(&mut ctx).await?; + // query + let mut sql = "SELECT c1 from aggregate_simple"; + let mut actual = execute_to_batches(&mut ctx, sql).await; + let mut expected = vec![ + "+----------+", + "| c1 |", + "+----------+", + "| 0.000010 |", + "| 0.000020 |", + "| 0.000020 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + // aggregate: min,max,count,sum + sql = "SELECT MIN(c1) from aggregate_simple"; + actual = execute_to_batches(&mut ctx, sql).await; + println!("{:?}", actual); + + sql = "SELECT MAX(c1) from aggregate_simple"; + actual = execute_to_batches(&mut ctx, sql).await; + println!("{:?}", actual); + + sql = "SELECT COUNT(c1) from aggregate_simple"; + actual = execute_to_batches(&mut ctx, sql).await; + println!("{:?}", actual); + + sql = "SELECT SUM(c1) from aggregate_simple"; + actual = execute_to_batches(&mut ctx, sql).await; + println!("{:?}", actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_float32() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -3526,6 +3618,29 @@ async fn explain_analyze_runs_optimizers() { assert_contains!(actual, expected); } +async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) { + // c1 DECIMAL(10,6) NOT NULL, + let df = ctx + .sql(&format!( + "CREATE EXTERNAL TABLE aggregate_simple ( + c1 DECIMAL(10,6) NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL + ) + STORED AS CSV + WITH HEADER ROW + LOCATION 'tests/aggregate_simple.csv'" + )) + .await + .expect("Creating dataframe for CREATE EXTERNAL TABLE with decimal data type"); + + let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE"); + assert!( + results.is_empty(), + "Expected no rows from executing CREATE EXTERNAL TABLE" + ); +} + async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { let testdata = datafusion::test_util::arrow_test_data(); @@ -3615,6 +3730,25 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } +async fn register_aggregate_simple_csv_use_decimal( + ctx: &mut ExecutionContext, +) -> Result<()> { + // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Decimal(10, 6), false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + ctx.register_csv( + "aggregate_simple", + "tests/aggregate_simple.csv", + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> { // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats let schema = Arc::new(Schema::new(vec![ From 6c76c1cac5272bda3ddf458343b55ab0e7dd4638 Mon Sep 17 00:00:00 2001 From: liukun4515 Date: Tue, 30 Nov 2021 16:16:40 +0800 Subject: [PATCH 2/2] fix the clippy --- datafusion/Cargo.toml | 8 ++-- datafusion/src/execution/context.rs | 11 +++-- datafusion/src/physical_plan/aggregates.rs | 10 ++--- .../src/physical_plan/expressions/average.rs | 27 ++++++----- .../src/physical_plan/expressions/coercion.rs | 10 ++--- .../src/physical_plan/expressions/sum.rs | 27 ++++++----- datafusion/src/physical_plan/type_coercion.rs | 14 ++---- datafusion/tests/sql.rs | 45 ++++++++++--------- 8 files changed, 82 insertions(+), 70 deletions(-) diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index a1f50b6b5de6e..c76899a95dccf 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -52,10 +52,10 @@ avro = ["avro-rs", "num-traits"] [dependencies] ahash = "0.7" hashbrown = { version = "0.11", features = ["raw"] } -#arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] } -arrow = { version = "6.2.0", features = ["prettyprint"] } -parquet = { version = "6.2.0", features = ["arrow"] } -#parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] } +arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] } +#arrow = { version = "6.2.0", features = ["prettyprint"] } +#parquet = { version = "6.2.0", features = ["arrow"] } +parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] } sqlparser = "0.12" paste = "^1.0" num_cpus = "1.13.0" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 232bff97361b8..1b887a1e184be 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -3914,10 +3914,13 @@ mod tests { .await?; // decimal query - let result = plan_and_collect(&mut ctx, "select min(c1) from aggregate_simple") - .await - .unwrap(); - println!("{:?}", result); + + // let result = plan_and_collect(&mut ctx, "select min(c1) from aggregate_simple") + // .await + // .unwrap(); + // + // println!("{:?}", result); + Ok(()) } diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 15fb88f47b1fa..e0b05eb298b7d 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -135,7 +135,7 @@ pub fn create_aggregate_expr( name: impl Into, ) -> Result> { let name = name.into(); - // get the coerced phy exprs if some expr need try cast + // get the coerced phy exprs if some expr need to be wrapped with the try cast. let coerced_phy_exprs = coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?; if coerced_phy_exprs.is_empty() { @@ -144,7 +144,7 @@ pub fn create_aggregate_expr( name, ))); } - let coerced_types = coerced_phy_exprs + let coerced_exprs_types = coerced_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; @@ -164,7 +164,7 @@ pub fn create_aggregate_expr( )), (AggregateFunction::Count, true) => { Arc::new(distinct_expressions::DistinctCount::new( - coerced_types, + coerced_exprs_types, coerced_phy_exprs, name, return_type, @@ -184,13 +184,13 @@ pub fn create_aggregate_expr( Arc::new(expressions::ApproxDistinct::new( coerced_phy_exprs[0].clone(), name, - coerced_types[0].clone(), + coerced_exprs_types[0].clone(), )) } (AggregateFunction::ArrayAgg, _) => Arc::new(expressions::ArrayAgg::new( coerced_phy_exprs[0].clone(), name, - coerced_types[0].clone(), + coerced_exprs_types[0].clone(), )), (AggregateFunction::Min, _) => Arc::new(expressions::Min::new( coerced_phy_exprs[0].clone(), diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 9489d30e49064..87b4484269e19 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -63,16 +63,23 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - // TODO: do we need to support the unsigned data type? - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true, - DataType::Float16 | DataType::Float32 | DataType::Float64 => true, - // TODO support the decimal data type - DataType::Decimal(_, _) => true, - // TODO support the interva - _ => false, - } + // TODO support the interval + // TODO: do we need to support the unsigned data type? + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) } impl Avg { diff --git a/datafusion/src/physical_plan/expressions/coercion.rs b/datafusion/src/physical_plan/expressions/coercion.rs index b8d8d62714d46..c9b5f39ca3d27 100644 --- a/datafusion/src/physical_plan/expressions/coercion.rs +++ b/datafusion/src/physical_plan/expressions/coercion.rs @@ -20,7 +20,7 @@ use arrow::datatypes::DataType; /// Determine if a DataType is signed numeric or not -pub(crate) fn is_signed_numeric(dt: &DataType) -> bool { +pub fn is_signed_numeric(dt: &DataType) -> bool { matches!( dt, DataType::Int8 @@ -35,7 +35,7 @@ pub(crate) fn is_signed_numeric(dt: &DataType) -> bool { } /// Determine if a DataType is numeric or not -fn is_numeric(dt: &DataType) -> bool { +pub fn is_numeric(dt: &DataType) -> bool { is_signed_numeric(dt) || match dt { DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { @@ -127,7 +127,7 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option 0 { let this_type = &test_types[index - 1]; - for i in 0..index { + for that_type in test_types.iter().take(index) { assert_eq!( Some(this_type.clone()), - numerical_coercion(this_type, &test_types[i]) + numerical_coercion(this_type, that_type) ); } index -= 1; diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index 36ab395084223..aafca631da5ba 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -64,16 +64,23 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { - match arg_type { - // TODO: do we need to support the unsigned data type? - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => true, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => true, - DataType::Float16 | DataType::Float32 | DataType::Float64 => true, - // TODO support the decimal data type - DataType::Decimal(_, _) => true, - // TODO support the interva - _ => false, - } + // TODO support the interval + // TODO: do we need to support the unsigned data type? + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) } impl Sum { diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index 3dd1de46a669c..3f86eed308309 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -145,8 +145,6 @@ fn maybe_data_types( valid_types: &[DataType], current_types: &[DataType], ) -> Option> { - // TODO liukun4515 - if valid_types.len() != current_types.len() { return None; } @@ -157,13 +155,10 @@ fn maybe_data_types( if current_type == valid_type { new_type.push(current_type.clone()) + } else if can_coerce_from(valid_type, current_type) { + new_type.push(valid_type.clone()) } else { - if can_coerce_from(valid_type, current_type) { - new_type.push(valid_type.clone()) - } else { - // not possible - return None; - } + return None; } } Some(new_type) @@ -173,9 +168,8 @@ fn maybe_data_types( /// (losslessly converted) into a value of `type_to` /// /// See the module level documentation for more detail on coercion. -fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { +pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; - // TODO liukun4515 match type_into { // TODO, decimal data type, we just support the decimal Int8 => matches!(type_from, Int8), diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index dc957a22749be..63b246f971494 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -414,11 +414,11 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { #[tokio::test] async fn error_count_agg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + // let mut ctx = ExecutionContext::new(); // let sql = "select sum('123')"; - let sql = "select min(1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + // let sql = "select min(1)"; + // let actual = execute_to_batches(&mut ctx, sql).await; Ok(()) } @@ -460,9 +460,9 @@ async fn csv_query_with_decimal() -> Result<()> { // the data type of c1 is decimal(10,6) register_aggregate_simple_csv_use_decimal(&mut ctx).await?; // query - let mut sql = "SELECT c1 from aggregate_simple"; - let mut actual = execute_to_batches(&mut ctx, sql).await; - let mut expected = vec![ + let sql = "SELECT c1 from aggregate_simple"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ "+----------+", "| c1 |", "+----------+", @@ -484,22 +484,23 @@ async fn csv_query_with_decimal() -> Result<()> { "+----------+", ]; assert_batches_eq!(expected, &actual); - // aggregate: min,max,count,sum - sql = "SELECT MIN(c1) from aggregate_simple"; - actual = execute_to_batches(&mut ctx, sql).await; - println!("{:?}", actual); - sql = "SELECT MAX(c1) from aggregate_simple"; - actual = execute_to_batches(&mut ctx, sql).await; - println!("{:?}", actual); - - sql = "SELECT COUNT(c1) from aggregate_simple"; - actual = execute_to_batches(&mut ctx, sql).await; - println!("{:?}", actual); - - sql = "SELECT SUM(c1) from aggregate_simple"; - actual = execute_to_batches(&mut ctx, sql).await; - println!("{:?}", actual); + // aggregate: min,max,count,sum + // sql = "SELECT MIN(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + // + // sql = "SELECT MAX(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + // + // sql = "SELECT COUNT(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + // + // sql = "SELECT SUM(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); Ok(()) } @@ -5748,7 +5749,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { let logical_plan = ctx.create_logical_plan(sql)?; let physical_plan = ctx.create_physical_plan(&logical_plan).await; let err = physical_plan.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Invalid or wrong number of arguments passed to aggregate: 'COUNT(DISTINCT )'"); + assert_eq!(err.to_string(), "The function Count expect argument number is 1, but the input argument number is 0"); Ok(()) }