diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 063f5bb3ec9..b36d31d2b45 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -39,9 +39,11 @@ use crate::execution::physical_plan::expressions::{ }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::limit::LimitExec; +use crate::execution::physical_plan::math_expressions::register_math_functions; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::projection::ProjectionExec; use crate::execution::physical_plan::selection::SelectionExec; +use crate::execution::physical_plan::udf::{ScalarFunction, ScalarFunctionExpr}; use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; use crate::execution::table_impl::TableImpl; use crate::logicalplan::*; @@ -57,14 +59,18 @@ use sqlparser::sqlast::{SQLColumnDef, SQLType}; /// Execution context for registering data sources and executing queries pub struct ExecutionContext { datasources: HashMap>, + scalar_functions: HashMap>, } impl ExecutionContext { /// Create a new execution context for in-memory queries pub fn new() -> Self { - Self { + let mut ctx = Self { datasources: HashMap::new(), - } + scalar_functions: HashMap::new(), + }; + register_math_functions(&mut ctx); + ctx } /// Execute a SQL query and produce a Relation (a schema-aware iterator over a series @@ -120,6 +126,7 @@ impl ExecutionContext { DFASTNode::ANSI(ansi) => { let schema_provider = ExecutionContextSchemaProvider { datasources: &self.datasources, + scalar_functions: &self.scalar_functions, }; // create a query planner @@ -150,6 +157,16 @@ impl ExecutionContext { } } + /// Register a scalar UDF + pub fn register_udf(&mut self, f: ScalarFunction) { + self.scalar_functions.insert(f.name.clone(), Box::new(f)); + } + + /// Get a reference to the registered scalar functions + pub fn scalar_functions(&self) -> &HashMap> { + &self.scalar_functions + } + fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::new(); @@ -239,7 +256,7 @@ impl ExecutionContext { let rules: Vec> = vec![ Box::new(ResolveColumnsRule::new()), Box::new(ProjectionPushDown::new()), - Box::new(TypeCoercionRule::new()), + Box::new(TypeCoercionRule::new(&self.scalar_functions)), ]; let mut plan = plan.clone(); for mut rule in rules { @@ -403,6 +420,28 @@ impl ExecutionContext { input_schema, data_type.clone(), )?)), + Expr::ScalarFunction { + name, + args, + return_type, + } => match &self.scalar_functions.get(name) { + Some(f) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(self.create_physical_expr(e, input_schema)?); + } + Ok(Arc::new(ScalarFunctionExpr::new( + name, + Box::new(f.fun.clone()), + physical_args, + return_type, + ))) + } + _ => Err(ExecutionError::General(format!( + "Invalid scalar function '{:?}'", + name + ))), + }, other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other @@ -519,6 +558,7 @@ impl ExecutionContext { struct ExecutionContextSchemaProvider<'a> { datasources: &'a HashMap>, + scalar_functions: &'a HashMap>, } impl SchemaProvider for ExecutionContextSchemaProvider<'_> { @@ -526,8 +566,15 @@ impl SchemaProvider for ExecutionContextSchemaProvider<'_> { self.datasources.get(name).map(|ds| ds.schema().clone()) } - fn get_function_meta(&self, _name: &str) -> Option> { - None + fn get_function_meta(&self, name: &str) -> Option> { + self.scalar_functions.get(name).map(|f| { + Arc::new(FunctionMeta::new( + name.to_owned(), + f.args.clone(), + f.return_type.clone(), + FunctionType::Scalar, + )) + }) } } @@ -535,7 +582,11 @@ impl SchemaProvider for ExecutionContextSchemaProvider<'_> { mod tests { use super::*; + use crate::datasource::MemTable; + use crate::execution::physical_plan::udf::ScalarUdf; use crate::test; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::compute::add; use std::fs::File; use std::io::prelude::*; use tempdir::TempDir; @@ -806,6 +857,99 @@ mod tests { Ok(()) } + #[test] + fn scalar_udf() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 10, 10, 100])), + Arc::new(Int32Array::from(vec![2, 12, 12, 120])), + ], + )?; + + let mut ctx = ExecutionContext::new(); + + let provider = MemTable::new(schema, vec![batch])?; + ctx.register_table("t", Box::new(provider)); + + let myfunc: ScalarUdf = |args: &Vec| { + let l = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let r = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + Ok(Arc::new(add(l, r)?)) + }; + + let my_add = ScalarFunction::new( + "my_add", + vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ], + DataType::Int32, + myfunc, + ); + + ctx.register_udf(my_add); + + let t = ctx.table("t")?; + + let plan = LogicalPlanBuilder::from(&t.to_logical_plan()) + .project(vec![ + col("a"), + col("b"), + scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32), + ])? + .build()?; + + assert_eq!( + format!("{:?}", plan), + "Projection: #a, #b, my_add(#a, #b)\n TableScan: t projection=None" + ); + + let plan = ctx.optimize(&plan)?; + let plan = ctx.create_physical_plan(&plan, 1024)?; + let result = ctx.collect(plan.as_ref())?; + + let batch = &result[0]; + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + let a = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("failed to cast a"); + let b = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("failed to cast b"); + let sum = batch + .column(2) + .as_any() + .downcast_ref::() + .expect("failed to cast sum"); + + assert_eq!(4, a.len()); + assert_eq!(4, b.len()); + assert_eq!(4, sum.len()); + for i in 0..sum.len() { + assert_eq!(a.value(i) + b.value(i), sum.value(i)); + } + + Ok(()) + } + /// Execute SQL and return results fn collect(ctx: &mut ExecutionContext, sql: &str) -> Result> { let logical_plan = ctx.create_logical_plan(sql)?; diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs new file mode 100644 index 00000000000..25bb7330af0 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -0,0 +1,114 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math expressions + +use crate::error::ExecutionError; +use crate::execution::context::ExecutionContext; +use crate::execution::physical_plan::udf::ScalarFunction; + +use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; +use arrow::datatypes::{DataType, Field}; + +use std::sync::Arc; + +macro_rules! math_unary_function { + ($NAME:expr, $FUNC:ident) => { + ScalarFunction::new( + $NAME, + vec![Field::new("n", DataType::Float64, true)], + DataType::Float64, + |args: &Vec| { + let n = &args[0].as_any().downcast_ref::(); + match n { + Some(array) => { + let mut builder = Float64Builder::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array.value(i).$FUNC())?; + } + } + Ok(Arc::new(builder.finish())) + } + _ => Err(ExecutionError::General(format!( + "Invalid data type for {}", + $NAME + ))), + } + }, + ) + }; +} + +/// Register math scalar functions with the context +pub fn register_math_functions(ctx: &mut ExecutionContext) { + ctx.register_udf(math_unary_function!("sqrt", sqrt)); + ctx.register_udf(math_unary_function!("sin", sin)); + ctx.register_udf(math_unary_function!("cos", cos)); + ctx.register_udf(math_unary_function!("tan", tan)); + ctx.register_udf(math_unary_function!("asin", asin)); + ctx.register_udf(math_unary_function!("acos", acos)); + ctx.register_udf(math_unary_function!("atan", atan)); + ctx.register_udf(math_unary_function!("floor", floor)); + ctx.register_udf(math_unary_function!("ceil", ceil)); + ctx.register_udf(math_unary_function!("round", round)); + ctx.register_udf(math_unary_function!("trunc", trunc)); + ctx.register_udf(math_unary_function!("abs", abs)); + ctx.register_udf(math_unary_function!("signum", signum)); + ctx.register_udf(math_unary_function!("exp", exp)); + ctx.register_udf(math_unary_function!("log", ln)); + ctx.register_udf(math_unary_function!("log2", log2)); + ctx.register_udf(math_unary_function!("log10", log10)); +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::logicalplan::{sqrt, Expr, LogicalPlanBuilder}; + use arrow::datatypes::Schema; + + #[test] + fn cast_i8_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Int8, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + let expected = "Projection: sqrt(CAST(#0 AS Float64))\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } + + #[test] + fn no_cast_f64_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Float64, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + let expected = "Projection: sqrt(#0)\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } +} diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 1b1226836e0..9868b78db01 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -91,7 +91,9 @@ pub mod datasource; pub mod expressions; pub mod hash_aggregate; pub mod limit; +pub mod math_expressions; pub mod merge; pub mod parquet; pub mod projection; pub mod selection; +pub mod udf; diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs new file mode 100644 index 00000000000..13926f9a483 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -0,0 +1,108 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDF support + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema}; + +use crate::error::Result; +use crate::execution::physical_plan::PhysicalExpr; + +use arrow::record_batch::RecordBatch; +use std::sync::Arc; + +/// Scalar UDF +pub type ScalarUdf = fn(input: &Vec) -> Result; + +/// Scalar UDF Expression +#[derive(Clone)] +pub struct ScalarFunction { + /// Function name + pub name: String, + /// Function argument meta-data + pub args: Vec, + /// Return type + pub return_type: DataType, + /// UDF implementation + pub fun: ScalarUdf, +} + +impl ScalarFunction { + /// Create a new ScalarFunction + pub fn new( + name: &str, + args: Vec, + return_type: DataType, + fun: ScalarUdf, + ) -> Self { + Self { + name: name.to_owned(), + args, + return_type, + fun, + } + } +} + +/// Scalar UDF Physical Expression +pub struct ScalarFunctionExpr { + name: String, + fun: Box, + args: Vec>, + return_type: DataType, +} + +impl ScalarFunctionExpr { + /// Create a new Scalar function + pub fn new( + name: &str, + fun: Box, + args: Vec>, + return_type: &DataType, + ) -> Self { + Self { + name: name.to_owned(), + fun, + args, + return_type: return_type.clone(), + } + } +} + +impl PhysicalExpr for ScalarFunctionExpr { + fn name(&self) -> String { + self.name.clone() + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + // evaluate the arguments + let inputs = self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + // evaluate the function + let fun = self.fun.as_ref(); + (fun)(&inputs) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 0b9464c1588..0443e6cd8ee 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -376,6 +376,35 @@ pub fn lit_str(str: &str) -> Expr { Expr::Literal(ScalarValue::Utf8(str.to_owned())) } +/// Create an convenience function representing a unary scalar function +macro_rules! unary_math_expr { + ($NAME:expr, $FUNC:ident) => { + #[allow(missing_docs)] + pub fn $FUNC(e: Expr) -> Expr { + scalar_function($NAME, vec![e], DataType::Float64) + } + }; +} + +// generate methods for creating the supported unary math expressions +unary_math_expr!("sqrt", sqrt); +unary_math_expr!("sin", sin); +unary_math_expr!("cos", cos); +unary_math_expr!("tan", tan); +unary_math_expr!("asin", asin); +unary_math_expr!("acos", acos); +unary_math_expr!("atan", atan); +unary_math_expr!("floor", floor); +unary_math_expr!("ceil", ceil); +unary_math_expr!("round", round); +unary_math_expr!("trunc", trunc); +unary_math_expr!("abs", abs); +unary_math_expr!("signum", signum); +unary_math_expr!("exp", exp); +unary_math_expr!("log", ln); +unary_math_expr!("log2", log2); +unary_math_expr!("log10", log10); + /// Create an aggregate expression pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { Expr::AggregateFunction { @@ -385,6 +414,15 @@ pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { } } +/// Create an aggregate expression +pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { + Expr::ScalarFunction { + name: name.to_owned(), + args: expr, + return_type, + } +} + impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index b8d90a90910..5e5dc0a8deb 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -15,35 +15,145 @@ // specific language governing permissions and limitations // under the License. -//! The type_coercion optimizer rule ensures that all binary operators are operating on +//! The type_coercion optimizer rule ensures that all operators are operating on //! compatible types by adding explicit cast operations to expressions. For example, //! the operation `c_float + c_int` would be rewritten as `c_float + CAST(c_int AS //! float)`. This keeps the runtime query execution code much simpler. +use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::Schema; use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::udf::ScalarFunction; use crate::logicalplan::LogicalPlan; use crate::logicalplan::{Expr, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; /// Implementation of type coercion optimizer rule -pub struct TypeCoercionRule {} +pub struct TypeCoercionRule<'a> { + scalar_functions: &'a HashMap>, +} + +impl<'a> TypeCoercionRule<'a> { + /// Create a new type coercion optimizer rule using meta-data about registered + /// scalar functions + pub fn new(scalar_functions: &'a HashMap>) -> Self { + Self { scalar_functions } + } + + /// Rewrite an expression list to include explicit CAST operations when required + fn rewrite_expr_list(&self, expr: &Vec, schema: &Schema) -> Result> { + Ok(expr + .iter() + .map(|e| self.rewrite_expr(e, schema)) + .collect::>>()?) + } + + /// Rewrite an expression to include explicit CAST operations when required + fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result { + match expr { + Expr::BinaryExpr { left, op, right } => { + let left = self.rewrite_expr(left, schema)?; + let right = self.rewrite_expr(right, schema)?; + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + if left_type == right_type { + Ok(Expr::BinaryExpr { + left: Arc::new(left), + op: op.clone(), + right: Arc::new(right), + }) + } else { + let super_type = utils::get_supertype(&left_type, &right_type)?; + Ok(Expr::BinaryExpr { + left: Arc::new(left.cast_to(&super_type, schema)?), + op: op.clone(), + right: Arc::new(right.cast_to(&super_type, schema)?), + }) + } + } + Expr::IsNull(e) => Ok(Expr::IsNull(Arc::new(self.rewrite_expr(e, schema)?))), + Expr::IsNotNull(e) => { + Ok(Expr::IsNotNull(Arc::new(self.rewrite_expr(e, schema)?))) + } + Expr::ScalarFunction { + name, + args, + return_type, + } => { + // cast the inputs of scalar functions to the appropriate type where possible + match self.scalar_functions.get(name) { + Some(func_meta) => { + let mut func_args = Vec::with_capacity(args.len()); + for i in 0..args.len() { + let field = &func_meta.args[i]; + let expr = self.rewrite_expr(&args[i], schema)?; + let actual_type = expr.get_type(schema)?; + let required_type = field.data_type(); + if &actual_type == required_type { + func_args.push(expr) + } else { + let super_type = + utils::get_supertype(&actual_type, required_type)?; + func_args.push(expr.cast_to(&super_type, schema)?); + } + } + + Ok(Expr::ScalarFunction { + name: name.clone(), + args: func_args, + return_type: return_type.clone(), + }) + } + _ => Err(ExecutionError::General(format!( + "Invalid scalar function {}", + name + ))), + } + } + Expr::AggregateFunction { + name, + args, + return_type, + } => Ok(Expr::AggregateFunction { + name: name.clone(), + args: args + .iter() + .map(|a| self.rewrite_expr(a, schema)) + .collect::>>()?, + return_type: return_type.clone(), + }), + Expr::Cast { .. } => Ok(expr.clone()), + Expr::Column(_) => Ok(expr.clone()), + Expr::Alias(expr, alias) => Ok(Expr::Alias( + Arc::new(self.rewrite_expr(expr, schema)?), + alias.to_owned(), + )), + Expr::Literal(_) => Ok(expr.clone()), + Expr::UnresolvedColumn(_) => Ok(expr.clone()), + Expr::Not(_) => Ok(expr.clone()), + Expr::Sort { .. } => Ok(expr.clone()), + Expr::Wildcard { .. } => Err(ExecutionError::General( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + } + } +} -impl OptimizerRule for TypeCoercionRule { +impl<'a> OptimizerRule for TypeCoercionRule<'a> { fn optimize(&mut self, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Projection { expr, input, .. } => { LogicalPlanBuilder::from(&self.optimize(input)?) - .project(rewrite_expr_list(expr, input.schema())?)? + .project(self.rewrite_expr_list(expr, input.schema())?)? .build() } LogicalPlan::Selection { expr, input, .. } => { LogicalPlanBuilder::from(&self.optimize(input)?) - .filter(rewrite_expr(expr, input.schema())?)? + .filter(self.rewrite_expr(expr, input.schema())?)? .build() } LogicalPlan::Aggregate { @@ -53,8 +163,8 @@ impl OptimizerRule for TypeCoercionRule { .. } => LogicalPlanBuilder::from(&self.optimize(input)?) .aggregate( - rewrite_expr_list(group_expr, input.schema())?, - rewrite_expr_list(aggr_expr, input.schema())?, + self.rewrite_expr_list(group_expr, input.schema())?, + self.rewrite_expr_list(aggr_expr, input.schema())?, )? .build(), LogicalPlan::TableScan { .. } => Ok(plan.clone()), @@ -69,88 +179,10 @@ impl OptimizerRule for TypeCoercionRule { } } -impl TypeCoercionRule { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn rewrite_expr_list(expr: &Vec, schema: &Schema) -> Result> { - Ok(expr - .iter() - .map(|e| rewrite_expr(e, schema)) - .collect::>>()?) -} - -/// Rewrite an expression to include explicit CAST operations when required -fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { - match expr { - Expr::BinaryExpr { left, op, right } => { - let left = rewrite_expr(left, schema)?; - let right = rewrite_expr(right, schema)?; - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - if left_type == right_type { - Ok(Expr::BinaryExpr { - left: Arc::new(left), - op: op.clone(), - right: Arc::new(right), - }) - } else { - let super_type = utils::get_supertype(&left_type, &right_type)?; - Ok(Expr::BinaryExpr { - left: Arc::new(left.cast_to(&super_type, schema)?), - op: op.clone(), - right: Arc::new(right.cast_to(&super_type, schema)?), - }) - } - } - Expr::IsNull(e) => Ok(Expr::IsNull(Arc::new(rewrite_expr(e, schema)?))), - Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Arc::new(rewrite_expr(e, schema)?))), - Expr::ScalarFunction { - name, - args, - return_type, - } => Ok(Expr::ScalarFunction { - name: name.clone(), - args: args - .iter() - .map(|a| rewrite_expr(a, schema)) - .collect::>>()?, - return_type: return_type.clone(), - }), - Expr::AggregateFunction { - name, - args, - return_type, - } => Ok(Expr::AggregateFunction { - name: name.clone(), - args: args - .iter() - .map(|a| rewrite_expr(a, schema)) - .collect::>>()?, - return_type: return_type.clone(), - }), - Expr::Cast { .. } => Ok(expr.clone()), - Expr::Column(_) => Ok(expr.clone()), - Expr::Alias(expr, alias) => Ok(Expr::Alias( - Arc::new(rewrite_expr(expr, schema)?), - alias.to_owned(), - )), - Expr::Literal(_) => Ok(expr.clone()), - Expr::UnresolvedColumn(_) => Ok(expr.clone()), - Expr::Not(_) => Ok(expr.clone()), - Expr::Sort { .. } => Ok(expr.clone()), - Expr::Wildcard { .. } => Err(ExecutionError::General( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - } -} - #[cfg(test)] mod tests { use super::*; + use crate::execution::context::ExecutionContext; use crate::logicalplan::Expr::*; use crate::logicalplan::Operator; use arrow::datatypes::{DataType, Field, Schema}; @@ -223,7 +255,10 @@ mod tests { right: Arc::new(Column(1)), }; - let expr2 = rewrite_expr(&expr, &schema).unwrap(); + let ctx = ExecutionContext::new(); + let rule = TypeCoercionRule::new(ctx.scalar_functions()); + + let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); assert_eq!(expected, format!("{:?}", expr2)); } diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index c45af09a840..949b6e2395c 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -27,6 +27,7 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +use datafusion::execution::physical_plan::udf::ScalarFunction; use datafusion::logicalplan::LogicalPlan; const DEFAULT_BATCH_SIZE: usize = 1024 * 1024; @@ -144,6 +145,49 @@ fn csv_query_group_by_int_min_max() { assert_eq!(expected, actual.join("\n")); } +#[test] +fn csv_query_avg_sqrt() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx); + let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql); + actual.sort(); + let expected = "0.6706002946036462".to_string(); + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +fn create_ctx() -> Result { + let mut ctx = ExecutionContext::new(); + + // register a custom UDF + ctx.register_udf(ScalarFunction::new( + "custom_sqrt", + vec![Field::new("n", DataType::Float64, true)], + DataType::Float64, + custom_sqrt, + )); + + Ok(ctx) +} + +fn custom_sqrt(args: &Vec) -> Result { + let input = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input.len()); + for i in 0..input.len() { + if input.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(input.value(i).sqrt())?; + } + } + Ok(Arc::new(builder.finish())) +} + #[test] fn csv_query_avg() { let mut ctx = ExecutionContext::new();