From aed59d7c14557b2330fcbd95a1eb9bb451375c94 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 24 Mar 2020 17:25:01 -0700 Subject: [PATCH 1/2] Conservatively constant fold Scala UDFs that meet all of the following criteria: - deterministic - all arguments are foldable - does not throw an exception upon evaluation --- .../sql/catalyst/optimizer/expressions.scala | 13 +++++ .../optimizer/ConstantFoldingSuite.scala | 51 +++++++++++++++++++ 2 files changed, 64 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index bd400f86ea2c1..8da67e0863203 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet import scala.collection.mutable.{ArrayBuffer, Stack} +import scala.util.{Failure, Success, Try} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ @@ -50,8 +51,20 @@ object ConstantFolding extends Rule[LogicalPlan] { // Fold expressions that are foldable. case e if e.foldable => Literal.create(e.eval(EmptyRow), e.dataType) + + // Fold ScalaUDFs if they're deterministic and all arguments are foldable. + // Watch out for potentially exception-throwing UDFs: some Scala UDFs may have been + // mis-declared as being deterministic, but throws exceptions at runtime. Do not optimize + // them, so that they can throw the exception at the expected timing. + case udf: ScalaUDF if maybeFoldable(udf) => Try(udf.eval(EmptyRow)) match { + case Success(v) => Literal.create(v, udf.dataType) + case Failure(_) => udf // defer any exception throwing to execution phase + } } } + + private def maybeFoldable(udf: ScalaUDF): Boolean = + udf.deterministic && udf.children.forall(_.foldable) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 23ab6b2df3e64..a39f2afe5b178 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -155,6 +155,57 @@ class ConstantFoldingSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Constant folding test: deterministic Scala UDFs") { + val normalFunc = (x: Int) => x + 41 + val exceptionFunc = (x: Int) => x / 0 + + val foldableUdf = ScalaUDF( + function = normalFunc, + dataType = IntegerType, + children = Seq(Literal(1)), + inputPrimitives = Seq(true), + udfName = None, + nullable = false, + udfDeterministic = true) + + val deterministicUnfoldableUdf = ScalaUDF( + function = normalFunc, + dataType = IntegerType, + children = Seq[Expression]('a), + inputPrimitives = Seq(true), + udfName = None, + nullable = false, + udfDeterministic = true) + + val exceptionUdf = ScalaUDF( + function = exceptionFunc, + dataType = IntegerType, + children = Seq(Literal(1)), + inputPrimitives = Seq(true), + udfName = None, + nullable = false, + udfDeterministic = true) // intentionally mis-declaring as deterministic + + val originalQuery = + testRelation + .select( + foldableUdf as Symbol("c1"), + deterministicUnfoldableUdf as Symbol("c2"), + exceptionUdf as Symbol("c3")) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = + testRelation + .select( + Literal(42) as Symbol("c1"), + deterministicUnfoldableUdf as Symbol("c2"), + exceptionUdf as Symbol("c3")) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("Constant folding test: expressions have nonfoldable functions") { val originalQuery = testRelation From ab15f5cbd4befc78abb7e1cafdd9729ec29b7216 Mon Sep 17 00:00:00 2001 From: Kris Mok Date: Tue, 24 Mar 2020 18:56:44 -0700 Subject: [PATCH 2/2] Rebase to latest master and fix compilation error in test suite --- .../sql/catalyst/optimizer/ConstantFoldingSuite.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index a39f2afe5b178..1065e809aeed2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -159,11 +160,13 @@ class ConstantFoldingSuite extends PlanTest { val normalFunc = (x: Int) => x + 41 val exceptionFunc = (x: Int) => x / 0 + val intEncoder = Option(ExpressionEncoder[Int]()) + val foldableUdf = ScalaUDF( function = normalFunc, dataType = IntegerType, children = Seq(Literal(1)), - inputPrimitives = Seq(true), + inputEncoders = Seq(intEncoder), udfName = None, nullable = false, udfDeterministic = true) @@ -172,7 +175,7 @@ class ConstantFoldingSuite extends PlanTest { function = normalFunc, dataType = IntegerType, children = Seq[Expression]('a), - inputPrimitives = Seq(true), + inputEncoders = Seq(intEncoder), udfName = None, nullable = false, udfDeterministic = true) @@ -181,7 +184,7 @@ class ConstantFoldingSuite extends PlanTest { function = exceptionFunc, dataType = IntegerType, children = Seq(Literal(1)), - inputPrimitives = Seq(true), + inputEncoders = Seq(intEncoder), udfName = None, nullable = false, udfDeterministic = true) // intentionally mis-declaring as deterministic