diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 003d64f78e853..4858af71c1a9c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1433,17 +1433,18 @@ class DAGScheduler( val failedStage = stageIdToStage(task.stageId) logInfo(s"Marking $failedStage (${failedStage.name}) as failed due to a barrier task " + "failed.") - val message = s"Stage failed because barrier task $task finished unsuccessfully. " + + val message = s"Stage failed because barrier task $task finished unsuccessfully.\n" + failure.toErrorString try { - // cancelTasks will fail if a SchedulerBackend does not implement killTask - taskScheduler.cancelTasks(stageId, interruptThread = false) + // killAllTaskAttempts will fail if a SchedulerBackend does not implement killTask. + val reason = s"Task $task from barrier stage $failedStage (${failedStage.name}) failed." + taskScheduler.killAllTaskAttempts(stageId, interruptThread = false, reason) } catch { case e: UnsupportedOperationException => // Cannot continue with barrier stage if failed to cancel zombie barrier tasks. // TODO SPARK-24877 leave the zombie tasks and ignore their completion events. - logWarning(s"Could not cancel tasks for stage $stageId", e) - abortStage(failedStage, "Could not cancel zombie barrier tasks for stage " + + logWarning(s"Could not kill all tasks for stage $stageId", e) + abortStage(failedStage, "Could not kill zombie barrier tasks for stage " + s"$failedStage (${failedStage.name})", Some(e)) } markStageAsFinished(failedStage, Some(message)) @@ -1457,7 +1458,8 @@ class DAGScheduler( if (shouldAbortStage) { val abortMessage = if (disallowStageRetryForTest) { - "Barrier stage will not retry stage due to testing config" + "Barrier stage will not retry stage due to testing config. Most recent failure " + + s"reason: $message" } else { s"""$failedStage (${failedStage.name}) |has failed the maximum allowable number of diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 90644fea23ab1..95f7ae4fd39a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -51,16 +51,22 @@ private[spark] trait TaskScheduler { // Submit a sequence of tasks to run. def submitTasks(taskSet: TaskSet): Unit - // Cancel a stage. + // Kill all the tasks in a stage and fail the stage and all the jobs that depend on the stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def cancelTasks(stageId: Int, interruptThread: Boolean): Unit /** * Kills a task attempt. + * Throw UnsupportedOperationException if the backend doesn't support kill a task. * * @return Whether the task was successfully killed. */ def killTaskAttempt(taskId: Long, interruptThread: Boolean, reason: String): Boolean + // Kill all the running task attempts in a stage. + // Throw UnsupportedOperationException if the backend doesn't support kill tasks. + def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 587ed4b5243b7..72691389d271c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -222,18 +222,11 @@ private[spark] class TaskSchedulerImpl( override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { logInfo("Cancelling stage " + stageId) + // Kill all running tasks for the stage. + killAllTaskAttempts(stageId, interruptThread, reason = "Stage cancelled") + // Cancel all attempts for the stage. taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => attempts.foreach { case (_, tsm) => - // There are two possible cases here: - // 1. The task set manager has been created and some tasks have been scheduled. - // In this case, send a kill signal to the executors to kill the task and then abort - // the stage. - // 2. The task set manager has been created but no tasks have been scheduled. In this case, - // simply abort the stage. - tsm.runningTasksSet.foreach { tid => - taskIdToExecutorId.get(tid).foreach(execId => - backend.killTask(tid, execId, interruptThread, reason = "Stage cancelled")) - } tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) } @@ -252,6 +245,27 @@ private[spark] class TaskSchedulerImpl( } } + override def killAllTaskAttempts( + stageId: Int, + interruptThread: Boolean, + reason: String): Unit = synchronized { + logInfo(s"Killing all running tasks in stage $stageId: $reason") + taskSetsByStageIdAndAttempt.get(stageId).foreach { attempts => + attempts.foreach { case (_, tsm) => + // There are two possible cases here: + // 1. The task set manager has been created and some tasks have been scheduled. + // In this case, send a kill signal to the executors to kill the task. + // 2. The task set manager has been created but no tasks have been scheduled. In this case, + // simply continue. + tsm.runningTasksSet.foreach { tid => + taskIdToExecutorId.get(tid).foreach { execId => + backend.killTask(tid, execId, interruptThread, reason) + } + } + } + } + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index b3db5e29fb82e..dad339e2cdb91 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -131,6 +131,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler) = {} override def defaultParallelism() = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} @@ -629,6 +631,10 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi taskId: Long, interruptThread: Boolean, reason: String): Boolean = { throw new UnsupportedOperationException } + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = { + throw new UnsupportedOperationException + } override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorHeartbeatReceived( diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index a4e4ea7cd2894..02b19e01ce7a0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -81,6 +81,8 @@ private class DummyTaskScheduler extends TaskScheduler { override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} override def killTaskAttempt( taskId: Long, interruptThread: Boolean, reason: String): Boolean = false + override def killAllTaskAttempts( + stageId: Int, interruptThread: Boolean, reason: String): Unit = {} override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} override def defaultParallelism(): Int = 2 override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 16c273b7bc8a4..38e26a82e750f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -1055,4 +1055,66 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten assert(3 === taskDescriptions.length) } + + test("cancelTasks shall kill all the running tasks and fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.cancelTasks(0, false) + assert(0 === tsm.runningTasks) + assert(tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isEmpty) + } + + test("killAllTaskAttempts shall kill all the running tasks and not fail the stage") { + val taskScheduler = setupScheduler() + + taskScheduler.initialize(new FakeSchedulerBackend { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Since we only submit one stage attempt, the following call is sufficient to mark the + // task as killed. + taskScheduler.taskSetManagerForAttempt(0, 0).get.runningTasksSet.remove(taskId) + } + }) + + val attempt1 = FakeTask.createTaskSet(10, 0) + taskScheduler.submitTasks(attempt1) + + val workerOffers = IndexedSeq(new WorkerOffer("executor0", "host0", 1), + new WorkerOffer("executor1", "host1", 1)) + val taskDescriptions = taskScheduler.resourceOffers(workerOffers).flatten + assert(2 === taskDescriptions.length) + val tsm = taskScheduler.taskSetManagerForAttempt(0, 0).get + assert(2 === tsm.runningTasks) + + taskScheduler.killAllTaskAttempts(0, false, "test") + assert(0 === tsm.runningTasks) + assert(!tsm.isZombie) + assert(taskScheduler.taskSetManagerForAttempt(0, 0).isDefined) + } }