diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 7a0aa08289efa..76733dd6dac3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { */ def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { plan match { - case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) => (leftKeys ++ rightKeys).exists { case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index cff4cee09427f..fad5e0574e503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 84be677e438a6..241873bbfe706 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.catalyst.planning +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf /** * A pattern that matches any number of project or filter operations on top of another relational @@ -98,9 +101,10 @@ object PhysicalOperation extends PredicateHelper { * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, rangeConditions, condition, leftChild, rightChild) */ type ReturnType = - (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) + (JoinType, Seq[Expression], Seq[Expression], Seq[BinaryComparison], + Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => @@ -132,13 +136,97 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if (joinKeys.nonEmpty) { val (leftKeys, rightKeys) = joinKeys.unzip - logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") - Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) + // Find any simple range expressions between two columns + // (and involving only those two columns) of the two tables being joined, + // which are not used in the equijoin expressions, + // and which can be used for secondary sort optimizations. + // rangePreds will contain the original expressions to be filtered out later. + val rangePreds = mutable.Set.empty[Expression] + var rangeConditions: Seq[BinaryComparison] = + if (SQLConf.get.useSmjInnerRangeOptimization) { + otherPredicates.flatMap { + case p@LessThan(l, r) => checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); GreaterThan(r, l) + case false => rangePreds.add(p); p + } + case p@LessThanOrEqual(l, r) => + checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); GreaterThanOrEqual(r, l) + case false => rangePreds.add(p); p + } + case p@GreaterThan(l, r) => checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); LessThan(r, l) + case false => rangePreds.add(p); p + } + case p@GreaterThanOrEqual(l, r) => + checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); LessThanOrEqual(r, l) + case false => rangePreds.add(p); p + } + case _ => None + } + } else { + Nil + } + + // Only using secondary join optimization when both lower and upper conditions + // are specified (e.g. t1.a < t2.b + x and t1.a > t2.b - x) + if (rangeConditions.size != 2 || + // Looking for one < and one > comparison: + rangeConditions.forall(x => !x.isInstanceOf[LessThan] && + !x.isInstanceOf[LessThanOrEqual]) || + rangeConditions.forall(x => !x.isInstanceOf[GreaterThan] && + !x.isInstanceOf[GreaterThanOrEqual]) || + // Check if both comparisons reference the same columns: + rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 || + rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) { + logDebug("Inner range optimization conditions not met. Clearing range conditions") + rangeConditions = Nil + rangePreds.clear() + } + + Some((joinType, leftKeys, rightKeys, rangeConditions, + otherPredicates.filterNot(rangePreds.contains(_)).reduceOption(And), left, right)) } else { None } case _ => None } + + /** + * Checks if l and r are valid range conditions: + * - l and r expressions should both contain a single reference to one and the same column + * - the referenced column should not be part of joinKeys + * If these conditions are not met, the function returns None. + * + * Otherwise, the function checks if the left plan contains l expression and the right plan + * contains r expression. If the expressions need to be switched, the function returns Some(true) + * and Some(false) otherwise. + */ + private def checkRangeConditions(l : Expression, r : Expression, + left : LogicalPlan, right : LogicalPlan, + joinKeys : Seq[(Expression, Expression)]): Option[Boolean] = { + val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) + if (lattrs.size != 1 || rattrs.size != 1) { + None + } else if (canEvaluate(l, left) && canEvaluate(r, right)) { + if (joinKeys.exists { case (ljk : Expression, rjk : Expression) => + ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) }) { + None + } else { + Some(false) + } + } else if (canEvaluate(l, right) && canEvaluate(r, left)) { + if (joinKeys.exists{ case (ljk : Expression, rjk : Expression) => + rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0)) }) { + None + } else { + Some(true) + } + } else { + None + } + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 2543e38a92c0a..19a0d1279cc32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging { case _ if !rowCountsExist(join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 86e068bf632bd..5050f519bf926 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1492,6 +1492,19 @@ object SQLConf { .intConf .createWithDefault(ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) + val USE_SMJ_INNER_RANGE_OPTIMIZATION = + buildConf("spark.sql.join.smj.useInnerRangeOptimization") + .internal() + .doc("Sort-merge join 'inner range optimization' is applicable in cases where the join " + + "condition includes equality expressions on pairs of columns and a range expression " + + "involving two other columns, (e.g. t1.x = t2.x AND t1.y BETWEEN t2.y - d AND t2.y + d)." + + " If the inner range optimization is enabled, the number of rows considered for each " + + "match of equality conditions can be reduced considerably because a moving window, " + + "corresponding to the range conditions, will be used for iterating over matched rows " + + "in the right relation.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1818,6 +1831,8 @@ class SQLConf extends Serializable with Logging { def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) + def useSmjInnerRangeOptimization: Boolean = getConf(USE_SMJ_INNER_RANGE_OPTIMIZATION) + def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index ac282ea2e94f5..999f83db93311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.util.ConcurrentModificationException -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, Queue} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging @@ -41,12 +41,16 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to * excessive disk writes. This may lead to a performance regression compared to the normal case * of using an [[ArrayBuffer]] or [[Array]]. + * + * If [[asQueue]] is set to true, the class will function as a queue, supporting peek() and + * dequeue() operations. */ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskMemoryManager: TaskMemoryManager, blockManager: BlockManager, serializerManager: SerializerManager, taskContext: TaskContext, + asQueue: Boolean, initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, @@ -58,6 +62,20 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get(), + false, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsInMemoryBufferThreshold, + numRowsSpillThreshold) + } + + def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int, asQueue: Boolean) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + asQueue, 1024, SparkEnv.get.memoryManager.pageSizeBytes, numRowsInMemoryBufferThreshold, @@ -67,7 +85,13 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( private val initialSizeOfInMemoryBuffer = Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) - private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { + private val inMemoryQueue = if (asQueue && initialSizeOfInMemoryBuffer > 0) { + new Queue[UnsafeRow]() + } else { + null + } + + private val inMemoryBuffer = if (!asQueue && initialSizeOfInMemoryBuffer > 0) { new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) } else { null @@ -76,6 +100,9 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( private var spillableArray: UnsafeExternalSorter = _ private var numRows = 0 + // Used when functioning as a queue to allow skipping 'dequeued' items + private var spillableArrayOffset = 0 + // A counter to keep track of total modifications done to this array since its creation. // This helps to invalidate iterators when there are changes done to the backing array. private var modificationsCount: Long = 0 @@ -95,17 +122,60 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( // inside `UnsafeExternalSorter` spillableArray.cleanupResources() spillableArray = null + spillableArrayOffset = 0 } else if (inMemoryBuffer != null) { inMemoryBuffer.clear() + } else if (inMemoryQueue != null) { + inMemoryQueue.clear() } numFieldsPerRow = 0 numRows = 0 modificationsCount += 1 } + def dequeue(): Option[UnsafeRow] = { + if (!asQueue) { + throw new IllegalStateException("Not instantiated as a queue!") + } + if (numRows == 0) { + None + } + else if (spillableArray != null) { + val retval = Some(generateIterator().next) + numRows -= 1 + modificationsCount += 1 + spillableArrayOffset += 1 + retval + } + else { + numRows -= 1 + modificationsCount += 1 + Some(inMemoryQueue.dequeue()) + } + } + + def peek(): Option[UnsafeRow] = { + if (!asQueue) { + throw new IllegalStateException("Not instantiated as a queue!") + } + if (numRows == 0) { + None + } + else if (spillableArray != null) { + Some(generateIterator().next) + } + else { + Some(inMemoryQueue(0)) + } + } + def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsInMemoryBufferThreshold) { - inMemoryBuffer += unsafeRow.copy() + if (spillableArray == null && numRows < numRowsInMemoryBufferThreshold) { + if (asQueue) { + inMemoryQueue += unsafeRow.copy() + } else { + inMemoryBuffer += unsafeRow.copy() + } } else { if (spillableArray == null) { logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + @@ -124,8 +194,21 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( numRowsSpillThreshold, false) + spillableArrayOffset = 0 + // populate with existing in-memory buffered rows - if (inMemoryBuffer != null) { + if (asQueue && inMemoryQueue != null) { + inMemoryQueue.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryQueue.clear() + } + if (!asQueue && inMemoryBuffer != null) { inMemoryBuffer.foreach(existingUnsafeRow => spillableArray.insertRecord( existingUnsafeRow.getBaseObject, @@ -168,7 +251,8 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( if (spillableArray == null) { new InMemoryBufferIterator(startIndex) } else { - new SpillableArrayIterator(spillableArray.getIterator(startIndex), numFieldsPerRow) + new SpillableArrayIterator(spillableArray.getIterator(startIndex + spillableArrayOffset), + numFieldsPerRow) } } @@ -198,7 +282,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( override def next(): UnsafeRow = { throwExceptionIfModified() - val result = inMemoryBuffer(currentIndex) + val result = if (asQueue) inMemoryQueue(currentIndex) else inMemoryBuffer(currentIndex) currentIndex += 1 result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..7db6992a58edb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -241,41 +241,45 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- BroadcastHashJoin -------------------------------------------------------------------- // broadcast hints were specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, cond, planLater(left), planLater(right))) // broadcast hints were not specified, so need to infer it from size and configuration. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, cond, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildRight, cond, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildLeft, cond, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if RowOrdering.isOrderable(leftKeys) => - joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + joins.SortMergeJoinExec(leftKeys, rightKeys, joinType, rangeConds, condition, + planLater(left), planLater(right)) :: Nil // --- Without joining keys ------------------------------------------------------------ @@ -380,11 +384,13 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangePreds, condition, left, right) if left.isStreaming && right.isStreaming => new StreamingSymmetricHashJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, joinType, + (rangePreds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And), + planLater(left), planLater(right)) :: Nil case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d2d5011bbcb97..be6b4b89f60e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -284,10 +284,12 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, left, right) - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + case SortMergeJoinExec(leftKeys, rightKeys, joinType, rangeConditions, + condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, rangeConditions, + condition, left, right) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d7d3f6d6078b4..f7c56132b45cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.BitSet /** @@ -37,9 +38,48 @@ case class SortMergeJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, + rangeConditions: Seq[BinaryComparison], condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryExecNode with CodegenSupport { + right: SparkPlan) extends BinaryExecNode with PredicateHelper with CodegenSupport { + + logDebug(s"SortMergeJoinExec args: leftKeys: $leftKeys, rightKeys: $rightKeys, " + + s"joinType: $joinType," + + s" rangeConditions: $rangeConditions, " + + s"condition: $condition, left: $left, right: $right") + + private val lowerSecondaryRangeExpression : Option[Expression] = { + logDebug(s"Finding secondary greaterThan expressions in $rangeConditions") + val thefind = rangeConditions.find(p => + p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) + logDebug(s"Found secondary greaterThan expression: $thefind") + thefind + } + private val upperSecondaryRangeExpression : Option[Expression] = { + logDebug(s"Finding secondary lowerThan expressions in $rangeConditions") + val thefind = rangeConditions.find(p => + p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) + logDebug(s"Found secondary lowerThan expression: $thefind") + thefind + } + + val useInnerRange = SQLConf.get.useSmjInnerRangeOptimization && + (lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined) + + logDebug(s"Use secondary range join resolved to $useInnerRange.") + + val lrKeys = if (useInnerRange) { + rangeConditions.flatMap(c => c.left.references.toSeq).distinct + } + else { + Nil + } + val rrKeys = if (useInnerRange) { + rangeConditions.flatMap(c => c.right.references.toSeq).distinct + } + else { + Nil + } override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -82,9 +122,10 @@ case class SortMergeJoinExec( override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. + case _: InnerLike => - val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) - val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + val leftKeyOrdering = getKeyOrdering(leftKeys ++ lrKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys ++ rrKeys, right.outputOrdering) leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => // Also add the right key and its `sameOrderExpressions` SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey @@ -122,7 +163,7 @@ case class SortMergeJoinExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + requiredOrders(leftKeys ++ lrKeys) :: requiredOrders(rightKeys ++ rrKeys) :: Nil private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. @@ -155,6 +196,20 @@ case class SortMergeJoinExec( (r: InternalRow) => true } } + val lowerRangeCondition: (InternalRow) => Boolean = { + lowerSecondaryRangeExpression.map { cond => + newPredicate(cond, left.output ++ right.output).eval _ + }.getOrElse { + (r: InternalRow) => true + } + } + val upperRangeCondition: (InternalRow) => Boolean = { + upperSecondaryRangeExpression.map { cond => + newPredicate(cond, left.output ++ right.output).eval _ + }.getOrElse { + (r: InternalRow) => true + } + } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -166,15 +221,31 @@ case class SortMergeJoinExec( private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null - private[this] val smjScanner = new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold - ) + private[this] val smjScanner = + if (useInnerRange) { + new SortMergeJoinInnerRangeScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + lowerRangeCondition, + upperRangeCondition + ) + } + else { + new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold + ) + } private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { @@ -206,7 +277,9 @@ case class SortMergeJoinExec( false } - override def getRow: InternalRow = resultProj(joinRow) + override def getRow: InternalRow = { + resultProj(joinRow) + } }.toScala case LeftOuter => @@ -421,10 +494,15 @@ case class SortMergeJoinExec( * matched one row from left side and buffered rows from right side. */ private def genScanner(ctx: CodegenContext): (String, String) = { + logInfo("SortMergeJoinE xec: generating inner range join scanner") // Create class member for next row from both sides. // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) + val rightTmpRow = if (useInnerRange) { + ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) + } + else { "" } // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -434,18 +512,143 @@ case class SortMergeJoinExec( // Copy the right key as class members so they could be used in next function call. val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) - // A list to hold all matched rows from right side. + val rangeKeys = rangeConditions.map{ + case GreaterThan(l, r) => (Some(l), None, Some(r), None) + case GreaterThanOrEqual(l, r) => (Some(l), None, Some(r), None) + case LessThan(l, r) => (None, Some(l), None, Some(r)) + case LessThanOrEqual(l, r) => (None, Some(l), None, Some(r)) + } + val (leftLowerKeys, leftUpperKeys, rightLowerKeys, rightUpperKeys) = + (rangeKeys.map(_._1).flatMap(x => x), + rangeKeys.map(_._2).flatMap(x => x), + rangeKeys.map(_._3).flatMap(x => x), + rangeKeys.map(_._4).flatMap(x => x)) + + // Variables for secondary range expressions + val (leftLowerKeyVars, leftUpperKeyVars, rightLowerKeyVars, rightUpperKeyVars) = + if (useInnerRange) { + (createJoinKey(ctx, leftRow, leftLowerKeys, left.output), + createJoinKey(ctx, leftRow, leftUpperKeys, left.output), + createJoinKey(ctx, rightRow, rightLowerKeys, right.output), + createJoinKey(ctx, rightRow, rightUpperKeys, right.output)) + } + else { + (Nil, Nil, Nil, Nil) + } + + val secRangeDataType = if (leftLowerKeys.size > 0) { leftLowerKeys(0).dataType } + else if (leftUpperKeys.size > 0) { leftUpperKeys(0).dataType } + else null + val secRangeInitValue = CodeGenerator.defaultValue(secRangeDataType) + + val (leftLowerSecRangeKey, leftUpperSecRangeKey, rightLowerSecRangeKey, rightUpperSecRangeKey) = + if (useInnerRange) { + (ctx.addBufferedState(secRangeDataType, "leftLowerSecRangeKey", secRangeInitValue), + ctx.addBufferedState(secRangeDataType, "leftUpperSecRangeKey", secRangeInitValue), + ctx.addBufferedState(secRangeDataType, "rightLowerSecRangeKey", secRangeInitValue), + ctx.addBufferedState(secRangeDataType, "rightUpperSecRangeKey", secRangeInitValue)) + } + else { + (null, null, null, null) + } + + // A queue to hold all matched rows from right side. val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - // Inline mutable state since not many join operations in a task - val matches = ctx.addMutableState(clsName, "matches", + val matches = if (useInnerRange) ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold, true);", forceInline = true) + else ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. val matchedKeyVars = copyKeys(ctx, leftKeyVars) + val lowerCompop = lowerSecondaryRangeExpression.map { + case GreaterThanOrEqual(_, _) => "<" + case GreaterThan(_, _) => "<=" + case _ => "" + }.getOrElse("") + val upperCompop = upperSecondaryRangeExpression.map { + case LessThanOrEqual(_, _) => ">" + case LessThan(_, _) => ">=" + case _ => "" + }.getOrElse("") + val lowerCompExp = if (!useInnerRange || lowerSecondaryRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftLowerSecRangeKey.value} " + + s"$lowerCompop ${rightLowerSecRangeKey.value})" + val upperCompExp = if (!useInnerRange || upperSecondaryRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " + + s"$upperCompop ${rightUpperSecRangeKey.value})" + + logDebug(s"lowerCompExp: $lowerCompExp") + logDebug(s"upperCompExp: $upperCompExp") + + // Add secondary range dequeue method + if (!useInnerRange || lowerSecondaryRangeExpression.isEmpty || + rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { + ctx.addNewFunction("dequeueUntilUpperConditionHolds", + "private void dequeueUntilUpperConditionHolds() { }") + } + else { + val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, + rightUpperKeys.slice(0, 1), right.output) + val rightRngTmpKeyVarsDecl = rightRngTmpKeyVars.map(_.code).mkString("\n") + rightRngTmpKeyVars.foreach(_.code = EmptyBlock) + val javaType = CodeGenerator.javaType(rightLowerKeys(0).dataType) + + ctx.addNewFunction("getRightTmpRangeValue", + s""" + |private $javaType getRightTmpRangeValue() { + | $rightRngTmpKeyVarsDecl + | return ${rightRngTmpKeyVars(0).value}; + |} + """.stripMargin) + + ctx.addNewFunction("dequeueUntilUpperConditionHolds", + s""" + |private void dequeueUntilUpperConditionHolds() { + | if($matches.isEmpty()) { + | $matches.clear(); + | return; + | } + | $rightTmpRow = (InternalRow) $matches.peek().get(); + | $javaType tempVal = getRightTmpRangeValue(); + | while(${leftLowerSecRangeKey.value} $upperCompop tempVal) { + | $matches.dequeue(); + | if($matches.isEmpty()) { + | $matches.clear(); + | break; + | } + | $rightTmpRow = (InternalRow) $matches.peek().get(); + | tempVal = getRightTmpRangeValue(); + | } + |} + """.stripMargin) + } + val (leftLowVarsCode, leftUpperVarsCode) = if (useInnerRange) { + (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) + } + else { ("", "") } + val (rightLowVarsCode, rightUpperVarsCode) = if (useInnerRange) { + (rightLowerKeyVars.map(_.code).mkString("\n"), rightUpperKeyVars.map(_.code).mkString("\n")) + } + else { ("", "") } + val (leftLowAssignCode, rightLowAssignCode) = if (leftLowerKeyVars.size > 0) { + lowerSecondaryRangeExpression.map(_ => + (s"${leftLowerSecRangeKey.value} = ${leftLowerKeyVars(0).value};", + s"${rightLowerSecRangeKey.value} = ${rightLowerKeyVars(0).value};")). + getOrElse(("", "")) + } + else { ("", "") } + val (leftUpperAssignCode, rightUpperAssignCode) = if (leftUpperKeyVars.size > 0) { + lowerSecondaryRangeExpression.map(_ => + (s"${leftUpperSecRangeKey.value} = ${leftUpperKeyVars(0).value};", + s"${rightUpperSecRangeKey.value} = ${rightUpperKeyVars(0).value};")). + getOrElse(("", "")) + } + else { ("", "") } + ctx.addNewFunction("findNextInnerJoinRows", s""" |private boolean findNextInnerJoinRows( @@ -461,12 +664,17 @@ case class SortMergeJoinExec( | $leftRow = null; | continue; | } + | $leftLowVarsCode + | $leftUpperVarsCode + | $leftLowAssignCode + | $leftUpperAssignCode | if (!$matches.isEmpty()) { | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { - | return true; + | dequeueUntilUpperConditionHolds(); + | } else { + | $matches.clear(); | } - | $matches.clear(); | } | | do { @@ -482,14 +690,20 @@ case class SortMergeJoinExec( | continue; | } | ${rightKeyVars.map(_.code).mkString("\n")} + | $rightLowVarsCode + | $rightUpperVarsCode + | $rightLowAssignCode + | $rightUpperAssignCode | } | ${genComparison(ctx, leftKeyVars, rightKeyVars)} - | if (comp > 0) { + | if (comp > 0 $upperCompExp) { | $rightRow = null; - | } else if (comp < 0) { + | } else if (comp < 0 $lowerCompExp) { | if (!$matches.isEmpty()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return true; + | } else { + | $matches.clear(); | } | $leftRow = null; | } else { @@ -500,7 +714,7 @@ case class SortMergeJoinExec( | } | return false; // unreachable |} - """.stripMargin, inlineToOuterClass = true) + """.stripMargin, inlineToOuterClass = true) (leftRow, matches) } @@ -561,9 +775,10 @@ case class SortMergeJoinExec( */ private def splitVarsByCondition( attributes: Seq[Attribute], - variables: Seq[ExprCode]): (String, String) = { - if (condition.isDefined) { - val condRefs = condition.get.references + variables: Seq[ExprCode], + cond: Option[Expression]): (String, String) = { + if (cond.isDefined) { + val condRefs = cond.get.references val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => condRefs.contains(a) } @@ -596,8 +811,8 @@ case class SortMergeJoinExec( val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars, condition) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars, condition) // Generate code for condition ctx.currentVars = leftVars ++ rightVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) @@ -653,7 +868,7 @@ case class SortMergeJoinExec( * * @param streamedKeyGenerator a projection that produces join keys from the streamed input. * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. - * @param keyOrdering an ordering which can be used to compare join keys. + * @param keyOrdering an ordering which can be used to compare join keys * @param streamedIter an input whose rows will be streamed. * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that * have the same join key. @@ -697,7 +912,7 @@ private[joins] class SortMergeJoinScanner( * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join * results. */ - final def findNextInnerJoinRows(): Boolean = { + def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. @@ -751,7 +966,7 @@ private[joins] class SortMergeJoinScanner( * then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer * join results. */ - final def findNextOuterJoinRows(): Boolean = { + def findNextOuterJoinRows(): Boolean = { if (!advancedStreamed()) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null @@ -841,6 +1056,241 @@ private[joins] class SortMergeJoinScanner( } } +/** + * Helper class that is used to implement [[SortMergeJoinExec]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will + * return an empty sequence if there are no matches from the buffered input). For efficiency, + * both of these methods return mutable objects which are re-used across calls to + * the `findNext*JoinRows()` methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by + * internal buffer + * @param spillThreshold Threshold for number of rows to be spilled by internal buffer + */ +private[joins] class SortMergeJoinInnerRangeScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + bufferedIter: RowIterator, + inMemoryThreshold: Int, + spillThreshold: Int, + lowerRangeCondition: InternalRow => Boolean, + upperRangeCondition: InternalRow => Boolean) + extends SortMergeJoinScanner(streamedKeyGenerator, bufferedKeyGenerator, keyOrdering, + streamedIter, bufferedIter, inMemoryThreshold, spillThreshold) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + + private[this] val bufferedMatches = + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, true) + + private[this] val testJoinRow = new JoinedRow + private[this] val joinedRow = new JoinedRow + private[this] var lowerConditionOk: Boolean = false + private[this] var upperConditionOk: Boolean = false + + // Already done in the superclass: + // advancedBufferedToRowWithNullFreeJoinKey() + bufferedRow = bufferedIter.getRow + if (bufferedRow != null) { + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + } + + // --- Public methods --------------------------------------------------------------------------- + + override def getStreamedRow: InternalRow = streamedRow + + override def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + override final def findNextInnerJoinRows(): Boolean = { + while (advanceStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, + // so the same matches can be used. + // But lower and upper ranges might not hold anymore, so check them: + // First dequeue all rows from the queue until the lower range condition holds. + // Then try to enqueue new rows with the same join key and for which the upper + // range condition holds. + dequeueUntilUpperConditionHolds() + if (bufferedRow != null) { + bufferMatchingRows() + } + if (bufferedMatches.isEmpty) { + matchJoinKey = null + findNextInnerJoinRows() + } + else true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = -1 + do { + if (streamedRowKey.anyNull) { + advanceStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0 || !lowerConditionOk) advanceStreamed() + else comp = checkBoundsAndAdvanceBuffered() + } + } while (streamedRow != null && bufferedRow != null && (comp != 0 || !lowerConditionOk)) + bufferedMatches.clear() + if (streamedRow == null || bufferedRow == null) { + // We have hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows() + if (bufferedMatches.isEmpty) { + matchJoinKey = null + findNextInnerJoinRows() + } + else true + } + } + } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advanceStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + updateJoinedRow() + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + } + else { + updateJoinedRow() + } + foundRow + } + + private def updateJoinedRow() = { + if (streamedRow != null && bufferedRow != null) { + joinedRow(streamedRow, bufferedRow) + lowerConditionOk = lowerRangeCondition(joinedRow) + upperConditionOk = upperRangeCondition(joinedRow) + } + } + + /** + * Advance the buffered iterator as long as the join key is the same and + * the upper range condition is not satisfied. + * Skip rows with nulls. + * @return Result of the join key comparison. + */ + private def checkBoundsAndAdvanceBuffered(): Int = { + assert(bufferedRow != null) + assert(streamedRow != null) + var comp = 0 + if (lowerConditionOk && !upperConditionOk) { + while (!upperConditionOk && lowerConditionOk && comp == 0 && + advancedBufferedToRowWithNullFreeJoinKey()) { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + } + } + comp + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + if (keyOrdering.compare(streamedRowKey, bufferedRowKey) != 0) { + return + } + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + do { + if (lowerConditionOk && upperConditionOk) { + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + } + } while (lowerConditionOk && advancedBufferedToRowWithNullFreeJoinKey() && + keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + } + + private def dequeueUntilUpperConditionHolds(): Unit = { + if (streamedRow != null) { + while (!bufferedMatches.isEmpty && + !upperRangeCondition(testJoinRow(streamedRow, bufferedMatches.peek.get))) { + bufferedMatches.dequeue() + } + } + } +} + /** * An iterator for outputting rows in left outer join. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index aa2162c9d2cda..b7a527b990b4d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -782,6 +782,31 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("test SortMergeJoin inner range (with spill)") { + withSQLConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "2") { + + val expected2 = new ListBuffer[Row]() + expected2.append( + Row(1, 3, 1, 2), Row(1, 3, 1, 2), + Row(1, 3, 1, 3), Row(1, 4, 1, 3), + Row(1, 4, 1, 5), Row(1, 7, 1, 7), + Row(1, 8, 1, 7), Row(2, 1, 2, 1), + Row(2, 1, 2, 2), Row(2, 2, 2, 1), + Row(2, 2, 2, 2), Row(2, 2, 2, 3), + Row(2, 3, 2, 2), Row(2, 3, 2, 3), + Row(3, 2, 3, 3), Row(3, 3, 3, 3), + Row(3, 5, 3, 6) + ) + assertSpilled(sparkContext, "inner range join") { + checkAnswer( + testData4.join(testData5, ('a === 'c ) and ('b between('d - 1, 'd + 1))), + expected2 + ) + } + } + } + test("outer broadcast hash join should not throw NPE") { withTempView("v1", "v2") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index b29de9c4adbaa..378c8b76624fb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -35,7 +35,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar super.afterAll() } - private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) + private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int, asQueue: Boolean) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) @@ -47,6 +47,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar SparkEnv.get.blockManager, SparkEnv.get.serializerManager, taskContext, + asQueue, 1024, SparkEnv.get.memoryManager.pageSizeBytes, inMemoryThreshold, @@ -114,265 +115,313 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar assert(getNumBytesSpilled > 0) } - test("insert rows less than the inMemoryThreshold") { - val (inMemoryThreshold, spillThreshold) = (100, 50) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - assert(array.isEmpty) + val asQueueVals = Array(false, true) - val expectedValues = populateRows(array, 1) - assert(!array.isEmpty) - assert(array.length == 1) + asQueueVals.foreach(q => { + test(s"insert rows less than the inMemoryThreshold $q") { + val (inMemoryThreshold, spillThreshold) = (100, 50) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + assert(array.isEmpty) - val iterator1 = validateData(array, expectedValues) + val expectedValues = populateRows(array, 1) + assert(!array.isEmpty) + assert(array.length == 1) - // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) - // Verify that NO spill has happened - populateRows(array, inMemoryThreshold - 1, expectedValues) - assert(array.length == inMemoryThreshold) - assertNoSpill() + val iterator1 = validateData(array, expectedValues) - val iterator2 = validateData(array, expectedValues) + // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) + // Verify that NO spill has happened + populateRows(array, inMemoryThreshold - 1, expectedValues) + assert(array.length == inMemoryThreshold) + assertNoSpill() - assert(!iterator1.hasNext) - assert(!iterator2.hasNext) - } - } - - test("insert rows more than the inMemoryThreshold but less than spillThreshold") { - val (inMemoryThreshold, spillThreshold) = (10, 50) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - assert(array.isEmpty) - val expectedValues = populateRows(array, inMemoryThreshold - 1) - assert(array.length == (inMemoryThreshold - 1)) - val iterator1 = validateData(array, expectedValues) - assertNoSpill() - - // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a - // spill to happen. Verify that NO spill has happened - populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) - assert(array.length == spillThreshold - 1) - assertNoSpill() - - val iterator2 = validateData(array, expectedValues) - assert(!iterator2.hasNext) - - assert(!iterator1.hasNext) - intercept[ConcurrentModificationException](iterator1.next()) - } - } + val iterator2 = validateData(array, expectedValues) - test("insert rows enough to force spill") { - val (inMemoryThreshold, spillThreshold) = (20, 10) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - assert(array.isEmpty) - val expectedValues = populateRows(array, inMemoryThreshold - 1) - assert(array.length == (inMemoryThreshold - 1)) - val iterator1 = validateData(array, expectedValues) - assertNoSpill() - - // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. - // Verify that spill has happened - populateRows(array, 2, expectedValues) - assert(array.length == inMemoryThreshold + 1) - assertSpill() - - val iterator2 = validateData(array, expectedValues) - assert(!iterator2.hasNext) - - assert(!iterator1.hasNext) - intercept[ConcurrentModificationException](iterator1.next()) - } - } - - test("iterator on an empty array should be empty") { - withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array => - val iterator = array.generateIterator() - assert(array.isEmpty) - assert(array.length == 0) - assert(!iterator.hasNext) - } - } - - test("generate iterator with negative start index") { - withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array => - val exception = - intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) - - assert(exception.getMessage.contains( - "Invalid `startIndex` provided for generating iterator over the array") - ) - } - } - - test("generate iterator with start index exceeding array's size (without spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - populateRows(array, spillThreshold / 2) + assert(!iterator1.hasNext) + assert(!iterator2.hasNext) + } + }}) + + asQueueVals.foreach(q => { + test(s"insert rows more than the inMemoryThreshold but less than spillThreshold $q") { + val (inMemoryThreshold, spillThreshold) = (10, 50) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a + // spill to happen. Verify that NO spill has happened + populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) + assert(array.length == spillThreshold - 1) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"insert rows enough to force spill $q") { + val (inMemoryThreshold, spillThreshold) = (20, 10) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. + // Verify that spill has happened + populateRows(array, 2, expectedValues) + assert(array.length == inMemoryThreshold + 1) + assertSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"iterator on an empty array should be empty $q") { + withExternalArray(inMemoryThreshold = 4, spillThreshold = 10, q) { array => + val iterator = array.generateIterator() + assert(array.isEmpty) + assert(array.length == 0) + assert(!iterator.hasNext) + } + }}) - val exception = - intercept[ArrayIndexOutOfBoundsException]( - array.generateIterator(startIndex = spillThreshold * 10)) - assert(exception.getMessage.contains( - "Invalid `startIndex` provided for generating iterator over the array")) - } - } + asQueueVals.foreach(q => { + test(s"generate iterator with negative start index $q") { + withExternalArray(inMemoryThreshold = 100, spillThreshold = 56, q) { array => + val exception = + intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) - test("generate iterator with start index exceeding array's size (with spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - populateRows(array, spillThreshold * 2) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array") + ) + } + }}) + + asQueueVals.foreach(q => { + test(s"generate iterator with start index exceeding array's size (without spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + populateRows(array, spillThreshold / 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + }}) - val exception = - intercept[ArrayIndexOutOfBoundsException]( - array.generateIterator(startIndex = spillThreshold * 10)) + asQueueVals.foreach(q => { + test(s"generate iterator with start index exceeding array's size (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + populateRows(array, spillThreshold * 2) - assert(exception.getMessage.contains( - "Invalid `startIndex` provided for generating iterator over the array")) - } - } + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) - test("generate iterator with custom start index (without spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - val expectedValues = populateRows(array, inMemoryThreshold) - val startIndex = inMemoryThreshold / 2 - val iterator = array.generateIterator(startIndex = startIndex) - for (i <- startIndex until expectedValues.length) { - checkIfValueExists(iterator, expectedValues(i)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) } - } - } - - test("generate iterator with custom start index (with spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - val expectedValues = populateRows(array, spillThreshold * 10) - val startIndex = spillThreshold * 2 - val iterator = array.generateIterator(startIndex = startIndex) - for (i <- startIndex until expectedValues.length) { - checkIfValueExists(iterator, expectedValues(i)) + }}) + + asQueueVals.foreach(q => { + test(s"generate iterator with custom start index (without spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + val expectedValues = populateRows(array, inMemoryThreshold) + val startIndex = inMemoryThreshold / 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } } - } - } + }}) + + asQueueVals.foreach(q => { + test(s"generate iterator with custom start index (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + val expectedValues = populateRows(array, spillThreshold * 10) + val startIndex = spillThreshold * 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + }}) + + asQueueVals.foreach(q => { + test(s"test iterator invalidation (without spill) $q") { + withExternalArray(inMemoryThreshold = 10, spillThreshold = 100, q) { array => + // insert 2 rows, iterate until the first row + populateRows(array, 2) + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + }}) - test("test iterator invalidation (without spill)") { - withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array => + test(s"test dequeue with spill") { + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3, true) { array => // insert 2 rows, iterate until the first row - populateRows(array, 2) - - var iterator = array.generateIterator() - assert(iterator.hasNext) - iterator.next() - - // Adding more row(s) should invalidate any old iterators - populateRows(array, 1) - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - - // Clearing the array should also invalidate any old iterators - iterator = array.generateIterator() - assert(iterator.hasNext) - iterator.next() - - array.clear() - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - } - } - - test("test iterator invalidation (with spill)") { - val (inMemoryThreshold, spillThreshold) = (2, 10) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - // Populate enough rows so that spill happens - populateRows(array, spillThreshold * 2) + populateRows(array, 5) assertSpill() var iterator = array.generateIterator() assert(iterator.hasNext) - iterator.next() - - // Adding more row(s) should invalidate any old iterators - populateRows(array, 1) - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) + val first = iterator.next() - // Clearing the array should also invalidate any old iterators - iterator = array.generateIterator() - assert(iterator.hasNext) - iterator.next() + val first2 = array.dequeue().get + assert(first.equals(first2)) + val second = array.dequeue().get + assert(!second.equals(first2)) - array.clear() - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - } - } + val third = array.peek().get + val third2 = array.dequeue().get + assert(third.equals(third2)) - test("clear on an empty the array") { - withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array => - val iterator = array.generateIterator() - assert(!iterator.hasNext) + assert(array.length == 2) - // multiple clear'ing should not have an side-effect - array.clear() - array.clear() - array.clear() - assert(array.isEmpty) - assert(array.length == 0) + array.dequeue() - // Clearing an empty array should also invalidate any old iterators - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - } - } + populateRows(array, 10) - test("clear array (without spill)") { - val (inMemoryThreshold, spillThreshold) = (10, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - // Populate rows ... but not enough to trigger spill - populateRows(array, inMemoryThreshold / 2) - assertNoSpill() + array.dequeue() + array.dequeue() - // Clear the array - array.clear() - assert(array.isEmpty) - - // Re-populate few rows so that there is no spill - // Verify the data. Verify that there was no spill - val expectedValues = populateRows(array, inMemoryThreshold / 2) - validateData(array, expectedValues) - assertNoSpill() - - // Populate more rows .. enough to not trigger a spill. - // Verify the data. Verify that there was no spill - populateRows(array, inMemoryThreshold / 2, expectedValues) - validateData(array, expectedValues) - assertNoSpill() + assert(array.length == 9) } } - test("clear array (with spill)") { - val (inMemoryThreshold, spillThreshold) = (10, 20) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - // Populate enough rows to trigger spill - populateRows(array, spillThreshold * 2) - val bytesSpilled = getNumBytesSpilled - assert(bytesSpilled > 0) - - // Clear the array - array.clear() - assert(array.isEmpty) - - // Re-populate the array ... but NOT upto the point that there is spill. - // Verify data. Verify that there was NO "extra" spill - val expectedValues = populateRows(array, spillThreshold / 2) - validateData(array, expectedValues) - assert(getNumBytesSpilled == bytesSpilled) - - // Populate more rows to trigger spill - // Verify the data. Verify that there was "extra" spill - populateRows(array, spillThreshold * 2, expectedValues) - validateData(array, expectedValues) - assert(getNumBytesSpilled > bytesSpilled) - } - } + asQueueVals.foreach(q => { + test(s"test iterator invalidation (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (2, 10) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + // Populate enough rows so that spill happens + populateRows(array, spillThreshold * 2) + assertSpill() + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"clear on an empty the array $q") { + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3, q) { array => + val iterator = array.generateIterator() + assert(!iterator.hasNext) + + // multiple clear'ing should not have an side-effect + array.clear() + array.clear() + array.clear() + assert(array.isEmpty) + assert(array.length == 0) + + // Clearing an empty array should also invalidate any old iterators + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"clear array (without spill) $q") { + val (inMemoryThreshold, spillThreshold) = (10, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + // Populate rows ... but not enough to trigger spill + populateRows(array, inMemoryThreshold / 2) + assertNoSpill() + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate few rows so that there is no spill + // Verify the data. Verify that there was no spill + val expectedValues = populateRows(array, inMemoryThreshold / 2) + validateData(array, expectedValues) + assertNoSpill() + + // Populate more rows .. enough to not trigger a spill. + // Verify the data. Verify that there was no spill + populateRows(array, inMemoryThreshold / 2, expectedValues) + validateData(array, expectedValues) + assertNoSpill() + } + }}) + + asQueueVals.foreach(q => { + test(s"clear array (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (10, 20) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + // Populate enough rows to trigger spill + populateRows(array, spillThreshold * 2) + val bytesSpilled = getNumBytesSpilled + assert(bytesSpilled > 0) + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate the array ... but NOT upto the point that there is spill. + // Verify data. Verify that there was NO "extra" spill + val expectedValues = populateRows(array, spillThreshold / 2) + validateData(array, expectedValues) + assert(getNumBytesSpilled == bytesSpilled) + + // Populate more rows to trigger spill + // Verify the data. Verify that there was "extra" spill + populateRows(array, spillThreshold * 2, expectedValues) + validateData(array, expectedValues) + assert(getNumBytesSpilled > bytesSpilled) + } + }}) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 142ab6170a734..c78f56c5b49eb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -469,6 +469,7 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Literal(1) :: Nil, Inner, + Nil, None, shuffle, shuffle) @@ -486,6 +487,7 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Literal(1) :: Nil, Inner, + Nil, None, ShuffleExchangeExec(finalPartitioning, inputPlan), ShuffleExchangeExec(finalPartitioning, inputPlan)) @@ -542,7 +544,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { Seq(Inner, Cross).foreach { joinType => - val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, + None, planA, planB) // Both left and right keys should be sorted after the SMJ. Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -556,8 +559,10 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + "child SMJ") { Seq(Inner, Cross).foreach { joinType => - val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) - val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, None, childSmj, planC) + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, + None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, Nil, + None, childSmj, planC) // After the second SMJ, exprA, exprB and exprC should all be sorted. Seq(orderingA, orderingB, orderingC).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -570,7 +575,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). - val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, Nil, + None, planA, planB) Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = leftSmj, @@ -581,7 +587,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after right outer sort merge join") { // Only right key is sorted after right outer SMJ (thus doesn't need a sort). - val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, Nil, + None, planA, planB) Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = rightSmj, @@ -592,7 +599,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements adds sort after full outer sort merge join") { // Neither keys is sorted after full outer SMJ, so they both need sorts. - val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, Nil, + None, planA, planB) Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( childPlan = fullSmj, @@ -686,11 +694,11 @@ class PlannerSuite extends SharedSQLContext { val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), outputPartitioning = HashPartitioning(exprB :: Nil, 5)) val smjExec = SortMergeJoinExec( - exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, Nil, None, plan1, plan2) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) outputPlan match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _, _) => assert(leftKeys == Seq(exprA, exprA)) assert(rightKeys == Seq(exprB, exprC)) case _ => fail() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index ad81711a13947..9e67c90a4f0ee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -148,6 +148,53 @@ object JoinBenchmark extends SqlBasedBenchmark { } } + def innerRangeTest(N: Int, M: Int): Unit = { + import sparkSession.implicits._ + val df1 = sparkSession.sparkContext.parallelize(1 to M). + cartesian(sparkSession.sparkContext.parallelize(1 to N)). + toDF("col1a", "col1b") + val df2 = sparkSession.sparkContext.parallelize(1 to M). + cartesian(sparkSession.sparkContext.parallelize(1 to N)). + toDF("col2a", "col2b") + val df = df1.join(df2, 'col1a === 'col2a and ('col1b < 'col2b + 3) and ('col1b > 'col2b - 3)) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) + df.count() + } + + ignore("sort merge inner range join") { + sparkSession.conf.set("spark.sql.join.smj.useInnerRangeOptimization", "false") + val N = 2 << 11 + val M = 100 + runBenchmark("sort merge inner range join", N * M) { + innerRangeTest(N, M) + } + + /* + *AMD EPYC 7401 24-Core Processor + *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *--------------------------------------------------------------------------------------------- + *sort merge join wholestage off 25226 / 25244 0.0 61585.9 1.0X + *sort merge join wholestage on 8581 / 8983 0.0 20948.6 2.9X + */ + } + + ignore("sort merge inner range join optimized") { + sparkSession.conf.set("spark.sql.join.smj.useInnerRangeOptimization", "true") + val N = 2 << 11 + val M = 100 + runBenchmark("sort merge inner range join optimized", N * M) { + innerRangeTest(N, M) + } + + /* + *AMD EPYC 7401 24-Core Processor + *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *--------------------------------------------------------------------------------------------- + *sort merge join wholestage off 1194 / 1212 0.3 2915.2 1.0X + *sort merge join wholestage on 814 / 867 0.5 1988.4 1.5X + */ + } + def shuffleHashJoin(): Unit = { val N: Long = 4 << 20 withSQLConf( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 22279a3a43eff..d8e43f9349836 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -102,7 +102,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rngConds, boundCondition, _, _) => + assert(rngConds.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -140,17 +141,18 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rngConds, boundCondition, _, _) => + assert(rngConds.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoinExec(leftKeys, rightKeys, joinType, Nil, boundCondition, left, right)), expectedAnswer, sortAnswers = true) checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( createLeftSemiPlusJoin(SortMergeJoinExec( - leftKeys, rightKeys, leftSemiPlus, boundCondition, left, right))), + leftKeys, rightKeys, leftSemiPlus, Nil, boundCondition, left, right))), expectedAnswer, sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index f5edd6bbd5e69..2145ddcc8c89b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join @@ -70,27 +70,39 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { (3, 2) ).toDF("a", "b") + private lazy val rangeTestData1 = Seq( + (1, 3), (1, 4), (1, 7), (1, 8), (1, 10), + (2, 1), (2, 2), (2, 3), (2, 8), + (3, 1), (3, 2), (3, 3), (3, 5), + (4, 1), (4, 2), (4, 3) + ).toDF("a", "b") + + private lazy val rangeTestData2 = Seq( + (1, 1), (1, 2), (1, 2), (1, 3), (1, 5), (1, 7), (1, 20), + (2, 1), (2, 2), (2, 3), (2, 5), (2, 6), + (3, 3), (3, 6) + ).toDF("a", "b") + // Note: the input dataframes and expression must be evaluated lazily because // the SQLContext should be used only within a test to keep SQL tests stable - private def testInnerJoin( - testName: String, - leftRows: => DataFrame, - rightRows: => DataFrame, - condition: () => Expression, - expectedAnswer: Seq[Product]): Unit = { + private def testInnerJoin(testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product], + expectRangeJoin: Boolean = false): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) ExtractEquiJoinKeys.unapply(join) } - def makeBroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - leftPlan: SparkPlan, - rightPlan: SparkPlan, - side: BuildSide) = { + def makeBroadcastHashJoin(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { val broadcastJoin = joins.BroadcastHashJoinExec( leftKeys, rightKeys, @@ -102,13 +114,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin) } - def makeShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - leftPlan: SparkPlan, - rightPlan: SparkPlan, - side: BuildSide) = { + def makeShuffledHashJoin(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = @@ -116,61 +127,78 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) } - def makeSortMergeJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - leftPlan: SparkPlan, - rightPlan: SparkPlan) = { - val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, - leftPlan, rightPlan) + def makeSortMergeJoin(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + rangeConditions: Seq[BinaryComparison], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { + val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, rangeConditions, + boundCondition, leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } + + val configOptions = List( + ("spark.sql.codegen.wholeStage", "true"), + ("spark.sql.codegen.wholeStage", "false")) - testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + // Disabling these because the code would never follow this path in case of a inner range join + if (!expectRangeJoin) { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=left)") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } } } - - testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ => - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + + if(!expectRangeJoin) { + testWithWholeStageCodegenOnAndOff(s"$testName using BroadcastHashJoin (build=right)") { _ => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) } } } - test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if(!expectRangeJoin) { + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => + test(s"$testName using ShuffledHashJoin (build=left) $counter") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } } } - test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if(!expectRangeJoin) { + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => + test(s"$testName using ShuffledHashJoin (build=right) $counter") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } } } @@ -186,31 +214,37 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using CartesianProduct") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - CartesianProductExec(left, right, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => + test(s"$testName using CartesianProduct $counter") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CROSS_JOINS_ENABLED.key -> "true", config -> confValue) { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + CartesianProductExec(left, right, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } - test(s"$testName using BroadcastNestedLoopJoin build left") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => + test(s"$testName using BroadcastNestedLoopJoin build left $counter") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } - test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => + test(s"$testName using BroadcastNestedLoopJoin build right $counter") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } @@ -272,6 +306,38 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) } + { + lazy val left = rangeTestData1.orderBy("a").sortWithinPartitions("b") + lazy val right = rangeTestData2.orderBy("a").sortWithinPartitions("b") + testInnerJoin( + "inner range join", + left, + right, + () => ((left("a") === right("a")) and (left("b") <= right("b") + 1) + and (left("b") >= right("b") - 1)).expr, + Seq( + (1, 3, 1, 2), + (1, 3, 1, 2), + (1, 3, 1, 3), + (1, 4, 1, 3), + (1, 4, 1, 5), + (1, 7, 1, 7), + (1, 8, 1, 7), + (2, 1, 2, 1), + (2, 1, 2, 2), + (2, 2, 2, 1), + (2, 2, 2, 2), + (2, 2, 2, 3), + (2, 3, 2, 2), + (2, 3, 2, 3), + (3, 2, 3, 3), + (3, 3, 3, 3), + (3, 5, 3, 6) + ), + true + ) + } + { def df: DataFrame = spark.range(3).selectExpr("struct(id, id) as key", "id as value") lazy val left = df.selectExpr("key", "concat('L', value) as value").alias("left") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 513248dae48be..dc258c84ad81b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -78,7 +78,9 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { if (joinType != FullOuter) { test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -99,7 +101,9 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { case RightOuter => BuildLeft case _ => fail(s"Unsupported join type $joinType") } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashJoinExec( @@ -112,11 +116,13 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( - SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoinExec(leftKeys, rightKeys, joinType, Nil, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 615923fe02d6c..e1cf5504364ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -72,6 +72,30 @@ private[sql] trait SQLTestData { self => df } + protected lazy val testData4: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData2(1, 3) :: TestData2(1, 4) :: TestData2(1, 7) :: TestData2(1, 8) :: + TestData2(1, 10) :: TestData2(2, 1) :: TestData2(2, 2) :: TestData2(2, 3) :: + TestData2(2, 8) :: TestData2(3, 1) :: TestData2(3, 2) :: TestData2(3, 3) :: + TestData2(3, 5) :: TestData2(4, 1) :: TestData2(4, 2) :: TestData2(4, 3) :: + Nil).toDF() + df.orderBy("a").sortWithinPartitions("b"). + createOrReplaceTempView("testData4") + df + } + + protected lazy val testData5: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData2(1, 1) :: TestData2(1, 2) :: TestData2(1, 2) :: TestData2(1, 3) :: + TestData2(1, 5) :: TestData2(1, 7) :: TestData2(1, 20) :: + TestData2(2, 1) :: TestData2(2, 2) :: TestData2(2, 3) :: TestData2(2, 5) :: + TestData2(2, 6) :: TestData2(3, 3) :: TestData2(3, 6) :: + Nil).toDF("c", "d") + df.orderBy("c").sortWithinPartitions("d"). + createOrReplaceTempView("testData5") + df + } + protected lazy val negativeData: DataFrame = { val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()