Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ select * from read_csv('t1.csv', header=>true, delim=>',');
select * from 't1.csv';
-- copy
copy t1 from 't1.csv' ( DELIMITER '|', HEADER false);
-- date and interval
select date '1998-12-01' - interval '1' month;
select interval '1' year + date '1998-12-01';
```


Expand Down
7 changes: 6 additions & 1 deletion src/common/cast.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use arrow::array::{Array, BooleanArray};
use arrow::array::{Array, BooleanArray, Date32Array};

use crate::function::FunctionError;

Expand All @@ -16,3 +16,8 @@ macro_rules! downcast_value {
pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, FunctionError> {
Ok(downcast_value!(array, BooleanArray))
}

// Downcast ArrayRef to Date32Array
pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, FunctionError> {
Ok(downcast_value!(array, Date32Array))
}
113 changes: 110 additions & 3 deletions src/function/scalar/arithmetic_function.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
use std::sync::Arc;

use arrow::array::{ArrayRef, *};
use arrow::compute::{add_checked, divide_checked, multiply_checked, subtract_checked};
use arrow::datatypes::DataType;
use arrow::compute::{
add_checked, add_dyn_checked, divide_checked, multiply_checked, negate_checked,
subtract_checked,
};
use arrow::datatypes::{DataType, IntervalUnit};

use super::ScalarFunction;
use crate::function::{BuiltinFunctions, FunctionError};
Expand Down Expand Up @@ -56,6 +59,7 @@ macro_rules! binary_primitive_array_op {
}
}};
}

pub struct AddFunction;

impl AddFunction {
Expand All @@ -66,6 +70,61 @@ impl AddFunction {
binary_primitive_array_op!(left, right, add_checked)
}

fn date_add_interval_func(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
Ok(add_dyn_checked(left, right)?)
}

fn interval_add_date_func(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
Ok(add_dyn_checked(right, left)?)
}

fn gen_date_funcs() -> Vec<ScalarFunction> {
let mut functions = vec![];
let args1 = [
[
LogicalType::Date,
LogicalType::Interval(IntervalUnit::YearMonth),
],
[
LogicalType::Date,
LogicalType::Interval(IntervalUnit::DayTime),
],
];
for arg in args1.iter() {
functions.push(ScalarFunction::new(
"add".to_string(),
Self::date_add_interval_func,
arg.to_vec(),
LogicalType::Date,
));
}
let args2 = [
[
LogicalType::Interval(IntervalUnit::YearMonth),
LogicalType::Date,
],
[
LogicalType::Interval(IntervalUnit::DayTime),
LogicalType::Date,
],
];
for arg in args2.iter() {
functions.push(ScalarFunction::new(
"add".to_string(),
Self::interval_add_date_func,
arg.to_vec(),
LogicalType::Date,
));
}
functions
}

pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> {
let mut functions = vec![];
for ty in LogicalType::numeric().iter() {
Expand All @@ -76,7 +135,8 @@ impl AddFunction {
ty.clone(),
));
}
set.add_scalar_functions("add".to_string(), functions.clone())?;
functions.extend(Self::gen_date_funcs());
set.add_scalar_functions("add".to_string(), functions)?;
Ok(())
}
}
Expand All @@ -91,6 +151,52 @@ impl SubtractFunction {
binary_primitive_array_op!(left, right, subtract_checked)
}

fn negate_interval(input: &ArrayRef) -> Result<ArrayRef, FunctionError> {
match input.data_type() {
DataType::Interval(IntervalUnit::YearMonth) => {
compute_op!(input, negate_checked, IntervalYearMonthArray)
}
DataType::Interval(IntervalUnit::DayTime) => {
compute_op!(input, negate_checked, IntervalDayTimeArray)
}
other => Err(FunctionError::InternalError(format!(
"Data type {:?} not supported for negate",
other
))),
}
}

fn date_subtract_interval_func(inputs: &[ArrayRef]) -> Result<ArrayRef, FunctionError> {
assert!(inputs.len() == 2);
let left = &inputs[0];
let right = &inputs[1];
let right = Self::negate_interval(right)?;
Ok(add_dyn_checked(left, &right)?)
}

fn gen_date_funcs() -> Vec<ScalarFunction> {
let mut functions = vec![];
let args1 = [
[
LogicalType::Date,
LogicalType::Interval(IntervalUnit::YearMonth),
],
[
LogicalType::Date,
LogicalType::Interval(IntervalUnit::DayTime),
],
];
for arg in args1.iter() {
functions.push(ScalarFunction::new(
"subtract".to_string(),
Self::date_subtract_interval_func,
arg.to_vec(),
LogicalType::Date,
));
}
functions
}

pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> {
let mut functions = vec![];
for ty in LogicalType::numeric().iter() {
Expand All @@ -101,6 +207,7 @@ impl SubtractFunction {
ty.clone(),
));
}
functions.extend(Self::gen_date_funcs());
set.add_scalar_functions("subtract".to_string(), functions.clone())?;
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion src/planner_v2/binder/expression/bind_cast_expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl BoundCastExpression {
return Ok(source_expr);
}
let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?;
let alias = format!("cast({} as {}", source_expr.alias(), target_type);
let alias = format!("cast({}) as {}", source_expr.alias(), target_type);
let base = BoundExpressionBase::new(alias, target_type);
Ok(BoundExpression::BoundCastExpression(
BoundCastExpression::new(base, Box::new(source_expr), try_cast, cast_function),
Expand Down
124 changes: 121 additions & 3 deletions src/planner_v2/expression_binder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@ use std::slice;

use derive_new::new;

use super::{BindError, Binder, BoundExpression, ColumnAliasData};
use crate::types_v2::LogicalType;
use super::{
BindError, Binder, BoundCastExpression, BoundConstantExpression, BoundExpression,
BoundExpressionBase, ColumnAliasData, SqlparserResolver,
};
use crate::types_v2::{LogicalType, ScalarValue};

#[derive(new)]
pub struct ExpressionBinder<'a> {
Expand Down Expand Up @@ -40,7 +43,13 @@ impl ExpressionBinder<'_> {
sqlparser::ast::Expr::Function(_) => todo!(),
sqlparser::ast::Expr::Exists { .. } => todo!(),
sqlparser::ast::Expr::Subquery(_) => todo!(),
_ => todo!(),
sqlparser::ast::Expr::TypedString { data_type, value } => {
self.bind_typed_string(data_type, value, result_names, result_types)
}
sqlparser::ast::Expr::Interval { .. } => {
self.bind_interval_expr(expr, result_names, result_types)
}
other => Err(BindError::UnsupportedExpr(other.to_string())),
}
}

Expand Down Expand Up @@ -73,4 +82,113 @@ impl ExpressionBinder<'_> {
other => Err(BindError::UnsupportedExpr(other.to_string())),
}
}

/// TypedString: A constant of form `<data_type> 'value'`.
fn bind_typed_string(
&mut self,
data_type: &sqlparser::ast::DataType,
value: &str,
result_names: &mut Vec<String>,
result_types: &mut Vec<LogicalType>,
) -> Result<BoundExpression, BindError> {
// A constant of form `<data_type> 'value'`.
let val = sqlparser::ast::Value::SingleQuotedString(value.to_string());
let constant_expr = self.bind_constant_expr(&val, &mut vec![], &mut vec![])?;
let target_type = LogicalType::try_from(data_type.clone())?;
let expr = BoundCastExpression::try_add_cast_to_type(constant_expr, target_type, true)?;
result_names.push(expr.alias());
result_types.push(expr.return_type());
Ok(expr)
}

/// bind a interval expression, currently only support one DateTimeFiled, such as: `interval '1'
/// day`. So if value contains unit, such as `interval '1 year 2 month'`, current binder will
/// return error. To support this, we need split the value into parts, and parse each part in
/// loop, so is more complex for now.
fn bind_interval_expr(
&mut self,
expr: &sqlparser::ast::Expr,
result_names: &mut Vec<String>,
result_types: &mut Vec<LogicalType>,
) -> Result<BoundExpression, BindError> {
match expr {
sqlparser::ast::Expr::Interval {
value,
leading_field,
leading_precision,
last_field,
fractional_seconds_precision,
} => {
if leading_precision.is_some()
|| last_field.is_some()
|| fractional_seconds_precision.is_some()
{
return Err(BindError::UnsupportedExpr(
"Unsupported Interval Expression".to_string(),
));
}

let val = SqlparserResolver::resolve_expr_to_string(value)?;
let num: i64 = val.parse().map_err(|e| {
BindError::UnsupportedExpr(format!(
"Interval value must be a number, but got {}",
e
))
})?;

let scalar = match leading_field {
Some(v) => {
match v {
// convert to IntervalYearMonth
sqlparser::ast::DateTimeField::Year => {
ScalarValue::IntervalYearMonth(Some(num as i32 * 12))
}
sqlparser::ast::DateTimeField::Month => {
ScalarValue::IntervalYearMonth(Some(num as i32))
}
// convert to IntervalDayTime
sqlparser::ast::DateTimeField::Week => {
ScalarValue::IntervalDayTime(Some(num * 7 * 24 * 60 * 60 * 1000))
}
sqlparser::ast::DateTimeField::Day => {
ScalarValue::IntervalDayTime(Some(num * 24 * 60 * 60 * 1000))
}
sqlparser::ast::DateTimeField::Hour => {
ScalarValue::IntervalDayTime(Some(num * 60 * 60 * 1000))
}
sqlparser::ast::DateTimeField::Minute => {
ScalarValue::IntervalDayTime(Some(num * 60 * 1000))
}
sqlparser::ast::DateTimeField::Second => {
ScalarValue::IntervalDayTime(Some(num * 1000))
}
other => {
return Err(BindError::UnsupportedExpr(format!(
"Unsupported Interval unit: {:?}",
other
)))
}
}
}
None => {
return Err(BindError::UnsupportedExpr(
"Interval must have DataTimeField".to_string(),
))
}
};

let base =
BoundExpressionBase::new(format!("{:?}", scalar), scalar.get_logical_type());
result_names.push(base.alias.clone());
result_types.push(base.return_type.clone());
let expr = BoundExpression::BoundConstantExpression(BoundConstantExpression::new(
base, scalar,
));
Ok(expr)
}
_ => Err(BindError::UnsupportedExpr(
"expect interval expr".to_string(),
)),
}
}
}
14 changes: 12 additions & 2 deletions src/types_v2/types.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use arrow::datatypes::IntervalUnit;
use strum_macros::AsRefStr;

use super::TypeError;
Expand All @@ -20,6 +21,8 @@ pub enum LogicalType {
Float,
Double,
Varchar,
Date,
Interval(IntervalUnit),
}

impl LogicalType {
Expand Down Expand Up @@ -191,6 +194,8 @@ impl LogicalType {
LogicalType::Float => matches!(to, LogicalType::Double),
LogicalType::Double => false,
LogicalType::Varchar => false,
LogicalType::Date => false,
LogicalType::Interval(_) => false,
}
}
}
Expand Down Expand Up @@ -220,6 +225,9 @@ impl TryFrom<sqlparser::ast::DataType> for LogicalType {
sqlparser::ast::DataType::BigInt(_) => Ok(LogicalType::Bigint),
sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint),
sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean),
sqlparser::ast::DataType::Date => Ok(LogicalType::Date),
// use day time interval for default interval value
sqlparser::ast::DataType::Interval => Ok(LogicalType::Interval(IntervalUnit::DayTime)),
other => Err(TypeError::NotImplementedSqlparserDataType(
other.to_string(),
)),
Expand All @@ -245,6 +253,8 @@ impl From<LogicalType> for arrow::datatypes::DataType {
LogicalType::Float => DataType::Float32,
LogicalType::Double => DataType::Float64,
LogicalType::Varchar => DataType::Utf8,
LogicalType::Date => DataType::Date32,
LogicalType::Interval(u) => DataType::Interval(u),
}
}
}
Expand All @@ -270,13 +280,13 @@ impl TryFrom<&arrow::datatypes::DataType> for LogicalType {
DataType::Float64 => LogicalType::Double,
DataType::Utf8 => LogicalType::Varchar,
DataType::LargeUtf8 => LogicalType::Varchar,
DataType::Date32 => LogicalType::Date,
DataType::Interval(u) => LogicalType::Interval(u.clone()),
DataType::Timestamp(_, _)
| DataType::Date32
| DataType::Date64
| DataType::Time32(_)
| DataType::Time64(_)
| DataType::Duration(_)
| DataType::Interval(_)
| DataType::Binary
| DataType::FixedSizeBinary(_)
| DataType::LargeBinary
Expand Down
Loading