From b7a4a3e84c248445dfd49400e10e5bc0516880d7 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 22 Mar 2018 10:25:19 +0100 Subject: [PATCH 01/51] Secondary sort Sort Merge Join optimization. Not finished yet. --- .../analysis/StreamingJoinHelper.scala | 2 +- .../sql/catalyst/planning/patterns.scala | 62 +++- .../statsEstimation/JoinEstimation.scala | 2 +- .../execution/InMemoryUnsafeRowQueue.scala | 237 ++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 32 ++- .../exchange/EnsureRequirements.scala | 5 +- .../execution/joins/SortMergeJoinExec.scala | 266 +++++++++++++++++- .../execution/joins/ExistenceJoinSuite.scala | 6 +- .../sql/execution/joins/InnerJoinSuite.scala | 10 +- .../sql/execution/joins/OuterJoinSuite.scala | 6 +- 10 files changed, 586 insertions(+), 42 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala index 7a0aa08289efa..76733dd6dac3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -41,7 +41,7 @@ object StreamingJoinHelper extends PredicateHelper with Logging { */ def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { plan match { - case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _, _) => (leftKeys ++ rightKeys).exists { case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) case _ => false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 84be677e438a6..914f1a9f0695f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -24,6 +24,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import scala.collection.mutable + /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -100,7 +102,7 @@ object PhysicalOperation extends PredicateHelper { object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ type ReturnType = - (JoinType, Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) + (JoinType, Seq[Expression], Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => @@ -131,14 +133,70 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } if (joinKeys.nonEmpty) { + // Find any simple comparison expressions between two columns + // (and involving only those two columns) + // of the two tables being joined, + // which are not used in the equijoin expressions, + // and which can be used for secondary sort optimizations. + val rangePreds:mutable.Set[Expression] = mutable.Set.empty + val rangeConditions:Seq[Expression] = otherPredicates.flatMap { + case p @ LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(LessThan(l, r)) + case "vs" => rangePreds.add(p); Some(GreaterThan(r, l)) + case _ => None + } + case p @ LessThanOrEqual(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(LessThanOrEqual(l, r)) + case "vs" => rangePreds.add(p); Some(GreaterThanOrEqual(r, l)) + case _ => None + } + case p @ GreaterThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(GreaterThan(l, r)) + case "vs" => rangePreds.add(p); Some(LessThan(r, l)) + case _ => None + } + case p @ GreaterThanOrEqual(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(GreaterThanOrEqual(l, r)) + case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) + case _ => None + } + } val (leftKeys, rightKeys) = joinKeys.unzip logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") - Some((joinType, leftKeys, rightKeys, otherPredicates.reduceOption(And), left, right)) + Some((joinType, leftKeys, rightKeys, rangeConditions, + otherPredicates.filterNot(rangePreds.contains(_)).reduceOption(And), left, right)) } else { None } case _ => None } + + + private def isValidRangeCondition(l:Expression, r:Expression, left:LogicalPlan, right:LogicalPlan, + joinKeys:Seq[(Expression, Expression)]) = { + val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) + if(lattrs.size != 1 || rattrs.size != 1) + "none" + else if (canEvaluate(l, left) && canEvaluate(r, right)) { + val equiset = joinKeys.filter{ case (ljk:Expression, rjk:Expression) => + ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0))} + if(equiset.isEmpty) + "asis" + else + "none" + } + else if (canEvaluate(l, right) && canEvaluate(r, left)) { + val equiset = joinKeys.filter{ case (ljk:Expression, rjk:Expression) => + rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0))} + if(equiset.isEmpty) + "vs" + else + "none" + } + else + "none" + } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 2543e38a92c0a..19a0d1279cc32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -56,7 +56,7 @@ case class JoinEstimation(join: Join) extends Logging { case _ if !rowCountsExist(join.left, join.right) => None - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _) => + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, _, _, _, _) => // 1. Compute join selectivity val joinKeyPairs = extractJoinKeysWithColStats(leftKeys, rightKeys) val (numInnerJoinedRows, keyStatsAfterJoin) = computeCardinalityAndStats(joinKeyPairs) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala new file mode 100644 index 0000000000000..db6842c4f6c55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -0,0 +1,237 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import java.util.ConcurrentModificationException + +import org.apache.spark.internal.Logging +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.serializer.SerializerManager +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer +import org.apache.spark.storage.BlockManager +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} +import org.apache.spark.{SparkEnv, TaskContext} + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +/** + * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array + * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which + * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is + * excessive memory consumption). Setting these threshold involves following trade-offs: + * + * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory + * than is available, resulting in OOM. + * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to + * excessive disk writes. This may lead to a performance regression compared to the normal case + * of using an [[ArrayBuffer]] or [[Array]]. + */ +private[sql] class InMemoryUnsafeRowQueue( + taskMemoryManager: TaskMemoryManager, + blockManager: BlockManager, + serializerManager: SerializerManager, + taskContext: TaskContext, + initialSize: Int, + pageSizeBytes: Long, + numRowsInMemoryBufferThreshold: Int, + numRowsSpillThreshold: Int) extends Logging { + + def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsInMemoryBufferThreshold, + numRowsSpillThreshold) + } + + private val initialSizeOfInMemoryBuffer = + Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) + + private val inMemoryQueue = if (initialSizeOfInMemoryBuffer > 0) { + new mutable.Queue[UnsafeRow]() + } else { + null + } + +// private var spillableArray: UnsafeExternalSorter = _ + private var numRows = 0 + + // A counter to keep track of total modifications done to this array since its creation. + // This helps to invalidate iterators when there are changes done to the backing array. + private var modificationsCount: Long = 0 + + private var numFieldsPerRow = 0 + + def length: Int = numRows + + def isEmpty: Boolean = numRows == 0 + + /** + * Clears up resources (eg. memory) held by the backing storage + */ + def clear(): Unit = { + /*if (spillableArray != null) { + // The last `spillableArray` of this task will be cleaned up via task completion listener + // inside `UnsafeExternalSorter` + spillableArray.cleanupResources() + spillableArray = null + } else*/ + if (inMemoryQueue != null) { + inMemoryQueue.clear() + } + numFieldsPerRow = 0 + numRows = 0 + modificationsCount += 1 + } + + def dequeue(): Option[UnsafeRow] = { + if(numRows == 0) + None + else { + numRows -= 1 + Some(inMemoryQueue.dequeue()) + } + } + + def add(unsafeRow: UnsafeRow): Unit = { + if (numRows < numRowsInMemoryBufferThreshold) { + inMemoryQueue += unsafeRow.copy() + } else { + throw new RuntimeException(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows") + /*if (spillableArray == null) { + logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + + s"${classOf[UnsafeExternalSorter].getName}") + + // We will not sort the rows, so prefixComparator and recordComparator are null + spillableArray = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + serializerManager, + taskContext, + null, + null, + initialSize, + pageSizeBytes, + numRowsSpillThreshold, + false) + + // populate with existing in-memory buffered rows + if (inMemoryBuffer != null) { + inMemoryBuffer.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryBuffer.clear() + } + numFieldsPerRow = unsafeRow.numFields() + } + + spillableArray.insertRecord( + unsafeRow.getBaseObject, + unsafeRow.getBaseOffset, + unsafeRow.getSizeInBytes, + 0, + false)*/ + } + + numRows += 1 + modificationsCount += 1 + } + + /** + * Creates an [[Iterator]] for the current rows in the array starting from a user provided index + * + * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of + * the iterator, then the iterator is invalidated thus saving clients from thinking that they + * have read all the data while there were new rows added to this array. + */ + def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { + if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { + throw new ArrayIndexOutOfBoundsException( + "Invalid `startIndex` provided for generating iterator over the array. " + + s"Total elements: $numRows, requested `startIndex`: $startIndex") + } + + //if (spillableArray == null) { + new InMemoryBufferIterator(startIndex) + /*} else { + new SpillableArrayIterator(spillableArray.getIterator(startIndex), numFieldsPerRow) + }*/ + } + + def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) + + private[this] + abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { + private val expectedModificationsCount = modificationsCount + + protected def isModified(): Boolean = expectedModificationsCount != modificationsCount + + protected def throwExceptionIfModified(): Unit = { + if (expectedModificationsCount != modificationsCount) { + throw new ConcurrentModificationException( + s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " + + s"since the creation of this Iterator") + } + } + } + + private[this] class InMemoryBufferIterator(startIndex: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private var currentIndex = startIndex + + override def hasNext(): Boolean = !isModified() && currentIndex < numRows + + override def next(): UnsafeRow = { + throwExceptionIfModified() + val result = inMemoryQueue(currentIndex) + currentIndex += 1 + result + } + } + + /*private[this] class SpillableArrayIterator( + iterator: UnsafeSorterIterator, + numFieldPerRow: Int) + extends ExternalAppendOnlyUnsafeRowArrayIterator { + + private val currentRow = new UnsafeRow(numFieldPerRow) + + override def hasNext(): Boolean = !isModified() && iterator.hasNext + + override def next(): UnsafeRow = { + throwExceptionIfModified() + iterator.loadNext() + currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength) + currentRow + } + }*/ +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dbc6db62bd820..d18a2c6ac6625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -241,41 +241,46 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // --- BroadcastHashJoin -------------------------------------------------------------------- // broadcast hints were specified - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if canBroadcastByHints(joinType, left, right) => val buildSide = broadcastSideByHints(joinType, left, right) + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, cond, planLater(left), planLater(right))) // broadcast hints were not specified, so need to infer it from size and configuration. - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if canBroadcastBySizes(joinType, left, right) => val buildSide = broadcastSideBySizes(joinType, left, right) + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.BroadcastHashJoinExec( - leftKeys, rightKeys, joinType, buildSide, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, buildSide, cond, planLater(left), planLater(right))) // --- ShuffledHashJoin --------------------------------------------------------------------- - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if !conf.preferSortMergeJoin && canBuildRight(joinType) && canBuildLocalHashMap(right) && muchSmaller(right, left) || !RowOrdering.isOrderable(leftKeys) => + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildRight, cond, planLater(left), planLater(right))) - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if !conf.preferSortMergeJoin && canBuildLeft(joinType) && canBuildLocalHashMap(left) && muchSmaller(left, right) || !RowOrdering.isOrderable(leftKeys) => + val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) Seq(joins.ShuffledHashJoinExec( - leftKeys, rightKeys, joinType, BuildLeft, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, joinType, BuildLeft, cond, planLater(left), planLater(right))) // --- SortMergeJoin ------------------------------------------------------------ - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if RowOrdering.isOrderable(leftKeys) => +// val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, joinType, rangeConds, condition, planLater(left), planLater(right)) :: Nil // --- Without joining keys ------------------------------------------------------------ @@ -380,11 +385,12 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object StreamingJoinStrategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = { plan match { - case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, condition, left, right) + case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangePreds, condition, left, right) if left.isStreaming && right.isStreaming => new StreamingSymmetricHashJoinExec( - leftKeys, rightKeys, joinType, condition, planLater(left), planLater(right)) :: Nil + leftKeys, rightKeys, joinType, (rangePreds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And), + planLater(left), planLater(right)) :: Nil case Join(left, right, _, _) if left.isStreaming && right.isStreaming => throw new AnalysisException( @@ -629,6 +635,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil + case r: logical.FixedRangeRepartitionByExpression => + exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index d2d5011bbcb97..94c83a87c0a1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -284,11 +284,10 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, left, right) - case SortMergeJoinExec(leftKeys, rightKeys, joinType, condition, left, right) => + case SortMergeJoinExec(leftKeys, rightKeys, joinType, rangeConditions, condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, condition, left, right) - + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, rangeConditions, condition, left, right) case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f4b9d132122e4..55cf5d71d5288 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -37,6 +37,7 @@ case class SortMergeJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, + rangeConditions: Seq[Expression], condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryExecNode with CodegenSupport { @@ -143,6 +144,13 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold } + private val lowerRangeExpression : Option[Expression] = { + rangeConditions.find(p => p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) + } + private val upperRangeExpression : Option[Expression] = { + rangeConditions.find(p => p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) + } + protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold @@ -155,6 +163,20 @@ case class SortMergeJoinExec( (r: InternalRow) => true } } + val lowerRangeCondition: (InternalRow) => Boolean = { + lowerRangeExpression.map { cond => + newPredicate(cond, left.output ++ right.output).eval _ + }.getOrElse { + (r: InternalRow) => true + } + } + val upperRangeCondition: (InternalRow) => Boolean = { + upperRangeExpression.map { cond => + newPredicate(cond, left.output ++ right.output).eval _ + }.getOrElse { + (r: InternalRow) => true + } + } // An ordering that can be used to compare keys from both sides. val keyOrdering = newNaturalAscendingOrdering(leftKeys.map(_.dataType)) @@ -166,15 +188,30 @@ case class SortMergeJoinExec( private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null - private[this] val smjScanner = new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold - ) + private[this] val smjScanner = if(lowerRangeExpression.isDefined || upperRangeExpression.isDefined) { + new SortMergeJoinInnerRangeScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + lowerRangeCondition, + upperRangeCondition + ) + } + else { + new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold + ) + } private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { @@ -697,7 +734,7 @@ private[joins] class SortMergeJoinScanner( * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join * results. */ - final def findNextInnerJoinRows(): Boolean = { + def findNextInnerJoinRows(): Boolean = { while (advancedStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains // no nulls or we hit the end of the streamed iterator. @@ -751,7 +788,7 @@ private[joins] class SortMergeJoinScanner( * then [[getStreamedRow]] and [[getBufferedMatches]] can be called to produce the outer * join results. */ - final def findNextOuterJoinRows(): Boolean = { + def findNextOuterJoinRows(): Boolean = { if (!advancedStreamed()) { // We have consumed the entire streamed iterator, so there can be no more matches. matchJoinKey = null @@ -841,6 +878,211 @@ private[joins] class SortMergeJoinScanner( } } +/** + * Helper class that is used to implement [[SortMergeJoinExec]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will return + * an empty sequence if there are no matches from the buffered input). For efficiency, both of these + * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` + * methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by + * internal buffer + * @param spillThreshold Threshold for number of rows to be spilled by internal buffer + */ +private[joins] class SortMergeJoinInnerRangeScanner( + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + bufferedIter: RowIterator, + inMemoryThreshold: Int, + spillThreshold: Int, + lowerRangeCondition: InternalRow => Boolean, + upperRangeCondition: InternalRow => Boolean) + extends SortMergeJoinScanner(streamedKeyGenerator, bufferedKeyGenerator, keyOrdering, + streamedIter, bufferedIter, inMemoryThreshold, spillThreshold) { + private[this] var streamedRow: InternalRow = _ + private[this] var streamedRowKey: InternalRow = _ + private[this] var bufferedRow: InternalRow = _ + // Note: this is guaranteed to never have any null columns: + private[this] var bufferedRowKey: InternalRow = _ + /** + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ + private[this] var matchJoinKey: InternalRow = _ + /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches = + new InMemoryUnsafeRowQueue(inMemoryThreshold, spillThreshold)//ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + + private[this] val joinRow = new JoinedRow + // Initialization (note: do _not_ want to advance streamed here). + advanceBufferedToRowWithNullFreeJoinKey() + + // --- Public methods --------------------------------------------------------------------------- + + override def getStreamedRow: InternalRow = streamedRow + + def getBufferedMatches: InMemoryUnsafeRowQueue = bufferedMatches + + /** + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ + override final def findNextInnerJoinRows(): Boolean = { + while (advanceStreamed() && streamedRowKey.anyNull) { + // Advance the streamed side of the join until we find the next row whose join key contains + // no nulls or we hit the end of the streamed iterator. + } + if (streamedRow == null) { + // We have consumed the entire streamed iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { + // The new streamed row has the same join key as the previous row, so the same matches can be used. + // But lower and upper ranges might not hold anymore, so check them: + // First dequeue all rows from the queue until the lower range condition holds. + // Then try to enqueue new rows with the same join key and for which the upper range condition holds. + dequeueUntilLowerConditionHolds() + bufferMatchingRows(true) + true + } else if (bufferedRow == null) { + // The streamed row's join key does not match the current batch of buffered rows and there are + // no more rows to read from the buffered iterator, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // Advance both the streamed and buffered iterators to find the next pair of matching rows. + var comp = -1//keyOrdering.compare(streamedRowKey, bufferedRowKey) + do { + if (streamedRowKey.anyNull) { + advanceStreamed() + } else { + assert(!bufferedRowKey.anyNull) + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if (comp > 0) advanceBufferedToRowWithNullFreeJoinKey() + else if (comp < 0) advanceStreamed() + else comp = checkLowerBoundAndAdvanceBuffered() + } + } while (streamedRow != null && bufferedRow != null && comp != 0) + if (streamedRow == null || bufferedRow == null) { + // We have hit the end of one of the iterators, so there can be no more matches. + matchJoinKey = null + bufferedMatches.clear() + false + } else { + // The streamed row's join key matches the current buffered row's join, so walk through the + // buffered iterator to buffer the rest of the matching rows. + assert(comp == 0) + bufferMatchingRows(true) + true + } + } + } + + // --- Private methods -------------------------------------------------------------------------- + + /** + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ + private def advanceStreamed(): Boolean = { + if (streamedIter.advanceNext()) { + streamedRow = streamedIter.getRow + streamedRowKey = streamedKeyGenerator(streamedRow) + true + } else { + streamedRow = null + streamedRowKey = null + false + } + } + + /** + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ + private def advanceBufferedToRowWithNullFreeJoinKey(): Boolean = { + var foundRow: Boolean = false + while (!foundRow && bufferedIter.advanceNext()) { + bufferedRow = bufferedIter.getRow + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + foundRow = !bufferedRowKey.anyNull + } + if (!foundRow) { + bufferedRow = null + bufferedRowKey = null + false + } else { + true + } + } + + /** + * Advance the buffered iterator as long as the join key is the same and the lower range condition is not satisfied. + * Skip rows with nulls. + * @return Result of the join key comparison. + */ + private def checkLowerBoundAndAdvanceBuffered(): Int = { + assert(bufferedRow != null) + assert(streamedRow != null) + var comp = 0 + var lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) + if(!lowCheck) + while(!lowCheck && comp == 0 && advanceBufferedToRowWithNullFreeJoinKey()) { + comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) + if(comp == 0) + lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) + } + comp + } + + /** + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ + private def bufferMatchingRows(clear: Boolean): Unit = { + assert(streamedRowKey != null) + assert(!streamedRowKey.anyNull) + assert(bufferedRowKey != null) + assert(!bufferedRowKey.anyNull) + assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + // This join key may have been produced by a mutable projection, so we need to make a copy: + matchJoinKey = streamedRowKey.copy() + if(clear) + bufferedMatches.clear() + var upperRangeOk = false + var lowerRangeOk = false + do { + val jr = joinRow(streamedRow, bufferedRow) + lowerRangeOk = lowerRangeCondition(jr) + upperRangeOk = upperRangeCondition(jr) + if(lowerRangeOk && upperRangeOk) + bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + advanceBufferedToRowWithNullFreeJoinKey() + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 && upperRangeOk) + } + + private def dequeueUntilLowerConditionHolds(): Unit = { + while(!bufferedMatches.isEmpty && !lowerRangeCondition(joinRow(streamedRow, bufferedRow))) { + bufferedMatches.dequeue() + } + } +} + /** * An iterator for outputting rows in left outer join. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 38377164c10e6..bffff1572c45d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -102,7 +102,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -121,7 +121,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using BroadcastHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -140,7 +140,7 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 4408ece112258..62de27fac95b1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -128,7 +128,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using BroadcastHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -140,7 +140,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using BroadcastHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -152,7 +152,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -164,7 +164,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -176,7 +176,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 001feb0f2b399..d4898a1447f7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -78,7 +78,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { if (joinType != FullOuter) { test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -99,7 +99,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { case RightOuter => BuildLeft case _ => fail(s"Unsupported join type $joinType") } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashJoinExec( @@ -112,7 +112,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( From fd8161385a263e3ee72aabaeb5d1e0f6b9b36b09 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 22 Mar 2018 10:56:33 +0100 Subject: [PATCH 02/51] SortMergeJoin secondary sort optimization --- .../execution/InMemoryUnsafeRowQueue.scala | 24 ++++++++++++------- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../spark/sql/execution/PlannerSuite.scala | 14 ++++++----- .../execution/joins/ExistenceJoinSuite.scala | 4 ++-- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../sql/execution/joins/OuterJoinSuite.scala | 2 +- 6 files changed, 29 insertions(+), 19 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index db6842c4f6c55..9bb02570bea7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -51,7 +51,15 @@ private[sql] class InMemoryUnsafeRowQueue( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) extends Logging { + numRowsSpillThreshold: Int) + extends ExternalAppendOnlyUnsafeRowArray(taskMemoryManager, + blockManager, + serializerManager, + taskContext, + initialSize, + pageSizeBytes, + numRowsInMemoryBufferThreshold, + numRowsSpillThreshold) { def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { this( @@ -83,14 +91,14 @@ private[sql] class InMemoryUnsafeRowQueue( private var numFieldsPerRow = 0 - def length: Int = numRows - - def isEmpty: Boolean = numRows == 0 +// def length: Int = numRows +// +// def isEmpty: Boolean = numRows == 0 /** * Clears up resources (eg. memory) held by the backing storage */ - def clear(): Unit = { + override def clear(): Unit = { /*if (spillableArray != null) { // The last `spillableArray` of this task will be cleaned up via task completion listener // inside `UnsafeExternalSorter` @@ -114,7 +122,7 @@ private[sql] class InMemoryUnsafeRowQueue( } } - def add(unsafeRow: UnsafeRow): Unit = { + override def add(unsafeRow: UnsafeRow): Unit = { if (numRows < numRowsInMemoryBufferThreshold) { inMemoryQueue += unsafeRow.copy() } else { @@ -170,7 +178,7 @@ private[sql] class InMemoryUnsafeRowQueue( * the iterator, then the iterator is invalidated thus saving clients from thinking that they * have read all the data while there were new rows added to this array. */ - def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { + override def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { throw new ArrayIndexOutOfBoundsException( "Invalid `startIndex` provided for generating iterator over the array. " + @@ -184,7 +192,7 @@ private[sql] class InMemoryUnsafeRowQueue( }*/ } - def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) +// override def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) private[this] abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 55cf5d71d5288..551d40ecd460f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -933,7 +933,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( override def getStreamedRow: InternalRow = streamedRow - def getBufferedMatches: InMemoryUnsafeRowQueue = bufferedMatches + override def getBufferedMatches: InMemoryUnsafeRowQueue = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 3db89ecfad9fc..e6447fa7c4a93 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -469,6 +469,7 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Literal(1) :: Nil, Inner, + Nil, None, shuffle, shuffle) @@ -486,6 +487,7 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Literal(1) :: Nil, Inner, + Nil, None, ShuffleExchangeExec(finalPartitioning, inputPlan), ShuffleExchangeExec(finalPartitioning, inputPlan)) @@ -542,7 +544,7 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { Seq(Inner, Cross).foreach { joinType => - val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, None, planA, planB) // Both left and right keys should be sorted after the SMJ. Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -556,8 +558,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + "child SMJ") { Seq(Inner, Cross).foreach { joinType => - val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, None, planA, planB) - val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, None, childSmj, planC) + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, Nil, None, childSmj, planC) // After the second SMJ, exprA, exprB and exprC should all be sorted. Seq(orderingA, orderingB, orderingC).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -570,7 +572,7 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). - val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, None, planA, planB) + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, Nil, None, planA, planB) Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = leftSmj, @@ -581,7 +583,7 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after right outer sort merge join") { // Only right key is sorted after right outer SMJ (thus doesn't need a sort). - val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, None, planA, planB) + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, Nil, None, planA, planB) Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = rightSmj, @@ -592,7 +594,7 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements adds sort after full outer sort merge join") { // Neither keys is sorted after full outer SMJ, so they both need sorts. - val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, None, planA, planB) + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, Nil, None, planA, planB) Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( childPlan = fullSmj, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index bffff1572c45d..9e6a38765d2b4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -144,13 +144,13 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( - SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoinExec(leftKeys, rightKeys, joinType, Nil, boundCondition, left, right)), expectedAnswer, sortAnswers = true) checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( createLeftSemiPlusJoin(SortMergeJoinExec( - leftKeys, rightKeys, leftSemiPlus, boundCondition, left, right))), + leftKeys, rightKeys, leftSemiPlus, Nil, boundCondition, left, right))), expectedAnswer, sortAnswers = true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 62de27fac95b1..b94553da10cab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -122,7 +122,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition: Option[Expression], leftPlan: SparkPlan, rightPlan: SparkPlan) = { - val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, boundCondition, + val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, Nil, boundCondition, leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index d4898a1447f7b..417b356f4e07c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -116,7 +116,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( - SortMergeJoinExec(leftKeys, rightKeys, joinType, boundCondition, left, right)), + SortMergeJoinExec(leftKeys, rightKeys, joinType, Nil, boundCondition, left, right)), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } From f1efa9b3e97bb6e2f111c7402e193f6a02b94b30 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Wed, 4 Apr 2018 20:17:47 +0200 Subject: [PATCH 03/51] Sort-Merge "inner range join" (secondary sort) - code generation --- .../UnsupportedOperationChecker.scala | 3 +- .../sql/catalyst/planning/patterns.scala | 8 +- .../execution/InMemoryUnsafeRowQueue.scala | 4 + .../spark/sql/execution/SparkStrategies.scala | 3 + .../execution/joins/SortMergeJoinExec.scala | 365 ++++++++++++++---- 5 files changed, 294 insertions(+), 89 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index cff4cee09427f..fad5e0574e503 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,9 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID} +import org.apache.spark.sql.catalyst.expressions.{Attribute, CurrentDate, CurrentTimestamp, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 914f1a9f0695f..18bad3546ec31 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,8 +17,10 @@ package org.apache.spark.sql.catalyst.planning +import org.apache.avro.hadoop.file.HadoopCodecFactory import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ @@ -102,7 +104,8 @@ object PhysicalOperation extends PredicateHelper { object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ type ReturnType = - (JoinType, Seq[Expression], Seq[Expression], Seq[Expression], Option[Expression], LogicalPlan, LogicalPlan) + (JoinType, Seq[Expression], Seq[Expression], Seq[Expression], + Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => @@ -162,7 +165,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } } val (leftKeys, rightKeys) = joinKeys.unzip - logDebug(s"leftKeys:$leftKeys | rightKeys:$rightKeys") + Some((joinType, leftKeys, rightKeys, rangeConditions, otherPredicates.filterNot(rangePreds.contains(_)).reduceOption(And), left, right)) } else { @@ -171,7 +174,6 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case _ => None } - private def isValidRangeCondition(l:Expression, r:Expression, left:LogicalPlan, right:LogicalPlan, joinKeys:Seq[(Expression, Expression)]) = { val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index 9bb02570bea7d..b5a51a07862e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -122,6 +122,10 @@ private[sql] class InMemoryUnsafeRowQueue( } } + def get(idx:Int): UnsafeRow = { + inMemoryQueue(idx) + } + override def add(unsafeRow: UnsafeRow): Unit = { if (numRows < numRowsInMemoryBufferThreshold) { inMemoryQueue += unsafeRow.copy() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d18a2c6ac6625..cdfe3308c2096 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -279,6 +279,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if RowOrdering.isOrderable(leftKeys) => // val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) + logDebug(s"SMJExecJoinType: $joinType") + //TODO check if left and right are reading from bucketed tables and those are bucketed by join key + // and sorted by keys in range conditions joins.SortMergeJoinExec( leftKeys, rightKeys, joinType, rangeConds, condition, planLater(left), planLater(right)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 551d40ecd460f..c30a301042e28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet /** @@ -42,6 +43,10 @@ case class SortMergeJoinExec( left: SparkPlan, right: SparkPlan) extends BinaryExecNode with CodegenSupport { + logDebug(s"SortMergeJoinExec args: leftKeys: $leftKeys, rightKeys: $rightKeys, joinType: $joinType," + + s" rangeConditions: $rangeConditions, " + + s"condition: $condition, left: $left, right: $right") + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -145,16 +150,24 @@ case class SortMergeJoinExec( } private val lowerRangeExpression : Option[Expression] = { - rangeConditions.find(p => p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) + logDebug(s"Finding greaterThan expressions in $rangeConditions") + val thefind = rangeConditions.find(p => p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) + logDebug(s"Found greaterThan expression: $thefind") + thefind } private val upperRangeExpression : Option[Expression] = { - rangeConditions.find(p => p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) + logDebug(s"Finding lowerThan expressions in $rangeConditions") + val thefind = rangeConditions.find(p => p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) + logDebug(s"Found lowerThan expression: $thefind") + thefind } protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold + if(lowerRangeExpression.isDefined || upperRangeExpression.isDefined) + logDebug("Should be using SortMergeJoinInnerRangeScanner") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -458,88 +471,270 @@ case class SortMergeJoinExec( * matched one row from left side and buffered rows from right side. */ private def genScanner(ctx: CodegenContext): (String, String) = { - // Create class member for next row from both sides. - // Inline mutable state since not many join operations in a task - val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) - val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) - - // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) - val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) - val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") - // Copy the right key as class members so they could be used in next function call. - val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) - - // A list to hold all matched rows from right side. - val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + if(lowerRangeExpression.isDefined || upperRangeExpression.isDefined) { + logInfo("SortMergeJoinExec: generating inner range join scanner") + // Create class member for next row from both sides. + // Inline mutable state since not many join operations in a task + val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) + val rightTmpRow = ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the right key as class members so they could be used in next function call. + val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + + val rangeKeys = rangeConditions.map{ + case GreaterThan(l, r) => (Some(l), None, Some(r), None) + case GreaterThanOrEqual(l, r) => (Some(l), None, Some(r), None) + case LessThan(l, r) => (None, Some(l), None, Some(r)) + case LessThanOrEqual(l, r) => (None, Some(l), None, Some(r)) + } + val (leftLowerKeys, leftUpperKeys, rightLowerKeys, rightUpperKeys) = + (rangeKeys.map(_._1).flatMap(x => x), + rangeKeys.map(_._2).flatMap(x => x), + rangeKeys.map(_._3).flatMap(x => x), + rangeKeys.map(_._4).flatMap(x => x)) + + // Variables for range expressions= + val leftLowerKeyVars = createJoinKey(ctx, leftRow, leftLowerKeys, left.output) + val leftUpperKeyVars = createJoinKey(ctx, leftRow, leftUpperKeys, left.output) + val rightLowerKeyVars = createJoinKey(ctx, rightRow, rightLowerKeys, right.output) + val rightUpperKeyVars = createJoinKey(ctx, rightRow, rightUpperKeys, right.output) + + val dataType = if(leftLowerKeys.size > 0) leftLowerKeys(0).dataType + else leftUpperKeys(0).dataType + val initValue = CodeGenerator.defaultValue(dataType) + val leftLowerRangeKey = ctx.addBufferedState(dataType, "leftLowerRangeKey", initValue) + val leftUpperRangeKey = ctx.addBufferedState(dataType, "leftUpperRangeKey", initValue) + val rightLowerRangeKey = ctx.addBufferedState(dataType, "rightLowerRangeKey", initValue) + val rightUpperRangeKey = ctx.addBufferedState(dataType, "rightUpperRangeKey", initValue) + + // A queue to hold all matched rows from right side. + val clsName = classOf[InMemoryUnsafeRowQueue].getName + + val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold + + val matches = ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + + val lowerCompop = lowerRangeExpression.get match { + case GreaterThanOrEqual(_, _) => "<" + case GreaterThan(_, _) => "<=" + case _ => "" + } + val upperCompop = upperRangeExpression.get match { + case LessThanOrEqual(_, _) => ">" + case LessThan(_, _) => ">=" + case _ => "" + } + val lowerCompExp = if(lowerRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftLowerRangeKey.value} $lowerCompop ${rightLowerRangeKey.value})" + val upperCompExp = if(upperRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftUpperRangeKey.value} $upperCompop ${rightUpperRangeKey.value})" + + logDebug(s"lowerCompExp: $lowerCompExp") + logDebug(s"upperCompExp: $upperCompExp") + + if(lowerRangeExpression.isEmpty || rightLowerKeys.size == 0) { + ctx.addNewFunction("dequeueUntilLowerConditionHolds", + "private void dequeueUntilLowerConditionHolds() { }", + inlineToOuterClass = true) + } + else { + val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, rightUpperKeys.slice(0, 1), right.output) + val rightRngTmpKeyVarsDecl = rightRngTmpKeyVars.map(_.code).mkString("\n") + rightRngTmpKeyVars.foreach(_.code = "") + val javaType = CodeGenerator.javaType(rightLowerKeys(0).dataType) - val spillThreshold = getSpillThreshold - val inMemoryThreshold = getInMemoryThreshold + ctx.addNewFunction("getRightTmpRangeValue", + s""" + |private $javaType getRightTmpRangeValue() { + | $rightRngTmpKeyVarsDecl + | return ${rightRngTmpKeyVars(0).value}; + |} + """.stripMargin) - // Inline mutable state since not many join operations in a task - val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. - val matchedKeyVars = copyKeys(ctx, leftKeyVars) + ctx.addNewFunction("dequeueUntilLowerConditionHolds", + s""" + |private void dequeueUntilLowerConditionHolds() { + | if($matches.isEmpty()) + | return; + | $rightTmpRow = $matches.get(0); + | $javaType tempVal = getRightTmpRangeValue(); + | while(!$matches.isEmpty() && ${leftLowerRangeKey.value} $upperCompop tempVal) { + | $matches.dequeue(); + | $rightTmpRow = $matches.get(0); + | tempVal = getRightTmpRangeValue(); + | } + |} + """.stripMargin, inlineToOuterClass = true) + } - ctx.addNewFunction("findNextInnerJoinRows", - s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; - | int comp = 0; - | while ($leftRow == null) { - | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); - | ${leftKeyVars.map(_.code).mkString("\n")} - | if ($leftAnyNull) { - | $leftRow = null; - | continue; - | } - | if (!$matches.isEmpty()) { - | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} - | if (comp == 0) { - | return true; - | } - | $matches.clear(); - | } - | - | do { - | if ($rightRow == null) { - | if (!rightIter.hasNext()) { - | ${matchedKeyVars.map(_.code).mkString("\n")} - | return !$matches.isEmpty(); - | } - | $rightRow = (InternalRow) rightIter.next(); - | ${rightKeyTmpVars.map(_.code).mkString("\n")} - | if ($rightAnyNull) { - | $rightRow = null; - | continue; - | } - | ${rightKeyVars.map(_.code).mkString("\n")} - | } - | ${genComparison(ctx, leftKeyVars, rightKeyVars)} - | if (comp > 0) { - | $rightRow = null; - | } else if (comp < 0) { - | if (!$matches.isEmpty()) { - | ${matchedKeyVars.map(_.code).mkString("\n")} - | return true; - | } - | $leftRow = null; - | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null;; - | } - | } while ($leftRow != null); - | } - | return false; // unreachable - |} + val (leftLowAssignCode, rightLowAssignCode) = lowerRangeExpression.map(_ => + (s"${leftLowerRangeKey.value} = ${leftLowerKeyVars(0).value};", s"${rightLowerRangeKey.value} = ${rightLowerKeyVars(0).value};")). + getOrElse(("", "")) + val (leftUpperAssignCode, rightUpperAssignCode) = lowerRangeExpression.map(_ => + (s"${leftUpperRangeKey.value} = ${leftUpperKeyVars(0).value};", s"${rightUpperRangeKey.value} = ${rightUpperKeyVars(0).value};")). + getOrElse(("", "")) + + ctx.addNewFunction("findNextInnerJoinRows", + s""" + |private boolean findNextInnerJoinRows( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) { + | $leftRow = null; + | int comp = 0; + | while ($leftRow == null) { + | if (!leftIter.hasNext()) return false; + | $leftRow = (InternalRow) leftIter.next(); + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | $leftRow = null; + | continue; + | } + | ${leftLowerKeyVars.map(_.code).mkString("\n")} + | ${leftUpperKeyVars.map(_.code).mkString("\n")} + | $leftLowAssignCode + | $leftUpperAssignCode + | if (!$matches.isEmpty()) { + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} + | if (comp == 0) { + | dequeueUntilLowerConditionHolds(); + | } + | else { + | $matches.clear(); + | } + | } + | + | do { + | if ($rightRow == null) { + | if (!rightIter.hasNext()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return !$matches.isEmpty(); + | } + | $rightRow = (InternalRow) rightIter.next(); + | ${rightKeyTmpVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | $rightRow = null; + | continue; + | } + | ${rightKeyVars.map(_.code).mkString("\n")} + | ${rightLowerKeyVars.map(_.code).mkString("\n")} + | ${rightUpperKeyVars.map(_.code).mkString("\n")} + | $rightLowAssignCode + | $rightUpperAssignCode + | } + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} + | if (comp > 0 $upperCompExp) { + | $rightRow = null; + | } else if (comp < 0 $lowerCompExp) { + | if (!$matches.isEmpty()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return true; + | } + | $leftRow = null; + | } else { + | $matches.add((UnsafeRow) $rightRow); + | $rightRow = null; + | } + | } while ($leftRow != null); + | } + | return false; // unreachable + |} """.stripMargin, inlineToOuterClass = true) - (leftRow, matches) + (leftRow, matches) + } + else { + // Create class member for next row from both sides. + // Inline mutable state since not many join operations in a task + val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the right key as class members so they could be used in next function call. + val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + + // A list to hold all matched rows from right side. + val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName + + val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold + + // Inline mutable state since not many join operations in a task + val matches = ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) + // Copy the left keys as class members so they could be used in next function call. + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + + ctx.addNewFunction("findNextInnerJoinRows", + s""" + |private boolean findNextInnerJoinRows( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) { + | $leftRow = null; + | int comp = 0; + | while ($leftRow == null) { + | if (!leftIter.hasNext()) return false; + | $leftRow = (InternalRow) leftIter.next(); + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | $leftRow = null; + | continue; + | } + | if (!$matches.isEmpty()) { + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} + | if (comp == 0) { + | return true; + | } + | $matches.clear(); + | } + | + | do { + | if ($rightRow == null) { + | if (!rightIter.hasNext()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return !$matches.isEmpty(); + | } + | $rightRow = (InternalRow) rightIter.next(); + | ${rightKeyTmpVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | $rightRow = null; + | continue; + | } + | ${rightKeyVars.map(_.code).mkString("\n")} + | } + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} + | if (comp > 0) { + | $rightRow = null; + | } else if (comp < 0) { + | if (!$matches.isEmpty()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return true; + | } + | $leftRow = null; + | } else { + | $matches.add((UnsafeRow) $rightRow); + | $rightRow = null;; + | } + | } while ($leftRow != null); + | } + | return false; // unreachable + |} + """.stripMargin, inlineToOuterClass = true) + + (leftRow, matches) + } } /** @@ -598,9 +793,10 @@ case class SortMergeJoinExec( */ private def splitVarsByCondition( attributes: Seq[Attribute], - variables: Seq[ExprCode]): (String, String) = { - if (condition.isDefined) { - val condRefs = condition.get.references + variables: Seq[ExprCode], + cond: Option[Expression]): (String, String) = { + if (cond.isDefined) { + val condRefs = cond.get.references val (used, notUsed) = attributes.zip(variables).partition{ case (a, ev) => condRefs.contains(a) } @@ -633,8 +829,8 @@ case class SortMergeJoinExec( val (beforeLoop, condCheck) = if (condition.isDefined) { // Split the code of creating variables based on whether it's used by condition or not. val loaded = ctx.freshName("loaded") - val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars) - val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) + val (leftBefore, leftAfter) = splitVarsByCondition(left.output, leftVars, condition) + val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars, condition) // Generate code for condition ctx.currentVars = leftVars ++ rightVars val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) @@ -922,6 +1118,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ + private[this] val bufferedMatches = new InMemoryUnsafeRowQueue(inMemoryThreshold, spillThreshold)//ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) From f533f6521293c8e2add817bed870e0255e52757b Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 5 Apr 2018 20:24:59 +0200 Subject: [PATCH 04/51] Sort-Merge "inner range join" (secondary sort) - two bug fixes - works now --- .../spark/sql/execution/InMemoryUnsafeRowQueue.scala | 2 ++ .../sql/execution/joins/SortMergeJoinExec.scala | 12 +++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index b5a51a07862e8..1112326d8d3da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -85,6 +85,8 @@ private[sql] class InMemoryUnsafeRowQueue( // private var spillableArray: UnsafeExternalSorter = _ private var numRows = 0 + override def isEmpty: Boolean = numRows == 0 + // A counter to keep track of total modifications done to this array since its creation. // This helps to invalidate iterators when there are changes done to the backing array. private var modificationsCount: Long = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index c30a301042e28..e7fcf5a599914 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -523,16 +523,16 @@ case class SortMergeJoinExec( v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) val matchedKeyVars = copyKeys(ctx, leftKeyVars) - val lowerCompop = lowerRangeExpression.get match { + val lowerCompop = lowerRangeExpression.map { case GreaterThanOrEqual(_, _) => "<" case GreaterThan(_, _) => "<=" case _ => "" - } - val upperCompop = upperRangeExpression.get match { + }.getOrElse("") + val upperCompop = upperRangeExpression.map { case LessThanOrEqual(_, _) => ">" case LessThan(_, _) => ">=" case _ => "" - } + }.getOrElse("") val lowerCompExp = if(lowerRangeExpression.isEmpty) "" else s" || (comp == 0 && ${leftLowerRangeKey.value} $lowerCompop ${rightLowerRangeKey.value})" val upperCompExp = if(upperRangeExpression.isEmpty) "" @@ -567,8 +567,10 @@ case class SortMergeJoinExec( | return; | $rightTmpRow = $matches.get(0); | $javaType tempVal = getRightTmpRangeValue(); - | while(!$matches.isEmpty() && ${leftLowerRangeKey.value} $upperCompop tempVal) { + | while(${leftLowerRangeKey.value} $upperCompop tempVal) { | $matches.dequeue(); + | if($matches.isEmpty()) + | break; | $rightTmpRow = $matches.get(0); | tempVal = getRightTmpRangeValue(); | } From 2ff492c58a55a59313d7d3f083d4d4cafff65a9e Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Mon, 9 Apr 2018 15:20:15 +0200 Subject: [PATCH 05/51] Code simplification --- .../execution/joins/SortMergeJoinExec.scala | 460 ++++++++---------- 1 file changed, 209 insertions(+), 251 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index e7fcf5a599914..f62243449e20e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet /** @@ -41,12 +40,47 @@ case class SortMergeJoinExec( rangeConditions: Seq[Expression], condition: Option[Expression], left: SparkPlan, - right: SparkPlan) extends BinaryExecNode with CodegenSupport { + right: SparkPlan) extends BinaryExecNode with PredicateHelper with CodegenSupport { logDebug(s"SortMergeJoinExec args: leftKeys: $leftKeys, rightKeys: $rightKeys, joinType: $joinType," + s" rangeConditions: $rangeConditions, " + s"condition: $condition, left: $left, right: $right") +// val leftBucketSpec = findBucketSpec(left) +// val rightBucketSpec = findBucketSpec(right) +// +// logDebug(s"Found left bucket spec: $leftBucketSpec. Found right bucket spec: $rightBucketSpec") + + private val lowerSecondaryRangeExpression : Option[Expression] = { + logDebug(s"Finding secondary greaterThan expressions in $rangeConditions") + val thefind = rangeConditions.find(p => p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) + logDebug(s"Found secondary greaterThan expression: $thefind") + thefind + } + private val upperSecondaryRangeExpression : Option[Expression] = { + logDebug(s"Finding secondary lowerThan expressions in $rangeConditions") + val thefind = rangeConditions.find(p => p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) + logDebug(s"Found secondary lowerThan expression: $thefind") + thefind + } + + val useSecondaryRange = shouldUseSecondaryRangeJoin() + + logDebug(s"Use secondary range join resolved to $useSecondaryRange.") + +// private def findBucketSpec(plan: SparkPlan): Option[BucketSpec] = { +// if(plan.isInstanceOf[FileSourceScanExec]) { +// plan.asInstanceOf[FileSourceScanExec].relation.bucketSpec +// } +// else { +// val cb = plan.children.flatMap(c => findBucketSpec(c)) +// if (cb.size > 0) +// Some(cb(0)) +// else +// None +// } +// } + override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -149,25 +183,10 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold } - private val lowerRangeExpression : Option[Expression] = { - logDebug(s"Finding greaterThan expressions in $rangeConditions") - val thefind = rangeConditions.find(p => p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) - logDebug(s"Found greaterThan expression: $thefind") - thefind - } - private val upperRangeExpression : Option[Expression] = { - logDebug(s"Finding lowerThan expressions in $rangeConditions") - val thefind = rangeConditions.find(p => p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) - logDebug(s"Found lowerThan expression: $thefind") - thefind - } - protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - if(lowerRangeExpression.isDefined || upperRangeExpression.isDefined) - logDebug("Should be using SortMergeJoinInnerRangeScanner") left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { condition.map { cond => @@ -177,14 +196,14 @@ case class SortMergeJoinExec( } } val lowerRangeCondition: (InternalRow) => Boolean = { - lowerRangeExpression.map { cond => + lowerSecondaryRangeExpression.map { cond => newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true } } val upperRangeCondition: (InternalRow) => Boolean = { - upperRangeExpression.map { cond => + upperSecondaryRangeExpression.map { cond => newPredicate(cond, left.output ++ right.output).eval _ }.getOrElse { (r: InternalRow) => true @@ -201,7 +220,7 @@ case class SortMergeJoinExec( private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null - private[this] val smjScanner = if(lowerRangeExpression.isDefined || upperRangeExpression.isDefined) { + private[this] val smjScanner = if(lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined) { new SortMergeJoinInnerRangeScanner( createLeftKeyGenerator(), createRightKeyGenerator(), @@ -466,82 +485,98 @@ case class SortMergeJoinExec( """.stripMargin } + private def shouldUseSecondaryRangeJoin():Boolean = { + //TODO check sorting of the two relations? - Check it during planning? + lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined + } + /** * Generate a function to scan both left and right to find a match, returns the term for * matched one row from left side and buffered rows from right side. */ private def genScanner(ctx: CodegenContext): (String, String) = { - if(lowerRangeExpression.isDefined || upperRangeExpression.isDefined) { - logInfo("SortMergeJoinExec: generating inner range join scanner") - // Create class member for next row from both sides. - // Inline mutable state since not many join operations in a task - val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) - val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) - val rightTmpRow = ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) - - // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) - val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) - val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") - // Copy the right key as class members so they could be used in next function call. - val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) - - val rangeKeys = rangeConditions.map{ - case GreaterThan(l, r) => (Some(l), None, Some(r), None) - case GreaterThanOrEqual(l, r) => (Some(l), None, Some(r), None) - case LessThan(l, r) => (None, Some(l), None, Some(r)) - case LessThanOrEqual(l, r) => (None, Some(l), None, Some(r)) - } - val (leftLowerKeys, leftUpperKeys, rightLowerKeys, rightUpperKeys) = - (rangeKeys.map(_._1).flatMap(x => x), - rangeKeys.map(_._2).flatMap(x => x), - rangeKeys.map(_._3).flatMap(x => x), - rangeKeys.map(_._4).flatMap(x => x)) - - // Variables for range expressions= - val leftLowerKeyVars = createJoinKey(ctx, leftRow, leftLowerKeys, left.output) - val leftUpperKeyVars = createJoinKey(ctx, leftRow, leftUpperKeys, left.output) - val rightLowerKeyVars = createJoinKey(ctx, rightRow, rightLowerKeys, right.output) - val rightUpperKeyVars = createJoinKey(ctx, rightRow, rightUpperKeys, right.output) - - val dataType = if(leftLowerKeys.size > 0) leftLowerKeys(0).dataType - else leftUpperKeys(0).dataType - val initValue = CodeGenerator.defaultValue(dataType) - val leftLowerRangeKey = ctx.addBufferedState(dataType, "leftLowerRangeKey", initValue) - val leftUpperRangeKey = ctx.addBufferedState(dataType, "leftUpperRangeKey", initValue) - val rightLowerRangeKey = ctx.addBufferedState(dataType, "rightLowerRangeKey", initValue) - val rightUpperRangeKey = ctx.addBufferedState(dataType, "rightUpperRangeKey", initValue) - - // A queue to hold all matched rows from right side. - val clsName = classOf[InMemoryUnsafeRowQueue].getName - - val spillThreshold = getSpillThreshold - val inMemoryThreshold = getInMemoryThreshold - - val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - val matchedKeyVars = copyKeys(ctx, leftKeyVars) - - val lowerCompop = lowerRangeExpression.map { - case GreaterThanOrEqual(_, _) => "<" - case GreaterThan(_, _) => "<=" - case _ => "" - }.getOrElse("") - val upperCompop = upperRangeExpression.map { - case LessThanOrEqual(_, _) => ">" - case LessThan(_, _) => ">=" - case _ => "" - }.getOrElse("") - val lowerCompExp = if(lowerRangeExpression.isEmpty) "" - else s" || (comp == 0 && ${leftLowerRangeKey.value} $lowerCompop ${rightLowerRangeKey.value})" - val upperCompExp = if(upperRangeExpression.isEmpty) "" - else s" || (comp == 0 && ${leftUpperRangeKey.value} $upperCompop ${rightUpperRangeKey.value})" - - logDebug(s"lowerCompExp: $lowerCompExp") - logDebug(s"upperCompExp: $upperCompExp") - - if(lowerRangeExpression.isEmpty || rightLowerKeys.size == 0) { + logInfo("SortMergeJoinE xec: generating inner range join scanner") + // Create class member for next row from both sides. + // Inline mutable state since not many join operations in a task + val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) + val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) + val rightTmpRow = if(useSecondaryRange) ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) + else "" + + // Create variables for join keys from both sides. + val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) + val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") + val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) + val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") + // Copy the right key as class members so they could be used in next function call. + val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) + + val rangeKeys = rangeConditions.map{ + case GreaterThan(l, r) => (Some(l), None, Some(r), None) + case GreaterThanOrEqual(l, r) => (Some(l), None, Some(r), None) + case LessThan(l, r) => (None, Some(l), None, Some(r)) + case LessThanOrEqual(l, r) => (None, Some(l), None, Some(r)) + } + val (leftLowerKeys, leftUpperKeys, rightLowerKeys, rightUpperKeys) = + (rangeKeys.map(_._1).flatMap(x => x), + rangeKeys.map(_._2).flatMap(x => x), + rangeKeys.map(_._3).flatMap(x => x), + rangeKeys.map(_._4).flatMap(x => x)) + + // Variables for secondary range expressions + val (leftLowerKeyVars, leftUpperKeyVars, rightLowerKeyVars, rightUpperKeyVars) = + if(useSecondaryRange) + (createJoinKey(ctx, leftRow, leftLowerKeys, left.output), + createJoinKey(ctx, leftRow, leftUpperKeys, left.output), + createJoinKey(ctx, rightRow, rightLowerKeys, right.output), + createJoinKey(ctx, rightRow, rightUpperKeys, right.output)) + else (null, null, null, null) + + val secRangeDataType = if(leftLowerKeys.size > 0) leftLowerKeys(0).dataType + else if(leftUpperKeys.size > 0) leftUpperKeys(0).dataType + else null + val secRangeInitValue = CodeGenerator.defaultValue(secRangeDataType) + + val (leftLowerSecRangeKey, leftUpperSecRangeKey, rightLowerSecRangeKey, rightUpperSecRangeKey) = + if(useSecondaryRange) + (ctx.addBufferedState(secRangeDataType, "leftLowerSecRangeKey", secRangeInitValue), + ctx.addBufferedState(secRangeDataType, "leftUpperSecRangeKey", secRangeInitValue), + ctx.addBufferedState(secRangeDataType, "rightLowerSecRangeKey", secRangeInitValue), + ctx.addBufferedState(secRangeDataType, "rightUpperSecRangeKey", secRangeInitValue)) + else (null, null, null, null) + + // A queue to hold all matched rows from right side. + val clsName = if(useSecondaryRange) classOf[InMemoryUnsafeRowQueue].getName + else classOf[ExternalAppendOnlyUnsafeRowArray].getName + + val spillThreshold = getSpillThreshold + val inMemoryThreshold = getInMemoryThreshold + + val matches = ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) + val matchedKeyVars = copyKeys(ctx, leftKeyVars) + + val lowerCompop = lowerSecondaryRangeExpression.map { + case GreaterThanOrEqual(_, _) => "<" + case GreaterThan(_, _) => "<=" + case _ => "" + }.getOrElse("") + val upperCompop = upperSecondaryRangeExpression.map { + case LessThanOrEqual(_, _) => ">" + case LessThan(_, _) => ">=" + case _ => "" + }.getOrElse("") + val lowerCompExp = if(lowerSecondaryRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftLowerSecRangeKey.value} $lowerCompop ${rightLowerSecRangeKey.value})" + val upperCompExp = if(upperSecondaryRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftUpperSecRangeKey.value} $upperCompop ${rightUpperSecRangeKey.value})" + + logDebug(s"lowerCompExp: $lowerCompExp") + logDebug(s"upperCompExp: $upperCompExp") + + // Add secondary range dequeue method + if(useSecondaryRange) { + if (lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { ctx.addNewFunction("dequeueUntilLowerConditionHolds", "private void dequeueUntilLowerConditionHolds() { }", inlineToOuterClass = true) @@ -567,7 +602,7 @@ case class SortMergeJoinExec( | return; | $rightTmpRow = $matches.get(0); | $javaType tempVal = getRightTmpRangeValue(); - | while(${leftLowerRangeKey.value} $upperCompop tempVal) { + | while(${leftLowerSecRangeKey.value} $upperCompop tempVal) { | $matches.dequeue(); | if($matches.isEmpty()) | break; @@ -577,166 +612,89 @@ case class SortMergeJoinExec( |} """.stripMargin, inlineToOuterClass = true) } - - val (leftLowAssignCode, rightLowAssignCode) = lowerRangeExpression.map(_ => - (s"${leftLowerRangeKey.value} = ${leftLowerKeyVars(0).value};", s"${rightLowerRangeKey.value} = ${rightLowerKeyVars(0).value};")). - getOrElse(("", "")) - val (leftUpperAssignCode, rightUpperAssignCode) = lowerRangeExpression.map(_ => - (s"${leftUpperRangeKey.value} = ${leftUpperKeyVars(0).value};", s"${rightUpperRangeKey.value} = ${rightUpperKeyVars(0).value};")). - getOrElse(("", "")) - - ctx.addNewFunction("findNextInnerJoinRows", - s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; - | int comp = 0; - | while ($leftRow == null) { - | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); - | ${leftKeyVars.map(_.code).mkString("\n")} - | if ($leftAnyNull) { - | $leftRow = null; - | continue; - | } - | ${leftLowerKeyVars.map(_.code).mkString("\n")} - | ${leftUpperKeyVars.map(_.code).mkString("\n")} - | $leftLowAssignCode - | $leftUpperAssignCode - | if (!$matches.isEmpty()) { - | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} - | if (comp == 0) { - | dequeueUntilLowerConditionHolds(); - | } - | else { - | $matches.clear(); - | } - | } - | - | do { - | if ($rightRow == null) { - | if (!rightIter.hasNext()) { - | ${matchedKeyVars.map(_.code).mkString("\n")} - | return !$matches.isEmpty(); - | } - | $rightRow = (InternalRow) rightIter.next(); - | ${rightKeyTmpVars.map(_.code).mkString("\n")} - | if ($rightAnyNull) { - | $rightRow = null; - | continue; - | } - | ${rightKeyVars.map(_.code).mkString("\n")} - | ${rightLowerKeyVars.map(_.code).mkString("\n")} - | ${rightUpperKeyVars.map(_.code).mkString("\n")} - | $rightLowAssignCode - | $rightUpperAssignCode - | } - | ${genComparison(ctx, leftKeyVars, rightKeyVars)} - | if (comp > 0 $upperCompExp) { - | $rightRow = null; - | } else if (comp < 0 $lowerCompExp) { - | if (!$matches.isEmpty()) { - | ${matchedKeyVars.map(_.code).mkString("\n")} - | return true; - | } - | $leftRow = null; - | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null; - | } - | } while ($leftRow != null); - | } - | return false; // unreachable - |} - """.stripMargin, inlineToOuterClass = true) - - (leftRow, matches) - } - else { - // Create class member for next row from both sides. - // Inline mutable state since not many join operations in a task - val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) - val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) - - // Create variables for join keys from both sides. - val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) - val leftAnyNull = leftKeyVars.map(_.isNull).mkString(" || ") - val rightKeyTmpVars = createJoinKey(ctx, rightRow, rightKeys, right.output) - val rightAnyNull = rightKeyTmpVars.map(_.isNull).mkString(" || ") - // Copy the right key as class members so they could be used in next function call. - val rightKeyVars = copyKeys(ctx, rightKeyTmpVars) - - // A list to hold all matched rows from right side. - val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName - - val spillThreshold = getSpillThreshold - val inMemoryThreshold = getInMemoryThreshold - - // Inline mutable state since not many join operations in a task - val matches = ctx.addMutableState(clsName, "matches", - v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) - // Copy the left keys as class members so they could be used in next function call. - val matchedKeyVars = copyKeys(ctx, leftKeyVars) - - ctx.addNewFunction("findNextInnerJoinRows", - s""" - |private boolean findNextInnerJoinRows( - | scala.collection.Iterator leftIter, - | scala.collection.Iterator rightIter) { - | $leftRow = null; - | int comp = 0; - | while ($leftRow == null) { - | if (!leftIter.hasNext()) return false; - | $leftRow = (InternalRow) leftIter.next(); - | ${leftKeyVars.map(_.code).mkString("\n")} - | if ($leftAnyNull) { - | $leftRow = null; - | continue; - | } - | if (!$matches.isEmpty()) { - | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} - | if (comp == 0) { - | return true; - | } - | $matches.clear(); - | } - | - | do { - | if ($rightRow == null) { - | if (!rightIter.hasNext()) { - | ${matchedKeyVars.map(_.code).mkString("\n")} - | return !$matches.isEmpty(); - | } - | $rightRow = (InternalRow) rightIter.next(); - | ${rightKeyTmpVars.map(_.code).mkString("\n")} - | if ($rightAnyNull) { - | $rightRow = null; - | continue; - | } - | ${rightKeyVars.map(_.code).mkString("\n")} - | } - | ${genComparison(ctx, leftKeyVars, rightKeyVars)} - | if (comp > 0) { - | $rightRow = null; - | } else if (comp < 0) { - | if (!$matches.isEmpty()) { - | ${matchedKeyVars.map(_.code).mkString("\n")} - | return true; - | } - | $leftRow = null; - | } else { - | $matches.add((UnsafeRow) $rightRow); - | $rightRow = null;; - | } - | } while ($leftRow != null); - | } - | return false; // unreachable - |} - """.stripMargin, inlineToOuterClass = true) - - (leftRow, matches) } + val (leftLowVarsCode, leftUpperVarsCode) = if(useSecondaryRange) + (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) + else ("", "") + val (rightLowVarsCode, rightUpperVarsCode) = if(useSecondaryRange) + (rightLowerKeyVars.map(_.code).mkString("\n"), rightUpperKeyVars.map(_.code).mkString("\n")) + else ("", "") + val (leftLowAssignCode, rightLowAssignCode) = if(leftLowerKeyVars.size > 0) lowerSecondaryRangeExpression.map(_ => + (s"${leftLowerSecRangeKey.value} = ${leftLowerKeyVars(0).value};", s"${rightLowerSecRangeKey.value} = ${rightLowerKeyVars(0).value};")). + getOrElse(("", "")) + else ("", "") + val (leftUpperAssignCode, rightUpperAssignCode) = if(leftUpperKeyVars.size > 0) lowerSecondaryRangeExpression.map(_ => + (s"${leftUpperSecRangeKey.value} = ${leftUpperKeyVars(0).value};", s"${rightUpperSecRangeKey.value} = ${rightUpperKeyVars(0).value};")). + getOrElse(("", "")) + else ("", "") + + ctx.addNewFunction("findNextInnerJoinRows", + s""" + |private boolean findNextInnerJoinRows( + | scala.collection.Iterator leftIter, + | scala.collection.Iterator rightIter) { + | $leftRow = null; + | int comp = 0; + | while ($leftRow == null) { + | if (!leftIter.hasNext()) return false; + | $leftRow = (InternalRow) leftIter.next(); + | ${leftKeyVars.map(_.code).mkString("\n")} + | if ($leftAnyNull) { + | $leftRow = null; + | continue; + | } + | $leftLowVarsCode + | $leftUpperVarsCode + | $leftLowAssignCode + | $leftUpperAssignCode + | if (!$matches.isEmpty()) { + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} + | if (comp == 0) { + | dequeueUntilLowerConditionHolds(); + | } + | else { + | $matches.clear(); + | } + | } + | + | do { + | if ($rightRow == null) { + | if (!rightIter.hasNext()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return !$matches.isEmpty(); + | } + | $rightRow = (InternalRow) rightIter.next(); + | ${rightKeyTmpVars.map(_.code).mkString("\n")} + | if ($rightAnyNull) { + | $rightRow = null; + | continue; + | } + | ${rightKeyVars.map(_.code).mkString("\n")} + | $rightLowVarsCode + | $rightUpperVarsCode + | $rightLowAssignCode + | $rightUpperAssignCode + | } + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} + | if (comp > 0 $upperCompExp) { + | $rightRow = null; + | } else if (comp < 0 $lowerCompExp) { + | if (!$matches.isEmpty()) { + | ${matchedKeyVars.map(_.code).mkString("\n")} + | return true; + | } + | $leftRow = null; + | } else { + | $matches.add((UnsafeRow) $rightRow); + | $rightRow = null; + | } + | } while ($leftRow != null); + | } + | return false; // unreachable + |} + """.stripMargin, inlineToOuterClass = true) + + (leftRow, matches) } /** From 3ff654a60a4947ba860df94e7c2f8b00320298fd Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Tue, 10 Apr 2018 11:56:21 +0200 Subject: [PATCH 06/51] Bug fix --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index f62243449e20e..258861e5bd8cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -530,7 +530,7 @@ case class SortMergeJoinExec( createJoinKey(ctx, leftRow, leftUpperKeys, left.output), createJoinKey(ctx, rightRow, rightLowerKeys, right.output), createJoinKey(ctx, rightRow, rightUpperKeys, right.output)) - else (null, null, null, null) + else (Nil, Nil, Nil, Nil) val secRangeDataType = if(leftLowerKeys.size > 0) leftLowerKeys(0).dataType else if(leftUpperKeys.size > 0) leftUpperKeys(0).dataType From 85039bbbf5d99f3dcfaf1f338a0c84d63d9d6f3f Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 08:45:11 +0200 Subject: [PATCH 07/51] Scalastyle fixes --- .../sql/catalyst/planning/patterns.scala | 68 ++++++++++++------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 18bad3546ec31..a2358900eb9b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.catalyst.planning -import org.apache.avro.hadoop.file.HadoopCodecFactory +import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -26,8 +27,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import scala.collection.mutable - /** * A pattern that matches any number of project or filter operations on top of another relational * operator. All filter operators are collected and their conditions are broken up and returned @@ -136,36 +135,53 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } if (joinKeys.nonEmpty) { - // Find any simple comparison expressions between two columns + // Find any simple range expressions between two columns // (and involving only those two columns) // of the two tables being joined, // which are not used in the equijoin expressions, // and which can be used for secondary sort optimizations. - val rangePreds:mutable.Set[Expression] = mutable.Set.empty - val rangeConditions:Seq[Expression] = otherPredicates.flatMap { + val rangePreds : mutable.Set[Expression] = mutable.Set.empty + var rangeConditions : Seq[BinaryComparison] = otherPredicates.flatMap { case p @ LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { case "asis" => rangePreds.add(p); Some(LessThan(l, r)) case "vs" => rangePreds.add(p); Some(GreaterThan(r, l)) case _ => None } - case p @ LessThanOrEqual(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(LessThanOrEqual(l, r)) - case "vs" => rangePreds.add(p); Some(GreaterThanOrEqual(r, l)) - case _ => None + case p @ LessThanOrEqual(l, r) => + isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(LessThanOrEqual(l, r)) + case "vs" => rangePreds.add(p); Some(GreaterThanOrEqual(r, l)) + case _ => None } case p @ GreaterThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { case "asis" => rangePreds.add(p); Some(GreaterThan(l, r)) case "vs" => rangePreds.add(p); Some(LessThan(r, l)) case _ => None } - case p @ GreaterThanOrEqual(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(GreaterThanOrEqual(l, r)) - case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) - case _ => None + case p @ GreaterThanOrEqual(l, r) => + isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(GreaterThanOrEqual(l, r)) + case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) + case _ => None } } val (leftKeys, rightKeys) = joinKeys.unzip + // Only using secondary join optimization when both lower and upper conditions + // are specified (e.g. t1.a < t2.b + x and t1.a > t2.b - x) + if(rangeConditions.size != 2 || + // Looking for one < and one > comparison: + rangeConditions.filter(x => x.isInstanceOf[LessThan] || + x.isInstanceOf[LessThanOrEqual]).size == 0 || + rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || + x.isInstanceOf[GreaterThanOrEqual]).size == 0 || + // Both comparisons reference the same columns: + rangeConditions.map(c => c.left).distinct.size != 1 || + rangeConditions.map(c => c.right).distinct.size != 1) { + rangeConditions = Nil + rangePreds.clear() + } + Some((joinType, leftKeys, rightKeys, rangeConditions, otherPredicates.filterNot(rangePreds.contains(_)).reduceOption(And), left, right)) } else { @@ -174,24 +190,28 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case _ => None } - private def isValidRangeCondition(l:Expression, r:Expression, left:LogicalPlan, right:LogicalPlan, - joinKeys:Seq[(Expression, Expression)]) = { + private def isValidRangeCondition(l : Expression, r : Expression, + left : LogicalPlan, right : LogicalPlan, + joinKeys : Seq[(Expression, Expression)]) = { val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) - if(lattrs.size != 1 || rattrs.size != 1) + if(lattrs.size != 1 || rattrs.size != 1) { "none" - else if (canEvaluate(l, left) && canEvaluate(r, right)) { - val equiset = joinKeys.filter{ case (ljk:Expression, rjk:Expression) => - ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0))} - if(equiset.isEmpty) + } + else if(canEvaluate(l, left) && canEvaluate(r, right)) { + val equiset = joinKeys.filter{ case (ljk : Expression, rjk : Expression) => + ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) } + if(equiset.isEmpty) { "asis" + } else "none" } else if (canEvaluate(l, right) && canEvaluate(r, left)) { - val equiset = joinKeys.filter{ case (ljk:Expression, rjk:Expression) => - rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0))} - if(equiset.isEmpty) + val equiset = joinKeys.filter{ case (ljk : Expression, rjk : Expression) => + rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0)) } + if(equiset.isEmpty) { "vs" + } else "none" } From bbcb400a1bb21b597226b47ab5302bca52f18ba7 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 08:45:43 +0200 Subject: [PATCH 08/51] Scalastyle fixes --- .../scala/org/apache/spark/sql/execution/SparkStrategies.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cdfe3308c2096..79e00b3d3baf9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -638,8 +638,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.RangeExec(r) :: Nil case r: logical.RepartitionByExpression => exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil - case r: logical.FixedRangeRepartitionByExpression => - exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, r.outputOrdering) :: Nil From 640aa6d7d075b24f1219d20c0617869d901f3f4f Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 08:46:28 +0200 Subject: [PATCH 09/51] SMJ range join unit tests --- .../sql/execution/joins/InnerJoinSuite.scala | 62 ++++++++++++++++++- .../sql/execution/joins/OuterJoinSuite.scala | 3 + 2 files changed, 64 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index b94553da10cab..782c87104b7d1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -70,6 +70,19 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { (3, 2) ).toDF("a", "b") + private lazy val rangeTestData1 = Seq( + (1, 1), (1, 2), (1, 3), + (2, 1), (2, 2), (2, 3), + (3, 1), (3, 2), (3, 3), + (4, 1), (4, 2), (4, 3) + ).toDF("a", "b") + + private lazy val rangeTestData2 = Seq( + (1, 1), (1, 2), (1, 3), + (2, 1), (2, 2), (2, 3), + (3, 1), (3, 2), (3, 3) + ).toDF("a", "b") + // Note: the input dataframes and expression must be evaluated lazily because // the SQLContext should be used only within a test to keep SQL tests stable private def testInnerJoin( @@ -77,7 +90,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftRows: => DataFrame, rightRows: => DataFrame, condition: () => Expression, - expectedAnswer: Seq[Product]): Unit = { + expectedAnswer: Seq[Product], + expectRangeJoin: Boolean = false): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) @@ -129,6 +143,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using BroadcastHashJoin (build=left)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -141,6 +157,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using BroadcastHashJoin (build=right)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeBroadcastHashJoin( @@ -153,6 +171,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using ShuffledHashJoin (build=left)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -165,6 +185,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using ShuffledHashJoin (build=right)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeShuffledHashJoin( @@ -177,6 +199,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), @@ -272,6 +296,42 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ) } + { + lazy val left = rangeTestData1 + lazy val right = rangeTestData2 + testInnerJoin( + "inner range join", + left, + right, + () => ((left("a") === right("a")) and (left("b") <= right("b")-1) + and (left("b") >= right("b")+1)).expr, + Seq( + (1, 1, 1, 1), + (1, 1, 1, 2), + (1, 2, 1, 1), + (1, 2, 1, 2), + (1, 2, 1, 3), + (1, 3, 1, 2), + (1, 3, 1, 3), + (2, 1, 2, 1), + (2, 1, 2, 2), + (2, 2, 2, 1), + (2, 2, 2, 2), + (2, 2, 2, 3), + (2, 3, 2, 2), + (2, 3, 2, 3), + (3, 1, 3, 1), + (3, 1, 3, 2), + (3, 2, 3, 1), + (3, 2, 3, 2), + (3, 2, 3, 3), + (3, 3, 3, 2), + (3, 3, 3, 3) + ), + true + ) + } + { def df: DataFrame = spark.range(3).selectExpr("struct(id, id) as key", "id as value") lazy val left = df.selectExpr("key", "concat('L', value) as value").alias("left") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 417b356f4e07c..257c58850be30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -79,6 +79,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { if (joinType != FullOuter) { test(s"$testName using ShuffledHashJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -100,6 +101,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { case _ => fail(s"Unsupported join type $joinType") } extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastHashJoinExec( @@ -113,6 +115,7 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { test(s"$testName using SortMergeJoin") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(spark.sessionState.conf).apply( From 2548b7d826c728f26263bfa3432e75dc19a38214 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 17:32:13 +0200 Subject: [PATCH 10/51] Scalastyle --- .../sql/catalyst/planning/patterns.scala | 13 +- .../execution/InMemoryUnsafeRowQueue.scala | 82 +---------- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../exchange/EnsureRequirements.scala | 6 +- .../execution/joins/SortMergeJoinExec.scala | 128 +++++++++++------- .../spark/sql/execution/PlannerSuite.scala | 18 ++- .../sql/execution/joins/InnerJoinSuite.scala | 21 +-- 7 files changed, 125 insertions(+), 154 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index a2358900eb9b9..d1781752ff5ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -197,14 +197,15 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if(lattrs.size != 1 || rattrs.size != 1) { "none" } - else if(canEvaluate(l, left) && canEvaluate(r, right)) { + else if (canEvaluate(l, left) && canEvaluate(r, right)) { val equiset = joinKeys.filter{ case (ljk : Expression, rjk : Expression) => ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) } - if(equiset.isEmpty) { + if (equiset.isEmpty) { "asis" } - else + else { "none" + } } else if (canEvaluate(l, right) && canEvaluate(r, left)) { val equiset = joinKeys.filter{ case (ljk : Expression, rjk : Expression) => @@ -212,11 +213,13 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if(equiset.isEmpty) { "vs" } - else + else { "none" + } } - else + else { "none" + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index 1112326d8d3da..eef3d6766a02f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -19,18 +19,16 @@ package org.apache.spark.sql.execution import java.util.ConcurrentModificationException -import org.apache.spark.internal.Logging +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerManager import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer import org.apache.spark.storage.BlockManager -import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} import org.apache.spark.{SparkEnv, TaskContext} -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - /** * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which @@ -101,12 +99,6 @@ private[sql] class InMemoryUnsafeRowQueue( * Clears up resources (eg. memory) held by the backing storage */ override def clear(): Unit = { - /*if (spillableArray != null) { - // The last `spillableArray` of this task will be cleaned up via task completion listener - // inside `UnsafeExternalSorter` - spillableArray.cleanupResources() - spillableArray = null - } else*/ if (inMemoryQueue != null) { inMemoryQueue.clear() } @@ -116,15 +108,16 @@ private[sql] class InMemoryUnsafeRowQueue( } def dequeue(): Option[UnsafeRow] = { - if(numRows == 0) + if (numRows == 0) { None + } else { numRows -= 1 Some(inMemoryQueue.dequeue()) } } - def get(idx:Int): UnsafeRow = { + def get(idx: Int): UnsafeRow = { inMemoryQueue(idx) } @@ -133,44 +126,6 @@ private[sql] class InMemoryUnsafeRowQueue( inMemoryQueue += unsafeRow.copy() } else { throw new RuntimeException(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows") - /*if (spillableArray == null) { - logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + - s"${classOf[UnsafeExternalSorter].getName}") - - // We will not sort the rows, so prefixComparator and recordComparator are null - spillableArray = UnsafeExternalSorter.create( - taskMemoryManager, - blockManager, - serializerManager, - taskContext, - null, - null, - initialSize, - pageSizeBytes, - numRowsSpillThreshold, - false) - - // populate with existing in-memory buffered rows - if (inMemoryBuffer != null) { - inMemoryBuffer.foreach(existingUnsafeRow => - spillableArray.insertRecord( - existingUnsafeRow.getBaseObject, - existingUnsafeRow.getBaseOffset, - existingUnsafeRow.getSizeInBytes, - 0, - false) - ) - inMemoryBuffer.clear() - } - numFieldsPerRow = unsafeRow.numFields() - } - - spillableArray.insertRecord( - unsafeRow.getBaseObject, - unsafeRow.getBaseOffset, - unsafeRow.getSizeInBytes, - 0, - false)*/ } numRows += 1 @@ -191,15 +146,9 @@ private[sql] class InMemoryUnsafeRowQueue( s"Total elements: $numRows, requested `startIndex`: $startIndex") } - //if (spillableArray == null) { - new InMemoryBufferIterator(startIndex) - /*} else { - new SpillableArrayIterator(spillableArray.getIterator(startIndex), numFieldsPerRow) - }*/ + new InMemoryBufferIterator(startIndex) } -// override def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) - private[this] abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { private val expectedModificationsCount = modificationsCount @@ -229,23 +178,6 @@ private[sql] class InMemoryUnsafeRowQueue( result } } - - /*private[this] class SpillableArrayIterator( - iterator: UnsafeSorterIterator, - numFieldPerRow: Int) - extends ExternalAppendOnlyUnsafeRowArrayIterator { - - private val currentRow = new UnsafeRow(numFieldPerRow) - - override def hasNext(): Boolean = !isModified() && iterator.hasNext - - override def next(): UnsafeRow = { - throwExceptionIfModified() - iterator.loadNext() - currentRow.pointTo(iterator.getBaseObject, iterator.getBaseOffset, iterator.getRecordLength) - currentRow - } - }*/ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 79e00b3d3baf9..7db6992a58edb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -278,12 +278,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case ExtractEquiJoinKeys(joinType, leftKeys, rightKeys, rangeConds, condition, left, right) if RowOrdering.isOrderable(leftKeys) => -// val cond = (rangeConds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And) - logDebug(s"SMJExecJoinType: $joinType") - //TODO check if left and right are reading from bucketed tables and those are bucketed by join key - // and sorted by keys in range conditions - joins.SortMergeJoinExec( - leftKeys, rightKeys, joinType, rangeConds, condition, planLater(left), planLater(right)) :: Nil + joins.SortMergeJoinExec(leftKeys, rightKeys, joinType, rangeConds, condition, + planLater(left), planLater(right)) :: Nil // --- Without joining keys ------------------------------------------------------------ @@ -392,7 +388,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { if left.isStreaming && right.isStreaming => new StreamingSymmetricHashJoinExec( - leftKeys, rightKeys, joinType, (rangePreds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And), + leftKeys, rightKeys, joinType, + (rangePreds ++ condition.map(x => Seq(x)).getOrElse(Nil)).reduceOption(And), planLater(left), planLater(right)) :: Nil case Join(left, right, _, _) if left.isStreaming && right.isStreaming => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 94c83a87c0a1f..10cd5b2bbf485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -284,10 +284,12 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { ShuffledHashJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, buildSide, condition, left, right) - case SortMergeJoinExec(leftKeys, rightKeys, joinType, rangeConditions, condition, left, right) => + case SortMergeJoinExec(leftKeys, rightKeys, joinType, rangeConditions, + condition, left, right) => val (reorderedLeftKeys, reorderedRightKeys) = reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) - SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, rangeConditions, condition, left, right) + SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, rangeConditions, + condition, left, right) case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 258861e5bd8cc..ef9a51ec7c20f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -42,7 +42,8 @@ case class SortMergeJoinExec( left: SparkPlan, right: SparkPlan) extends BinaryExecNode with PredicateHelper with CodegenSupport { - logDebug(s"SortMergeJoinExec args: leftKeys: $leftKeys, rightKeys: $rightKeys, joinType: $joinType," + + logDebug(s"SortMergeJoinExec args: leftKeys: $leftKeys, rightKeys: $rightKeys, " + + s"joinType: $joinType," + s" rangeConditions: $rangeConditions, " + s"condition: $condition, left: $left, right: $right") @@ -53,13 +54,15 @@ case class SortMergeJoinExec( private val lowerSecondaryRangeExpression : Option[Expression] = { logDebug(s"Finding secondary greaterThan expressions in $rangeConditions") - val thefind = rangeConditions.find(p => p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) + val thefind = rangeConditions.find(p => + p.isInstanceOf[GreaterThan] || p.isInstanceOf[GreaterThanOrEqual]) logDebug(s"Found secondary greaterThan expression: $thefind") thefind } private val upperSecondaryRangeExpression : Option[Expression] = { logDebug(s"Finding secondary lowerThan expressions in $rangeConditions") - val thefind = rangeConditions.find(p => p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) + val thefind = rangeConditions.find(p => + p.isInstanceOf[LessThan] || p.isInstanceOf[LessThanOrEqual]) logDebug(s"Found secondary lowerThan expression: $thefind") thefind } @@ -220,30 +223,32 @@ case class SortMergeJoinExec( private[this] var currentLeftRow: InternalRow = _ private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null - private[this] val smjScanner = if(lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined) { - new SortMergeJoinInnerRangeScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold, - lowerRangeCondition, - upperRangeCondition - ) - } - else { - new SortMergeJoinScanner( - createLeftKeyGenerator(), - createRightKeyGenerator(), - keyOrdering, - RowIterator.fromScala(leftIter), - RowIterator.fromScala(rightIter), - inMemoryThreshold, - spillThreshold - ) - } + private[this] val smjScanner = + if(lowerSecondaryRangeExpression.isDefined || + upperSecondaryRangeExpression.isDefined) { + new SortMergeJoinInnerRangeScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold, + lowerRangeCondition, + upperRangeCondition + ) + } + else { + new SortMergeJoinScanner( + createLeftKeyGenerator(), + createRightKeyGenerator(), + keyOrdering, + RowIterator.fromScala(leftIter), + RowIterator.fromScala(rightIter), + inMemoryThreshold, + spillThreshold + ) + } private[this] val joinRow = new JoinedRow if (smjScanner.findNextInnerJoinRows()) { @@ -500,7 +505,8 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) - val rightTmpRow = if(useSecondaryRange) ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) + val rightTmpRow = if (useSecondaryRange) + ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) else "" // Create variables for join keys from both sides. @@ -525,25 +531,31 @@ case class SortMergeJoinExec( // Variables for secondary range expressions val (leftLowerKeyVars, leftUpperKeyVars, rightLowerKeyVars, rightUpperKeyVars) = - if(useSecondaryRange) + if (useSecondaryRange) { (createJoinKey(ctx, leftRow, leftLowerKeys, left.output), createJoinKey(ctx, leftRow, leftUpperKeys, left.output), createJoinKey(ctx, rightRow, rightLowerKeys, right.output), createJoinKey(ctx, rightRow, rightUpperKeys, right.output)) - else (Nil, Nil, Nil, Nil) + } + else { + (Nil, Nil, Nil, Nil) + } - val secRangeDataType = if(leftLowerKeys.size > 0) leftLowerKeys(0).dataType - else if(leftUpperKeys.size > 0) leftUpperKeys(0).dataType + val secRangeDataType = if(leftLowerKeys.size > 0) { leftLowerKeys(0).dataType } + else if (leftUpperKeys.size > 0) { leftUpperKeys(0).dataType } else null val secRangeInitValue = CodeGenerator.defaultValue(secRangeDataType) val (leftLowerSecRangeKey, leftUpperSecRangeKey, rightLowerSecRangeKey, rightUpperSecRangeKey) = - if(useSecondaryRange) + if (useSecondaryRange) { (ctx.addBufferedState(secRangeDataType, "leftLowerSecRangeKey", secRangeInitValue), ctx.addBufferedState(secRangeDataType, "leftUpperSecRangeKey", secRangeInitValue), ctx.addBufferedState(secRangeDataType, "rightLowerSecRangeKey", secRangeInitValue), ctx.addBufferedState(secRangeDataType, "rightUpperSecRangeKey", secRangeInitValue)) - else (null, null, null, null) + } + else { + (null, null, null, null) + } // A queue to hold all matched rows from right side. val clsName = if(useSecondaryRange) classOf[InMemoryUnsafeRowQueue].getName @@ -566,23 +578,27 @@ case class SortMergeJoinExec( case LessThan(_, _) => ">=" case _ => "" }.getOrElse("") - val lowerCompExp = if(lowerSecondaryRangeExpression.isEmpty) "" - else s" || (comp == 0 && ${leftLowerSecRangeKey.value} $lowerCompop ${rightLowerSecRangeKey.value})" - val upperCompExp = if(upperSecondaryRangeExpression.isEmpty) "" - else s" || (comp == 0 && ${leftUpperSecRangeKey.value} $upperCompop ${rightUpperSecRangeKey.value})" + val lowerCompExp = if (lowerSecondaryRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftLowerSecRangeKey.value} " + + s"$lowerCompop ${rightLowerSecRangeKey.value})" + val upperCompExp = if (upperSecondaryRangeExpression.isEmpty) "" + else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " + + s"$upperCompop ${rightUpperSecRangeKey.value})" logDebug(s"lowerCompExp: $lowerCompExp") logDebug(s"upperCompExp: $upperCompExp") // Add secondary range dequeue method if(useSecondaryRange) { - if (lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { + if (lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || + rightUpperKeys.size == 0) { ctx.addNewFunction("dequeueUntilLowerConditionHolds", "private void dequeueUntilLowerConditionHolds() { }", inlineToOuterClass = true) } else { - val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, rightUpperKeys.slice(0, 1), right.output) + val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, + rightUpperKeys.slice(0, 1), right.output) val rightRngTmpKeyVarsDecl = rightRngTmpKeyVars.map(_.code).mkString("\n") rightRngTmpKeyVars.foreach(_.code = "") val javaType = CodeGenerator.javaType(rightLowerKeys(0).dataType) @@ -619,12 +635,16 @@ case class SortMergeJoinExec( val (rightLowVarsCode, rightUpperVarsCode) = if(useSecondaryRange) (rightLowerKeyVars.map(_.code).mkString("\n"), rightUpperKeyVars.map(_.code).mkString("\n")) else ("", "") - val (leftLowAssignCode, rightLowAssignCode) = if(leftLowerKeyVars.size > 0) lowerSecondaryRangeExpression.map(_ => - (s"${leftLowerSecRangeKey.value} = ${leftLowerKeyVars(0).value};", s"${rightLowerSecRangeKey.value} = ${rightLowerKeyVars(0).value};")). + val (leftLowAssignCode, rightLowAssignCode) = if(leftLowerKeyVars.size > 0) + lowerSecondaryRangeExpression.map(_ => + (s"${leftLowerSecRangeKey.value} = ${leftLowerKeyVars(0).value};", + s"${rightLowerSecRangeKey.value} = ${rightLowerKeyVars(0).value};")). getOrElse(("", "")) else ("", "") - val (leftUpperAssignCode, rightUpperAssignCode) = if(leftUpperKeyVars.size > 0) lowerSecondaryRangeExpression.map(_ => - (s"${leftUpperSecRangeKey.value} = ${leftUpperKeyVars(0).value};", s"${rightUpperSecRangeKey.value} = ${rightUpperKeyVars(0).value};")). + val (leftUpperAssignCode, rightUpperAssignCode) = if(leftUpperKeyVars.size > 0) + lowerSecondaryRangeExpression.map(_ => + (s"${leftUpperSecRangeKey.value} = ${leftUpperKeyVars(0).value};", + s"${rightUpperSecRangeKey.value} = ${rightUpperKeyVars(0).value};")). getOrElse(("", "")) else ("", "") @@ -1080,7 +1100,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches = - new InMemoryUnsafeRowQueue(inMemoryThreshold, spillThreshold)//ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new InMemoryUnsafeRowQueue(inMemoryThreshold, spillThreshold) private[this] val joinRow = new JoinedRow // Initialization (note: do _not_ want to advance streamed here). @@ -1109,10 +1129,12 @@ private[joins] class SortMergeJoinInnerRangeScanner( bufferedMatches.clear() false } else if (matchJoinKey != null && keyOrdering.compare(streamedRowKey, matchJoinKey) == 0) { - // The new streamed row has the same join key as the previous row, so the same matches can be used. + // The new streamed row has the same join key as the previous row, + // so the same matches can be used. // But lower and upper ranges might not hold anymore, so check them: // First dequeue all rows from the queue until the lower range condition holds. - // Then try to enqueue new rows with the same join key and for which the upper range condition holds. + // Then try to enqueue new rows with the same join key and for which the upper + // range condition holds. dequeueUntilLowerConditionHolds() bufferMatchingRows(true) true @@ -1190,7 +1212,8 @@ private[joins] class SortMergeJoinInnerRangeScanner( } /** - * Advance the buffered iterator as long as the join key is the same and the lower range condition is not satisfied. + * Advance the buffered iterator as long as the join key is the same and + * the lower range condition is not satisfied. * Skip rows with nulls. * @return Result of the join key comparison. */ @@ -1202,8 +1225,9 @@ private[joins] class SortMergeJoinInnerRangeScanner( if(!lowCheck) while(!lowCheck && comp == 0 && advanceBufferedToRowWithNullFreeJoinKey()) { comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - if(comp == 0) + if(comp == 0) { lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) + } } comp } @@ -1219,8 +1243,9 @@ private[joins] class SortMergeJoinInnerRangeScanner( assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() - if(clear) + if (clear) { bufferedMatches.clear() + } var upperRangeOk = false var lowerRangeOk = false do { @@ -1230,7 +1255,8 @@ private[joins] class SortMergeJoinInnerRangeScanner( if(lowerRangeOk && upperRangeOk) bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) advanceBufferedToRowWithNullFreeJoinKey() - } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 && upperRangeOk) + } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 + && upperRangeOk) } private def dequeueUntilLowerConditionHolds(): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index e6447fa7c4a93..dfde1031fbd79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -544,7 +544,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when either side of join keys is required after inner SMJ") { Seq(Inner, Cross).foreach { joinType => - val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, None, planA, planB) + val innerSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, + None, planA, planB) // Both left and right keys should be sorted after the SMJ. Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -558,8 +559,10 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements skips sort when key order of a parent SMJ is propagated from its " + "child SMJ") { Seq(Inner, Cross).foreach { joinType => - val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, None, planA, planB) - val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, Nil, None, childSmj, planC) + val childSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, joinType, Nil, + None, planA, planB) + val parentSmj = SortMergeJoinExec(exprB :: Nil, exprC :: Nil, joinType, Nil, + None, childSmj, planC) // After the second SMJ, exprA, exprB and exprC should all be sorted. Seq(orderingA, orderingB, orderingC).foreach { ordering => assertSortRequirementsAreSatisfied( @@ -572,7 +575,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after left outer sort merge join") { // Only left key is sorted after left outer SMJ (thus doesn't need a sort). - val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, Nil, None, planA, planB) + val leftSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, LeftOuter, Nil, + None, planA, planB) Seq((orderingA, false), (orderingB, true)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = leftSmj, @@ -583,7 +587,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements for sort operator after right outer sort merge join") { // Only right key is sorted after right outer SMJ (thus doesn't need a sort). - val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, Nil, None, planA, planB) + val rightSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, RightOuter, Nil, + None, planA, planB) Seq((orderingA, true), (orderingB, false)).foreach { case (ordering, needSort) => assertSortRequirementsAreSatisfied( childPlan = rightSmj, @@ -594,7 +599,8 @@ class PlannerSuite extends SharedSQLContext { test("EnsureRequirements adds sort after full outer sort merge join") { // Neither keys is sorted after full outer SMJ, so they both need sorts. - val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, Nil, None, planA, planB) + val fullSmj = SortMergeJoinExec(exprA :: Nil, exprB :: Nil, FullOuter, Nil, + None, planA, planB) Seq(orderingA, orderingB).foreach { ordering => assertSortRequirementsAreSatisfied( childPlan = fullSmj, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 782c87104b7d1..75a815e79deab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -136,13 +136,14 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { boundCondition: Option[Expression], leftPlan: SparkPlan, rightPlan: SparkPlan) = { - val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, Nil, boundCondition, - leftPlan, rightPlan) + val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, Nil, + boundCondition, leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } test(s"$testName using BroadcastHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(!expectRangeJoin && rangeConditions.isEmpty || expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -156,7 +157,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using BroadcastHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(!expectRangeJoin && rangeConditions.isEmpty || expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -170,7 +172,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(!expectRangeJoin && rangeConditions.isEmpty || expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -184,7 +187,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(!expectRangeJoin && rangeConditions.isEmpty || expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -198,7 +202,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(!expectRangeJoin && rangeConditions.isEmpty || expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { @@ -304,7 +309,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { left, right, () => ((left("a") === right("a")) and (left("b") <= right("b")-1) - and (left("b") >= right("b")+1)).expr, + and (left("b") >= right("b") + 1)).expr, Seq( (1, 1, 1, 1), (1, 1, 1, 2), From 069bc019e91884deb7681d63741229d170a22bfc Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 20:32:16 +0200 Subject: [PATCH 11/51] Scalastyle --- .../execution/InMemoryUnsafeRowQueue.scala | 4 +- .../execution/joins/SortMergeJoinExec.scala | 148 +++++++++--------- .../execution/joins/ExistenceJoinSuite.scala | 9 +- .../sql/execution/joins/OuterJoinSuite.scala | 9 +- 4 files changed, 92 insertions(+), 78 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index eef3d6766a02f..d91a1f328303b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -24,10 +24,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerManager -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer import org.apache.spark.storage.BlockManager import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer /** * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ef9a51ec7c20f..7203752f0b15b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -224,7 +225,7 @@ case class SortMergeJoinExec( private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = - if(lowerSecondaryRangeExpression.isDefined || + if (lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined) { new SortMergeJoinInnerRangeScanner( createLeftKeyGenerator(), @@ -490,8 +491,8 @@ case class SortMergeJoinExec( """.stripMargin } - private def shouldUseSecondaryRangeJoin():Boolean = { - //TODO check sorting of the two relations? - Check it during planning? + private def shouldUseSecondaryRangeJoin(): Boolean = { + // TODO check sorting of the two relations? - Check it during planning? lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined } @@ -505,9 +506,10 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) - val rightTmpRow = if (useSecondaryRange) + val rightTmpRow = if (useSecondaryRange) { ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) - else "" + } + else { "" } // Create variables for join keys from both sides. val leftKeyVars = createJoinKey(ctx, leftRow, leftKeys, left.output) @@ -541,7 +543,7 @@ case class SortMergeJoinExec( (Nil, Nil, Nil, Nil) } - val secRangeDataType = if(leftLowerKeys.size > 0) { leftLowerKeys(0).dataType } + val secRangeDataType = if (leftLowerKeys.size > 0) { leftLowerKeys(0).dataType } else if (leftUpperKeys.size > 0) { leftUpperKeys(0).dataType } else null val secRangeInitValue = CodeGenerator.defaultValue(secRangeDataType) @@ -558,7 +560,7 @@ case class SortMergeJoinExec( } // A queue to hold all matched rows from right side. - val clsName = if(useSecondaryRange) classOf[InMemoryUnsafeRowQueue].getName + val clsName = if (useSecondaryRange) classOf[InMemoryUnsafeRowQueue].getName else classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold @@ -629,24 +631,28 @@ case class SortMergeJoinExec( """.stripMargin, inlineToOuterClass = true) } } - val (leftLowVarsCode, leftUpperVarsCode) = if(useSecondaryRange) - (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) - else ("", "") - val (rightLowVarsCode, rightUpperVarsCode) = if(useSecondaryRange) - (rightLowerKeyVars.map(_.code).mkString("\n"), rightUpperKeyVars.map(_.code).mkString("\n")) - else ("", "") - val (leftLowAssignCode, rightLowAssignCode) = if(leftLowerKeyVars.size > 0) + val (leftLowVarsCode, leftUpperVarsCode) = if (useSecondaryRange) { + (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) + } + else { ("", "") } + val (rightLowVarsCode, rightUpperVarsCode) = if (useSecondaryRange) { + (rightLowerKeyVars.map(_.code).mkString("\n"), rightUpperKeyVars.map(_.code).mkString("\n")) + } + else { ("", "") } + val (leftLowAssignCode, rightLowAssignCode) = if (leftLowerKeyVars.size > 0) { lowerSecondaryRangeExpression.map(_ => - (s"${leftLowerSecRangeKey.value} = ${leftLowerKeyVars(0).value};", - s"${rightLowerSecRangeKey.value} = ${rightLowerKeyVars(0).value};")). - getOrElse(("", "")) - else ("", "") - val (leftUpperAssignCode, rightUpperAssignCode) = if(leftUpperKeyVars.size > 0) + (s"${leftLowerSecRangeKey.value} = ${leftLowerKeyVars(0).value};", + s"${rightLowerSecRangeKey.value} = ${rightLowerKeyVars(0).value};")). + getOrElse(("", "")) + } + else { ("", "") } + val (leftUpperAssignCode, rightUpperAssignCode) = if (leftUpperKeyVars.size > 0) { lowerSecondaryRangeExpression.map(_ => - (s"${leftUpperSecRangeKey.value} = ${leftUpperKeyVars(0).value};", - s"${rightUpperSecRangeKey.value} = ${rightUpperKeyVars(0).value};")). - getOrElse(("", "")) - else ("", "") + (s"${leftUpperSecRangeKey.value} = ${leftUpperKeyVars(0).value};", + s"${rightUpperSecRangeKey.value} = ${rightUpperKeyVars(0).value};")). + getOrElse(("", "")) + } + else { ("", "") } ctx.addNewFunction("findNextInnerJoinRows", s""" @@ -1055,27 +1061,27 @@ private[joins] class SortMergeJoinScanner( } /** - * Helper class that is used to implement [[SortMergeJoinExec]]. - * - * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] - * which returns `true` if a result has been produced and `false` - * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return - * the matching row from the streamed input and may call [[getBufferedMatches]] to return the - * sequence of matching rows from the buffered input (in the case of an outer join, this will return - * an empty sequence if there are no matches from the buffered input). For efficiency, both of these - * methods return mutable objects which are re-used across calls to the `findNext*JoinRows()` - * methods. - * - * @param streamedKeyGenerator a projection that produces join keys from the streamed input. - * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. - * @param keyOrdering an ordering which can be used to compare join keys. - * @param streamedIter an input whose rows will be streamed. - * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that - * have the same join key. - * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by - * internal buffer - * @param spillThreshold Threshold for number of rows to be spilled by internal buffer - */ + * Helper class that is used to implement [[SortMergeJoinExec]]. + * + * To perform an inner (outer) join, users of this class call [[findNextInnerJoinRows()]] + * which returns `true` if a result has been produced and `false` + * otherwise. If a result has been produced, then the caller may call [[getStreamedRow]] to return + * the matching row from the streamed input and may call [[getBufferedMatches]] to return the + * sequence of matching rows from the buffered input (in the case of an outer join, this will + * return an empty sequence if there are no matches from the buffered input). For efficiency, + * both of these methods return mutable objects which are re-used across calls to + * the `findNext*JoinRows()` methods. + * + * @param streamedKeyGenerator a projection that produces join keys from the streamed input. + * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. + * @param keyOrdering an ordering which can be used to compare join keys. + * @param streamedIter an input whose rows will be streamed. + * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that + * have the same join key. + * @param inMemoryThreshold Threshold for number of rows guaranteed to be held in memory by + * internal buffer + * @param spillThreshold Threshold for number of rows to be spilled by internal buffer + */ private[joins] class SortMergeJoinInnerRangeScanner( streamedKeyGenerator: Projection, bufferedKeyGenerator: Projection, @@ -1094,8 +1100,8 @@ private[joins] class SortMergeJoinInnerRangeScanner( // Note: this is guaranteed to never have any null columns: private[this] var bufferedRowKey: InternalRow = _ /** - * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty - */ + * The join key for the rows buffered in `bufferedMatches`, or null if `bufferedMatches` is empty + */ private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ @@ -1113,11 +1119,11 @@ private[joins] class SortMergeJoinInnerRangeScanner( override def getBufferedMatches: InMemoryUnsafeRowQueue = bufferedMatches /** - * Advances both input iterators, stopping when we have found rows with matching join keys. - * @return true if matching rows have been found and false otherwise. If this returns true, then - * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join - * results. - */ + * Advances both input iterators, stopping when we have found rows with matching join keys. + * @return true if matching rows have been found and false otherwise. If this returns true, then + * [[getStreamedRow]] and [[getBufferedMatches]] can be called to construct the join + * results. + */ override final def findNextInnerJoinRows(): Boolean = { while (advanceStreamed() && streamedRowKey.anyNull) { // Advance the streamed side of the join until we find the next row whose join key contains @@ -1146,7 +1152,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. - var comp = -1//keyOrdering.compare(streamedRowKey, bufferedRowKey) + var comp = -1 // keyOrdering.compare(streamedRowKey, bufferedRowKey) do { if (streamedRowKey.anyNull) { advanceStreamed() @@ -1176,9 +1182,9 @@ private[joins] class SortMergeJoinInnerRangeScanner( // --- Private methods -------------------------------------------------------------------------- /** - * Advance the streamed iterator and compute the new row's join key. - * @return true if the streamed iterator returned a row and false otherwise. - */ + * Advance the streamed iterator and compute the new row's join key. + * @return true if the streamed iterator returned a row and false otherwise. + */ private def advanceStreamed(): Boolean = { if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow @@ -1192,9 +1198,9 @@ private[joins] class SortMergeJoinInnerRangeScanner( } /** - * Advance the buffered iterator until we find a row with join key that does not contain nulls. - * @return true if the buffered iterator returned a row and false otherwise. - */ + * Advance the buffered iterator until we find a row with join key that does not contain nulls. + * @return true if the buffered iterator returned a row and false otherwise. + */ private def advanceBufferedToRowWithNullFreeJoinKey(): Boolean = { var foundRow: Boolean = false while (!foundRow && bufferedIter.advanceNext()) { @@ -1212,29 +1218,30 @@ private[joins] class SortMergeJoinInnerRangeScanner( } /** - * Advance the buffered iterator as long as the join key is the same and - * the lower range condition is not satisfied. - * Skip rows with nulls. - * @return Result of the join key comparison. - */ + * Advance the buffered iterator as long as the join key is the same and + * the lower range condition is not satisfied. + * Skip rows with nulls. + * @return Result of the join key comparison. + */ private def checkLowerBoundAndAdvanceBuffered(): Int = { assert(bufferedRow != null) assert(streamedRow != null) var comp = 0 var lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) - if(!lowCheck) - while(!lowCheck && comp == 0 && advanceBufferedToRowWithNullFreeJoinKey()) { + if (!lowCheck) { + while (!lowCheck && comp == 0 && advanceBufferedToRowWithNullFreeJoinKey()) { comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - if(comp == 0) { + if (comp == 0) { lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) } } + } comp } /** - * Called when the streamed and buffered join keys match in order to buffer the matching rows. - */ + * Called when the streamed and buffered join keys match in order to buffer the matching rows. + */ private def bufferMatchingRows(clear: Boolean): Unit = { assert(streamedRowKey != null) assert(!streamedRowKey.anyNull) @@ -1252,15 +1259,16 @@ private[joins] class SortMergeJoinInnerRangeScanner( val jr = joinRow(streamedRow, bufferedRow) lowerRangeOk = lowerRangeCondition(jr) upperRangeOk = upperRangeCondition(jr) - if(lowerRangeOk && upperRangeOk) + if (lowerRangeOk && upperRangeOk) { bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) + } advanceBufferedToRowWithNullFreeJoinKey() } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 && upperRangeOk) } private def dequeueUntilLowerConditionHolds(): Unit = { - while(!bufferedMatches.isEmpty && !lowerRangeCondition(joinRow(streamedRow, bufferedRow))) { + while (!bufferedMatches.isEmpty && !lowerRangeCondition(joinRow(streamedRow, bufferedRow))) { bufferedMatches.dequeue() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala index 9e6a38765d2b4..1e3c6881c3364 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -102,7 +102,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rngConds, boundCondition, _, _) => + assert(rngConds.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -121,7 +122,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using BroadcastHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rngConds, boundCondition, _, _) => + assert(rngConds.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( @@ -140,7 +142,8 @@ class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rngConds, boundCondition, _, _) => + assert(rngConds.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => EnsureRequirements(left.sqlContext.sessionState.conf).apply( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 257c58850be30..5ba0d3314fa9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -78,7 +78,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { if (joinType != FullOuter) { test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { val buildSide = if (joinType == LeftOuter) BuildRight else BuildLeft @@ -100,7 +101,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { case RightOuter => BuildLeft case _ => fail(s"Unsupported join type $joinType") } - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -114,7 +116,8 @@ class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { } test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => From 3bc71c57ae6f0a46d9c0d2dd0439cddd887fd15b Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 20:36:12 +0200 Subject: [PATCH 12/51] Scalastyle --- .../apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala | 4 ++-- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index d91a1f328303b..7b557c380779a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -22,12 +22,12 @@ import java.util.ConcurrentModificationException import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.SerializerManager -import org.apache.spark.storage.BlockManager -import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer +import org.apache.spark.storage.BlockManager /** * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 7203752f0b15b..18dc41300f136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ From 5c62f98d848bacbfb0f1eec044d7ac489225cf43 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 20:47:55 +0200 Subject: [PATCH 13/51] Fix generated code - dequeue method missing --- .../execution/joins/SortMergeJoinExec.scala | 76 +++++++++---------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 18dc41300f136..84b729f8ec700 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -591,45 +591,43 @@ case class SortMergeJoinExec( logDebug(s"upperCompExp: $upperCompExp") // Add secondary range dequeue method - if(useSecondaryRange) { - if (lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || - rightUpperKeys.size == 0) { - ctx.addNewFunction("dequeueUntilLowerConditionHolds", - "private void dequeueUntilLowerConditionHolds() { }", - inlineToOuterClass = true) - } - else { - val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, - rightUpperKeys.slice(0, 1), right.output) - val rightRngTmpKeyVarsDecl = rightRngTmpKeyVars.map(_.code).mkString("\n") - rightRngTmpKeyVars.foreach(_.code = "") - val javaType = CodeGenerator.javaType(rightLowerKeys(0).dataType) - - ctx.addNewFunction("getRightTmpRangeValue", - s""" - |private $javaType getRightTmpRangeValue() { - | $rightRngTmpKeyVarsDecl - | return ${rightRngTmpKeyVars(0).value}; - |} - """.stripMargin) - - ctx.addNewFunction("dequeueUntilLowerConditionHolds", - s""" - |private void dequeueUntilLowerConditionHolds() { - | if($matches.isEmpty()) - | return; - | $rightTmpRow = $matches.get(0); - | $javaType tempVal = getRightTmpRangeValue(); - | while(${leftLowerSecRangeKey.value} $upperCompop tempVal) { - | $matches.dequeue(); - | if($matches.isEmpty()) - | break; - | $rightTmpRow = $matches.get(0); - | tempVal = getRightTmpRangeValue(); - | } - |} - """.stripMargin, inlineToOuterClass = true) - } + if (!useSecondaryRange || lowerSecondaryRangeExpression.isEmpty || + rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { + ctx.addNewFunction("dequeueUntilLowerConditionHolds", + "private void dequeueUntilLowerConditionHolds() { }", + inlineToOuterClass = true) + } + else { + val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, + rightUpperKeys.slice(0, 1), right.output) + val rightRngTmpKeyVarsDecl = rightRngTmpKeyVars.map(_.code).mkString("\n") + rightRngTmpKeyVars.foreach(_.code = "") + val javaType = CodeGenerator.javaType(rightLowerKeys(0).dataType) + + ctx.addNewFunction("getRightTmpRangeValue", + s""" + |private $javaType getRightTmpRangeValue() { + | $rightRngTmpKeyVarsDecl + | return ${rightRngTmpKeyVars(0).value}; + |} + """.stripMargin) + + ctx.addNewFunction("dequeueUntilLowerConditionHolds", + s""" + |private void dequeueUntilLowerConditionHolds() { + | if($matches.isEmpty()) + | return; + | $rightTmpRow = $matches.get(0); + | $javaType tempVal = getRightTmpRangeValue(); + | while(${leftLowerSecRangeKey.value} $upperCompop tempVal) { + | $matches.dequeue(); + | if($matches.isEmpty()) + | break; + | $rightTmpRow = $matches.get(0); + | tempVal = getRightTmpRangeValue(); + | } + |} + """.stripMargin, inlineToOuterClass = true) } val (leftLowVarsCode, leftUpperVarsCode) = if (useSecondaryRange) { (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) From 16e3e1b346a12dd6ffddd86f6f7e6e0abd644e1b Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 13 Apr 2018 20:53:24 +0200 Subject: [PATCH 14/51] Bug fix: include other binary comparisons in range conditions match --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index d1781752ff5ac..ab7a3485d8796 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -164,6 +164,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) case _ => None } + case _ => None } val (leftKeys, rightKeys) = joinKeys.unzip From e7f7bdfd3417976e7e4efb83e2c515bdfa57e4a6 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Mon, 16 Apr 2018 09:31:19 -0700 Subject: [PATCH 15/51] Test fix: sortWithinPartitions; Bug Fix: check references in rangeConditions, not columns --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 4 ++-- .../apache/spark/sql/execution/joins/InnerJoinSuite.scala | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ab7a3485d8796..6c3cbd727314c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -177,8 +177,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || x.isInstanceOf[GreaterThanOrEqual]).size == 0 || // Both comparisons reference the same columns: - rangeConditions.map(c => c.left).distinct.size != 1 || - rangeConditions.map(c => c.right).distinct.size != 1) { + rangeConditions.map(c => c.left.references).distinct.size != 1 || + rangeConditions.map(c => c.right.references).distinct.size != 1) { rangeConditions = Nil rangePreds.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 75a815e79deab..c13ab90cae6b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -302,14 +302,14 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } { - lazy val left = rangeTestData1 - lazy val right = rangeTestData2 + lazy val left = rangeTestData1.orderBy("a").sortWithinPartitions("b") + lazy val right = rangeTestData2.orderBy("a").sortWithinPartitions("b") testInnerJoin( "inner range join", left, right, - () => ((left("a") === right("a")) and (left("b") <= right("b")-1) - and (left("b") >= right("b") + 1)).expr, + () => ((left("a") === right("a")) and (left("b") <= right("b") + 1) + and (left("b") >= right("b") - 1)).expr, Seq( (1, 1, 1, 1), (1, 1, 1, 2), From 41cde2736c3aad4c6fca35e527b2523f589ed548 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Mon, 16 Apr 2018 16:58:44 -0700 Subject: [PATCH 16/51] Test fix --- .../apache/spark/sql/catalyst/planning/patterns.scala | 10 +++++++++- .../spark/sql/execution/joins/InnerJoinSuite.scala | 10 ++++++---- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6c3cbd727314c..85c175af3474e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -176,9 +176,17 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { x.isInstanceOf[LessThanOrEqual]).size == 0 || rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || x.isInstanceOf[GreaterThanOrEqual]).size == 0 || - // Both comparisons reference the same columns: + // Check if both comparisons reference the same columns: rangeConditions.map(c => c.left.references).distinct.size != 1 || rangeConditions.map(c => c.right.references).distinct.size != 1) { + logDebug(s"Clearing range conditions because: " + + s"${rangeConditions.size}, " + + s"${rangeConditions.filter(x => x.isInstanceOf[LessThan] || + x.isInstanceOf[LessThanOrEqual]).size}, " + + s"${rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || + x.isInstanceOf[GreaterThanOrEqual]).size}, " + + s"${rangeConditions.map(c => c.left.references)}, " + + s"${rangeConditions.map(c => c.right.references)}") rangeConditions = Nil rangePreds.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index c13ab90cae6b9..291635184670f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -73,12 +73,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { private lazy val rangeTestData1 = Seq( (1, 1), (1, 2), (1, 3), (2, 1), (2, 2), (2, 3), - (3, 1), (3, 2), (3, 3), + (3, 1), (3, 2), (3, 3), (3, 5), (4, 1), (4, 2), (4, 3) ).toDF("a", "b") private lazy val rangeTestData2 = Seq( - (1, 1), (1, 2), (1, 3), + (1, 1), (1, 2), (1, 3), (1, 5), (2, 1), (2, 2), (2, 3), (3, 1), (3, 2), (3, 3) ).toDF("a", "b") @@ -134,9 +134,10 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftKeys: Seq[Expression], rightKeys: Seq[Expression], boundCondition: Option[Expression], + rangeConditions: Seq[Expression], leftPlan: SparkPlan, rightPlan: SparkPlan) = { - val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, Nil, + val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, rangeConditions, boundCondition, leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } @@ -208,7 +209,8 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectRangeJoin && rangeConditions.size == 2) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan), + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, + leftPlan, rightPlan), expectedAnswer.map(Row.fromTuple), sortAnswers = true) } From 080ab0d5070ea4532c27f524ab183f65c0bc7f4b Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Mon, 16 Apr 2018 17:19:49 -0700 Subject: [PATCH 17/51] Test fix --- .../sql/execution/joins/InnerJoinSuite.scala | 88 +++++++++---------- 1 file changed, 44 insertions(+), 44 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 291635184670f..21e2972f6d733 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -142,62 +142,62 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } - test(s"$testName using BroadcastHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => - assert(!expectRangeJoin && rangeConditions.isEmpty || - expectRangeJoin && rangeConditions.size == 2) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if(!expectRangeJoin) { + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } - test(s"$testName using BroadcastHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => - assert(!expectRangeJoin && rangeConditions.isEmpty || - expectRangeJoin && rangeConditions.size == 2) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if(!expectRangeJoin) { + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } - test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => - assert(!expectRangeJoin && rangeConditions.isEmpty || - expectRangeJoin && rangeConditions.size == 2) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if(!expectRangeJoin) { + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } - test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => - assert(!expectRangeJoin && rangeConditions.isEmpty || - expectRangeJoin && rangeConditions.size == 2) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if(!expectRangeJoin) { + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } From 862821603868212cf642f04f7162a4d321f83a21 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Wed, 18 Apr 2018 15:26:48 -0700 Subject: [PATCH 18/51] Fix required child ordering for inner range queries --- .../spark/sql/catalyst/planning/patterns.scala | 6 +++--- .../sql/execution/joins/SortMergeJoinExec.scala | 17 +++++++++++++---- .../sql/execution/joins/InnerJoinSuite.scala | 4 ++-- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 85c175af3474e..ee2920b81b131 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -103,7 +103,7 @@ object PhysicalOperation extends PredicateHelper { object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ type ReturnType = - (JoinType, Seq[Expression], Seq[Expression], Seq[Expression], + (JoinType, Seq[Expression], Seq[Expression], Seq[BinaryComparison], Option[Expression], LogicalPlan, LogicalPlan) def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { @@ -177,8 +177,8 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || x.isInstanceOf[GreaterThanOrEqual]).size == 0 || // Check if both comparisons reference the same columns: - rangeConditions.map(c => c.left.references).distinct.size != 1 || - rangeConditions.map(c => c.right.references).distinct.size != 1) { + rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 || + rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) { logDebug(s"Clearing range conditions because: " + s"${rangeConditions.size}, " + s"${rangeConditions.filter(x => x.isInstanceOf[LessThan] || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 84b729f8ec700..cc9134936e10b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -38,7 +38,7 @@ case class SortMergeJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - rangeConditions: Seq[Expression], + rangeConditions: Seq[BinaryComparison], condition: Option[Expression], left: SparkPlan, right: SparkPlan) extends BinaryExecNode with PredicateHelper with CodegenSupport { @@ -72,6 +72,9 @@ case class SortMergeJoinExec( logDebug(s"Use secondary range join resolved to $useSecondaryRange.") + val lrKeys = rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct + val rrKeys = rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct + // private def findBucketSpec(plan: SparkPlan): Option[BucketSpec] = { // if(plan.isInstanceOf[FileSourceScanExec]) { // plan.asInstanceOf[FileSourceScanExec].relation.bucketSpec @@ -126,9 +129,15 @@ case class SortMergeJoinExec( override def outputOrdering: Seq[SortOrder] = joinType match { // For inner join, orders of both sides keys should be kept. + case _: InnerLike => - val leftKeyOrdering = getKeyOrdering(leftKeys, left.outputOrdering) - val rightKeyOrdering = getKeyOrdering(rightKeys, right.outputOrdering) + + logDebug(s"Left range ordering: ${leftKeys ++ lrKeys}") + logDebug(s"Right range ordering: ${rightKeys ++ rrKeys}") + + val leftKeyOrdering = getKeyOrdering(leftKeys ++ lrKeys, left.outputOrdering) + val rightKeyOrdering = getKeyOrdering(rightKeys ++ rrKeys, right.outputOrdering) + logDebug(s"outputOrdering results: $leftKeyOrdering and $rightKeyOrdering") leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => // Also add the right key and its `sameOrderExpressions` SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey @@ -166,7 +175,7 @@ case class SortMergeJoinExec( } override def requiredChildOrdering: Seq[Seq[SortOrder]] = - requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil + requiredOrders(leftKeys ++ lrKeys) :: requiredOrders(rightKeys ++ rrKeys) :: Nil private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 21e2972f6d733..b0482eabeebd8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{BinaryComparison, Expression} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join @@ -134,7 +134,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { leftKeys: Seq[Expression], rightKeys: Seq[Expression], boundCondition: Option[Expression], - rangeConditions: Seq[Expression], + rangeConditions: Seq[BinaryComparison], leftPlan: SparkPlan, rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, rangeConditions, From 7bd673223254d921c3f2e5ee26249ba22bf23ca1 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 09:45:07 -0700 Subject: [PATCH 19/51] Parameter for turning off inner range optimization --- .../apache/spark/sql/internal/SQLConf.scala | 17 ++++++ .../execution/joins/SortMergeJoinExec.scala | 60 ++++++++----------- 2 files changed, 41 insertions(+), 36 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 738d8fee891d1..188b3f1f9fc3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1447,6 +1447,19 @@ object SQLConf { .intConf .createWithDefault(Int.MaxValue) + val USE_SMJ_INNER_RANGE_OPTIMIZATION = + buildConf("spark.sql.join.smj.useInnerRangeOptimization") + .internal() + .doc("Sort-merge join 'inner range optimization' is applicable in cases where the join " + + "condition includes equality expressions on pairs of columns and a range expression " + + "involving two other columns, (e.g. t1.x = t2.x AND t1.y BETWEEN t2.y - d AND t2.y + d)." + + " If the inner range optimization is enabled, the number of rows considered for each " + + "match of equality conditions can be reduced considerably because a moving window, " + + "corresponding to the range conditions, will be used for iterating over matched rows " + + "in the right relation.") + .booleanConf + .createWithDefault(true) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1718,7 +1731,11 @@ class SQLConf extends Serializable with Logging { def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) +<<<<<<< HEAD def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) +======= + def useSmjInnerRangeOptimization: Boolean = getConf(USE_SMJ_INNER_RANGE_OPTIMIZATION) +>>>>>>> Parameter for turning off inner range optimization def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index cc9134936e10b..d8beb894a99c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -18,17 +18,17 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer - import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.collection.BitSet /** @@ -48,11 +48,6 @@ case class SortMergeJoinExec( s" rangeConditions: $rangeConditions, " + s"condition: $condition, left: $left, right: $right") -// val leftBucketSpec = findBucketSpec(left) -// val rightBucketSpec = findBucketSpec(right) -// -// logDebug(s"Found left bucket spec: $leftBucketSpec. Found right bucket spec: $rightBucketSpec") - private val lowerSecondaryRangeExpression : Option[Expression] = { logDebug(s"Finding secondary greaterThan expressions in $rangeConditions") val thefind = rangeConditions.find(p => @@ -68,25 +63,23 @@ case class SortMergeJoinExec( thefind } - val useSecondaryRange = shouldUseSecondaryRangeJoin() - - logDebug(s"Use secondary range join resolved to $useSecondaryRange.") + val useInnerRange = SQLConf.get.useSmjInnerRangeOptimization && + (lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined) - val lrKeys = rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct - val rrKeys = rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct + logDebug(s"Use secondary range join resolved to $useInnerRange.") -// private def findBucketSpec(plan: SparkPlan): Option[BucketSpec] = { -// if(plan.isInstanceOf[FileSourceScanExec]) { -// plan.asInstanceOf[FileSourceScanExec].relation.bucketSpec -// } -// else { -// val cb = plan.children.flatMap(c => findBucketSpec(c)) -// if (cb.size > 0) -// Some(cb(0)) -// else -// None -// } -// } + val lrKeys = if(useInnerRange) { + rangeConditions.flatMap(c => c.left.references.toSeq).distinct + } + else { + Nil + } + val rrKeys = if(useInnerRange) { + rangeConditions.flatMap(c => c.right.references.toSeq).distinct + } + else { + Nil + } override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) @@ -500,11 +493,6 @@ case class SortMergeJoinExec( """.stripMargin } - private def shouldUseSecondaryRangeJoin(): Boolean = { - // TODO check sorting of the two relations? - Check it during planning? - lowerSecondaryRangeExpression.isDefined || upperSecondaryRangeExpression.isDefined - } - /** * Generate a function to scan both left and right to find a match, returns the term for * matched one row from left side and buffered rows from right side. @@ -515,7 +503,7 @@ case class SortMergeJoinExec( // Inline mutable state since not many join operations in a task val leftRow = ctx.addMutableState("InternalRow", "leftRow", forceInline = true) val rightRow = ctx.addMutableState("InternalRow", "rightRow", forceInline = true) - val rightTmpRow = if (useSecondaryRange) { + val rightTmpRow = if (useInnerRange) { ctx.addMutableState("InternalRow", "rightTmpRow", forceInline = true) } else { "" } @@ -542,7 +530,7 @@ case class SortMergeJoinExec( // Variables for secondary range expressions val (leftLowerKeyVars, leftUpperKeyVars, rightLowerKeyVars, rightUpperKeyVars) = - if (useSecondaryRange) { + if (useInnerRange) { (createJoinKey(ctx, leftRow, leftLowerKeys, left.output), createJoinKey(ctx, leftRow, leftUpperKeys, left.output), createJoinKey(ctx, rightRow, rightLowerKeys, right.output), @@ -558,7 +546,7 @@ case class SortMergeJoinExec( val secRangeInitValue = CodeGenerator.defaultValue(secRangeDataType) val (leftLowerSecRangeKey, leftUpperSecRangeKey, rightLowerSecRangeKey, rightUpperSecRangeKey) = - if (useSecondaryRange) { + if (useInnerRange) { (ctx.addBufferedState(secRangeDataType, "leftLowerSecRangeKey", secRangeInitValue), ctx.addBufferedState(secRangeDataType, "leftUpperSecRangeKey", secRangeInitValue), ctx.addBufferedState(secRangeDataType, "rightLowerSecRangeKey", secRangeInitValue), @@ -569,7 +557,7 @@ case class SortMergeJoinExec( } // A queue to hold all matched rows from right side. - val clsName = if (useSecondaryRange) classOf[InMemoryUnsafeRowQueue].getName + val clsName = if (useInnerRange) classOf[InMemoryUnsafeRowQueue].getName else classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold @@ -600,7 +588,7 @@ case class SortMergeJoinExec( logDebug(s"upperCompExp: $upperCompExp") // Add secondary range dequeue method - if (!useSecondaryRange || lowerSecondaryRangeExpression.isEmpty || + if (!useInnerRange || lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { ctx.addNewFunction("dequeueUntilLowerConditionHolds", "private void dequeueUntilLowerConditionHolds() { }", @@ -638,11 +626,11 @@ case class SortMergeJoinExec( |} """.stripMargin, inlineToOuterClass = true) } - val (leftLowVarsCode, leftUpperVarsCode) = if (useSecondaryRange) { + val (leftLowVarsCode, leftUpperVarsCode) = if (useInnerRange) { (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) } else { ("", "") } - val (rightLowVarsCode, rightUpperVarsCode) = if (useSecondaryRange) { + val (rightLowVarsCode, rightUpperVarsCode) = if (useInnerRange) { (rightLowerKeyVars.map(_.code).mkString("\n"), rightUpperKeyVars.map(_.code).mkString("\n")) } else { ("", "") } From 094f66b10509d0c9c5a88cc1bd56de2f4e4ffbad Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 11:21:41 -0700 Subject: [PATCH 20/51] Scala style --- .../spark/sql/execution/joins/SortMergeJoinExec.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d8beb894a99c6..25d2523d5e666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -18,12 +18,13 @@ package org.apache.spark.sql.execution.joins import scala.collection.mutable.ArrayBuffer + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ @@ -68,13 +69,13 @@ case class SortMergeJoinExec( logDebug(s"Use secondary range join resolved to $useInnerRange.") - val lrKeys = if(useInnerRange) { + val lrKeys = if (useInnerRange) { rangeConditions.flatMap(c => c.left.references.toSeq).distinct } else { Nil } - val rrKeys = if(useInnerRange) { + val rrKeys = if (useInnerRange) { rangeConditions.flatMap(c => c.right.references.toSeq).distinct } else { From efd595e4ed669478fcc0d639465cadbfbea39195 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 12:20:32 -0700 Subject: [PATCH 21/51] Bug fix - NPE when inner range optimization turned off --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 25d2523d5e666..ccdab185e2d7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -578,10 +578,10 @@ case class SortMergeJoinExec( case LessThan(_, _) => ">=" case _ => "" }.getOrElse("") - val lowerCompExp = if (lowerSecondaryRangeExpression.isEmpty) "" + val lowerCompExp = if (!useInnerRange || lowerSecondaryRangeExpression.isEmpty) "" else s" || (comp == 0 && ${leftLowerSecRangeKey.value} " + s"$lowerCompop ${rightLowerSecRangeKey.value})" - val upperCompExp = if (upperSecondaryRangeExpression.isEmpty) "" + val upperCompExp = if (!useInnerRange || upperSecondaryRangeExpression.isEmpty) "" else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " + s"$upperCompop ${rightUpperSecRangeKey.value})" From a8372e35b264c2dd2b28d30a8468b724d807a0c2 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 12:21:01 -0700 Subject: [PATCH 22/51] Adding test case when inner range optimization is turned off --- .../spark/sql/execution/joins/InnerJoinSuite.scala | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index b0482eabeebd8..3beb3247ea294 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -142,6 +142,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } + //disabling these because the code would never follow this path in case of a inner range join if(!expectRangeJoin) { test(s"$testName using BroadcastHashJoin (build=left)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, @@ -214,6 +215,16 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer.map(Row.fromTuple), sortAnswers = true) } + if (expectRangeJoin) { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, + leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } + } } } From 4396985309e850f98b7e1784e76963e586c99335 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 12:24:49 -0700 Subject: [PATCH 23/51] Stala style --- .../org/apache/spark/sql/execution/joins/InnerJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 3beb3247ea294..6cb9f91ed08ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -143,7 +143,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } //disabling these because the code would never follow this path in case of a inner range join - if(!expectRangeJoin) { + if (!expectRangeJoin) { test(s"$testName using BroadcastHashJoin (build=left)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => From 9c14368ac776ef47922861fe4701509c02c94c46 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 12:27:28 -0700 Subject: [PATCH 24/51] Stala style --- .../org/apache/spark/sql/execution/joins/InnerJoinSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 6cb9f91ed08ed..59f0451d269ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -142,7 +142,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } - //disabling these because the code would never follow this path in case of a inner range join + // Disabling these because the code would never follow this path in case of a inner range join if (!expectRangeJoin) { test(s"$testName using BroadcastHashJoin (build=left)") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, From 6cbf9fecaae62967aeae37d38adfa41750f55c2e Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 12:40:31 -0700 Subject: [PATCH 25/51] Remove range condition extraction when inner range join optimization is disabled --- .../sql/catalyst/planning/patterns.scala | 65 +++++++++---------- 1 file changed, 32 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index ee2920b81b131..448575dc2bdac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.planning import scala.collection.mutable - import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.BucketSpec @@ -26,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.internal.SQLConf /** * A pattern that matches any number of project or filter operations on top of another relational @@ -135,38 +135,45 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } if (joinKeys.nonEmpty) { + val (leftKeys, rightKeys) = joinKeys.unzip + // Find any simple range expressions between two columns // (and involving only those two columns) // of the two tables being joined, // which are not used in the equijoin expressions, // and which can be used for secondary sort optimizations. - val rangePreds : mutable.Set[Expression] = mutable.Set.empty - var rangeConditions : Seq[BinaryComparison] = otherPredicates.flatMap { - case p @ LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(LessThan(l, r)) - case "vs" => rangePreds.add(p); Some(GreaterThan(r, l)) - case _ => None - } - case p @ LessThanOrEqual(l, r) => - isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(LessThanOrEqual(l, r)) - case "vs" => rangePreds.add(p); Some(GreaterThanOrEqual(r, l)) + val rangePreds: mutable.Set[Expression] = mutable.Set.empty + var rangeConditions: Seq[BinaryComparison] = + if (SQLConf.get.useSmjInnerRangeOptimization) { + otherPredicates.flatMap { + case p@LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(LessThan(l, r)) + case "vs" => rangePreds.add(p); Some(GreaterThan(r, l)) + case _ => None + } + case p@LessThanOrEqual(l, r) => + isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(LessThanOrEqual(l, r)) + case "vs" => rangePreds.add(p); Some(GreaterThanOrEqual(r, l)) + case _ => None + } + case p@GreaterThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(GreaterThan(l, r)) + case "vs" => rangePreds.add(p); Some(LessThan(r, l)) + case _ => None + } + case p@GreaterThanOrEqual(l, r) => + isValidRangeCondition(l, r, left, right, joinKeys) match { + case "asis" => rangePreds.add(p); Some(GreaterThanOrEqual(l, r)) + case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) + case _ => None + } case _ => None + } } - case p @ GreaterThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(GreaterThan(l, r)) - case "vs" => rangePreds.add(p); Some(LessThan(r, l)) - case _ => None + else { + Nil } - case p @ GreaterThanOrEqual(l, r) => - isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(GreaterThanOrEqual(l, r)) - case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) - case _ => None - } - case _ => None - } - val (leftKeys, rightKeys) = joinKeys.unzip // Only using secondary join optimization when both lower and upper conditions // are specified (e.g. t1.a < t2.b + x and t1.a > t2.b - x) @@ -179,14 +186,6 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Check if both comparisons reference the same columns: rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 || rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) { - logDebug(s"Clearing range conditions because: " + - s"${rangeConditions.size}, " + - s"${rangeConditions.filter(x => x.isInstanceOf[LessThan] || - x.isInstanceOf[LessThanOrEqual]).size}, " + - s"${rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || - x.isInstanceOf[GreaterThanOrEqual]).size}, " + - s"${rangeConditions.map(c => c.left.references)}, " + - s"${rangeConditions.map(c => c.right.references)}") rangeConditions = Nil rangePreds.clear() } From 82943b806f7c9c60547b9e45583c900719650c85 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 12:42:11 -0700 Subject: [PATCH 26/51] Scala style --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 448575dc2bdac..27065b8530dff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.planning import scala.collection.mutable + import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.BucketSpec From bbddf7a83263e4f87d9e45fd35bb3d197ac265e8 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 13:44:58 -0700 Subject: [PATCH 27/51] Unit test fix --- .../sql/execution/joins/InnerJoinSuite.scala | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 59f0451d269ce..638a4720a24ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -215,14 +215,20 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { expectedAnswer.map(Row.fromTuple), sortAnswers = true) } - if (expectRangeJoin) { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, - leftPlan, rightPlan), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + } + if (expectRangeJoin) { + withSQLConf(SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, + leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } From c4060d7f803c238a486a26b6e07541b936086e4b Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 19 Apr 2018 13:55:50 -0700 Subject: [PATCH 28/51] Unit test fix --- .../org/apache/spark/sql/execution/joins/InnerJoinSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 638a4720a24ea..b3d8d668c3b7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -220,8 +220,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { withSQLConf(SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => - assert(!expectRangeJoin && rangeConditions.isEmpty || - expectRangeJoin && rangeConditions.size == 2) + assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, From 5b0f2b5a0d08c6cb9a42036daabcabc452f59cd9 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 27 Apr 2018 09:21:13 -0700 Subject: [PATCH 29/51] - Turning off inner range optimization when whole stage code generation is turned off. - Debugging - comments --- .../sql/catalyst/planning/patterns.scala | 3 +- .../execution/joins/SortMergeJoinExec.scala | 30 +++++++++---------- 2 files changed, 17 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 27065b8530dff..4b122d9c50479 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -102,7 +102,7 @@ object PhysicalOperation extends PredicateHelper { * value). */ object ExtractEquiJoinKeys extends Logging with PredicateHelper { - /** (joinType, leftKeys, rightKeys, condition, leftChild, rightChild) */ + /** (joinType, leftKeys, rightKeys, rangeConditions, condition, leftChild, rightChild) */ type ReturnType = (JoinType, Seq[Expression], Seq[Expression], Seq[BinaryComparison], Option[Expression], LogicalPlan, LogicalPlan) @@ -187,6 +187,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Check if both comparisons reference the same columns: rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 || rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) { + logDebug("Inner range optimization conditions not met. Clearing range conditions") rangeConditions = Nil rangePreds.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ccdab185e2d7f..d1a35263a483d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -228,8 +228,7 @@ case class SortMergeJoinExec( private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = - if (lowerSecondaryRangeExpression.isDefined || - upperSecondaryRangeExpression.isDefined) { + if (false) { // useInnerRange) { new SortMergeJoinInnerRangeScanner( createLeftKeyGenerator(), createRightKeyGenerator(), @@ -1137,9 +1136,11 @@ private[joins] class SortMergeJoinInnerRangeScanner( // First dequeue all rows from the queue until the lower range condition holds. // Then try to enqueue new rows with the same join key and for which the upper // range condition holds. + bufferedRowKey = matchJoinKey dequeueUntilLowerConditionHolds() - bufferMatchingRows(true) - true + bufferMatchingRows() + if (bufferedMatches.isEmpty) matchJoinKey = null + ! bufferedMatches.isEmpty } else if (bufferedRow == null) { // The streamed row's join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. @@ -1169,8 +1170,10 @@ private[joins] class SortMergeJoinInnerRangeScanner( // The streamed row's join key matches the current buffered row's join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) - bufferMatchingRows(true) - true + bufferedMatches.clear() + bufferMatchingRows() + if (bufferedMatches.isEmpty) matchJoinKey = null + ! bufferedMatches.isEmpty } } } @@ -1207,10 +1210,8 @@ private[joins] class SortMergeJoinInnerRangeScanner( if (!foundRow) { bufferedRow = null bufferedRowKey = null - false - } else { - true } + foundRow } /** @@ -1238,7 +1239,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( /** * Called when the streamed and buffered join keys match in order to buffer the matching rows. */ - private def bufferMatchingRows(clear: Boolean): Unit = { + private def bufferMatchingRows(): Unit = { assert(streamedRowKey != null) assert(!streamedRowKey.anyNull) assert(bufferedRowKey != null) @@ -1246,9 +1247,6 @@ private[joins] class SortMergeJoinInnerRangeScanner( assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() - if (clear) { - bufferedMatches.clear() - } var upperRangeOk = false var lowerRangeOk = false do { @@ -1264,8 +1262,10 @@ private[joins] class SortMergeJoinInnerRangeScanner( } private def dequeueUntilLowerConditionHolds(): Unit = { - while (!bufferedMatches.isEmpty && !lowerRangeCondition(joinRow(streamedRow, bufferedRow))) { - bufferedMatches.dequeue() + if (streamedRow != null && bufferedRow != null) { + while (!bufferedMatches.isEmpty && !lowerRangeCondition(joinRow(streamedRow, bufferedRow))) { + bufferedMatches.dequeue() + } } } } From 68e00c0f77499a24fd674a008ff1197967e6f3df Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 27 Apr 2018 09:56:09 -0700 Subject: [PATCH 30/51] Switch off inner range optimization when whole stage codegen is off. --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 4b122d9c50479..63e62457910ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -145,7 +145,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // and which can be used for secondary sort optimizations. val rangePreds: mutable.Set[Expression] = mutable.Set.empty var rangeConditions: Seq[BinaryComparison] = - if (SQLConf.get.useSmjInnerRangeOptimization) { + if (SQLConf.get.useSmjInnerRangeOptimization && SQLConf.get.wholeStageEnabled) { otherPredicates.flatMap { case p@LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { case "asis" => rangePreds.add(p); Some(LessThan(l, r)) From f5b9ca892286adb5da0d27c6b1f024186cae673f Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 27 Apr 2018 10:02:21 -0700 Subject: [PATCH 31/51] SMJ inner range optimization benchmarks --- .../execution/benchmark/JoinBenchmark.scala | 55 +++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 5a25d72308370..5e19b33969e63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -222,6 +222,61 @@ class JoinBenchmark extends BenchmarkBase { */ } + val expensiveFunc = (first: Int, second: Int) => { + for (i <- 1 to 2000) { + Math.sqrt(i * i * i) + } + Math.abs(first - second) + } + + def innerRangeTest(N: Int, M: Int): Unit = { + import sparkSession.implicits._ + val expUdf = sparkSession.udf.register("expensiveFunc", expensiveFunc) + val df1 = sparkSession.sparkContext.parallelize(1 to M). + cartesian(sparkSession.sparkContext.parallelize(1 to N)). + toDF("col1a", "col1b") + val df2 = sparkSession.sparkContext.parallelize(1 to M). + cartesian(sparkSession.sparkContext.parallelize(1 to N)). + toDF("col2a", "col2b") + val df = df1.join(df2, 'col1a === 'col2a and ('col1b < 'col2b + 3) and ('col1b > 'col2b - 3)) + assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) + df.where(expUdf('col1b, 'col2b) < 3).count() + } + + ignore("sort merge inner range join") { + sparkSession.conf.set("spark.sql.join.smj.useInnerRangeOptimization", "false") + val N = 2 << 5 + val M = 100 + runBenchmark("sort merge inner range join", N * M) { + innerRangeTest(N, M) + } + + /* + *AMD EPYC 7401 24-Core Processor + *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *--------------------------------------------------------------------------------------------- + *sort merge join wholestage off 13822 / 14068 0.0 2159662.3 1.0X + *sort merge join wholestage on 3863 / 4226 0.0 603547.0 3.6X + */ + } + + ignore("sort merge inner range join optimized") { + sparkSession.conf.set("spark.sql.join.smj.useInnerRangeOptimization", "true") + val N = 2 << 5 + val M = 100 + runBenchmark("sort merge inner range join optimized", N * M) { + innerRangeTest(N, M) + } + + /* + *AMD EPYC 7401 24-Core Processor + *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + *--------------------------------------------------------------------------------------------- + *sort merge join wholestage off 12723 / 12800 0.0 1988008.4 1.0X + *sort merge join wholestage on 469 / 526 0.0 73340.4 27.1X + */ + } + ignore("shuffle hash join") { val N = 4 << 20 sparkSession.conf.set("spark.sql.shuffle.partitions", "2") From 7457ab3f7216fbaebcf54e1340b7a33f4cb60072 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Sat, 28 Apr 2018 12:27:55 -0700 Subject: [PATCH 32/51] Removing "expensive function" from the SMJ inner range optimization benchmark. --- .../execution/benchmark/JoinBenchmark.scala | 22 ++++++------------- 1 file changed, 7 insertions(+), 15 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 5e19b33969e63..907402bfa703f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -222,16 +222,8 @@ class JoinBenchmark extends BenchmarkBase { */ } - val expensiveFunc = (first: Int, second: Int) => { - for (i <- 1 to 2000) { - Math.sqrt(i * i * i) - } - Math.abs(first - second) - } - def innerRangeTest(N: Int, M: Int): Unit = { import sparkSession.implicits._ - val expUdf = sparkSession.udf.register("expensiveFunc", expensiveFunc) val df1 = sparkSession.sparkContext.parallelize(1 to M). cartesian(sparkSession.sparkContext.parallelize(1 to N)). toDF("col1a", "col1b") @@ -240,12 +232,12 @@ class JoinBenchmark extends BenchmarkBase { toDF("col2a", "col2b") val df = df1.join(df2, 'col1a === 'col2a and ('col1b < 'col2b + 3) and ('col1b > 'col2b - 3)) assert(df.queryExecution.sparkPlan.find(_.isInstanceOf[SortMergeJoinExec]).isDefined) - df.where(expUdf('col1b, 'col2b) < 3).count() + df.count() } ignore("sort merge inner range join") { sparkSession.conf.set("spark.sql.join.smj.useInnerRangeOptimization", "false") - val N = 2 << 5 + val N = 2 << 11 val M = 100 runBenchmark("sort merge inner range join", N * M) { innerRangeTest(N, M) @@ -255,14 +247,14 @@ class JoinBenchmark extends BenchmarkBase { *AMD EPYC 7401 24-Core Processor *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative *--------------------------------------------------------------------------------------------- - *sort merge join wholestage off 13822 / 14068 0.0 2159662.3 1.0X - *sort merge join wholestage on 3863 / 4226 0.0 603547.0 3.6X + *sort merge join wholestage off 30956 / 31374 0.0 75575.5 1.0X + *sort merge join wholestage on 10864 / 11043 0.0 26523.6 2.8X */ } ignore("sort merge inner range join optimized") { sparkSession.conf.set("spark.sql.join.smj.useInnerRangeOptimization", "true") - val N = 2 << 5 + val N = 2 << 11 val M = 100 runBenchmark("sort merge inner range join optimized", N * M) { innerRangeTest(N, M) @@ -272,8 +264,8 @@ class JoinBenchmark extends BenchmarkBase { *AMD EPYC 7401 24-Core Processor *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative *--------------------------------------------------------------------------------------------- - *sort merge join wholestage off 12723 / 12800 0.0 1988008.4 1.0X - *sort merge join wholestage on 469 / 526 0.0 73340.4 27.1X + *sort merge join wholestage off 30734 / 31135 0.0 75035.2 1.0X + *sort merge join wholestage on 959 / 1040 0.4 2341.3 32.0X */ } From c47c8cd5dbd8a667117d2e4a4031c2c9d77d7323 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 10 May 2018 07:46:39 +0200 Subject: [PATCH 33/51] SMJ inner range optimization with wholeStage codegen turned off - code + tests --- .../sql/catalyst/planning/patterns.scala | 3 +- .../execution/InMemoryUnsafeRowQueue.scala | 8 +- .../execution/joins/SortMergeJoinExec.scala | 119 ++++++----- .../sql/execution/joins/InnerJoinSuite.scala | 184 +++++++++--------- 4 files changed, 170 insertions(+), 144 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 63e62457910ae..8bdce8f1f0b68 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -109,7 +109,6 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => - logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) @@ -145,7 +144,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // and which can be used for secondary sort optimizations. val rangePreds: mutable.Set[Expression] = mutable.Set.empty var rangeConditions: Seq[BinaryComparison] = - if (SQLConf.get.useSmjInnerRangeOptimization && SQLConf.get.wholeStageEnabled) { + if (SQLConf.get.useSmjInnerRangeOptimization) { // && SQLConf.get.wholeStageEnabled) { otherPredicates.flatMap { case p@LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { case "asis" => rangePreds.add(p); Some(LessThan(l, r)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index 7b557c380779a..b8a17e0763ac9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -83,6 +83,8 @@ private[sql] class InMemoryUnsafeRowQueue( // private var spillableArray: UnsafeExternalSorter = _ private var numRows = 0 + override def length: Int = numRows + override def isEmpty: Boolean = numRows == 0 // A counter to keep track of total modifications done to this array since its creation. @@ -91,10 +93,6 @@ private[sql] class InMemoryUnsafeRowQueue( private var numFieldsPerRow = 0 -// def length: Int = numRows -// -// def isEmpty: Boolean = numRows == 0 - /** * Clears up resources (eg. memory) held by the backing storage */ @@ -149,6 +147,8 @@ private[sql] class InMemoryUnsafeRowQueue( new InMemoryBufferIterator(startIndex) } + override def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) + private[this] abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { private val expectedModificationsCount = modificationsCount diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index d1a35263a483d..c405d8ef5d38f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -125,13 +125,8 @@ case class SortMergeJoinExec( // For inner join, orders of both sides keys should be kept. case _: InnerLike => - - logDebug(s"Left range ordering: ${leftKeys ++ lrKeys}") - logDebug(s"Right range ordering: ${rightKeys ++ rrKeys}") - val leftKeyOrdering = getKeyOrdering(leftKeys ++ lrKeys, left.outputOrdering) val rightKeyOrdering = getKeyOrdering(rightKeys ++ rrKeys, right.outputOrdering) - logDebug(s"outputOrdering results: $leftKeyOrdering and $rightKeyOrdering") leftKeyOrdering.zip(rightKeyOrdering).map { case (lKey, rKey) => // Also add the right key and its `sameOrderExpressions` SortOrder(lKey.child, Ascending, lKey.sameOrderExpressions + rKey.child ++ rKey @@ -228,7 +223,7 @@ case class SortMergeJoinExec( private[this] var currentRightMatches: ExternalAppendOnlyUnsafeRowArray = _ private[this] var rightMatchesIterator: Iterator[UnsafeRow] = null private[this] val smjScanner = - if (false) { // useInnerRange) { + if (useInnerRange) { new SortMergeJoinInnerRangeScanner( createLeftKeyGenerator(), createRightKeyGenerator(), @@ -283,7 +278,9 @@ case class SortMergeJoinExec( false } - override def getRow: InternalRow = resultProj(joinRow) + override def getRow: InternalRow = { + resultProj(joinRow) + } }.toScala case LeftOuter => @@ -590,8 +587,8 @@ case class SortMergeJoinExec( // Add secondary range dequeue method if (!useInnerRange || lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { - ctx.addNewFunction("dequeueUntilLowerConditionHolds", - "private void dequeueUntilLowerConditionHolds() { }", + ctx.addNewFunction("dequeueUntilUpperConditionHolds", + "private void dequeueUntilUpperConditionHolds() { }", inlineToOuterClass = true) } else { @@ -609,9 +606,9 @@ case class SortMergeJoinExec( |} """.stripMargin) - ctx.addNewFunction("dequeueUntilLowerConditionHolds", + ctx.addNewFunction("dequeueUntilUpperConditionHolds", s""" - |private void dequeueUntilLowerConditionHolds() { + |private void dequeueUntilUpperConditionHolds() { | if($matches.isEmpty()) | return; | $rightTmpRow = $matches.get(0); @@ -671,7 +668,7 @@ case class SortMergeJoinExec( | if (!$matches.isEmpty()) { | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { - | dequeueUntilLowerConditionHolds(); + | dequeueUntilUpperConditionHolds(); | } | else { | $matches.clear(); @@ -1103,9 +1100,17 @@ private[joins] class SortMergeJoinInnerRangeScanner( private[this] val bufferedMatches = new InMemoryUnsafeRowQueue(inMemoryThreshold, spillThreshold) - private[this] val joinRow = new JoinedRow - // Initialization (note: do _not_ want to advance streamed here). - advanceBufferedToRowWithNullFreeJoinKey() + private[this] val testJoinRow = new JoinedRow + private[this] val joinedRow = new JoinedRow + private[this] var lowerConditionOk: Boolean = false + private[this] var upperConditionOk: Boolean = false + + // Already done in the superclass: + //advancedBufferedToRowWithNullFreeJoinKey() + bufferedRow = bufferedIter.getRow + if (bufferedRow != null) { + bufferedRowKey = bufferedKeyGenerator(bufferedRow) + } // --- Public methods --------------------------------------------------------------------------- @@ -1136,11 +1141,15 @@ private[joins] class SortMergeJoinInnerRangeScanner( // First dequeue all rows from the queue until the lower range condition holds. // Then try to enqueue new rows with the same join key and for which the upper // range condition holds. - bufferedRowKey = matchJoinKey - dequeueUntilLowerConditionHolds() - bufferMatchingRows() - if (bufferedMatches.isEmpty) matchJoinKey = null - ! bufferedMatches.isEmpty + dequeueUntilUpperConditionHolds() + if (bufferedRow != null) { + bufferMatchingRows() + } + if (bufferedMatches.isEmpty) { + matchJoinKey = null + findNextInnerJoinRows() + } + else true } else if (bufferedRow == null) { // The streamed row's join key does not match the current batch of buffered rows and there are // no more rows to read from the buffered iterator, so there can be no more matches. @@ -1149,31 +1158,33 @@ private[joins] class SortMergeJoinInnerRangeScanner( false } else { // Advance both the streamed and buffered iterators to find the next pair of matching rows. - var comp = -1 // keyOrdering.compare(streamedRowKey, bufferedRowKey) + var comp = -1 do { if (streamedRowKey.anyNull) { advanceStreamed() } else { assert(!bufferedRowKey.anyNull) comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - if (comp > 0) advanceBufferedToRowWithNullFreeJoinKey() - else if (comp < 0) advanceStreamed() - else comp = checkLowerBoundAndAdvanceBuffered() + if (comp > 0) advancedBufferedToRowWithNullFreeJoinKey() + else if (comp < 0 || !lowerConditionOk) advanceStreamed() + else comp = checkBoundsAndAdvanceBuffered() } - } while (streamedRow != null && bufferedRow != null && comp != 0) + } while (streamedRow != null && bufferedRow != null && (comp != 0 || !lowerConditionOk)) + bufferedMatches.clear() if (streamedRow == null || bufferedRow == null) { // We have hit the end of one of the iterators, so there can be no more matches. matchJoinKey = null - bufferedMatches.clear() false } else { // The streamed row's join key matches the current buffered row's join, so walk through the // buffered iterator to buffer the rest of the matching rows. assert(comp == 0) - bufferedMatches.clear() bufferMatchingRows() - if (bufferedMatches.isEmpty) matchJoinKey = null - ! bufferedMatches.isEmpty + if (bufferedMatches.isEmpty) { + matchJoinKey = null + findNextInnerJoinRows() + } + else true } } } @@ -1188,6 +1199,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( if (streamedIter.advanceNext()) { streamedRow = streamedIter.getRow streamedRowKey = streamedKeyGenerator(streamedRow) + updateJoinedRow() true } else { streamedRow = null @@ -1200,7 +1212,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( * Advance the buffered iterator until we find a row with join key that does not contain nulls. * @return true if the buffered iterator returned a row and false otherwise. */ - private def advanceBufferedToRowWithNullFreeJoinKey(): Boolean = { + private def advancedBufferedToRowWithNullFreeJoinKey(): Boolean = { var foundRow: Boolean = false while (!foundRow && bufferedIter.advanceNext()) { bufferedRow = bufferedIter.getRow @@ -1211,26 +1223,34 @@ private[joins] class SortMergeJoinInnerRangeScanner( bufferedRow = null bufferedRowKey = null } + else { + updateJoinedRow() + } foundRow } + private def updateJoinedRow() = { + if (streamedRow != null && bufferedRow != null) { + joinedRow(streamedRow, bufferedRow) + lowerConditionOk = lowerRangeCondition(joinedRow) + upperConditionOk = upperRangeCondition(joinedRow) + } + } + /** * Advance the buffered iterator as long as the join key is the same and - * the lower range condition is not satisfied. + * the upper range condition is not satisfied. * Skip rows with nulls. * @return Result of the join key comparison. */ - private def checkLowerBoundAndAdvanceBuffered(): Int = { + private def checkBoundsAndAdvanceBuffered(): Int = { assert(bufferedRow != null) assert(streamedRow != null) var comp = 0 - var lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) - if (!lowCheck) { - while (!lowCheck && comp == 0 && advanceBufferedToRowWithNullFreeJoinKey()) { + if (lowerConditionOk && !upperConditionOk) { + while (!upperConditionOk && lowerConditionOk && comp == 0 && + advancedBufferedToRowWithNullFreeJoinKey()) { comp = keyOrdering.compare(streamedRowKey, bufferedRowKey) - if (comp == 0) { - lowCheck = lowerRangeCondition(joinRow(streamedRow, bufferedRow)) - } } } comp @@ -1244,26 +1264,23 @@ private[joins] class SortMergeJoinInnerRangeScanner( assert(!streamedRowKey.anyNull) assert(bufferedRowKey != null) assert(!bufferedRowKey.anyNull) - assert(keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) + if (keyOrdering.compare(streamedRowKey, bufferedRowKey) != 0) { + return + } // This join key may have been produced by a mutable projection, so we need to make a copy: matchJoinKey = streamedRowKey.copy() - var upperRangeOk = false - var lowerRangeOk = false do { - val jr = joinRow(streamedRow, bufferedRow) - lowerRangeOk = lowerRangeCondition(jr) - upperRangeOk = upperRangeCondition(jr) - if (lowerRangeOk && upperRangeOk) { + if (lowerConditionOk && upperConditionOk) { bufferedMatches.add(bufferedRow.asInstanceOf[UnsafeRow]) } - advanceBufferedToRowWithNullFreeJoinKey() - } while (bufferedRow != null && keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0 - && upperRangeOk) + } while (lowerConditionOk && advancedBufferedToRowWithNullFreeJoinKey() && + keyOrdering.compare(streamedRowKey, bufferedRowKey) == 0) } - private def dequeueUntilLowerConditionHolds(): Unit = { - if (streamedRow != null && bufferedRow != null) { - while (!bufferedMatches.isEmpty && !lowerRangeCondition(joinRow(streamedRow, bufferedRow))) { + private def dequeueUntilUpperConditionHolds(): Unit = { + if (streamedRow != null) { + while (!bufferedMatches.isEmpty && + !upperRangeCondition(testJoinRow(streamedRow, bufferedMatches.get(0)))) { bufferedMatches.dequeue() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index b3d8d668c3b7e..08748ae52709b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -71,16 +71,16 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ).toDF("a", "b") private lazy val rangeTestData1 = Seq( - (1, 1), (1, 2), (1, 3), - (2, 1), (2, 2), (2, 3), + (1, 3), (1, 4), (1, 7), (1, 8), (1, 10), + (2, 1), (2, 2), (2, 3), (2, 8), (3, 1), (3, 2), (3, 3), (3, 5), (4, 1), (4, 2), (4, 3) ).toDF("a", "b") private lazy val rangeTestData2 = Seq( - (1, 1), (1, 2), (1, 3), (1, 5), - (2, 1), (2, 2), (2, 3), - (3, 1), (3, 2), (3, 3) + (1, 1), (1, 2), (1, 2), (1, 3), (1, 5), (1, 7), (1, 20), + (2, 1), (2, 2), (2, 3), (2, 5), (2, 6), + (3, 3), (3, 6) ).toDF("a", "b") // Note: the input dataframes and expression must be evaluated lazily because @@ -142,84 +142,98 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) } + val configOptions = List( + ("spark.sql.codegen.wholeStage", "true"), + ("spark.sql.codegen.wholeStage", "false")) + // Disabling these because the code would never follow this path in case of a inner range join if (!expectRangeJoin) { - test(s"$testName using BroadcastHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, - boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using BroadcastHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } } if(!expectRangeJoin) { - test(s"$testName using BroadcastHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, - boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeBroadcastHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using BroadcastHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeBroadcastHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } } if(!expectRangeJoin) { - test(s"$testName using ShuffledHashJoin (build=left)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, - boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using ShuffledHashJoin (build=left)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } } if(!expectRangeJoin) { - test(s"$testName using ShuffledHashJoin (build=right)") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, - boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeShuffledHashJoin( - leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using ShuffledHashJoin (build=right)") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, + boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeShuffledHashJoin( + leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } } - test(s"$testName using SortMergeJoin") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => - assert(!expectRangeJoin && rangeConditions.isEmpty || - expectRangeJoin && rangeConditions.size == 2) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, - leftPlan, rightPlan), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using SortMergeJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => + assert(!expectRangeJoin && rangeConditions.isEmpty || + expectRangeJoin && rangeConditions.size == 2) + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, + leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } if (expectRangeJoin) { withSQLConf(SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => + boundCondition, _, _) => assert(rangeConditions.isEmpty) withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => @@ -233,31 +247,37 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } - test(s"$testName using CartesianProduct") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", - SQLConf.CROSS_JOINS_ENABLED.key -> "true") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - CartesianProductExec(left, right, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using CartesianProduct") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", + SQLConf.CROSS_JOINS_ENABLED.key -> "true", config -> confValue) { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + CartesianProductExec(left, right, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } - test(s"$testName using BroadcastNestedLoopJoin build left") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } - test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + configOptions.foreach { case (config, confValue) => + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } @@ -329,27 +349,17 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { () => ((left("a") === right("a")) and (left("b") <= right("b") + 1) and (left("b") >= right("b") - 1)).expr, Seq( - (1, 1, 1, 1), - (1, 1, 1, 2), - (1, 2, 1, 1), - (1, 2, 1, 2), - (1, 2, 1, 3), (1, 3, 1, 2), - (1, 3, 1, 3), - (2, 1, 2, 1), + (1, 3, 1, 2), + (1, 4, 1, 3), + (1, 4, 1, 5), + (1, 8, 1, 7), (2, 1, 2, 2), (2, 2, 2, 1), - (2, 2, 2, 2), (2, 2, 2, 3), (2, 3, 2, 2), - (2, 3, 2, 3), - (3, 1, 3, 1), - (3, 1, 3, 2), - (3, 2, 3, 1), - (3, 2, 3, 2), (3, 2, 3, 3), - (3, 3, 3, 2), - (3, 3, 3, 3) + (3, 5, 3, 6) ), true ) From 3fbedfc21bb0146cffdb9610c66e537da698bc7f Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 10 May 2018 21:55:21 +0200 Subject: [PATCH 34/51] Unit test fix. Benchmark results update. --- .../execution/joins/SortMergeJoinExec.scala | 2 +- .../execution/benchmark/JoinBenchmark.scala | 20 ++-- .../sql/execution/joins/InnerJoinSuite.scala | 110 +++++++++++------- 3 files changed, 77 insertions(+), 55 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index c405d8ef5d38f..806bee2d0edd1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -1106,7 +1106,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( private[this] var upperConditionOk: Boolean = false // Already done in the superclass: - //advancedBufferedToRowWithNullFreeJoinKey() + // advancedBufferedToRowWithNullFreeJoinKey() bufferedRow = bufferedIter.getRow if (bufferedRow != null) { bufferedRowKey = bufferedKeyGenerator(bufferedRow) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index 907402bfa703f..c6de3aa576bf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.IntegerType /** - * Benchmark to measure performance for aggregate primitives. - * To run this: - * build/sbt "sql/test-only *benchmark.JoinBenchmark" - * - * Benchmarks in this file are skipped in normal builds. - */ + * Benchmark to measure performance for aggregate primitives. + * To run this: + * build/sbt "sql/test-only *benchmark.JoinBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ class JoinBenchmark extends BenchmarkBase { ignore("broadcast hash join, long key") { @@ -247,8 +247,8 @@ class JoinBenchmark extends BenchmarkBase { *AMD EPYC 7401 24-Core Processor *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative *--------------------------------------------------------------------------------------------- - *sort merge join wholestage off 30956 / 31374 0.0 75575.5 1.0X - *sort merge join wholestage on 10864 / 11043 0.0 26523.6 2.8X + *sort merge join wholestage off 25226 / 25244 0.0 61585.9 1.0X + *sort merge join wholestage on 8581 / 8983 0.0 20948.6 2.9X */ } @@ -264,8 +264,8 @@ class JoinBenchmark extends BenchmarkBase { *AMD EPYC 7401 24-Core Processor *sort merge join: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative *--------------------------------------------------------------------------------------------- - *sort merge join wholestage off 30734 / 31135 0.0 75035.2 1.0X - *sort merge join wholestage on 959 / 1040 0.4 2341.3 32.0X + *sort merge join wholestage off 1194 / 1212 0.3 2915.2 1.0X + *sort merge join wholestage on 814 / 867 0.5 1988.4 1.5X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 08748ae52709b..bda7666a5a2e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -86,12 +86,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { // Note: the input dataframes and expression must be evaluated lazily because // the SQLContext should be used only within a test to keep SQL tests stable private def testInnerJoin( - testName: String, - leftRows: => DataFrame, - rightRows: => DataFrame, - condition: () => Expression, - expectedAnswer: Seq[Product], - expectRangeJoin: Boolean = false): Unit = { + testName: String, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: () => Expression, + expectedAnswer: Seq[Product], + expectRangeJoin: Boolean = false): Unit = { def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition())) @@ -99,12 +99,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } def makeBroadcastHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - leftPlan: SparkPlan, - rightPlan: SparkPlan, - side: BuildSide) = { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { val broadcastJoin = joins.BroadcastHashJoinExec( leftKeys, rightKeys, @@ -117,12 +117,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } def makeShuffledHashJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - leftPlan: SparkPlan, - rightPlan: SparkPlan, - side: BuildSide) = { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + leftPlan: SparkPlan, + rightPlan: SparkPlan, + side: BuildSide) = { val shuffledHashJoin = joins.ShuffledHashJoinExec(leftKeys, rightKeys, Inner, side, None, leftPlan, rightPlan) val filteredJoin = @@ -131,12 +131,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } def makeSortMergeJoin( - leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - boundCondition: Option[Expression], - rangeConditions: Seq[BinaryComparison], - leftPlan: SparkPlan, - rightPlan: SparkPlan) = { + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + boundCondition: Option[Expression], + rangeConditions: Seq[BinaryComparison], + leftPlan: SparkPlan, + rightPlan: SparkPlan) = { val sortMergeJoin = joins.SortMergeJoinExec(leftKeys, rightKeys, Inner, rangeConditions, boundCondition, leftPlan, rightPlan) EnsureRequirements(spark.sessionState.conf).apply(sortMergeJoin) @@ -148,8 +148,9 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { // Disabling these because the code would never follow this path in case of a inner range join if (!expectRangeJoin) { + var counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using BroadcastHashJoin (build=left)") { + test(s"$testName using BroadcastHashJoin (build=left) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { @@ -161,12 +162,14 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + counter += 1 } } if(!expectRangeJoin) { + var counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using BroadcastHashJoin (build=right)") { + test(s"$testName using BroadcastHashJoin (build=right) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { @@ -178,12 +181,14 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + counter += 1 } } if(!expectRangeJoin) { + var counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using ShuffledHashJoin (build=left)") { + test(s"$testName using ShuffledHashJoin (build=left) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { @@ -195,12 +200,14 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + counter += 1 } } if(!expectRangeJoin) { + var counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using ShuffledHashJoin (build=right)") { + test(s"$testName using ShuffledHashJoin (build=right) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { @@ -212,11 +219,13 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } + counter += 1 } } + var counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using SortMergeJoin") { + test(s"$testName using SortMergeJoin $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => assert(!expectRangeJoin && rangeConditions.isEmpty || @@ -229,26 +238,28 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } - } - if (expectRangeJoin) { - withSQLConf(SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { - extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, - boundCondition, _, _) => - assert(rangeConditions.isEmpty) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => - makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, - leftPlan, rightPlan), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) + if (expectRangeJoin) { + withSQLConf(SQLConf.USE_SMJ_INNER_RANGE_OPTIMIZATION.key -> "false") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, + boundCondition, _, _) => + assert(rangeConditions.isEmpty) + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => + makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, + leftPlan, rightPlan), + expectedAnswer.map(Row.fromTuple), + sortAnswers = true) + } } } } } + counter += 1 } + counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using CartesianProduct") { + test(s"$testName using CartesianProduct $counter") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", SQLConf.CROSS_JOINS_ENABLED.key -> "true", config -> confValue) { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -257,10 +268,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } + counter += 1 } + counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using BroadcastNestedLoopJoin build left") { + test(s"$testName using BroadcastNestedLoopJoin build left $counter") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastNestedLoopJoinExec(left, right, BuildLeft, Inner, Some(condition())), @@ -268,10 +281,12 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } + counter += 1 } + counter = 1 configOptions.foreach { case (config, confValue) => - test(s"$testName using BroadcastNestedLoopJoin build right") { + test(s"$testName using BroadcastNestedLoopJoin build right $counter") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => BroadcastNestedLoopJoinExec(left, right, BuildRight, Inner, Some(condition())), @@ -279,6 +294,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } + counter += 1 } } @@ -351,14 +367,20 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { Seq( (1, 3, 1, 2), (1, 3, 1, 2), + (1, 3, 1, 3), (1, 4, 1, 3), (1, 4, 1, 5), + (1, 7, 1, 7), (1, 8, 1, 7), + (2, 1, 2, 1), (2, 1, 2, 2), (2, 2, 2, 1), + (2, 2, 2, 2), (2, 2, 2, 3), (2, 3, 2, 2), + (2, 3, 2, 3), (3, 2, 3, 3), + (3, 3, 3, 3), (3, 5, 3, 6) ), true From b8e1ee4e2913392d805735f7ee2b83954a5ba63f Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 10 May 2018 22:13:58 +0200 Subject: [PATCH 35/51] Scalastyle for comments --- .../sql/execution/benchmark/JoinBenchmark.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala index c6de3aa576bf0..3e24d4c4092ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/JoinBenchmark.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.IntegerType /** - * Benchmark to measure performance for aggregate primitives. - * To run this: - * build/sbt "sql/test-only *benchmark.JoinBenchmark" - * - * Benchmarks in this file are skipped in normal builds. - */ + * Benchmark to measure performance for aggregate primitives. + * To run this: + * build/sbt "sql/test-only *benchmark.JoinBenchmark" + * + * Benchmarks in this file are skipped in normal builds. + */ class JoinBenchmark extends BenchmarkBase { ignore("broadcast hash join, long key") { From 27109571c4e892c10fff1cc6f9c1b1fec267e3da Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Tue, 15 May 2018 14:51:01 +0200 Subject: [PATCH 36/51] Code changes based on review comments. --- .../sql/catalyst/planning/patterns.scala | 84 +++++++++---------- .../execution/joins/SortMergeJoinExec.scala | 23 +++-- 2 files changed, 53 insertions(+), 54 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8bdce8f1f0b68..1c9584a8f1ce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -109,6 +109,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { def unapply(plan: LogicalPlan): Option[ReturnType] = plan match { case join @ Join(left, right, joinType, condition) => + logDebug(s"Considering join on: $condition") // Find equi-join predicates that can be evaluated before the join, and thus can be used // as join keys. val predicates = condition.map(splitConjunctivePredicates).getOrElse(Nil) @@ -136,42 +137,36 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { if (joinKeys.nonEmpty) { val (leftKeys, rightKeys) = joinKeys.unzip - // Find any simple range expressions between two columns - // (and involving only those two columns) - // of the two tables being joined, + // (and involving only those two columns) of the two tables being joined, // which are not used in the equijoin expressions, // and which can be used for secondary sort optimizations. + // rangePreds will contain the original expressions to be filtered out later. val rangePreds: mutable.Set[Expression] = mutable.Set.empty var rangeConditions: Seq[BinaryComparison] = - if (SQLConf.get.useSmjInnerRangeOptimization) { // && SQLConf.get.wholeStageEnabled) { + if (SQLConf.get.useSmjInnerRangeOptimization) { otherPredicates.flatMap { - case p@LessThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(LessThan(l, r)) - case "vs" => rangePreds.add(p); Some(GreaterThan(r, l)) - case _ => None + case p@LessThan(l, r) => checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); GreaterThan(r, l) + case false => rangePreds.add(p); p } case p@LessThanOrEqual(l, r) => - isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(LessThanOrEqual(l, r)) - case "vs" => rangePreds.add(p); Some(GreaterThanOrEqual(r, l)) - case _ => None + checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); GreaterThanOrEqual(r, l) + case false => rangePreds.add(p); p } - case p@GreaterThan(l, r) => isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(GreaterThan(l, r)) - case "vs" => rangePreds.add(p); Some(LessThan(r, l)) - case _ => None + case p@GreaterThan(l, r) => checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); LessThan(r, l) + case false => rangePreds.add(p); p } case p@GreaterThanOrEqual(l, r) => - isValidRangeCondition(l, r, left, right, joinKeys) match { - case "asis" => rangePreds.add(p); Some(GreaterThanOrEqual(l, r)) - case "vs" => rangePreds.add(p); Some(LessThanOrEqual(r, l)) - case _ => None + checkRangeConditions(l, r, left, right, joinKeys).map { + case true => rangePreds.add(p); LessThanOrEqual(r, l) + case false => rangePreds.add(p); p } case _ => None } - } - else { + } else { Nil } @@ -199,38 +194,43 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { case _ => None } - private def isValidRangeCondition(l : Expression, r : Expression, - left : LogicalPlan, right : LogicalPlan, - joinKeys : Seq[(Expression, Expression)]) = { + /** + * Checks if l and r are valid range conditions: + * - l and r expressions should both contain a single reference to one and the same column. + * - the referenced column should not be part of joinKeys + * If these conditions are not met, the function returns None. + * + * Otherwise, the function checks if the left plan contains l expression and the right plan + * contains r expression. If the expressions need to be switched, the function returns Some(true) + * and Some(false) otherwise. + */ + private def checkRangeConditions(l : Expression, r : Expression, + left : LogicalPlan, right : LogicalPlan, + joinKeys : Seq[(Expression, Expression)]) = { val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) if(lattrs.size != 1 || rattrs.size != 1) { - "none" + None } else if (canEvaluate(l, left) && canEvaluate(r, right)) { - val equiset = joinKeys.filter{ case (ljk : Expression, rjk : Expression) => - ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) } - if (equiset.isEmpty) { - "asis" - } - else { - "none" + if (joinKeys.exists { case (ljk : Expression, rjk : Expression) => + ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) }) { + None + } else { + Some(false) } } else if (canEvaluate(l, right) && canEvaluate(r, left)) { - val equiset = joinKeys.filter{ case (ljk : Expression, rjk : Expression) => - rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0)) } - if(equiset.isEmpty) { - "vs" + if (joinKeys.exists{ case (ljk : Expression, rjk : Expression) => + rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0)) }) { + None } else { - "none" + Some(true) } - } - else { - "none" + } else { + None } } - } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 806bee2d0edd1..2932a7d7e03d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -588,8 +588,7 @@ case class SortMergeJoinExec( if (!useInnerRange || lowerSecondaryRangeExpression.isEmpty || rightLowerKeys.size == 0 || rightUpperKeys.size == 0) { ctx.addNewFunction("dequeueUntilUpperConditionHolds", - "private void dequeueUntilUpperConditionHolds() { }", - inlineToOuterClass = true) + "private void dequeueUntilUpperConditionHolds() { }") } else { val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, @@ -621,7 +620,7 @@ case class SortMergeJoinExec( | tempVal = getRightTmpRangeValue(); | } |} - """.stripMargin, inlineToOuterClass = true) + """.stripMargin) } val (leftLowVarsCode, leftUpperVarsCode) = if (useInnerRange) { (leftLowerKeyVars.map(_.code).mkString("\n"), leftUpperKeyVars.map(_.code).mkString("\n")) @@ -1075,15 +1074,15 @@ private[joins] class SortMergeJoinScanner( * @param spillThreshold Threshold for number of rows to be spilled by internal buffer */ private[joins] class SortMergeJoinInnerRangeScanner( - streamedKeyGenerator: Projection, - bufferedKeyGenerator: Projection, - keyOrdering: Ordering[InternalRow], - streamedIter: RowIterator, - bufferedIter: RowIterator, - inMemoryThreshold: Int, - spillThreshold: Int, - lowerRangeCondition: InternalRow => Boolean, - upperRangeCondition: InternalRow => Boolean) + streamedKeyGenerator: Projection, + bufferedKeyGenerator: Projection, + keyOrdering: Ordering[InternalRow], + streamedIter: RowIterator, + bufferedIter: RowIterator, + inMemoryThreshold: Int, + spillThreshold: Int, + lowerRangeCondition: InternalRow => Boolean, + upperRangeCondition: InternalRow => Boolean) extends SortMergeJoinScanner(streamedKeyGenerator, bufferedKeyGenerator, keyOrdering, streamedIter, bufferedIter, inMemoryThreshold, spillThreshold) { private[this] var streamedRow: InternalRow = _ From 52f2b70058f26211dd8436edf7e958a32dd68e3f Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 7 Jun 2018 16:27:58 +0200 Subject: [PATCH 37/51] Code review changes --- .../sql/catalyst/planning/patterns.scala | 25 +++++------ .../execution/InMemoryUnsafeRowQueue.scala | 1 - .../sql/execution/joins/InnerJoinSuite.scala | 44 +++++-------------- 3 files changed, 23 insertions(+), 47 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 1c9584a8f1ce8..7288d11d982f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -142,7 +142,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // which are not used in the equijoin expressions, // and which can be used for secondary sort optimizations. // rangePreds will contain the original expressions to be filtered out later. - val rangePreds: mutable.Set[Expression] = mutable.Set.empty + val rangePreds = mutable.Set.empty[Expression] var rangeConditions: Seq[BinaryComparison] = if (SQLConf.get.useSmjInnerRangeOptimization) { otherPredicates.flatMap { @@ -172,12 +172,12 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { // Only using secondary join optimization when both lower and upper conditions // are specified (e.g. t1.a < t2.b + x and t1.a > t2.b - x) - if(rangeConditions.size != 2 || + if (rangeConditions.size != 2 || // Looking for one < and one > comparison: - rangeConditions.filter(x => x.isInstanceOf[LessThan] || - x.isInstanceOf[LessThanOrEqual]).size == 0 || - rangeConditions.filter(x => x.isInstanceOf[GreaterThan] || - x.isInstanceOf[GreaterThanOrEqual]).size == 0 || + rangeConditions.forall(x => !x.isInstanceOf[LessThan] && + !x.isInstanceOf[LessThanOrEqual]) || + rangeConditions.forall(x => !x.isInstanceOf[GreaterThan] && + !x.isInstanceOf[GreaterThanOrEqual]) || // Check if both comparisons reference the same columns: rangeConditions.flatMap(c => c.left.references.toSeq.distinct).distinct.size != 1 || rangeConditions.flatMap(c => c.right.references.toSeq.distinct).distinct.size != 1) { @@ -206,25 +206,22 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { */ private def checkRangeConditions(l : Expression, r : Expression, left : LogicalPlan, right : LogicalPlan, - joinKeys : Seq[(Expression, Expression)]) = { + joinKeys : Seq[(Expression, Expression)]):Option[Boolean] = { val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) - if(lattrs.size != 1 || rattrs.size != 1) { + if (lattrs.size != 1 || rattrs.size != 1) { None - } - else if (canEvaluate(l, left) && canEvaluate(r, right)) { + } else if (canEvaluate(l, left) && canEvaluate(r, right)) { if (joinKeys.exists { case (ljk : Expression, rjk : Expression) => ljk.references.toSeq.contains(lattrs(0)) && rjk.references.toSeq.contains(rattrs(0)) }) { None } else { Some(false) } - } - else if (canEvaluate(l, right) && canEvaluate(r, left)) { + } else if (canEvaluate(l, right) && canEvaluate(r, left)) { if (joinKeys.exists{ case (ljk : Expression, rjk : Expression) => rjk.references.toSeq.contains(lattrs(0)) && ljk.references.toSeq.contains(rattrs(0)) }) { None - } - else { + } else { Some(true) } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index b8a17e0763ac9..dd691c38664ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -80,7 +80,6 @@ private[sql] class InMemoryUnsafeRowQueue( null } -// private var spillableArray: UnsafeExternalSorter = _ private var numRows = 0 override def length: Int = numRows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index bda7666a5a2e2..6ec98b50f785a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -85,8 +85,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { // Note: the input dataframes and expression must be evaluated lazily because // the SQLContext should be used only within a test to keep SQL tests stable - private def testInnerJoin( - testName: String, + private def testInnerJoin(testName: String, leftRows: => DataFrame, rightRows: => DataFrame, condition: () => Expression, @@ -98,8 +97,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { ExtractEquiJoinKeys.unapply(join) } - def makeBroadcastHashJoin( - leftKeys: Seq[Expression], + def makeBroadcastHashJoin(leftKeys: Seq[Expression], rightKeys: Seq[Expression], boundCondition: Option[Expression], leftPlan: SparkPlan, @@ -116,8 +114,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(broadcastJoin) } - def makeShuffledHashJoin( - leftKeys: Seq[Expression], + def makeShuffledHashJoin(leftKeys: Seq[Expression], rightKeys: Seq[Expression], boundCondition: Option[Expression], leftPlan: SparkPlan, @@ -130,8 +127,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { EnsureRequirements(spark.sessionState.conf).apply(filteredJoin) } - def makeSortMergeJoin( - leftKeys: Seq[Expression], + def makeSortMergeJoin(leftKeys: Seq[Expression], rightKeys: Seq[Expression], boundCondition: Option[Expression], rangeConditions: Seq[BinaryComparison], @@ -148,8 +144,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { // Disabling these because the code would never follow this path in case of a inner range join if (!expectRangeJoin) { - var counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using BroadcastHashJoin (build=left) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => @@ -162,13 +157,11 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } - counter += 1 } } if(!expectRangeJoin) { - var counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using BroadcastHashJoin (build=right) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => @@ -181,13 +174,11 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } - counter += 1 } } if(!expectRangeJoin) { - var counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using ShuffledHashJoin (build=left) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => @@ -200,13 +191,11 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } - counter += 1 } } if(!expectRangeJoin) { - var counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using ShuffledHashJoin (build=right) $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, _, boundCondition, _, _) => @@ -219,12 +208,10 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } - counter += 1 } } - var counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using SortMergeJoin $counter") { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => @@ -254,11 +241,9 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { } } } - counter += 1 } - counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using CartesianProduct $counter") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", SQLConf.CROSS_JOINS_ENABLED.key -> "true", config -> confValue) { @@ -268,11 +253,9 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } - counter += 1 } - counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using BroadcastNestedLoopJoin build left $counter") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -281,11 +264,9 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } - counter += 1 } - counter = 1 - configOptions.foreach { case (config, confValue) => + configOptions.zipWithIndex.foreach { case ((config, confValue), counter) => test(s"$testName using BroadcastNestedLoopJoin build right $counter") { withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => @@ -294,7 +275,6 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = true) } } - counter += 1 } } From 169bd709781b07d36656f3f26b7bec2dd26d9196 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 7 Jun 2018 20:15:39 +0200 Subject: [PATCH 38/51] Removing exception when numRowsInMemoryBufferThreshold is reached in InMemoryUnsafeRowQueue class. --- .../execution/InMemoryUnsafeRowQueue.scala | 22 +++++-------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index dd691c38664ca..b7be8e42b290a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -30,16 +30,9 @@ import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultIn import org.apache.spark.storage.BlockManager /** - * An append-only array for [[UnsafeRow]]s that strictly keeps content in an in-memory array - * until [[numRowsInMemoryBufferThreshold]] is reached post which it will switch to a mode which - * would flush to disk after [[numRowsSpillThreshold]] is met (or before if there is - * excessive memory consumption). Setting these threshold involves following trade-offs: - * - * - If [[numRowsInMemoryBufferThreshold]] is too high, the in-memory array may occupy more memory - * than is available, resulting in OOM. - * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to - * excessive disk writes. This may lead to a performance regression compared to the normal case - * of using an [[ArrayBuffer]] or [[Array]]. + * A queue which implements a moving window used for sort-merge join inner range optimization. + * Unlike [[ExternalAppendOnlyUnsafeRowArray]] this class currently does not spill over to disk. + * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged. */ private[sql] class InMemoryUnsafeRowQueue( taskMemoryManager: TaskMemoryManager, @@ -119,11 +112,10 @@ private[sql] class InMemoryUnsafeRowQueue( } override def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsInMemoryBufferThreshold) { - inMemoryQueue += unsafeRow.copy() - } else { - throw new RuntimeException(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows") + if (numRows >= numRowsInMemoryBufferThreshold) { + logWarning(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows") } + inMemoryQueue += unsafeRow.copy() numRows += 1 modificationsCount += 1 @@ -178,5 +170,3 @@ private[sql] class InMemoryUnsafeRowQueue( } } } - - From 89169dec6d812f49c325a0906404595f015f49ed Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 7 Jun 2018 20:38:26 +0200 Subject: [PATCH 39/51] Scala style --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 7288d11d982f4..922668d7100a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -206,7 +206,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { */ private def checkRangeConditions(l : Expression, r : Expression, left : LogicalPlan, right : LogicalPlan, - joinKeys : Seq[(Expression, Expression)]):Option[Boolean] = { + joinKeys : Seq[(Expression, Expression)]): Option[Boolean] = { val (lattrs, rattrs) = (l.references.toSeq, r.references.toSeq) if (lattrs.size != 1 || rattrs.size != 1) { None From eeaf048d6b5bbed7c772c844ae6f4535136a5866 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Wed, 13 Jun 2018 10:44:18 +0200 Subject: [PATCH 40/51] Unneeded import --- .../org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index b7be8e42b290a..773a222b4c86c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import java.util.ConcurrentModificationException import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.memory.TaskMemoryManager From 75ce55d2e880fa7e76f94c1487ca21449ef317c7 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Wed, 13 Jun 2018 18:43:16 +0200 Subject: [PATCH 41/51] A dot --- .../scala/org/apache/spark/sql/catalyst/planning/patterns.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 922668d7100a3..5e9ba6c60d242 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -196,7 +196,7 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { /** * Checks if l and r are valid range conditions: - * - l and r expressions should both contain a single reference to one and the same column. + * - l and r expressions should both contain a single reference to one and the same column * - the referenced column should not be part of joinKeys * If these conditions are not met, the function returns None. * From 77dd2a8002ae9b5d5e8dc5b8bb28b2c4cd45db86 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Thu, 14 Jun 2018 09:05:07 +0200 Subject: [PATCH 42/51] A dot --- .../org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index 773a222b4c86c..a82daa66e5401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.BlockManager /** * A queue which implements a moving window used for sort-merge join inner range optimization. * Unlike [[ExternalAppendOnlyUnsafeRowArray]] this class currently does not spill over to disk. - * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged. + * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged */ private[sql] class InMemoryUnsafeRowQueue( taskMemoryManager: TaskMemoryManager, From eac81b46d7cc3264499a1779062917ca337be6e5 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 15 Jun 2018 10:19:49 +0200 Subject: [PATCH 43/51] A dot --- .../org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index a82daa66e5401..773a222b4c86c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.BlockManager /** * A queue which implements a moving window used for sort-merge join inner range optimization. * Unlike [[ExternalAppendOnlyUnsafeRowArray]] this class currently does not spill over to disk. - * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged + * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged. */ private[sql] class InMemoryUnsafeRowQueue( taskMemoryManager: TaskMemoryManager, From 1abde55994720b32e83dabc4d37d2d3254853ed0 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Mon, 18 Jun 2018 10:01:41 +0200 Subject: [PATCH 44/51] A dot --- .../org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index 773a222b4c86c..a82daa66e5401 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.BlockManager /** * A queue which implements a moving window used for sort-merge join inner range optimization. * Unlike [[ExternalAppendOnlyUnsafeRowArray]] this class currently does not spill over to disk. - * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged. + * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged */ private[sql] class InMemoryUnsafeRowQueue( taskMemoryManager: TaskMemoryManager, From dfb4c0fd3d77eccc1761a825658481aff80b09ae Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Tue, 19 Jun 2018 15:50:18 +0200 Subject: [PATCH 45/51] A dot --- .../org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala index a82daa66e5401..773a222b4c86c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala @@ -31,7 +31,7 @@ import org.apache.spark.storage.BlockManager /** * A queue which implements a moving window used for sort-merge join inner range optimization. * Unlike [[ExternalAppendOnlyUnsafeRowArray]] this class currently does not spill over to disk. - * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged + * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged. */ private[sql] class InMemoryUnsafeRowQueue( taskMemoryManager: TaskMemoryManager, From 6d4c031d81616e36914e627420ed8c84f401173e Mon Sep 17 00:00:00 2001 From: Colin Slater Date: Thu, 28 Jun 2018 16:06:51 -0700 Subject: [PATCH 46/51] Fixes for some rebase issues. --- .../org/apache/spark/sql/catalyst/planning/patterns.scala | 1 - .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 +--- .../spark/sql/execution/exchange/EnsureRequirements.scala | 1 + .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 5e9ba6c60d242..241873bbfe706 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -21,7 +21,6 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 188b3f1f9fc3e..094598433b4cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1731,11 +1731,9 @@ class SQLConf extends Serializable with Logging { def sortBeforeRepartition: Boolean = getConf(SORT_BEFORE_REPARTITION) -<<<<<<< HEAD def topKSortFallbackThreshold: Int = getConf(TOP_K_SORT_FALLBACK_THRESHOLD) -======= + def useSmjInnerRangeOptimization: Boolean = getConf(USE_SMJ_INNER_RANGE_OPTIMIZATION) ->>>>>>> Parameter for turning off inner range optimization def fastHashAggregateRowMaxCapacityBit: Int = getConf(FAST_HASH_AGGREGATE_MAX_ROWS_CAPACITY_BIT) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 10cd5b2bbf485..be6b4b89f60e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -290,6 +290,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { reorderJoinKeys(leftKeys, rightKeys, left.outputPartitioning, right.outputPartitioning) SortMergeJoinExec(reorderedLeftKeys, reorderedRightKeys, joinType, rangeConditions, condition, left, right) + case other => other } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 2932a7d7e03d7..e80db0f932a41 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ From 6d9cd123380b35739f444528c45eb80810a12379 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 29 Jun 2018 09:49:22 +0200 Subject: [PATCH 47/51] Merge with upstream --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 4 ++-- .../scala/org/apache/spark/sql/execution/PlannerSuite.scala | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index e80db0f932a41..8463cf0f1a20a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -577,7 +577,7 @@ case class SortMergeJoinExec( else s" || (comp == 0 && ${leftLowerSecRangeKey.value} " + s"$lowerCompop ${rightLowerSecRangeKey.value})" val upperCompExp = if (!useInnerRange || upperSecondaryRangeExpression.isEmpty) "" - else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " + + else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " +36:12 s"$upperCompop ${rightUpperSecRangeKey.value})" logDebug(s"lowerCompExp: $lowerCompExp") @@ -593,7 +593,7 @@ case class SortMergeJoinExec( val rightRngTmpKeyVars = createJoinKey(ctx, rightTmpRow, rightUpperKeys.slice(0, 1), right.output) val rightRngTmpKeyVarsDecl = rightRngTmpKeyVars.map(_.code).mkString("\n") - rightRngTmpKeyVars.foreach(_.code = "") + rightRngTmpKeyVars.foreach(_.code = EmptyBlock) val javaType = CodeGenerator.javaType(rightLowerKeys(0).dataType) ctx.addNewFunction("getRightTmpRangeValue", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index dfde1031fbd79..8d712d21f0589 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -694,11 +694,11 @@ class PlannerSuite extends SharedSQLContext { val plan2 = DummySparkPlan(outputOrdering = Seq(orderingB), outputPartitioning = HashPartitioning(exprB :: Nil, 5)) val smjExec = SortMergeJoinExec( - exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, None, plan1, plan2) + exprA :: exprA :: Nil, exprB :: exprC :: Nil, Inner, Nil, None, plan1, plan2) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(smjExec) outputPlan match { - case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _) => + case SortMergeJoinExec(leftKeys, rightKeys, _, _, _, _, _) => assert(leftKeys == Seq(exprA, exprA)) assert(rightKeys == Seq(exprB, exprC)) case _ => fail() From 7742c104eee9a56331c6e8d607e6756ebdccba87 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 29 Jun 2018 11:22:37 +0200 Subject: [PATCH 48/51] Merge with upstream --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 8463cf0f1a20a..0e0f4612390a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -577,7 +577,7 @@ case class SortMergeJoinExec( else s" || (comp == 0 && ${leftLowerSecRangeKey.value} " + s"$lowerCompop ${rightLowerSecRangeKey.value})" val upperCompExp = if (!useInnerRange || upperSecondaryRangeExpression.isEmpty) "" - else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " +36:12 + else s" || (comp == 0 && ${leftUpperSecRangeKey.value} " + s"$upperCompop ${rightUpperSecRangeKey.value})" logDebug(s"lowerCompExp: $lowerCompExp") From 3a717eef1ab733a5c58ba18997723a14f1929937 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 10 Aug 2018 09:21:35 +0200 Subject: [PATCH 49/51] SMJ inner range spill over implementation and tests --- .../ExternalAppendOnlyUnsafeRowArray.scala | 98 +++- .../execution/InMemoryUnsafeRowQueue.scala | 171 ------ .../execution/joins/SortMergeJoinExec.scala | 30 +- .../org/apache/spark/sql/JoinSuite.scala | 25 + ...xternalAppendOnlyUnsafeRowArraySuite.scala | 515 ++++++++++-------- .../sql/execution/joins/InnerJoinSuite.scala | 2 +- .../apache/spark/sql/test/SQLTestData.scala | 24 + 7 files changed, 440 insertions(+), 425 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index ac282ea2e94f5..999f83db93311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import java.util.ConcurrentModificationException -import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.{ArrayBuffer, Queue} import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.internal.Logging @@ -41,12 +41,16 @@ import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, Unsaf * - If [[numRowsSpillThreshold]] is too low, data will be spilled frequently and lead to * excessive disk writes. This may lead to a performance regression compared to the normal case * of using an [[ArrayBuffer]] or [[Array]]. + * + * If [[asQueue]] is set to true, the class will function as a queue, supporting peek() and + * dequeue() operations. */ private[sql] class ExternalAppendOnlyUnsafeRowArray( taskMemoryManager: TaskMemoryManager, blockManager: BlockManager, serializerManager: SerializerManager, taskContext: TaskContext, + asQueue: Boolean, initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, @@ -58,6 +62,20 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( SparkEnv.get.blockManager, SparkEnv.get.serializerManager, TaskContext.get(), + false, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes, + numRowsInMemoryBufferThreshold, + numRowsSpillThreshold) + } + + def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int, asQueue: Boolean) { + this( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + SparkEnv.get.serializerManager, + TaskContext.get(), + asQueue, 1024, SparkEnv.get.memoryManager.pageSizeBytes, numRowsInMemoryBufferThreshold, @@ -67,7 +85,13 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( private val initialSizeOfInMemoryBuffer = Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) - private val inMemoryBuffer = if (initialSizeOfInMemoryBuffer > 0) { + private val inMemoryQueue = if (asQueue && initialSizeOfInMemoryBuffer > 0) { + new Queue[UnsafeRow]() + } else { + null + } + + private val inMemoryBuffer = if (!asQueue && initialSizeOfInMemoryBuffer > 0) { new ArrayBuffer[UnsafeRow](initialSizeOfInMemoryBuffer) } else { null @@ -76,6 +100,9 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( private var spillableArray: UnsafeExternalSorter = _ private var numRows = 0 + // Used when functioning as a queue to allow skipping 'dequeued' items + private var spillableArrayOffset = 0 + // A counter to keep track of total modifications done to this array since its creation. // This helps to invalidate iterators when there are changes done to the backing array. private var modificationsCount: Long = 0 @@ -95,17 +122,60 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( // inside `UnsafeExternalSorter` spillableArray.cleanupResources() spillableArray = null + spillableArrayOffset = 0 } else if (inMemoryBuffer != null) { inMemoryBuffer.clear() + } else if (inMemoryQueue != null) { + inMemoryQueue.clear() } numFieldsPerRow = 0 numRows = 0 modificationsCount += 1 } + def dequeue(): Option[UnsafeRow] = { + if (!asQueue) { + throw new IllegalStateException("Not instantiated as a queue!") + } + if (numRows == 0) { + None + } + else if (spillableArray != null) { + val retval = Some(generateIterator().next) + numRows -= 1 + modificationsCount += 1 + spillableArrayOffset += 1 + retval + } + else { + numRows -= 1 + modificationsCount += 1 + Some(inMemoryQueue.dequeue()) + } + } + + def peek(): Option[UnsafeRow] = { + if (!asQueue) { + throw new IllegalStateException("Not instantiated as a queue!") + } + if (numRows == 0) { + None + } + else if (spillableArray != null) { + Some(generateIterator().next) + } + else { + Some(inMemoryQueue(0)) + } + } + def add(unsafeRow: UnsafeRow): Unit = { - if (numRows < numRowsInMemoryBufferThreshold) { - inMemoryBuffer += unsafeRow.copy() + if (spillableArray == null && numRows < numRowsInMemoryBufferThreshold) { + if (asQueue) { + inMemoryQueue += unsafeRow.copy() + } else { + inMemoryBuffer += unsafeRow.copy() + } } else { if (spillableArray == null) { logInfo(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows, switching to " + @@ -124,8 +194,21 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( numRowsSpillThreshold, false) + spillableArrayOffset = 0 + // populate with existing in-memory buffered rows - if (inMemoryBuffer != null) { + if (asQueue && inMemoryQueue != null) { + inMemoryQueue.foreach(existingUnsafeRow => + spillableArray.insertRecord( + existingUnsafeRow.getBaseObject, + existingUnsafeRow.getBaseOffset, + existingUnsafeRow.getSizeInBytes, + 0, + false) + ) + inMemoryQueue.clear() + } + if (!asQueue && inMemoryBuffer != null) { inMemoryBuffer.foreach(existingUnsafeRow => spillableArray.insertRecord( existingUnsafeRow.getBaseObject, @@ -168,7 +251,8 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( if (spillableArray == null) { new InMemoryBufferIterator(startIndex) } else { - new SpillableArrayIterator(spillableArray.getIterator(startIndex), numFieldsPerRow) + new SpillableArrayIterator(spillableArray.getIterator(startIndex + spillableArrayOffset), + numFieldsPerRow) } } @@ -198,7 +282,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( override def next(): UnsafeRow = { throwExceptionIfModified() - val result = inMemoryBuffer(currentIndex) + val result = if (asQueue) inMemoryQueue(currentIndex) else inMemoryBuffer(currentIndex) currentIndex += 1 result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala deleted file mode 100644 index 773a222b4c86c..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/InMemoryUnsafeRowQueue.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import java.util.ConcurrentModificationException - -import scala.collection.mutable - -import org.apache.spark.{SparkEnv, TaskContext} -import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.serializer.SerializerManager -import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.execution.ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer -import org.apache.spark.storage.BlockManager - -/** - * A queue which implements a moving window used for sort-merge join inner range optimization. - * Unlike [[ExternalAppendOnlyUnsafeRowArray]] this class currently does not spill over to disk. - * In case [[numRowsInMemoryBufferThreshold]] is reached, only a warning will be logged. - */ -private[sql] class InMemoryUnsafeRowQueue( - taskMemoryManager: TaskMemoryManager, - blockManager: BlockManager, - serializerManager: SerializerManager, - taskContext: TaskContext, - initialSize: Int, - pageSizeBytes: Long, - numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) - extends ExternalAppendOnlyUnsafeRowArray(taskMemoryManager, - blockManager, - serializerManager, - taskContext, - initialSize, - pageSizeBytes, - numRowsInMemoryBufferThreshold, - numRowsSpillThreshold) { - - def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { - this( - TaskContext.get().taskMemoryManager(), - SparkEnv.get.blockManager, - SparkEnv.get.serializerManager, - TaskContext.get(), - 1024, - SparkEnv.get.memoryManager.pageSizeBytes, - numRowsInMemoryBufferThreshold, - numRowsSpillThreshold) - } - - private val initialSizeOfInMemoryBuffer = - Math.min(DefaultInitialSizeOfInMemoryBuffer, numRowsInMemoryBufferThreshold) - - private val inMemoryQueue = if (initialSizeOfInMemoryBuffer > 0) { - new mutable.Queue[UnsafeRow]() - } else { - null - } - - private var numRows = 0 - - override def length: Int = numRows - - override def isEmpty: Boolean = numRows == 0 - - // A counter to keep track of total modifications done to this array since its creation. - // This helps to invalidate iterators when there are changes done to the backing array. - private var modificationsCount: Long = 0 - - private var numFieldsPerRow = 0 - - /** - * Clears up resources (eg. memory) held by the backing storage - */ - override def clear(): Unit = { - if (inMemoryQueue != null) { - inMemoryQueue.clear() - } - numFieldsPerRow = 0 - numRows = 0 - modificationsCount += 1 - } - - def dequeue(): Option[UnsafeRow] = { - if (numRows == 0) { - None - } - else { - numRows -= 1 - Some(inMemoryQueue.dequeue()) - } - } - - def get(idx: Int): UnsafeRow = { - inMemoryQueue(idx) - } - - override def add(unsafeRow: UnsafeRow): Unit = { - if (numRows >= numRowsInMemoryBufferThreshold) { - logWarning(s"Reached spill threshold of $numRowsInMemoryBufferThreshold rows") - } - inMemoryQueue += unsafeRow.copy() - - numRows += 1 - modificationsCount += 1 - } - - /** - * Creates an [[Iterator]] for the current rows in the array starting from a user provided index - * - * If there are subsequent [[add()]] or [[clear()]] calls made on this array after creation of - * the iterator, then the iterator is invalidated thus saving clients from thinking that they - * have read all the data while there were new rows added to this array. - */ - override def generateIterator(startIndex: Int): Iterator[UnsafeRow] = { - if (startIndex < 0 || (numRows > 0 && startIndex > numRows)) { - throw new ArrayIndexOutOfBoundsException( - "Invalid `startIndex` provided for generating iterator over the array. " + - s"Total elements: $numRows, requested `startIndex`: $startIndex") - } - - new InMemoryBufferIterator(startIndex) - } - - override def generateIterator(): Iterator[UnsafeRow] = generateIterator(startIndex = 0) - - private[this] - abstract class ExternalAppendOnlyUnsafeRowArrayIterator extends Iterator[UnsafeRow] { - private val expectedModificationsCount = modificationsCount - - protected def isModified(): Boolean = expectedModificationsCount != modificationsCount - - protected def throwExceptionIfModified(): Unit = { - if (expectedModificationsCount != modificationsCount) { - throw new ConcurrentModificationException( - s"The backing ${classOf[ExternalAppendOnlyUnsafeRowArray].getName} has been modified " + - s"since the creation of this Iterator") - } - } - } - - private[this] class InMemoryBufferIterator(startIndex: Int) - extends ExternalAppendOnlyUnsafeRowArrayIterator { - - private var currentIndex = startIndex - - override def hasNext(): Boolean = !isModified() && currentIndex < numRows - - override def next(): UnsafeRow = { - throwExceptionIfModified() - val result = inMemoryQueue(currentIndex) - currentIndex += 1 - result - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 0e0f4612390a6..ac3ab6843d45b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -553,13 +553,14 @@ case class SortMergeJoinExec( } // A queue to hold all matched rows from right side. - val clsName = if (useInnerRange) classOf[InMemoryUnsafeRowQueue].getName - else classOf[ExternalAppendOnlyUnsafeRowArray].getName + val clsName = classOf[ExternalAppendOnlyUnsafeRowArray].getName val spillThreshold = getSpillThreshold val inMemoryThreshold = getInMemoryThreshold - val matches = ctx.addMutableState(clsName, "matches", + val matches = if (useInnerRange) ctx.addMutableState(clsName, "matches", + v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold, true);", forceInline = true) + else ctx.addMutableState(clsName, "matches", v => s"$v = new $clsName($inMemoryThreshold, $spillThreshold);", forceInline = true) val matchedKeyVars = copyKeys(ctx, leftKeyVars) @@ -607,15 +608,19 @@ case class SortMergeJoinExec( ctx.addNewFunction("dequeueUntilUpperConditionHolds", s""" |private void dequeueUntilUpperConditionHolds() { - | if($matches.isEmpty()) + | if($matches.isEmpty()) { + | $matches.clear(); | return; - | $rightTmpRow = $matches.get(0); + | } + | $rightTmpRow = (InternalRow) $matches.peek().get(); | $javaType tempVal = getRightTmpRangeValue(); | while(${leftLowerSecRangeKey.value} $upperCompop tempVal) { | $matches.dequeue(); - | if($matches.isEmpty()) + | if($matches.isEmpty()) { + | $matches.clear(); | break; - | $rightTmpRow = $matches.get(0); + | } + | $rightTmpRow = (InternalRow) $matches.peek().get(); | tempVal = getRightTmpRangeValue(); | } |} @@ -667,8 +672,7 @@ case class SortMergeJoinExec( | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | dequeueUntilUpperConditionHolds(); - | } - | else { + | } else { | $matches.clear(); | } | } @@ -698,6 +702,8 @@ case class SortMergeJoinExec( | if (!$matches.isEmpty()) { | ${matchedKeyVars.map(_.code).mkString("\n")} | return true; + | } else { + | $matches.clear(); | } | $leftRow = null; | } else { @@ -1096,7 +1102,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches = - new InMemoryUnsafeRowQueue(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, true) private[this] val testJoinRow = new JoinedRow private[this] val joinedRow = new JoinedRow @@ -1114,7 +1120,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( override def getStreamedRow: InternalRow = streamedRow - override def getBufferedMatches: InMemoryUnsafeRowQueue = bufferedMatches + override def getBufferedMatches: ExternalAppendOnlyUnsafeRowArray = bufferedMatches /** * Advances both input iterators, stopping when we have found rows with matching join keys. @@ -1278,7 +1284,7 @@ private[joins] class SortMergeJoinInnerRangeScanner( private def dequeueUntilUpperConditionHolds(): Unit = { if (streamedRow != null) { while (!bufferedMatches.isEmpty && - !upperRangeCondition(testJoinRow(streamedRow, bufferedMatches.get(0)))) { + !upperRangeCondition(testJoinRow(streamedRow, bufferedMatches.peek.get))) { bufferedMatches.dequeue() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 44767dfc92497..499f869cf9020 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -780,6 +780,31 @@ class JoinSuite extends QueryTest with SharedSQLContext { } } + test("test SortMergeJoin inner range (with spill)") { + withSQLConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold" -> "1", + "spark.sql.sortMergeJoinExec.buffer.spill.threshold" -> "2") { + + val expected2 = new ListBuffer[Row]() + expected2.append( + Row(1, 3, 1, 2), Row(1, 3, 1, 2), + Row(1, 3, 1, 3), Row(1, 4, 1, 3), + Row(1, 4, 1, 5), Row(1, 7, 1, 7), + Row(1, 8, 1, 7), Row(2, 1, 2, 1), + Row(2, 1, 2, 2), Row(2, 2, 2, 1), + Row(2, 2, 2, 2), Row(2, 2, 2, 3), + Row(2, 3, 2, 2), Row(2, 3, 2, 3), + Row(3, 2, 3, 3), Row(3, 3, 3, 3), + Row(3, 5, 3, 6) + ) + assertSpilled(sparkContext, "inner range join") { + checkAnswer( + testData4.join(testData5, ('a === 'c ) and ('b between('d - 1, 'd + 1))), + expected2 + ) + } + } + } + test("outer broadcast hash join should not throw NPE") { withTempView("v1", "v2") { withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index ecc7264d79442..bc39d2f3a5657 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -31,7 +31,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar override def afterAll(): Unit = TaskContext.unset() - private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int) + private def withExternalArray(inMemoryThreshold: Int, spillThreshold: Int, asQueue: Boolean) (f: ExternalAppendOnlyUnsafeRowArray => Unit): Unit = { sc = new SparkContext("local", "test", new SparkConf(false)) @@ -43,6 +43,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar SparkEnv.get.blockManager, SparkEnv.get.serializerManager, taskContext, + asQueue, 1024, SparkEnv.get.memoryManager.pageSizeBytes, inMemoryThreshold, @@ -110,265 +111,311 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar assert(getNumBytesSpilled > 0) } - test("insert rows less than the inMemoryThreshold") { - val (inMemoryThreshold, spillThreshold) = (100, 50) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - assert(array.isEmpty) + val asQueueVals = Array(false, true) - val expectedValues = populateRows(array, 1) - assert(!array.isEmpty) - assert(array.length == 1) + asQueueVals.foreach(q => { + test(s"insert rows less than the inMemoryThreshold $q") { + val (inMemoryThreshold, spillThreshold) = (100, 50) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + assert(array.isEmpty) - val iterator1 = validateData(array, expectedValues) + val expectedValues = populateRows(array, 1) + assert(!array.isEmpty) + assert(array.length == 1) - // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) - // Verify that NO spill has happened - populateRows(array, inMemoryThreshold - 1, expectedValues) - assert(array.length == inMemoryThreshold) - assertNoSpill() + val iterator1 = validateData(array, expectedValues) - val iterator2 = validateData(array, expectedValues) + // Add more rows (but not too many to trigger switch to [[UnsafeExternalSorter]]) + // Verify that NO spill has happened + populateRows(array, inMemoryThreshold - 1, expectedValues) + assert(array.length == inMemoryThreshold) + assertNoSpill() - assert(!iterator1.hasNext) - assert(!iterator2.hasNext) - } - } - - test("insert rows more than the inMemoryThreshold but less than spillThreshold") { - val (inMemoryThreshold, spillThreshold) = (10, 50) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - assert(array.isEmpty) - val expectedValues = populateRows(array, inMemoryThreshold - 1) - assert(array.length == (inMemoryThreshold - 1)) - val iterator1 = validateData(array, expectedValues) - assertNoSpill() - - // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a - // spill to happen. Verify that NO spill has happened - populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) - assert(array.length == spillThreshold - 1) - assertNoSpill() - - val iterator2 = validateData(array, expectedValues) - assert(!iterator2.hasNext) - - assert(!iterator1.hasNext) - intercept[ConcurrentModificationException](iterator1.next()) - } - } + val iterator2 = validateData(array, expectedValues) - test("insert rows enough to force spill") { - val (inMemoryThreshold, spillThreshold) = (20, 10) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - assert(array.isEmpty) - val expectedValues = populateRows(array, inMemoryThreshold - 1) - assert(array.length == (inMemoryThreshold - 1)) - val iterator1 = validateData(array, expectedValues) - assertNoSpill() - - // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. - // Verify that spill has happened - populateRows(array, 2, expectedValues) - assert(array.length == inMemoryThreshold + 1) - assertSpill() - - val iterator2 = validateData(array, expectedValues) - assert(!iterator2.hasNext) - - assert(!iterator1.hasNext) - intercept[ConcurrentModificationException](iterator1.next()) - } - } - - test("iterator on an empty array should be empty") { - withExternalArray(inMemoryThreshold = 4, spillThreshold = 10) { array => - val iterator = array.generateIterator() - assert(array.isEmpty) - assert(array.length == 0) - assert(!iterator.hasNext) - } - } - - test("generate iterator with negative start index") { - withExternalArray(inMemoryThreshold = 100, spillThreshold = 56) { array => - val exception = - intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) - - assert(exception.getMessage.contains( - "Invalid `startIndex` provided for generating iterator over the array") - ) - } - } - - test("generate iterator with start index exceeding array's size (without spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - populateRows(array, spillThreshold / 2) + assert(!iterator1.hasNext) + assert(!iterator2.hasNext) + } + }}) + + asQueueVals.foreach(q => { + test(s"insert rows more than the inMemoryThreshold but less than spillThreshold $q") { + val (inMemoryThreshold, spillThreshold) = (10, 50) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] but not too many to cause a + // spill to happen. Verify that NO spill has happened + populateRows(array, spillThreshold - expectedValues.length - 1, expectedValues) + assert(array.length == spillThreshold - 1) + assertNoSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"insert rows enough to force spill $q") { + val (inMemoryThreshold, spillThreshold) = (20, 10) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + assert(array.isEmpty) + val expectedValues = populateRows(array, inMemoryThreshold - 1) + assert(array.length == (inMemoryThreshold - 1)) + val iterator1 = validateData(array, expectedValues) + assertNoSpill() + + // Add more rows to trigger switch to [[UnsafeExternalSorter]] and cause a spill to happen. + // Verify that spill has happened + populateRows(array, 2, expectedValues) + assert(array.length == inMemoryThreshold + 1) + assertSpill() + + val iterator2 = validateData(array, expectedValues) + assert(!iterator2.hasNext) + + assert(!iterator1.hasNext) + intercept[ConcurrentModificationException](iterator1.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"iterator on an empty array should be empty $q") { + withExternalArray(inMemoryThreshold = 4, spillThreshold = 10, q) { array => + val iterator = array.generateIterator() + assert(array.isEmpty) + assert(array.length == 0) + assert(!iterator.hasNext) + } + }}) - val exception = - intercept[ArrayIndexOutOfBoundsException]( - array.generateIterator(startIndex = spillThreshold * 10)) - assert(exception.getMessage.contains( - "Invalid `startIndex` provided for generating iterator over the array")) - } - } + asQueueVals.foreach(q => { + test(s"generate iterator with negative start index $q") { + withExternalArray(inMemoryThreshold = 100, spillThreshold = 56, q) { array => + val exception = + intercept[ArrayIndexOutOfBoundsException](array.generateIterator(startIndex = -10)) - test("generate iterator with start index exceeding array's size (with spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - populateRows(array, spillThreshold * 2) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array") + ) + } + }}) + + asQueueVals.foreach(q => { + test(s"generate iterator with start index exceeding array's size (without spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + populateRows(array, spillThreshold / 2) + + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) + } + }}) - val exception = - intercept[ArrayIndexOutOfBoundsException]( - array.generateIterator(startIndex = spillThreshold * 10)) + asQueueVals.foreach(q => { + test(s"generate iterator with start index exceeding array's size (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + populateRows(array, spillThreshold * 2) - assert(exception.getMessage.contains( - "Invalid `startIndex` provided for generating iterator over the array")) - } - } + val exception = + intercept[ArrayIndexOutOfBoundsException]( + array.generateIterator(startIndex = spillThreshold * 10)) - test("generate iterator with custom start index (without spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - val expectedValues = populateRows(array, inMemoryThreshold) - val startIndex = inMemoryThreshold / 2 - val iterator = array.generateIterator(startIndex = startIndex) - for (i <- startIndex until expectedValues.length) { - checkIfValueExists(iterator, expectedValues(i)) + assert(exception.getMessage.contains( + "Invalid `startIndex` provided for generating iterator over the array")) } - } - } - - test("generate iterator with custom start index (with spill)") { - val (inMemoryThreshold, spillThreshold) = (20, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - val expectedValues = populateRows(array, spillThreshold * 10) - val startIndex = spillThreshold * 2 - val iterator = array.generateIterator(startIndex = startIndex) - for (i <- startIndex until expectedValues.length) { - checkIfValueExists(iterator, expectedValues(i)) + }}) + + asQueueVals.foreach(q => { + test(s"generate iterator with custom start index (without spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + val expectedValues = populateRows(array, inMemoryThreshold) + val startIndex = inMemoryThreshold / 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } } - } - } + }}) + + asQueueVals.foreach(q => { + test(s"generate iterator with custom start index (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (20, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + val expectedValues = populateRows(array, spillThreshold * 10) + val startIndex = spillThreshold * 2 + val iterator = array.generateIterator(startIndex = startIndex) + for (i <- startIndex until expectedValues.length) { + checkIfValueExists(iterator, expectedValues(i)) + } + } + }}) + + asQueueVals.foreach(q => { + test(s"test iterator invalidation (without spill) $q") { + withExternalArray(inMemoryThreshold = 10, spillThreshold = 100, q) { array => + // insert 2 rows, iterate until the first row + populateRows(array, 2) + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + }}) - test("test iterator invalidation (without spill)") { - withExternalArray(inMemoryThreshold = 10, spillThreshold = 100) { array => + test(s"test dequeue with spill") { + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3, true) { array => // insert 2 rows, iterate until the first row - populateRows(array, 2) - - var iterator = array.generateIterator() - assert(iterator.hasNext) - iterator.next() - - // Adding more row(s) should invalidate any old iterators - populateRows(array, 1) - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - - // Clearing the array should also invalidate any old iterators - iterator = array.generateIterator() - assert(iterator.hasNext) - iterator.next() - - array.clear() - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - } - } - - test("test iterator invalidation (with spill)") { - val (inMemoryThreshold, spillThreshold) = (2, 10) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - // Populate enough rows so that spill happens - populateRows(array, spillThreshold * 2) + populateRows(array, 10) assertSpill() var iterator = array.generateIterator() assert(iterator.hasNext) - iterator.next() + val first = iterator.next() - // Adding more row(s) should invalidate any old iterators - populateRows(array, 1) - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) + val first2 = array.dequeue().get + assert(first.equals(first2)) + val second = array.dequeue().get + assert(!second.equals(first2)) - // Clearing the array should also invalidate any old iterators - iterator = array.generateIterator() - assert(iterator.hasNext) - iterator.next() + val third = array.peek().get + val third2 = array.dequeue().get + assert(third.equals(third2)) - array.clear() - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - } - } + assert(array.length == 7) - test("clear on an empty the array") { - withExternalArray(inMemoryThreshold = 2, spillThreshold = 3) { array => - val iterator = array.generateIterator() - assert(!iterator.hasNext) + populateRows(array, 10) - // multiple clear'ing should not have an side-effect - array.clear() - array.clear() - array.clear() - assert(array.isEmpty) - assert(array.length == 0) - - // Clearing an empty array should also invalidate any old iterators - assert(!iterator.hasNext) - intercept[ConcurrentModificationException](iterator.next()) - } - } + array.dequeue() + array.dequeue() - test("clear array (without spill)") { - val (inMemoryThreshold, spillThreshold) = (10, 100) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - // Populate rows ... but not enough to trigger spill - populateRows(array, inMemoryThreshold / 2) - assertNoSpill() - - // Clear the array - array.clear() - assert(array.isEmpty) - - // Re-populate few rows so that there is no spill - // Verify the data. Verify that there was no spill - val expectedValues = populateRows(array, inMemoryThreshold / 2) - validateData(array, expectedValues) - assertNoSpill() - - // Populate more rows .. enough to not trigger a spill. - // Verify the data. Verify that there was no spill - populateRows(array, inMemoryThreshold / 2, expectedValues) - validateData(array, expectedValues) - assertNoSpill() + assert(array.length == 15) } } - test("clear array (with spill)") { - val (inMemoryThreshold, spillThreshold) = (10, 20) - withExternalArray(inMemoryThreshold, spillThreshold) { array => - // Populate enough rows to trigger spill - populateRows(array, spillThreshold * 2) - val bytesSpilled = getNumBytesSpilled - assert(bytesSpilled > 0) - - // Clear the array - array.clear() - assert(array.isEmpty) - - // Re-populate the array ... but NOT upto the point that there is spill. - // Verify data. Verify that there was NO "extra" spill - val expectedValues = populateRows(array, spillThreshold / 2) - validateData(array, expectedValues) - assert(getNumBytesSpilled == bytesSpilled) - - // Populate more rows to trigger spill - // Verify the data. Verify that there was "extra" spill - populateRows(array, spillThreshold * 2, expectedValues) - validateData(array, expectedValues) - assert(getNumBytesSpilled > bytesSpilled) - } - } + asQueueVals.foreach(q => { + test(s"test iterator invalidation (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (2, 10) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + // Populate enough rows so that spill happens + populateRows(array, spillThreshold * 2) + assertSpill() + + var iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + // Adding more row(s) should invalidate any old iterators + populateRows(array, 1) + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + + // Clearing the array should also invalidate any old iterators + iterator = array.generateIterator() + assert(iterator.hasNext) + iterator.next() + + array.clear() + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"clear on an empty the array $q") { + withExternalArray(inMemoryThreshold = 2, spillThreshold = 3, q) { array => + val iterator = array.generateIterator() + assert(!iterator.hasNext) + + // multiple clear'ing should not have an side-effect + array.clear() + array.clear() + array.clear() + assert(array.isEmpty) + assert(array.length == 0) + + // Clearing an empty array should also invalidate any old iterators + assert(!iterator.hasNext) + intercept[ConcurrentModificationException](iterator.next()) + } + }}) + + asQueueVals.foreach(q => { + test(s"clear array (without spill) $q") { + val (inMemoryThreshold, spillThreshold) = (10, 100) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + // Populate rows ... but not enough to trigger spill + populateRows(array, inMemoryThreshold / 2) + assertNoSpill() + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate few rows so that there is no spill + // Verify the data. Verify that there was no spill + val expectedValues = populateRows(array, inMemoryThreshold / 2) + validateData(array, expectedValues) + assertNoSpill() + + // Populate more rows .. enough to not trigger a spill. + // Verify the data. Verify that there was no spill + populateRows(array, inMemoryThreshold / 2, expectedValues) + validateData(array, expectedValues) + assertNoSpill() + } + }}) + + asQueueVals.foreach(q => { + test(s"clear array (with spill) $q") { + val (inMemoryThreshold, spillThreshold) = (10, 20) + withExternalArray(inMemoryThreshold, spillThreshold, q) { array => + // Populate enough rows to trigger spill + populateRows(array, spillThreshold * 2) + val bytesSpilled = getNumBytesSpilled + assert(bytesSpilled > 0) + + // Clear the array + array.clear() + assert(array.isEmpty) + + // Re-populate the array ... but NOT upto the point that there is spill. + // Verify data. Verify that there was NO "extra" spill + val expectedValues = populateRows(array, spillThreshold / 2) + validateData(array, expectedValues) + assert(getNumBytesSpilled == bytesSpilled) + + // Populate more rows to trigger spill + // Verify the data. Verify that there was "extra" spill + populateRows(array, spillThreshold * 2, expectedValues) + validateData(array, expectedValues) + assert(getNumBytesSpilled > bytesSpilled) + } + }}) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 6ec98b50f785a..757eba0e91cc9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -230,7 +230,7 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext { extractJoinParts().foreach { case (_, leftKeys, rightKeys, rangeConditions, boundCondition, _, _) => assert(rangeConditions.isEmpty) - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1", config -> confValue) { checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) => makeSortMergeJoin(leftKeys, rightKeys, boundCondition, rangeConditions, leftPlan, rightPlan), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 615923fe02d6c..e1cf5504364ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -72,6 +72,30 @@ private[sql] trait SQLTestData { self => df } + protected lazy val testData4: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData2(1, 3) :: TestData2(1, 4) :: TestData2(1, 7) :: TestData2(1, 8) :: + TestData2(1, 10) :: TestData2(2, 1) :: TestData2(2, 2) :: TestData2(2, 3) :: + TestData2(2, 8) :: TestData2(3, 1) :: TestData2(3, 2) :: TestData2(3, 3) :: + TestData2(3, 5) :: TestData2(4, 1) :: TestData2(4, 2) :: TestData2(4, 3) :: + Nil).toDF() + df.orderBy("a").sortWithinPartitions("b"). + createOrReplaceTempView("testData4") + df + } + + protected lazy val testData5: DataFrame = { + val df = spark.sparkContext.parallelize( + TestData2(1, 1) :: TestData2(1, 2) :: TestData2(1, 2) :: TestData2(1, 3) :: + TestData2(1, 5) :: TestData2(1, 7) :: TestData2(1, 20) :: + TestData2(2, 1) :: TestData2(2, 2) :: TestData2(2, 3) :: TestData2(2, 5) :: + TestData2(2, 6) :: TestData2(3, 3) :: TestData2(3, 6) :: + Nil).toDF("c", "d") + df.orderBy("c").sortWithinPartitions("d"). + createOrReplaceTempView("testData5") + df + } + protected lazy val negativeData: DataFrame = { val df = spark.sparkContext.parallelize( (1 to 100).map(i => TestData(-i, (-i).toString))).toDF() From 64437e57d3f46a0c183f72bc7ef084b340dee6b4 Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Fri, 10 Aug 2018 09:56:08 +0200 Subject: [PATCH 50/51] External unsafe row dequeue test extension --- .../execution/ExternalAppendOnlyUnsafeRowArraySuite.scala | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index bc39d2f3a5657..12bef254ecb5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -292,7 +292,7 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar test(s"test dequeue with spill") { withExternalArray(inMemoryThreshold = 2, spillThreshold = 3, true) { array => // insert 2 rows, iterate until the first row - populateRows(array, 10) + populateRows(array, 5) assertSpill() var iterator = array.generateIterator() @@ -308,14 +308,16 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar val third2 = array.dequeue().get assert(third.equals(third2)) - assert(array.length == 7) + assert(array.length == 2) + + array.dequeue() populateRows(array, 10) array.dequeue() array.dequeue() - assert(array.length == 15) + assert(array.length == 9) } } From 0a5c8de7769315934712bf853401c332dd747a6e Mon Sep 17 00:00:00 2001 From: Petar Zecevic Date: Sat, 11 Aug 2018 22:48:41 +0200 Subject: [PATCH 51/51] A dot --- .../apache/spark/sql/execution/joins/SortMergeJoinExec.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index ac3ab6843d45b..f7c56132b45cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -868,7 +868,7 @@ case class SortMergeJoinExec( * * @param streamedKeyGenerator a projection that produces join keys from the streamed input. * @param bufferedKeyGenerator a projection that produces join keys from the buffered input. - * @param keyOrdering an ordering which can be used to compare join keys. + * @param keyOrdering an ordering which can be used to compare join keys * @param streamedIter an input whose rows will be streamed. * @param bufferedIter an input whose rows will be buffered to construct sequences of rows that * have the same join key.