From 2da68859749137fa35ed71a50c346c559d2b35e2 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 19 Dec 2020 10:08:22 +0800 Subject: [PATCH 1/5] Improve SimplifyConditionals and PushFoldableIntoBranches --- .../sql/catalyst/optimizer/expressions.scala | 8 ++- .../PushFoldableIntoBranchesSuite.scala | 54 ++++++++++--------- .../optimizer/SimplifyConditionalSuite.scala | 10 ++++ 3 files changed, 46 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index e6730c9275a1e..eb4e11ac1b0c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,6 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue + case If(cond, TrueLiteral, FalseLiteral) if cond.deterministic => cond + case If(cond, FalseLiteral, TrueLiteral) if cond.deterministic => Not(cond) case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l) @@ -558,13 +560,15 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), - elseValue.map(e => b.makeCopy(Array(e, right)))) + elseValue.orElse(Some(Literal.create(null, right.dataType))) + .map(e => b.makeCopy(Array(e, right)))) case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))), - elseValue.map(e => b.makeCopy(Array(left, e)))) + elseValue.orElse(Some(Literal.create(null, left.dataType))) + .map(e => b.makeCopy(Array(left, e)))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index 43360af46ffb3..ab393befbfecc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -44,6 +44,8 @@ class PushFoldableIntoBranchesSuite private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) private val ifExp = If(a, Literal(2), Literal(3)) private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) + private val nullInt = Literal(null, IntegerType) + private val nullBoolean = Literal(null, BooleanType) protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze @@ -53,7 +55,7 @@ class PushFoldableIntoBranchesSuite test("Push down EqualTo through If") { assertEquivalent(EqualTo(ifExp, Literal(4)), FalseLiteral) - assertEquivalent(EqualTo(ifExp, Literal(3)), If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(EqualTo(ifExp, Literal(3)), Not(a)) // Push down at most one not foldable expressions. assertEquivalent( @@ -73,17 +75,17 @@ class PushFoldableIntoBranchesSuite // Handle Null values. assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), - If(a, Literal(null, BooleanType), TrueLiteral)) + EqualTo(If(a, nullInt, Literal(1)), Literal(1)), + If(a, nullBoolean, TrueLiteral)) assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), - If(a, Literal(null, BooleanType), FalseLiteral)) + EqualTo(If(a, nullInt, Literal(1)), Literal(2)), + If(a, nullBoolean, FalseLiteral)) assertEquivalent( - EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(If(a, Literal(1), Literal(2)), nullInt), + nullBoolean) assertEquivalent( - EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), - Literal(null, BooleanType)) + EqualTo(If(a, nullInt, nullInt), Literal(1)), + nullBoolean) } test("Push down other BinaryComparison through If") { @@ -102,8 +104,7 @@ class PushFoldableIntoBranchesSuite assertEquivalent(Remainder(ifExp, Literal(4)), If(a, Literal(2), Literal(3))) assertEquivalent(Divide(If(a, Literal(2.0), Literal(3.0)), Literal(1.0)), If(a, Literal(2.0), Literal(3.0))) - assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), - If(a, FalseLiteral, TrueLiteral)) + assertEquivalent(And(If(a, FalseLiteral, TrueLiteral), TrueLiteral), Not(a)) assertEquivalent(Or(If(a, FalseLiteral, TrueLiteral), TrueLiteral), TrueLiteral) } @@ -123,7 +124,9 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), nullBoolean)) + assertEquivalent( + EqualTo(CaseWhen(Seq((a, nullInt), (c, nullInt)), None), Literal(4)), nullBoolean) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), @@ -131,7 +134,7 @@ class PushFoldableIntoBranchesSuite // Push down at most one branch is not foldable expressions. assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), - CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), nullBoolean)) assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), @@ -148,22 +151,22 @@ class PushFoldableIntoBranchesSuite // Handle Null values. assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), - CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) + EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(2)), + CaseWhen(Seq((a, nullBoolean)), Some(FalseLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), nullInt), + nullBoolean) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), - CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) + EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(1)), + CaseWhen(Seq((a, nullBoolean)), Some(TrueLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)), Literal(1)), - Literal(null, BooleanType)) + nullBoolean) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), - Literal(null, IntegerType)), - Literal(null, BooleanType)) + EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)), + nullInt), + nullBoolean) } test("Push down other BinaryComparison through CaseWhen") { @@ -220,6 +223,9 @@ class PushFoldableIntoBranchesSuite test("Push down BinaryExpression through If/CaseWhen backwards") { assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) + assertEquivalent(EqualTo(Literal(4), If(a, nullInt, nullInt)), nullBoolean) assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) + assertEquivalent(EqualTo(Literal(4), CaseWhen(Seq((a, nullInt), (c, nullInt)), None)), + nullBoolean) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index bac962ced4618..2611e0f41b149 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -79,6 +79,16 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P Literal(9))) } + test("remove unnecessary if when the outputs are boolean type") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), + IsNotNull(UnresolvedAttribute("a"))) + + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), + IsNull(UnresolvedAttribute("a"))) + } + test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( From 090890c2e2da18b2968f590f538f0e1ab0e8a050 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 19 Dec 2020 12:28:33 +0800 Subject: [PATCH 2/5] Improve SimplifyConditionals --- .../sql/catalyst/optimizer/expressions.scala | 6 ++-- .../PushFoldableIntoBranchesSuite.scala | 8 ++--- ...ReplaceNullWithFalseInPredicateSuite.scala | 31 +++++++++++-------- .../optimizer/SimplifyConditionalSuite.scala | 26 ++++++++++------ 4 files changed, 38 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index eb4e11ac1b0c2..2c7346770146f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -560,15 +560,13 @@ object PushFoldableIntoBranches extends Rule[LogicalPlan] with PredicateHelper { if right.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(e._2, right)))), - elseValue.orElse(Some(Literal.create(null, right.dataType))) - .map(e => b.makeCopy(Array(e, right)))) + elseValue.map(e => b.makeCopy(Array(e, right)))) case b @ BinaryExpression(left, c @ CaseWhen(branches, elseValue)) if left.foldable && atMostOneUnfoldable(branches.map(_._2) ++ elseValue) => c.copy( branches.map(e => e.copy(_2 = b.makeCopy(Array(left, e._2)))), - elseValue.orElse(Some(Literal.create(null, left.dataType))) - .map(e => b.makeCopy(Array(left, e)))) + elseValue.map(e => b.makeCopy(Array(left, e)))) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index ab393befbfecc..c65b1f94b22e9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -124,9 +124,7 @@ class PushFoldableIntoBranchesSuite CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), Some(TrueLiteral))) assertEquivalent( EqualTo(CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), None), Literal(4)), - CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), nullBoolean)) - assertEquivalent( - EqualTo(CaseWhen(Seq((a, nullInt), (c, nullInt)), None), Literal(4)), nullBoolean) + CaseWhen(Seq((a, FalseLiteral), (c, FalseLiteral)), None)) assertEquivalent( And(EqualTo(caseWhen, Literal(5)), EqualTo(caseWhen, Literal(6))), @@ -134,7 +132,7 @@ class PushFoldableIntoBranchesSuite // Push down at most one branch is not foldable expressions. assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, Literal(1))), None), Literal(1)), - CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), nullBoolean)) + CaseWhen(Seq((a, EqualTo(b, Literal(1))), (c, TrueLiteral)), None)) assertEquivalent(EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1)), EqualTo(CaseWhen(Seq((a, b), (c, b + 1)), None), Literal(1))) assertEquivalent(EqualTo(CaseWhen(Seq((a, b)), None), Literal(1)), @@ -225,7 +223,5 @@ class PushFoldableIntoBranchesSuite assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) assertEquivalent(EqualTo(Literal(4), If(a, nullInt, nullInt)), nullBoolean) assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) - assertEquivalent(EqualTo(Literal(4), CaseWhen(Seq((a, nullInt), (c, nullInt)), None)), - nullBoolean) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 00433a5490574..5da71c31e1990 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, LessThanOrEqual, Literal, MapFilter, NamedExpression, Or, UnresolvedNamedLambdaVariable} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{DeleteFromTable, LocalRelation, LogicalPlan, UpdateTable} @@ -236,12 +236,13 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(2) === nestedCaseWhen, TrueLiteral, FalseLiteral) - val branches = Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue) - val condition = CaseWhen(branches) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + val condition = CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> branchValue)) + val expectedCond = + CaseWhen(Seq((UnresolvedAttribute("i") > Literal(10)) -> (Literal(2) === nestedCaseWhen))) + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("inability to replace null in non-boolean branches of If inside another If") { @@ -252,10 +253,14 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { Literal(3)), TrueLiteral, FalseLiteral) - testFilter(originalCond = condition, expectedCond = condition) - testJoin(originalCond = condition, expectedCond = condition) - testDelete(originalCond = condition, expectedCond = condition) - testUpdate(originalCond = condition, expectedCond = condition) + val expectedCond = Literal(5) > If( + UnresolvedAttribute("i") === Literal(15), + Literal(null, IntegerType), + Literal(3)) + testFilter(originalCond = condition, expectedCond = expectedCond) + testJoin(originalCond = condition, expectedCond = expectedCond) + testDelete(originalCond = condition, expectedCond = expectedCond) + testUpdate(originalCond = condition, expectedCond = expectedCond) } test("replace null in If used as a join condition") { @@ -405,9 +410,9 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { val lambda1 = LambdaFunction( function = If(cond, Literal(null, BooleanType), TrueLiteral), arguments = lambdaArgs) - // the optimized lambda body is: if(arg > 0, false, true) + // the optimized lambda body is: if(arg > 0, false, true) => arg <= 0 val lambda2 = LambdaFunction( - function = If(cond, FalseLiteral, TrueLiteral), + function = LessThanOrEqual(condArg, Literal(0)), arguments = lambdaArgs) testProjection( originalExpr = createExpr(argument, lambda1) as 'x, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 2611e0f41b149..438fa75a7836e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -79,16 +79,6 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P Literal(9))) } - test("remove unnecessary if when the outputs are boolean type") { - assertEquivalent( - If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), - IsNotNull(UnresolvedAttribute("a"))) - - assertEquivalent( - If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), - IsNull(UnresolvedAttribute("a"))) - } - test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( @@ -209,4 +199,20 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P If(Factorial(5) > 100L, b, nullLiteral).eval(EmptyRow)) } } + + test("remove unnecessary if when the outputs are boolean type") { + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), + IsNotNull(UnresolvedAttribute("a"))) + assertEquivalent( + If(IsNotNull(UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), + IsNull(UnresolvedAttribute("a"))) + + assertEquivalent( + If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), + If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral)) + assertEquivalent( + If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), + If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral)) + } } From f4d8f6b78404594e2d4ff4a8269586032dcd56e3 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Sat, 19 Dec 2020 13:26:15 +0800 Subject: [PATCH 3/5] fix --- .../PushFoldableIntoBranchesSuite.scala | 41 +++++++++---------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index c65b1f94b22e9..c9124e6dc5f33 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -44,8 +44,6 @@ class PushFoldableIntoBranchesSuite private val c = EqualTo(UnresolvedAttribute("c"), Literal(true)) private val ifExp = If(a, Literal(2), Literal(3)) private val caseWhen = CaseWhen(Seq((a, Literal(1)), (c, Literal(2))), Some(Literal(3))) - private val nullInt = Literal(null, IntegerType) - private val nullBoolean = Literal(null, BooleanType) protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { val correctAnswer = Project(Alias(e2, "out")() :: Nil, relation).analyze @@ -75,17 +73,17 @@ class PushFoldableIntoBranchesSuite // Handle Null values. assertEquivalent( - EqualTo(If(a, nullInt, Literal(1)), Literal(1)), - If(a, nullBoolean, TrueLiteral)) + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(1)), + If(a, Literal(null, BooleanType), TrueLiteral)) assertEquivalent( - EqualTo(If(a, nullInt, Literal(1)), Literal(2)), - If(a, nullBoolean, FalseLiteral)) + EqualTo(If(a, Literal(null, IntegerType), Literal(1)), Literal(2)), + If(a, Literal(null, BooleanType), FalseLiteral)) assertEquivalent( - EqualTo(If(a, Literal(1), Literal(2)), nullInt), - nullBoolean) + EqualTo(If(a, Literal(1), Literal(2)), Literal(null, IntegerType)), + Literal(null, BooleanType)) assertEquivalent( - EqualTo(If(a, nullInt, nullInt), Literal(1)), - nullBoolean) + EqualTo(If(a, Literal(null, IntegerType), Literal(null, IntegerType)), Literal(1)), + Literal(null, BooleanType)) } test("Push down other BinaryComparison through If") { @@ -149,22 +147,22 @@ class PushFoldableIntoBranchesSuite // Handle Null values. assertEquivalent( - EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(2)), - CaseWhen(Seq((a, nullBoolean)), Some(FalseLiteral))) + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(2)), + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(FalseLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), nullInt), - nullBoolean) + EqualTo(CaseWhen(Seq((a, Literal(1))), Some(Literal(2))), Literal(null, IntegerType)), + Literal(null, BooleanType)) assertEquivalent( - EqualTo(CaseWhen(Seq((a, nullInt)), Some(Literal(1))), Literal(1)), - CaseWhen(Seq((a, nullBoolean)), Some(TrueLiteral))) + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(1))), Literal(1)), + CaseWhen(Seq((a, Literal(null, BooleanType))), Some(TrueLiteral))) assertEquivalent( - EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)), + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), Literal(1)), - nullBoolean) + Literal(null, BooleanType)) assertEquivalent( - EqualTo(CaseWhen(Seq((a, nullInt)), Some(nullInt)), - nullInt), - nullBoolean) + EqualTo(CaseWhen(Seq((a, Literal(null, IntegerType))), Some(Literal(null, IntegerType))), + Literal(null, IntegerType)), + Literal(null, BooleanType)) } test("Push down other BinaryComparison through CaseWhen") { @@ -221,7 +219,6 @@ class PushFoldableIntoBranchesSuite test("Push down BinaryExpression through If/CaseWhen backwards") { assertEquivalent(EqualTo(Literal(4), ifExp), FalseLiteral) - assertEquivalent(EqualTo(Literal(4), If(a, nullInt, nullInt)), nullBoolean) assertEquivalent(EqualTo(Literal(4), caseWhen), FalseLiteral) } } From c11dbd0f79325b5e856ec4086417560500d243e4 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 21 Dec 2020 17:11:23 +0800 Subject: [PATCH 4/5] fix --- .../org/apache/spark/sql/catalyst/optimizer/expressions.scala | 4 ++-- .../catalyst/optimizer/PushFoldableIntoBranchesSuite.scala | 2 +- .../sql/catalyst/optimizer/SimplifyConditionalSuite.scala | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2c7346770146f..ac2caaeb15357 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -475,8 +475,8 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue case If(Literal(null, _), _, falseValue) => falseValue - case If(cond, TrueLiteral, FalseLiteral) if cond.deterministic => cond - case If(cond, FalseLiteral, TrueLiteral) if cond.deterministic => Not(cond) + case If(cond, TrueLiteral, FalseLiteral) => cond + case If(cond, FalseLiteral, TrueLiteral) => Not(cond) case If(cond, trueValue, falseValue) if cond.deterministic && trueValue.semanticEquals(falseValue) => trueValue case If(cond, l @ Literal(null, _), FalseLiteral) if !cond.nullable => And(cond, l) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala index c9124e6dc5f33..de4f4be8ec333 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PushFoldableIntoBranchesSuite.scala @@ -67,7 +67,7 @@ class PushFoldableIntoBranchesSuite val nonDeterministic = If(LessThan(Rand(1), Literal(0.5)), Literal(1), Literal(2)) assert(!nonDeterministic.deterministic) assertEquivalent(EqualTo(nonDeterministic, Literal(2)), - If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, TrueLiteral)) + GreaterThanOrEqual(Rand(1), Literal(0.5))) assertEquivalent(EqualTo(nonDeterministic, Literal(3)), If(LessThan(Rand(1), Literal(0.5)), FalseLiteral, FalseLiteral)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 438fa75a7836e..dc23f9ea485c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -210,9 +210,9 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P assertEquivalent( If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), - If(GreaterThan(Rand(0), UnresolvedAttribute("a")), TrueLiteral, FalseLiteral)) + GreaterThan(Rand(0), UnresolvedAttribute("a"))) assertEquivalent( If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral), - If(GreaterThan(Rand(0), UnresolvedAttribute("a")), FalseLiteral, TrueLiteral)) + LessThanOrEqual(Rand(0), UnresolvedAttribute("a"))) } } From 1055e481c6780956f2d2cfc5a1469086365e688f Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Mon, 21 Dec 2020 20:04:32 +0800 Subject: [PATCH 5/5] fix --- .../spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index dc23f9ea485c6..328fc107e1c1b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -200,7 +200,7 @@ class SimplifyConditionalSuite extends PlanTest with ExpressionEvalHelper with P } } - test("remove unnecessary if when the outputs are boolean type") { + test("SPARK-33845: remove unnecessary if when the outputs are boolean type") { assertEquivalent( If(IsNotNull(UnresolvedAttribute("a")), TrueLiteral, FalseLiteral), IsNotNull(UnresolvedAttribute("a")))