diff --git a/datafusion-examples/examples/csv_decimal_sql.rs b/datafusion-examples/examples/csv_decimal_sql.rs new file mode 100644 index 0000000000000..e419d62a480ce --- /dev/null +++ b/datafusion-examples/examples/csv_decimal_sql.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use datafusion::arrow::datatypes::{DataType, Field, Schema}; +use datafusion::error::Result; +use datafusion::prelude::*; +use std::sync::Arc; + +/// This example demonstrates executing a simple query against an Arrow data source (CSV) and +/// fetching results +#[tokio::main] +async fn main() -> Result<()> { + // create local execution context + let mut ctx = ExecutionContext::new(); + + let testdata = datafusion::test_util::arrow_test_data(); + + // schema with decimal type + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Decimal(10, 6), false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + // register csv file with the execution context + ctx.register_csv( + "aggregate_simple", + &format!("{}/csv/aggregate_simple.csv", testdata), + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // execute the query + let df = ctx.sql("select c1 from aggregate_simple").await?; + + // print the results + df.show().await?; + + Ok(()) +} diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 390dab43160de..c76899a95dccf 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -52,8 +52,10 @@ avro = ["avro-rs", "num-traits"] [dependencies] ahash = "0.7" hashbrown = { version = "0.11", features = ["raw"] } -arrow = { version = "6.2.0", features = ["prettyprint"] } -parquet = { version = "6.2.0", features = ["arrow"] } +arrow = { path = "/Users/kliu3/Documents/github/arrow-rs/arrow", features = ["prettyprint"] } +#arrow = { version = "6.2.0", features = ["prettyprint"] } +#parquet = { version = "6.2.0", features = ["arrow"] } +parquet = { path = "/Users/kliu3/Documents/github/arrow-rs/parquet", features = ["arrow"] } sqlparser = "0.12" paste = "^1.0" num_cpus = "1.13.0" diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 27116c0a4a952..1b887a1e184be 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -2058,7 +2058,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: Coercion from [Timestamp(Nanosecond, None)] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed."); + assert_eq!(results.to_string(), "Error during planning: The function Sum do not support the Timestamp(Nanosecond, None)."); Ok(()) } @@ -2155,7 +2155,7 @@ mod tests { .await .unwrap_err(); - assert_eq!(results.to_string(), "Error during planning: Coercion from [Timestamp(Nanosecond, None)] to the signature Uniform(1, [Int8, Int16, Int32, Int64, UInt8, UInt16, UInt32, UInt64, Float32, Float64]) failed."); + assert_eq!(results.to_string(), "Error during planning: The function Avg do not support the Timestamp(Nanosecond, None)."); Ok(()) } @@ -3896,6 +3896,34 @@ mod tests { Ok(()) } + #[tokio::test] + async fn aggregate_decimal() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // schema with data + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Decimal(10, 6), false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + ctx.register_csv( + "aggregate_simple", + "tests/aggregate_simple.csv", + CsvReadOptions::new().schema(&schema), + ) + .await?; + + // decimal query + + // let result = plan_and_collect(&mut ctx, "select min(c1) from aggregate_simple") + // .await + // .unwrap(); + // + // println!("{:?}", result); + + Ok(()) + } + #[tokio::test] async fn create_external_table_with_timestamps() { let mut ctx = ExecutionContext::new(); diff --git a/datafusion/src/physical_plan/aggregates.rs b/datafusion/src/physical_plan/aggregates.rs index 1ec33a409efba..e0b05eb298b7d 100644 --- a/datafusion/src/physical_plan/aggregates.rs +++ b/datafusion/src/physical_plan/aggregates.rs @@ -28,15 +28,16 @@ use super::{ functions::{Signature, Volatility}, - type_coercion::{coerce, data_types}, Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{DataFusionError, Result}; +use crate::physical_plan::coercion_rule::aggregate_rule::{coerce_exprs, coerce_types}; use crate::physical_plan::distinct_expressions; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use expressions::{avg_return_type, sum_return_type}; use std::{fmt, str::FromStr, sync::Arc}; + /// the implementation of an aggregate function pub type AccumulatorFunctionImplementation = Arc Result> + Send + Sync>; @@ -87,13 +88,14 @@ impl FromStr for AggregateFunction { return Err(DataFusionError::Plan(format!( "There is no built-in function named {}", name - ))) + ))); } }) } } -/// Returns the datatype of the aggregation function +/// Returns the datatype of the aggregate function. +/// This is used to get the returned data type for aggregate expr. pub fn return_type( fun: &AggregateFunction, input_expr_types: &[DataType], @@ -101,21 +103,23 @@ pub fn return_type( // Note that this function *must* return the same type that the respective physical expression returns // or the execution panics. - // verify that this is a valid set of data types for this function - data_types(input_expr_types, &signature(fun))?; + let coerced_data_types = coerce_types(fun, input_expr_types, &signature(fun))?; match fun { + // TODO If the datafusion is compatible with PostgreSQL, the returned data type should be INT64. AggregateFunction::Count | AggregateFunction::ApproxDistinct => { Ok(DataType::UInt64) } AggregateFunction::Max | AggregateFunction::Min => { - Ok(input_expr_types[0].clone()) + // For min and max agg function, the returned type is same as input type. + // The coerced_data_types is same with input_types. + Ok(coerced_data_types[0].clone()) } - AggregateFunction::Sum => sum_return_type(&input_expr_types[0]), - AggregateFunction::Avg => avg_return_type(&input_expr_types[0]), + AggregateFunction::Sum => sum_return_type(&coerced_data_types[0]), + AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]), AggregateFunction::ArrayAgg => Ok(DataType::List(Box::new(Field::new( "item", - input_expr_types[0].clone(), + coerced_data_types[0].clone(), true, )))), } @@ -131,26 +135,26 @@ pub fn create_aggregate_expr( name: impl Into, ) -> Result> { let name = name.into(); - let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &signature(fun))?; + // get the coerced phy exprs if some expr need to be wrapped with the try cast. + let coerced_phy_exprs = + coerce_exprs(fun, input_phy_exprs, input_schema, &signature(fun))?; if coerced_phy_exprs.is_empty() { return Err(DataFusionError::Plan(format!( "Invalid or wrong number of arguments passed to aggregate: '{}'", name, ))); } - let coerced_exprs_types = coerced_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - let input_exprs_types = input_phy_exprs + // get the result data type for this aggregate function + let input_phy_types = input_phy_exprs .iter() .map(|e| e.data_type(input_schema)) .collect::>>()?; - - // In order to get the result data type, we must use the original input data type to calculate the result type. - let return_type = return_type(fun, &input_exprs_types)?; + let return_type = return_type(fun, &input_phy_types)?; Ok(match (fun, distinct) { (AggregateFunction::Count, false) => Arc::new(expressions::Count::new( @@ -161,7 +165,7 @@ pub fn create_aggregate_expr( (AggregateFunction::Count, true) => { Arc::new(distinct_expressions::DistinctCount::new( coerced_exprs_types, - coerced_phy_exprs.to_vec(), + coerced_phy_exprs, name, return_type, )) @@ -262,6 +266,130 @@ pub fn signature(fun: &AggregateFunction) -> Signature { mod tests { use super::*; use crate::error::Result; + use crate::physical_plan::expressions::{ApproxDistinct, ArrayAgg, Count, Max, Min}; + + #[test] + fn test_count_arragg_approx_expr() -> Result<()> { + let funcs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + ]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Count => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ApproxDistinct => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", DataType::UInt64, false), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::ArrayAgg => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new( + "c1", + DataType::List(Box::new(Field::new( + "item", + data_type.clone(), + true + ))), + false + ), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_min_max_expr() -> Result<()> { + let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; + let data_types = vec![ + DataType::UInt32, + DataType::Int32, + DataType::Float32, + DataType::Float64, + DataType::Decimal(10, 2), + DataType::Utf8, + ]; + for fun in funcs { + for data_type in &data_types { + let input_schema = + Schema::new(vec![Field::new("c1", data_type.clone(), true)]); + let input_phy_exprs: Vec> = vec![Arc::new( + expressions::Column::new_with_schema("c1", &input_schema).unwrap(), + )]; + let result_agg_phy_exprs = create_aggregate_expr( + &fun, + false, + &input_phy_exprs[0..1], + &input_schema, + "c1", + )?; + match fun { + AggregateFunction::Min => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + AggregateFunction::Max => { + assert!(result_agg_phy_exprs.as_any().is::()); + assert_eq!("c1", result_agg_phy_exprs.name()); + assert_eq!( + Field::new("c1", data_type.clone(), true), + result_agg_phy_exprs.field().unwrap() + ); + } + _ => {} + }; + } + } + Ok(()) + } + + #[test] + fn test_sum_avg_expr() -> Result<()> { + Ok(()) + } #[test] fn test_min_max() -> Result<()> { @@ -270,6 +398,16 @@ mod tests { let observed = return_type(&AggregateFunction::Max, &[DataType::Int32])?; assert_eq!(DataType::Int32, observed); + + // test decimal for min + let observed = return_type(&AggregateFunction::Min, &[DataType::Decimal(10, 6)])?; + assert_eq!(DataType::Decimal(10, 6), observed); + + // test decimal for max + let observed = + return_type(&AggregateFunction::Max, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::Decimal(28, 13), observed); + Ok(()) } @@ -293,6 +431,10 @@ mod tests { let observed = return_type(&AggregateFunction::Count, &[DataType::Int8])?; assert_eq!(DataType::UInt64, observed); + + let observed = + return_type(&AggregateFunction::Count, &[DataType::Decimal(28, 13)])?; + assert_eq!(DataType::UInt64, observed); Ok(()) } diff --git a/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs new file mode 100644 index 0000000000000..24f24fec94809 --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs @@ -0,0 +1,175 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Support the coercion rule for aggregate function. + +use crate::arrow::datatypes::Schema; +use crate::error::{DataFusionError, Result}; +use crate::physical_plan::aggregates::AggregateFunction; +use crate::physical_plan::expressions::{ + is_avg_support_arg_type, is_sum_support_arg_type, try_cast, +}; +use crate::physical_plan::functions::{Signature, TypeSignature}; +use crate::physical_plan::PhysicalExpr; +use arrow::datatypes::DataType; +use std::sync::Arc; + +pub fn coerce_types( + agg_fun: &AggregateFunction, + input_types: &[DataType], + signature: &Signature, +) -> Result> { + match signature.type_signature { + TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => { + if input_types.len() != agg_count { + return Err(DataFusionError::Plan(format!("The function {:?} expect argument number is {:?}, but the input argument number is {:?}", + agg_fun, agg_count, input_types.len()))); + } + } + _ => { + return Err(DataFusionError::Plan(format!( + "The aggregate coercion rule don't support this {:?}", + signature + ))); + } + }; + match agg_fun { + AggregateFunction::Count | AggregateFunction::ApproxDistinct => { + Ok(input_types.to_vec()) + } + AggregateFunction::ArrayAgg => Ok(input_types.to_vec()), + AggregateFunction::Min | AggregateFunction::Max => Ok(input_types.to_vec()), + AggregateFunction::Sum => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval. + if !is_sum_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} do not support the {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + AggregateFunction::Avg => { + // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc + // smallint, int, bigint, real, double precision, decimal, or interval + if !is_avg_support_arg_type(&input_types[0]) { + return Err(DataFusionError::Plan(format!( + "The function {:?} do not support the {:?}.", + agg_fun, input_types[0] + ))); + } + Ok(input_types.to_vec()) + } + } +} + +pub fn coerce_exprs( + agg_fun: &AggregateFunction, + input_exprs: &[Arc], + schema: &Schema, + signature: &Signature, +) -> Result>> { + if input_exprs.is_empty() { + return Ok(vec![]); + } + let input_types = input_exprs + .iter() + .map(|e| e.data_type(schema)) + .collect::>>()?; + + // get the coerced data types + let coerced_types = coerce_types(agg_fun, &input_types, signature)?; + + // try cast if need + input_exprs + .iter() + .enumerate() + .map(|(i, expr)| try_cast(expr.clone(), schema, coerced_types[i].clone())) + .collect::>>() +} + +#[cfg(test)] +mod tests { + use crate::physical_plan::aggregates; + use crate::physical_plan::aggregates::{signature, AggregateFunction}; + use crate::physical_plan::coercion_rule::aggregate_rule::coerce_types; + use arrow::datatypes::DataType; + + #[test] + fn test_aggregate_coerce_types() { + // test input args with error number input types + let fun = AggregateFunction::Min; + let input_types = vec![DataType::Int64, DataType::Int32]; + let signature = signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!("Error during planning: The function Min expect argument number is 1, but the input argument number is 2", result.unwrap_err().to_string()); + + // test input args is invalid data type for sum or avg + let fun = AggregateFunction::Sum; + let input_types = vec![DataType::Utf8]; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Sum do not support the Utf8.", + result.unwrap_err().to_string() + ); + let fun = AggregateFunction::Avg; + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, &input_types, &signature); + assert_eq!( + "Error during planning: The function Avg do not support the Utf8.", + result.unwrap_err().to_string() + ); + + // test count, array_agg, approx_distinct, min, max. + // the coerced types is same with input types + let funs = vec![ + AggregateFunction::Count, + AggregateFunction::ArrayAgg, + AggregateFunction::ApproxDistinct, + AggregateFunction::Min, + AggregateFunction::Max, + ]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Decimal(10, 2)], + vec![DataType::Utf8], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + // test sum, avg + let funs = vec![AggregateFunction::Sum, AggregateFunction::Avg]; + let input_types = vec![ + vec![DataType::Int32], + vec![DataType::Float32], + vec![DataType::Decimal(20, 3)], + ]; + for fun in funs { + for input_type in &input_types { + let signature = aggregates::signature(&fun); + let result = coerce_types(&fun, input_type, &signature); + assert_eq!(*input_type, result.unwrap()); + } + } + } +} diff --git a/datafusion/src/physical_plan/coercion_rule/mod.rs b/datafusion/src/physical_plan/coercion_rule/mod.rs new file mode 100644 index 0000000000000..8d07b10bfe23d --- /dev/null +++ b/datafusion/src/physical_plan/coercion_rule/mod.rs @@ -0,0 +1,19 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! define the coercion rule for different Expr type +pub(crate) mod aggregate_rule; diff --git a/datafusion/src/physical_plan/expressions/average.rs b/datafusion/src/physical_plan/expressions/average.rs index 2e218191f6683..87b4484269e19 100644 --- a/datafusion/src/physical_plan/expressions/average.rs +++ b/datafusion/src/physical_plan/expressions/average.rs @@ -62,6 +62,26 @@ pub fn avg_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_avg_support_arg_type(arg_type: &DataType) -> bool { + // TODO support the interval + // TODO: do we need to support the unsigned data type? + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) +} + impl Avg { /// Create a new AVG aggregate function pub fn new( diff --git a/datafusion/src/physical_plan/expressions/coercion.rs b/datafusion/src/physical_plan/expressions/coercion.rs index 180b16548b32b..c9b5f39ca3d27 100644 --- a/datafusion/src/physical_plan/expressions/coercion.rs +++ b/datafusion/src/physical_plan/expressions/coercion.rs @@ -29,7 +29,8 @@ pub fn is_signed_numeric(dt: &DataType) -> bool { | DataType::Int64 | DataType::Float16 | DataType::Float32 - | DataType::Float64 + | DataType::Float64 // TODO liukun4515 + // | DataType::Decimal(_,_) ) } @@ -125,6 +126,11 @@ pub fn numerical_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Option 0 { + let this_type = &test_types[index - 1]; + for that_type in test_types.iter().take(index) { + assert_eq!( + Some(this_type.clone()), + numerical_coercion(this_type, that_type) + ); + } + index -= 1; + } + } } diff --git a/datafusion/src/physical_plan/expressions/min_max.rs b/datafusion/src/physical_plan/expressions/min_max.rs index 9e5b1e095cd6f..e445ab31f4ea1 100644 --- a/datafusion/src/physical_plan/expressions/min_max.rs +++ b/datafusion/src/physical_plan/expressions/min_max.rs @@ -487,6 +487,42 @@ mod tests { use arrow::datatypes::*; use arrow::record_batch::RecordBatch; + #[test] + fn min_decimal() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn max_decimal() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn min_decimal_with_nulls() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn max_decimal_with_nulls() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn min_decimal_with_all_nulls() -> Result<()> { + // todo!() + Ok(()) + } + + #[test] + fn max_decimal_with_all_nulls() -> Result<()> { + // todo!() + Ok(()) + } + #[test] fn max_i32() -> Result<()> { let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2, 3, 4, 5])); diff --git a/datafusion/src/physical_plan/expressions/mod.rs b/datafusion/src/physical_plan/expressions/mod.rs index 5647ee0a4d270..134c6d89ac4f1 100644 --- a/datafusion/src/physical_plan/expressions/mod.rs +++ b/datafusion/src/physical_plan/expressions/mod.rs @@ -60,6 +60,7 @@ pub mod helpers { pub use approx_distinct::ApproxDistinct; pub use array_agg::ArrayAgg; +pub(crate) use average::is_avg_support_arg_type; pub use average::{avg_return_type, Avg, AvgAccumulator}; pub use binary::{binary, binary_operator_data_type, BinaryExpr}; pub use case::{case, CaseExpr}; @@ -83,6 +84,7 @@ pub use nth_value::NthValue; pub use nullif::{nullif_func, SUPPORTED_NULLIF_TYPES}; pub use rank::{dense_rank, percent_rank, rank}; pub use row_number::RowNumber; +pub(crate) use sum::is_sum_support_arg_type; pub use sum::{sum_return_type, Sum}; pub use try_cast::{try_cast, TryCastExpr}; diff --git a/datafusion/src/physical_plan/expressions/sum.rs b/datafusion/src/physical_plan/expressions/sum.rs index c3f57e31e0d54..aafca631da5ba 100644 --- a/datafusion/src/physical_plan/expressions/sum.rs +++ b/datafusion/src/physical_plan/expressions/sum.rs @@ -63,6 +63,26 @@ pub fn sum_return_type(arg_type: &DataType) -> Result { } } +pub(crate) fn is_sum_support_arg_type(arg_type: &DataType) -> bool { + // TODO support the interval + // TODO: do we need to support the unsigned data type? + matches!( + arg_type, + DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float16 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal(_, _) + ) +} + impl Sum { /// Create a new SUM aggregate function pub fn new( @@ -154,7 +174,7 @@ pub(super) fn sum_batch(values: &ArrayRef) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive the type {:?}", e - ))) + ))); } }) } @@ -238,7 +258,7 @@ pub(super) fn sum(lhs: &ScalarValue, rhs: &ScalarValue) -> Result { return Err(DataFusionError::Internal(format!( "Sum is not expected to receive a scalar {:?}", e - ))) + ))); } }) } diff --git a/datafusion/src/physical_plan/functions.rs b/datafusion/src/physical_plan/functions.rs index 9c59b9662daac..6db15c7e84eeb 100644 --- a/datafusion/src/physical_plan/functions.rs +++ b/datafusion/src/physical_plan/functions.rs @@ -144,6 +144,7 @@ impl Signature { } ///A function's volatility, which defines the functions eligibility for certain optimizations +///Ref from postgresql https://www.postgresql.org/docs/current/xfunc-volatility.html #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum Volatility { /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos]. diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index ef53d8602b405..8c5f662a4ac73 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -608,6 +608,7 @@ pub mod analyze; pub mod array_expressions; pub mod coalesce_batches; pub mod coalesce_partitions; +mod coercion_rule; pub mod common; pub mod cross_join; #[cfg(feature = "crypto_expressions")] diff --git a/datafusion/src/physical_plan/type_coercion.rs b/datafusion/src/physical_plan/type_coercion.rs index b4133565aebfa..3f86eed308309 100644 --- a/datafusion/src/physical_plan/type_coercion.rs +++ b/datafusion/src/physical_plan/type_coercion.rs @@ -76,6 +76,7 @@ pub fn data_types( if current_types.is_empty() { return Ok(vec![]); } + let valid_types = get_valid_types(&signature.type_signature, current_types)?; if valid_types @@ -154,14 +155,10 @@ fn maybe_data_types( if current_type == valid_type { new_type.push(current_type.clone()) + } else if can_coerce_from(valid_type, current_type) { + new_type.push(valid_type.clone()) } else { - // attempt to coerce - if can_coerce_from(valid_type, current_type) { - new_type.push(valid_type.clone()) - } else { - // not possible - return None; - } + return None; } } Some(new_type) @@ -174,6 +171,7 @@ fn maybe_data_types( pub fn can_coerce_from(type_into: &DataType, type_from: &DataType) -> bool { use self::DataType::*; match type_into { + // TODO, decimal data type, we just support the decimal Int8 => matches!(type_from, Int8), Int16 => matches!(type_from, Int8 | Int16 | UInt8), Int32 => matches!(type_from, Int8 | Int16 | Int32 | UInt8 | UInt16), diff --git a/datafusion/src/sql/planner.rs b/datafusion/src/sql/planner.rs index ac3d3b176dcdb..72e59bbcf8b09 100644 --- a/datafusion/src/sql/planner.rs +++ b/datafusion/src/sql/planner.rs @@ -372,13 +372,32 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { SQLDataType::Char(_) | SQLDataType::Varchar(_) | SQLDataType::Text => { Ok(DataType::Utf8) } - SQLDataType::Decimal(_, _) => Ok(DataType::Float64), SQLDataType::Float(_) => Ok(DataType::Float32), SQLDataType::Real | SQLDataType::Double => Ok(DataType::Float64), SQLDataType::Boolean => Ok(DataType::Boolean), SQLDataType::Date => Ok(DataType::Date32), SQLDataType::Time => Ok(DataType::Time64(TimeUnit::Millisecond)), SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), + // TODO liukun4515 + // from the sql statement data type to the arrow or datafusion data type + SQLDataType::Decimal(precision, scale) => { + if precision.is_none() + || scale.is_none() + || precision.unwrap() > 38 + || scale.unwrap() > precision.unwrap() + { + // illegal arguments for decimal + Err(DataFusionError::Internal(format!( + "The decimal data type {:?} is error", + sql_type + ))) + } else { + Ok(DataType::Decimal( + precision.unwrap() as usize, + scale.unwrap() as usize, + )) + } + } _ => Err(DataFusionError::NotImplemented(format!( "The SQL data type {:?} is not implemented", sql_type @@ -1996,6 +2015,16 @@ pub fn convert_data_type(sql: &SQLDataType) -> Result { SQLDataType::Char(_) | SQLDataType::Varchar(_) => Ok(DataType::Utf8), SQLDataType::Timestamp => Ok(DataType::Timestamp(TimeUnit::Nanosecond, None)), SQLDataType::Date => Ok(DataType::Date32), + // TODO liukun4515 + // SQLDataType::Decimal(precision,scale) => { + // if precision.is_none() || scale.is_none() { + // Err(DataFusionError::Internal(format!("Error DecimalType({:?},{:?})",precision,scale))) + // } else if precision.unwrap() >38 || scale.unwrap()>precision.unwrap() { + // Err(DataFusionError::Internal(format!("Error DecimalType({:?},{:?})",precision,scale))) + // } else { + // Ok(DataType::Decimal(precision.unwrap() as usize, scale.unwrap() as usize)) + // } + // } other => Err(DataFusionError::NotImplemented(format!( "Unsupported SQL type {:?}", other diff --git a/datafusion/tests/sql.rs b/datafusion/tests/sql.rs index 640556cb27249..63b246f971494 100644 --- a/datafusion/tests/sql.rs +++ b/datafusion/tests/sql.rs @@ -412,6 +412,99 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { Ok(()) } +#[tokio::test] +async fn error_count_agg() -> Result<()> { + // let mut ctx = ExecutionContext::new(); + + // let sql = "select sum('123')"; + // let sql = "select min(1)"; + // let actual = execute_to_batches(&mut ctx, sql).await; + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_decimal_by_sql() -> Result<()> { + let mut ctx = ExecutionContext::new(); + register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; + let sql = "SELECT c1 from aggregate_simple"; + // let sql = "SELECT '123'+1"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| c1 |", + "+----------+", + "| 0.000010 |", + "| 0.000020 |", + "| 0.000020 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + Ok(()) +} + +#[tokio::test] +async fn csv_query_with_decimal() -> Result<()> { + let mut ctx = ExecutionContext::new(); + // the data type of c1 is decimal(10,6) + register_aggregate_simple_csv_use_decimal(&mut ctx).await?; + // query + let sql = "SELECT c1 from aggregate_simple"; + let actual = execute_to_batches(&mut ctx, sql).await; + let expected = vec![ + "+----------+", + "| c1 |", + "+----------+", + "| 0.000010 |", + "| 0.000020 |", + "| 0.000020 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000030 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000040 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "| 0.000050 |", + "+----------+", + ]; + assert_batches_eq!(expected, &actual); + + // aggregate: min,max,count,sum + // sql = "SELECT MIN(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + // + // sql = "SELECT MAX(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + // + // sql = "SELECT COUNT(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + // + // sql = "SELECT SUM(c1) from aggregate_simple"; + // actual = execute_to_batches(&mut ctx, sql).await; + // println!("{:?}", actual); + + Ok(()) +} + #[tokio::test] async fn csv_query_group_by_float32() -> Result<()> { let mut ctx = ExecutionContext::new(); @@ -3761,6 +3854,47 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } +async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) { + let df = ctx + .sql(&format!( + "CREATE EXTERNAL TABLE aggregate_simple ( + c1 DECIMAL(10,6) NOT NULL, + c2 DOUBLE NOT NULL, + c3 BOOLEAN NOT NULL + ) + STORED AS CSV + WITH HEADER ROW + LOCATION 'tests/aggregate_simple.csv'" + )) + .await + .expect("Creating dataframe for CREATE EXTERNAL TABLE with decimal data type"); + + let results = df.collect().await.expect("Executing CREATE EXTERNAL TABLE"); + assert!( + results.is_empty(), + "Expected no rows from executing CREATE EXTERNAL TABLE" + ); +} + +async fn register_aggregate_simple_csv_use_decimal( + ctx: &mut ExecutionContext, +) -> Result<()> { + // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats + let schema = Arc::new(Schema::new(vec![ + Field::new("c1", DataType::Decimal(10, 6), false), + Field::new("c2", DataType::Float64, false), + Field::new("c3", DataType::Boolean, false), + ])); + + ctx.register_csv( + "aggregate_simple", + "tests/aggregate_simple.csv", + CsvReadOptions::new().schema(&schema), + ) + .await?; + Ok(()) +} + async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> { // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats let schema = Arc::new(Schema::new(vec![ @@ -5615,7 +5749,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { let logical_plan = ctx.create_logical_plan(sql)?; let physical_plan = ctx.create_physical_plan(&logical_plan).await; let err = physical_plan.unwrap_err(); - assert_eq!(err.to_string(), "Error during planning: Invalid or wrong number of arguments passed to aggregate: 'COUNT(DISTINCT )'"); + assert_eq!(err.to_string(), "The function Count expect argument number is 1, but the input argument number is 0"); Ok(()) }