From 91014e2d0d6ceb81210b157716142d797073ae52 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 15 Apr 2021 17:55:11 -0700 Subject: [PATCH 01/10] Allow concurrent writers when writing dynamic partitions and bucketed table --- .../apache/spark/sql/internal/SQLConf.scala | 10 + .../datasources/BasicWriteStatsTracker.scala | 36 ++- .../datasources/FileFormatDataWriter.scala | 213 ++++++++++++++++-- .../datasources/FileFormatWriter.scala | 42 +++- .../datasources/WriteStatsTracker.scala | 29 +-- .../datasources/v2/FileWriterFactory.scala | 2 +- .../sql/test/DataFrameReaderWriterSuite.scala | 28 +++ 7 files changed, 290 insertions(+), 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 04e740039f005..f61c5f229ce34 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3150,6 +3150,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val MAX_CONCURRENT_OUTPUT_WRITERS = buildConf("spark.sql.maxConcurrentOutputWriters") + .internal() + .doc("Maximum number of output writers to use concurrently. If number of writers needed " + + "exceeds this limit, task will sort rest of output then writing them.") + .version("3.2.0") + .intConf + .createWithDefault(0) + /** * Holds information about keys that have been deprecated. * @@ -3839,6 +3847,8 @@ class SQLConf extends Serializable with Logging { def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) + def maxConcurrentOutputWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_WRITERS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b6b07de8a5d17..f1c809cea6e58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -53,11 +53,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) private[this] val partitions: mutable.ArrayBuffer[InternalRow] = mutable.ArrayBuffer.empty private[this] var numFiles: Int = 0 - private[this] var submittedFiles: Int = 0 + private[this] var numSubmittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: Option[String] = None + private[this] val submittedFiles = mutable.HashSet[String]() /** * Get the size of the file expected to have been written by a worker. @@ -134,23 +134,20 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) partitions.append(partitionValues) } - override def newBucket(bucketId: Int): Unit = { - // currently unhandled + override def newFile(filePath: String): Unit = { + submittedFiles += filePath + numSubmittedFiles += 1 } - override def newFile(filePath: String): Unit = { - statCurrentFile() - curFile = Some(filePath) - submittedFiles += 1 + override def closeFile(filePath: String): Unit = { + getFileStats(filePath) + submittedFiles.remove(filePath) } - private def statCurrentFile(): Unit = { - curFile.foreach { path => - getFileSize(path).foreach { len => - numBytes += len - numFiles += 1 - } - curFile = None + private def getFileStats(filePath: String): Unit = { + getFileSize(filePath).foreach { len => + numBytes += len + numFiles += 1 } } @@ -159,7 +156,8 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - statCurrentFile() + submittedFiles.foreach(getFileStats) + submittedFiles.clear() // Reports bytesWritten and recordsWritten to the Spark output metrics. Option(TaskContext.get()).map(_.taskMetrics().outputMetrics).foreach { outputMetrics => @@ -167,12 +165,12 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) outputMetrics.setRecordsWritten(numRows) } - if (submittedFiles != numFiles) { - logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + if (numSubmittedFiles != numFiles) { + logInfo(s"Expected $numSubmittedFiles files, but only saw $numFiles. " + "This could be due to the output format not writing empty files, " + "or files being not immediately visible in the filesystem.") } - BasicWriteTaskStats(partitions.toSeq, numFiles, numBytes, numRows) + BasicWriteTaskStats(partitions, numFiles, numBytes, numRows) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 6de9b1d7cea4b..bd1b43cd205bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} +import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration @@ -47,6 +48,7 @@ abstract class FileFormatDataWriter( protected val MAX_FILE_COUNTER: Int = 1000 * 1000 protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() protected var currentWriter: OutputWriter = _ + protected var currentPath: String = _ /** Trackers for computing various statistics on the data as it's being written out. */ protected val statsTrackers: Seq[WriteTaskStatsTracker] = @@ -56,6 +58,7 @@ abstract class FileFormatDataWriter( if (currentWriter != null) { try { currentWriter.close() + statsTrackers.foreach(_.closeFile(currentPath)) } finally { currentWriter = null } @@ -115,7 +118,7 @@ class SingleDirectoryDataWriter( releaseResources() val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - val currentPath = committer.newTaskTempFile( + currentPath = committer.newTaskTempFile( taskAttemptContext, None, f"-c$fileCounter%03d" + ext) @@ -150,7 +153,8 @@ class SingleDirectoryDataWriter( class DynamicPartitionDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol) + committer: FileCommitProtocol, + concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]) extends FileFormatDataWriter(description, taskAttemptContext, committer) { /** Flag saying whether or not the data to be written out is partitioned. */ @@ -167,8 +171,14 @@ class DynamicPartitionDataWriter( private var fileCounter: Int = _ private var recordsInFile: Long = _ - private var currentPartitionValues: Option[UnsafeRow] = None - private var currentBucketId: Option[Int] = None + + private var mode: WriterMode = concurrentOutputWriterSpec match { + case Some(_) => ConcurrentWriterBeforeSort + case None => SingleWriter + } + private val concurrentWriters = + mutable.HashMap[WriterIndex, ConcurrentWriterStatus]() + private val currentWriterId = WriterIndex(None, None) /** Extracts the partition values out of an input row. */ private lazy val getPartitionValues: InternalRow => UnsafeRow = { @@ -204,6 +214,32 @@ class DynamicPartitionDataWriter( private val getOutputRow = UnsafeProjection.create(description.dataColumns, description.allColumns) + override protected def releaseResources(): Unit = { + mode match { + case SingleWriter => + if (currentWriter != null) { + try { + currentWriter.close() + statsTrackers.foreach(_.closeFile(currentPath)) + } finally { + currentWriter = null + } + } + case _ => + currentWriter = null + concurrentWriters.values.foreach(status => { + if (status.outputWriter != null) { + try { + status.outputWriter.close() + } finally { + status.outputWriter = null + } + } + }) + concurrentWriters.clear() + } + } + /** * Opens a new OutputWriter given a partition key and/or a bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the @@ -212,10 +248,17 @@ class DynamicPartitionDataWriter( * @param partitionValues the partition which all tuples being written by this `OutputWriter` * belong to * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + * @param closeCurrentWriter close and release resource for current writer */ - private def newOutputWriter(partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { + private def newOutputWriter( + partitionValues: Option[InternalRow], + bucketId: Option[Int], + closeCurrentWriter: Boolean): Unit = { + recordsInFile = 0 - releaseResources() + if (closeCurrentWriter) { + super.releaseResources() + } val partDir = partitionValues.map(getPartitionPath(_)) partDir.foreach(updatedPartitions.add) @@ -229,7 +272,7 @@ class DynamicPartitionDataWriter( val customPath = partDir.flatMap { dir => description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) } - val currentPath = if (customPath.isDefined) { + currentPath = if (customPath.isDefined) { committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) @@ -247,20 +290,25 @@ class DynamicPartitionDataWriter( val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None - if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { + if (currentWriterId.partitionValues != nextPartitionValues || + currentWriterId.bucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). - if (isPartitioned && currentPartitionValues != nextPartitionValues) { - currentPartitionValues = Some(nextPartitionValues.get.copy()) - statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) - } + updateCurrentWriterStatus() + if (isBucketed) { - currentBucketId = nextBucketId - statsTrackers.foreach(_.newBucket(currentBucketId.get)) + currentWriterId.bucketId = nextBucketId + } + if (isPartitioned && currentWriterId.partitionValues != nextPartitionValues) { + currentWriterId.partitionValues = Some(nextPartitionValues.get.copy()) + if (mode == SingleWriter || !concurrentWriters.contains(currentWriterId)) { + statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get)) + } } - fileCounter = 0 - newOutputWriter(currentPartitionValues, currentBucketId) - } else if (description.maxRecordsPerFile > 0 && + getOrNewOutputWriter() + } + + if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { // Exceeded the threshold in terms of the number of records per file. // Create a new file by increasing the file counter. @@ -268,13 +316,142 @@ class DynamicPartitionDataWriter( assert(fileCounter < MAX_FILE_COUNTER, s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - newOutputWriter(currentPartitionValues, currentBucketId) + newOutputWriter(currentWriterId.partitionValues, currentWriterId.bucketId, true) } val outputRow = getOutputRow(record) currentWriter.write(outputRow) statsTrackers.foreach(_.newRow(outputRow)) recordsInFile += 1 } + + /** + * Dedicated write code path when enabling concurrent writers. + * + * The process has the following step: + * - Step 1: Maintain a map of output writers per each partition and/or bucket columns. + * Keep all writers open and write rows one by one. + * - Step 2: If number of concurrent writers exceeds limit, sort rest of rows. Write rows + * one by one, and eagerly close the writer when finishing each partition and/or + * bucket. + */ + def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext && mode == ConcurrentWriterBeforeSort) { + write(iterator.next()) + } + + if (iterator.hasNext) { + resetWriterStatus() + val sorter = concurrentOutputWriterSpec.get.createSorter() + val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]]) + while (sortIterator.hasNext) { + write(sortIterator.next()) + } + } + } + + sealed abstract class WriterMode + + /** + * Single writer mode always has at most one writer. + * The output is expected to be sorted on partition and/or bucket columns before writing. + */ + case object SingleWriter extends WriterMode + + /** + * Concurrent writer mode before sort happens, and can have multiple concurrent writers + * for each partition and/or bucket columns. + */ + case object ConcurrentWriterBeforeSort extends WriterMode + + /** + * Concurrent writer mode after sort happens. + */ + case object ConcurrentWriterAfterSort extends WriterMode + + /** Wrapper class to index a unique output writer. */ + private case class WriterIndex( + var partitionValues: Option[UnsafeRow], + var bucketId: Option[Int]) + + /** Wrapper class for status of a unique concurrent output writer. */ + private case class ConcurrentWriterStatus( + var outputWriter: OutputWriter, + var recordsInFile: Long, + var fileCounter: Int, + var filePath: String) + + /** + * Update current writer status when a new writer is needed for writing row. + */ + private def updateCurrentWriterStatus(): Unit = { + mode match { + case ConcurrentWriterBeforeSort + if currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined => + // Update writer status in concurrent writers map, because the writer is probably needed + // again later for writing other rows. + val status = concurrentWriters(currentWriterId) + status.outputWriter = currentWriter + status.recordsInFile = recordsInFile + status.fileCounter = fileCounter + status.filePath = currentPath + case ConcurrentWriterAfterSort + if currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined => + // Remove writer status in concurrent writers map and release writer resource, + // because the writer is not needed any more. + concurrentWriters.remove(currentWriterId) + super.releaseResources() + case _ => + } + } + + /** + * Get or create a new writer based on writer mode. + */ + private def getOrNewOutputWriter(): Unit = { + mode match { + case SingleWriter => + fileCounter = 0 + newOutputWriter(currentWriterId.partitionValues, currentWriterId.bucketId, true) + case _ => + if (concurrentWriters.contains(currentWriterId)) { + val status = concurrentWriters(currentWriterId) + currentWriter = status.outputWriter + recordsInFile = status.recordsInFile + fileCounter = status.fileCounter + currentPath = status.filePath + } else { + fileCounter = 0 + newOutputWriter( + currentWriterId.partitionValues, + currentWriterId.bucketId, + false) + concurrentWriters.put( + WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), + ConcurrentWriterStatus(currentWriter, recordsInFile, fileCounter, currentPath)) + if (concurrentWriters.size > concurrentOutputWriterSpec.get.maxWriters && + mode == ConcurrentWriterBeforeSort) { + // Fall back to sort-based single writer mode + mode = ConcurrentWriterAfterSort + } + } + } + } + + private def resetWriterStatus(): Unit = { + if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { + val status = concurrentWriters(currentWriterId) + status.outputWriter = currentWriter + status.recordsInFile = recordsInFile + status.fileCounter = fileCounter + status.filePath = currentPath + } + currentWriterId.partitionValues = None + currentWriterId.bucketId = None + currentWriter = null + recordsInFile = 0 + fileCounter = 0 + currentPath = null + } } /** A shared job description for all the write tasks. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 6300e10c0bb3d..c4237d92178b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{ProjectExec, SortExec, SparkPlan, SQLExecution, UnsafeExternalRowSorter} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.unsafe.types.UTF8String @@ -73,6 +73,11 @@ object FileFormatWriter extends Logging { copy(child = newChild) } + /** Describes how concurrent output writers should be executed. */ + case class ConcurrentOutputWriterSpec( + maxWriters: Int, + createSorter: () => UnsafeExternalRowSorter) + /** * Basic work flow of this command is: * 1. Driver side setup, including output committer initialization and data source specific @@ -177,18 +182,27 @@ object FileFormatWriter extends Logging { committer.setupJob(job) try { - val rdd = if (orderingMatched) { - empty2NullPlan.execute() + val (rdd, concurrentOutputWriterSpec) = if (orderingMatched) { + (empty2NullPlan.execute(), None) } else { // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and // the physical plan may have different attribute ids due to optimizer removing some // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. val orderingExpr = bindReferences( requiredOrdering.map(SortOrder(_, Ascending)), outputSpec.outputColumns) - SortExec( + val sortPlan = SortExec( orderingExpr, global = false, - child = empty2NullPlan).execute() + child = empty2NullPlan) + + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputWriters + val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty + if (concurrentWritersEnabled) { + (empty2NullPlan.execute(), + Some(ConcurrentOutputWriterSpec(maxWriters, () => sortPlan.createSorter()))) + } else { + (sortPlan.execute(), None) + } } // SPARK-23271 If we are attempting to write a zero partition rdd, create a dummy single @@ -211,7 +225,8 @@ object FileFormatWriter extends Logging { sparkPartitionId = taskContext.partitionId(), sparkAttemptNumber = taskContext.taskAttemptId().toInt & Integer.MAX_VALUE, committer, - iterator = iter) + iterator = iter, + concurrentOutputWriterSpec = concurrentOutputWriterSpec) }, rddWithNonEmptyPartitions.partitions.indices, (index, res: WriteTaskResult) => { @@ -245,7 +260,8 @@ object FileFormatWriter extends Logging { sparkPartitionId: Int, sparkAttemptNumber: Int, committer: FileCommitProtocol, - iterator: Iterator[InternalRow]): WriteTaskResult = { + iterator: Iterator[InternalRow], + concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]): WriteTaskResult = { val jobId = SparkHadoopWriterUtils.createJobID(new Date(jobIdInstant), sparkStageId) val taskId = new TaskID(jobId, TaskType.MAP, sparkPartitionId) @@ -273,14 +289,20 @@ object FileFormatWriter extends Logging { } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter( + description, taskAttemptContext, committer, concurrentOutputWriterSpec) } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - while (iterator.hasNext) { - dataWriter.write(iterator.next()) + dataWriter match { + case w: DynamicPartitionDataWriter if concurrentOutputWriterSpec.isDefined => + w.writeWithIterator(iterator) + case _ => + while (iterator.hasNext) { + dataWriter.write(iterator.next()) + } } dataWriter.commit() })(catchBlock = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala index c39a82ee037bc..aaf866bced868 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriteStatsTracker.scala @@ -32,20 +32,7 @@ trait WriteTaskStats extends Serializable * A trait for classes that are capable of collecting statistics on data that's being processed by * a single write task in [[FileFormatWriter]] - i.e. there should be one instance per executor. * - * This trait is coupled with the way [[FileFormatWriter]] works, in the sense that its methods - * will be called according to how tuples are being written out to disk, namely in sorted order - * according to partitionValue(s), then bucketId. - * - * As such, a typical call scenario is: - * - * newPartition -> newBucket -> newFile -> newRow -. - * ^ |______^___________^ ^ ^____| - * | | |______________| - * | |____________________________| - * |____________________________________________| - * - * newPartition and newBucket events are only triggered if the relation to be written out is - * partitioned and/or bucketed, respectively. + * newPartition event is only triggered if the relation to be written out is partitioned. */ trait WriteTaskStatsTracker { @@ -56,22 +43,20 @@ trait WriteTaskStatsTracker { */ def newPartition(partitionValues: InternalRow): Unit - /** - * Process the fact that a new bucket is about to written. - * Only triggered when the relation is bucketed by a (non-empty) sequence of columns. - * @param bucketId The bucket number. - */ - def newBucket(bucketId: Int): Unit - /** * Process the fact that a new file is about to be written. * @param filePath Path of the file into which future rows will be written. */ def newFile(filePath: String): Unit + /** + * Process the fact that a file is finished to be written and closed. + * @param filePath Path of the file. + */ + def closeFile(filePath: String): Unit + /** * Process the fact that a new row to update the tracked statistics accordingly. - * The row will be written to the most recently witnessed file (via `newFile`). * @note Keep in mind that any overhead here is per-row, obviously, * so implementations should be as lightweight as possible. * @param row Current data row to be processed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index 1f25fed3000b2..0feba01bd60ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -35,7 +35,7 @@ case class FileWriterFactory ( if (description.partitionColumns.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer) + new DynamicPartitionDataWriter(description, taskAttemptContext, committer, None) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 273658fcfa4c2..e7ec21def325d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -1219,4 +1219,32 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with } } } + + test("SPARK-26164: Allow concurrent writers for multiple partitions and buckets") { + withTable("t1", "t2") { + val df = spark.range(200).map(_ => { + val n = scala.util.Random.nextInt + (n, n.toString, n % 5) + }).toDF("k1", "k2", "part") + df.write.format("parquet").saveAsTable("t1") + spark.sql("CREATE TABLE t2(k1 int, k2 string, part int) USING parquet PARTITIONED " + + "BY (part) CLUSTERED BY (k1) INTO 3 BUCKETS") + val queryToInsertTable = "INSERT OVERWRITE TABLE t2 SELECT k1, k2, part FROM t1" + + Seq( + // Single writer + 0, + // Concurrent writers without fallback + 200, + // concurrent writers with fallback + 3 + ).foreach { maxWriters => + withSQLConf(SQLConf.MAX_CONCURRENT_OUTPUT_WRITERS.key -> maxWriters.toString) { + spark.sql(queryToInsertTable).collect() + checkAnswer(spark.table("t2").orderBy("k1"), + spark.table("t1").orderBy("k1")) + } + } + } + } } From 9a77db616a289f0b4abda95645213e92e2afffa2 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 20 Apr 2021 20:41:30 -0700 Subject: [PATCH 02/10] Refactor to keep separate classes of single vs concurrent writers --- .../datasources/FileFormatDataWriter.scala | 347 ++++++++++-------- .../datasources/FileFormatWriter.scala | 11 +- .../datasources/v2/FileWriterFactory.scala | 4 +- 3 files changed, 206 insertions(+), 156 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index bd1b43cd205bb..96d6d19811a95 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -147,41 +147,38 @@ class SingleDirectoryDataWriter( } /** - * Writes data to using dynamic partition writes, meaning this single function can write to + * Holds common logic for writing data with dynamic partition writes, meaning it can write to * multiple directories (partitions) or files (bucketing). */ -class DynamicPartitionDataWriter( +abstract class BaseDynamicPartitionDataWriter( description: WriteJobDescription, taskAttemptContext: TaskAttemptContext, - committer: FileCommitProtocol, - concurrentOutputWriterSpec: Option[ConcurrentOutputWriterSpec]) + committer: FileCommitProtocol) extends FileFormatDataWriter(description, taskAttemptContext, committer) { /** Flag saying whether or not the data to be written out is partitioned. */ - private val isPartitioned = description.partitionColumns.nonEmpty + protected val isPartitioned = description.partitionColumns.nonEmpty /** Flag saying whether or not the data to be written out is bucketed. */ - private val isBucketed = description.bucketIdExpression.isDefined + protected val isBucketed = description.bucketIdExpression.isDefined assert(isPartitioned || isBucketed, s"""DynamicPartitionWriteTask should be used for writing out data that's either - |partitioned or bucketed. In this case neither is true. - |WriteJobDescription: $description + |partitioned or bucketed. In this case neither is true. + |WriteJobDescription: $description """.stripMargin) - private var fileCounter: Int = _ - private var recordsInFile: Long = _ + /** Number of records in current file. */ + protected var recordsInFile: Long = _ - private var mode: WriterMode = concurrentOutputWriterSpec match { - case Some(_) => ConcurrentWriterBeforeSort - case None => SingleWriter - } - private val concurrentWriters = - mutable.HashMap[WriterIndex, ConcurrentWriterStatus]() - private val currentWriterId = WriterIndex(None, None) + /** + * File counter for writing current partition or bucket. For same partition or bucket, + * we may have more than one file, due to number of records limit per file. + */ + protected var fileCounter: Int = _ /** Extracts the partition values out of an input row. */ - private lazy val getPartitionValues: InternalRow => UnsafeRow = { + protected lazy val getPartitionValues: InternalRow => UnsafeRow = { val proj = UnsafeProjection.create(description.partitionColumns, description.allColumns) row => proj(row) } @@ -196,68 +193,44 @@ class DynamicPartitionDataWriter( if (i == 0) Seq(partitionName) else Seq(Literal(Path.SEPARATOR), partitionName) }) - /** Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns - * the partition string. */ + /** + * Evaluates the `partitionPathExpression` above on a row of `partitionValues` and returns + * the partition string. + */ private lazy val getPartitionPath: InternalRow => String = { val proj = UnsafeProjection.create(Seq(partitionPathExpression), description.partitionColumns) row => proj(row).getString(0) } /** Given an input row, returns the corresponding `bucketId` */ - private lazy val getBucketId: InternalRow => Int = { + protected lazy val getBucketId: InternalRow => Int = { val proj = UnsafeProjection.create(description.bucketIdExpression.toSeq, description.allColumns) row => proj(row).getInt(0) } /** Returns the data columns to be written given an input row */ - private val getOutputRow = + protected val getOutputRow = UnsafeProjection.create(description.dataColumns, description.allColumns) - override protected def releaseResources(): Unit = { - mode match { - case SingleWriter => - if (currentWriter != null) { - try { - currentWriter.close() - statsTrackers.foreach(_.closeFile(currentPath)) - } finally { - currentWriter = null - } - } - case _ => - currentWriter = null - concurrentWriters.values.foreach(status => { - if (status.outputWriter != null) { - try { - status.outputWriter.close() - } finally { - status.outputWriter = null - } - } - }) - concurrentWriters.clear() - } - } - /** * Opens a new OutputWriter given a partition key and/or a bucket id. * If bucket id is specified, we will append it to the end of the file name, but before the * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet * - * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * @param partitionValues the partition which all tuples being written by this OutputWriter * belong to - * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + * @param bucketId the bucket which all tuples being written by this OutputWriter belong to * @param closeCurrentWriter close and release resource for current writer */ - private def newOutputWriter( + protected def newOutputWriter( partitionValues: Option[InternalRow], bucketId: Option[Int], closeCurrentWriter: Boolean): Unit = { recordsInFile = 0 if (closeCurrentWriter) { - super.releaseResources() + super[FileFormatDataWriter].releaseResources() } val partDir = partitionValues.map(getPartitionPath(_)) @@ -286,6 +259,135 @@ class DynamicPartitionDataWriter( statsTrackers.foreach(_.newFile(currentPath)) } + /** + * Checks if number of records exceeding limit. Open a new OutputWriter if needed. + * + * @param partitionValues the partition which all tuples being written by this `OutputWriter` + * belong to + * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to + */ + protected def checkRecordsInFile( + partitionValues: Option[InternalRow], + bucketId: Option[Int]): Unit = { + if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + + newOutputWriter(partitionValues, bucketId, closeCurrentWriter = true) + } + } + + /** + * Writes the given record with current writer. + * + * @param record The record to write + */ + protected def writeRecord(record: InternalRow): Unit = { + val outputRow = getOutputRow(record) + currentWriter.write(outputRow) + statsTrackers.foreach(_.newRow(outputRow)) + recordsInFile += 1 + } +} + +/** + * Dynamic partition writer with single writer, meaning only one writer is opened at any time for + * writing. The records to be written are required to be sorted on partition and/or bucket + * column(s) before writing. + */ +class DynamicPartitionDataSingleWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + + private var currentPartitionValues: Option[UnsafeRow] = None + private var currentBucketId: Option[Int] = None + + override def write(record: InternalRow): Unit = { + val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None + val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None + + if (currentPartitionValues != nextPartitionValues || currentBucketId != nextBucketId) { + // See a new partition or bucket - write to a new partition dir (or a new bucket file). + if (isPartitioned && currentPartitionValues != nextPartitionValues) { + currentPartitionValues = Some(nextPartitionValues.get.copy()) + statsTrackers.foreach(_.newPartition(currentPartitionValues.get)) + } + if (isBucketed) { + currentBucketId = nextBucketId + } + + fileCounter = 0 + newOutputWriter(currentPartitionValues, currentBucketId, true) + } else { + checkRecordsInFile(currentPartitionValues, currentBucketId) + } + writeRecord(record) + } +} + +/** + * Dynamic partition writer with concurrent writers, meaning multiple concurrent writers are opened + * for writing. + * + * The process has the following steps: + * - Step 1: Maintain a map of output writers per each partition and/or bucket columns. Keep all + * writers opened and write rows one by one. + * - Step 2: If number of concurrent writers exceeds limit, sort rest of rows on partition and/or + * bucket column(s). Write rows one by one, and eagerly close the writer when finishing + * each partition and/or bucket. + * + * Caller is expected to call `writeWithIterator()` instead of `write()` to write records. + */ +class DynamicPartitionDataConcurrentWriter( + description: WriteJobDescription, + taskAttemptContext: TaskAttemptContext, + committer: FileCommitProtocol, + concurrentOutputWriterSpec: ConcurrentOutputWriterSpec) + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + + /** Wrapper class to index a unique concurrent output writer. */ + private case class WriterIndex( + var partitionValues: Option[UnsafeRow], + var bucketId: Option[Int]) + + /** Wrapper class for status of a unique concurrent output writer. */ + private case class WriterStatus( + var outputWriter: OutputWriter, + var recordsInFile: Long, + var fileCounter: Int, + var latestFilePath: String) + + /** + * State to indicate if we are falling back to sort-based writer. + * Because we first try to use concurrent writers, its initial value is false. + */ + private var sortBased: Boolean = false + private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]() + private val currentWriterId = WriterIndex(None, None) + + /** + * Release resources for all concurrent output writers. + */ + override protected def releaseResources(): Unit = { + currentWriter = null + concurrentWriters.values.foreach(status => { + if (status.outputWriter != null) { + try { + status.outputWriter.close() + } finally { + status.outputWriter = null + } + } + }) + concurrentWriters.clear() + } + override def write(record: InternalRow): Unit = { val nextPartitionValues = if (isPartitioned) Some(getPartitionValues(record)) else None val nextBucketId = if (isBucketed) Some(getBucketId(record)) else None @@ -293,55 +395,34 @@ class DynamicPartitionDataWriter( if (currentWriterId.partitionValues != nextPartitionValues || currentWriterId.bucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). - updateCurrentWriterStatus() - + updateCurrentWriterStatusInMap() if (isBucketed) { currentWriterId.bucketId = nextBucketId } if (isPartitioned && currentWriterId.partitionValues != nextPartitionValues) { currentWriterId.partitionValues = Some(nextPartitionValues.get.copy()) - if (mode == SingleWriter || !concurrentWriters.contains(currentWriterId)) { + if (!concurrentWriters.contains(currentWriterId)) { statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get)) } } - - getOrNewOutputWriter() + retrieveWriterInMap() } - if (description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - newOutputWriter(currentWriterId.partitionValues, currentWriterId.bucketId, true) - } - val outputRow = getOutputRow(record) - currentWriter.write(outputRow) - statsTrackers.foreach(_.newRow(outputRow)) - recordsInFile += 1 + checkRecordsInFile(currentWriterId.partitionValues, currentWriterId.bucketId) + writeRecord(record) } /** - * Dedicated write code path when enabling concurrent writers. - * - * The process has the following step: - * - Step 1: Maintain a map of output writers per each partition and/or bucket columns. - * Keep all writers open and write rows one by one. - * - Step 2: If number of concurrent writers exceeds limit, sort rest of rows. Write rows - * one by one, and eagerly close the writer when finishing each partition and/or - * bucket. + * Write iterator of records with concurrent writers. */ def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { - while (iterator.hasNext && mode == ConcurrentWriterBeforeSort) { + while (iterator.hasNext && !sortBased) { write(iterator.next()) } if (iterator.hasNext) { - resetWriterStatus() - val sorter = concurrentOutputWriterSpec.get.createSorter() + clearCurrentWriterStatus() + val sorter = concurrentOutputWriterSpec.createSorter() val sortIterator = sorter.sort(iterator.asInstanceOf[Iterator[UnsafeRow]]) while (sortIterator.hasNext) { write(sortIterator.next()) @@ -349,101 +430,65 @@ class DynamicPartitionDataWriter( } } - sealed abstract class WriterMode - - /** - * Single writer mode always has at most one writer. - * The output is expected to be sorted on partition and/or bucket columns before writing. - */ - case object SingleWriter extends WriterMode - - /** - * Concurrent writer mode before sort happens, and can have multiple concurrent writers - * for each partition and/or bucket columns. - */ - case object ConcurrentWriterBeforeSort extends WriterMode - - /** - * Concurrent writer mode after sort happens. - */ - case object ConcurrentWriterAfterSort extends WriterMode - - /** Wrapper class to index a unique output writer. */ - private case class WriterIndex( - var partitionValues: Option[UnsafeRow], - var bucketId: Option[Int]) - - /** Wrapper class for status of a unique concurrent output writer. */ - private case class ConcurrentWriterStatus( - var outputWriter: OutputWriter, - var recordsInFile: Long, - var fileCounter: Int, - var filePath: String) - /** * Update current writer status when a new writer is needed for writing row. */ - private def updateCurrentWriterStatus(): Unit = { - mode match { - case ConcurrentWriterBeforeSort - if currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined => + private def updateCurrentWriterStatusInMap(): Unit = { + if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { + if (!sortBased) { // Update writer status in concurrent writers map, because the writer is probably needed // again later for writing other rows. val status = concurrentWriters(currentWriterId) status.outputWriter = currentWriter status.recordsInFile = recordsInFile status.fileCounter = fileCounter - status.filePath = currentPath - case ConcurrentWriterAfterSort - if currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined => + status.latestFilePath = currentPath + } else { // Remove writer status in concurrent writers map and release writer resource, // because the writer is not needed any more. concurrentWriters.remove(currentWriterId) super.releaseResources() - case _ => + } } } /** - * Get or create a new writer based on writer mode. + * Retrieve writer in map, or create a new writer if not exists. */ - private def getOrNewOutputWriter(): Unit = { - mode match { - case SingleWriter => - fileCounter = 0 - newOutputWriter(currentWriterId.partitionValues, currentWriterId.bucketId, true) - case _ => - if (concurrentWriters.contains(currentWriterId)) { - val status = concurrentWriters(currentWriterId) - currentWriter = status.outputWriter - recordsInFile = status.recordsInFile - fileCounter = status.fileCounter - currentPath = status.filePath - } else { - fileCounter = 0 - newOutputWriter( - currentWriterId.partitionValues, - currentWriterId.bucketId, - false) - concurrentWriters.put( - WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), - ConcurrentWriterStatus(currentWriter, recordsInFile, fileCounter, currentPath)) - if (concurrentWriters.size > concurrentOutputWriterSpec.get.maxWriters && - mode == ConcurrentWriterBeforeSort) { - // Fall back to sort-based single writer mode - mode = ConcurrentWriterAfterSort - } - } + private def retrieveWriterInMap(): Unit = { + if (concurrentWriters.contains(currentWriterId)) { + val status = concurrentWriters(currentWriterId) + currentWriter = status.outputWriter + recordsInFile = status.recordsInFile + fileCounter = status.fileCounter + currentPath = status.latestFilePath + } else { + fileCounter = 0 + newOutputWriter( + currentWriterId.partitionValues, + currentWriterId.bucketId, + closeCurrentWriter = false) + concurrentWriters.put( + WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), + WriterStatus(currentWriter, recordsInFile, fileCounter, currentPath)) + if (concurrentWriters.size > concurrentOutputWriterSpec.maxWriters && + !sortBased) { + // Fall back to sort-based sequential writer mode. + sortBased = true + } } } - private def resetWriterStatus(): Unit = { + /** + * Clear the current writer status in map. + */ + private def clearCurrentWriterStatus(): Unit = { if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { val status = concurrentWriters(currentWriterId) status.outputWriter = currentWriter status.recordsInFile = recordsInFile status.fileCounter = fileCounter - status.filePath = currentPath + status.latestFilePath = currentPath } currentWriterId.partitionValues = None currentWriterId.bucketId = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index c4237d92178b7..15df8d612f0a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -289,15 +289,20 @@ object FileFormatWriter extends Logging { } else if (description.partitionColumns.isEmpty && description.bucketIdExpression.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter( - description, taskAttemptContext, committer, concurrentOutputWriterSpec) + concurrentOutputWriterSpec match { + case Some(spec) => + new DynamicPartitionDataConcurrentWriter( + description, taskAttemptContext, committer, spec) + case _ => + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) + } } try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. dataWriter match { - case w: DynamicPartitionDataWriter if concurrentOutputWriterSpec.isDefined => + case w: DynamicPartitionDataConcurrentWriter => w.writeWithIterator(iterator) case _ => while (iterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala index 0feba01bd60ba..d827e83623570 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriterFactory.scala @@ -24,7 +24,7 @@ import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark.internal.io.{FileCommitProtocol, SparkHadoopWriterUtils} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.write.{DataWriter, DataWriterFactory} -import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataWriter, SingleDirectoryDataWriter, WriteJobDescription} +import org.apache.spark.sql.execution.datasources.{DynamicPartitionDataSingleWriter, SingleDirectoryDataWriter, WriteJobDescription} case class FileWriterFactory ( description: WriteJobDescription, @@ -35,7 +35,7 @@ case class FileWriterFactory ( if (description.partitionColumns.isEmpty) { new SingleDirectoryDataWriter(description, taskAttemptContext, committer) } else { - new DynamicPartitionDataWriter(description, taskAttemptContext, committer, None) + new DynamicPartitionDataSingleWriter(description, taskAttemptContext, committer) } } From 4dc952403dd5945f702f4a5848540f9d5e2a6929 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Tue, 20 Apr 2021 22:47:12 -0700 Subject: [PATCH 03/10] Try to fix scala compilation error --- .../sql/execution/datasources/BasicWriteStatsTracker.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index f1c809cea6e58..9f66c822d3ea6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -170,7 +170,7 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) "This could be due to the output format not writing empty files, " + "or files being not immediately visible in the filesystem.") } - BasicWriteTaskStats(partitions, numFiles, numBytes, numRows) + BasicWriteTaskStats(partitions.toSeq, numFiles, numBytes, numRows) } } From 5d98d402a69d14644b51db5e705b07d1c718e06d Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 21 Apr 2021 14:36:49 -0700 Subject: [PATCH 04/10] Address all comments --- .../spark/sql/avro/AvroOutputWriter.scala | 2 + .../ml/source/libsvm/LibSVMRelation.scala | 2 + .../apache/spark/sql/internal/SQLConf.scala | 8 ++-- .../datasources/BasicWriteStatsTracker.scala | 6 +-- .../datasources/FileFormatDataWriter.scala | 40 ++++++++++++------- .../datasources/FileFormatWriter.scala | 4 +- .../execution/datasources/OutputWriter.scala | 7 +++- .../datasources/csv/CsvOutputWriter.scala | 2 + .../datasources/json/JsonOutputWriter.scala | 2 + .../datasources/orc/OrcOutputWriter.scala | 2 + .../parquet/ParquetOutputWriter.scala | 2 + .../datasources/text/TextOutputWriter.scala | 2 + .../sql/test/DataFrameReaderWriterSuite.scala | 2 +- .../sql/hive/execution/HiveFileFormat.scala | 2 + .../spark/sql/hive/orc/OrcFileFormat.scala | 2 + .../sql/sources/SimpleTextRelation.scala | 4 ++ 16 files changed, 63 insertions(+), 26 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index 52aa5a69737ef..ed403ee37c4ad 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -84,4 +84,6 @@ private[avro] class AvroOutputWriter( } override def close(): Unit = recordWriter.close(context) + + override def path(): String = path } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index df64de4b10075..0b35ec3094cd4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -64,6 +64,8 @@ private[libsvm] class LibSVMOutputWriter( override def close(): Unit = { writer.close() } + + override def path(): String = path } /** @see [[LibSVMDataSource]] for public documentation. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index f61c5f229ce34..d29f4d01d240d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3150,10 +3150,10 @@ object SQLConf { .booleanConf .createWithDefault(false) - val MAX_CONCURRENT_OUTPUT_WRITERS = buildConf("spark.sql.maxConcurrentOutputWriters") + val MAX_CONCURRENT_OUTPUT_FILE_WRITERS = buildConf("spark.sql.maxConcurrentOutputFileWriters") .internal() - .doc("Maximum number of output writers to use concurrently. If number of writers needed " + - "exceeds this limit, task will sort rest of output then writing them.") + .doc("Maximum number of output file writers to use concurrently. If number of writers " + + "needed exceeds this limit, task will sort rest of output then writing them.") .version("3.2.0") .intConf .createWithDefault(0) @@ -3847,7 +3847,7 @@ class SQLConf extends Serializable with Logging { def decorrelateInnerQueryEnabled: Boolean = getConf(SQLConf.DECORRELATE_INNER_QUERY_ENABLED) - def maxConcurrentOutputWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_WRITERS) + def maxConcurrentOutputFileWriters: Int = getConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS) /** ********************** SQLConf functionality methods ************ */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index 9f66c822d3ea6..4f60a9d4c8c0c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -140,11 +140,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def closeFile(filePath: String): Unit = { - getFileStats(filePath) + updateFileStats(filePath) submittedFiles.remove(filePath) } - private def getFileStats(filePath: String): Unit = { + private def updateFileStats(filePath: String): Unit = { getFileSize(filePath).foreach { len => numBytes += len numFiles += 1 @@ -156,7 +156,7 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - submittedFiles.foreach(getFileStats) + submittedFiles.foreach(updateFileStats) submittedFiles.clear() // Reports bytesWritten and recordsWritten to the Spark output metrics. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 96d6d19811a95..ea54c588b27c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -48,23 +48,29 @@ abstract class FileFormatDataWriter( protected val MAX_FILE_COUNTER: Int = 1000 * 1000 protected val updatedPartitions: mutable.Set[String] = mutable.Set[String]() protected var currentWriter: OutputWriter = _ - protected var currentPath: String = _ /** Trackers for computing various statistics on the data as it's being written out. */ protected val statsTrackers: Seq[WriteTaskStatsTracker] = description.statsTrackers.map(_.newTaskInstance()) - protected def releaseResources(): Unit = { + /** Release resources of `currentWriter`. */ + protected def releaseCurrentWriter(): Unit = { if (currentWriter != null) { try { currentWriter.close() - statsTrackers.foreach(_.closeFile(currentPath)) + statsTrackers.foreach(_.closeFile(currentWriter.path())) } finally { currentWriter = null } } } + /** Release all resources. */ + protected def releaseResources(): Unit = { + // Call `releaseCurrentWriter()` by default, as this is the only resource to be released. + releaseCurrentWriter() + } + /** Writes a record */ def write(record: InternalRow): Unit @@ -118,7 +124,7 @@ class SingleDirectoryDataWriter( releaseResources() val ext = description.outputWriterFactory.getFileExtension(taskAttemptContext) - currentPath = committer.newTaskTempFile( + val currentPath = committer.newTaskTempFile( taskAttemptContext, None, f"-c$fileCounter%03d" + ext) @@ -230,7 +236,7 @@ abstract class BaseDynamicPartitionDataWriter( recordsInFile = 0 if (closeCurrentWriter) { - super[FileFormatDataWriter].releaseResources() + releaseCurrentWriter() } val partDir = partitionValues.map(getPartitionPath(_)) @@ -245,7 +251,7 @@ abstract class BaseDynamicPartitionDataWriter( val customPath = partDir.flatMap { dir => description.customPartitionLocations.get(PartitioningUtils.parsePathFragment(dir)) } - currentPath = if (customPath.isDefined) { + val currentPath = if (customPath.isDefined) { committer.newTaskTempFileAbsPath(taskAttemptContext, customPath.get, ext) } else { committer.newTaskTempFile(taskAttemptContext, partDir, ext) @@ -292,6 +298,15 @@ abstract class BaseDynamicPartitionDataWriter( statsTrackers.foreach(_.newRow(outputRow)) recordsInFile += 1 } + + /** + * Write an iterator of records. + */ + def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext) { + write(iterator.next()) + } + } } /** @@ -357,11 +372,10 @@ class DynamicPartitionDataConcurrentWriter( var bucketId: Option[Int]) /** Wrapper class for status of a unique concurrent output writer. */ - private case class WriterStatus( + private class WriterStatus( var outputWriter: OutputWriter, var recordsInFile: Long, - var fileCounter: Int, - var latestFilePath: String) + var fileCounter: Int) /** * State to indicate if we are falling back to sort-based writer. @@ -415,7 +429,7 @@ class DynamicPartitionDataConcurrentWriter( /** * Write iterator of records with concurrent writers. */ - def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { while (iterator.hasNext && !sortBased) { write(iterator.next()) } @@ -442,7 +456,6 @@ class DynamicPartitionDataConcurrentWriter( status.outputWriter = currentWriter status.recordsInFile = recordsInFile status.fileCounter = fileCounter - status.latestFilePath = currentPath } else { // Remove writer status in concurrent writers map and release writer resource, // because the writer is not needed any more. @@ -461,7 +474,6 @@ class DynamicPartitionDataConcurrentWriter( currentWriter = status.outputWriter recordsInFile = status.recordsInFile fileCounter = status.fileCounter - currentPath = status.latestFilePath } else { fileCounter = 0 newOutputWriter( @@ -470,7 +482,7 @@ class DynamicPartitionDataConcurrentWriter( closeCurrentWriter = false) concurrentWriters.put( WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), - WriterStatus(currentWriter, recordsInFile, fileCounter, currentPath)) + new WriterStatus(currentWriter, recordsInFile, fileCounter)) if (concurrentWriters.size > concurrentOutputWriterSpec.maxWriters && !sortBased) { // Fall back to sort-based sequential writer mode. @@ -488,14 +500,12 @@ class DynamicPartitionDataConcurrentWriter( status.outputWriter = currentWriter status.recordsInFile = recordsInFile status.fileCounter = fileCounter - status.latestFilePath = currentPath } currentWriterId.partitionValues = None currentWriterId.bucketId = None currentWriter = null recordsInFile = 0 fileCounter = 0 - currentPath = null } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 15df8d612f0a0..f7c324dcb0a9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -195,7 +195,7 @@ object FileFormatWriter extends Logging { global = false, child = empty2NullPlan) - val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputWriters + val maxWriters = sparkSession.sessionState.conf.maxConcurrentOutputFileWriters val concurrentWritersEnabled = maxWriters > 0 && sortColumns.isEmpty if (concurrentWritersEnabled) { (empty2NullPlan.execute(), @@ -302,7 +302,7 @@ object FileFormatWriter extends Logging { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. dataWriter match { - case w: DynamicPartitionDataConcurrentWriter => + case w: BaseDynamicPartitionDataWriter => w.writeWithIterator(iterator) case _ => while (iterator.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala index 1d7abe5b938c2..7c479d986f3e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/OutputWriter.scala @@ -57,7 +57,7 @@ abstract class OutputWriterFactory extends Serializable { */ abstract class OutputWriter { /** - * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned * tables, dynamic partition columns are not included in rows to be written. */ def write(row: InternalRow): Unit @@ -67,4 +67,9 @@ abstract class OutputWriter { * the task output is committed. */ def close(): Unit + + /** + * The file path to write. Invoked on the executor side. + */ + def path(): String } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala index 2b549536ae486..2be744ac864da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -46,4 +46,6 @@ class CsvOutputWriter( override def write(row: InternalRow): Unit = gen.write(row) override def close(): Unit = gen.close() + + override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala index 719d72f5b9b52..7ea0a7d86087b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala @@ -58,4 +58,6 @@ class JsonOutputWriter( gen.close() writer.close() } + + override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index 08086bcd91f6e..e5c70ede2f894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -57,4 +57,6 @@ private[sql] class OrcOutputWriter( override def close(): Unit = { recordWriter.close(context) } + + override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 70f6726c581a2..66a1d6144d4a7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -39,4 +39,6 @@ class ParquetOutputWriter(path: String, context: TaskAttemptContext) override def write(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) + + override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala index 2b1b81f60ceb4..a6ea3763ec9eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala @@ -43,4 +43,6 @@ class TextOutputWriter( override def close(): Unit = { writer.close() } + + override def path(): String = path } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e7ec21def325d..e1911e9831c0d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -1239,7 +1239,7 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with // concurrent writers with fallback 3 ).foreach { maxWriters => - withSQLConf(SQLConf.MAX_CONCURRENT_OUTPUT_WRITERS.key -> maxWriters.toString) { + withSQLConf(SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key -> maxWriters.toString) { spark.sql(queryToInsertTable).collect() checkAnswer(spark.table("t2").orderBy("k1"), spark.table("t1").orderBy("k1")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index c51c521cacba0..7ffdae4c4029a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -160,4 +160,6 @@ class HiveOutputWriter( // Seems the boolean value passed into close does not matter. hiveWriter.close(false) } + + override def path(): String = path } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4707311341fcb..4fff5cd0d06ea 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -303,6 +303,8 @@ private[orc] class OrcOutputWriter( recordWriter.close(Reporter.NULL) } } + + override def path(): String = path } private[orc] object OrcFileFormat extends HiveInspectors with Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index d1b97b2852fbc..2b87a356c7f7c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -134,6 +134,10 @@ class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: Task override def close(): Unit = { writer.close() } + + override def path(): String = { + path + } } object SimpleTextRelation { From a498d9c8ca7d8723184d1cbd2cba4ef8bed04df9 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 21 Apr 2021 14:45:19 -0700 Subject: [PATCH 05/10] Move writeWithIterator to FileFormatDataWriter --- .../datasources/FileFormatDataWriter.scala | 19 +++++++++---------- .../datasources/FileFormatWriter.scala | 9 +-------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index ea54c588b27c8..05a7a517573de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -71,9 +71,17 @@ abstract class FileFormatDataWriter( releaseCurrentWriter() } - /** Writes a record */ + /** Writes a record. */ def write(record: InternalRow): Unit + + /** Write an iterator of records. */ + def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { + while (iterator.hasNext) { + write(iterator.next()) + } + } + /** * Returns the summary of relative information which * includes the list of partition strings written out. The list of partitions is sent back @@ -298,15 +306,6 @@ abstract class BaseDynamicPartitionDataWriter( statsTrackers.foreach(_.newRow(outputRow)) recordsInFile += 1 } - - /** - * Write an iterator of records. - */ - def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { - while (iterator.hasNext) { - write(iterator.next()) - } - } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index f7c324dcb0a9d..6839a4db0bc28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -301,14 +301,7 @@ object FileFormatWriter extends Logging { try { Utils.tryWithSafeFinallyAndFailureCallbacks(block = { // Execute the task to write rows out and commit the task. - dataWriter match { - case w: BaseDynamicPartitionDataWriter => - w.writeWithIterator(iterator) - case _ => - while (iterator.hasNext) { - dataWriter.write(iterator.next()) - } - } + dataWriter.writeWithIterator(iterator) dataWriter.commit() })(catchBlock = { // If there is an error, abort the task From a4ed53b702f1514de66a5696989abfc85e14a551 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Wed, 21 Apr 2021 15:34:35 -0700 Subject: [PATCH 06/10] Fix the bug that did not update map for increased file counter --- .../datasources/FileFormatDataWriter.scala | 68 ++++++++++--------- .../sql/test/DataFrameReaderWriterSuite.scala | 6 ++ 2 files changed, 42 insertions(+), 32 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 05a7a517573de..b47487808eee0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -274,25 +274,22 @@ abstract class BaseDynamicPartitionDataWriter( } /** - * Checks if number of records exceeding limit. Open a new OutputWriter if needed. + * Increase the file counter and open a new OutputWriter. + * This is used when number of records records exceeding limit. * * @param partitionValues the partition which all tuples being written by this `OutputWriter` * belong to * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to */ - protected def checkRecordsInFile( + protected def increaseFileCounter( partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { - if (description.maxRecordsPerFile > 0 && - recordsInFile >= description.maxRecordsPerFile) { - // Exceeded the threshold in terms of the number of records per file. - // Create a new file by increasing the file counter. - fileCounter += 1 - assert(fileCounter < MAX_FILE_COUNTER, - s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - - newOutputWriter(partitionValues, bucketId, closeCurrentWriter = true) - } + // Exceeded the threshold in terms of the number of records per file. + // Create a new file by increasing the file counter. + fileCounter += 1 + assert(fileCounter < MAX_FILE_COUNTER, + s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") + newOutputWriter(partitionValues, bucketId, closeCurrentWriter = true) } /** @@ -338,8 +335,9 @@ class DynamicPartitionDataSingleWriter( fileCounter = 0 newOutputWriter(currentPartitionValues, currentBucketId, true) - } else { - checkRecordsInFile(currentPartitionValues, currentBucketId) + } else if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + increaseFileCounter(currentPartitionValues, currentBucketId) } writeRecord(record) } @@ -408,7 +406,19 @@ class DynamicPartitionDataConcurrentWriter( if (currentWriterId.partitionValues != nextPartitionValues || currentWriterId.bucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). - updateCurrentWriterStatusInMap() + if (currentWriter != null) { + if (!sortBased) { + // Update writer status in concurrent writers map, because the writer is probably needed + // again later for writing other rows. + updateCurrentWriterStatusInMap() + } else { + // Remove writer status in concurrent writers map and release current writer resource, + // because the writer is not needed any more. + concurrentWriters.remove(currentWriterId) + releaseCurrentWriter() + } + } + if (isBucketed) { currentWriterId.bucketId = nextBucketId } @@ -421,7 +431,12 @@ class DynamicPartitionDataConcurrentWriter( retrieveWriterInMap() } - checkRecordsInFile(currentWriterId.partitionValues, currentWriterId.bucketId) + if (description.maxRecordsPerFile > 0 && + recordsInFile >= description.maxRecordsPerFile) { + increaseFileCounter(currentWriterId.partitionValues, currentWriterId.bucketId) + // Update writer status in concurrent writers map, as a new writer is created. + updateCurrentWriterStatusInMap() + } writeRecord(record) } @@ -444,24 +459,13 @@ class DynamicPartitionDataConcurrentWriter( } /** - * Update current writer status when a new writer is needed for writing row. + * Update current writer status in map. */ private def updateCurrentWriterStatusInMap(): Unit = { - if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { - if (!sortBased) { - // Update writer status in concurrent writers map, because the writer is probably needed - // again later for writing other rows. - val status = concurrentWriters(currentWriterId) - status.outputWriter = currentWriter - status.recordsInFile = recordsInFile - status.fileCounter = fileCounter - } else { - // Remove writer status in concurrent writers map and release writer resource, - // because the writer is not needed any more. - concurrentWriters.remove(currentWriterId) - super.releaseResources() - } - } + val status = concurrentWriters(currentWriterId) + status.outputWriter = currentWriter + status.recordsInFile = recordsInFile + status.fileCounter = fileCounter } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e1911e9831c0d..37f66e3ed8c32 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -1243,6 +1243,12 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with spark.sql(queryToInsertTable).collect() checkAnswer(spark.table("t2").orderBy("k1"), spark.table("t1").orderBy("k1")) + + withSQLConf(SQLConf.MAX_RECORDS_PER_FILE.key -> "1") { + spark.sql(queryToInsertTable).collect() + checkAnswer(spark.table("t2").orderBy("k1"), + spark.table("t1").orderBy("k1")) + } } } } From 1e8d447e8c8d6a4cd9954a719d8a55907bcb3e5d Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 22 Apr 2021 12:09:44 -0700 Subject: [PATCH 07/10] Address all comments --- .../spark/sql/avro/AvroOutputWriter.scala | 4 +--- .../ml/source/libsvm/LibSVMRelation.scala | 4 +--- .../datasources/FileFormatDataWriter.scala | 24 ++++++++----------- .../datasources/csv/CsvOutputWriter.scala | 4 +--- .../datasources/json/JsonOutputWriter.scala | 4 +--- .../datasources/orc/OrcOutputWriter.scala | 4 +--- .../parquet/ParquetOutputWriter.scala | 4 +--- .../datasources/text/TextOutputWriter.scala | 4 +--- .../sql/hive/execution/HiveFileFormat.scala | 4 +--- .../spark/sql/hive/orc/OrcFileFormat.scala | 4 +--- .../sql/sources/SimpleTextRelation.scala | 6 +---- 11 files changed, 20 insertions(+), 46 deletions(-) diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala index ed403ee37c4ad..424526eafdfaa 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroOutputWriter.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.types._ // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[avro] class AvroOutputWriter( - path: String, + val path: String, context: TaskAttemptContext, schema: StructType, avroSchema: Schema) extends OutputWriter { @@ -84,6 +84,4 @@ private[avro] class AvroOutputWriter( } override def close(): Unit = recordWriter.close(context) - - override def path(): String = path } diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 0b35ec3094cd4..837883e53d306 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.SerializableConfiguration private[libsvm] class LibSVMOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { @@ -64,8 +64,6 @@ private[libsvm] class LibSVMOutputWriter( override def close(): Unit = { writer.close() } - - override def path(): String = path } /** @see [[LibSVMDataSource]] for public documentation. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index b47487808eee0..0f62b14c7b9cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -237,7 +237,7 @@ abstract class BaseDynamicPartitionDataWriter( * @param bucketId the bucket which all tuples being written by this OutputWriter belong to * @param closeCurrentWriter close and release resource for current writer */ - protected def newOutputWriter( + protected def renewCurrentWriter( partitionValues: Option[InternalRow], bucketId: Option[Int], closeCurrentWriter: Boolean): Unit = { @@ -289,7 +289,7 @@ abstract class BaseDynamicPartitionDataWriter( fileCounter += 1 assert(fileCounter < MAX_FILE_COUNTER, s"File counter $fileCounter is beyond max value $MAX_FILE_COUNTER") - newOutputWriter(partitionValues, bucketId, closeCurrentWriter = true) + renewCurrentWriter(partitionValues, bucketId, closeCurrentWriter = true) } /** @@ -334,7 +334,7 @@ class DynamicPartitionDataSingleWriter( } fileCounter = 0 - newOutputWriter(currentPartitionValues, currentBucketId, true) + renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } else if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { increaseFileCounter(currentPartitionValues, currentBucketId) @@ -378,7 +378,7 @@ class DynamicPartitionDataConcurrentWriter( * State to indicate if we are falling back to sort-based writer. * Because we first try to use concurrent writers, its initial value is false. */ - private var sortBased: Boolean = false + private var sorted: Boolean = false private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]() private val currentWriterId = WriterIndex(None, None) @@ -407,7 +407,7 @@ class DynamicPartitionDataConcurrentWriter( currentWriterId.bucketId != nextBucketId) { // See a new partition or bucket - write to a new partition dir (or a new bucket file). if (currentWriter != null) { - if (!sortBased) { + if (!sorted) { // Update writer status in concurrent writers map, because the writer is probably needed // again later for writing other rows. updateCurrentWriterStatusInMap() @@ -444,7 +444,7 @@ class DynamicPartitionDataConcurrentWriter( * Write iterator of records with concurrent writers. */ override def writeWithIterator(iterator: Iterator[InternalRow]): Unit = { - while (iterator.hasNext && !sortBased) { + while (iterator.hasNext && !sorted) { write(iterator.next()) } @@ -479,17 +479,16 @@ class DynamicPartitionDataConcurrentWriter( fileCounter = status.fileCounter } else { fileCounter = 0 - newOutputWriter( + renewCurrentWriter( currentWriterId.partitionValues, currentWriterId.bucketId, closeCurrentWriter = false) concurrentWriters.put( WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), new WriterStatus(currentWriter, recordsInFile, fileCounter)) - if (concurrentWriters.size > concurrentOutputWriterSpec.maxWriters && - !sortBased) { + if (concurrentWriters.size > concurrentOutputWriterSpec.maxWriters && !sorted) { // Fall back to sort-based sequential writer mode. - sortBased = true + sorted = true } } } @@ -499,10 +498,7 @@ class DynamicPartitionDataConcurrentWriter( */ private def clearCurrentWriterStatus(): Unit = { if (currentWriterId.partitionValues.isDefined || currentWriterId.bucketId.isDefined) { - val status = concurrentWriters(currentWriterId) - status.outputWriter = currentWriter - status.recordsInFile = recordsInFile - status.fileCounter = fileCounter + updateCurrentWriterStatusInMap() } currentWriterId.partitionValues = None currentWriterId.bucketId = None diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala index 2be744ac864da..35d0e098b19e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CsvOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class CsvOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext, params: CSVOptions) extends OutputWriter with Logging { @@ -46,6 +46,4 @@ class CsvOutputWriter( override def write(row: InternalRow): Unit = gen.write(row) override def close(): Unit = gen.close() - - override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala index 7ea0a7d86087b..55602ce2ed9b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JsonOutputWriter.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class JsonOutputWriter( - path: String, + val path: String, options: JSONOptions, dataSchema: StructType, context: TaskAttemptContext) @@ -58,6 +58,4 @@ class JsonOutputWriter( gen.close() writer.close() } - - override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala index e5c70ede2f894..6f215737f5703 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOutputWriter.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.OutputWriter import org.apache.spark.sql.types._ private[sql] class OrcOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { @@ -57,6 +57,4 @@ private[sql] class OrcOutputWriter( override def close(): Unit = { recordWriter.close(context) } - - override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala index 66a1d6144d4a7..efb322f3fc906 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOutputWriter.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.OutputWriter // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -class ParquetOutputWriter(path: String, context: TaskAttemptContext) +class ParquetOutputWriter(val path: String, context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { @@ -39,6 +39,4 @@ class ParquetOutputWriter(path: String, context: TaskAttemptContext) override def write(row: InternalRow): Unit = recordWriter.write(null, row) override def close(): Unit = recordWriter.close(context) - - override def path(): String = path } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala index a6ea3763ec9eb..2fb37c0dc0359 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/TextOutputWriter.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.execution.datasources.{CodecStreams, OutputWriter} import org.apache.spark.sql.types.StructType class TextOutputWriter( - path: String, + val path: String, dataSchema: StructType, lineSeparator: Array[Byte], context: TaskAttemptContext) @@ -43,6 +43,4 @@ class TextOutputWriter( override def close(): Unit = { writer.close() } - - override def path(): String = path } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala index 7ffdae4c4029a..d4ec590f79f5e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveFileFormat.scala @@ -108,7 +108,7 @@ class HiveFileFormat(fileSinkConf: FileSinkDesc) } class HiveOutputWriter( - path: String, + val path: String, fileSinkConf: FileSinkDesc, jobConf: JobConf, dataSchema: StructType) extends OutputWriter with HiveInspectors { @@ -160,6 +160,4 @@ class HiveOutputWriter( // Seems the boolean value passed into close does not matter. hiveWriter.close(false) } - - override def path(): String = path } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4fff5cd0d06ea..d2ac06ad0a16a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -271,7 +271,7 @@ private[orc] class OrcSerializer(dataSchema: StructType, conf: Configuration) } private[orc] class OrcOutputWriter( - path: String, + val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { @@ -303,8 +303,6 @@ private[orc] class OrcOutputWriter( recordWriter.close(Reporter.NULL) } } - - override def path(): String = path } private[orc] object OrcFileFormat extends HiveInspectors with Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 2b87a356c7f7c..debe1ab734cc9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -117,7 +117,7 @@ class SimpleTextSource extends TextBasedFileFormat with DataSourceRegister { } } -class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) +class SimpleTextOutputWriter(val path: String, dataSchema: StructType, context: TaskAttemptContext) extends OutputWriter { private val writer = CodecStreams.createOutputStreamWriter(context, new Path(path)) @@ -134,10 +134,6 @@ class SimpleTextOutputWriter(path: String, dataSchema: StructType, context: Task override def close(): Unit = { writer.close() } - - override def path(): String = { - path - } } object SimpleTextRelation { From 17a80a6f02d2011f7dc6d9ee6c4a9843af63fa62 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Thu, 22 Apr 2021 16:13:27 -0700 Subject: [PATCH 08/10] Address comments for doc and max value check around concurrentWriters --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 2 +- .../execution/datasources/FileFormatDataWriter.scala | 11 ++++++++++- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d29f4d01d240d..9d09715d25932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -3153,7 +3153,7 @@ object SQLConf { val MAX_CONCURRENT_OUTPUT_FILE_WRITERS = buildConf("spark.sql.maxConcurrentOutputFileWriters") .internal() .doc("Maximum number of output file writers to use concurrently. If number of writers " + - "needed exceeds this limit, task will sort rest of output then writing them.") + "needed reaches this limit, task will sort rest of output then writing them.") .version("3.2.0") .intConf .createWithDefault(0) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 0f62b14c7b9cd..a961899080bf6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -483,10 +483,19 @@ class DynamicPartitionDataConcurrentWriter( currentWriterId.partitionValues, currentWriterId.bucketId, closeCurrentWriter = false) + if (!sorted) { + assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters, + s"Number of concurrent output file writers is ${concurrentWriters.size} " + + s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters}") + } else { + assert(concurrentWriters.size <= concurrentOutputWriterSpec.maxWriters + 1, + s"Number of output file writers after sort is ${concurrentWriters.size} " + + s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters + 1}") + } concurrentWriters.put( WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), new WriterStatus(currentWriter, recordsInFile, fileCounter)) - if (concurrentWriters.size > concurrentOutputWriterSpec.maxWriters && !sorted) { + if (concurrentWriters.size >= concurrentOutputWriterSpec.maxWriters && !sorted) { // Fall back to sort-based sequential writer mode. sorted = true } From 2895837e8d80e371388a23207f3c897808e62a55 Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Sun, 25 Apr 2021 22:49:26 -0700 Subject: [PATCH 09/10] Address all comments --- .../datasources/FileFormatDataWriter.scala | 20 +++++++++++++++---- .../sql/test/DataFrameReaderWriterSuite.scala | 6 ++++-- 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index a961899080bf6..9e4d2e3f6e713 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -21,6 +21,7 @@ import scala.collection.mutable import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.sql.catalyst.InternalRow @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage} import org.apache.spark.sql.execution.datasources.FileFormatWriter.ConcurrentOutputWriterSpec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StringType import org.apache.spark.util.SerializableConfiguration @@ -361,7 +363,8 @@ class DynamicPartitionDataConcurrentWriter( taskAttemptContext: TaskAttemptContext, committer: FileCommitProtocol, concurrentOutputWriterSpec: ConcurrentOutputWriterSpec) - extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) { + extends BaseDynamicPartitionDataWriter(description, taskAttemptContext, committer) + with Logging { /** Wrapper class to index a unique concurrent output writer. */ private case class WriterIndex( @@ -380,6 +383,12 @@ class DynamicPartitionDataConcurrentWriter( */ private var sorted: Boolean = false private val concurrentWriters = mutable.HashMap[WriterIndex, WriterStatus]() + + /** + * The index for current writer. Intentionally make the index mutable and reusable. + * Avoid JVM GC issue when many short-living `WriterIndex` objects are created + * if switching between concurrent writers frequently. + */ private val currentWriterId = WriterIndex(None, None) /** @@ -428,7 +437,7 @@ class DynamicPartitionDataConcurrentWriter( statsTrackers.foreach(_.newPartition(currentWriterId.partitionValues.get)) } } - retrieveWriterInMap() + setupCurrentWriterUsingMap() } if (description.maxRecordsPerFile > 0 && @@ -471,7 +480,7 @@ class DynamicPartitionDataConcurrentWriter( /** * Retrieve writer in map, or create a new writer if not exists. */ - private def retrieveWriterInMap(): Unit = { + private def setupCurrentWriterUsingMap(): Unit = { if (concurrentWriters.contains(currentWriterId)) { val status = concurrentWriters(currentWriterId) currentWriter = status.outputWriter @@ -493,10 +502,13 @@ class DynamicPartitionDataConcurrentWriter( s" which is beyond max value ${concurrentOutputWriterSpec.maxWriters + 1}") } concurrentWriters.put( - WriterIndex(currentWriterId.partitionValues, currentWriterId.bucketId), + currentWriterId.copy(), new WriterStatus(currentWriter, recordsInFile, fileCounter)) if (concurrentWriters.size >= concurrentOutputWriterSpec.maxWriters && !sorted) { // Fall back to sort-based sequential writer mode. + logInfo(s"Number of concurrent writers ${concurrentWriters.size} reaches the threshold. " + + "Fall back from concurrent writers to sort-based sequential writer. You may change " + + s"threshold with configuration ${SQLConf.MAX_CONCURRENT_OUTPUT_FILE_WRITERS.key}") sorted = true } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 37f66e3ed8c32..41d11568750cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.test import java.io.File -import java.util.Locale +import java.util.{Locale, Random} import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ @@ -1222,8 +1222,10 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSparkSession with test("SPARK-26164: Allow concurrent writers for multiple partitions and buckets") { withTable("t1", "t2") { + // Uses fixed seed to ensure reproducible test execution + val r = new Random(31) val df = spark.range(200).map(_ => { - val n = scala.util.Random.nextInt + val n = r.nextInt() (n, n.toString, n % 5) }).toDF("k1", "k2", "part") df.write.format("parquet").saveAsTable("t1") From efe026cc797db994787b6a556aee699d6a79d4af Mon Sep 17 00:00:00 2001 From: Cheng Su Date: Mon, 26 Apr 2021 12:23:17 -0700 Subject: [PATCH 10/10] Address comment for renaming increaseFileCounter --- .../sql/execution/datasources/FileFormatDataWriter.scala | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala index 9e4d2e3f6e713..8230737a61ca0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatDataWriter.scala @@ -276,14 +276,13 @@ abstract class BaseDynamicPartitionDataWriter( } /** - * Increase the file counter and open a new OutputWriter. - * This is used when number of records records exceeding limit. + * Open a new output writer when number of records exceeding limit. * * @param partitionValues the partition which all tuples being written by this `OutputWriter` * belong to * @param bucketId the bucket which all tuples being written by this `OutputWriter` belong to */ - protected def increaseFileCounter( + protected def renewCurrentWriterIfTooManyRecords( partitionValues: Option[InternalRow], bucketId: Option[Int]): Unit = { // Exceeded the threshold in terms of the number of records per file. @@ -339,7 +338,7 @@ class DynamicPartitionDataSingleWriter( renewCurrentWriter(currentPartitionValues, currentBucketId, closeCurrentWriter = true) } else if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - increaseFileCounter(currentPartitionValues, currentBucketId) + renewCurrentWriterIfTooManyRecords(currentPartitionValues, currentBucketId) } writeRecord(record) } @@ -442,7 +441,7 @@ class DynamicPartitionDataConcurrentWriter( if (description.maxRecordsPerFile > 0 && recordsInFile >= description.maxRecordsPerFile) { - increaseFileCounter(currentWriterId.partitionValues, currentWriterId.bucketId) + renewCurrentWriterIfTooManyRecords(currentWriterId.partitionValues, currentWriterId.bucketId) // Update writer status in concurrent writers map, as a new writer is created. updateCurrentWriterStatusInMap() }