Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/execution/expression_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,12 @@ impl ExpressionExecutor {
let func = e.function.function;
func(&children_result)?
}
BoundExpression::BoundComparisonExpression(e) => {
let left_result = Self::execute_internal(&e.left, input)?;
let right_result = Self::execute_internal(&e.right, input)?;
let func = e.function.function;
func(&left_result, &right_result)?
}
})
}
}
33 changes: 33 additions & 0 deletions src/function/comparison/comparison_function.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use arrow::array::ArrayRef;
use derive_new::new;

use crate::function::FunctionError;
use crate::types_v2::LogicalType;

pub type ComparisonFunc = fn(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError>;

#[derive(new, Clone)]
pub struct ComparisonFunction {
// The name of the function
pub(crate) name: String,
/// The main comparision function to execute.
/// Left and right arguments must be the same type
pub(crate) function: ComparisonFunc,
/// The comparison type
pub(crate) comparison_type: LogicalType,
}

impl std::fmt::Debug for ComparisonFunction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompressionFunction")
.field("name", &self.name)
.field(
"func",
&format!(
"{}{}{}",
self.comparison_type, self.name, self.comparison_type
),
)
.finish()
}
}
79 changes: 79 additions & 0 deletions src/function/comparison/default_comparison.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::compute::{eq_dyn, gt_dyn, gt_eq_dyn, lt_dyn, lt_eq_dyn, neq_dyn};
use sqlparser::ast::BinaryOperator;

use super::{ComparisonFunc, ComparisonFunction};
use crate::function::FunctionError;
use crate::types_v2::LogicalType;

pub struct DefaultComparisonFunctions;

impl DefaultComparisonFunctions {
fn default_gt_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
Ok(Arc::new(gt_dyn(left, right)?))
}

fn default_gt_eq_function(
left: &ArrayRef,
right: &ArrayRef,
) -> Result<ArrayRef, FunctionError> {
Ok(Arc::new(gt_eq_dyn(left, right)?))
}

fn default_lt_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
Ok(Arc::new(lt_dyn(left, right)?))
}

fn default_lt_eq_function(
left: &ArrayRef,
right: &ArrayRef,
) -> Result<ArrayRef, FunctionError> {
Ok(Arc::new(lt_eq_dyn(left, right)?))
}

fn default_eq_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
Ok(Arc::new(eq_dyn(left, right)?))
}

fn default_neq_function(left: &ArrayRef, right: &ArrayRef) -> Result<ArrayRef, FunctionError> {
Ok(Arc::new(neq_dyn(left, right)?))
}

fn get_comparison_function_internal(
op: &BinaryOperator,
) -> Result<(&str, ComparisonFunc), FunctionError> {
Ok(match op {
BinaryOperator::Eq => ("eq", Self::default_eq_function),
BinaryOperator::NotEq => ("neq", Self::default_neq_function),
BinaryOperator::Lt => ("lt", Self::default_lt_function),
BinaryOperator::LtEq => ("lt_eq", Self::default_lt_eq_function),
BinaryOperator::Gt => ("gt", Self::default_gt_function),
BinaryOperator::GtEq => ("gt_eq", Self::default_gt_eq_function),
_ => {
return Err(FunctionError::ComparisonError(format!(
"Unsupported comparison operator {:?}",
op
)))
}
})
}

pub fn get_comparison_function(
op: &BinaryOperator,
comparison_type: &LogicalType,
) -> Result<ComparisonFunction, FunctionError> {
if comparison_type == &LogicalType::Invalid {
return Err(FunctionError::ComparisonError(
"Invalid comparison type".to_string(),
));
}
let (name, func) = Self::get_comparison_function_internal(op)?;
Ok(ComparisonFunction::new(
name.to_string(),
func,
comparison_type.clone(),
))
}
}
4 changes: 4 additions & 0 deletions src/function/comparison/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
mod comparison_function;
mod default_comparison;
pub use comparison_function::*;
pub use default_comparison::*;
2 changes: 2 additions & 0 deletions src/function/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,6 @@ pub enum FunctionError {
InternalError(String),
#[error("Cast error: {0}")]
CastError(String),
#[error("Comparison error: {0}")]
ComparisonError(String),
}
2 changes: 2 additions & 0 deletions src/function/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod cast;
mod comparison;
mod errors;
mod scalar;
mod table;

use std::sync::Arc;

pub use cast::*;
pub use comparison::*;
use derive_new::new;
pub use errors::*;
pub use scalar::*;
Expand Down
70 changes: 70 additions & 0 deletions src/planner_v2/binder/expression/bind_comparison_expression.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use derive_new::new;

use super::{BoundCastExpression, BoundExpression, BoundExpressionBase};
use crate::function::{ComparisonFunction, DefaultComparisonFunctions};
use crate::planner_v2::{BindError, ExpressionBinder};
use crate::types_v2::LogicalType;

#[derive(new, Debug, Clone)]
pub struct BoundComparisonExpression {
pub(crate) base: BoundExpressionBase,
pub(crate) left: Box<BoundExpression>,
pub(crate) right: Box<BoundExpression>,
/// The comparison function to execute
pub(crate) function: ComparisonFunction,
}

impl ExpressionBinder<'_> {
pub fn bind_comparison_expression(
&mut self,
left: &sqlparser::ast::Expr,
op: &sqlparser::ast::BinaryOperator,
right: &sqlparser::ast::Expr,
result_names: &mut Vec<String>,
result_types: &mut Vec<LogicalType>,
) -> Result<BoundExpression, BindError> {
let mut return_names = vec![];
let mut return_types = vec![];
let mut bound_left = self.bind_expression(left, &mut return_names, &mut return_types)?;
let mut bound_right = self.bind_expression(right, &mut return_names, &mut return_types)?;
let left_type = bound_left.return_type();
let right_type = bound_right.return_type();
// cast the input types to the same type, now obtain the result type of the input types
let input_type = LogicalType::max_logical_type(&left_type, &right_type)?;
if input_type != left_type {
let alias = format!("cast({} as {}", bound_left.alias(), input_type);
bound_left = BoundCastExpression::add_cast_to_type(
bound_left,
input_type.clone(),
alias.clone(),
true,
)?;
return_names[0] = alias;
return_types[0] = input_type.clone();
}
if input_type != right_type {
let alias = format!("cast({} as {}", bound_right.alias(), input_type);
bound_right = BoundCastExpression::add_cast_to_type(
bound_right,
input_type.clone(),
alias.clone(),
true,
)?;
return_names[1] = alias;
return_types[1] = input_type.clone();
}

result_names.push(format!("{}({},{})", op, return_names[0], return_names[1]));
result_types.push(LogicalType::Boolean);
let function = DefaultComparisonFunctions::get_comparison_function(op, &input_type)?;
let base = BoundExpressionBase::new("".to_string(), LogicalType::Boolean);
Ok(BoundExpression::BoundComparisonExpression(
BoundComparisonExpression::new(
base,
Box::new(bound_left),
Box::new(bound_right),
function,
),
))
}
}
6 changes: 6 additions & 0 deletions src/planner_v2/binder/expression/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
mod bind_cast_expression;
mod bind_column_ref_expression;
mod bind_comparison_expression;
mod bind_constant_expression;
mod bind_function_expression;
mod bind_reference_expression;
mod column_binding;

pub use bind_cast_expression::*;
pub use bind_column_ref_expression::*;
pub use bind_comparison_expression::*;
pub use bind_constant_expression::*;
pub use bind_function_expression::*;
pub use bind_reference_expression::*;
Expand All @@ -30,6 +32,7 @@ pub enum BoundExpression {
BoundReferenceExpression(BoundReferenceExpression),
BoundCastExpression(BoundCastExpression),
BoundFunctionExpression(BoundFunctionExpression),
BoundComparisonExpression(BoundComparisonExpression),
}

impl BoundExpression {
Expand All @@ -40,6 +43,7 @@ impl BoundExpression {
BoundExpression::BoundReferenceExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundCastExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundFunctionExpression(expr) => expr.base.return_type.clone(),
BoundExpression::BoundComparisonExpression(expr) => expr.base.return_type.clone(),
}
}

Expand All @@ -50,6 +54,7 @@ impl BoundExpression {
BoundExpression::BoundReferenceExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundCastExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundFunctionExpression(expr) => expr.base.alias.clone(),
BoundExpression::BoundComparisonExpression(expr) => expr.base.alias.clone(),
}
}

Expand All @@ -60,6 +65,7 @@ impl BoundExpression {
BoundExpression::BoundReferenceExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundCastExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundFunctionExpression(expr) => expr.base.alias = alias,
BoundExpression::BoundComparisonExpression(expr) => expr.base.alias = alias,
}
}
}
14 changes: 8 additions & 6 deletions src/planner_v2/expression_binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,14 @@ impl ExpressionBinder<'_> {
| sqlparser::ast::BinaryOperator::Divide => {
self.bind_function_expression(left, op, right, result_names, result_types)
}
sqlparser::ast::BinaryOperator::Gt => todo!(),
sqlparser::ast::BinaryOperator::Lt => todo!(),
sqlparser::ast::BinaryOperator::GtEq => todo!(),
sqlparser::ast::BinaryOperator::LtEq => todo!(),
sqlparser::ast::BinaryOperator::Eq => todo!(),
sqlparser::ast::BinaryOperator::NotEq => todo!(),
sqlparser::ast::BinaryOperator::Gt
| sqlparser::ast::BinaryOperator::Lt
| sqlparser::ast::BinaryOperator::GtEq
| sqlparser::ast::BinaryOperator::LtEq
| sqlparser::ast::BinaryOperator::Eq
| sqlparser::ast::BinaryOperator::NotEq => {
self.bind_comparison_expression(left, op, right, result_names, result_types)
}
sqlparser::ast::BinaryOperator::And => todo!(),
sqlparser::ast::BinaryOperator::Or => todo!(),
other => Err(BindError::UnsupportedExpr(other.to_string())),
Expand Down
4 changes: 4 additions & 0 deletions src/planner_v2/expression_iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ impl ExpressionIterator {
}
BoundExpression::BoundCastExpression(e) => callback(&mut e.child),
BoundExpression::BoundFunctionExpression(e) => e.children.iter_mut().for_each(callback),
BoundExpression::BoundComparisonExpression(e) => {
callback(&mut e.left);
callback(&mut e.right);
}
}
}
}
12 changes: 10 additions & 2 deletions src/planner_v2/logical_operator_visitor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::{
BoundCastExpression, BoundColumnRefExpression, BoundConstantExpression, BoundExpression,
BoundFunctionExpression, BoundReferenceExpression, ExpressionIterator, LogicalOperator,
BoundCastExpression, BoundColumnRefExpression, BoundComparisonExpression,
BoundConstantExpression, BoundExpression, BoundFunctionExpression, BoundReferenceExpression,
ExpressionIterator, LogicalOperator,
};

/// Visitor pattern on logical operators, also includes rewrite expression ability.
Expand Down Expand Up @@ -36,6 +37,7 @@ pub trait LogicalOperatorVisitor {
BoundExpression::BoundReferenceExpression(e) => self.visit_replace_reference(e),
BoundExpression::BoundCastExpression(e) => self.visit_replace_cast(e),
BoundExpression::BoundFunctionExpression(e) => self.visit_function_expression(e),
BoundExpression::BoundComparisonExpression(e) => self.visit_comparison_expression(e),
};
if let Some(new_expr) = result {
*expr = new_expr;
Expand Down Expand Up @@ -63,4 +65,10 @@ pub trait LogicalOperatorVisitor {
fn visit_function_expression(&self, _: &BoundFunctionExpression) -> Option<BoundExpression> {
None
}
fn visit_comparison_expression(
&self,
_: &BoundComparisonExpression,
) -> Option<BoundExpression> {
None
}
}
5 changes: 5 additions & 0 deletions src/util/tree_render.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ impl TreeRender {
.join(", ");
format!("{}({}])", e.function.name, args)
}
BoundExpression::BoundComparisonExpression(e) => {
let l = Self::bound_expression_to_string(&e.left);
let r = Self::bound_expression_to_string(&e.right);
format!("{} {} {}", l, e.function.name, r)
}
}
}

Expand Down
19 changes: 19 additions & 0 deletions tests/slt/comparison_function.slt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
onlyif sqlrs_v2
statement error
select 'abc' > 10

onlyif sqlrs_v2
statement error
select 20.0 = 'abc'

onlyif sqlrs_v2
query T
select 100 > 20
----
true

onlyif sqlrs_v2
query T
select '1000' > '20'
----
false