From 9af089eb44864e9abe721e88c622183ec495a725 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Sat, 24 Dec 2022 14:55:46 +0800 Subject: [PATCH] feat(planner): introduce comparison function Signed-off-by: Fedomn --- src/execution/expression_executor.rs | 6 ++ .../comparison/comparison_function.rs | 33 ++++++++ src/function/comparison/default_comparison.rs | 79 +++++++++++++++++++ src/function/comparison/mod.rs | 4 + src/function/errors.rs | 2 + src/function/mod.rs | 2 + .../expression/bind_comparison_expression.rs | 70 ++++++++++++++++ src/planner_v2/binder/expression/mod.rs | 6 ++ src/planner_v2/expression_binder.rs | 14 ++-- src/planner_v2/expression_iterator.rs | 4 + src/planner_v2/logical_operator_visitor.rs | 12 ++- src/util/tree_render.rs | 5 ++ tests/slt/comparison_function.slt | 19 +++++ 13 files changed, 248 insertions(+), 8 deletions(-) create mode 100644 src/function/comparison/comparison_function.rs create mode 100644 src/function/comparison/default_comparison.rs create mode 100644 src/function/comparison/mod.rs create mode 100644 src/planner_v2/binder/expression/bind_comparison_expression.rs create mode 100644 tests/slt/comparison_function.slt diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index 8c0a82c..4dee327 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -42,6 +42,12 @@ impl ExpressionExecutor { let func = e.function.function; func(&children_result)? } + BoundExpression::BoundComparisonExpression(e) => { + let left_result = Self::execute_internal(&e.left, input)?; + let right_result = Self::execute_internal(&e.right, input)?; + let func = e.function.function; + func(&left_result, &right_result)? + } }) } } diff --git a/src/function/comparison/comparison_function.rs b/src/function/comparison/comparison_function.rs new file mode 100644 index 0000000..5c6c9e8 --- /dev/null +++ b/src/function/comparison/comparison_function.rs @@ -0,0 +1,33 @@ +use arrow::array::ArrayRef; +use derive_new::new; + +use crate::function::FunctionError; +use crate::types_v2::LogicalType; + +pub type ComparisonFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result; + +#[derive(new, Clone)] +pub struct ComparisonFunction { + // The name of the function + pub(crate) name: String, + /// The main comparision function to execute. + /// Left and right arguments must be the same type + pub(crate) function: ComparisonFunc, + /// The comparison type + pub(crate) comparison_type: LogicalType, +} + +impl std::fmt::Debug for ComparisonFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CompressionFunction") + .field("name", &self.name) + .field( + "func", + &format!( + "{}{}{}", + self.comparison_type, self.name, self.comparison_type + ), + ) + .finish() + } +} diff --git a/src/function/comparison/default_comparison.rs b/src/function/comparison/default_comparison.rs new file mode 100644 index 0000000..3c57a81 --- /dev/null +++ b/src/function/comparison/default_comparison.rs @@ -0,0 +1,79 @@ +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::compute::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn}; +use sqlparser::ast::BinaryOperator; + +use super::{ComparisonFunc, ComparisonFunction}; +use crate::function::FunctionError; +use crate::types_v2::LogicalType; + +pub struct DefaultComparisonFunctions; + +impl DefaultComparisonFunctions { + fn default_gt_function(left: &ArrayRef, right: &ArrayRef) -> Result { + Ok(Arc::new(gt_dyn(left, right)?)) + } + + fn default_gt_eq_function( + left: &ArrayRef, + right: &ArrayRef, + ) -> Result { + Ok(Arc::new(gt_eq_dyn(left, right)?)) + } + + fn default_lt_function(left: &ArrayRef, right: &ArrayRef) -> Result { + Ok(Arc::new(lt_dyn(left, right)?)) + } + + fn default_lt_eq_function( + left: &ArrayRef, + right: &ArrayRef, + ) -> Result { + Ok(Arc::new(lt_eq_dyn(left, right)?)) + } + + fn default_eq_function(left: &ArrayRef, right: &ArrayRef) -> Result { + Ok(Arc::new(eq_dyn(left, right)?)) + } + + fn default_neq_function(left: &ArrayRef, right: &ArrayRef) -> Result { + Ok(Arc::new(neq_dyn(left, right)?)) + } + + fn get_comparison_function_internal( + op: &BinaryOperator, + ) -> Result<(&str, ComparisonFunc), FunctionError> { + Ok(match op { + BinaryOperator::Eq => ("eq", Self::default_eq_function), + BinaryOperator::NotEq => ("neq", Self::default_neq_function), + BinaryOperator::Lt => ("lt", Self::default_lt_function), + BinaryOperator::LtEq => ("lt_eq", Self::default_lt_eq_function), + BinaryOperator::Gt => ("gt", Self::default_gt_function), + BinaryOperator::GtEq => ("gt_eq", Self::default_gt_eq_function), + _ => { + return Err(FunctionError::ComparisonError(format!( + "Unsupported comparison operator {:?}", + op + ))) + } + }) + } + + pub fn get_comparison_function( + op: &BinaryOperator, + comparison_type: &LogicalType, + ) -> Result { + if comparison_type == &LogicalType::Invalid { + return Err(FunctionError::ComparisonError( + "Invalid comparison type".to_string(), + )); + } + let (name, func) = Self::get_comparison_function_internal(op)?; + Ok(ComparisonFunction::new( + name.to_string(), + func, + comparison_type.clone(), + )) + } +} diff --git a/src/function/comparison/mod.rs b/src/function/comparison/mod.rs new file mode 100644 index 0000000..500dea0 --- /dev/null +++ b/src/function/comparison/mod.rs @@ -0,0 +1,4 @@ +mod comparison_function; +mod default_comparison; +pub use comparison_function::*; +pub use default_comparison::*; diff --git a/src/function/errors.rs b/src/function/errors.rs index 70327cb..97833ca 100644 --- a/src/function/errors.rs +++ b/src/function/errors.rs @@ -27,4 +27,6 @@ pub enum FunctionError { InternalError(String), #[error("Cast error: {0}")] CastError(String), + #[error("Comparison error: {0}")] + ComparisonError(String), } diff --git a/src/function/mod.rs b/src/function/mod.rs index 5d28890..a4040b9 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,4 +1,5 @@ mod cast; +mod comparison; mod errors; mod scalar; mod table; @@ -6,6 +7,7 @@ mod table; use std::sync::Arc; pub use cast::*; +pub use comparison::*; use derive_new::new; pub use errors::*; pub use scalar::*; diff --git a/src/planner_v2/binder/expression/bind_comparison_expression.rs b/src/planner_v2/binder/expression/bind_comparison_expression.rs new file mode 100644 index 0000000..d866a4c --- /dev/null +++ b/src/planner_v2/binder/expression/bind_comparison_expression.rs @@ -0,0 +1,70 @@ +use derive_new::new; + +use super::{BoundCastExpression, BoundExpression, BoundExpressionBase}; +use crate::function::{ComparisonFunction, DefaultComparisonFunctions}; +use crate::planner_v2::{BindError, ExpressionBinder}; +use crate::types_v2::LogicalType; + +#[derive(new, Debug, Clone)] +pub struct BoundComparisonExpression { + pub(crate) base: BoundExpressionBase, + pub(crate) left: Box, + pub(crate) right: Box, + /// The comparison function to execute + pub(crate) function: ComparisonFunction, +} + +impl ExpressionBinder<'_> { + pub fn bind_comparison_expression( + &mut self, + left: &sqlparser::ast::Expr, + op: &sqlparser::ast::BinaryOperator, + right: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let mut return_names = vec![]; + let mut return_types = vec![]; + let mut bound_left = self.bind_expression(left, &mut return_names, &mut return_types)?; + let mut bound_right = self.bind_expression(right, &mut return_names, &mut return_types)?; + let left_type = bound_left.return_type(); + let right_type = bound_right.return_type(); + // cast the input types to the same type, now obtain the result type of the input types + let input_type = LogicalType::max_logical_type(&left_type, &right_type)?; + if input_type != left_type { + let alias = format!("cast({} as {}", bound_left.alias(), input_type); + bound_left = BoundCastExpression::add_cast_to_type( + bound_left, + input_type.clone(), + alias.clone(), + true, + )?; + return_names[0] = alias; + return_types[0] = input_type.clone(); + } + if input_type != right_type { + let alias = format!("cast({} as {}", bound_right.alias(), input_type); + bound_right = BoundCastExpression::add_cast_to_type( + bound_right, + input_type.clone(), + alias.clone(), + true, + )?; + return_names[1] = alias; + return_types[1] = input_type.clone(); + } + + result_names.push(format!("{}({},{})", op, return_names[0], return_names[1])); + result_types.push(LogicalType::Boolean); + let function = DefaultComparisonFunctions::get_comparison_function(op, &input_type)?; + let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean); + Ok(BoundExpression::BoundComparisonExpression( + BoundComparisonExpression::new( + base, + Box::new(bound_left), + Box::new(bound_right), + function, + ), + )) + } +} diff --git a/src/planner_v2/binder/expression/mod.rs b/src/planner_v2/binder/expression/mod.rs index e5b224f..3e7f023 100644 --- a/src/planner_v2/binder/expression/mod.rs +++ b/src/planner_v2/binder/expression/mod.rs @@ -1,5 +1,6 @@ mod bind_cast_expression; mod bind_column_ref_expression; +mod bind_comparison_expression; mod bind_constant_expression; mod bind_function_expression; mod bind_reference_expression; @@ -7,6 +8,7 @@ mod column_binding; pub use bind_cast_expression::*; pub use bind_column_ref_expression::*; +pub use bind_comparison_expression::*; pub use bind_constant_expression::*; pub use bind_function_expression::*; pub use bind_reference_expression::*; @@ -30,6 +32,7 @@ pub enum BoundExpression { BoundReferenceExpression(BoundReferenceExpression), BoundCastExpression(BoundCastExpression), BoundFunctionExpression(BoundFunctionExpression), + BoundComparisonExpression(BoundComparisonExpression), } impl BoundExpression { @@ -40,6 +43,7 @@ impl BoundExpression { BoundExpression::BoundReferenceExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundFunctionExpression(expr) => expr.base.return_type.clone(), + BoundExpression::BoundComparisonExpression(expr) => expr.base.return_type.clone(), } } @@ -50,6 +54,7 @@ impl BoundExpression { BoundExpression::BoundReferenceExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundFunctionExpression(expr) => expr.base.alias.clone(), + BoundExpression::BoundComparisonExpression(expr) => expr.base.alias.clone(), } } @@ -60,6 +65,7 @@ impl BoundExpression { BoundExpression::BoundReferenceExpression(expr) => expr.base.alias = alias, BoundExpression::BoundCastExpression(expr) => expr.base.alias = alias, BoundExpression::BoundFunctionExpression(expr) => expr.base.alias = alias, + BoundExpression::BoundComparisonExpression(expr) => expr.base.alias = alias, } } } diff --git a/src/planner_v2/expression_binder.rs b/src/planner_v2/expression_binder.rs index 3c77e94..cb39376 100644 --- a/src/planner_v2/expression_binder.rs +++ b/src/planner_v2/expression_binder.rs @@ -53,12 +53,14 @@ impl ExpressionBinder<'_> { | sqlparser::ast::BinaryOperator::Divide => { self.bind_function_expression(left, op, right, result_names, result_types) } - sqlparser::ast::BinaryOperator::Gt => todo!(), - sqlparser::ast::BinaryOperator::Lt => todo!(), - sqlparser::ast::BinaryOperator::GtEq => todo!(), - sqlparser::ast::BinaryOperator::LtEq => todo!(), - sqlparser::ast::BinaryOperator::Eq => todo!(), - sqlparser::ast::BinaryOperator::NotEq => todo!(), + sqlparser::ast::BinaryOperator::Gt + | sqlparser::ast::BinaryOperator::Lt + | sqlparser::ast::BinaryOperator::GtEq + | sqlparser::ast::BinaryOperator::LtEq + | sqlparser::ast::BinaryOperator::Eq + | sqlparser::ast::BinaryOperator::NotEq => { + self.bind_comparison_expression(left, op, right, result_names, result_types) + } sqlparser::ast::BinaryOperator::And => todo!(), sqlparser::ast::BinaryOperator::Or => todo!(), other => Err(BindError::UnsupportedExpr(other.to_string())), diff --git a/src/planner_v2/expression_iterator.rs b/src/planner_v2/expression_iterator.rs index 9beb092..b1e9113 100644 --- a/src/planner_v2/expression_iterator.rs +++ b/src/planner_v2/expression_iterator.rs @@ -15,6 +15,10 @@ impl ExpressionIterator { } BoundExpression::BoundCastExpression(e) => callback(&mut e.child), BoundExpression::BoundFunctionExpression(e) => e.children.iter_mut().for_each(callback), + BoundExpression::BoundComparisonExpression(e) => { + callback(&mut e.left); + callback(&mut e.right); + } } } } diff --git a/src/planner_v2/logical_operator_visitor.rs b/src/planner_v2/logical_operator_visitor.rs index 6bb078b..e07de49 100644 --- a/src/planner_v2/logical_operator_visitor.rs +++ b/src/planner_v2/logical_operator_visitor.rs @@ -1,6 +1,7 @@ use super::{ - BoundCastExpression, BoundColumnRefExpression, BoundConstantExpression, BoundExpression, - BoundFunctionExpression, BoundReferenceExpression, ExpressionIterator, LogicalOperator, + BoundCastExpression, BoundColumnRefExpression, BoundComparisonExpression, + BoundConstantExpression, BoundExpression, BoundFunctionExpression, BoundReferenceExpression, + ExpressionIterator, LogicalOperator, }; /// Visitor pattern on logical operators, also includes rewrite expression ability. @@ -36,6 +37,7 @@ pub trait LogicalOperatorVisitor { BoundExpression::BoundReferenceExpression(e) => self.visit_replace_reference(e), BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e), BoundExpression::BoundFunctionExpression(e) => self.visit_function_expression(e), + BoundExpression::BoundComparisonExpression(e) => self.visit_comparison_expression(e), }; if let Some(new_expr) = result { *expr = new_expr; @@ -63,4 +65,10 @@ pub trait LogicalOperatorVisitor { fn visit_function_expression(&self, _: &BoundFunctionExpression) -> Option { None } + fn visit_comparison_expression( + &self, + _: &BoundComparisonExpression, + ) -> Option { + None + } } diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index 1f13fa5..fc7a4f4 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -46,6 +46,11 @@ impl TreeRender { .join(", "); format!("{}({}])", e.function.name, args) } + BoundExpression::BoundComparisonExpression(e) => { + let l = Self::bound_expression_to_string(&e.left); + let r = Self::bound_expression_to_string(&e.right); + format!("{} {} {}", l, e.function.name, r) + } } } diff --git a/tests/slt/comparison_function.slt b/tests/slt/comparison_function.slt new file mode 100644 index 0000000..3226b17 --- /dev/null +++ b/tests/slt/comparison_function.slt @@ -0,0 +1,19 @@ +onlyif sqlrs_v2 +statement error +select 'abc' > 10 + +onlyif sqlrs_v2 +statement error +select 20.0 = 'abc' + +onlyif sqlrs_v2 +query T +select 100 > 20 +---- +true + +onlyif sqlrs_v2 +query T +select '1000' > '20' +---- +false