From bbd572c1fe542c6b2fd642212f927ba384c882e4 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 1 Sep 2018 00:07:00 +0800 Subject: [PATCH 1/9] Fix regression in FileFormatWriter output schema --- .../spark/sql/catalyst/plans/QueryPlan.scala | 14 ++++ .../command/DataWritingCommand.scala | 2 +- .../command/createDataSourceTables.scala | 4 +- .../execution/datasources/DataSource.scala | 15 ++-- .../datasources/DataSourceStrategy.scala | 4 +- .../InsertIntoHadoopFsRelationCommand.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 76 +++++++++++++++++++ .../spark/sql/hive/HiveStrategies.scala | 6 +- .../CreateHiveTableAsSelectCommand.scala | 6 +- .../execution/InsertIntoHiveDirCommand.scala | 4 +- .../hive/execution/InsertIntoHiveTable.scala | 4 +- 11 files changed, 116 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b1ffdca091461..c0b80d1e8137d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -38,6 +38,20 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ def outputSet: AttributeSet = AttributeSet(output) + /** + * Returns output attributes with provided names. + * The length of provided names should be the same of the length of [[output]]. + */ + def outputWithNames(names: Seq[String]): Seq[Attribute] = { + // Save the output attributes to a variable to avoid duplicated function calls. + val outputAttributes = output + assert(outputAttributes.length == names.length, + "The length of provided names doesn't match the length of output attributes.") + outputAttributes.zipWithIndex.map { case (element, index) => + element.withName(names(index)) + } + } + /** * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index e11dbd201004d..ea0ca3944ee7f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -42,7 +42,7 @@ trait DataWritingCommand extends Command { override final def children: Seq[LogicalPlan] = query :: Nil // Output columns of the analyzed input query plan - def outputColumns: Seq[Attribute] + def outputColumnNames: Seq[String] lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index f6ef433f2ce15..b2e1f530b5328 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -139,7 +139,7 @@ case class CreateDataSourceTableAsSelectCommand( table: CatalogTable, mode: SaveMode, query: LogicalPlan, - outputColumns: Seq[Attribute]) + outputColumnNames: Seq[String]) extends DataWritingCommand { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { @@ -214,7 +214,7 @@ case class CreateDataSourceTableAsSelectCommand( catalogTable = if (tableExists) Some(table) else None) try { - dataSource.writeAndRead(mode, query, outputColumns, physicalPlan) + dataSource.writeAndRead(mode, query, outputColumnNames, physicalPlan) } catch { case ex: AnalysisException => logError(s"Failed to write to table ${table.identifier.unquotedString}", ex) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 1dcf9f3185de9..784b2b01eb9a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -450,7 +450,7 @@ case class DataSource( mode = mode, catalogTable = catalogTable, fileIndex = fileIndex, - outputColumns = data.output) + outputColumnNames = data.output.map(_.name)) } /** @@ -460,9 +460,9 @@ case class DataSource( * @param mode The save mode for this writing. * @param data The input query plan that produces the data to be written. Note that this plan * is analyzed and optimized. - * @param outputColumns The original output columns of the input query plan. The optimizer may not - * preserve the output column's names' case, so we need this parameter - * instead of `data.output`. + * @param outputColumnNames The original output column names of the input query plan. The + * optimizer may not preserve the output column's names' case, so we need + * this parameter instead of `data.output`. * @param physicalPlan The physical plan of the input query plan. We should run the writing * command with this physical plan instead of creating a new physical plan, * so that the metrics can be correctly linked to the given physical plan and @@ -471,8 +471,9 @@ case class DataSource( def writeAndRead( mode: SaveMode, data: LogicalPlan, - outputColumns: Seq[Attribute], + outputColumnNames: Seq[String], physicalPlan: SparkPlan): BaseRelation = { + val outputColumns = data.outputWithNames(names = outputColumnNames) if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -495,7 +496,9 @@ case class DataSource( s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") } } - val resolved = cmd.copy(partitionColumns = resolvedPartCols, outputColumns = outputColumns) + val resolved = cmd.copy( + partitionColumns = resolvedPartCols, + outputColumnNames = outputColumns.map(_.name)) resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 6b61e749e3063..c6000442fae76 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -139,7 +139,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast case CreateTable(tableDesc, mode, Some(query)) if query.resolved && DDLUtils.isDatasourceTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc.copy(schema = query.schema)) - CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output) + CreateDataSourceTableAsSelectCommand(tableDesc, mode, query, query.output.map(_.name)) case InsertIntoTable(l @ LogicalRelation(_: InsertableRelation, _, _, _), parts, query, overwrite, false) if parts.isEmpty => @@ -209,7 +209,7 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with Cast mode, table, Some(t.location), - actualQuery.output) + actualQuery.output.map(_.name)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 2ae21b7df9823..f5199b710036a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -56,7 +56,7 @@ case class InsertIntoHadoopFsRelationCommand( mode: SaveMode, catalogTable: Option[CatalogTable], fileIndex: Option[FileIndex], - outputColumns: Seq[Attribute]) + outputColumnNames: Seq[String]) extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName @@ -155,6 +155,7 @@ case class InsertIntoHadoopFsRelationCommand( } } + val outputColumns = query.outputWithNames(outputColumnNames) val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 01dc28d70184e..4f734c79e0f4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} @@ -2853,6 +2854,81 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("Insert overwrite table command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long) USING parquet") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Insert overwrite table command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + + "FROM view1 CLUSTER BY COL3") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + test("SPARK-25144 'distinct' causes memory leak") { val ds = List(Foo(Some("bar"))).toDS val result = ds.flatMap(_.bar).distinct diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 9fe83bb332a9a..07ee105404311 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -149,7 +149,7 @@ object HiveAnalysis extends Rule[LogicalPlan] { case InsertIntoTable(r: HiveTableRelation, partSpec, query, overwrite, ifPartitionNotExists) if DDLUtils.isHiveTable(r.tableMeta) => InsertIntoHiveTable(r.tableMeta, partSpec, query, overwrite, - ifPartitionNotExists, query.output) + ifPartitionNotExists, query.output.map(_.name)) case CreateTable(tableDesc, mode, None) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) @@ -157,14 +157,14 @@ object HiveAnalysis extends Rule[LogicalPlan] { case CreateTable(tableDesc, mode, Some(query)) if DDLUtils.isHiveTable(tableDesc) => DDLUtils.checkDataColNames(tableDesc) - CreateHiveTableAsSelectCommand(tableDesc, query, query.output, mode) + CreateHiveTableAsSelectCommand(tableDesc, query, query.output.map(_.name), mode) case InsertIntoDir(isLocal, storage, provider, child, overwrite) if DDLUtils.isHiveTable(provider) => val outputPath = new Path(storage.locationUri.get) if (overwrite) DDLUtils.verifyNotReadPath(child, outputPath) - InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output) + InsertIntoHiveDirCommand(isLocal, storage, child, overwrite, child.output.map(_.name)) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 27d807cc35627..9b9f5d2a3bb0c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.execution.command.DataWritingCommand case class CreateHiveTableAsSelectCommand( tableDesc: CatalogTable, query: LogicalPlan, - outputColumns: Seq[Attribute], + outputColumnNames: Seq[String], mode: SaveMode) extends DataWritingCommand { @@ -63,7 +63,7 @@ case class CreateHiveTableAsSelectCommand( query, overwrite = false, ifPartitionNotExists = false, - outputColumns = outputColumns).run(sparkSession, child) + outputColumnNames = outputColumnNames).run(sparkSession, child) } else { // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data @@ -82,7 +82,7 @@ case class CreateHiveTableAsSelectCommand( query, overwrite = true, ifPartitionNotExists = false, - outputColumns = outputColumns).run(sparkSession, child) + outputColumnNames = outputColumnNames).run(sparkSession, child) } catch { case NonFatal(e) => // drop the created table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index cebeca0ce9444..7a5a8db07b403 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -57,7 +57,7 @@ case class InsertIntoHiveDirCommand( storage: CatalogStorageFormat, query: LogicalPlan, overwrite: Boolean, - outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) extends SaveAsHiveFile { override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { assert(storage.locationUri.nonEmpty) @@ -105,7 +105,7 @@ case class InsertIntoHiveDirCommand( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpPath.toString, - allColumns = outputColumns) + allColumns = query.outputWithNames(outputColumnNames)) val fs = writeToPath.getFileSystem(hadoopConf) if (overwrite && fs.exists(writeToPath)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 02a60f16b3b3a..e5b2abdc3c0f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -69,7 +69,7 @@ case class InsertIntoHiveTable( query: LogicalPlan, overwrite: Boolean, ifPartitionNotExists: Boolean, - outputColumns: Seq[Attribute]) extends SaveAsHiveFile { + outputColumnNames: Seq[String]) extends SaveAsHiveFile { /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the @@ -198,7 +198,7 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, - allColumns = outputColumns, + allColumns = query.outputWithNames(outputColumnNames), partitionAttributes = partitionAttributes) if (partition.nonEmpty) { From 5bce8a0f325eed4c37687dab98b707c46ee4f50e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Sep 2018 19:21:30 +0800 Subject: [PATCH 2/9] address comments --- .../spark/sql/catalyst/plans/QueryPlan.scala | 14 ----------- .../command/DataWritingCommand.scala | 23 ++++++++++++++++++- .../execution/datasources/DataSource.scala | 9 ++++---- .../InsertIntoHadoopFsRelationCommand.scala | 1 - .../execution/InsertIntoHiveDirCommand.scala | 2 +- .../hive/execution/InsertIntoHiveTable.scala | 2 +- 6 files changed, 29 insertions(+), 22 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index c0b80d1e8137d..b1ffdca091461 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -38,20 +38,6 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT */ def outputSet: AttributeSet = AttributeSet(output) - /** - * Returns output attributes with provided names. - * The length of provided names should be the same of the length of [[output]]. - */ - def outputWithNames(names: Seq[String]): Seq[Attribute] = { - // Save the output attributes to a variable to avoid duplicated function calls. - val outputAttributes = output - assert(outputAttributes.length == names.length, - "The length of provided names doesn't match the length of output attributes.") - outputAttributes.zipWithIndex.map { case (element, index) => - element.withName(names(index)) - } - } - /** * All Attributes that appear in expressions from this operator. Note that this set does not * include attributes that are implicitly referenced by being passed through to the output tuple. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index ea0ca3944ee7f..de7d5f46b9aca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter -import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.util.SerializableConfiguration /** @@ -44,6 +44,9 @@ trait DataWritingCommand extends Command { // Output columns of the analyzed input query plan def outputColumnNames: Seq[String] + def outputColumns: Seq[Attribute] = + DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames) + lazy val metrics: Map[String, SQLMetric] = BasicWriteJobStatsTracker.metrics def basicWriteJobStatsTracker(hadoopConf: Configuration): BasicWriteJobStatsTracker = { @@ -53,3 +56,21 @@ trait DataWritingCommand extends Command { def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] } + +object DataWritingCommand { + /** + * Returns output attributes with provided names. + * The length of provided names should be the same of the length of [[LogicalPlan.output]]. + */ + def logicalPlanOutputWithNames( + query: LogicalPlan, + names: Seq[String]): Seq[Attribute] = { + // Save the output attributes to a variable to avoid duplicated function calls. + val outputAttributes = query.output + assert(outputAttributes.length == names.length, + "The length of provided names doesn't match the length of output attributes.") + outputAttributes.zipWithIndex.map { case (element, index) => + element.withName(names(index)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 784b2b01eb9a0..ce3bc3dd48327 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat import org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider import org.apache.spark.sql.execution.datasources.json.JsonFileFormat @@ -461,8 +462,8 @@ case class DataSource( * @param data The input query plan that produces the data to be written. Note that this plan * is analyzed and optimized. * @param outputColumnNames The original output column names of the input query plan. The - * optimizer may not preserve the output column's names' case, so we need - * this parameter instead of `data.output`. + * optimizer may not preserve the output column's names' case, so we need + * this parameter instead of `data.output`. * @param physicalPlan The physical plan of the input query plan. We should run the writing * command with this physical plan instead of creating a new physical plan, * so that the metrics can be correctly linked to the given physical plan and @@ -473,7 +474,7 @@ case class DataSource( data: LogicalPlan, outputColumnNames: Seq[String], physicalPlan: SparkPlan): BaseRelation = { - val outputColumns = data.outputWithNames(names = outputColumnNames) + val outputColumns = DataWritingCommand.logicalPlanOutputWithNames(data, outputColumnNames) if (outputColumns.map(_.dataType).exists(_.isInstanceOf[CalendarIntervalType])) { throw new AnalysisException("Cannot save interval data type into external storage.") } @@ -498,7 +499,7 @@ case class DataSource( } val resolved = cmd.copy( partitionColumns = resolvedPartCols, - outputColumnNames = outputColumns.map(_.name)) + outputColumnNames = outputColumnNames) resolved.run(sparkSession, physicalPlan) // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring copy(userSpecifiedSchema = Some(outputColumns.toStructType.asNullable)).resolveRelation() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index f5199b710036a..9dc69104cd5f1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -155,7 +155,6 @@ case class InsertIntoHadoopFsRelationCommand( } } - val outputColumns = query.outputWithNames(outputColumnNames) val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 7a5a8db07b403..0a73aaa94bc75 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -105,7 +105,7 @@ case class InsertIntoHiveDirCommand( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpPath.toString, - allColumns = query.outputWithNames(outputColumnNames)) + allColumns = outputColumns) val fs = writeToPath.getFileSystem(hadoopConf) if (overwrite && fs.exists(writeToPath)) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index e5b2abdc3c0f9..75a0563e72c91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -198,7 +198,7 @@ case class InsertIntoHiveTable( hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, - allColumns = query.outputWithNames(outputColumnNames), + allColumns = outputColumns, partitionAttributes = partitionAttributes) if (partition.nonEmpty) { From 16bb457828ff0284456ef4ef36a23384b4a74b6e Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Sep 2018 19:27:38 +0800 Subject: [PATCH 3/9] address more comment --- .../spark/sql/execution/command/DataWritingCommand.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index de7d5f46b9aca..a1bb5af1ab723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -41,9 +41,10 @@ trait DataWritingCommand extends Command { override final def children: Seq[LogicalPlan] = query :: Nil - // Output columns of the analyzed input query plan + // Output column names of the analyzed input query plan. def outputColumnNames: Seq[String] + // Output columns of the analyzed input query plan. def outputColumns: Seq[Attribute] = DataWritingCommand.logicalPlanOutputWithNames(query, outputColumnNames) @@ -69,8 +70,8 @@ object DataWritingCommand { val outputAttributes = query.output assert(outputAttributes.length == names.length, "The length of provided names doesn't match the length of output attributes.") - outputAttributes.zipWithIndex.map { case (element, index) => - element.withName(names(index)) + outputAttributes.zip(names).map { case (attr, outputName) => + attr.withName(outputName) } } } From 3c282ef85acf80b1fb2507d75c1a2ad585efe115 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 3 Sep 2018 21:46:46 +0800 Subject: [PATCH 4/9] add more test cases --- .../org/apache/spark/sql/SQLQuerySuite.scala | 75 ------------------- .../sql/test/DataFrameReaderWriterSuite.scala | 75 +++++++++++++++++++ .../sql/hive/execution/HiveDDLSuite.scala | 41 ++++++++++ 3 files changed, 116 insertions(+), 75 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4f734c79e0f4f..cef1332a0bda6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2854,81 +2854,6 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } - test("Insert overwrite table command should output correct schema: basic") { - withTable("tbl", "tbl2") { - withView("view1") { - val df = spark.range(10).toDF("id") - df.write.format("parquet").saveAsTable("tbl") - spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") - spark.sql("CREATE TABLE tbl2(ID long) USING parquet") - spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) - val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString - val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) - assert(spark.read.parquet(location).schema == expectedSchema) - checkAnswer(spark.table("tbl2"), df) - } - } - } - - test("Insert overwrite table command should output correct schema: complex") { - withTable("tbl", "tbl2") { - withView("view1") { - val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") - df.write.format("parquet").saveAsTable("tbl") - spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") - spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + - "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") - spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + - "FROM view1 CLUSTER BY COL3") - val identifier = TableIdentifier("tbl2", Some("default")) - val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString - val expectedSchema = StructType(Seq( - StructField("COL1", LongType, true), - StructField("COL3", IntegerType, true), - StructField("COL2", IntegerType, true))) - assert(spark.read.parquet(location).schema == expectedSchema) - checkAnswer(spark.table("tbl2"), df) - } - } - } - - test("Create table as select command should output correct schema: basic") { - withTable("tbl", "tbl2") { - withView("view1") { - val df = spark.range(10).toDF("id") - df.write.format("parquet").saveAsTable("tbl") - spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") - spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) - val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString - val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) - assert(spark.read.parquet(location).schema == expectedSchema) - checkAnswer(spark.table("tbl2"), df) - } - } - } - - test("Create table as select command should output correct schema: complex") { - withTable("tbl", "tbl2") { - withView("view1") { - val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") - df.write.format("parquet").saveAsTable("tbl") - spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") - spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + - "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) - val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString - val expectedSchema = StructType(Seq( - StructField("COL1", LongType, true), - StructField("COL3", IntegerType, true), - StructField("COL2", IntegerType, true))) - assert(spark.read.parquet(location).schema == expectedSchema) - checkAnswer(spark.table("tbl2"), df) - } - } - } - test("SPARK-25144 'distinct' causes memory leak") { val ds = List(Foo(Some("bar"))).toDS val result = ds.flatMap(_.bar).distinct diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index b65058fffd339..6b09835f09d0a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -805,6 +805,81 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } + test("Insert overwrite table command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long) USING parquet") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Insert overwrite table command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + + "FROM view1 CLUSTER BY COL3") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: basic") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).toDF("id") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + + test("Create table as select command should output correct schema: complex") { + withTable("tbl", "tbl2") { + withView("view1") { + val df = spark.range(10).map(x => (x, x.toInt, x.toInt)).toDF("col1", "col2", "col3") + df.write.format("parquet").saveAsTable("tbl") + spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") + spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") + val identifier = TableIdentifier("tbl2", Some("default")) + val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString + val expectedSchema = StructType(Seq( + StructField("COL1", LongType, true), + StructField("COL3", IntegerType, true), + StructField("COL2", IntegerType, true))) + assert(spark.read.parquet(location).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), df) + } + } + } + test("use Spark jobs to list files") { withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { withTempDir { dir => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 728817729dcf7..3bf3d2c7c6d8e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -754,6 +754,47 @@ class HiveDDLSuite } } + test("Insert overwrite Hive table should output correct schema") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl SELECT 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2(ID long)") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + + test("Insert into Hive directory should output correct schema") { + withTable("tbl") { + withView("view1") { + withTempPath { path => + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl SELECT 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql(s"CREATE TABLE tbl2(ID long) location '${path.toURI}'") + spark.sql(s"INSERT OVERWRITE DIRECTORY '${path.toURI}' SELECT ID FROM view1") + // spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + } + + test("Create Hive table as select should output correct schema") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl SELECT 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + spark.sql("CREATE TABLE tbl2 AS SELECT ID FROM view1") + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } + } + test("alter table partition - storage information") { sql("CREATE TABLE boxes (height INT, length INT) PARTITIONED BY (width INT)") sql("INSERT OVERWRITE TABLE boxes PARTITION (width=4) SELECT 4, 4") From 98bf027df9c4467adc2673097c6762a3ec5210ce Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Sep 2018 00:51:57 +0800 Subject: [PATCH 5/9] revise tests --- .../src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 1 - .../org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index cef1332a0bda6..01dc28d70184e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicBoolean import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 6b09835f09d0a..7f31372c89b1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -830,8 +830,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") - spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 " + - "FROM view1 CLUSTER BY COL3") + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 FROM view1") val identifier = TableIdentifier("tbl2", Some("default")) val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString val expectedSchema = StructType(Seq( From 538fea99ed2158316d89f64ce397c4791fbed1f3 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Sep 2018 11:34:32 +0800 Subject: [PATCH 6/9] revise test --- .../apache/spark/sql/test/DataFrameReaderWriterSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 7f31372c89b1d..534cdf14bc1ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -831,7 +831,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be spark.sql("CREATE TABLE tbl2(COL1 long, COL2 int, COL3 int) USING parquet PARTITIONED " + "BY (COL2) CLUSTERED BY (COL3) INTO 3 BUCKETS") spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT COL1, COL2, COL3 FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) + val identifier = TableIdentifier("tbl2") val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString val expectedSchema = StructType(Seq( StructField("COL1", LongType, true), @@ -850,7 +850,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be df.write.format("parquet").saveAsTable("tbl") spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") spark.sql("CREATE TABLE tbl2 USING parquet AS SELECT ID FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) + val identifier = TableIdentifier("tbl2") val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) assert(spark.read.parquet(location).schema == expectedSchema) @@ -867,7 +867,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be spark.sql("CREATE VIEW view1 AS SELECT * FROM tbl") spark.sql("CREATE TABLE tbl2 USING parquet PARTITIONED BY (COL2) " + "CLUSTERED BY (COL3) INTO 3 BUCKETS AS SELECT COL1, COL2, COL3 FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) + val identifier = TableIdentifier("tbl2") val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString val expectedSchema = StructType(Seq( StructField("COL1", LongType, true), From 45d2a20fd9b13edabc3a36d3fec65b6bc7a0463a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Sep 2018 13:53:39 +0800 Subject: [PATCH 7/9] revise --- .../org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala | 2 +- .../org/apache/spark/sql/hive/execution/HiveDDLSuite.scala | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 534cdf14bc1ba..237872585e11d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -813,7 +813,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") spark.sql("CREATE TABLE tbl2(ID long) USING parquet") spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") - val identifier = TableIdentifier("tbl2", Some("default")) + val identifier = TableIdentifier("tbl2") val location = spark.sessionState.catalog.getTableMetadata(identifier).location.toString val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) assert(spark.read.parquet(location).schema == expectedSchema) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3bf3d2c7c6d8e..f07b4a1560ed9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -776,7 +776,6 @@ class HiveDDLSuite spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") spark.sql(s"CREATE TABLE tbl2(ID long) location '${path.toURI}'") spark.sql(s"INSERT OVERWRITE DIRECTORY '${path.toURI}' SELECT ID FROM view1") - // spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") checkAnswer(spark.table("tbl2"), Seq(Row(4))) } } From 3ca072d18474d1536c3ac729fe1e0b79cd855cca Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 4 Sep 2018 16:28:50 +0800 Subject: [PATCH 8/9] fix CreateHiveTableAsSelectCommand output schema and verify it in test case --- .../command/DataWritingCommand.scala | 15 +++++ .../CreateHiveTableAsSelectCommand.scala | 3 +- .../sql/hive/execution/HiveDDLSuite.scala | 62 +++++++++++-------- 3 files changed, 52 insertions(+), 28 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index a1bb5af1ab723..0a185b8472060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration /** @@ -74,4 +75,18 @@ object DataWritingCommand { attr.withName(outputName) } } + + /** + * Returns schema of logical plan with provided names. + * The length of provided names should be the same of the length of [[LogicalPlan.schema]]. + */ + def logicalPlanSchemaWithNames( + query: LogicalPlan, + names: Seq[String]): StructType = { + assert(query.schema.length == names.length, + "The length of provided names doesn't match the length of query schema.") + StructType(query.schema.zip(names).map { case (structField, outputName) => + structField.copy(name = outputName) + }) + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala index 9b9f5d2a3bb0c..0eb2f0de0acd9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala @@ -69,7 +69,8 @@ case class CreateHiveTableAsSelectCommand( // add the relation into catalog, just in case of failure occurs while data // processing. assert(tableDesc.schema.isEmpty) - catalog.createTable(tableDesc.copy(schema = query.schema), ignoreIfExists = false) + val schema = DataWritingCommand.logicalPlanSchemaWithNames(query, outputColumnNames) + catalog.createTable(tableDesc.copy(schema = schema), ignoreIfExists = false) try { // Read back the metadata of the table which was created just now. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f07b4a1560ed9..6eab8ecc53b6d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -755,41 +755,49 @@ class HiveDDLSuite } test("Insert overwrite Hive table should output correct schema") { - withTable("tbl", "tbl2") { - withView("view1") { - spark.sql("CREATE TABLE tbl(id long)") - spark.sql("INSERT OVERWRITE TABLE tbl SELECT 4") - spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") - spark.sql("CREATE TABLE tbl2(ID long)") - spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") - checkAnswer(spark.table("tbl2"), Seq(Row(4))) - } - } - } - - test("Insert into Hive directory should output correct schema") { - withTable("tbl") { - withView("view1") { - withTempPath { path => + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl", "tbl2") { + withView("view1") { spark.sql("CREATE TABLE tbl(id long)") - spark.sql("INSERT OVERWRITE TABLE tbl SELECT 4") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") - spark.sql(s"CREATE TABLE tbl2(ID long) location '${path.toURI}'") - spark.sql(s"INSERT OVERWRITE DIRECTORY '${path.toURI}' SELECT ID FROM view1") - checkAnswer(spark.table("tbl2"), Seq(Row(4))) + withTempPath { path => + sql( + s""" + |CREATE TABLE tbl2(ID long) USING hive + |OPTIONS(fileFormat 'parquet') + |LOCATION '${path.toURI}' + """.stripMargin) + spark.sql("INSERT OVERWRITE TABLE tbl2 SELECT ID FROM view1") + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } } } } } test("Create Hive table as select should output correct schema") { - withTable("tbl", "tbl2") { - withView("view1") { - spark.sql("CREATE TABLE tbl(id long)") - spark.sql("INSERT OVERWRITE TABLE tbl SELECT 4") - spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") - spark.sql("CREATE TABLE tbl2 AS SELECT ID FROM view1") - checkAnswer(spark.table("tbl2"), Seq(Row(4))) + withSQLConf(CONVERT_METASTORE_PARQUET.key -> "false") { + withTable("tbl", "tbl2") { + withView("view1") { + spark.sql("CREATE TABLE tbl(id long)") + spark.sql("INSERT OVERWRITE TABLE tbl VALUES 4") + spark.sql("CREATE VIEW view1 AS SELECT id FROM tbl") + withTempPath { path => + sql( + s""" + |CREATE TABLE tbl2 USING hive + |OPTIONS(fileFormat 'parquet') + |LOCATION '${path.toURI}' + |AS SELECT ID FROM view1 + """.stripMargin) + val expectedSchema = StructType(Seq(StructField("ID", LongType, true))) + assert(spark.read.parquet(path.toString).schema == expectedSchema) + checkAnswer(spark.table("tbl2"), Seq(Row(4))) + } + } } } } From 4590c9837026e820d7d91300a7ab3f87a668755c Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Wed, 5 Sep 2018 11:39:18 +0800 Subject: [PATCH 9/9] revise --- .../datasources/InsertIntoHadoopFsRelationCommand.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 9dc69104cd5f1..484942d35c857 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -62,8 +62,8 @@ case class InsertIntoHadoopFsRelationCommand( override def run(sparkSession: SparkSession, child: SparkPlan): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that - SchemaUtils.checkSchemaColumnNameDuplication( - query.schema, + SchemaUtils.checkColumnNameDuplication( + outputColumnNames, s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis)