diff --git a/datafusion/common/src/scalar.rs b/datafusion/common/src/scalar.rs index 553dc5eae5709..e3442d749a498 100644 --- a/datafusion/common/src/scalar.rs +++ b/datafusion/common/src/scalar.rs @@ -37,6 +37,8 @@ use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc}; /// Represents a dynamically typed, nullable single value. /// This is the single-valued counter-part of arrow’s `Array`. +/// https://arrow.apache.org/docs/python/api/datatypes.html +/// https://github.com/apache/arrow/blob/master/format/Schema.fbs#L354-L375 #[derive(Clone)] pub enum ScalarValue { /// represents `DataType::Null` (castable to/from any other type) @@ -75,9 +77,9 @@ pub enum ScalarValue { LargeBinary(Option>), /// list of nested ScalarValue List(Option>, Box), - /// Date stored as a signed 32bit int + /// Date stored as a signed 32bit int days since UNIX epoch 1970-01-01 Date32(Option), - /// Date stored as a signed 64bit int + /// Date stored as a signed 64bit int milliseconds since UNIX epoch 1970-01-01 Date64(Option), /// Timestamp Second TimestampSecond(Option, Option), @@ -87,11 +89,14 @@ pub enum ScalarValue { TimestampMicrosecond(Option, Option), /// Timestamp Nanoseconds TimestampNanosecond(Option, Option), - /// Interval with YearMonth unit + /// Number of elapsed whole months IntervalYearMonth(Option), - /// Interval with DayTime unit + /// Number of elapsed days and milliseconds (no leap seconds) + /// stored as 2 contiguous 32-bit signed integers IntervalDayTime(Option), - /// Interval with MonthDayNano unit + /// A triple of the number of elapsed months, days, and nanoseconds. + /// Months and days are encoded as 32-bit signed integers. + /// Nanoseconds is encoded as a 64-bit signed integer (no leap seconds). IntervalMonthDayNano(Option), /// struct of nested ScalarValue Struct(Option>, Box>), diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 458b915260a61..393e8aef34ec1 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -99,6 +99,7 @@ use chrono::{DateTime, Utc}; use datafusion_common::ScalarValue; use datafusion_expr::TableSource; use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; +use datafusion_optimizer::subquery_decorrelate::SubqueryDecorrelate; use datafusion_sql::{ parser::DFParser, planner::{ContextProvider, SqlToRel}, @@ -1239,6 +1240,7 @@ impl SessionState { // of applying other optimizations Arc::new(SimplifyExpressions::new()), Arc::new(SubqueryFilterToJoin::new()), + Arc::new(SubqueryDecorrelate::new()), Arc::new(EliminateFilter::new()), Arc::new(CommonSubexprEliminate::new()), Arc::new(EliminateLimit::new()), diff --git a/datafusion/core/tests/sql/mod.rs b/datafusion/core/tests/sql/mod.rs index 0e3e08873cce4..fd5a189a9f660 100644 --- a/datafusion/core/tests/sql/mod.rs +++ b/datafusion/core/tests/sql/mod.rs @@ -108,6 +108,7 @@ mod explain; mod idenfifers; pub mod information_schema; mod partitioned_csv; +mod subqueries; #[cfg(feature = "unicode_expressions")] pub mod unicode; @@ -483,7 +484,37 @@ fn get_tpch_table_schema(table: &str) -> Schema { Field::new("n_comment", DataType::Utf8, false), ]), - _ => unimplemented!(), + "supplier" => Schema::new(vec![ + Field::new("s_suppkey", DataType::Int64, false), + Field::new("s_name", DataType::Utf8, false), + Field::new("s_address", DataType::Utf8, false), + Field::new("s_nationkey", DataType::Int64, false), + Field::new("s_phone", DataType::Utf8, false), + Field::new("s_acctbal", DataType::Float64, false), + Field::new("s_comment", DataType::Utf8, false), + ]), + + "partsupp" => Schema::new(vec![ + Field::new("ps_partkey", DataType::Int64, false), + Field::new("ps_suppkey", DataType::Int64, false), + Field::new("ps_availqty", DataType::Int32, false), + Field::new("ps_supplycost", DataType::Float64, false), + Field::new("ps_comment", DataType::Utf8, false), + ]), + + "part" => Schema::new(vec![ + Field::new("p_partkey", DataType::Int64, false), + Field::new("p_name", DataType::Utf8, false), + Field::new("p_mfgr", DataType::Utf8, false), + Field::new("p_brand", DataType::Utf8, false), + Field::new("p_type", DataType::Utf8, false), + Field::new("p_size", DataType::Int32, false), + Field::new("p_container", DataType::Utf8, false), + Field::new("p_retailprice", DataType::Float64, false), + Field::new("p_comment", DataType::Utf8, false), + ]), + + _ => unimplemented!("Table: {}", table), } } diff --git a/datafusion/core/tests/sql/subqueries.rs b/datafusion/core/tests/sql/subqueries.rs new file mode 100644 index 0000000000000..9ff1d34fc2e9e --- /dev/null +++ b/datafusion/core/tests/sql/subqueries.rs @@ -0,0 +1,67 @@ +use super::*; +use crate::sql::execute_to_batches; +use datafusion::assert_batches_eq; +use datafusion::prelude::SessionContext; + +#[tokio::test] +async fn tpch_q4_correlated() -> Result<()> { + let ctx = SessionContext::new(); + register_tpch_csv(&ctx, "orders").await?; + register_tpch_csv(&ctx, "lineitem").await?; + + /* + #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Filter: EXISTS ( -- plan + Subquery: Projection: * -- proj + Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter + TableScan: lineitem projection=None -- filter.input + ) + TableScan: orders projection=None -- plan.inputs + */ + let sql = r#" + select o_orderpriority, count(*) as order_count + from orders + where exists ( + select * from lineitem where l_orderkey = o_orderkey and l_commitdate < l_receiptdate) + group by o_orderpriority + order by o_orderpriority; + "#; + + // assert plan + let plan = ctx + .create_logical_plan(sql) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let plan = ctx + .optimize(&plan) + .map_err(|e| format!("{:?} at {}", e, "error")) + .unwrap(); + let actual = format!("{}", plan.display_indent()); + let expected = r#"Sort: #orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Inner Join: #orders.o_orderkey = #lineitem.l_orderkey + TableScan: orders projection=[o_orderkey, o_orderpriority] + Projection: #lineitem.l_orderkey + Aggregate: groupBy=[[#lineitem.l_orderkey]], aggr=[[]] + Filter: #lineitem.l_commitdate < #lineitem.l_receiptdate + TableScan: lineitem projection=[l_orderkey, l_commitdate, l_receiptdate], partial_filters=[#lineitem.l_commitdate < #lineitem.l_receiptdate]"# + .to_string(); + assert_eq!(actual, expected); + + // assert data + let results = execute_to_batches(&ctx, sql).await; + let expected = vec![ + "+-----------------+-------------+", + "| o_orderpriority | order_count |", + "+-----------------+-------------+", + "| 1-URGENT | 1 |", + "| 5-LOW | 1 |", + "+-----------------+-------------+", + ]; + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/sql/timestamp.rs b/datafusion/core/tests/sql/timestamp.rs index 1e475fb175bd7..9acc3f3cbe28f 100644 --- a/datafusion/core/tests/sql/timestamp.rs +++ b/datafusion/core/tests/sql/timestamp.rs @@ -814,3 +814,123 @@ async fn group_by_timestamp_millis() -> Result<()> { assert_batches_eq!(expected, &actual); Ok(()) } + +#[tokio::test] +async fn interval_year() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-01' + interval '1' year as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1995-01-01 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn add_interval_month() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-31' + interval '1' month as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1994-02-28 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sub_interval_month() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-03-31' - interval '1' month as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1994-02-28 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sub_month_wrap() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-15' - interval '1' month as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1993-12-15 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn add_interval_day() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-15' + interval '1' day as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1994-01-16 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} + +#[tokio::test] +async fn sub_interval_day() -> Result<()> { + let ctx = SessionContext::new(); + + let sql = "select date '1994-01-01' - interval '1' day as date;"; + let results = execute_to_batches(&ctx, sql).await; + + let expected = vec![ + "+------------+", + "| date |", + "+------------+", + "| 1993-12-31 |", + "+------------+", + ]; + + assert_batches_eq!(expected, &results); + + Ok(()) +} diff --git a/datafusion/core/tests/tpch-csv/part.csv b/datafusion/core/tests/tpch-csv/part.csv new file mode 100644 index 0000000000000..f790f07bc2fe4 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/part.csv @@ -0,0 +1,2 @@ +p_partkey,p_name,p_mfgr,p_brand,p_type,p_size,p_container,p_retailprice,p_comment +1,goldenrod lavender spring chocolate lace,Manufacturer#1,Brand#13,PROMO BURNISHED COPPER,7,JUMBO PKG,901.00,ly. slyly ironi diff --git a/datafusion/core/tests/tpch-csv/partsupp.csv b/datafusion/core/tests/tpch-csv/partsupp.csv new file mode 100644 index 0000000000000..d7db83d030429 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/partsupp.csv @@ -0,0 +1,2 @@ +ps_partkey,ps_suppkey,ps_availqty,ps_supplycost,ps_comment +67310,7311,100,993.49,ven ideas. quickly even packages print. pending multipliers must have to are fluff diff --git a/datafusion/core/tests/tpch-csv/region.csv b/datafusion/core/tests/tpch-csv/region.csv new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/datafusion/core/tests/tpch-csv/supplier.csv b/datafusion/core/tests/tpch-csv/supplier.csv new file mode 100644 index 0000000000000..768096c7ffa68 --- /dev/null +++ b/datafusion/core/tests/tpch-csv/supplier.csv @@ -0,0 +1,2 @@ +s_suppkey,s_name,s_address,s_nationkey,s_phone,s_acctbal,s_comment +1,Supplier#000000001, N kD4on9OM Ipw3,gf0JBoQDd7tgrzrddZ,17,27-918-335-1736,5755.94,each slyly above the careful diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index 4bd3868793fe5..51b41a1c52915 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -44,4 +44,5 @@ datafusion-common = { path = "../common", version = "9.0.0" } datafusion-expr = { path = "../expr", version = "9.0.0" } datafusion-physical-expr = { path = "../physical-expr", version = "9.0.0" } hashbrown = { version = "0.12", features = ["raw"] } +itertools = "0.10" log = "^0.4" diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs index a6b7cfcbb8fbb..9afe6af0ca130 100644 --- a/datafusion/optimizer/src/lib.rs +++ b/datafusion/optimizer/src/lib.rs @@ -27,6 +27,7 @@ pub mod projection_push_down; pub mod reduce_outer_join; pub mod simplify_expressions; pub mod single_distinct_to_groupby; +pub mod subquery_decorrelate; pub mod subquery_filter_to_join; pub mod utils; diff --git a/datafusion/optimizer/src/simplify_expressions.rs b/datafusion/optimizer/src/simplify_expressions.rs index aa089a00a6a91..da4dfa9eece1c 100644 --- a/datafusion/optimizer/src/simplify_expressions.rs +++ b/datafusion/optimizer/src/simplify_expressions.rs @@ -1951,7 +1951,7 @@ mod tests { let date_plus_interval_expr = to_timestamp_expr(ts_string) .cast_to(&DataType::Date32, schema) .unwrap() - + Expr::Literal(ScalarValue::IntervalDayTime(Some(123))); + + Expr::Literal(ScalarValue::IntervalDayTime(Some(123i64 << 32))); let plan = LogicalPlanBuilder::from(table_scan.clone()) .project(vec![date_plus_interval_expr]) @@ -1963,10 +1963,10 @@ mod tests { // Note that constant folder runs and folds the entire // expression down to a single constant (true) - let expected = "Projection: Date32(\"18636\") AS CAST(totimestamp(Utf8(\"2020-09-08T12:05:00+00:00\")) AS Date32) + IntervalDayTime(\"123\")\ - \n TableScan: test"; + let expected = r#"Projection: Date32("18636") AS CAST(totimestamp(Utf8("2020-09-08T12:05:00+00:00")) AS Date32) + IntervalDayTime("528280977408") + TableScan: test"#; let actual = get_optimized_plan_formatted(&plan, &time); - assert_eq!(expected, actual); + assert_eq!(actual, expected); } } diff --git a/datafusion/optimizer/src/subquery_decorrelate.rs b/datafusion/optimizer/src/subquery_decorrelate.rs new file mode 100644 index 0000000000000..cd19ea0bc7daf --- /dev/null +++ b/datafusion/optimizer/src/subquery_decorrelate.rs @@ -0,0 +1,200 @@ +use crate::{utils, OptimizerConfig, OptimizerRule}; +use datafusion_common::{Column}; +use datafusion_expr::logical_plan::{Filter, JoinType, Subquery}; +use datafusion_expr::{combine_filters, Expr, LogicalPlan, LogicalPlanBuilder, Operator}; +use hashbrown::HashSet; +use itertools::{Either, Itertools}; +use std::sync::Arc; + +/// Optimizer rule for rewriting subquery filters to joins +#[derive(Default)] +pub struct SubqueryDecorrelate {} + +impl SubqueryDecorrelate { + #[allow(missing_docs)] + pub fn new() -> Self { + Self {} + } +} + +impl OptimizerRule for SubqueryDecorrelate { + fn optimize( + &self, + plan: &LogicalPlan, + optimizer_config: &OptimizerConfig, + ) -> datafusion_common::Result { + match plan { + LogicalPlan::Filter(Filter { predicate, input }) => { + let mut filters = vec![]; + utils::split_conjunction(predicate, &mut filters); + + let (subqueries, others): (Vec<_>, Vec<_>) = filters.iter() + .partition_map(|f| { + match f { + Expr::Exists { subquery, negated } => { + if *negated { // TODO: not exists + Either::Right((*f).clone()) + } else { + Either::Left(subquery.clone()) + } + } + _ => Either::Right((*f).clone()) + } + }); + if subqueries.len() != 1 { + return Ok(plan.clone()); // TODO: >1 subquery + } + let subquery = match subqueries.get(0) { + Some(q) => q, + _ => return Ok(plan.clone()) + }; + + optimize_exists(plan, subquery, input, &others) + } + _ => { + // Apply the optimization to all inputs of the plan + utils::optimize_children(self, plan, optimizer_config) + } + } + } + + fn name(&self) -> &str { + "subquery_decorrelate" + } +} + +/* +#orders.o_orderpriority ASC NULLS LAST + Projection: #orders.o_orderpriority, #COUNT(UInt8(1)) AS order_count + Aggregate: groupBy=[[#orders.o_orderpriority]], aggr=[[COUNT(UInt8(1))]] + Filter: EXISTS ( -- plan + Subquery: Projection: * -- proj + Filter: #lineitem.l_orderkey = #orders.o_orderkey -- filter + TableScan: lineitem projection=None -- filter.input + ) + TableScan: orders projection=None -- plan.inputs + */ + +/// Takes a query like: +/// +/// select c.id from customers c where exists (select * from orders o where o.c_id = c.id) +/// +/// and optimizes it into: +/// +/// select c.id from customers c +/// inner join (select o.c_id from orders o group by o.c_id) o on o.c_id = c.c_id +fn optimize_exists( + plan: &LogicalPlan, + subquery: &Subquery, + input: &Arc, + outer_others: &[Expr], +) -> datafusion_common::Result { + // Only operate if there is one input + let sub_inputs = subquery.subquery.inputs(); + if sub_inputs.len() != 1 { + return Ok(plan.clone()); + } + let sub_input = if let Some(i) = sub_inputs.get(0) { + i + } else { + return Ok(plan.clone()); + }; + + // Only operate on subqueries that are trying to filter on an expression from an outer query + let filter = if let LogicalPlan::Filter(f) = sub_input { + f + } else { + return Ok(plan.clone()); + }; + + // split into filters + let mut filters = vec![]; + utils::split_conjunction(&filter.predicate, &mut filters); + + // get names of fields TODO: Must fully qualify these! + let fields: HashSet<_> = sub_input + .schema() + .fields() + .iter() + .map(|f| f.name()) + .collect(); + + // Grab column names to join on + let (cols, others) = find_join_exprs(filters, &fields); + if cols.is_empty() { + return Ok(plan.clone()); // no joins found + } + + // Only operate if one column is present and the other closed upon from outside scope + let l_col: Vec<_> = cols + .iter() + .map(|it| &it.0) + .map(|it| Column::from_qualified_name(it.as_str())) + .collect(); + let r_col: Vec<_> = cols + .iter() + .map(|it| &it.1) + .map(|it| Column::from_qualified_name(it.as_str())) + .collect(); + let expr: Vec<_> = r_col.iter().map(|it| Expr::Column(it.clone())).collect(); + let aggr_expr: Vec = vec![]; + let join_keys = (l_col, r_col); + let right = LogicalPlanBuilder::from((*filter.input).clone()); + let right = if let Some(expr) = combine_filters(&others) { + right.filter(expr)? + } else { + right + }; + let right = right + .aggregate(expr.clone(), aggr_expr)? + .project(expr)? + .build()?; + let new_plan = LogicalPlanBuilder::from((**input).clone()) + .join(&right, JoinType::Inner, join_keys, None)?; + let new_plan = if let Some(expr) = combine_filters(outer_others) { + new_plan.filter(expr)? + } else { + new_plan + }; + new_plan.build() +} + +fn find_join_exprs( + filters: Vec<&Expr>, + fields: &HashSet<&String>, +) -> (Vec<(String, String)>, Vec) { + let (joins, others): (Vec<_>, Vec<_>) = filters.iter().partition_map(|filter| { + let (left, op, right) = match filter { + Expr::BinaryExpr { left, op, right } => (*left.clone(), *op, *right.clone()), + _ => return Either::Right((*filter).clone()), + }; + match op { + Operator::Eq => {} + _ => return Either::Right((*filter).clone()), + } + let left = match left { + Expr::Column(c) => c, + _ => return Either::Right((*filter).clone()), + }; + let right = match right { + Expr::Column(c) => c, + _ => return Either::Right((*filter).clone()), + }; + if fields.contains(&left.name) && fields.contains(&right.name) { + return Either::Right((*filter).clone()); // Need one of each + } + if !fields.contains(&left.name) && !fields.contains(&right.name) { + return Either::Right((*filter).clone()); // Need one of each + } + + let sorted = if fields.contains(&left.name) { + (right.name, left.name) + } else { + (left.name, right.name) + }; + + Either::Left(sorted) + }); + + (joins, others) +} diff --git a/datafusion/physical-expr/src/expressions/datetime.rs b/datafusion/physical-expr/src/expressions/datetime.rs index 3d84e79f2cc97..6b4b2e571fe1b 100644 --- a/datafusion/physical-expr/src/expressions/datetime.rs +++ b/datafusion/physical-expr/src/expressions/datetime.rs @@ -18,11 +18,14 @@ use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; +use chrono::{Datelike, Duration, NaiveDate}; use datafusion_common::Result; use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, Operator}; use std::any::Any; +use std::cmp::min; use std::fmt::{Display, Formatter}; +use std::ops::{Add, Sub}; use std::sync::Arc; /// Perform DATE +/ INTERVAL math @@ -74,88 +77,121 @@ impl PhysicalExpr for DateIntervalExpr { self } - fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + fn data_type(&self, input_schema: &Schema) -> Result { self.lhs.data_type(input_schema) } - fn nullable(&self, input_schema: &Schema) -> datafusion_common::Result { + fn nullable(&self, input_schema: &Schema) -> Result { self.lhs.nullable(input_schema) } - fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + fn evaluate(&self, batch: &RecordBatch) -> Result { let dates = self.lhs.evaluate(batch)?; let intervals = self.rhs.evaluate(batch)?; - let interval = match intervals { - ColumnarValue::Scalar(interval) => match interval { - ScalarValue::IntervalDayTime(Some(interval)) => interval as i32, - ScalarValue::IntervalYearMonth(Some(_)) => { - return Err(DataFusionError::Execution( - "DateIntervalExpr does not support IntervalYearMonth".to_string(), - )) - } - ScalarValue::IntervalMonthDayNano(Some(_)) => { - return Err(DataFusionError::Execution( - "DateIntervalExpr does not support IntervalMonthDayNano" - .to_string(), - )) - } - other => { - return Err(DataFusionError::Execution(format!( - "DateIntervalExpr does not support non-interval type {:?}", - other - ))) - } - }, - _ => { - return Err(DataFusionError::Execution( - "Columnar execution is not yet supported for DateIntervalExpr" - .to_string(), - )) - } + // Unwrap days since epoch + let operand = match dates { + ColumnarValue::Scalar(scalar) => scalar, + _ => Err(DataFusionError::Execution( + "Columnar execution is not yet supported for DateIntervalExpr" + .to_string(), + ))?, }; - match dates { - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Date32(Some(date)) => match &self.op { - Operator::Plus => Ok(ColumnarValue::Scalar(ScalarValue::Date32( - Some(date + interval), - ))), - Operator::Minus => Ok(ColumnarValue::Scalar(ScalarValue::Date32( - Some(date - interval), - ))), - _ => { - // this should be unreachable because we check the operators in `try_new` - Err(DataFusionError::Execution( - "Invalid operator for DateIntervalExpr".to_string(), - )) - } - }, - ScalarValue::Date64(Some(date)) => match &self.op { - Operator::Plus => Ok(ColumnarValue::Scalar(ScalarValue::Date64( - Some(date + interval as i64), - ))), - Operator::Minus => Ok(ColumnarValue::Scalar(ScalarValue::Date64( - Some(date - interval as i64), - ))), - _ => { - // this should be unreachable because we check the operators in `try_new` - Err(DataFusionError::Execution( - "Invalid operator for DateIntervalExpr".to_string(), - )) - } - }, - _ => { - // this should be unreachable because we check the types in `try_new` - Err(DataFusionError::Execution( - "Invalid lhs type for DateIntervalExpr".to_string(), - )) - } - }, + // Convert to NaiveDate + let epoch = NaiveDate::from_ymd(1970, 1, 1); + let prior = match operand { + ScalarValue::Date32(Some(d)) => epoch.add(Duration::days(d as i64)), + ScalarValue::Date64(Some(ms)) => epoch.add(Duration::milliseconds(ms)), + _ => Err(DataFusionError::Execution(format!( + "Invalid lhs type for DateIntervalExpr: {:?}", + operand + )))?, + }; + + // Unwrap interval to add + let scalar = match &intervals { + ColumnarValue::Scalar(interval) => interval, _ => Err(DataFusionError::Execution( "Columnar execution is not yet supported for DateIntervalExpr" .to_string(), - )), - } + ))?, + }; + + // Invert sign for subtraction + let sign = match &self.op { + Operator::Plus => 1, + Operator::Minus => -1, + _ => { + // this should be unreachable because we check the operators in `try_new` + Err(DataFusionError::Execution( + "Invalid operator for DateIntervalExpr".to_string(), + ))? + } + }; + + // Do math + let posterior = match scalar { + ScalarValue::IntervalDayTime(Some(i)) => add_day_time(prior, *i, sign), + ScalarValue::IntervalYearMonth(Some(i)) => add_months(prior, *i * sign), + ScalarValue::IntervalMonthDayNano(Some(i)) => add_m_d_nano(prior, *i, sign), + other => Err(DataFusionError::Execution(format!( + "DateIntervalExpr does not support non-interval type {:?}", + other + )))?, + }; + + // convert back + let res = match operand { + ScalarValue::Date32(Some(_)) => { + let days = posterior.sub(epoch).num_days() as i32; + ColumnarValue::Scalar(ScalarValue::Date32(Some(days))) + } + ScalarValue::Date64(Some(_)) => { + let ms = posterior.sub(epoch).num_milliseconds(); + ColumnarValue::Scalar(ScalarValue::Date64(Some(ms))) + } + _ => Err(DataFusionError::Execution(format!( + "Invalid lhs type for DateIntervalExpr: {}", + scalar + )))?, + }; + Ok(res) } } + +fn add_m_d_nano(prior: NaiveDate, interval: i128, sign: i32) -> NaiveDate { + let interval = interval as u128; + let months = (interval >> 96) as i32 * sign; + let days = (interval >> 64) as i32 * sign; + let nanos = interval as i64 * sign as i64; + let a = add_months(prior, months); + let b = a.add(Duration::days(days as i64)); + b.add(Duration::nanoseconds(nanos)) +} + +fn add_day_time(prior: NaiveDate, interval: i64, sign: i32) -> NaiveDate { + let interval = interval as u64; + let days = (interval >> 32) as i32 * sign; + let ms = interval as i32 * sign; + let intermediate = prior.add(Duration::days(days as i64)); + intermediate.add(Duration::milliseconds(ms as i64)) +} + +fn add_months(prior: NaiveDate, interval: i32) -> NaiveDate { + let target = chrono_add_months(prior, interval); + let target_plus = chrono_add_months(target, 1); + let last_day = target_plus.sub(chrono::Duration::days(1)); + let day = min(prior.day(), last_day.day()); + NaiveDate::from_ymd(target.year(), target.month(), day) +} + +fn chrono_add_months(dt: NaiveDate, delta: i32) -> NaiveDate { + let ay = dt.year(); + let am = dt.month() as i32 - 1; // zero-based for modulo operations + let bm = am + delta as i32; + let by = ay + if bm < 0 { bm / 12 - 1 } else { bm / 12 }; + let cm = bm % 12; + let dm = if cm < 0 { cm + 12 } else { cm }; + NaiveDate::from_ymd(by, dm as u32 + 1, 1) +}