diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 0a193684265..1ee4f10578b 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -53,6 +53,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] Limit - [x] Aggregate - [x] UDFs + - [x] Scalar UDFs + - [x] Aggregate UDFs - [x] Common math functions - String functions - [x] Length of the string diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 8f92aae302b..95e22ee1940 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -28,6 +28,7 @@ use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; +use super::physical_plan::udf::AggregateFunction; use crate::dataframe::DataFrame; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; @@ -38,10 +39,10 @@ use crate::execution::physical_plan::common; use crate::execution::physical_plan::csv::CsvReadOptions; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::planner::PhysicalPlannerImpl; -use crate::execution::physical_plan::scalar_functions; use crate::execution::physical_plan::udf::ScalarFunction; use crate::execution::physical_plan::ExecutionPlan; use crate::execution::physical_plan::PhysicalPlanner; +use crate::execution::physical_plan::{aggregate_functions, scalar_functions}; use crate::logicalplan::{FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; @@ -103,12 +104,16 @@ impl ExecutionContext { state: Arc::new(Mutex::new(ExecutionContextState { datasources: Arc::new(Mutex::new(HashMap::new())), scalar_functions: Arc::new(Mutex::new(HashMap::new())), + aggregate_functions: Arc::new(Mutex::new(HashMap::new())), config, })), }; for udf in scalar_functions() { ctx.register_udf(udf); } + for udf in aggregate_functions() { + ctx.register_aggregate_udf(udf); + } ctx } @@ -200,6 +205,16 @@ impl ExecutionContext { .insert(f.name.clone(), Box::new(f)); } + /// Register an aggregate function + pub fn register_aggregate_udf(&mut self, f: AggregateFunction) { + let state = self.state.lock().expect("failed to lock mutex"); + state + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .insert(f.name.clone(), Box::new(f)); + } + /// Get a reference to the registered scalar functions pub fn scalar_functions(&self) -> Arc>>> { self.state @@ -209,6 +224,17 @@ impl ExecutionContext { .clone() } + /// Get a reference to the registered aggregate functions + pub fn aggregate_functions( + &self, + ) -> Arc>>> { + self.state + .lock() + .expect("failed to lock mutex") + .aggregate_functions + .clone() + } + /// Creates a DataFrame for reading a CSV data source. pub fn read_csv( &mut self, @@ -336,11 +362,8 @@ impl ExecutionContext { let rules: Vec> = vec![ Box::new(ProjectionPushDown::new()), Box::new(TypeCoercionRule::new( - self.state - .lock() - .expect("failed to lock mutex") - .scalar_functions - .clone(), + self.scalar_functions(), + self.aggregate_functions(), )), ]; let mut plan = plan.clone(); @@ -495,6 +518,8 @@ pub struct ExecutionContextState { pub datasources: Arc>>>, /// Scalar functions that are registered with the context pub scalar_functions: Arc>>>, + /// Aggregate functions that are registered with the context + pub aggregate_functions: Arc>>>, /// Context configuration pub config: ExecutionConfig, } @@ -509,19 +534,58 @@ impl SchemaProvider for ExecutionContextState { } fn get_function_meta(&self, name: &str) -> Option> { - self.scalar_functions + let scalar = self + .scalar_functions .lock() .expect("failed to lock mutex") .get(name) .map(|f| { Arc::new(FunctionMeta::new( name.to_owned(), - f.args.clone(), + f.arg_types.clone(), f.return_type.clone(), FunctionType::Scalar, )) + }); + // give priority to scalar functions + if scalar.is_some() { + return scalar; + } + + self.aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + .map(|f| { + Arc::new(FunctionMeta::new( + name.to_owned(), + f.arg_types.clone(), + // this is wrong, but the actual type is overwritten by the physical plan + // as aggregate functions have a variable type. + DataType::Float32, + FunctionType::Aggregate, + )) }) } + + fn functions(&self) -> Vec { + let mut scalars: Vec = self + .scalar_functions + .lock() + .expect("failed to lock mutex") + .keys() + .cloned() + .collect(); + let mut aggregates: Vec = self + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .keys() + .cloned() + .collect(); + aggregates.append(&mut scalars); + aggregates + } } #[cfg(test)] @@ -529,13 +593,17 @@ mod tests { use super::*; use crate::datasource::MemTable; - use crate::execution::physical_plan::udf::ScalarUdf; - use crate::logicalplan::{aggregate_expr, col, scalar_function}; + use crate::execution::physical_plan::{ + expressions::Column, + udf::{AggregateFunctionExpr, ScalarUdf}, + Accumulator, AggregateExpr, Aggregator, + }; + use crate::logicalplan::{aggregate_expr, col, scalar_function, ScalarValue}; use crate::test; - use arrow::array::{ArrayRef, Int32Array}; - use arrow::compute::add; + use arrow::array::{ArrayRef, Float64Array, Int32Array}; + use arrow::compute::{add, sum}; use std::fs::File; - use std::io::prelude::*; + use std::{cell::RefCell, io::prelude::*, rc::Rc}; use tempdir::TempDir; use test::*; @@ -737,7 +805,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["SUM(c1)", "SUM(c2)"]); + assert_eq!(field_names(batch), vec!["sum(c1)", "sum(c2)"]); let expected: Vec<&str> = vec!["60,220"]; let mut rows = test::format_batch(&batch); @@ -754,7 +822,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["AVG(c1)", "AVG(c2)"]); + assert_eq!(field_names(batch), vec!["avg(c1)", "avg(c2)"]); let expected: Vec<&str> = vec!["1.5,5.5"]; let mut rows = test::format_batch(&batch); @@ -771,7 +839,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["MAX(c1)", "MAX(c2)"]); + assert_eq!(field_names(batch), vec!["max(c1)", "max(c2)"]); let expected: Vec<&str> = vec!["3,10"]; let mut rows = test::format_batch(&batch); @@ -788,7 +856,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["MIN(c1)", "MIN(c2)"]); + assert_eq!(field_names(batch), vec!["min(c1)", "min(c2)"]); let expected: Vec<&str> = vec!["0,1"]; let mut rows = test::format_batch(&batch); @@ -805,7 +873,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "SUM(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "sum(c2)"]); let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"]; let mut rows = test::format_batch(&batch); @@ -822,7 +890,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "avg(c2)"]); let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; let mut rows = test::format_batch(&batch); @@ -839,7 +907,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "MAX(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "max(c2)"]); let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); @@ -856,7 +924,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "MIN(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "min(c2)"]); let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; let mut rows = test::format_batch(&batch); @@ -873,7 +941,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["count(c1)", "count(c2)"]); let expected: Vec<&str> = vec!["10,10"]; let mut rows = test::format_batch(&batch); @@ -889,7 +957,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["count(c1)", "count(c2)"]); let expected: Vec<&str> = vec!["40,40"]; let mut rows = test::format_batch(&batch); @@ -905,7 +973,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "count(c2)"]); let expected = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); @@ -927,9 +995,9 @@ mod tests { let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? .aggregate( vec![col("c1")], - vec![aggregate_expr("SUM", col("c2"), DataType::UInt32)], + vec![aggregate_expr("sum", col("c2"), DataType::UInt32)], )? - .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? + .project(vec![col("c1"), col("sum(c2)").alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; @@ -1049,10 +1117,7 @@ mod tests { let my_add = ScalarFunction::new( "my_add", - vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ], + vec![vec![DataType::Int32, DataType::Int32]], DataType::Int32, myfunc, ); @@ -1205,6 +1270,107 @@ mod tests { CsvReadOptions::new().schema(&schema), )?; + ctx.register_aggregate_udf(avg()); + Ok(ctx) } + + // declare an accumulator of an average of f64 + // math details here: https://stackoverflow.com/a/23493727/931303 + #[derive(Debug)] + struct MyAvg { + avg: f64, // online average + n: usize, + } + + impl Accumulator for MyAvg { + fn accumulate_scalar(&mut self, value: Option) -> Result<()> { + if let Some(value) = value { + match value { + ScalarValue::Float64(v) => { + self.n += 1; + self.avg = (self.avg * ((self.n - 1) as f64) - v) / self.n as f64; + } + _ => { + return Err(ExecutionError::ExecutionError(format!( + "Unsupported type {:?}.", + value + ))) + } + } + } + Ok(()) + } + + fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> { + match array.data_type() { + DataType::Float64 => { + let array = array + .as_any() + .downcast_ref::() + .expect("Failed to cast array"); + let sum = sum(array).unwrap_or(0.0); + let m = array.len(); + + self.n += m; + self.avg = (self.avg * (self.n - m) as f64 - sum) / self.n as f64; + } + _ => { + return Err(ExecutionError::ExecutionError(format!( + "Unsupported type {:?}.", + array.data_type() + ))) + } + } + Ok(()) + } + + fn get_value(&self) -> Result> { + Ok(Some(ScalarValue::Float64(self.avg))) + } + } + + fn avg() -> AggregateFunction { + AggregateFunction { + name: "my_avg".to_string(), + return_type: Arc::new(|_, _| Ok(DataType::Float64)), + arg_types: vec![vec![DataType::Float64]], + aggregate: Arc::new(MyAvg { avg: 0.0, n: 0 }), + } + } + + // implement it on the same struct just for fun + impl Aggregator for MyAvg { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MyAvg { avg: 0.0, n: 0 })) + } + + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(AggregateFunctionExpr::new( + "my_avg", + vec![Arc::new(Column::new(column_name))], + Box::new(avg()), + )) + } + } + + #[test] + fn aggregate_custom_agg() -> Result<()> { + let results = execute("SELECT c1, my_avg(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + + assert_eq!( + field_names(batch), + vec!["c1", "my_avg(CAST(c2 as Float64))"] + ); + + let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 491cb70a3a4..89994d9feff 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -94,27 +94,27 @@ impl DataFrame for DataFrameImpl { /// Create an expression to represent the min() aggregate function fn min(&self, expr: Expr) -> Result { - self.aggregate_expr("MIN", expr) + self.aggregate_expr("min", expr) } /// Create an expression to represent the max() aggregate function fn max(&self, expr: Expr) -> Result { - self.aggregate_expr("MAX", expr) + self.aggregate_expr("max", expr) } /// Create an expression to represent the sum() aggregate function fn sum(&self, expr: Expr) -> Result { - self.aggregate_expr("SUM", expr) + self.aggregate_expr("sum", expr) } /// Create an expression to represent the avg() aggregate function fn avg(&self, expr: Expr) -> Result { - self.aggregate_expr("AVG", expr) + self.aggregate_expr("avg", expr) } /// Create an expression to represent the count() aggregate function fn count(&self, expr: Expr) -> Result { - self.aggregate_expr("COUNT", expr) + self.aggregate_expr("count", expr) } /// Convert to logical plan @@ -218,7 +218,7 @@ mod tests { let plan = t2.to_logical_plan(); // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \ + let sql = "SELECT c1, min(c12), max(c12), avg(c12), sum(c12), count(c12) \ FROM aggregate_test_100 \ GROUP BY c1"; let sql_plan = create_plan(sql)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index dff913b9162..bc770883fd4 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -24,7 +24,10 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::common::get_scalar_value; -use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; +use crate::execution::physical_plan::udf; +use crate::execution::physical_plan::{ + Accumulator, AggregateExpr, Aggregator, PhysicalExpr, +}; use crate::logicalplan::{Operator, ScalarValue}; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, @@ -47,6 +50,7 @@ use arrow::compute::kernels::comparison::{ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; +use udf::AggregateFunction; /// Represents the column at a given index in a RecordBatch #[derive(Debug)] @@ -94,47 +98,54 @@ pub fn col(name: &str) -> Arc { Arc::new(Column::new(name)) } +/// aggregate functions declared in this module +pub fn aggregate_functions() -> Vec { + vec![sum(), avg(), max(), min(), count()] +} /// SUM aggregate expression #[derive(Debug)] -pub struct Sum { - expr: Arc, -} +pub struct Sum {} -impl Sum { - /// Create a new SUM aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl AggregateExpr for Sum { - fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(DataType::Int64) - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(DataType::UInt64) - } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "SUM does not support {:?}", - other - ))), - } +impl Aggregator for Sum { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(SumAccumulator { sum: None })) } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_sum(Arc::new(Column::new(column_name))) } +} - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(SumAccumulator { sum: None })) +fn sum_return_type( + expr: &Vec>, + schema: &Schema, +) -> Result { + match expr[0].data_type(schema)? { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(DataType::Int64) + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "SUM does not support {:?}", + other + ))), } +} - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_name)))) +/// Creates a sum aggregate function +pub fn sum() -> AggregateFunction { + AggregateFunction { + name: "sum".to_string(), + return_type: Arc::new(sum_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Sum {}), } } @@ -156,7 +167,7 @@ macro_rules! sum_accumulate { #[derive(Debug)] struct SumAccumulator { - sum: Option, + pub sum: Option, } impl Accumulator for SumAccumulator { @@ -283,48 +294,20 @@ impl Accumulator for SumAccumulator { } } -/// Create a sum expression -pub fn sum(expr: Arc) -> Arc { - Arc::new(Sum::new(expr)) +/// Create a physical aggregate sum expression +pub fn physical_sum(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "SUM", + vec![expr], + Box::new(sum()), + )) } -/// AVG aggregate expression +/// Average aggregate expression. #[derive(Debug)] -pub struct Avg { - expr: Arc, -} - -impl Avg { - /// Create a new AVG aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl AggregateExpr for Avg { - fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "AVG does not support {:?}", - other - ))), - } - } - - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +pub struct Avg {} +impl Aggregator for Avg { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(AvgAccumulator { sum: None, @@ -333,7 +316,51 @@ impl AggregateExpr for Avg { } fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Avg::new(Arc::new(Column::new(column_name)))) + physical_avg(Arc::new(Column::new(column_name))) + } +} + +fn avg_return_type( + expr: &Vec>, + schema: &Schema, +) -> Result { + match expr[0].data_type(schema)? { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "AVG does not support {:?}", + other + ))), + } +} + +/// Create a physical aggregate avg expression +pub fn physical_avg(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "AVG", + vec![expr], + Box::new(avg()), + )) +} + +/// Creates a avg aggregate function +pub fn avg() -> AggregateFunction { + AggregateFunction { + name: "avg".to_string(), + return_type: Arc::new(avg_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Avg {}), } } @@ -399,40 +426,63 @@ impl Accumulator for AvgAccumulator { } } -/// Create a avg expression -pub fn avg(expr: Arc) -> Arc { - Arc::new(Avg::new(expr)) -} - /// MAX aggregate expression #[derive(Debug)] -pub struct Max { - expr: Arc, -} +pub struct Max {} + +impl Aggregator for Max { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MaxAccumulator { max: None })) + } -impl Max { - /// Create a new MAX aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } + fn create_reducer(&self, column_name: &str) -> Arc { + physical_max(Arc::new(Column::new(column_name))) } } -impl AggregateExpr for Max { - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) - } +/// Create a physical aggregate max expression +pub fn physical_max(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "MAX", + vec![expr], + Box::new(max()), + )) +} - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +fn common_types() -> Vec { + // this order dictactes the order on which we try to cast to. + vec![ + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ] +} - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(MaxAccumulator { max: None })) +/// Creates a max aggregate function +pub fn max() -> AggregateFunction { + AggregateFunction { + name: "max".to_string(), + return_type: Arc::new(max_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Max {}), } +} - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Max::new(Arc::new(Column::new(column_name)))) - } +fn max_return_type( + expr: &Vec>, + schema: &Schema, +) -> Result { + expr[0].data_type(schema) } macro_rules! max_accumulate { @@ -583,39 +633,39 @@ impl Accumulator for MaxAccumulator { } } -/// Create a max expression -pub fn max(expr: Arc) -> Arc { - Arc::new(Max::new(expr)) -} - /// MIN aggregate expression #[derive(Debug)] -pub struct Min { - expr: Arc, -} +pub struct Min {} -impl Min { - /// Create a new MIN aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } +/// Create a physical aggregate min expression +pub fn physical_min(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "MIN", + vec![expr], + Box::new(min()), + )) } -impl AggregateExpr for Min { - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) - } - - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) +/// Creates a avg aggregate function +pub fn min() -> AggregateFunction { + AggregateFunction { + name: "min".to_string(), + return_type: Arc::new(max_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Min {}), } +} +impl Aggregator for Min { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MinAccumulator { min: None })) } fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Min::new(Arc::new(Column::new(column_name)))) + physical_min(Arc::new(Column::new(column_name))) } } @@ -767,43 +817,53 @@ impl Accumulator for MinAccumulator { } } -/// Create a min expression -pub fn min(expr: Arc) -> Arc { - Arc::new(Min::new(expr)) -} - /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug)] -pub struct Count { - expr: Arc, -} +pub struct Count {} -impl Count { - /// Create a new COUNT aggregate function. - pub fn new(expr: Arc) -> Self { - Self { expr: expr } +impl Aggregator for Count { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(CountAccumulator { count: 0 })) } -} -impl AggregateExpr for Count { - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(DataType::UInt64) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_sum(Arc::new(Column::new(column_name))) } +} - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +/// Create a physical aggregate count expression +pub fn physical_count(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "COUNT", + vec![expr], + Box::new(count()), + )) +} - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(CountAccumulator { count: 0 })) - } +/// Creates a count aggregate function +pub fn count() -> AggregateFunction { + let mut types = common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(); + types.push(vec![DataType::Utf8]); - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_name)))) + AggregateFunction { + name: "count".to_string(), + return_type: Arc::new(count_return_type), + arg_types: types, + aggregate: Arc::new(Count {}), } } +fn count_return_type( + _expr: &Vec>, + _schema: &Schema, +) -> Result { + Ok(DataType::UInt64) +} + #[derive(Debug)] struct CountAccumulator { count: u64, @@ -827,11 +887,6 @@ impl Accumulator for CountAccumulator { } } -/// Create a count expression -pub fn count(expr: Arc) -> Arc { - Arc::new(Count::new(expr)) -} - /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -1431,7 +1486,7 @@ mod tests { fn sum_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let sum = sum(col("a")); + let sum = physical_sum(col("a")); assert_eq!(DataType::Int64, sum.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: @@ -1450,7 +1505,7 @@ mod tests { fn max_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let max = max(col("a")); + let max = physical_max(col("a")); assert_eq!(DataType::Int32, max.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: @@ -1469,7 +1524,7 @@ mod tests { fn min_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let min = min(col("a")); + let min = physical_min(col("a")); assert_eq!(DataType::Int32, min.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: @@ -1486,16 +1541,16 @@ mod tests { fn avg_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let avg = avg(col("a")); + let avg = physical_avg(col("a")); assert_eq!(DataType::Float64, avg.data_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("SUM(a)", avg.data_type(&schema)?, false), + Field::new("AVG(a)", avg.data_type(&schema)?, false), ]); - let combiner = avg.create_reducer("SUM(a)"); + let combiner = avg.create_reducer("AVG(a)"); assert_eq!(DataType::Float64, combiner.data_type(&schema)?); Ok(()) @@ -1826,9 +1881,9 @@ mod tests { } fn do_sum(batch: &RecordBatch) -> Result> { - let sum = sum(col("a")); + let sum = physical_sum(col("a")); let accum = sum.create_accumulator(); - let input = sum.evaluate_input(batch)?; + let input = sum.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1837,9 +1892,9 @@ mod tests { } fn do_max(batch: &RecordBatch) -> Result> { - let max = max(col("a")); + let max = physical_max(col("a")); let accum = max.create_accumulator(); - let input = max.evaluate_input(batch)?; + let input = max.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1848,9 +1903,9 @@ mod tests { } fn do_min(batch: &RecordBatch) -> Result> { - let min = min(col("a")); + let min = physical_min(col("a")); let accum = min.create_accumulator(); - let input = min.evaluate_input(batch)?; + let input = min.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1859,9 +1914,9 @@ mod tests { } fn do_count(batch: &RecordBatch) -> Result> { - let count = count(col("a")); + let count = physical_count(col("a")); let accum = count.create_accumulator(); - let input = count.evaluate_input(batch)?; + let input = count.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1870,9 +1925,9 @@ mod tests { } fn do_avg(batch: &RecordBatch) -> Result> { - let avg = avg(col("a")); + let avg = physical_avg(col("a")); let accum = avg.create_accumulator(); - let input = avg.evaluate_input(batch)?; + let input = avg.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index d7366395ca4..54787b5e16c 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -268,7 +268,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { .aggr_expr .iter() .map(|expr| { - expr.evaluate_input(&batch) + expr.evaluate(&batch) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>>()?; @@ -433,7 +433,7 @@ impl RecordBatchReader for HashAggregateIterator { .aggr_expr .iter() .map(|expr| { - expr.evaluate_input(&batch) + expr.evaluate(&batch) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>>()?; @@ -698,7 +698,7 @@ mod tests { use super::*; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; - use crate::execution::physical_plan::expressions::{col, sum}; + use crate::execution::physical_plan::expressions::{col, physical_sum}; use crate::execution::physical_plan::merge::MergeExec; use crate::test; @@ -716,7 +716,7 @@ mod tests { vec![(col("c2"), "c2".to_string())]; let aggregates: Vec<(Arc, String)> = - vec![(sum(col("c4")), "SUM(c4)".to_string())]; + vec![(physical_sum(col("c4")), "SUM(c4)".to_string())]; let partition_aggregate = HashAggregateExec::try_new( groups.clone(), diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 97098d65b5e..1c97c2139d3 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -20,36 +20,62 @@ use crate::error::ExecutionError; use crate::execution::physical_plan::udf::ScalarFunction; -use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; -use arrow::datatypes::{DataType, Field}; +use arrow::array::{Array, ArrayRef}; +use arrow::array::{Float32Array, Float64Array}; +use arrow::datatypes::DataType; use std::sync::Arc; +macro_rules! compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident) => {{ + let mut builder = <$TYPE>::builder($ARRAY.len()); + for i in 0..$ARRAY.len() { + if $ARRAY.is_null(i) { + builder.append_null()?; + } else { + builder.append_value($ARRAY.value(i).$FUNC())?; + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => compute_op!(array, $FUNC, $TYPE), + _ => Err(ExecutionError::General(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +macro_rules! unary_primitive_array_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident) => {{ + match ($ARRAY).data_type() { + DataType::Float32 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float32Array), + DataType::Float64 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float64Array), + other => Err(ExecutionError::General(format!( + "Unsupported data type {:?} for function {}", + other, $NAME, + ))), + } + }}; +} + macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { ScalarFunction::new( $NAME, - vec![Field::new("n", DataType::Float64, true)], + // order: from faster to slower + vec![vec![DataType::Float32], vec![DataType::Float64]], DataType::Float64, Arc::new(|args: &[ArrayRef]| { - let n = &args[0].as_any().downcast_ref::(); - match n { - Some(array) => { - let mut builder = Float64Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array.value(i).$FUNC())?; - } - } - Ok(Arc::new(builder.finish())) - } - _ => Err(ExecutionError::General(format!( - "Invalid data type for {}", - $NAME - ))), - } + let array = &args[0]; + unary_primitive_array_op!(array, $NAME, $FUNC) }), ) }; @@ -86,7 +112,7 @@ mod tests { execution::context::ExecutionContext, logicalplan::{col, sqrt, LogicalPlanBuilder}, }; - use arrow::datatypes::Schema; + use arrow::datatypes::{Field, Schema}; #[test] fn cast_i8_input() -> Result<()> { @@ -96,7 +122,7 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(CAST(#c0 AS Float64))\ + let expected = "Projection: sqrt(CAST(#c0 AS Float32))\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); Ok(()) @@ -115,4 +141,18 @@ mod tests { assert_eq!(format!("{:?}", plan), expected); Ok(()) } + + #[test] + fn no_cast_f32_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Float32, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(col("c0"))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + let expected = "Projection: sqrt(#c0)\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } } diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 4b66955d036..b16ba8163ea 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -26,7 +26,7 @@ use crate::error::Result; use crate::execution::context::ExecutionContextState; use crate::logicalplan::{LogicalPlan, ScalarValue}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::{ compute::kernels::length::length, record_batch::{RecordBatch, RecordBatchReader}, @@ -68,12 +68,9 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug { fn evaluate(&self, batch: &RecordBatch) -> Result; } -/// Aggregate expression that can be evaluated against a RecordBatch -pub trait AggregateExpr: Send + Sync + Debug { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; - /// Evaluate the expression being aggregated - fn evaluate_input(&self, batch: &RecordBatch) -> Result; +/// An aggregators knows how to accumulate arrays in parts, so that the array does not have +/// to be all available in memory. This type of aggregation is also known as online update. +pub trait Aggregator: Send + Sync + Debug { /// Create an accumulator for this aggregate expression fn create_accumulator(&self) -> Rc>; /// Create an aggregate expression for combining the results of accumulators from partitions. @@ -82,7 +79,12 @@ pub trait AggregateExpr: Send + Sync + Debug { fn create_reducer(&self, column_name: &str) -> Arc; } -/// Aggregate accumulator +/// Aggregate expression that can be evaluated against a RecordBatch +pub trait AggregateExpr: PhysicalExpr + Aggregator {} +impl AggregateExpr for T {} + +/// An accumulator knows how compute aggregations without full access to the complete array. +/// This is also known as online updates. pub trait Accumulator: Debug { /// Update the accumulator based on a row in a batch fn accumulate_scalar(&mut self, value: Option) -> Result<()>; @@ -96,7 +98,7 @@ pub trait Accumulator: Debug { pub fn scalar_functions() -> Vec { let mut udfs = vec![ScalarFunction::new( "length", - vec![Field::new("n", DataType::Utf8, true)], + vec![vec![DataType::Utf8]], DataType::UInt32, Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))), )]; @@ -104,6 +106,11 @@ pub fn scalar_functions() -> Vec { udfs } +/// Vector of aggregate functions declared in this module +pub fn aggregate_functions() -> Vec { + expressions::aggregate_functions() +} + pub mod common; pub mod csv; pub mod datasource; diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 7ce845bf434..6475fccdfc0 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -19,13 +19,14 @@ use std::sync::{Arc, Mutex}; +use super::udf::AggregateFunctionExpr; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::explain::ExplainExec; use crate::execution::physical_plan::expressions::{ - Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, + BinaryExpr, CastExpr, Column, Literal, PhysicalSortExpr, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::limit::GlobalLimitExec; @@ -403,35 +404,32 @@ impl PhysicalPlannerImpl { ) -> Result> { match e { Expr::AggregateFunction { name, args, .. } => { - match name.to_lowercase().as_ref() { - "sum" => Ok(Arc::new(Sum::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "avg" => Ok(Arc::new(Avg::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "max" => Ok(Arc::new(Max::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "min" => Ok(Arc::new(Min::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "count" => Ok(Arc::new(Count::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - other => Err(ExecutionError::NotImplemented(format!( - "Unsupported aggregate function '{}'", - other + match &ctx_state + .lock() + .expect("failed to lock mutex") + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + { + Some(f) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(self.create_physical_expr( + e, + input_schema, + ctx_state.clone(), + )?); + } + Ok(Arc::new(AggregateFunctionExpr::new( + name, + physical_args, + Box::new(f.as_ref().clone()), + ))) + } + _ => Err(ExecutionError::General(format!( + "Invalid aggregate function '{:?}'", + name ))), } } diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 944b5c9bef3..6fb68c92fd0 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -20,25 +20,33 @@ use std::fmt; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; use crate::execution::physical_plan::PhysicalExpr; +use super::{Accumulator, AggregateExpr, Aggregator}; use arrow::record_batch::RecordBatch; use fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{cell::RefCell, rc::Rc, sync::Arc}; /// Scalar UDF pub type ScalarUdf = Arc Result + Send + Sync>; +/// Function to construct the return type of a function given its arguments. +pub type ReturnType = + Arc>, &Schema) -> Result + Send + Sync>; + /// Scalar UDF Expression #[derive(Clone)] pub struct ScalarFunction { /// Function name pub name: String, - /// Function argument meta-data - pub args: Vec, + /// Set of valid argument types. + /// The first dimension (0) represents specific combinations of valid argument types + /// The second dimension (1) represents the types of each argument. + /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second + pub arg_types: Vec>, /// Return type pub return_type: DataType, /// UDF implementation @@ -49,7 +57,7 @@ impl Debug for ScalarFunction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("ScalarFunction") .field("name", &self.name) - .field("args", &self.args) + .field("arg_types", &self.arg_types) .field("return_type", &self.return_type) .field("fun", &"") .finish() @@ -60,13 +68,13 @@ impl ScalarFunction { /// Create a new ScalarFunction pub fn new( name: &str, - args: Vec, + arg_types: Vec>, return_type: DataType, fun: ScalarUdf, ) -> Self { Self { name: name.to_owned(), - args, + arg_types, return_type, fun, } @@ -146,3 +154,99 @@ impl PhysicalExpr for ScalarFunctionExpr { (fun)(&inputs) } } + +/// A generic aggregate function +/* +An aggregate function accepts an arbitrary number of arguments, of arbitrary data types, +and returns an arbitrary type based on the incoming types. + +It is the developer of the function's responsibility to ensure that the aggregator correctly handles the different +types that are presented to them, and that the return type correctly matches the type returned by the +aggregator. + +It is the user of the function's responsibility to pass arguments to the function that have valid types. +*/ +#[derive(Clone)] +pub struct AggregateFunction { + /// Function name + pub name: String, + /// A list of arguments and their respective types. A function can accept more than one type as argument + /// (e.g. sum(i8), sum(u8)). + pub arg_types: Vec>, + /// Return type. This function takes + pub return_type: ReturnType, + /// implementation of the aggregation + pub aggregate: Arc, +} + +/// An aggregate function physical expression +pub struct AggregateFunctionExpr { + name: String, + fun: Box, + // for now, our AggregateFunctionExpr accepts a single element only. + arg: Arc, +} + +impl AggregateFunctionExpr { + /// Create a new AggregateFunctionExpr + pub fn new( + name: &str, + args: Vec>, + fun: Box, + ) -> Self { + Self { + name: name.to_owned(), + arg: args[0].clone(), + fun, + } + } +} + +impl Debug for AggregateFunctionExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("AggregateFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.arg) + .finish() + } +} + +impl PhysicalExpr for AggregateFunctionExpr { + fn data_type(&self, input_schema: &Schema) -> Result { + self.fun.as_ref().return_type.as_ref()(&vec![self.arg.clone()], input_schema) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.arg.evaluate(batch) + } +} + +impl Aggregator for AggregateFunctionExpr { + fn create_accumulator(&self) -> Rc> { + self.fun.aggregate.create_accumulator() + } + + fn create_reducer(&self, column_name: &str) -> Arc { + self.fun.aggregate.create_reducer(column_name) + } +} + +impl fmt::Display for AggregateFunctionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + [&self.arg] + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index ee3c17aedd2..f9f9d4e021c 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -47,8 +47,8 @@ pub enum FunctionType { pub struct FunctionMeta { /// Function name name: String, - /// Function arguments - args: Vec, + /// Function argument types + arg_types: Vec>, /// Function return type return_type: DataType, /// Function type (Scalar or Aggregate) @@ -59,13 +59,13 @@ impl FunctionMeta { #[allow(missing_docs)] pub fn new( name: String, - args: Vec, + arg_types: Vec>, return_type: DataType, function_type: FunctionType, ) -> Self { FunctionMeta { name, - args, + arg_types, return_type, function_type, } @@ -75,8 +75,8 @@ impl FunctionMeta { &self.name } /// Getter for the arg list - pub fn args(&self) -> &Vec { - &self.args + pub fn arg_types(&self) -> &Vec> { + &self.arg_types } /// Getter for the `DataType` the function returns pub fn return_type(&self) -> &DataType { @@ -603,7 +603,7 @@ pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { } } -/// Create an aggregate expression +/// Create an scalar function expression pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { Expr::ScalarFunction { name: name.to_owned(), @@ -612,6 +612,15 @@ pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Ex } } +/// Create an aggregate expression +pub fn aggregate_function(name: &str, expr: Vec, return_type: DataType) -> Expr { + Expr::AggregateFunction { + name: name.to_owned(), + args: expr, + return_type, + } +} + impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index a485423975d..84fca1c2d48 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -23,10 +23,10 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Schema}; use crate::error::{ExecutionError, Result}; -use crate::execution::physical_plan::udf::ScalarFunction; +use crate::execution::physical_plan::udf::{AggregateFunction, ScalarFunction}; use crate::logicalplan::Expr; use crate::logicalplan::LogicalPlan; use crate::optimizer::optimizer::OptimizerRule; @@ -38,6 +38,7 @@ use utils::optimize_explain; /// This optimizer does not alter the structure of the plan, it only changes expressions on it. pub struct TypeCoercionRule { scalar_functions: Arc>>>, + aggregate_functions: Arc>>>, } impl TypeCoercionRule { @@ -45,8 +46,12 @@ impl TypeCoercionRule { /// scalar functions pub fn new( scalar_functions: Arc>>>, + aggregate_functions: Arc>>>, ) -> Self { - Self { scalar_functions } + Self { + scalar_functions, + aggregate_functions, + } } /// Rewrite an expression to include explicit CAST operations when required @@ -71,6 +76,50 @@ impl TypeCoercionRule { expressions[1] = expressions[1].cast_to(&super_type, schema)?; } } + Expr::AggregateFunction { name, .. } => { + match self + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + { + Some(func_meta) => { + // compute the current types and expressions + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new = if func_meta.arg_types.contains(¤t_types) { + Some(expressions) + } else { + maybe_rewrite( + &expressions, + ¤t_types, + &schema, + &func_meta.arg_types, + )? + }; + + if let Some(args) = new { + expressions = args; + } else { + return Err(ExecutionError::General(format!( + "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", + func_meta.name, + func_meta.arg_types, + current_types, + ))); + } + } + None => { + return Err(ExecutionError::General(format!( + "Invalid aggregate function {}", + name + ))) + } + } + } Expr::ScalarFunction { name, .. } => { // cast the inputs of scalar functions to the appropriate type where possible match self @@ -80,17 +129,8 @@ impl TypeCoercionRule { .get(name) { Some(func_meta) => { - for i in 0..expressions.len() { - let field = &func_meta.args[i]; - let actual_type = expressions[i].get_type(schema)?; - let required_type = field.data_type(); - if &actual_type != required_type { - let super_type = - utils::get_supertype(&actual_type, required_type)?; - expressions[i] = - expressions[i].cast_to(&super_type, schema)? - }; - } + expressions = + rewrite_args(expressions, schema, func_meta.as_ref())?; } _ => { return Err(ExecutionError::General(format!( @@ -147,6 +187,78 @@ impl OptimizerRule for TypeCoercionRule { } } +/// rewrites +fn rewrite_args( + expressions: Vec, + schema: &Schema, + func_meta: &ScalarFunction, +) -> Result> { + // compute the current types and expressions + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new = if func_meta.arg_types.contains(¤t_types) { + Some(expressions) + } else { + maybe_rewrite(&expressions, ¤t_types, &schema, &func_meta.arg_types)? + }; + + if let Some(args) = new { + Ok(args) + } else { + Err(ExecutionError::General(format!( + "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", + func_meta.name, + func_meta.arg_types, + current_types, + ))) + } +} + +/// tries to re-cast expressions under schema based on the set of valid signatures +fn maybe_rewrite( + expressions: &Vec, + current_types: &Vec, + schema: &Schema, + signature: &Vec>, +) -> Result>> { + // for each set of valid signatures, try to coerce all expressions to one of them + for valid_types in signature { + // for each option, try to coerce all arguments to it + if let Some(types) = maybe_coerce(valid_types, ¤t_types) { + // yes: let's re-write the expressions + return Ok(Some( + expressions + .iter() + .enumerate() + .map(|(i, expr)| expr.cast_to(&types[i], schema)) + .collect::>>()?, + )); + } + // we cannot: try the next + } + Ok(None) +} + +/// Try to coerce current_types into valid_types +fn maybe_coerce( + valid_types: &Vec, + current_types: &Vec, +) -> Option> { + let mut super_type = Vec::with_capacity(valid_types.len()); + for (i, valid_type) in valid_types.iter().enumerate() { + let current_type = ¤t_types[i]; + if let Ok(t) = utils::get_supertype(current_type, valid_type) { + super_type.push(t) + } else { + return None; + } + } + Some(super_type) +} + #[cfg(test)] mod tests { use super::*; @@ -168,14 +280,15 @@ mod tests { .project(vec![col("c1"), col("c2")])? .aggregate( vec![col("c1")], - vec![aggregate_expr("SUM", col("c2"), DataType::Int64)], + vec![aggregate_expr("sum", col("c2"), DataType::Int64)], )? .sort(vec![col("c1")])? .limit(10)? .build()?; - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(Arc::new(Mutex::new(scalar_functions))); + let ctx = ExecutionContext::new(); + let mut rule = + TypeCoercionRule::new(ctx.scalar_functions(), ctx.aggregate_functions()); let plan = rule.optimize(&plan)?; // check that the filter had a cast added @@ -183,7 +296,7 @@ mod tests { println!("{}", plan_str); let expected_plan_str = "Limit: 10 Sort: #c1 - Aggregate: groupBy=[[#c1]], aggr=[[SUM(#c2)]] + Aggregate: groupBy=[[#c1]], aggr=[[sum(#c2)]] Projection: #c1, #c2 Selection: #c7 Lt CAST(UInt8(5) AS Int64)"; assert!(plan_str.starts_with(expected_plan_str)); @@ -201,8 +314,10 @@ mod tests { .filter(col("c7").lt(col("c12")))? .build()?; - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(Arc::new(Mutex::new(scalar_functions))); + let mut rule = TypeCoercionRule::new( + Arc::new(Mutex::new(HashMap::new())), + Arc::new(Mutex::new(HashMap::new())), + ); let plan = rule.optimize(&plan)?; assert!( @@ -281,10 +396,150 @@ mod tests { }; let ctx = ExecutionContext::new(); - let rule = TypeCoercionRule::new(ctx.scalar_functions()); + let rule = + TypeCoercionRule::new(ctx.scalar_functions(), ctx.aggregate_functions()); let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); assert_eq!(expected, format!("{:?}", expr2)); } + + #[test] + fn test_maybe_coerce() -> Result<()> { + // this vec contains: arg1, arg2, expected result + let cases = vec![ + // 2 entries, same values + ( + vec![DataType::UInt8, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt8, DataType::UInt16]), + ), + // 2 entries, can coerse values + ( + vec![DataType::UInt16, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt16, DataType::UInt16]), + ), + // 0 entries, all good + (vec![], vec![], Some(vec![])), + // 2 entries, can't coerce + ( + vec![DataType::Boolean, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + None, + ), + // u32 -> u16 is possible + ( + vec![DataType::Boolean, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt16], + Some(vec![DataType::Boolean, DataType::UInt32]), + ), + ]; + + for case in cases { + assert_eq!(maybe_coerce(&case.0, &case.1), case.2) + } + Ok(()) + } + + #[test] + fn test_maybe_rewrite() -> Result<()> { + // create a schema + let schema = |t: Vec| { + Schema::new( + t.iter() + .enumerate() + .map(|(i, t)| Field::new(&*format!("c{}", i), t.clone(), true)) + .collect(), + ) + }; + + // create a vector of expressions + let expressions = |t: Vec, schema| -> Result> { + t.iter() + .enumerate() + .map(|(i, t)| col(&*format!("c{}", i)).cast_to(&t, &schema)) + .collect::>>() + }; + + // map expr + schema to types + let current_types = |expressions: &Vec, schema| -> Result> { + Ok(expressions + .iter() + .map(|e| e.get_type(&schema)) + .collect::>>()?) + }; + + // create a case: input + expected result + let case = |observed: Vec, + valid, + expected: Option>| + -> Result<_> { + let schema = schema(observed.clone()); + let expr = expressions(observed, schema.clone())?; + let expected = if let Some(e) = expected { + // expressions re-written as cast + Some(expressions(e, schema.clone())?) + } else { + None + }; + Ok(( + expr.clone(), + current_types(&expr, schema.clone())?, + schema, + valid, + expected, + )) + }; + + let cases = vec![ + // no conversion -> all good + case(vec![], vec![vec![]], Some(vec![]))?, + // u16 -> u32 + case( + vec![DataType::UInt16, DataType::UInt32], + vec![vec![DataType::UInt32, DataType::UInt32]], + Some(vec![DataType::UInt32, DataType::UInt32]), + )?, + // same type + case( + vec![DataType::UInt16, DataType::UInt32], + vec![vec![DataType::UInt16, DataType::UInt32]], + Some(vec![DataType::UInt16, DataType::UInt32]), + )?, + // we do not know how to cast bool to UInt16 => fail + case( + vec![DataType::Boolean, DataType::UInt32], + vec![vec![DataType::UInt16, DataType::UInt32]], + None, + )?, + // we do not know how to cast (bool,u16) to (u16,u32), + // but we know to cast to (bool,u32) + case( + vec![DataType::Boolean, DataType::UInt16], + vec![ + vec![DataType::UInt16, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt32], + ], + Some(vec![DataType::Boolean, DataType::UInt32]), + )?, + // we do not know how to cast (bool,u16) to (u16,u32) nor (u32,u16) + case( + vec![DataType::Boolean, DataType::UInt32], + vec![ + vec![DataType::UInt16, DataType::UInt32], + vec![DataType::UInt32, DataType::UInt16], + ], + None, + )?, + ]; + + for (i, case) in cases.iter().enumerate() { + if maybe_rewrite(&case.0, &case.1, &case.2, &case.3)? != case.4 { + assert_eq!(maybe_rewrite(&case.0, &case.1, &case.2, &case.3)?, case.4); + return Err(ExecutionError::General(format!("case {} failed", i))); + } + } + Ok(()) + } } diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 16da4527557..37f696e3542 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -22,8 +22,8 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::logicalplan::Expr::Alias; use crate::logicalplan::{ - lit, Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, - ScalarValue, StringifiedPlan, + lit, Expr, FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder, Operator, + PlanType, ScalarValue, StringifiedPlan, }; use crate::sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}; @@ -44,6 +44,8 @@ pub trait SchemaProvider { fn get_table_meta(&self, name: &str) -> Option; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; + /// Getter list of valid udfs + fn functions(&self) -> Vec; } /// SQL query planner @@ -476,70 +478,50 @@ impl SqlToRel { } SQLExpr::Function(function) => { - //TODO: fix this hack - let name: String = function.name.to_string(); - match name.to_lowercase().as_ref() { - "min" | "max" | "sum" | "avg" => { - let rex_args = function - .args - .iter() - .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; - - // return type is same as the argument type for these aggregate - // functions - let return_type = rex_args[0].get_type(schema)?.clone(); - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - return_type, - }) - } - "count" => { - let rex_args = function - .args - .iter() - .map(|a| match a { - SQLExpr::Value(Value::Number(_)) => Ok(lit(1_u8)), - SQLExpr::Wildcard => Ok(lit(1_u8)), - _ => self.sql_to_rex(a, schema), - }) - .collect::>>()?; - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - return_type: DataType::UInt64, - }) - } - _ => match self.schema_provider.get_function_meta(&name) { - Some(fm) => { - let rex_args = function + // make the search case-insensitive + let name: String = function.name.to_string().to_lowercase(); + + match self.schema_provider.get_function_meta(&name) { + Some(fm) => { + let args = if name == "count" { + // optimization to avoid computing expressions + function + .args + .iter() + .map(|a| match a { + SQLExpr::Value(Value::Number(_)) => Ok(lit(1_u8)), + SQLExpr::Wildcard => Ok(lit(1_u8)), + _ => self.sql_to_rex(a, schema), + }) + .collect::>>()? + } else { + function .args .iter() .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; + .collect::>>()? + }; - let mut safe_args: Vec = vec![]; - for i in 0..rex_args.len() { - safe_args.push( - rex_args[i] - .cast_to(fm.args()[i].data_type(), schema)?, - ); - } + //let args = coerse_expr(&args, &fm, &schema)?; - Ok(Expr::ScalarFunction { + match fm.function_type() { + FunctionType::Scalar => Ok(Expr::ScalarFunction { name: name.clone(), - args: safe_args, + args, return_type: fm.return_type().clone(), - }) + }), + FunctionType::Aggregate => Ok(Expr::AggregateFunction { + name: name.clone(), + args, + return_type: fm.return_type().clone(), + }), } - _ => Err(ExecutionError::General(format!( - "Invalid function '{}'", - name - ))), - }, + } + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'. Valid functions: {:?}", + name, + self.schema_provider.functions(), + ))), } } @@ -598,7 +580,7 @@ mod tests { fn select_scalar_func_with_literal_no_relation() { quick_test( "SELECT sqrt(9)", - "Projection: sqrt(CAST(Int64(9) AS Float64))\ + "Projection: sqrt(Int64(9))\ \n EmptyRelation", ); } @@ -685,7 +667,7 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[min(#age)]]\ \n TableScan: person projection=None", ); } @@ -694,7 +676,7 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[sum(#age)]]\ \n TableScan: person projection=None", ); } @@ -703,7 +685,7 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Aggregate: groupBy=[[#state]], aggr=[[min(#age), max(#age)]]\ \n TableScan: person projection=None", ); } @@ -720,7 +702,7 @@ mod tests { #[test] fn select_count_one() { let sql = "SELECT COUNT(1) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -728,7 +710,7 @@ mod tests { #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(#id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -736,7 +718,7 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64))\ + let expected = "Projection: sqrt(#age)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -744,7 +726,7 @@ mod tests { #[test] fn select_aliased_scalar_func() { let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64)) AS square_people\ + let expected = "Projection: sqrt(#age) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -797,8 +779,8 @@ mod tests { fn select_group_by_needs_projection() { let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; let expected = "\ - Projection: #COUNT(state), #state\ - \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\ + Projection: #count(state), #state\ + \n Aggregate: groupBy=[[#state]], aggr=[[count(#state)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -909,15 +891,43 @@ mod tests { } fn get_function_meta(&self, name: &str) -> Option> { - match name { - "sqrt" => Some(Arc::new(FunctionMeta::new( - "sqrt".to_string(), - vec![Field::new("n", DataType::Float64, false)], + let fnc_type = if name == "sqrt" { + FunctionType::Scalar + } else { + FunctionType::Aggregate + }; + let valid_types = if name == "sqrt" { + vec![DataType::Float64] + } else { + vec![ + DataType::UInt8, + DataType::UInt32, + DataType::Int64, + DataType::Int32, DataType::Float64, - FunctionType::Scalar, - ))), + DataType::Utf8, + ] + }; + + let fm = Arc::new(FunctionMeta::new( + name.to_string(), + vec![valid_types], + DataType::Float64, + fnc_type, + )); + + match name { + "sqrt" => Some(fm), + "min" => Some(fm), + "max" => Some(fm), + "sum" => Some(fm), + "count" => Some(fm), _ => None, } } + + fn functions(&self) -> Vec { + vec!["sqrt".to_string()] + } } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index d2d5349d5e4..9a1807f2fa2 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -201,17 +201,111 @@ fn csv_query_avg_sqrt() -> Result<()> { Ok(()) } +#[test] +fn csv_query_avg_custom_udf_f64_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c12 is f64 + let sql = "SELECT avg(custom_add(c12, c12)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // perform equivalent calculation + let sql = "SELECT avg(c12 + c12) FROM aggregate_test_100"; + let expected = execute(&mut ctx, sql); + + // verify equality + assert_eq!(actual.join("\n"), expected.join("\n")); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_f32_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c11 is f32 + // c12 is f64 + let sql = "SELECT avg(custom_add(c11, c12)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as f32,f64 returns a constant 3264.0 + let expected = "3264.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_f32_f32() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c11 is f32 + let sql = "SELECT avg(custom_add(c11, c11)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as f32,f32 returns 3232.0 + let expected = "3232.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_i8() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c3 is i8, castable to float32 + let sql = "SELECT avg(custom_add(c3, c3)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as float32,float32 returns a constant 1111.0 + let expected = "3232.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_utf8() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // utf8 is currently convertable to any type. See https://issues.apache.org/jira/browse/ARROW-4957 + let sql = "SELECT avg(custom_add(c1, c1)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted on any other type returns a constant 1111.0 + let expected = "1111.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + fn create_ctx() -> Result { let mut ctx = ExecutionContext::new(); - // register a custom UDF + // register a UDF of 1 argument ctx.register_udf(ScalarFunction::new( "custom_sqrt", - vec![Field::new("n", DataType::Float64, true)], + vec![vec![DataType::Float64]], DataType::Float64, Arc::new(custom_sqrt), )); + // register a udf of two arguments + ctx.register_udf(ScalarFunction::new( + "custom_add", + vec![ + vec![DataType::Float32, DataType::Float32], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ], + DataType::Float64, + Arc::new(custom_add), + )); + Ok(ctx) } @@ -232,6 +326,55 @@ fn custom_sqrt(args: &[ArrayRef]) -> Result { Ok(Arc::new(builder.finish())) } +fn custom_add(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Float64, DataType::Float64) => { + let input1 = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let input2 = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input1.len()); + for i in 0..input1.len() { + if input1.is_null(i) || input2.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(input1.value(i) + input2.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + (DataType::Float32, DataType::Float32) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(3232.0)?; + } + Ok(Arc::new(builder.finish())) + } + (DataType::Float32, DataType::Float64) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(3264.0)?; + } + Ok(Arc::new(builder.finish())) + } + (_, _) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(1111.0)?; + } + Ok(Arc::new(builder.finish())) + } + } +} + #[test] fn csv_query_avg() -> Result<()> { let mut ctx = ExecutionContext::new();