diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index af84a0abcf94d..6d15157c24857 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -124,7 +124,9 @@ test_that("kmeans", { take(training, 1) model <- kmeans(x = training, centers = 2) - sample <- take(select(predict(model, training), "prediction"), 1) + prediction <- predict(model, training) + expect_equal(dim(select(prediction, "features")), c(150, 1)) + sample <- take(select(prediction, "prediction"), 1) expect_equal(typeof(sample$prediction), "integer") expect_equal(sample$prediction, 1) diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index d23e4fc9d1f57..edfd8f13d29a1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -58,7 +58,7 @@ private[r] object SparkRWrappers { maxIter: Double, k: Double, columns: Array[String]): PipelineModel = { - val assembler = new VectorAssembler().setInputCols(columns) + val assembler = new VectorAssembler().setInputCols(columns).setOutputCol("features") val kMeans = new KMeans() .setInitMode(initMode) .setMaxIter(maxIter.toInt)