diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 0a193684265..1ee4f10578b 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -53,6 +53,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] Limit - [x] Aggregate - [x] UDFs + - [x] Scalar UDFs + - [x] Aggregate UDFs - [x] Common math functions - String functions - [x] Length of the string diff --git a/rust/datafusion/src/datatyped.rs b/rust/datafusion/src/datatyped.rs new file mode 100644 index 00000000000..f1ef241027c --- /dev/null +++ b/rust/datafusion/src/datatyped.rs @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains a public trait to annotate objects that know their data type. +// The pattern in this module follows https://stackoverflow.com/a/28664881/931303 + +use crate::error::Result; +use arrow::datatypes::{DataType, Schema}; + +/// Any object that knows how to infer its resulting data type from an underlying schema +pub trait DataTyped: AsDataTyped { + fn get_type(&self, input_schema: &Schema) -> Result; +} + +/// Trait that allows DataTyped objects to be upcasted to DataTyped. +pub trait AsDataTyped { + fn as_datatyped(&self) -> &dyn DataTyped; +} + +impl AsDataTyped for T { + fn as_datatyped(&self) -> &dyn DataTyped { + self + } +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 8f92aae302b..7ae806eef6a 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -28,6 +28,7 @@ use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; +use super::physical_plan::udf::AggregateFunction; use crate::dataframe::DataFrame; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; @@ -38,10 +39,10 @@ use crate::execution::physical_plan::common; use crate::execution::physical_plan::csv::CsvReadOptions; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::planner::PhysicalPlannerImpl; -use crate::execution::physical_plan::scalar_functions; use crate::execution::physical_plan::udf::ScalarFunction; use crate::execution::physical_plan::ExecutionPlan; use crate::execution::physical_plan::PhysicalPlanner; +use crate::execution::physical_plan::{aggregate_functions, scalar_functions}; use crate::logicalplan::{FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::projection_push_down::ProjectionPushDown; @@ -103,12 +104,16 @@ impl ExecutionContext { state: Arc::new(Mutex::new(ExecutionContextState { datasources: Arc::new(Mutex::new(HashMap::new())), scalar_functions: Arc::new(Mutex::new(HashMap::new())), + aggregate_functions: Arc::new(Mutex::new(HashMap::new())), config, })), }; for udf in scalar_functions() { ctx.register_udf(udf); } + for udf in aggregate_functions() { + ctx.register_aggregate_udf(udf); + } ctx } @@ -200,6 +205,16 @@ impl ExecutionContext { .insert(f.name.clone(), Box::new(f)); } + /// Register an aggregate function + pub fn register_aggregate_udf(&mut self, f: AggregateFunction) { + let state = self.state.lock().expect("failed to lock mutex"); + state + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .insert(f.name.clone(), Box::new(f)); + } + /// Get a reference to the registered scalar functions pub fn scalar_functions(&self) -> Arc>>> { self.state @@ -209,6 +224,17 @@ impl ExecutionContext { .clone() } + /// Get a reference to the registered aggregate functions + pub fn aggregate_functions( + &self, + ) -> Arc>>> { + self.state + .lock() + .expect("failed to lock mutex") + .aggregate_functions + .clone() + } + /// Creates a DataFrame for reading a CSV data source. pub fn read_csv( &mut self, @@ -336,11 +362,8 @@ impl ExecutionContext { let rules: Vec> = vec![ Box::new(ProjectionPushDown::new()), Box::new(TypeCoercionRule::new( - self.state - .lock() - .expect("failed to lock mutex") - .scalar_functions - .clone(), + self.scalar_functions(), + self.aggregate_functions(), )), ]; let mut plan = plan.clone(); @@ -495,6 +518,8 @@ pub struct ExecutionContextState { pub datasources: Arc>>>, /// Scalar functions that are registered with the context pub scalar_functions: Arc>>>, + /// Aggregate functions that are registered with the context + pub aggregate_functions: Arc>>>, /// Context configuration pub config: ExecutionConfig, } @@ -509,19 +534,56 @@ impl SchemaProvider for ExecutionContextState { } fn get_function_meta(&self, name: &str) -> Option> { - self.scalar_functions + let scalar = self + .scalar_functions .lock() .expect("failed to lock mutex") .get(name) .map(|f| { Arc::new(FunctionMeta::new( name.to_owned(), - f.args.clone(), + f.arg_types.clone(), f.return_type.clone(), FunctionType::Scalar, )) + }); + // give priority to scalar functions + if scalar.is_some() { + return scalar; + } + + self.aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + .map(|f| { + Arc::new(FunctionMeta::new( + name.to_owned(), + f.arg_types.clone(), + f.return_type.clone(), + FunctionType::Aggregate, + )) }) } + + fn functions(&self) -> Vec { + let mut scalars: Vec = self + .scalar_functions + .lock() + .expect("failed to lock mutex") + .keys() + .cloned() + .collect(); + let mut aggregates: Vec = self + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .keys() + .cloned() + .collect(); + aggregates.append(&mut scalars); + aggregates + } } #[cfg(test)] @@ -529,13 +591,17 @@ mod tests { use super::*; use crate::datasource::MemTable; - use crate::execution::physical_plan::udf::ScalarUdf; - use crate::logicalplan::{aggregate_expr, col, scalar_function}; + use crate::execution::physical_plan::{ + expressions::Column, + udf::{AggregateFunctionExpr, ScalarUdf}, + Accumulator, AggregateExpr, Aggregator, + }; + use crate::logicalplan::{aggregate_expr, col, scalar_function, ScalarValue}; use crate::test; - use arrow::array::{ArrayRef, Int32Array}; - use arrow::compute::add; + use arrow::array::{ArrayRef, Float64Array, Int32Array}; + use arrow::compute::{add, sum}; use std::fs::File; - use std::io::prelude::*; + use std::{cell::RefCell, io::prelude::*, rc::Rc}; use tempdir::TempDir; use test::*; @@ -737,7 +803,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["SUM(c1)", "SUM(c2)"]); + assert_eq!(field_names(batch), vec!["sum(c1)", "sum(c2)"]); let expected: Vec<&str> = vec!["60,220"]; let mut rows = test::format_batch(&batch); @@ -754,7 +820,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["AVG(c1)", "AVG(c2)"]); + assert_eq!(field_names(batch), vec!["avg(c1)", "avg(c2)"]); let expected: Vec<&str> = vec!["1.5,5.5"]; let mut rows = test::format_batch(&batch); @@ -771,7 +837,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["MAX(c1)", "MAX(c2)"]); + assert_eq!(field_names(batch), vec!["max(c1)", "max(c2)"]); let expected: Vec<&str> = vec!["3,10"]; let mut rows = test::format_batch(&batch); @@ -788,7 +854,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["MIN(c1)", "MIN(c2)"]); + assert_eq!(field_names(batch), vec!["min(c1)", "min(c2)"]); let expected: Vec<&str> = vec!["0,1"]; let mut rows = test::format_batch(&batch); @@ -805,7 +871,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "SUM(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "sum(c2)"]); let expected: Vec<&str> = vec!["0,55", "1,55", "2,55", "3,55"]; let mut rows = test::format_batch(&batch); @@ -822,7 +888,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "AVG(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "avg(c2)"]); let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; let mut rows = test::format_batch(&batch); @@ -839,7 +905,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "MAX(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "max(c2)"]); let expected: Vec<&str> = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); @@ -856,7 +922,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "MIN(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "min(c2)"]); let expected: Vec<&str> = vec!["0,1", "1,1", "2,1", "3,1"]; let mut rows = test::format_batch(&batch); @@ -873,7 +939,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["count(c1)", "count(c2)"]); let expected: Vec<&str> = vec!["10,10"]; let mut rows = test::format_batch(&batch); @@ -889,7 +955,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["COUNT(c1)", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["count(c1)", "count(c2)"]); let expected: Vec<&str> = vec!["40,40"]; let mut rows = test::format_batch(&batch); @@ -905,7 +971,7 @@ mod tests { let batch = &results[0]; - assert_eq!(field_names(batch), vec!["c1", "COUNT(c2)"]); + assert_eq!(field_names(batch), vec!["c1", "count(c2)"]); let expected = vec!["0,10", "1,10", "2,10", "3,10"]; let mut rows = test::format_batch(&batch); @@ -927,9 +993,13 @@ mod tests { let plan = LogicalPlanBuilder::scan("default", "test", schema.as_ref(), None)? .aggregate( vec![col("c1")], - vec![aggregate_expr("SUM", col("c2"), DataType::UInt32)], + vec![aggregate_expr( + "sum", + col("c2"), + Arc::new(|_, _| Ok(DataType::Int32)), + )], )? - .project(vec![col("c1"), col("SUM(c2)").alias("total_salary")])? + .project(vec![col("c1"), col("sum(c2)").alias("total_salary")])? .build()?; let plan = ctx.optimize(&plan)?; @@ -1049,11 +1119,8 @@ mod tests { let my_add = ScalarFunction::new( "my_add", - vec![ - Field::new("a", DataType::Int32, true), - Field::new("b", DataType::Int32, true), - ], - DataType::Int32, + vec![vec![DataType::Int32, DataType::Int32]], + Arc::new(|_, _| Ok(DataType::Int32)), myfunc, ); @@ -1065,7 +1132,11 @@ mod tests { .project(vec![ col("a"), col("b"), - scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32), + scalar_function( + "my_add", + vec![col("a"), col("b")], + Arc::new(|_, _| Ok(DataType::Int32)), + ), ])? .build()?; @@ -1205,6 +1276,107 @@ mod tests { CsvReadOptions::new().schema(&schema), )?; + ctx.register_aggregate_udf(avg()); + Ok(ctx) } + + // declare an accumulator of an average of f64 + // math details here: https://stackoverflow.com/a/23493727/931303 + #[derive(Debug)] + struct MyAvg { + avg: f64, // online average + n: usize, + } + + impl Accumulator for MyAvg { + fn accumulate_scalar(&mut self, value: Option) -> Result<()> { + if let Some(value) = value { + match value { + ScalarValue::Float64(v) => { + self.n += 1; + self.avg = (self.avg * ((self.n - 1) as f64) - v) / self.n as f64; + } + _ => { + return Err(ExecutionError::ExecutionError(format!( + "Unsupported type {:?}.", + value + ))) + } + } + } + Ok(()) + } + + fn accumulate_batch(&mut self, array: &ArrayRef) -> Result<()> { + match array.data_type() { + DataType::Float64 => { + let array = array + .as_any() + .downcast_ref::() + .expect("Failed to cast array"); + let sum = sum(array).unwrap_or(0.0); + let m = array.len(); + + self.n += m; + self.avg = (self.avg * (self.n - m) as f64 - sum) / self.n as f64; + } + _ => { + return Err(ExecutionError::ExecutionError(format!( + "Unsupported type {:?}.", + array.data_type() + ))) + } + } + Ok(()) + } + + fn get_value(&self) -> Result> { + Ok(Some(ScalarValue::Float64(self.avg))) + } + } + + fn avg() -> AggregateFunction { + AggregateFunction { + name: "my_avg".to_string(), + return_type: Arc::new(|_, _| Ok(DataType::Float64)), + arg_types: vec![vec![DataType::Float64]], + aggregate: Arc::new(MyAvg { avg: 0.0, n: 0 }), + } + } + + // implement it on the same struct just for fun + impl Aggregator for MyAvg { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MyAvg { avg: 0.0, n: 0 })) + } + + fn create_reducer(&self, column_name: &str) -> Arc { + Arc::new(AggregateFunctionExpr::new( + "my_avg", + vec![Arc::new(Column::new(column_name))], + Box::new(avg()), + )) + } + } + + #[test] + fn aggregate_custom_agg() -> Result<()> { + let results = execute("SELECT c1, my_avg(c2) FROM test GROUP BY c1", 4)?; + assert_eq!(results.len(), 1); + + let batch = &results[0]; + + assert_eq!( + field_names(batch), + vec!["c1", "my_avg(CAST(c2 as Float64))"] + ); + + let expected: Vec<&str> = vec!["0,5.5", "1,5.5", "2,5.5", "3,5.5"]; + let mut rows = test::format_batch(&batch); + rows.sort(); + assert_eq!(rows, expected); + + Ok(()) + } } diff --git a/rust/datafusion/src/execution/dataframe_impl.rs b/rust/datafusion/src/execution/dataframe_impl.rs index 491cb70a3a4..bd471041d49 100644 --- a/rust/datafusion/src/execution/dataframe_impl.rs +++ b/rust/datafusion/src/execution/dataframe_impl.rs @@ -19,10 +19,9 @@ use std::sync::{Arc, Mutex}; -use crate::arrow::datatypes::DataType; use crate::arrow::record_batch::RecordBatch; use crate::dataframe::*; -use crate::error::{ExecutionError, Result}; +use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; use crate::logicalplan::{col, Expr, LogicalPlan, LogicalPlanBuilder}; use arrow::datatypes::Schema; @@ -94,27 +93,27 @@ impl DataFrame for DataFrameImpl { /// Create an expression to represent the min() aggregate function fn min(&self, expr: Expr) -> Result { - self.aggregate_expr("MIN", expr) + self.aggregate_expr("min", expr) } /// Create an expression to represent the max() aggregate function fn max(&self, expr: Expr) -> Result { - self.aggregate_expr("MAX", expr) + self.aggregate_expr("max", expr) } /// Create an expression to represent the sum() aggregate function fn sum(&self, expr: Expr) -> Result { - self.aggregate_expr("SUM", expr) + self.aggregate_expr("sum", expr) } /// Create an expression to represent the avg() aggregate function fn avg(&self, expr: Expr) -> Result { - self.aggregate_expr("AVG", expr) + self.aggregate_expr("avg", expr) } /// Create an expression to represent the count() aggregate function fn count(&self, expr: Expr) -> Result { - self.aggregate_expr("COUNT", expr) + self.aggregate_expr("count", expr) } /// Convert to logical plan @@ -134,29 +133,12 @@ impl DataFrame for DataFrameImpl { } impl DataFrameImpl { - /// Determine the data type for a given expression - fn get_data_type(&self, expr: &Expr) -> Result { - match expr { - Expr::Column(name) => Ok(self - .plan - .schema() - .field_with_name(name)? - .data_type() - .clone()), - _ => Err(ExecutionError::General(format!( - "Could not determine data type for expr {:?}", - expr - ))), - } - } - /// Create an expression to represent a named aggregate function fn aggregate_expr(&self, name: &str, expr: Expr) -> Result { - let return_type = self.get_data_type(&expr)?; Ok(Expr::AggregateFunction { name: name.to_string(), - args: vec![expr.clone()], - return_type, + args: vec![expr], + return_type: Arc::new(|e, schema| e[0].get_type(&schema)), }) } } @@ -218,7 +200,7 @@ mod tests { let plan = t2.to_logical_plan(); // build same plan using SQL API - let sql = "SELECT c1, MIN(c12), MAX(c12), AVG(c12), SUM(c12), COUNT(c12) \ + let sql = "SELECT c1, min(c12), max(c12), avg(c12), sum(c12), count(c12) \ FROM aggregate_test_100 \ GROUP BY c1"; let sql_plan = create_plan(sql)?; diff --git a/rust/datafusion/src/execution/physical_plan/expressions.rs b/rust/datafusion/src/execution/physical_plan/expressions.rs index dff913b9162..b1fdb2f96ff 100644 --- a/rust/datafusion/src/execution/physical_plan/expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/expressions.rs @@ -24,8 +24,14 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::execution::physical_plan::common::get_scalar_value; -use crate::execution::physical_plan::{Accumulator, AggregateExpr, PhysicalExpr}; -use crate::logicalplan::{Operator, ScalarValue}; +use crate::execution::physical_plan::udf; +use crate::execution::physical_plan::{ + Accumulator, AggregateExpr, Aggregator, PhysicalExpr, +}; +use crate::{ + datatyped::DataTyped, + logicalplan::{Operator, ScalarValue}, +}; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, TimestampNanosecondArray, UInt16Array, @@ -47,6 +53,7 @@ use arrow::compute::kernels::comparison::{ use arrow::compute::kernels::sort::{SortColumn, SortOptions}; use arrow::datatypes::{DataType, Schema, TimeUnit}; use arrow::record_batch::RecordBatch; +use udf::AggregateFunction; /// Represents the column at a given index in a RecordBatch #[derive(Debug)] @@ -69,15 +76,16 @@ impl fmt::Display for Column { } } -impl PhysicalExpr for Column { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result { +impl DataTyped for Column { + fn get_type(&self, input_schema: &Schema) -> Result { Ok(input_schema .field_with_name(&self.name)? .data_type() .clone()) } +} +impl PhysicalExpr for Column { /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { Ok(input_schema.field_with_name(&self.name)?.is_nullable()) @@ -94,47 +102,51 @@ pub fn col(name: &str) -> Arc { Arc::new(Column::new(name)) } +/// aggregate functions declared in this module +pub fn aggregate_functions() -> Vec { + vec![sum(), avg(), max(), min(), count()] +} /// SUM aggregate expression #[derive(Debug)] -pub struct Sum { - expr: Arc, -} - -impl Sum { - /// Create a new SUM aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} +pub struct Sum {} -impl AggregateExpr for Sum { - fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - Ok(DataType::Int64) - } - DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { - Ok(DataType::UInt64) - } - DataType::Float32 => Ok(DataType::Float32), - DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "SUM does not support {:?}", - other - ))), - } +impl Aggregator for Sum { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(SumAccumulator { sum: None })) } - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_sum(Arc::new(Column::new(column_name))) } +} - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(SumAccumulator { sum: None })) +fn sum_return_type(expr: &Vec<&dyn DataTyped>, schema: &Schema) -> Result { + match expr[0].get_type(schema)? { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + Ok(DataType::Int64) + } + DataType::UInt8 | DataType::UInt16 | DataType::UInt32 | DataType::UInt64 => { + Ok(DataType::UInt64) + } + DataType::Float32 => Ok(DataType::Float32), + DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "SUM does not support {:?}", + other + ))), } +} - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_name)))) +/// Creates a sum aggregate function +pub fn sum() -> AggregateFunction { + AggregateFunction { + name: "sum".to_string(), + return_type: Arc::new(sum_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Sum {}), } } @@ -156,7 +168,7 @@ macro_rules! sum_accumulate { #[derive(Debug)] struct SumAccumulator { - sum: Option, + pub sum: Option, } impl Accumulator for SumAccumulator { @@ -283,48 +295,20 @@ impl Accumulator for SumAccumulator { } } -/// Create a sum expression -pub fn sum(expr: Arc) -> Arc { - Arc::new(Sum::new(expr)) +/// Create a physical aggregate sum expression +pub fn physical_sum(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "SUM", + vec![expr], + Box::new(sum()), + )) } -/// AVG aggregate expression +/// Average aggregate expression. #[derive(Debug)] -pub struct Avg { - expr: Arc, -} - -impl Avg { - /// Create a new AVG aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } -} - -impl AggregateExpr for Avg { - fn data_type(&self, input_schema: &Schema) -> Result { - match self.expr.data_type(input_schema)? { - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::UInt8 - | DataType::UInt16 - | DataType::UInt32 - | DataType::UInt64 - | DataType::Float32 - | DataType::Float64 => Ok(DataType::Float64), - other => Err(ExecutionError::General(format!( - "AVG does not support {:?}", - other - ))), - } - } - - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +pub struct Avg {} +impl Aggregator for Avg { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(AvgAccumulator { sum: None, @@ -333,7 +317,48 @@ impl AggregateExpr for Avg { } fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Avg::new(Arc::new(Column::new(column_name)))) + physical_avg(Arc::new(Column::new(column_name))) + } +} + +fn avg_return_type(expr: &Vec<&dyn DataTyped>, schema: &Schema) -> Result { + match expr[0].get_type(schema)? { + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::UInt8 + | DataType::UInt16 + | DataType::UInt32 + | DataType::UInt64 + | DataType::Float32 + | DataType::Float64 => Ok(DataType::Float64), + other => Err(ExecutionError::General(format!( + "AVG does not support {:?}", + other + ))), + } +} + +/// Create a physical aggregate avg expression +pub fn physical_avg(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "AVG", + vec![expr], + Box::new(avg()), + )) +} + +/// Creates a avg aggregate function +pub fn avg() -> AggregateFunction { + AggregateFunction { + name: "avg".to_string(), + return_type: Arc::new(avg_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Avg {}), } } @@ -399,40 +424,60 @@ impl Accumulator for AvgAccumulator { } } -/// Create a avg expression -pub fn avg(expr: Arc) -> Arc { - Arc::new(Avg::new(expr)) -} - /// MAX aggregate expression #[derive(Debug)] -pub struct Max { - expr: Arc, -} +pub struct Max {} + +impl Aggregator for Max { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(MaxAccumulator { max: None })) + } -impl Max { - /// Create a new MAX aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } + fn create_reducer(&self, column_name: &str) -> Arc { + physical_max(Arc::new(Column::new(column_name))) } } -impl AggregateExpr for Max { - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) - } +/// Create a physical aggregate max expression +pub fn physical_max(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "MAX", + vec![expr], + Box::new(max()), + )) +} - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +fn common_types() -> Vec { + // this order dictactes the order on which we try to cast to. + vec![ + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::Float32, + DataType::Float64, + ] +} - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(MaxAccumulator { max: None })) +/// Creates a max aggregate function +pub fn max() -> AggregateFunction { + AggregateFunction { + name: "max".to_string(), + return_type: Arc::new(max_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Max {}), } +} - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Max::new(Arc::new(Column::new(column_name)))) - } +fn max_return_type(expr: &Vec<&dyn DataTyped>, schema: &Schema) -> Result { + expr[0].get_type(schema) } macro_rules! max_accumulate { @@ -583,39 +628,39 @@ impl Accumulator for MaxAccumulator { } } -/// Create a max expression -pub fn max(expr: Arc) -> Arc { - Arc::new(Max::new(expr)) -} - /// MIN aggregate expression #[derive(Debug)] -pub struct Min { - expr: Arc, -} +pub struct Min {} -impl Min { - /// Create a new MIN aggregate function - pub fn new(expr: Arc) -> Self { - Self { expr } - } +/// Create a physical aggregate min expression +pub fn physical_min(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "MIN", + vec![expr], + Box::new(min()), + )) } -impl AggregateExpr for Min { - fn data_type(&self, input_schema: &Schema) -> Result { - self.expr.data_type(input_schema) - } - - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) +/// Creates a avg aggregate function +pub fn min() -> AggregateFunction { + AggregateFunction { + name: "min".to_string(), + return_type: Arc::new(max_return_type), + arg_types: common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(), + aggregate: Arc::new(Min {}), } +} +impl Aggregator for Min { fn create_accumulator(&self) -> Rc> { Rc::new(RefCell::new(MinAccumulator { min: None })) } fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Min::new(Arc::new(Column::new(column_name)))) + physical_min(Arc::new(Column::new(column_name))) } } @@ -767,43 +812,50 @@ impl Accumulator for MinAccumulator { } } -/// Create a min expression -pub fn min(expr: Arc) -> Arc { - Arc::new(Min::new(expr)) -} - /// COUNT aggregate expression /// Returns the amount of non-null values of the given expression. #[derive(Debug)] -pub struct Count { - expr: Arc, -} +pub struct Count {} -impl Count { - /// Create a new COUNT aggregate function. - pub fn new(expr: Arc) -> Self { - Self { expr: expr } +impl Aggregator for Count { + fn create_accumulator(&self) -> Rc> { + Rc::new(RefCell::new(CountAccumulator { count: 0 })) } -} -impl AggregateExpr for Count { - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(DataType::UInt64) + fn create_reducer(&self, column_name: &str) -> Arc { + physical_sum(Arc::new(Column::new(column_name))) } +} - fn evaluate_input(&self, batch: &RecordBatch) -> Result { - self.expr.evaluate(batch) - } +/// Create a physical aggregate count expression +pub fn physical_count(expr: Arc) -> Arc { + Arc::new(udf::AggregateFunctionExpr::new( + "COUNT", + vec![expr], + Box::new(count()), + )) +} - fn create_accumulator(&self) -> Rc> { - Rc::new(RefCell::new(CountAccumulator { count: 0 })) - } +/// Creates a count aggregate function +pub fn count() -> AggregateFunction { + let mut types = common_types() + .iter() + .map(|x| vec![x.clone()]) + .collect::>(); + types.push(vec![DataType::Utf8]); - fn create_reducer(&self, column_name: &str) -> Arc { - Arc::new(Sum::new(Arc::new(Column::new(column_name)))) + AggregateFunction { + name: "count".to_string(), + return_type: Arc::new(count_return_type), + arg_types: types, + aggregate: Arc::new(Count {}), } } +fn count_return_type(_expr: &Vec<&dyn DataTyped>, _schema: &Schema) -> Result { + Ok(DataType::UInt64) +} + #[derive(Debug)] struct CountAccumulator { count: u64, @@ -827,11 +879,6 @@ impl Accumulator for CountAccumulator { } } -/// Create a count expression -pub fn count(expr: Arc) -> Arc { - Arc::new(Count::new(expr)) -} - /// Invoke a compute kernel on a pair of binary data arrays macro_rules! compute_utf8_op { ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ @@ -964,11 +1011,13 @@ impl fmt::Display for BinaryExpr { } } -impl PhysicalExpr for BinaryExpr { - fn data_type(&self, input_schema: &Schema) -> Result { - self.left.data_type(input_schema) +impl DataTyped for BinaryExpr { + fn get_type(&self, input_schema: &Schema) -> Result { + self.left.get_type(input_schema) } +} +impl PhysicalExpr for BinaryExpr { fn nullable(&self, _input_schema: &Schema) -> Result { // binary operator should always return a boolean value Ok(false) @@ -1054,11 +1103,14 @@ impl fmt::Display for NotExpr { write!(f, "NOT {}", self.arg) } } -impl PhysicalExpr for NotExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { + +impl DataTyped for NotExpr { + fn get_type(&self, _input_schema: &Schema) -> Result { return Ok(DataType::Boolean); } +} +impl PhysicalExpr for NotExpr { fn nullable(&self, _input_schema: &Schema) -> Result { // !Null == true Ok(false) @@ -1111,7 +1163,7 @@ impl CastExpr { input_schema: &Schema, cast_type: DataType, ) -> Result { - let expr_type = expr.data_type(input_schema)?; + let expr_type = expr.get_type(input_schema)?; // numbers can be cast to numbers and strings if is_numeric(&expr_type) && (is_numeric(&cast_type) || cast_type == DataType::Utf8) @@ -1138,11 +1190,13 @@ impl fmt::Display for CastExpr { } } -impl PhysicalExpr for CastExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { +impl DataTyped for CastExpr { + fn get_type(&self, _input_schema: &Schema) -> Result { Ok(self.cast_type.clone()) } +} +impl PhysicalExpr for CastExpr { fn nullable(&self, input_schema: &Schema) -> Result { self.expr.nullable(input_schema) } @@ -1184,11 +1238,13 @@ impl fmt::Display for Literal { } } -impl PhysicalExpr for Literal { - fn data_type(&self, _input_schema: &Schema) -> Result { +impl DataTyped for Literal { + fn get_type(&self, _input_schema: &Schema) -> Result { self.value.get_datatype() } +} +impl PhysicalExpr for Literal { fn nullable(&self, _input_schema: &Schema) -> Result { match &self.value { ScalarValue::Null => Ok(true), @@ -1431,17 +1487,17 @@ mod tests { fn sum_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let sum = sum(col("a")); - assert_eq!(DataType::Int64, sum.data_type(&schema)?); + let sum = physical_sum(col("a")); + assert_eq!(DataType::Int64, sum.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("SUM(a)", sum.data_type(&schema)?, false), + Field::new("SUM(a)", sum.get_type(&schema)?, false), ]); let combiner = sum.create_reducer("SUM(a)"); - assert_eq!(DataType::Int64, combiner.data_type(&schema)?); + assert_eq!(DataType::Int64, combiner.get_type(&schema)?); Ok(()) } @@ -1450,17 +1506,17 @@ mod tests { fn max_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let max = max(col("a")); - assert_eq!(DataType::Int32, max.data_type(&schema)?); + let max = physical_max(col("a")); + assert_eq!(DataType::Int32, max.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("Max(a)", max.data_type(&schema)?, false), + Field::new("Max(a)", max.get_type(&schema)?, false), ]); let combiner = max.create_reducer("Max(a)"); - assert_eq!(DataType::Int32, combiner.data_type(&schema)?); + assert_eq!(DataType::Int32, combiner.get_type(&schema)?); Ok(()) } @@ -1469,16 +1525,16 @@ mod tests { fn min_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let min = min(col("a")); - assert_eq!(DataType::Int32, min.data_type(&schema)?); + let min = physical_min(col("a")); + assert_eq!(DataType::Int32, min.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("MIN(a)", min.data_type(&schema)?, false), + Field::new("MIN(a)", min.get_type(&schema)?, false), ]); let combiner = min.create_reducer("MIN(a)"); - assert_eq!(DataType::Int32, combiner.data_type(&schema)?); + assert_eq!(DataType::Int32, combiner.get_type(&schema)?); Ok(()) } @@ -1486,17 +1542,17 @@ mod tests { fn avg_contract() -> Result<()> { let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); - let avg = avg(col("a")); - assert_eq!(DataType::Float64, avg.data_type(&schema)?); + let avg = physical_avg(col("a")); + assert_eq!(DataType::Float64, avg.get_type(&schema)?); // after the aggr expression is applied, the schema changes to: let schema = Schema::new(vec![ schema.field(0).clone(), - Field::new("SUM(a)", avg.data_type(&schema)?, false), + Field::new("AVG(a)", avg.get_type(&schema)?, false), ]); - let combiner = avg.create_reducer("SUM(a)"); - assert_eq!(DataType::Float64, combiner.data_type(&schema)?); + let combiner = avg.create_reducer("AVG(a)"); + assert_eq!(DataType::Float64, combiner.get_type(&schema)?); Ok(()) } @@ -1826,9 +1882,9 @@ mod tests { } fn do_sum(batch: &RecordBatch) -> Result> { - let sum = sum(col("a")); + let sum = physical_sum(col("a")); let accum = sum.create_accumulator(); - let input = sum.evaluate_input(batch)?; + let input = sum.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1837,9 +1893,9 @@ mod tests { } fn do_max(batch: &RecordBatch) -> Result> { - let max = max(col("a")); + let max = physical_max(col("a")); let accum = max.create_accumulator(); - let input = max.evaluate_input(batch)?; + let input = max.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1848,9 +1904,9 @@ mod tests { } fn do_min(batch: &RecordBatch) -> Result> { - let min = min(col("a")); + let min = physical_min(col("a")); let accum = min.create_accumulator(); - let input = min.evaluate_input(batch)?; + let input = min.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1859,9 +1915,9 @@ mod tests { } fn do_count(batch: &RecordBatch) -> Result> { - let count = count(col("a")); + let count = physical_count(col("a")); let accum = count.create_accumulator(); - let input = count.evaluate_input(batch)?; + let input = count.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; @@ -1870,9 +1926,9 @@ mod tests { } fn do_avg(batch: &RecordBatch) -> Result> { - let avg = avg(col("a")); + let avg = physical_avg(col("a")); let accum = avg.create_accumulator(); - let input = avg.evaluate_input(batch)?; + let input = avg.evaluate(batch)?; let mut accum = accum.borrow_mut(); for i in 0..batch.num_rows() { accum.accumulate_scalar(get_scalar_value(&input, i)?)?; diff --git a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs index d7366395ca4..41780921e45 100644 --- a/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs +++ b/rust/datafusion/src/execution/physical_plan/hash_aggregate.rs @@ -64,10 +64,10 @@ impl HashAggregateExec { let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len()); for (expr, name) in &group_expr { - fields.push(Field::new(name, expr.data_type(&input_schema)?, true)) + fields.push(Field::new(name, expr.get_type(&input_schema)?, true)) } for (expr, name) in &aggr_expr { - fields.push(Field::new(&name, expr.data_type(&input_schema)?, true)) + fields.push(Field::new(&name, expr.get_type(&input_schema)?, true)) } let schema = Arc::new(Schema::new(fields)); @@ -268,7 +268,7 @@ impl RecordBatchReader for GroupedHashAggregateIterator { .aggr_expr .iter() .map(|expr| { - expr.evaluate_input(&batch) + expr.evaluate(&batch) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>>()?; @@ -433,7 +433,7 @@ impl RecordBatchReader for HashAggregateIterator { .aggr_expr .iter() .map(|expr| { - expr.evaluate_input(&batch) + expr.evaluate(&batch) .map_err(ExecutionError::into_arrow_external_error) }) .collect::>>()?; @@ -459,7 +459,7 @@ impl RecordBatchReader for HashAggregateIterator { // aggregate values for i in 0..self.aggr_expr.len() { let aggr_data_type = self.aggr_expr[i] - .data_type(&input_schema) + .get_type(&input_schema) .map_err(ExecutionError::into_arrow_external_error)?; let value = accumulators[i] .borrow_mut() @@ -698,7 +698,7 @@ mod tests { use super::*; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; - use crate::execution::physical_plan::expressions::{col, sum}; + use crate::execution::physical_plan::expressions::{col, physical_sum}; use crate::execution::physical_plan::merge::MergeExec; use crate::test; @@ -716,7 +716,7 @@ mod tests { vec![(col("c2"), "c2".to_string())]; let aggregates: Vec<(Arc, String)> = - vec![(sum(col("c4")), "SUM(c4)".to_string())]; + vec![(physical_sum(col("c4")), "SUM(c4)".to_string())]; let partition_aggregate = HashAggregateExec::try_new( groups.clone(), diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 97098d65b5e..aea348d255f 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -20,36 +20,62 @@ use crate::error::ExecutionError; use crate::execution::physical_plan::udf::ScalarFunction; -use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; -use arrow::datatypes::{DataType, Field}; +use arrow::array::{Array, ArrayRef}; +use arrow::array::{Float32Array, Float64Array}; +use arrow::datatypes::DataType; use std::sync::Arc; +macro_rules! compute_op { + ($ARRAY:expr, $FUNC:ident, $TYPE:ident) => {{ + let mut builder = <$TYPE>::builder($ARRAY.len()); + for i in 0..$ARRAY.len() { + if $ARRAY.is_null(i) { + builder.append_null()?; + } else { + builder.append_value($ARRAY.value(i).$FUNC())?; + } + } + Ok(Arc::new(builder.finish())) + }}; +} + +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => compute_op!(array, $FUNC, $TYPE), + _ => Err(ExecutionError::General(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +macro_rules! unary_primitive_array_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident) => {{ + match ($ARRAY).data_type() { + DataType::Float32 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float32Array), + DataType::Float64 => downcast_compute_op!($ARRAY, $NAME, $FUNC, Float64Array), + other => Err(ExecutionError::General(format!( + "Unsupported data type {:?} for function {}", + other, $NAME, + ))), + } + }}; +} + macro_rules! math_unary_function { ($NAME:expr, $FUNC:ident) => { ScalarFunction::new( $NAME, - vec![Field::new("n", DataType::Float64, true)], - DataType::Float64, + // order: from faster to slower + vec![vec![DataType::Float32], vec![DataType::Float64]], + Arc::new(|expr, schema| expr[0].get_type(schema)), Arc::new(|args: &[ArrayRef]| { - let n = &args[0].as_any().downcast_ref::(); - match n { - Some(array) => { - let mut builder = Float64Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array.value(i).$FUNC())?; - } - } - Ok(Arc::new(builder.finish())) - } - _ => Err(ExecutionError::General(format!( - "Invalid data type for {}", - $NAME - ))), - } + let array = &args[0]; + unary_primitive_array_op!(array, $NAME, $FUNC) }), ) }; @@ -86,7 +112,7 @@ mod tests { execution::context::ExecutionContext, logicalplan::{col, sqrt, LogicalPlanBuilder}, }; - use arrow::datatypes::Schema; + use arrow::datatypes::{Field, Schema}; #[test] fn cast_i8_input() -> Result<()> { @@ -96,7 +122,13 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; - let expected = "Projection: sqrt(CAST(#c0 AS Float64))\ + + assert_eq!( + *plan.schema().field_with_name("sqrt(c0)")?.data_type(), + DataType::Float32 + ); + + let expected = "Projection: sqrt(CAST(#c0 AS Float32))\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); Ok(()) @@ -110,6 +142,32 @@ mod tests { .build()?; let ctx = ExecutionContext::new(); let plan = ctx.optimize(&plan)?; + + assert_eq!( + *plan.schema().field_with_name("sqrt(c0)")?.data_type(), + DataType::Float64 + ); + + let expected = "Projection: sqrt(#c0)\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } + + #[test] + fn no_cast_f32_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Float32, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(col("c0"))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + + assert_eq!( + *plan.schema().field_with_name("sqrt(c0)")?.data_type(), + DataType::Float32 + ); + let expected = "Projection: sqrt(#c0)\ \n TableScan: projection=Some([0])"; assert_eq!(format!("{:?}", plan), expected); diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 4b66955d036..40fc04a1244 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -22,11 +22,12 @@ use std::fmt::{Debug, Display}; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use crate::datatyped::DataTyped; use crate::error::Result; use crate::execution::context::ExecutionContextState; use crate::logicalplan::{LogicalPlan, ScalarValue}; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow::datatypes::{DataType, Schema, SchemaRef}; use arrow::{ compute::kernels::length::length, record_batch::{RecordBatch, RecordBatchReader}, @@ -59,21 +60,16 @@ pub trait Partition: Send + Sync + Debug { /// Expression that can be evaluated against a RecordBatch /// A Physical expression knows its type, nullability and how to evaluate itself. -pub trait PhysicalExpr: Send + Sync + Display + Debug { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; +pub trait PhysicalExpr: Send + Sync + Display + Debug + DataTyped { /// Decide whehter this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result; /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; } -/// Aggregate expression that can be evaluated against a RecordBatch -pub trait AggregateExpr: Send + Sync + Debug { - /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; - /// Evaluate the expression being aggregated - fn evaluate_input(&self, batch: &RecordBatch) -> Result; +/// An aggregators knows how to accumulate arrays in parts, so that the array does not have +/// to be all available in memory. This type of aggregation is also known as online update. +pub trait Aggregator: Send + Sync + Debug { /// Create an accumulator for this aggregate expression fn create_accumulator(&self) -> Rc>; /// Create an aggregate expression for combining the results of accumulators from partitions. @@ -82,7 +78,12 @@ pub trait AggregateExpr: Send + Sync + Debug { fn create_reducer(&self, column_name: &str) -> Arc; } -/// Aggregate accumulator +/// Aggregate expression that can be evaluated against a RecordBatch +pub trait AggregateExpr: PhysicalExpr + Aggregator {} +impl AggregateExpr for T {} + +/// An accumulator knows how compute aggregations without full access to the complete array. +/// This is also known as online updates. pub trait Accumulator: Debug { /// Update the accumulator based on a row in a batch fn accumulate_scalar(&mut self, value: Option) -> Result<()>; @@ -96,14 +97,19 @@ pub trait Accumulator: Debug { pub fn scalar_functions() -> Vec { let mut udfs = vec![ScalarFunction::new( "length", - vec![Field::new("n", DataType::Utf8, true)], - DataType::UInt32, + vec![vec![DataType::Utf8]], + Arc::new(|_, _| Ok(DataType::UInt32)), Arc::new(|args: &[ArrayRef]| Ok(Arc::new(length(args[0].as_ref())?))), )]; udfs.append(&mut math_expressions::scalar_functions()); udfs } +/// Vector of aggregate functions declared in this module +pub fn aggregate_functions() -> Vec { + expressions::aggregate_functions() +} + pub mod common; pub mod csv; pub mod datasource; diff --git a/rust/datafusion/src/execution/physical_plan/planner.rs b/rust/datafusion/src/execution/physical_plan/planner.rs index 7ce845bf434..058e110e332 100644 --- a/rust/datafusion/src/execution/physical_plan/planner.rs +++ b/rust/datafusion/src/execution/physical_plan/planner.rs @@ -19,13 +19,14 @@ use std::sync::{Arc, Mutex}; +use super::udf::AggregateFunctionExpr; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::execution::physical_plan::csv::{CsvExec, CsvReadOptions}; use crate::execution::physical_plan::datasource::DatasourceExec; use crate::execution::physical_plan::explain::ExplainExec; use crate::execution::physical_plan::expressions::{ - Avg, BinaryExpr, CastExpr, Column, Count, Literal, Max, Min, PhysicalSortExpr, Sum, + BinaryExpr, CastExpr, Column, Literal, PhysicalSortExpr, }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::limit::GlobalLimitExec; @@ -354,11 +355,7 @@ impl PhysicalPlannerImpl { input_schema, data_type.clone(), )?)), - Expr::ScalarFunction { - name, - args, - return_type, - } => match ctx_state + Expr::ScalarFunction { name, args, .. } => match ctx_state .lock() .expect("failed to lock mutex") .scalar_functions @@ -379,7 +376,7 @@ impl PhysicalPlannerImpl { name, Box::new(f.fun.clone()), physical_args, - return_type, + f.return_type.clone(), ))) } _ => Err(ExecutionError::General(format!( @@ -403,35 +400,32 @@ impl PhysicalPlannerImpl { ) -> Result> { match e { Expr::AggregateFunction { name, args, .. } => { - match name.to_lowercase().as_ref() { - "sum" => Ok(Arc::new(Sum::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "avg" => Ok(Arc::new(Avg::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "max" => Ok(Arc::new(Max::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "min" => Ok(Arc::new(Min::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - "count" => Ok(Arc::new(Count::new(self.create_physical_expr( - &args[0], - input_schema, - ctx_state.clone(), - )?))), - other => Err(ExecutionError::NotImplemented(format!( - "Unsupported aggregate function '{}'", - other + match &ctx_state + .lock() + .expect("failed to lock mutex") + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + { + Some(f) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(self.create_physical_expr( + e, + input_schema, + ctx_state.clone(), + )?); + } + Ok(Arc::new(AggregateFunctionExpr::new( + name, + physical_args, + Box::new(f.as_ref().clone()), + ))) + } + _ => Err(ExecutionError::General(format!( + "Invalid aggregate function '{:?}'", + name ))), } } diff --git a/rust/datafusion/src/execution/physical_plan/projection.rs b/rust/datafusion/src/execution/physical_plan/projection.rs index a5ad0ef3e03..3ce3f637904 100644 --- a/rust/datafusion/src/execution/physical_plan/projection.rs +++ b/rust/datafusion/src/execution/physical_plan/projection.rs @@ -52,7 +52,7 @@ impl ProjectionExec { .map(|(e, name)| { Ok(Field::new( name, - e.data_type(&input_schema)?, + e.get_type(&input_schema)?, e.nullable(&input_schema)?, )) }) diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 944b5c9bef3..d43d13acbb2 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -20,27 +20,35 @@ use std::fmt; use arrow::array::ArrayRef; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use crate::error::Result; -use crate::execution::physical_plan::PhysicalExpr; +use crate::{datatyped::DataTyped, execution::physical_plan::PhysicalExpr}; +use super::{Accumulator, AggregateExpr, Aggregator}; use arrow::record_batch::RecordBatch; use fmt::{Debug, Formatter}; -use std::sync::Arc; +use std::{cell::RefCell, rc::Rc, sync::Arc}; /// Scalar UDF pub type ScalarUdf = Arc Result + Send + Sync>; +/// Function to construct the return type of a function given its arguments. +pub type ReturnType = + Arc, &Schema) -> Result + Send + Sync>; + /// Scalar UDF Expression #[derive(Clone)] pub struct ScalarFunction { /// Function name pub name: String, - /// Function argument meta-data - pub args: Vec, + /// Set of valid argument types. + /// The first dimension (0) represents specific combinations of valid argument types + /// The second dimension (1) represents the types of each argument. + /// For example, [[t1, t2]] is a function of 2 arguments that only accept t1 on the first arg and t2 on the second + pub arg_types: Vec>, /// Return type - pub return_type: DataType, + pub return_type: ReturnType, /// UDF implementation pub fun: ScalarUdf, } @@ -49,8 +57,8 @@ impl Debug for ScalarFunction { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("ScalarFunction") .field("name", &self.name) - .field("args", &self.args) - .field("return_type", &self.return_type) + .field("arg_types", &self.arg_types) + //.field("return_type", &self.return_type) .field("fun", &"") .finish() } @@ -60,13 +68,13 @@ impl ScalarFunction { /// Create a new ScalarFunction pub fn new( name: &str, - args: Vec, - return_type: DataType, + arg_types: Vec>, + return_type: ReturnType, fun: ScalarUdf, ) -> Self { Self { name: name.to_owned(), - args, + arg_types, return_type, fun, } @@ -78,7 +86,7 @@ pub struct ScalarFunctionExpr { fun: Box, name: String, args: Vec>, - return_type: DataType, + return_type: ReturnType, } impl Debug for ScalarFunctionExpr { @@ -87,7 +95,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + //.field("return_type", &self.return_type) .finish() } } @@ -98,7 +106,7 @@ impl ScalarFunctionExpr { name: &str, fun: Box, args: Vec>, - return_type: &DataType, + return_type: ReturnType, ) -> Self { Self { fun, @@ -124,11 +132,15 @@ impl fmt::Display for ScalarFunctionExpr { } } -impl PhysicalExpr for ScalarFunctionExpr { - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) +impl DataTyped for ScalarFunctionExpr { + fn get_type(&self, input_schema: &Schema) -> Result { + let x = self.args.clone(); + let x = x.iter().map(|x| x.as_datatyped()).collect::>(); + (self.return_type)(&x, input_schema) } +} +impl PhysicalExpr for ScalarFunctionExpr { fn nullable(&self, _input_schema: &Schema) -> Result { Ok(true) } @@ -146,3 +158,104 @@ impl PhysicalExpr for ScalarFunctionExpr { (fun)(&inputs) } } + +/// A generic aggregate function +/* +An aggregate function accepts an arbitrary number of arguments, of arbitrary data types, +and returns an arbitrary type based on the incoming types. + +It is the developer of the function's responsibility to ensure that the aggregator correctly handles the different +types that are presented to them, and that the return type correctly matches the type returned by the +aggregator. + +It is the user of the function's responsibility to pass arguments to the function that have valid types. +*/ +#[derive(Clone)] +pub struct AggregateFunction { + /// Function name + pub name: String, + /// A list of arguments and their respective types. A function can accept more than one type as argument + /// (e.g. sum(i8), sum(u8)). + pub arg_types: Vec>, + /// Return type. This function takes + pub return_type: ReturnType, + /// implementation of the aggregation + pub aggregate: Arc, +} + +/// An aggregate function physical expression +pub struct AggregateFunctionExpr { + name: String, + fun: Box, + // for now, our AggregateFunctionExpr accepts a single element only. + arg: Arc, +} + +impl AggregateFunctionExpr { + /// Create a new AggregateFunctionExpr + pub fn new( + name: &str, + args: Vec>, + fun: Box, + ) -> Self { + Self { + name: name.to_owned(), + arg: args[0].clone(), + fun, + } + } +} + +impl Debug for AggregateFunctionExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("AggregateFunctionExpr") + .field("fun", &"") + .field("name", &self.name) + .field("args", &self.arg) + .finish() + } +} + +impl DataTyped for AggregateFunctionExpr { + fn get_type(&self, input_schema: &Schema) -> Result { + self.fun.as_ref().return_type.as_ref()( + &vec![self.arg.as_datatyped()], + input_schema, + ) + } +} + +impl PhysicalExpr for AggregateFunctionExpr { + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.arg.evaluate(batch) + } +} + +impl Aggregator for AggregateFunctionExpr { + fn create_accumulator(&self) -> Rc> { + self.fun.aggregate.create_accumulator() + } + + fn create_reducer(&self, column_name: &str) -> Arc { + self.fun.aggregate.create_reducer(column_name) + } +} + +impl fmt::Display for AggregateFunctionExpr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}({})", + self.name, + [&self.arg] + .iter() + .map(|e| format!("{}", e)) + .collect::>() + .join(", ") + ) + } +} diff --git a/rust/datafusion/src/lib.rs b/rust/datafusion/src/lib.rs index 73897eeaedf..0fb0e184fa8 100644 --- a/rust/datafusion/src/lib.rs +++ b/rust/datafusion/src/lib.rs @@ -31,6 +31,7 @@ extern crate sqlparser; pub mod dataframe; pub mod datasource; +mod datatyped; pub mod error; pub mod execution; pub mod logicalplan; diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index ee3c17aedd2..bd3f66595e9 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -25,12 +25,16 @@ use std::{fmt, sync::Arc}; use arrow::datatypes::{DataType, Field, Schema}; -use crate::datasource::csv::{CsvFile, CsvReadOptions}; +use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; -use crate::datasource::TableProvider; +use crate::datasource::{CsvReadOptions, TableProvider}; use crate::error::{ExecutionError, Result}; use crate::optimizer::utils; -use crate::sql::parser::FileType; +use crate::{ + datatyped::{AsDataTyped, DataTyped}, + execution::physical_plan::udf::ReturnType, + sql::parser::FileType, +}; use arrow::record_batch::RecordBatch; /// Enumeration of supported function types (Scalar and Aggregate) @@ -43,29 +47,29 @@ pub enum FunctionType { } /// Logical representation of a UDF (user-defined function) -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct FunctionMeta { /// Function name name: String, - /// Function arguments - args: Vec, + /// Function argument types + arg_types: Vec>, /// Function return type - return_type: DataType, + return_type: ReturnType, /// Function type (Scalar or Aggregate) function_type: FunctionType, } impl FunctionMeta { - #[allow(missing_docs)] + /// constructs a new FunctionMeta pub fn new( name: String, - args: Vec, - return_type: DataType, + arg_types: Vec>, + return_type: ReturnType, function_type: FunctionType, ) -> Self { FunctionMeta { name, - args, + arg_types, return_type, function_type, } @@ -75,11 +79,11 @@ impl FunctionMeta { &self.name } /// Getter for the arg list - pub fn args(&self) -> &Vec { - &self.args + pub fn arg_types(&self) -> &Vec> { + &self.arg_types } /// Getter for the `DataType` the function returns - pub fn return_type(&self) -> &DataType { + pub fn return_type(&self) -> &ReturnType { &self.return_type } /// Getter for the `FunctionType` @@ -271,12 +275,8 @@ pub fn expr_to_field(e: &Expr, input_schema: &Schema) -> Result { Expr::Alias(expr, ..) => expr.get_type(input_schema), Expr::Column(name) => Ok(input_schema.field_with_name(name)?.data_type().clone()), Expr::Literal(ref lit) => lit.get_datatype(), - Expr::ScalarFunction { - ref return_type, .. - } => Ok(return_type.clone()), - Expr::AggregateFunction { - ref return_type, .. - } => Ok(return_type.clone()), + Expr::ScalarFunction { .. } => e.get_type(&input_schema), + Expr::AggregateFunction { .. } => e.get_type(&input_schema), Expr::Cast { ref data_type, .. } => Ok(data_type.clone()), Expr::BinaryExpr { ref left, @@ -307,7 +307,7 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result, String), @@ -355,7 +355,7 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, /// The `DataType` the expression will yield - return_type: DataType, + return_type: ReturnType, }, /// aggregate function AggregateFunction { @@ -364,22 +364,25 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, /// The `DataType` the expression will yield - return_type: DataType, + return_type: ReturnType, }, /// Wildcard Wildcard, } -impl Expr { - /// Find the `DataType` for the expression - pub fn get_type(&self, schema: &Schema) -> Result { +impl DataTyped for Expr { + fn get_type(&self, schema: &Schema) -> Result { match self { Expr::Alias(expr, _) => expr.get_type(schema), Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), Expr::Literal(l) => l.get_datatype(), Expr::Cast { data_type, .. } => Ok(data_type.clone()), - Expr::ScalarFunction { return_type, .. } => Ok(return_type.clone()), - Expr::AggregateFunction { return_type, .. } => Ok(return_type.clone()), + Expr::ScalarFunction { + args, return_type, .. + } => return_type(&args.iter().map(|x| x.as_datatyped()).collect(), schema), + Expr::AggregateFunction { + args, return_type, .. + } => return_type(&args.iter().map(|x| x.as_datatyped()).collect(), schema), Expr::Not(_) => Ok(DataType::Boolean), Expr::IsNull(_) => Ok(DataType::Boolean), Expr::IsNotNull(_) => Ok(DataType::Boolean), @@ -407,7 +410,9 @@ impl Expr { Expr::Nested(e) => e.get_type(schema), } } +} +impl Expr { /// Return the name of this expression /// /// This represents how a column with this expression is named when no alias is chosen @@ -565,7 +570,7 @@ macro_rules! unary_math_expr { ($NAME:expr, $FUNC:ident) => { #[allow(missing_docs)] pub fn $FUNC(e: Expr) -> Expr { - scalar_function($NAME, vec![e], DataType::Float64) + scalar_function($NAME, vec![e], Arc::new(|e, schema| e[0].get_type(&schema))) } }; } @@ -591,11 +596,11 @@ unary_math_expr!("log10", log10); /// returns the length of a string in bytes pub fn length(e: Expr) -> Expr { - scalar_function("length", vec![e], DataType::UInt32) + scalar_function("length", vec![e], Arc::new(|_, _| Ok(DataType::UInt32))) } /// Create an aggregate expression -pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { +pub fn aggregate_expr(name: &str, expr: Expr, return_type: ReturnType) -> Expr { Expr::AggregateFunction { name: name.to_owned(), args: vec![expr], @@ -603,8 +608,8 @@ pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { } } -/// Create an aggregate expression -pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { +/// Create an scalar function expression +pub fn scalar_function(name: &str, expr: Vec, return_type: ReturnType) -> Expr { Expr::ScalarFunction { name: name.to_owned(), args: expr, @@ -612,6 +617,15 @@ pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Ex } } +/// Create an aggregate expression +pub fn aggregate_function(name: &str, expr: Vec, return_type: ReturnType) -> Expr { + Expr::AggregateFunction { + name: name.to_owned(), + args: expr, + return_type, + } +} + impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -1092,7 +1106,7 @@ impl LogicalPlanBuilder { /// Apply a projection pub fn project(&self, expr: Vec) -> Result { let input_schema = self.plan.schema(); - let projected_expr = if expr.contains(&Expr::Wildcard) { + let projected_expr = { let mut expr_vec = vec![]; (0..expr.len()).for_each(|i| match &expr[i] { Expr::Wildcard => { @@ -1102,8 +1116,6 @@ impl LogicalPlanBuilder { _ => expr_vec.push(expr[i].clone()), }); expr_vec - } else { - expr.clone() }; let schema = @@ -1268,8 +1280,12 @@ mod tests { )? .aggregate( vec![col("state")], - vec![aggregate_expr("SUM", col("salary"), DataType::Int32) - .alias("total_salary")], + vec![aggregate_expr( + "SUM", + col("salary"), + Arc::new(|_, _| Ok(DataType::Int32)), + ) + .alias("total_salary")], )? .project(vec![col("state"), col("total_salary")])? .build()?; diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index a485423975d..98076f47a80 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -23,14 +23,14 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; -use arrow::datatypes::Schema; +use arrow::datatypes::{DataType, Schema}; use crate::error::{ExecutionError, Result}; -use crate::execution::physical_plan::udf::ScalarFunction; +use crate::execution::physical_plan::udf::{AggregateFunction, ScalarFunction}; use crate::logicalplan::Expr; use crate::logicalplan::LogicalPlan; use crate::optimizer::optimizer::OptimizerRule; -use crate::optimizer::utils; +use crate::{datatyped::DataTyped, optimizer::utils}; use utils::optimize_explain; /// Optimizer that applies coercion rules to expressions in the logical plan. @@ -38,6 +38,7 @@ use utils::optimize_explain; /// This optimizer does not alter the structure of the plan, it only changes expressions on it. pub struct TypeCoercionRule { scalar_functions: Arc>>>, + aggregate_functions: Arc>>>, } impl TypeCoercionRule { @@ -45,8 +46,12 @@ impl TypeCoercionRule { /// scalar functions pub fn new( scalar_functions: Arc>>>, + aggregate_functions: Arc>>>, ) -> Self { - Self { scalar_functions } + Self { + scalar_functions, + aggregate_functions, + } } /// Rewrite an expression to include explicit CAST operations when required @@ -71,6 +76,50 @@ impl TypeCoercionRule { expressions[1] = expressions[1].cast_to(&super_type, schema)?; } } + Expr::AggregateFunction { name, .. } => { + match self + .aggregate_functions + .lock() + .expect("failed to lock mutex") + .get(name) + { + Some(func_meta) => { + // compute the current types and expressions + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new = if func_meta.arg_types.contains(¤t_types) { + Some(expressions) + } else { + maybe_rewrite( + &expressions, + ¤t_types, + &schema, + &func_meta.arg_types, + )? + }; + + if let Some(args) = new { + expressions = args; + } else { + return Err(ExecutionError::General(format!( + "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", + func_meta.name, + func_meta.arg_types, + current_types, + ))); + } + } + None => { + return Err(ExecutionError::General(format!( + "Invalid aggregate function {}", + name + ))) + } + } + } Expr::ScalarFunction { name, .. } => { // cast the inputs of scalar functions to the appropriate type where possible match self @@ -80,17 +129,8 @@ impl TypeCoercionRule { .get(name) { Some(func_meta) => { - for i in 0..expressions.len() { - let field = &func_meta.args[i]; - let actual_type = expressions[i].get_type(schema)?; - let required_type = field.data_type(); - if &actual_type != required_type { - let super_type = - utils::get_supertype(&actual_type, required_type)?; - expressions[i] = - expressions[i].cast_to(&super_type, schema)? - }; - } + expressions = + rewrite_args(expressions, schema, func_meta.as_ref())?; } _ => { return Err(ExecutionError::General(format!( @@ -147,6 +187,78 @@ impl OptimizerRule for TypeCoercionRule { } } +/// rewrites +fn rewrite_args( + expressions: Vec, + schema: &Schema, + func_meta: &ScalarFunction, +) -> Result> { + // compute the current types and expressions + let current_types = expressions + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + + let new = if func_meta.arg_types.contains(¤t_types) { + Some(expressions) + } else { + maybe_rewrite(&expressions, ¤t_types, &schema, &func_meta.arg_types)? + }; + + if let Some(args) = new { + Ok(args) + } else { + Err(ExecutionError::General(format!( + "The scalar function '{}' requires one of the type variants {:?}, but the arguments of type '{:?}' cannot be safely casted to any of them.", + func_meta.name, + func_meta.arg_types, + current_types, + ))) + } +} + +/// tries to re-cast expressions under schema based on the set of valid signatures +fn maybe_rewrite( + expressions: &Vec, + current_types: &Vec, + schema: &Schema, + signature: &Vec>, +) -> Result>> { + // for each set of valid signatures, try to coerce all expressions to one of them + for valid_types in signature { + // for each option, try to coerce all arguments to it + if let Some(types) = maybe_coerce(valid_types, ¤t_types) { + // yes: let's re-write the expressions + return Ok(Some( + expressions + .iter() + .enumerate() + .map(|(i, expr)| expr.cast_to(&types[i], schema)) + .collect::>>()?, + )); + } + // we cannot: try the next + } + Ok(None) +} + +/// Try to coerce current_types into valid_types +fn maybe_coerce( + valid_types: &Vec, + current_types: &Vec, +) -> Option> { + let mut super_type = Vec::with_capacity(valid_types.len()); + for (i, valid_type) in valid_types.iter().enumerate() { + let current_type = ¤t_types[i]; + if let Ok(t) = utils::get_supertype(current_type, valid_type) { + super_type.push(t) + } else { + return None; + } + } + Some(super_type) +} + #[cfg(test)] mod tests { use super::*; @@ -155,6 +267,7 @@ mod tests { use crate::logicalplan::{aggregate_expr, col, lit, LogicalPlanBuilder, Operator}; use crate::test::arrow_testdata_path; use arrow::datatypes::{DataType, Field, Schema}; + use std::sync::Arc; #[test] fn test_all_operators() -> Result<()> { @@ -168,14 +281,19 @@ mod tests { .project(vec![col("c1"), col("c2")])? .aggregate( vec![col("c1")], - vec![aggregate_expr("SUM", col("c2"), DataType::Int64)], + vec![aggregate_expr( + "sum", + col("c2"), + Arc::new(|_, _| Ok(DataType::Int64)), + )], )? .sort(vec![col("c1")])? .limit(10)? .build()?; - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(Arc::new(Mutex::new(scalar_functions))); + let ctx = ExecutionContext::new(); + let mut rule = + TypeCoercionRule::new(ctx.scalar_functions(), ctx.aggregate_functions()); let plan = rule.optimize(&plan)?; // check that the filter had a cast added @@ -183,7 +301,7 @@ mod tests { println!("{}", plan_str); let expected_plan_str = "Limit: 10 Sort: #c1 - Aggregate: groupBy=[[#c1]], aggr=[[SUM(#c2)]] + Aggregate: groupBy=[[#c1]], aggr=[[sum(#c2)]] Projection: #c1, #c2 Selection: #c7 Lt CAST(UInt8(5) AS Int64)"; assert!(plan_str.starts_with(expected_plan_str)); @@ -201,8 +319,10 @@ mod tests { .filter(col("c7").lt(col("c12")))? .build()?; - let scalar_functions = HashMap::new(); - let mut rule = TypeCoercionRule::new(Arc::new(Mutex::new(scalar_functions))); + let mut rule = TypeCoercionRule::new( + Arc::new(Mutex::new(HashMap::new())), + Arc::new(Mutex::new(HashMap::new())), + ); let plan = rule.optimize(&plan)?; assert!( @@ -281,10 +401,150 @@ mod tests { }; let ctx = ExecutionContext::new(); - let rule = TypeCoercionRule::new(ctx.scalar_functions()); + let rule = + TypeCoercionRule::new(ctx.scalar_functions(), ctx.aggregate_functions()); let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); assert_eq!(expected, format!("{:?}", expr2)); } + + #[test] + fn test_maybe_coerce() -> Result<()> { + // this vec contains: arg1, arg2, expected result + let cases = vec![ + // 2 entries, same values + ( + vec![DataType::UInt8, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt8, DataType::UInt16]), + ), + // 2 entries, can coerse values + ( + vec![DataType::UInt16, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + Some(vec![DataType::UInt16, DataType::UInt16]), + ), + // 0 entries, all good + (vec![], vec![], Some(vec![])), + // 2 entries, can't coerce + ( + vec![DataType::Boolean, DataType::UInt16], + vec![DataType::UInt8, DataType::UInt16], + None, + ), + // u32 -> u16 is possible + ( + vec![DataType::Boolean, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt16], + Some(vec![DataType::Boolean, DataType::UInt32]), + ), + ]; + + for case in cases { + assert_eq!(maybe_coerce(&case.0, &case.1), case.2) + } + Ok(()) + } + + #[test] + fn test_maybe_rewrite() -> Result<()> { + // create a schema + let schema = |t: Vec| { + Schema::new( + t.iter() + .enumerate() + .map(|(i, t)| Field::new(&*format!("c{}", i), t.clone(), true)) + .collect(), + ) + }; + + // create a vector of expressions + let expressions = |t: Vec, schema| -> Result> { + t.iter() + .enumerate() + .map(|(i, t)| col(&*format!("c{}", i)).cast_to(&t, &schema)) + .collect::>>() + }; + + // map expr + schema to types + let current_types = |expressions: &Vec, schema| -> Result> { + Ok(expressions + .iter() + .map(|e| e.get_type(&schema)) + .collect::>>()?) + }; + + // create a case: input + expected result + let case = |observed: Vec, + valid, + expected: Option>| + -> Result<_> { + let schema = schema(observed.clone()); + let expr = expressions(observed, schema.clone())?; + let expected = if let Some(e) = expected { + // expressions re-written as cast + Some(expressions(e, schema.clone())?) + } else { + None + }; + Ok(( + expr.clone(), + current_types(&expr, schema.clone())?, + schema, + valid, + expected, + )) + }; + + let cases = vec![ + // no conversion -> all good + case(vec![], vec![vec![]], Some(vec![]))?, + // u16 -> u32 + case( + vec![DataType::UInt16, DataType::UInt32], + vec![vec![DataType::UInt32, DataType::UInt32]], + Some(vec![DataType::UInt32, DataType::UInt32]), + )?, + // same type + case( + vec![DataType::UInt16, DataType::UInt32], + vec![vec![DataType::UInt16, DataType::UInt32]], + Some(vec![DataType::UInt16, DataType::UInt32]), + )?, + // we do not know how to cast bool to UInt16 => fail + case( + vec![DataType::Boolean, DataType::UInt32], + vec![vec![DataType::UInt16, DataType::UInt32]], + None, + )?, + // we do not know how to cast (bool,u16) to (u16,u32), + // but we know to cast to (bool,u32) + case( + vec![DataType::Boolean, DataType::UInt16], + vec![ + vec![DataType::UInt16, DataType::UInt32], + vec![DataType::Boolean, DataType::UInt32], + ], + Some(vec![DataType::Boolean, DataType::UInt32]), + )?, + // we do not know how to cast (bool,u16) to (u16,u32) nor (u32,u16) + case( + vec![DataType::Boolean, DataType::UInt32], + vec![ + vec![DataType::UInt16, DataType::UInt32], + vec![DataType::UInt32, DataType::UInt16], + ], + None, + )?, + ]; + + for (i, case) in cases.iter().enumerate() { + let result = maybe_rewrite(&case.0, &case.1, &case.2, &case.3)?; + let result = format!("case {}: {:?}", i, result); + let expected = format!("case {}: {:?}", i, case.4); + assert_eq!(result, expected); + } + Ok(()) + } } diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 9195e195895..bf72a7dc0c4 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -19,9 +19,10 @@ use std::collections::HashSet; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use super::optimizer::OptimizerRule; +use crate::datatyped::DataTyped; use crate::error::{ExecutionError, Result}; use crate::logicalplan::{Expr, LogicalPlan, PlanType, StringifiedPlan}; @@ -269,7 +270,21 @@ pub fn from_plan( LogicalPlan::Projection { schema, .. } => Ok(LogicalPlan::Projection { expr: expr.clone(), input: Box::new(inputs[0].clone()), - schema: schema.clone(), + // new expressions may have a different type, which changes the resulting schema + schema: Box::new(Schema::new( + schema + .fields() + .iter() + .enumerate() + .map(|(i, f)| { + Ok(Field::new( + f.name(), + expr[i].get_type(inputs[0].schema())?, + f.is_nullable(), + )) + }) + .collect::>>()?, + )), }), LogicalPlan::Selection { .. } => Ok(LogicalPlan::Selection { expr: expr[0].clone(), diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 16da4527557..8a61832dcde 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -22,8 +22,8 @@ use std::sync::Arc; use crate::error::{ExecutionError, Result}; use crate::logicalplan::Expr::Alias; use crate::logicalplan::{ - lit, Expr, FunctionMeta, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, - ScalarValue, StringifiedPlan, + lit, Expr, FunctionMeta, FunctionType, LogicalPlan, LogicalPlanBuilder, Operator, + PlanType, ScalarValue, StringifiedPlan, }; use crate::sql::parser::{CreateExternalTable, FileType, Statement as DFStatement}; @@ -44,6 +44,8 @@ pub trait SchemaProvider { fn get_table_meta(&self, name: &str) -> Option; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; + /// Getter list of valid udfs + fn functions(&self) -> Vec; } /// SQL query planner @@ -476,70 +478,50 @@ impl SqlToRel { } SQLExpr::Function(function) => { - //TODO: fix this hack - let name: String = function.name.to_string(); - match name.to_lowercase().as_ref() { - "min" | "max" | "sum" | "avg" => { - let rex_args = function - .args - .iter() - .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; - - // return type is same as the argument type for these aggregate - // functions - let return_type = rex_args[0].get_type(schema)?.clone(); - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - return_type, - }) - } - "count" => { - let rex_args = function - .args - .iter() - .map(|a| match a { - SQLExpr::Value(Value::Number(_)) => Ok(lit(1_u8)), - SQLExpr::Wildcard => Ok(lit(1_u8)), - _ => self.sql_to_rex(a, schema), - }) - .collect::>>()?; - - Ok(Expr::AggregateFunction { - name: name.clone(), - args: rex_args, - return_type: DataType::UInt64, - }) - } - _ => match self.schema_provider.get_function_meta(&name) { - Some(fm) => { - let rex_args = function + // make the search case-insensitive + let name: String = function.name.to_string().to_lowercase(); + + match self.schema_provider.get_function_meta(&name) { + Some(fm) => { + let args = if name == "count" { + // optimization to avoid computing expressions + function + .args + .iter() + .map(|a| match a { + SQLExpr::Value(Value::Number(_)) => Ok(lit(1_u8)), + SQLExpr::Wildcard => Ok(lit(1_u8)), + _ => self.sql_to_rex(a, schema), + }) + .collect::>>()? + } else { + function .args .iter() .map(|a| self.sql_to_rex(a, schema)) - .collect::>>()?; + .collect::>>()? + }; - let mut safe_args: Vec = vec![]; - for i in 0..rex_args.len() { - safe_args.push( - rex_args[i] - .cast_to(fm.args()[i].data_type(), schema)?, - ); - } + //let args = coerse_expr(&args, &fm, &schema)?; - Ok(Expr::ScalarFunction { + match fm.function_type() { + FunctionType::Scalar => Ok(Expr::ScalarFunction { name: name.clone(), - args: safe_args, + args, return_type: fm.return_type().clone(), - }) + }), + FunctionType::Aggregate => Ok(Expr::AggregateFunction { + name: name.clone(), + args, + return_type: fm.return_type().clone(), + }), } - _ => Err(ExecutionError::General(format!( - "Invalid function '{}'", - name - ))), - }, + } + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'. Valid functions: {:?}", + name, + self.schema_provider.functions(), + ))), } } @@ -598,7 +580,7 @@ mod tests { fn select_scalar_func_with_literal_no_relation() { quick_test( "SELECT sqrt(9)", - "Projection: sqrt(CAST(Int64(9) AS Float64))\ + "Projection: sqrt(Int64(9))\ \n EmptyRelation", ); } @@ -685,7 +667,7 @@ mod tests { fn select_simple_aggregate() { quick_test( "SELECT MIN(age) FROM person", - "Aggregate: groupBy=[[]], aggr=[[MIN(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[min(#age)]]\ \n TableScan: person projection=None", ); } @@ -694,7 +676,7 @@ mod tests { fn test_sum_aggregate() { quick_test( "SELECT SUM(age) from person", - "Aggregate: groupBy=[[]], aggr=[[SUM(#age)]]\ + "Aggregate: groupBy=[[]], aggr=[[sum(#age)]]\ \n TableScan: person projection=None", ); } @@ -703,7 +685,7 @@ mod tests { fn select_simple_aggregate_with_groupby() { quick_test( "SELECT state, MIN(age), MAX(age) FROM person GROUP BY state", - "Aggregate: groupBy=[[#state]], aggr=[[MIN(#age), MAX(#age)]]\ + "Aggregate: groupBy=[[#state]], aggr=[[min(#age), max(#age)]]\ \n TableScan: person projection=None", ); } @@ -720,7 +702,7 @@ mod tests { #[test] fn select_count_one() { let sql = "SELECT COUNT(1) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(UInt8(1))]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -728,7 +710,7 @@ mod tests { #[test] fn select_count_column() { let sql = "SELECT COUNT(id) FROM person"; - let expected = "Aggregate: groupBy=[[]], aggr=[[COUNT(#id)]]\ + let expected = "Aggregate: groupBy=[[]], aggr=[[count(#id)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -736,7 +718,7 @@ mod tests { #[test] fn select_scalar_func() { let sql = "SELECT sqrt(age) FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64))\ + let expected = "Projection: sqrt(#age)\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -744,7 +726,7 @@ mod tests { #[test] fn select_aliased_scalar_func() { let sql = "SELECT sqrt(age) AS square_people FROM person"; - let expected = "Projection: sqrt(CAST(#age AS Float64)) AS square_people\ + let expected = "Projection: sqrt(#age) AS square_people\ \n TableScan: person projection=None"; quick_test(sql, expected); } @@ -797,8 +779,8 @@ mod tests { fn select_group_by_needs_projection() { let sql = "SELECT COUNT(state), state FROM person GROUP BY state"; let expected = "\ - Projection: #COUNT(state), #state\ - \n Aggregate: groupBy=[[#state]], aggr=[[COUNT(#state)]]\ + Projection: #count(state), #state\ + \n Aggregate: groupBy=[[#state]], aggr=[[count(#state)]]\ \n TableScan: person projection=None"; quick_test(sql, expected); @@ -909,15 +891,43 @@ mod tests { } fn get_function_meta(&self, name: &str) -> Option> { - match name { - "sqrt" => Some(Arc::new(FunctionMeta::new( - "sqrt".to_string(), - vec![Field::new("n", DataType::Float64, false)], + let fnc_type = if name == "sqrt" { + FunctionType::Scalar + } else { + FunctionType::Aggregate + }; + let valid_types = if name == "sqrt" { + vec![DataType::Float64] + } else { + vec![ + DataType::UInt8, + DataType::UInt32, + DataType::Int64, + DataType::Int32, DataType::Float64, - FunctionType::Scalar, - ))), + DataType::Utf8, + ] + }; + + let fm = Arc::new(FunctionMeta::new( + name.to_string(), + vec![valid_types], + Arc::new(|_, _| Ok(DataType::Float64)), + fnc_type, + )); + + match name { + "sqrt" => Some(fm), + "min" => Some(fm), + "max" => Some(fm), + "sum" => Some(fm), + "count" => Some(fm), _ => None, } } + + fn functions(&self) -> Vec { + vec!["sqrt".to_string()] + } } } diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 317c14564f1..cd066c6ab91 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -228,6 +228,6 @@ pub fn max(expr: Expr) -> Expr { Expr::AggregateFunction { name: "MAX".to_owned(), args: vec![expr], - return_type: DataType::Float64, + return_type: Arc::new(|_, _| Ok(DataType::Float64)), } } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index d2d5349d5e4..07ad59e4586 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -201,17 +201,111 @@ fn csv_query_avg_sqrt() -> Result<()> { Ok(()) } +#[test] +fn csv_query_avg_custom_udf_f64_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c12 is f64 + let sql = "SELECT avg(custom_add(c12, c12)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // perform equivalent calculation + let sql = "SELECT avg(c12 + c12) FROM aggregate_test_100"; + let expected = execute(&mut ctx, sql); + + // verify equality + assert_eq!(actual.join("\n"), expected.join("\n")); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_f32_f64() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c11 is f32 + // c12 is f64 + let sql = "SELECT avg(custom_add(c11, c12)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as f32,f64 returns a constant 3264.0 + let expected = "3264.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_f32_f32() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c11 is f32 + let sql = "SELECT avg(custom_add(c11, c11)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as f32,f32 returns 3232.0 + let expected = "3232.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_i8() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // c3 is i8, castable to float32 + let sql = "SELECT avg(custom_add(c3, c3)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted as float32,float32 returns a constant 1111.0 + let expected = "3232.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +#[test] +fn csv_query_avg_custom_udf_utf8() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx)?; + // utf8 is currently convertable to any type. See https://issues.apache.org/jira/browse/ARROW-4957 + let sql = "SELECT avg(custom_add(c1, c1)) FROM aggregate_test_100"; + let actual = execute(&mut ctx, sql); + + // function evaluted on any other type returns a constant 1111.0 + let expected = "1111.0".to_string(); + + // verify equality + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + fn create_ctx() -> Result { let mut ctx = ExecutionContext::new(); - // register a custom UDF + // register a UDF of 1 argument ctx.register_udf(ScalarFunction::new( "custom_sqrt", - vec![Field::new("n", DataType::Float64, true)], - DataType::Float64, + vec![vec![DataType::Float64]], + Arc::new(|_, _| Ok(DataType::Float64)), Arc::new(custom_sqrt), )); + // register a udf of two arguments + ctx.register_udf(ScalarFunction::new( + "custom_add", + vec![ + vec![DataType::Float32, DataType::Float32], + vec![DataType::Float32, DataType::Float64], + vec![DataType::Float64, DataType::Float64], + ], + Arc::new(|_, _| Ok(DataType::Float64)), + Arc::new(custom_add), + )); + Ok(ctx) } @@ -232,6 +326,55 @@ fn custom_sqrt(args: &[ArrayRef]) -> Result { Ok(Arc::new(builder.finish())) } +fn custom_add(args: &[ArrayRef]) -> Result { + match (args[0].data_type(), args[1].data_type()) { + (DataType::Float64, DataType::Float64) => { + let input1 = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let input2 = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input1.len()); + for i in 0..input1.len() { + if input1.is_null(i) || input2.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(input1.value(i) + input2.value(i))?; + } + } + Ok(Arc::new(builder.finish())) + } + (DataType::Float32, DataType::Float32) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(3232.0)?; + } + Ok(Arc::new(builder.finish())) + } + (DataType::Float32, DataType::Float64) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(3264.0)?; + } + Ok(Arc::new(builder.finish())) + } + (_, _) => { + // all other cases return a constant vector (just to be diferent) + let mut builder = Float64Builder::new(args[0].len()); + for _ in 0..args[0].len() { + builder.append_value(1111.0)?; + } + Ok(Arc::new(builder.finish())) + } + } +} + #[test] fn csv_query_avg() -> Result<()> { let mut ctx = ExecutionContext::new();