diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 93a2bbe25157b..c436025e06bb8 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -89,7 +89,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, - val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS) + val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS, + val checksumMismatchFullRetryEnabled: Boolean = false) extends Dependency[Product2[K, V]] with Logging { def this( diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3f823b60156ad..334eb832c4c2b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -165,9 +165,11 @@ private class ShuffleStatus( /** * Register a map output. If there is already a registered location for the map output then it - * will be replaced by the new location. + * will be replaced by the new location. Returns true if the checksum in the new MapStatus is + * different from a previous registered MapStatus. Otherwise, returns false. */ - def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { + def addMapOutput(mapIndex: Int, status: MapStatus): Boolean = withWriteLock { + var isChecksumMismatch: Boolean = false val currentMapStatus = mapStatuses(mapIndex) if (currentMapStatus == null) { _numAvailableMapOutputs += 1 @@ -183,9 +185,11 @@ private class ShuffleStatus( logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " + s"${status.checksumValue} for task ${status.mapId}.") checksumMismatchIndices.add(mapIndex) + isChecksumMismatch = true } mapStatuses(mapIndex) = status mapIdToMapIndex(status.mapId) = mapIndex + isChecksumMismatch } /** @@ -853,7 +857,7 @@ private[spark] class MapOutputTrackerMaster( } } - def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = { + def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Boolean = { shuffleStatuses(shuffleId).addMapOutput(mapIndex, status) } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 117b2925710d3..d1408ee774ce8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1773,7 +1773,7 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD is reliably checkpointed and materialized. */ - private[rdd] def isReliablyCheckpointed: Boolean = { + private[spark] def isReliablyCheckpointed: Boolean = { checkpointData match { case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true case _ => false diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 30eb49b0c0798..3b719a2c7d24c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1551,29 +1551,46 @@ private[spark] class DAGScheduler( // The operation here can make sure for the partially completed intermediate stage, // `findMissingPartitions()` returns all partitions every time. stage match { - case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => - // already executed at least once - if (sms.getNextAttemptId > 0) { - // While we previously validated possible rollbacks during the handling of a FetchFailure, - // where we were fetching from an indeterminate source map stages, this later check - // covers additional cases like recalculating an indeterminate stage after an executor - // loss. Moreover, because this check occurs later in the process, if a result stage task - // has successfully completed, we can detect this and abort the job, as rolling back a - // result stage is not possible. - val stagesToRollback = collectSucceedingStages(sms) - abortStageWithInvalidRollBack(stagesToRollback) - // stages which cannot be rolled back were aborted which leads to removing the - // the dependant job(s) from the active jobs set - val numActiveJobsWithStageAfterRollback = - activeJobs.count(job => stagesToRollback.contains(job.finalStage)) - if (numActiveJobsWithStageAfterRollback == 0) { - logInfo(log"All jobs depending on the indeterminate stage " + - log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") - return + case sms: ShuffleMapStage if !sms.isAvailable => + val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) { + // When the parents of this stage are indeterminate (e.g., some parents are not + // checkpointed and checksum mismatches are detected), the output data of the parents + // may have changed due to task retries. For correctness reason, we need to + // retry all tasks of the current stage. The legacy way of using current stage's + // deterministic level to trigger full stage retry is not accurate. + stage.isParentIndeterminate + } else { + if (stage.isIndeterminate) { + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, + // where we were fetching from an indeterminate source map stages, this later check + // covers additional cases like recalculating an indeterminate stage after an executor + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. + val stagesToRollback = collectSucceedingStages(sms) + abortStageWithInvalidRollBack(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set + val numActiveJobsWithStageAfterRollback = + activeJobs.count(job => stagesToRollback.contains(job.finalStage)) + if (numActiveJobsWithStageAfterRollback == 0) { + logInfo(log"All jobs depending on the indeterminate stage " + + log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") + return + } + } + true + } else { + false } } - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() + + if (needFullStageRetry) { + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + } case _ => } @@ -1886,6 +1903,20 @@ private[spark] class DAGScheduler( } } + /** + * If a map stage is non-deterministic, the map tasks of the stage may return different result + * when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding + * stages, as the input data may be changed after the map tasks are re-tried. For stages where + * rollback and retry all tasks are not possible, we will need to abort the stages. + */ + private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = { + val stagesToRollback = collectSucceedingStages(mapStage) + val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) + logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " + + log"was failed, we will roll back and rerun below stages which include itself and all its " + + log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -2022,8 +2053,26 @@ private[spark] class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - mapOutputTracker.registerMapOutput( + val isChecksumMismatched = mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + if (isChecksumMismatched) { + shuffleStage.isChecksumMismatched = isChecksumMismatched + // There could be multiple checksum mismatches detected for a single stage attempt. + // We check for stage abortion once and only once when we first detect checksum + // mismatch for each stage attempt. For example, assume that we have + // stage1 -> stage2, and we encounter checksum mismatch during the retry of stage1. + // In this case, we need to call abortUnrollbackableStages() for the succeeding + // stages. Assume that when stage2 is retried, some tasks finish and some tasks + // failed again with FetchFailed. In case that we encounter checksum mismatch again + // during the retry of stage1, we need to call abortUnrollbackableStages() again. + if (shuffleStage.maxChecksumMismatchedId < smt.stageAttemptId) { + shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId + if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled + && shuffleStage.isStageIndeterminate) { + abortUnrollbackableStages(shuffleStage) + } + } + } } } else { logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an older attempt of indeterminate stage") @@ -2148,12 +2197,8 @@ private[spark] class DAGScheduler( // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. - if (mapStage.isIndeterminate) { - val stagesToRollback = collectSucceedingStages(mapStage) - val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) - logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " + - log"we will roll back and rerun below stages which include itself and all its " + - log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") + if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) { + abortUnrollbackableStages(mapStage) } // We expect one executor failure to trigger many FetchFailures in rapid succession, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index f35beafd87480..9bf604e9a83cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -72,6 +72,18 @@ private[scheduler] abstract class Stage( private var nextAttemptId: Int = 0 private[scheduler] def getNextAttemptId: Int = nextAttemptId + /** + * Whether checksum mismatches have been detected across different attempt of the stage, where + * checksum mismatches typically indicates that different stage attempts have produced different + * data. + */ + private[scheduler] var isChecksumMismatched: Boolean = false + + /** + * The maximum of task attempt id where checksum mismatches are detected. + */ + private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId + val name: String = callSite.shortForm val details: String = callSite.longForm @@ -131,4 +143,14 @@ private[scheduler] abstract class Stage( def isIndeterminate: Boolean = { rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE } + + // Returns true if any parents of this stage are indeterminate. + def isParentIndeterminate: Boolean = { + parents.exists(_.isStageIndeterminate) + } + + // Returns true if the stage itself is indeterminate. + def isStageIndeterminate: Boolean = { + !rdd.isReliablyCheckpointed && isChecksumMismatched + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1ada81cbdd0ee..c20866fda0a3b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3415,6 +3415,19 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + private def checkAndCompleteRetryStage( + taskSetIndex: Int, + stageId: Int, + shuffleId: Int, + numTasks: Int = 2, + checksumVal: Long = 0): Unit = { + assert(taskSets(taskSetIndex).stageId == stageId) + assert(taskSets(taskSetIndex).stageAttemptId == 1) + assert(taskSets(taskSetIndex).tasks.length == numTasks) + completeShuffleMapStageSuccessfully(stageId, 1, 2, checksumVal = checksumVal) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + } + test("SPARK-25341: continuous indeterminate stage roll back") { // shuffleMapRdd1/2/3 are all indeterminate. val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) @@ -3454,17 +3467,6 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2)) scheduler.resubmitFailedStages() - def checkAndCompleteRetryStage( - taskSetIndex: Int, - stageId: Int, - shuffleId: Int): Unit = { - assert(taskSets(taskSetIndex).stageId == stageId) - assert(taskSets(taskSetIndex).stageAttemptId == 1) - assert(taskSets(taskSetIndex).tasks.length == 2) - completeShuffleMapStageSuccessfully(stageId, 1, 2) - assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) - } - // Check all indeterminate stage roll back. checkAndCompleteRetryStage(3, 0, shuffleId1) checkAndCompleteRetryStage(4, 1, shuffleId2) @@ -3477,6 +3479,253 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + // Construct the scenario of stages with checksum mismatches and FetchFailed. + private def constructChecksumMismatchStageFetchFailed(): (Int, Int) = { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) + + val shuffleDep1 = new ShuffleDependency( + shuffleMapRdd1, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + + val shuffleDep2 = new ShuffleDependency( + shuffleMapRdd2, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + completeShuffleMapStageSuccessfully( + 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // The first task of the second shuffle map stage failed with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"), + null)) + + // Finish the second task of the second shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), Success, makeMapStatus("hostB", 2), + Seq.empty, Array.empty, createFakeTaskInfoWithId(1))) + + (shuffleId1, shuffleId2) + } + + // Construct the scenario of stages with checksum mismatches and FetchFailed. + // This function assumes that the input `mapRdd` has a single stage with 2 partitions. + private def constructChecksumMismatchStageFetchFailed(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency( + mapRdd, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully( + 0, 0, numShufflePartitions = 2, Seq("hostA", "hostB"), checksumVal = 100) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Fail the first task of the result stage with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), + null)) + + // Finish the second task of the result stage. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), Success, 42, + Seq.empty, Array.empty, createFakeTaskInfoWithId(0))) + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + // Shuffle blocks of "hostA" is lost, so first task of the shuffle map stage and + // result stage needs to retry. + assert(failedStages.map(_.id) == Seq(0, 1)) + assert(failedStages.forall(_.findMissingPartitions() == Seq(0))) + + scheduler.resubmitFailedStages() + + // First shuffle map stage reran failed tasks with a different checksum. + completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101) + } + + private def assertChecksumMismatchResultStageFailToRollback(mapRdd: MyRDD): Unit = { + constructChecksumMismatchStageFetchFailed(mapRdd) + + // The job should fail because Spark can't rollback the result stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + private def assertChecksumMismatchResultStageNotRolledBack(mapRdd: MyRDD): Unit = { + constructChecksumMismatchStageFetchFailed(mapRdd) + + assert(failure == null, "job should not fail") + // Result stage success, all job ended. + complete(taskSets(3), Seq((Success, 41))) + assert(results === Map(0 -> 41, 1 -> 42)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-53575: abort stage while using old fetch protocol") { + conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true") + constructChecksumMismatchStageFetchFailed() + + scheduler.resubmitFailedStages() + completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101) + + // The job should fail because Spark can't rollback the shuffle map stage while + // using old protocol. + assert(failure != null && failure.getMessage.contains( + "Spark can only do this while using the new shuffle block fetching protocol")) + } + + test("SPARK-53575: retry all the succeeding stages when the map stage has checksum mismatches") { + val (shuffleId1, shuffleId2) = + constructChecksumMismatchStageFetchFailed() + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd1` and + // `shuffleMapRdd2` needs to retry. + assert(failedStages.map(_.id) == Seq(0, 1)) + assert(failedStages.forall(_.findMissingPartitions() == Seq(0))) + + scheduler.resubmitFailedStages() + + // First shuffle map stage reran failed tasks with a different checksum. + checkAndCompleteRetryStage(2, 0, shuffleId1, numTasks = 1, checksumVal = 101) + + // Second shuffle map stage reran all tasks. + checkAndCompleteRetryStage(3, 1, shuffleId2, numTasks = 2) + + complete(taskSets(4), Seq((Success, 11), (Success, 12))) + + // Job successful ended. + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-53575: continuous checksum mismatch stage roll back") { + // shuffleMapRdd1/2 have checksum mismatches, and shuffleMapRdd2/3 requires full stage retries. + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) + val shuffleDep1 = new ShuffleDependency( + shuffleMapRdd1, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) + val shuffleId1 = shuffleDep1.shuffleId + + val shuffleMapRdd2 = new MyRDD( + sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + val shuffleDep2 = new ShuffleDependency( + shuffleMapRdd2, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) + val shuffleId2 = shuffleDep2.shuffleId + + val shuffleMapRdd3 = new MyRDD( + sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + val shuffleDep3 = new ShuffleDependency( + shuffleMapRdd3, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) + val shuffleId3 = shuffleDep3.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1), properties = new Properties()) + + // Finish the first 2 shuffle map stages. + completeShuffleMapStageSuccessfully(0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostA", "hostB"), checksumVal = 200) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // Fail the first task of the third shuffle map stage with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId2, 0L, 0, 0, "ignored"), + null)) + + // Finish the second task of the third shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(2).tasks(1), Success, makeMapStatus("hostB", 2), + Seq.empty, Array.empty, createFakeTaskInfoWithId(1))) + mapOutputTracker.removeOutputsOnHost("hostA") + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` and + // `shuffleMapRdd3` needs to retry. + assert(failedStages.map(_.id) == Seq(1, 2)) + assert(failedStages.forall(_.findMissingPartitions() == Seq(0))) + + scheduler.resubmitFailedStages() + + // First shuffle map stage reran failed tasks with a different checksum. + checkAndCompleteRetryStage(3, 0, shuffleId1, numTasks = 1, checksumVal = 101) + // Second and third shuffle map stages reran all tasks with a different checksum. + checkAndCompleteRetryStage(4, 1, shuffleId2, numTasks = 2, checksumVal = 201) + checkAndCompleteRetryStage(5, 2, shuffleId3, numTasks = 2, checksumVal = 301) + // Result stage success, all job ended. + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-53575: cannot rollback a result stage") { + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-53575: local checkpoint fail to rollback (checkpointed before)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.localCheckpoint() + shuffleMapRdd.doCheckpoint() + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-53575: local checkpoint fail to rollback (checkpointing now)") { + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.localCheckpoint() + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-53575: reliable checkpoint can avoid rollback (checkpointed before)") { + withTempDir { dir => + sc.setCheckpointDir(dir.getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.checkpoint() + shuffleMapRdd.doCheckpoint() + assertChecksumMismatchResultStageNotRolledBack(shuffleMapRdd) + } + } + + test("SPARK-53575: reliable checkpoint fail to rollback (checkpointing now)") { + withTempDir { dir => + sc.setCheckpointDir(dir.getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.checkpoint() + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + } + test("SPARK-29042: Sampled RDD with unordered input should be indeterminate") { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 17b8dd493cf80..477d09d29a051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -890,6 +890,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED = + buildConf("spark.sql.shuffle.orderIndependentChecksum.enableFullRetryOnMismatch") + .doc("Whether to retry all tasks of a consumer stage when we detect checksum mismatches " + + "with its producer stages.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") .internal() @@ -6651,6 +6659,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def shuffleOrderIndependentChecksumEnabled: Boolean = getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED) + def shuffleChecksumMismatchFullRetryEnabled: Boolean = + getConf(SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED) + def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS) def objectLevelCollationsEnabled: Boolean = getConf(OBJECT_LEVEL_COLLATIONS_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 9c86bbb606a57..f052bd9068805 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -480,19 +480,22 @@ object ShuffleExchangeExec { // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. - val checksumSize = - if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) { + val checksumSize = { + if (SQLConf.get.shuffleOrderIndependentChecksumEnabled || + SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) { part.numPartitions } else { 0 } + } val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), serializer, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), - rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize)) + rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize), + checksumMismatchFullRetryEnabled = SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) dependency } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index 0fe6603122103..abcd346c32775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -25,7 +25,6 @@ import org.apache.spark.sql.test.SQLTestUtils class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() .master("local") - .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, value = false) .getOrCreate() @@ -39,26 +38,34 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { } test("Propagate checksum from executor to driver") { - assert(spark.sparkContext.conf - .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") - assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") - assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") - assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") - assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") + assert(spark.sparkContext.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) == "5") + assert(spark.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) == "5") + assert(spark.sparkContext.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key) == "false") - assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") == "false") + assert(spark.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key) == "false") - withTable("t") { - spark.range(1000).repartition(10).write.mode("overwrite"). - saveAsTable("t") - } + var shuffleId = 0 + Seq(("true", "false"), ("false", "true"), ("true", "true")).foreach { + case (orderIndependentChecksumEnabled: String, checksumMismatchFullRetryEnabled: String) => + withSQLConf( + SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key -> + orderIndependentChecksumEnabled, + SQLConf.SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key -> + checksumMismatchFullRetryEnabled) { + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") + } - val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. - asInstanceOf[MapOutputTrackerMaster].shuffleStatuses - assert(shuffleStatuses.size == 1) + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.contains(shuffleId)) - val mapStatuses = shuffleStatuses(0).mapStatuses - assert(mapStatuses.length == 5) - assert(mapStatuses.forall(_.checksumValue != 0)) + val mapStatuses = shuffleStatuses(shuffleId).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) + shuffleId += 1 + } + } } }