diff --git a/src/execution/expression_executor.rs b/src/execution/expression_executor.rs index 4dee327..03f8187 100644 --- a/src/execution/expression_executor.rs +++ b/src/execution/expression_executor.rs @@ -48,6 +48,18 @@ impl ExpressionExecutor { let func = e.function.function; func(&left_result, &right_result)? } + BoundExpression::BoundConjunctionExpression(e) => { + assert!(e.children.len() >= 2); + let mut conjunction_result = Self::execute_internal(&e.children[0], input)?; + for i in 1..e.children.len() { + let func = e.function.function; + conjunction_result = func( + &conjunction_result, + &Self::execute_internal(&e.children[i], input)?, + )?; + } + conjunction_result + } }) } } diff --git a/src/function/conjunction/conjunction_function.rs b/src/function/conjunction/conjunction_function.rs new file mode 100644 index 0000000..e2c7436 --- /dev/null +++ b/src/function/conjunction/conjunction_function.rs @@ -0,0 +1,20 @@ +use arrow::array::ArrayRef; +use derive_new::new; + +use crate::function::FunctionError; + +pub type ConjunctionFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result; + +#[derive(new, Clone)] +pub struct ConjunctionFunction { + pub(crate) name: String, + pub(crate) function: ConjunctionFunc, +} + +impl std::fmt::Debug for ConjunctionFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ConjunctionFunction") + .field("name", &self.name) + .finish() + } +} diff --git a/src/function/conjunction/default_conjunction.rs b/src/function/conjunction/default_conjunction.rs new file mode 100644 index 0000000..93ad755 --- /dev/null +++ b/src/function/conjunction/default_conjunction.rs @@ -0,0 +1,65 @@ +use std::sync::Arc; + +use arrow::array::{ArrayRef, BooleanArray}; +use arrow::compute::{and_kleene, or_kleene}; +use arrow::datatypes::DataType; +use sqlparser::ast::BinaryOperator; + +use super::{ConjunctionFunc, ConjunctionFunction}; +use crate::function::FunctionError; + +pub struct DefaultConjunctionFunctions; + +macro_rules! boolean_op { + ($LEFT:expr, $RIGHT:expr, $OP:ident) => {{ + if *$LEFT.data_type() != DataType::Boolean || *$RIGHT.data_type() != DataType::Boolean { + return Err(FunctionError::ConjunctionError(format!( + "Cannot evaluate binary expression with types {:?} and {:?}, only Boolean supported", + $LEFT.data_type(), + $RIGHT.data_type() + ))); + } + + let ll = $LEFT + .as_any() + .downcast_ref::() + .expect("boolean_op failed to downcast array"); + let rr = $RIGHT + .as_any() + .downcast_ref::() + .expect("boolean_op failed to downcast array"); + Ok(Arc::new($OP(&ll, &rr)?)) + }}; +} + +impl DefaultConjunctionFunctions { + fn default_and_function(left: &ArrayRef, right: &ArrayRef) -> Result { + boolean_op!(left, right, and_kleene) + } + + fn default_or_function(left: &ArrayRef, right: &ArrayRef) -> Result { + boolean_op!(left, right, or_kleene) + } + + fn get_conjunction_function_internal( + op: &BinaryOperator, + ) -> Result<(&str, ConjunctionFunc), FunctionError> { + Ok(match op { + BinaryOperator::And => ("and", Self::default_and_function), + BinaryOperator::Or => ("or", Self::default_or_function), + _ => { + return Err(FunctionError::ConjunctionError(format!( + "Unsupported conjunction operator {:?}", + op + ))) + } + }) + } + + pub fn get_conjunction_function( + op: &BinaryOperator, + ) -> Result { + let (name, func) = Self::get_conjunction_function_internal(op)?; + Ok(ConjunctionFunction::new(name.to_string(), func)) + } +} diff --git a/src/function/conjunction/mod.rs b/src/function/conjunction/mod.rs new file mode 100644 index 0000000..b915760 --- /dev/null +++ b/src/function/conjunction/mod.rs @@ -0,0 +1,4 @@ +mod conjunction_function; +mod default_conjunction; +pub use conjunction_function::*; +pub use default_conjunction::*; diff --git a/src/function/errors.rs b/src/function/errors.rs index 97833ca..6b2abf7 100644 --- a/src/function/errors.rs +++ b/src/function/errors.rs @@ -29,4 +29,6 @@ pub enum FunctionError { CastError(String), #[error("Comparison error: {0}")] ComparisonError(String), + #[error("Conjunction error: {0}")] + ConjunctionError(String), } diff --git a/src/function/mod.rs b/src/function/mod.rs index a4040b9..ce6bd41 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -1,5 +1,6 @@ mod cast; mod comparison; +mod conjunction; mod errors; mod scalar; mod table; @@ -8,6 +9,7 @@ use std::sync::Arc; pub use cast::*; pub use comparison::*; +pub use conjunction::*; use derive_new::new; pub use errors::*; pub use scalar::*; diff --git a/src/planner_v2/binder/expression/bind_cast_expression.rs b/src/planner_v2/binder/expression/bind_cast_expression.rs index 6abba1f..e98e5e0 100644 --- a/src/planner_v2/binder/expression/bind_cast_expression.rs +++ b/src/planner_v2/binder/expression/bind_cast_expression.rs @@ -24,6 +24,7 @@ impl BoundCastExpression { alias: String, try_cast: bool, ) -> Result { + // TODO: enhance alias to reduce outside alias assignment let source_type = expr.return_type(); assert!(source_type != target_type); let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?; diff --git a/src/planner_v2/binder/expression/bind_conjunction_expression.rs b/src/planner_v2/binder/expression/bind_conjunction_expression.rs new file mode 100644 index 0000000..30f65ba --- /dev/null +++ b/src/planner_v2/binder/expression/bind_conjunction_expression.rs @@ -0,0 +1,55 @@ +use derive_new::new; + +use super::{BoundCastExpression, BoundExpression, BoundExpressionBase}; +use crate::function::{ConjunctionFunction, DefaultConjunctionFunctions}; +use crate::planner_v2::{BindError, ExpressionBinder}; +use crate::types_v2::LogicalType; + +#[derive(new, Debug, Clone)] +pub struct BoundConjunctionExpression { + pub(crate) base: BoundExpressionBase, + pub(crate) function: ConjunctionFunction, + pub(crate) children: Vec, +} + +impl ExpressionBinder<'_> { + pub fn bind_conjunction_expression( + &mut self, + left: &sqlparser::ast::Expr, + op: &sqlparser::ast::BinaryOperator, + right: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + let function = DefaultConjunctionFunctions::get_conjunction_function(op)?; + let mut return_names = vec![]; + let mut left = self.bind_expression(left, &mut return_names, &mut vec![])?; + let mut right = self.bind_expression(right, &mut return_names, &mut vec![])?; + if left.return_type() != LogicalType::Boolean { + let alias = format!("cast({} as {}", left.alias(), LogicalType::Boolean); + left = BoundCastExpression::add_cast_to_type( + left, + LogicalType::Boolean, + alias.clone(), + true, + )?; + return_names[0] = alias; + } + if right.return_type() != LogicalType::Boolean { + let alias = format!("cast({} as {}", right.alias(), LogicalType::Boolean); + right = BoundCastExpression::add_cast_to_type( + right, + LogicalType::Boolean, + alias.clone(), + true, + )?; + return_names[1] = alias; + } + result_names.push(format!("{}({},{})", op, return_names[0], return_names[1])); + result_types.push(LogicalType::Boolean); + let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean); + Ok(BoundExpression::BoundConjunctionExpression( + BoundConjunctionExpression::new(base, function, vec![left, right]), + )) + } +} diff --git a/src/planner_v2/binder/expression/mod.rs b/src/planner_v2/binder/expression/mod.rs index 3e7f023..7a068d5 100644 --- a/src/planner_v2/binder/expression/mod.rs +++ b/src/planner_v2/binder/expression/mod.rs @@ -1,6 +1,7 @@ mod bind_cast_expression; mod bind_column_ref_expression; mod bind_comparison_expression; +mod bind_conjunction_expression; mod bind_constant_expression; mod bind_function_expression; mod bind_reference_expression; @@ -9,6 +10,7 @@ mod column_binding; pub use bind_cast_expression::*; pub use bind_column_ref_expression::*; pub use bind_comparison_expression::*; +pub use bind_conjunction_expression::*; pub use bind_constant_expression::*; pub use bind_function_expression::*; pub use bind_reference_expression::*; @@ -33,6 +35,7 @@ pub enum BoundExpression { BoundCastExpression(BoundCastExpression), BoundFunctionExpression(BoundFunctionExpression), BoundComparisonExpression(BoundComparisonExpression), + BoundConjunctionExpression(BoundConjunctionExpression), } impl BoundExpression { @@ -44,6 +47,7 @@ impl BoundExpression { BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundFunctionExpression(expr) => expr.base.return_type.clone(), BoundExpression::BoundComparisonExpression(expr) => expr.base.return_type.clone(), + BoundExpression::BoundConjunctionExpression(expr) => expr.base.return_type.clone(), } } @@ -55,6 +59,7 @@ impl BoundExpression { BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundFunctionExpression(expr) => expr.base.alias.clone(), BoundExpression::BoundComparisonExpression(expr) => expr.base.alias.clone(), + BoundExpression::BoundConjunctionExpression(expr) => expr.base.alias.clone(), } } @@ -66,6 +71,7 @@ impl BoundExpression { BoundExpression::BoundCastExpression(expr) => expr.base.alias = alias, BoundExpression::BoundFunctionExpression(expr) => expr.base.alias = alias, BoundExpression::BoundComparisonExpression(expr) => expr.base.alias = alias, + BoundExpression::BoundConjunctionExpression(expr) => expr.base.alias = alias, } } } diff --git a/src/planner_v2/expression_binder.rs b/src/planner_v2/expression_binder.rs index cb39376..7c963ce 100644 --- a/src/planner_v2/expression_binder.rs +++ b/src/planner_v2/expression_binder.rs @@ -61,8 +61,9 @@ impl ExpressionBinder<'_> { | sqlparser::ast::BinaryOperator::NotEq => { self.bind_comparison_expression(left, op, right, result_names, result_types) } - sqlparser::ast::BinaryOperator::And => todo!(), - sqlparser::ast::BinaryOperator::Or => todo!(), + sqlparser::ast::BinaryOperator::And | sqlparser::ast::BinaryOperator::Or => { + self.bind_conjunction_expression(left, op, right, result_names, result_types) + } other => Err(BindError::UnsupportedExpr(other.to_string())), } } diff --git a/src/planner_v2/expression_iterator.rs b/src/planner_v2/expression_iterator.rs index b1e9113..02adf9b 100644 --- a/src/planner_v2/expression_iterator.rs +++ b/src/planner_v2/expression_iterator.rs @@ -19,6 +19,9 @@ impl ExpressionIterator { callback(&mut e.left); callback(&mut e.right); } + BoundExpression::BoundConjunctionExpression(e) => { + e.children.iter_mut().for_each(callback) + } } } } diff --git a/src/planner_v2/logical_operator_visitor.rs b/src/planner_v2/logical_operator_visitor.rs index e07de49..66972e3 100644 --- a/src/planner_v2/logical_operator_visitor.rs +++ b/src/planner_v2/logical_operator_visitor.rs @@ -1,7 +1,7 @@ use super::{ BoundCastExpression, BoundColumnRefExpression, BoundComparisonExpression, - BoundConstantExpression, BoundExpression, BoundFunctionExpression, BoundReferenceExpression, - ExpressionIterator, LogicalOperator, + BoundConjunctionExpression, BoundConstantExpression, BoundExpression, BoundFunctionExpression, + BoundReferenceExpression, ExpressionIterator, LogicalOperator, }; /// Visitor pattern on logical operators, also includes rewrite expression ability. @@ -38,6 +38,7 @@ pub trait LogicalOperatorVisitor { BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e), BoundExpression::BoundFunctionExpression(e) => self.visit_function_expression(e), BoundExpression::BoundComparisonExpression(e) => self.visit_comparison_expression(e), + BoundExpression::BoundConjunctionExpression(e) => self.visit_conjunction_expression(e), }; if let Some(new_expr) = result { *expr = new_expr; @@ -71,4 +72,10 @@ pub trait LogicalOperatorVisitor { ) -> Option { None } + fn visit_conjunction_expression( + &self, + _: &BoundConjunctionExpression, + ) -> Option { + None + } } diff --git a/src/util/tree_render.rs b/src/util/tree_render.rs index fc7a4f4..e604315 100644 --- a/src/util/tree_render.rs +++ b/src/util/tree_render.rs @@ -51,6 +51,15 @@ impl TreeRender { let r = Self::bound_expression_to_string(&e.right); format!("{} {} {}", l, e.function.name, r) } + BoundExpression::BoundConjunctionExpression(e) => { + let args = e + .children + .iter() + .map(Self::bound_expression_to_string) + .collect::>() + .join(", "); + format!("{}({}])", e.function.name, args) + } } } diff --git a/tests/slt/conjunction_function.slt b/tests/slt/conjunction_function.slt new file mode 100644 index 0000000..3b3ad80 --- /dev/null +++ b/tests/slt/conjunction_function.slt @@ -0,0 +1,159 @@ +onlyif sqlrs_v2 +query T +SELECT true and true and true +---- +true + +onlyif sqlrs_v2 +query T +SELECT true and false +---- +false + +onlyif sqlrs_v2 +query T +SELECT false and NULL +---- +false + +onlyif sqlrs_v2 +query T +SELECT NULL and true +---- +NULL + +onlyif sqlrs_v2 +query T +SELECT true and false or false +---- +false + +onlyif sqlrs_v2 +query T +SELECT true or false +---- +true + +onlyif sqlrs_v2 +query T +SELECT false or NULL +---- +NULL + +onlyif sqlrs_v2 +query T +SELECT NULL or false +---- +NULL + + +# create table +onlyif sqlrs_v2 +statement ok +CREATE TABLE a (i integer, j integer); +INSERT INTO a VALUES (3, 4), (4, 5), (5, 6); + +# test single constant in conjunctions +onlyif sqlrs_v2 +query T +SELECT true AND i>3 FROM a +---- +false +true +true + +onlyif sqlrs_v2 +query T +SELECT i>3 AND true FROM a +---- +false +true +true + +onlyif sqlrs_v2 +query T +SELECT 2>3 AND i>3 FROM a +---- +false +false +false + +onlyif sqlrs_v2 +query T +SELECT false AND i>3 FROM a +---- +false +false +false + +onlyif sqlrs_v2 +query T +SELECT i>3 AND false FROM a +---- +false +false +false + +onlyif sqlrs_v2 +query T +SELECT false OR i>3 FROM a +---- +false +true +true + +onlyif sqlrs_v2 +query T +SELECT i>3 OR false FROM a +---- +false +true +true + +onlyif sqlrs_v2 +query T +SELECT true OR i>3 FROM a +---- +true +true +true + +onlyif sqlrs_v2 +query T +SELECT i>3 OR true FROM a +---- +true +true +true + +onlyif sqlrs_v2 +query T +SELECT NULL OR i>3 FROM a +---- +NULL +true +true + +onlyif sqlrs_v2 +query T +SELECT i>3 OR NULL FROM a +---- +NULL +true +true + +onlyif sqlrs_v2 +query T +SELECT NULL AND i>3 FROM a +---- +false +NULL +NULL + +onlyif sqlrs_v2 +query T +SELECT i>3 AND NULL FROM a +---- +false +NULL +NULL