From 435a3dfacad4e655e377df6eb9ac8cd819207950 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 18 Nov 2020 23:45:38 +0800 Subject: [PATCH 01/11] support char/varchar type --- .../sql/catalyst/analysis/Analyzer.scala | 9 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 6 +- .../catalyst/analysis/ResolveCatalogs.scala | 5 - .../analysis/TableOutputResolver.scala | 19 +- .../sql/catalyst/catalog/SessionCatalog.scala | 10 +- .../sql/catalyst/parser/AstBuilder.scala | 13 +- .../catalyst/plans/logical/v2Commands.scala | 4 +- .../sql/catalyst/util/CharVarcharUtils.scala | 277 +++++++++++++ .../sql/connector/catalog/CatalogV2Util.scala | 18 +- .../datasources/v2/DataSourceV2Relation.scala | 8 +- .../org/apache/spark/sql/types/CharType.scala | 36 ++ .../org/apache/spark/sql/types/DataType.scala | 8 +- .../spark/sql/types/HiveStringType.scala | 81 ---- .../apache/spark/sql/types/VarcharType.scala | 35 ++ .../org/apache/spark/sql/types/package.scala | 10 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 18 +- .../parser/TableSchemaParserSuite.scala | 11 +- .../spark/sql/connector/InMemoryTable.scala | 15 +- .../catalog/CatalogV2UtilSuite.scala | 2 +- .../scala/org/apache/spark/sql/Column.scala | 6 +- .../apache/spark/sql/DataFrameReader.scala | 14 +- .../analysis/ResolveSessionCatalog.scala | 36 +- .../datasources/ApplyCharTypePadding.scala | 135 +++++++ .../datasources/LogicalRelation.scala | 18 +- .../datasources/jdbc/JdbcUtils.scala | 12 +- .../internal/BaseSessionStateBuilder.scala | 1 + .../sql/streaming/DataStreamReader.scala | 10 +- .../spark/sql/CharVarcharTestSuite.scala | 374 ++++++++++++++++++ .../command/PlanResolutionSuite.scala | 44 +-- .../spark/sql/sources/TableScanSuite.scala | 14 +- .../sql/hive/HiveSessionStateBuilder.scala | 1 + .../sql/hive/client/HiveClientImpl.scala | 19 +- .../spark/sql/HiveCharVarcharTestSuite.scala | 43 ++ .../sql/hive/HiveMetastoreCatalogSuite.scala | 15 +- .../sql/hive/execution/HiveDDLSuite.scala | 4 +- 35 files changed, 1034 insertions(+), 297 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 837686420375a..af1a2bc9db978 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnChange, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} @@ -3097,7 +3097,12 @@ class Analyzer(override val catalogManager: CatalogManager) val projection = TableOutputResolver.resolveOutputColumns( v2Write.table.name, v2Write.table.output, v2Write.query, v2Write.isByName, conf) if (projection != v2Write.query) { - v2Write.withNewQuery(projection) + val cleanedTable = v2Write.table match { + case r: DataSourceV2Relation => + r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata)) + case other => other + } + v2Write.withNewQuery(projection).withNewTable(cleanedTable) } else { v2Write } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 9998035d65c3f..d567107c9fc8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.optimizer.BooleanSimplification import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, TypeUtils} import org.apache.spark.sql.connector.catalog.{SupportsAtomicPartitionManagement, SupportsPartitionManagement, Table} import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, After, ColumnPosition, DeleteColumn, RenameColumn, UpdateColumnComment, UpdateColumnNullability, UpdateColumnPosition, UpdateColumnType} import org.apache.spark.sql.internal.SQLConf @@ -94,6 +94,10 @@ trait CheckAnalysis extends PredicateHelper { case p if p.analyzed => // Skip already analyzed sub-plans + case leaf: LeafNode if leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => + throw new IllegalStateException( + "[BUG] leaf logical plan should not have output of char/varchar type: " + leaf) + case u: UnresolvedNamespace => u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala index deeb8215d22c6..128ca1278bf54 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveCatalogs.scala @@ -35,7 +35,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableAddColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => cols.foreach(c => failNullType(c.dataType)) - cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( col.name.toArray, @@ -49,7 +48,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case AlterTableReplaceColumnsStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), cols) => cols.foreach(c => failNullType(c.dataType)) - cols.foreach(c => failCharType(c.dataType)) val changes: Seq[TableChange] = loadTable(catalog, tbl.asIdentifier) match { case Some(table) => // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. @@ -72,7 +70,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case a @ AlterTableAlterColumnStatement( nameParts @ NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _) => a.dataType.foreach(failNullType) - a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => TableChange.updateColumnType(colName, newDataType) @@ -145,7 +142,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ CreateTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, tbl.asIdentifier, @@ -173,7 +169,6 @@ class ResolveCatalogs(val catalogManager: CatalogManager) case c @ ReplaceTableStatement( NonSessionCatalogAndTable(catalog, tbl), _, _, _, _, _, _, _, _, _) => assertNoNullTypeInSchema(c.tableSchema) - assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, tbl.asIdentifier, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index 4f33ca99c02db..c6bba370c8fef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Alias, AnsiCast, Attribute, Cast, NamedExpression} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.StoreAssignmentPolicy import org.apache.spark.sql.types.DataType @@ -93,19 +94,17 @@ object TableOutputResolver { tableAttr.metadata == queryExpr.metadata) { Some(queryExpr) } else { - // Renaming is needed for handling the following cases like - // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 - // 2) Target tables have column metadata - storeAssignmentPolicy match { + val casted = storeAssignmentPolicy match { case StoreAssignmentPolicy.ANSI => - Some(Alias( - AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)), - tableAttr.name)(explicitMetadata = Option(tableAttr.metadata))) + AnsiCast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) case _ => - Some(Alias( - Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)), - tableAttr.name)(explicitMetadata = Option(tableAttr.metadata))) + Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) } + val strLenChecked = CharVarcharUtils.stringLengthCheck(casted, tableAttr) + // Renaming is needed for handling the following cases like + // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 + // 2) Target tables have column metadata + Some(Alias(strLenChecked, tableAttr.name)(explicitMetadata = Some(tableAttr.metadata))) } storeAssignmentPolicy match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 17ab6664df75c..ec8daf3eb46d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, ImplicitCastInputTypes} import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} -import org.apache.spark.sql.catalyst.util.StringUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, StringUtils} import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.GLOBAL_TEMP_DATABASE @@ -473,7 +473,13 @@ class SessionCatalog( val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - externalCatalog.getTable(db, table) + removeCharVarcharFromTableSchema(externalCatalog.getTable(db, table)) + } + + // We replace char/varchar with string type in the table schema, as Spark's type system doesn't + // support char/varchar yet. + private def removeCharVarcharFromTableSchema(t: CatalogTable): CatalogTable = { + t.copy(schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(t.schema)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 50580b8e335ff..960aee4be6f89 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -36,8 +36,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, IntervalUtils} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} -import org.apache.spark.sql.catalyst.util.IntervalUtils import org.apache.spark.sql.catalyst.util.IntervalUtils.IntervalUnit import org.apache.spark.sql.connector.catalog.{SupportsNamespaces, TableCatalog} import org.apache.spark.sql.connector.catalog.TableChange.ColumnPosition @@ -2216,7 +2216,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg * Create a Spark DataType. */ private def visitSparkDataType(ctx: DataTypeContext): DataType = { - HiveStringType.replaceCharType(typedVisit(ctx)) + CharVarcharUtils.replaceCharVarcharWithString(typedVisit(ctx)) } /** @@ -2291,16 +2291,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg builder.putString("comment", _) } - // Add Hive type string to metadata. - val rawDataType = typedVisit[DataType](ctx.dataType) - val cleanedDataType = HiveStringType.replaceCharType(rawDataType) - if (rawDataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, rawDataType.catalogString) - } - StructField( name = colName.getText, - dataType = cleanedDataType, + dataType = typedVisit[DataType](ctx.dataType), nullable = NULL == null, metadata = builder.build()) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index 5bda2b5b8db01..1081fec37490c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.{NamedRelation, PartitionSpec, Res import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, Expression, Unevaluable} import org.apache.spark.sql.catalyst.plans.DescribeCommandSchema +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, ColumnChange} import org.apache.spark.sql.connector.expressions.Transform @@ -45,9 +46,10 @@ trait V2WriteCommand extends Command { table.skipSchemaResolution || (query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => + val outType = CharVarcharUtils.getRawType(outAttr.metadata).getOrElse(outAttr.dataType) // names and types must match, nullability must be compatible inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outAttr.dataType) && + DataType.equalsIgnoreCompatibleNullability(inAttr.dataType, outType) && (outAttr.nullable || !inAttr.nullable) }) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala new file mode 100644 index 0000000000000..6b867b36c62d4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.types._ + +object CharVarcharUtils { + + private val CHAR_VARCHAR_TYPE_STRING_METADATA_KEY = "__CHAR_VARCHAR_TYPE_STRING" + + /** + * Replaces CharType/VarcharType with StringType recursively in the given struct type. If a + * top-level StructField's data type is CharType/VarcharType or has nested CharType/VarcharType, + * this method will add the original type string to the StructField's metadata, so that we can + * re-construct the original data type with CharType/VarcharType later when needed. + */ + def replaceCharVarcharWithStringInSchema(st: StructType): StructType = { + StructType(st.map { field => + if (hasCharVarchar(field.dataType)) { + val metadata = new MetadataBuilder().withMetadata(field.metadata) + .putString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY, field.dataType.sql).build() + field.copy(dataType = replaceCharVarcharWithString(field.dataType), metadata = metadata) + } else { + field + } + }) + } + + /** + * Returns true if the given data type is CharType/VarcharType or has nested CharType/VarcharType. + */ + def hasCharVarchar(dt: DataType): Boolean = { + dt.existsRecursively(f => f.isInstanceOf[CharType] || f.isInstanceOf[VarcharType]) + } + + /** + * Replaces CharType/VarcharType with StringType recursively in the given data type. + */ + def replaceCharVarcharWithString(dt: DataType): DataType = dt match { + case ArrayType(et, nullable) => + ArrayType(replaceCharVarcharWithString(et), nullable) + case MapType(kt, vt, nullable) => + MapType(replaceCharVarcharWithString(kt), replaceCharVarcharWithString(vt), nullable) + case StructType(fields) => + StructType(fields.map { field => + field.copy(dataType = replaceCharVarcharWithString(field.dataType)) + }) + case _: CharType => StringType + case _: VarcharType => StringType + case _ => dt + } + + /** + * Removes the metadata entry that contains the original type string of CharType/VarcharType from + * the given attribute's metadata. + */ + def cleanAttrMetadata(attr: AttributeReference): AttributeReference = { + val cleaned = new MetadataBuilder().withMetadata(attr.metadata) + .remove(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY).build() + attr.withMetadata(cleaned) + } + + /** + * Re-construct the original data type from the type string in the given metadata. + * This is needed when dealing with char/varchar columns/fields. + */ + def getRawType(metadata: Metadata): Option[DataType] = { + if (metadata.contains(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY)) { + Some(CatalystSqlParser.parseRawDataType( + metadata.getString(CHAR_VARCHAR_TYPE_STRING_METADATA_KEY))) + } else { + None + } + } + + /** + * Re-construct the original StructType from the type strings in the metadata of StructFields. + * This is needed when dealing with char/varchar columns/fields. + */ + def getRawSchema(schema: StructType): StructType = { + StructType(schema.map { field => + getRawType(field.metadata).map(rawType => field.copy(dataType = rawType)).getOrElse(field) + }) + } + + /** + * Returns expressions to apply read-side char type padding for the given attributes. String + * values should be right-padded to N characters if it's from a CHAR(N) column/field. + */ + def charTypePadding(output: Seq[AttributeReference]): Seq[NamedExpression] = { + output.map { attr => + getRawType(attr.metadata).filter { rawType => + rawType.existsRecursively(_.isInstanceOf[CharType]) + }.map { rawType => + Alias(charTypePadding(attr, rawType), attr.name)(explicitMetadata = Some(attr.metadata)) + }.getOrElse(attr) + } + } + + private def charTypePadding(expr: Expression, dt: DataType): Expression = dt match { + case CharType(length) => StringRPad(expr, Literal(length)) + + case StructType(fields) => + CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), charTypePadding(GetStructField(expr, i, Some(f.name)), f.dataType)) + }) + + case ArrayType(et, containsNull) => charTypePaddingInArray(expr, et, containsNull) + + case MapType(kt, vt, valueContainsNull) => + val newKeys = charTypePaddingInArray(MapKeys(expr), kt, containsNull = false) + val newValues = charTypePaddingInArray(MapValues(expr), vt, valueContainsNull) + MapFromArrays(newKeys, newValues) + + case _ => expr + } + + private def charTypePaddingInArray( + arr: Expression, et: DataType, containsNull: Boolean): Expression = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val func = LambdaFunction(charTypePadding(param, et), Seq(param)) + ArrayTransform(arr, func) + } + + /** + * Returns an expression to apply write-side char type padding for the given expression. A string + * value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) column/field. + */ + def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = { + getRawType(targetAttr.metadata).map { rawType => + stringLengthCheck(expr, rawType) + }.getOrElse(expr) + } + + private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match { + case CharType(length) => + val trimmed = StringTrimRight(expr) + val errorMsg = Concat(Seq( + Literal("input string '"), + expr, + Literal(s"' exceeds char type length limitation: $length"))) + // Trailing spaces do not count in the length check. We don't need to retain the trailing + // spaces, as we will pad char type columns/fields at read time. + If( + GreaterThan(Length(trimmed), Literal(length)), + Cast(RaiseError(errorMsg), StringType), + trimmed) + + case VarcharType(length) => + val trimmed = StringTrimRight(expr) + val errorMsg = Concat(Seq( + Literal("input string '"), + expr, + Literal(s"' exceeds varchar type length limitation: $length"))) + // Trailing spaces do not count in the length check. We need to retain the trailing spaces + // (truncate to length N), as there is no read-time padding for varchar type. + // TODO: create a special TrimRight function that can trim to a certain length. + If( + LessThanOrEqual(Length(expr), Literal(length)), + expr, + If( + GreaterThan(Length(trimmed), Literal(length)), + Cast(RaiseError(errorMsg), StringType), + StringRPad(trimmed, Literal(length)))) + + case StructType(fields) => + CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + Seq(Literal(f.name), stringLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType)) + }) + + case ArrayType(et, containsNull) => stringLengthCheckInArray(expr, et, containsNull) + + case MapType(kt, vt, valueContainsNull) => + val newKeys = stringLengthCheckInArray(MapKeys(expr), kt, containsNull = false) + val newValues = stringLengthCheckInArray(MapValues(expr), vt, valueContainsNull) + MapFromArrays(newKeys, newValues) + + case _ => expr + } + + private def stringLengthCheckInArray( + arr: Expression, et: DataType, containsNull: Boolean): Expression = { + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + val func = LambdaFunction(stringLengthCheck(param, et), Seq(param)) + ArrayTransform(arr, func) + } + + /** + * Return expressions to apply char type padding for the string comparison between the given + * attributes. When comparing two char type columns/fields, we need to pad the shorter one to + * the longer length. + */ + def addPaddingInStringComparison(attrs: Seq[Attribute]): Seq[Expression] = { + val rawTypes = attrs.map(attr => getRawType(attr.metadata)) + if (rawTypes.exists(_.isEmpty)) { + attrs + } else { + val typeWithTargetCharLength = rawTypes.map(_.get).reduce(typeWithWiderCharLength) + attrs.zip(rawTypes.map(_.get)).map { case (attr, rawType) => + padCharToTargetLength(attr, rawType, typeWithTargetCharLength).getOrElse(attr) + } + } + } + + private def typeWithWiderCharLength(type1: DataType, type2: DataType): DataType = { + (type1, type2) match { + case (CharType(len1), CharType(len2)) => + CharType(math.max(len1, len2)) + case (StructType(fields1), StructType(fields2)) => + assert(fields1.length == fields2.length) + StructType(fields1.zip(fields2).map { case (left, right) => + StructField("", typeWithWiderCharLength(left.dataType, right.dataType)) + }) + case (ArrayType(et1, _), ArrayType(et2, _)) => + ArrayType(typeWithWiderCharLength(et1, et2)) + case (MapType(kt1, vt1, _), MapType(kt2, vt2, _)) => + MapType(typeWithWiderCharLength(kt1, kt2), typeWithWiderCharLength(vt1, vt2)) + case _ => NullType + } + } + + private def padCharToTargetLength( + expr: Expression, + rawType: DataType, + typeWithTargetCharLength: DataType): Option[Expression] = { + (rawType, typeWithTargetCharLength) match { + case (CharType(len), CharType(target)) if target > len => + Some(StringRPad(expr, Literal(target))) + + case (StructType(fields), StructType(targets)) => + assert(fields.length == targets.length) + var i = 0 + var needPadding = false + val createStructExprs = mutable.ArrayBuffer.empty[Expression] + while (i < fields.length) { + val field = fields(i) + val fieldExpr = GetStructField(expr, i, Some(field.name)) + val padded = padCharToTargetLength(fieldExpr, field.dataType, targets(i).dataType) + needPadding = padded.isDefined + createStructExprs += Literal(field.name) + createStructExprs += padded.getOrElse(fieldExpr) + i += 1 + } + if (needPadding) Some(CreateNamedStruct(createStructExprs)) else None + + case (ArrayType(et, containsNull), ArrayType(target, _)) => + val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) + padCharToTargetLength(param, et, target).map { padded => + val func = LambdaFunction(padded, Seq(param)) + ArrayTransform(expr, func) + } + + // We don't handle MapType here as it's not comparable. + + case _ => None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index 1a3a7207c6ca9..a14e165f7adf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -24,11 +24,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamedRelation, NoSuchDatabaseException, NoSuchNamespaceException, NoSuchTableException, UnresolvedV2Relation} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.AlterTable import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation -import org.apache.spark.sql.types.{ArrayType, DataType, HIVE_TYPE_STRING, HiveStringType, MapType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils @@ -332,21 +331,6 @@ private[sql] object CatalogV2Util { .asTableCatalog } - def failCharType(dt: DataType): Unit = { - if (HiveStringType.containsCharType(dt)) { - throw new AnalysisException( - "Cannot use CHAR type in non-Hive-Serde tables, please use STRING type instead.") - } - } - - def assertNoCharTypeInSchema(schema: StructType): Unit = { - schema.foreach { f => - if (f.metadata.contains(HIVE_TYPE_STRING)) { - failCharType(CatalystSqlParser.parseRawDataType(f.metadata.getString(HIVE_TYPE_STRING))) - } - } - } - def failNullType(dt: DataType): Unit = { def containsNullType(dt: DataType): Boolean = dt match { case ArrayType(et, _) => containsNullType(et) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index f541411daeff4..a31b1cc924fc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{CatalogPlugin, Identifier, MetadataColumn, SupportsMetadataColumns, Table, TableCapability} import org.apache.spark.sql.connector.read.{Scan, Statistics => V2Statistics, SupportsReportStatistics} import org.apache.spark.sql.connector.read.streaming.{Offset, SparkDataStream} @@ -171,8 +171,10 @@ object DataSourceV2Relation { catalog: Option[CatalogPlugin], identifier: Option[Identifier], options: CaseInsensitiveStringMap): DataSourceV2Relation = { - val output = table.schema().toAttributes - DataSourceV2Relation(table, output, catalog, identifier, options) + // The v2 source may return schema containing char/varchar type. We replace char/varchar + // with string type here as Spark's type system doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.schema) + DataSourceV2Relation(table, schema.toAttributes, catalog, identifier, options) } def create( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala new file mode 100644 index 0000000000000..dce4bfaa4fab5 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.UTF8String + +@Experimental +case class CharType(length: Int) extends AtomicType { + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = typeTag[InternalType] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + override def defaultSize: Int = length + override def typeName: String = s"char($length)" + override def toString: String = s"CharType($length)" + private[spark] override def asNullable: CharType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 7556a19f0d316..6b871c9783471 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -124,6 +124,8 @@ abstract class DataType extends AbstractDataType { object DataType { private val FIXED_DECIMAL = """decimal\(\s*(\d+)\s*,\s*(\-?\d+)\s*\)""".r + private val CHAR_TYPE = """char\(\s*(\d+)\s*\)""".r + private val VARCHAR_TYPE = """varchar\(\s*(\d+)\s*\)""".r def fromDDL(ddl: String): DataType = { parseTypeWithFallback( @@ -166,7 +168,7 @@ object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) - private val nonDecimalNameToType = { + private val otherTypes = { Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap @@ -177,7 +179,9 @@ object DataType { name match { case "decimal" => DecimalType.USER_DEFAULT case FIXED_DECIMAL(precision, scale) => DecimalType(precision.toInt, scale.toInt) - case other => nonDecimalNameToType.getOrElse( + case CHAR_TYPE(length) => CharType(length.toInt) + case VARCHAR_TYPE(length) => VarcharType(length.toInt) + case other => otherTypes.getOrElse( other, throw new IllegalArgumentException( s"Failed to convert the JSON string '$name' to a data type.")) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala deleted file mode 100644 index a29f49ad14a77..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/HiveStringType.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.types - -import scala.math.Ordering -import scala.reflect.runtime.universe.typeTag - -import org.apache.spark.unsafe.types.UTF8String - -/** - * A hive string type for compatibility. These datatypes should only used for parsing, - * and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -sealed abstract class HiveStringType extends AtomicType { - private[sql] type InternalType = UTF8String - - private[sql] val ordering = implicitly[Ordering[InternalType]] - - @transient private[sql] lazy val tag = typeTag[InternalType] - - override def defaultSize: Int = length - - private[spark] override def asNullable: HiveStringType = this - - def length: Int -} - -object HiveStringType { - def replaceCharType(dt: DataType): DataType = dt match { - case ArrayType(et, nullable) => - ArrayType(replaceCharType(et), nullable) - case MapType(kt, vt, nullable) => - MapType(replaceCharType(kt), replaceCharType(vt), nullable) - case StructType(fields) => - StructType(fields.map { field => - field.copy(dataType = replaceCharType(field.dataType)) - }) - case _: HiveStringType => StringType - case _ => dt - } - - def containsCharType(dt: DataType): Boolean = dt match { - case ArrayType(et, _) => containsCharType(et) - case MapType(kt, vt, _) => containsCharType(kt) || containsCharType(vt) - case StructType(fields) => fields.exists(f => containsCharType(f.dataType)) - case _ => dt.isInstanceOf[CharType] - } -} - -/** - * Hive char type. Similar to other HiveStringType's, these datatypes should only used for - * parsing, and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -case class CharType(length: Int) extends HiveStringType { - override def simpleString: String = s"char($length)" -} - -/** - * Hive varchar type. Similar to other HiveStringType's, these datatypes should only used for - * parsing, and should NOT be used anywhere else. Any instance of these data types should be - * replaced by a [[StringType]] before analysis. - */ -case class VarcharType(length: Int) extends HiveStringType { - override def simpleString: String = s"varchar($length)" -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala new file mode 100644 index 0000000000000..14454550dd981 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.types + +import scala.math.Ordering +import scala.reflect.runtime.universe.typeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.unsafe.types.UTF8String + +@Experimental +case class VarcharType(length: Int) extends AtomicType { + private[sql] type InternalType = UTF8String + @transient private[sql] lazy val tag = typeTag[InternalType] + private[sql] val ordering = implicitly[Ordering[InternalType]] + + override def defaultSize: Int = length + override def typeName: String = s"varchar($length)" + override def toString: String = s"CharType($length)" + private[spark] override def asNullable: VarcharType = this +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala index f29cbc2069e39..346a51ea10c82 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/package.scala @@ -21,12 +21,4 @@ package org.apache.spark.sql * Contains a type system for attributes produced by relations, including complex types like * structs, arrays and maps. */ -package object types { - /** - * Metadata key used to store the raw hive type string in the metadata of StructField. This - * is relevant for datatypes that do not have a direct Spark SQL counterpart, such as CHAR and - * VARCHAR. We need to preserve the original type in order to invoke the correct object - * inspector in Hive. - */ - val HIVE_TYPE_STRING = "HIVE_TYPE_STRING" -} +package object types diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index f0a24d4a56048..6820d5d189537 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import java.util.TimeZone +import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -41,9 +42,11 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, RangePartitioning, RoundRobinPartitioning} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.connector.InMemoryTable +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - +import org.apache.spark.sql.util.CaseInsensitiveStringMap class AnalysisSuite extends AnalysisTest with Matchers { import org.apache.spark.sql.catalyst.analysis.TestRelations._ @@ -55,6 +58,19 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } + test("fail for leaf node with char/varchar type") { + val schema1 = new StructType().add("c", CharType(5)) + val schema2 = new StructType().add("c", VarcharType(5)) + val schema3 = new StructType().add("c", ArrayType(CharType(5))) + Seq(schema1, schema2, schema3).foreach { schema => + val table = new InMemoryTable("t", schema, Array.empty, Map.empty[String, String].asJava) + intercept[IllegalStateException] { + DataSourceV2Relation( + table, schema.toAttributes, None, None, CaseInsensitiveStringMap.empty()).analyze + } + } + } + test("union project *") { val plan = (1 to 120) .map(_ => testRelation) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 6803fc307f919..5519f016e48d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -57,11 +57,6 @@ class TableSchemaParserSuite extends SparkFunSuite { |anotherArray:Array> """.stripMargin.replace("\n", "") - val builder = new MetadataBuilder - builder.putString(HIVE_TYPE_STRING, - "struct," + - "MAP:map,arrAy:array,anotherArray:array>") - val expectedDataType = StructType( StructField("complexStructCol", StructType( @@ -69,11 +64,9 @@ class TableSchemaParserSuite extends SparkFunSuite { StructType( StructField("deciMal", DecimalType.USER_DEFAULT) :: StructField("anotherDecimal", DecimalType(5, 2)) :: Nil)) :: - StructField("MAP", MapType(TimestampType, StringType)) :: + StructField("MAP", MapType(TimestampType, VarcharType(10))) :: StructField("arrAy", ArrayType(DoubleType)) :: - StructField("anotherArray", ArrayType(StringType)) :: Nil), - nullable = true, - builder.build()) :: Nil) + StructField("anotherArray", ArrayType(CharType(9))) :: Nil)) :: Nil) assert(parse(tableSchemaString) === expectedDataType) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala index ffff00b54f1b8..cfb044b428e41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/InMemoryTable.scala @@ -28,7 +28,7 @@ import org.scalatest.Assertions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, JoinedRow} -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, DateTimeUtils} import org.apache.spark.sql.connector.catalog._ import org.apache.spark.sql.connector.expressions.{BucketTransform, DaysTransform, HoursTransform, IdentityTransform, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.connector.read._ @@ -116,11 +116,12 @@ class InMemoryTable( } } + val cleanedSchema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema) partitioning.map { case IdentityTransform(ref) => - extractor(ref.fieldNames, schema, row)._1 + extractor(ref.fieldNames, cleanedSchema, row)._1 case YearsTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => ChronoUnit.YEARS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) case (micros: Long, TimestampType) => @@ -130,7 +131,7 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case MonthsTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days: Int, DateType) => ChronoUnit.MONTHS.between(EPOCH_LOCAL_DATE, DateTimeUtils.daysToLocalDate(days)) case (micros: Long, TimestampType) => @@ -140,7 +141,7 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case DaysTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (days, DateType) => days case (micros: Long, TimestampType) => @@ -149,14 +150,14 @@ class InMemoryTable( throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case HoursTransform(ref) => - extractor(ref.fieldNames, schema, row) match { + extractor(ref.fieldNames, cleanedSchema, row) match { case (micros: Long, TimestampType) => ChronoUnit.HOURS.between(Instant.EPOCH, DateTimeUtils.microsToInstant(micros)) case (v, t) => throw new IllegalArgumentException(s"Match: unsupported argument(s) type - ($v, $t)") } case BucketTransform(numBuckets, ref) => - val (value, dataType) = extractor(ref.fieldNames, schema, row) + val (value, dataType) = extractor(ref.fieldNames, cleanedSchema, row) val valueHashCode = if (value == null) 0 else value.hashCode ((valueHashCode + 31 * dataType.hashCode()) & Integer.MAX_VALUE) % numBuckets } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala index 7a9a7f52ff8fd..da5cfab8be3c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogV2UtilSuite.scala @@ -28,7 +28,7 @@ class CatalogV2UtilSuite extends SparkFunSuite { val testCatalog = mock(classOf[TableCatalog]) val ident = mock(classOf[Identifier]) val table = mock(classOf[Table]) - when(table.schema()).thenReturn(mock(classOf[StructType])) + when(table.schema()).thenReturn(new StructType().add("i", "int")) when(testCatalog.loadTable(ident)).thenReturn(table) val r = CatalogV2Util.loadRelation(testCatalog, ident) assert(r.isDefined) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index c164835c753e8..b3e403ffa7382 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.toPrettySQL +import org.apache.spark.sql.catalyst.util.{toPrettySQL, CharVarcharUtils} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.lit @@ -1181,7 +1181,9 @@ class Column(val expr: Expression) extends Logging { * @group expr_ops * @since 1.3.0 */ - def cast(to: DataType): Column = withExpr { Cast(expr, to) } + def cast(to: DataType): Column = withExpr { + Cast(expr, CharVarcharUtils.replaceCharVarcharWithString(to)) + } /** * Casts the column to a different data type, using the canonical string representation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b26bc6441b6cf..3b5532ccb910f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.csv.{CSVHeaderChecker, CSVOptions, UnivocityParser} import org.apache.spark.sql.catalyst.expressions.ExprUtils import org.apache.spark.sql.catalyst.json.{CreateJacksonParser, JacksonParser, JSONOptions} -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailureSafeParser} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, FailureSafeParser} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, SupportsCatalogOptions, SupportsRead} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils @@ -274,11 +274,14 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray)) } + val cleanedUserSpecifiedSchema = userSpecifiedSchema + .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) + val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val (table, catalog, ident) = provider match { - case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => + case _: SupportsCatalogOptions if cleanedUserSpecifiedSchema.nonEmpty => throw new IllegalArgumentException( s"$source does not support user specified schema. Please don't specify the schema.") case hasCatalog: SupportsCatalogOptions => @@ -290,7 +293,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { (catalog.loadTable(ident), Some(catalog), Some(ident)) case _ => // TODO: Non-catalog paths for DSV2 are currently not well defined. - val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) + val tbl = DataSourceV2Utils.getTableFromProvider( + provider, dsOptions, cleanedUserSpecifiedSchema) (tbl, None, None) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ @@ -312,13 +316,15 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } else { (paths, extraOptions) } + val cleanedUserSpecifiedSchema = userSpecifiedSchema + .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) // Code path for data source v1. sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = finalPaths, - userSpecifiedSchema = userSpecifiedSchema, + userSpecifiedSchema = cleanedUserSpecifiedSchema, className = source, options = finalOptions.originalMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index 303ae47f06b84..506ffbb4aadc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource} import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 -import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} /** * Resolves catalogs from the multi-part identifiers in SQL statements, and convert the statements @@ -50,9 +50,6 @@ class ResolveSessionCatalog( cols.foreach(c => failNullType(c.dataType)) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => - if (!DDLUtils.isHiveTable(v1Table.v1Table)) { - cols.foreach(c => failCharType(c.dataType)) - } cols.foreach { c => assertTopLevelColumn(c.name, "AlterTableAddColumnsCommand") if (!c.nullable) { @@ -62,7 +59,6 @@ class ResolveSessionCatalog( } AlterTableAddColumnsCommand(tbl.asTableIdentifier, cols.map(convertToStructField)) }.getOrElse { - cols.foreach(c => failCharType(c.dataType)) val changes = cols.map { col => TableChange.addColumn( col.name.toArray, @@ -81,7 +77,6 @@ class ResolveSessionCatalog( case Some(_: V1Table) => throw new AnalysisException("REPLACE COLUMNS is only supported with v2 tables.") case Some(table) => - cols.foreach(c => failCharType(c.dataType)) // REPLACE COLUMNS deletes all the existing columns and adds new columns specified. val deleteChanges = table.schema.fieldNames.map { name => TableChange.deleteColumn(Array(name)) @@ -104,10 +99,6 @@ class ResolveSessionCatalog( a.dataType.foreach(failNullType) loadTable(catalog, tbl.asIdentifier).collect { case v1Table: V1Table => - if (!DDLUtils.isHiveTable(v1Table.v1Table)) { - a.dataType.foreach(failCharType) - } - if (a.column.length > 1) { throw new AnalysisException( "ALTER COLUMN with qualified column is only supported with v2 tables.") @@ -133,19 +124,13 @@ class ResolveSessionCatalog( s"Available: ${v1Table.schema.fieldNames.mkString(", ")}") } } - // Add Hive type string to metadata. - val cleanedDataType = HiveStringType.replaceCharType(dataType) - if (dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, dataType.catalogString) - } val newColumn = StructField( colName, - cleanedDataType, + dataType, nullable = true, builder.build()) AlterTableChangeColumnCommand(tbl.asTableIdentifier, colName, newColumn) }.getOrElse { - a.dataType.foreach(failCharType) val colName = a.column.toArray val typeChange = a.dataType.map { newDataType => TableChange.updateColumnType(colName, newDataType) @@ -269,16 +254,12 @@ class ResolveSessionCatalog( assertNoNullTypeInSchema(c.tableSchema) val provider = c.provider.getOrElse(conf.defaultDataSourceName) if (!isV2Provider(provider)) { - if (!DDLUtils.isHiveTable(Some(provider))) { - assertNoCharTypeInSchema(c.tableSchema) - } val tableDesc = buildCatalogTable(tbl.asTableIdentifier, c.tableSchema, c.partitioning, c.bucketSpec, c.properties, provider, c.options, c.location, c.comment, c.ifNotExists) val mode = if (c.ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, None) } else { - assertNoCharTypeInSchema(c.tableSchema) CreateV2Table( catalog.asTableCatalog, tbl.asIdentifier, @@ -328,7 +309,6 @@ class ResolveSessionCatalog( if (!isV2Provider(provider)) { throw new AnalysisException("REPLACE TABLE is only supported with v2 tables.") } else { - assertNoCharTypeInSchema(c.tableSchema) ReplaceTable( catalog.asTableCatalog, tbl.asIdentifier, @@ -716,17 +696,7 @@ class ResolveSessionCatalog( private def convertToStructField(col: QualifiedColType): StructField = { val builder = new MetadataBuilder col.comment.foreach(builder.putString("comment", _)) - - val cleanedDataType = HiveStringType.replaceCharType(col.dataType) - if (col.dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, col.dataType.catalogString) - } - - StructField( - col.name.head, - cleanedDataType, - nullable = true, - builder.build()) + StructField(col.name.head, col.dataType, nullable = true, builder.build()) } private def isV2Provider(provider: String): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala new file mode 100644 index 0000000000000..35bb86f178eb1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.expressions.{Attribute, BinaryComparison, Expression, In, Literal, StringRPad} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.types.{CharType, StringType} +import org.apache.spark.unsafe.types.UTF8String + +/** + * This rule applies char type padding in two places: + * 1. When reading values from column/field of type CHAR(N), right-pad the values to length N. + * 2. When comparing char type column/field with string literal or char type column/field, + * right-pad the shorter one to the longer length. + */ +object ApplyCharTypePadding extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = { + val padded = plan.resolveOperatorsUpWithNewOutput { + case r: LogicalRelation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedOutput = r.output.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, r.copy(output = cleanedOutput)) + padded -> r.output.zip(padded.output) + } + + case r: DataSourceV2Relation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedOutput = r.output.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, r.copy(output = cleanedOutput)) + padded -> r.output.zip(padded.output) + } + + case r: HiveTableRelation => + val projectList = CharVarcharUtils.charTypePadding(r.output) + if (projectList == r.output) { + r -> Nil + } else { + val cleanedDataCols = r.dataCols.map(CharVarcharUtils.cleanAttrMetadata) + val cleanedPartCols = r.partitionCols.map(CharVarcharUtils.cleanAttrMetadata) + val padded = Project(projectList, + r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols)) + padded -> r.output.zip(padded.output) + } + } + + padded.resolveOperatorsUp { + case operator if operator.resolved => operator.transformExpressionsUp { + // String literal is treated as char type when it's compared to a char type column. + // We should pad the shorter one to the longer length. + case b @ BinaryComparison(attr: Attribute, lit) if lit.foldable => + padAttrLitCmp(attr, lit).map { newChildren => + b.withNewChildren(newChildren) + }.getOrElse(b) + + case b @ BinaryComparison(lit, attr: Attribute) if lit.foldable => + padAttrLitCmp(attr, lit).map { newChildren => + b.withNewChildren(newChildren.reverse) + }.getOrElse(b) + + case i @ In(attr: Attribute, list) + if attr.dataType == StringType && list.forall(_.foldable) => + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val literalCharLengths = list.map(_.eval().asInstanceOf[UTF8String].numChars()) + val targetLen = (length +: literalCharLengths).max + Some(i.copy( + value = addPadding(attr, length, targetLen), + list = list.zip(literalCharLengths).map { + case (lit, charLength) => addPadding(lit, charLength, targetLen) + })) + case _ => None + }.getOrElse(i) + + // For char type column or inner field comparison, pad the shorter one to the longer length. + case b @ BinaryComparison(left: Attribute, right: Attribute) => + b.withNewChildren(CharVarcharUtils.addPaddingInStringComparison(Seq(left, right))) + + case i @ In(attr: Attribute, list) if list.forall(_.isInstanceOf[Attribute]) => + val newChildren = CharVarcharUtils.addPaddingInStringComparison( + attr +: list.map(_.asInstanceOf[Attribute])) + i.copy(value = newChildren.head, list = newChildren.tail) + } + } + } + + private def padAttrLitCmp(attr: Attribute, lit: Expression): Option[Seq[Expression]] = { + if (attr.dataType == StringType) { + CharVarcharUtils.getRawType(attr.metadata).flatMap { + case CharType(length) => + val str = lit.eval().asInstanceOf[UTF8String] + val stringLitLen = str.numChars() + if (length < stringLitLen) { + Some(Seq(StringRPad(attr, Literal(stringLitLen)), lit)) + } else if (length > stringLitLen) { + Some(Seq(attr, StringRPad(lit, Literal(length)))) + } else { + None + } + case _ => None + } + } else { + None + } + } + + private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { + if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 33a3486bf6f67..0c6a80d441686 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.{AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.util.truncatedString +import org.apache.spark.sql.catalyst.util.{truncatedString, CharVarcharUtils} import org.apache.spark.sql.sources.BaseRelation /** @@ -69,9 +69,17 @@ case class LogicalRelation( } object LogicalRelation { - def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, None, isStreaming) + def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = { + // The v1 source may return schema containing char/varchar type. We replace char/varchar + // with string type here as Spark's type system doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) + LogicalRelation(relation, schema.toAttributes, None, isStreaming) + } - def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = - LogicalRelation(relation, relation.schema.toAttributes, Some(table), false) + def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = { + // The v1 source may return schema containing char/varchar type. We replace char/varchar + // with string type here as Spark's type system doesn't support char/varchar yet. + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) + LogicalRelation(relation, schema.toAttributes, Some(table), false) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 78f31fb80ecf6..6733aab947be6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -761,16 +761,6 @@ object JdbcUtils extends Logging { schema: StructType, caseSensitive: Boolean, createTableColumnTypes: String): Map[String, String] = { - def typeName(f: StructField): String = { - // char/varchar gets translated to string type. Real data type specified by the user - // is available in the field metadata as HIVE_TYPE_STRING - if (f.metadata.contains(HIVE_TYPE_STRING)) { - f.metadata.getString(HIVE_TYPE_STRING) - } else { - f.dataType.catalogString - } - } - val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) val nameEquality = if (caseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution @@ -791,7 +781,7 @@ object JdbcUtils extends Logging { } } - val userSchemaMap = userSchema.fields.map(f => f.name -> typeName(f)).toMap + val userSchemaMap = userSchema.fields.map(f => f.name -> f.dataType.catalogString).toMap if (caseSensitive) userSchemaMap else CaseInsensitiveMap(userSchemaMap) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 538a5408723bb..a89a5de3b7e72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -189,6 +189,7 @@ abstract class BaseSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis +: + ApplyCharTypePadding +: customPostHocResolutionRules override val extendedCheckRules: Seq[LogicalPlan => Unit] = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 9bc4acd49a980..b9a1a465d9e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -26,7 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SparkSession} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.streaming.StreamingRelationV2 -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils} import org.apache.spark.sql.connector.catalog.{SupportsRead, TableProvider} import org.apache.spark.sql.connector.catalog.TableCapability._ import org.apache.spark.sql.execution.command.DDLUtils @@ -203,6 +203,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo extraOptions + ("path" -> path.get) } + val cleanedUserSpecifiedSchema = userSpecifiedSchema + .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) + val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. @@ -210,7 +213,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo // writer or whether the query is continuous. val v1DataSource = DataSource( sparkSession, - userSpecifiedSchema = userSpecifiedSchema, + userSpecifiedSchema = cleanedUserSpecifiedSchema, className = source, options = optionsWithPath.originalMap) val v1Relation = ds match { @@ -225,7 +228,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) + val table = DataSourceV2Utils.getTableFromProvider( + provider, dsOptions, cleanedUserSpecifiedSchema) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala new file mode 100644 index 0000000000000..e192a63956232 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -0,0 +1,374 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.{SparkConf, SparkException} +import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} + +trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { + + def format: String + + test("char type values should be padded: top-level columns") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + } + } + + test("char type values should be padded: partitioned columns") { + // DS V2 doesn't support partitioned table. + if (!conf.contains(SQLConf.DEFAULT_CATALOG.key)) { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + } + } + } + + test("char type values should be padded: nested in struct") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c STRUCT) USING $format") + sql("INSERT INTO t VALUES ('1', struct('a'))") + checkAnswer(spark.table("t"), Row("1", Row("a" + " " * 4))) + } + } + + test("char type values should be padded: nested in array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY) USING $format") + sql("INSERT INTO t VALUES ('1', array('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Seq("a" + " " * 4, "ab" + " " * 3))) + } + } + + test("char type values should be padded: nested in map key") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab")))) + } + } + + test("char type values should be padded: nested in map value") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a", "ab" + " " * 3)))) + } + } + + test("char type values should be padded: nested in both map key and value") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") + sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") + checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab" + " " * 8)))) + } + } + + test("char type values should be padded: nested in struct of array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c STRUCT>) USING $format") + sql("INSERT INTO t VALUES ('1', struct(array('a', 'ab')))") + checkAnswer(spark.table("t"), Row("1", Row(Seq("a" + " " * 4, "ab" + " " * 3)))) + } + } + + test("char type values should be padded: nested in array of struct") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") + sql("INSERT INTO t VALUES ('1', array(struct('a'), struct('ab')))") + checkAnswer(spark.table("t"), Row("1", Seq(Row("a" + " " * 4), Row("ab" + " " * 3)))) + } + } + + test("char type values should be padded: nested in array of array") { + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") + sql("INSERT INTO t VALUES ('1', array(array('a', 'ab')))") + checkAnswer(spark.table("t"), Row("1", Seq(Seq("a" + " " * 4, "ab" + " " * 3)))) + } + } + + private def testTableWrite(f: String => Unit): Unit = { + withTable("t") { f("char") } + withTable("t") { f("varchar") } + } + + test("length check for input string values: top-level columns") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c $typeName(5)) USING $format") + sql("INSERT INTO t VALUES (null)") + checkAnswer(spark.table("t"), Row(null)) + val e = intercept[SparkException](sql("INSERT INTO t VALUES ('123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: partitioned columns") { + // DS V2 doesn't support partitioned table. + if (!conf.contains(SQLConf.DEFAULT_CATALOG.key)) { + testTableWrite { typeName => + sql(s"CREATE TABLE t(i INT, c $typeName(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES (1, null)") + checkAnswer(spark.table("t"), Row(1, null)) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (1, '123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + } + + test("length check for input string values: nested in struct") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c STRUCT) USING $format") + sql("INSERT INTO t SELECT struct(null)") + checkAnswer(spark.table("t"), Row(Row(null))) + val e = intercept[SparkException](sql("INSERT INTO t SELECT struct('123456')")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY<$typeName(5)>) USING $format") + sql("INSERT INTO t VALUES (array(null))") + checkAnswer(spark.table("t"), Row(Seq(null))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array('a', '123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in map key") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP<$typeName(5), STRING>) USING $format") + val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in map value") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP) USING $format") + sql("INSERT INTO t VALUES (map('a', null))") + checkAnswer(spark.table("t"), Row(Map("a" -> null))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in both map key and value") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c MAP<$typeName(5), $typeName(5)>) USING $format") + val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (map('123456', 'a'))")) + assert(e1.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (map('a', '123456'))")) + assert(e2.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in struct of array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c STRUCT>) USING $format") + sql("INSERT INTO t SELECT struct(array(null))") + checkAnswer(spark.table("t"), Row(Row(Seq(null)))) + val e = intercept[SparkException](sql("INSERT INTO t SELECT struct(array('123456'))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array of struct") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY>) USING $format") + sql("INSERT INTO t VALUES (array(struct(null)))") + checkAnswer(spark.table("t"), Row(Seq(Row(null)))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(struct('123456')))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: nested in array of array") { + testTableWrite { typeName => + sql(s"CREATE TABLE t(c ARRAY>) USING $format") + sql("INSERT INTO t VALUES (array(array(null)))") + checkAnswer(spark.table("t"), Row(Seq(Seq(null)))) + val e = intercept[SparkException](sql("INSERT INTO t VALUES (array(array('123456')))")) + assert(e.getCause.getMessage.contains( + s"input string '123456' exceeds $typeName type length limitation: 5")) + } + } + + test("length check for input string values: with trailing spaces") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(5), c2 VARCHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('12 ', '12 ')") + sql("INSERT INTO t VALUES ('1234 ', '1234 ')") + checkAnswer(spark.table("t"), Seq( + Row("12" + " " * 3, "12 "), + Row("1234 ", "1234 "))) + } + } + + test("length check for input string values: with implicit cast") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(5), c2 VARCHAR(5)) USING $format") + sql("INSERT INTO t VALUES (1234, 1234)") + checkAnswer(spark.table("t"), Row("1234 ", "1234")) + val e1 = intercept[SparkException](sql("INSERT INTO t VALUES (123456, 1)")) + assert(e1.getCause.getMessage.contains( + "input string '123456' exceeds char type length limitation: 5")) + val e2 = intercept[SparkException](sql("INSERT INTO t VALUES (1, 123456)")) + assert(e2.getCause.getMessage.contains( + "input string '123456' exceeds varchar type length limitation: 5")) + } + } + + private def testConditions(df: DataFrame, conditions: Seq[(String, Boolean)]): Unit = { + checkAnswer(df.selectExpr(conditions.map(_._1): _*), Row.fromSeq(conditions.map(_._2))) + } + + test("char type comparison: top-level columns") { + withTable("t") { + sql(s"CREATE TABLE t(c1 CHAR(2), c2 CHAR(5)) USING $format") + sql("INSERT INTO t VALUES ('a', 'a')") + testConditions(spark.table("t"), Seq( + ("c1 = 'a'", true), + ("'a' = c1", true), + ("c1 = 'a '", true), + ("c1 > 'a'", false), + ("c1 IN ('a', 'b')", true), + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: partitioned columns") { + withTable("t") { + sql(s"CREATE TABLE t(i INT, c1 CHAR(2), c2 CHAR(5)) USING $format PARTITIONED BY (c1, c2)") + sql("INSERT INTO t VALUES (1, 'a', 'a')") + testConditions(spark.table("t"), Seq( + ("c1 = 'a'", true), + ("'a' = c1", true), + ("c1 = 'a '", true), + ("c1 > 'a'", false), + ("c1 IN ('a', 'b')", true), + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: join") { + withTable("t1", "t2") { + sql(s"CREATE TABLE t1(c CHAR(2)) USING $format") + sql(s"CREATE TABLE t2(c CHAR(5)) USING $format") + sql("INSERT INTO t1 VALUES ('a')") + sql("INSERT INTO t2 VALUES ('a')") + checkAnswer(sql("SELECT t1.c FROM t1 JOIN t2 ON t1.c = t2.c"), Row("a ")) + } + } + + test("char type comparison: nested in struct") { + withTable("t") { + sql(s"CREATE TABLE t(c1 STRUCT, c2 STRUCT) USING $format") + sql("INSERT INTO t VALUES (struct('a'), struct('a'))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array") { + withTable("t") { + sql(s"CREATE TABLE t(c1 ARRAY, c2 ARRAY) USING $format") + sql("INSERT INTO t VALUES (array('a', 'b'), array('a', 'b'))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in struct of array") { + withTable("t") { + sql("CREATE TABLE t(c1 STRUCT>, c2 STRUCT>) " + + s"USING $format") + sql("INSERT INTO t VALUES (struct(array('a', 'b')), struct(array('a', 'b')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array of struct") { + withTable("t") { + sql("CREATE TABLE t(c1 ARRAY>, c2 ARRAY>) " + + s"USING $format") + sql("INSERT INTO t VALUES (array(struct('a')), array(struct('a')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } + + test("char type comparison: nested in array of array") { + withTable("t") { + sql("CREATE TABLE t(c1 ARRAY>, c2 ARRAY>) " + + s"USING $format") + sql("INSERT INTO t VALUES (array(array('a')), array(array('a')))") + testConditions(spark.table("t"), Seq( + ("c1 = c2", true), + ("c1 < c2", false), + ("c1 IN (c2)", true))) + } + } +} + +class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSparkSession { + override def format: String = "parquet" + override protected def sparkConf: SparkConf = { + super.sparkConf.set(SQLConf.USE_V1_SOURCE_LIST, "parquet") + } +} + +class DSV2CharVarcharTestSuite extends CharVarcharTestSuite + with SharedSparkSession { + override def format: String = "foo" + protected override def sparkConf = { + super.sparkConf + .set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + .set(SQLConf.DEFAULT_CATALOG.key, "testcat") + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index fd1978c5137a5..ee8b3efb8eb31 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{CharType, DoubleType, HIVE_TYPE_STRING, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.types.{CharType, DoubleType, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ @@ -1076,9 +1076,7 @@ class PlanResolutionSuite extends AnalysisTest { } val sql = s"ALTER TABLE v1HiveTable ALTER COLUMN i TYPE char(1)" - val builder = new MetadataBuilder - builder.putString(HIVE_TYPE_STRING, CharType(1).catalogString) - val newColumnWithCleanedType = StructField("i", StringType, true, builder.build()) + val newColumnWithCleanedType = StructField("i", CharType(1), true) val expected = AlterTableChangeColumnCommand( TableIdentifier("v1HiveTable", Some("default")), "i", newColumnWithCleanedType) val parsed = parseAndResolve(sql) @@ -1519,44 +1517,6 @@ class PlanResolutionSuite extends AnalysisTest { } } - test("SPARK-31147: forbid CHAR type in non-Hive tables") { - def checkFailure(t: String, provider: String): Unit = { - val types = Seq( - "CHAR(2)", - "ARRAY", - "MAP", - "MAP", - "STRUCT") - types.foreach { tpe => - intercept[AnalysisException] { - parseAndResolve(s"CREATE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"REPLACE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"CREATE OR REPLACE TABLE $t(col $tpe) USING $provider") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ADD COLUMN col $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ADD COLUMN col $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t ALTER COLUMN col TYPE $tpe") - } - intercept[AnalysisException] { - parseAndResolve(s"ALTER TABLE $t REPLACE COLUMNS (col $tpe)") - } - } - } - - checkFailure("v1Table", v1Format) - checkFailure("v2Table", v2Format) - checkFailure("testcat.tab", "foo") - } - // TODO: add tests for more commands. } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index 9a95bf770772e..ca3e714665818 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -127,7 +128,7 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Date.valueOf("1970-01-01"), new Timestamp(20000 + i), s"varchar_$i", - s"char_$i", + s"char_$i".padTo(18, ' '), Seq(i, i + 1), Seq(Map(s"str_$i" -> Row(i.toLong))), Map(i -> i.toString), @@ -206,10 +207,6 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { (2 to 10).map(i => Row(i, i - 1)).toSeq) test("Schema and all fields") { - def hiveMetadata(dt: String): Metadata = { - new MetadataBuilder().putString(HIVE_TYPE_STRING, dt).build() - } - val expectedSchema = StructType( StructField("string$%Field", StringType, true) :: StructField("binaryField", BinaryType, true) :: @@ -224,8 +221,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { StructField("decimalField2", DecimalType(9, 2), true) :: StructField("dateField", DateType, true) :: StructField("timestampField", TimestampType, true) :: - StructField("varcharField", StringType, true, hiveMetadata("varchar(12)")) :: - StructField("charField", StringType, true, hiveMetadata("char(18)")) :: + StructField("varcharField", VarcharType(12), true) :: + StructField("charField", CharType(18), true) :: StructField("arrayFieldSimple", ArrayType(IntegerType), true) :: StructField("arrayFieldComplex", ArrayType( @@ -248,7 +245,8 @@ class TableScanSuite extends DataSourceTest with SharedSparkSession { Nil ) - assert(expectedSchema == spark.table("tableWithSchema").schema) + assert(CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedSchema) == + spark.table("tableWithSchema").schema) withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { checkAnswer( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index b30492802495f..da37b61688951 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -90,6 +90,7 @@ class HiveSessionStateBuilder( PreprocessTableCreation(session) +: PreprocessTableInsertion +: DataSourceAnalysis +: + ApplyCharTypePadding +: HiveAnalysis +: customPostHocResolutionRules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 9bc99b08c2cc8..338b023d49774 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -976,19 +976,14 @@ private[hive] class HiveClientImpl( private[hive] object HiveClientImpl extends Logging { /** Converts the native StructField to Hive's FieldSchema. */ def toHiveColumn(c: StructField): FieldSchema = { - val typeString = if (c.metadata.contains(HIVE_TYPE_STRING)) { - c.metadata.getString(HIVE_TYPE_STRING) - } else { - // replace NullType to HiveVoidType since Hive parse void not null. - HiveVoidType.replaceVoidType(c.dataType).catalogString - } + val typeString = HiveVoidType.replaceVoidType(c.dataType).catalogString new FieldSchema(c.name, typeString, c.getComment().orNull) } /** Get the Spark SQL native DataType from Hive's FieldSchema. */ private def getSparkSQLDataType(hc: FieldSchema): DataType = { try { - CatalystSqlParser.parseDataType(hc.getType) + CatalystSqlParser.parseRawDataType(hc.getType) } catch { case e: ParseException => throw new SparkException( @@ -999,18 +994,10 @@ private[hive] object HiveClientImpl extends Logging { /** Builds the native StructField from Hive's FieldSchema. */ def fromHiveColumn(hc: FieldSchema): StructField = { val columnType = getSparkSQLDataType(hc) - val replacedVoidType = HiveVoidType.replaceVoidType(columnType) - val metadata = if (hc.getType != replacedVoidType.catalogString) { - new MetadataBuilder().putString(HIVE_TYPE_STRING, hc.getType).build() - } else { - Metadata.empty - } - val field = StructField( name = hc.getName, dataType = columnType, - nullable = true, - metadata = metadata) + nullable = true) Option(hc.getComment).map(field.withComment).getOrElse(field) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala new file mode 100644 index 0000000000000..55d305fda4f96 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/HiveCharVarcharTestSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.hive.test.TestHiveSingleton + +class HiveCharVarcharTestSuite extends CharVarcharTestSuite with TestHiveSingleton { + + // The default Hive serde doesn't support nested null values. + override def format: String = "hive OPTIONS(fileFormat='parquet')" + + private var originalPartitionMode = "" + + override protected def beforeAll(): Unit = { + super.beforeAll() + originalPartitionMode = spark.conf.get("hive.exec.dynamic.partition.mode", "") + spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") + } + + override protected def afterAll(): Unit = { + if (originalPartitionMode == "") { + spark.conf.unset("hive.exec.dynamic.partition.mode") + } else { + spark.conf.set("hive.exec.dynamic.partition.mode", originalPartitionMode) + } + super.afterAll() + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 8f71ba3337aa2..1a6f6843d3911 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -113,24 +113,19 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { .add("c9", "date") .add("c10", "timestamp") .add("c11", "string") - .add("c12", "string", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "char(10)").build()) - .add("c13", "string", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "varchar(10)").build()) + .add("c12", CharType(10), true) + .add("c13", VarcharType(10), true) .add("c14", "binary") .add("c15", "decimal") .add("c16", "decimal(10)") .add("c17", "decimal(10,2)") .add("c18", "array") .add("c19", "array") - .add("c20", "array", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "array").build()) + .add("c20", ArrayType(CharType(10)), true) .add("c21", "map") - .add("c22", "map", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "map").build()) + .add("c22", MapType(IntegerType, CharType(10)), true) .add("c23", "struct") - .add("c24", "struct", true, - new MetadataBuilder().putString(HIVE_TYPE_STRING, "struct").build()) + .add("c24", new StructType().add("c", VarcharType(10)).add("d", "int"), true) assert(schema == expectedSchema) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 56b871644453b..1b4f69d557210 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -2252,8 +2252,8 @@ class HiveDDLSuite ) sql("ALTER TABLE tab ADD COLUMNS (c5 char(10))") - assert(spark.table("tab").schema.find(_.name == "c5") - .get.metadata.getString("HIVE_TYPE_STRING") == "char(10)") + assert(spark.sharedState.externalCatalog.getTable("default", "tab") + .schema.find(_.name == "c5").get.dataType == CharType(10)) } } } From d1adb0ed1cbb5755bb645560eae1b28d393368e8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 19 Nov 2020 22:37:03 +0800 Subject: [PATCH 02/11] address comments --- .../sql/catalyst/analysis/CheckAnalysis.scala | 4 ++-- .../analysis/TableOutputResolver.scala | 4 ++-- .../sql/catalyst/catalog/SessionCatalog.scala | 9 +++------ .../sql/catalyst/util/CharVarcharUtils.scala | 20 +++++++++---------- .../datasources/v2/DataSourceV2Relation.scala | 2 +- .../org/apache/spark/sql/types/CharType.scala | 2 ++ .../apache/spark/sql/types/VarcharType.scala | 2 ++ .../datasources/LogicalRelation.scala | 4 ++-- 8 files changed, 24 insertions(+), 23 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d567107c9fc8d..bb03370dd4166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -94,9 +94,9 @@ trait CheckAnalysis extends PredicateHelper { case p if p.analyzed => // Skip already analyzed sub-plans - case leaf: LeafNode if leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => + case p if p.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => throw new IllegalStateException( - "[BUG] leaf logical plan should not have output of char/varchar type: " + leaf) + "[BUG] logical plan should not have output of char/varchar type: " + p) case u: UnresolvedNamespace => u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala index c6bba370c8fef..d5c407b47c5be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TableOutputResolver.scala @@ -100,11 +100,11 @@ object TableOutputResolver { case _ => Cast(queryExpr, tableAttr.dataType, Option(conf.sessionLocalTimeZone)) } - val strLenChecked = CharVarcharUtils.stringLengthCheck(casted, tableAttr) + val exprWithStrLenCheck = CharVarcharUtils.stringLengthCheck(casted, tableAttr) // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Some(Alias(strLenChecked, tableAttr.name)(explicitMetadata = Some(tableAttr.metadata))) + Some(Alias(exprWithStrLenCheck, tableAttr.name)(explicitMetadata = Some(tableAttr.metadata))) } storeAssignmentPolicy match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index ec8daf3eb46d2..a79c26a985982 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -473,12 +473,9 @@ class SessionCatalog( val table = formatTableName(name.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Some(db))) - removeCharVarcharFromTableSchema(externalCatalog.getTable(db, table)) - } - - // We replace char/varchar with string type in the table schema, as Spark's type system doesn't - // support char/varchar yet. - private def removeCharVarcharFromTableSchema(t: CatalogTable): CatalogTable = { + val t = externalCatalog.getTable(db, table) + // We replace char/varchar with "annotated" string type in the table schema, as the query + // engine doesn't support char/varchar yet. t.copy(schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(t.schema)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 6b867b36c62d4..eca6339073992 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -151,26 +151,26 @@ object CharVarcharUtils { }.getOrElse(expr) } + private def raiseError(expr: Expression, typeName: String, length: Int): Expression = { + val errorMsg = Concat(Seq( + Literal("input string '"), + expr, + Literal(s"' exceeds $typeName type length limitation: $length"))) + Cast(RaiseError(errorMsg), StringType) + } + private def stringLengthCheck(expr: Expression, dt: DataType): Expression = dt match { case CharType(length) => val trimmed = StringTrimRight(expr) - val errorMsg = Concat(Seq( - Literal("input string '"), - expr, - Literal(s"' exceeds char type length limitation: $length"))) // Trailing spaces do not count in the length check. We don't need to retain the trailing // spaces, as we will pad char type columns/fields at read time. If( GreaterThan(Length(trimmed), Literal(length)), - Cast(RaiseError(errorMsg), StringType), + raiseError(expr, "char", length), trimmed) case VarcharType(length) => val trimmed = StringTrimRight(expr) - val errorMsg = Concat(Seq( - Literal("input string '"), - expr, - Literal(s"' exceeds varchar type length limitation: $length"))) // Trailing spaces do not count in the length check. We need to retain the trailing spaces // (truncate to length N), as there is no read-time padding for varchar type. // TODO: create a special TrimRight function that can trim to a certain length. @@ -179,7 +179,7 @@ object CharVarcharUtils { expr, If( GreaterThan(Length(trimmed), Literal(length)), - Cast(RaiseError(errorMsg), StringType), + raiseError(expr, "varchar", length), StringRPad(trimmed, Literal(length)))) case StructType(fields) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index a31b1cc924fc3..4debdd380e6b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -172,7 +172,7 @@ object DataSourceV2Relation { identifier: Option[Identifier], options: CaseInsensitiveStringMap): DataSourceV2Relation = { // The v2 source may return schema containing char/varchar type. We replace char/varchar - // with string type here as Spark's type system doesn't support char/varchar yet. + // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(table.schema) DataSourceV2Relation(table, schema.toAttributes, catalog, identifier, options) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala index dce4bfaa4fab5..b329b5a964c87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -25,6 +25,8 @@ import org.apache.spark.unsafe.types.UTF8String @Experimental case class CharType(length: Int) extends AtomicType { + require(length >= 0, "The length if char type cannot be negative.") + private[sql] type InternalType = UTF8String @transient private[sql] lazy val tag = typeTag[InternalType] private[sql] val ordering = implicitly[Ordering[InternalType]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala index 14454550dd981..dd52b76ee2783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -24,6 +24,8 @@ import org.apache.spark.unsafe.types.UTF8String @Experimental case class VarcharType(length: Int) extends AtomicType { + require(length >= 0, "The length if varchar type cannot be negative.") + private[sql] type InternalType = UTF8String @transient private[sql] lazy val tag = typeTag[InternalType] private[sql] val ordering = implicitly[Ordering[InternalType]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 0c6a80d441686..8c61c8cd4f52e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -71,14 +71,14 @@ case class LogicalRelation( object LogicalRelation { def apply(relation: BaseRelation, isStreaming: Boolean = false): LogicalRelation = { // The v1 source may return schema containing char/varchar type. We replace char/varchar - // with string type here as Spark's type system doesn't support char/varchar yet. + // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) LogicalRelation(relation, schema.toAttributes, None, isStreaming) } def apply(relation: BaseRelation, table: CatalogTable): LogicalRelation = { // The v1 source may return schema containing char/varchar type. We replace char/varchar - // with string type here as Spark's type system doesn't support char/varchar yet. + // with "annotated" string type here as the query engine doesn't support char/varchar yet. val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema(relation.schema) LogicalRelation(relation, schema.toAttributes, Some(table), false) } From 090fda2c71b26b042284ea7ce3b5ec03ded74f9a Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 20 Nov 2020 02:04:17 +0800 Subject: [PATCH 03/11] fix test --- .../spark/sql/execution/datasources/v2/PushDownUtils.scala | 4 +++- .../spark/sql/execution/command/PlanResolutionSuite.scala | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala index ce8edce6f08d6..2208e930f6b08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownUtils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, NamedExpression, PredicateHelper, SchemaPruning} +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.read.{Scan, ScanBuilder, SupportsPushDownFilters, SupportsPushDownRequiredColumns} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.internal.SQLConf @@ -110,7 +111,8 @@ object PushDownUtils extends PredicateHelper { schema: StructType, relation: DataSourceV2Relation): Seq[AttributeReference] = { val nameToAttr = relation.output.map(_.name).zip(relation.output).toMap - schema.toAttributes.map { + val cleaned = CharVarcharUtils.replaceCharVarcharWithString(schema).asInstanceOf[StructType] + cleaned.toAttributes.map { // we have to keep the attribute id during transformation a => a.withExprId(nameToAttr(a.name).exprId) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index ee8b3efb8eb31..810942a1371c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{CharType, DoubleType, IntegerType, LongType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.types.{CharType, DoubleType, IntegerType, LongType, StringType, StructField, StructType} class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ From f46e32fb1649023eed0ddab4cb23ca4a97b14a0f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2020 01:42:53 +0800 Subject: [PATCH 04/11] update --- docs/sql-ref-datatypes.md | 2 ++ .../spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../spark/sql/catalyst/util/CharVarcharUtils.scala | 10 ---------- .../scala/org/apache/spark/sql/types/DataType.scala | 4 +++- .../spark/sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../org/apache/spark/sql/CharVarcharTestSuite.scala | 11 ++++------- 6 files changed, 11 insertions(+), 20 deletions(-) diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index f27f1a0ca967f..322b223d1de4f 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -37,6 +37,8 @@ Spark SQL and DataFrames support the following data types: - `DecimalType`: Represents arbitrary-precision signed decimal numbers. Backed internally by `java.math.BigDecimal`. A `BigDecimal` consists of an arbitrary precision integer unscaled value and a 32-bit integer scale. * String type - `StringType`: Represents character string values. + - `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators. + - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Data writing will pad the input string if its length is smaller than the char type length. Char type comparison will pad the short one to the longer length. * Binary type - `BinaryType`: Represents byte sequence values. * Boolean type diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index bb03370dd4166..0873eac0c1b1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -94,7 +94,7 @@ trait CheckAnalysis extends PredicateHelper { case p if p.analyzed => // Skip already analyzed sub-plans - case p if p.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => + case p if p.resolved && p.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => throw new IllegalStateException( "[BUG] logical plan should not have output of char/varchar type: " + p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index eca6339073992..00a0689243617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -92,16 +92,6 @@ object CharVarcharUtils { } } - /** - * Re-construct the original StructType from the type strings in the metadata of StructFields. - * This is needed when dealing with char/varchar columns/fields. - */ - def getRawSchema(schema: StructType): StructType = { - StructType(schema.map { field => - getRawType(field.metadata).map(rawType => field.copy(dataType = rawType)).getOrElse(field) - }) - } - /** * Returns expressions to apply read-side char type padding for the given attributes. String * values should be right-padded to N characters if it's from a CHAR(N) column/field. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 6b871c9783471..5f6ebb2f20814 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.internal.SQLConf @@ -132,7 +133,8 @@ object DataType { ddl, CatalystSqlParser.parseDataType, "Cannot parse the data type: ", - fallbackParser = CatalystSqlParser.parseTableSchema) + fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithString( + CatalystSqlParser.parseTableSchema(str))) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6820d5d189537..592b89dee367e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -58,7 +58,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("fail for leaf node with char/varchar type") { + test("fail if a plan node has char/varchar type output") { val schema1 = new StructType().add("c", CharType(5)) val schema2 = new StructType().add("c", VarcharType(5)) val schema3 = new StructType().add("c", ArrayType(CharType(5))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index e192a63956232..97e66b2dfe2bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -35,13 +35,10 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } test("char type values should be padded: partitioned columns") { - // DS V2 doesn't support partitioned table. - if (!conf.contains(SQLConf.DEFAULT_CATALOG.key)) { - withTable("t") { - sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format PARTITIONED BY (c)") - sql("INSERT INTO t VALUES ('1', 'a')") - checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) - } + withTable("t") { + sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format PARTITIONED BY (c)") + sql("INSERT INTO t VALUES ('1', 'a')") + checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) } } From e5fb41aa2aa220b0750c6d6812e34075a9c118b2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2020 19:17:37 +0800 Subject: [PATCH 05/11] remove dead code --- .../org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 00a0689243617..d6fb6d1cc40b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -222,8 +222,6 @@ object CharVarcharUtils { }) case (ArrayType(et1, _), ArrayType(et2, _)) => ArrayType(typeWithWiderCharLength(et1, et2)) - case (MapType(kt1, vt1, _), MapType(kt2, vt2, _)) => - MapType(typeWithWiderCharLength(kt1, kt2), typeWithWiderCharLength(vt1, vt2)) case _ => NullType } } From b6d74c481ecb651286685490b1beded99f0d50f9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2020 21:33:13 +0800 Subject: [PATCH 06/11] fix scala 2.13 --- .../org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index d6fb6d1cc40b6..7ff01b831419c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -248,7 +248,7 @@ object CharVarcharUtils { createStructExprs += padded.getOrElse(fieldExpr) i += 1 } - if (needPadding) Some(CreateNamedStruct(createStructExprs)) else None + if (needPadding) Some(CreateNamedStruct(createStructExprs.toSeq)) else None case (ArrayType(et, containsNull), ArrayType(target, _)) => val param = NamedLambdaVariable("x", replaceCharVarcharWithString(et), containsNull) From 38999b535e78817d2647d186605618438f438220 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 24 Nov 2020 23:06:32 +0800 Subject: [PATCH 07/11] fix --- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 0873eac0c1b1e..b1a06a3c855e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -94,9 +94,9 @@ trait CheckAnalysis extends PredicateHelper { case p if p.analyzed => // Skip already analyzed sub-plans - case p if p.resolved && p.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => + case leaf: LeafNode if leaf.output.map(_.dataType).exists(CharVarcharUtils.hasCharVarchar) => throw new IllegalStateException( - "[BUG] logical plan should not have output of char/varchar type: " + p) + "[BUG] logical plan should not have output of char/varchar type: " + leaf) case u: UnresolvedNamespace => u.failAnalysis(s"Namespace not found: ${u.multipartIdentifier.quoted}") From 671471f9dad6a3cffa4c5442daba32424d9d3c13 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 26 Nov 2020 02:11:10 +0800 Subject: [PATCH 08/11] add more tests and fix bugs --- docs/sql-ref-datatypes.md | 2 +- .../analysis/ResolvePartitionSpec.scala | 4 +- .../sql/catalyst/util/CharVarcharUtils.scala | 14 +- .../org/apache/spark/sql/types/DataType.scala | 2 +- .../sql/catalyst/analysis/AnalysisSuite.scala | 2 +- .../spark/sql/CharVarcharTestSuite.scala | 130 +++++++++++++++++- 6 files changed, 146 insertions(+), 8 deletions(-) diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index 322b223d1de4f..fa829623545e1 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -38,7 +38,7 @@ Spark SQL and DataFrames support the following data types: * String type - `StringType`: Represents character string values. - `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators. - - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Data writing will pad the input string if its length is smaller than the char type length. Char type comparison will pad the short one to the longer length. + - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `VarcharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length. * Binary type - `BinaryType`: Represents byte sequence values. * Boolean type diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala index 6d061fce06919..98c6872a47cc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolvePartitionSpec.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical.{AlterTableAddPartition, AlterTableDropPartition, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.connector.catalog.SupportsPartitionManagement import org.apache.spark.sql.types._ import org.apache.spark.sql.util.PartitioningUtils.normalizePartitionSpec @@ -66,7 +67,8 @@ object ResolvePartitionSpec extends Rule[LogicalPlan] { val partValues = partSchema.map { part => val raw = normalizedSpec.get(part.name).orNull - Cast(Literal.create(raw, StringType), part.dataType, Some(conf.sessionLocalTimeZone)).eval() + val dt = CharVarcharUtils.replaceCharVarcharWithString(part.dataType) + Cast(Literal.create(raw, StringType), dt, Some(conf.sessionLocalTimeZone)).eval() } InternalRow.fromSeq(partValues) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index 7ff01b831419c..e8b09fd1247d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -110,9 +110,14 @@ object CharVarcharUtils { case CharType(length) => StringRPad(expr, Literal(length)) case StructType(fields) => - CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => Seq(Literal(f.name), charTypePadding(GetStructField(expr, i, Some(f.name)), f.dataType)) }) + if (expr.nullable) { + If(IsNull(expr), Literal(null, struct.dataType), struct) + } else { + struct + } case ArrayType(et, containsNull) => charTypePaddingInArray(expr, et, containsNull) @@ -173,9 +178,14 @@ object CharVarcharUtils { StringRPad(trimmed, Literal(length)))) case StructType(fields) => - CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => + val struct = CreateNamedStruct(fields.zipWithIndex.flatMap { case (f, i) => Seq(Literal(f.name), stringLengthCheck(GetStructField(expr, i, Some(f.name)), f.dataType)) }) + if (expr.nullable) { + If(IsNull(expr), Literal(null, struct.dataType), struct) + } else { + struct + } case ArrayType(et, containsNull) => stringLengthCheckInArray(expr, et, containsNull) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 5f6ebb2f20814..73022de572747 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -133,7 +133,7 @@ object DataType { ddl, CatalystSqlParser.parseDataType, "Cannot parse the data type: ", - fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithString( + fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithStringInSchema( CatalystSqlParser.parseTableSchema(str))) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 592b89dee367e..0afa811e5d590 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -58,7 +58,7 @@ class AnalysisSuite extends AnalysisTest with Matchers { } } - test("fail if a plan node has char/varchar type output") { + test("fail if a leaf node has char/varchar type output") { val schema1 = new StructType().add("c", CharType(5)) val schema2 = new StructType().add("c", VarcharType(5)) val schema3 = new StructType().add("c", ArrayType(CharType(5))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 97e66b2dfe2bc..d5100c237f732 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -18,19 +18,34 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkException} -import org.apache.spark.sql.connector.InMemoryTableCatalog +import org.apache.spark.sql.catalyst.util.CharVarcharUtils +import org.apache.spark.sql.connector.{InMemoryPartitionTableCatalog, SchemaRequiredDataSource} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.SimpleInsertSource import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils} +import org.apache.spark.sql.types.{ArrayType, CharType, DataType, MapType, StringType, StructField, StructType} +// The base trait for char/varchar tests that need to be run with different table implementations. trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { def format: String + def checkColType(f: StructField, dt: DataType): Unit = { + assert(f.dataType == CharVarcharUtils.replaceCharVarcharWithString(dt)) + assert(CharVarcharUtils.getRawType(f.metadata) == Some(dt)) + } + test("char type values should be padded: top-level columns") { withTable("t") { sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format") sql("INSERT INTO t VALUES ('1', 'a')") checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + checkColType(spark.table("t").schema(1), CharType(5)) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) } } @@ -39,6 +54,11 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c CHAR(5)) USING $format PARTITIONED BY (c)") sql("INSERT INTO t VALUES ('1', 'a')") checkAnswer(spark.table("t"), Row("1", "a" + " " * 4)) + checkColType(spark.table("t").schema(1), CharType(5)) + + sql("ALTER TABLE t DROP PARTITION(c='a')") + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) } } @@ -47,6 +67,12 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c STRUCT) USING $format") sql("INSERT INTO t VALUES ('1', struct('a'))") checkAnswer(spark.table("t"), Row("1", Row("a" + " " * 4))) + checkColType(spark.table("t").schema(1), new StructType().add("c", CharType(5))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', struct(null))") + checkAnswer(spark.table("t"), Row("1", Row(null))) } } @@ -55,6 +81,12 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c ARRAY) USING $format") sql("INSERT INTO t VALUES ('1', array('a', 'ab'))") checkAnswer(spark.table("t"), Row("1", Seq("a" + " " * 4, "ab" + " " * 3))) + checkColType(spark.table("t").schema(1), ArrayType(CharType(5))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', array(null))") + checkAnswer(spark.table("t"), Row("1", Seq(null))) } } @@ -63,6 +95,10 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab")))) + checkColType(spark.table("t").schema(1), MapType(CharType(5), StringType)) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) } } @@ -71,6 +107,12 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") checkAnswer(spark.table("t"), Row("1", Map(("a", "ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), MapType(StringType, CharType(5))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', map('a', null))") + checkAnswer(spark.table("t"), Row("1", Map("a" -> null))) } } @@ -79,6 +121,10 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c MAP) USING $format") sql("INSERT INTO t VALUES ('1', map('a', 'ab'))") checkAnswer(spark.table("t"), Row("1", Map(("a" + " " * 4, "ab" + " " * 8)))) + checkColType(spark.table("t").schema(1), MapType(CharType(5), CharType(10))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) } } @@ -87,6 +133,15 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c STRUCT>) USING $format") sql("INSERT INTO t VALUES ('1', struct(array('a', 'ab')))") checkAnswer(spark.table("t"), Row("1", Row(Seq("a" + " " * 4, "ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), + new StructType().add("c", ArrayType(CharType(5)))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', struct(null))") + checkAnswer(spark.table("t"), Row("1", Row(null))) + sql("INSERT OVERWRITE t VALUES ('1', struct(array(null)))") + checkAnswer(spark.table("t"), Row("1", Row(Seq(null)))) } } @@ -95,6 +150,15 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") sql("INSERT INTO t VALUES ('1', array(struct('a'), struct('ab')))") checkAnswer(spark.table("t"), Row("1", Seq(Row("a" + " " * 4), Row("ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), + ArrayType(new StructType().add("c", CharType(5)))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', array(null))") + checkAnswer(spark.table("t"), Row("1", Seq(null))) + sql("INSERT OVERWRITE t VALUES ('1', array(struct(null)))") + checkAnswer(spark.table("t"), Row("1", Seq(Row(null)))) } } @@ -103,6 +167,14 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { sql(s"CREATE TABLE t(i STRING, c ARRAY>) USING $format") sql("INSERT INTO t VALUES ('1', array(array('a', 'ab')))") checkAnswer(spark.table("t"), Row("1", Seq(Seq("a" + " " * 4, "ab" + " " * 3)))) + checkColType(spark.table("t").schema(1), ArrayType(ArrayType(CharType(5)))) + + sql("INSERT OVERWRITE t VALUES ('1', null)") + checkAnswer(spark.table("t"), Row("1", null)) + sql("INSERT OVERWRITE t VALUES ('1', array(null))") + checkAnswer(spark.table("t"), Row("1", Seq(null))) + sql("INSERT OVERWRITE t VALUES ('1', array(array(null)))") + checkAnswer(spark.table("t"), Row("1", Seq(Seq(null)))) } } @@ -353,6 +425,60 @@ trait CharVarcharTestSuite extends QueryTest with SQLTestUtils { } } +// Some basic char/varchar tests which doesn't rely on table implementation. +class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { + import testImplicits._ + + test("user-specified schema in cast") { + def assertNoCharType(df: DataFrame): Unit = { + checkAnswer(df, Row("0")) + assert(df.schema.map(_.dataType) == Seq(StringType)) + } + + assertNoCharType(spark.range(1).select($"id".cast("char(5)"))) + assertNoCharType(spark.range(1).select($"id".cast(CharType(5)))) + assertNoCharType(spark.range(1).selectExpr("CAST(id AS CHAR(5))")) + assertNoCharType(sql("SELECT CAST(id AS CHAR(5)) FROM range(1)")) + } + + test("user-specified schema in functions") { + val df = sql("""SELECT from_json('{"a": "str"}', 'a CHAR(5)')""") + checkAnswer(df, Row(Row("str"))) + val schema = df.schema.head.dataType.asInstanceOf[StructType] + assert(schema.map(_.dataType) == Seq(StringType)) + } + + test("user-specified schema in DataFrameReader: DSV1") { + def checkSchema(df: DataFrame): Unit = { + val relations = df.queryExecution.analyzed.collect { + case l: LogicalRelation => l.relation + } + assert(relations.length == 1) + assert(relations.head.schema.map(_.dataType) == Seq(StringType)) + } + + checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) + .format(classOf[SimpleInsertSource].getName).load()) + checkSchema(spark.read.schema("id char(5)") + .format(classOf[SimpleInsertSource].getName).load()) + } + + test("user-specified schema in DataFrameReader: DSV2") { + def checkSchema(df: DataFrame): Unit = { + val tables = df.queryExecution.analyzed.collect { + case d: DataSourceV2Relation => d.table + } + assert(tables.length == 1) + assert(tables.head.schema.map(_.dataType) == Seq(StringType)) + } + + checkSchema(spark.read.schema(new StructType().add("id", CharType(5))) + .format(classOf[SchemaRequiredDataSource].getName).load()) + checkSchema(spark.read.schema("id char(5)") + .format(classOf[SchemaRequiredDataSource].getName).load()) + } +} + class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSparkSession { override def format: String = "parquet" override protected def sparkConf: SparkConf = { @@ -365,7 +491,7 @@ class DSV2CharVarcharTestSuite extends CharVarcharTestSuite override def format: String = "foo" protected override def sparkConf = { super.sparkConf - .set("spark.sql.catalog.testcat", classOf[InMemoryTableCatalog].getName) + .set("spark.sql.catalog.testcat", classOf[InMemoryPartitionTableCatalog].getName) .set(SQLConf.DEFAULT_CATALOG.key, "testcat") } } From 69adca50f12fde7deefc7bef2e31d459526cac3c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 27 Nov 2020 14:43:27 +0800 Subject: [PATCH 09/11] more fixes --- docs/sql-ref-datatypes.md | 2 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 +++- .../spark/sql/catalyst/util/CharVarcharUtils.scala | 5 +++-- .../org/apache/spark/sql/types/CharType.scala | 2 +- .../org/apache/spark/sql/types/DataType.scala | 4 +--- .../org/apache/spark/sql/types/VarcharType.scala | 2 +- .../org/apache/spark/sql/DataFrameReader.scala | 14 ++++---------- .../spark/sql/streaming/DataStreamReader.scala | 10 +++------- .../apache/spark/sql/CharVarcharTestSuite.scala | 8 ++++++++ 9 files changed, 25 insertions(+), 26 deletions(-) diff --git a/docs/sql-ref-datatypes.md b/docs/sql-ref-datatypes.md index fa829623545e1..0087867a8c7f7 100644 --- a/docs/sql-ref-datatypes.md +++ b/docs/sql-ref-datatypes.md @@ -38,7 +38,7 @@ Spark SQL and DataFrames support the following data types: * String type - `StringType`: Represents character string values. - `VarcharType(length)`: A variant of `StringType` which has a length limitation. Data writing will fail if the input string exceeds the length limitation. Note: this type can only be used in table schema, not functions/operators. - - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `VarcharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length. + - `CharType(length)`: A variant of `VarcharType(length)` which is fixed length. Reading column of type `CharType(n)` always returns string values of length `n`. Char type column comparison will pad the short one to the longer length. * Binary type - `BinaryType`: Represents byte sequence values. * Boolean type diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3d428b39a6db0..d173756a45f32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -99,7 +99,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with SQLConfHelper with Logg } override def visitSingleTableSchema(ctx: SingleTableSchemaContext): StructType = { - withOrigin(ctx)(StructType(visitColTypeList(ctx.colTypeList))) + val schema = CharVarcharUtils.replaceCharVarcharWithStringInSchema( + StructType(visitColTypeList(ctx.colTypeList))) + withOrigin(ctx)(schema) } def parseRawDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala index e8b09fd1247d2..0cbe5abdbbd7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/CharVarcharUtils.scala @@ -137,8 +137,9 @@ object CharVarcharUtils { } /** - * Returns an expression to apply write-side char type padding for the given expression. A string - * value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) column/field. + * Returns an expression to apply write-side string length check for the given expression. A + * string value can not exceed N characters if it's written into a CHAR(N)/VARCHAR(N) + * column/field. */ def stringLengthCheck(expr: Expression, targetAttr: Attribute): Expression = { getRawType(targetAttr.metadata).map { rawType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala index b329b5a964c87..67ab1cc2f3321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/CharType.scala @@ -25,7 +25,7 @@ import org.apache.spark.unsafe.types.UTF8String @Experimental case class CharType(length: Int) extends AtomicType { - require(length >= 0, "The length if char type cannot be negative.") + require(length >= 0, "The length of char type cannot be negative.") private[sql] type InternalType = UTF8String @transient private[sql] lazy val tag = typeTag[InternalType] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 73022de572747..e4ee6eb377a4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DataTypeJsonUtils.{DataTypeJsonDeserializer, DataTypeJsonSerializer} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat import org.apache.spark.sql.internal.SQLConf @@ -133,8 +132,7 @@ object DataType { ddl, CatalystSqlParser.parseDataType, "Cannot parse the data type: ", - fallbackParser = str => CharVarcharUtils.replaceCharVarcharWithStringInSchema( - CatalystSqlParser.parseTableSchema(str))) + fallbackParser = str => CatalystSqlParser.parseTableSchema(str)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala index dd52b76ee2783..8d78640c1e125 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/VarcharType.scala @@ -24,7 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String @Experimental case class VarcharType(length: Int) extends AtomicType { - require(length >= 0, "The length if varchar type cannot be negative.") + require(length >= 0, "The length of varchar type cannot be negative.") private[sql] type InternalType = UTF8String @transient private[sql] lazy val tag = typeTag[InternalType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3b5532ccb910f..49b3335bf1769 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -73,7 +73,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { * @since 1.4.0 */ def schema(schema: StructType): DataFrameReader = { - this.userSpecifiedSchema = Option(schema) + this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) this } @@ -274,14 +274,11 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { extraOptions + ("paths" -> objectMapper.writeValueAsString(paths.toArray)) } - val cleanedUserSpecifiedSchema = userSpecifiedSchema - .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) - val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val (table, catalog, ident) = provider match { - case _: SupportsCatalogOptions if cleanedUserSpecifiedSchema.nonEmpty => + case _: SupportsCatalogOptions if userSpecifiedSchema.nonEmpty => throw new IllegalArgumentException( s"$source does not support user specified schema. Please don't specify the schema.") case hasCatalog: SupportsCatalogOptions => @@ -293,8 +290,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { (catalog.loadTable(ident), Some(catalog), Some(ident)) case _ => // TODO: Non-catalog paths for DSV2 are currently not well defined. - val tbl = DataSourceV2Utils.getTableFromProvider( - provider, dsOptions, cleanedUserSpecifiedSchema) + val tbl = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) (tbl, None, None) } import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ @@ -316,15 +312,13 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } else { (paths, extraOptions) } - val cleanedUserSpecifiedSchema = userSpecifiedSchema - .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) // Code path for data source v1. sparkSession.baseRelationToDataFrame( DataSource.apply( sparkSession, paths = finalPaths, - userSpecifiedSchema = cleanedUserSpecifiedSchema, + userSpecifiedSchema = userSpecifiedSchema, className = source, options = finalOptions.originalMap).resolveRelation()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index b9a1a465d9e52..4e755682242d9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -64,7 +64,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo * @since 2.0.0 */ def schema(schema: StructType): DataStreamReader = { - this.userSpecifiedSchema = Option(schema) + this.userSpecifiedSchema = Option(CharVarcharUtils.replaceCharVarcharWithStringInSchema(schema)) this } @@ -203,9 +203,6 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo extraOptions + ("path" -> path.get) } - val cleanedUserSpecifiedSchema = userSpecifiedSchema - .map(CharVarcharUtils.replaceCharVarcharWithStringInSchema) - val ds = DataSource.lookupDataSource(source, sparkSession.sqlContext.conf). getConstructor().newInstance() // We need to generate the V1 data source so we can pass it to the V2 relation as a shim. @@ -213,7 +210,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo // writer or whether the query is continuous. val v1DataSource = DataSource( sparkSession, - userSpecifiedSchema = cleanedUserSpecifiedSchema, + userSpecifiedSchema = userSpecifiedSchema, className = source, options = optionsWithPath.originalMap) val v1Relation = ds match { @@ -228,8 +225,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val finalOptions = sessionOptions.filterKeys(!optionsWithPath.contains(_)).toMap ++ optionsWithPath.originalMap val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) - val table = DataSourceV2Utils.getTableFromProvider( - provider, dsOptions, cleanedUserSpecifiedSchema) + val table = DataSourceV2Utils.getTableFromProvider(provider, dsOptions, userSpecifiedSchema) import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index d5100c237f732..abb13270d20e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -448,6 +448,14 @@ class BasicCharVarcharTestSuite extends QueryTest with SharedSparkSession { assert(schema.map(_.dataType) == Seq(StringType)) } + test("user-specified schema in DataFrameReader: file source from Dataset") { + val ds = spark.range(10).map(_.toString) + val df1 = spark.read.schema(new StructType().add("id", CharType(5))).csv(ds) + assert(df1.schema.map(_.dataType) == Seq(StringType)) + val df2 = spark.read.schema("id char(5)").csv(ds) + assert(df2.schema.map(_.dataType) == Seq(StringType)) + } + test("user-specified schema in DataFrameReader: DSV1") { def checkSchema(df: DataFrame): Unit = { val relations = df.queryExecution.analyzed.collect { From 3bbe7e784b0544dbd05fbe113ade387d080ba990 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 27 Nov 2020 16:32:13 +0800 Subject: [PATCH 10/11] fix test --- .../spark/sql/catalyst/parser/TableSchemaParserSuite.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala index 5519f016e48d3..95851d44b4747 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableSchemaParserSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.types._ class TableSchemaParserSuite extends SparkFunSuite { @@ -68,7 +69,8 @@ class TableSchemaParserSuite extends SparkFunSuite { StructField("arrAy", ArrayType(DoubleType)) :: StructField("anotherArray", ArrayType(CharType(9))) :: Nil)) :: Nil) - assert(parse(tableSchemaString) === expectedDataType) + assert(parse(tableSchemaString) === + CharVarcharUtils.replaceCharVarcharWithStringInSchema(expectedDataType)) } // Negative cases From 73b99dc7a98f5d0673d309adba457cd19144be92 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 27 Nov 2020 18:44:33 +0800 Subject: [PATCH 11/11] fix JDBC --- .../spark/sql/execution/datasources/jdbc/JdbcUtils.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 6733aab947be6..5dd0d2bd74838 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.execution.datasources.jdbc.connection.ConnectionProvider import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} @@ -761,7 +761,10 @@ object JdbcUtils extends Logging { schema: StructType, caseSensitive: Boolean, createTableColumnTypes: String): Map[String, String] = { - val userSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val parsedSchema = CatalystSqlParser.parseTableSchema(createTableColumnTypes) + val userSchema = StructType(parsedSchema.map { field => + field.copy(dataType = CharVarcharUtils.getRawType(field.metadata).getOrElse(field.dataType)) + }) val nameEquality = if (caseSensitive) { org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution } else {