-
Notifications
You must be signed in to change notification settings - Fork 4k
ARROW-9836: [Rust][DataFusion] Improve API for usage of UDFs #8032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,14 +21,15 @@ | |
| //! Logical query plans can then be optimized and executed directly, or translated into | ||
| //! physical query plans and executed. | ||
|
|
||
| use std::{fmt, sync::Arc}; | ||
| use std::{collections::HashSet, fmt, sync::Arc}; | ||
|
|
||
| use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; | ||
|
|
||
| use crate::datasource::csv::{CsvFile, CsvReadOptions}; | ||
| use crate::datasource::parquet::ParquetTable; | ||
| use crate::datasource::TableProvider; | ||
| use crate::error::{ExecutionError, Result}; | ||
| use crate::physical_plan::udf; | ||
| use crate::{ | ||
| physical_plan::{ | ||
| expressions::binary_operator_data_type, functions, type_coercion::can_coerce_from, | ||
|
|
@@ -199,12 +200,12 @@ fn create_name(e: &Expr, input_schema: &Schema) -> Result<String> { | |
| } | ||
| Ok(format!("{}({})", fun, names.join(","))) | ||
| } | ||
| Expr::ScalarUDF { name, args, .. } => { | ||
| Expr::ScalarUDF { fun, args, .. } => { | ||
| let mut names = Vec::with_capacity(args.len()); | ||
| for e in args { | ||
| names.push(create_name(e, input_schema)?); | ||
| } | ||
| Ok(format!("{}({})", name, names.join(","))) | ||
| Ok(format!("{}({})", fun.name, names.join(","))) | ||
| } | ||
| Expr::AggregateFunction { name, args, .. } => { | ||
| let mut names = Vec::with_capacity(args.len()); | ||
|
|
@@ -226,7 +227,7 @@ pub fn exprlist_to_fields(expr: &[Expr], input_schema: &Schema) -> Result<Vec<Fi | |
| } | ||
|
|
||
| /// Relation expression | ||
| #[derive(Clone, PartialEq)] | ||
| #[derive(Clone)] | ||
| pub enum Expr { | ||
| /// An aliased expression | ||
| Alias(Box<Expr>, String), | ||
|
|
@@ -276,12 +277,10 @@ pub enum Expr { | |
| }, | ||
| /// scalar udf. | ||
| ScalarUDF { | ||
| /// The function's name | ||
| name: String, | ||
| /// The function | ||
| fun: Arc<udf::ScalarFunction>, | ||
| /// List of expressions to feed to the functions as arguments | ||
| args: Vec<Expr>, | ||
| /// The `DataType` the expression will yield | ||
| return_type: DataType, | ||
| }, | ||
| /// aggregate function | ||
| AggregateFunction { | ||
|
|
@@ -302,7 +301,7 @@ impl Expr { | |
| Expr::Column(name) => Ok(schema.field_with_name(name)?.data_type().clone()), | ||
| Expr::Literal(l) => l.get_datatype(), | ||
| Expr::Cast { data_type, .. } => Ok(data_type.clone()), | ||
| Expr::ScalarUDF { return_type, .. } => Ok(return_type.clone()), | ||
| Expr::ScalarUDF { fun, .. } => Ok(fun.return_type.clone()), | ||
| Expr::ScalarFunction { fun, args } => { | ||
| let data_types = args | ||
| .iter() | ||
|
|
@@ -686,15 +685,6 @@ pub fn aggregate_expr(name: &str, expr: Expr) -> Expr { | |
| } | ||
| } | ||
|
|
||
| /// call a scalar UDF | ||
| pub fn scalar_function(name: &str, expr: Vec<Expr>, return_type: DataType) -> Expr { | ||
| Expr::ScalarUDF { | ||
| name: name.to_owned(), | ||
| args: expr, | ||
| return_type, | ||
| } | ||
| } | ||
|
|
||
| impl fmt::Debug for Expr { | ||
| fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | ||
| match self { | ||
|
|
@@ -737,8 +727,8 @@ impl fmt::Debug for Expr { | |
|
|
||
| write!(f, ")") | ||
| } | ||
| Expr::ScalarUDF { name, ref args, .. } => { | ||
| write!(f, "{}(", name)?; | ||
| Expr::ScalarUDF { fun, ref args, .. } => { | ||
| write!(f, "{}(", fun.name)?; | ||
| for i in 0..args.len() { | ||
| if i > 0 { | ||
| write!(f, ", ")?; | ||
|
|
@@ -1038,6 +1028,15 @@ impl fmt::Debug for LogicalPlan { | |
| } | ||
| } | ||
|
|
||
| /// A registry knows how to build logical expressions out of user-defined function' names | ||
| pub trait FunctionRegistry { | ||
| /// Set of all available udfs. | ||
| fn udfs(&self) -> HashSet<String>; | ||
|
|
||
| /// Constructs a logical expression with a call to the udf. | ||
| fn udf(&self, name: &str, args: Vec<Expr>) -> Result<Expr>; | ||
| } | ||
|
|
||
| /// Builder for logical plans | ||
| pub struct LogicalPlanBuilder { | ||
| plan: LogicalPlan, | ||
|
|
@@ -1137,19 +1136,14 @@ impl LogicalPlanBuilder { | |
| /// Apply a projection | ||
| pub fn project(&self, expr: Vec<Expr>) -> Result<Self> { | ||
| let input_schema = self.plan.schema(); | ||
| let projected_expr = if expr.contains(&Expr::Wildcard) { | ||
| let mut expr_vec = vec![]; | ||
| (0..expr.len()).for_each(|i| match &expr[i] { | ||
| Expr::Wildcard => { | ||
| (0..input_schema.fields().len()) | ||
| .for_each(|i| expr_vec.push(col(input_schema.field(i).name()))); | ||
| } | ||
| _ => expr_vec.push(expr[i].clone()), | ||
| }); | ||
| expr_vec | ||
| } else { | ||
| expr.clone() | ||
| }; | ||
| let mut projected_expr = vec![]; | ||
|
||
| (0..expr.len()).for_each(|i| match &expr[i] { | ||
| Expr::Wildcard => { | ||
| (0..input_schema.fields().len()) | ||
| .for_each(|i| projected_expr.push(col(input_schema.field(i).name()))); | ||
| } | ||
| _ => projected_expr.push(expr[i].clone()), | ||
| }); | ||
|
|
||
| let schema = | ||
| Schema::new(exprlist_to_fields(&projected_expr, input_schema.as_ref())?); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this registry specific to scalar functions or will it also be used for aggregate functions? Perhaps we should name the method either
function_registryorscalar_function_registry?