diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveJoinOrientation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveJoinOrientation.scala new file mode 100644 index 0000000000000..8131f6d58f504 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveJoinOrientation.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.plans.InnerLike +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf + +/** + * This optimization rule detects if the probe side of the SortMerge join is smaller + * than build side. While joining, probe side is streamed and build side is buffered, + * so having a larger build side can cause memory issues. + */ +case class AdaptiveJoinOrientation(conf: SQLConf) extends Rule[LogicalPlan] { + + private def isMaterializedShuffleStage(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, shuffleExec: ShuffleQueryStageExec) + if shuffleExec.resultOption.get.isDefined => true + case _ => + false + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(left, right, _: InnerLike, _, _) + if Seq(left, right).forall(plan => isMaterializedShuffleStage(plan) && + plan.stats.sizeInBytes > conf.autoBroadcastJoinThreshold) && + left.stats.sizeInBytes < right.stats.sizeInBytes => + Project(j.output, j.copy(left = right, right = left)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index bc924e6978ddc..1f7a22705b972 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -79,7 +79,8 @@ case class AdaptiveSparkPlanExec( @transient private val optimizer = new RuleExecutor[LogicalPlan] { // TODO add more optimization rules override protected def batches: Seq[Batch] = Seq( - Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)) + Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)), + Batch("Orient SortMergeJoin", Once, AdaptiveJoinOrientation(conf)) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index c696d3f648ed1..a616882f2b51d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -25,6 +25,7 @@ import org.apache.log4j.Level import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent, SparkListenerJobStart} import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession, Strategy} import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight} +import org.apache.spark.sql.catalyst.plans.InnerLike import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec @@ -1147,4 +1148,62 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-SMJ_ORIENTATION: reorient SMJ using adaptive stats") { + withTable("testTbl1", "testTbl2") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + + def getProbeBuildSides(plan: SparkPlan): Option[(SparkPlan, SparkPlan)] = { + plan collectFirst { + case SortMergeJoinExec(_, _, _: InnerLike, _, left, right, _) => + (left, right) + } + } + + def validateOrientation(initialPlan: SparkPlan, updatedPlan: SparkPlan, + orderShouldBeChanged: Boolean): Unit = { + val (initialProbePlan, initialBuildPlan) = getProbeBuildSides(initialPlan).get + val (updatedProbePlan, updatedBuildPlan) = getProbeBuildSides(updatedPlan).get + if (orderShouldBeChanged) { + assert(updatedBuildPlan.output == initialProbePlan.output) + assert(updatedProbePlan.output == initialBuildPlan.output) + } else { + assert(updatedBuildPlan.output == initialBuildPlan.output) + assert(updatedProbePlan.output == initialProbePlan.output) + } + } + + val df1 = (0 until 10).toDF("col1").as("df1") + df1.write.format("parquet").saveAsTable("testTbl1") + + val df2 = (0 until 100).toDF("col1").as("df2") + df2.write.format("parquet").saveAsTable("testTbl2") + + val dfWrongOrder = + spark.sql("SELECT * from testTbl1 JOIN testTbl2 ON testTbl1.col1 = testTbl2.col1") + val dfCorrectOrder = + spark.sql("SELECT * from testTbl2 JOIN testTbl1 ON testTbl1.col1 = testTbl2.col1") + + val initialAdaptivePlan1 = dfWrongOrder.queryExecution.executedPlan + val initialPlan1 = initialAdaptivePlan1.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + dfWrongOrder.collect + val adaptivePlan1 = dfWrongOrder.queryExecution.executedPlan + val updatedPlan1 = adaptivePlan1.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + // Join orientation should have changed + validateOrientation(initialPlan1, updatedPlan1, true) + // Project should correct the result order + assert(initialPlan1.output == adaptivePlan1.output) + + val initialAdaptivePlan2 = dfCorrectOrder.queryExecution.executedPlan + val initialPlan2 = initialAdaptivePlan2.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + dfCorrectOrder.collect + val adaptivePlan2 = dfCorrectOrder.queryExecution.executedPlan + val updatedPlan2 = adaptivePlan2.asInstanceOf[AdaptiveSparkPlanExec].executedPlan + // Join orientation should not have changed + validateOrientation(initialPlan2, updatedPlan2, false) + } + } + } }