From 68cb3e2f92b1aaa14dddfcdb311a7d685d209e97 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 25 Jan 2017 22:03:43 -0800 Subject: [PATCH 1/2] add getNumPartitions and test --- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 7 +++++++ .../test/scala/org/apache/spark/sql/DataFrameSuite.scala | 7 ++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 24b9b810fc5ca..9222562492814 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2420,6 +2420,13 @@ class Dataset[T] private[sql]( RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, numPartitions = None) } + /** + * Returns the number of partitions of this Dataset. + * @group basic + * @since 2.2.0 + */ + def getNumPartitions: Int = rdd.getNumPartitions() + /** * Returns a new Dataset that has exactly `numPartitions` partitions. * Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index cb7b97906a7d7..6a7c630eda145 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -269,6 +269,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).repartition(10).select('key), testData.select('key).collect().toSeq) + checkAnswer( + testData.select('key).repartition(10).select('key).getNumPartitions(), + 10) } test("coalesce") { @@ -1301,12 +1304,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("distributeBy and localSort") { val original = testData.repartition(1) assert(original.rdd.partitions.length == 1) + assert(original.getNumPartitions() == 1) val df = original.repartition(5, $"key") assert(df.rdd.partitions.length == 5) + assert(df.getNumPartitions() == 5) checkAnswer(original.select(), df.select()) val df2 = original.repartition(10, $"key") - assert(df2.rdd.partitions.length == 10) + assert(df2.getNumPartitions() == 10) checkAnswer(original.select(), df2.select()) // Group by the column we are distributed by. This should generate a plan with no exchange From 048759b23d9c4303dc7e7c9cd6d6d6e8eb4a3c21 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 25 Jan 2017 23:23:45 -0800 Subject: [PATCH 2/2] commit the right stuff --- .../src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../scala/org/apache/spark/sql/DataFrameSuite.scala | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9222562492814..157b377eb278e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2425,7 +2425,7 @@ class Dataset[T] private[sql]( * @group basic * @since 2.2.0 */ - def getNumPartitions: Int = rdd.getNumPartitions() + def getNumPartitions: Int = rdd.getNumPartitions /** * Returns a new Dataset that has exactly `numPartitions` partitions. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6a7c630eda145..044a72b678a50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -269,9 +269,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).repartition(10).select('key), testData.select('key).collect().toSeq) - checkAnswer( - testData.select('key).repartition(10).select('key).getNumPartitions(), - 10) + assert(testData.select('key).repartition(10).select('key).getNumPartitions == 10) } test("coalesce") { @@ -1304,14 +1302,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { test("distributeBy and localSort") { val original = testData.repartition(1) assert(original.rdd.partitions.length == 1) - assert(original.getNumPartitions() == 1) + assert(original.getNumPartitions == 1) val df = original.repartition(5, $"key") assert(df.rdd.partitions.length == 5) - assert(df.getNumPartitions() == 5) + assert(df.getNumPartitions == 5) checkAnswer(original.select(), df.select()) val df2 = original.repartition(10, $"key") - assert(df2.getNumPartitions() == 10) + assert(df2.getNumPartitions == 10) checkAnswer(original.select(), df2.select()) // Group by the column we are distributed by. This should generate a plan with no exchange