From 5dff653ea1f8678162bc9364915099d88aebf0f5 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Sun, 18 Nov 2018 01:09:53 -0800 Subject: [PATCH 1/2] Extend ReplaceNullWithFalseInPredicate to support higher-order functions: ArrayExists, ArrayFilter, MapFilter --- .../sql/catalyst/optimizer/Optimizer.scala | 2 +- .../sql/catalyst/optimizer/expressions.scala | 11 ++++- ...eplaceNullWithFalseInPredicateSuite.scala} | 48 +++++++++++++++++-- 3 files changed, 55 insertions(+), 6 deletions(-) rename sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/{ReplaceNullWithFalseSuite.scala => ReplaceNullWithFalseInPredicateSuite.scala} (86%) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a330a84a3a24f..8d251eeab8484 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -84,7 +84,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyConditionals, RemoveDispensableExpressions, SimplifyBinaryComparison, - ReplaceNullWithFalse, + ReplaceNullWithFalseInPredicate, PruneFilters, EliminateSorts, SimplifyCasts, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 2b29b49d00ab9..354efd883f814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -755,7 +755,7 @@ object CombineConcats extends Rule[LogicalPlan] { * * As a result, many unnecessary computations can be removed in the query optimization phase. */ -object ReplaceNullWithFalse extends Rule[LogicalPlan] { +object ReplaceNullWithFalseInPredicate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case f @ Filter(cond, _) => f.copy(condition = replaceNullWithFalse(cond)) @@ -767,6 +767,15 @@ object ReplaceNullWithFalse extends Rule[LogicalPlan] { replaceNullWithFalse(cond) -> value } cw.copy(branches = newBranches) + case af @ ArrayFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + af.copy(function = newLambda) + case ae @ ArrayExists(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + ae.copy(function = newLambda) + case mf @ MapFilter(_, lf @ LambdaFunction(func, _, _)) => + val newLambda = lf.copy(function = replaceNullWithFalse(func)) + mf.copy(function = newLambda) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala similarity index 86% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index c6b5d0ec96776..5de29e0ef02be 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, CaseWhen, Expression, GreaterThan, If, Literal, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.{BooleanType, IntegerType} -class ReplaceNullWithFalseSuite extends PlanTest { +class ReplaceNullWithFalseInPredicateSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = @@ -36,10 +36,11 @@ class ReplaceNullWithFalseSuite extends PlanTest { ConstantFolding, BooleanSimplification, SimplifyConditionals, - ReplaceNullWithFalse) :: Nil + ReplaceNullWithFalseInPredicate) :: Nil } - private val testRelation = LocalRelation('i.int, 'b.boolean) + private val testRelation = + LocalRelation('i.int, 'b.boolean, 'a.array(IntegerType), 'm.map(IntegerType, IntegerType)) private val anotherTestRelation = LocalRelation('d.int) test("replace null inside filter and join conditions") { @@ -298,6 +299,45 @@ class ReplaceNullWithFalseSuite extends PlanTest { testProjection(originalExpr = column, expectedExpr = column) } + test("replace nulls in lambda function of ArrayFilter") { + val cond = GreaterThan(UnresolvedAttribute("e"), Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = Seq(UnresolvedAttribute("e"))) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = Seq(UnresolvedAttribute("e"))) + testProjection( + originalExpr = ArrayFilter('a, lambda1) as 'x, + expectedExpr = ArrayFilter('a, lambda2) as 'x) + } + + test("replace nulls in lambda function of ArrayExists") { + val cond = GreaterThan(UnresolvedAttribute("e"), Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = Seq(UnresolvedAttribute("e"))) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = Seq(UnresolvedAttribute("e"))) + testProjection( + originalExpr = ArrayExists('a, lambda1) as 'x, + expectedExpr = ArrayExists('a, lambda2) as 'x) + } + + test("replace nulls in lambda function of MapFilter") { + val cond = GreaterThan(UnresolvedAttribute("k"), Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = Seq(UnresolvedAttribute("k"), UnresolvedAttribute("v"))) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = Seq(UnresolvedAttribute("k"), UnresolvedAttribute("v"))) + testProjection( + originalExpr = MapFilter('m, lambda1) as 'x, + expectedExpr = MapFilter('m, lambda2) as 'x) + } + private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { test((rel, exp) => rel.where(exp), originalCond, expectedCond) } From 6646a96c8b9e905e3cad0b29e7f4063551b23c4c Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Mon, 19 Nov 2018 00:55:18 -0800 Subject: [PATCH 2/2] Address comments --- ...ReplaceNullWithFalseInPredicateSuite.scala | 62 +++++++++---------- ...llWithFalseInPredicateEndToEndSuite.scala} | 45 +++++++++++++- 2 files changed, 74 insertions(+), 33 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{ReplaceNullWithFalseEndToEndSuite.scala => ReplaceNullWithFalseInPredicateEndToEndSuite.scala} (63%) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala index 5de29e0ef02be..3a9e6cae0fd87 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ReplaceNullWithFalseInPredicateSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, Or} +import org.apache.spark.sql.catalyst.expressions.{And, ArrayExists, ArrayFilter, ArrayTransform, CaseWhen, Expression, GreaterThan, If, LambdaFunction, Literal, MapFilter, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -300,42 +300,23 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { } test("replace nulls in lambda function of ArrayFilter") { - val cond = GreaterThan(UnresolvedAttribute("e"), Literal(0)) - val lambda1 = LambdaFunction( - function = If(cond, Literal(null, BooleanType), TrueLiteral), - arguments = Seq(UnresolvedAttribute("e"))) - val lambda2 = LambdaFunction( - function = If(cond, FalseLiteral, TrueLiteral), - arguments = Seq(UnresolvedAttribute("e"))) - testProjection( - originalExpr = ArrayFilter('a, lambda1) as 'x, - expectedExpr = ArrayFilter('a, lambda2) as 'x) + testHigherOrderFunc('a, ArrayFilter, Seq('e)) } test("replace nulls in lambda function of ArrayExists") { - val cond = GreaterThan(UnresolvedAttribute("e"), Literal(0)) - val lambda1 = LambdaFunction( - function = If(cond, Literal(null, BooleanType), TrueLiteral), - arguments = Seq(UnresolvedAttribute("e"))) - val lambda2 = LambdaFunction( - function = If(cond, FalseLiteral, TrueLiteral), - arguments = Seq(UnresolvedAttribute("e"))) - testProjection( - originalExpr = ArrayExists('a, lambda1) as 'x, - expectedExpr = ArrayExists('a, lambda2) as 'x) + testHigherOrderFunc('a, ArrayExists, Seq('e)) } test("replace nulls in lambda function of MapFilter") { - val cond = GreaterThan(UnresolvedAttribute("k"), Literal(0)) - val lambda1 = LambdaFunction( - function = If(cond, Literal(null, BooleanType), TrueLiteral), - arguments = Seq(UnresolvedAttribute("k"), UnresolvedAttribute("v"))) - val lambda2 = LambdaFunction( - function = If(cond, FalseLiteral, TrueLiteral), - arguments = Seq(UnresolvedAttribute("k"), UnresolvedAttribute("v"))) - testProjection( - originalExpr = MapFilter('m, lambda1) as 'x, - expectedExpr = MapFilter('m, lambda2) as 'x) + testHigherOrderFunc('m, MapFilter, Seq('k, 'v)) + } + + test("inability to replace nulls in arbitrary higher-order function") { + val lambdaFunc = LambdaFunction( + function = If('e > 0, Literal(null, BooleanType), TrueLiteral), + arguments = Seq[NamedExpression]('e)) + val column = ArrayTransform('a, lambdaFunc) + testProjection(originalExpr = column, expectedExpr = column) } private def testFilter(originalCond: Expression, expectedCond: Expression): Unit = { @@ -350,6 +331,25 @@ class ReplaceNullWithFalseInPredicateSuite extends PlanTest { test((rel, exp) => rel.select(exp), originalExpr, expectedExpr) } + private def testHigherOrderFunc( + argument: Expression, + createExpr: (Expression, Expression) => Expression, + lambdaArgs: Seq[NamedExpression]): Unit = { + val condArg = lambdaArgs.last + // the lambda body is: if(arg > 0, null, true) + val cond = GreaterThan(condArg, Literal(0)) + val lambda1 = LambdaFunction( + function = If(cond, Literal(null, BooleanType), TrueLiteral), + arguments = lambdaArgs) + // the optimized lambda body is: if(arg > 0, false, true) + val lambda2 = LambdaFunction( + function = If(cond, FalseLiteral, TrueLiteral), + arguments = lambdaArgs) + testProjection( + originalExpr = createExpr(argument, lambda1) as 'x, + expectedExpr = createExpr(argument, lambda2) as 'x) + } + private def test( func: (LogicalPlan, Expression) => LogicalPlan, originalExpr: Expression, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala similarity index 63% rename from sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala index fc6ecc4e032f6..0f84b0c961a10 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ReplaceNullWithFalseInPredicateEndToEndSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If} +import org.apache.spark.sql.catalyst.expressions.{CaseWhen, If, Literal} import org.apache.spark.sql.execution.LocalTableScanExec import org.apache.spark.sql.functions.{lit, when} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType -class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext { +class ReplaceNullWithFalseInPredicateEndToEndSuite extends QueryTest with SharedSQLContext { import testImplicits._ test("SPARK-25860: Replace Literal(null, _) with FalseLiteral whenever possible") { @@ -68,4 +69,44 @@ class ReplaceNullWithFalseEndToEndSuite extends QueryTest with SharedSQLContext case p => fail(s"$p is not LocalTableScanExec") } } + + test("SPARK-26107: Replace Literal(null, _) with FalseLiteral in higher-order functions") { + def assertNoLiteralNullInPlan(df: DataFrame): Unit = { + df.queryExecution.executedPlan.foreach { p => + assert(p.expressions.forall(_.find { + case Literal(null, BooleanType) => true + case _ => false + }.isEmpty)) + } + } + + withTable("t1", "t2") { + // to test ArrayFilter and ArrayExists + spark.sql("select array(null, 1, null, 3) as a") + .write.saveAsTable("t1") + // to test MapFilter + spark.sql(""" + select map_from_entries(arrays_zip(a, transform(a, e -> if(mod(e, 2) = 0, null, e)))) as m + from (select array(0, 1, 2, 3) as a) + """).write.saveAsTable("t2") + + val df1 = spark.table("t1") + val df2 = spark.table("t2") + + // ArrayExists + val q1 = df1.selectExpr("EXISTS(a, e -> IF(e is null, null, true))") + checkAnswer(q1, Row(true) :: Nil) + assertNoLiteralNullInPlan(q1) + + // ArrayFilter + val q2 = df1.selectExpr("FILTER(a, e -> IF(e is null, null, true))") + checkAnswer(q2, Row(Seq[Any](1, 3)) :: Nil) + assertNoLiteralNullInPlan(q2) + + // MapFilter + val q3 = df2.selectExpr("MAP_FILTER(m, (k, v) -> IF(v is null, null, true))") + checkAnswer(q3, Row(Map[Any, Any](1 -> 1, 3 -> 3))) + assertNoLiteralNullInPlan(q3) + } + } }