diff --git a/datafusion/expr/src/binary_rule.rs b/datafusion/expr/src/binary_rule.rs index 5eeb1867a0170..29c1cf0d4174b 100644 --- a/datafusion/expr/src/binary_rule.rs +++ b/datafusion/expr/src/binary_rule.rs @@ -72,6 +72,12 @@ pub fn binary_operator_data_type( /// Coercion rules for all binary operators. Returns the output type /// of applying `op` to an argument of `lhs_type` and `rhs_type`. +/// +/// TODO this function is trying to serve two purposes at once; it determines the result type +/// of the binary operation and also determines how the inputs can be coerced but this +/// results in inconsistencies in some cases (particular around date + interval) +/// +/// Tracking issue is https://github.com/apache/arrow-datafusion/issues/3419 pub fn coerce_types( lhs_type: &DataType, op: &Operator, @@ -516,6 +522,8 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: &DataType) -> Option Some(Date64), + (Date32, Date64) => Some(Date64), (Utf8, Date32) => Some(Date32), (Date32, Utf8) => Some(Date32), (Utf8, Date64) => Some(Date64), diff --git a/datafusion/optimizer/src/type_coercion.rs b/datafusion/optimizer/src/type_coercion.rs index 42c081af3dfe1..df0d3681177a1 100644 --- a/datafusion/optimizer/src/type_coercion.rs +++ b/datafusion/optimizer/src/type_coercion.rs @@ -18,15 +18,15 @@ //! Optimizer rule for type validation and coercion use crate::{OptimizerConfig, OptimizerRule}; -use datafusion_common::{DFSchema, DFSchemaRef, Result}; -use datafusion_expr::binary_rule::coerce_types; +use arrow::datatypes::DataType; +use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result}; +use datafusion_expr::binary_rule::{coerce_types, comparison_coercion}; use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; -use datafusion_expr::logical_plan::builder::build_join_schema; -use datafusion_expr::logical_plan::JoinType; use datafusion_expr::type_coercion::data_types; use datafusion_expr::utils::from_plan; use datafusion_expr::{Expr, LogicalPlan}; use datafusion_expr::{ExprSchemable, Signature}; +use std::sync::Arc; #[derive(Default)] pub struct TypeCoercion {} @@ -54,17 +54,19 @@ impl OptimizerRule for TypeCoercion { .map(|p| self.optimize(p, optimizer_config)) .collect::>>()?; - let schema = match new_inputs.len() { - 1 => new_inputs[0].schema().clone(), - 2 => DFSchemaRef::new(build_join_schema( - new_inputs[0].schema(), - new_inputs[1].schema(), - &JoinType::Inner, - )?), - _ => DFSchemaRef::new(DFSchema::empty()), - }; + // get schema representing all available input fields. This is used for data type + // resolution only, so order does not matter here + let schema = new_inputs.iter().map(|input| input.schema()).fold( + DFSchema::empty(), + |mut lhs, rhs| { + lhs.merge(rhs); + lhs + }, + ); - let mut expr_rewrite = TypeCoercionRewriter { schema }; + let mut expr_rewrite = TypeCoercionRewriter { + schema: Arc::new(schema), + }; let new_expr = plan .expressions() @@ -87,14 +89,55 @@ impl ExprRewriter for TypeCoercionRewriter { fn mutate(&mut self, expr: Expr) -> Result { match expr { - Expr::BinaryExpr { left, op, right } => { + Expr::BinaryExpr { + ref left, + op, + ref right, + } => { let left_type = left.get_type(&self.schema)?; let right_type = right.get_type(&self.schema)?; - let coerced_type = coerce_types(&left_type, &op, &right_type)?; - Ok(Expr::BinaryExpr { - left: Box::new(left.cast_to(&coerced_type, &self.schema)?), - op, - right: Box::new(right.cast_to(&coerced_type, &self.schema)?), + match (&left_type, &right_type) { + ( + DataType::Date32 | DataType::Date64 | DataType::Timestamp(_, _), + &DataType::Interval(_), + ) => { + // this is a workaround for https://github.com/apache/arrow-datafusion/issues/3419 + Ok(expr.clone()) + } + _ => { + let coerced_type = coerce_types(&left_type, &op, &right_type)?; + Ok(Expr::BinaryExpr { + left: Box::new( + left.clone().cast_to(&coerced_type, &self.schema)?, + ), + op, + right: Box::new( + right.clone().cast_to(&coerced_type, &self.schema)?, + ), + }) + } + } + } + Expr::Between { + expr, + negated, + low, + high, + } => { + let expr_type = expr.get_type(&self.schema)?; + let low_type = low.get_type(&self.schema)?; + let coerced_type = comparison_coercion(&expr_type, &low_type) + .ok_or_else(|| { + DataFusionError::Internal(format!( + "Failed to coerce types {} and {} in BETWEEN expression", + expr_type, low_type + )) + })?; + Ok(Expr::Between { + expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?), + negated, + low: Box::new(low.cast_to(&coerced_type, &self.schema)?), + high: Box::new(high.cast_to(&coerced_type, &self.schema)?), }) } Expr::ScalarUDF { fun, args } => { @@ -145,12 +188,12 @@ mod test { use crate::type_coercion::TypeCoercion; use crate::{OptimizerConfig, OptimizerRule}; use arrow::datatypes::DataType; - use datafusion_common::{DFSchema, Result}; + use datafusion_common::{DFSchema, Result, ScalarValue}; use datafusion_expr::{ lit, logical_plan::{EmptyRelation, Projection}, - Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, ScalarUDF, - Signature, Volatility, + Expr, LogicalPlan, Operator, ReturnTypeFunction, ScalarFunctionImplementation, + ScalarUDF, Signature, Volatility, }; use std::sync::Arc; @@ -244,6 +287,34 @@ mod test { Ok(()) } + #[test] + fn binary_op_date32_add_interval() -> Result<()> { + //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640") + let expr = Expr::BinaryExpr { + left: Box::new(Expr::Cast { + expr: Box::new(lit("1998-03-18")), + data_type: DataType::Date32, + }), + op: Operator::Plus, + right: Box::new(Expr::Literal(ScalarValue::IntervalDayTime(Some( + 386547056640, + )))), + }; + let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + })); + let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], empty, None)?); + let rule = TypeCoercion::new(); + let mut config = OptimizerConfig::default(); + let plan = rule.optimize(&plan, &mut config)?; + assert_eq!( + "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + IntervalDayTime(\"386547056640\")\n EmptyRelation", + &format!("{:?}", plan) + ); + Ok(()) + } + fn empty() -> Arc { Arc::new(LogicalPlan::EmptyRelation(EmptyRelation { produce_one_row: false, diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 55c38689bdfb6..87a0bab68a40a 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -27,6 +27,7 @@ use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; use datafusion_optimizer::filter_push_down::FilterPushDown; use datafusion_optimizer::limit_push_down::LimitPushDown; use datafusion_optimizer::optimizer::Optimizer; +use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions; use datafusion_optimizer::projection_push_down::ProjectionPushDown; use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin; use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; @@ -34,6 +35,7 @@ use datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin; use datafusion_optimizer::simplify_expressions::SimplifyExpressions; use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin; +use datafusion_optimizer::type_coercion::TypeCoercion; use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; @@ -56,11 +58,56 @@ fn distribute_by() -> Result<()> { Ok(()) } +#[test] +fn intersect() -> Result<()> { + let sql = "SELECT col_int32, col_utf8 FROM test \ + INTERSECT SELECT col_int32, col_utf8 FROM test \ + INTERSECT SELECT col_int32, col_utf8 FROM test"; + let plan = test_sql(sql)?; + let expected = + "Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 = #test.col_utf8\ + \n Distinct:\ + \n Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 = #test.col_utf8\ + \n Distinct:\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n TableScan: test projection=[col_int32, col_utf8]\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +#[test] +fn between_date32_plus_interval() -> Result<()> { + let sql = "SELECT count(1) FROM test \ + WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; + let plan = test_sql(sql)?; + let expected = + "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32) AND #test.col_date32 <= Date32(\"10393\")\ + \n TableScan: test projection=[col_date32]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +#[test] +fn between_date64_plus_interval() -> Result<()> { + let sql = "SELECT count(1) FROM test \ + WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) + INTERVAL '90 days'"; + let plan = test_sql(sql)?; + let expected = + "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\ + \n Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64) AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\ + \n TableScan: test projection=[col_date64]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + fn test_sql(sql: &str) -> Result { let rules: Vec> = vec![ // Simplify expressions first to maximize the chance // of applying other optimizations Arc::new(SimplifyExpressions::new()), + Arc::new(PreCastLitInComparisonExpressions::new()), Arc::new(DecorrelateWhereExists::new()), Arc::new(DecorrelateWhereIn::new()), Arc::new(ScalarSubqueryToJoin::new()), @@ -73,6 +120,7 @@ fn test_sql(sql: &str) -> Result { Arc::new(FilterNullJoinKeys::default()), Arc::new(ReduceOuterJoin::new()), Arc::new(FilterPushDown::new()), + Arc::new(TypeCoercion::new()), Arc::new(LimitPushDown::new()), Arc::new(SingleDistinctToGroupBy::new()), ]; @@ -107,6 +155,8 @@ impl ContextProvider for MySchemaProvider { vec![ Field::new("col_int32", DataType::Int32, true), Field::new("col_utf8", DataType::Utf8, true), + Field::new("col_date32", DataType::Date32, true), + Field::new("col_date64", DataType::Date64, true), ], HashMap::new(), );