From ab410fc48a039f161693ea81f8d15dbd3041d57f Mon Sep 17 00:00:00 2001 From: Adi Muraru Date: Sun, 15 Sep 2019 14:44:29 +0000 Subject: [PATCH] [SPARK-27734][CORE][SQL] Add memory based thresholds for shuffle spill When running large shuffles (700TB input data, 200k map tasks, 50k reducers on a 300 nodes cluster) the job is regularly OOMing in map and reduce phase. IIUC ShuffleExternalSorter (map side) and ExternalAppendOnlyMap and ExternalSorter (reduce side) are trying to max out the available execution memory. This in turn doesn't play nice with the Garbage Collector and executors are failing with OutOfMemoryError when the memory allocation from these in-memory structure is maxing out the available heap size (in our case we are running with 9 cores/executor, 32G per executor) To mitigate this, one can set spark.shuffle.spill.numElementsForceSpillThreshold to force the spill on disk. While this config works, it is not flexible enough as it's expressed in number of elements, and in our case we run multiple shuffles in a single job and element size is different from one stage to another. This patch extends the spill threshold behaviour and adds two new parameters to control the spill based on memory usage: - spark.shuffle.spill.map.maxRecordsSizeForSpillThreshold - spark.shuffle.spill.reduce.maxRecordsSizeForSpillThreshold --- .../shuffle/sort/ShuffleExternalSorter.java | 19 ++++++++++-- .../unsafe/sort/UnsafeExternalSorter.java | 25 ++++++++++++--- .../spark/internal/config/package.scala | 20 ++++++++++++ .../spark/util/collection/Spillable.scala | 19 ++++++++---- .../sort/UnsafeExternalSorterSuite.java | 14 ++++++--- .../apache/spark/sql/internal/SQLConf.scala | 31 +++++++++++++++++++ .../execution/UnsafeExternalRowSorter.java | 2 ++ .../UnsafeFixedWidthAggregationMap.java | 2 ++ .../sql/execution/UnsafeKVExternalSorter.java | 8 +++-- .../ExternalAppendOnlyUnsafeRowArray.scala | 11 +++++-- .../aggregate/ObjectAggregationIterator.scala | 1 + .../aggregate/ObjectAggregationMap.scala | 1 + .../joins/CartesianProductExec.scala | 9 ++++-- .../execution/joins/SortMergeJoinExec.scala | 14 ++++++++- .../execution/python/WindowInPandasExec.scala | 4 ++- .../sql/execution/window/WindowExec.scala | 4 ++- ...nalAppendOnlyUnsafeRowArrayBenchmark.scala | 8 +++-- ...xternalAppendOnlyUnsafeRowArraySuite.scala | 3 +- .../UnsafeKVExternalSorterSuite.scala | 5 ++- 19 files changed, 167 insertions(+), 33 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index 833744f4777ce..0ac3a7b891a64 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -83,6 +83,11 @@ final class ShuffleExternalSorter extends MemoryConsumer { */ private final int numElementsForSpillThreshold; + /** + * Force this sorter to spill when the size in memory is beyond this threshold. + */ + private final long recordsSizeForSpillThreshold; + /** The buffer size to use when writing spills using DiskBlockObjectWriter */ private final int fileBufferSizeBytes; @@ -106,6 +111,7 @@ final class ShuffleExternalSorter extends MemoryConsumer { @Nullable private ShuffleInMemorySorter inMemSorter; @Nullable private MemoryBlock currentPage = null; private long pageCursor = -1; + private long inMemRecordsSize = 0; ShuffleExternalSorter( TaskMemoryManager memoryManager, @@ -127,6 +133,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.numElementsForSpillThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + this.recordsSizeForSpillThreshold = + (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()); this.writeMetrics = writeMetrics; this.inMemSorter = new ShuffleInMemorySorter( this, initialSize, (boolean) conf.get(package$.MODULE$.SHUFFLE_SORT_USE_RADIXSORT())); @@ -316,6 +324,7 @@ private long freeMemory() { allocatedPages.clear(); currentPage = null; pageCursor = 0; + inMemRecordsSize = 0; return memoryFreed; } @@ -397,11 +406,14 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p // for tests assert(inMemSorter != null); if (inMemSorter.numRecords() >= numElementsForSpillThreshold) { - logger.info("Spilling data because number of spilledRecords crossed the threshold " + - numElementsForSpillThreshold); + logger.info("Spilling data because number of spilledRecords ({}) crossed the threshold: {}", + inMemSorter.numRecords(), numElementsForSpillThreshold); + spill(); + } else if (inMemRecordsSize >= recordsSizeForSpillThreshold) { + logger.info("Spilling data because size of spilledRecords ({}) crossed the threshold: {}", + inMemRecordsSize, recordsSizeForSpillThreshold); spill(); } - growPointerArrayIfNecessary(); final int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 or 8 bytes to store the record length. @@ -416,6 +428,7 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, partitionId); + inMemRecordsSize += length; } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 55e4e609c3c7b..d63f2fd0ab880 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -74,6 +74,11 @@ public final class UnsafeExternalSorter extends MemoryConsumer { */ private final int numElementsForSpillThreshold; + /** + * Force this sorter to spill when the size in memory is beyond this threshold. + */ + private final long maxRecordsSizeForSpillThreshold; + /** * Memory pages that hold the records being sorted. The pages in this list are freed when * spilling, although in principle we could recycle these pages across spills (on the other hand, @@ -86,6 +91,7 @@ public final class UnsafeExternalSorter extends MemoryConsumer { // These variables are reset after spilling: @Nullable private volatile UnsafeInMemorySorter inMemSorter; + private long inMemRecordsSize = 0; private MemoryBlock currentPage = null; private long pageCursor = -1; @@ -104,10 +110,12 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, UnsafeInMemorySorter inMemorySorter) throws IOException { UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager, - serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, - pageSizeBytes, numElementsForSpillThreshold, inMemorySorter, false /* ignored */); + serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, + pageSizeBytes, numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, + inMemorySorter, false /* ignored */); sorter.spill(Long.MAX_VALUE, sorter); // The external sorter will be used to insert records, in-memory sorter is not needed. sorter.inMemSorter = null; @@ -124,10 +132,11 @@ public static UnsafeExternalSorter create( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, boolean canUseRadixSort) { return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager, taskContext, recordComparatorSupplier, prefixComparator, initialSize, pageSizeBytes, - numElementsForSpillThreshold, null, canUseRadixSort); + numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null, canUseRadixSort); } private UnsafeExternalSorter( @@ -140,6 +149,7 @@ private UnsafeExternalSorter( int initialSize, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, @Nullable UnsafeInMemorySorter existingInMemorySorter, boolean canUseRadixSort) { super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode()); @@ -170,6 +180,7 @@ private UnsafeExternalSorter( } this.peakMemoryUsedBytes = getMemoryUsage(); this.numElementsForSpillThreshold = numElementsForSpillThreshold; + this.maxRecordsSizeForSpillThreshold = maxRecordsSizeForSpillThreshold; // Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at // the end of the task. This is necessary to avoid memory leaks in when the downstream operator @@ -228,7 +239,7 @@ public long spill(long size, MemoryConsumer trigger) throws IOException { // Reset the in-memory sorter's pointer array only after freeing up the memory pages holding the // records. Otherwise, if the task is over allocated memory, then without freeing the memory // pages, we might not be able to get memory for the pointer array. - + inMemRecordsSize = 0; taskContext.taskMetrics().incMemoryBytesSpilled(spillSize); taskContext.taskMetrics().incDiskBytesSpilled(writeMetrics.bytesWritten()); totalSpillBytes += spillSize; @@ -396,8 +407,11 @@ public void insertRecord( logger.info("Spilling data because number of spilledRecords crossed the threshold " + numElementsForSpillThreshold); spill(); + } else if (inMemRecordsSize >= maxRecordsSizeForSpillThreshold) { + logger.info("Spilling data because size of spilledRecords crossed the threshold " + + maxRecordsSizeForSpillThreshold); + spill(); } - growPointerArrayIfNecessary(); int uaoSize = UnsafeAlignedOffset.getUaoSize(); // Need 4 or 8 bytes to store the record length. @@ -411,6 +425,7 @@ public void insertRecord( Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length); pageCursor += length; inMemSorter.insertRecord(recordAddress, prefix, prefixIsNull); + inMemRecordsSize += length; } /** diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 652db2bdf90a9..71ad2ecc2e7a6 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -998,6 +998,26 @@ package object config { .intConf .createWithDefault(Integer.MAX_VALUE) + private[spark] val SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.map.maxRecordsSizeForSpillThreshold") + .internal() + .doc("The maximum size in memory before forcing the map-side shuffle sorter to spill. " + + "By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Long.MaxValue) + + private[spark] val SHUFFLE_SPILL_REDUCE_MAX_SIZE_FORCE_SPILL_THRESHOLD = + ConfigBuilder("spark.shuffle.spill.reduce.maxRecordsSizeForSpillThreshold") + .internal() + .doc("The maximum size in memory before forcing the reduce-side to spill. " + + "By default it is Long.MAX_VALUE, which means we never force the sorter to spill, " + + "until we reach some limitations, like the max page size limitation for the pointer " + + "array in the sorter.") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(Long.MaxValue) + private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD = ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold") .internal() diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 1983b0002853d..e6df4ef799ae6 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -53,6 +53,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) private[this] val initialMemoryThreshold: Long = SparkEnv.get.conf.get(SHUFFLE_SPILL_INITIAL_MEM_THRESHOLD) + // Force this collection to spill when its size is greater than this threshold + private[this] val maxSizeForceSpillThreshold: Long = + SparkEnv.get.conf.get(SHUFFLE_SPILL_REDUCE_MAX_SIZE_FORCE_SPILL_THRESHOLD) + // Force this collection to spill when there are this many elements in memory // For testing only private[this] val numElementsForceSpillThreshold: Int = @@ -81,7 +85,11 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) */ protected def maybeSpill(collection: C, currentMemory: Long): Boolean = { var shouldSpill = false - if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { + // Check number of elements or memory usage limits, whichever is hit first + if (_elementsRead > numElementsForceSpillThreshold + || currentMemory > maxSizeForceSpillThreshold) { + shouldSpill = true + } else if (elementsRead % 32 == 0 && currentMemory >= myMemoryThreshold) { // Claim up to double our current memory from the shuffle memory pool val amountToRequest = 2 * currentMemory - myMemoryThreshold val granted = acquireMemory(amountToRequest) @@ -90,11 +98,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) // or we already had more memory than myMemoryThreshold), spill the current collection shouldSpill = currentMemory >= myMemoryThreshold } - shouldSpill = shouldSpill || _elementsRead > numElementsForceSpillThreshold // Actually spill if (shouldSpill) { _spillCount += 1 - logSpillage(currentMemory) + logSpillage(currentMemory, elementsRead) spill(collection) _elementsRead = 0 _memoryBytesSpilled += currentMemory @@ -141,10 +148,10 @@ private[spark] abstract class Spillable[C](taskMemoryManager: TaskMemoryManager) * * @param size number of bytes spilled */ - @inline private def logSpillage(size: Long): Unit = { + @inline private def logSpillage(size: Long, elements: Int) { val threadId = Thread.currentThread().getId - logInfo("Thread %d spilling in-memory map of %s to disk (%d time%s so far)" - .format(threadId, org.apache.spark.util.Utils.bytesToString(size), + logInfo("Thread %d spilling in-memory map of %s (elements: %d) to disk (%d time%s so far)" + .format(threadId, org.apache.spark.util.Utils.bytesToString(size), elements, _spillCount, if (_spillCount > 1) "s" else "")) } } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 43977717f6c97..0b202f98d4583 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -91,9 +91,12 @@ public int compare( private final long pageSizeBytes = conf.getSizeAsBytes( package$.MODULE$.BUFFER_PAGESIZE().key(), "4m"); - private final int spillThreshold = + private final int spillElementsThreshold = (int) conf.get(package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()); + private final long spillSizeThreshold = + (long) conf.get(package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()); + @Before public void setUp() { MockitoAnnotations.initMocks(this); @@ -167,7 +170,8 @@ private UnsafeExternalSorter newSorter() throws IOException { prefixComparator, /* initialSize */ 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); } @@ -394,7 +398,8 @@ public void forcedSpillingWithoutComparator() throws Exception { null, /* initialSize */ 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); long[] record = new long[100]; int recordSize = record.length * 8; @@ -456,7 +461,8 @@ public void testPeakMemoryUsed() throws Exception { prefixComparator, 1024, pageSizeBytes, - spillThreshold, + spillElementsThreshold, + spillSizeThreshold, shouldUseRadixSort()); // Peak memory should be monotonically increasing. More specifically, every time diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 700f6b773727b..ebfca1f88708b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1516,6 +1516,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.windowExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by window operator") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(SHUFFLE_SPILL_REDUCE_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.sortMergeJoinExec.buffer.in.memory.threshold") .internal() @@ -1531,6 +1538,13 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + val SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.sortMergeJoinExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by sort merge join operator") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get) + val CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD = buildConf("spark.sql.cartesianProductExec.buffer.in.memory.threshold") .internal() @@ -1546,6 +1560,15 @@ object SQLConf { .intConf .createWithDefault(SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD = + buildConf("spark.sql.cartesianProductExec.buffer.spill.size.threshold") + .internal() + .doc("Threshold for size of rows to be spilled by cartesian product operator") + .bytesConf(ByteUnit.BYTE) + .createWithDefault(SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get) + + val SUPPORT_QUOTED_REGEX_COLUMN_NAME = buildConf("spark.sql.parser.quotedRegexColumnNames") .doc("When true, quoted Identifiers (using backticks) in SELECT statement are interpreted" + " as regular expressions.") @@ -2647,18 +2670,26 @@ class SQLConf extends Serializable with Logging { def windowExecBufferSpillThreshold: Int = getConf(WINDOW_EXEC_BUFFER_SPILL_THRESHOLD) + def windowExecBufferSpillSizeThreshold: Long = getConf(WINDOW_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def sortMergeJoinExecBufferInMemoryThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_IN_MEMORY_THRESHOLD) def sortMergeJoinExecBufferSpillThreshold: Int = getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SPILL_THRESHOLD) + def sortMergeJoinExecBufferSpillSizeThreshold: Long = + getConf(SORT_MERGE_JOIN_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def cartesianProductExecBufferInMemoryThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_IN_MEMORY_THRESHOLD) def cartesianProductExecBufferSpillThreshold: Int = getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SPILL_THRESHOLD) + def cartesianProductExecBufferSizeSpillThreshold: Long = + getConf(CARTESIAN_PRODUCT_EXEC_BUFFER_SIZE_SPILL_THRESHOLD) + def codegenSplitAggregateFunc: Boolean = getConf(SQLConf.CODEGEN_SPLIT_AGGREGATE_FUNC) def maxNestedViewDepth: Int = getConf(SQLConf.MAX_NESTED_VIEW_DEPTH) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 90b55a8586de7..bd5dbc82b07f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -120,6 +120,8 @@ private UnsafeExternalRowSorter( pageSizeBytes, (int) SparkEnv.get().conf().get( package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + (long) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()), canUseRadixSort ); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index 117e98f33a0ec..ffd4e67fe5411 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -247,6 +247,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti map.getPageSizeBytes(), (int) SparkEnv.get().conf().get( package$.MODULE$.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD()), + (long) SparkEnv.get().conf().get( + package$.MODULE$.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD()), map); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index acd54fe25d62d..d0ac80da6a69a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -59,9 +59,10 @@ public UnsafeKVExternalSorter( BlockManager blockManager, SerializerManager serializerManager, long pageSizeBytes, - int numElementsForSpillThreshold) throws IOException { + int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold) throws IOException { this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, - numElementsForSpillThreshold, null); + numElementsForSpillThreshold, maxRecordsSizeForSpillThreshold, null); } public UnsafeKVExternalSorter( @@ -71,6 +72,7 @@ public UnsafeKVExternalSorter( SerializerManager serializerManager, long pageSizeBytes, int numElementsForSpillThreshold, + long maxRecordsSizeForSpillThreshold, @Nullable BytesToBytesMap map) throws IOException { this.keySchema = keySchema; this.valueSchema = valueSchema; @@ -97,6 +99,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, + maxRecordsSizeForSpillThreshold, canUseRadixSort); } else { // During spilling, the pointer array in `BytesToBytesMap` will not be used, so we can borrow @@ -163,6 +166,7 @@ public UnsafeKVExternalSorter( (int) (long) SparkEnv.get().conf().get(package$.MODULE$.SHUFFLE_SORT_INIT_BUFFER_SIZE()), pageSizeBytes, numElementsForSpillThreshold, + maxRecordsSizeForSpillThreshold, inMemSorter); // reset the map, so we can re-use it to insert new records. the inMemSorter will not used diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala index ac282ea2e94f5..1c016d4f82415 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArray.scala @@ -50,9 +50,12 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize: Int, pageSizeBytes: Long, numRowsInMemoryBufferThreshold: Int, - numRowsSpillThreshold: Int) extends Logging { + numRowsSpillThreshold: Int, + maxSizeSpillThreshold: Long) extends Logging { - def this(numRowsInMemoryBufferThreshold: Int, numRowsSpillThreshold: Int) { + def this(numRowsInMemoryBufferThreshold: Int, + numRowsSpillThreshold: Int, + maxSizeSpillThreshold: Long) { this( TaskContext.get().taskMemoryManager(), SparkEnv.get.blockManager, @@ -61,7 +64,8 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( 1024, SparkEnv.get.memoryManager.pageSizeBytes, numRowsInMemoryBufferThreshold, - numRowsSpillThreshold) + numRowsSpillThreshold, + maxSizeSpillThreshold) } private val initialSizeOfInMemoryBuffer = @@ -122,6 +126,7 @@ private[sql] class ExternalAppendOnlyUnsafeRowArray( initialSize, pageSizeBytes, numRowsSpillThreshold, + maxSizeSpillThreshold, false) // populate with existing in-memory buffered rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala index 75651500954cf..b3e8cdcf24089 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationIterator.scala @@ -316,6 +316,7 @@ class SortBasedAggregator( SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD), null ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala index b5372bcca89dd..31676473ec8ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectAggregationMap.scala @@ -75,6 +75,7 @@ class ObjectAggregationMap() { SparkEnv.get.serializerManager, TaskContext.get().taskMemoryManager().pageSizeBytes, SparkEnv.get.conf.get(config.SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD), + SparkEnv.get.conf.get(config.SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD), null ) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala index 29645a736548c..a6476aa4b80ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala @@ -36,11 +36,13 @@ class UnsafeCartesianRDD( right : RDD[UnsafeRow], numFieldsOfRight: Int, inMemoryBufferThreshold: Int, - spillThreshold: Int) + spillThreshold: Int, + spillSizeThreshold: Long) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { override def compute(split: Partition, context: TaskContext): Iterator[(UnsafeRow, UnsafeRow)] = { - val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold) + val rowArray = new ExternalAppendOnlyUnsafeRowArray(inMemoryBufferThreshold, spillThreshold, + spillSizeThreshold) val partition = split.asInstanceOf[CartesianPartition] rdd2.iterator(partition.s2, context).foreach(rowArray.add) @@ -88,7 +90,8 @@ case class CartesianProductExec( rightResults, right.output.size, sqlContext.conf.cartesianProductExecBufferInMemoryThreshold, - sqlContext.conf.cartesianProductExecBufferSpillThreshold) + sqlContext.conf.cartesianProductExecBufferSpillThreshold, + sqlContext.conf.cartesianProductExecBufferSizeSpillThreshold) pair.mapPartitionsWithIndexInternal { (index, iter) => val joiner = GenerateUnsafeRowJoiner.create(left.schema, right.schema) val filtered = if (condition.isDefined) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 62eea611556ff..59197c6740bc1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -171,6 +171,10 @@ case class SortMergeJoinExec( sqlContext.conf.sortMergeJoinExecBufferSpillThreshold } + private def getSpillSizeThreshold: Long = { + sqlContext.conf.sortMergeJoinExecBufferSpillSizeThreshold + } + private def getInMemoryThreshold: Int = { sqlContext.conf.sortMergeJoinExecBufferInMemoryThreshold } @@ -178,6 +182,7 @@ case class SortMergeJoinExec( protected override def doExecute(): RDD[InternalRow] = { val numOutputRows = longMetric("numOutputRows") val spillThreshold = getSpillThreshold + val spillSizeThreshold = getSpillSizeThreshold val inMemoryThreshold = getInMemoryThreshold left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => val boundCondition: (InternalRow) => Boolean = { @@ -206,6 +211,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -251,6 +257,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) val rightNullRow = new GenericInternalRow(right.output.length) @@ -266,6 +273,7 @@ case class SortMergeJoinExec( bufferedIter = RowIterator.fromScala(leftIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) val leftNullRow = new GenericInternalRow(left.output.length) @@ -301,6 +309,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -337,6 +346,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -380,6 +390,7 @@ case class SortMergeJoinExec( RowIterator.fromScala(rightIter), inMemoryThreshold, spillThreshold, + spillSizeThreshold, cleanupResources ) private[this] val joinRow = new JoinedRow @@ -712,6 +723,7 @@ private[joins] class SortMergeJoinScanner( bufferedIter: RowIterator, inMemoryThreshold: Int, spillThreshold: Int, + spillSizeThreshold: Long, eagerCleanupResources: () => Unit) { private[this] var streamedRow: InternalRow = _ private[this] var streamedRowKey: InternalRow = _ @@ -724,7 +736,7 @@ private[joins] class SortMergeJoinScanner( private[this] var matchJoinKey: InternalRow = _ /** Buffered rows from the buffered side of the join. This is empty if there are no matches. */ private[this] val bufferedMatches = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, spillSizeThreshold) // Initialization (note: do _not_ want to advance streamed here). advancedBufferedToRowWithNullFreeJoinKey() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index f54c4b8f22066..7102456ca1b05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -199,6 +199,7 @@ case class WindowInPandasExec( val inMemoryThreshold = conf.windowExecBufferInMemoryThreshold val spillThreshold = conf.windowExecBufferSpillThreshold + val spillSizeThreshold = conf.windowExecBufferSpillSizeThreshold val sessionLocalTimeZone = conf.sessionLocalTimeZone // Extract window expressions and window functions @@ -318,7 +319,8 @@ case class WindowInPandasExec( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) var bufferIterator: Iterator[UnsafeRow] = _ val indexRow = new SpecificInternalRow(Array.fill(numBoundIndices)(IntegerType)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala index d191f3790ffa8..18efae24e96a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/window/WindowExec.scala @@ -110,6 +110,7 @@ case class WindowExec( val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray val inMemoryThreshold = sqlContext.conf.windowExecBufferInMemoryThreshold val spillThreshold = sqlContext.conf.windowExecBufferSpillThreshold + val spillSizeThreshold = sqlContext.conf.windowExecBufferSpillSizeThreshold // Start processing. child.execute().mapPartitions { stream => @@ -137,7 +138,8 @@ case class WindowExec( // Manage the current partition. val buffer: ExternalAppendOnlyUnsafeRowArray = - new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold) + new ExternalAppendOnlyUnsafeRowArray(inMemoryThreshold, spillThreshold, + spillSizeThreshold) var bufferIterator: Iterator[UnsafeRow] = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala index 0869e25674e69..6f66104b2c82b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArrayBenchmark.scala @@ -103,7 +103,8 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { for (_ <- 0L until iterations) { val array = new ExternalAppendOnlyUnsafeRowArray( ExternalAppendOnlyUnsafeRowArray.DefaultInitialSizeOfInMemoryBuffer, - numSpillThreshold) + numSpillThreshold, + Long.MaxValue) rows.foreach(x => array.add(x)) @@ -142,6 +143,7 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { 1024, SparkEnv.get.memoryManager.pageSizeBytes, numSpillThreshold, + Long.MaxValue, false) rows.foreach(x => @@ -166,7 +168,9 @@ object ExternalAppendOnlyUnsafeRowArrayBenchmark extends BenchmarkBase { benchmark.addCase("ExternalAppendOnlyUnsafeRowArray") { _: Int => var sum = 0L for (_ <- 0L until iterations) { - val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, numSpillThreshold) + val array = new ExternalAppendOnlyUnsafeRowArray(numSpillThreshold, + numSpillThreshold, + Long.MaxValue) rows.foreach(x => array.add(x)) val iterator = array.generateIterator() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala index b29de9c4adbaa..fc255426c2699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExternalAppendOnlyUnsafeRowArraySuite.scala @@ -50,7 +50,8 @@ class ExternalAppendOnlyUnsafeRowArraySuite extends SparkFunSuite with LocalSpar 1024, SparkEnv.get.memoryManager.pageSizeBytes, inMemoryThreshold, - spillThreshold) + spillThreshold, + Long.MaxValue) try f(array) finally { array.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 8aa003a3dfeb0..48560311f0ce7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -127,7 +127,9 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession val sorter = new UnsafeKVExternalSorter( keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, - pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get) + pageSize, SHUFFLE_SPILL_NUM_ELEMENTS_FORCE_SPILL_THRESHOLD.defaultValue.get, + SHUFFLE_SPILL_MAP_MAX_SIZE_FORCE_SPILL_THRESHOLD.defaultValue.get + ) // Insert the keys and values into the sorter inputData.foreach { case (k, v) => @@ -240,6 +242,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSparkSession sparkContext.env.serializerManager, taskMemoryManager.pageSizeBytes(), Int.MaxValue, + Long.MaxValue, map) } finally { TaskContext.unset()