diff --git a/rust/datafusion/README.md b/rust/datafusion/README.md index 5405f269d28..43e99e22bb5 100644 --- a/rust/datafusion/README.md +++ b/rust/datafusion/README.md @@ -52,7 +52,8 @@ DataFusion includes a simple command-line interactive SQL utility. See the [CLI - [x] Filter (WHERE) - [x] Limit - [x] Aggregate -- [x] UDFs +- [x] UDFs (user-defined functions) +- [x] UDAFs (user-defined aggregate functions) - [x] Common math functions - String functions - [x] Length diff --git a/rust/datafusion/examples/simple_udaf.rs b/rust/datafusion/examples/simple_udaf.rs new file mode 100644 index 00000000000..deec9b2cc73 --- /dev/null +++ b/rust/datafusion/examples/simple_udaf.rs @@ -0,0 +1,168 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// In this example we will declare a single-type, single return type UDAF that computes the geometric mean. +/// The geometric mean is described here: https://en.wikipedia.org/wiki/Geometric_mean +use arrow::{ + array::Float32Array, array::Float64Array, array::PrimitiveArrayOps, + datatypes::DataType, record_batch::RecordBatch, +}; + +use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumulator}; +use datafusion::{prelude::*, scalar::ScalarValue}; +use std::{cell::RefCell, rc::Rc, sync::Arc}; + +// create local execution context with an in-memory table +fn create_context() -> Result { + use arrow::datatypes::{Field, Schema}; + use datafusion::datasource::MemTable; + // define a schema. + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, false)])); + + // define data in two partitions + let batch1 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![2.0, 4.0, 8.0]))], + )?; + let batch2 = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Float32Array::from(vec![64.0]))], + )?; + + // declare a new context. In spark API, this corresponds to a new spark SQLsession + let mut ctx = ExecutionContext::new(); + + // declare a table in memory. In spark API, this corresponds to createDataFrame(...). + let provider = MemTable::new(schema, vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Box::new(provider)); + Ok(ctx) +} + +/// A UDAF has state across multiple rows, and thus we require a `struct` with that state. +#[derive(Debug)] +struct GeometricMean { + n: u32, + prod: f64, +} + +impl GeometricMean { + // how the struct is initialized + pub fn new() -> Self { + GeometricMean { n: 0, prod: 1.0 } + } +} + +// UDAFs are built using the trait `Accumulator`, that offers DataFusion the necessary functions +// to use them. +impl Accumulator for GeometricMean { + // this function serializes our state to `ScalarValue`, which DataFusion uses + // to pass this state between execution stages. + // Note that this can be arbitrary data. + fn state(&self) -> Result> { + Ok(vec![ + ScalarValue::from(self.prod), + ScalarValue::from(self.n), + ]) + } + + // this function receives one entry per argument of this accumulator. + // DataFusion calls this function on every row, and expects this function to update the accumulator's state. + fn update(&mut self, values: &Vec) -> Result<()> { + // this is a one-argument UDAF, and thus we use `0`. + let value = &values[0]; + match value { + // here we map `ScalarValue` to our internal state. `Float64` indicates that this function + // only accepts Float64 as its argument (DataFusion does try to coerce arguments to this type) + // + // Note that `.map` here ensures that we ignore Nulls. + ScalarValue::Float64(e) => e.map(|value| { + self.prod *= value; + self.n += 1; + }), + _ => unreachable!(""), + }; + Ok(()) + } + + // this function receives states from other accumulators (Vec) + // and updates the accumulator. + fn merge(&mut self, states: &Vec) -> Result<()> { + let prod = &states[0]; + let n = &states[1]; + match (prod, n) { + (ScalarValue::Float64(Some(prod)), ScalarValue::UInt32(Some(n))) => { + self.prod *= prod; + self.n += n; + } + _ => unreachable!(""), + }; + Ok(()) + } + + // DataFusion expects this function to return the final value of this aggregator. + // in this case, this is the formula of the geometric mean + fn evaluate(&self) -> Result { + let value = self.prod.powf(1.0 / self.n as f64); + Ok(ScalarValue::from(value)) + } + + // Optimization hint: this trait also supports `update_batch` and `merge_batch`, + // that can be used to perform these operations on arrays instead of single values. + // By default, these methods call `update` and `merge` row by row +} + +fn main() -> Result<()> { + let mut ctx = create_context()?; + + // here is where we define the UDAF. We also declare its signature: + let geometric_mean = create_udaf( + // the name; used to represent it in plan descriptions and in the registry, to use in SQL. + "geo_mean", + // the input type; DataFusion guarantees that the first entry of `values` in `update` has this type. + DataType::Float64, + // the return type; DataFusion expects this to match the type returned by `evaluate`. + Arc::new(DataType::Float64), + // This is the accumulator factory; DataFusion uses it to create new accumulators. + Arc::new(|| Ok(Rc::new(RefCell::new(GeometricMean::new())))), + // This is the description of the state. `state()` must match the types here. + Arc::new(vec![DataType::Float64, DataType::UInt32]), + ); + + // get a DataFrame from the context + // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. + let df = ctx.table("t")?; + + // perform the aggregation + let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; + + // note that "a" is f32, not f64. DataFusion coerces it to match the UDAF's signature. + + // execute the query + let results = df.collect()?; + + // downcast the array to the expected type + let result = results[0] + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + + // verify that the calculation is correct + assert_eq!(result.value(0), 8.0); + + Ok(()) +} diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 09cf89217e9..216c582ffa5 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -30,7 +30,6 @@ use arrow::csv; use arrow::datatypes::*; use arrow::record_batch::RecordBatch; -use crate::dataframe::DataFrame; use crate::datasource::csv::CsvFile; use crate::datasource::parquet::ParquetTable; use crate::datasource::TableProvider; @@ -52,6 +51,7 @@ use crate::sql::{ planner::{SchemaProvider, SqlToRel}, }; use crate::variable::{VarProvider, VarType}; +use crate::{dataframe::DataFrame, physical_plan::udaf::AggregateUDF}; /// ExecutionContext is the main interface for executing queries with DataFusion. The context /// provides the following functionality: @@ -109,6 +109,7 @@ impl ExecutionContext { datasources: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), + aggregate_functions: HashMap::new(), config, }, }; @@ -195,6 +196,13 @@ impl ExecutionContext { .insert(f.name.clone(), Arc::new(f)); } + /// Register a aggregate UDF + pub fn register_udaf(&mut self, f: AggregateUDF) { + self.state + .aggregate_functions + .insert(f.name.clone(), Arc::new(f)); + } + /// Creates a DataFrame for reading a CSV data source. pub fn read_csv( &mut self, @@ -473,6 +481,8 @@ pub struct ExecutionContextState { pub scalar_functions: HashMap>, /// Variable provider that are registered with the context pub var_provider: HashMap>, + /// Aggregate functions registered in the context + pub aggregate_functions: HashMap>, /// Context configuration pub config: ExecutionConfig, } @@ -487,6 +497,12 @@ impl SchemaProvider for ExecutionContextState { .get(name) .and_then(|func| Some(func.clone())) } + + fn get_aggregate_meta(&self, name: &str) -> Option> { + self.aggregate_functions + .get(name) + .and_then(|func| Some(func.clone())) + } } impl FunctionRegistry for ExecutionContextState { @@ -504,23 +520,38 @@ impl FunctionRegistry for ExecutionContextState { Ok(result.unwrap()) } } + + fn udaf(&self, name: &str) -> Result<&AggregateUDF> { + let result = self.aggregate_functions.get(name); + if result.is_none() { + Err(ExecutionError::General( + format!("There is no UDAF named \"{}\" in the registry", name) + .to_string(), + )) + } else { + Ok(result.unwrap()) + } + } } #[cfg(test)] mod tests { use super::*; - use crate::datasource::MemTable; use crate::logical_plan::{col, create_udf, sum}; use crate::physical_plan::functions::ScalarFunctionImplementation; use crate::test; use crate::variable::VarType; + use crate::{ + datasource::MemTable, logical_plan::create_udaf, + physical_plan::expressions::AvgAccumulator, + }; use arrow::array::{ ArrayRef, Float64Array, Int32Array, PrimitiveArrayOps, StringArray, StringArrayOps, }; use arrow::compute::add; - use std::fs::File; + use std::{cell::RefCell, fs::File, rc::Rc}; use std::{io::prelude::*, sync::Mutex}; use tempfile::TempDir; use test::*; @@ -1200,6 +1231,57 @@ mod tests { Ok(()) } + /// tests the creation, registration and usage of a UDAF + #[test] + fn simple_udaf() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch1 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![1, 2, 3]))], + )?; + let batch2 = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![4, 5]))], + )?; + + let mut ctx = ExecutionContext::new(); + + let provider = MemTable::new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; + ctx.register_table("t", Box::new(provider)); + + // define a udaf, using a DataFusion's accumulator + let my_avg = create_udaf( + "MY_AVG", + DataType::Float64, + Arc::new(DataType::Float64), + Arc::new(|| { + Ok(Rc::new(RefCell::new(AvgAccumulator::try_new( + &DataType::Float64, + )?))) + }), + Arc::new(vec![DataType::UInt64, DataType::Float64]), + ); + + ctx.register_udaf(my_avg); + + let result = collect(&mut ctx, "SELECT MY_AVG(a) FROM t")?; + + let batch = &result[0]; + assert_eq!(1, batch.num_columns()); + assert_eq!(1, batch.num_rows()); + + let values = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("failed to cast version"); + assert_eq!(values.len(), 1); + // avg(1,2,3,4,5) = 3.0 + assert_eq!(values.value(0), 3.0_f64); + Ok(()) + } + #[test] fn custom_query_planner() -> Result<()> { let mut ctx = ExecutionContext::with_config( diff --git a/rust/datafusion/src/logical_plan/mod.rs b/rust/datafusion/src/logical_plan/mod.rs index 30759ffced2..35fbd67073c 100644 --- a/rust/datafusion/src/logical_plan/mod.rs +++ b/rust/datafusion/src/logical_plan/mod.rs @@ -24,6 +24,7 @@ use fmt::Debug; use std::{any::Any, collections::HashSet, fmt, sync::Arc}; +use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use crate::datasource::parquet::ParquetTable; @@ -31,6 +32,7 @@ use crate::datasource::TableProvider; use crate::error::{ExecutionError, Result}; use crate::{ datasource::csv::{CsvFile, CsvReadOptions}, + physical_plan::udaf::AggregateUDF, scalar::ScalarValue, }; use crate::{ @@ -88,6 +90,13 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result { Expr::AggregateFunction { fun, args, .. } => { create_function_name(&fun.to_string(), args, input_schema) } + Expr::AggregateUDF { fun, args } => { + let mut names = Vec::with_capacity(args.len()); + for e in args { + names.push(create_name(e, input_schema)?); + } + Ok(format!("{}({})", fun.name, names.join(","))) + } other => Err(ExecutionError::NotImplemented(format!( "Physical plan does not support logical expression {:?}", other @@ -179,6 +188,13 @@ pub enum Expr { /// List of expressions to feed to the functions as arguments args: Vec, }, + /// aggregate function + AggregateUDF { + /// The function + fun: Arc, + /// List of expressions to feed to the functions as arguments + args: Vec, + }, /// Represents a reference to all fields in a schema. Wildcard, } @@ -219,6 +235,13 @@ impl Expr { .collect::>>()?; aggregates::return_type(fun, &data_types) } + Expr::AggregateUDF { fun, args, .. } => { + let data_types = args + .iter() + .map(|e| e.get_type(schema)) + .collect::>>()?; + Ok((fun.return_type)(&data_types)?.as_ref().clone()) + } Expr::Not(_) => Ok(DataType::Boolean), Expr::IsNull(_) => Ok(DataType::Boolean), Expr::IsNotNull(_) => Ok(DataType::Boolean), @@ -255,6 +278,7 @@ impl Expr { Expr::ScalarFunction { .. } => Ok(true), Expr::ScalarUDF { .. } => Ok(true), Expr::AggregateFunction { .. } => Ok(true), + Expr::AggregateUDF { .. } => Ok(true), Expr::Not(expr) => expr.nullable(input_schema), Expr::IsNull(_) => Ok(false), Expr::IsNotNull(_) => Ok(false), @@ -568,6 +592,26 @@ pub fn create_udf( ScalarUDF::new(name, &Signature::Exact(input_types), &return_type, &fun) } +/// Creates a new UDAF with a specific signature, state type and return type. +/// The signature and state type must match the `Acumulator's implementation`. +pub fn create_udaf( + name: &str, + input_type: DataType, + return_type: Arc, + accumulator: AccumulatorFunctionImplementation, + state_type: Arc>, +) -> AggregateUDF { + let return_type: ReturnTypeFunction = Arc::new(move |_| Ok(return_type.clone())); + let state_type: StateTypeFunction = Arc::new(move |_| Ok(state_type.clone())); + AggregateUDF::new( + name, + &Signature::Exact(vec![input_type]), + &return_type, + &accumulator, + &state_type, + ) +} + fn fmt_function(f: &mut fmt::Formatter, fun: &String, args: &Vec) -> fmt::Result { let args: Vec = args.iter().map(|arg| format!("{:?}", arg)).collect(); write!(f, "{}({})", fun, args.join(", ")) @@ -612,6 +656,7 @@ impl fmt::Debug for Expr { Expr::AggregateFunction { fun, ref args, .. } => { fmt_function(f, &fun.to_string(), args) } + Expr::AggregateUDF { fun, ref args, .. } => fmt_function(f, &fun.name, args), Expr::Wildcard => write!(f, "*"), Expr::Nested(expr) => write!(f, "({:?})", expr), } @@ -975,6 +1020,9 @@ pub trait FunctionRegistry { /// Returns a reference to the udf named `name`. fn udf(&self, name: &str) -> Result<&ScalarUDF>; + + /// Returns a reference to the udaf named `name`. + fn udaf(&self, name: &str) -> Result<&AggregateUDF>; } /// Builder for logical plans diff --git a/rust/datafusion/src/optimizer/utils.rs b/rust/datafusion/src/optimizer/utils.rs index 788d1e4c2fb..4c44055a9da 100644 --- a/rust/datafusion/src/optimizer/utils.rs +++ b/rust/datafusion/src/optimizer/utils.rs @@ -65,6 +65,7 @@ pub fn expr_to_column_names(expr: &Expr, accum: &mut HashSet) -> Result< Expr::Cast { expr, .. } => expr_to_column_names(expr, accum), Expr::Sort { expr, .. } => expr_to_column_names(expr, accum), Expr::AggregateFunction { args, .. } => exprlist_to_column_names(args, accum), + Expr::AggregateUDF { args, .. } => exprlist_to_column_names(args, accum), Expr::ScalarFunction { args, .. } => exprlist_to_column_names(args, accum), Expr::ScalarUDF { args, .. } => exprlist_to_column_names(args, accum), Expr::Wildcard => Err(ExecutionError::General( @@ -206,6 +207,7 @@ pub fn expr_sub_expressions(expr: &Expr) -> Result> { Expr::ScalarFunction { args, .. } => Ok(args.iter().collect()), Expr::ScalarUDF { args, .. } => Ok(args.iter().collect()), Expr::AggregateFunction { args, .. } => Ok(args.iter().collect()), + Expr::AggregateUDF { args, .. } => Ok(args.iter().collect()), Expr::Cast { expr, .. } => Ok(vec![expr]), Expr::Column(_) => Ok(vec![]), Expr::Alias(expr, ..) => Ok(vec![expr]), @@ -243,6 +245,10 @@ pub fn rewrite_expression(expr: &Expr, expressions: &Vec) -> Result fun: fun.clone(), args: expressions.clone(), }), + Expr::AggregateUDF { fun, .. } => Ok(Expr::AggregateUDF { + fun: fun.clone(), + args: expressions.clone(), + }), Expr::Cast { data_type, .. } => Ok(Expr::Cast { expr: Box::new(expressions[0].clone()), data_type: data_type.clone(), diff --git a/rust/datafusion/src/physical_plan/aggregates.rs b/rust/datafusion/src/physical_plan/aggregates.rs index 4523e13bfba..03833f61a52 100644 --- a/rust/datafusion/src/physical_plan/aggregates.rs +++ b/rust/datafusion/src/physical_plan/aggregates.rs @@ -29,13 +29,22 @@ use super::{ functions::Signature, type_coercion::{coerce, data_types}, - AggregateExpr, PhysicalExpr, + Accumulator, AggregateExpr, PhysicalExpr, }; use crate::error::{ExecutionError, Result}; use crate::physical_plan::expressions; use arrow::datatypes::{DataType, Schema}; use expressions::{avg_return_type, sum_return_type}; -use std::{fmt, str::FromStr, sync::Arc}; +use std::{cell::RefCell, fmt, rc::Rc, str::FromStr, sync::Arc}; + +/// the implementation of an aggregate function +pub type AccumulatorFunctionImplementation = + Arc Result>> + Send + Sync>; + +/// This signature corresponds to which types an aggregator serializes +/// its state, given its return datatype. +pub type StateTypeFunction = + Arc Result>> + Send + Sync>; /// Enum of all built-in scalar functions #[derive(Debug, Clone, PartialEq, Eq)] diff --git a/rust/datafusion/src/physical_plan/expressions.rs b/rust/datafusion/src/physical_plan/expressions.rs index 8b62079711a..0c4eaaf11be 100644 --- a/rust/datafusion/src/physical_plan/expressions.rs +++ b/rust/datafusion/src/physical_plan/expressions.rs @@ -51,7 +51,8 @@ use arrow::{ datatypes::Field, }; -fn format_state_name(name: &str, state_name: &str) -> String { +/// returns the name of the state +pub fn format_state_name(name: &str, state_name: &str) -> String { format!("{}[{}]", name, state_name) } @@ -402,8 +403,9 @@ impl AggregateExpr for Avg { } } +/// An accumulator to compute the average #[derive(Debug)] -struct AvgAccumulator { +pub(crate) struct AvgAccumulator { // sum is used for null sum: ScalarValue, count: u64, diff --git a/rust/datafusion/src/physical_plan/mod.rs b/rust/datafusion/src/physical_plan/mod.rs index df399a13a40..342f2e5a5b9 100644 --- a/rust/datafusion/src/physical_plan/mod.rs +++ b/rust/datafusion/src/physical_plan/mod.rs @@ -202,4 +202,5 @@ pub mod projection; pub mod sort; pub mod string_expressions; pub mod type_coercion; +pub mod udaf; pub mod udf; diff --git a/rust/datafusion/src/physical_plan/planner.rs b/rust/datafusion/src/physical_plan/planner.rs index c7e035cb1ee..f9db3b673d6 100644 --- a/rust/datafusion/src/physical_plan/planner.rs +++ b/rust/datafusion/src/physical_plan/planner.rs @@ -19,7 +19,7 @@ use std::sync::Arc; -use super::{aggregates, empty::EmptyExec, expressions::binary, functions}; +use super::{aggregates, empty::EmptyExec, expressions::binary, functions, udaf}; use crate::error::{ExecutionError, Result}; use crate::execution::context::ExecutionContextState; use crate::logical_plan::{ @@ -474,6 +474,19 @@ impl DefaultPhysicalPlanner { e.name(input_schema)?, ) } + Expr::AggregateUDF { fun, args, .. } => { + let args = args + .iter() + .map(|e| self.create_physical_expr(e, input_schema, ctx_state)) + .collect::>>()?; + + udaf::create_aggregate_expr( + fun, + &args, + input_schema, + e.name(input_schema)?, + ) + } other => Err(ExecutionError::General(format!( "Invalid aggregate expression '{:?}'", other @@ -540,6 +553,7 @@ mod tests { datasources: HashMap::new(), scalar_functions: HashMap::new(), var_provider: HashMap::new(), + aggregate_functions: HashMap::new(), config: ExecutionConfig::new(), } } diff --git a/rust/datafusion/src/physical_plan/udaf.rs b/rust/datafusion/src/physical_plan/udaf.rs new file mode 100644 index 00000000000..933fd237c65 --- /dev/null +++ b/rust/datafusion/src/physical_plan/udaf.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module contains functions and structs supporting user-defined aggregate functions. + +use fmt::{Debug, Formatter}; +use std::{cell::RefCell, fmt, rc::Rc}; + +use arrow::{ + datatypes::Field, + datatypes::{DataType, Schema}, +}; + +use crate::physical_plan::PhysicalExpr; +use crate::{error::Result, logical_plan::Expr}; + +use super::{ + aggregates::AccumulatorFunctionImplementation, + aggregates::StateTypeFunction, + expressions::format_state_name, + functions::{ReturnTypeFunction, Signature}, + type_coercion::coerce, + Accumulator, AggregateExpr, +}; +use std::sync::Arc; + +/// Logical representation of a user-defined aggregate function (UDAF) +/// A UDAF is different from a UDF in that it is stateful across batches. +#[derive(Clone)] +pub struct AggregateUDF { + /// name + pub name: String, + /// signature + pub signature: Signature, + /// Return type + pub return_type: ReturnTypeFunction, + /// actual implementation + pub accumulator: AccumulatorFunctionImplementation, + /// the accumulator's state's description as a function of the return type + pub state_type: StateTypeFunction, +} + +impl Debug for AggregateUDF { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + f.debug_struct("AggregateUDF") + .field("name", &self.name) + .field("signature", &self.signature) + .field("fun", &"") + .finish() + } +} + +impl AggregateUDF { + /// Create a new AggregateUDF + pub fn new( + name: &str, + signature: &Signature, + return_type: &ReturnTypeFunction, + accumulator: &AccumulatorFunctionImplementation, + state_type: &StateTypeFunction, + ) -> Self { + Self { + name: name.to_owned(), + signature: signature.clone(), + return_type: return_type.clone(), + accumulator: accumulator.clone(), + state_type: state_type.clone(), + } + } + + /// creates a logical expression with a call of the UDAF + /// This utility allows using the UDAF without requiring access to the registry. + pub fn call(&self, args: Vec) -> Expr { + Expr::AggregateUDF { + fun: Arc::new(self.clone()), + args, + } + } +} + +/// Creates a physical expression of the UDAF, that includes all necessary type coercion. +/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF. +pub fn create_aggregate_expr( + fun: &AggregateUDF, + args: &Vec>, + input_schema: &Schema, + name: String, +) -> Result> { + // coerce + let args = coerce(args, input_schema, &fun.signature)?; + + let arg_types = args + .iter() + .map(|arg| arg.data_type(input_schema)) + .collect::>>()?; + + Ok(Arc::new(AggregateFunctionExpr { + fun: fun.clone(), + args: args.clone(), + data_type: (fun.return_type)(&arg_types)?.as_ref().clone(), + name: name.clone(), + })) +} + +/// Physical aggregate expression of a UDAF. +#[derive(Debug)] +pub struct AggregateFunctionExpr { + fun: AggregateUDF, + args: Vec>, + data_type: DataType, + name: String, +} + +impl AggregateExpr for AggregateFunctionExpr { + fn expressions(&self) -> Vec> { + self.args.clone() + } + + fn state_fields(&self) -> Result> { + let fields = (self.fun.state_type)(&self.data_type)? + .iter() + .enumerate() + .map(|(i, data_type)| { + Field::new( + &format_state_name(&self.name, &format!("{}", i)), + data_type.clone(), + true, + ) + }) + .collect::>(); + + Ok(fields) + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, self.data_type.clone(), true)) + } + + fn create_accumulator(&self) -> Result>> { + (self.fun.accumulator)() + } +} diff --git a/rust/datafusion/src/sql/planner.rs b/rust/datafusion/src/sql/planner.rs index 73d341ecdf2..08d63513e72 100644 --- a/rust/datafusion/src/sql/planner.rs +++ b/rust/datafusion/src/sql/planner.rs @@ -20,12 +20,15 @@ use std::str::FromStr; use std::sync::Arc; -use crate::error::{ExecutionError, Result}; use crate::logical_plan::Expr::Alias; use crate::logical_plan::{ lit, Expr, LogicalPlan, LogicalPlanBuilder, Operator, PlanType, StringifiedPlan, }; use crate::scalar::ScalarValue; +use crate::{ + error::{ExecutionError, Result}, + physical_plan::udaf::AggregateUDF, +}; use crate::{ physical_plan::udf::ScalarUDF, physical_plan::{aggregates, functions}, @@ -49,6 +52,8 @@ pub trait SchemaProvider { fn get_table_meta(&self, name: &str) -> Option; /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; + /// Getter for a UDAF description + fn get_aggregate_meta(&self, name: &str) -> Option>; } /// SQL query planner @@ -537,7 +542,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { return Ok(Expr::AggregateFunction { fun, args }); }; - // finally, user-defined functions + // finally, user-defined functions (UDF) and UDAF match self.schema_provider.get_function_meta(&name) { Some(fm) => { let args = function @@ -551,10 +556,24 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { args, }) } - _ => Err(ExecutionError::General(format!( - "Invalid function '{}'", - name - ))), + None => match self.schema_provider.get_aggregate_meta(&name) { + Some(fm) => { + let args = function + .args + .iter() + .map(|a| self.sql_to_rex(a, schema)) + .collect::>>()?; + + Ok(Expr::AggregateUDF { + fun: fm.clone(), + args, + }) + } + _ => Err(ExecutionError::General(format!( + "Invalid function '{}'", + name + ))), + }, } } @@ -571,7 +590,7 @@ impl<'a, S: SchemaProvider> SqlToRel<'a, S> { /// Determine if an expression is an aggregate expression or not fn is_aggregate_expr(e: &Expr) -> bool { match e { - Expr::AggregateFunction { .. } => true, + Expr::AggregateFunction { .. } | Expr::AggregateUDF { .. } => true, _ => false, } } @@ -936,5 +955,9 @@ mod tests { _ => None, } } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + unimplemented!() + } } }