From f0e09c5dcc71e91a945c075412b0bb400a02b370 Mon Sep 17 00:00:00 2001 From: Andy MacKinlay Date: Fri, 20 Jan 2017 12:56:45 +1100 Subject: [PATCH 1/6] [SPARK-19234][MLLib] make sure label is positive for AFT regression --- .../org/apache/spark/ml/regression/AFTSurvivalRegression.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2f78dd30b3af7..9ed400bd9eafe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -644,4 +644,5 @@ private class AFTCostFun( */ private[regression] case class AFTPoint(features: Vector, label: Double, censor: Double) { require(censor == 1.0 || censor == 0.0, "censor of class AFTPoint must be 1.0 or 0.0") + require(label > 0.0, "label of AFTPoint must be positive") } From 0efbf0f478344a3735787b120fd7afc6b4a28f57 Mon Sep 17 00:00:00 2001 From: Andy MacKinlay Date: Fri, 20 Jan 2017 15:02:54 +1100 Subject: [PATCH 2/6] [SPARK-19234][MLLib] fix test suite to ensure no zero-labels get passed in test cases as they now throw errors --- .../spark/ml/regression/AFTSurvivalRegressionSuite.scala | 8 +++++--- .../scala/org/apache/spark/ml/util/MLTestingUtils.scala | 3 +-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 0fdfdf37cf38d..ccfa5c3f3d9a5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -18,20 +18,22 @@ package org.apache.spark.ml.regression import scala.util.Random - -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SparkSession} class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import testImplicits._ + import AFTSurvivalRegressionSuite._ + @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala index f1ed568d5e60a..9795550f64ce2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestingUtils.scala @@ -64,7 +64,7 @@ object MLTestingUtils extends SparkFunSuite { actuals.foreach(actual => check(expected, actual)) val dfWithStringLabels = spark.createDataFrame(Seq( - ("0", 1, Vectors.dense(0, 2, 3), 0.0) + ("1", 1, Vectors.dense(0, 2, 3), 0.0) )).toDF("label", "weight", "features", "censor") val thrown = intercept[IllegalArgumentException] { estimator.fit(dfWithStringLabels) @@ -156,7 +156,6 @@ object MLTestingUtils extends SparkFunSuite { featuresColName: String = "features", censorColName: String = "censor"): Map[NumericType, DataFrame] = { val df = spark.createDataFrame(Seq( - (0, Vectors.dense(0)), (1, Vectors.dense(1)), (2, Vectors.dense(2)), (3, Vectors.dense(3)), From cfdd28649b1d83bfcb63cfc7b5389470125d2cad Mon Sep 17 00:00:00 2001 From: Andrew MacKinlay Date: Mon, 23 Jan 2017 17:19:11 +1100 Subject: [PATCH 3/6] [SPARK-19234][MLLib] added test case to ensure fast failure on zero labels in AFT --- .../ml/regression/AFTSurvivalRegressionSuite.scala | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index ccfa5c3f3d9a5..92c98ceb709e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -32,8 +32,6 @@ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { import testImplicits._ - import AFTSurvivalRegressionSuite._ - @transient var datasetUnivariate: DataFrame = _ @transient var datasetMultivariate: DataFrame = _ @@ -402,6 +400,17 @@ class AFTSurvivalRegressionSuite val trainer = new AFTSurvivalRegression() trainer.fit(dataset) } + + test("SPARK-19234: Fail fast on zero-valued labels") { + val dataset = spark.createDataFrame(Seq( + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (0.000, 0.0, Vectors.dense(0.346, 2.158)), // ← generates error; zero labels invalid + (4.199, 0.0, Vectors.dense(0.795, -0.226)))).toDF("label", "censor", "features") + val aft = new AFTSurvivalRegression() + intercept[SparkException] { + aft.fit(dataset) + } + } } object AFTSurvivalRegressionSuite { From f645ab8fb40af31a444c26602f76e6f48d58f60f Mon Sep 17 00:00:00 2001 From: Andrew MacKinlay Date: Tue, 31 Jan 2017 09:44:55 +1100 Subject: [PATCH 4/6] [SPARK-19234][MLLib] added a clue about test failure per code-review suggestion --- .../spark/ml/regression/AFTSurvivalRegressionSuite.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 92c98ceb709e3..b4df3040d78ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -407,8 +407,10 @@ class AFTSurvivalRegressionSuite (0.000, 0.0, Vectors.dense(0.346, 2.158)), // ← generates error; zero labels invalid (4.199, 0.0, Vectors.dense(0.795, -0.226)))).toDF("label", "censor", "features") val aft = new AFTSurvivalRegression() - intercept[SparkException] { - aft.fit(dataset) + withClue("label of AFTPoint must be positive") { + intercept[SparkException] { + aft.fit(dataset) + } } } } From 88aee2f01dcfd927500d704197fde2a38f81976c Mon Sep 17 00:00:00 2001 From: Andrew MacKinlay Date: Tue, 31 Jan 2017 09:58:13 +1100 Subject: [PATCH 5/6] [SPARK-19234][MLLib] fixing indentation to match code guidelines --- .../spark/ml/regression/AFTSurvivalRegressionSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index b4df3040d78ab..823e5c4f4d78c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -403,9 +403,9 @@ class AFTSurvivalRegressionSuite test("SPARK-19234: Fail fast on zero-valued labels") { val dataset = spark.createDataFrame(Seq( - (1.218, 1.0, Vectors.dense(1.560, -0.605)), - (0.000, 0.0, Vectors.dense(0.346, 2.158)), // ← generates error; zero labels invalid - (4.199, 0.0, Vectors.dense(0.795, -0.226)))).toDF("label", "censor", "features") + (1.218, 1.0, Vectors.dense(1.560, -0.605)), + (0.000, 0.0, Vectors.dense(0.346, 2.158)), // ← generates error; zero labels invalid + (4.199, 0.0, Vectors.dense(0.795, -0.226)))).toDF("label", "censor", "features") val aft = new AFTSurvivalRegression() withClue("label of AFTPoint must be positive") { intercept[SparkException] { From c855976dd2806cd7e8fd13f46d120f229f7ff6cd Mon Sep 17 00:00:00 2001 From: Andrew MacKinlay Date: Tue, 31 Jan 2017 10:31:02 +1100 Subject: [PATCH 6/6] [SPARK-19234][MLLib] fixing imports to match scala style guidelines --- .../apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 823e5c4f4d78c..4eb4fc50ef3c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ml.regression import scala.util.Random + import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.linalg.{Vector, Vectors}