Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2182,6 +2182,15 @@ object SQLConf {
.booleanConf
.createWithDefault(true)

val ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN =
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't entirely sure if this is the right approach, but the existing flag was a bit explicit in FULL OUTER JOIN, and in the discussion like #41398 (comment) build-side outer join is taken to mean, if only one side is OUTER.

buildConf("spark.sql.codegen.join.buildSideOuterShuffledHashJoin.enabled")
.internal()
.doc("When true, enable code-gen for an OUTER shuffled hash join where outer side" +
" is the build side.")
.version("3.5.0")
.booleanConf
.createWithDefault(true)

val ENABLE_FULL_OUTER_SORT_MERGE_JOIN_CODEGEN =
buildConf("spark.sql.codegen.join.fullOuterSortMergeJoin.enabled")
.internal()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,8 +340,10 @@ case class ShuffledHashJoinExec(

override def supportCodegen: Boolean = joinType match {
case FullOuter => conf.getConf(SQLConf.ENABLE_FULL_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
case LeftOuter if buildSide == BuildLeft => false
case RightOuter if buildSide == BuildRight => false
case LeftOuter if buildSide == BuildLeft =>
conf.getConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
case RightOuter if buildSide == BuildRight =>
conf.getConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN)
case _ => true
}

Expand All @@ -362,9 +364,15 @@ case class ShuffledHashJoinExec(
}

override def doProduce(ctx: CodegenContext): String = {
// Specialize `doProduce` code for full outer join, because full outer join needs to
// iterate streamed and build side separately.
if (joinType != FullOuter) {
// Specialize `doProduce` code for full outer join and build-side outer join,
// because we need to iterate streamed and build side separately.
val specializedProduce = joinType match {
case FullOuter => true
case LeftOuter if buildSide == BuildLeft => true
case RightOuter if buildSide == BuildRight => true
case _ => false
}
if (!specializedProduce) {
return super.doProduce(ctx)
}

Expand Down Expand Up @@ -407,21 +415,24 @@ case class ShuffledHashJoinExec(
case BuildLeft => buildResultVars ++ streamedResultVars
case BuildRight => streamedResultVars ++ buildResultVars
}
val consumeFullOuterJoinRow = ctx.freshName("consumeFullOuterJoinRow")
ctx.addNewFunction(consumeFullOuterJoinRow,
val consumeOuterJoinRow = ctx.freshName("consumeOuterJoinRow")
ctx.addNewFunction(consumeOuterJoinRow,
s"""
|private void $consumeFullOuterJoinRow() throws java.io.IOException {
|private void $consumeOuterJoinRow() throws java.io.IOException {
| ${metricTerm(ctx, "numOutputRows")}.add(1);
| ${consume(ctx, resultVars)}
|}
""".stripMargin)

val joinWithUniqueKey = codegenFullOuterJoinWithUniqueKey(
val isFullOuterJoin = joinType == FullOuter
val joinWithUniqueKey = codegenBuildSideOrFullOuterJoinWithUniqueKey(
ctx, (streamedRow, buildRow), (streamedInput, buildInput), streamedKeyEv, streamedKeyAnyNull,
streamedKeyExprCode.value, relationTerm, conditionCheck, consumeFullOuterJoinRow)
val joinWithNonUniqueKey = codegenFullOuterJoinWithNonUniqueKey(
streamedKeyExprCode.value, relationTerm, conditionCheck, consumeOuterJoinRow,
isFullOuterJoin)
val joinWithNonUniqueKey = codegenBuildSideOrFullOuterJoinNonUniqueKey(
ctx, (streamedRow, buildRow), (streamedInput, buildInput), streamedKeyEv, streamedKeyAnyNull,
streamedKeyExprCode.value, relationTerm, conditionCheck, consumeFullOuterJoinRow)
streamedKeyExprCode.value, relationTerm, conditionCheck, consumeOuterJoinRow,
isFullOuterJoin)

s"""
|if ($keyIsUnique) {
Expand All @@ -433,10 +444,10 @@ case class ShuffledHashJoinExec(
}

/**
* Generates the code for full outer join with unique join keys.
* This is code-gen version of `fullOuterJoinWithUniqueKey()`.
* Generates the code for build-side or full outer join with unique join keys.
* This is code-gen version of `buildSideOrFullOuterJoinUniqueKey()`.
*/
private def codegenFullOuterJoinWithUniqueKey(
private def codegenBuildSideOrFullOuterJoinWithUniqueKey(
ctx: CodegenContext,
rows: (String, String),
inputs: (String, String),
Expand All @@ -445,7 +456,8 @@ case class ShuffledHashJoinExec(
streamedKeyValue: ExprValue,
relationTerm: String,
conditionCheck: String,
consumeFullOuterJoinRow: String): String = {
consumeOuterJoinRow: String,
isFullOuterJoin: Boolean): String = {
// Inline mutable state since not many join operations in a task
val matchedKeySetClsName = classOf[BitSet].getName
val matchedKeySet = ctx.addMutableState(matchedKeySetClsName, "matchedKeySet",
Expand Down Expand Up @@ -484,7 +496,10 @@ case class ShuffledHashJoinExec(
| }
| }
|
| $consumeFullOuterJoinRow();
| if ($foundMatch || $isFullOuterJoin) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we can statically remove or simply the condition:

val consumeCode = if (isFullOuterJoin) {
  s"$consumeOuterJoinRow();"
} else {
  s"""
    |if ($foundMatch) {
    |  $consumeOuterJoinRow();
    |}
  """
}

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks , let me make some follow ups a bit later

| $consumeOuterJoinRow();
| }
|
| if (shouldStop()) return;
|}
""".stripMargin
Expand All @@ -500,7 +515,7 @@ case class ShuffledHashJoinExec(
| // check if key index is not in matched keys set
| if (!$matchedKeySet.get($rowWithIndex.getKeyIndex())) {
| $buildRow = $rowWithIndex.getValue();
| $consumeFullOuterJoinRow();
| $consumeOuterJoinRow();
| }
|
| if (shouldStop()) return;
Expand All @@ -514,10 +529,10 @@ case class ShuffledHashJoinExec(
}

/**
* Generates the code for full outer join with non-unique join keys.
* This is code-gen version of `fullOuterJoinWithNonUniqueKey()`.
* Generates the code for build-side or full outer join with non-unique join keys.
* This is code-gen version of `buildSideOrFullOuterJoinNonUniqueKey()`.
*/
private def codegenFullOuterJoinWithNonUniqueKey(
private def codegenBuildSideOrFullOuterJoinNonUniqueKey(
ctx: CodegenContext,
rows: (String, String),
inputs: (String, String),
Expand All @@ -526,7 +541,8 @@ case class ShuffledHashJoinExec(
streamedKeyValue: ExprValue,
relationTerm: String,
conditionCheck: String,
consumeFullOuterJoinRow: String): String = {
consumeOuterJoinRow: String,
isFullOuterJoin: Boolean): String = {
// Inline mutable state since not many join operations in a task
val matchedRowSetClsName = classOf[OpenHashSet[_]].getName
val matchedRowSet = ctx.addMutableState(matchedRowSetClsName, "matchedRowSet",
Expand Down Expand Up @@ -572,13 +588,15 @@ case class ShuffledHashJoinExec(
| // set row index in matched row set
| $matchedRowSet.add($rowIndex);
| $foundMatch = true;
| $consumeFullOuterJoinRow();
| $consumeOuterJoinRow();
| }
| }
|
| if (!$foundMatch) {
| $buildRow = null;
| $consumeFullOuterJoinRow();
| if ($isFullOuterJoin) {
| $consumeOuterJoinRow();
| }
| }
|
| if (shouldStop()) return;
Expand All @@ -603,7 +621,7 @@ case class ShuffledHashJoinExec(
| // check if row index is not in matched row set
| if (!$matchedRowSet.contains($rowIndex)) {
| $buildRow = $rowWithIndex.getValue();
| $consumeFullOuterJoinRow();
| $consumeOuterJoinRow();
| }
|
| if (shouldStop()) return;
Expand Down
146 changes: 76 additions & 70 deletions sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1315,78 +1315,84 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan

test("SPARK-36612: Support left outer join build left or right outer join build right in " +
"shuffled hash join") {
val inputDFs = Seq(
// Test unique join key
(spark.range(10).selectExpr("id as k1"),
spark.range(30).selectExpr("id as k2"),
$"k1" === $"k2"),
// Test non-unique join key
(spark.range(10).selectExpr("id % 5 as k1"),
spark.range(30).selectExpr("id % 5 as k2"),
$"k1" === $"k2"),
// Test empty build side
(spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
spark.range(30).selectExpr("id as k2"),
$"k1" === $"k2"),
// Test empty stream side
(spark.range(10).selectExpr("id as k1"),
spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
$"k1" === $"k2"),
// Test empty build and stream side
(spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
$"k1" === $"k2"),
// Test string join key
(spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
spark.range(30).selectExpr("cast(id as string) as k2"),
$"k1" === $"k2"),
// Test build side at right
(spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
spark.range(10).selectExpr("cast(id as string) as k2"),
$"k1" === $"k2"),
// Test NULL join key
(spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"),
spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"),
$"k1" === $"k2"),
(spark.range(10).map(i => if (i % 3 == 0) i else null).selectExpr("value as k1"),
spark.range(30).map(i => if (i % 5 == 0) i else null).selectExpr("value as k2"),
$"k1" === $"k2"),
// Test multiple join keys
(spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
"value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"),
spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr(
"value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as long) as k6"),
$"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
)

// test left outer with left side build
inputDFs.foreach { case (df1, df2, joinExprs) =>
val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true }.size === 1)
val smjResult = smjDF.collect()

val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}
Seq("true", "false").foreach{ codegen =>
withSQLConf(SQLConf.ENABLE_BUILD_SIDE_OUTER_SHUFFLED_HASH_JOIN_CODEGEN.key -> codegen) {
val inputDFs = Seq(
// Test unique join key
(spark.range(10).selectExpr("id as k1"),
spark.range(30).selectExpr("id as k2"),
$"k1" === $"k2"),
// Test non-unique join key
(spark.range(10).selectExpr("id % 5 as k1"),
spark.range(30).selectExpr("id % 5 as k2"),
$"k1" === $"k2"),
// Test empty build side
(spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
spark.range(30).selectExpr("id as k2"),
$"k1" === $"k2"),
// Test empty stream side
(spark.range(10).selectExpr("id as k1"),
spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
$"k1" === $"k2"),
// Test empty build and stream side
(spark.range(10).selectExpr("id as k1").filter("k1 < -1"),
spark.range(30).selectExpr("id as k2").filter("k2 < -1"),
$"k1" === $"k2"),
// Test string join key
(spark.range(10).selectExpr("cast(id * 3 as string) as k1"),
spark.range(30).selectExpr("cast(id as string) as k2"),
$"k1" === $"k2"),
// Test build side at right
(spark.range(30).selectExpr("cast(id / 3 as string) as k1"),
spark.range(10).selectExpr("cast(id as string) as k2"),
$"k1" === $"k2"),
// Test NULL join key
(spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr("value as k1"),
spark.range(30).map(i => if (i % 4 == 0) i else null).selectExpr("value as k2"),
$"k1" === $"k2"),
(spark.range(10).map(i => if (i % 3 == 0) i else null).selectExpr("value as k1"),
spark.range(30).map(i => if (i % 5 == 0) i else null).selectExpr("value as k2"),
$"k1" === $"k2"),
// Test multiple join keys
(spark.range(10).map(i => if (i % 2 == 0) i else null).selectExpr(
"value as k1", "cast(value % 5 as short) as k2", "cast(value * 3 as long) as k3"),
spark.range(30).map(i => if (i % 3 == 0) i else null).selectExpr(
"value as k4", "cast(value % 5 as short) as k5", "cast(value * 3 as long) as k6"),
$"k1" === $"k4" && $"k2" === $"k5" && $"k3" === $"k6")
)

// test right outer with right side build
inputDFs.foreach { case (df2, df1, joinExprs) =>
val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true }.size === 1)
val smjResult = smjDF.collect()
// test left outer with left side build
inputDFs.foreach { case (df1, df2, joinExprs) =>
val smjDF = df1.hint("SHUFFLE_MERGE").join(df2, joinExprs, "leftouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()

val shjDF = df1.hint("SHUFFLE_HASH").join(df2, joinExprs, "leftouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}

val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
// test right outer with right side build
inputDFs.foreach { case (df2, df1, joinExprs) =>
val smjDF = df2.join(df1.hint("SHUFFLE_MERGE"), joinExprs, "rightouter")
assert(collect(smjDF.queryExecution.executedPlan) {
case _: SortMergeJoinExec => true
}.size === 1)
val smjResult = smjDF.collect()

val shjDF = df2.join(df1.hint("SHUFFLE_HASH"), joinExprs, "rightouter")
assert(collect(shjDF.queryExecution.executedPlan) {
case _: ShuffledHashJoinExec => true
}.size === 1)
// Same result between shuffled hash join and sort merge join
checkAnswer(shjDF, smjResult)
}
}
}
}

Expand Down
Loading