Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, Murmur3HashFunction, RowOrdering}
import org.apache.spark.sql.catalyst.plans.physical.KeyGroupedPartitioning
import org.apache.spark.sql.connector.read.{HasPartitionKey, InputPartition}
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.NonFateSharingCache
Expand Down Expand Up @@ -85,22 +84,25 @@ object InternalRowComparableWrapper {
}

def mergePartitions(
leftPartitioning: KeyGroupedPartitioning,
rightPartitioning: KeyGroupedPartitioning,
partitionExpression: Seq[Expression]): Seq[InternalRow] = {
leftPartitioning: Seq[InternalRow],
rightPartitioning: Seq[InternalRow],
partitionExpression: Seq[Expression],
intersect: Boolean = false): Seq[InternalRowComparableWrapper] = {
val partitionDataTypes = partitionExpression.map(_.dataType)
val partitionsSet = new mutable.HashSet[InternalRowComparableWrapper]
leftPartitioning.partitionValues
val leftPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
leftPartitioning
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
rightPartitioning.partitionValues
.foreach(partition => leftPartitionSet.add(partition))
val rightPartitionSet = new mutable.HashSet[InternalRowComparableWrapper]
rightPartitioning
.map(new InternalRowComparableWrapper(_, partitionDataTypes))
.foreach(partition => partitionsSet.add(partition))
// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.createNaturalAscendingOrdering(partitionDataTypes)
.foreach(partition => rightPartitionSet.add(partition))

val result = if (intersect) {
leftPartitionSet.intersect(rightPartitionSet)
} else {
leftPartitionSet.union(rightPartitionSet)
}
partitionsSet.map(_.row).toSeq.sorted(partitionOrdering)
result.toSeq
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we still need to sort the result partitions?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see it is sorted later in the other method now

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,17 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val V2_BUCKETING_PARTITION_FILTER_ENABLED =
buildConf("spark.sql.sources.v2.bucketing.partition.filter.enabled")
.doc(s"Whether to filter partitions when running storage-partition join. " +
s"When enabled, partitions without matches on the other side can be omitted for " +
s"scanning, if allowed by the join type. This config requires both " +
s"${V2_BUCKETING_ENABLED.key} and ${V2_BUCKETING_PUSH_PART_VALUES_ENABLED.key} to be " +
s"enabled.")
.version("4.0.0")
.booleanConf
.createWithDefault(false)

val BUCKETING_MAX_BUCKETS = buildConf("spark.sql.sources.bucketing.maxBuckets")
.doc("The maximum number of buckets allowed.")
.version("2.4.0")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ object InternalRowComparableWrapperBenchmark extends BenchmarkBase {
val leftPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
val rightPartitioning = KeyGroupedPartitioning(expressions, bucketNum, partitions)
val merged = InternalRowComparableWrapper.mergePartitions(
leftPartitioning, rightPartitioning, expressions)
leftPartitioning.partitionValues, rightPartitioning.partitionValues, expressions)
assert(merged.size == bucketNum)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ case class BatchScanExec(
.get
.map(t => (InternalRowComparableWrapper(t._1, partExpressions), t._2))
.toMap
val nestGroupedPartitions = finalGroupedPartitions.map { case (partValue, splits) =>
val filteredGroupedPartitions = finalGroupedPartitions.filter {
case (partValues, _) =>
commonPartValuesMap.keySet.contains(
InternalRowComparableWrapper(partValues, partExpressions))
}
val nestGroupedPartitions = filteredGroupedPartitions.map { case (partValue, splits) =>
// `commonPartValuesMap` should contain the part value since it's the super set.
val numSplits = commonPartValuesMap
.get(InternalRowComparableWrapper(partValue, partExpressions))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,19 @@ case class EnsureRequirements(
// expressions
val partitionExprs = leftSpec.partitioning.expressions

var mergedPartValues = InternalRowComparableWrapper
.mergePartitions(leftSpec.partitioning, rightSpec.partitioning, partitionExprs)
// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val leftParts = reducePartValues(leftSpec.partitioning.partitionValues,
partitionExprs,
leftReducers)
val rightReducers = rightSpec.reducers(leftSpec)
val rightParts = reducePartValues(rightSpec.partitioning.partitionValues,
partitionExprs,
Copy link
Copy Markdown
Member

@viirya viirya Aug 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

partitionExprs are from left spec. As this goes to reduce on right spec. Though they are compatible, but does it guarantee that right spec's partition expressions have same data types as left spec?

For compatible partition expressions, it is r(t1(x)) = t2(x), or r(t2(x)) = t1(x) by definition. But t1 and t2 still can have different data types, isn't?

It just requires r must be same data type as other side, i.e., r(t1(x)) and t2(x), or r(t2(x)) and t1(x).

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes you may be right, let me double check this with a test and get back to you.

rightReducers)

// merge values on both sides
var mergedPartValues = mergePartitions(leftParts, rightParts, partitionExprs, joinType)
.map(v => (v, 1))

logInfo(log"After merging, there are " +
Expand Down Expand Up @@ -525,23 +536,6 @@ case class EnsureRequirements(
}
}

// in case of compatible but not identical partition expressions, we apply 'reduce'
// transforms to group one side's partitions as well as the common partition values
val leftReducers = leftSpec.reducers(rightSpec)
val rightReducers = rightSpec.reducers(leftSpec)

if (leftReducers.isDefined || rightReducers.isDefined) {
mergedPartValues = reduceCommonPartValues(mergedPartValues,
leftSpec.partitioning.expressions,
leftReducers)
mergedPartValues = reduceCommonPartValues(mergedPartValues,
rightSpec.partitioning.expressions,
rightReducers)
val rowOrdering = RowOrdering
.createNaturalAscendingOrdering(partitionExprs.map(_.dataType))
mergedPartValues = mergedPartValues.sorted(rowOrdering.on((t: (InternalRow, _)) => t._1))
}

// Now we need to push-down the common partition information to the scan in each child
newLeft = populateCommonPartitionInfo(left, mergedPartValues, leftSpec.joinKeyPositions,
leftReducers, applyPartialClustering, replicateLeftSide)
Expand Down Expand Up @@ -602,15 +596,15 @@ case class EnsureRequirements(
child, joinKeyPositions))
}

private def reduceCommonPartValues(
commonPartValues: Seq[(InternalRow, Int)],
private def reducePartValues(
partValues: Seq[InternalRow],
expressions: Seq[Expression],
reducers: Option[Seq[Option[Reducer[_, _]]]]) = {
reducers match {
case Some(reducers) => commonPartValues.groupBy { case (row, _) =>
case Some(reducers) => partValues.map { row =>
KeyGroupedShuffleSpec.reducePartitionValue(row, expressions, reducers)
}.map{ case(wrapper, splits) => (wrapper.row, splits.map(_._2).sum) }.toSeq
case _ => commonPartValues
}.distinct.map(_.row)
case _ => partValues
}
}

Expand Down Expand Up @@ -651,6 +645,46 @@ case class EnsureRequirements(
}
}

/**
* Merge and sort partitions values for SPJ and optionally enable partition filtering.
* Both sides must have
* matching partition expressions.
* @param leftPartitioning left side partition values
* @param rightPartitioning right side partition values
* @param partitionExpression partition expressions
* @param joinType join type for optional partition filtering
* @return merged and sorted partition values
*/
private def mergePartitions(
leftPartitioning: Seq[InternalRow],
rightPartitioning: Seq[InternalRow],
partitionExpression: Seq[Expression],
joinType: JoinType): Seq[InternalRow] = {

val merged = if (SQLConf.get.getConf(SQLConf.V2_BUCKETING_PARTITION_FILTER_ENABLED)) {
joinType match {
case Inner => InternalRowComparableWrapper.mergePartitions(
leftPartitioning, rightPartitioning, partitionExpression, intersect = true)
case LeftOuter => leftPartitioning.map(
InternalRowComparableWrapper(_, partitionExpression))
case RightOuter => rightPartitioning.map(
InternalRowComparableWrapper(_, partitionExpression))
case _ => InternalRowComparableWrapper.mergePartitions(leftPartitioning,
rightPartitioning, partitionExpression)
}
} else {
InternalRowComparableWrapper.mergePartitions(leftPartitioning, rightPartitioning,
partitionExpression)
}

// SPARK-41471: We keep to order of partitions to make sure the order of
// partitions is deterministic in different case.
val partitionOrdering: Ordering[InternalRow] = {
RowOrdering.createNaturalAscendingOrdering(partitionExpression.map(_.dataType))
}
merged.map(_.row).sorted(partitionOrdering)
}

def apply(plan: SparkPlan): SparkPlan = {
val newPlan = plan.transformUp {
case operator @ ShuffleExchangeExec(upper: HashPartitioning, child, shuffleOrigin, _)
Expand Down
Loading