diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index be527005bc0ad..62433da4b4e23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -103,6 +103,7 @@ class Analyzer( ResolveWindowOrder :: ResolveWindowFrame :: ResolveNaturalAndUsingJoin :: + ResolveOutputColumns :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -451,7 +452,7 @@ class Analyzer( } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => + case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _, _) if child.resolved => // A partitioned relation's schema can be different from the input logicalPlan, since // partition columns are all moved after data columns. We Project to adjust the ordering. val input = if (parts.nonEmpty) { @@ -516,6 +517,124 @@ class Analyzer( } } + object ResolveOutputColumns extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case ins @ InsertIntoTable(relation: LogicalPlan, partition, _, _, _, _) + if relation.resolved && ins.childrenResolved && !ins.resolved => + resolveOutputColumns(ins, expectedColumns(relation, partition), relation.toString) + } + + private def resolveOutputColumns( + insertInto: InsertIntoTable, + columns: Seq[Attribute], + relation: String) = { + val resolved = if (insertInto.isMatchByName) { + projectAndCastOutputColumns(columns, insertInto.child, relation) + } else { + castAndRenameOutputColumns(columns, insertInto.child, relation) + } + + if (resolved == insertInto.child.output) { + insertInto + } else { + insertInto.copy(child = Project(resolved, insertInto.child)) + } + } + + /** + * Resolves output columns by input column name, adding casts if necessary. + */ + private def projectAndCastOutputColumns( + output: Seq[Attribute], + data: LogicalPlan, + relation: String): Seq[NamedExpression] = { + if (output.size > data.output.size) { + // always a problem + throw new AnalysisException( + s"""Not enough data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${output.mkString(",")}""".stripMargin) + } else if (output.size < data.output.size) { + // be conservative and fail if there are too many columns + throw new AnalysisException( + s"""Extra data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${output.mkString(",")}""".stripMargin) + } + + output.map { col => + data.resolveQuoted(col.name, resolver) match { + case Some(inCol) if !col.dataType.sameType(inCol.dataType) => + Alias(UpCast(inCol, col.dataType, Seq()), col.name)() + case Some(inCol) => inCol + case None => + throw new AnalysisException( + s"Cannot resolve ${col.name} in ${data.output.mkString(",")}") + } + } + } + + private def castAndRenameOutputColumns( + output: Seq[Attribute], + data: LogicalPlan, + relation: String): Seq[NamedExpression] = { + val outputNames = output.map(_.name) + // incoming expressions may not have names + val inputNames = data.output.flatMap(col => Option(col.name)) + if (output.size > data.output.size) { + // always a problem + throw new AnalysisException( + s"""Not enough data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } else if (output.size < data.output.size) { + // be conservative and fail if there are too many columns + throw new AnalysisException( + s"""Extra data columns to write into $relation: + |Data columns: ${data.output.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } else { + // check for reordered names and warn. this may be on purpose, so it isn't an error. + if (outputNames.toSet == inputNames.toSet && outputNames != inputNames) { + logWarning( + s"""Data column names match the table in a different order: + |Data columns: ${inputNames.mkString(",")} + |Table columns: ${outputNames.mkString(",")}""".stripMargin) + } + } + + data.output.zip(output).map { + case (in, out) if !in.dataType.sameType(out.dataType) => + Alias(Cast(in, out.dataType), out.name)() + case (in, out) if in.name != out.name => + Alias(in, out.name)() + case (in, _) => in + } + } + + private def expectedColumns( + data: LogicalPlan, + partitionData: Map[String, Option[String]]): Seq[Attribute] = { + data match { + case partitioned: CatalogRelation => + val tablePartitionNames = partitioned.catalogTable.partitionColumns.map(_.name) + val (inputPartCols, dataColumns) = data.output.partition { attr => + tablePartitionNames.contains(attr.name) + } + // Get the dynamic partition columns in partition order + val dynamicNames = tablePartitionNames.filter( + name => partitionData.getOrElse(name, None).isEmpty) + val dynamicPartCols = dynamicNames.map { name => + inputPartCols.find(_.name == name).getOrElse( + throw new AnalysisException(s"Cannot find partition column $name")) + } + + dataColumns ++ dynamicPartCols + case _ => data.output + } + } + } + /** * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from * a logical plan node's children. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b451baaa02b9..edce9c4e22bfa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -313,7 +313,7 @@ trait CheckAnalysis extends PredicateHelper { |${s.catalogTable.identifier} """.stripMargin) - case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) => + case InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) => failAnalysis( s""" |Hive support is required to insert into the following tables: diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2ca990d19a2cb..a5c2fd038ace8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -367,7 +367,7 @@ package object dsl { def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( analysis.UnresolvedRelation(TableIdentifier(tableName)), - Map.empty, logicalPlan, overwrite, false) + Map.empty, logicalPlan, overwrite, ifNotExists = false, Map.empty) def as(alias: String): LogicalPlan = logicalPlan match { case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e380643f548ba..38c1f3c72031d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -175,8 +175,9 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { UnresolvedRelation(tableIdent, None), partitionKeys, query, - ctx.OVERWRITE != null, - ctx.EXISTS != null) + overwrite = ctx.OVERWRITE != null, + ifNotExists = ctx.EXISTS != null, + Map.empty /* SQL always matches by position */) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 898784dab1d98..bbb8e4221c450 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -359,30 +359,43 @@ case class InsertIntoTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + options: Map[String, String]) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil override def output: Seq[Attribute] = Seq.empty + private[spark] def isMatchByName: Boolean = { + options.get("matchByName").map(_.toBoolean).getOrElse(false) + } + private[spark] lazy val expectedColumns = { if (table.output.isEmpty) { None } else { - val numDynamicPartitions = partition.values.count(_.isEmpty) + val dynamicPartitionNames = partition.filter { + case (name, Some(_)) => false + case (name, None) => true + }.keySet val (partitionColumns, dataColumns) = table.output .partition(a => partition.keySet.contains(a.name)) - Some(dataColumns ++ partitionColumns.takeRight(numDynamicPartitions)) + Some(dataColumns ++ partitionColumns.filter(col => dynamicPartitionNames.contains(col.name))) } } assert(overwrite || !ifNotExists) override lazy val resolved: Boolean = - childrenResolved && table.resolved && expectedColumns.forall { expected => - child.output.size == expected.size && child.output.zip(expected).forall { - case (childAttr, tableAttr) => - DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) - } + childrenResolved && table.resolved && { + expectedColumns match { + case Some(expected) => + child.output.size == expected.size && child.output.zip(expected).forall { + case (childAttr, tableAttr) => + childAttr.name == tableAttr.name && // required by some relations + DataType.equalsIgnoreCompatibleNullability(childAttr.dataType, tableAttr.dataType) + } + case None => true + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 77023cfd3d60f..9c21ca98c75f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -178,7 +178,7 @@ class PlanParserSuite extends PlanTest { partition: Map[String, Option[String]], overwrite: Boolean = false, ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists, Map.empty) // Single inserts assertEqual(s"insert overwrite table s $sql", @@ -196,9 +196,11 @@ class PlanParserSuite extends PlanTest { val plan2 = table("t").where('x > 5).select(star()) assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false, + Map.empty).union( InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false, + Map.empty))) } test("aggregation") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 1c2003c18e3fc..d8757c8dcde89 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -512,7 +512,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitions.getOrElse(Map.empty[String, Option[String]]), df.logicalPlan, overwrite, - ifNotExists = false)).toRdd + ifNotExists = false, + options = extraOptions.toMap)).toRdd } private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2b4786542c72f..436e7d10bc4d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -46,7 +46,7 @@ import org.apache.spark.unsafe.types.UTF8String private[sql] object DataSourceAnalysis extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transform { case i @ logical.InsertIntoTable( - l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false) + l @ LogicalRelation(t: HadoopFsRelation, _, _), part, query, overwrite, false, _) if query.resolved && t.schema.asNullable == query.schema.asNullable => // Sanity checks @@ -110,7 +110,7 @@ private[sql] class FindDataSourceTable(sparkSession: SparkSession) extends Rule[ } override def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _) + case i @ logical.InsertIntoTable(s: SimpleCatalogRelation, _, _, _, _, _) if DDLUtils.isDatasourceTable(s.metadata) => i.copy(table = readDataSourceTable(sparkSession, s.metadata)) @@ -152,7 +152,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), - part, query, overwrite, false) if part.isEmpty => + part, query, overwrite, false, _) if part.isEmpty => ExecutedCommandExec(InsertIntoDataSourceCommand(l, query, overwrite)) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7ac62fb191d40..23827b1ec239e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -61,55 +61,6 @@ private[sql] class ResolveDataSource(sparkSession: SparkSession) extends Rule[Lo } } -/** - * A rule to do pre-insert data type casting and field renaming. Before we insert into - * an [[InsertableRelation]], we will use this rule to make sure that - * the columns to be inserted have the correct data type and fields have the correct names. - */ -private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - // We are inserting into an InsertableRelation or HadoopFsRelation. - case i @ InsertIntoTable( - l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _), _, child, _, _) => - // First, make sure the data to be inserted have the same number of fields with the - // schema of the relation. - if (l.output.size != child.output.size) { - sys.error( - s"$l requires that the query in the SELECT clause of the INSERT INTO/OVERWRITE " + - s"statement generates the same number of columns as its schema.") - } - castAndRenameChildOutput(i, l.output, child) - } - - /** If necessary, cast data types and rename fields to the expected types and names. */ - def castAndRenameChildOutput( - insertInto: InsertIntoTable, - expectedOutput: Seq[Attribute], - child: LogicalPlan): InsertIntoTable = { - val newChildOutput = expectedOutput.zip(child.output).map { - case (expected, actual) => - val needCast = !expected.dataType.sameType(actual.dataType) - // We want to make sure the filed names in the data to be inserted exactly match - // names in the schema. - val needRename = expected.name != actual.name - (needCast, needRename) match { - case (true, _) => Alias(Cast(actual, expected.dataType), expected.name)() - case (false, true) => Alias(actual, expected.name)() - case (_, _) => actual - } - } - - if (newChildOutput == child.output) { - insertInto - } else { - insertInto.copy(child = Project(newChildOutput, child)) - } - } -} - /** * A rule to do various checks before inserting into or writing to a data source table. */ @@ -122,7 +73,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) plan.foreach { case i @ logical.InsertIntoTable( l @ LogicalRelation(t: InsertableRelation, _, _), - partition, query, overwrite, ifNotExists) => + partition, query, overwrite, ifNotExists, _) => // Right now, we do not support insert into a data source table with partition specs. if (partition.nonEmpty) { failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.") @@ -140,7 +91,7 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) } case logical.InsertIntoTable( - LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _) => + LogicalRelation(r: HadoopFsRelation, _, _), part, query, overwrite, _, _) => // We need to make sure the partition columns specified by users do match partition // columns of the relation. val existingPartitionColumns = r.partitionSchema.fieldNames.toSet @@ -168,11 +119,11 @@ private[sql] case class PreWriteCheck(conf: SQLConf, catalog: SessionCatalog) // OK } - case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _) => + case logical.InsertIntoTable(l: LogicalRelation, _, _, _, _, _) => // The relation in l is not an InsertableRelation. failAnalysis(s"$l does not allow insertion.") - case logical.InsertIntoTable(t, _, _, _, _) => + case logical.InsertIntoTable(t, _, _, _, _, _) => if (!t.isInstanceOf[LeafNode] || t == OneRowRelation || t.isInstanceOf[LocalRelation]) { failAnalysis(s"Inserting into an RDD-based table is not allowed.") } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index b2db377ec7f8d..0710ae6617ec9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.AnalyzeTableCommand -import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, PreInsertCastAndRename, ResolveDataSource} +import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, FindDataSourceTable, ResolveDataSource} import org.apache.spark.sql.streaming.{ContinuousQuery, ContinuousQueryManager} import org.apache.spark.sql.util.ExecutionListenerManager @@ -111,7 +111,6 @@ private[sql] class SessionState(sparkSession: SparkSession) { lazy val analyzer: Analyzer = { new Analyzer(catalog, conf) { override val extendedResolutionRules = - PreInsertCastAndRename :: new FindDataSourceTable(sparkSession) :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 4780eb473d79b..ebbae11db981d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -87,15 +87,15 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { } test("SELECT clause generating a different number of columns is not allowed.") { - val message = intercept[RuntimeException] { + val message = intercept[AnalysisException] { sql( s""" |INSERT OVERWRITE TABLE jsonTable SELECT a FROM jt """.stripMargin) }.getMessage assert( - message.contains("generates the same number of columns as its schema"), - "SELECT clause generating a different number of columns should not be not allowed." + message.contains("Not enough data columns to write"), + "SELECT clause must generate all of a table's columns to write" ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index f10afa75f2bfc..85a740c551417 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._ @@ -373,16 +372,20 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + options) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + options) // Write path - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoHiveTable(r: MetastoreRelation, + partition, child, overwrite, ifNotExists, options) // Inserting into partitioned table is not supported in Parquet data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreParquet(r) => - InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToParquetRelation(r), partition, child, overwrite, ifNotExists, + options) // Read path case relation: MetastoreRelation if shouldConvertMetastoreParquet(relation) => @@ -417,16 +420,20 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log plan transformUp { // Write path - case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists, + options) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + options) // Write path - case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite, ifNotExists) + case InsertIntoHiveTable(r: MetastoreRelation, + partition, child, overwrite, ifNotExists, options) // Inserting into partitioned table is not supported in Orc data source (yet). if !r.hiveQlTable.isPartitioned && shouldConvertMetastoreOrc(r) => - InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists) + InsertIntoTable(convertToOrcRelation(r), partition, child, overwrite, ifNotExists, + options) // Read path case relation: MetastoreRelation if shouldConvertMetastoreOrc(relation) => @@ -463,49 +470,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log allowExisting) } } - - /** - * Casts input data to correct data types according to table definition before inserting into - * that table. - */ - object PreInsertionCasts extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan.transform { - // Wait until children are resolved. - case p: LogicalPlan if !p.childrenResolved => p - - case p @ InsertIntoTable(table: MetastoreRelation, _, child, _, _) => - castChildOutput(p, table, child) - } - - def castChildOutput(p: InsertIntoTable, table: MetastoreRelation, child: LogicalPlan) - : LogicalPlan = { - val childOutputDataTypes = child.output.map(_.dataType) - val numDynamicPartitions = p.partition.values.count(_.isEmpty) - val tableOutputDataTypes = - (table.attributes ++ table.partitionKeys.takeRight(numDynamicPartitions)) - .take(child.output.length).map(_.dataType) - - if (childOutputDataTypes == tableOutputDataTypes) { - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else if (childOutputDataTypes.size == tableOutputDataTypes.size && - childOutputDataTypes.zip(tableOutputDataTypes) - .forall { case (left, right) => left.sameType(right) }) { - // If both types ignoring nullability of ArrayType, MapType, StructType are the same, - // use InsertIntoHiveTable instead of InsertIntoTable. - InsertIntoHiveTable(table, p.partition, p.child, p.overwrite, p.ifNotExists) - } else { - // Only do the casting when child output data types differ from table output data types. - val castedChildOutput = child.output.zip(table.output).map { - case (input, output) if input.dataType != output.dataType => - Alias(Cast(input, output.dataType), input.name)() - case (input, _) => input - } - - p.copy(child = logical.Project(castedChildOutput, child)) - } - } - } - } /** @@ -549,7 +513,8 @@ private[hive] case class InsertIntoHiveTable( partition: Map[String, Option[String]], child: LogicalPlan, overwrite: Boolean, - ifNotExists: Boolean) + ifNotExists: Boolean, + options: Map[String, String]) extends LogicalPlan { override def children: Seq[LogicalPlan] = child :: Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4f8aac8c2fcdd..2f6a2207855ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -87,7 +87,6 @@ private[sql] class HiveSessionCatalog( val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables - val PreInsertionCasts: Rule[LogicalPlan] = metastoreCatalog.PreInsertionCasts override def refreshTable(name: TableIdentifier): Unit = { metastoreCatalog.refreshTable(name) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index ca8e5f8223968..75899bd1cb307 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -65,8 +65,6 @@ private[hive] class HiveSessionState(sparkSession: SparkSession) catalog.ParquetConversions :: catalog.OrcConversions :: catalog.CreateTables :: - catalog.PreInsertionCasts :: - PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLonFile) new ResolveDataSource(sparkSession) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 71b180e55b58c..7d1daa496f094 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -43,11 +43,11 @@ private[hive] trait HiveStrategies { object DataSinks extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.InsertIntoTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case hive.InsertIntoHiveTable( - table: MetastoreRelation, partition, child, overwrite, ifNotExists) => + table: MetastoreRelation, partition, child, overwrite, ifNotExists, _) => execution.InsertIntoHiveTable( table, partition, planLater(child), overwrite, ifNotExists) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index b8099385a466b..4169b39c8e06e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -88,7 +88,7 @@ case class CreateHiveTableAsSelectCommand( } } else { sparkSession.sessionState.executePlan(InsertIntoTable( - metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd + metastoreRelation, Map(), query, overwrite = true, ifNotExists = false, Map.empty)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index fae59001b98e1..b277d4194633a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -284,8 +284,133 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef val data = (1 to 10).map(i => (i.toLong, s"data-$i")).toDF("id", "data") val logical = InsertIntoTable(spark.table("partitioned").logicalPlan, - Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false) + Map("part" -> None), data.logicalPlan, overwrite = false, ifNotExists = false, + Map("matchByName" -> "true")) assert(!logical.resolved, "Should not resolve: missing partition data") } } + + test("Insert unnamed expressions by position") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val expected = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + val data = expected.select("id", "part") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + // should be able to insert an expression when NOT mapping columns by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id)") + .write.insertInto("partitioned") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Insert expression by name") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val expected = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + val data = expected.select("id", "part") + + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + intercept[AnalysisException] { + // also a problem when mapping by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id)") + .write.option("matchByName", true).insertInto("partitioned") + } + + // should be able to insert an expression using AS when mapping columns by name + spark.table("source").selectExpr("id", "part", "CONCAT('data-', id) as data") + .write.option("matchByName", true).insertInto("partitioned") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Reject missing columns") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + intercept[AnalysisException] { + spark.table("source").write.insertInto("partitioned") + } + + intercept[AnalysisException] { + // also a problem when mapping by name + spark.table("source").write.option("matchByName", true).insertInto("partitioned") + } + } + } + + test("Reject extra columns") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, data string, extra string, part string)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + intercept[AnalysisException] { + spark.table("source").write.insertInto("partitioned") + } + + val data = (1 to 10) + .map(i => (i, s"data-$i", s"${i * i}", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "extra", "part") + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + intercept[AnalysisException] { + spark.table("source").write.option("matchByName", true).insertInto("partitioned") + } + + spark.table("source").select("data", "id", "part").write.option("matchByName", true) + .insertInto("partitioned") + + val expected = data.select("id", "data", "part") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } + + test("Ignore names when writing by position") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (id bigint, part string, data string)") // part, data transposed + sql("CREATE TABLE destination (id bigint, data string, part string)") + + val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) + .toDF("id", "data", "part") + + // write into the reordered table by name + data.write.option("matchByName", true).insertInto("source") + checkAnswer(sql("SELECT id, data, part FROM source"), data.collect().toSeq) + + val expected = data.select($"id", $"part" as "data", $"data" as "part") + + // this produces a warning, but writes src.part -> dest.data and src.data -> dest.part + spark.table("source").write.insertInto("destination") + checkAnswer(sql("SELECT id, data, part FROM destination"), expected.collect().toSeq) + } + } + + test("Reorder columns by name") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + sql("CREATE TABLE source (data string, part string, id bigint)") + sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") + + val data = (1 to 10).map(i => (s"data-$i", if ((i % 2) == 0) "even" else "odd", i)) + .toDF("data", "part", "id") + data.write.insertInto("source") + checkAnswer(sql("SELECT * FROM source"), data.collect().toSeq) + + spark.table("source").write.option("matchByName", true).insertInto("partitioned") + + val expected = data.select("id", "data", "part") + checkAnswer(sql("SELECT * FROM partitioned"), expected.collect().toSeq) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index a7652143a4252..ad68fc388be28 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -348,6 +348,7 @@ abstract class HiveComparisonTest val containsCommands = originalQuery.analyzed.collectFirst { case _: Command => () case _: LogicalInsertIntoHiveTable => () + case _: InsertIntoTable => () }.nonEmpty if (containsCommands) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0a2bab4f5d1e1..7eec6fb0f0d7e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1057,7 +1057,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SET hive.exec.dynamic.partition.mode=nonstrict") sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") - sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value, ds, hr FROM srcpart") .queryExecution.analyzed } @@ -1068,6 +1068,26 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + test("SPARK-14543: AnalysisException for missing partition columns") { + loadTestTable("srcpart") + sql("DROP TABLE IF EXISTS withparts") + sql("CREATE TABLE withparts LIKE srcpart") + sql("SET hive.exec.dynamic.partition.mode=nonstrict") + sql("CREATE TABLE IF NOT EXISTS withparts LIKE srcpart") + + intercept[AnalysisException] { + // src doesn't have ds and hr partition columns + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM src") + .queryExecution.analyzed + } + + intercept[AnalysisException] { + // ds and hr partition columns aren't selected + sql("INSERT INTO TABLE withparts PARTITION(ds, hr) SELECT key, value FROM srcpart") + .queryExecution.analyzed + } + } + test("parse HQL set commands") { // Adapted from its SQL counterpart. val testKey = "spark.sql.key.usedfortestonly"