diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 40d40692e5931..1de458a9838f1 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -1266,6 +1266,54 @@ impl Expr { Ok(Transformed::Yes(expr)) }) } + + /// Returns true if some of this `exprs` subexpressions may not be evaluated + /// and thus any side effects (like divide by zero) may not be encountered + pub fn short_circuits(&self) -> bool { + match self { + Expr::ScalarFunction(ScalarFunction { func_def, .. }) => { + matches!(func_def, ScalarFunctionDefinition::BuiltIn(fun) if *fun == BuiltinScalarFunction::Coalesce) + } + Expr::BinaryExpr(BinaryExpr { op, .. }) => { + matches!(op, Operator::And | Operator::Or) + } + Expr::Case { .. } => true, + // Use explicit pattern match instead of a default + // implementation, so that in the future if someone adds + // new Expr types, they will check here as well + Expr::AggregateFunction(..) + | Expr::Alias(..) + | Expr::Between(..) + | Expr::Cast(..) + | Expr::Column(..) + | Expr::Exists(..) + | Expr::GetIndexedField(..) + | Expr::GroupingSet(..) + | Expr::InList(..) + | Expr::InSubquery(..) + | Expr::IsFalse(..) + | Expr::IsNotFalse(..) + | Expr::IsNotNull(..) + | Expr::IsNotTrue(..) + | Expr::IsNotUnknown(..) + | Expr::IsNull(..) + | Expr::IsTrue(..) + | Expr::IsUnknown(..) + | Expr::Like(..) + | Expr::ScalarSubquery(..) + | Expr::ScalarVariable(_, _) + | Expr::SimilarTo(..) + | Expr::Not(..) + | Expr::Negative(..) + | Expr::OuterReferenceColumn(_, _) + | Expr::TryCast(..) + | Expr::Wildcard { .. } + | Expr::WindowFunction(..) + | Expr::Literal(..) + | Expr::Sort(..) + | Expr::Placeholder(..) => false, + } + } } // modifies expr if it is a placeholder with datatype of right diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index f29c7406acc99..fe71171ce5455 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -616,8 +616,8 @@ impl TreeNodeVisitor for ExprIdentifierVisitor<'_> { fn pre_visit(&mut self, expr: &Expr) -> Result { // related to https://github.com/apache/arrow-datafusion/issues/8814 - // If the expr contain volatile expression or is a case expression, skip it. - if matches!(expr, Expr::Case(..)) || is_volatile_expression(expr)? { + // If the expr contain volatile expression or is a short-circuit expression, skip it. + if expr.short_circuits() || is_volatile_expression(expr)? { return Ok(VisitRecursion::Skip); } self.visit_stack @@ -696,7 +696,13 @@ struct CommonSubexprRewriter<'a> { impl TreeNodeRewriter for CommonSubexprRewriter<'_> { type N = Expr; - fn pre_visit(&mut self, _: &Expr) -> Result { + fn pre_visit(&mut self, expr: &Expr) -> Result { + // The `CommonSubexprRewriter` relies on `ExprIdentifierVisitor` to generate + // the `id_array`, which records the expr's identifier used to rewrite expr. So if we + // skip an expr in `ExprIdentifierVisitor`, we should skip it here, too. + if expr.short_circuits() || is_volatile_expression(expr)? { + return Ok(RewriteRecursion::Stop); + } if self.curr_index >= self.id_array.len() || self.max_series_number > self.id_array[self.curr_index].0 { @@ -1249,12 +1255,11 @@ mod test { let table_scan = test_table_scan()?; let plan = LogicalPlanBuilder::from(table_scan) - .filter(lit(1).gt(col("a")).and(lit(1).gt(col("a"))))? + .filter((lit(1) + col("a") - lit(10)).gt(lit(1) + col("a")))? .build()?; let expected = "Projection: test.a, test.b, test.c\ - \n Filter: Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a AND Int32(1) > test.atest.aInt32(1) AS Int32(1) > test.a\ - \n Projection: Int32(1) > test.a AS Int32(1) > test.atest.aInt32(1), test.a, test.b, test.c\ + \n Filter: Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a - Int32(10) > Int32(1) + test.atest.aInt32(1) AS Int32(1) + test.a\n Projection: Int32(1) + test.a AS Int32(1) + test.atest.aInt32(1), test.a, test.b, test.c\ \n TableScan: test"; assert_optimized_plan_eq(expected, &plan); diff --git a/datafusion/sqllogictest/test_files/select.slt b/datafusion/sqllogictest/test_files/select.slt index ca48c07b09146..9ffddc6e2d465 100644 --- a/datafusion/sqllogictest/test_files/select.slt +++ b/datafusion/sqllogictest/test_files/select.slt @@ -1129,5 +1129,49 @@ FROM t AS A, (SELECT * FROM t WHERE x = 0) AS B; 0 0 0 0 +# Expressions that short circuit should not be refactored out as that may cause side effects (divide by zero) +# at plan time that would not actually happen during execution, so the follow three query should not be extract +# the common sub-expression +query TT +explain select coalesce(1, y/x), coalesce(2, y/x) from t; +---- +logical_plan +Projection: coalesce(Int64(1), CAST(t.y / t.x AS Int64)), coalesce(Int64(2), CAST(t.y / t.x AS Int64)) +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[coalesce(1, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(1),t.y / t.x), coalesce(2, CAST(y@1 / x@0 AS Int64)) as coalesce(Int64(2),t.y / t.x)] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; +---- +logical_plan +Projection: t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y > Int64(0) AND Int64(1) / t.y < Int64(1), t.x > Int32(0) AND t.y > Int32(0) AND Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 as t.y > Int64(0) AND Int64(1) / t.y < Int64(1), x@0 > 0 AND y@1 > 0 AND 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x > Int64(0) AND t.y > Int64(0) AND Int64(1) / t.y < Int64(1) / t.x] +--MemoryExec: partitions=1, partition_sizes=[1] + +query TT +EXPLAIN SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; +---- +logical_plan +Projection: t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) AS t.y = Int64(0) OR Int64(1) / t.y < Int64(1), t.x = Int32(0) OR t.y = Int32(0) OR Int64(1) / CAST(t.y AS Int64) < Int64(1) / CAST(t.x AS Int64) AS t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x +--TableScan: t projection=[x, y] +physical_plan +ProjectionExec: expr=[y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 as t.y = Int64(0) OR Int64(1) / t.y < Int64(1), x@0 = 0 OR y@1 = 0 OR 1 / CAST(y@1 AS Int64) < 1 / CAST(x@0 AS Int64) as t.x = Int64(0) OR t.y = Int64(0) OR Int64(1) / t.y < Int64(1) / t.x] +--MemoryExec: partitions=1, partition_sizes=[1] + +# due to the reason describe in https://github.com/apache/arrow-datafusion/issues/8927, +# the following queries will fail +query error +select coalesce(1, y/x), coalesce(2, y/x) from t; + +query error +SELECT y > 0 and 1 / y < 1, x > 0 and y > 0 and 1 / y < 1 / x from t; + +query error +SELECT y = 0 or 1 / y < 1, x = 0 or y = 0 or 1 / y < 1 / x from t; + statement ok DROP TABLE t;