From 92172e28be5dd3f9ecd36071b6c9133a40a84ed3 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Thu, 4 May 2017 15:18:46 -0700 Subject: [PATCH 1/9] allow imputer to handle numeric types --- .../org/apache/spark/ml/feature/Imputer.scala | 2 +- .../spark/ml/feature/ImputerSuite.scala | 26 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1c074e204ad99..d262b274e7810 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -73,7 +73,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu s" and outputCols(${$(outputCols).length}) should have the same length") val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => val inputField = schema(inputCol) - SchemaUtils.checkColumnTypes(schema, inputCol, Seq(DoubleType, FloatType)) + SchemaUtils.checkNumericType(schema, inputCol) StructField(outputCol, inputField.dataType, inputField.nullable) } StructType(schema ++ outputFields) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 75f63a623e6d8..fd0e0cdeccab3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -20,6 +20,8 @@ import org.apache.spark.SparkException import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ class ImputerSuite extends MLTest with DefaultReadWriteTest { @@ -176,6 +178,27 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) } + test("Imputer for Numeric with default missing Value NaN") { + val df = spark.createDataFrame( Seq( + (0, 1.0, 1.0, 1.0), + (1, 11.0, 11.0, 11.0), + (2, 1.5, 1.5, 1.5), + (3, Double.NaN, 4.5, 1.5) + )).toDF("id", "value1", "expected_mean_value1", "expected_median_value1") + val imputer = new Imputer() + .setInputCols(Array("value1")) + .setOutputCols(Array("out1")) + + val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, + ByteType, DecimalType(10, 0)) + for (mType <- types) { + val df2 = df.withColumn("value1", col("value1").cast(mType)) + .withColumn("value1", when(col("value1").equalTo(0), null).otherwise(col("value1"))) + .withColumn("expected_mean_value1", col("expected_mean_value1").cast(mType)) + .withColumn("expected_median_value1", col("expected_median_value1").cast(mType)) + ImputerSuite.iterateStrategyTest(imputer, df2) + } + } } object ImputerSuite { @@ -197,6 +220,9 @@ object ImputerSuite { case Row(exp: Double, out: Double) => assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp, out) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") } } } From 400e84801f98e910d9cdeab7e26a42b4a178f7e6 Mon Sep 17 00:00:00 2001 From: Wayne Zhang Date: Fri, 5 May 2017 22:19:37 -0700 Subject: [PATCH 2/9] convert output to double type --- .../org/apache/spark/ml/feature/Imputer.scala | 5 +++-- .../spark/ml/feature/ImputerSuite.scala | 22 ++++++------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index d262b274e7810..f7a115fbe0922 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -74,7 +74,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => val inputField = schema(inputCol) SchemaUtils.checkNumericType(schema, inputCol) - StructField(outputCol, inputField.dataType, inputField.nullable) + StructField(outputCol, DoubleType, inputField.nullable) } StructType(schema ++ outputFields) } @@ -84,12 +84,13 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu * :: Experimental :: * Imputation estimator for completing missing values, either using the mean or the median * of the columns in which the missing values are located. The input columns should be of - * DoubleType or FloatType. Currently Imputer does not support categorical features + * numeric type. Currently Imputer does not support categorical features * (SPARK-15041) and possibly creates incorrect values for a categorical feature. * * Note that the mean/median value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. + * The output column is always of Double type regardless of the input column type. */ @Experimental @Since("2.2.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index fd0e0cdeccab3..1d2bbc9e99c6f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -53,11 +53,11 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Float with missing Value -1.0") { val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0F, 1.0F), - (1, 3.0F, 3.0F, 3.0F), - (2, 10.0F, 10.0F, 10.0F), - (3, 10.0F, 10.0F, 10.0F), - (4, -1.0F, 6.0F, 3.0F) + (0, 1.0F, 1.0, 1.0), + (1, 3.0F, 3.0, 3.0), + (2, 10.0F, 10.0, 10.0), + (3, 10.0F, 10.0, 10.0), + (4, -1.0F, 6.0, 3.0) )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) @@ -182,8 +182,8 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { val df = spark.createDataFrame( Seq( (0, 1.0, 1.0, 1.0), (1, 11.0, 11.0, 11.0), - (2, 1.5, 1.5, 1.5), - (3, Double.NaN, 4.5, 1.5) + (2, 3.0, 3.0, 3.0), + (3, Double.NaN, 5.0, 3.0) )).toDF("id", "value1", "expected_mean_value1", "expected_median_value1") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -194,8 +194,6 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { for (mType <- types) { val df2 = df.withColumn("value1", col("value1").cast(mType)) .withColumn("value1", when(col("value1").equalTo(0), null).otherwise(col("value1"))) - .withColumn("expected_mean_value1", col("expected_mean_value1").cast(mType)) - .withColumn("expected_median_value1", col("expected_median_value1").cast(mType)) ImputerSuite.iterateStrategyTest(imputer, df2) } } @@ -214,15 +212,9 @@ object ImputerSuite { val resultDF = model.transform(df) imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { - case Row(exp: Float, out: Float) => - assert((exp.isNaN && out.isNaN) || (exp == out), - s"Imputed values differ. Expected: $exp, actual: $out") case Row(exp: Double, out: Double) => assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), s"Imputed values differ. Expected: $exp, actual: $out") - case Row(exp, out) => - assert(exp == out, - s"Imputed values differ. Expected: $exp, actual: $out") } } } From e80f6cc5b4e730c612a3bbfc49324b109ea224b4 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 09:31:18 -0700 Subject: [PATCH 3/9] force returning same type --- .../org/apache/spark/ml/feature/Imputer.scala | 2 +- .../spark/ml/feature/ImputerSuite.scala | 34 ++++++++++++++++--- 2 files changed, 31 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index f7a115fbe0922..e3d932518aa50 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -219,7 +219,7 @@ class ImputerModel private[ml] ( val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType - val ic = col(inputCol) + val ic = col(inputCol).cast(DoubleType) when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 1d2bbc9e99c6f..74e7956dc83c5 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -179,22 +179,48 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { } test("Imputer for Numeric with default missing Value NaN") { - val df = spark.createDataFrame( Seq( + val df = spark.createDataFrame(Seq( (0, 1.0, 1.0, 1.0), (1, 11.0, 11.0, 11.0), - (2, 3.0, 3.0, 3.0), - (3, Double.NaN, 5.0, 3.0) + (2, 3.6, 3.6, 3.6), + (3, Double.NaN, 5.2, 3.6) )).toDF("id", "value1", "expected_mean_value1", "expected_median_value1") + val imputer = new Imputer() .setInputCols(Array("value1")) .setOutputCols(Array("out1")) val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, ByteType, DecimalType(10, 0)) + for (mType <- types) { + // cast `value` to desired data type for testing val df2 = df.withColumn("value1", col("value1").cast(mType)) .withColumn("value1", when(col("value1").equalTo(0), null).otherwise(col("value1"))) - ImputerSuite.iterateStrategyTest(imputer, df2) + + + Seq("mean", "median").foreach { strategy => + imputer.setStrategy(strategy) + val model = imputer.fit(df) + val resultDF = model.transform(df) + + imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + + // check dataType is consistent between input and output + val inputType = resultDF.schema(inputCol).dataType + val outputType = resultDF.schema(outputCol).dataType + assert(inputType === outputType, "Output type is not the same as input type.") + + // check value + val expectedDoubleColumn = s"texpected_${strategy}_$inputCol" + resultDF.select(col(expectedDoubleColumn), col(outputCol).cast(DoubleType)) + .collect().foreach { + case Row(exp: Double, out: Double) => + assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), + s"Imputed values differ. Expected: $exp, actual: $out") + } + } + } } } } From c0722f12b5f25d0cb94fa65992a64e555b86abd6 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 10:32:12 -0700 Subject: [PATCH 4/9] update test --- .../org/apache/spark/ml/feature/ImputerSuite.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 74e7956dc83c5..174f9046dbd57 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -184,7 +184,7 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { (1, 11.0, 11.0, 11.0), (2, 3.6, 3.6, 3.6), (3, Double.NaN, 5.2, 3.6) - )).toDF("id", "value1", "expected_mean_value1", "expected_median_value1") + )).toDF("id", "value1", "expected_mean_value1_double", "expected_median_value1_double") val imputer = new Imputer() .setInputCols(Array("value1")) @@ -198,11 +198,10 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { val df2 = df.withColumn("value1", col("value1").cast(mType)) .withColumn("value1", when(col("value1").equalTo(0), null).otherwise(col("value1"))) - Seq("mean", "median").foreach { strategy => imputer.setStrategy(strategy) - val model = imputer.fit(df) - val resultDF = model.transform(df) + val model = imputer.fit(df2) + val resultDF = model.transform(df2) imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => @@ -212,8 +211,9 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { assert(inputType === outputType, "Output type is not the same as input type.") // check value - val expectedDoubleColumn = s"texpected_${strategy}_$inputCol" - resultDF.select(col(expectedDoubleColumn), col(outputCol).cast(DoubleType)) + val expectedDoubleColumn = s"expected_${strategy}_${inputCol}_double" + resultDF.select(col(expectedDoubleColumn).cast(mType).cast(DoubleType), + col(outputCol).cast(DoubleType)) .collect().foreach { case Row(exp: Double, out: Double) => assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), From 9e9589f83be86702e809c4c7b52bbffb9d106dcf Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 10:52:08 -0700 Subject: [PATCH 5/9] update doc to explain logic for non-double type --- .../main/scala/org/apache/spark/ml/feature/Imputer.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index e3d932518aa50..46b5b6bb6b63b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -87,6 +87,12 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu * numeric type. Currently Imputer does not support categorical features * (SPARK-15041) and possibly creates incorrect values for a categorical feature. * + * Note that the input columns are converted to Double data type internally to compute + * the mean/median value and impute the missing values, which are then casted back to + * the original data type in the output. So the output column always has the same data + * type as the input. As an example, if the input column is IntegerType (1, 2, 4, null), + * the output will be IntegerType (1, 2, 4, 2) after mean imputation. + * * Note that the mean/median value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. From ad04109f3cacbc4966a61186328cb39685fafa62 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 10:57:06 -0700 Subject: [PATCH 6/9] revert some unnecessary changes on test --- .../org/apache/spark/ml/feature/ImputerSuite.scala | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 174f9046dbd57..2b139efa2bbbc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -53,11 +53,11 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { test("Imputer for Float with missing Value -1.0") { val df = spark.createDataFrame( Seq( - (0, 1.0F, 1.0, 1.0), - (1, 3.0F, 3.0, 3.0), - (2, 10.0F, 10.0, 10.0), - (3, 10.0F, 10.0, 10.0), - (4, -1.0F, 6.0, 3.0) + (0, 1.0F, 1.0F, 1.0F), + (1, 3.0F, 3.0F, 3.0F), + (2, 10.0F, 10.0F, 10.0F), + (3, 10.0F, 10.0F, 10.0F), + (4, -1.0F, 6.0F, 3.0F) )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1) @@ -238,6 +238,9 @@ object ImputerSuite { val resultDF = model.transform(df) imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { + case Row(exp: Float, out: Float) => + assert((exp.isNaN && out.isNaN) || (exp == out), + s"Imputed values differ. Expected: $exp, actual: $out") case Row(exp: Double, out: Double) => assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), s"Imputed values differ. Expected: $exp, actual: $out") From 7bbc69e25a842ef461cbb6b4d18a398064847ef9 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 11:29:35 -0700 Subject: [PATCH 7/9] fix output data type --- mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 46b5b6bb6b63b..95ed9efeafb26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -74,7 +74,7 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu val outputFields = $(inputCols).zip($(outputCols)).map { case (inputCol, outputCol) => val inputField = schema(inputCol) SchemaUtils.checkNumericType(schema, inputCol) - StructField(outputCol, DoubleType, inputField.nullable) + StructField(outputCol, inputField.dataType, inputField.nullable) } StructType(schema ++ outputFields) } From 13a9659dbfffd447c5ab729102d6653d84a2bb88 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 11:31:28 -0700 Subject: [PATCH 8/9] fix doc --- mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 95ed9efeafb26..7a0083e5be6a4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -96,7 +96,6 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu * Note that the mean/median value is computed after filtering out missing values. * All Null values in the input columns are treated as missing, and so are also imputed. For * computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. - * The output column is always of Double type regardless of the input column type. */ @Experimental @Since("2.2.0") From 16a73481c11f70b249a96c7e231eab65dad96a70 Mon Sep 17 00:00:00 2001 From: actuaryzhang Date: Wed, 31 Jul 2019 14:58:46 -0700 Subject: [PATCH 9/9] improve tests and doc --- .../org/apache/spark/ml/feature/Imputer.scala | 6 +- .../spark/ml/feature/ImputerSuite.scala | 82 +++++++++++-------- 2 files changed, 48 insertions(+), 40 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 7a0083e5be6a4..bdad804083b01 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -87,10 +87,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols with HasOu * numeric type. Currently Imputer does not support categorical features * (SPARK-15041) and possibly creates incorrect values for a categorical feature. * - * Note that the input columns are converted to Double data type internally to compute - * the mean/median value and impute the missing values, which are then casted back to - * the original data type in the output. So the output column always has the same data - * type as the input. As an example, if the input column is IntegerType (1, 2, 4, null), + * Note when an input column is integer, the imputed value is casted (truncated) to an integer type. + * For example, if the input column is IntegerType (1, 2, 4, null), * the output will be IntegerType (1, 2, 4, 2) after mean imputation. * * Note that the mean/median value is computed after filtering out missing values. diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index 2b139efa2bbbc..02ef261a6c067 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -178,49 +178,46 @@ class ImputerSuite extends MLTest with DefaultReadWriteTest { assert(newInstance.surrogateDF.collect() === instance.surrogateDF.collect()) } - test("Imputer for Numeric with default missing Value NaN") { - val df = spark.createDataFrame(Seq( - (0, 1.0, 1.0, 1.0), - (1, 11.0, 11.0, 11.0), - (2, 3.6, 3.6, 3.6), - (3, Double.NaN, 5.2, 3.6) - )).toDF("id", "value1", "expected_mean_value1_double", "expected_median_value1_double") + test("Imputer for IntegerType with default missing value null") { + + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (null, 5, 3) + )).toDF("value1", "expected_mean_value1", "expected_median_value1") val imputer = new Imputer() .setInputCols(Array("value1")) .setOutputCols(Array("out1")) - val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, - ByteType, DecimalType(10, 0)) + val types = Seq(IntegerType, LongType) + for (mType <- types) { + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(imputer, df2) + } + } + + test("Imputer for IntegerType with missing value -1") { + + val df = spark.createDataFrame(Seq[(Integer, Integer, Integer)]( + (1, 1, 1), + (11, 11, 11), + (3, 3, 3), + (-1, 5, 3) + )).toDF("value1", "expected_mean_value1", "expected_median_value1") + + val imputer = new Imputer() + .setInputCols(Array("value1")) + .setOutputCols(Array("out1")) + .setMissingValue(-1.0) + val types = Seq(IntegerType, LongType) for (mType <- types) { - // cast `value` to desired data type for testing - val df2 = df.withColumn("value1", col("value1").cast(mType)) - .withColumn("value1", when(col("value1").equalTo(0), null).otherwise(col("value1"))) - - Seq("mean", "median").foreach { strategy => - imputer.setStrategy(strategy) - val model = imputer.fit(df2) - val resultDF = model.transform(df2) - - imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => - - // check dataType is consistent between input and output - val inputType = resultDF.schema(inputCol).dataType - val outputType = resultDF.schema(outputCol).dataType - assert(inputType === outputType, "Output type is not the same as input type.") - - // check value - val expectedDoubleColumn = s"expected_${strategy}_${inputCol}_double" - resultDF.select(col(expectedDoubleColumn).cast(mType).cast(DoubleType), - col(outputCol).cast(DoubleType)) - .collect().foreach { - case Row(exp: Double, out: Double) => - assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), - s"Imputed values differ. Expected: $exp, actual: $out") - } - } - } + // cast all columns to desired data type for testing + val df2 = df.select(df.columns.map(c => col(c).cast(mType)): _*) + ImputerSuite.iterateStrategyTest(imputer, df2) } } } @@ -237,6 +234,13 @@ object ImputerSuite { val model = imputer.fit(df) val resultDF = model.transform(df) imputer.getInputCols.zip(imputer.getOutputCols).foreach { case (inputCol, outputCol) => + + // check dataType is consistent between input and output + val inputType = resultDF.schema(inputCol).dataType + val outputType = resultDF.schema(outputCol).dataType + assert(inputType == outputType, "Output type is not the same as input type.") + + // check value resultDF.select(s"expected_${strategy}_$inputCol", outputCol).collect().foreach { case Row(exp: Float, out: Float) => assert((exp.isNaN && out.isNaN) || (exp == out), @@ -244,6 +248,12 @@ object ImputerSuite { case Row(exp: Double, out: Double) => assert((exp.isNaN && out.isNaN) || (exp ~== out absTol 1e-5), s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Integer, out: Integer) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") + case Row(exp: Long, out: Long) => + assert(exp == out, + s"Imputed values differ. Expected: $exp, actual: $out") } } }