Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ impl ExpressionExecutor {
let func = e.function.function;
func(&left_result, &right_result)?
}
BoundExpression::BoundConjunctionExpression(e) => {
assert!(e.children.len() >= 2);
let mut conjunction_result = Self::execute_internal(&e.children[0], input)?;
for i in 1..e.children.len() {
let func = e.function.function;
conjunction_result = func(
&conjunction_result,
&Self::execute_internal(&e.children[i], input)?,
)?;
}
conjunction_result
}
})
}
}
20 changes: 20 additions & 0 deletions src/function/conjunction/conjunction_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use arrow::array::ArrayRef;
use derive_new::new;

use crate::function::FunctionError;

pub type ConjunctionFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError>;

#[derive(new, Clone)]
pub struct ConjunctionFunction {
pub(crate) name: String,
pub(crate) function: ConjunctionFunc,
}

impl std::fmt::Debug for ConjunctionFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConjunctionFunction")
.field("name", &self.name)
.finish()
}
}
65 changes: 65 additions & 0 deletions src/function/conjunction/default_conjunction.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, BooleanArray};
use arrow::compute::{and_kleene, or_kleene};
use arrow::datatypes::DataType;
use sqlparser::ast::BinaryOperator;

use super::{ConjunctionFunc, ConjunctionFunction};
use crate::function::FunctionError;

pub struct DefaultConjunctionFunctions;

macro_rules! boolean_op {
($LEFT:expr, $RIGHT:expr, $OP:ident) => {{
if *$LEFT.data_type() != DataType::Boolean || *$RIGHT.data_type() != DataType::Boolean {
return Err(FunctionError::ConjunctionError(format!(
"Cannot evaluate binary expression with types {:?} and {:?}, only Boolean supported",
$LEFT.data_type(),
$RIGHT.data_type()
)));
}

let ll = $LEFT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
let rr = $RIGHT
.as_any()
.downcast_ref::<BooleanArray>()
.expect("boolean_op failed to downcast array");
Ok(Arc::new($OP(&ll, &rr)?))
}};
}

impl DefaultConjunctionFunctions {
fn default_and_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
boolean_op!(left, right, and_kleene)
}

fn default_or_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
boolean_op!(left, right, or_kleene)
}

fn get_conjunction_function_internal(
op: &BinaryOperator,
) -> Result<(&str, ConjunctionFunc), FunctionError> {
Ok(match op {
BinaryOperator::And => ("and", Self::default_and_function),
BinaryOperator::Or => ("or", Self::default_or_function),
_ => {
return Err(FunctionError::ConjunctionError(format!(
"Unsupported conjunction operator {:?}",
op
)))
}
})
}

pub fn get_conjunction_function(
op: &BinaryOperator,
) -> Result<ConjunctionFunction, FunctionError> {
let (name, func) = Self::get_conjunction_function_internal(op)?;
Ok(ConjunctionFunction::new(name.to_string(), func))
}
}
4 changes: 4 additions & 0 deletions src/function/conjunction/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod conjunction_function;
mod default_conjunction;
pub use conjunction_function::*;
pub use default_conjunction::*;
2 changes: 2 additions & 0 deletions src/function/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,6 @@ pub enum FunctionError {
CastError(String),
#[error("Comparison error: {0}")]
ComparisonError(String),
#[error("Conjunction error: {0}")]
ConjunctionError(String),
}
2 changes: 2 additions & 0 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
mod cast;
mod comparison;
mod conjunction;
mod errors;
mod scalar;
mod table;
Expand All @@ -8,6 +9,7 @@ use std::sync::Arc;

pub use cast::*;
pub use comparison::*;
pub use conjunction::*;
use derive_new::new;
pub use errors::*;
pub use scalar::*;
Expand Down
1 change: 1 addition & 0 deletions src/planner_v2/binder/expression/bind_cast_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ impl BoundCastExpression {
alias: String,
try_cast: bool,
) -> Result<BoundExpression, BindError> {
// TODO: enhance alias to reduce outside alias assignment
let source_type = expr.return_type();
assert!(source_type != target_type);
let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?;
Expand Down
55 changes: 55 additions & 0 deletions src/planner_v2/binder/expression/bind_conjunction_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use derive_new::new;

use super::{BoundCastExpression, BoundExpression, BoundExpressionBase};
use crate::function::{ConjunctionFunction, DefaultConjunctionFunctions};
use crate::planner_v2::{BindError, ExpressionBinder};
use crate::types_v2::LogicalType;

#[derive(new, Debug, Clone)]
pub struct BoundConjunctionExpression {
pub(crate) base: BoundExpressionBase,
pub(crate) function: ConjunctionFunction,
pub(crate) children: Vec<BoundExpression>,
}

impl ExpressionBinder<'_> {
pub fn bind_conjunction_expression(
&mut self,
left: &sqlparser::ast::Expr,
op: &sqlparser::ast::BinaryOperator,
right: &sqlparser::ast::Expr,
result_names: &mut Vec<String>,
result_types: &mut Vec<LogicalType>,
) -> Result<BoundExpression, BindError> {
let function = DefaultConjunctionFunctions::get_conjunction_function(op)?;
let mut return_names = vec![];
let mut left = self.bind_expression(left, &mut return_names, &mut vec![])?;
let mut right = self.bind_expression(right, &mut return_names, &mut vec![])?;
if left.return_type() != LogicalType::Boolean {
let alias = format!("cast({} as {}", left.alias(), LogicalType::Boolean);
left = BoundCastExpression::add_cast_to_type(
left,
LogicalType::Boolean,
alias.clone(),
true,
)?;
return_names[0] = alias;
}
if right.return_type() != LogicalType::Boolean {
let alias = format!("cast({} as {}", right.alias(), LogicalType::Boolean);
right = BoundCastExpression::add_cast_to_type(
right,
LogicalType::Boolean,
alias.clone(),
true,
)?;
return_names[1] = alias;
}
result_names.push(format!("{}({},{})", op, return_names[0], return_names[1]));
result_types.push(LogicalType::Boolean);
let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean);
Ok(BoundExpression::BoundConjunctionExpression(
BoundConjunctionExpression::new(base, function, vec![left, right]),
))
}
}
6 changes: 6 additions & 0 deletions src/planner_v2/binder/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
mod bind_cast_expression;
mod bind_column_ref_expression;
mod bind_comparison_expression;
mod bind_conjunction_expression;
mod bind_constant_expression;
mod bind_function_expression;
mod bind_reference_expression;
Expand All @@ -9,6 +10,7 @@ mod column_binding;
pub use bind_cast_expression::*;
pub use bind_column_ref_expression::*;
pub use bind_comparison_expression::*;
pub use bind_conjunction_expression::*;
pub use bind_constant_expression::*;
pub use bind_function_expression::*;
pub use bind_reference_expression::*;
Expand All @@ -33,6 +35,7 @@ pub enum BoundExpression {
BoundCastExpression(BoundCastExpression),
BoundFunctionExpression(BoundFunctionExpression),
BoundComparisonExpression(BoundComparisonExpression),
BoundConjunctionExpression(BoundConjunctionExpression),
}

impl BoundExpression {
Expand All @@ -44,6 +47,7 @@ impl BoundExpression {
BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundFunctionExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundComparisonExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundConjunctionExpression(expr) => expr.base.return_type.clone(),
}
}

Expand All @@ -55,6 +59,7 @@ impl BoundExpression {
BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundFunctionExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundComparisonExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundConjunctionExpression(expr) => expr.base.alias.clone(),
}
}

Expand All @@ -66,6 +71,7 @@ impl BoundExpression {
BoundExpression::BoundCastExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundFunctionExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundComparisonExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundConjunctionExpression(expr) => expr.base.alias = alias,
}
}
}
5 changes: 3 additions & 2 deletions src/planner_v2/expression_binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,9 @@ impl ExpressionBinder<'_> {
| sqlparser::ast::BinaryOperator::NotEq => {
self.bind_comparison_expression(left, op, right, result_names, result_types)
}
sqlparser::ast::BinaryOperator::And => todo!(),
sqlparser::ast::BinaryOperator::Or => todo!(),
sqlparser::ast::BinaryOperator::And | sqlparser::ast::BinaryOperator::Or => {
self.bind_conjunction_expression(left, op, right, result_names, result_types)
}
other => Err(BindError::UnsupportedExpr(other.to_string())),
}
}
Expand Down
3 changes: 3 additions & 0 deletions src/planner_v2/expression_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ impl ExpressionIterator {
callback(&mut e.left);
callback(&mut e.right);
}
BoundExpression::BoundConjunctionExpression(e) => {
e.children.iter_mut().for_each(callback)
}
}
}
}
11 changes: 9 additions & 2 deletions src/planner_v2/logical_operator_visitor.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::{
BoundCastExpression, BoundColumnRefExpression, BoundComparisonExpression,
BoundConstantExpression, BoundExpression, BoundFunctionExpression, BoundReferenceExpression,
ExpressionIterator, LogicalOperator,
BoundConjunctionExpression, BoundConstantExpression, BoundExpression, BoundFunctionExpression,
BoundReferenceExpression, ExpressionIterator, LogicalOperator,
};

/// Visitor pattern on logical operators, also includes rewrite expression ability.
Expand Down Expand Up @@ -38,6 +38,7 @@ pub trait LogicalOperatorVisitor {
BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e),
BoundExpression::BoundFunctionExpression(e) => self.visit_function_expression(e),
BoundExpression::BoundComparisonExpression(e) => self.visit_comparison_expression(e),
BoundExpression::BoundConjunctionExpression(e) => self.visit_conjunction_expression(e),
};
if let Some(new_expr) = result {
*expr = new_expr;
Expand Down Expand Up @@ -71,4 +72,10 @@ pub trait LogicalOperatorVisitor {
) -> Option<BoundExpression> {
None
}
fn visit_conjunction_expression(
&self,
_: &BoundConjunctionExpression,
) -> Option<BoundExpression> {
None
}
}
9 changes: 9 additions & 0 deletions src/util/tree_render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,15 @@ impl TreeRender {
let r = Self::bound_expression_to_string(&e.right);
format!("{} {} {}", l, e.function.name, r)
}
BoundExpression::BoundConjunctionExpression(e) => {
let args = e
.children
.iter()
.map(Self::bound_expression_to_string)
.collect::<Vec<_>>()
.join(", ");
format!("{}({}])", e.function.name, args)
}
}
}

Expand Down
Loading