From 1f64ff56c5be5ff2463b6abafd273fcf242d7016 Mon Sep 17 00:00:00 2001 From: Fedomn Date: Sat, 7 Jan 2023 19:34:53 +0800 Subject: [PATCH] feat(types): support date and interval Signed-off-by: Fedomn --- README.md | 3 + src/common/cast.rs | 7 +- src/function/scalar/arithmetic_function.rs | 113 +++++++++++++++- .../binder/expression/bind_cast_expression.rs | 2 +- src/planner_v2/expression_binder.rs | 124 +++++++++++++++++- src/types_v2/types.rs | 14 +- src/types_v2/values.rs | 88 ++++++++++++- tests/slt/time.slt | 35 +++++ 8 files changed, 371 insertions(+), 15 deletions(-) create mode 100644 tests/slt/time.slt diff --git a/README.md b/README.md index 9564221..65b08e5 100644 --- a/README.md +++ b/README.md @@ -72,6 +72,9 @@ select * from read_csv('t1.csv', header=>true, delim=>','); select * from 't1.csv'; -- copy copy t1 from 't1.csv' ( DELIMITER '|', HEADER false); +-- date and interval +select date '1998-12-01' - interval '1' month; +select interval '1' year + date '1998-12-01'; ``` diff --git a/src/common/cast.rs b/src/common/cast.rs index 2501fa9..66a975a 100644 --- a/src/common/cast.rs +++ b/src/common/cast.rs @@ -1,4 +1,4 @@ -use arrow::array::{Array, BooleanArray}; +use arrow::array::{Array, BooleanArray, Date32Array}; use crate::function::FunctionError; @@ -16,3 +16,8 @@ macro_rules! downcast_value { pub fn as_boolean_array(array: &dyn Array) -> Result<&BooleanArray, FunctionError> { Ok(downcast_value!(array, BooleanArray)) } + +// Downcast ArrayRef to Date32Array +pub fn as_date32_array(array: &dyn Array) -> Result<&Date32Array, FunctionError> { + Ok(downcast_value!(array, Date32Array)) +} diff --git a/src/function/scalar/arithmetic_function.rs b/src/function/scalar/arithmetic_function.rs index deb8ae4..58d2ca9 100644 --- a/src/function/scalar/arithmetic_function.rs +++ b/src/function/scalar/arithmetic_function.rs @@ -1,8 +1,11 @@ use std::sync::Arc; use arrow::array::{ArrayRef, *}; -use arrow::compute::{add_checked, divide_checked, multiply_checked, subtract_checked}; -use arrow::datatypes::DataType; +use arrow::compute::{ + add_checked, add_dyn_checked, divide_checked, multiply_checked, negate_checked, + subtract_checked, +}; +use arrow::datatypes::{DataType, IntervalUnit}; use super::ScalarFunction; use crate::function::{BuiltinFunctions, FunctionError}; @@ -56,6 +59,7 @@ macro_rules! binary_primitive_array_op { } }}; } + pub struct AddFunction; impl AddFunction { @@ -66,6 +70,61 @@ impl AddFunction { binary_primitive_array_op!(left, right, add_checked) } + fn date_add_interval_func(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + Ok(add_dyn_checked(left, right)?) + } + + fn interval_add_date_func(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + Ok(add_dyn_checked(right, left)?) + } + + fn gen_date_funcs() -> Vec { + let mut functions = vec![]; + let args1 = [ + [ + LogicalType::Date, + LogicalType::Interval(IntervalUnit::YearMonth), + ], + [ + LogicalType::Date, + LogicalType::Interval(IntervalUnit::DayTime), + ], + ]; + for arg in args1.iter() { + functions.push(ScalarFunction::new( + "add".to_string(), + Self::date_add_interval_func, + arg.to_vec(), + LogicalType::Date, + )); + } + let args2 = [ + [ + LogicalType::Interval(IntervalUnit::YearMonth), + LogicalType::Date, + ], + [ + LogicalType::Interval(IntervalUnit::DayTime), + LogicalType::Date, + ], + ]; + for arg in args2.iter() { + functions.push(ScalarFunction::new( + "add".to_string(), + Self::interval_add_date_func, + arg.to_vec(), + LogicalType::Date, + )); + } + functions + } + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { let mut functions = vec![]; for ty in LogicalType::numeric().iter() { @@ -76,7 +135,8 @@ impl AddFunction { ty.clone(), )); } - set.add_scalar_functions("add".to_string(), functions.clone())?; + functions.extend(Self::gen_date_funcs()); + set.add_scalar_functions("add".to_string(), functions)?; Ok(()) } } @@ -91,6 +151,52 @@ impl SubtractFunction { binary_primitive_array_op!(left, right, subtract_checked) } + fn negate_interval(input: &ArrayRef) -> Result { + match input.data_type() { + DataType::Interval(IntervalUnit::YearMonth) => { + compute_op!(input, negate_checked, IntervalYearMonthArray) + } + DataType::Interval(IntervalUnit::DayTime) => { + compute_op!(input, negate_checked, IntervalDayTimeArray) + } + other => Err(FunctionError::InternalError(format!( + "Data type {:?} not supported for negate", + other + ))), + } + } + + fn date_subtract_interval_func(inputs: &[ArrayRef]) -> Result { + assert!(inputs.len() == 2); + let left = &inputs[0]; + let right = &inputs[1]; + let right = Self::negate_interval(right)?; + Ok(add_dyn_checked(left, &right)?) + } + + fn gen_date_funcs() -> Vec { + let mut functions = vec![]; + let args1 = [ + [ + LogicalType::Date, + LogicalType::Interval(IntervalUnit::YearMonth), + ], + [ + LogicalType::Date, + LogicalType::Interval(IntervalUnit::DayTime), + ], + ]; + for arg in args1.iter() { + functions.push(ScalarFunction::new( + "subtract".to_string(), + Self::date_subtract_interval_func, + arg.to_vec(), + LogicalType::Date, + )); + } + functions + } + pub fn register_function(set: &mut BuiltinFunctions) -> Result<(), FunctionError> { let mut functions = vec![]; for ty in LogicalType::numeric().iter() { @@ -101,6 +207,7 @@ impl SubtractFunction { ty.clone(), )); } + functions.extend(Self::gen_date_funcs()); set.add_scalar_functions("subtract".to_string(), functions.clone())?; Ok(()) } diff --git a/src/planner_v2/binder/expression/bind_cast_expression.rs b/src/planner_v2/binder/expression/bind_cast_expression.rs index f87ee69..f516773 100644 --- a/src/planner_v2/binder/expression/bind_cast_expression.rs +++ b/src/planner_v2/binder/expression/bind_cast_expression.rs @@ -30,7 +30,7 @@ impl BoundCastExpression { return Ok(source_expr); } let cast_function = DefaultCastFunctions::get_cast_function(&source_type, &target_type)?; - let alias = format!("cast({} as {}", source_expr.alias(), target_type); + let alias = format!("cast({}) as {}", source_expr.alias(), target_type); let base = BoundExpressionBase::new(alias, target_type); Ok(BoundExpression::BoundCastExpression( BoundCastExpression::new(base, Box::new(source_expr), try_cast, cast_function), diff --git a/src/planner_v2/expression_binder.rs b/src/planner_v2/expression_binder.rs index 8aae59f..2103499 100644 --- a/src/planner_v2/expression_binder.rs +++ b/src/planner_v2/expression_binder.rs @@ -2,8 +2,11 @@ use std::slice; use derive_new::new; -use super::{BindError, Binder, BoundExpression, ColumnAliasData}; -use crate::types_v2::LogicalType; +use super::{ + BindError, Binder, BoundCastExpression, BoundConstantExpression, BoundExpression, + BoundExpressionBase, ColumnAliasData, SqlparserResolver, +}; +use crate::types_v2::{LogicalType, ScalarValue}; #[derive(new)] pub struct ExpressionBinder<'a> { @@ -40,7 +43,13 @@ impl ExpressionBinder<'_> { sqlparser::ast::Expr::Function(_) => todo!(), sqlparser::ast::Expr::Exists { .. } => todo!(), sqlparser::ast::Expr::Subquery(_) => todo!(), - _ => todo!(), + sqlparser::ast::Expr::TypedString { data_type, value } => { + self.bind_typed_string(data_type, value, result_names, result_types) + } + sqlparser::ast::Expr::Interval { .. } => { + self.bind_interval_expr(expr, result_names, result_types) + } + other => Err(BindError::UnsupportedExpr(other.to_string())), } } @@ -73,4 +82,113 @@ impl ExpressionBinder<'_> { other => Err(BindError::UnsupportedExpr(other.to_string())), } } + + /// TypedString: A constant of form ` 'value'`. + fn bind_typed_string( + &mut self, + data_type: &sqlparser::ast::DataType, + value: &str, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + // A constant of form ` 'value'`. + let val = sqlparser::ast::Value::SingleQuotedString(value.to_string()); + let constant_expr = self.bind_constant_expr(&val, &mut vec![], &mut vec![])?; + let target_type = LogicalType::try_from(data_type.clone())?; + let expr = BoundCastExpression::try_add_cast_to_type(constant_expr, target_type, true)?; + result_names.push(expr.alias()); + result_types.push(expr.return_type()); + Ok(expr) + } + + /// bind a interval expression, currently only support one DateTimeFiled, such as: `interval '1' + /// day`. So if value contains unit, such as `interval '1 year 2 month'`, current binder will + /// return error. To support this, we need split the value into parts, and parse each part in + /// loop, so is more complex for now. + fn bind_interval_expr( + &mut self, + expr: &sqlparser::ast::Expr, + result_names: &mut Vec, + result_types: &mut Vec, + ) -> Result { + match expr { + sqlparser::ast::Expr::Interval { + value, + leading_field, + leading_precision, + last_field, + fractional_seconds_precision, + } => { + if leading_precision.is_some() + || last_field.is_some() + || fractional_seconds_precision.is_some() + { + return Err(BindError::UnsupportedExpr( + "Unsupported Interval Expression".to_string(), + )); + } + + let val = SqlparserResolver::resolve_expr_to_string(value)?; + let num: i64 = val.parse().map_err(|e| { + BindError::UnsupportedExpr(format!( + "Interval value must be a number, but got {}", + e + )) + })?; + + let scalar = match leading_field { + Some(v) => { + match v { + // convert to IntervalYearMonth + sqlparser::ast::DateTimeField::Year => { + ScalarValue::IntervalYearMonth(Some(num as i32 * 12)) + } + sqlparser::ast::DateTimeField::Month => { + ScalarValue::IntervalYearMonth(Some(num as i32)) + } + // convert to IntervalDayTime + sqlparser::ast::DateTimeField::Week => { + ScalarValue::IntervalDayTime(Some(num * 7 * 24 * 60 * 60 * 1000)) + } + sqlparser::ast::DateTimeField::Day => { + ScalarValue::IntervalDayTime(Some(num * 24 * 60 * 60 * 1000)) + } + sqlparser::ast::DateTimeField::Hour => { + ScalarValue::IntervalDayTime(Some(num * 60 * 60 * 1000)) + } + sqlparser::ast::DateTimeField::Minute => { + ScalarValue::IntervalDayTime(Some(num * 60 * 1000)) + } + sqlparser::ast::DateTimeField::Second => { + ScalarValue::IntervalDayTime(Some(num * 1000)) + } + other => { + return Err(BindError::UnsupportedExpr(format!( + "Unsupported Interval unit: {:?}", + other + ))) + } + } + } + None => { + return Err(BindError::UnsupportedExpr( + "Interval must have DataTimeField".to_string(), + )) + } + }; + + let base = + BoundExpressionBase::new(format!("{:?}", scalar), scalar.get_logical_type()); + result_names.push(base.alias.clone()); + result_types.push(base.return_type.clone()); + let expr = BoundExpression::BoundConstantExpression(BoundConstantExpression::new( + base, scalar, + )); + Ok(expr) + } + _ => Err(BindError::UnsupportedExpr( + "expect interval expr".to_string(), + )), + } + } } diff --git a/src/types_v2/types.rs b/src/types_v2/types.rs index be88976..8356e1e 100644 --- a/src/types_v2/types.rs +++ b/src/types_v2/types.rs @@ -1,3 +1,4 @@ +use arrow::datatypes::IntervalUnit; use strum_macros::AsRefStr; use super::TypeError; @@ -20,6 +21,8 @@ pub enum LogicalType { Float, Double, Varchar, + Date, + Interval(IntervalUnit), } impl LogicalType { @@ -191,6 +194,8 @@ impl LogicalType { LogicalType::Float => matches!(to, LogicalType::Double), LogicalType::Double => false, LogicalType::Varchar => false, + LogicalType::Date => false, + LogicalType::Interval(_) => false, } } } @@ -220,6 +225,9 @@ impl TryFrom for LogicalType { sqlparser::ast::DataType::BigInt(_) => Ok(LogicalType::Bigint), sqlparser::ast::DataType::UnsignedBigInt(_) => Ok(LogicalType::UBigint), sqlparser::ast::DataType::Boolean => Ok(LogicalType::Boolean), + sqlparser::ast::DataType::Date => Ok(LogicalType::Date), + // use day time interval for default interval value + sqlparser::ast::DataType::Interval => Ok(LogicalType::Interval(IntervalUnit::DayTime)), other => Err(TypeError::NotImplementedSqlparserDataType( other.to_string(), )), @@ -245,6 +253,8 @@ impl From for arrow::datatypes::DataType { LogicalType::Float => DataType::Float32, LogicalType::Double => DataType::Float64, LogicalType::Varchar => DataType::Utf8, + LogicalType::Date => DataType::Date32, + LogicalType::Interval(u) => DataType::Interval(u), } } } @@ -270,13 +280,13 @@ impl TryFrom<&arrow::datatypes::DataType> for LogicalType { DataType::Float64 => LogicalType::Double, DataType::Utf8 => LogicalType::Varchar, DataType::LargeUtf8 => LogicalType::Varchar, + DataType::Date32 => LogicalType::Date, + DataType::Interval(u) => LogicalType::Interval(u.clone()), DataType::Timestamp(_, _) - | DataType::Date32 | DataType::Date64 | DataType::Time32(_) | DataType::Time64(_) | DataType::Duration(_) - | DataType::Interval(_) | DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary diff --git a/src/types_v2/values.rs b/src/types_v2/values.rs index 3459634..70fc266 100644 --- a/src/types_v2/values.rs +++ b/src/types_v2/values.rs @@ -5,13 +5,15 @@ use std::iter::repeat; use std::sync::Arc; use arrow::array::{ - new_null_array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, Float32Array, - Float32Builder, Float64Array, Float64Builder, Int16Array, Int16Builder, Int32Array, - Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, StringArray, StringBuilder, - UInt16Array, UInt16Builder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, UInt8Array, + new_null_array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, Date32Array, + Date32Builder, Float32Array, Float32Builder, Float64Array, Float64Builder, Int16Array, + Int16Builder, Int32Array, Int32Builder, Int64Array, Int64Builder, Int8Array, Int8Builder, + IntervalDayTimeArray, IntervalDayTimeBuilder, IntervalMonthDayNanoBuilder, + IntervalYearMonthArray, IntervalYearMonthBuilder, StringArray, StringBuilder, UInt16Array, + UInt16Builder, UInt32Array, UInt32Builder, UInt64Array, UInt64Builder, UInt8Array, UInt8Builder, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, IntervalUnit}; use ordered_float::OrderedFloat; use super::{LogicalType, TypeError}; @@ -31,6 +33,13 @@ pub enum ScalarValue { UInt32(Option), UInt64(Option), Utf8(Option), + /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 + Date32(Option), + /// Number of elapsed whole months + IntervalYearMonth(Option), + /// Number of elapsed days and milliseconds (no leap seconds) + /// stored as 2 contiguous 32-bit signed integers + IntervalDayTime(Option), } impl PartialEq for ScalarValue { @@ -71,6 +80,12 @@ impl PartialEq for ScalarValue { (Utf8(_), _) => false, (Null, Null) => true, (Null, _) => false, + (Date32(v1), Date32(v2)) => v1.eq(v2), + (Date32(_), _) => false, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.eq(v2), + (IntervalYearMonth(_), _) => false, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.eq(v2), + (IntervalDayTime(_), _) => false, } } } @@ -113,6 +128,12 @@ impl PartialOrd for ScalarValue { (Utf8(_), _) => None, (Null, Null) => Some(Ordering::Equal), (Null, _) => None, + (Date32(v1), Date32(v2)) => v1.partial_cmp(v2), + (Date32(_), _) => None, + (IntervalYearMonth(v1), IntervalYearMonth(v2)) => v1.partial_cmp(v2), + (IntervalYearMonth(_), _) => None, + (IntervalDayTime(v1), IntervalDayTime(v2)) => v1.partial_cmp(v2), + (IntervalDayTime(_), _) => None, } } } @@ -142,6 +163,9 @@ impl Hash for ScalarValue { UInt64(v) => v.hash(state), Utf8(v) => v.hash(state), Null => 1.hash(state), + Date32(v) => v.hash(state), + IntervalYearMonth(v) => v.hash(state), + IntervalDayTime(v) => v.hash(state), } } } @@ -244,6 +268,9 @@ impl ScalarValue { ScalarValue::UInt32(_) => LogicalType::UInteger, ScalarValue::UInt64(_) => LogicalType::UBigint, ScalarValue::Utf8(_) => LogicalType::Varchar, + ScalarValue::Date32(_) => LogicalType::Date, + ScalarValue::IntervalYearMonth(_) => LogicalType::Interval(IntervalUnit::YearMonth), + ScalarValue::IntervalDayTime(_) => LogicalType::Interval(IntervalUnit::DayTime), } } @@ -282,6 +309,23 @@ impl ScalarValue { None => new_null_array(&DataType::Utf8, size), }, ScalarValue::Null => new_null_array(&DataType::Null, size), + ScalarValue::Date32(e) => { + build_array_from_option!(Date32, Date32Array, e, size) + } + ScalarValue::IntervalDayTime(e) => build_array_from_option!( + Interval, + IntervalUnit::DayTime, + IntervalDayTimeArray, + e, + size + ), + ScalarValue::IntervalYearMonth(e) => build_array_from_option!( + Interval, + IntervalUnit::YearMonth, + IntervalYearMonthArray, + e, + size + ), } } @@ -303,6 +347,16 @@ impl ScalarValue { LogicalType::Float => Ok(Box::new(Float32Builder::new())), LogicalType::Double => Ok(Box::new(Float64Builder::new())), LogicalType::Varchar => Ok(Box::new(StringBuilder::new())), + LogicalType::Date => Ok(Box::new(Date32Builder::new())), + LogicalType::Interval(IntervalUnit::DayTime) => { + Ok(Box::new(IntervalDayTimeBuilder::new())) + } + LogicalType::Interval(IntervalUnit::YearMonth) => { + Ok(Box::new(IntervalYearMonthBuilder::new())) + } + LogicalType::Interval(IntervalUnit::MonthDayNano) => { + Ok(Box::new(IntervalMonthDayNanoBuilder::new())) + } } } @@ -376,6 +430,21 @@ impl ScalarValue { .downcast_mut::() .unwrap() .append_option(*v), + ScalarValue::Date32(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::IntervalYearMonth(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), + ScalarValue::IntervalDayTime(v) => builder + .as_any_mut() + .downcast_mut::() + .unwrap() + .append_option(*v), } Ok(()) } @@ -395,6 +464,9 @@ impl ScalarValue { ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, ScalarValue::Null => DataType::Null, + ScalarValue::Date32(_) => DataType::Date32, + ScalarValue::IntervalYearMonth(_) => DataType::Interval(IntervalUnit::YearMonth), + ScalarValue::IntervalDayTime(_) => DataType::Interval(IntervalUnit::DayTime), } } } @@ -479,6 +551,9 @@ impl fmt::Display for ScalarValue { ScalarValue::UInt64(e) => format_option!(f, e)?, ScalarValue::Utf8(e) => format_option!(f, e)?, ScalarValue::Null => write!(f, "NULL")?, + ScalarValue::Date32(e) => format_option!(f, e)?, + ScalarValue::IntervalDayTime(e) => format_option!(f, e)?, + ScalarValue::IntervalYearMonth(e) => format_option!(f, e)?, }; Ok(()) } @@ -501,6 +576,9 @@ impl fmt::Debug for ScalarValue { ScalarValue::Utf8(None) => write!(f, "Utf8({})", self), ScalarValue::Utf8(Some(_)) => write!(f, "Utf8(\"{}\")", self), ScalarValue::Null => write!(f, "NULL"), + ScalarValue::Date32(_) => write!(f, "Date32({})", self), + ScalarValue::IntervalYearMonth(_) => write!(f, "IntervalYearMonth({})", self), + ScalarValue::IntervalDayTime(_) => write!(f, "IntervalDayTime({})", self), } } } diff --git a/tests/slt/time.slt b/tests/slt/time.slt new file mode 100644 index 0000000..8e83fbe --- /dev/null +++ b/tests/slt/time.slt @@ -0,0 +1,35 @@ +onlyif sqlrs_v2 +statement ok +create table t5(v1 date); +insert into t5 values ('2021-01-02'), ('2021-01-03'); + +onlyif sqlrs_v2 +query I +select v1 + interval '1' day from t5; +---- +2021-01-03 +2021-01-04 + +onlyif sqlrs_v2 +query I +select interval '1' year + date '1998-12-01'; +---- +1999-12-01 + +onlyif sqlrs_v2 +query I +select interval '1' month + date '1998-12-01'; +---- +1999-01-01 + +onlyif sqlrs_v2 +query I +select date '1998-12-01' - interval '1' month; +---- +1998-11-01 + +onlyif sqlrs_v2 +query I +select date '1998-12-01' - interval '1' day; +---- +1998-11-29