From 16d6af8ef3eb20e88452db3cca1cdc307697b1c0 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Tue, 3 Jun 2014 18:50:35 -0700 Subject: [PATCH 1/6] [SPARK-1552] Fix type comparison bug in mapVertices and outerJoinVertices In GraphImpl, mapVertices and outerJoinVertices use a more efficient implementation when the map function conserves vertex attribute types. This is implemented by comparing the ClassTags of the old and new vertex attribute types. However, ClassTags store erased types, so the comparison will return a false positive for types with different type parameters, such as Option[Int] and Option[Double]. This PR resolves the problem by unconditionally using the general implementation in mapVertices and outerJoinVertices, and introducing "Conserve" variants of these methods that enforce type equality and use the more efficient implementation. It also adds a test called "mapVertices changing type with same erased type" that failed before the PR and succeeds now. The "Conserve" naming comes from Scala's `List#mapConserve` method. --- docs/graphx-programming-guide.md | 17 +++++- .../scala/org/apache/spark/graphx/Graph.scala | 32 ++++++++++++ .../org/apache/spark/graphx/GraphOps.scala | 2 +- .../org/apache/spark/graphx/Pregel.scala | 4 +- .../apache/spark/graphx/impl/GraphImpl.scala | 50 +++++++++--------- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 8 +-- .../lib/StronglyConnectedComponents.scala | 10 ++-- .../org/apache/spark/graphx/GraphSuite.scala | 52 ++++++++++++++++++- 8 files changed, 136 insertions(+), 39 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index fdb9f98e214e5..7ba2830548953 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -320,6 +320,7 @@ class Graph[VD, ED] { def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] // Transform vertex and edge attributes ========================================================== def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED] + def mapVerticesConserve(map: (VertexID, VD) => VD): Graph[VD, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapEdges[ED2](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] @@ -338,6 +339,9 @@ class Graph[VD, ED] { def outerJoinVertices[U, VD2](other: RDD[(VertexID, U)]) (mapFunc: (VertexID, VD, Option[U]) => VD2) : Graph[VD2, ED] + def outerJoinVerticesConserve[U](other: RDD[(VertexID, U)]) + (mapFunc: (VertexID, VD, Option[U]) => VD) + : Graph[VD, ED] // Aggregate information about adjacent triplets ================================================= def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] @@ -369,6 +373,7 @@ graph contains the following: {% highlight scala %} class Graph[VD, ED] { def mapVertices[VD2](map: (VertexId, VD) => VD2): Graph[VD2, ED] + def mapVerticesConserve(map: (VertexId, VD) => VD): Graph[VD, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] } @@ -392,6 +397,10 @@ val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr)) [Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] +When a call to `mapVertices` would not change the vertex attribute type, use the +`mapVerticesConserve` operator for better performance. This version of the operator avoids moving +unchanged vertex attributes when updating the triplets view. + These operators are often used to initialize the graph for a particular computation or project away unnecessary properties. For example, given a graph with the out-degrees as the vertex properties (we describe how to construct such a graph later), we initialize it for PageRank: @@ -506,6 +515,8 @@ class Graph[VD, ED] { : Graph[VD, ED] def outerJoinVertices[U, VD2](table: RDD[(VertexId, U)])(map: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] + def outerJoinVerticesConserve[U](table: RDD[(VertexId, U)])(map: (VertexId, VD, Option[U]) => VD) + : Graph[VD, ED] } {% endhighlight %} @@ -533,6 +544,10 @@ property type. Because not all vertices may have a matching value in the input function takes an `Option` type. For example, we can setup a graph for PageRank by initializing vertex properties with their `outDegree`. +Similarly to `mapVerticesConserve`, when a call to `outerJoinVertices` would not change the vertex +attribute type, use the `outerJoinVerticesConserve` operator for better performance. This version of +the operator avoids moving unchanged vertex attributes when updating the triplets view. + [Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] @@ -748,7 +763,7 @@ class GraphOps[VD, ED] { // Run the vertex program on all vertices that receive messages val newVerts = g.vertices.innerJoin(messages)(vprog).cache() // Merge the new vertex values back into the graph - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() + g = g.outerJoinVerticesConserve(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() // Send Messages: ------------------------------------------------------------------------------ // Vertices that didn't receive a message above don't appear in newVerts and therefore don't // get to send messages. More precisely the map phase of mapReduceTriplets is only invoked diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index c4f9d6514cae3..a0112eea3c0de 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -130,6 +130,19 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED] + /** + * Transforms each vertex attribute in the graph using the map function. Like [[mapVertices]], but + * since the type is conserved, is able to avoid moving unchanged vertex attributes when updating + * the triplets view. + * + * @note The new graph has the same structure. As a consequence the underlying index structures + * can be reused. + * + * @param map the function from a vertex object to a new vertex value of the same type + * + */ + def mapVerticesConserve(map: (VertexId, VD) => VD): Graph[VD, ED] + /** * Transforms each edge attribute in the graph using the map function. The map function is not * passed the vertex value for the vertices adjacent to the edge. If vertex values are desired, @@ -341,6 +354,25 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] + /** + * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. Like + * [[outerJoinVertices]], but since the type is conserved, is able to avoid moving unchanged + * vertex attributes when updating the triplets view. + * + * The input table should contain at most one entry for each vertex. If no entry in `other` is + * provided for a particular vertex in the graph, the map function receives `None`. + * + * @tparam U the type of entry in the table of updates + * + * @param other the table to join with the vertices in the graph. + * The table should contain at most one entry for each vertex. + * @param mapFunc the function used to compute the new vertex values. The map function is invoked + * for all vertices, even those that do not have a corresponding entry in the table. It must + * conserve the original vertex attribute type. + */ + def outerJoinVerticesConserve[U: ClassTag](other: RDD[(VertexId, U)]) + (mapFunc: (VertexId, VD, Option[U]) => VD): Graph[VD, ED] + /** * The associated [[GraphOps]] object. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index edd5b79da1522..eae85fb6dcf1d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -213,7 +213,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali case None => data } } - graph.outerJoinVertices(table)(uf) + graph.outerJoinVerticesConserve(table)(uf) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 4572eab2875bb..7d28c33d2adf6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -119,7 +119,7 @@ object Pregel extends Logging { mergeMsg: (A, A) => A) : Graph[VD, ED] = { - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + var g = graph.mapVerticesConserve((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -131,7 +131,7 @@ object Pregel extends Logging { val newVerts = g.vertices.innerJoin(messages)(vprog).cache() // Update the graph with the new vertices. prevG = g - g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } + g = g.outerJoinVerticesConserve(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } g.cache() val oldMessages = messages diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 59d9a8808e56e..72de1c37ca1e7 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -101,18 +101,17 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { - vertices.cache() - // The map preserves type, so we can use incremental replication - val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() - val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) - val newReplicatedVertexView = replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2, ED]] - .updateVertices(changedVerts) - new GraphImpl(newVerts, newReplicatedVertexView) - } else { - // The map does not preserve type, so we must re-replicate all vertices - GraphImpl(vertices.mapVertexPartitions(_.map(f)), replicatedVertexView.edges) - } + // The map does not conserve type, so we must re-replicate all vertices + GraphImpl(vertices.mapVertexPartitions(_.map(f)), replicatedVertexView.edges) + } + + override def mapVerticesConserve(f: (VertexId, VD) => VD): Graph[VD, ED] = { + vertices.cache() + // The map conserves type, so we can use incremental replication + val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() + val changedVerts = vertices.diff(newVerts) + val newReplicatedVertexView = replicatedVertexView.updateVertices(changedVerts) + new GraphImpl(newVerts, newReplicatedVertexView) } override def mapEdges[ED2: ClassTag]( @@ -229,19 +228,20 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { - vertices.cache() - // updateF preserves type, so we can use incremental replication - val newVerts = vertices.leftJoin(other)(updateF).cache() - val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) - val newReplicatedVertexView = replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2, ED]] - .updateVertices(changedVerts) - new GraphImpl(newVerts, newReplicatedVertexView) - } else { - // updateF does not preserve type, so we must re-replicate all vertices - val newVerts = vertices.leftJoin(other)(updateF) - GraphImpl(newVerts, replicatedVertexView.edges) - } + // updateF does not conserve type, so we must re-replicate all vertices + val newVerts = vertices.leftJoin(other)(updateF) + GraphImpl(newVerts, replicatedVertexView.edges) + } + + override def outerJoinVerticesConserve[U: ClassTag] + (other: RDD[(VertexId, U)]) + (updateF: (VertexId, VD, Option[U]) => VD): Graph[VD, ED] = { + vertices.cache() + // updateF conserves type, so we can use incremental replication + val newVerts = vertices.leftJoin(other)(updateF).cache() + val changedVerts = vertices.diff(newVerts) + val newReplicatedVertexView = replicatedVertexView.updateVertices(changedVerts) + new GraphImpl(newVerts, newReplicatedVertexView) } /** Test whether the closure accesses the the attribute with name `attrName`. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index ccd7de537b6e3..65a06770fee37 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -78,7 +78,7 @@ object SVDPlusPlus { et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVertices(t0) { + g = g.outerJoinVerticesConserve(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) @@ -112,7 +112,7 @@ object SVDPlusPlus { val t1 = g.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr._2)), (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) - g = g.outerJoinVertices(t1) { + g = g.outerJoinVerticesConserve(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => if (msg.isDefined) (vd._1, vd._1 @@ -125,7 +125,7 @@ object SVDPlusPlus { mapTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVertices(t2) { + g = g.outerJoinVerticesConserve(t2) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => @@ -149,7 +149,7 @@ object SVDPlusPlus { } g.cache() val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) - g = g.outerJoinVertices(t3) { + g = g.outerJoinVerticesConserve(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index 46da38eeb725a..a6ae0883d1d3f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -48,9 +48,9 @@ object StronglyConnectedComponents { iter += 1 do { numVertices = sccWorkGraph.numVertices - sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) { + sccWorkGraph = sccWorkGraph.outerJoinVerticesConserve(sccWorkGraph.outDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) - }.outerJoinVertices(sccWorkGraph.inDegrees) { + }.outerJoinVerticesConserve(sccWorkGraph.inDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) }.cache() @@ -60,14 +60,16 @@ object StronglyConnectedComponents { .mapValues { (vid, data) => data._1} // write values to sccGraph - sccGraph = sccGraph.outerJoinVertices(finalVertices) { + sccGraph = sccGraph.outerJoinVerticesConserve(finalVertices) { (vid, scc, opt) => opt.getOrElse(scc) } // only keep vertices that are not final sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2).cache() } while (sccWorkGraph.numVertices < numVertices) - sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) } + sccWorkGraph = sccWorkGraph.mapVerticesConserve { + case (vid, (color, isFinal)) => (vid, isFinal) + } // collect min of all my neighbor's scc values, update if it's smaller than mine // then notify any neighbors with scc values larger than mine diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index abc25d0671133..4d9dfaf5d92d2 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -150,7 +150,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - // mapVertices preserving type + // mapVertices conserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type @@ -159,6 +159,40 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("mapVertices changing type with same erased type") { + withSpark { sc => + val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])]( + (1L, Some(1)), + (2L, Some(2)), + (3L, Some(3)) + )) + val edges = sc.parallelize(Array( + Edge(1L, 2L, 0), + Edge(2L, 3L, 0), + Edge(3L, 1L, 0) + )) + val graph0 = Graph(vertices, edges) + // Trigger initial vertex replication + graph0.triplets.foreach(x => {}) + // Change type of replicated vertices, but conserve erased type + val graph1 = graph0.mapVertices { + case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + } + // Access replicated vertices, exposing the erased type + val graph2 = graph1.mapTriplets(t => t.srcAttr.get) + assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + } + } + + test("mapVerticesConserve") { + withSpark { sc => + val n = 5 + val star = starGraph(sc, n) + val mappedVAttrs = star.mapVerticesConserve((vid, attr) => attr + "2") + assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) + } + } + test("mapEdges") { withSpark { sc => val n = 3 @@ -297,6 +331,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() + // outerJoinVertices changing type val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } @@ -304,7 +339,8 @@ class GraphSuite extends FunSuite with LocalSparkContext { et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), (a: Int, b: Int) => a + b).collect.toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) - // outerJoinVertices preserving type + + // outerJoinVertices conserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } @@ -313,6 +349,18 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("outerJoinVerticesConserve") { + withSpark { sc => + val n = 5 + val reverseStar = starGraph(sc, n).reverse.cache() + val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } + val newReverseStar = + reverseStar.outerJoinVerticesConserve(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } + assert(newReverseStar.vertices.map(_._2).collect.toSet === + (0 to n).map(x => "v%d".format(x)).toSet) + } + } + test("more edge partitions than vertex partitions") { withSpark { sc => val verts = sc.parallelize(List((1: VertexId, "a"), (2: VertexId, "b")), 1) From f458c8308f74d4cee7a1d29c5a0e23f00a0a89b1 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 4 Jun 2014 13:33:34 -0700 Subject: [PATCH 2/6] Revert "[SPARK-1552] Fix type comparison bug in mapVertices and outerJoinVertices" This reverts commit 16d6af8ef3eb20e88452db3cca1cdc307697b1c0. --- docs/graphx-programming-guide.md | 17 +----- .../scala/org/apache/spark/graphx/Graph.scala | 32 ------------ .../org/apache/spark/graphx/GraphOps.scala | 2 +- .../org/apache/spark/graphx/Pregel.scala | 4 +- .../apache/spark/graphx/impl/GraphImpl.scala | 50 +++++++++--------- .../apache/spark/graphx/lib/SVDPlusPlus.scala | 8 +-- .../lib/StronglyConnectedComponents.scala | 10 ++-- .../org/apache/spark/graphx/GraphSuite.scala | 52 +------------------ 8 files changed, 39 insertions(+), 136 deletions(-) diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index 7ba2830548953..fdb9f98e214e5 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -320,7 +320,6 @@ class Graph[VD, ED] { def partitionBy(partitionStrategy: PartitionStrategy): Graph[VD, ED] // Transform vertex and edge attributes ========================================================== def mapVertices[VD2](map: (VertexID, VD) => VD2): Graph[VD2, ED] - def mapVerticesConserve(map: (VertexID, VD) => VD): Graph[VD, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapEdges[ED2](map: (PartitionID, Iterator[Edge[ED]]) => Iterator[ED2]): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] @@ -339,9 +338,6 @@ class Graph[VD, ED] { def outerJoinVertices[U, VD2](other: RDD[(VertexID, U)]) (mapFunc: (VertexID, VD, Option[U]) => VD2) : Graph[VD2, ED] - def outerJoinVerticesConserve[U](other: RDD[(VertexID, U)]) - (mapFunc: (VertexID, VD, Option[U]) => VD) - : Graph[VD, ED] // Aggregate information about adjacent triplets ================================================= def collectNeighborIds(edgeDirection: EdgeDirection): VertexRDD[Array[VertexID]] def collectNeighbors(edgeDirection: EdgeDirection): VertexRDD[Array[(VertexID, VD)]] @@ -373,7 +369,6 @@ graph contains the following: {% highlight scala %} class Graph[VD, ED] { def mapVertices[VD2](map: (VertexId, VD) => VD2): Graph[VD2, ED] - def mapVerticesConserve(map: (VertexId, VD) => VD): Graph[VD, ED] def mapEdges[ED2](map: Edge[ED] => ED2): Graph[VD, ED2] def mapTriplets[ED2](map: EdgeTriplet[VD, ED] => ED2): Graph[VD, ED2] } @@ -397,10 +392,6 @@ val newGraph = graph.mapVertices((id, attr) => mapUdf(id, attr)) [Graph.mapVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@mapVertices[VD2]((VertexId,VD)⇒VD2)(ClassTag[VD2]):Graph[VD2,ED] -When a call to `mapVertices` would not change the vertex attribute type, use the -`mapVerticesConserve` operator for better performance. This version of the operator avoids moving -unchanged vertex attributes when updating the triplets view. - These operators are often used to initialize the graph for a particular computation or project away unnecessary properties. For example, given a graph with the out-degrees as the vertex properties (we describe how to construct such a graph later), we initialize it for PageRank: @@ -515,8 +506,6 @@ class Graph[VD, ED] { : Graph[VD, ED] def outerJoinVertices[U, VD2](table: RDD[(VertexId, U)])(map: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] - def outerJoinVerticesConserve[U](table: RDD[(VertexId, U)])(map: (VertexId, VD, Option[U]) => VD) - : Graph[VD, ED] } {% endhighlight %} @@ -544,10 +533,6 @@ property type. Because not all vertices may have a matching value in the input function takes an `Option` type. For example, we can setup a graph for PageRank by initializing vertex properties with their `outDegree`. -Similarly to `mapVerticesConserve`, when a call to `outerJoinVertices` would not change the vertex -attribute type, use the `outerJoinVerticesConserve` operator for better performance. This version of -the operator avoids moving unchanged vertex attributes when updating the triplets view. - [Graph.outerJoinVertices]: api/scala/index.html#org.apache.spark.graphx.Graph@outerJoinVertices[U,VD2](RDD[(VertexId,U)])((VertexId,VD,Option[U])⇒VD2)(ClassTag[U],ClassTag[VD2]):Graph[VD2,ED] @@ -763,7 +748,7 @@ class GraphOps[VD, ED] { // Run the vertex program on all vertices that receive messages val newVerts = g.vertices.innerJoin(messages)(vprog).cache() // Merge the new vertex values back into the graph - g = g.outerJoinVerticesConserve(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() + g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) }.cache() // Send Messages: ------------------------------------------------------------------------------ // Vertices that didn't receive a message above don't appear in newVerts and therefore don't // get to send messages. More precisely the map phase of mapReduceTriplets is only invoked diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index a0112eea3c0de..c4f9d6514cae3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -130,19 +130,6 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab */ def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED] - /** - * Transforms each vertex attribute in the graph using the map function. Like [[mapVertices]], but - * since the type is conserved, is able to avoid moving unchanged vertex attributes when updating - * the triplets view. - * - * @note The new graph has the same structure. As a consequence the underlying index structures - * can be reused. - * - * @param map the function from a vertex object to a new vertex value of the same type - * - */ - def mapVerticesConserve(map: (VertexId, VD) => VD): Graph[VD, ED] - /** * Transforms each edge attribute in the graph using the map function. The map function is not * passed the vertex value for the vertices adjacent to the edge. If vertex values are desired, @@ -354,25 +341,6 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab (mapFunc: (VertexId, VD, Option[U]) => VD2) : Graph[VD2, ED] - /** - * Joins the vertices with entries in the `table` RDD and merges the results using `mapFunc`. Like - * [[outerJoinVertices]], but since the type is conserved, is able to avoid moving unchanged - * vertex attributes when updating the triplets view. - * - * The input table should contain at most one entry for each vertex. If no entry in `other` is - * provided for a particular vertex in the graph, the map function receives `None`. - * - * @tparam U the type of entry in the table of updates - * - * @param other the table to join with the vertices in the graph. - * The table should contain at most one entry for each vertex. - * @param mapFunc the function used to compute the new vertex values. The map function is invoked - * for all vertices, even those that do not have a corresponding entry in the table. It must - * conserve the original vertex attribute type. - */ - def outerJoinVerticesConserve[U: ClassTag](other: RDD[(VertexId, U)]) - (mapFunc: (VertexId, VD, Option[U]) => VD): Graph[VD, ED] - /** * The associated [[GraphOps]] object. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index eae85fb6dcf1d..edd5b79da1522 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -213,7 +213,7 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali case None => data } } - graph.outerJoinVerticesConserve(table)(uf) + graph.outerJoinVertices(table)(uf) } /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 7d28c33d2adf6..4572eab2875bb 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -119,7 +119,7 @@ object Pregel extends Logging { mergeMsg: (A, A) => A) : Graph[VD, ED] = { - var g = graph.mapVerticesConserve((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -131,7 +131,7 @@ object Pregel extends Logging { val newVerts = g.vertices.innerJoin(messages)(vprog).cache() // Update the graph with the new vertices. prevG = g - g = g.outerJoinVerticesConserve(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } + g = g.outerJoinVertices(newVerts) { (vid, old, newOpt) => newOpt.getOrElse(old) } g.cache() val oldMessages = messages diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 72de1c37ca1e7..59d9a8808e56e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -101,17 +101,18 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( } override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = { - // The map does not conserve type, so we must re-replicate all vertices - GraphImpl(vertices.mapVertexPartitions(_.map(f)), replicatedVertexView.edges) - } - - override def mapVerticesConserve(f: (VertexId, VD) => VD): Graph[VD, ED] = { - vertices.cache() - // The map conserves type, so we can use incremental replication - val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() - val changedVerts = vertices.diff(newVerts) - val newReplicatedVertexView = replicatedVertexView.updateVertices(changedVerts) - new GraphImpl(newVerts, newReplicatedVertexView) + if (classTag[VD] equals classTag[VD2]) { + vertices.cache() + // The map preserves type, so we can use incremental replication + val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() + val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) + val newReplicatedVertexView = replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2, ED]] + .updateVertices(changedVerts) + new GraphImpl(newVerts, newReplicatedVertexView) + } else { + // The map does not preserve type, so we must re-replicate all vertices + GraphImpl(vertices.mapVertexPartitions(_.map(f)), replicatedVertexView.edges) + } } override def mapEdges[ED2: ClassTag]( @@ -228,20 +229,19 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { - // updateF does not conserve type, so we must re-replicate all vertices - val newVerts = vertices.leftJoin(other)(updateF) - GraphImpl(newVerts, replicatedVertexView.edges) - } - - override def outerJoinVerticesConserve[U: ClassTag] - (other: RDD[(VertexId, U)]) - (updateF: (VertexId, VD, Option[U]) => VD): Graph[VD, ED] = { - vertices.cache() - // updateF conserves type, so we can use incremental replication - val newVerts = vertices.leftJoin(other)(updateF).cache() - val changedVerts = vertices.diff(newVerts) - val newReplicatedVertexView = replicatedVertexView.updateVertices(changedVerts) - new GraphImpl(newVerts, newReplicatedVertexView) + if (classTag[VD] equals classTag[VD2]) { + vertices.cache() + // updateF preserves type, so we can use incremental replication + val newVerts = vertices.leftJoin(other)(updateF).cache() + val changedVerts = vertices.asInstanceOf[VertexRDD[VD2]].diff(newVerts) + val newReplicatedVertexView = replicatedVertexView.asInstanceOf[ReplicatedVertexView[VD2, ED]] + .updateVertices(changedVerts) + new GraphImpl(newVerts, newReplicatedVertexView) + } else { + // updateF does not preserve type, so we must re-replicate all vertices + val newVerts = vertices.leftJoin(other)(updateF) + GraphImpl(newVerts, replicatedVertexView.edges) + } } /** Test whether the closure accesses the the attribute with name `attrName`. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 65a06770fee37..ccd7de537b6e3 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -78,7 +78,7 @@ object SVDPlusPlus { et => Iterator((et.srcId, (1L, et.attr)), (et.dstId, (1L, et.attr))), (g1: (Long, Double), g2: (Long, Double)) => (g1._1 + g2._1, g1._2 + g2._2)) - g = g.outerJoinVerticesConserve(t0) { + g = g.outerJoinVertices(t0) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(Long, Double)]) => (vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1)) @@ -112,7 +112,7 @@ object SVDPlusPlus { val t1 = g.mapReduceTriplets( et => Iterator((et.srcId, et.dstAttr._2)), (g1: DoubleMatrix, g2: DoubleMatrix) => g1.addColumnVector(g2)) - g = g.outerJoinVerticesConserve(t1) { + g = g.outerJoinVertices(t1) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[DoubleMatrix]) => if (msg.isDefined) (vd._1, vd._1 @@ -125,7 +125,7 @@ object SVDPlusPlus { mapTrainF(conf, u), (g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) => (g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3)) - g = g.outerJoinVerticesConserve(t2) { + g = g.outerJoinVertices(t2) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) => @@ -149,7 +149,7 @@ object SVDPlusPlus { } g.cache() val t3 = g.mapReduceTriplets(mapTestF(conf, u), (g1: Double, g2: Double) => g1 + g2) - g = g.outerJoinVerticesConserve(t3) { + g = g.outerJoinVertices(t3) { (vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) => if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala index a6ae0883d1d3f..46da38eeb725a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/StronglyConnectedComponents.scala @@ -48,9 +48,9 @@ object StronglyConnectedComponents { iter += 1 do { numVertices = sccWorkGraph.numVertices - sccWorkGraph = sccWorkGraph.outerJoinVerticesConserve(sccWorkGraph.outDegrees) { + sccWorkGraph = sccWorkGraph.outerJoinVertices(sccWorkGraph.outDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) - }.outerJoinVerticesConserve(sccWorkGraph.inDegrees) { + }.outerJoinVertices(sccWorkGraph.inDegrees) { (vid, data, degreeOpt) => if (degreeOpt.isDefined) data else (vid, true) }.cache() @@ -60,16 +60,14 @@ object StronglyConnectedComponents { .mapValues { (vid, data) => data._1} // write values to sccGraph - sccGraph = sccGraph.outerJoinVerticesConserve(finalVertices) { + sccGraph = sccGraph.outerJoinVertices(finalVertices) { (vid, scc, opt) => opt.getOrElse(scc) } // only keep vertices that are not final sccWorkGraph = sccWorkGraph.subgraph(vpred = (vid, data) => !data._2).cache() } while (sccWorkGraph.numVertices < numVertices) - sccWorkGraph = sccWorkGraph.mapVerticesConserve { - case (vid, (color, isFinal)) => (vid, isFinal) - } + sccWorkGraph = sccWorkGraph.mapVertices{ case (vid, (color, isFinal)) => (vid, isFinal) } // collect min of all my neighbor's scc values, update if it's smaller than mine // then notify any neighbors with scc values larger than mine diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 4d9dfaf5d92d2..abc25d0671133 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -150,7 +150,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val star = starGraph(sc, n) - // mapVertices conserving type + // mapVertices preserving type val mappedVAttrs = star.mapVertices((vid, attr) => attr + "2") assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) // mapVertices changing type @@ -159,40 +159,6 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } - test("mapVertices changing type with same erased type") { - withSpark { sc => - val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])]( - (1L, Some(1)), - (2L, Some(2)), - (3L, Some(3)) - )) - val edges = sc.parallelize(Array( - Edge(1L, 2L, 0), - Edge(2L, 3L, 0), - Edge(3L, 1L, 0) - )) - val graph0 = Graph(vertices, edges) - // Trigger initial vertex replication - graph0.triplets.foreach(x => {}) - // Change type of replicated vertices, but conserve erased type - val graph1 = graph0.mapVertices { - case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) - } - // Access replicated vertices, exposing the erased type - val graph2 = graph1.mapTriplets(t => t.srcAttr.get) - assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) - } - } - - test("mapVerticesConserve") { - withSpark { sc => - val n = 5 - val star = starGraph(sc, n) - val mappedVAttrs = star.mapVerticesConserve((vid, attr) => attr + "2") - assert(mappedVAttrs.vertices.collect.toSet === (0 to n).map(x => (x: VertexId, "v2")).toSet) - } - } - test("mapEdges") { withSpark { sc => val n = 3 @@ -331,7 +297,6 @@ class GraphSuite extends FunSuite with LocalSparkContext { withSpark { sc => val n = 5 val reverseStar = starGraph(sc, n).reverse.cache() - // outerJoinVertices changing type val reverseStarDegrees = reverseStar.outerJoinVertices(reverseStar.outDegrees) { (vid, a, bOpt) => bOpt.getOrElse(0) } @@ -339,8 +304,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { et => Iterator((et.srcId, et.dstAttr), (et.dstId, et.srcAttr)), (a: Int, b: Int) => a + b).collect.toSet assert(neighborDegreeSums === Set((0: VertexId, n)) ++ (1 to n).map(x => (x: VertexId, 0))) - - // outerJoinVertices conserving type + // outerJoinVertices preserving type val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } val newReverseStar = reverseStar.outerJoinVertices(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } @@ -349,18 +313,6 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } - test("outerJoinVerticesConserve") { - withSpark { sc => - val n = 5 - val reverseStar = starGraph(sc, n).reverse.cache() - val messages = reverseStar.vertices.mapValues { (vid, attr) => vid.toString } - val newReverseStar = - reverseStar.outerJoinVerticesConserve(messages) { (vid, a, bOpt) => a + bOpt.getOrElse("") } - assert(newReverseStar.vertices.map(_._2).collect.toSet === - (0 to n).map(x => "v%d".format(x)).toSet) - } - } - test("more edge partitions than vertex partitions") { withSpark { sc => val verts = sc.parallelize(List((1: VertexId, "a"), (2: VertexId, "b")), 1) From 29a5ab7574bec9eb93f9912a62206f6cec8c603c Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 4 Jun 2014 13:58:27 -0700 Subject: [PATCH 3/6] Add failing test --- .../org/apache/spark/graphx/GraphSuite.scala | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index abc25d0671133..3208faaed117e 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -159,6 +159,31 @@ class GraphSuite extends FunSuite with LocalSparkContext { } } + test("mapVertices changing type with same erased type") { + withSpark { sc => + val vertices = sc.parallelize(Array[(Long, Option[java.lang.Integer])]( + (1L, Some(1)), + (2L, Some(2)), + (3L, Some(3)) + )) + val edges = sc.parallelize(Array( + Edge(1L, 2L, 0), + Edge(2L, 3L, 0), + Edge(3L, 1L, 0) + )) + val graph0 = Graph(vertices, edges) + // Trigger initial vertex replication + graph0.triplets.foreach(x => {}) + // Change type of replicated vertices, but conserve erased type + val graph1 = graph0.mapVertices { + case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) + } + // Access replicated vertices, exposing the erased type + val graph2 = graph1.mapTriplets(t => t.srcAttr.get) + assert(graph2.edges.map(_.attr).collect.toSet === Set[java.lang.Double](1.0, 2.0, 3.0)) + } + } + test("mapEdges") { withSpark { sc => val n = 3 From a704e5fece80497a1c57264414f86243d5ceffe0 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 4 Jun 2014 13:57:05 -0700 Subject: [PATCH 4/6] Use type equality constraint with default argument --- .../main/scala/org/apache/spark/graphx/Graph.scala | 5 +++-- .../org/apache/spark/graphx/impl/GraphImpl.scala | 14 ++++++++++---- .../apache/spark/graphx/lib/LabelPropagation.scala | 2 +- .../apache/spark/graphx/lib/ShortestPaths.scala | 2 +- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index c4f9d6514cae3..f10b453cdc075 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -128,7 +128,8 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} * */ - def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2): Graph[VD2, ED] + def mapVertices[VD2: ClassTag](map: (VertexId, VD) => VD2) + (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] /** * Transforms each edge attribute in the graph using the map function. The map function is not @@ -338,7 +339,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * }}} */ def outerJoinVertices[U: ClassTag, VD2: ClassTag](other: RDD[(VertexId, U)]) - (mapFunc: (VertexId, VD, Option[U]) => VD2) + (mapFunc: (VertexId, VD, Option[U]) => VD2)(implicit eq: VD =:= VD2 = null) : Graph[VD2, ED] /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index 59d9a8808e56e..12dc45d2f8c40 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -100,8 +100,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( new GraphImpl(vertices.reverseRoutingTables(), replicatedVertexView.reverse()) } - override def mapVertices[VD2: ClassTag](f: (VertexId, VD) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { + override def mapVertices[VD2: ClassTag] + (f: (VertexId, VD) => VD2)(implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = { + // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left + // null if not + if (eq != null) { vertices.cache() // The map preserves type, so we can use incremental replication val newVerts = vertices.mapVertexPartitions(_.map(f)).cache() @@ -228,8 +231,11 @@ class GraphImpl[VD: ClassTag, ED: ClassTag] protected ( override def outerJoinVertices[U: ClassTag, VD2: ClassTag] (other: RDD[(VertexId, U)]) - (updateF: (VertexId, VD, Option[U]) => VD2): Graph[VD2, ED] = { - if (classTag[VD] equals classTag[VD2]) { + (updateF: (VertexId, VD, Option[U]) => VD2) + (implicit eq: VD =:= VD2 = null): Graph[VD2, ED] = { + // The implicit parameter eq will be populated by the compiler if VD and VD2 are equal, and left + // null if not + if (eq != null) { vertices.cache() // updateF preserves type, so we can use incremental replication val newVerts = vertices.leftJoin(other)(updateF).cache() diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index 776bfb8dd6bfa..f6c3fa69ae0dc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -41,7 +41,7 @@ object LabelPropagation { * * @return a graph with vertex attributes containing the label of community affiliation */ - def run[ED: ClassTag](graph: Graph[_, ED], maxSteps: Int): Graph[VertexId, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(e: EdgeTriplet[VertexId, ED]) = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index bba070f256d80..1b78407706f93 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -49,7 +49,7 @@ object ShortestPaths { * @return a graph where each vertex attribute is a map containing the shortest-path distance to * each reachable landmark vertex. */ - def run[ED: ClassTag](graph: Graph[_, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { + def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { val spGraph = graph.mapVertices { (vid, attr) => if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap() } From 7388705b5d5bed01a39fa358eadd358485ea1d11 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 4 Jun 2014 14:12:30 -0700 Subject: [PATCH 5/6] Remove unnecessary ClassTag for VD parameters --- .../scala/org/apache/spark/graphx/lib/LabelPropagation.scala | 2 +- .../main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index f6c3fa69ae0dc..82e9e06515179 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -41,7 +41,7 @@ object LabelPropagation { * * @return a graph with vertex attributes containing the label of community affiliation */ - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { + def run[VD, ED: ClassTag](graph: Graph[VD, ED], maxSteps: Int): Graph[VertexId, ED] = { val lpaGraph = graph.mapVertices { case (vid, _) => vid } def sendMessage(e: EdgeTriplet[VertexId, ED]) = { Iterator((e.srcId, Map(e.dstAttr -> 1L)), (e.dstId, Map(e.srcAttr -> 1L))) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 1b78407706f93..590f0474957dd 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -49,7 +49,7 @@ object ShortestPaths { * @return a graph where each vertex attribute is a map containing the shortest-path distance to * each reachable landmark vertex. */ - def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { + def run[VD, ED: ClassTag](graph: Graph[VD, ED], landmarks: Seq[VertexId]): Graph[SPMap, ED] = { val spGraph = graph.mapVertices { (vid, attr) => if (landmarks.contains(vid)) makeMap(vid -> 0) else makeMap() } From 68a4fffbf77be3b9dbb7e400d96d6d1118696d50 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Wed, 4 Jun 2014 14:15:54 -0700 Subject: [PATCH 6/6] Undo conserve naming --- graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala index 3208faaed117e..6506bac73d71c 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala @@ -174,7 +174,7 @@ class GraphSuite extends FunSuite with LocalSparkContext { val graph0 = Graph(vertices, edges) // Trigger initial vertex replication graph0.triplets.foreach(x => {}) - // Change type of replicated vertices, but conserve erased type + // Change type of replicated vertices, but preserve erased type val graph1 = graph0.mapVertices { case (vid, integerOpt) => integerOpt.map((x: java.lang.Integer) => (x.toDouble): java.lang.Double) }