From bd7f0d5cf4afc7f349232932adb0a97c8a58a8db Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 3 Jan 2017 23:06:50 +0800 Subject: [PATCH] [SPARK-18877][SQL][BACKPORT-2.0] `CSVInferSchema.inferField` on DecimalType should find a common type with `typeSoFar` CSV type inferencing causes `IllegalArgumentException` on decimal numbers with heterogeneous precisions and scales because the current logic uses the last decimal type in a **partition**. Specifically, `inferRowType`, the **seqOp** of **aggregate**, returns the last decimal type. This PR fixes it to use `findTightestCommonType`. **decimal.csv** ``` 9.03E+12 1.19E+11 ``` **BEFORE** ```scala scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").printSchema root |-- _c0: decimal(3,-9) (nullable = true) scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").show 16/12/16 14:32:49 ERROR Executor: Exception in task 0.0 in stage 4.0 (TID 4) java.lang.IllegalArgumentException: requirement failed: Decimal precision 4 exceeds max precision 3 ``` **AFTER** ```scala scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").printSchema root |-- _c0: decimal(4,-9) (nullable = true) scala> spark.read.format("csv").option("inferSchema", true).load("decimal.csv").show +---------+ | _c0| +---------+ |9.030E+12| | 1.19E+11| +---------+ ``` Pass the newly add test case. --- .../datasources/csv/CSVInferSchema.scala | 4 +++- .../datasources/csv/CSVInferSchemaSuite.scala | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 3ab775c909238..bc2071ee5a50c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -85,7 +85,9 @@ private[csv] object CSVInferSchema { case NullType => tryParseInteger(field, options) case IntegerType => tryParseInteger(field, options) case LongType => tryParseLong(field, options) - case _: DecimalType => tryParseDecimal(field, options) + case _: DecimalType => + // DecimalTypes have different precisions and scales, so we try to find the common type. + findTightestCommonType(typeSoFar, tryParseDecimal(field, options)).getOrElse(StringType) case DoubleType => tryParseDouble(field, options) case TimestampType => tryParseTimestamp(field, options) case BooleanType => tryParseBoolean(field, options) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 5e00f669b8593..7aca644796e76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -109,4 +109,21 @@ class CSVInferSchemaSuite extends SparkFunSuite { val mergedNullTypes = CSVInferSchema.mergeRowTypes(Array(NullType), Array(NullType)) assert(mergedNullTypes.deep == Array(NullType).deep) } + + test("SPARK-18877: `inferField` on DecimalType should find a common type with `typeSoFar`") { + val options = new CSVOptions(Map.empty[String, String]) + + // 9.03E+12 is Decimal(3, -10) and 1.19E+11 is Decimal(3, -9). + assert(CSVInferSchema.inferField(DecimalType(3, -10), "1.19E+11", options) == + DecimalType(4, -9)) + + // BigDecimal("12345678901234567890.01234567890123456789") is precision 40 and scale 20. + val value = "12345678901234567890.01234567890123456789" + assert(CSVInferSchema.inferField(DecimalType(3, -10), value, options) == DoubleType) + + // Seq(s"${Long.MaxValue}1", "2015-12-01 00:00:00") should be StringType + assert(CSVInferSchema.inferField(NullType, s"${Long.MaxValue}1", options) == DecimalType(20, 0)) + assert(CSVInferSchema.inferField(DecimalType(20, 0), "2015-12-01 00:00:00", options) + == StringType) + } }