diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index 244dbaf..8c0a82c 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -1,5 +1,4 @@ use arrow::array::ArrayRef; -use arrow::compute::{cast_with_options, CastOptions}; use arrow::record_batch::RecordBatch; use super::ExecutorError; @@ -31,9 +30,8 @@ impl ExpressionExecutor { BoundExpression::BoundReferenceExpression(e) => input.column(e.index).clone(), BoundExpression::BoundCastExpression(e) => { let child_result = Self::execute_internal(&e.child, input)?; - let to_type = e.base.return_type.clone().into(); - let options = CastOptions { safe: e.try_cast }; - cast_with_options(&child_result, &to_type, &options)? + let cast_function = e.function.function; + cast_function(&child_result, &e.base.return_type, e.try_cast)? } BoundExpression::BoundFunctionExpression(e) => { let children_result = e diff --git a/src/function/cast/cast_function.rs b/src/function/cast/cast_function.rs new file mode 100644 index 0000000..ab7e701 --- /dev/null +++ b/src/function/cast/cast_function.rs @@ -0,0 +1,26 @@ +use arrow::array::ArrayRef; +use derive_new::new; + +use crate::function::FunctionError; +use crate::types_v2::LogicalType; + +pub type CastFunc = + fn(array: &ArrayRef, to_type: &LogicalType, try_cast: bool) -> Result; + +#[derive(new, Clone)] +pub struct CastFunction { + /// The source type of the cast + pub(crate) source: LogicalType, + /// The target type of the cast + pub(crate) target: LogicalType, + /// The main cast function to execute + pub(crate) function: CastFunc, +} + +impl std::fmt::Debug for CastFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CastFunction") + .field("cast", &format!("{:?} -> {:?}", self.source, self.target)) + .finish() + } +} diff --git a/src/function/cast/cast_rules.rs b/src/function/cast/cast_rules.rs new file mode 100644 index 0000000..8be90a4 --- /dev/null +++ b/src/function/cast/cast_rules.rs @@ -0,0 +1,15 @@ +use crate::types_v2::LogicalType; + +pub struct CastRules; + +impl CastRules { + pub fn implicit_cast_cost(from: &LogicalType, to: &LogicalType) -> i32 { + if from == to { + 0 + } else if LogicalType::can_implicit_cast(from, to) { + 1 + } else { + -1 + } + } +} diff --git a/src/function/cast/default_cast.rs b/src/function/cast/default_cast.rs new file mode 100644 index 0000000..acc854a --- /dev/null +++ b/src/function/cast/default_cast.rs @@ -0,0 +1,37 @@ +use arrow::array::ArrayRef; +use arrow::compute::{cast_with_options, CastOptions}; + +use super::CastFunction; +use crate::function::FunctionError; +use crate::types_v2::LogicalType; + +pub struct DefaultCastFunctions; + +impl DefaultCastFunctions { + fn default_cast_function( + array: &ArrayRef, + to_type: &LogicalType, + try_cast: bool, + ) -> Result { + let to_type = to_type.clone().into(); + let options = CastOptions { safe: try_cast }; + Ok(cast_with_options(array, &to_type, &options)?) + } + + pub fn get_cast_function( + source: &LogicalType, + target: &LogicalType, + ) -> Result { + assert!(source != target); + match source { + LogicalType::Invalid => { + Err(FunctionError::CastError("Invalid source type".to_string())) + } + _ => Ok(CastFunction::new( + source.clone(), + target.clone(), + Self::default_cast_function, + )), + } + } +} diff --git a/src/function/cast/mod.rs b/src/function/cast/mod.rs new file mode 100644 index 0000000..6c1cfe5 --- /dev/null +++ b/src/function/cast/mod.rs @@ -0,0 +1,7 @@ +mod cast_function; +mod cast_rules; +mod default_cast; + +pub use cast_function::*; +pub use cast_rules::*; +pub use default_cast::*; diff --git a/src/function/errors.rs b/src/function/errors.rs index 8f2746a..70327cb 100644 --- a/src/function/errors.rs +++ b/src/function/errors.rs @@ -25,4 +25,6 @@ pub enum FunctionError { ), #[error("Internal error: {0}")] InternalError(String), + #[error("Cast error: {0}")] + CastError(String), } diff --git a/src/function/mod.rs b/src/function/mod.rs index 50388c0..5d28890 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,9 +1,11 @@ +mod cast; mod errors; mod scalar; mod table; use std::sync::Arc; +pub use cast::*; use derive_new::new; pub use errors::*; pub use scalar::*; diff --git a/src/function/scalar/scalar_function.rs b/src/function/scalar/scalar_function.rs index b8ed5d3..6cff15a 100644 --- a/src/function/scalar/scalar_function.rs +++ b/src/function/scalar/scalar_function.rs @@ -4,7 +4,6 @@ use derive_new::new; use crate::function::FunctionError; use crate::types_v2::LogicalType; -// pub type ScalarFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result; pub type ScalarFunc = fn(inputs: &[ArrayRef]) -> Result; #[derive(new, Clone)] diff --git a/src/planner_v2/binder/expression/bind_cast_expression.rs b/src/planner_v2/binder/expression/bind_cast_expression.rs index 0b18384..6abba1f 100644 --- a/src/planner_v2/binder/expression/bind_cast_expression.rs +++ b/src/planner_v2/binder/expression/bind_cast_expression.rs @@ -1,6 +1,8 @@ use derive_new::new; use super::{BoundExpression, BoundExpressionBase}; +use crate::function::{CastFunction, DefaultCastFunctions}; +use crate::planner_v2::BindError; use crate::types_v2::LogicalType; #[derive(new, Debug, Clone)] @@ -11,6 +13,8 @@ pub struct BoundCastExpression { /// Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of /// throwing an error. pub(crate) try_cast: bool, + /// The cast function to execute + pub(crate) function: CastFunction, } impl BoundCastExpression { @@ -19,15 +23,13 @@ impl BoundCastExpression { target_type: LogicalType, alias: String, try_cast: bool, - ) -> BoundExpression { - if expr.return_type() == target_type { - return expr; - } + ) -> Result { + let source_type = expr.return_type(); + assert!(source_type != target_type); + let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?; let base = BoundExpressionBase::new(alias, target_type); - BoundExpression::BoundCastExpression(BoundCastExpression::new( - base, - Box::new(expr), - try_cast, + Ok(BoundExpression::BoundCastExpression( + BoundCastExpression::new(base, Box::new(expr), try_cast, cast_function), )) } } diff --git a/src/planner_v2/binder/query_node/plan_select_node.rs b/src/planner_v2/binder/query_node/plan_select_node.rs index 8f23324..808c010 100644 --- a/src/planner_v2/binder/query_node/plan_select_node.rs +++ b/src/planner_v2/binder/query_node/plan_select_node.rs @@ -55,7 +55,7 @@ impl Binder { target_type.clone(), alias, false, - ); + )?; node.base.types[idx] = target_type.clone(); } } diff --git a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs index 2929d90..a68f078 100644 --- a/src/planner_v2/binder/tableref/bind_expression_list_ref.rs +++ b/src/planner_v2/binder/tableref/bind_expression_list_ref.rs @@ -61,14 +61,14 @@ impl Binder { // insert values contains SqlNull, the expr should be cast to the max logical type for exprs in bound_expr_list.iter_mut() { for (idx, bound_expr) in exprs.iter_mut().enumerate() { - if bound_expr.return_type() == LogicalType::SqlNull { + if bound_expr.return_type() != types[idx] { let alias = bound_expr.alias().clone(); *bound_expr = BoundCastExpression::add_cast_to_type( bound_expr.clone(), types[idx].clone(), alias, false, - ) + )? } } } diff --git a/src/planner_v2/function_binder.rs b/src/planner_v2/function_binder.rs index 64d8ebe..9fbae1c 100644 --- a/src/planner_v2/function_binder.rs +++ b/src/planner_v2/function_binder.rs @@ -1,8 +1,8 @@ use derive_new::new; -use super::{BindError, BoundExpressionBase, INVALID_INDEX}; +use super::{BindError, BoundCastExpression, BoundExpressionBase, INVALID_INDEX}; use crate::catalog_v2::ScalarFunctionCatalogEntry; -use crate::function::ScalarFunction; +use crate::function::{CastRules, ScalarFunction}; use crate::planner_v2::{BoundExpression, BoundFunctionExpression}; use crate::types_v2::LogicalType; @@ -16,11 +16,20 @@ impl FunctionBinder { func: ScalarFunctionCatalogEntry, children: Vec, ) -> Result { + // bind the function let arguments = self.get_logical_types_from_expressions(&children); + // found a matching function! let best_func_idx = self.bind_function_from_arguments(&func, &arguments)?; let bound_function = func.functions[best_func_idx].clone(); + // check if we need to add casts to the children + let new_children = self.cast_to_function_arguments(&bound_function, children)?; + // now create the function let base = BoundExpressionBase::new("".to_string(), bound_function.return_type.clone()); - Ok(BoundFunctionExpression::new(base, bound_function, children)) + Ok(BoundFunctionExpression::new( + base, + bound_function, + new_children, + )) } fn get_logical_types_from_expressions(&self, children: &[BoundExpression]) -> Vec { @@ -34,13 +43,28 @@ impl FunctionBinder { ) -> Result { let mut candidate_functions = vec![]; let mut best_function_idx = INVALID_INDEX; + let mut lowest_cost = i32::MAX; for (func_idx, each_func) in func.functions.iter().enumerate() { + // check the arguments of the function let cost = self.bind_function_cost(each_func, arguments); if cost < 0 { + // auto casting was not possible continue; } - candidate_functions.push(func_idx); + if cost == lowest_cost { + // we have multiple functions with the same cost, so just add it to the candidates + candidate_functions.push(func_idx); + continue; + } + if cost > lowest_cost { + // we have a function with a higher cost, so skip it + continue; + } + // we have a function with a lower cost, so clear the candidates and add this one + candidate_functions.clear(); + lowest_cost = cost; best_function_idx = func_idx; + candidate_functions.push(best_function_idx); } if best_function_idx == INVALID_INDEX { @@ -65,14 +89,45 @@ impl FunctionBinder { // invalid argument count: check the next function return -1; } - let cost = 0; - // TODO: use cast function to infer the cost and choose the best matched function. + let mut cost = 0; for (i, arg) in arguments.iter().enumerate() { if func.arguments[i] != *arg { // invalid argument count: check the next function - return -1; + let cast_cost = CastRules::implicit_cast_cost(arg, &func.arguments[i]); + if cast_cost >= 0 { + // we can implicitly cast, add the cost to the total cost + cost += cast_cost; + } else { + // we can't implicitly cast + return -1; + } } } cost } + + fn cast_to_function_arguments( + &self, + bound_function: &ScalarFunction, + children: Vec, + ) -> Result, BindError> { + let mut new_children = vec![]; + for (i, child) in children.into_iter().enumerate() { + let target_type = &bound_function.arguments[i]; + let source_type = &child.return_type(); + if source_type == target_type { + // no need to cast + new_children.push(child); + } else { + // we need to cast + new_children.push(BoundCastExpression::add_cast_to_type( + child, + target_type.clone(), + "".to_string(), + true, + )?); + } + } + Ok(new_children) + } } diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs index ef16dcc..1842db9 100644 --- a/src/types_v2/types.rs +++ b/src/types_v2/types.rs @@ -132,7 +132,7 @@ impl LogicalType { } } - fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool { + pub fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool { if from == to { return true; } diff --git a/src/types_v2/values.rs b/src/types_v2/values.rs index 07eba7e..3459634 100644 --- a/src/types_v2/values.rs +++ b/src/types_v2/values.rs @@ -432,19 +432,8 @@ impl From<&sqlparser::ast::Value> for ScalarValue { fn from(v: &sqlparser::ast::Value) -> Self { match v { sqlparser::ast::Value::Number(n, _) => { - if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { - v.into() - } else if let Ok(v) = n.parse::() { + // use i32 to handle most cases + if let Ok(v) = n.parse::() { v.into() } else if let Ok(v) = n.parse::() { v.into() diff --git a/tests/slt/scalar_function.slt b/tests/slt/scalar_function.slt index 60eef74..e740e66 100644 --- a/tests/slt/scalar_function.slt +++ b/tests/slt/scalar_function.slt @@ -38,3 +38,11 @@ select a/a from test 1 1 NULL + + +# cast arguments +onlyif sqlrs_v2 +query I +select 100 + 1000.2 +---- +1100.2 diff --git a/tests/slt/select.slt b/tests/slt/select.slt index b4e79bd..4956d4a 100644 --- a/tests/slt/select.slt +++ b/tests/slt/select.slt @@ -7,6 +7,11 @@ Gregg CO 2 10000 John CO 3 11500 Von (empty) 4 NULL +# test insert projection with cast expression +onlyif sqlrs_v2 +statement ok +create table t2(v1 tinyint); +insert into t2(v1) values (1), (5); onlyif sqlrs_v2 statement ok