From 19952dedff9e1b1d609802dfc31bfbe174484333 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 15 Aug 2020 12:01:44 +0200 Subject: [PATCH 01/13] Added support for multiple types per UDF argument. Currently a UDF's argument can only be a single type. This PR adds support for multiple types per argument, thus allowing users to register UDFs that can handle multiple types at once. --- rust/datafusion/src/execution/context.rs | 5 +- .../physical_plan/math_expressions.rs | 6 +- .../src/execution/physical_plan/mod.rs | 4 +- .../src/execution/physical_plan/udf.rs | 11 ++- rust/datafusion/src/logicalplan.rs | 8 +- .../datafusion/src/optimizer/type_coercion.rs | 82 ++++++++++++++++--- rust/datafusion/src/sql/planner.rs | 14 ++-- rust/datafusion/tests/sql.rs | 2 +- 8 files changed, 98 insertions(+), 34 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 8f92aae302b..6fb59b16148 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -1049,10 +1049,7 @@ mod tests { let my_add = ScalarFunction::new( "my_add", - vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ], + vec![vec![DataType::Int32, DataType::Int32]], DataType::Int32, myfunc, ); diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 97098d65b5e..c4ba6afc77a 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -21,7 +21,7 @@ use crate::error::ExecutionError; use crate::execution::physical_plan::udf::ScalarFunction; use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::DataType; use std::sync::Arc; @@ -29,7 +29,7 @@ macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { ScalarFunction::new( $NAME, - vec![Field::new("n", DataType::Float64, true)], + vec![vec![DataType::Float64]], DataType::Float64, Arc::new(|args: &[ArrayRef]| { let n = &args[0].as_any().downcast_ref::(); @@ -86,7 +86,7 @@ mod tests { execution::context::ExecutionContext, logicalplan::{col, sqrt, LogicalPlanBuilder}, }; - use arrow::datatypes::Schema; + use arrow::datatypes::{Field, Schema}; #[test] fn cast_i8_input() -> Result<()> { diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 4b66955d036..c2ba055e185 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -26,7 +26,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContextState; use crate::logicalplan::{LogicalPlan, ScalarValue}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::{ compute::kernels::length::length, record_batch::{RecordBatch, RecordBatchReader}, @@ -96,7 +96,7 @@ pub trait Accumulator: Debug { pub fn scalar_functions() -> Vec { let mut udfs = vec![ScalarFunction::new( "length", - vec![Field::new("n", DataType::Utf8, true)], + vec![vec![DataType::Utf8]], DataType::UInt32, Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))), )]; diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 944b5c9bef3..370faf3500f 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -20,7 +20,7 @@ use std::fmt; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; use crate::execution::physical_plan::PhysicalExpr; @@ -37,8 +37,11 @@ pub type ScalarUdf = Arc Result + Send + Sync>; pub struct ScalarFunction { /// Function name pub name: String, - /// Function argument meta-data - pub args: Vec, + /// Set of valid argument types. + /// The first dimension (0) represents specific combinations of valid argument types + /// The second dimension (1) represents the types of each argument. + /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second + pub args: Vec>, /// Return type pub return_type: DataType, /// UDF implementation @@ -60,7 +63,7 @@ impl ScalarFunction { /// Create a new ScalarFunction pub fn new( name: &str, - args: Vec, + args: Vec>, return_type: DataType, fun: ScalarUdf, ) -> Self { diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index ee3c17aedd2..317f904d1f4 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -47,8 +47,8 @@ pub enum FunctionType { pub struct FunctionMeta { /// Function name name: String, - /// Function arguments - args: Vec, + /// Function arguments. Each argument i can be one of the types of args[i], with respective priority + args: Vec>, /// Function return type return_type: DataType, /// Function type (Scalar or Aggregate) @@ -59,7 +59,7 @@ impl FunctionMeta { #[allow(missing_docs)] pub fn new( name: String, - args: Vec, + args: Vec>, return_type: DataType, function_type: FunctionType, ) -> Self { @@ -75,7 +75,7 @@ impl FunctionMeta { &self.name } /// Getter for the arg list - pub fn args(&self) -> &Vec { + pub fn args(&self) -> &Vec> { &self.args } /// Getter for the `DataType` the function returns diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index a485423975d..1e4e43ec4c7 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -23,7 +23,7 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Schema}; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::udf::ScalarFunction; @@ -80,16 +80,32 @@ impl TypeCoercionRule { .get(name) { Some(func_meta) => { - for i in 0..expressions.len() { - let field = &func_meta.args[i]; - let actual_type = expressions[i].get_type(schema)?; - let required_type = field.data_type(); - if &actual_type != required_type { - let super_type = - utils::get_supertype(&actual_type, required_type)?; - expressions[i] = - expressions[i].cast_to(&super_type, schema)? - }; + // compute the current types and expressions + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new = if func_meta.args.contains(¤t_types) { + Some(expressions) + } else { + maybe_rewrite( + &expressions, + ¤t_types, + &schema, + &func_meta.args, + )? + }; + + if let Some(args) = new { + expressions = args; + } else { + return Err(ExecutionError::General(format!( + "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", + func_meta.name, + func_meta.args, + current_types, + ))); } } _ => { @@ -147,6 +163,50 @@ impl OptimizerRule for TypeCoercionRule { } } +/// tries to re-cast expressions under schema based on the set of valid signatures +fn maybe_rewrite( + expressions: &Vec, + current_types: &Vec, + schema: &Schema, + signature: &Vec>, +) -> Result>> { + // for each set of valid signatures, try to coerse all expressions to one of them + let mut new_expressions: Option> = None; + for valid_types in signature { + // for each option, try to coerse all arguments to it + if let Some(types) = maybe_coerse(valid_types, ¤t_types) { + // yes: let's re-write the expressions + new_expressions = Some( + expressions + .iter() + .enumerate() + .map(|(i, expr)| expr.cast_to(&types[i], schema)) + .collect::>>()?, + ); + break; + } + // we cannot: try the next + } + Ok(new_expressions) +} + +/// Try to coerse current_types into valid_types +fn maybe_coerse( + valid_types: &Vec, + current_types: &Vec, +) -> Option> { + let mut super_type = Vec::with_capacity(valid_types.len()); + for (i, valid_type) in valid_types.iter().enumerate() { + let current_type = ¤t_types[i]; + if let Ok(t) = utils::get_supertype(current_type, valid_type) { + super_type.push(t) + } else { + return None; + } + } + Some(super_type) +} + #[cfg(test)] mod tests { use super::*; diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 16da4527557..8cc4023dd6a 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -523,10 +523,14 @@ impl SqlToRel { let mut safe_args: Vec = vec![]; for i in 0..rex_args.len() { - safe_args.push( - rex_args[i] - .cast_to(fm.args()[i].data_type(), schema)?, - ); + let expr = if fm.args()[i] + .contains(&rex_args[i].get_type(schema)?) + { + rex_args[i].clone() + } else { + rex_args[i].cast_to(&fm.args()[i][0], schema)? + }; + safe_args.push(expr) } Ok(Expr::ScalarFunction { @@ -912,7 +916,7 @@ mod tests { match name { "sqrt" => Some(Arc::new(FunctionMeta::new( "sqrt".to_string(), - vec![Field::new("n", DataType::Float64, false)], + vec![vec![DataType::Float64]], DataType::Float64, FunctionType::Scalar, ))), diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index d2d5349d5e4..f66793746dd 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -207,7 +207,7 @@ fn create_ctx() -> Result { // register a custom UDF ctx.register_udf(ScalarFunction::new( "custom_sqrt", - vec![Field::new("n", DataType::Float64, true)], + vec![vec![DataType::Float64]], DataType::Float64, Arc::new(custom_sqrt), )); From 82a803292e0773f53fadca3e4800a5a8a511e9ff Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 15 Aug 2020 12:41:27 +0200 Subject: [PATCH 02/13] Made math expressions accept float32. --- .../physical_plan/math_expressions.rs | 82 ++++++++++++++----- 1 file changed, 61 insertions(+), 21 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index c4ba6afc77a..1c97c2139d3 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -20,36 +20,62 @@ use crate::error::ExecutionError; use crate::execution::physical_plan::udf::ScalarFunction; -use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; +use arrow::array::{Array, ArrayRef}; +use arrow::array::{Float32Array, Float64Array}; use arrow::datatypes::DataType; use std::sync::Arc; +macro_rules! compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident) => {{ + let mut builder = <$TYPE>::builder($ARRAY.len()); + for i in 0..$ARRAY.len() { + if $ARRAY.is_null(i) { + builder.append_null()?; + } else { + builder.append_value($ARRAY.value(i).$FUNC())?; + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => compute_op!(array, $FUNC, $TYPE), + _ => Err(ExecutionError::General(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +macro_rules! unary_primitive_array_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident) => {{ + match ($ARRAY).data_type() { + DataType::Float32 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float32Array), + DataType::Float64 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float64Array), + other => Err(ExecutionError::General(format!( + "Unsupported data type {:?} for function {}", + other, $NAME, + ))), + } + }}; +} + macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { ScalarFunction::new( $NAME, - vec![vec![DataType::Float64]], + // order: from faster to slower + vec![vec![DataType::Float32], vec![DataType::Float64]], DataType::Float64, Arc::new(|args: &[ArrayRef]| { - let n = &args[0].as_any().downcast_ref::(); - match n { - Some(array) => { - let mut builder = Float64Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array.value(i).$FUNC())?; - } - } - Ok(Arc::new(builder.finish())) - } - _ => Err(ExecutionError::General(format!( - "Invalid data type for {}", - $NAME - ))), - } + let array = &args[0]; + unary_primitive_array_op!(array, $NAME, $FUNC) }), ) }; @@ -96,7 +122,7 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(CAST(#c0 AS Float64))\ + let expected = "Projection: sqrt(CAST(#c0 AS Float32))\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); Ok(()) @@ -115,4 +141,18 @@ mod tests { assert_eq!(format!("{:?}", plan), expected); Ok(()) } + + #[test] + fn no_cast_f32_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Float32, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(col("c0"))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + let expected = "Projection: sqrt(#c0)\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } } From 9bba8c4211cc3c8b600a9bfa686d924aff3e9a8b Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 16:12:35 +0200 Subject: [PATCH 03/13] Renamed attribute. --- rust/datafusion/src/execution/context.rs | 2 +- rust/datafusion/src/execution/physical_plan/udf.rs | 8 ++++---- rust/datafusion/src/logicalplan.rs | 10 +++++----- rust/datafusion/src/optimizer/type_coercion.rs | 6 +++--- rust/datafusion/src/sql/planner.rs | 4 ++-- 5 files changed, 15 insertions(+), 15 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 6fb59b16148..64a279b6bd0 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -516,7 +516,7 @@ impl SchemaProvider for ExecutionContextState { .map(|f| { Arc::new(FunctionMeta::new( name.to_owned(), - f.args.clone(), + f.arg_types.clone(), f.return_type.clone(), FunctionType::Scalar, )) diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 370faf3500f..4d63b8a0309 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -41,7 +41,7 @@ pub struct ScalarFunction { /// The first dimension (0) represents specific combinations of valid argument types /// The second dimension (1) represents the types of each argument. /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second - pub args: Vec>, + pub arg_types: Vec>, /// Return type pub return_type: DataType, /// UDF implementation @@ -52,7 +52,7 @@ impl Debug for ScalarFunction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("ScalarFunction") .field("name", &self.name) - .field("args", &self.args) + .field("arg_types", &self.arg_types) .field("return_type", &self.return_type) .field("fun", &"") .finish() @@ -63,13 +63,13 @@ impl ScalarFunction { /// Create a new ScalarFunction pub fn new( name: &str, - args: Vec>, + arg_types: Vec>, return_type: DataType, fun: ScalarUdf, ) -> Self { Self { name: name.to_owned(), - args, + arg_types, return_type, fun, } diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 317f904d1f4..38c86a2e951 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -48,7 +48,7 @@ pub struct FunctionMeta { /// Function name name: String, /// Function arguments. Each argument i can be one of the types of args[i], with respective priority - args: Vec>, + arg_types: Vec>, /// Function return type return_type: DataType, /// Function type (Scalar or Aggregate) @@ -59,13 +59,13 @@ impl FunctionMeta { #[allow(missing_docs)] pub fn new( name: String, - args: Vec>, + arg_types: Vec>, return_type: DataType, function_type: FunctionType, ) -> Self { FunctionMeta { name, - args, + arg_types, return_type, function_type, } @@ -75,8 +75,8 @@ impl FunctionMeta { &self.name } /// Getter for the arg list - pub fn args(&self) -> &Vec> { - &self.args + pub fn arg_types(&self) -> &Vec> { + &self.arg_types } /// Getter for the `DataType` the function returns pub fn return_type(&self) -> &DataType { diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 1e4e43ec4c7..4aedc4d833e 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -86,14 +86,14 @@ impl TypeCoercionRule { .map(|e| e.get_type(schema)) .collect::>>()?; - let new = if func_meta.args.contains(¤t_types) { + let new = if func_meta.arg_types.contains(¤t_types) { Some(expressions) } else { maybe_rewrite( &expressions, ¤t_types, &schema, - &func_meta.args, + &func_meta.arg_types, )? }; @@ -103,7 +103,7 @@ impl TypeCoercionRule { return Err(ExecutionError::General(format!( "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", func_meta.name, - func_meta.args, + func_meta.arg_types, current_types, ))); } diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 8cc4023dd6a..23185284459 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -523,12 +523,12 @@ impl SqlToRel { let mut safe_args: Vec = vec![]; for i in 0..rex_args.len() { - let expr = if fm.args()[i] + let expr = if fm.arg_types()[i] .contains(&rex_args[i].get_type(schema)?) { rex_args[i].clone() } else { - rex_args[i].cast_to(&fm.args()[i][0], schema)? + rex_args[i].cast_to(&fm.arg_types()[i][0], schema)? }; safe_args.push(expr) } From 1b0f2b6cfb753e3049518be167bcc78cf3765097 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 16:18:56 +0200 Subject: [PATCH 04/13] Spell and cleanup --- .../src/execution/physical_plan/udf.rs | 3 --- rust/datafusion/src/logicalplan.rs | 5 ++++- rust/datafusion/src/optimizer/type_coercion.rs | 18 ++++++++---------- 3 files changed, 12 insertions(+), 14 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 4d63b8a0309..5601f97c8bc 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -38,9 +38,6 @@ pub struct ScalarFunction { /// Function name pub name: String, /// Set of valid argument types. - /// The first dimension (0) represents specific combinations of valid argument types - /// The second dimension (1) represents the types of each argument. - /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second pub arg_types: Vec>, /// Return type pub return_type: DataType, diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 38c86a2e951..f03e4ca640c 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -47,7 +47,10 @@ pub enum FunctionType { pub struct FunctionMeta { /// Function name name: String, - /// Function arguments. Each argument i can be one of the types of args[i], with respective priority + /// Function argument types + /// The first dimension (0) represents specific combinations of valid argument types + /// The second dimension (1) represents the types of each argument. + /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second arg_types: Vec>, /// Function return type return_type: DataType, diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 4aedc4d833e..64babdb41b6 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -170,28 +170,26 @@ fn maybe_rewrite( schema: &Schema, signature: &Vec>, ) -> Result>> { - // for each set of valid signatures, try to coerse all expressions to one of them - let mut new_expressions: Option> = None; + // for each set of valid signatures, try to coerce all expressions to one of them for valid_types in signature { - // for each option, try to coerse all arguments to it - if let Some(types) = maybe_coerse(valid_types, ¤t_types) { + // for each option, try to coerce all arguments to it + if let Some(types) = maybe_coerce(valid_types, ¤t_types) { // yes: let's re-write the expressions - new_expressions = Some( + return Ok(Some( expressions .iter() .enumerate() .map(|(i, expr)| expr.cast_to(&types[i], schema)) .collect::>>()?, - ); - break; + )) } // we cannot: try the next } - Ok(new_expressions) + Ok(None) } -/// Try to coerse current_types into valid_types -fn maybe_coerse( +/// Try to coerce current_types into valid_types +fn maybe_coerce( valid_types: &Vec, current_types: &Vec, ) -> Option> { From 286b25ca199b2cb7b3e6e23602716e6e24aacf93 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 17:51:06 +0200 Subject: [PATCH 05/13] Added testing to type_coercion rules. --- .../datafusion/src/optimizer/type_coercion.rs | 141 +++++++++++++++++- 1 file changed, 140 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 64babdb41b6..f12c2275ce3 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -181,7 +181,7 @@ fn maybe_rewrite( .enumerate() .map(|(i, expr)| expr.cast_to(&types[i], schema)) .collect::>>()?, - )) + )); } // we cannot: try the next } @@ -345,4 +345,143 @@ mod tests { assert_eq!(expected, format!("{:?}", expr2)); } + + #[test] + fn test_maybe_coerce() -> Result<()> { + // this vec contains: arg1, arg2, expected result + let cases = vec![ + // 2 entries, same values + ( + vec![DataType::UInt8, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt8, DataType::UInt16]), + ), + // 2 entries, can coerse values + ( + vec![DataType::UInt16, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt16, DataType::UInt16]), + ), + // 0 entries, all good + (vec![], vec![], Some(vec![])), + // 2 entries, can't coerce + ( + vec![DataType::Boolean, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + None, + ), + // u32 -> u16 is possible + ( + vec![DataType::Boolean, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt16], + Some(vec![DataType::Boolean, DataType::UInt32]), + ), + ]; + + for case in cases { + assert_eq!(maybe_coerce(&case.0, &case.1), case.2) + } + Ok(()) + } + + #[test] + fn test_maybe_rewrite() -> Result<()> { + // create a schema + let schema = |t: Vec| { + Schema::new( + t.iter() + .enumerate() + .map(|(i, t)| Field::new(&*format!("c{}", i), t.clone(), true)) + .collect(), + ) + }; + + // create a vector of expressions + let expressions = |t: Vec, schema| -> Result> { + t.iter() + .enumerate() + .map(|(i, t)| col(&*format!("c{}", i)).cast_to(&t, &schema)) + .collect::>>() + }; + + // map expr + schema to types + let current_types = |expressions: &Vec, schema| -> Result> { + Ok(expressions + .iter() + .map(|e| e.get_type(&schema)) + .collect::>>()?) + }; + + // create a case: input + expected result + let case = |observed: Vec, + valid, + expected: Option>| + -> Result<_> { + let schema = schema(observed.clone()); + let expr = expressions(observed, schema.clone())?; + let expected = if let Some(e) = expected { + // expressions re-written as cast + Some(expressions(e, schema.clone())?) + } else { + None + }; + Ok(( + expr.clone(), + current_types(&expr, schema.clone())?, + schema, + valid, + expected, + )) + }; + + let cases = vec![ + // no conversion -> all good + case(vec![], vec![vec![]], Some(vec![]))?, + // u16 -> u32 + case( + vec![DataType::UInt16, DataType::UInt32], + vec![vec![DataType::UInt32, DataType::UInt32]], + Some(vec![DataType::UInt32, DataType::UInt32]), + )?, + // same type + case( + vec![DataType::UInt16, DataType::UInt32], + vec![vec![DataType::UInt16, DataType::UInt32]], + Some(vec![DataType::UInt16, DataType::UInt32]), + )?, + // we do not know how to cast bool to UInt16 => fail + case( + vec![DataType::Boolean, DataType::UInt32], + vec![vec![DataType::UInt16, DataType::UInt32]], + None, + )?, + // we do not know how to cast (bool,u16) to (u16,u32), + // but we know to cast to (bool,u32) + case( + vec![DataType::Boolean, DataType::UInt16], + vec![ + vec![DataType::UInt16, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt32], + ], + Some(vec![DataType::Boolean, DataType::UInt32]), + )?, + // we do not know how to cast (bool,u16) to (u16,u32) nor (u32,u16) + case( + vec![DataType::Boolean, DataType::UInt32], + vec![ + vec![DataType::UInt16, DataType::UInt32], + vec![DataType::UInt32, DataType::UInt16], + ], + None, + )?, + ]; + + for (i, case) in cases.iter().enumerate() { + if maybe_rewrite(&case.0, &case.1, &case.2, &case.3)? != case.4 { + assert_eq!(maybe_rewrite(&case.0, &case.1, &case.2, &case.3)?, case.4); + return Err(ExecutionError::General(format!("case {} failed", i))); + } + } + Ok(()) + } } From 52218c852b7b3016afeaf95d8a46d6deea89d231 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 18:11:18 +0200 Subject: [PATCH 06/13] Removed type coercion from the planner. This operation is already by the optimizer and is more complete. --- rust/datafusion/src/sql/planner.rs | 34 ++++++++++++++++-------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 23185284459..30fabaa3da4 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -515,27 +515,29 @@ impl SqlToRel { } _ => match self.schema_provider.get_function_meta(&name) { Some(fm) => { - let rex_args = function + let args = function .args .iter() .map(|a| self.sql_to_rex(a, schema)) .collect::>>()?; - - let mut safe_args: Vec = vec![]; - for i in 0..rex_args.len() { - let expr = if fm.arg_types()[i] - .contains(&rex_args[i].get_type(schema)?) - { - rex_args[i].clone() - } else { - rex_args[i].cast_to(&fm.arg_types()[i][0], schema)? - }; - safe_args.push(expr) + let expected_args = match fm.arg_types().len() { + 0 => 0, + _ => fm.arg_types()[0].len(), + }; + let current_args = args.len(); + + if current_args != expected_args { + return Err(ExecutionError::General( + format!("The function '{}' expects {} arguments, but {} were passed", + name, + expected_args, + current_args, + ))); } Ok(Expr::ScalarFunction { name: name.clone(), - args: safe_args, + args, return_type: fm.return_type().clone(), }) } @@ -602,7 +604,7 @@ mod tests { fn select_scalar_func_with_literal_no_relation() { quick_test( "SELECT sqrt(9)", - "Projection: sqrt(CAST(Int64(9) AS Float64))\ + "Projection: sqrt(Int64(9))\ \n EmptyRelation", ); } @@ -740,7 +742,7 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64))\ + let expected = "Projection: sqrt(#age)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -748,7 +750,7 @@ mod tests { #[test] fn select_aliased_scalar_func() { let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64)) AS square_people\ + let expected = "Projection: sqrt(#age) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } From e6dad6b56ba8087f64d5a558174767a25032f9b4 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 18:50:03 +0200 Subject: [PATCH 07/13] Added end-to-end tests on custom UDFs with multiple types. --- rust/datafusion/tests/sql.rs | 145 ++++++++++++++++++++++++++++++++++- 1 file changed, 144 insertions(+), 1 deletion(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index f66793746dd..9a1807f2fa2 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -201,10 +201,92 @@ fn csv_query_avg_sqrt() -> Result<()> { Ok(()) } +#[test] +fn csv_query_avg_custom_udf_f64_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c12 is f64 + let sql = "SELECT avg(custom_add(c12, c12)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // perform equivalent calculation + let sql = "SELECT avg(c12 + c12) FROM aggregate_test_100"; + let expected = execute(&mut ctx, sql); + + // verify equality + assert_eq!(actual.join("\n"), expected.join("\n")); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_f32_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c11 is f32 + // c12 is f64 + let sql = "SELECT avg(custom_add(c11, c12)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as f32,f64 returns a constant 3264.0 + let expected = "3264.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_f32_f32() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c11 is f32 + let sql = "SELECT avg(custom_add(c11, c11)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as f32,f32 returns 3232.0 + let expected = "3232.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_i8() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c3 is i8, castable to float32 + let sql = "SELECT avg(custom_add(c3, c3)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as float32,float32 returns a constant 1111.0 + let expected = "3232.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_utf8() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // utf8 is currently convertable to any type. See https://issues.apache.org/jira/browse/ARROW-4957 + let sql = "SELECT avg(custom_add(c1, c1)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted on any other type returns a constant 1111.0 + let expected = "1111.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + fn create_ctx() -> Result { let mut ctx = ExecutionContext::new(); - // register a custom UDF + // register a UDF of 1 argument ctx.register_udf(ScalarFunction::new( "custom_sqrt", vec![vec![DataType::Float64]], @@ -212,6 +294,18 @@ fn create_ctx() -> Result { Arc::new(custom_sqrt), )); + // register a udf of two arguments + ctx.register_udf(ScalarFunction::new( + "custom_add", + vec![ + vec![DataType::Float32, DataType::Float32], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ], + DataType::Float64, + Arc::new(custom_add), + )); + Ok(ctx) } @@ -232,6 +326,55 @@ fn custom_sqrt(args: &[ArrayRef]) -> Result { Ok(Arc::new(builder.finish())) } +fn custom_add(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Float64, DataType::Float64) => { + let input1 = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let input2 = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input1.len()); + for i in 0..input1.len() { + if input1.is_null(i) || input2.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(input1.value(i) + input2.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + (DataType::Float32, DataType::Float32) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(3232.0)?; + } + Ok(Arc::new(builder.finish())) + } + (DataType::Float32, DataType::Float64) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(3264.0)?; + } + Ok(Arc::new(builder.finish())) + } + (_, _) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(1111.0)?; + } + Ok(Arc::new(builder.finish())) + } + } +} + #[test] fn csv_query_avg() -> Result<()> { let mut ctx = ExecutionContext::new(); From 3ce50686e78c5310c412c7f538e7b08431826345 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 18:55:00 +0200 Subject: [PATCH 08/13] Moved comment around. --- rust/datafusion/src/execution/physical_plan/udf.rs | 3 +++ rust/datafusion/src/logicalplan.rs | 3 --- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 5601f97c8bc..4d63b8a0309 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -38,6 +38,9 @@ pub struct ScalarFunction { /// Function name pub name: String, /// Set of valid argument types. + /// The first dimension (0) represents specific combinations of valid argument types + /// The second dimension (1) represents the types of each argument. + /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second pub arg_types: Vec>, /// Return type pub return_type: DataType, diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index f03e4ca640c..5cd75324a55 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -48,9 +48,6 @@ pub struct FunctionMeta { /// Function name name: String, /// Function argument types - /// The first dimension (0) represents specific combinations of valid argument types - /// The second dimension (1) represents the types of each argument. - /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second arg_types: Vec>, /// Function return type return_type: DataType, From 6f370e8ec85739eab52f18d1188d4f8db4e335c6 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 15 Aug 2020 11:15:26 +0200 Subject: [PATCH 09/13] Split AggregateExpr in PhysicalExpr + Aggregate --- .../execution/physical_plan/expressions.rs | 94 +++++++++++++++---- .../execution/physical_plan/hash_aggregate.rs | 4 +- .../src/execution/physical_plan/mod.rs | 12 +-- 3 files changed, 86 insertions(+), 24 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index dff913b9162..1c64c3fe44b 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -24,7 +24,9 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::common::get_scalar_value; -use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::execution::physical_plan::{ + Accumulator, Aggregate, AggregateExpr, PhysicalExpr, +}; use crate::logicalplan::{Operator, ScalarValue}; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, @@ -107,7 +109,7 @@ impl Sum { } } -impl AggregateExpr for Sum { +impl PhysicalExpr for Sum { fn data_type(&self, input_schema: &Schema) -> Result { match self.expr.data_type(input_schema)? { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { @@ -125,10 +127,22 @@ impl AggregateExpr for Sum { } } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { self.expr.evaluate(batch) } +} +impl fmt::Display for Sum { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "SUM({})", self.expr) + } +} + +impl Aggregate for Sum { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(SumAccumulator { sum: None })) } @@ -301,7 +315,7 @@ impl Avg { } } -impl AggregateExpr for Avg { +impl PhysicalExpr for Avg { fn data_type(&self, input_schema: &Schema) -> Result { match self.expr.data_type(input_schema)? { DataType::Int8 @@ -321,10 +335,22 @@ impl AggregateExpr for Avg { } } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { self.expr.evaluate(batch) } +} +impl fmt::Display for Avg { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AVG({})", self.expr) + } +} + +impl Aggregate for Avg { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(AvgAccumulator { sum: None, @@ -417,15 +443,27 @@ impl Max { } } -impl AggregateExpr for Max { +impl PhysicalExpr for Max { fn data_type(&self, input_schema: &Schema) -> Result { self.expr.data_type(input_schema) } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { self.expr.evaluate(batch) } +} +impl fmt::Display for Max { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MAX({})", self.expr) + } +} + +impl Aggregate for Max { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MaxAccumulator { max: None })) } @@ -601,15 +639,27 @@ impl Min { } } -impl AggregateExpr for Min { +impl PhysicalExpr for Min { fn data_type(&self, input_schema: &Schema) -> Result { self.expr.data_type(input_schema) } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { self.expr.evaluate(batch) } +} +impl fmt::Display for Min { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "MIN({})", self.expr) + } +} + +impl Aggregate for Min { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MinAccumulator { min: None })) } @@ -786,15 +836,27 @@ impl Count { } } -impl AggregateExpr for Count { +impl PhysicalExpr for Count { fn data_type(&self, _input_schema: &Schema) -> Result { Ok(DataType::UInt64) } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { self.expr.evaluate(batch) } +} + +impl fmt::Display for Count { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "COUNT({})", self.expr) + } +} +impl Aggregate for Count { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(CountAccumulator { count: 0 })) } @@ -1828,7 +1890,7 @@ mod tests { fn do_sum(batch: &RecordBatch) -> Result> { let sum = sum(col("a")); let accum = sum.create_accumulator(); - let input = sum.evaluate_input(batch)?; + let input = sum.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1839,7 +1901,7 @@ mod tests { fn do_max(batch: &RecordBatch) -> Result> { let max = max(col("a")); let accum = max.create_accumulator(); - let input = max.evaluate_input(batch)?; + let input = max.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1850,7 +1912,7 @@ mod tests { fn do_min(batch: &RecordBatch) -> Result> { let min = min(col("a")); let accum = min.create_accumulator(); - let input = min.evaluate_input(batch)?; + let input = min.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1861,7 +1923,7 @@ mod tests { fn do_count(batch: &RecordBatch) -> Result> { let count = count(col("a")); let accum = count.create_accumulator(); - let input = count.evaluate_input(batch)?; + let input = count.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1872,7 +1934,7 @@ mod tests { fn do_avg(batch: &RecordBatch) -> Result> { let avg = avg(col("a")); let accum = avg.create_accumulator(); - let input = avg.evaluate_input(batch)?; + let input = avg.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index d7366395ca4..dab65250642 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -268,7 +268,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { .aggr_expr .iter() .map(|expr| { - expr.evaluate_input(&batch) + expr.evaluate(&batch) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>>()?; @@ -433,7 +433,7 @@ impl RecordBatchReader for HashAggregateIterator { .aggr_expr .iter() .map(|expr| { - expr.evaluate_input(&batch) + expr.evaluate(&batch) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>>()?; diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index c2ba055e185..c9935c480fa 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -68,12 +68,8 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug { fn evaluate(&self, batch: &RecordBatch) -> Result; } -/// Aggregate expression that can be evaluated against a RecordBatch -pub trait AggregateExpr: Send + Sync + Debug { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; - /// Evaluate the expression being aggregated - fn evaluate_input(&self, batch: &RecordBatch) -> Result; +/// Aggregate knows how to accumulate arrays +pub trait Aggregate: Send + Sync + Debug { /// Create an accumulator for this aggregate expression fn create_accumulator(&self) -> Rc>; /// Create an aggregate expression for combining the results of accumulators from partitions. @@ -82,6 +78,10 @@ pub trait AggregateExpr: Send + Sync + Debug { fn create_reducer(&self, column_name: &str) -> Arc; } +/// Aggregate expression that can be evaluated against a RecordBatch +pub trait AggregateExpr: PhysicalExpr + Aggregate {} +impl AggregateExpr for T {} + /// Aggregate accumulator pub trait Accumulator: Debug { /// Update the accumulator based on a row in a batch From de3f3d0c048439c75a6b1e56e918a5d0a4f4ae32 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sat, 15 Aug 2020 14:35:06 +0200 Subject: [PATCH 10/13] Added generic aggregate udf expressions and migrated sum to it. --- rust/datafusion/README.md | 2 + rust/datafusion/src/execution/context.rs | 84 +++- .../src/execution/dataframe_impl.rs | 12 +- .../execution/physical_plan/expressions.rs | 396 +++++++++--------- .../execution/physical_plan/hash_aggregate.rs | 4 +- .../src/execution/physical_plan/mod.rs | 17 +- .../src/execution/physical_plan/planner.rs | 58 ++- .../src/execution/physical_plan/udf.rs | 105 ++++- rust/datafusion/src/logicalplan.rs | 11 +- rust/datafusion/src/sql/planner.rs | 160 +++---- 10 files changed, 502 insertions(+), 347 deletions(-) diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 0a193684265..1ee4f10578b 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -53,6 +53,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] Limit - [x] Aggregate - [x] UDFs + - [x] Scalar UDFs + - [x] Aggregate UDFs - [x] Common math functions - String functions - [x] Length of the string diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 64a279b6bd0..b5f210c5c48 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -28,6 +28,7 @@ use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; +use super::physical_plan::udf::AggregateFunction; use crate::dataframe::DataFrame; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; @@ -38,10 +39,10 @@ use crate::execution::physical_plan::common; use crate::execution::physical_plan::csv::CsvReadOptions; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::planner::PhysicalPlannerImpl; -use crate::execution::physical_plan::scalar_functions; use crate::execution::physical_plan::udf::ScalarFunction; use crate::execution::physical_plan::ExecutionPlan; use crate::execution::physical_plan::PhysicalPlanner; +use crate::execution::physical_plan::{aggregate_functions, scalar_functions}; use crate::logicalplan::{FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; @@ -103,12 +104,16 @@ impl ExecutionContext { state: Arc::new(Mutex::new(ExecutionContextState { datasources: Arc::new(Mutex::new(HashMap::new())), scalar_functions: Arc::new(Mutex::new(HashMap::new())), + aggregate_functions: Arc::new(Mutex::new(HashMap::new())), config, })), }; for udf in scalar_functions() { ctx.register_udf(udf); } + for udf in aggregate_functions() { + ctx.register_aggregate_udf(udf); + } ctx } @@ -200,6 +205,16 @@ impl ExecutionContext { .insert(f.name.clone(), Box::new(f)); } + /// Register an aggregate function + pub fn register_aggregate_udf(&mut self, f: AggregateFunction) { + let state = self.state.lock().expect("failed to lock mutex"); + state + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .insert(f.name.clone(), Box::new(f)); + } + /// Get a reference to the registered scalar functions pub fn scalar_functions(&self) -> Arc>>> { self.state @@ -495,6 +510,8 @@ pub struct ExecutionContextState { pub datasources: Arc>>>, /// Scalar functions that are registered with the context pub scalar_functions: Arc>>>, + /// Aggregate functions that are registered with the context + pub aggregate_functions: Arc>>>, /// Context configuration pub config: ExecutionConfig, } @@ -509,7 +526,8 @@ impl SchemaProvider for ExecutionContextState { } fn get_function_meta(&self, name: &str) -> Option> { - self.scalar_functions + let scalar = self + .scalar_functions .lock() .expect("failed to lock mutex") .get(name) @@ -520,8 +538,44 @@ impl SchemaProvider for ExecutionContextState { f.return_type.clone(), FunctionType::Scalar, )) + }); + // give priority to scalar functions + if scalar.is_some() { + return scalar; + } + + self.aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + .map(|f| { + Arc::new(FunctionMeta::new( + name.to_owned(), + f.args.clone(), + DataType::Float32, + FunctionType::Aggregate, + )) }) } + + fn functions(&self) -> Vec { + let mut scalars: Vec = self + .scalar_functions + .lock() + .expect("failed to lock mutex") + .keys() + .cloned() + .collect(); + let mut aggregates: Vec = self + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .keys() + .cloned() + .collect(); + aggregates.append(&mut scalars); + aggregates + } } #[cfg(test)] @@ -737,7 +791,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["SUM(c1)", "SUM(c2)"]); + assert_eq!(field_names(batch), vec!["sum(c1)", "sum(c2)"]); let expected: Vec<&str> = vec!["60,220"]; let mut rows = test::format_batch(&batch); @@ -754,7 +808,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["AVG(c1)", "AVG(c2)"]); + assert_eq!(field_names(batch), vec!["avg(c1)", "avg(c2)"]); let expected: Vec<&str> = vec!["1.5,5.5"]; let mut rows = test::format_batch(&batch); @@ -771,7 +825,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["MAX(c1)", "MAX(c2)"]); + assert_eq!(field_names(batch), vec!["max(c1)", "max(c2)"]); let expected: Vec<&str> = vec!["3,10"]; let mut rows = test::format_batch(&batch); @@ -788,7 +842,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["MIN(c1)", "MIN(c2)"]); + assert_eq!(field_names(batch), vec!["min(c1)", "min(c2)"]); let expected: Vec<&str> = vec!["0,1"]; let mut rows = test::format_batch(&batch); @@ -805,7 +859,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "SUM(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "sum(c2)"]); let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"]; let mut rows = test::format_batch(&batch); @@ -822,7 +876,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "avg(c2)"]); let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; let mut rows = test::format_batch(&batch); @@ -839,7 +893,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "MAX(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "max(c2)"]); let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); @@ -856,7 +910,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "MIN(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "min(c2)"]); let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; let mut rows = test::format_batch(&batch); @@ -873,7 +927,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["count(c1)", "count(c2)"]); let expected: Vec<&str> = vec!["10,10"]; let mut rows = test::format_batch(&batch); @@ -889,7 +943,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["count(c1)", "count(c2)"]); let expected: Vec<&str> = vec!["40,40"]; let mut rows = test::format_batch(&batch); @@ -905,7 +959,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "count(c2)"]); let expected = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); @@ -927,9 +981,9 @@ mod tests { let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? .aggregate( vec![col("c1")], - vec![aggregate_expr("SUM", col("c2"), DataType::UInt32)], + vec![aggregate_expr("sum", col("c2"), DataType::UInt32)], )? - .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? + .project(vec![col("c1"), col("sum(c2)").alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 491cb70a3a4..89994d9feff 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -94,27 +94,27 @@ impl DataFrame for DataFrameImpl { /// Create an expression to represent the min() aggregate function fn min(&self, expr: Expr) -> Result { - self.aggregate_expr("MIN", expr) + self.aggregate_expr("min", expr) } /// Create an expression to represent the max() aggregate function fn max(&self, expr: Expr) -> Result { - self.aggregate_expr("MAX", expr) + self.aggregate_expr("max", expr) } /// Create an expression to represent the sum() aggregate function fn sum(&self, expr: Expr) -> Result { - self.aggregate_expr("SUM", expr) + self.aggregate_expr("sum", expr) } /// Create an expression to represent the avg() aggregate function fn avg(&self, expr: Expr) -> Result { - self.aggregate_expr("AVG", expr) + self.aggregate_expr("avg", expr) } /// Create an expression to represent the count() aggregate function fn count(&self, expr: Expr) -> Result { - self.aggregate_expr("COUNT", expr) + self.aggregate_expr("count", expr) } /// Convert to logical plan @@ -218,7 +218,7 @@ mod tests { let plan = t2.to_logical_plan(); // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \ + let sql = "SELECT c1, min(c12), max(c12), avg(c12), sum(c12), count(c12) \ FROM aggregate_test_100 \ GROUP BY c1"; let sql_plan = create_plan(sql)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index 1c64c3fe44b..e5e204a2a90 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -24,8 +24,9 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::common::get_scalar_value; +use crate::execution::physical_plan::udf; use crate::execution::physical_plan::{ - Accumulator, Aggregate, AggregateExpr, PhysicalExpr, + Accumulator, AggregateExpr, Aggregator, PhysicalExpr, }; use crate::logicalplan::{Operator, ScalarValue}; use arrow::array::{ @@ -49,6 +50,7 @@ use arrow::compute::kernels::comparison::{ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; +use udf::AggregateFunction; /// Represents the column at a given index in a RecordBatch #[derive(Debug)] @@ -96,59 +98,51 @@ pub fn col(name: &str) -> Arc { Arc::new(Column::new(name)) } +/// aggregate functions declared in this module +pub fn aggregate_functions() -> Vec { + vec![sum(), avg(), max(), min(), count()] +} /// SUM aggregate expression #[derive(Debug)] -pub struct Sum { - expr: Arc, -} - -impl Sum { - /// Create a new SUM aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl PhysicalExpr for Sum { - fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(DataType::Int64) - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(DataType::UInt64) - } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "SUM does not support {:?}", - other - ))), - } - } +pub struct Sum {} - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(false) +impl Aggregator for Sum { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(SumAccumulator { sum: None })) } - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_sum(Arc::new(Column::new(column_name))) } } -impl fmt::Display for Sum { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "SUM({})", self.expr) +fn sum_return_type( + expr: &Vec>, + schema: &Schema, +) -> Result { + match expr[0].data_type(schema)? { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(DataType::Int64) + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "SUM does not support {:?}", + other + ))), } } -impl Aggregate for Sum { - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(SumAccumulator { sum: None })) - } - - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_name)))) +/// Creates a sum aggregate function +pub fn sum() -> AggregateFunction { + AggregateFunction { + name: "sum".to_string(), + return_type: Arc::new(sum_return_type), + args: vec![common_types()], + aggregate: Arc::new(Sum {}), } } @@ -170,7 +164,7 @@ macro_rules! sum_accumulate { #[derive(Debug)] struct SumAccumulator { - sum: Option, + pub sum: Option, } impl Accumulator for SumAccumulator { @@ -297,69 +291,70 @@ impl Accumulator for SumAccumulator { } } -/// Create a sum expression -pub fn sum(expr: Arc) -> Arc { - Arc::new(Sum::new(expr)) +/// Create a physical aggregate sum expression +pub fn physical_sum(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "SUM", + vec![expr], + Box::new(sum()), + )) } -/// AVG aggregate expression +/// Average aggregate expression. #[derive(Debug)] -pub struct Avg { - expr: Arc, -} +pub struct Avg {} -impl Avg { - /// Create a new AVG aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl PhysicalExpr for Avg { - fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "AVG does not support {:?}", - other - ))), - } - } - - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(false) +impl Aggregator for Avg { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(AvgAccumulator { + sum: None, + count: None, + })) } - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_avg(Arc::new(Column::new(column_name))) } } -impl fmt::Display for Avg { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "AVG({})", self.expr) +fn avg_return_type( + expr: &Vec>, + schema: &Schema, +) -> Result { + match expr[0].data_type(schema)? { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "AVG does not support {:?}", + other + ))), } } -impl Aggregate for Avg { - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(AvgAccumulator { - sum: None, - count: None, - })) - } +/// Create a physical aggregate avg expression +pub fn physical_avg(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "AVG", + vec![expr], + Box::new(avg()), + )) +} - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Avg::new(Arc::new(Column::new(column_name)))) +/// Creates a avg aggregate function +pub fn avg() -> AggregateFunction { + AggregateFunction { + name: "avg".to_string(), + return_type: Arc::new(avg_return_type), + args: vec![common_types()], + aggregate: Arc::new(Avg {}), } } @@ -425,52 +420,60 @@ impl Accumulator for AvgAccumulator { } } -/// Create a avg expression -pub fn avg(expr: Arc) -> Arc { - Arc::new(Avg::new(expr)) -} - /// MAX aggregate expression #[derive(Debug)] -pub struct Max { - expr: Arc, -} +pub struct Max {} -impl Max { - /// Create a new MAX aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl PhysicalExpr for Max { - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) +impl Aggregator for Max { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MaxAccumulator { max: None })) } - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(false) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_max(Arc::new(Column::new(column_name))) } +} - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +/// Create a physical aggregate max expression +pub fn physical_max(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "MAX", + vec![expr], + Box::new(max()), + )) } -impl fmt::Display for Max { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "MAX({})", self.expr) - } +fn common_types() -> Vec { + // this order dictactes the order on which we try to cast to. + vec![ + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ] } -impl Aggregate for Max { - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(MaxAccumulator { max: None })) +/// Creates a max aggregate function +pub fn max() -> AggregateFunction { + AggregateFunction { + name: "max".to_string(), + return_type: Arc::new(max_return_type), + args: vec![common_types()], + aggregate: Arc::new(Max {}), } +} - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Max::new(Arc::new(Column::new(column_name)))) - } +fn max_return_type( + expr: &Vec>, + schema: &Schema, +) -> Result { + expr[0].data_type(schema) } macro_rules! max_accumulate { @@ -621,51 +624,36 @@ impl Accumulator for MaxAccumulator { } } -/// Create a max expression -pub fn max(expr: Arc) -> Arc { - Arc::new(Max::new(expr)) -} - /// MIN aggregate expression #[derive(Debug)] -pub struct Min { - expr: Arc, -} - -impl Min { - /// Create a new MIN aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl PhysicalExpr for Min { - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) - } - - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(false) - } +pub struct Min {} - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +/// Create a physical aggregate min expression +pub fn physical_min(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "MIN", + vec![expr], + Box::new(min()), + )) } -impl fmt::Display for Min { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "MIN({})", self.expr) +/// Creates a avg aggregate function +pub fn min() -> AggregateFunction { + AggregateFunction { + name: "min".to_string(), + return_type: Arc::new(max_return_type), + args: vec![common_types()], + aggregate: Arc::new(Min {}), } } -impl Aggregate for Min { +impl Aggregator for Min { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MinAccumulator { min: None })) } fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Min::new(Arc::new(Column::new(column_name)))) + physical_min(Arc::new(Column::new(column_name))) } } @@ -817,55 +805,50 @@ impl Accumulator for MinAccumulator { } } -/// Create a min expression -pub fn min(expr: Arc) -> Arc { - Arc::new(Min::new(expr)) -} - /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug)] -pub struct Count { - expr: Arc, -} - -impl Count { - /// Create a new COUNT aggregate function. - pub fn new(expr: Arc) -> Self { - Self { expr: expr } - } -} +pub struct Count {} -impl PhysicalExpr for Count { - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(DataType::UInt64) - } - - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(false) +impl Aggregator for Count { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(CountAccumulator { count: 0 })) } - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_sum(Arc::new(Column::new(column_name))) } } -impl fmt::Display for Count { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "COUNT({})", self.expr) - } +/// Create a physical aggregate count expression +pub fn physical_count(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "COUNT", + vec![expr], + Box::new(count()), + )) } -impl Aggregate for Count { - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(CountAccumulator { count: 0 })) - } +/// Creates a count aggregate function +pub fn count() -> AggregateFunction { + let mut types = common_types(); + types.push(DataType::Utf8); - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_name)))) + AggregateFunction { + name: "count".to_string(), + return_type: Arc::new(count_return_type), + args: vec![types], + aggregate: Arc::new(Count {}), } } +fn count_return_type( + _expr: &Vec>, + _schema: &Schema, +) -> Result { + Ok(DataType::UInt64) +} + #[derive(Debug)] struct CountAccumulator { count: u64, @@ -889,11 +872,6 @@ impl Accumulator for CountAccumulator { } } -/// Create a count expression -pub fn count(expr: Arc) -> Arc { - Arc::new(Count::new(expr)) -} - /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -1493,7 +1471,7 @@ mod tests { fn sum_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let sum = sum(col("a")); + let sum = physical_sum(col("a")); assert_eq!(DataType::Int64, sum.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: @@ -1512,7 +1490,7 @@ mod tests { fn max_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let max = max(col("a")); + let max = physical_max(col("a")); assert_eq!(DataType::Int32, max.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: @@ -1531,7 +1509,7 @@ mod tests { fn min_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let min = min(col("a")); + let min = physical_min(col("a")); assert_eq!(DataType::Int32, min.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: @@ -1548,16 +1526,16 @@ mod tests { fn avg_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let avg = avg(col("a")); + let avg = physical_avg(col("a")); assert_eq!(DataType::Float64, avg.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("SUM(a)", avg.data_type(&schema)?, false), + Field::new("AVG(a)", avg.data_type(&schema)?, false), ]); - let combiner = avg.create_reducer("SUM(a)"); + let combiner = avg.create_reducer("AVG(a)"); assert_eq!(DataType::Float64, combiner.data_type(&schema)?); Ok(()) @@ -1888,7 +1866,7 @@ mod tests { } fn do_sum(batch: &RecordBatch) -> Result> { - let sum = sum(col("a")); + let sum = physical_sum(col("a")); let accum = sum.create_accumulator(); let input = sum.evaluate(batch)?; let mut accum = accum.borrow_mut(); @@ -1899,7 +1877,7 @@ mod tests { } fn do_max(batch: &RecordBatch) -> Result> { - let max = max(col("a")); + let max = physical_max(col("a")); let accum = max.create_accumulator(); let input = max.evaluate(batch)?; let mut accum = accum.borrow_mut(); @@ -1910,7 +1888,7 @@ mod tests { } fn do_min(batch: &RecordBatch) -> Result> { - let min = min(col("a")); + let min = physical_min(col("a")); let accum = min.create_accumulator(); let input = min.evaluate(batch)?; let mut accum = accum.borrow_mut(); @@ -1921,7 +1899,7 @@ mod tests { } fn do_count(batch: &RecordBatch) -> Result> { - let count = count(col("a")); + let count = physical_count(col("a")); let accum = count.create_accumulator(); let input = count.evaluate(batch)?; let mut accum = accum.borrow_mut(); @@ -1932,7 +1910,7 @@ mod tests { } fn do_avg(batch: &RecordBatch) -> Result> { - let avg = avg(col("a")); + let avg = physical_avg(col("a")); let accum = avg.create_accumulator(); let input = avg.evaluate(batch)?; let mut accum = accum.borrow_mut(); diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index dab65250642..54787b5e16c 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -698,7 +698,7 @@ mod tests { use super::*; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; - use crate::execution::physical_plan::expressions::{col, sum}; + use crate::execution::physical_plan::expressions::{col, physical_sum}; use crate::execution::physical_plan::merge::MergeExec; use crate::test; @@ -716,7 +716,7 @@ mod tests { vec![(col("c2"), "c2".to_string())]; let aggregates: Vec<(Arc, String)> = - vec![(sum(col("c4")), "SUM(c4)".to_string())]; + vec![(physical_sum(col("c4")), "SUM(c4)".to_string())]; let partition_aggregate = HashAggregateExec::try_new( groups.clone(), diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index c9935c480fa..b16ba8163ea 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -68,8 +68,9 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug { fn evaluate(&self, batch: &RecordBatch) -> Result; } -/// Aggregate knows how to accumulate arrays -pub trait Aggregate: Send + Sync + Debug { +/// An aggregators knows how to accumulate arrays in parts, so that the array does not have +/// to be all available in memory. This type of aggregation is also known as online update. +pub trait Aggregator: Send + Sync + Debug { /// Create an accumulator for this aggregate expression fn create_accumulator(&self) -> Rc>; /// Create an aggregate expression for combining the results of accumulators from partitions. @@ -79,10 +80,11 @@ pub trait Aggregate: Send + Sync + Debug { } /// Aggregate expression that can be evaluated against a RecordBatch -pub trait AggregateExpr: PhysicalExpr + Aggregate {} -impl AggregateExpr for T {} +pub trait AggregateExpr: PhysicalExpr + Aggregator {} +impl AggregateExpr for T {} -/// Aggregate accumulator +/// An accumulator knows how compute aggregations without full access to the complete array. +/// This is also known as online updates. pub trait Accumulator: Debug { /// Update the accumulator based on a row in a batch fn accumulate_scalar(&mut self, value: Option) -> Result<()>; @@ -104,6 +106,11 @@ pub fn scalar_functions() -> Vec { udfs } +/// Vector of aggregate functions declared in this module +pub fn aggregate_functions() -> Vec { + expressions::aggregate_functions() +} + pub mod common; pub mod csv; pub mod datasource; diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 7ce845bf434..6475fccdfc0 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -19,13 +19,14 @@ use std::sync::{Arc, Mutex}; +use super::udf::AggregateFunctionExpr; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::explain::ExplainExec; use crate::execution::physical_plan::expressions::{ - Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, + BinaryExpr, CastExpr, Column, Literal, PhysicalSortExpr, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::limit::GlobalLimitExec; @@ -403,35 +404,32 @@ impl PhysicalPlannerImpl { ) -> Result> { match e { Expr::AggregateFunction { name, args, .. } => { - match name.to_lowercase().as_ref() { - "sum" => Ok(Arc::new(Sum::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "avg" => Ok(Arc::new(Avg::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "max" => Ok(Arc::new(Max::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "min" => Ok(Arc::new(Min::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "count" => Ok(Arc::new(Count::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - other => Err(ExecutionError::NotImplemented(format!( - "Unsupported aggregate function '{}'", - other + match &ctx_state + .lock() + .expect("failed to lock mutex") + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + { + Some(f) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(self.create_physical_expr( + e, + input_schema, + ctx_state.clone(), + )?); + } + Ok(Arc::new(AggregateFunctionExpr::new( + name, + physical_args, + Box::new(f.as_ref().clone()), + ))) + } + _ => Err(ExecutionError::General(format!( + "Invalid aggregate function '{:?}'", + name ))), } } diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 4d63b8a0309..a41abb801dc 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -25,13 +25,18 @@ use arrow::datatypes::{DataType, Schema}; use crate::error::Result; use crate::execution::physical_plan::PhysicalExpr; +use super::{Accumulator, AggregateExpr, Aggregator}; use arrow::record_batch::RecordBatch; use fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{cell::RefCell, rc::Rc, sync::Arc}; /// Scalar UDF pub type ScalarUdf = Arc Result + Send + Sync>; +/// Function to construct the return type of a function given its arguments. +pub type ReturnType = + Arc>, &Schema) -> Result + Send + Sync>; + /// Scalar UDF Expression #[derive(Clone)] pub struct ScalarFunction { @@ -149,3 +154,101 @@ impl PhysicalExpr for ScalarFunctionExpr { (fun)(&inputs) } } + +/// A generic aggregate function +/* +This struct is + +An aggregate function accepts an arbitrary number of arguments, of arbitrary data types, +and returns an arbitrary type based on the incoming types. + +It is the developer of the function's responsibility to ensure that the aggregator correctly handles the different +types that are presented to them, and that the return type correctly matches the type returned by the +aggregator. + +It is the user of the function's responsibility to pass arguments to the function that have valid types. +*/ +#[derive(Clone)] +pub struct AggregateFunction { + /// Function name + pub name: String, + /// A list of arguments and their respective types. A function can accept more than one type as argument + /// (e.g. sum(i8), sum(u8)). + pub args: Vec>, + /// Return type. This function takes + pub return_type: ReturnType, + /// implementation of the aggregation + pub aggregate: Arc, +} + +/// An aggregate function physical expression +pub struct AggregateFunctionExpr { + name: String, + fun: Box, + // for now, our AggregateFunctionExpr accepts a single element only. + arg: Arc, +} + +impl AggregateFunctionExpr { + /// Create a new AggregateFunctionExpr + pub fn new( + name: &str, + args: Vec>, + fun: Box, + ) -> Self { + Self { + name: name.to_owned(), + arg: args[0].clone(), + fun, + } + } +} + +impl Debug for AggregateFunctionExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("AggregateFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.arg) + .finish() + } +} + +impl PhysicalExpr for AggregateFunctionExpr { + fn data_type(&self, input_schema: &Schema) -> Result { + self.fun.as_ref().return_type.as_ref()(&vec![self.arg.clone()], input_schema) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.arg.evaluate(batch) + } +} + +impl Aggregator for AggregateFunctionExpr { + fn create_accumulator(&self) -> Rc> { + self.fun.aggregate.create_accumulator() + } + + fn create_reducer(&self, column_name: &str) -> Arc { + self.fun.aggregate.create_reducer(column_name) + } +} + +impl fmt::Display for AggregateFunctionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + [&self.arg] + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 5cd75324a55..f9f9d4e021c 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -603,7 +603,7 @@ pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { } } -/// Create an aggregate expression +/// Create an scalar function expression pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { Expr::ScalarFunction { name: name.to_owned(), @@ -612,6 +612,15 @@ pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Ex } } +/// Create an aggregate expression +pub fn aggregate_function(name: &str, expr: Vec, return_type: DataType) -> Expr { + Expr::AggregateFunction { + name: name.to_owned(), + args: expr, + return_type, + } +} + impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 30fabaa3da4..37f696e3542 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -22,8 +22,8 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::logicalplan::Expr::Alias; use crate::logicalplan::{ - lit, Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, - ScalarValue, StringifiedPlan, + lit, Expr, FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder, Operator, + PlanType, ScalarValue, StringifiedPlan, }; use crate::sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}; @@ -44,6 +44,8 @@ pub trait SchemaProvider { fn get_table_meta(&self, name: &str) -> Option; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; + /// Getter list of valid udfs + fn functions(&self) -> Vec; } /// SQL query planner @@ -476,76 +478,50 @@ impl SqlToRel { } SQLExpr::Function(function) => { - //TODO: fix this hack - let name: String = function.name.to_string(); - match name.to_lowercase().as_ref() { - "min" | "max" | "sum" | "avg" => { - let rex_args = function - .args - .iter() - .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; - - // return type is same as the argument type for these aggregate - // functions - let return_type = rex_args[0].get_type(schema)?.clone(); - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - return_type, - }) - } - "count" => { - let rex_args = function - .args - .iter() - .map(|a| match a { - SQLExpr::Value(Value::Number(_)) => Ok(lit(1_u8)), - SQLExpr::Wildcard => Ok(lit(1_u8)), - _ => self.sql_to_rex(a, schema), - }) - .collect::>>()?; - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - return_type: DataType::UInt64, - }) - } - _ => match self.schema_provider.get_function_meta(&name) { - Some(fm) => { - let args = function + // make the search case-insensitive + let name: String = function.name.to_string().to_lowercase(); + + match self.schema_provider.get_function_meta(&name) { + Some(fm) => { + let args = if name == "count" { + // optimization to avoid computing expressions + function + .args + .iter() + .map(|a| match a { + SQLExpr::Value(Value::Number(_)) => Ok(lit(1_u8)), + SQLExpr::Wildcard => Ok(lit(1_u8)), + _ => self.sql_to_rex(a, schema), + }) + .collect::>>()? + } else { + function .args .iter() .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; - let expected_args = match fm.arg_types().len() { - 0 => 0, - _ => fm.arg_types()[0].len(), - }; - let current_args = args.len(); - - if current_args != expected_args { - return Err(ExecutionError::General( - format!("The function '{}' expects {} arguments, but {} were passed", - name, - expected_args, - current_args, - ))); - } - - Ok(Expr::ScalarFunction { + .collect::>>()? + }; + + //let args = coerse_expr(&args, &fm, &schema)?; + + match fm.function_type() { + FunctionType::Scalar => Ok(Expr::ScalarFunction { + name: name.clone(), + args, + return_type: fm.return_type().clone(), + }), + FunctionType::Aggregate => Ok(Expr::AggregateFunction { name: name.clone(), args, return_type: fm.return_type().clone(), - }) + }), } - _ => Err(ExecutionError::General(format!( - "Invalid function '{}'", - name - ))), - }, + } + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'. Valid functions: {:?}", + name, + self.schema_provider.functions(), + ))), } } @@ -691,7 +667,7 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[min(#age)]]\ \n TableScan: person projection=None", ); } @@ -700,7 +676,7 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[sum(#age)]]\ \n TableScan: person projection=None", ); } @@ -709,7 +685,7 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Aggregate: groupBy=[[#state]], aggr=[[min(#age), max(#age)]]\ \n TableScan: person projection=None", ); } @@ -726,7 +702,7 @@ mod tests { #[test] fn select_count_one() { let sql = "SELECT COUNT(1) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -734,7 +710,7 @@ mod tests { #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(#id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -803,8 +779,8 @@ mod tests { fn select_group_by_needs_projection() { let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; let expected = "\ - Projection: #COUNT(state), #state\ - \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\ + Projection: #count(state), #state\ + \n Aggregate: groupBy=[[#state]], aggr=[[count(#state)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -915,15 +891,43 @@ mod tests { } fn get_function_meta(&self, name: &str) -> Option> { - match name { - "sqrt" => Some(Arc::new(FunctionMeta::new( - "sqrt".to_string(), - vec![vec![DataType::Float64]], + let fnc_type = if name == "sqrt" { + FunctionType::Scalar + } else { + FunctionType::Aggregate + }; + let valid_types = if name == "sqrt" { + vec![DataType::Float64] + } else { + vec![ + DataType::UInt8, + DataType::UInt32, + DataType::Int64, + DataType::Int32, DataType::Float64, - FunctionType::Scalar, - ))), + DataType::Utf8, + ] + }; + + let fm = Arc::new(FunctionMeta::new( + name.to_string(), + vec![valid_types], + DataType::Float64, + fnc_type, + )); + + match name { + "sqrt" => Some(fm), + "min" => Some(fm), + "max" => Some(fm), + "sum" => Some(fm), + "count" => Some(fm), _ => None, } } + + fn functions(&self) -> Vec { + vec!["sqrt".to_string()] + } } } From 61c9ce1744b0efb26044e559be11c8475b76e277 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 08:01:16 +0200 Subject: [PATCH 11/13] Added test with a custom aggregation. --- rust/datafusion/src/execution/context.rs | 137 ++++++++++++++++-- .../execution/physical_plan/expressions.rs | 29 +++- .../src/execution/physical_plan/udf.rs | 4 +- .../datafusion/src/optimizer/type_coercion.rs | 82 +++++++++-- 4 files changed, 219 insertions(+), 33 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index b5f210c5c48..95e22ee1940 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -224,6 +224,17 @@ impl ExecutionContext { .clone() } + /// Get a reference to the registered aggregate functions + pub fn aggregate_functions( + &self, + ) -> Arc>>> { + self.state + .lock() + .expect("failed to lock mutex") + .aggregate_functions + .clone() + } + /// Creates a DataFrame for reading a CSV data source. pub fn read_csv( &mut self, @@ -351,11 +362,8 @@ impl ExecutionContext { let rules: Vec> = vec![ Box::new(ProjectionPushDown::new()), Box::new(TypeCoercionRule::new( - self.state - .lock() - .expect("failed to lock mutex") - .scalar_functions - .clone(), + self.scalar_functions(), + self.aggregate_functions(), )), ]; let mut plan = plan.clone(); @@ -551,7 +559,9 @@ impl SchemaProvider for ExecutionContextState { .map(|f| { Arc::new(FunctionMeta::new( name.to_owned(), - f.args.clone(), + f.arg_types.clone(), + // this is wrong, but the actual type is overwritten by the physical plan + // as aggregate functions have a variable type. DataType::Float32, FunctionType::Aggregate, )) @@ -583,13 +593,17 @@ mod tests { use super::*; use crate::datasource::MemTable; - use crate::execution::physical_plan::udf::ScalarUdf; - use crate::logicalplan::{aggregate_expr, col, scalar_function}; + use crate::execution::physical_plan::{ + expressions::Column, + udf::{AggregateFunctionExpr, ScalarUdf}, + Accumulator, AggregateExpr, Aggregator, + }; + use crate::logicalplan::{aggregate_expr, col, scalar_function, ScalarValue}; use crate::test; - use arrow::array::{ArrayRef, Int32Array}; - use arrow::compute::add; + use arrow::array::{ArrayRef, Float64Array, Int32Array}; + use arrow::compute::{add, sum}; use std::fs::File; - use std::io::prelude::*; + use std::{cell::RefCell, io::prelude::*, rc::Rc}; use tempdir::TempDir; use test::*; @@ -1256,6 +1270,107 @@ mod tests { CsvReadOptions::new().schema(&schema), )?; + ctx.register_aggregate_udf(avg()); + Ok(ctx) } + + // declare an accumulator of an average of f64 + // math details here: https://stackoverflow.com/a/23493727/931303 + #[derive(Debug)] + struct MyAvg { + avg: f64, // online average + n: usize, + } + + impl Accumulator for MyAvg { + fn accumulate_scalar(&mut self, value: Option) -> Result<()> { + if let Some(value) = value { + match value { + ScalarValue::Float64(v) => { + self.n += 1; + self.avg = (self.avg * ((self.n - 1) as f64) - v) / self.n as f64; + } + _ => { + return Err(ExecutionError::ExecutionError(format!( + "Unsupported type {:?}.", + value + ))) + } + } + } + Ok(()) + } + + fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> { + match array.data_type() { + DataType::Float64 => { + let array = array + .as_any() + .downcast_ref::() + .expect("Failed to cast array"); + let sum = sum(array).unwrap_or(0.0); + let m = array.len(); + + self.n += m; + self.avg = (self.avg * (self.n - m) as f64 - sum) / self.n as f64; + } + _ => { + return Err(ExecutionError::ExecutionError(format!( + "Unsupported type {:?}.", + array.data_type() + ))) + } + } + Ok(()) + } + + fn get_value(&self) -> Result> { + Ok(Some(ScalarValue::Float64(self.avg))) + } + } + + fn avg() -> AggregateFunction { + AggregateFunction { + name: "my_avg".to_string(), + return_type: Arc::new(|_, _| Ok(DataType::Float64)), + arg_types: vec![vec![DataType::Float64]], + aggregate: Arc::new(MyAvg { avg: 0.0, n: 0 }), + } + } + + // implement it on the same struct just for fun + impl Aggregator for MyAvg { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MyAvg { avg: 0.0, n: 0 })) + } + + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(AggregateFunctionExpr::new( + "my_avg", + vec![Arc::new(Column::new(column_name))], + Box::new(avg()), + )) + } + } + + #[test] + fn aggregate_custom_agg() -> Result<()> { + let results = execute("SELECT c1, my_avg(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + + assert_eq!( + field_names(batch), + vec!["c1", "my_avg(CAST(c2 as Float64))"] + ); + + let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index e5e204a2a90..bc770883fd4 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -141,7 +141,10 @@ pub fn sum() -> AggregateFunction { AggregateFunction { name: "sum".to_string(), return_type: Arc::new(sum_return_type), - args: vec![common_types()], + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), aggregate: Arc::new(Sum {}), } } @@ -353,7 +356,10 @@ pub fn avg() -> AggregateFunction { AggregateFunction { name: "avg".to_string(), return_type: Arc::new(avg_return_type), - args: vec![common_types()], + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), aggregate: Arc::new(Avg {}), } } @@ -464,7 +470,10 @@ pub fn max() -> AggregateFunction { AggregateFunction { name: "max".to_string(), return_type: Arc::new(max_return_type), - args: vec![common_types()], + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), aggregate: Arc::new(Max {}), } } @@ -642,7 +651,10 @@ pub fn min() -> AggregateFunction { AggregateFunction { name: "min".to_string(), return_type: Arc::new(max_return_type), - args: vec![common_types()], + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), aggregate: Arc::new(Min {}), } } @@ -831,13 +843,16 @@ pub fn physical_count(expr: Arc) -> Arc { /// Creates a count aggregate function pub fn count() -> AggregateFunction { - let mut types = common_types(); - types.push(DataType::Utf8); + let mut types = common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(); + types.push(vec![DataType::Utf8]); AggregateFunction { name: "count".to_string(), return_type: Arc::new(count_return_type), - args: vec![types], + arg_types: types, aggregate: Arc::new(Count {}), } } diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index a41abb801dc..6fb68c92fd0 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -157,8 +157,6 @@ impl PhysicalExpr for ScalarFunctionExpr { /// A generic aggregate function /* -This struct is - An aggregate function accepts an arbitrary number of arguments, of arbitrary data types, and returns an arbitrary type based on the incoming types. @@ -174,7 +172,7 @@ pub struct AggregateFunction { pub name: String, /// A list of arguments and their respective types. A function can accept more than one type as argument /// (e.g. sum(i8), sum(u8)). - pub args: Vec>, + pub arg_types: Vec>, /// Return type. This function takes pub return_type: ReturnType, /// implementation of the aggregation diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index f12c2275ce3..84fca1c2d48 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -26,7 +26,7 @@ use std::sync::{Arc, Mutex}; use arrow::datatypes::{DataType, Schema}; use crate::error::{ExecutionError, Result}; -use crate::execution::physical_plan::udf::ScalarFunction; +use crate::execution::physical_plan::udf::{AggregateFunction, ScalarFunction}; use crate::logicalplan::Expr; use crate::logicalplan::LogicalPlan; use crate::optimizer::optimizer::OptimizerRule; @@ -38,6 +38,7 @@ use utils::optimize_explain; /// This optimizer does not alter the structure of the plan, it only changes expressions on it. pub struct TypeCoercionRule { scalar_functions: Arc>>>, + aggregate_functions: Arc>>>, } impl TypeCoercionRule { @@ -45,8 +46,12 @@ impl TypeCoercionRule { /// scalar functions pub fn new( scalar_functions: Arc>>>, + aggregate_functions: Arc>>>, ) -> Self { - Self { scalar_functions } + Self { + scalar_functions, + aggregate_functions, + } } /// Rewrite an expression to include explicit CAST operations when required @@ -71,10 +76,9 @@ impl TypeCoercionRule { expressions[1] = expressions[1].cast_to(&super_type, schema)?; } } - Expr::ScalarFunction { name, .. } => { - // cast the inputs of scalar functions to the appropriate type where possible + Expr::AggregateFunction { name, .. } => { match self - .scalar_functions + .aggregate_functions .lock() .expect("failed to lock mutex") .get(name) @@ -108,6 +112,26 @@ impl TypeCoercionRule { ))); } } + None => { + return Err(ExecutionError::General(format!( + "Invalid aggregate function {}", + name + ))) + } + } + } + Expr::ScalarFunction { name, .. } => { + // cast the inputs of scalar functions to the appropriate type where possible + match self + .scalar_functions + .lock() + .expect("failed to lock mutex") + .get(name) + { + Some(func_meta) => { + expressions = + rewrite_args(expressions, schema, func_meta.as_ref())?; + } _ => { return Err(ExecutionError::General(format!( "Invalid scalar function {}", @@ -163,6 +187,36 @@ impl OptimizerRule for TypeCoercionRule { } } +/// rewrites +fn rewrite_args( + expressions: Vec, + schema: &Schema, + func_meta: &ScalarFunction, +) -> Result> { + // compute the current types and expressions + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new = if func_meta.arg_types.contains(¤t_types) { + Some(expressions) + } else { + maybe_rewrite(&expressions, ¤t_types, &schema, &func_meta.arg_types)? + }; + + if let Some(args) = new { + Ok(args) + } else { + Err(ExecutionError::General(format!( + "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", + func_meta.name, + func_meta.arg_types, + current_types, + ))) + } +} + /// tries to re-cast expressions under schema based on the set of valid signatures fn maybe_rewrite( expressions: &Vec, @@ -226,14 +280,15 @@ mod tests { .project(vec![col("c1"), col("c2")])? .aggregate( vec![col("c1")], - vec![aggregate_expr("SUM", col("c2"), DataType::Int64)], + vec![aggregate_expr("sum", col("c2"), DataType::Int64)], )? .sort(vec![col("c1")])? .limit(10)? .build()?; - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(Arc::new(Mutex::new(scalar_functions))); + let ctx = ExecutionContext::new(); + let mut rule = + TypeCoercionRule::new(ctx.scalar_functions(), ctx.aggregate_functions()); let plan = rule.optimize(&plan)?; // check that the filter had a cast added @@ -241,7 +296,7 @@ mod tests { println!("{}", plan_str); let expected_plan_str = "Limit: 10 Sort: #c1 - Aggregate: groupBy=[[#c1]], aggr=[[SUM(#c2)]] + Aggregate: groupBy=[[#c1]], aggr=[[sum(#c2)]] Projection: #c1, #c2 Selection: #c7 Lt CAST(UInt8(5) AS Int64)"; assert!(plan_str.starts_with(expected_plan_str)); @@ -259,8 +314,10 @@ mod tests { .filter(col("c7").lt(col("c12")))? .build()?; - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(Arc::new(Mutex::new(scalar_functions))); + let mut rule = TypeCoercionRule::new( + Arc::new(Mutex::new(HashMap::new())), + Arc::new(Mutex::new(HashMap::new())), + ); let plan = rule.optimize(&plan)?; assert!( @@ -339,7 +396,8 @@ mod tests { }; let ctx = ExecutionContext::new(); - let rule = TypeCoercionRule::new(ctx.scalar_functions()); + let rule = + TypeCoercionRule::new(ctx.scalar_functions(), ctx.aggregate_functions()); let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); From 091f8a5f85556e87d9e9d72519865f830dd43ddf Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Sun, 16 Aug 2020 10:10:52 +0200 Subject: [PATCH 12/13] Added support for UDFs of variable return types. This allows to declare UDFs that support multiple types. The existing UDFs (math) now support float32 and float64. --- rust/datafusion/src/datatyped.rs | 38 ++++++++ rust/datafusion/src/execution/context.rs | 18 ++-- .../src/execution/dataframe_impl.rs | 24 +---- .../execution/physical_plan/expressions.rs | 91 ++++++++++--------- .../execution/physical_plan/hash_aggregate.rs | 6 +- .../physical_plan/math_expressions.rs | 23 ++++- .../src/execution/physical_plan/mod.rs | 7 +- .../src/execution/physical_plan/planner.rs | 8 +- .../src/execution/physical_plan/projection.rs | 2 +- .../src/execution/physical_plan/udf.rs | 37 +++++--- rust/datafusion/src/lib.rs | 1 + rust/datafusion/src/logicalplan.rs | 71 ++++++++------- .../datafusion/src/optimizer/type_coercion.rs | 17 ++-- rust/datafusion/src/sql/planner.rs | 2 +- rust/datafusion/src/test/mod.rs | 2 +- rust/datafusion/tests/sql.rs | 4 +- 16 files changed, 208 insertions(+), 143 deletions(-) create mode 100644 rust/datafusion/src/datatyped.rs diff --git a/rust/datafusion/src/datatyped.rs b/rust/datafusion/src/datatyped.rs new file mode 100644 index 00000000000..f1ef241027c --- /dev/null +++ b/rust/datafusion/src/datatyped.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains a public trait to annotate objects that know their data type. +// The pattern in this module follows https://stackoverflow.com/a/28664881/931303 + +use crate::error::Result; +use arrow::datatypes::{DataType, Schema}; + +/// Any object that knows how to infer its resulting data type from an underlying schema +pub trait DataTyped: AsDataTyped { + fn get_type(&self, input_schema: &Schema) -> Result; +} + +/// Trait that allows DataTyped objects to be upcasted to DataTyped. +pub trait AsDataTyped { + fn as_datatyped(&self) -> &dyn DataTyped; +} + +impl AsDataTyped for T { + fn as_datatyped(&self) -> &dyn DataTyped { + self + } +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 95e22ee1940..7ae806eef6a 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -560,9 +560,7 @@ impl SchemaProvider for ExecutionContextState { Arc::new(FunctionMeta::new( name.to_owned(), f.arg_types.clone(), - // this is wrong, but the actual type is overwritten by the physical plan - // as aggregate functions have a variable type. - DataType::Float32, + f.return_type.clone(), FunctionType::Aggregate, )) }) @@ -995,7 +993,11 @@ mod tests { let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? .aggregate( vec![col("c1")], - vec![aggregate_expr("sum", col("c2"), DataType::UInt32)], + vec![aggregate_expr( + "sum", + col("c2"), + Arc::new(|_, _| Ok(DataType::Int32)), + )], )? .project(vec![col("c1"), col("sum(c2)").alias("total_salary")])? .build()?; @@ -1118,7 +1120,7 @@ mod tests { let my_add = ScalarFunction::new( "my_add", vec![vec![DataType::Int32, DataType::Int32]], - DataType::Int32, + Arc::new(|_, _| Ok(DataType::Int32)), myfunc, ); @@ -1130,7 +1132,11 @@ mod tests { .project(vec![ col("a"), col("b"), - scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32), + scalar_function( + "my_add", + vec![col("a"), col("b")], + Arc::new(|_, _| Ok(DataType::Int32)), + ), ])? .build()?; diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 89994d9feff..bd471041d49 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -19,10 +19,9 @@ use std::sync::{Arc, Mutex}; -use crate::arrow::datatypes::DataType; use crate::arrow::record_batch::RecordBatch; use crate::dataframe::*; -use crate::error::{ExecutionError, Result}; +use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; use crate::logicalplan::{col, Expr, LogicalPlan, LogicalPlanBuilder}; use arrow::datatypes::Schema; @@ -134,29 +133,12 @@ impl DataFrame for DataFrameImpl { } impl DataFrameImpl { - /// Determine the data type for a given expression - fn get_data_type(&self, expr: &Expr) -> Result { - match expr { - Expr::Column(name) => Ok(self - .plan - .schema() - .field_with_name(name)? - .data_type() - .clone()), - _ => Err(ExecutionError::General(format!( - "Could not determine data type for expr {:?}", - expr - ))), - } - } - /// Create an expression to represent a named aggregate function fn aggregate_expr(&self, name: &str, expr: Expr) -> Result { - let return_type = self.get_data_type(&expr)?; Ok(Expr::AggregateFunction { name: name.to_string(), - args: vec![expr.clone()], - return_type, + args: vec![expr], + return_type: Arc::new(|e, schema| e[0].get_type(&schema)), }) } } diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index bc770883fd4..b1fdb2f96ff 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -28,7 +28,10 @@ use crate::execution::physical_plan::udf; use crate::execution::physical_plan::{ Accumulator, AggregateExpr, Aggregator, PhysicalExpr, }; -use crate::logicalplan::{Operator, ScalarValue}; +use crate::{ + datatyped::DataTyped, + logicalplan::{Operator, ScalarValue}, +}; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray, UInt16Array, @@ -73,15 +76,16 @@ impl fmt::Display for Column { } } -impl PhysicalExpr for Column { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result { +impl DataTyped for Column { + fn get_type(&self, input_schema: &Schema) -> Result { Ok(input_schema .field_with_name(&self.name)? .data_type() .clone()) } +} +impl PhysicalExpr for Column { /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { Ok(input_schema.field_with_name(&self.name)?.is_nullable()) @@ -116,11 +120,8 @@ impl Aggregator for Sum { } } -fn sum_return_type( - expr: &Vec>, - schema: &Schema, -) -> Result { - match expr[0].data_type(schema)? { +fn sum_return_type(expr: &Vec<&dyn DataTyped>, schema: &Schema) -> Result { + match expr[0].get_type(schema)? { DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { Ok(DataType::Int64) } @@ -320,11 +321,8 @@ impl Aggregator for Avg { } } -fn avg_return_type( - expr: &Vec>, - schema: &Schema, -) -> Result { - match expr[0].data_type(schema)? { +fn avg_return_type(expr: &Vec<&dyn DataTyped>, schema: &Schema) -> Result { + match expr[0].get_type(schema)? { DataType::Int8 | DataType::Int16 | DataType::Int32 @@ -478,11 +476,8 @@ pub fn max() -> AggregateFunction { } } -fn max_return_type( - expr: &Vec>, - schema: &Schema, -) -> Result { - expr[0].data_type(schema) +fn max_return_type(expr: &Vec<&dyn DataTyped>, schema: &Schema) -> Result { + expr[0].get_type(schema) } macro_rules! max_accumulate { @@ -857,10 +852,7 @@ pub fn count() -> AggregateFunction { } } -fn count_return_type( - _expr: &Vec>, - _schema: &Schema, -) -> Result { +fn count_return_type(_expr: &Vec<&dyn DataTyped>, _schema: &Schema) -> Result { Ok(DataType::UInt64) } @@ -1019,11 +1011,13 @@ impl fmt::Display for BinaryExpr { } } -impl PhysicalExpr for BinaryExpr { - fn data_type(&self, input_schema: &Schema) -> Result { - self.left.data_type(input_schema) +impl DataTyped for BinaryExpr { + fn get_type(&self, input_schema: &Schema) -> Result { + self.left.get_type(input_schema) } +} +impl PhysicalExpr for BinaryExpr { fn nullable(&self, _input_schema: &Schema) -> Result { // binary operator should always return a boolean value Ok(false) @@ -1109,11 +1103,14 @@ impl fmt::Display for NotExpr { write!(f, "NOT {}", self.arg) } } -impl PhysicalExpr for NotExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { + +impl DataTyped for NotExpr { + fn get_type(&self, _input_schema: &Schema) -> Result { return Ok(DataType::Boolean); } +} +impl PhysicalExpr for NotExpr { fn nullable(&self, _input_schema: &Schema) -> Result { // !Null == true Ok(false) @@ -1166,7 +1163,7 @@ impl CastExpr { input_schema: &Schema, cast_type: DataType, ) -> Result { - let expr_type = expr.data_type(input_schema)?; + let expr_type = expr.get_type(input_schema)?; // numbers can be cast to numbers and strings if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) @@ -1193,11 +1190,13 @@ impl fmt::Display for CastExpr { } } -impl PhysicalExpr for CastExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { +impl DataTyped for CastExpr { + fn get_type(&self, _input_schema: &Schema) -> Result { Ok(self.cast_type.clone()) } +} +impl PhysicalExpr for CastExpr { fn nullable(&self, input_schema: &Schema) -> Result { self.expr.nullable(input_schema) } @@ -1239,11 +1238,13 @@ impl fmt::Display for Literal { } } -impl PhysicalExpr for Literal { - fn data_type(&self, _input_schema: &Schema) -> Result { +impl DataTyped for Literal { + fn get_type(&self, _input_schema: &Schema) -> Result { self.value.get_datatype() } +} +impl PhysicalExpr for Literal { fn nullable(&self, _input_schema: &Schema) -> Result { match &self.value { ScalarValue::Null => Ok(true), @@ -1487,16 +1488,16 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let sum = physical_sum(col("a")); - assert_eq!(DataType::Int64, sum.data_type(&schema)?); + assert_eq!(DataType::Int64, sum.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("SUM(a)", sum.data_type(&schema)?, false), + Field::new("SUM(a)", sum.get_type(&schema)?, false), ]); let combiner = sum.create_reducer("SUM(a)"); - assert_eq!(DataType::Int64, combiner.data_type(&schema)?); + assert_eq!(DataType::Int64, combiner.get_type(&schema)?); Ok(()) } @@ -1506,16 +1507,16 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let max = physical_max(col("a")); - assert_eq!(DataType::Int32, max.data_type(&schema)?); + assert_eq!(DataType::Int32, max.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("Max(a)", max.data_type(&schema)?, false), + Field::new("Max(a)", max.get_type(&schema)?, false), ]); let combiner = max.create_reducer("Max(a)"); - assert_eq!(DataType::Int32, combiner.data_type(&schema)?); + assert_eq!(DataType::Int32, combiner.get_type(&schema)?); Ok(()) } @@ -1525,15 +1526,15 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let min = physical_min(col("a")); - assert_eq!(DataType::Int32, min.data_type(&schema)?); + assert_eq!(DataType::Int32, min.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("MIN(a)", min.data_type(&schema)?, false), + Field::new("MIN(a)", min.get_type(&schema)?, false), ]); let combiner = min.create_reducer("MIN(a)"); - assert_eq!(DataType::Int32, combiner.data_type(&schema)?); + assert_eq!(DataType::Int32, combiner.get_type(&schema)?); Ok(()) } @@ -1542,16 +1543,16 @@ mod tests { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); let avg = physical_avg(col("a")); - assert_eq!(DataType::Float64, avg.data_type(&schema)?); + assert_eq!(DataType::Float64, avg.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("AVG(a)", avg.data_type(&schema)?, false), + Field::new("AVG(a)", avg.get_type(&schema)?, false), ]); let combiner = avg.create_reducer("AVG(a)"); - assert_eq!(DataType::Float64, combiner.data_type(&schema)?); + assert_eq!(DataType::Float64, combiner.get_type(&schema)?); Ok(()) } diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index 54787b5e16c..41780921e45 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -64,10 +64,10 @@ impl HashAggregateExec { let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); for (expr, name) in &group_expr { - fields.push(Field::new(name, expr.data_type(&input_schema)?, true)) + fields.push(Field::new(name, expr.get_type(&input_schema)?, true)) } for (expr, name) in &aggr_expr { - fields.push(Field::new(&name, expr.data_type(&input_schema)?, true)) + fields.push(Field::new(&name, expr.get_type(&input_schema)?, true)) } let schema = Arc::new(Schema::new(fields)); @@ -459,7 +459,7 @@ impl RecordBatchReader for HashAggregateIterator { // aggregate values for i in 0..self.aggr_expr.len() { let aggr_data_type = self.aggr_expr[i] - .data_type(&input_schema) + .get_type(&input_schema) .map_err(ExecutionError::into_arrow_external_error)?; let value = accumulators[i] .borrow_mut() diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 1c97c2139d3..7da8a770b5e 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -72,7 +72,7 @@ macro_rules! math_unary_function { $NAME, // order: from faster to slower vec![vec![DataType::Float32], vec![DataType::Float64]], - DataType::Float64, + Arc::new(|expr, schema| expr[0].get_type(schema)), Arc::new(|args: &[ArrayRef]| { let array = &args[0]; unary_primitive_array_op!(array, $NAME, $FUNC) @@ -122,6 +122,15 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; + + assert_eq!( + *plan + .schema() + .field_with_name("sqrt(CAST(c0 as Float32))")? + .data_type(), + DataType::Float32 + ); + let expected = "Projection: sqrt(CAST(#c0 AS Float32))\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); @@ -136,6 +145,12 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; + + assert_eq!( + *plan.schema().field_with_name("sqrt(c0)")?.data_type(), + DataType::Float64 + ); + let expected = "Projection: sqrt(#c0)\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); @@ -150,6 +165,12 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; + + assert_eq!( + *plan.schema().field_with_name("sqrt(c0)")?.data_type(), + DataType::Float32 + ); + let expected = "Projection: sqrt(#c0)\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index b16ba8163ea..40fc04a1244 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -22,6 +22,7 @@ use std::fmt::{Debug, Display}; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use crate::datatyped::DataTyped; use crate::error::Result; use crate::execution::context::ExecutionContextState; use crate::logicalplan::{LogicalPlan, ScalarValue}; @@ -59,9 +60,7 @@ pub trait Partition: Send + Sync + Debug { /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. -pub trait PhysicalExpr: Send + Sync + Display + Debug { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; +pub trait PhysicalExpr: Send + Sync + Display + Debug + DataTyped { /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result; /// Evaluate an expression against a RecordBatch @@ -99,7 +98,7 @@ pub fn scalar_functions() -> Vec { let mut udfs = vec![ScalarFunction::new( "length", vec![vec![DataType::Utf8]], - DataType::UInt32, + Arc::new(|_, _| Ok(DataType::UInt32)), Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))), )]; udfs.append(&mut math_expressions::scalar_functions()); diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 6475fccdfc0..058e110e332 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -355,11 +355,7 @@ impl PhysicalPlannerImpl { input_schema, data_type.clone(), )?)), - Expr::ScalarFunction { - name, - args, - return_type, - } => match ctx_state + Expr::ScalarFunction { name, args, .. } => match ctx_state .lock() .expect("failed to lock mutex") .scalar_functions @@ -380,7 +376,7 @@ impl PhysicalPlannerImpl { name, Box::new(f.fun.clone()), physical_args, - return_type, + f.return_type.clone(), ))) } _ => Err(ExecutionError::General(format!( diff --git a/rust/datafusion/src/execution/physical_plan/projection.rs b/rust/datafusion/src/execution/physical_plan/projection.rs index a5ad0ef3e03..3ce3f637904 100644 --- a/rust/datafusion/src/execution/physical_plan/projection.rs +++ b/rust/datafusion/src/execution/physical_plan/projection.rs @@ -52,7 +52,7 @@ impl ProjectionExec { .map(|(e, name)| { Ok(Field::new( name, - e.data_type(&input_schema)?, + e.get_type(&input_schema)?, e.nullable(&input_schema)?, )) }) diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 6fb68c92fd0..d43d13acbb2 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -23,7 +23,7 @@ use arrow::array::ArrayRef; use arrow::datatypes::{DataType, Schema}; use crate::error::Result; -use crate::execution::physical_plan::PhysicalExpr; +use crate::{datatyped::DataTyped, execution::physical_plan::PhysicalExpr}; use super::{Accumulator, AggregateExpr, Aggregator}; use arrow::record_batch::RecordBatch; @@ -35,7 +35,7 @@ pub type ScalarUdf = Arc Result + Send + Sync>; /// Function to construct the return type of a function given its arguments. pub type ReturnType = - Arc>, &Schema) -> Result + Send + Sync>; + Arc, &Schema) -> Result + Send + Sync>; /// Scalar UDF Expression #[derive(Clone)] @@ -48,7 +48,7 @@ pub struct ScalarFunction { /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second pub arg_types: Vec>, /// Return type - pub return_type: DataType, + pub return_type: ReturnType, /// UDF implementation pub fun: ScalarUdf, } @@ -58,7 +58,7 @@ impl Debug for ScalarFunction { f.debug_struct("ScalarFunction") .field("name", &self.name) .field("arg_types", &self.arg_types) - .field("return_type", &self.return_type) + //.field("return_type", &self.return_type) .field("fun", &"") .finish() } @@ -69,7 +69,7 @@ impl ScalarFunction { pub fn new( name: &str, arg_types: Vec>, - return_type: DataType, + return_type: ReturnType, fun: ScalarUdf, ) -> Self { Self { @@ -86,7 +86,7 @@ pub struct ScalarFunctionExpr { fun: Box, name: String, args: Vec>, - return_type: DataType, + return_type: ReturnType, } impl Debug for ScalarFunctionExpr { @@ -95,7 +95,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + //.field("return_type", &self.return_type) .finish() } } @@ -106,7 +106,7 @@ impl ScalarFunctionExpr { name: &str, fun: Box, args: Vec>, - return_type: &DataType, + return_type: ReturnType, ) -> Self { Self { fun, @@ -132,11 +132,15 @@ impl fmt::Display for ScalarFunctionExpr { } } -impl PhysicalExpr for ScalarFunctionExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) +impl DataTyped for ScalarFunctionExpr { + fn get_type(&self, input_schema: &Schema) -> Result { + let x = self.args.clone(); + let x = x.iter().map(|x| x.as_datatyped()).collect::>(); + (self.return_type)(&x, input_schema) } +} +impl PhysicalExpr for ScalarFunctionExpr { fn nullable(&self, _input_schema: &Schema) -> Result { Ok(true) } @@ -212,11 +216,16 @@ impl Debug for AggregateFunctionExpr { } } -impl PhysicalExpr for AggregateFunctionExpr { - fn data_type(&self, input_schema: &Schema) -> Result { - self.fun.as_ref().return_type.as_ref()(&vec![self.arg.clone()], input_schema) +impl DataTyped for AggregateFunctionExpr { + fn get_type(&self, input_schema: &Schema) -> Result { + self.fun.as_ref().return_type.as_ref()( + &vec![self.arg.as_datatyped()], + input_schema, + ) } +} +impl PhysicalExpr for AggregateFunctionExpr { fn nullable(&self, _input_schema: &Schema) -> Result { Ok(false) } diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index 73897eeaedf..0fb0e184fa8 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -31,6 +31,7 @@ extern crate sqlparser; pub mod dataframe; pub mod datasource; +mod datatyped; pub mod error; pub mod execution; pub mod logicalplan; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index f9f9d4e021c..bd3f66595e9 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -25,12 +25,16 @@ use std::{fmt, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema}; -use crate::datasource::csv::{CsvFile, CsvReadOptions}; +use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; -use crate::datasource::TableProvider; +use crate::datasource::{CsvReadOptions, TableProvider}; use crate::error::{ExecutionError, Result}; use crate::optimizer::utils; -use crate::sql::parser::FileType; +use crate::{ + datatyped::{AsDataTyped, DataTyped}, + execution::physical_plan::udf::ReturnType, + sql::parser::FileType, +}; use arrow::record_batch::RecordBatch; /// Enumeration of supported function types (Scalar and Aggregate) @@ -43,24 +47,24 @@ pub enum FunctionType { } /// Logical representation of a UDF (user-defined function) -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct FunctionMeta { /// Function name name: String, /// Function argument types arg_types: Vec>, /// Function return type - return_type: DataType, + return_type: ReturnType, /// Function type (Scalar or Aggregate) function_type: FunctionType, } impl FunctionMeta { - #[allow(missing_docs)] + /// constructs a new FunctionMeta pub fn new( name: String, arg_types: Vec>, - return_type: DataType, + return_type: ReturnType, function_type: FunctionType, ) -> Self { FunctionMeta { @@ -79,7 +83,7 @@ impl FunctionMeta { &self.arg_types } /// Getter for the `DataType` the function returns - pub fn return_type(&self) -> &DataType { + pub fn return_type(&self) -> &ReturnType { &self.return_type } /// Getter for the `FunctionType` @@ -271,12 +275,8 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { Expr::Alias(expr, ..) => expr.get_type(input_schema), Expr::Column(name) => Ok(input_schema.field_with_name(name)?.data_type().clone()), Expr::Literal(ref lit) => lit.get_datatype(), - Expr::ScalarFunction { - ref return_type, .. - } => Ok(return_type.clone()), - Expr::AggregateFunction { - ref return_type, .. - } => Ok(return_type.clone()), + Expr::ScalarFunction { .. } => e.get_type(&input_schema), + Expr::AggregateFunction { .. } => e.get_type(&input_schema), Expr::Cast { ref data_type, .. } => Ok(data_type.clone()), Expr::BinaryExpr { ref left, @@ -307,7 +307,7 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result, String), @@ -355,7 +355,7 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, /// The `DataType` the expression will yield - return_type: DataType, + return_type: ReturnType, }, /// aggregate function AggregateFunction { @@ -364,22 +364,25 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, /// The `DataType` the expression will yield - return_type: DataType, + return_type: ReturnType, }, /// Wildcard Wildcard, } -impl Expr { - /// Find the `DataType` for the expression - pub fn get_type(&self, schema: &Schema) -> Result { +impl DataTyped for Expr { + fn get_type(&self, schema: &Schema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), Expr::Literal(l) => l.get_datatype(), Expr::Cast { data_type, .. } => Ok(data_type.clone()), - Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()), - Expr::AggregateFunction { return_type, .. } => Ok(return_type.clone()), + Expr::ScalarFunction { + args, return_type, .. + } => return_type(&args.iter().map(|x| x.as_datatyped()).collect(), schema), + Expr::AggregateFunction { + args, return_type, .. + } => return_type(&args.iter().map(|x| x.as_datatyped()).collect(), schema), Expr::Not(_) => Ok(DataType::Boolean), Expr::IsNull(_) => Ok(DataType::Boolean), Expr::IsNotNull(_) => Ok(DataType::Boolean), @@ -407,7 +410,9 @@ impl Expr { Expr::Nested(e) => e.get_type(schema), } } +} +impl Expr { /// Return the name of this expression /// /// This represents how a column with this expression is named when no alias is chosen @@ -565,7 +570,7 @@ macro_rules! unary_math_expr { ($NAME:expr, $FUNC:ident) => { #[allow(missing_docs)] pub fn $FUNC(e: Expr) -> Expr { - scalar_function($NAME, vec![e], DataType::Float64) + scalar_function($NAME, vec![e], Arc::new(|e, schema| e[0].get_type(&schema))) } }; } @@ -591,11 +596,11 @@ unary_math_expr!("log10", log10); /// returns the length of a string in bytes pub fn length(e: Expr) -> Expr { - scalar_function("length", vec![e], DataType::UInt32) + scalar_function("length", vec![e], Arc::new(|_, _| Ok(DataType::UInt32))) } /// Create an aggregate expression -pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { +pub fn aggregate_expr(name: &str, expr: Expr, return_type: ReturnType) -> Expr { Expr::AggregateFunction { name: name.to_owned(), args: vec![expr], @@ -604,7 +609,7 @@ pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { } /// Create an scalar function expression -pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { +pub fn scalar_function(name: &str, expr: Vec, return_type: ReturnType) -> Expr { Expr::ScalarFunction { name: name.to_owned(), args: expr, @@ -613,7 +618,7 @@ pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Ex } /// Create an aggregate expression -pub fn aggregate_function(name: &str, expr: Vec, return_type: DataType) -> Expr { +pub fn aggregate_function(name: &str, expr: Vec, return_type: ReturnType) -> Expr { Expr::AggregateFunction { name: name.to_owned(), args: expr, @@ -1101,7 +1106,7 @@ impl LogicalPlanBuilder { /// Apply a projection pub fn project(&self, expr: Vec) -> Result { let input_schema = self.plan.schema(); - let projected_expr = if expr.contains(&Expr::Wildcard) { + let projected_expr = { let mut expr_vec = vec![]; (0..expr.len()).for_each(|i| match &expr[i] { Expr::Wildcard => { @@ -1111,8 +1116,6 @@ impl LogicalPlanBuilder { _ => expr_vec.push(expr[i].clone()), }); expr_vec - } else { - expr.clone() }; let schema = @@ -1277,8 +1280,12 @@ mod tests { )? .aggregate( vec![col("state")], - vec![aggregate_expr("SUM", col("salary"), DataType::Int32) - .alias("total_salary")], + vec![aggregate_expr( + "SUM", + col("salary"), + Arc::new(|_, _| Ok(DataType::Int32)), + ) + .alias("total_salary")], )? .project(vec![col("state"), col("total_salary")])? .build()?; diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index 84fca1c2d48..98076f47a80 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -30,7 +30,7 @@ use crate::execution::physical_plan::udf::{AggregateFunction, ScalarFunction}; use crate::logicalplan::Expr; use crate::logicalplan::LogicalPlan; use crate::optimizer::optimizer::OptimizerRule; -use crate::optimizer::utils; +use crate::{datatyped::DataTyped, optimizer::utils}; use utils::optimize_explain; /// Optimizer that applies coercion rules to expressions in the logical plan. @@ -267,6 +267,7 @@ mod tests { use crate::logicalplan::{aggregate_expr, col, lit, LogicalPlanBuilder, Operator}; use crate::test::arrow_testdata_path; use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; #[test] fn test_all_operators() -> Result<()> { @@ -280,7 +281,11 @@ mod tests { .project(vec![col("c1"), col("c2")])? .aggregate( vec![col("c1")], - vec![aggregate_expr("sum", col("c2"), DataType::Int64)], + vec![aggregate_expr( + "sum", + col("c2"), + Arc::new(|_, _| Ok(DataType::Int64)), + )], )? .sort(vec![col("c1")])? .limit(10)? @@ -535,10 +540,10 @@ mod tests { ]; for (i, case) in cases.iter().enumerate() { - if maybe_rewrite(&case.0, &case.1, &case.2, &case.3)? != case.4 { - assert_eq!(maybe_rewrite(&case.0, &case.1, &case.2, &case.3)?, case.4); - return Err(ExecutionError::General(format!("case {} failed", i))); - } + let result = maybe_rewrite(&case.0, &case.1, &case.2, &case.3)?; + let result = format!("case {}: {:?}", i, result); + let expected = format!("case {}: {:?}", i, case.4); + assert_eq!(result, expected); } Ok(()) } diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 37f696e3542..8a61832dcde 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -912,7 +912,7 @@ mod tests { let fm = Arc::new(FunctionMeta::new( name.to_string(), vec![valid_types], - DataType::Float64, + Arc::new(|_, _| Ok(DataType::Float64)), fnc_type, )); diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 317c14564f1..cd066c6ab91 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -228,6 +228,6 @@ pub fn max(expr: Expr) -> Expr { Expr::AggregateFunction { name: "MAX".to_owned(), args: vec![expr], - return_type: DataType::Float64, + return_type: Arc::new(|_, _| Ok(DataType::Float64)), } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 9a1807f2fa2..07ad59e4586 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -290,7 +290,7 @@ fn create_ctx() -> Result { ctx.register_udf(ScalarFunction::new( "custom_sqrt", vec![vec![DataType::Float64]], - DataType::Float64, + Arc::new(|_, _| Ok(DataType::Float64)), Arc::new(custom_sqrt), )); @@ -302,7 +302,7 @@ fn create_ctx() -> Result { vec![DataType::Float32, DataType::Float64], vec![DataType::Float64, DataType::Float64], ], - DataType::Float64, + Arc::new(|_, _| Ok(DataType::Float64)), Arc::new(custom_add), )); From c01f9ebdb267a7a09e3600942d113674060f1905 Mon Sep 17 00:00:00 2001 From: "Jorge C. Leitao" Date: Mon, 17 Aug 2020 20:16:34 +0200 Subject: [PATCH 13/13] Made re-write expression in type coercer respect new expr types. --- .../physical_plan/math_expressions.rs | 5 +---- rust/datafusion/src/optimizer/utils.rs | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 7da8a770b5e..aea348d255f 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -124,10 +124,7 @@ mod tests { let plan = ctx.optimize(&plan)?; assert_eq!( - *plan - .schema() - .field_with_name("sqrt(CAST(c0 as Float32))")? - .data_type(), + *plan.schema().field_with_name("sqrt(c0)")?.data_type(), DataType::Float32 ); diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 9195e195895..bf72a7dc0c4 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -19,9 +19,10 @@ use std::collections::HashSet; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use super::optimizer::OptimizerRule; +use crate::datatyped::DataTyped; use crate::error::{ExecutionError, Result}; use crate::logicalplan::{Expr, LogicalPlan, PlanType, StringifiedPlan}; @@ -269,7 +270,21 @@ pub fn from_plan( LogicalPlan::Projection { schema, .. } => Ok(LogicalPlan::Projection { expr: expr.clone(), input: Box::new(inputs[0].clone()), - schema: schema.clone(), + // new expressions may have a different type, which changes the resulting schema + schema: Box::new(Schema::new( + schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + Ok(Field::new( + f.name(), + expr[i].get_type(inputs[0].schema())?, + f.is_nullable(), + )) + }) + .collect::>>()?, + )), }), LogicalPlan::Selection { .. } => Ok(LogicalPlan::Selection { expr: expr[0].clone(),