Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ object ExtractSingleColumnNullAwareAntiJoin extends JoinSelectionHelper with Pre
*/
def unapply(join: Join): Option[ReturnType] = join match {
case Join(left, right, LeftAnti,
Some(Or(e @ EqualTo(leftAttr: AttributeReference, rightAttr: AttributeReference),
Some(Or(e @ EqualTo(leftAttr: Expression, rightAttr: Expression),
IsNull(e2 @ EqualTo(_, _)))), _)
if SQLConf.get.optimizeNullAwareAntiJoin &&
e.semanticEquals(e2) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ case class AdaptiveSparkPlanExec(
@transient private val optimizer = new RuleExecutor[LogicalPlan] {
// TODO add more optimization rules
override protected def batches: Seq[Batch] = Seq(
Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf))
Batch("Demote BroadcastHashJoin", Once, DemoteBroadcastHashJoin(conf)),
Batch("Eliminate Null Aware Anti Join", Once, EliminateNullAwareAntiJoin)
)
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.adaptive

import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJoin
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.joins.EmptyHashedRelationWithAllNullKeys

/**
* This optimization rule detects and convert a NAAJ to an Empty LocalRelation
* when buildSide is EmptyHashedRelationWithAllNullKeys.
*/
object EliminateNullAwareAntiJoin extends Rule[LogicalPlan] {

private def canEliminate(plan: LogicalPlan): Boolean = plan match {
case LogicalQueryStage(_, stage: BroadcastQueryStageExec) if stage.resultOption.get().isDefined
&& stage.broadcast.relationFuture.get().value == EmptyHashedRelationWithAllNullKeys => true
case _ => false
}

def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if canEliminate(j.right) =>
LocalRelation(j.output, data = Seq.empty, isStreaming = j.isStreaming)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,11 @@ case class BroadcastHashJoinExec(
(broadcastRelation, relationTerm)
}

protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = {
protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = {
val (broadcastRelation, relationTerm) = prepareBroadcast(ctx)
(relationTerm, broadcastRelation.value.keyIsUnique)
HashedRelationInfo(relationTerm,
broadcastRelation.value.keyIsUnique,
broadcastRelation.value == EmptyHashedRelation)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@ import org.apache.spark.sql.execution.{CodegenSupport, ExplainUtils, RowIterator
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.types.{BooleanType, IntegralType, LongType}

/**
* @param relationTerm variable name for HashedRelation
* @param keyIsUnique indicate whether keys of HashedRelation known to be unique in code-gen time
* @param isEmpty indicate whether it known to be EmptyHashedRelation in code-gen time
*/
private[joins] case class HashedRelationInfo(
relationTerm: String,
keyIsUnique: Boolean,
isEmpty: Boolean)

trait HashJoin extends BaseJoinExec with CodegenSupport {
def buildSide: BuildSide

Expand Down Expand Up @@ -270,6 +280,11 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
private def antiJoin(
streamIter: Iterator[InternalRow],
hashedRelation: HashedRelation): Iterator[InternalRow] = {
// If the right side is empty, AntiJoin simply returns the left side.
if (hashedRelation == EmptyHashedRelation) {
return streamIter
}

val joinKeys = streamSideKeyGenerator()
val joinedRow = new JoinedRow

Expand Down Expand Up @@ -417,7 +432,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for Inner join.
*/
protected def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (relationTerm, keyIsUnique) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
Expand Down Expand Up @@ -467,7 +482,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for left or right outer join.
*/
protected def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (relationTerm, keyIsUnique) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val matched = ctx.freshName("matched")
val buildVars = genBuildSideVars(ctx, matched)
Expand Down Expand Up @@ -544,7 +559,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for left semi join.
*/
protected def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (relationTerm, keyIsUnique) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (matched, checkCondition, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
Expand Down Expand Up @@ -593,10 +608,18 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for anti join.
*/
protected def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (relationTerm, keyIsUnique) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, isEmptyHashedRelation) = prepareRelation(ctx)
val numOutput = metricTerm(ctx, "numOutputRows")
if (isEmptyHashedRelation) {
return s"""
|// If the right side is empty, Anti Join simply returns the left side.
|$numOutput.add(1);
|${consume(ctx, input)}
|""".stripMargin
}

val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val (matched, checkCondition, _) = getJoinCondition(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")

if (keyIsUnique) {
val found = ctx.freshName("found")
Expand Down Expand Up @@ -654,7 +677,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
* Generates the code for existence join.
*/
protected def codegenExistence(ctx: CodegenContext, input: Seq[ExprCode]): String = {
val (relationTerm, keyIsUnique) = prepareRelation(ctx)
val HashedRelationInfo(relationTerm, keyIsUnique, _) = prepareRelation(ctx)
val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input)
val numOutput = metricTerm(ctx, "numOutputRows")
val existsVar = ctx.freshName("exists")
Expand Down Expand Up @@ -715,12 +738,7 @@ trait HashJoin extends BaseJoinExec with CodegenSupport {
}
}

/**
* Returns a tuple of variable name for HashedRelation,
* and a boolean to indicate whether keys of HashedRelation
* known to be unique in code-gen time.
*/
protected def prepareRelation(ctx: CodegenContext): (String, Boolean)
protected def prepareRelation(ctx: CodegenContext): HashedRelationInfo
}

object HashJoin {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private[execution] object HashedRelation {
0)
}

if (isNullAware && !input.hasNext) {
if (!input.hasNext) {
EmptyHashedRelation
} else if (key.length == 1 && key.head.dataType == LongType) {
LongHashedRelation(input, key, sizeEstimate, mm, isNullAware)
Expand Down Expand Up @@ -950,8 +950,18 @@ trait NullAwareHashedRelation extends HashedRelation with Externalizable {

/**
* A special HashedRelation indicates it built from a empty input:Iterator[InternalRow].
* get & getValue will return null just like
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW where do we call these methods for EmptyHashedRelation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. It will be called. EmptyHashedRelation now is also applied at all joinType. I confirmed it while running UT.

* empty LongHashedRelation or empty UnsafeHashedRelation does.
*/
object EmptyHashedRelation extends NullAwareHashedRelation {
override def get(key: Long): Iterator[InternalRow] = null

override def get(key: InternalRow): Iterator[InternalRow] = null

override def getValue(key: Long): InternalRow = null

override def getValue(key: InternalRow): InternalRow = null

override def asReadOnlyCopy(): EmptyHashedRelation.type = this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ case class ShuffledHashJoinExec(

override def needCopyResult: Boolean = true

protected override def prepareRelation(ctx: CodegenContext): (String, Boolean) = {
protected override def prepareRelation(ctx: CodegenContext): HashedRelationInfo = {
val thisPlan = ctx.addReferenceObj("plan", this)
val clsName = classOf[HashedRelation].getName

// Inline mutable state since not many join operations in a task
val relationTerm = ctx.addMutableState(clsName, "relation",
v => s"$v = $thisPlan.buildHashedRelation(inputs[1]);", forceInline = true)
(relationTerm, false)
HashedRelationInfo(relationTerm, keyIsUnique = false, isEmpty = false)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan}
import org.apache.spark.sql.execution.{PartialReducerPartitionSpec, ReusedSubqueryExec, ShuffledRowRDD, SparkPlan}
import org.apache.spark.sql.execution.command.DataWritingCommandExec
import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, Exchange, ReusedExchangeExec, ShuffleExchangeExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.joins.{BaseJoinExec, BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.execution.ui.SparkListenerSQLAdaptiveExecutionUpdate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
Expand Down Expand Up @@ -100,6 +100,12 @@ class AdaptiveQueryExecSuite
}
}

private def findTopLevelBaseJoin(plan: SparkPlan): Seq[BaseJoinExec] = {
collect(plan) {
case j: BaseJoinExec => j
}
}

private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
collectWithSubqueries(plan) {
case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
Expand Down Expand Up @@ -1148,4 +1154,18 @@ class AdaptiveQueryExecSuite
}
}
}

test("SPARK-32573: Eliminate NAAJ when BuildSide is EmptyHashedRelationWithAllNullKeys") {
withSQLConf(
SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true",
SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> Long.MaxValue.toString) {
val (plan, adaptivePlan) = runAdaptiveAndVerifyResult(
"SELECT * FROM testData2 t1 WHERE t1.b NOT IN (SELECT b FROM testData3)")
val bhj = findTopLevelBroadcastHashJoin(plan)
assert(bhj.size == 1)
val join = findTopLevelBaseJoin(adaptivePlan)
assert(join.isEmpty)
checkNumLocalShuffleReaders(adaptivePlan)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -580,4 +580,16 @@ class HashedRelationSuite extends SharedSparkSession {
assert(proj(packedKeys).get(0, dt) == -i - 1)
}
}

test("EmptyHashedRelation return null in get / getValue") {
val buildKey = Seq(BoundReference(0, LongType, false))
val hashed = HashedRelation(Seq.empty[InternalRow].toIterator, buildKey, 1, mm)
assert(hashed == EmptyHashedRelation)

val key = InternalRow(1L)
assert(hashed.get(0L) == null)
assert(hashed.get(key) == null)
assert(hashed.getValue(0L) == null)
assert(hashed.getValue(key) == null)
}
}