From bba0832e48b86c7d5a7206c91bb37925dd043f26 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 3 Jul 2024 17:12:16 +0800 Subject: [PATCH 1/2] [SPARK-48792][SQL] Fix bug for INSERT with partial column list to a table with char/varchar --- .../analysis/TableOutputResolver.scala | 37 ++++++++++--------- .../spark/sql/CharVarcharTestSuite.scala | 13 ++++++- 2 files changed, 31 insertions(+), 19 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 1398552399cd7..98cbdf3f4aa36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -83,13 +83,9 @@ object TableOutputResolver extends SQLConfHelper with Logging { // TODO: Only DS v1 writing will set it to true. We should enable in for DS v2 as well. supportColDefaultValue: Boolean = false): LogicalPlan = { - val actualExpectedCols = expected.map { attr => - attr.withDataType(CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType)) - } - - if (actualExpectedCols.size < query.output.size) { + if (expected.size < query.output.size) { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError( - tableName, actualExpectedCols.map(_.name), query.output) + tableName, expected.map(_.name), query.output) } val errors = new mutable.ArrayBuffer[String]() @@ -100,21 +96,21 @@ object TableOutputResolver extends SQLConfHelper with Logging { reorderColumnsByName( tableName, query.output, - actualExpectedCols, + expected, conf, errors += _, fillDefaultValue = supportColDefaultValue) } else { - if (actualExpectedCols.size > query.output.size) { + if (expected.size > query.output.size) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError( - tableName, actualExpectedCols.map(_.name), query.output) + tableName, expected.map(_.name), query.output) } - resolveColumnsByPosition(tableName, query.output, actualExpectedCols, conf, errors += _) + resolveColumnsByPosition(tableName, query.output, expected, conf, errors += _) } if (errors.nonEmpty) { throw QueryCompilationErrors.incompatibleDataToTableCannotFindDataError( - tableName, actualExpectedCols.map(_.name).map(toSQLId).mkString(", ")) + tableName, expected.map(_.name).map(toSQLId).mkString(", ")) } if (resolved == query.output) { @@ -246,22 +242,25 @@ object TableOutputResolver extends SQLConfHelper with Logging { case a: Alias => a.withName(expectedName) case other => other } - (matchedCol.dataType, expectedCol.dataType) match { + val replacedExpectedCol = expectedCol.withDataType { + CharVarcharUtils.getRawType(expectedCol.metadata).getOrElse(expectedCol.dataType) + } + (matchedCol.dataType, replacedExpectedCol.dataType) match { case (matchedType: StructType, expectedType: StructType) => resolveStructType( - tableName, matchedCol, matchedType, expectedCol, expectedType, + tableName, matchedCol, matchedType, replacedExpectedCol, expectedType, byName = true, conf, addError, newColPath) case (matchedType: ArrayType, expectedType: ArrayType) => resolveArrayType( - tableName, matchedCol, matchedType, expectedCol, expectedType, + tableName, matchedCol, matchedType, replacedExpectedCol, expectedType, byName = true, conf, addError, newColPath) case (matchedType: MapType, expectedType: MapType) => resolveMapType( - tableName, matchedCol, matchedType, expectedCol, expectedType, + tableName, matchedCol, matchedType, replacedExpectedCol, expectedType, byName = true, conf, addError, newColPath) case _ => checkField( - tableName, expectedCol, matchedCol, byName = true, conf, addError, newColPath) + tableName, replacedExpectedCol, matchedCol, byName = true, conf, addError, newColPath) } } } @@ -288,11 +287,13 @@ object TableOutputResolver extends SQLConfHelper with Logging { private def resolveColumnsByPosition( tableName: String, inputCols: Seq[NamedExpression], - expectedCols: Seq[Attribute], + expected: Seq[Attribute], conf: SQLConf, addError: String => Unit, colPath: Seq[String] = Nil): Seq[NamedExpression] = { - + val expectedCols = expected.map { attr => + attr.withDataType { CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) } + } if (inputCols.size > expectedCols.size) { val extraColsStr = inputCols.takeRight(inputCols.size - expectedCols.size) .map(col => toSQLId(col.name)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index a93dee3bf2a61..c05e919a49fc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -81,7 +81,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { checkPlainResult(spark.table("t"), typ, v) } sql("INSERT OVERWRITE t VALUES ('1', null)") - checkPlainResult(spark.table("t"), typ, null) + (spark.table("t"), typ, null) } } } @@ -661,6 +661,17 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } } } + + test("SPARK-48792: Fix INSERT with partial column list to a table with char/varchar") { + Seq("char", "varchar").foreach { typ => + withTable("students") { + sql(s"CREATE TABLE students (name $typ(64), address $typ(64)) USING $format") + sql("INSERT INTO students VALUES ('Kent Yao', 'Hangzhou')") + sql("INSERT INTO students (address) VALUES ('')") + checkAnswer(sql("SELECT count(*) FROM students"), Row(2)) + } + } + } } // Some basic char/varchar tests which doesn't rely on table implementation. From 0d8b9d313a1ab46678217c06ee568c49caab1335 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Wed, 3 Jul 2024 20:03:11 +0800 Subject: [PATCH 2/2] fix --- .../analysis/TableOutputResolver.scala | 30 +++++++++---------- .../spark/sql/CharVarcharTestSuite.scala | 4 ++- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 98cbdf3f4aa36..5b559becbb118 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -242,25 +242,25 @@ object TableOutputResolver extends SQLConfHelper with Logging { case a: Alias => a.withName(expectedName) case other => other } - val replacedExpectedCol = expectedCol.withDataType { + val actualExpectedCol = expectedCol.withDataType { CharVarcharUtils.getRawType(expectedCol.metadata).getOrElse(expectedCol.dataType) } - (matchedCol.dataType, replacedExpectedCol.dataType) match { + (matchedCol.dataType, actualExpectedCol.dataType) match { case (matchedType: StructType, expectedType: StructType) => resolveStructType( - tableName, matchedCol, matchedType, replacedExpectedCol, expectedType, + tableName, matchedCol, matchedType, actualExpectedCol, expectedType, byName = true, conf, addError, newColPath) case (matchedType: ArrayType, expectedType: ArrayType) => resolveArrayType( - tableName, matchedCol, matchedType, replacedExpectedCol, expectedType, + tableName, matchedCol, matchedType, actualExpectedCol, expectedType, byName = true, conf, addError, newColPath) case (matchedType: MapType, expectedType: MapType) => resolveMapType( - tableName, matchedCol, matchedType, replacedExpectedCol, expectedType, + tableName, matchedCol, matchedType, actualExpectedCol, expectedType, byName = true, conf, addError, newColPath) case _ => checkField( - tableName, replacedExpectedCol, matchedCol, byName = true, conf, addError, newColPath) + tableName, actualExpectedCol, matchedCol, byName = true, conf, addError, newColPath) } } } @@ -287,32 +287,32 @@ object TableOutputResolver extends SQLConfHelper with Logging { private def resolveColumnsByPosition( tableName: String, inputCols: Seq[NamedExpression], - expected: Seq[Attribute], + expectedCols: Seq[Attribute], conf: SQLConf, addError: String => Unit, colPath: Seq[String] = Nil): Seq[NamedExpression] = { - val expectedCols = expected.map { attr => + val actualExpectedCols = expectedCols.map { attr => attr.withDataType { CharVarcharUtils.getRawType(attr.metadata).getOrElse(attr.dataType) } } - if (inputCols.size > expectedCols.size) { - val extraColsStr = inputCols.takeRight(inputCols.size - expectedCols.size) + if (inputCols.size > actualExpectedCols.size) { + val extraColsStr = inputCols.takeRight(inputCols.size - actualExpectedCols.size) .map(col => toSQLId(col.name)) .mkString(", ") if (colPath.isEmpty) { throw QueryCompilationErrors.cannotWriteTooManyColumnsToTableError(tableName, - expectedCols.map(_.name), inputCols.map(_.toAttribute)) + actualExpectedCols.map(_.name), inputCols.map(_.toAttribute)) } else { throw QueryCompilationErrors.incompatibleDataToTableExtraStructFieldsError( tableName, colPath.quoted, extraColsStr ) } - } else if (inputCols.size < expectedCols.size) { - val missingColsStr = expectedCols.takeRight(expectedCols.size - inputCols.size) + } else if (inputCols.size < actualExpectedCols.size) { + val missingColsStr = actualExpectedCols.takeRight(actualExpectedCols.size - inputCols.size) .map(col => toSQLId(col.name)) .mkString(", ") if (colPath.isEmpty) { throw QueryCompilationErrors.cannotWriteNotEnoughColumnsToTableError(tableName, - expectedCols.map(_.name), inputCols.map(_.toAttribute)) + actualExpectedCols.map(_.name), inputCols.map(_.toAttribute)) } else { throw QueryCompilationErrors.incompatibleDataToTableStructMissingFieldsError( tableName, colPath.quoted, missingColsStr @@ -320,7 +320,7 @@ object TableOutputResolver extends SQLConfHelper with Logging { } } - inputCols.zip(expectedCols).flatMap { case (inputCol, expectedCol) => + inputCols.zip(actualExpectedCols).flatMap { case (inputCol, expectedCol) => val newColPath = colPath :+ expectedCol.name (inputCol.dataType, expectedCol.dataType) match { case (inputType: StructType, expectedType: StructType) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index c05e919a49fc6..5df46ea101c19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -81,7 +81,7 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { checkPlainResult(spark.table("t"), typ, v) } sql("INSERT OVERWRITE t VALUES ('1', null)") - (spark.table("t"), typ, null) + checkPlainResult(spark.table("t"), typ, null) } } } @@ -663,6 +663,8 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } test("SPARK-48792: Fix INSERT with partial column list to a table with char/varchar") { + assume(format != "foo", + "TODO: TableOutputResolver.resolveOutputColumns supportColDefaultValue is false") Seq("char", "varchar").foreach { typ => withTable("students") { sql(s"CREATE TABLE students (name $typ(64), address $typ(64)) USING $format")