From 92dc106aec59a0f2755d7621d2d03831250cccc0 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Mon, 17 Jul 2017 21:06:18 +0800 Subject: [PATCH 1/6] Passing the joined row for condition evaluation --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 26fb6103953fc..c33f5a3561e28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -560,6 +560,9 @@ case class SortMergeJoinExec( val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) + // Create variable for joinned row + val joinedRow = ctx.freshName("joinedRow") + val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { @@ -569,6 +572,7 @@ case class SortMergeJoinExec( val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition ctx.currentVars = leftVars ++ rightVars + ctx.INPUT_ROW = joinedRow val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" @@ -597,6 +601,7 @@ case class SortMergeJoinExec( | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { | InternalRow $rightRow = (InternalRow) $iterator.next(); + | InternalRow $joinedRow = new JoinedRow($leftRow, $rightRow); | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} From 931c4f9c497c3344fd0fd91440340cb0bb8b3080 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 19 Jul 2017 10:10:55 +0800 Subject: [PATCH 2/6] Disable codegen for SortMergeJoinExec with CodegenFallback expressions --- .../apache/spark/sql/execution/WholeStageCodegenExec.scala | 6 +++--- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 5 ----- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index ac30b11557adb..dda62bf558308 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -489,13 +489,13 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { * Inserts an InputAdapter on top of those that do not support codegen. */ private def insertInputAdapter(plan: SparkPlan): SparkPlan = plan match { + case p if !supportCodegen(p) => + // collapse them recursively + InputAdapter(insertWholeStageCodegen(p)) case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen => // The children of SortMergeJoin should do codegen separately. j.copy(left = InputAdapter(insertWholeStageCodegen(left)), right = InputAdapter(insertWholeStageCodegen(right))) - case p if !supportCodegen(p) => - // collapse them recursively - InputAdapter(insertWholeStageCodegen(p)) case p => p.withNewChildren(p.children.map(insertInputAdapter)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index c33f5a3561e28..26fb6103953fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -560,9 +560,6 @@ case class SortMergeJoinExec( val rightRow = ctx.freshName("rightRow") val rightVars = createRightVar(ctx, rightRow) - // Create variable for joinned row - val joinedRow = ctx.freshName("joinedRow") - val iterator = ctx.freshName("iterator") val numOutput = metricTerm(ctx, "numOutputRows") val (beforeLoop, condCheck) = if (condition.isDefined) { @@ -572,7 +569,6 @@ case class SortMergeJoinExec( val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition ctx.currentVars = leftVars ++ rightVars - ctx.INPUT_ROW = joinedRow val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" @@ -601,7 +597,6 @@ case class SortMergeJoinExec( | scala.collection.Iterator $iterator = $matches.generateIterator(); | while ($iterator.hasNext()) { | InternalRow $rightRow = (InternalRow) $iterator.next(); - | InternalRow $joinedRow = new JoinedRow($leftRow, $rightRow); | ${condCheck.trim} | $numOutput.add(1); | ${consume(ctx, leftVars ++ rightVars)} From 1161ffdf879b287f03065ac9e3661ffddd3b64f3 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 19 Jul 2017 11:05:24 +0800 Subject: [PATCH 3/6] Avoid verifying supportCodegen again --- .../org/apache/spark/sql/execution/WholeStageCodegenExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index dda62bf558308..27ce7b2db1b70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -492,7 +492,7 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case p if !supportCodegen(p) => // collapse them recursively InputAdapter(insertWholeStageCodegen(p)) - case j @ SortMergeJoinExec(_, _, _, _, left, right) if j.supportCodegen => + case j @ SortMergeJoinExec(_, _, _, _, left, right) => // The children of SortMergeJoin should do codegen separately. j.copy(left = InputAdapter(insertWholeStageCodegen(left)), right = InputAdapter(insertWholeStageCodegen(right))) From b55cab315a3e7187f0c962bd78fe0ef71d08abc8 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 19 Jul 2017 17:46:30 +0800 Subject: [PATCH 4/6] Add a test case --- .../execution/WholeStageCodegenSuite.scala | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a4b30a2f8cec1..b1f0516e563f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -127,4 +127,25 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { "named_struct('a',id+2, 'b',id+2) as col2") .filter("col1 = col2").count() } + + test("SPARK-21441 SortMergeJoin codegen with CodegenFallback expressions should be disabled") { + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1") { + import testImplicits._ + + val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int") + val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str") + + // join condition contains CodegenFallback expression (i.e., reflect) + val df = df1.join(df2, df1("key") === df2("key")) + .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1") + .select("int") + + val plan = df.queryExecution.executedPlan + assert(!plan.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.children(0) + .isInstanceOf[SortMergeJoinExec]).isDefined) + assert(df.collect() === Array(Row(1), Row(2))) + } + } } From e2d8cefa17f741fc3d94da540c40b31d715bace1 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 19 Jul 2017 17:49:35 +0800 Subject: [PATCH 5/6] Add a test case --- .../org/apache/spark/sql/execution/WholeStageCodegenSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index b1f0516e563f6..e3ca988c94d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,10 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Column, Dataset, Row} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec +import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext From b53bf1c97b2e69dab0ac4afb9e293fe75577b057 Mon Sep 17 00:00:00 2001 From: donnyzone Date: Wed, 19 Jul 2017 18:49:22 +0800 Subject: [PATCH 6/6] fix scala style --- .../apache/spark/sql/execution/WholeStageCodegenSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index e3ca988c94d7b..183c68fd3c016 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.{Column, Dataset, Row} -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} import org.apache.spark.sql.execution.aggregate.HashAggregateExec @@ -26,6 +25,7 @@ import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -137,7 +137,6 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val df1 = Seq((1, 1), (2, 2), (3, 3)).toDF("key", "int") val df2 = Seq((1, "1"), (2, "2"), (3, "3")).toDF("key", "str") - // join condition contains CodegenFallback expression (i.e., reflect) val df = df1.join(df2, df1("key") === df2("key")) .filter("int = 2 or reflect('java.lang.Integer', 'valueOf', str) = 1") .select("int")