From 2ab6a38fbcf2a18512e7fe945a6aecfe9e789b83 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 15 Sep 2025 08:01:11 +0000 Subject: [PATCH 1/7] Retry entire consumer stages when checksum mismatch detected for a retried shuffle map task --- .../org/apache/spark/MapOutputTracker.scala | 10 +- .../spark/internal/config/package.scala | 9 + .../main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../apache/spark/scheduler/DAGScheduler.scala | 146 ++++++++-- .../org/apache/spark/scheduler/Stage.scala | 22 ++ .../spark/scheduler/DAGSchedulerSuite.scala | 258 +++++++++++++++++- 6 files changed, 409 insertions(+), 38 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 3f823b60156ad..334eb832c4c2b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -165,9 +165,11 @@ private class ShuffleStatus( /** * Register a map output. If there is already a registered location for the map output then it - * will be replaced by the new location. + * will be replaced by the new location. Returns true if the checksum in the new MapStatus is + * different from a previous registered MapStatus. Otherwise, returns false. */ - def addMapOutput(mapIndex: Int, status: MapStatus): Unit = withWriteLock { + def addMapOutput(mapIndex: Int, status: MapStatus): Boolean = withWriteLock { + var isChecksumMismatch: Boolean = false val currentMapStatus = mapStatuses(mapIndex) if (currentMapStatus == null) { _numAvailableMapOutputs += 1 @@ -183,9 +185,11 @@ private class ShuffleStatus( logInfo(s"Checksum of map output changes from ${preStatus.checksumValue} to " + s"${status.checksumValue} for task ${status.mapId}.") checksumMismatchIndices.add(mapIndex) + isChecksumMismatch = true } mapStatuses(mapIndex) = status mapIdToMapIndex(status.mapId) = mapIndex + isChecksumMismatch } /** @@ -853,7 +857,7 @@ private[spark] class MapOutputTrackerMaster( } } - def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Unit = { + def registerMapOutput(shuffleId: Int, mapIndex: Int, status: MapStatus): Boolean = { shuffleStatuses(shuffleId).addMapOutput(mapIndex, status) } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d0f4806c49482..1a037af8ce567 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1679,6 +1679,15 @@ package object config { .checkValues(Set("ADLER32", "CRC32", "CRC32C")) .createWithDefault("ADLER32") + private[spark] val SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED = + ConfigBuilder("spark.scheduler.checksumMismatchFullRetry.enabled") + .doc("Whether to retry all tasks of a consumer stage when we detect checksum mismatches " + + "with its producer stages. The checksum computation is controlled by another config " + + "called SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + private[spark] val SHUFFLE_COMPRESS = ConfigBuilder("spark.shuffle.compress") .doc("Whether to compress shuffle output. Compression will use " + diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 117b2925710d3..d1408ee774ce8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1773,7 +1773,7 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD is reliably checkpointed and materialized. */ - private[rdd] def isReliablyCheckpointed: Boolean = { + private[spark] def isReliablyCheckpointed: Boolean = { checkpointData match { case Some(reliable: ReliableRDDCheckpointData[_]) if reliable.isCheckpointed => true case _ => false diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 30eb49b0c0798..12a43f3ad219d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -307,6 +307,9 @@ private[spark] class DAGScheduler( private val shuffleFinalizeRpcThreads = sc.conf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) + private val checksumMismatchFullRetryEnabled = + sc.getConf.get(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED) + // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be // initialized lazily private lazy val externalShuffleClient: Option[BlockStoreClient] = @@ -1551,29 +1554,41 @@ private[spark] class DAGScheduler( // The operation here can make sure for the partially completed intermediate stage, // `findMissingPartitions()` returns all partitions every time. stage match { - case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => - // already executed at least once - if (sms.getNextAttemptId > 0) { - // While we previously validated possible rollbacks during the handling of a FetchFailure, - // where we were fetching from an indeterminate source map stages, this later check - // covers additional cases like recalculating an indeterminate stage after an executor - // loss. Moreover, because this check occurs later in the process, if a result stage task - // has successfully completed, we can detect this and abort the job, as rolling back a - // result stage is not possible. - val stagesToRollback = collectSucceedingStages(sms) - abortStageWithInvalidRollBack(stagesToRollback) - // stages which cannot be rolled back were aborted which leads to removing the - // the dependant job(s) from the active jobs set - val numActiveJobsWithStageAfterRollback = - activeJobs.count(job => stagesToRollback.contains(job.finalStage)) - if (numActiveJobsWithStageAfterRollback == 0) { - logInfo(log"All jobs depending on the indeterminate stage " + - log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") - return + case sms: ShuffleMapStage if !sms.isAvailable => + if (checksumMismatchFullRetryEnabled) { + // When the parents of this stage are indeterminate (e.g., some parents are not + // checkpointed and checksum mismatches are detected), the output data of the parents + // may have changed due to task retries. For correctness reason, we need to + // retry all tasks of the current stage. The legacy way of using current stage's + // deterministic level to trigger full stage retry is not accurate. + if (stage.isParentIndeterminate) { + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() + } + } else if (stage.isIndeterminate) { + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, + // where we were fetching from an indeterminate source map stages, this later check + // covers additional cases like recalculating an indeterminate stage after an executor + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. + val stagesToRollback = collectSucceedingStages(sms) + abortStageWithInvalidRollBack(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set + val numActiveJobsWithStageAfterRollback = + activeJobs.count(job => stagesToRollback.contains(job.finalStage)) + if (numActiveJobsWithStageAfterRollback == 0) { + logInfo(log"All jobs depending on the indeterminate stage " + + log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") + return + } } + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) + sms.shuffleDep.newShuffleMergeState() } - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() case _ => } @@ -1886,6 +1901,74 @@ private[spark] class DAGScheduler( } } + /** + * If a map stage is non-deterministic, the map tasks of the stage may return different result + * when re-try. To make sure data correctness, we need to re-try all the tasks of its succeeding + * stages, as the input data may be changed after the map tasks are re-tried. For stages where + * rollback and retry all tasks are not possible, we will need to abort the stages. + */ + private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = { + // It's a little tricky to find all the succeeding stages of `mapStage`, because + // each stage only know its parents not children. Here we traverse the stages from + // the leaf nodes (the result stages of active jobs), and rollback all the stages + // in the stage chains that connect to the `mapStage`. To speed up the stage + // traversing, we collect the stages to rollback first. If a stage needs to + // rollback, all its succeeding stages need to rollback to. + val stagesToRollback = HashSet[Stage](mapStage) + + def collectStagesToRollback(stageChain: List[Stage]): Unit = { + if (stagesToRollback.contains(stageChain.head)) { + stageChain.drop(1).foreach(s => stagesToRollback += s) + } else { + stageChain.head.parents.foreach { s => + collectStagesToRollback(s :: stageChain) + } + } + } + + def generateErrorMessage(stage: Stage): String = { + "A shuffle map stage with indeterminate output was failed and retried. " + + s"However, Spark cannot rollback the $stage to re-process the input data, " + + "and has to fail this job. Please eliminate the indeterminacy by " + + "checkpointing the RDD before repartition and try again." + } + + activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) + + // Abort all stages where rollback is not possible. Other stages will be rolled back and + // the whole task set for the stages will be retried when we resubmit missing tasks for the + // stages. + val rollingBackStages = HashSet[Stage](mapStage) + stagesToRollback.foreach { + case mapStage: ShuffleMapStage => + val numMissingPartitions = mapStage.findMissingPartitions().length + if (numMissingPartitions < mapStage.numTasks) { + if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { + val reason = "A shuffle map stage with indeterminate output was failed " + + "and retried. However, Spark can only do this while using the new " + + "shuffle block fetching protocol. Please check the config " + + "'spark.shuffle.useOldFetchProtocol', see more detail in " + + "SPARK-27665 and SPARK-25341." + abortStage(mapStage, reason, None) + } else { + rollingBackStages += mapStage + } + } + + case resultStage: ResultStage if resultStage.activeJob.isDefined => + val numMissingPartitions = resultStage.findMissingPartitions().length + if (numMissingPartitions < resultStage.numTasks) { + // TODO: support to rollback result tasks. + abortStage(resultStage, generateErrorMessage(resultStage), None) + } + + case _ => + } + logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + + s"we will roll back and rerun below stages which include itself and all its " + + s"indeterminate child stages: $rollingBackStages") + } + /** * Responds to a task finishing. This is called inside the event loop so it assumes that it can * modify the scheduler's internal state. Use taskEnded() to post a task end event from outside. @@ -2022,8 +2105,25 @@ private[spark] class DAGScheduler( // The epoch of the task is acceptable (i.e., the task was launched after the most // recent failure we're aware of for the executor), so mark the task's output as // available. - mapOutputTracker.registerMapOutput( + val isChecksumMismatched = mapOutputTracker.registerMapOutput( shuffleStage.shuffleDep.shuffleId, smt.partitionId, status) + if (isChecksumMismatched) { + shuffleStage.isChecksumMismatched = isChecksumMismatched + // There could be multiple checksum mismatches detected for a single stage attempt. + // We check for stage abortion once and only once when we first detect checksum + // mismatch for each stage attempt. For example, assume that we have + // stage1 -> stage2, and we encounter checksum mismatch during the retry of stage1. + // In this case, we need to call abortUnrollbackableStages() for the succeeding + // stages. Assume that when stage2 is retried, some tasks finish and some tasks + // failed again with FetchFailed. In case that we encounter checksum mismatch again + // during the retry of stage1, we need to call abortUnrollbackableStages() again. + if (shuffleStage.maxChecksumMismatchedId < smt.stageAttemptId) { + shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId + if (checksumMismatchFullRetryEnabled && shuffleStage.isStageIndeterminate) { + abortUnrollbackableStages(shuffleStage) + } + } + } } } else { logInfo(log"Ignoring ${MDC(TASK_NAME, smt)} completion from an older attempt of indeterminate stage") @@ -2148,7 +2248,7 @@ private[spark] class DAGScheduler( // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. - if (mapStage.isIndeterminate) { + if (mapStage.isIndeterminate && !checksumMismatchFullRetryEnabled) { val stagesToRollback = collectSucceedingStages(mapStage) val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index f35beafd87480..9bf604e9a83cf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -72,6 +72,18 @@ private[scheduler] abstract class Stage( private var nextAttemptId: Int = 0 private[scheduler] def getNextAttemptId: Int = nextAttemptId + /** + * Whether checksum mismatches have been detected across different attempt of the stage, where + * checksum mismatches typically indicates that different stage attempts have produced different + * data. + */ + private[scheduler] var isChecksumMismatched: Boolean = false + + /** + * The maximum of task attempt id where checksum mismatches are detected. + */ + private[scheduler] var maxChecksumMismatchedId: Int = nextAttemptId + val name: String = callSite.shortForm val details: String = callSite.longForm @@ -131,4 +143,14 @@ private[scheduler] abstract class Stage( def isIndeterminate: Boolean = { rdd.outputDeterministicLevel == DeterministicLevel.INDETERMINATE } + + // Returns true if any parents of this stage are indeterminate. + def isParentIndeterminate: Boolean = { + parents.exists(_.isStageIndeterminate) + } + + // Returns true if the stage itself is indeterminate. + def isStageIndeterminate: Boolean = { + !rdd.isReliablyCheckpointed && isChecksumMismatched + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 1ada81cbdd0ee..a01f7553ddff4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3415,6 +3415,19 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + private def checkAndCompleteRetryStage( + taskSetIndex: Int, + stageId: Int, + shuffleId: Int, + numTasks: Int = 2, + checksumVal: Long = 0): Unit = { + assert(taskSets(taskSetIndex).stageId == stageId) + assert(taskSets(taskSetIndex).stageAttemptId == 1) + assert(taskSets(taskSetIndex).tasks.length == 2) + completeShuffleMapStageSuccessfully(stageId, 1, 2) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + } + test("SPARK-25341: continuous indeterminate stage roll back") { // shuffleMapRdd1/2/3 are all indeterminate. val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = true) @@ -3454,17 +3467,6 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(scheduler.failedStages.toSeq.map(_.id) == Seq(1, 2)) scheduler.resubmitFailedStages() - def checkAndCompleteRetryStage( - taskSetIndex: Int, - stageId: Int, - shuffleId: Int): Unit = { - assert(taskSets(taskSetIndex).stageId == stageId) - assert(taskSets(taskSetIndex).stageAttemptId == 1) - assert(taskSets(taskSetIndex).tasks.length == 2) - completeShuffleMapStageSuccessfully(stageId, 1, 2) - assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) - } - // Check all indeterminate stage roll back. checkAndCompleteRetryStage(3, 0, shuffleId1) checkAndCompleteRetryStage(4, 1, shuffleId2) @@ -3477,6 +3479,240 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assertDataStructuresEmpty() } + // Construct the scenario of stages with checksum mismatches and FetchFailed. + private def constructChecksumMismatchStageFetchFailed(): (Int, Int) = { + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) + + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + // Finish the first shuffle map stage. + completeShuffleMapStageSuccessfully( + 0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + + // The first task of the second shuffle map stage failed with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId1, 0L, 0, 0, "ignored"), + null)) + + // Finish the second task of the second shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), Success, makeMapStatus("hostB", 2), + Seq.empty, Array.empty, createFakeTaskInfoWithId(1))) + + (shuffleId1, shuffleId2) + } + + // Construct the scenario of stages with checksum mismatches and FetchFailed. + // This function assumes that the input `mapRdd` has a single stage with 2 partitions. + private def constructChecksumMismatchStageFetchFailed(mapRdd: MyRDD): Unit = { + val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleId = shuffleDep.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully( + 0, 0, numShufflePartitions = 2, Seq("hostA", "hostB"), checksumVal = 100) + assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) + + // Fail the first task of the result stage with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0L, 0, 0, "ignored"), + null)) + + // Finish the second task of the result stage. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), Success, 42, + Seq.empty, Array.empty, createFakeTaskInfoWithId(0))) + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + // Shuffle blocks of "hostA" is lost, so first task of the shuffle map stage and + // result stage needs to retry. + assert(failedStages.map(_.id) == Seq(0, 1)) + assert(failedStages.forall(_.findMissingPartitions() == Seq(0))) + + scheduler.resubmitFailedStages() + + // First shuffle map stage reran failed tasks with a different checksum. + completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101) + } + + private def assertChecksumMismatchResultStageFailToRollback(mapRdd: MyRDD): Unit = { + constructChecksumMismatchStageFetchFailed(mapRdd) + + // The job should fail because Spark can't rollback the result stage. + assert(failure != null && failure.getMessage.contains("Spark cannot rollback")) + } + + private def assertChecksumMismatchResultStageNotRolledBack(mapRdd: MyRDD): Unit = { + constructChecksumMismatchStageFetchFailed(mapRdd) + + assert(failure == null, "job should not fail") + // Result stage success, all job ended. + complete(taskSets(3), Seq((Success, 41))) + assert(results === Map(0 -> 41, 1 -> 42)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-53575: abort stage while using old fetch protocol") { + conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true") + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + constructChecksumMismatchStageFetchFailed() + + scheduler.resubmitFailedStages() + completeShuffleMapStageSuccessfully(0, 1, 2, checksumVal = 101) + + // The job should fail because Spark can't rollback the shuffle map stage while + // using old protocol. + assert(failure != null && failure.getMessage.contains( + "Spark can only do this while using the new shuffle block fetching protocol")) + } + + test("SPARK-53575: retry all the succeeding stages when the map stage has checksum mismatches") { + // Disable the stage resubmit that triggered by `DAGScheduler.messageScheduler`, + // so that the stage resubmit won't happen earlier than `scheduler.resubmitFailedStages()` + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + + val (shuffleId1, shuffleId2) = constructChecksumMismatchStageFetchFailed() + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd1` and + // `shuffleMapRdd2` needs to retry. + assert(failedStages.map(_.id) == Seq(0, 1)) + assert(failedStages.forall(_.findMissingPartitions() == Seq(0))) + + scheduler.resubmitFailedStages() + + // First shuffle map stage reran failed tasks with a different checksum. + checkAndCompleteRetryStage(2, 0, shuffleId1, numTasks = 1, checksumVal = 101) + + // Second shuffle map stage reran all tasks. + checkAndCompleteRetryStage(3, 1, shuffleId2, numTasks = 2) + + complete(taskSets(4), Seq((Success, 11), (Success, 12))) + + // Job successful ended. + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-53575: continuous checksum mismatch stage roll back") { + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + + // shuffleMapRdd1/2 have checksum mismatches, and shuffleMapRdd2/3 requires full stage retries. + val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleId1 = shuffleDep1.shuffleId + + val shuffleMapRdd2 = new MyRDD( + sc, 2, List(shuffleDep1), tracker = mapOutputTracker) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleId2 = shuffleDep2.shuffleId + + val shuffleMapRdd3 = new MyRDD( + sc, 2, List(shuffleDep2), tracker = mapOutputTracker) + val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2)) + val shuffleId3 = shuffleDep3.shuffleId + val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) + + submit(finalRdd, Array(0, 1), properties = new Properties()) + + // Finish the first 2 shuffle map stages. + completeShuffleMapStageSuccessfully(0, 0, 2, Seq("hostA", "hostB"), checksumVal = 100) + assert(mapOutputTracker.findMissingPartitions(shuffleId1) === Some(Seq.empty)) + completeShuffleMapStageSuccessfully(1, 0, 2, Seq("hostA", "hostB"), checksumVal = 200) + assert(mapOutputTracker.findMissingPartitions(shuffleId2) === Some(Seq.empty)) + + // Fail the first task of the third shuffle map stage with FetchFailed. + runEvent(makeCompletionEvent( + taskSets(2).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleId2, 0L, 0, 0, "ignored"), + null)) + + // Finish the second task of the third shuffle map stage. + runEvent(makeCompletionEvent( + taskSets(2).tasks(1), Success, makeMapStatus("hostB", 2), + Seq.empty, Array.empty, createFakeTaskInfoWithId(1))) + mapOutputTracker.removeOutputsOnHost("hostA") + + // Check status for all failedStages. + val failedStages = scheduler.failedStages.toSeq + // Shuffle blocks of "hostA" is lost, so first task of the `shuffleMapRdd2` and + // `shuffleMapRdd3` needs to retry. + assert(failedStages.map(_.id) == Seq(1, 2)) + assert(failedStages.forall(_.findMissingPartitions() == Seq(0))) + + scheduler.resubmitFailedStages() + + // First shuffle map stage reran failed tasks with a different checksum. + checkAndCompleteRetryStage(3, 0, shuffleId1, numTasks = 1, checksumVal = 101) + // Second and third shuffle map stages reran all tasks with a different checksum. + checkAndCompleteRetryStage(4, 1, shuffleId2, numTasks = 2, checksumVal = 201) + checkAndCompleteRetryStage(5, 2, shuffleId3, numTasks = 2, checksumVal = 301) + // Result stage success, all job ended. + complete(taskSets(6), Seq((Success, 11), (Success, 12))) + assert(results === Map(0 -> 11, 1 -> 12)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-53575: cannot rollback a result stage") { + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + val shuffleMapRdd = new MyRDD(sc, 2, Nil) + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-53575: local checkpoint fail to rollback (checkpointed before)") { + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.localCheckpoint() + shuffleMapRdd.doCheckpoint() + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-53575: local checkpoint fail to rollback (checkpointing now)") { + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.localCheckpoint() + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + + test("SPARK-53575: reliable checkpoint can avoid rollback (checkpointed before)") { + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + withTempDir { dir => + sc.setCheckpointDir(dir.getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.checkpoint() + shuffleMapRdd.doCheckpoint() + assertChecksumMismatchResultStageNotRolledBack(shuffleMapRdd) + } + } + + test("SPARK-53575: reliable checkpoint fail to rollback (checkpointing now)") { + conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") + withTempDir { dir => + sc.setCheckpointDir(dir.getCanonicalPath) + val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) + shuffleMapRdd.checkpoint() + assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) + } + } + test("SPARK-29042: Sampled RDD with unordered input should be indeterminate") { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil, indeterminate = false) From 86442326d3063255d7e93e08bf2e389d9ee18b55 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Wed, 17 Sep 2025 02:43:48 +0000 Subject: [PATCH 2/7] refactor code and fix ut --- .../apache/spark/scheduler/DAGScheduler.scala | 70 ++----------------- .../spark/scheduler/DAGSchedulerSuite.scala | 4 +- 2 files changed, 8 insertions(+), 66 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 12a43f3ad219d..8893e63bd7b7f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1908,65 +1908,11 @@ private[spark] class DAGScheduler( * rollback and retry all tasks are not possible, we will need to abort the stages. */ private[scheduler] def abortUnrollbackableStages(mapStage: ShuffleMapStage): Unit = { - // It's a little tricky to find all the succeeding stages of `mapStage`, because - // each stage only know its parents not children. Here we traverse the stages from - // the leaf nodes (the result stages of active jobs), and rollback all the stages - // in the stage chains that connect to the `mapStage`. To speed up the stage - // traversing, we collect the stages to rollback first. If a stage needs to - // rollback, all its succeeding stages need to rollback to. - val stagesToRollback = HashSet[Stage](mapStage) - - def collectStagesToRollback(stageChain: List[Stage]): Unit = { - if (stagesToRollback.contains(stageChain.head)) { - stageChain.drop(1).foreach(s => stagesToRollback += s) - } else { - stageChain.head.parents.foreach { s => - collectStagesToRollback(s :: stageChain) - } - } - } - - def generateErrorMessage(stage: Stage): String = { - "A shuffle map stage with indeterminate output was failed and retried. " + - s"However, Spark cannot rollback the $stage to re-process the input data, " + - "and has to fail this job. Please eliminate the indeterminacy by " + - "checkpointing the RDD before repartition and try again." - } - - activeJobs.foreach(job => collectStagesToRollback(job.finalStage :: Nil)) - - // Abort all stages where rollback is not possible. Other stages will be rolled back and - // the whole task set for the stages will be retried when we resubmit missing tasks for the - // stages. - val rollingBackStages = HashSet[Stage](mapStage) - stagesToRollback.foreach { - case mapStage: ShuffleMapStage => - val numMissingPartitions = mapStage.findMissingPartitions().length - if (numMissingPartitions < mapStage.numTasks) { - if (sc.getConf.get(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL)) { - val reason = "A shuffle map stage with indeterminate output was failed " + - "and retried. However, Spark can only do this while using the new " + - "shuffle block fetching protocol. Please check the config " + - "'spark.shuffle.useOldFetchProtocol', see more detail in " + - "SPARK-27665 and SPARK-25341." - abortStage(mapStage, reason, None) - } else { - rollingBackStages += mapStage - } - } - - case resultStage: ResultStage if resultStage.activeJob.isDefined => - val numMissingPartitions = resultStage.findMissingPartitions().length - if (numMissingPartitions < resultStage.numTasks) { - // TODO: support to rollback result tasks. - abortStage(resultStage, generateErrorMessage(resultStage), None) - } - - case _ => - } - logInfo(s"The shuffle map stage $mapStage with indeterminate output was failed, " + - s"we will roll back and rerun below stages which include itself and all its " + - s"indeterminate child stages: $rollingBackStages") + val stagesToRollback = collectSucceedingStages(mapStage) + val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) + logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output " + + log"was failed, we will roll back and rerun below stages which include itself and all its " + + log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") } /** @@ -2249,11 +2195,7 @@ private[spark] class DAGScheduler( // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. if (mapStage.isIndeterminate && !checksumMismatchFullRetryEnabled) { - val stagesToRollback = collectSucceedingStages(mapStage) - val rollingBackStages = abortStageWithInvalidRollBack(stagesToRollback) - logInfo(log"The shuffle map stage ${MDC(SHUFFLE_ID, mapStage)} with indeterminate output was failed, " + - log"we will roll back and rerun below stages which include itself and all its " + - log"indeterminate child stages: ${MDC(STAGES, rollingBackStages)}") + abortUnrollbackableStages(mapStage) } // We expect one executor failure to trigger many FetchFailures in rapid succession, diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a01f7553ddff4..98a0d89674050 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3423,8 +3423,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti checksumVal: Long = 0): Unit = { assert(taskSets(taskSetIndex).stageId == stageId) assert(taskSets(taskSetIndex).stageAttemptId == 1) - assert(taskSets(taskSetIndex).tasks.length == 2) - completeShuffleMapStageSuccessfully(stageId, 1, 2) + assert(taskSets(taskSetIndex).tasks.length == numTasks) + completeShuffleMapStageSuccessfully(stageId, 1, 2, checksumVal = checksumVal) assert(mapOutputTracker.findMissingPartitions(shuffleId) === Some(Seq.empty)) } From 6e214f021c6fea6a5e743756a45603e31f1fbb5e Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 22 Sep 2025 15:24:13 +0800 Subject: [PATCH 3/7] remove invalid comments --- .../scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 98a0d89674050..b2287714e49a7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3582,8 +3582,6 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } test("SPARK-53575: retry all the succeeding stages when the map stage has checksum mismatches") { - // Disable the stage resubmit that triggered by `DAGScheduler.messageScheduler`, - // so that the stage resubmit won't happen earlier than `scheduler.resubmitFailedStages()` conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") val (shuffleId1, shuffleId2) = constructChecksumMismatchStageFetchFailed() From cd20fa4417e2af882ea0d12f401281696e8fe576 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Wed, 24 Sep 2025 02:00:55 +0000 Subject: [PATCH 4/7] address comments --- .../exchange/ShuffleExchangeExec.scala | 3 +- .../spark/sql/MapStatusEndToEndSuite.scala | 44 ++++++++++++------- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 9c86bbb606a57..90485f37ba4df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -481,7 +481,8 @@ object ShuffleExchangeExec { // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. val checksumSize = - if (SQLConf.get.shuffleOrderIndependentChecksumEnabled) { + if (SQLConf.get.shuffleOrderIndependentChecksumEnabled || + SparkEnv.get.conf.get(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED)) { part.numPartitions } else { 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index 0fe6603122103..b3909b024d833 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} +import org.apache.spark.{MapOutputTrackerMaster, SparkEnv, SparkFunSuite} +import org.apache.spark.internal.config import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -25,7 +26,6 @@ import org.apache.spark.sql.test.SQLTestUtils class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { override def spark: SparkSession = SparkSession.builder() .master("local") - .config(SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key, value = true) .config(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key, value = 5) .config(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key, value = false) .getOrCreate() @@ -39,26 +39,40 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { } test("Propagate checksum from executor to driver") { - assert(spark.sparkContext.conf - .get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") - assert(spark.conf.get("spark.sql.shuffle.orderIndependentChecksum.enabled") == "true") assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") == "false") assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") == "false") - withTable("t") { - spark.range(1000).repartition(10).write.mode("overwrite"). - saveAsTable("t") - } + var shuffleId = 0 + Seq(("true", "false"), ("false", "true"), ("true", "true")).foreach { + case (orderIndependentChecksumEnabled: String, checksumMismatchFullRetryEnabled: String) => + withSQLConf( + "spark.sql.shuffle.orderIndependentChecksum.enabled" -> orderIndependentChecksumEnabled) { + withSparkEnvConfs("spark.scheduler.checksumMismatchFullRetry.enabled" -> + checksumMismatchFullRetryEnabled) { + + assert(SQLConf.get.shuffleOrderIndependentChecksumEnabled === + orderIndependentChecksumEnabled.toBoolean) + assert(SparkEnv.get.conf.get(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED) === + checksumMismatchFullRetryEnabled.toBoolean) - val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. - asInstanceOf[MapOutputTrackerMaster].shuffleStatuses - assert(shuffleStatuses.size == 1) + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") + } - val mapStatuses = shuffleStatuses(0).mapStatuses - assert(mapStatuses.length == 5) - assert(mapStatuses.forall(_.checksumValue != 0)) + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.contains(shuffleId)) + + val mapStatuses = shuffleStatuses(shuffleId).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) + shuffleId += 1 + } + } + } } } From e75e88aca5d30c779b655d4784c2397694e71bc9 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Fri, 26 Sep 2025 02:56:08 +0000 Subject: [PATCH 5/7] address comments --- .../scala/org/apache/spark/Dependency.scala | 3 +- .../spark/internal/config/package.scala | 9 ---- .../apache/spark/scheduler/DAGScheduler.scala | 10 ++-- .../spark/scheduler/DAGSchedulerSuite.scala | 49 ++++++++++++------- .../apache/spark/sql/internal/SQLConf.scala | 11 +++++ .../exchange/ShuffleExchangeExec.scala | 8 +-- .../spark/sql/MapStatusEndToEndSuite.scala | 39 +++++++-------- 7 files changed, 72 insertions(+), 57 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 93a2bbe25157b..c436025e06bb8 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -89,7 +89,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false, val shuffleWriterProcessor: ShuffleWriteProcessor = new ShuffleWriteProcessor, - val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS) + val rowBasedChecksums: Array[RowBasedChecksum] = ShuffleDependency.EMPTY_ROW_BASED_CHECKSUMS, + val checksumMismatchFullRetryEnabled: Boolean = false) extends Dependency[Product2[K, V]] with Logging { def this( diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 1a037af8ce567..d0f4806c49482 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -1679,15 +1679,6 @@ package object config { .checkValues(Set("ADLER32", "CRC32", "CRC32C")) .createWithDefault("ADLER32") - private[spark] val SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED = - ConfigBuilder("spark.scheduler.checksumMismatchFullRetry.enabled") - .doc("Whether to retry all tasks of a consumer stage when we detect checksum mismatches " + - "with its producer stages. The checksum computation is controlled by another config " + - "called SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.") - .version("4.1.0") - .booleanConf - .createWithDefault(false) - private[spark] val SHUFFLE_COMPRESS = ConfigBuilder("spark.shuffle.compress") .doc("Whether to compress shuffle output. Compression will use " + diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8893e63bd7b7f..b10fe5e1479d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -307,9 +307,6 @@ private[spark] class DAGScheduler( private val shuffleFinalizeRpcThreads = sc.conf.get(config.PUSH_SHUFFLE_FINALIZE_RPC_THREADS) - private val checksumMismatchFullRetryEnabled = - sc.getConf.get(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED) - // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be // initialized lazily private lazy val externalShuffleClient: Option[BlockStoreClient] = @@ -1555,7 +1552,7 @@ private[spark] class DAGScheduler( // `findMissingPartitions()` returns all partitions every time. stage match { case sms: ShuffleMapStage if !sms.isAvailable => - if (checksumMismatchFullRetryEnabled) { + if (sms.shuffleDep.checksumMismatchFullRetryEnabled) { // When the parents of this stage are indeterminate (e.g., some parents are not // checkpointed and checksum mismatches are detected), the output data of the parents // may have changed due to task retries. For correctness reason, we need to @@ -2065,7 +2062,8 @@ private[spark] class DAGScheduler( // during the retry of stage1, we need to call abortUnrollbackableStages() again. if (shuffleStage.maxChecksumMismatchedId < smt.stageAttemptId) { shuffleStage.maxChecksumMismatchedId = smt.stageAttemptId - if (checksumMismatchFullRetryEnabled && shuffleStage.isStageIndeterminate) { + if (shuffleStage.shuffleDep.checksumMismatchFullRetryEnabled + && shuffleStage.isStageIndeterminate) { abortUnrollbackableStages(shuffleStage) } } @@ -2194,7 +2192,7 @@ private[spark] class DAGScheduler( // Note that, if map stage is UNORDERED, we are fine. The shuffle partitioner is // guaranteed to be determinate, so the input data of the reducers will not change // even if the map tasks are re-tried. - if (mapStage.isIndeterminate && !checksumMismatchFullRetryEnabled) { + if (mapStage.isIndeterminate && !mapStage.shuffleDep.checksumMismatchFullRetryEnabled) { abortUnrollbackableStages(mapStage) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index b2287714e49a7..c20866fda0a3b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -3483,11 +3483,19 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti private def constructChecksumMismatchStageFetchFailed(): (Int, Int) = { val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) - val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleDep1 = new ShuffleDependency( + shuffleMapRdd1, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) val shuffleId1 = shuffleDep1.shuffleId val shuffleMapRdd2 = new MyRDD(sc, 2, List(shuffleDep1), tracker = mapOutputTracker) - val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleDep2 = new ShuffleDependency( + shuffleMapRdd2, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) val shuffleId2 = shuffleDep2.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep2), tracker = mapOutputTracker) @@ -3515,7 +3523,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti // Construct the scenario of stages with checksum mismatches and FetchFailed. // This function assumes that the input `mapRdd` has a single stage with 2 partitions. private def constructChecksumMismatchStageFetchFailed(mapRdd: MyRDD): Unit = { - val shuffleDep = new ShuffleDependency(mapRdd, new HashPartitioner(2)) + val shuffleDep = new ShuffleDependency( + mapRdd, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) val shuffleId = shuffleDep.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep), tracker = mapOutputTracker) @@ -3569,7 +3581,6 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti test("SPARK-53575: abort stage while using old fetch protocol") { conf.set(config.SHUFFLE_USE_OLD_FETCH_PROTOCOL.key, "true") - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") constructChecksumMismatchStageFetchFailed() scheduler.resubmitFailedStages() @@ -3582,9 +3593,8 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } test("SPARK-53575: retry all the succeeding stages when the map stage has checksum mismatches") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") - - val (shuffleId1, shuffleId2) = constructChecksumMismatchStageFetchFailed() + val (shuffleId1, shuffleId2) = + constructChecksumMismatchStageFetchFailed() // Check status for all failedStages. val failedStages = scheduler.failedStages.toSeq @@ -3610,21 +3620,31 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } test("SPARK-53575: continuous checksum mismatch stage roll back") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") - // shuffleMapRdd1/2 have checksum mismatches, and shuffleMapRdd2/3 requires full stage retries. val shuffleMapRdd1 = new MyRDD(sc, 2, Nil) - val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(2)) + val shuffleDep1 = new ShuffleDependency( + shuffleMapRdd1, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) val shuffleId1 = shuffleDep1.shuffleId val shuffleMapRdd2 = new MyRDD( sc, 2, List(shuffleDep1), tracker = mapOutputTracker) - val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(2)) + val shuffleDep2 = new ShuffleDependency( + shuffleMapRdd2, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) val shuffleId2 = shuffleDep2.shuffleId val shuffleMapRdd3 = new MyRDD( sc, 2, List(shuffleDep2), tracker = mapOutputTracker) - val shuffleDep3 = new ShuffleDependency(shuffleMapRdd3, new HashPartitioner(2)) + val shuffleDep3 = new ShuffleDependency( + shuffleMapRdd3, + new HashPartitioner(2), + checksumMismatchFullRetryEnabled = true + ) val shuffleId3 = shuffleDep3.shuffleId val finalRdd = new MyRDD(sc, 2, List(shuffleDep3), tracker = mapOutputTracker) @@ -3670,13 +3690,11 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } test("SPARK-53575: cannot rollback a result stage") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") val shuffleMapRdd = new MyRDD(sc, 2, Nil) assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) } test("SPARK-53575: local checkpoint fail to rollback (checkpointed before)") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) shuffleMapRdd.localCheckpoint() shuffleMapRdd.doCheckpoint() @@ -3684,14 +3702,12 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } test("SPARK-53575: local checkpoint fail to rollback (checkpointing now)") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) shuffleMapRdd.localCheckpoint() assertChecksumMismatchResultStageFailToRollback(shuffleMapRdd) } test("SPARK-53575: reliable checkpoint can avoid rollback (checkpointed before)") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") withTempDir { dir => sc.setCheckpointDir(dir.getCanonicalPath) val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) @@ -3702,7 +3718,6 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } test("SPARK-53575: reliable checkpoint fail to rollback (checkpointing now)") { - conf.set(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key, "true") withTempDir { dir => sc.setCheckpointDir(dir.getCanonicalPath) val shuffleMapRdd = new MyCheckpointRDD(sc, 2, Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 17b8dd493cf80..477d09d29a051 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -890,6 +890,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED = + buildConf("spark.sql.shuffle.orderIndependentChecksum.enableFullRetryOnMismatch") + .doc("Whether to retry all tasks of a consumer stage when we detect checksum mismatches " + + "with its producer stages.") + .version("4.1.0") + .booleanConf + .createWithDefault(false) + val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = buildConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize") .internal() @@ -6651,6 +6659,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def shuffleOrderIndependentChecksumEnabled: Boolean = getConf(SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED) + def shuffleChecksumMismatchFullRetryEnabled: Boolean = + getConf(SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED) + def allowCollationsInMapKeys: Boolean = getConf(ALLOW_COLLATIONS_IN_MAP_KEYS) def objectLevelCollationsEnabled: Boolean = getConf(OBJECT_LEVEL_COLLATIONS_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 90485f37ba4df..f052bd9068805 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -480,20 +480,22 @@ object ShuffleExchangeExec { // Now, we manually create a ShuffleDependency. Because pairs in rddWithPartitionIds // are in the form of (partitionId, row) and every partitionId is in the expected range // [0, part.numPartitions - 1]. The partitioner of this is a PartitionIdPassthrough. - val checksumSize = + val checksumSize = { if (SQLConf.get.shuffleOrderIndependentChecksumEnabled || - SparkEnv.get.conf.get(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED)) { + SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) { part.numPartitions } else { 0 } + } val dependency = new ShuffleDependency[Int, InternalRow, InternalRow]( rddWithPartitionIds, new PartitionIdPassthrough(part.numPartitions), serializer, shuffleWriterProcessor = createShuffleWriteProcessor(writeMetrics), - rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize)) + rowBasedChecksums = UnsafeRowChecksum.createUnsafeRowChecksums(checksumSize), + checksumMismatchFullRetryEnabled = SQLConf.get.shuffleChecksumMismatchFullRetryEnabled) dependency } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index b3909b024d833..af0e2a9a7f256 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.{MapOutputTrackerMaster, SparkEnv, SparkFunSuite} -import org.apache.spark.internal.config +import org.apache.spark.{MapOutputTrackerMaster, SparkFunSuite} import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -49,29 +48,27 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { Seq(("true", "false"), ("false", "true"), ("true", "true")).foreach { case (orderIndependentChecksumEnabled: String, checksumMismatchFullRetryEnabled: String) => withSQLConf( - "spark.sql.shuffle.orderIndependentChecksum.enabled" -> orderIndependentChecksumEnabled) { - withSparkEnvConfs("spark.scheduler.checksumMismatchFullRetry.enabled" -> + "spark.sql.shuffle.orderIndependentChecksum.enabled" -> orderIndependentChecksumEnabled, + "spark.sql.shuffle.orderIndependentChecksum.enableFullRetryOnMismatch" -> checksumMismatchFullRetryEnabled) { + assert(SQLConf.get.shuffleOrderIndependentChecksumEnabled === + orderIndependentChecksumEnabled.toBoolean) + assert(SQLConf.get.shuffleChecksumMismatchFullRetryEnabled === + checksumMismatchFullRetryEnabled.toBoolean) - assert(SQLConf.get.shuffleOrderIndependentChecksumEnabled === - orderIndependentChecksumEnabled.toBoolean) - assert(SparkEnv.get.conf.get(config.SCHEDULER_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED) === - checksumMismatchFullRetryEnabled.toBoolean) - - withTable("t") { - spark.range(1000).repartition(10).write.mode("overwrite"). - saveAsTable("t") - } + withTable("t") { + spark.range(1000).repartition(10).write.mode("overwrite"). + saveAsTable("t") + } - val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. - asInstanceOf[MapOutputTrackerMaster].shuffleStatuses - assert(shuffleStatuses.contains(shuffleId)) + val shuffleStatuses = spark.sparkContext.env.mapOutputTracker. + asInstanceOf[MapOutputTrackerMaster].shuffleStatuses + assert(shuffleStatuses.contains(shuffleId)) - val mapStatuses = shuffleStatuses(shuffleId).mapStatuses - assert(mapStatuses.length == 5) - assert(mapStatuses.forall(_.checksumValue != 0)) - shuffleId += 1 - } + val mapStatuses = shuffleStatuses(shuffleId).mapStatuses + assert(mapStatuses.length == 5) + assert(mapStatuses.forall(_.checksumValue != 0)) + shuffleId += 1 } } } From 491a04b3bb56c6dd19667153c41e5dd05358848b Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 29 Sep 2025 02:27:34 +0000 Subject: [PATCH 6/7] address comments --- .../apache/spark/scheduler/DAGScheduler.scala | 53 ++++++++++--------- .../spark/sql/MapStatusEndToEndSuite.scala | 13 ++--- 2 files changed, 36 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b10fe5e1479d2..3b719a2c7d24c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1552,37 +1552,42 @@ private[spark] class DAGScheduler( // `findMissingPartitions()` returns all partitions every time. stage match { case sms: ShuffleMapStage if !sms.isAvailable => - if (sms.shuffleDep.checksumMismatchFullRetryEnabled) { + val needFullStageRetry = if (sms.shuffleDep.checksumMismatchFullRetryEnabled) { // When the parents of this stage are indeterminate (e.g., some parents are not // checkpointed and checksum mismatches are detected), the output data of the parents // may have changed due to task retries. For correctness reason, we need to // retry all tasks of the current stage. The legacy way of using current stage's // deterministic level to trigger full stage retry is not accurate. - if (stage.isParentIndeterminate) { - mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) - sms.shuffleDep.newShuffleMergeState() - } - } else if (stage.isIndeterminate) { - // already executed at least once - if (sms.getNextAttemptId > 0) { - // While we previously validated possible rollbacks during the handling of a FetchFailure, - // where we were fetching from an indeterminate source map stages, this later check - // covers additional cases like recalculating an indeterminate stage after an executor - // loss. Moreover, because this check occurs later in the process, if a result stage task - // has successfully completed, we can detect this and abort the job, as rolling back a - // result stage is not possible. - val stagesToRollback = collectSucceedingStages(sms) - abortStageWithInvalidRollBack(stagesToRollback) - // stages which cannot be rolled back were aborted which leads to removing the - // the dependant job(s) from the active jobs set - val numActiveJobsWithStageAfterRollback = - activeJobs.count(job => stagesToRollback.contains(job.finalStage)) - if (numActiveJobsWithStageAfterRollback == 0) { - logInfo(log"All jobs depending on the indeterminate stage " + - log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") - return + stage.isParentIndeterminate + } else { + if (stage.isIndeterminate) { + // already executed at least once + if (sms.getNextAttemptId > 0) { + // While we previously validated possible rollbacks during the handling of a FetchFailure, + // where we were fetching from an indeterminate source map stages, this later check + // covers additional cases like recalculating an indeterminate stage after an executor + // loss. Moreover, because this check occurs later in the process, if a result stage task + // has successfully completed, we can detect this and abort the job, as rolling back a + // result stage is not possible. + val stagesToRollback = collectSucceedingStages(sms) + abortStageWithInvalidRollBack(stagesToRollback) + // stages which cannot be rolled back were aborted which leads to removing the + // the dependant job(s) from the active jobs set + val numActiveJobsWithStageAfterRollback = + activeJobs.count(job => stagesToRollback.contains(job.finalStage)) + if (numActiveJobsWithStageAfterRollback == 0) { + logInfo(log"All jobs depending on the indeterminate stage " + + log"(${MDC(STAGE_ID, stage.id)}) were aborted so this stage is not needed anymore.") + return + } } + true + } else { + false } + } + + if (needFullStageRetry) { mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) sms.shuffleDep.newShuffleMergeState() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index af0e2a9a7f256..837d273303a74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -38,18 +38,19 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { } test("Propagate checksum from executor to driver") { - assert(spark.sparkContext.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") - assert(spark.conf.get("spark.sql.leafNodeDefaultParallelism") == "5") - assert(spark.sparkContext.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") + assert(spark.sparkContext.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) == "5") + assert(spark.conf.get(SQLConf.LEAF_NODE_DEFAULT_PARALLELISM.key) == "5") + assert(spark.sparkContext.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key) == "false") - assert(spark.conf.get("spark.sql.classic.shuffleDependency.fileCleanup.enabled") == "false") + assert(spark.conf.get(SQLConf.CLASSIC_SHUFFLE_DEPENDENCY_FILE_CLEANUP_ENABLED.key) == "false") var shuffleId = 0 Seq(("true", "false"), ("false", "true"), ("true", "true")).foreach { case (orderIndependentChecksumEnabled: String, checksumMismatchFullRetryEnabled: String) => withSQLConf( - "spark.sql.shuffle.orderIndependentChecksum.enabled" -> orderIndependentChecksumEnabled, - "spark.sql.shuffle.orderIndependentChecksum.enableFullRetryOnMismatch" -> + SQLConf.SHUFFLE_ORDER_INDEPENDENT_CHECKSUM_ENABLED.key -> + orderIndependentChecksumEnabled, + SQLConf.SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key -> checksumMismatchFullRetryEnabled) { assert(SQLConf.get.shuffleOrderIndependentChecksumEnabled === orderIndependentChecksumEnabled.toBoolean) From 0bd01f603f1554e4e8156a5b4749ab6a5488f651 Mon Sep 17 00:00:00 2001 From: Tengfei Huang Date: Mon, 29 Sep 2025 02:31:41 +0000 Subject: [PATCH 7/7] remove unnecessary assert --- .../scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala | 5 ----- 1 file changed, 5 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala index 837d273303a74..abcd346c32775 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MapStatusEndToEndSuite.scala @@ -52,11 +52,6 @@ class MapStatusEndToEndSuite extends SparkFunSuite with SQLTestUtils { orderIndependentChecksumEnabled, SQLConf.SHUFFLE_CHECKSUM_MISMATCH_FULL_RETRY_ENABLED.key -> checksumMismatchFullRetryEnabled) { - assert(SQLConf.get.shuffleOrderIndependentChecksumEnabled === - orderIndependentChecksumEnabled.toBoolean) - assert(SQLConf.get.shuffleChecksumMismatchFullRetryEnabled === - checksumMismatchFullRetryEnabled.toBoolean) - withTable("t") { spark.range(1000).repartition(10).write.mode("overwrite"). saveAsTable("t")