From 27ce39af9a0660b353e624551549e9610260a3f9 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Wed, 16 Jun 2021 16:19:55 +0800 Subject: [PATCH 1/4] [SPARK-35688][SQL]Subexpressions should be lazy evaluation in GeneratePredicate --- .../expressions/codegen/CodeGenerator.scala | 145 ++++++++++++++---- .../codegen/GeneratePredicate.scala | 10 +- .../org/apache/spark/sql/DataFrameSuite.scala | 30 ++++ 3 files changed, 148 insertions(+), 37 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9831b13ea754f..f056645e77309 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -417,6 +417,10 @@ class CodegenContext extends Logging { // The collection of sub-expression result resetting methods that need to be called on each row. private val subexprFunctions = mutable.ArrayBuffer.empty[String] + // The collection of reset sub-expression, in lazy evaluation sub-expression, we should invoke + // after processing sub-expression. + private val subexprResetFunctions = mutable.ArrayBuffer.empty[String] + val outerClassName = "OuterClass" /** @@ -1012,6 +1016,14 @@ class CodegenContext extends Logging { splitExpressions(subexprFunctions.toSeq, "subexprFunc_split", Seq("InternalRow" -> INPUT_ROW)) } + /** + * Returns the code for reset subexpression after splitting it if necessary. + */ + def subexprResetFunctionCode: String = { + assert(currentVars == null || subexprResetFunctions.isEmpty) + splitExpressions(subexprResetFunctions.toSeq, "subexprResetFunc_split", Seq()) + } + /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1136,7 +1148,9 @@ class CodegenContext extends Logging { * common subexpressions, generates the functions that evaluate those expressions and populates * the mapping of common subexpressions to the generated functions. */ - private def subexpressionElimination(expressions: Seq[Expression]): Unit = { + private def subexpressionElimination( + expressions: Seq[Expression], + lazyEval: Boolean = false): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -1145,40 +1159,104 @@ class CodegenContext extends Logging { val commonExprs = equivalentExpressions.getAllEquivalentExprs(1) commonExprs.foreach { e => val expr = e.head - val fnName = freshName("subExpr") - val isNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") - val value = addMutableState(javaType(expr.dataType), "subExprValue") - // Generate the code for this expression tree and wrap it in a function. val eval = expr.genCode(this) - val fn = - s""" - |private void $fnName(InternalRow $INPUT_ROW) { - | ${eval.code} - | $isNull = ${eval.isNull}; - | $value = ${eval.value}; - |} + + val subExprValue = addMutableState(javaType(expr.dataType), "subExprValue") + val subExprValueisNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val evalSubExprValueFnName = freshName("evalSubExprValue") + + if (!lazyEval) { + val fn = + s""" + |private void $evalSubExprValueFnName(InternalRow $INPUT_ROW) { + | ${eval.code} + | $subExprValueisNull = ${eval.isNull}; + | $subExprValue = ${eval.value}; + |} """.stripMargin - // Add a state and a mapping of the common subexpressions that are associate with this - // state. Adding this expression to subExprEliminationExprMap means it will call `fn` - // when it is code generated. This decision should be a cost based one. - // - // The cost of doing subexpression elimination is: - // 1. Extra function call, although this is probably *good* as the JIT can decide to - // inline or not. - // The benefit doing subexpression elimination is: - // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 - // above. - // 2. Less code. - // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with - // at least two nodes) as the cost of doing it is expected to be low. - - subexprFunctions += s"${addNewFunction(fnName, fn)}($INPUT_ROW);" - val state = SubExprEliminationState( - JavaCode.isNullGlobal(isNull), - JavaCode.global(value, expr.dataType)) - subExprEliminationExprs ++= e.map(_ -> state).toMap + // Add a state and a mapping of the common subexpressions that are associate with this + // state. Adding this expression to subExprEliminationExprMap means it will call `fn` + // when it is code generated. This decision should be a cost based one. + // + // The cost of doing subexpression elimination is: + // 1. Extra function call, although this is probably *good* as the JIT can decide to + // inline or not. + // The benefit doing subexpression elimination is: + // 1. Running the expression logic. Even for a simple expression, it is likely more than 3 + // above. + // 2. Less code. + // Currently, we will do this for all non-leaf only expression trees (i.e. expr trees with + // at least two nodes) as the cost of doing it is expected to be low. + + subexprFunctions += s"${addNewFunction(evalSubExprValueFnName, fn)}($INPUT_ROW);" + val state = SubExprEliminationState( + JavaCode.isNullGlobal(subExprValueisNull), + JavaCode.global(subExprValue, expr.dataType)) + subExprEliminationExprs ++= e.map(_ -> state).toMap + } else { + + // the variable to check if a subexpression is evaulated + val isSubExprEval = addMutableState(JAVA_BOOLEAN, "isSubExprEval") + + val evalSubExprValueFnName = freshName("evalSubExprValue") + val evalSubExprValueFn = + s""" + |private void $evalSubExprValueFnName(InternalRow ${INPUT_ROW}) { + | ${eval.code} + | $subExprValueisNull = ${eval.isNull}; + | $subExprValue = ${eval.value}; + | $isSubExprEval = true; + |} + |""".stripMargin + + val getSubExprValueFnName = freshName("getSubExprValue") + val getSubExprValueFn = + s""" + |private ${boxedType(expr.dataType)} $getSubExprValueFnName(InternalRow ${INPUT_ROW}) { + | if (!$isSubExprEval) { + | $evalSubExprValueFnName($INPUT_ROW); + | } + | return ${subExprValue}; + |} + |""".stripMargin + + val getSubExprValueIsNullFnName = freshName("getSubExprValueIsNull") + val getSubExprValueIsNullFn = + s""" + |private boolean ${getSubExprValueIsNullFnName}(InternalRow ${INPUT_ROW}) { + | if (!$isSubExprEval) { + | $evalSubExprValueFnName($INPUT_ROW); + | } + | return $subExprValueisNull; + |} + |""".stripMargin + + // the function for reset subexpression after processing. + val resetFnName = freshName("resetSubExpr") + val resetFn = + s""" + |private void $resetFnName() { + | $isSubExprEval = false; + |} + |""".stripMargin + + addNewFunction(evalSubExprValueFnName, evalSubExprValueFn) + subexprResetFunctions += s"${addNewFunction(resetFnName, resetFn)}();" + + val splitIsNull = splitExpressions(Seq( + s"${addNewFunction(getSubExprValueIsNullFnName, getSubExprValueIsNullFn)}($INPUT_ROW)"), + s"${getSubExprValueIsNullFnName}_split", Seq("InternalRow" -> INPUT_ROW)) + val splitValue = splitExpressions( + Seq(s"${addNewFunction(getSubExprValueFnName, getSubExprValueFn)}($INPUT_ROW)"), + s"${getSubExprValueFnName}_split", Seq("InternalRow" -> INPUT_ROW)) + val state = SubExprEliminationState( + JavaCode.isNullGlobal(splitIsNull), + JavaCode.global(splitValue, expr.dataType)) + + subExprEliminationExprs ++= e.map(_ -> state).toMap + } } } @@ -1189,8 +1267,9 @@ class CodegenContext extends Logging { */ def generateExpressions( expressions: Seq[Expression], - doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { - if (doSubexpressionElimination) subexpressionElimination(expressions) + doSubexpressionElimination: Boolean = false, + lazyEvalSubexpression: Boolean = false): Seq[ExprCode] = { + if (doSubexpressionElimination) subexpressionElimination(expressions, lazyEvalSubexpression) expressions.map(e => e.genCode(this)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index c246d07f189b4..3d64ac73f7268 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -38,8 +38,9 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { val ctx = newCodeGenContext() // Do sub-expression elimination for predicates. - val eval = ctx.generateExpressions(Seq(predicate), useSubexprElimination).head - val evalSubexpr = ctx.subexprFunctionsCode + val eval = + ctx.generateExpressions(Seq(predicate), useSubexprElimination, useSubexprElimination).head + val subExprReset = ctx.subexprResetFunctionCode val codeBody = s""" public SpecificPredicate generate(Object[] references) { @@ -60,9 +61,10 @@ object GeneratePredicate extends CodeGenerator[Expression, BasePredicate] { } public boolean eval(InternalRow ${ctx.INPUT_ROW}) { - $evalSubexpr ${eval.code} - return !${eval.isNull} && ${eval.value}; + boolean result = !${eval.isNull} && ${eval.value}; + $subExprReset; + return result; } ${ctx.declareAddedFunctions()} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 3e810a4533770..51b691819fd5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2907,6 +2907,36 @@ class DataFrameSuite extends QueryTest } } } + + test("SPARK-35688: subexpressions should be lazy evaluation in GeneratePredicate") { + withTempPath(dir => { + Seq( + ("true", "false"), + ("false", "true"), + ("false", "false"), + ("true", "true"), + ).foreach { case (subExprEliminationEnabled, codegenEnabled) => + + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, + "spark.sql.ansi.enabled" -> "true") { + Seq( + (1 to 10).toArray, + (1 to 5).toArray + ).toDF("c1") + .write + .mode("overwrite") + .save(dir.getCanonicalPath) + val df = spark.read.load(dir.getCanonicalPath) + .filter("size(c1) > 5 and (element_at(c1, 7) = 8 or element_at(c1, 7) = 7)") + checkAnswer( + df, Row((1 to 10).toArray) :: Nil + ) + } + } + }) + } } case class GroupByKey(a: Int, b: Int) From f6ae8a36deb41122be0b9515605e1e714363974e Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Sun, 20 Jun 2021 14:05:28 +0800 Subject: [PATCH 2/4] fix style --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 51b691819fd5c..f7d0139570b17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2909,7 +2909,7 @@ class DataFrameSuite extends QueryTest } test("SPARK-35688: subexpressions should be lazy evaluation in GeneratePredicate") { - withTempPath(dir => { + withTempPath { dir => Seq( ("true", "false"), ("false", "true"), @@ -2935,7 +2935,7 @@ class DataFrameSuite extends QueryTest ) } } - }) + } } } From 8642336013d7ce148091fb94299cebf157ca8a4e Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Sun, 20 Jun 2021 14:48:42 +0800 Subject: [PATCH 3/4] fix style --- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index f7d0139570b17..d125ebc763ad7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2914,9 +2914,8 @@ class DataFrameSuite extends QueryTest ("true", "false"), ("false", "true"), ("false", "false"), - ("true", "true"), + ("true", "true") ).foreach { case (subExprEliminationEnabled, codegenEnabled) => - withSQLConf( SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> subExprEliminationEnabled, SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> codegenEnabled, From 05a11c880da09799bccfbd3f5dd48c3448315c27 Mon Sep 17 00:00:00 2001 From: Fu Chen Date: Mon, 21 Jun 2021 13:59:07 +0800 Subject: [PATCH 4/4] add test --- .../expressions/codegen/CodeGenerator.scala | 25 +++++++++++-------- .../codegen/GenerateMutableProjection.scala | 17 ++++++++----- .../codegen/GenerateUnsafeProjection.scala | 17 +++++++------ .../expressions/CodeGenerationSuite.scala | 6 +++-- 4 files changed, 39 insertions(+), 26 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f056645e77309..1dc18b3a130cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -1150,7 +1150,7 @@ class CodegenContext extends Logging { */ private def subexpressionElimination( expressions: Seq[Expression], - lazyEval: Boolean = false): Unit = { + lazyEvaluation: Boolean = false): Unit = { // Add each expression tree and compute the common subexpressions. expressions.foreach(equivalentExpressions.addExprTree(_)) @@ -1163,15 +1163,15 @@ class CodegenContext extends Logging { val eval = expr.genCode(this) val subExprValue = addMutableState(javaType(expr.dataType), "subExprValue") - val subExprValueisNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") + val subExprValueIsNull = addMutableState(JAVA_BOOLEAN, "subExprIsNull") val evalSubExprValueFnName = freshName("evalSubExprValue") - if (!lazyEval) { + if (!lazyEvaluation) { val fn = s""" |private void $evalSubExprValueFnName(InternalRow $INPUT_ROW) { | ${eval.code} - | $subExprValueisNull = ${eval.isNull}; + | $subExprValueIsNull = ${eval.isNull}; | $subExprValue = ${eval.value}; |} """.stripMargin @@ -1192,12 +1192,12 @@ class CodegenContext extends Logging { subexprFunctions += s"${addNewFunction(evalSubExprValueFnName, fn)}($INPUT_ROW);" val state = SubExprEliminationState( - JavaCode.isNullGlobal(subExprValueisNull), + JavaCode.isNullGlobal(subExprValueIsNull), JavaCode.global(subExprValue, expr.dataType)) subExprEliminationExprs ++= e.map(_ -> state).toMap } else { - // the variable to check if a subexpression is evaulated + // the variable to check if a subexpression is evaluated or not. val isSubExprEval = addMutableState(JAVA_BOOLEAN, "isSubExprEval") val evalSubExprValueFnName = freshName("evalSubExprValue") @@ -1205,18 +1205,22 @@ class CodegenContext extends Logging { s""" |private void $evalSubExprValueFnName(InternalRow ${INPUT_ROW}) { | ${eval.code} - | $subExprValueisNull = ${eval.isNull}; + | $subExprValueIsNull = ${eval.isNull}; | $subExprValue = ${eval.value}; | $isSubExprEval = true; |} |""".stripMargin + val splitEvalSubExprValueFnName = splitExpressions(Seq( + s"${addNewFunction(evalSubExprValueFnName, evalSubExprValueFn)}($INPUT_ROW)"), + s"${evalSubExprValueFnName}_split", Seq("InternalRow" -> INPUT_ROW)) + val getSubExprValueFnName = freshName("getSubExprValue") val getSubExprValueFn = s""" |private ${boxedType(expr.dataType)} $getSubExprValueFnName(InternalRow ${INPUT_ROW}) { | if (!$isSubExprEval) { - | $evalSubExprValueFnName($INPUT_ROW); + | $splitEvalSubExprValueFnName; | } | return ${subExprValue}; |} @@ -1227,9 +1231,9 @@ class CodegenContext extends Logging { s""" |private boolean ${getSubExprValueIsNullFnName}(InternalRow ${INPUT_ROW}) { | if (!$isSubExprEval) { - | $evalSubExprValueFnName($INPUT_ROW); + | $splitEvalSubExprValueFnName; | } - | return $subExprValueisNull; + | return $subExprValueIsNull; |} |""".stripMargin @@ -1242,7 +1246,6 @@ class CodegenContext extends Logging { |} |""".stripMargin - addNewFunction(evalSubExprValueFnName, evalSubExprValueFn) subexprResetFunctions += s"${addNewFunction(resetFnName, resetFn)}();" val splitIsNull = splitExpressions(Seq( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 2e018de07101e..743d265be9234 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -42,26 +42,31 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP expressions: Seq[Expression], inputSchema: Seq[Attribute], useSubexprElimination: Boolean): MutableProjection = { - create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination, false) } - def generate(expressions: Seq[Expression], useSubexprElimination: Boolean): MutableProjection = { - create(canonicalize(expressions), useSubexprElimination) + def generate( + expressions: Seq[Expression], + useSubexprElimination: Boolean, + lazyEvaluation: Boolean = false): MutableProjection = { + create(canonicalize(expressions), useSubexprElimination, lazyEvaluation) } protected def create(expressions: Seq[Expression]): MutableProjection = { - create(expressions, false) + create(expressions, false, false) } private def create( expressions: Seq[Expression], - useSubexprElimination: Boolean): MutableProjection = { + useSubexprElimination: Boolean, + lazyEvaluation: Boolean): MutableProjection = { val ctx = newCodeGenContext() val validExpr = expressions.zipWithIndex.filter { case (NoOp, _) => false case _ => true } - val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination) + val exprVals = ctx.generateExpressions(validExpr.map(_._1), useSubexprElimination, + lazyEvaluation) // 4-tuples: (code for projection, isNull variable name, value variable name, column index) val projectionCodes: Seq[(String, String)] = validExpr.zip(exprVals).map { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 459c1d9a8ba11..ab4c438c62dd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -286,8 +286,9 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def createCode( ctx: CodegenContext, expressions: Seq[Expression], - useSubexprElimination: Boolean = false): ExprCode = { - val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination) + useSubexprElimination: Boolean = false, + lazyEvaluation: Boolean = false): ExprCode = { + val exprEvals = ctx.generateExpressions(expressions, useSubexprElimination, lazyEvaluation) val exprSchemas = expressions.map(e => Schema(e.dataType, e.nullable)) val numVarLenFields = exprSchemas.count { @@ -323,19 +324,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro def generate( expressions: Seq[Expression], - subexpressionEliminationEnabled: Boolean): UnsafeProjection = { - create(canonicalize(expressions), subexpressionEliminationEnabled) + subexpressionEliminationEnabled: Boolean, + lazyEvaluation: Boolean = false): UnsafeProjection = { + create(canonicalize(expressions), subexpressionEliminationEnabled, lazyEvaluation) } protected def create(references: Seq[Expression]): UnsafeProjection = { - create(references, subexpressionEliminationEnabled = false) + create(references, subexpressionEliminationEnabled = false, lazyEvaluation = false) } private def create( expressions: Seq[Expression], - subexpressionEliminationEnabled: Boolean): UnsafeProjection = { + subexpressionEliminationEnabled: Boolean, + lazyEvaluation: Boolean): UnsafeProjection = { val ctx = newCodeGenContext() - val eval = createCode(ctx, expressions, subexpressionEliminationEnabled) + val eval = createCode(ctx, expressions, subexpressionEliminationEnabled, false) val codeBody = s""" diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 44b6aa6b6271f..fedb67876b53c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -535,8 +535,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { Add(BoundReference(colIndex, DoubleType, true), BoundReference(numOfExprs + colIndex, DoubleType, true)))) // these should not fail to compile due to 64K limit - GenerateUnsafeProjection.generate(exprs, true) - GenerateMutableProjection.generate(exprs, true) + GenerateUnsafeProjection.generate(exprs, true, false) + GenerateMutableProjection.generate(exprs, true, false) + GenerateUnsafeProjection.generate(exprs, true, true) + GenerateMutableProjection.generate(exprs, true, true) } test("SPARK-32624: Use CodeGenerator.typeName() to fix byte[] compile issue") {