diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 5a994f1ad0a39..2880e87ab1566 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -410,7 +410,7 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre */ def unapply(join: Join): Option[ReturnType] = join match { case Join(left, right, LeftAnti, - Some(Or(e @ EqualTo(leftAttr: AttributeReference, rightAttr: AttributeReference), + Some(Or(e @ EqualTo(leftAttr: Expression, rightAttr: Expression), IsNull(e2 @ EqualTo(_, _)))), _) if SQLConf.get.optimizeNullAwareAntiJoin && e.semanticEquals(e2) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala index b160b8ac2ed68..6b56feceb2665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/AdaptiveSparkPlanExec.scala @@ -79,7 +79,8 @@ case class AdaptiveSparkPlanExec( @transient private val optimizer = new RuleExecutor[LogicalPlan] { // TODO add more optimization rules override protected def batches: Seq[Batch] = Seq( - Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)) + Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)), + Batch("Eliminate Null Aware Anti Join", Once, EliminateNullAwareAntiJoin) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateNullAwareAntiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateNullAwareAntiJoin.scala new file mode 100644 index 0000000000000..4e0247e2f4bb5 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/EliminateNullAwareAntiJoin.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.adaptive + +import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.joins.EmptyHashedRelationWithAllNullKeys + +/** + * This optimization rule detects and convert a NAAJ to an Empty LocalRelation + * when buildSide is EmptyHashedRelationWithAllNullKeys. + */ +object EliminateNullAwareAntiJoin extends Rule[LogicalPlan] { + + private def canEliminate(plan: LogicalPlan): Boolean = plan match { + case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined + && stage.broadcast.relationFuture.get().value == EmptyHashedRelationWithAllNullKeys => true + case _ => false + } + + def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if canEliminate(j.right) => + LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index e4935c8c72228..5df06f8e4d4fa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -213,9 +213,11 @@ case class BroadcastHashJoinExec( (broadcastRelation, relationTerm) } - protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = { + protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - (relationTerm, broadcastRelation.value.keyIsUnique) + HashedRelationInfo(relationTerm, + broadcastRelation.value.keyIsUnique, + broadcastRelation.value == EmptyHashedRelation) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 1c6504b141890..2154e370a1596 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -29,6 +29,16 @@ import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType} +/** + * @param relationTerm variable name for HashedRelation + * @param keyIsUnique indicate whether keys of HashedRelation known to be unique in code-gen time + * @param isEmpty indicate whether it known to be EmptyHashedRelation in code-gen time + */ +private[joins] case class HashedRelationInfo( + relationTerm: String, + keyIsUnique: Boolean, + isEmpty: Boolean) + trait HashJoin extends BaseJoinExec with CodegenSupport { def buildSide: BuildSide @@ -270,6 +280,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { private def antiJoin( streamIter: Iterator[InternalRow], hashedRelation: HashedRelation): Iterator[InternalRow] = { + // If the right side is empty, AntiJoin simply returns the left side. + if (hashedRelation == EmptyHashedRelation) { + return streamIter + } + val joinKeys = streamSideKeyGenerator() val joinedRow = new JoinedRow @@ -417,7 +432,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for Inner join. */ protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") @@ -467,7 +482,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for left or right outer join. */ protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) @@ -544,7 +559,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for left semi join. */ protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val (matched, checkCondition, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") @@ -593,10 +608,18 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for anti join. */ protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx) + val numOutput = metricTerm(ctx, "numOutputRows") + if (isEmptyHashedRelation) { + return s""" + |// If the right side is empty, Anti Join simply returns the left side. + |$numOutput.add(1); + |${consume(ctx, input)} + |""".stripMargin + } + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val (matched, checkCondition, _) = getJoinCondition(ctx, input) - val numOutput = metricTerm(ctx, "numOutputRows") if (keyIsUnique) { val found = ctx.freshName("found") @@ -654,7 +677,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { * Generates the code for existence join. */ protected def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (relationTerm, keyIsUnique) = prepareRelation(ctx) + val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") val existsVar = ctx.freshName("exists") @@ -715,12 +738,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport { } } - /** - * Returns a tuple of variable name for HashedRelation, - * and a boolean to indicate whether keys of HashedRelation - * known to be unique in code-gen time. - */ - protected def prepareRelation(ctx: CodegenContext): (String, Boolean) + protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo } object HashJoin { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f2835c2fa6626..0d40520ae71a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -108,7 +108,7 @@ private[execution] object HashedRelation { 0) } - if (isNullAware && !input.hasNext) { + if (!input.hasNext) { EmptyHashedRelation } else if (key.length == 1 && key.head.dataType == LongType) { LongHashedRelation(input, key, sizeEstimate, mm, isNullAware) @@ -950,8 +950,18 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable { /** * A special HashedRelation indicates it built from a empty input:Iterator[InternalRow]. + * get & getValue will return null just like + * empty LongHashedRelation or empty UnsafeHashedRelation does. */ object EmptyHashedRelation extends NullAwareHashedRelation { + override def get(key: Long): Iterator[InternalRow] = null + + override def get(key: InternalRow): Iterator[InternalRow] = null + + override def getValue(key: Long): InternalRow = null + + override def getValue(key: InternalRow): InternalRow = null + override def asReadOnlyCopy(): EmptyHashedRelation.type = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala index 9f811cddef6a7..41cefd03dd931 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoinExec.scala @@ -81,13 +81,13 @@ case class ShuffledHashJoinExec( override def needCopyResult: Boolean = true - protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = { + protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = { val thisPlan = ctx.addReferenceObj("plan", this) val clsName = classOf[HashedRelation].getName // Inline mutable state since not many join operations in a task val relationTerm = ctx.addMutableState(clsName, "relation", v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true) - (relationTerm, false) + HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 7fdcbd0d089cc..d3fb63e407f1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan} import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan} import org.apache.spark.sql.execution.command.DataWritingCommandExec import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} +import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -100,6 +100,12 @@ class AdaptiveQueryExecSuite } } + private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = { + collect(plan) { + case j: BaseJoinExec => j + } + } + private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = { collectWithSubqueries(plan) { case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e @@ -1148,4 +1154,18 @@ class AdaptiveQueryExecSuite } } } + + test("SPARK-32573: Eliminate NAAJ when BuildSide is EmptyHashedRelationWithAllNullKeys") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) { + val (plan, adaptivePlan) = runAdaptiveAndVerifyResult( + "SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)") + val bhj = findTopLevelBroadcastHashJoin(plan) + assert(bhj.size == 1) + val join = findTopLevelBaseJoin(adaptivePlan) + assert(join.isEmpty) + checkNumLocalShuffleReaders(adaptivePlan) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala index 21ee88f0d7426..8b270bd5a2636 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -580,4 +580,16 @@ class HashedRelationSuite extends SharedSparkSession { assert(proj(packedKeys).get(0, dt) == -i - 1) } } + + test("EmptyHashedRelation return null in get / getValue") { + val buildKey = Seq(BoundReference(0, LongType, false)) + val hashed = HashedRelation(Seq.empty[InternalRow].toIterator, buildKey, 1, mm) + assert(hashed == EmptyHashedRelation) + + val key = InternalRow(1L) + assert(hashed.get(0L) == null) + assert(hashed.get(key) == null) + assert(hashed.getValue(0L) == null) + assert(hashed.getValue(key) == null) + } }