From 79d706085e8371fb1724ce73377767c38d551e5d Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Fri, 10 Mar 2017 12:45:56 +0800 Subject: [PATCH 1/4] Enhance StringIndexer with NULL values --- .../spark/ml/feature/StringIndexer.scala | 57 +++++++------ .../spark/ml/feature/StringIndexerSuite.scala | 80 +++++++++++++++++++ 2 files changed, 114 insertions(+), 23 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 810b02febbe77..fd372de27686a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -28,7 +28,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap @@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { /** - * Param for how to handle unseen labels. Options are 'skip' (filter out rows with - * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional * bucket, at index numLabels. * Default: "error" * @group param */ @Since("1.6.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + - "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + - "at index numLabels).", + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) - setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) /** @group getParam */ @Since("1.6.0") @@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.select(col($(inputCol)).cast(StringType)) + val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() @@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { - private[feature] val SKIP_UNSEEN_LABEL: String = "skip" - private[feature] val ERROR_UNSEEN_LABEL: String = "error" - private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = - Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -188,7 +189,7 @@ class StringIndexerModel ( transformSchema(dataset.schema, logging = true) val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case StringIndexer.KEEP_INVALID => labels :+ "__unknown" case _ => labels } @@ -196,27 +197,37 @@ class StringIndexerModel ( .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_UNSEEN_LABEL => + case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) } (dataset.where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) } - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else if (keepInvalid) { - labels.length + val indexer = udf { row: Row => + if (row.isNullAt(0)) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + - s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + val label = String.valueOf(row.get(0)) + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } } } filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) + indexer(struct(Array(dataset($(inputCol))): _*)).as($(outputCol), metadata)) } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 188dffb3dd55f..b328da5a81191 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,6 +122,86 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with a string input column with NULLs") { + val data: Seq[java.lang.String] = Seq("a", "b", "b", null) + val data2: Seq[java.lang.String] = Seq("a", "b", null) + val expectedSkip = Array(1.0, 0.0) + val expectedKeep = Array(1.0, 0.0, 2.0) + val df = data.toDF("label") + val df2 = data2.toDF("label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleValid=error when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + assert(transformedSkip.select("labelIndex").rdd.map { r => + r.getDouble(0) + }.collect() === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + assert(transformedKeep.select("labelIndex").rdd.map { r => + r.getDouble(0) + }.collect() === expectedKeep) + } + + test("StringIndexer with a numeric input column with NULLs") { + val data: Seq[Integer] = Seq(1, 2, 2, null) + val data2: Seq[Integer] = Seq(1, 2, null) + val expectedSkip = Array(1.0, 0.0) + val expectedKeep = Array(1.0, 0.0, 2.0) + val df = data.toDF("label") + val df2 = data2.toDF("label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleValid=error when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("2", "1")) + assert(transformedSkip.select("labelIndex").rdd.map { r => + r.getDouble(0) + }.collect() === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("2", "1", "__unknown")) + assert(transformedKeep.select("labelIndex").rdd.map { r => + r.getDouble(0) + }.collect() === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") From 0cb121c65f592b9623bdeef2746d7c2a3c281ae1 Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Fri, 10 Mar 2017 12:52:30 +0800 Subject: [PATCH 2/4] filter out NULLs when transform dataset --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index fd372de27686a..cbf2cd974eeab 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -201,7 +201,7 @@ class StringIndexerModel ( val filterer = udf { label: String => labelToIndex.contains(label) } - (dataset.where(filterer(dataset($(inputCol)))), false) + (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) } From e80a158cc5aae0c593ec774b682fb642a6c2298d Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Tue, 14 Mar 2017 11:20:42 +0800 Subject: [PATCH 3/4] improve code and unit tests --- .../spark/ml/feature/StringIndexer.scala | 9 +-- .../spark/ml/feature/StringIndexerSuite.scala | 73 +++++-------------- 2 files changed, 23 insertions(+), 59 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index cbf2cd974eeab..027dfc2968773 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -42,7 +42,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha * Param for how to handle invalid data (unseen labels or NULL values). * Options are 'skip' (filter out rows with invalid data), * 'error' (throw an error), or 'keep' (put invalid data in a special additional - * bucket, at index numLabels. + * bucket, at index numLabels). * Default: "error" * @group param */ @@ -205,8 +205,8 @@ class StringIndexerModel ( case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) } - val indexer = udf { row: Row => - if (row.isNullAt(0)) { + val indexer = udf { label: String => + if (label == null) { if (keepInvalid) { labels.length } else { @@ -214,7 +214,6 @@ class StringIndexerModel ( "NULLS, try setting StringIndexer.handleInvalid.") } } else { - val label = String.valueOf(row.get(0)) if (labelToIndex.contains(label)) { labelToIndex(label) } else if (keepInvalid) { @@ -227,7 +226,7 @@ class StringIndexerModel ( } filteredDataset.select(col("*"), - indexer(struct(Array(dataset($(inputCol))): _*)).as($(outputCol), metadata)) + indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } @Since("1.4.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index b328da5a81191..8d9042b31e033 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,19 +122,18 @@ class StringIndexerSuite assert(output === expected) } - test("StringIndexer with a string input column with NULLs") { - val data: Seq[java.lang.String] = Seq("a", "b", "b", null) - val data2: Seq[java.lang.String] = Seq("a", "b", null) - val expectedSkip = Array(1.0, 0.0) - val expectedKeep = Array(1.0, 0.0, 2.0) - val df = data.toDF("label") - val df2 = data2.toDF("label") + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex") - withClue("StringIndexer should throw error when setHandleValid=error when given NULL values") { + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { intercept[SparkException] { indexer.setHandleInvalid("error") indexer.fit(df).transform(df2).collect() @@ -147,9 +146,12 @@ class StringIndexerSuite .fromStructField(transformedSkip.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attrSkip.values.get === Array("b", "a")) - assert(transformedSkip.select("labelIndex").rdd.map { r => - r.getDouble(0) - }.collect() === expectedSkip) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) indexer.setHandleInvalid("keep") val transformedKeep = indexer.fit(df).transform(df2) @@ -157,49 +159,12 @@ class StringIndexerSuite .fromStructField(transformedKeep.schema("labelIndex")) .asInstanceOf[NominalAttribute] assert(attrKeep.values.get === Array("b", "a", "__unknown")) - assert(transformedKeep.select("labelIndex").rdd.map { r => - r.getDouble(0) - }.collect() === expectedKeep) - } - - test("StringIndexer with a numeric input column with NULLs") { - val data: Seq[Integer] = Seq(1, 2, 2, null) - val data2: Seq[Integer] = Seq(1, 2, null) - val expectedSkip = Array(1.0, 0.0) - val expectedKeep = Array(1.0, 0.0, 2.0) - val df = data.toDF("label") - val df2 = data2.toDF("label") - - val indexer = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - - withClue("StringIndexer should throw error when setHandleValid=error when given NULL values") { - intercept[SparkException] { - indexer.setHandleInvalid("error") - indexer.fit(df).transform(df2).collect() - } - } - - indexer.setHandleInvalid("skip") - val transformedSkip = indexer.fit(df).transform(df2) - val attrSkip = Attribute - .fromStructField(transformedSkip.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrSkip.values.get === Array("2", "1")) - assert(transformedSkip.select("labelIndex").rdd.map { r => - r.getDouble(0) - }.collect() === expectedSkip) - - indexer.setHandleInvalid("keep") - val transformedKeep = indexer.fit(df).transform(df2) - val attrKeep = Attribute - .fromStructField(transformedKeep.schema("labelIndex")) - .asInstanceOf[NominalAttribute] - assert(attrKeep.values.get === Array("2", "1", "__unknown")) - assert(transformedKeep.select("labelIndex").rdd.map { r => - r.getDouble(0) - }.collect() === expectedKeep) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) } test("StringIndexerModel should keep silent if the input column does not exist.") { From 2a0a7567dcbb3ffcabd88cc1b3824a559fdd6d9e Mon Sep 17 00:00:00 2001 From: Menglong TAN Date: Tue, 14 Mar 2017 11:22:32 +0800 Subject: [PATCH 4/4] remove unused import --- .../main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 027dfc2968773..99321bcc7cf98 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -28,7 +28,7 @@ import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashMap