diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index dc9a49e69aa5a..72fc74e87a48d 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -22,6 +22,10 @@ license: | * Table of contents {:toc} +## Upgrading from Spark SQL 3.2 to 3.3 + + - In Spark 3.3, spark will fail when parsing a JSON/CSV string with `PERMISSIVE` mode and schema contains non-nullable fields. You can set mode to `FAILFAST/DROPMALFORMED` if you want to read JSON/CSV with a schema that contains nullable fields. + ## Upgrading from Spark SQL 3.1 to 3.2 - Since Spark 3.2, ADD FILE/JAR/ARCHIVE commands require each path to be enclosed by `"` or `'` if the path contains whitespaces. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 8a1191c5b7ee2..8b6c211e759fe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -405,10 +405,18 @@ class JacksonParser( schema.getFieldIndex(parser.getCurrentName) match { case Some(index) => try { - row.update(index, fieldConverters(index).apply(parser)) + val fieldValue = fieldConverters(index).apply(parser) + val isIllegal = + options.parseMode != PermissiveMode && !schema(index).nullable && fieldValue == null + if (isIllegal) { + throw new IllegalSchemaArgumentException( + s"field ${schema(index).name} is not nullable but the parsed value is null.") + } + row.update(index, fieldValue) skipRow = structFilters.skipRow(row, index) } catch { case e: SparkUpgradeException => throw e + case e: IllegalSchemaArgumentException => throw e case NonFatal(e) if isRoot => badRecordException = badRecordException.orElse(Some(e)) parser.skipChildren() @@ -418,6 +426,9 @@ class JacksonParser( } } + // When the input schema is setting to `nullable = false`, make sure the field is not null. + checkNotNullableInRow(row, schema, skipRow, badRecordException) + if (skipRow) { None } else if (badRecordException.isEmpty) { @@ -427,6 +438,28 @@ class JacksonParser( } } + // As PERMISSIVE mode only works with nullable fields, we can skip this not nullable check when + // the mode is PERMISSIVE. (see FailureSafeParser.checkNullabilityForPermissiveMode) + private lazy val checkNotNullableInRow = if (options.parseMode != PermissiveMode) { + (row: GenericInternalRow, + schema: StructType, + skipRow: Boolean, + runtimeExceptionOption: Option[Throwable]) => { + if (runtimeExceptionOption.isEmpty && !skipRow) { + var index = 0 + while (index < schema.length) { + if (!schema(index).nullable && row.isNullAt(index)) { + throw new IllegalSchemaArgumentException( + s"field ${schema(index).name} is not nullable but it's missing in one record.") + } + index += 1 + } + } + } + } else { + (_: GenericInternalRow, _: StructType, _: Boolean, _: Option[Throwable]) => {} + } + /** * Parse an object as a Map, preserving all fields. */ @@ -483,6 +516,7 @@ class JacksonParser( } } catch { case e: SparkUpgradeException => throw e + case e: IllegalSchemaArgumentException => throw e case e @ (_: RuntimeException | _: JsonProcessingException | _: MalformedInputException) => // JSON parser currently doesn't support partial results for corrupted records. // For such records, all fields other than the field configured by diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala index d719a33929fcc..0a07a6c9c4862 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/BadRecordException.scala @@ -41,3 +41,8 @@ case class BadRecordException( record: () => UTF8String, partialResult: () => Option[InternalRow], cause: Throwable) extends Exception(cause) + +/** + * Exception thrown when the actual value is null but the schema is setting to non-nullable. + */ +case class IllegalSchemaArgumentException(message: String) extends Exception(message) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala index ab7c9310bf844..008ace454da1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/FailureSafeParser.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{StructField, StructType} import org.apache.spark.unsafe.types.UTF8String class FailureSafeParser[IN]( @@ -29,11 +29,30 @@ class FailureSafeParser[IN]( schema: StructType, columnNameOfCorruptRecord: String) { + checkNullabilityForPermissiveMode() private val corruptFieldIndex = schema.getFieldIndex(columnNameOfCorruptRecord) private val actualSchema = StructType(schema.filterNot(_.name == columnNameOfCorruptRecord)) private val resultRow = new GenericInternalRow(schema.length) private val nullResult = new GenericInternalRow(schema.length) + // As PERMISSIVE mode should not fail at runtime, so fail if the mode is PERMISSIVE and schema + // contains non-nullable fields. + private def checkNullabilityForPermissiveMode(): Unit = { + def checkNotNullableRecursively(schema: StructType): Unit = { + schema.fields.foreach { + case _ @ StructField(name, _, nullable, _) if (!nullable) => + throw new IllegalSchemaArgumentException(s"Field ${name} is not nullable but " + + "PERMISSIVE mode only works with nullable fields.") + case _ @ StructField(_, dt: StructType, _, _) => checkNotNullableRecursively(dt) + case _ => + } + } + mode match { + case PermissiveMode => checkNotNullableRecursively(schema) + case _ => + } + } + // This function takes 2 parameters: an optional partial result, and the bad record. If the given // schema doesn't contain a field for corrupted record, we just return the partial result or a // row with all fields null. If the given schema contains a field for corrupted record, we will @@ -67,6 +86,8 @@ class FailureSafeParser[IN]( case FailFastMode => throw QueryExecutionErrors.malformedRecordsDetectedInRecordParsingError(e) } + case _: IllegalSchemaArgumentException if mode == DropMalformedMode => + Iterator.empty } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonParserSuite.scala index 587e22e787b87..e32a03038af6c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/json/JacksonParserSuite.scala @@ -19,24 +19,27 @@ package org.apache.spark.sql.catalyst.json import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.IllegalSchemaArgumentException import org.apache.spark.sql.sources.{EqualTo, Filter, StringStartsWith} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class JacksonParserSuite extends SparkFunSuite { - test("skipping rows using pushdown filters") { - def check( + + private def check( input: String = """{"i":1, "s": "a"}""", schema: StructType = StructType.fromDDL("i INTEGER"), filters: Seq[Filter], + config: Map[String, String] = Map.empty, expected: Seq[InternalRow]): Unit = { - val options = new JSONOptions(Map.empty[String, String], "GMT", "") - val parser = new JacksonParser(schema, options, false, filters) - val createParser = CreateJacksonParser.string _ - val actual = parser.parse(input, createParser, UTF8String.fromString) - assert(actual === expected) - } + val options = new JSONOptions(config, "GMT", "") + val parser = new JacksonParser(schema, options, false, filters) + val createParser = CreateJacksonParser.string _ + val actual = parser.parse(input, createParser, UTF8String.fromString) + assert(actual === expected) + } + test("skipping rows using pushdown filters") { check(filters = Seq(), expected = Seq(InternalRow(1))) check(filters = Seq(EqualTo("i", 1)), expected = Seq(InternalRow(1))) check(filters = Seq(EqualTo("i", 2)), expected = Seq.empty) @@ -54,4 +57,42 @@ class JacksonParserSuite extends SparkFunSuite { filters = Seq(EqualTo("d", 3.14)), expected = Seq(InternalRow(1, 3.14))) } + + test("SPARK-35912: nullability with different schema nullable setting") { + val missingFieldInput = """{"c1":1}""" + val nullValueInput = """{"c1": 1, "c2": null}""" + + def assertAction(nullable: Boolean, input: String)(action: => Unit): Unit = { + if (nullable) { + action + } else { + val msg = intercept[IllegalSchemaArgumentException] { + action + }.message + val expected = if (input == missingFieldInput) { + "field c2 is not nullable but it's missing in one record." + } else { + "field c2 is not nullable but the parsed value is null." + } + assert(msg.contains(expected)) + } + } + + Seq("FAILFAST", "DROPMALFORMED").foreach { mode => + val config = Map("mode" -> mode) + Seq(true, false).foreach { nullable => + val schema = StructType(Seq( + StructField("c1", IntegerType), + StructField("c2", IntegerType, nullable = nullable) + )) + val expected = Seq(InternalRow(1, null)) + Seq(missingFieldInput, nullValueInput).foreach { input => + assertAction(nullable, input) { + check(input = input, schema = schema, filters = Seq.empty, config = config, + expected = expected) + } + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala index 25b8849d61248..9555317f46bbe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DeprecatedAPISuite.scala @@ -182,8 +182,8 @@ class DeprecatedAPISuite extends QueryTest with SharedSparkSession { jsonDF = sqlContext.jsonRDD(jsonRDD.toJavaRDD()) checkAnswer(jsonDF, Row(18, "Marry") :: Row(20, "Jack") :: Nil) - schema = StructType(StructField("name", StringType, false) :: - StructField("age", IntegerType, false) :: Nil) + schema = StructType(StructField("name", StringType, true) :: + StructField("age", IntegerType, true) :: Nil) jsonDF = sqlContext.jsonRDD(jsonRDD, schema) checkAnswer(jsonDF, Row("Jack", 20) :: Row("Marry", 18) :: Nil) jsonDF = sqlContext.jsonRDD(jsonRDD.toJavaRDD(), schema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index cc52b6d8a14a7..4b779ef6d52f5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -147,8 +147,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque "{\"id\":2,\"vec\":[2.25,4.5,8.75]}" ) val schema = StructType(Seq( - StructField("id", IntegerType, false), - StructField("vec", new TestUDT.MyDenseVectorUDT, false) + StructField("id", IntegerType, true), + StructField("vec", new TestUDT.MyDenseVectorUDT, true) )) val jsonRDD = spark.read.schema(schema).json(data.toDS()) @@ -167,8 +167,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSparkSession with Parque ) val schema = StructType(Seq( - StructField("id", IntegerType, false), - StructField("vec", new TestUDT.MyDenseVectorUDT, false) + StructField("id", IntegerType, true), + StructField("vec", new TestUDT.MyDenseVectorUDT, true) )) val jsonDataset = spark.read.schema(schema).json(data.toDS()) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index dab1255eeab32..d16dfc7a7bdc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -924,7 +924,7 @@ abstract class JsonSuite test("Applying schemas with MapType") { withTempView("jsonWithSimpleMap", "jsonWithComplexMap") { val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + StructField("map", MapType(StringType, IntegerType, true), true) :: Nil) val jsonWithSimpleMap = spark.read.schema(schemaWithSimpleMap).json(mapType1) jsonWithSimpleMap.createOrReplaceTempView("jsonWithSimpleMap") @@ -953,7 +953,7 @@ abstract class JsonSuite StructField("field1", ArrayType(IntegerType, true), true) :: StructField("field2", IntegerType, true) :: Nil) val schemaWithComplexMap = StructType( - StructField("map", MapType(StringType, innerStruct, true), false) :: Nil) + StructField("map", MapType(StringType, innerStruct, true), true) :: Nil) val jsonWithComplexMap = spark.read.schema(schemaWithComplexMap).json(mapType2) @@ -1392,7 +1392,7 @@ abstract class JsonSuite withSQLConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD.key -> "_unparsed") { withTempDir { dir => val schemaWithSimpleMap = StructType( - StructField("map", MapType(StringType, IntegerType, true), false) :: Nil) + StructField("map", MapType(StringType, IntegerType, true), true) :: Nil) val df = spark.read.schema(schemaWithSimpleMap).json(mapType1) val path = dir.getAbsolutePath @@ -2919,6 +2919,55 @@ abstract class JsonSuite } } } + + test("SPARK-35912: nullability with different parse mode -- struct") { + // JSON field is missing. + val missingFieldInput = """{"c1": 1}""" + // JSON filed is null. + val nullValueInput = """{"c1": 1, "c2": null}""" + + val load = (mode: String, schema: StructType, inputJson: String) => { + val json = spark.createDataset( + spark.sparkContext.parallelize(inputJson :: Nil))(Encoders.STRING) + spark.read + .option("mode", mode) + .schema(schema) + .json(json) + } + + Seq(true, false).foreach { nullable => + val schema = StructType(Seq( + StructField("c1", IntegerType, nullable = true), + StructField("c2", IntegerType, nullable = nullable))) + + Seq(missingFieldInput, nullValueInput).foreach { jsonString => + if (nullable) { + checkAnswer(load("DROPMALFORMED", schema, jsonString), Row(1, null) :: Nil) + checkAnswer(load("FAILFAST", schema, jsonString), Row(1, null) :: Nil) + checkAnswer(load("PERMISSIVE", schema, jsonString), Row(1, null) :: Nil) + } else { + checkAnswer(load("DROPMALFORMED", schema, jsonString), Seq.empty) + + val exceptionMsg1 = intercept[SparkException] { + load("FAILFAST", schema, jsonString).collect + }.getMessage + val expectedMsg1 = if (jsonString == missingFieldInput) { + "field c2 is not nullable but it's missing in one record." + } else { + s"field c2 is not nullable but the parsed value is null." + } + assert(exceptionMsg1.contains(expectedMsg1)) + + val exceptionMsg2 = intercept[SparkException] { + load("PERMISSIVE", schema, jsonString).collect + } + val expectedMsg2 = + "Field c2 is not nullable but PERMISSIVE mode only works with nullable fields." + assert(exceptionMsg2.getMessage.contains(expectedMsg2)) + } + } + } + } } class JsonV1Suite extends JsonSuite {