From cd02df5a995ac7daa2e76dcc13c54b987265b6e4 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 8 May 2018 15:02:33 -0700 Subject: [PATCH 1/3] fix for Int id type for PowerIterationClustering in spark.ml --- .../ml/clustering/PowerIterationClustering.scala | 2 +- .../clustering/PowerIterationClusteringSuite.scala | 13 +++++-------- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index 2c30a1d9aa947..b672a006a1685 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -232,7 +232,7 @@ class PowerIterationClustering private[clustering] ( case _: LongType => uncastPredictions case otherType => - uncastPredictions.select(col($(idCol)).cast(otherType).alias($(idCol))) + uncastPredictions.withColumn($(idCol), col($(idCol)).cast(otherType)) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala index 65328df17baff..db0f470dca85b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/PowerIterationClusteringSuite.scala @@ -103,24 +103,21 @@ class PowerIterationClusteringSuite extends SparkFunSuite .setK(2) .setMaxIter(1) - def runTest(idType: DataType, neighborType: DataType, similarityType: DataType): Unit = { + def runTest(idType: DataType, similarityType: DataType): Unit = { val typedData = data.select( col("id").cast(idType).alias("id"), - col("neighbors").cast(ArrayType(neighborType, containsNull = false)).alias("neighbors"), + col("neighbors").cast(ArrayType(idType, containsNull = false)).alias("neighbors"), col("similarities").cast(ArrayType(similarityType, containsNull = false)) .alias("similarities") ) - model.transform(typedData).collect() + model.transform(typedData).select("id", "prediction").collect() } for (idType <- Seq(IntegerType, LongType)) { - runTest(idType, LongType, DoubleType) - } - for (neighborType <- Seq(IntegerType, LongType)) { - runTest(LongType, neighborType, DoubleType) + runTest(idType, DoubleType) } for (similarityType <- Seq(FloatType, DoubleType)) { - runTest(LongType, LongType, similarityType) + runTest(LongType, similarityType) } } From e504c3e8172f06b9985a9987fa150ec26bde006a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 9 May 2018 10:41:59 -0700 Subject: [PATCH 2/3] cleaned up code --- .../spark/ml/clustering/PowerIterationClustering.scala | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index b672a006a1685..f5cbd90d85a77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -231,8 +231,12 @@ class PowerIterationClustering private[clustering] ( dataset.schema($(idCol)).dataType match { case _: LongType => uncastPredictions + case _: IntegerType => + uncastPredictions.withColumn($(idCol), col($(idCol)).cast(LongType)) case otherType => - uncastPredictions.withColumn($(idCol), col($(idCol)).cast(otherType)) + throw new IllegalArgumentException(s"PowerIterationClustering had an unexpected error: " + + s"ID col was found to be of type $otherType, despite initial schema checks. Please " + + s"report this bug.") } } From 3e40a92f4d80a5a455cd519d29a8f3c229a3f4d0 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Wed, 9 May 2018 14:46:56 -0700 Subject: [PATCH 3/3] code cleanup per review --- .../apache/spark/ml/clustering/PowerIterationClustering.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala index f5cbd90d85a77..ad49d2595adc8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/PowerIterationClustering.scala @@ -235,8 +235,8 @@ class PowerIterationClustering private[clustering] ( uncastPredictions.withColumn($(idCol), col($(idCol)).cast(LongType)) case otherType => throw new IllegalArgumentException(s"PowerIterationClustering had an unexpected error: " + - s"ID col was found to be of type $otherType, despite initial schema checks. Please " + - s"report this bug.") + s"ID col was found to be of type ${otherType.simpleString}, despite initial schema " + + s"checks. Please report this bug.") } }