diff --git a/src/catalog_v2/catalog.rs b/src/catalog_v2/catalog.rs index 440014e..001b9bf 100644 --- a/src/catalog_v2/catalog.rs +++ b/src/catalog_v2/catalog.rs @@ -1,8 +1,11 @@ use std::sync::Arc; use super::entry::{CatalogEntry, DataTable}; -use super::{CatalogError, CatalogSet, TableCatalogEntry, TableFunctionCatalogEntry}; -use crate::common::CreateTableFunctionInfo; +use super::{ + CatalogError, CatalogSet, ScalarFunctionCatalogEntry, TableCatalogEntry, + TableFunctionCatalogEntry, +}; +use crate::common::{CreateScalarFunctionInfo, CreateTableFunctionInfo}; use crate::main_entry::ClientContext; /// The Catalog object represents the catalog of the database. @@ -113,4 +116,39 @@ impl Catalog { } Err(CatalogError::CatalogEntryTypeNotMatch) } + + pub fn create_scalar_function( + client_context: Arc, + info: CreateScalarFunctionInfo, + ) -> Result<(), CatalogError> { + let mut catalog = match client_context.db.catalog.try_write() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + let version = catalog.catalog_version; + let entry = catalog.schemas.get_mut_entry(info.base.schema.clone())?; + + if let CatalogEntry::SchemaCatalogEntry(mut_entry) = entry { + mut_entry.create_scalar_function(version + 1, info)?; + catalog.catalog_version += 1; + Ok(()) + } else { + Err(CatalogError::CatalogEntryTypeNotMatch) + } + } + + pub fn get_scalar_function( + client_context: Arc, + schema: String, + scalar_function: String, + ) -> Result { + let catalog = match client_context.db.catalog.try_read() { + Ok(c) => c, + Err(_) => return Err(CatalogError::CatalogLockedError), + }; + if let CatalogEntry::SchemaCatalogEntry(entry) = catalog.schemas.get_entry(schema)? { + return entry.get_scalar_function(scalar_function); + } + Err(CatalogError::CatalogEntryTypeNotMatch) + } } diff --git a/src/catalog_v2/catalog_set.rs b/src/catalog_v2/catalog_set.rs index f5c2cf3..8a27685 100644 --- a/src/catalog_v2/catalog_set.rs +++ b/src/catalog_v2/catalog_set.rs @@ -33,6 +33,15 @@ impl CatalogSet { Err(CatalogError::CatalogEntryNotExists(name)) } + pub fn get_mut_entry(&mut self, name: String) -> Result<&mut CatalogEntry, CatalogError> { + if let Some(index) = self.mapping.get(&name) { + if let Some(entry) = self.entries.get_mut(index) { + return Ok(entry); + } + } + Err(CatalogError::CatalogEntryNotExists(name)) + } + pub fn replace_entry( &mut self, name: String, diff --git a/src/catalog_v2/entry/mod.rs b/src/catalog_v2/entry/mod.rs index ce1ce85..fb992ed 100644 --- a/src/catalog_v2/entry/mod.rs +++ b/src/catalog_v2/entry/mod.rs @@ -1,8 +1,10 @@ +mod scalar_function_catalog_entry; mod schema_catalog_entry; mod table_catalog_entry; mod table_function_catalog_entry; use derive_new::new; +pub use scalar_function_catalog_entry::*; pub use schema_catalog_entry::*; pub use table_catalog_entry::*; pub use table_function_catalog_entry::*; @@ -12,6 +14,7 @@ pub enum CatalogEntry { SchemaCatalogEntry(SchemaCatalogEntry), TableCatalogEntry(TableCatalogEntry), TableFunctionCatalogEntry(TableFunctionCatalogEntry), + ScalarFunctionCatalogEntry(ScalarFunctionCatalogEntry), } impl CatalogEntry { diff --git a/src/catalog_v2/entry/scalar_function_catalog_entry.rs b/src/catalog_v2/entry/scalar_function_catalog_entry.rs new file mode 100644 index 0000000..05adf3a --- /dev/null +++ b/src/catalog_v2/entry/scalar_function_catalog_entry.rs @@ -0,0 +1,12 @@ +use derive_new::new; + +use super::CatalogEntryBase; +use crate::function::ScalarFunction; + +#[derive(new, Clone, Debug)] +pub struct ScalarFunctionCatalogEntry { + #[allow(dead_code)] + pub(crate) base: CatalogEntryBase, + #[allow(dead_code)] + pub(crate) functions: Vec, +} diff --git a/src/catalog_v2/entry/schema_catalog_entry.rs b/src/catalog_v2/entry/schema_catalog_entry.rs index 276fd53..e766c14 100644 --- a/src/catalog_v2/entry/schema_catalog_entry.rs +++ b/src/catalog_v2/entry/schema_catalog_entry.rs @@ -1,7 +1,9 @@ use super::table_catalog_entry::{DataTable, TableCatalogEntry}; -use super::{CatalogEntry, CatalogEntryBase, TableFunctionCatalogEntry}; +use super::{ + CatalogEntry, CatalogEntryBase, ScalarFunctionCatalogEntry, TableFunctionCatalogEntry, +}; use crate::catalog_v2::{CatalogError, CatalogSet}; -use crate::common::CreateTableFunctionInfo; +use crate::common::{CreateScalarFunctionInfo, CreateTableFunctionInfo}; #[allow(dead_code)] #[derive(Clone, Debug)] @@ -76,4 +78,28 @@ impl SchemaCatalogEntry { result.extend(self.functions.scan_entries(callback)); result } + + pub fn create_scalar_function( + &mut self, + oid: usize, + info: CreateScalarFunctionInfo, + ) -> Result<(), CatalogError> { + let entry = ScalarFunctionCatalogEntry::new( + CatalogEntryBase::new(oid, info.name.clone()), + info.functions, + ); + let entry = CatalogEntry::ScalarFunctionCatalogEntry(entry); + self.functions.create_entry(info.name, entry)?; + Ok(()) + } + + pub fn get_scalar_function( + &self, + scalar_function: String, + ) -> Result { + match self.functions.get_entry(scalar_function.clone())? { + CatalogEntry::ScalarFunctionCatalogEntry(e) => Ok(e), + _ => Err(CatalogError::CatalogEntryNotExists(scalar_function)), + } + } } diff --git a/src/common/create_info.rs b/src/common/create_info.rs index 5508454..dcea19d 100644 --- a/src/common/create_info.rs +++ b/src/common/create_info.rs @@ -1,7 +1,7 @@ use derive_new::new; use crate::catalog_v2::ColumnDefinition; -use crate::function::TableFunction; +use crate::function::{ScalarFunction, TableFunction}; #[derive(new, Debug, Clone)] pub struct CreateInfoBase { @@ -25,3 +25,12 @@ pub struct CreateTableFunctionInfo { /// Functions with different arguments pub(crate) functions: Vec, } + +#[derive(new)] +pub struct CreateScalarFunctionInfo { + pub(crate) base: CreateInfoBase, + /// Function name + pub(crate) name: String, + /// Functions with different arguments + pub(crate) functions: Vec, +} diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index 7f3aa6c..244dbaf 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -35,6 +35,15 @@ impl ExpressionExecutor { let options = CastOptions { safe: e.try_cast }; cast_with_options(&child_result, &to_type, &options)? } + BoundExpression::BoundFunctionExpression(e) => { + let children_result = e + .children + .iter() + .map(|c| Self::execute_internal(c, input)) + .collect::, _>>()?; + let func = e.function.function; + func(&children_result)? + } }) } } diff --git a/src/function/mod.rs b/src/function/mod.rs index 5a65524..50388c0 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,14 +1,16 @@ mod errors; +mod scalar; mod table; use std::sync::Arc; use derive_new::new; pub use errors::*; +pub use scalar::*; pub use table::*; use crate::catalog_v2::{Catalog, DEFAULT_SCHEMA}; -use crate::common::{CreateInfoBase, CreateTableFunctionInfo}; +use crate::common::{CreateInfoBase, CreateScalarFunctionInfo, CreateTableFunctionInfo}; use crate::main_entry::ClientContext; #[derive(Debug, Clone)] @@ -33,9 +35,26 @@ impl BuiltinFunctions { Ok(Catalog::create_table_function(self.context.clone(), info)?) } + pub fn add_scalar_functions( + &mut self, + function_name: String, + functions: Vec, + ) -> Result<(), FunctionError> { + let info = CreateScalarFunctionInfo::new( + CreateInfoBase::new(DEFAULT_SCHEMA.to_string()), + function_name, + functions, + ); + Ok(Catalog::create_scalar_function(self.context.clone(), info)?) + } + pub fn initialize(&mut self) -> Result<(), FunctionError> { SqlrsTablesFunc::register_function(self)?; SqlrsColumnsFunc::register_function(self)?; + AddFunction::register_function(self)?; + SubtractFunction::register_function(self)?; + MultiplyFunction::register_function(self)?; + DivideFunction::register_function(self)?; Ok(()) } } diff --git a/src/function/scalar/arithmetic_function.rs b/src/function/scalar/arithmetic_function.rs new file mode 100644 index 0000000..deb8ae4 --- /dev/null +++ b/src/function/scalar/arithmetic_function.rs @@ -0,0 +1,157 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, *}; +use arrow::compute::{add_checked, divide_checked, multiply_checked, subtract_checked}; +use arrow::datatypes::DataType; + +use super::ScalarFunction; +use crate::function::{BuiltinFunctions, FunctionError}; +use crate::types_v2::LogicalType; + +/// Invoke a compute kernel on array(s) +macro_rules! compute_op { + // invoke binary operator + ($LEFT:expr, $RIGHT:expr, $OP:ident, $DT:ident) => {{ + let ll = $LEFT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new($OP(&ll, &rr)?)) + }}; + // invoke unary operator + ($OPERAND:expr, $OP:ident, $DT:ident) => {{ + let operand = $OPERAND + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + Ok(Arc::new($OP(&operand)?)) + }}; +} + +/// Invoke a compute kernel on a pair of arrays +/// The binary_primitive_array_op macro only evaluates for primitive types +/// like integers and floats. +macro_rules! binary_primitive_array_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + match $LEFT.data_type() { + DataType::Int8 => compute_op!($LEFT, $RIGHT, $OP, Int8Array), + DataType::Int16 => compute_op!($LEFT, $RIGHT, $OP, Int16Array), + DataType::Int32 => compute_op!($LEFT, $RIGHT, $OP, Int32Array), + DataType::Int64 => compute_op!($LEFT, $RIGHT, $OP, Int64Array), + DataType::UInt8 => compute_op!($LEFT, $RIGHT, $OP, UInt8Array), + DataType::UInt16 => compute_op!($LEFT, $RIGHT, $OP, UInt16Array), + DataType::UInt32 => compute_op!($LEFT, $RIGHT, $OP, UInt32Array), + DataType::UInt64 => compute_op!($LEFT, $RIGHT, $OP, UInt64Array), + DataType::Float32 => compute_op!($LEFT, $RIGHT, $OP, Float32Array), + DataType::Float64 => compute_op!($LEFT, $RIGHT, $OP, Float64Array), + other => Err(FunctionError::InternalError(format!( + "Data type {:?} not supported for binary operation '{}' on primitive arrays", + other, + stringify!($OP) + ))), + } + }}; +} +pub struct AddFunction; + +impl AddFunction { + fn add(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + binary_primitive_array_op!(left, right, add_checked) + } + + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { + let mut functions = vec![]; + for ty in LogicalType::numeric().iter() { + functions.push(ScalarFunction::new( + "add".to_string(), + Self::add, + vec![ty.clone(), ty.clone()], + ty.clone(), + )); + } + set.add_scalar_functions("add".to_string(), functions.clone())?; + Ok(()) + } +} + +pub struct SubtractFunction; + +impl SubtractFunction { + fn subtract(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + binary_primitive_array_op!(left, right, subtract_checked) + } + + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { + let mut functions = vec![]; + for ty in LogicalType::numeric().iter() { + functions.push(ScalarFunction::new( + "subtract".to_string(), + Self::subtract, + vec![ty.clone(), ty.clone()], + ty.clone(), + )); + } + set.add_scalar_functions("subtract".to_string(), functions.clone())?; + Ok(()) + } +} + +pub struct MultiplyFunction; + +impl MultiplyFunction { + fn multiply(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + binary_primitive_array_op!(left, right, multiply_checked) + } + + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { + let mut functions = vec![]; + for ty in LogicalType::numeric().iter() { + functions.push(ScalarFunction::new( + "multiply".to_string(), + Self::multiply, + vec![ty.clone(), ty.clone()], + ty.clone(), + )); + } + set.add_scalar_functions("multiply".to_string(), functions.clone())?; + Ok(()) + } +} + +pub struct DivideFunction; + +impl DivideFunction { + fn divide(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + binary_primitive_array_op!(left, right, divide_checked) + } + + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { + let mut functions = vec![]; + for ty in LogicalType::numeric().iter() { + functions.push(ScalarFunction::new( + "divide".to_string(), + Self::divide, + vec![ty.clone(), ty.clone()], + ty.clone(), + )); + } + set.add_scalar_functions("divide".to_string(), functions.clone())?; + Ok(()) + } +} diff --git a/src/function/scalar/mod.rs b/src/function/scalar/mod.rs new file mode 100644 index 0000000..6c903ac --- /dev/null +++ b/src/function/scalar/mod.rs @@ -0,0 +1,4 @@ +mod arithmetic_function; +mod scalar_function; +pub use arithmetic_function::*; +pub use scalar_function::*; diff --git a/src/function/scalar/scalar_function.rs b/src/function/scalar/scalar_function.rs new file mode 100644 index 0000000..b8ed5d3 --- /dev/null +++ b/src/function/scalar/scalar_function.rs @@ -0,0 +1,32 @@ +use arrow::array::ArrayRef; +use derive_new::new; + +use crate::function::FunctionError; +use crate::types_v2::LogicalType; + +// pub type ScalarFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result; +pub type ScalarFunc = fn(inputs: &[ArrayRef]) -> Result; + +#[derive(new, Clone)] +pub struct ScalarFunction { + // The name of the function + pub(crate) name: String, + /// The main scalar function to execute + pub(crate) function: ScalarFunc, + /// The set of arguments of the function + pub(crate) arguments: Vec, + /// Return type of the function + pub(crate) return_type: LogicalType, +} + +impl std::fmt::Debug for ScalarFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ScalarFunction") + .field("name", &self.name) + .field( + "types", + &format!("{:?} -> {:?}", self.arguments, self.return_type), + ) + .finish() + } +} diff --git a/src/planner_v2/binder/errors.rs b/src/planner_v2/binder/errors.rs index 8cf6922..39290dd 100644 --- a/src/planner_v2/binder/errors.rs +++ b/src/planner_v2/binder/errors.rs @@ -11,6 +11,8 @@ pub enum BindError { SqlParserUnsupportedStmt(String), #[error("bind internal error: {0}")] Internal(String), + #[error("{0}")] + FunctionBindError(String), #[error("type error: {0}")] TypeError( #[from] diff --git a/src/planner_v2/binder/expression/bind_function_expression.rs b/src/planner_v2/binder/expression/bind_function_expression.rs new file mode 100644 index 0000000..d67c2d0 --- /dev/null +++ b/src/planner_v2/binder/expression/bind_function_expression.rs @@ -0,0 +1,53 @@ +use derive_new::new; + +use super::{BoundExpression, BoundExpressionBase}; +use crate::catalog_v2::{Catalog, DEFAULT_SCHEMA}; +use crate::function::ScalarFunction; +use crate::planner_v2::{BindError, ExpressionBinder, FunctionBinder}; +use crate::types_v2::LogicalType; + +#[derive(new, Debug, Clone)] +pub struct BoundFunctionExpression { + pub(crate) base: BoundExpressionBase, + /// The bound function expression + pub(crate) function: ScalarFunction, + /// List of child-expressions of the function + pub(crate) children: Vec, +} + +impl ExpressionBinder<'_> { + pub fn bind_function_expression( + &mut self, + left: &sqlparser::ast::Expr, + op: &sqlparser::ast::BinaryOperator, + right: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let function_name = match op { + sqlparser::ast::BinaryOperator::Plus => "add", + sqlparser::ast::BinaryOperator::Minus => "subtract", + sqlparser::ast::BinaryOperator::Multiply => "multiply", + sqlparser::ast::BinaryOperator::Divide => "divide", + other => { + return Err(BindError::Internal(format!( + "unexpected binary operator {} for function expression", + other + ))) + } + }; + let function = Catalog::get_scalar_function( + self.binder.clone_client_context(), + DEFAULT_SCHEMA.to_string(), + function_name.to_string(), + )?; + let mut return_names = vec![]; + let left = self.bind_expression(left, &mut return_names, &mut vec![])?; + let right = self.bind_expression(right, &mut return_names, &mut vec![])?; + let func_binder = FunctionBinder::new(); + let bound_function = func_binder.bind_scalar_function(function, vec![left, right])?; + result_names.push(format!("{}({})", function_name, return_names.join(", "))); + result_types.push(bound_function.base.return_type.clone()); + Ok(BoundExpression::BoundFunctionExpression(bound_function)) + } +} diff --git a/src/planner_v2/binder/expression/mod.rs b/src/planner_v2/binder/expression/mod.rs index 22aabae..e5b224f 100644 --- a/src/planner_v2/binder/expression/mod.rs +++ b/src/planner_v2/binder/expression/mod.rs @@ -1,12 +1,14 @@ mod bind_cast_expression; mod bind_column_ref_expression; mod bind_constant_expression; +mod bind_function_expression; mod bind_reference_expression; mod column_binding; pub use bind_cast_expression::*; pub use bind_column_ref_expression::*; pub use bind_constant_expression::*; +pub use bind_function_expression::*; pub use bind_reference_expression::*; pub use column_binding::*; use derive_new::new; @@ -27,6 +29,7 @@ pub enum BoundExpression { BoundConstantExpression(BoundConstantExpression), BoundReferenceExpression(BoundReferenceExpression), BoundCastExpression(BoundCastExpression), + BoundFunctionExpression(BoundFunctionExpression), } impl BoundExpression { @@ -36,6 +39,7 @@ impl BoundExpression { BoundExpression::BoundConstantExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundReferenceExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(), + BoundExpression::BoundFunctionExpression(expr) => expr.base.return_type.clone(), } } @@ -45,6 +49,7 @@ impl BoundExpression { BoundExpression::BoundConstantExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundReferenceExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(), + BoundExpression::BoundFunctionExpression(expr) => expr.base.alias.clone(), } } @@ -54,6 +59,7 @@ impl BoundExpression { BoundExpression::BoundConstantExpression(expr) => expr.base.alias = alias, BoundExpression::BoundReferenceExpression(expr) => expr.base.alias = alias, BoundExpression::BoundCastExpression(expr) => expr.base.alias = alias, + BoundExpression::BoundFunctionExpression(expr) => expr.base.alias = alias, } } } diff --git a/src/planner_v2/binder/statement/bind_select.rs b/src/planner_v2/binder/statement/bind_select.rs index dea07f3..0d53ddd 100644 --- a/src/planner_v2/binder/statement/bind_select.rs +++ b/src/planner_v2/binder/statement/bind_select.rs @@ -8,7 +8,6 @@ impl Binder { match stmt { Statement::Query(query) => { let node = self.bind_select_node(query)?; - // println!("bind context: {:#?}", self.bind_context); self.create_plan_for_select_node(node) } _ => Err(BindError::UnsupportedStmt(format!("{:?}", stmt))), diff --git a/src/planner_v2/expression_binder.rs b/src/planner_v2/expression_binder.rs index 4fef4e8..3c77e94 100644 --- a/src/planner_v2/expression_binder.rs +++ b/src/planner_v2/expression_binder.rs @@ -24,7 +24,9 @@ impl ExpressionBinder<'_> { sqlparser::ast::Expr::CompoundIdentifier(idents) => { self.bind_column_ref_expr(idents, result_names, result_types) } - sqlparser::ast::Expr::BinaryOp { .. } => todo!(), + sqlparser::ast::Expr::BinaryOp { left, op, right } => { + self.bind_binary_op_internal(left, op, right, result_names, result_types) + } sqlparser::ast::Expr::UnaryOp { .. } => todo!(), sqlparser::ast::Expr::Value(v) => { self.bind_constant_expr(v, result_names, result_types) @@ -35,4 +37,31 @@ impl ExpressionBinder<'_> { _ => todo!(), } } + + fn bind_binary_op_internal( + &mut self, + left: &sqlparser::ast::Expr, + op: &sqlparser::ast::BinaryOperator, + right: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + match op { + sqlparser::ast::BinaryOperator::Plus + | sqlparser::ast::BinaryOperator::Minus + | sqlparser::ast::BinaryOperator::Multiply + | sqlparser::ast::BinaryOperator::Divide => { + self.bind_function_expression(left, op, right, result_names, result_types) + } + sqlparser::ast::BinaryOperator::Gt => todo!(), + sqlparser::ast::BinaryOperator::Lt => todo!(), + sqlparser::ast::BinaryOperator::GtEq => todo!(), + sqlparser::ast::BinaryOperator::LtEq => todo!(), + sqlparser::ast::BinaryOperator::Eq => todo!(), + sqlparser::ast::BinaryOperator::NotEq => todo!(), + sqlparser::ast::BinaryOperator::And => todo!(), + sqlparser::ast::BinaryOperator::Or => todo!(), + other => Err(BindError::UnsupportedExpr(other.to_string())), + } + } } diff --git a/src/planner_v2/expression_iterator.rs b/src/planner_v2/expression_iterator.rs index b175dd0..9beb092 100644 --- a/src/planner_v2/expression_iterator.rs +++ b/src/planner_v2/expression_iterator.rs @@ -14,6 +14,7 @@ impl ExpressionIterator { // these node types have no children } BoundExpression::BoundCastExpression(e) => callback(&mut e.child), + BoundExpression::BoundFunctionExpression(e) => e.children.iter_mut().for_each(callback), } } } diff --git a/src/planner_v2/function_binder.rs b/src/planner_v2/function_binder.rs new file mode 100644 index 0000000..64d8ebe --- /dev/null +++ b/src/planner_v2/function_binder.rs @@ -0,0 +1,78 @@ +use derive_new::new; + +use super::{BindError, BoundExpressionBase, INVALID_INDEX}; +use crate::catalog_v2::ScalarFunctionCatalogEntry; +use crate::function::ScalarFunction; +use crate::planner_v2::{BoundExpression, BoundFunctionExpression}; +use crate::types_v2::LogicalType; + +/// Find the function with matching parameters from the function list. +#[derive(new)] +pub struct FunctionBinder; + +impl FunctionBinder { + pub fn bind_scalar_function( + &self, + func: ScalarFunctionCatalogEntry, + children: Vec, + ) -> Result { + let arguments = self.get_logical_types_from_expressions(&children); + let best_func_idx = self.bind_function_from_arguments(&func, &arguments)?; + let bound_function = func.functions[best_func_idx].clone(); + let base = BoundExpressionBase::new("".to_string(), bound_function.return_type.clone()); + Ok(BoundFunctionExpression::new(base, bound_function, children)) + } + + fn get_logical_types_from_expressions(&self, children: &[BoundExpression]) -> Vec { + children.iter().map(|c| c.return_type()).collect() + } + + fn bind_function_from_arguments( + &self, + func: &ScalarFunctionCatalogEntry, + arguments: &[LogicalType], + ) -> Result { + let mut candidate_functions = vec![]; + let mut best_function_idx = INVALID_INDEX; + for (func_idx, each_func) in func.functions.iter().enumerate() { + let cost = self.bind_function_cost(each_func, arguments); + if cost < 0 { + continue; + } + candidate_functions.push(func_idx); + best_function_idx = func_idx; + } + + if best_function_idx == INVALID_INDEX { + return Err(BindError::FunctionBindError(format!( + "No function matched for given function and arguments {} {:?}", + func.base.name, arguments + ))); + } + + if candidate_functions.len() > 1 { + return Err(BindError::FunctionBindError(format!( + "Ambiguous function call for function {} and arguments {:?}", + func.base.name, arguments + ))); + } + + Ok(candidate_functions[0]) + } + + fn bind_function_cost(&self, func: &ScalarFunction, arguments: &[LogicalType]) -> i32 { + if func.arguments.len() != arguments.len() { + // invalid argument count: check the next function + return -1; + } + let cost = 0; + // TODO: use cast function to infer the cost and choose the best matched function. + for (i, arg) in arguments.iter().enumerate() { + if func.arguments[i] != *arg { + // invalid argument count: check the next function + return -1; + } + } + cost + } +} diff --git a/src/planner_v2/logical_operator_visitor.rs b/src/planner_v2/logical_operator_visitor.rs index 0d1a4fc..6bb078b 100644 --- a/src/planner_v2/logical_operator_visitor.rs +++ b/src/planner_v2/logical_operator_visitor.rs @@ -1,6 +1,6 @@ use super::{ BoundCastExpression, BoundColumnRefExpression, BoundConstantExpression, BoundExpression, - BoundReferenceExpression, ExpressionIterator, LogicalOperator, + BoundFunctionExpression, BoundReferenceExpression, ExpressionIterator, LogicalOperator, }; /// Visitor pattern on logical operators, also includes rewrite expression ability. @@ -35,6 +35,7 @@ pub trait LogicalOperatorVisitor { BoundExpression::BoundConstantExpression(e) => self.visit_replace_constant(e), BoundExpression::BoundReferenceExpression(e) => self.visit_replace_reference(e), BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e), + BoundExpression::BoundFunctionExpression(e) => self.visit_function_expression(e), }; if let Some(new_expr) = result { *expr = new_expr; @@ -59,4 +60,7 @@ pub trait LogicalOperatorVisitor { fn visit_replace_cast(&self, _: &BoundCastExpression) -> Option { None } + fn visit_function_expression(&self, _: &BoundFunctionExpression) -> Option { + None + } } diff --git a/src/planner_v2/mod.rs b/src/planner_v2/mod.rs index 0ebbdd5..0ae769f 100644 --- a/src/planner_v2/mod.rs +++ b/src/planner_v2/mod.rs @@ -3,6 +3,7 @@ mod constants; mod errors; mod expression_binder; mod expression_iterator; +mod function_binder; mod logical_operator_visitor; mod operator; @@ -13,6 +14,7 @@ pub use constants::*; pub use errors::*; pub use expression_binder::*; pub use expression_iterator::*; +pub use function_binder::*; use log::debug; pub use logical_operator_visitor::*; pub use operator::*; diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs index 27fd9f8..ef16dcc 100644 --- a/src/types_v2/types.rs +++ b/src/types_v2/types.rs @@ -23,6 +23,21 @@ pub enum LogicalType { } impl LogicalType { + pub fn numeric() -> Vec { + vec![ + LogicalType::Tinyint, + LogicalType::UTinyint, + LogicalType::Smallint, + LogicalType::USmallint, + LogicalType::Integer, + LogicalType::UInteger, + LogicalType::Bigint, + LogicalType::UBigint, + LogicalType::Float, + LogicalType::Double, + ] + } + pub fn is_numeric(&self) -> bool { matches!( self, diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index 6d40b77..1f13fa5 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -37,6 +37,15 @@ impl TreeRender { e.base.return_type, ) } + BoundExpression::BoundFunctionExpression(e) => { + let args = e + .children + .iter() + .map(Self::bound_expression_to_string) + .collect::>() + .join(", "); + format!("{}({}])", e.function.name, args) + } } } diff --git a/tests/slt/scalar_function.slt b/tests/slt/scalar_function.slt new file mode 100644 index 0000000..60eef74 --- /dev/null +++ b/tests/slt/scalar_function.slt @@ -0,0 +1,40 @@ +onlyif sqlrs_v2 +statement ok +CREATE TABLE test(a integer); +insert into test values (1), (2), (3), (NULL); + +onlyif sqlrs_v2 +query I +select a+a from test +---- +2 +4 +6 +NULL + +onlyif sqlrs_v2 +query I +select a-a from test +---- +0 +0 +0 +NULL + +onlyif sqlrs_v2 +query I +select a*a from test +---- +1 +4 +9 +NULL + +onlyif sqlrs_v2 +query I +select a/a from test +---- +1 +1 +1 +NULL