From e59df34bb71a3bfaa4ec0dca704dab3f1ac14660 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 2 Sep 2020 00:48:02 +0800 Subject: [PATCH] Bucket join should work if SHUFFLE_PARTITIONS larger than bucket number --- .../exchange/EnsureRequirements.scala | 15 ++++++---- .../spark/sql/sources/BucketedReadSuite.scala | 30 +++++++++++++++++++ 2 files changed, 40 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 186bac6f43332..b176598ed8c2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -85,11 +85,16 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { childrenIndexes.map(children).filterNot(_.isInstanceOf[ShuffleExchangeExec]) .map(_.outputPartitioning.numPartitions) val expectedChildrenNumPartitions = if (nonShuffleChildrenNumPartitions.nonEmpty) { - // Here we pick the max number of partitions among these non-shuffle children as the - // expected number of shuffle partitions. However, if it's smaller than - // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the - // expected number of shuffle partitions. - math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions) + if (nonShuffleChildrenNumPartitions.length == childrenIndexes.length) { + // Here we pick the max number of partitions among these non-shuffle children. + nonShuffleChildrenNumPartitions.max + } else { + // Here we pick the max number of partitions among these non-shuffle children as the + // expected number of shuffle partitions. However, if it's smaller than + // `conf.numShufflePartitions`, we pick `conf.numShufflePartitions` as the + // expected number of shuffle partitions. + math.max(nonShuffleChildrenNumPartitions.max, conf.defaultNumShufflePartitions) + } } else { childrenNumPartitions.max } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 98886d271e977..f8276b143c1e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -875,6 +875,36 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { } } + test("SPARK-32767 Bucket join should work if SHUFFLE_PARTITIONS larger than bucket number") { + withSQLConf( + SQLConf.SHUFFLE_PARTITIONS.key -> "9", + SQLConf.COALESCE_PARTITIONS_INITIAL_PARTITION_NUM.key -> "10") { + + val testSpec1 = BucketedTableTestSpec( + Some(BucketSpec(8, Seq("i", "j"), Seq("i", "j"))), + numPartitions = 1, + expectedShuffle = false, + expectedSort = false, + expectedNumOutputPartitions = Some(8)) + val testSpec2 = BucketedTableTestSpec( + Some(BucketSpec(6, Seq("i", "j"), Seq("i", "j"))), + numPartitions = 1, + expectedShuffle = true, + expectedSort = true, + expectedNumOutputPartitions = Some(8)) + Seq(false, true).foreach { enableAdaptive => + withSQLConf(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> s"$enableAdaptive") { + Seq((testSpec1, testSpec2), (testSpec2, testSpec1)).foreach { specs => + testBucketing( + bucketedTableTestSpecLeft = specs._1, + bucketedTableTestSpecRight = specs._2, + joinCondition = joinCondition(Seq("i", "j"))) + } + } + } + } + } + test("bucket coalescing eliminates shuffle") { withSQLConf(SQLConf.COALESCE_BUCKETS_IN_JOIN_ENABLED.key -> "true") { // The side with bucketedTableTestSpec1 will be coalesced to have 4 output partitions.