Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use arrow::array::ArrayRef;
use arrow::compute::{cast_with_options, CastOptions};
use arrow::record_batch::RecordBatch;

use super::ExecutorError;
Expand Down Expand Up @@ -31,9 +30,8 @@ impl ExpressionExecutor {
BoundExpression::BoundReferenceExpression(e) => input.column(e.index).clone(),
BoundExpression::BoundCastExpression(e) => {
let child_result = Self::execute_internal(&e.child, input)?;
let to_type = e.base.return_type.clone().into();
let options = CastOptions { safe: e.try_cast };
cast_with_options(&child_result, &to_type, &options)?
let cast_function = e.function.function;
cast_function(&child_result, &e.base.return_type, e.try_cast)?
}
BoundExpression::BoundFunctionExpression(e) => {
let children_result = e
Expand Down
26 changes: 26 additions & 0 deletions src/function/cast/cast_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use arrow::array::ArrayRef;
use derive_new::new;

use crate::function::FunctionError;
use crate::types_v2::LogicalType;

pub type CastFunc =
fn(array: &ArrayRef, to_type: &LogicalType, try_cast: bool) -> Result<ArrayRef, FunctionError>;

#[derive(new, Clone)]
pub struct CastFunction {
/// The source type of the cast
pub(crate) source: LogicalType,
/// The target type of the cast
pub(crate) target: LogicalType,
/// The main cast function to execute
pub(crate) function: CastFunc,
}

impl std::fmt::Debug for CastFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CastFunction")
.field("cast", &format!("{:?} -> {:?}", self.source, self.target))
.finish()
}
}
15 changes: 15 additions & 0 deletions src/function/cast/cast_rules.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
use crate::types_v2::LogicalType;

pub struct CastRules;

impl CastRules {
pub fn implicit_cast_cost(from: &LogicalType, to: &LogicalType) -> i32 {
if from == to {
0
} else if LogicalType::can_implicit_cast(from, to) {
1
} else {
-1
}
}
}
37 changes: 37 additions & 0 deletions src/function/cast/default_cast.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use arrow::array::ArrayRef;
use arrow::compute::{cast_with_options, CastOptions};

use super::CastFunction;
use crate::function::FunctionError;
use crate::types_v2::LogicalType;

pub struct DefaultCastFunctions;

impl DefaultCastFunctions {
fn default_cast_function(
array: &ArrayRef,
to_type: &LogicalType,
try_cast: bool,
) -> Result<ArrayRef, FunctionError> {
let to_type = to_type.clone().into();
let options = CastOptions { safe: try_cast };
Ok(cast_with_options(array, &to_type, &options)?)
}

pub fn get_cast_function(
source: &LogicalType,
target: &LogicalType,
) -> Result<CastFunction, FunctionError> {
assert!(source != target);
match source {
LogicalType::Invalid => {
Err(FunctionError::CastError("Invalid source type".to_string()))
}
_ => Ok(CastFunction::new(
source.clone(),
target.clone(),
Self::default_cast_function,
)),
}
}
}
7 changes: 7 additions & 0 deletions src/function/cast/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
mod cast_function;
mod cast_rules;
mod default_cast;

pub use cast_function::*;
pub use cast_rules::*;
pub use default_cast::*;
2 changes: 2 additions & 0 deletions src/function/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ pub enum FunctionError {
),
#[error("Internal error: {0}")]
InternalError(String),
#[error("Cast error: {0}")]
CastError(String),
}
2 changes: 2 additions & 0 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
mod cast;
mod errors;
mod scalar;
mod table;

use std::sync::Arc;

pub use cast::*;
use derive_new::new;
pub use errors::*;
pub use scalar::*;
Expand Down
1 change: 0 additions & 1 deletion src/function/scalar/scalar_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use derive_new::new;
use crate::function::FunctionError;
use crate::types_v2::LogicalType;

// pub type ScalarFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError>;
pub type ScalarFunc = fn(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError>;

#[derive(new, Clone)]
Expand Down
18 changes: 10 additions & 8 deletions src/planner_v2/binder/expression/bind_cast_expression.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use derive_new::new;

use super::{BoundExpression, BoundExpressionBase};
use crate::function::{CastFunction, DefaultCastFunctions};
use crate::planner_v2::BindError;
use crate::types_v2::LogicalType;

#[derive(new, Debug, Clone)]
Expand All @@ -11,6 +13,8 @@ pub struct BoundCastExpression {
/// Whether to use try_cast or not. try_cast converts cast failures into NULLs instead of
/// throwing an error.
pub(crate) try_cast: bool,
/// The cast function to execute
pub(crate) function: CastFunction,
}

impl BoundCastExpression {
Expand All @@ -19,15 +23,13 @@ impl BoundCastExpression {
target_type: LogicalType,
alias: String,
try_cast: bool,
) -> BoundExpression {
if expr.return_type() == target_type {
return expr;
}
) -> Result<BoundExpression, BindError> {
let source_type = expr.return_type();
assert!(source_type != target_type);
let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?;
let base = BoundExpressionBase::new(alias, target_type);
BoundExpression::BoundCastExpression(BoundCastExpression::new(
base,
Box::new(expr),
try_cast,
Ok(BoundExpression::BoundCastExpression(
BoundCastExpression::new(base, Box::new(expr), try_cast, cast_function),
))
}
}
2 changes: 1 addition & 1 deletion src/planner_v2/binder/query_node/plan_select_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ impl Binder {
target_type.clone(),
alias,
false,
);
)?;
node.base.types[idx] = target_type.clone();
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/planner_v2/binder/tableref/bind_expression_list_ref.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,14 @@ impl Binder {
// insert values contains SqlNull, the expr should be cast to the max logical type
for exprs in bound_expr_list.iter_mut() {
for (idx, bound_expr) in exprs.iter_mut().enumerate() {
if bound_expr.return_type() == LogicalType::SqlNull {
if bound_expr.return_type() != types[idx] {
let alias = bound_expr.alias().clone();
*bound_expr = BoundCastExpression::add_cast_to_type(
bound_expr.clone(),
types[idx].clone(),
alias,
false,
)
)?
}
}
}
Expand Down
69 changes: 62 additions & 7 deletions src/planner_v2/function_binder.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use derive_new::new;

use super::{BindError, BoundExpressionBase, INVALID_INDEX};
use super::{BindError, BoundCastExpression, BoundExpressionBase, INVALID_INDEX};
use crate::catalog_v2::ScalarFunctionCatalogEntry;
use crate::function::ScalarFunction;
use crate::function::{CastRules, ScalarFunction};
use crate::planner_v2::{BoundExpression, BoundFunctionExpression};
use crate::types_v2::LogicalType;

Expand All @@ -16,11 +16,20 @@ impl FunctionBinder {
func: ScalarFunctionCatalogEntry,
children: Vec<BoundExpression>,
) -> Result<BoundFunctionExpression, BindError> {
// bind the function
let arguments = self.get_logical_types_from_expressions(&children);
// found a matching function!
let best_func_idx = self.bind_function_from_arguments(&func, &arguments)?;
let bound_function = func.functions[best_func_idx].clone();
// check if we need to add casts to the children
let new_children = self.cast_to_function_arguments(&bound_function, children)?;
// now create the function
let base = BoundExpressionBase::new("".to_string(), bound_function.return_type.clone());
Ok(BoundFunctionExpression::new(base, bound_function, children))
Ok(BoundFunctionExpression::new(
base,
bound_function,
new_children,
))
}

fn get_logical_types_from_expressions(&self, children: &[BoundExpression]) -> Vec<LogicalType> {
Expand All @@ -34,13 +43,28 @@ impl FunctionBinder {
) -> Result<usize, BindError> {
let mut candidate_functions = vec![];
let mut best_function_idx = INVALID_INDEX;
let mut lowest_cost = i32::MAX;
for (func_idx, each_func) in func.functions.iter().enumerate() {
// check the arguments of the function
let cost = self.bind_function_cost(each_func, arguments);
if cost < 0 {
// auto casting was not possible
continue;
}
candidate_functions.push(func_idx);
if cost == lowest_cost {
// we have multiple functions with the same cost, so just add it to the candidates
candidate_functions.push(func_idx);
continue;
}
if cost > lowest_cost {
// we have a function with a higher cost, so skip it
continue;
}
// we have a function with a lower cost, so clear the candidates and add this one
candidate_functions.clear();
lowest_cost = cost;
best_function_idx = func_idx;
candidate_functions.push(best_function_idx);
}

if best_function_idx == INVALID_INDEX {
Expand All @@ -65,14 +89,45 @@ impl FunctionBinder {
// invalid argument count: check the next function
return -1;
}
let cost = 0;
// TODO: use cast function to infer the cost and choose the best matched function.
let mut cost = 0;
for (i, arg) in arguments.iter().enumerate() {
if func.arguments[i] != *arg {
// invalid argument count: check the next function
return -1;
let cast_cost = CastRules::implicit_cast_cost(arg, &func.arguments[i]);
if cast_cost >= 0 {
// we can implicitly cast, add the cost to the total cost
cost += cast_cost;
} else {
// we can't implicitly cast
return -1;
}
}
}
cost
}

fn cast_to_function_arguments(
&self,
bound_function: &ScalarFunction,
children: Vec<BoundExpression>,
) -> Result<Vec<BoundExpression>, BindError> {
let mut new_children = vec![];
for (i, child) in children.into_iter().enumerate() {
let target_type = &bound_function.arguments[i];
let source_type = &child.return_type();
if source_type == target_type {
// no need to cast
new_children.push(child);
} else {
// we need to cast
new_children.push(BoundCastExpression::add_cast_to_type(
child,
target_type.clone(),
"".to_string(),
true,
)?);
}
}
Ok(new_children)
}
}
2 changes: 1 addition & 1 deletion src/types_v2/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ impl LogicalType {
}
}

fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool {
pub fn can_implicit_cast(from: &LogicalType, to: &LogicalType) -> bool {
if from == to {
return true;
}
Expand Down
15 changes: 2 additions & 13 deletions src/types_v2/values.rs
Original file line number Diff line number Diff line change
Expand Up @@ -432,19 +432,8 @@ impl From<&sqlparser::ast::Value> for ScalarValue {
fn from(v: &sqlparser::ast::Value) -> Self {
match v {
sqlparser::ast::Value::Number(n, _) => {
if let Ok(v) = n.parse::<u8>() {
v.into()
} else if let Ok(v) = n.parse::<u16>() {
v.into()
} else if let Ok(v) = n.parse::<u32>() {
v.into()
} else if let Ok(v) = n.parse::<u64>() {
v.into()
} else if let Ok(v) = n.parse::<i8>() {
v.into()
} else if let Ok(v) = n.parse::<i16>() {
v.into()
} else if let Ok(v) = n.parse::<i32>() {
// use i32 to handle most cases
if let Ok(v) = n.parse::<i32>() {
v.into()
} else if let Ok(v) = n.parse::<i64>() {
v.into()
Expand Down
8 changes: 8 additions & 0 deletions tests/slt/scalar_function.slt
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ select a/a from test
1
1
NULL


# cast arguments
onlyif sqlrs_v2
query I
select 100 + 1000.2
----
1100.2
5 changes: 5 additions & 0 deletions tests/slt/select.slt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ Gregg CO 2 10000
John CO 3 11500
Von (empty) 4 NULL

# test insert projection with cast expression
onlyif sqlrs_v2
statement ok
create table t2(v1 tinyint);
insert into t2(v1) values (1), (5);

onlyif sqlrs_v2
statement ok
Expand Down