From 48061de21addf5b021f874fa93cddb2882959042 Mon Sep 17 00:00:00 2001 From: NarineK Date: Thu, 17 Mar 2016 13:00:10 -0700 Subject: [PATCH 1/3] Fixed features column header --- mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1f57..edfd8f13d29a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -58,7 +58,7 @@ private[r] object SparkRWrappers { maxIter: Double, k: Double, columns: Array[String]): PipelineModel = { - val assembler = new VectorAssembler().setInputCols(columns) + val assembler = new VectorAssembler().setInputCols(columns).setOutputCol("features") val kMeans = new KMeans() .setInitMode(initMode) .setMaxIter(maxIter.toInt) From 7900fcf705e618f6aceaf49302d7e410852fea9f Mon Sep 17 00:00:00 2001 From: NarineK Date: Thu, 24 Mar 2016 20:42:06 -0700 Subject: [PATCH 2/3] added assertion test check for features --- R/pkg/inst/tests/testthat/test_mllib.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index af84a0abcf94d..1e80a77f99c8c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -124,7 +124,9 @@ test_that("kmeans", { take(training, 1) model <- kmeans(x = training, centers = 2) - sample <- take(select(predict(model, training), "prediction"), 1) + prediction <- predict(model, training) + expect_equal(dim(select(prediction, "features")), c(150,1)) + sample <- take(select(prediction, "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1) From 41017871a3a0d91c9136855524c0257e5c51ecdb Mon Sep 17 00:00:00 2001 From: NarineK Date: Thu, 24 Mar 2016 20:55:37 -0700 Subject: [PATCH 3/3] fixed stylecheck --- R/pkg/inst/tests/testthat/test_mllib.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 1e80a77f99c8c..6d15157c24857 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -125,7 +125,7 @@ test_that("kmeans", { model <- kmeans(x = training, centers = 2) prediction <- predict(model, training) - expect_equal(dim(select(prediction, "features")), c(150,1)) + expect_equal(dim(select(prediction, "features")), c(150, 1)) sample <- take(select(prediction, "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1)