From 76f4bb7b29e0b605561dac048535380b645839fd Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 17 Apr 2014 21:43:32 -0700 Subject: [PATCH 1/5] init impl of allReduce --- .../spark/rdd/ButterflyReducedRDD.scala | 54 +++++++++++++++++++ .../main/scala/org/apache/spark/rdd/RDD.scala | 13 +++++ .../scala/org/apache/spark/rdd/RDDSuite.scala | 15 ++++++ 3 files changed, 82 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala new file mode 100644 index 0000000000000..a5274920a9480 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.{TaskContext, Partition} + +private case class ButterflyReducedRDDPartition( + override val index: Int, + _1: Partition, + _2: Partition) extends Partition + +private[spark] class ButterflyReducedRDD[T: ClassTag]( + rdd: RDD[T], + offset: Int, + reducer: (T, T) => T) extends RDD[T](rdd) { + + val numPartitions = rdd.partitions.size + + private def targetPartition(i: Int): Int = { + (i + offset) % numPartitions + } + + override def getPartitions: Array[Partition] = { + rdd.partitions.zipWithIndex.map { case (part, i) => + ButterflyReducedRDDPartition(i, part, rdd.partitions(targetPartition(i))) + } + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val pair = s.asInstanceOf[ButterflyReducedRDDPartition] + Iterator((rdd.iterator(pair._1, context) ++ rdd.iterator(pair._2, context)).reduce(reducer)) + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + rdd.preferredLocations(s.asInstanceOf[ButterflyReducedRDDPartition]._1) + } +} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 891efccf23b6a..fbab07a7e4f1b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -765,6 +765,19 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + def allReduce(f: (T, T) => T): RDD[T] = { + var butterfly = this.mapPartitions( (iter) => + Iterator(iter.reduce(f)), + preservesPartitioning = true + ).cache() + var offset = this.partitions.size / 2 + while (offset > 0) { + butterfly = new ButterflyReducedRDD[T](butterfly, offset, f).cache() + offset /= 2 + } + butterfly + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 1901330d8b188..e873d911b52e0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -568,4 +568,19 @@ class RDDSuite extends FunSuite with SharedSparkContext { val ids = ranked.map(_._1).distinct().collect() assert(ids.length === n) } + + test("allReduce") { + val numPartitions = 2048 + val rdd = sc.parallelize(0 until numPartitions * 1000, numPartitions) + var start = System.nanoTime() + val sum = rdd.reduce(_ + _) + println((System.nanoTime() - start) / 1e9) + start = System.nanoTime() + val allReduced = rdd.allReduce(_ + _) + allReduced.count() + println((System.nanoTime() - start) / 1e9) + assert(allReduced.partitions.size === numPartitions) + assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) + fail("") + } } From d14300540da65900a2a19f8edc5adc5ce9c3e72d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 22 Apr 2014 14:11:25 -0700 Subject: [PATCH 2/5] move allReduce to mllib --- .../main/scala/org/apache/spark/rdd/RDD.scala | 13 ----------- .../scala/org/apache/spark/rdd/RDDSuite.scala | 15 ------------- .../mllib}/rdd/ButterflyReducedRDD.scala | 3 ++- .../apache/spark/mllib/rdd/RDDFunctions.scala | 22 +++++++++++++++++++ .../spark/mllib/rdd/RDDFunctionsSuite.scala | 9 ++++++++ 5 files changed, 33 insertions(+), 29 deletions(-) rename {core/src/main/scala/org/apache/spark => mllib/src/main/scala/org/apache/spark/mllib}/rdd/ButterflyReducedRDD.scala (96%) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index fbab07a7e4f1b..891efccf23b6a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -765,19 +765,6 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } - def allReduce(f: (T, T) => T): RDD[T] = { - var butterfly = this.mapPartitions( (iter) => - Iterator(iter.reduce(f)), - preservesPartitioning = true - ).cache() - var offset = this.partitions.size / 2 - while (offset > 0) { - butterfly = new ButterflyReducedRDD[T](butterfly, offset, f).cache() - offset /= 2 - } - butterfly - } - /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e873d911b52e0..1901330d8b188 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -568,19 +568,4 @@ class RDDSuite extends FunSuite with SharedSparkContext { val ids = ranked.map(_._1).distinct().collect() assert(ids.length === n) } - - test("allReduce") { - val numPartitions = 2048 - val rdd = sc.parallelize(0 until numPartitions * 1000, numPartitions) - var start = System.nanoTime() - val sum = rdd.reduce(_ + _) - println((System.nanoTime() - start) / 1e9) - start = System.nanoTime() - val allReduced = rdd.allReduce(_ + _) - allReduced.count() - println((System.nanoTime() - start) / 1e9) - assert(allReduced.partitions.size === numPartitions) - assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) - fail("") - } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala similarity index 96% rename from core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala rename to mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala index a5274920a9480..a2051369931e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ButterflyReducedRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.rdd +package org.apache.spark.mllib.rdd import scala.reflect.ClassTag import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.rdd.RDD private case class ButterflyReducedRDDPartition( override val index: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 365b5e75d7f75..b9e3c3888db4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -44,6 +44,28 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { new SlidingRDD[T](self, windowSize) } } + + + /** + * Computes the all-reduced RDD of the parent RDD, which has the same number of partitions and + * locality information as its parent RDD. Each partition contains only one record, the same as + * calling `RDD#reduce` on its parent RDD. + * + * @param f reducer + * @return all-reduced RDD + */ + def allReduce(f: (T, T) => T): RDD[T] = { + var butterfly = self.mapPartitions( (iter) => + Iterator(iter.reduce(f)), + preservesPartitioning = true + ).cache() + var offset = self.partitions.size / 2 + while (offset > 0) { + butterfly = new ButterflyReducedRDD[T](butterfly, offset, f).cache() + offset /= 2 + } + butterfly + } } private[mllib] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 3f3b10dfff35e..f2a79b672e587 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,4 +46,13 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { val expected = data.flatMap(x => x).sliding(3).toList assert(sliding.collect().toList === expected) } + + test("allReduce") { + val numPartitions = 16 + val rdd = sc.parallelize(0 until numPartitions * 1000, numPartitions) + val sum = rdd.reduce(_ + _) + val allReduced = rdd.allReduce(_ + _) + assert(allReduced.partitions.size === numPartitions) + assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) + } } From 98c329d5ac12ec341a9cf6355cd67ea031189a24 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Wed, 23 Apr 2014 02:44:51 -0700 Subject: [PATCH 3/5] allow arbitrary number of partitions --- .../spark/mllib/rdd/ButterflyReducedRDD.scala | 33 ++++++++------ .../spark/mllib/rdd/PartitionSlicingRDD.scala | 43 +++++++++++++++++++ .../apache/spark/mllib/rdd/RDDFunctions.scala | 38 +++++++++++++--- .../spark/mllib/rdd/RDDFunctionsSuite.scala | 20 ++++++--- 4 files changed, 108 insertions(+), 26 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala index a2051369931e9..4f4cd303efc36 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/ButterflyReducedRDD.scala @@ -22,34 +22,39 @@ import scala.reflect.ClassTag import org.apache.spark.{TaskContext, Partition} import org.apache.spark.rdd.RDD +/** A partition in a butterfly-reduced RDD. */ private case class ButterflyReducedRDDPartition( override val index: Int, - _1: Partition, - _2: Partition) extends Partition + source: Partition, + target: Partition) extends Partition -private[spark] class ButterflyReducedRDD[T: ClassTag]( - rdd: RDD[T], - offset: Int, - reducer: (T, T) => T) extends RDD[T](rdd) { - - val numPartitions = rdd.partitions.size - - private def targetPartition(i: Int): Int = { - (i + offset) % numPartitions +/** + * Butterfly-reduced RDD. + */ +private[mllib] class ButterflyReducedRDD[T: ClassTag]( + @transient rdd: RDD[T], + reducer: (T, T) => T, + @transient offset: Int) extends RDD[T](rdd) { + + /** Computes the target partition. */ + private def targetPartition(i: Int): Partition = { + val j = (i + offset) % rdd.partitions.size + rdd.partitions(j) } override def getPartitions: Array[Partition] = { rdd.partitions.zipWithIndex.map { case (part, i) => - ButterflyReducedRDDPartition(i, part, rdd.partitions(targetPartition(i))) + ButterflyReducedRDDPartition(i, part, targetPartition(i)) } } override def compute(s: Partition, context: TaskContext): Iterator[T] = { val pair = s.asInstanceOf[ButterflyReducedRDDPartition] - Iterator((rdd.iterator(pair._1, context) ++ rdd.iterator(pair._2, context)).reduce(reducer)) + Iterator((firstParent[T].iterator(pair.source, context) ++ + firstParent[T].iterator(pair.target, context)).reduce(reducer)) } override def getPreferredLocations(s: Partition): Seq[String] = { - rdd.preferredLocations(s.asInstanceOf[ButterflyReducedRDDPartition]._1) + rdd.preferredLocations(s.asInstanceOf[ButterflyReducedRDDPartition].source) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala new file mode 100644 index 0000000000000..6bfc06a802bbf --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.rdd + +import scala.reflect.ClassTag + +import org.apache.spark.rdd.RDD +import org.apache.spark.{TaskContext, Partition} + +/** + * Represents an RDD obtained from partition slicing of its parent RDD. + */ +private[mllib] class PartitionSlicingRDD[T: ClassTag]( + @transient rdd: RDD[T], + @transient slice: Seq[Int]) extends RDD[T](rdd) { + + override def getPartitions: Array[Partition] = { + slice.map(i => rdd.partitions(i)).toArray + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + firstParent[T].iterator(s, context) + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + rdd.preferredLocations(s) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index b9e3c3888db4c..11ffcdb632918 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -45,26 +45,52 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } + /** + * Returns an RDD with the specified slice of partitions. + */ + def slicePartitions(slice: Seq[Int]): RDD[T] = { + new PartitionSlicingRDD(self, slice) + } /** * Computes the all-reduced RDD of the parent RDD, which has the same number of partitions and - * locality information as its parent RDD. Each partition contains only one record, the same as - * calling `RDD#reduce` on its parent RDD. + * locality information as its parent RDD. Each partition contains only one record, which is the + * same as calling `RDD#reduce` on its parent RDD. * * @param f reducer * @return all-reduced RDD */ def allReduce(f: (T, T) => T): RDD[T] = { + val numPartitions = self.partitions.size + require(numPartitions > 0, "Parent RDD does not have any partitions.") + val nextPowerOfTwo = { + var i = 0 + while ((numPartitions >> i) > 0) { + i += 1 + } + 1 << i + } var butterfly = self.mapPartitions( (iter) => Iterator(iter.reduce(f)), preservesPartitioning = true ).cache() - var offset = self.partitions.size / 2 + + if (nextPowerOfTwo > numPartitions) { + val padding = self.context.parallelize(Seq.empty[T], nextPowerOfTwo - numPartitions) + butterfly = butterfly.union(padding) + } + + var offset = nextPowerOfTwo >> 1 while (offset > 0) { - butterfly = new ButterflyReducedRDD[T](butterfly, offset, f).cache() - offset /= 2 + butterfly = new ButterflyReducedRDD[T](butterfly, f, offset).cache() + offset >>= 1 + } + + if (nextPowerOfTwo > numPartitions) { + new PartitionSlicingRDD(butterfly, 0 until numPartitions) + } else { + butterfly } - butterfly } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index f2a79b672e587..b39c64f0e39c3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -47,12 +47,20 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(sliding.collect().toList === expected) } + test("slicePartitions") { + val rdd = sc.parallelize(0 until 10, 10) + val slice = Seq(0, 2, 4, 7) + val sliced = rdd.slicePartitions(slice) + assert(sliced.collect().toSeq === slice) + } + test("allReduce") { - val numPartitions = 16 - val rdd = sc.parallelize(0 until numPartitions * 1000, numPartitions) - val sum = rdd.reduce(_ + _) - val allReduced = rdd.allReduce(_ + _) - assert(allReduced.partitions.size === numPartitions) - assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) + for (numPartitions <- 1 to 10) { + val rdd = sc.parallelize(0 until 1000, numPartitions) + val sum = rdd.reduce(_ + _) + val allReduced = rdd.allReduce(_ + _) + assert(allReduced.partitions.size === numPartitions) + assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) + } } } From 49b42cb99ee25f0000615564e5e24afbf82930b9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Fri, 25 Apr 2014 01:07:45 -0700 Subject: [PATCH 4/5] use PartitionPruningRDD --- .../spark/mllib/rdd/PartitionSlicingRDD.scala | 43 ------------------- .../apache/spark/mllib/rdd/RDDFunctions.scala | 11 +---- .../spark/mllib/rdd/RDDFunctionsSuite.scala | 7 --- 3 files changed, 2 insertions(+), 59 deletions(-) delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala deleted file mode 100644 index 6bfc06a802bbf..0000000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/PartitionSlicingRDD.scala +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.rdd - -import scala.reflect.ClassTag - -import org.apache.spark.rdd.RDD -import org.apache.spark.{TaskContext, Partition} - -/** - * Represents an RDD obtained from partition slicing of its parent RDD. - */ -private[mllib] class PartitionSlicingRDD[T: ClassTag]( - @transient rdd: RDD[T], - @transient slice: Seq[Int]) extends RDD[T](rdd) { - - override def getPartitions: Array[Partition] = { - slice.map(i => rdd.partitions(i)).toArray - } - - override def compute(s: Partition, context: TaskContext): Iterator[T] = { - firstParent[T].iterator(s, context) - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - rdd.preferredLocations(s) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 11ffcdb632918..0d4cb85b63c66 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.rdd import scala.language.implicitConversions import scala.reflect.ClassTag -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{PartitionPruningRDD, RDD} /** * Machine learning specific RDD functions. @@ -45,13 +45,6 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } } - /** - * Returns an RDD with the specified slice of partitions. - */ - def slicePartitions(slice: Seq[Int]): RDD[T] = { - new PartitionSlicingRDD(self, slice) - } - /** * Computes the all-reduced RDD of the parent RDD, which has the same number of partitions and * locality information as its parent RDD. Each partition contains only one record, which is the @@ -87,7 +80,7 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { } if (nextPowerOfTwo > numPartitions) { - new PartitionSlicingRDD(butterfly, 0 until numPartitions) + PartitionPruningRDD.create(butterfly, (i) => i < numPartitions) } else { butterfly } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index b39c64f0e39c3..697f0773c0fe9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -47,13 +47,6 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(sliding.collect().toList === expected) } - test("slicePartitions") { - val rdd = sc.parallelize(0 until 10, 10) - val slice = Seq(0, 2, 4, 7) - val sliced = rdd.slicePartitions(slice) - assert(sliced.collect().toSeq === slice) - } - test("allReduce") { for (numPartitions <- 1 to 10) { val rdd = sc.parallelize(0 until 1000, numPartitions) From 97b5588d7ac207274f3d11bbf6a40e3f3af47aa9 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Sun, 4 May 2014 22:43:19 -0700 Subject: [PATCH 5/5] add binaryTreeReduce --- .../mllib/rdd/BinaryTreeReducedRDD.scala | 71 +++++++++++++++++++ .../apache/spark/mllib/rdd/RDDFunctions.scala | 18 +++++ .../spark/mllib/rdd/RDDFunctionsSuite.scala | 10 +++ 3 files changed, 99 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala new file mode 100644 index 0000000000000..58544b3428323 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/BinaryTreeReducedRDD.scala @@ -0,0 +1,71 @@ +package org.apache.spark.mllib.rdd + +import org.apache.spark.{TaskContext, Partition, NarrowDependency} + +import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD + +/** + * Represents a binary tree dependency, where partition `i` depends on partitions `2 * i` and + * `2 * i + 1` (if it exists) of the parent RDD. + * @param rdd parent RDD + * @tparam T value type + */ +private class BinaryTreeDependency[T](@transient rdd: RDD[T]) extends NarrowDependency(rdd) { + + val n = rdd.partitions.size + + override def getParents(partitionId: Int): Seq[Int] = { + val i1 = 2 * partitionId + val i2 = i1 + 1 + if (i2 < n) { + Seq(i1, i2) + } else { + Seq(i1) + } + } +} + +private class BinaryTreeNodePartition( + override val index: Int, + val left: Partition, + val right: Option[Partition]) extends Partition { +} + +private object BinaryTreeNodePartition { + def apply(rdd: RDD[_], i: Int): Partition = { + val n = rdd.partitions.size + val i1 = 2 * i + val i2 = i1 + 1 + if (i2 < n) { + new BinaryTreeNodePartition(i, rdd.partitions(i1), Some(rdd.partitions(i2))) + } else { + new BinaryTreeNodePartition(i, rdd.partitions(i1), None) + } + } +} + +private[mllib] class BinaryTreeReducedRDD[T: ClassTag](rdd: RDD[T], f: (T, T) => T) + extends RDD[T](rdd.context, List(new BinaryTreeDependency(rdd))) { + + override protected def getPartitions: Array[Partition] = { + Array.tabulate((rdd.partitions.size + 1) / 2)(i => BinaryTreeNodePartition(rdd, i)) + } + + override def compute(split: Partition, context: TaskContext): Iterator[T] = { + val p = split.asInstanceOf[BinaryTreeNodePartition] + val iterLeft = rdd.compute(p.left, context) + val iterRight = if (p.right.isDefined) rdd.compute(p.right.get, context) else Iterator.empty + val iter = iterLeft ++ iterRight + if (iter.isEmpty) { + Iterator.empty + } else { + Iterator(iter.reduce(f)) + } + } + + override protected def getPreferredLocations(split: Partition): Seq[String] = { + val p = split.asInstanceOf[BinaryTreeNodePartition] + rdd.preferredLocations(p.left) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 0d4cb85b63c66..5de5da7c4b663 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -85,6 +85,24 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) { butterfly } } + + /** + * Reduce the elements of this RDD using the binary tree algorithm. + */ + def binaryTreeReduce(f: (T, T) => T): T = { + var reduced = self.mapPartitions( (iter) => + if (iter.isEmpty) { + Iterator.empty + } else { + Iterator(iter.reduce(f)) + }, + preservesPartitioning = true + ) + while (reduced.partitions.size > 3) { + reduced = new BinaryTreeReducedRDD(reduced, f) + } + reduced.reduce(f) + } } private[mllib] diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 697f0773c0fe9..3900bc0812770 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -56,4 +56,14 @@ class RDDFunctionsSuite extends FunSuite with LocalSparkContext { assert(allReduced.collect().toSeq === Iterator.fill(numPartitions)(sum).toSeq) } } + + test("binaryTreeReduce") { + val data = 0 until 5 + val expected = data.reduce(_ + _) + for (numPartitions <- 1 to 12) { + val rdd = sc.parallelize(data, numPartitions) + val actual = rdd.binaryTreeReduce(_ + _) + assert(actual === expected) + } + } }