diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index e2a96267082b8..ddeca123cafd0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -232,6 +232,10 @@ object FileCommitProtocol extends Logging { def getStagingDir(path: String, jobId: String): Path = { new Path(path, ".spark-staging-" + jobId) } + + def overwriteStagingDir(path: String, jobId: String): Path = { + new Path(new Path(path).getParent, s".${new Path(path).getName}-spark-staging-" + jobId) + } } /** diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 3a24da98ecc24..2fe6f128e340d 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -236,7 +236,7 @@ class HadoopMapReduceCommitProtocol( } } - fs.delete(stagingDir, true) + fs.deleteOnExit(stagingDir) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 0839a2f487511..b7e27ab2698bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4351,6 +4351,9 @@ class SQLConf extends Serializable with Logging { def fileCommitProtocolClass: String = getConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS) + def useOverwriteFileCommitProtocol: Boolean = fileCommitProtocolClass == + "org.apache.spark.sql.execution.datasources.SQLOverwriteHadoopMapReduceCommitProtocol" + def parallelPartitionDiscoveryThreshold: Int = getConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 794d90b242c1b..9d0fbaeef83b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -76,6 +76,7 @@ object FileFormatWriter extends Logging { * processing statistics. * @return The set of all partition paths that were updated during this write job. */ + // scalastyle:off argcount def write( sparkSession: SparkSession, plan: SparkPlan, @@ -87,7 +88,8 @@ object FileFormatWriter extends Logging { bucketSpec: Option[BucketSpec], statsTrackers: Seq[WriteJobStatsTracker], options: Map[String, String], - numStaticPartitionCols: Int = 0) + numStaticPartitionCols: Int = 0, + preCommitJob: Option[() => Unit] = None) : Set[String] = { require(partitionColumns.size >= numStaticPartitionCols) @@ -234,6 +236,7 @@ object FileFormatWriter extends Logging { val commitMsgs = ret.map(_.commitMsg) + preCommitJob.map(_()) logInfo(s"Start to commit write Job ${description.uuid}.") val (_, duration) = Utils.timeTakenMs { committer.commitJob(job, commitMsgs) } logInfo(s"Write Job ${description.uuid} committed. Elapsed time: $duration ms.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 41b55a3b6e936..e3d180b9b3250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.io.IOException + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.io.FileCommitProtocol @@ -84,7 +86,8 @@ case class InsertIntoHadoopFsRelationCommand( outputColumnNames, s"when inserting into $outputPath", sparkSession.sessionState.conf.caseSensitiveAnalysis) - + val useOverwriteFileCommitProtocol = + sparkSession.sessionState.conf.useOverwriteFileCommitProtocol val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options) val fs = outputPath.getFileSystem(hadoopConf) val qualifiedOutputPath = outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) @@ -125,6 +128,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (useOverwriteFileCommitProtocol) { + // For `SQLOverwriteHadoopMapReduceCommitProtocol`, do not delete directories first. + true } else if (dynamicPartitionOverwrite) { // For dynamic partition overwrite, do not delete partition directories ahead. true @@ -168,13 +174,27 @@ case class InsertIntoHadoopFsRelationCommand( // For dynamic partition overwrite, FileOutputCommitter's output path is staging path, files // will be renamed from staging path to final output path during commit job - val committerOutputPath = if (dynamicPartitionOverwrite) { + val committerOutputPath = if (useOverwriteFileCommitProtocol && mode == SaveMode.Overwrite) { + FileCommitProtocol.overwriteStagingDir(outputPath.toString, jobId) + .makeQualified(fs.getUri, fs.getWorkingDirectory) + } else if (dynamicPartitionOverwrite) { FileCommitProtocol.getStagingDir(outputPath.toString, jobId) .makeQualified(fs.getUri, fs.getWorkingDirectory) } else { qualifiedOutputPath } + // When `dynamicPartitionOverwrite` is true, `SQLOverwriteHadoopMapReduceCommitProtocol` + // will execute as the method `dynamicPartitionOverwrite`, so Spark don't need to delete + // matching partition here. + val preCommitJob = if (useOverwriteFileCommitProtocol && + mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { + Some(() => + deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer)) + } else { + None + } + val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, @@ -188,8 +208,44 @@ case class InsertIntoHadoopFsRelationCommand( bucketSpec = bucketSpec, statsTrackers = Seq(basicWriteJobStatsTracker(hadoopConf)), options = options, - numStaticPartitionCols = staticPartitions.size) + numStaticPartitionCols = staticPartitions.size, + preCommitJob = preCommitJob) + if (useOverwriteFileCommitProtocol && mode == SaveMode.Overwrite) { + if (partitionColumns.isEmpty) { + // Non-partition table overwrite should rename staging dir to output path + if (!fs.rename(committerOutputPath, qualifiedOutputPath)) { + throw new IOException(s"Failed to rename $committerOutputPath to $outputPath") + } + } else if (staticPartitions.size == partitionColumns.size) { + // For static partition overwrite, if custom partition path is not empty, result data + // haven been written to target custom partition path during commitJob. + if (!customPartitionLocations.contains(staticPartitions)) { + val stagingStaticPartitionPath = committerOutputPath.suffix(staticPartitionPrefix) + val targetLocation = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (!fs.exists(targetLocation.getParent)) { + fs.mkdirs(targetLocation.getParent) + } + if (!fs.rename(stagingStaticPartitionPath, targetLocation)) { + throw new IOException(s"Failed to rename $stagingStaticPartitionPath to " + + s"$targetLocation") + } + } + } else if (dynamicPartitionOverwrite) { + // Same behavior as default, do nothing here. + } else { + // STATIC mode dynamic partition overwrite + val targetLocation = qualifiedOutputPath.suffix(staticPartitionPrefix) + if (!fs.exists(targetLocation.getParent)) { + fs.mkdirs(targetLocation.getParent) + } + val stagingStaticPartitionPath = committerOutputPath.suffix(staticPartitionPrefix) + if (!fs.rename(stagingStaticPartitionPath, targetLocation)) { + throw new IOException(s"Failed to rename $stagingStaticPartitionPath to " + + s"$targetLocation") + } + } + } // update metastore partition metadata if (updatedPartitionPaths.isEmpty && staticPartitions.nonEmpty @@ -218,6 +274,17 @@ case class InsertIntoHadoopFsRelationCommand( Seq.empty[Row] } + + def staticPartitionPrefix: String = { + if (staticPartitions.nonEmpty) { + "/" + partitionColumns.flatMap { p => + staticPartitions.get(p.name).map(getPartitionPathString(p.name, _)) + }.mkString("/") + } else { + "" + } + } + /** * Deletes all partition files that match the specified static prefix. Partitions with custom * locations are also cleared based on the custom locations map given to this class. @@ -227,13 +294,6 @@ case class InsertIntoHadoopFsRelationCommand( qualifiedOutputPath: Path, customPartitionLocations: Map[TablePartitionSpec, String], committer: FileCommitProtocol): Unit = { - val staticPartitionPrefix = if (staticPartitions.nonEmpty) { - "/" + partitionColumns.flatMap { p => - staticPartitions.get(p.name).map(getPartitionPathString(p.name, _)) - }.mkString("/") - } else { - "" - } // first clear the path determined by the static partition keys (e.g. /table/foo=1) val staticPrefixPath = qualifiedOutputPath.suffix(staticPartitionPrefix) if (fs.exists(staticPrefixPath) && !committer.deleteWithJob(fs, staticPrefixPath, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLOverwriteHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLOverwriteHadoopMapReduceCommitProtocol.scala new file mode 100644 index 0000000000000..e6799caee6a71 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLOverwriteHadoopMapReduceCommitProtocol.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.internal.io.FileCommitProtocol.overwriteStagingDir +import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol + +/** + * A variant of [[HadoopMapReduceCommitProtocol]] that used for overwrite save mode. + */ +class SQLOverwriteHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) { + // Override stagingDir here to keep use same staging dir when dynamicPartitionOverwrite is true. + @transient override lazy val stagingDir: Path = overwriteStagingDir(path, jobId) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SQLOverwriteHadoopMapReduceCommitProtocolSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SQLOverwriteHadoopMapReduceCommitProtocolSuite.scala new file mode 100644 index 0000000000000..d8e65bd95ee1d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SQLOverwriteHadoopMapReduceCommitProtocolSuite.scala @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode._ +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.Utils + +class SQLOverwriteHadoopMapReduceCommitProtocolSuite extends QueryTest with SharedSparkSession { + + import testImplicits._ + + test("SPARK-36571: Check staging dir") { + val path = new Path(Utils.createTempDir().toString) + val commitProtocol = + new SQLOverwriteHadoopMapReduceCommitProtocol("000001", path.toString, true) + assert(commitProtocol.stagingDir.getParent == path.getParent) + assert(commitProtocol.stagingDir.getName == s".${path.getName}-spark-staging-000001") + } + + test("SPARK-36571: Non-partitioned table insert overwrite") { + withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[SQLOverwriteHadoopMapReduceCommitProtocol].getName) { + withTable("t") { + withTempView("temp") { + sql( + s""" + | CREATE TABLE t(c1 int, p1 int) USING PARQUET + """.stripMargin) + + val df = Seq((1, 2), (1, 2)) + .toDF("c1", "p1").repartition(1) + df.createOrReplaceTempView("temp") + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), df) + + // test can delete data correctly + sql("INSERT INTO TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), + Row(1, 2) :: Row(1, 2) :: Row(1, 2) :: Row(1, 2) :: Nil) + + // test can delete data correctly + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), df) + + } + } + } + } + + test("SPARK-36571: Partitioned table insert single partition") { + withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[SQLOverwriteHadoopMapReduceCommitProtocol].getName) { + withTable("t") { + withTempView("temp") { + sql( + s""" + | CREATE TABLE t(c1 int, p1 string, p2 string) + | USING PARQUET + | PARTITIONED BY (p1, p2) + """.stripMargin) + + val df = Seq(1, 2, 3).toDF("c1") + df.createOrReplaceTempView("temp") + sql("INSERT OVERWRITE TABLE t PARTITION (p1 = 1, p2 = 1) SELECT * FROM temp") + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 1 AND p2 = 1"), df) + + // test won't delete other partitions data + sql("INSERT OVERWRITE TABLE t PARTITION (p1 = 2, p2 = 2) SELECT * FROM temp") + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 2 AND p2 = 2"), df) + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 1 AND p2 = 1"), df) + + // test can delete data correctly + sql("INSERT OVERWRITE TABLE t PARTITION (p1 = 1, p2 = 1) SELECT * FROM temp") + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 1 AND p2 = 1"), df) + + // test can delete data correctly + sql("INSERT INTO TABLE t PARTITION (p1 = 1, p2 = 1) SELECT * FROM temp") + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 1 AND p2 = 1"), + Row(1) :: Row(2) :: Row(3) :: Row(1) :: Row(2) :: Row(3) :: Nil) + + // customized partition location + withTempPath { path => + sql( + s""" + |ALTER TABLE t ADD PARTITION (p1=3, p2=3) + |LOCATION '$path' + |""".stripMargin) + sql("INSERT OVERWRITE TABLE t PARTITION (p1 = 3, p2 = 3) SELECT 3") + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 3 AND p2 = 3"), Row(3) :: Nil) + sql("INSERT INTO TABLE t PARTITION (p1 = 3, p2 = 3) SELECT 3") + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 3 AND p2 = 3"), Row(3) :: Row(3) :: Nil) + } + } + } + } + } + + test("SPARK-36571: Dynamic partition overwrite - DYNAMIC mode") { + withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[SQLOverwriteHadoopMapReduceCommitProtocol].getName, + SQLConf.PARTITION_OVERWRITE_MODE.key -> DYNAMIC.toString) { + withTable("t") { + withTempView("temp") { + sql( + s""" + | CREATE TABLE t(c1 int, p1 string, p2 string) + | USING PARQUET + | PARTITIONED BY (p1, p2) + """.stripMargin) + + val df = Seq((1, "1", "1"), (2, "2", "2"), (3, "3", "3")).toDF("c1", "p1", "p2") + df.createOrReplaceTempView("temp") + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), df) + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 1 AND p2 = 1"), Row(1) :: Nil) + + Seq((3, 3, 3), (4, 4, 4), (5, 5, 5)) + .toDF("c1", "p1", "p2").createOrReplaceTempView("temp") + // test won't delete other partitions data + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: + Row(4, "4", "4") :: Row(5, "5", "5") :: Nil) + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 5 AND p2 = 5"), Row(5) :: Nil) + + // customized partition location + withTempPath { path => + sql( + s""" + |ALTER TABLE t ADD PARTITION (p1=6, p2=6) + |LOCATION '$path' + |""".stripMargin) + Seq((5, 5, 5), (6, 6, 6)) + .toDF("c1", "p1", "p2").createOrReplaceTempView("temp") + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), + Row(1, "1", "1") :: Row(2, "2", "2") :: Row(3, "3", "3") :: + Row(4, "4", "4") :: Row(5, "5", "5") :: Row(6, "6", "6") :: Nil) + } + } + } + } + } + + test("SPARK-36571: Dynamic partition overwrite - STATIC mode") { + withSQLConf(SQLConf.FILE_COMMIT_PROTOCOL_CLASS.key -> + classOf[SQLOverwriteHadoopMapReduceCommitProtocol].getName, + SQLConf.PARTITION_OVERWRITE_MODE.key -> STATIC.toString) { + withTable("t") { + withTempView("temp") { + sql( + s""" + | CREATE TABLE t(c1 int, p1 string, p2 string) + | USING PARQUET + | PARTITIONED BY (p1, p2) + """.stripMargin) + + val df = Seq((1, "1", "1"), (2, "2", "2"), (3, "3", "3")).toDF("c1", "p1", "p2") + df.createOrReplaceTempView("temp") + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), df) + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 1 AND p2 = 1"), Row(1) :: Nil) + + Seq((3, 3, 3), (4, 4, 4), (5, 5, 5)) + .toDF("c1", "p1", "p2").createOrReplaceTempView("temp") + // test won't delete other partitions data + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), + Row(3, "3", "3") :: Row(4, "4", "4") :: Row(5, "5", "5") :: Nil) + checkAnswer(sql("SELECT c1 FROM t WHERE p1 = 5 AND p2 = 5"), Row(5) :: Nil) + + // customized partition location + withTempPath { path => + sql( + s""" + |ALTER TABLE t ADD PARTITION (p1=6, p2=6) + |LOCATION '$path' + |""".stripMargin) + Seq((5, 5, 5), (6, 6, 6)) + .toDF("c1", "p1", "p2").createOrReplaceTempView("temp") + sql("INSERT OVERWRITE TABLE t SELECT * FROM temp") + checkAnswer(sql("SELECT * FROM t"), Row(5, "5", "5") :: Row(6, "6", "6") :: Nil) + } + } + } + } + } +}