From 14c72e8a23ac5ddbff8fda203269a0a299d0801a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 10:19:58 -0600 Subject: [PATCH 01/13] Implement Scalar UDF support (WIP) --- rust/datafusion/src/execution/context.rs | 121 +++++++++++++++++- .../src/execution/physical_plan/mod.rs | 1 + .../src/execution/physical_plan/udf.rs | 105 +++++++++++++++ rust/datafusion/src/logicalplan.rs | 9 ++ rust/datafusion/tests/sql.rs | 40 ++++++ 5 files changed, 274 insertions(+), 2 deletions(-) create mode 100644 rust/datafusion/src/execution/physical_plan/udf.rs diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 063f5bb3ec9..5af47c1307c 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -42,6 +42,7 @@ use crate::execution::physical_plan::limit::LimitExec; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::projection::ProjectionExec; use crate::execution::physical_plan::selection::SelectionExec; +use crate::execution::physical_plan::udf::{ScalarFunction, ScalarFunctionExpr}; use crate::execution::physical_plan::{AggregateExpr, ExecutionPlan, PhysicalExpr}; use crate::execution::table_impl::TableImpl; use crate::logicalplan::*; @@ -57,6 +58,7 @@ use sqlparser::sqlast::{SQLColumnDef, SQLType}; /// Execution context for registering data sources and executing queries pub struct ExecutionContext { datasources: HashMap>, + scalar_functions: HashMap>, } impl ExecutionContext { @@ -64,6 +66,7 @@ impl ExecutionContext { pub fn new() -> Self { Self { datasources: HashMap::new(), + scalar_functions: HashMap::new(), } } @@ -120,6 +123,7 @@ impl ExecutionContext { DFASTNode::ANSI(ansi) => { let schema_provider = ExecutionContextSchemaProvider { datasources: &self.datasources, + scalar_functions: &self.scalar_functions, }; // create a query planner @@ -150,6 +154,11 @@ impl ExecutionContext { } } + /// Register a scalar UDF + pub fn register_udf(&mut self, name: &str, f: ScalarFunction) { + self.scalar_functions.insert(name.to_owned(), Box::new(f)); + } + fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::new(); @@ -403,6 +412,29 @@ impl ExecutionContext { input_schema, data_type.clone(), )?)), + Expr::ScalarFunction { + name, + args, + return_type, + } => { + match &self.scalar_functions.get(name) { + Some(f) => { + let mut physical_args = vec![]; + for e in args { + physical_args + .push(self.create_physical_expr(e, input_schema)?); + } + //TODO pass refs not clone + Ok(Arc::new(ScalarFunctionExpr::new( + name.to_owned(), + Box::new(f.fun.clone()), + physical_args, + return_type.clone(), + ))) + } + _ => panic!(), + } + } other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other @@ -519,6 +551,7 @@ impl ExecutionContext { struct ExecutionContextSchemaProvider<'a> { datasources: &'a HashMap>, + scalar_functions: &'a HashMap>, } impl SchemaProvider for ExecutionContextSchemaProvider<'_> { @@ -526,8 +559,15 @@ impl SchemaProvider for ExecutionContextSchemaProvider<'_> { self.datasources.get(name).map(|ds| ds.schema().clone()) } - fn get_function_meta(&self, _name: &str) -> Option> { - None + fn get_function_meta(&self, name: &str) -> Option> { + self.scalar_functions.get(name).map(|f| { + Arc::new(FunctionMeta::new( + name.to_owned(), + f.args.clone(), + f.return_type.clone(), + FunctionType::Scalar, + )) + }) } } @@ -535,7 +575,11 @@ impl SchemaProvider for ExecutionContextSchemaProvider<'_> { mod tests { use super::*; + use crate::datasource::MemTable; + use crate::execution::physical_plan::udf::ScalarUdf; use crate::test; + use arrow::array::{ArrayRef, Int32Array}; + use arrow::compute::add; use std::fs::File; use std::io::prelude::*; use tempdir::TempDir; @@ -806,6 +850,79 @@ mod tests { Ok(()) } + #[test] + fn scalar_udf() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 10, 10, 100])), + Arc::new(Int32Array::from(vec![2, 12, 12, 120])), + ], + )?; + + let mut ctx = ExecutionContext::new(); + + let provider = MemTable::new(schema, vec![batch]).unwrap(); + ctx.register_table("t", Box::new(provider)); + + let myfunc: ScalarUdf = |args: &Vec| { + let l = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + let r = &args[1] + .as_any() + .downcast_ref::() + .expect("cast failed"); + println!("Running my_add"); + Ok(Arc::new(add(l, r).unwrap())) + }; + + let def = ScalarFunction::new( + "my_add", + vec![ + Field::new("a", DataType::Float64, true), + Field::new("b", DataType::Float64, true), + ], + DataType::Float64, + myfunc, + ); + + ctx.register_udf("my_add", def); + + let t = ctx.table("t").unwrap(); + + let plan = LogicalPlanBuilder::from(t.to_logical_plan().as_ref()) + .project(vec![ + col(0), + col(1), + scalar_function("my_add", vec![col(0), col(1)], DataType::Int32), + ])? + .build()?; + + assert_eq!( + format!("{:?}", plan), + "Projection: #0, #1, my_add(#0, #1)\n TableScan: t projection=None" + ); + + let plan = ctx.optimize(&plan)?; + let plan = ctx.create_physical_plan(&plan, 1024)?; + let result = ctx.collect(plan.as_ref())?; + + let batch = &result[0]; + assert_eq!(3, batch.num_columns()); + assert_eq!(4, batch.num_rows()); + + //TODO assert correct results + + Ok(()) + } + /// Execute SQL and return results fn collect(ctx: &mut ExecutionContext, sql: &str) -> Result> { let logical_plan = ctx.create_logical_plan(sql)?; diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 1b1226836e0..9bad646b557 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -95,3 +95,4 @@ pub mod merge; pub mod parquet; pub mod projection; pub mod selection; +pub mod udf; diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs new file mode 100644 index 00000000000..0d8e9eaa2f0 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -0,0 +1,105 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! UDF support + +use arrow::array::ArrayRef; +use arrow::datatypes::{DataType, Field, Schema}; + +use crate::error::Result; +use crate::execution::physical_plan::PhysicalExpr; + +use arrow::record_batch::RecordBatch; +use std::sync::Arc; + +/// Scalar UDF +pub type ScalarUdf = fn(input: &Vec) -> Result; + +/// Scalar UDF Expression +#[derive(Clone)] +pub struct ScalarFunction { + /// Function name + pub name: String, + /// Function argument meta-data + pub args: Vec, + /// Return type + pub return_type: DataType, + /// UDF implementation + pub fun: ScalarUdf, +} + +impl ScalarFunction { + /// Create a new ScalarFunction + pub fn new( + name: &str, + args: Vec, + return_type: DataType, + fun: ScalarUdf, + ) -> Self { + Self { + name: name.to_owned(), + args, + return_type, + fun, + } + } +} +/// Scalar UDF Expression +pub struct ScalarFunctionExpr { + name: String, + fun: Box, + args: Vec>, + return_type: DataType, +} + +impl ScalarFunctionExpr { + /// Create a new Scalar function + pub fn new( + name: String, + fun: Box, + args: Vec>, + return_type: DataType, + ) -> Self { + Self { + name, + fun, + args, + return_type, + } + } +} + +impl PhysicalExpr for ScalarFunctionExpr { + fn name(&self) -> String { + self.name.clone() + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.return_type.clone()) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let inputs = self + .args + .iter() + .map(|e| e.evaluate(batch)) + .collect::>>()?; + + let fun = self.fun.as_ref(); + (fun)(&inputs) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 0b9464c1588..0f7dbfff0ca 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -385,6 +385,15 @@ pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { } } +/// Create an aggregate expression +pub fn scalar_function(name: &str, expr: Vec, return_type: DataType) -> Expr { + Expr::ScalarFunction { + name: name.to_owned(), + args: expr, + return_type, + } +} + impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index c45af09a840..c46664fa3f4 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -27,6 +27,7 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; +use datafusion::execution::physical_plan::udf::{ScalarFunction, ScalarUdf}; use datafusion::logicalplan::LogicalPlan; const DEFAULT_BATCH_SIZE: usize = 1024 * 1024; @@ -144,6 +145,45 @@ fn csv_query_group_by_int_min_max() { assert_eq!(expected, actual.join("\n")); } +#[test] +fn csv_query_avg_sqrt() -> Result<()> { + let mut ctx = create_ctx()?; + register_aggregate_csv(&mut ctx); + let sql = "SELECT avg(sqrt(c12)) FROM aggregate_test_100"; + let mut actual = execute(&mut ctx, sql); + actual.sort(); + let expected = "0.6706002946036462".to_string(); + assert_eq!(actual.join("\n"), expected); + Ok(()) +} + +fn create_ctx() -> Result { + let mut ctx = ExecutionContext::new(); + + let sqrt: ScalarUdf = |args: &Vec| { + let input = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input.len()); + for i in 0..input.len() { + builder.append_value(input.value(i).sqrt())?; + } + Ok(Arc::new(builder.finish())) + }; + + let sqrt_meta = ScalarFunction::new( + "sqrt", + vec![Field::new("n", DataType::Float64, true)], + DataType::Float64, + sqrt, + ); + + ctx.register_udf("sqrt", sqrt_meta); + Ok(ctx) +} + #[test] fn csv_query_avg() { let mut ctx = ExecutionContext::new(); From 967da362d7ed7a711445a1bce9327c5b2c1a670e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 10:25:02 -0600 Subject: [PATCH 02/13] remove panic --- rust/datafusion/src/execution/context.rs | 33 +++++++++---------- .../src/execution/physical_plan/udf.rs | 13 +++++--- 2 files changed, 24 insertions(+), 22 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 5af47c1307c..cce88bac8e8 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -416,25 +416,24 @@ impl ExecutionContext { name, args, return_type, - } => { - match &self.scalar_functions.get(name) { - Some(f) => { - let mut physical_args = vec![]; - for e in args { - physical_args - .push(self.create_physical_expr(e, input_schema)?); - } - //TODO pass refs not clone - Ok(Arc::new(ScalarFunctionExpr::new( - name.to_owned(), - Box::new(f.fun.clone()), - physical_args, - return_type.clone(), - ))) + } => match &self.scalar_functions.get(name) { + Some(f) => { + let mut physical_args = vec![]; + for e in args { + physical_args.push(self.create_physical_expr(e, input_schema)?); } - _ => panic!(), + Ok(Arc::new(ScalarFunctionExpr::new( + name, + Box::new(f.fun.clone()), + physical_args, + return_type, + ))) } - } + _ => Err(ExecutionError::General(format!( + "Invalid scalar function '{:?}'", + name + ))), + }, other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other diff --git a/rust/datafusion/src/execution/physical_plan/udf.rs b/rust/datafusion/src/execution/physical_plan/udf.rs index 0d8e9eaa2f0..13926f9a483 100644 --- a/rust/datafusion/src/execution/physical_plan/udf.rs +++ b/rust/datafusion/src/execution/physical_plan/udf.rs @@ -58,7 +58,8 @@ impl ScalarFunction { } } } -/// Scalar UDF Expression + +/// Scalar UDF Physical Expression pub struct ScalarFunctionExpr { name: String, fun: Box, @@ -69,16 +70,16 @@ pub struct ScalarFunctionExpr { impl ScalarFunctionExpr { /// Create a new Scalar function pub fn new( - name: String, + name: &str, fun: Box, args: Vec>, - return_type: DataType, + return_type: &DataType, ) -> Self { Self { - name, + name: name.to_owned(), fun, args, - return_type, + return_type: return_type.clone(), } } } @@ -93,12 +94,14 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { + // evaluate the arguments let inputs = self .args .iter() .map(|e| e.evaluate(batch)) .collect::>>()?; + // evaluate the function let fun = self.fun.as_ref(); (fun)(&inputs) } From 13305ac8e8f46d7e6f3c69d34050fcb10fc1d428 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 11:28:03 -0600 Subject: [PATCH 03/13] rebase --- rust/datafusion/src/execution/context.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index cce88bac8e8..b5c609e9bd2 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -878,7 +878,6 @@ mod tests { .as_any() .downcast_ref::() .expect("cast failed"); - println!("Running my_add"); Ok(Arc::new(add(l, r).unwrap())) }; @@ -898,15 +897,15 @@ mod tests { let plan = LogicalPlanBuilder::from(t.to_logical_plan().as_ref()) .project(vec![ - col(0), - col(1), - scalar_function("my_add", vec![col(0), col(1)], DataType::Int32), + col("a"), + col("b"), + scalar_function("my_add", vec![col("a"), col("b")], DataType::Int32), ])? .build()?; assert_eq!( format!("{:?}", plan), - "Projection: #0, #1, my_add(#0, #1)\n TableScan: t projection=None" + "Projection: #a, #b, my_add(#a, #b)\n TableScan: t projection=None" ); let plan = ctx.optimize(&plan)?; From 905cd642c42498a04ff382310a1820fd7cd239e9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 11:41:30 -0600 Subject: [PATCH 04/13] unit test passes --- rust/datafusion/src/execution/context.rs | 29 ++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index b5c609e9bd2..49512c4f2bb 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -884,10 +884,10 @@ mod tests { let def = ScalarFunction::new( "my_add", vec![ - Field::new("a", DataType::Float64, true), - Field::new("b", DataType::Float64, true), + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), ], - DataType::Float64, + DataType::Int32, myfunc, ); @@ -916,7 +916,28 @@ mod tests { assert_eq!(3, batch.num_columns()); assert_eq!(4, batch.num_rows()); - //TODO assert correct results + let a = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("failed to cast a"); + let b = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("failed to cast b"); + let sum = batch + .column(2) + .as_any() + .downcast_ref::() + .expect("failed to cast sum"); + + assert_eq!(4, a.len()); + assert_eq!(4, b.len()); + assert_eq!(4, sum.len()); + for i in 0..sum.len() { + assert_eq!(a.value(i) + b.value(i), sum.value(i)); + } Ok(()) } From 16d9a53c3636611abd1a7c03b62454d725ebd6ff Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 11:43:34 -0600 Subject: [PATCH 05/13] remove unwrap from test --- rust/datafusion/src/execution/context.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 49512c4f2bb..7b9c3087343 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -866,7 +866,7 @@ mod tests { let mut ctx = ExecutionContext::new(); - let provider = MemTable::new(schema, vec![batch]).unwrap(); + let provider = MemTable::new(schema, vec![batch])?; ctx.register_table("t", Box::new(provider)); let myfunc: ScalarUdf = |args: &Vec| { @@ -878,7 +878,7 @@ mod tests { .as_any() .downcast_ref::() .expect("cast failed"); - Ok(Arc::new(add(l, r).unwrap())) + Ok(Arc::new(add(l, r)?)) }; let def = ScalarFunction::new( @@ -893,7 +893,7 @@ mod tests { ctx.register_udf("my_add", def); - let t = ctx.table("t").unwrap(); + let t = ctx.table("t")?; let plan = LogicalPlanBuilder::from(t.to_logical_plan().as_ref()) .project(vec![ From a4d7631beee74ba81cea2f91077a5c8c7bc10b8e Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 12:00:21 -0600 Subject: [PATCH 06/13] implement sqrt as first built-in scalar function --- rust/datafusion/src/execution/context.rs | 15 +++--- .../physical_plan/math_expressions.rs | 51 +++++++++++++++++++ .../src/execution/physical_plan/mod.rs | 1 + rust/datafusion/tests/sql.rs | 6 +-- 4 files changed, 64 insertions(+), 9 deletions(-) create mode 100644 rust/datafusion/src/execution/physical_plan/math_expressions.rs diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 7b9c3087343..8c809d3794e 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -39,6 +39,7 @@ use crate::execution::physical_plan::expressions::{ }; use crate::execution::physical_plan::hash_aggregate::HashAggregateExec; use crate::execution::physical_plan::limit::LimitExec; +use crate::execution::physical_plan::math_expressions::register_math_functions; use crate::execution::physical_plan::merge::MergeExec; use crate::execution::physical_plan::projection::ProjectionExec; use crate::execution::physical_plan::selection::SelectionExec; @@ -64,10 +65,12 @@ pub struct ExecutionContext { impl ExecutionContext { /// Create a new execution context for in-memory queries pub fn new() -> Self { - Self { + let mut ctx = Self { datasources: HashMap::new(), scalar_functions: HashMap::new(), - } + }; + register_math_functions(&mut ctx); + ctx } /// Execute a SQL query and produce a Relation (a schema-aware iterator over a series @@ -155,8 +158,8 @@ impl ExecutionContext { } /// Register a scalar UDF - pub fn register_udf(&mut self, name: &str, f: ScalarFunction) { - self.scalar_functions.insert(name.to_owned(), Box::new(f)); + pub fn register_udf(&mut self, f: ScalarFunction) { + self.scalar_functions.insert(f.name.clone(), Box::new(f)); } fn build_schema(&self, columns: Vec) -> Result { @@ -881,7 +884,7 @@ mod tests { Ok(Arc::new(add(l, r)?)) }; - let def = ScalarFunction::new( + let my_add = ScalarFunction::new( "my_add", vec![ Field::new("a", DataType::Int32, true), @@ -891,7 +894,7 @@ mod tests { myfunc, ); - ctx.register_udf("my_add", def); + ctx.register_udf(my_add); let t = ctx.table("t")?; diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs new file mode 100644 index 00000000000..4f105af3821 --- /dev/null +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Math expressions + +use crate::execution::context::ExecutionContext; +use crate::execution::physical_plan::udf::ScalarFunction; + +use arrow::array::{ArrayRef, Float64Array, Float64Builder}; +use arrow::datatypes::{DataType, Field}; + +use std::sync::Arc; + +/// Register math scalar functions with the context +pub fn register_math_functions(ctx: &mut ExecutionContext) { + ctx.register_udf(sqrt_fn()); +} + +fn sqrt_fn() -> ScalarFunction { + ScalarFunction::new( + "sqrt", + vec![Field::new("n", DataType::Float64, true)], + DataType::Float64, + |args: &Vec| { + let input = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input.len()); + for i in 0..input.len() { + builder.append_value(input.value(i).sqrt())?; + } + Ok(Arc::new(builder.finish())) + }, + ) +} diff --git a/rust/datafusion/src/execution/physical_plan/mod.rs b/rust/datafusion/src/execution/physical_plan/mod.rs index 9bad646b557..9868b78db01 100644 --- a/rust/datafusion/src/execution/physical_plan/mod.rs +++ b/rust/datafusion/src/execution/physical_plan/mod.rs @@ -91,6 +91,7 @@ pub mod datasource; pub mod expressions; pub mod hash_aggregate; pub mod limit; +pub mod math_expressions; pub mod merge; pub mod parquet; pub mod projection; diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index c46664fa3f4..b60124ade09 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -149,7 +149,7 @@ fn csv_query_group_by_int_min_max() { fn csv_query_avg_sqrt() -> Result<()> { let mut ctx = create_ctx()?; register_aggregate_csv(&mut ctx); - let sql = "SELECT avg(sqrt(c12)) FROM aggregate_test_100"; + let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql); actual.sort(); let expected = "0.6706002946036462".to_string(); @@ -174,13 +174,13 @@ fn create_ctx() -> Result { }; let sqrt_meta = ScalarFunction::new( - "sqrt", + "custom_sqrt", vec![Field::new("n", DataType::Float64, true)], DataType::Float64, sqrt, ); - ctx.register_udf("sqrt", sqrt_meta); + ctx.register_udf(sqrt_meta); Ok(ctx) } From b7bacbda18812e01805bfbc098406955d7643bad Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 12:06:02 -0600 Subject: [PATCH 07/13] code cleanup --- .../physical_plan/math_expressions.rs | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 4f105af3821..10eceda3bc1 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -17,10 +17,11 @@ //! Math expressions +use crate::error::ExecutionError; use crate::execution::context::ExecutionContext; use crate::execution::physical_plan::udf::ScalarFunction; -use arrow::array::{ArrayRef, Float64Array, Float64Builder}; +use arrow::array::{Array, ArrayRef, Float64Array, Float64Builder}; use arrow::datatypes::{DataType, Field}; use std::sync::Arc; @@ -36,16 +37,24 @@ fn sqrt_fn() -> ScalarFunction { vec![Field::new("n", DataType::Float64, true)], DataType::Float64, |args: &Vec| { - let input = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); + let n = &args[0].as_any().downcast_ref::(); - let mut builder = Float64Builder::new(input.len()); - for i in 0..input.len() { - builder.append_value(input.value(i).sqrt())?; + match n { + Some(array) => { + let mut builder = Float64Builder::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array.value(i).sqrt())?; + } + } + Ok(Arc::new(builder.finish())) + } + _ => Err(ExecutionError::General( + "Invalid data type for sqrt".to_owned(), + )), } - Ok(Arc::new(builder.finish())) }, ) } From 2e440ce56fdb10afc906dcb8585d89f82e023ad2 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 12:09:30 -0600 Subject: [PATCH 08/13] rebase --- rust/datafusion/src/execution/context.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 8c809d3794e..7b9893f9245 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -898,7 +898,7 @@ mod tests { let t = ctx.table("t")?; - let plan = LogicalPlanBuilder::from(t.to_logical_plan().as_ref()) + let plan = LogicalPlanBuilder::from(&t.to_logical_plan()) .project(vec![ col("a"), col("b"), From fee950ab8da080f0afaf27d1d069f664ace9e78a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 28 Mar 2020 13:31:29 -0600 Subject: [PATCH 09/13] implement some common unary math expressions --- .../physical_plan/math_expressions.rs | 74 ++++++++++++------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 10eceda3bc1..65a9dbe178c 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -26,35 +26,53 @@ use arrow::datatypes::{DataType, Field}; use std::sync::Arc; -/// Register math scalar functions with the context -pub fn register_math_functions(ctx: &mut ExecutionContext) { - ctx.register_udf(sqrt_fn()); -} - -fn sqrt_fn() -> ScalarFunction { - ScalarFunction::new( - "sqrt", - vec![Field::new("n", DataType::Float64, true)], - DataType::Float64, - |args: &Vec| { - let n = &args[0].as_any().downcast_ref::(); - - match n { - Some(array) => { - let mut builder = Float64Builder::new(array.len()); - for i in 0..array.len() { - if array.is_null(i) { - builder.append_null()?; - } else { - builder.append_value(array.value(i).sqrt())?; +macro_rules! math_unary_function { + ($NAME:expr, $FUNC:ident) => { + ScalarFunction::new( + $NAME, + vec![Field::new("n", DataType::Float64, true)], + DataType::Float64, + |args: &Vec| { + let n = &args[0].as_any().downcast_ref::(); + match n { + Some(array) => { + let mut builder = Float64Builder::new(array.len()); + for i in 0..array.len() { + if array.is_null(i) { + builder.append_null()?; + } else { + builder.append_value(array.value(i).$FUNC())?; + } } + Ok(Arc::new(builder.finish())) } - Ok(Arc::new(builder.finish())) + _ => Err(ExecutionError::General(format!( + "Invalid data type for {}", + $NAME + ))), } - _ => Err(ExecutionError::General( - "Invalid data type for sqrt".to_owned(), - )), - } - }, - ) + }, + ) + }; +} + +/// Register math scalar functions with the context +pub fn register_math_functions(ctx: &mut ExecutionContext) { + ctx.register_udf(math_unary_function!("sqrt", sqrt)); + ctx.register_udf(math_unary_function!("sin", sin)); + ctx.register_udf(math_unary_function!("cos", cos)); + ctx.register_udf(math_unary_function!("tan", tan)); + ctx.register_udf(math_unary_function!("asin", asin)); + ctx.register_udf(math_unary_function!("acos", acos)); + ctx.register_udf(math_unary_function!("atan", atan)); + ctx.register_udf(math_unary_function!("floor", floor)); + ctx.register_udf(math_unary_function!("ceil", ceil)); + ctx.register_udf(math_unary_function!("round", round)); + ctx.register_udf(math_unary_function!("trunc", trunc)); + ctx.register_udf(math_unary_function!("abs", abs)); + ctx.register_udf(math_unary_function!("signum", signum)); + ctx.register_udf(math_unary_function!("exp", exp)); + ctx.register_udf(math_unary_function!("log", ln)); + ctx.register_udf(math_unary_function!("log2", log2)); + ctx.register_udf(math_unary_function!("log10", log10)); } From d490521ccaada39afe97834c220b757cdf74d4a8 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 Mar 2020 09:25:48 -0600 Subject: [PATCH 10/13] Implement type coercion for scalar function arguments --- rust/datafusion/src/execution/context.rs | 7 +- .../physical_plan/math_expressions.rs | 36 +++ rust/datafusion/src/logicalplan.rs | 5 + .../datafusion/src/optimizer/type_coercion.rs | 209 ++++++++++-------- 4 files changed, 169 insertions(+), 88 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 7b9893f9245..b36d31d2b45 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -162,6 +162,11 @@ impl ExecutionContext { self.scalar_functions.insert(f.name.clone(), Box::new(f)); } + /// Get a reference to the registered scalar functions + pub fn scalar_functions(&self) -> &HashMap> { + &self.scalar_functions + } + fn build_schema(&self, columns: Vec) -> Result { let mut fields = Vec::new(); @@ -251,7 +256,7 @@ impl ExecutionContext { let rules: Vec> = vec![ Box::new(ResolveColumnsRule::new()), Box::new(ProjectionPushDown::new()), - Box::new(TypeCoercionRule::new()), + Box::new(TypeCoercionRule::new(&self.scalar_functions)), ]; let mut plan = plan.clone(); for mut rule in rules { diff --git a/rust/datafusion/src/execution/physical_plan/math_expressions.rs b/rust/datafusion/src/execution/physical_plan/math_expressions.rs index 65a9dbe178c..25bb7330af0 100644 --- a/rust/datafusion/src/execution/physical_plan/math_expressions.rs +++ b/rust/datafusion/src/execution/physical_plan/math_expressions.rs @@ -76,3 +76,39 @@ pub fn register_math_functions(ctx: &mut ExecutionContext) { ctx.register_udf(math_unary_function!("log2", log2)); ctx.register_udf(math_unary_function!("log10", log10)); } + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::Result; + use crate::logicalplan::{sqrt, Expr, LogicalPlanBuilder}; + use arrow::datatypes::Schema; + + #[test] + fn cast_i8_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Int8, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + let expected = "Projection: sqrt(CAST(#0 AS Float64))\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } + + #[test] + fn no_cast_f64_input() -> Result<()> { + let schema = Schema::new(vec![Field::new("c0", DataType::Float64, true)]); + let plan = LogicalPlanBuilder::scan("", "", &schema, None)? + .project(vec![sqrt(Expr::UnresolvedColumn("c0".to_owned()))])? + .build()?; + let ctx = ExecutionContext::new(); + let plan = ctx.optimize(&plan)?; + let expected = "Projection: sqrt(#0)\ + \n TableScan: projection=Some([0])"; + assert_eq!(format!("{:?}", plan), expected); + Ok(()) + } +} diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 0f7dbfff0ca..2a41d120040 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -376,6 +376,11 @@ pub fn lit_str(str: &str) -> Expr { Expr::Literal(ScalarValue::Utf8(str.to_owned())) } +/// Create an expression representing the sqrt scalar function +pub fn sqrt(e: Expr) -> Expr { + scalar_function("sqrt", vec![e], DataType::Float64) +} + /// Create an aggregate expression pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { Expr::AggregateFunction { diff --git a/rust/datafusion/src/optimizer/type_coercion.rs b/rust/datafusion/src/optimizer/type_coercion.rs index b8d90a90910..5e5dc0a8deb 100644 --- a/rust/datafusion/src/optimizer/type_coercion.rs +++ b/rust/datafusion/src/optimizer/type_coercion.rs @@ -15,35 +15,145 @@ // specific language governing permissions and limitations // under the License. -//! The type_coercion optimizer rule ensures that all binary operators are operating on +//! The type_coercion optimizer rule ensures that all operators are operating on //! compatible types by adding explicit cast operations to expressions. For example, //! the operation `c_float + c_int` would be rewritten as `c_float + CAST(c_int AS //! float)`. This keeps the runtime query execution code much simpler. +use std::collections::HashMap; use std::sync::Arc; use arrow::datatypes::Schema; use crate::error::{ExecutionError, Result}; +use crate::execution::physical_plan::udf::ScalarFunction; use crate::logicalplan::LogicalPlan; use crate::logicalplan::{Expr, LogicalPlanBuilder}; use crate::optimizer::optimizer::OptimizerRule; use crate::optimizer::utils; /// Implementation of type coercion optimizer rule -pub struct TypeCoercionRule {} +pub struct TypeCoercionRule<'a> { + scalar_functions: &'a HashMap>, +} + +impl<'a> TypeCoercionRule<'a> { + /// Create a new type coercion optimizer rule using meta-data about registered + /// scalar functions + pub fn new(scalar_functions: &'a HashMap>) -> Self { + Self { scalar_functions } + } + + /// Rewrite an expression list to include explicit CAST operations when required + fn rewrite_expr_list(&self, expr: &Vec, schema: &Schema) -> Result> { + Ok(expr + .iter() + .map(|e| self.rewrite_expr(e, schema)) + .collect::>>()?) + } + + /// Rewrite an expression to include explicit CAST operations when required + fn rewrite_expr(&self, expr: &Expr, schema: &Schema) -> Result { + match expr { + Expr::BinaryExpr { left, op, right } => { + let left = self.rewrite_expr(left, schema)?; + let right = self.rewrite_expr(right, schema)?; + let left_type = left.get_type(schema)?; + let right_type = right.get_type(schema)?; + if left_type == right_type { + Ok(Expr::BinaryExpr { + left: Arc::new(left), + op: op.clone(), + right: Arc::new(right), + }) + } else { + let super_type = utils::get_supertype(&left_type, &right_type)?; + Ok(Expr::BinaryExpr { + left: Arc::new(left.cast_to(&super_type, schema)?), + op: op.clone(), + right: Arc::new(right.cast_to(&super_type, schema)?), + }) + } + } + Expr::IsNull(e) => Ok(Expr::IsNull(Arc::new(self.rewrite_expr(e, schema)?))), + Expr::IsNotNull(e) => { + Ok(Expr::IsNotNull(Arc::new(self.rewrite_expr(e, schema)?))) + } + Expr::ScalarFunction { + name, + args, + return_type, + } => { + // cast the inputs of scalar functions to the appropriate type where possible + match self.scalar_functions.get(name) { + Some(func_meta) => { + let mut func_args = Vec::with_capacity(args.len()); + for i in 0..args.len() { + let field = &func_meta.args[i]; + let expr = self.rewrite_expr(&args[i], schema)?; + let actual_type = expr.get_type(schema)?; + let required_type = field.data_type(); + if &actual_type == required_type { + func_args.push(expr) + } else { + let super_type = + utils::get_supertype(&actual_type, required_type)?; + func_args.push(expr.cast_to(&super_type, schema)?); + } + } + + Ok(Expr::ScalarFunction { + name: name.clone(), + args: func_args, + return_type: return_type.clone(), + }) + } + _ => Err(ExecutionError::General(format!( + "Invalid scalar function {}", + name + ))), + } + } + Expr::AggregateFunction { + name, + args, + return_type, + } => Ok(Expr::AggregateFunction { + name: name.clone(), + args: args + .iter() + .map(|a| self.rewrite_expr(a, schema)) + .collect::>>()?, + return_type: return_type.clone(), + }), + Expr::Cast { .. } => Ok(expr.clone()), + Expr::Column(_) => Ok(expr.clone()), + Expr::Alias(expr, alias) => Ok(Expr::Alias( + Arc::new(self.rewrite_expr(expr, schema)?), + alias.to_owned(), + )), + Expr::Literal(_) => Ok(expr.clone()), + Expr::UnresolvedColumn(_) => Ok(expr.clone()), + Expr::Not(_) => Ok(expr.clone()), + Expr::Sort { .. } => Ok(expr.clone()), + Expr::Wildcard { .. } => Err(ExecutionError::General( + "Wildcard expressions are not valid in a logical query plan".to_owned(), + )), + } + } +} -impl OptimizerRule for TypeCoercionRule { +impl<'a> OptimizerRule for TypeCoercionRule<'a> { fn optimize(&mut self, plan: &LogicalPlan) -> Result { match plan { LogicalPlan::Projection { expr, input, .. } => { LogicalPlanBuilder::from(&self.optimize(input)?) - .project(rewrite_expr_list(expr, input.schema())?)? + .project(self.rewrite_expr_list(expr, input.schema())?)? .build() } LogicalPlan::Selection { expr, input, .. } => { LogicalPlanBuilder::from(&self.optimize(input)?) - .filter(rewrite_expr(expr, input.schema())?)? + .filter(self.rewrite_expr(expr, input.schema())?)? .build() } LogicalPlan::Aggregate { @@ -53,8 +163,8 @@ impl OptimizerRule for TypeCoercionRule { .. } => LogicalPlanBuilder::from(&self.optimize(input)?) .aggregate( - rewrite_expr_list(group_expr, input.schema())?, - rewrite_expr_list(aggr_expr, input.schema())?, + self.rewrite_expr_list(group_expr, input.schema())?, + self.rewrite_expr_list(aggr_expr, input.schema())?, )? .build(), LogicalPlan::TableScan { .. } => Ok(plan.clone()), @@ -69,88 +179,10 @@ impl OptimizerRule for TypeCoercionRule { } } -impl TypeCoercionRule { - #[allow(missing_docs)] - pub fn new() -> Self { - Self {} - } -} - -fn rewrite_expr_list(expr: &Vec, schema: &Schema) -> Result> { - Ok(expr - .iter() - .map(|e| rewrite_expr(e, schema)) - .collect::>>()?) -} - -/// Rewrite an expression to include explicit CAST operations when required -fn rewrite_expr(expr: &Expr, schema: &Schema) -> Result { - match expr { - Expr::BinaryExpr { left, op, right } => { - let left = rewrite_expr(left, schema)?; - let right = rewrite_expr(right, schema)?; - let left_type = left.get_type(schema)?; - let right_type = right.get_type(schema)?; - if left_type == right_type { - Ok(Expr::BinaryExpr { - left: Arc::new(left), - op: op.clone(), - right: Arc::new(right), - }) - } else { - let super_type = utils::get_supertype(&left_type, &right_type)?; - Ok(Expr::BinaryExpr { - left: Arc::new(left.cast_to(&super_type, schema)?), - op: op.clone(), - right: Arc::new(right.cast_to(&super_type, schema)?), - }) - } - } - Expr::IsNull(e) => Ok(Expr::IsNull(Arc::new(rewrite_expr(e, schema)?))), - Expr::IsNotNull(e) => Ok(Expr::IsNotNull(Arc::new(rewrite_expr(e, schema)?))), - Expr::ScalarFunction { - name, - args, - return_type, - } => Ok(Expr::ScalarFunction { - name: name.clone(), - args: args - .iter() - .map(|a| rewrite_expr(a, schema)) - .collect::>>()?, - return_type: return_type.clone(), - }), - Expr::AggregateFunction { - name, - args, - return_type, - } => Ok(Expr::AggregateFunction { - name: name.clone(), - args: args - .iter() - .map(|a| rewrite_expr(a, schema)) - .collect::>>()?, - return_type: return_type.clone(), - }), - Expr::Cast { .. } => Ok(expr.clone()), - Expr::Column(_) => Ok(expr.clone()), - Expr::Alias(expr, alias) => Ok(Expr::Alias( - Arc::new(rewrite_expr(expr, schema)?), - alias.to_owned(), - )), - Expr::Literal(_) => Ok(expr.clone()), - Expr::UnresolvedColumn(_) => Ok(expr.clone()), - Expr::Not(_) => Ok(expr.clone()), - Expr::Sort { .. } => Ok(expr.clone()), - Expr::Wildcard { .. } => Err(ExecutionError::General( - "Wildcard expressions are not valid in a logical query plan".to_owned(), - )), - } -} - #[cfg(test)] mod tests { use super::*; + use crate::execution::context::ExecutionContext; use crate::logicalplan::Expr::*; use crate::logicalplan::Operator; use arrow::datatypes::{DataType, Field, Schema}; @@ -223,7 +255,10 @@ mod tests { right: Arc::new(Column(1)), }; - let expr2 = rewrite_expr(&expr, &schema).unwrap(); + let ctx = ExecutionContext::new(); + let rule = TypeCoercionRule::new(ctx.scalar_functions()); + + let expr2 = rule.rewrite_expr(&expr, &schema).unwrap(); assert_eq!(expected, format!("{:?}", expr2)); } From 32496a6c30c0346139f2ac7297a350b6be286f7a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 Mar 2020 09:33:24 -0600 Subject: [PATCH 11/13] add convenience methods for creating logical unary math expressions --- rust/datafusion/src/logicalplan.rs | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 2a41d120040..7670f0e10f1 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -376,11 +376,35 @@ pub fn lit_str(str: &str) -> Expr { Expr::Literal(ScalarValue::Utf8(str.to_owned())) } -/// Create an expression representing the sqrt scalar function -pub fn sqrt(e: Expr) -> Expr { - scalar_function("sqrt", vec![e], DataType::Float64) +/// Create an convenience function representing a unary scalar function +macro_rules! unary_math_expr { + ($NAME:expr, $FUNC:ident) => { + #[allow(missing_docs)] + pub fn $FUNC(e: Expr) -> Expr { + scalar_function($NAME, vec![e], DataType::Float64) + } + } } +// generate methods for creating the supported unary math expressions +unary_math_expr!("sqrt", sqrt); +unary_math_expr!("sin", sin); +unary_math_expr!("cos", cos); +unary_math_expr!("tan", tan); +unary_math_expr!("asin", asin); +unary_math_expr!("acos", acos); +unary_math_expr!("atan", atan); +unary_math_expr!("floor", floor); +unary_math_expr!("ceil", ceil); +unary_math_expr!("round", round); +unary_math_expr!("trunc", trunc); +unary_math_expr!("abs", abs); +unary_math_expr!("signum", signum); +unary_math_expr!("exp", exp); +unary_math_expr!("log", ln); +unary_math_expr!("log2", log2); +unary_math_expr!("log10", log10); + /// Create an aggregate expression pub fn aggregate_expr(name: &str, expr: Expr, return_type: DataType) -> Expr { Expr::AggregateFunction { From 08bd7017cfe3784536762957a6a463bcb79a431a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 Mar 2020 09:41:53 -0600 Subject: [PATCH 12/13] code cleanup --- rust/datafusion/tests/sql.rs | 38 ++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 17 deletions(-) diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index b60124ade09..072f3ce8bc2 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -160,30 +160,34 @@ fn csv_query_avg_sqrt() -> Result<()> { fn create_ctx() -> Result { let mut ctx = ExecutionContext::new(); - let sqrt: ScalarUdf = |args: &Vec| { - let input = &args[0] - .as_any() - .downcast_ref::() - .expect("cast failed"); - - let mut builder = Float64Builder::new(input.len()); - for i in 0..input.len() { - builder.append_value(input.value(i).sqrt())?; - } - Ok(Arc::new(builder.finish())) - }; - - let sqrt_meta = ScalarFunction::new( + // register a custom UDF + ctx.register_udf(ScalarFunction::new( "custom_sqrt", vec![Field::new("n", DataType::Float64, true)], DataType::Float64, - sqrt, - ); + custom_sqrt, + )); - ctx.register_udf(sqrt_meta); Ok(ctx) } +fn custom_sqrt(args: &Vec) -> Result { + let input = &args[0] + .as_any() + .downcast_ref::() + .expect("cast failed"); + + let mut builder = Float64Builder::new(input.len()); + for i in 0..input.len() { + if input.is_null(i) { + builder.append_null(); + } else { + builder.append_value(input.value(i).sqrt())?; + } + } + Ok(Arc::new(builder.finish())) +} + #[test] fn csv_query_avg() { let mut ctx = ExecutionContext::new(); From 8250e90d0713f889f14cc4dd325b8e8d0967acb0 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sun, 29 Mar 2020 09:48:59 -0600 Subject: [PATCH 13/13] formatting --- rust/datafusion/src/logicalplan.rs | 2 +- rust/datafusion/tests/sql.rs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/rust/datafusion/src/logicalplan.rs b/rust/datafusion/src/logicalplan.rs index 7670f0e10f1..0443e6cd8ee 100644 --- a/rust/datafusion/src/logicalplan.rs +++ b/rust/datafusion/src/logicalplan.rs @@ -383,7 +383,7 @@ macro_rules! unary_math_expr { pub fn $FUNC(e: Expr) -> Expr { scalar_function($NAME, vec![e], DataType::Float64) } - } + }; } // generate methods for creating the supported unary math expressions diff --git a/rust/datafusion/tests/sql.rs b/rust/datafusion/tests/sql.rs index 072f3ce8bc2..949b6e2395c 100644 --- a/rust/datafusion/tests/sql.rs +++ b/rust/datafusion/tests/sql.rs @@ -27,7 +27,7 @@ use arrow::record_batch::RecordBatch; use datafusion::error::Result; use datafusion::execution::context::ExecutionContext; -use datafusion::execution::physical_plan::udf::{ScalarFunction, ScalarUdf}; +use datafusion::execution::physical_plan::udf::ScalarFunction; use datafusion::logicalplan::LogicalPlan; const DEFAULT_BATCH_SIZE: usize = 1024 * 1024; @@ -180,7 +180,7 @@ fn custom_sqrt(args: &Vec) -> Result { let mut builder = Float64Builder::new(input.len()); for i in 0..input.len() { if input.is_null(i) { - builder.append_null(); + builder.append_null()?; } else { builder.append_value(input.value(i).sqrt())?; }