From dbde1e4ca84090a3a924ce389bf5a081f370a401 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 28 Jul 2018 02:12:01 +0800 Subject: [PATCH 01/19] implement BarrierTaskContext.barrier() --- .../org/apache/spark/BarrierCoordinator.scala | 184 ++++++++++++++++++ .../apache/spark/BarrierTaskContextImpl.scala | 0 .../spark/internal/config/package.scala | 10 + .../spark/scheduler/TaskSchedulerImpl.scala | 19 +- 4 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 core/src/main/scala/org/apache/spark/BarrierCoordinator.scala create mode 100644 core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala new file mode 100644 index 0000000000000..7664a514079d9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -0,0 +1,184 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Timer, TimerTask} + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import org.apache.spark.internal.Logging +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} + +class BarrierCoordinator( + timeout: Long, + override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + + private val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + + // Barrier epoch for each stage attempt, fail a sync request if the barrier epoch in the request + // mismatches the barrier epoch in the coordinator. + private val barrierEpochByStageIdAndAttempt = new HashMap[Int, HashMap[Int, Int]] + + // Any access to this should be synchronized. + private val syncRequestsByStageIdAndAttempt = + new HashMap[Int, HashMap[Int, ArrayBuffer[RpcCallContext]]] + + /** + * Get the array of [[RpcCallContext]]s that correspond to a barrier sync request from a stage + * attempt. + */ + private def getOrInitSyncRequests( + stageId: Int, + stageAttemptId: Int, + numTasks: Int = 0): ArrayBuffer[RpcCallContext] = synchronized { + val syncRequestsByStage = syncRequestsByStageIdAndAttempt + .getOrElseUpdate(stageId, new HashMap[Int, ArrayBuffer[RpcCallContext]]) + syncRequestsByStage.getOrElseUpdate(stageAttemptId, new ArrayBuffer[RpcCallContext](numTasks)) + } + + /** + * Clean up the array of [[RpcCallContext]]s that correspond to a barrier sync request from a + * stage attempt. + */ + private def cleanupSyncRequests(stageId: Int, stageAttemptId: Int): Unit = synchronized { + syncRequestsByStageIdAndAttempt.get(stageId).foreach { syncRequestByStage => + syncRequestByStage.get(stageAttemptId).foreach { syncRequests => + syncRequests.clear() + } + syncRequestByStage -= stageAttemptId + if (syncRequestByStage.isEmpty) { + syncRequestsByStageIdAndAttempt -= stageId + } + logInfo(s"Removed all the pending barrier sync requests from Stage $stageId(Attempt " + + s"$stageAttemptId).") + } + } + + /** + * Get the barrier epoch that correspond to a barrier sync request from a stage attempt. + */ + private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): Int = synchronized { + val barrierEpochByStage = barrierEpochByStageIdAndAttempt + .getOrElseUpdate(stageId, new HashMap[Int, Int]) + val barrierEpoch = barrierEpochByStage.getOrElseUpdate(stageAttemptId, 0) + logInfo(s"Current barrier epoch for Stage $stageId(Attempt $stageAttemptId) is $barrierEpoch.") + barrierEpoch + } + + /** + * Update the barrier epoch that correspond to a barrier sync request from a stage attempt. + */ + private def updateBarrierEpoch( + stageId: Int, + stageAttemptId: Int, + newBarrierEpoch: Int): Unit = synchronized { + val barrierEpochByStage = barrierEpochByStageIdAndAttempt + .getOrElseUpdate(stageId, new HashMap[Int, Int]) + barrierEpochByStage.put(stageAttemptId, newBarrierEpoch) + logInfo(s"Current barrier epoch for Stage $stageId(Attempt $stageAttemptId) is " + + s"$newBarrierEpoch.") + } + + /** + * Send failure to all the blocking barrier sync requests from a stage attempt with proper + * failure message. + */ + private def failAllSyncRequests( + syncRequests: ArrayBuffer[RpcCallContext], + message: String): Unit = { + syncRequests.foreach(_.sendFailure(new SparkException(message))) + } + + /** + * Finish all the blocking barrier sync requests from a stage attempt successfully if we + * have received all the sync requests. + */ + private def maybeFinishAllSyncRequests( + syncRequests: ArrayBuffer[RpcCallContext], + numTasks: Int): Boolean = { + if (syncRequests.size == numTasks) { + syncRequests.foreach(_.reply(())) + return true + } + + false + } + + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case RequestToSync(numTasks, stageId, stageAttemptId, taskAttemptId, barrierEpoch) => + // Check the barrier epoch, fail the sync request if barrier epoch mismatches. + val currentBarrierEpoch = getOrInitBarrierEpoch(stageId, stageAttemptId) + val syncRequests = getOrInitSyncRequests(stageId, stageAttemptId) + if (barrierEpoch != currentBarrierEpoch) { + syncRequests += context + failAllSyncRequests(syncRequests, + "The request to sync fails due to mismatched barrier epoch, the barrier epoch from " + + s"task $taskAttemptId is $barrierEpoch, while the barrier epoch from the " + + s"coordinator is $currentBarrierEpoch.") + cleanupSyncRequests(stageId, stageAttemptId) + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + updateBarrierEpoch(stageId, stageAttemptId, -1) + } else { + // If this is the first sync message received for a barrier() call, init a timer to ensure + // we may timeout for the sync. + if (syncRequests.isEmpty) { + timer.schedule(new TimerTask { + override def run(): Unit = { + // Timeout for current barrier() call, fail all the sync requests and reset the + // barrier epoch. + val requests = getOrInitSyncRequests(stageId, stageAttemptId) + failAllSyncRequests(requests, + "The coordinator didn't get all barrier sync requests for barrier epoch " + + s"$barrierEpoch from Stage $stageId(Attempt $stageAttemptId) within $timeout " + + "ms.") + cleanupSyncRequests(stageId, stageAttemptId) + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + updateBarrierEpoch(stageId, stageAttemptId, -1) + } + }, timeout) + } + + syncRequests += context + logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId(Attempt $stageAttemptId) " + + s"received update from Task $taskAttemptId, current progress: " + + s"${syncRequests.size}/$numTasks.") + if (maybeFinishAllSyncRequests(syncRequests, numTasks)) { + // Finished current barrier() call successfully, clean up internal data and increase the + // barrier epoch. + logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId(Attempt " + + s"$stageAttemptId) received all updates from tasks, finished successfully.") + cleanupSyncRequests(stageId, stageAttemptId) + updateBarrierEpoch(stageId, stageAttemptId, currentBarrierEpoch + 1) + } + } + } + + override def onStop(): Unit = timer.cancel() +} + +private[spark] sealed trait BarrierCoordinatorMessage extends Serializable + +private[spark] case class RequestToSync( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + taskAttemptId: Long, + barrierEpoch: Int) extends BarrierCoordinatorMessage diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 8fef2aa6863c5..d5b244cf4b15b 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -567,4 +567,14 @@ package object config { .intConf .checkValue(v => v > 0, "The value should be a positive integer.") .createWithDefault(2000) + + private[spark] val BARRIER_SYNC_TIMEOUT = + ConfigBuilder("spark.barrier.sync.timeout") + .doc("The timeout in milliseconds for each barrier() call from a barrier task. If the " + + "coordinator didn't receive all the sync messages from barrier tasks within the " + + "configed time, throw a SparkException to fail all the tasks. The default value is set " + + "to Long.MaxValue so the barrier() call shall wait forever.") + .longConf + .checkValue(v => v > 0, "The value should be a positive long value.") + .createWithDefault(Long.MaxValue) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 72691389d271c..0006809651db9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -30,6 +30,7 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.internal.Logging import org.apache.spark.internal.config +import org.apache.spark.rpc.RpcEndpoint import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality import org.apache.spark.storage.BlockManagerId @@ -138,6 +139,15 @@ private[spark] class TaskSchedulerImpl( // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) + private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) + + private[scheduler] lazy val barrierCoordinator: RpcEndpoint = { + val coordinator = new BarrierCoordinator(barrierSyncTimeout, sc.env.rpcEnv) + sc.env.rpcEnv.setupEndpoint("barrierSync", coordinator) + logInfo("Registered BarrierCoordinator endpoint") + coordinator + } + override def setDAGScheduler(dagScheduler: DAGScheduler) { this.dagScheduler = dagScheduler } @@ -419,7 +429,11 @@ private[spark] class TaskSchedulerImpl( .sortBy(_._2.partitionId) .map(_._1) .mkString(",") - addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) + addressesWithDescs.foreach { case (_, taskDesc) => + taskDesc.properties.setProperty("addresses", addressesStr) + taskDesc.properties.setProperty("numTasks", taskSet.numTasks.toString) + taskDesc.properties.setProperty("barrierTimeout", barrierSyncTimeout.toString) + } logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + s"stage ${taskSet.stageId}.") @@ -566,6 +580,9 @@ private[spark] class TaskSchedulerImpl( if (taskResultGetter != null) { taskResultGetter.stop() } + if (barrierCoordinator != null) { + barrierCoordinator.stop() + } starvationTimer.cancel() } From b690f675e616698a660bce3240948516fc339471 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 29 Jul 2018 12:13:35 +0800 Subject: [PATCH 02/19] update --- .../apache/spark/BarrierTaskContextImpl.scala | 97 ++++++++++++ .../spark/internal/config/package.scala | 4 +- .../spark/scheduler/TaskSchedulerImpl.scala | 4 +- .../scheduler/BarrierTaskContextSuite.scala | 144 ++++++++++++++++++ 4 files changed, 246 insertions(+), 3 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala index e69de29bb2d1d..b5e85968b0e73 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.util.{Properties, Timer, TimerTask} + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.util.RpcUtils + +/** A [[BarrierTaskContext]] implementation. */ +private[spark] class BarrierTaskContextImpl( + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, + override val taskAttemptId: Long, + override val attemptNumber: Int, + override val taskMemoryManager: TaskMemoryManager, + localProperties: Properties, + @transient private val metricsSystem: MetricsSystem, + // The default value is only used in tests. + override val taskMetrics: TaskMetrics = TaskMetrics.empty) + extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, + taskMemoryManager, localProperties, metricsSystem, taskMetrics) + with BarrierTaskContext { + + private val barrierCoordinator: RpcEndpointRef = { + val env = SparkEnv.get + RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) + } + + private val timer = new Timer("Barrier task timer for barrier() calls.") + + private var barrierEpoch = 0 + + private lazy val numTasks = localProperties.getProperty("numTasks", "0").toInt + + override def barrier(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + barrierCoordinator.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, barrierEpoch), + timeout = new RpcTimeout(31536000 /** = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + + s"$barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + throw e + } + } + + override def getTaskInfos(): Array[BarrierTaskInfo] = { + val addressesStr = localProperties.getProperty("addresses", "") + addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index d5b244cf4b15b..7adab1a6c4734 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -573,8 +573,8 @@ package object config { .doc("The timeout in milliseconds for each barrier() call from a barrier task. If the " + "coordinator didn't receive all the sync messages from barrier tasks within the " + "configed time, throw a SparkException to fail all the tasks. The default value is set " + - "to Long.MaxValue so the barrier() call shall wait forever.") + "to 31536000000(3600 * 24 * 365 * 1000) so the barrier() call shall wait for one year.") .longConf .checkValue(v => v > 0, "The value should be a positive long value.") - .createWithDefault(Long.MaxValue) + .createWithDefault(31536000000L) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 0006809651db9..ab4a691f88014 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -423,6 +423,9 @@ private[spark] class TaskSchedulerImpl( s"${taskSet.numTasks} tasks got resource offers. The resource offers may have " + "been blacklisted or cannot fulfill task locality requirements.") + // materialize the barrier coordinator. + barrierCoordinator + // Update the taskInfos into all the barrier task properties. val addressesStr = addressesWithDescs // Addresses ordered by partitionId @@ -432,7 +435,6 @@ private[spark] class TaskSchedulerImpl( addressesWithDescs.foreach { case (_, taskDesc) => taskDesc.properties.setProperty("addresses", addressesStr) taskDesc.properties.setProperty("numTasks", taskSet.numTasks.toString) - taskDesc.properties.setProperty("barrierTimeout", barrierSyncTimeout.toString) } logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala new file mode 100644 index 0000000000000..a9e613568b5c5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import scala.util.Random + +import org.apache.spark._ + +class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { + + test("global sync by barrier() call") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + Seq(System.currentTimeMillis()).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish global sync within a short time slot. + assert(times.max - times.min <= 5) + } + + test("support multiple barrier() call within a single task") { + val conf = new SparkConf() + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Sleep for a random time before global sync. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time1 = System.currentTimeMillis() + // Sleep for a random time between two global syncs. + Thread.sleep(Random.nextInt(1000)) + context.barrier() + val time2 = System.currentTimeMillis() + Seq((time1, time2)).iterator + } + + val times = rdd2.collect() + // All the tasks shall finish the first round of global sync within a short time slot. + val times1 = times.map(_._1) + assert(times1.max - times1.min <= 5) + + // All the tasks shall finish the second round of global sync within a short time slot. + val times2 = times.map(_._2) + assert(times2.max - times2.min <= 5) + } + + test("throw exception on barrier() call timeout") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "100") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + // Task 3 shall sleep 200ms to ensure barrier() call timeout + if (context.taskAttemptId() == 3) { + Thread.sleep(200) + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 100 ms")) + } + + test("throw exception if barrier() call doesn't happen on every task") { + val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "100") + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + if (context.taskAttemptId() != 0) { + context.barrier() + } + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 100 ms")) + } + + ignore("throw exception if barrier() call mismatched") { + val conf = new SparkConf() + .set("spark.test.noStageRetry", "true") + .setMaster("local-cluster[4, 1, 1024]") + .setAppName("test-cluster") + sc = new SparkContext(conf) + val rdd = sc.makeRDD(1 to 10, 4) + val rdd2 = rdd.barrier().mapPartitions { (it, context) => + try { + if (context.taskAttemptId() == 0) { + // Task 0 skip the first barrier() call. + throw new SparkException("test") + } + context.barrier() + } catch { + case e: Exception => // Do nothing + } + context.barrier() + it + } + + val error = intercept[SparkException] { + rdd2.collect() + }.getMessage + assert(error.contains("fails due to mismatched barrier epoch")) + } +} From 2696f1811272ce3c1c4047475812abdfb172129f Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 29 Jul 2018 13:06:37 +0800 Subject: [PATCH 03/19] update --- .../main/scala/org/apache/spark/BarrierTaskContextImpl.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala index b5e85968b0e73..a8aeba5c47622 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala @@ -87,6 +87,8 @@ private[spark] class BarrierTaskContextImpl( s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + s"is $barrierEpoch.") throw e + } finally { + timerTask.cancel() } } From 330a26b186210a63d532388266859e6ac33b63dc Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 2 Aug 2018 01:20:11 +0800 Subject: [PATCH 04/19] update --- .../org/apache/spark/BarrierCoordinator.scala | 133 +++++++++--------- .../apache/spark/BarrierTaskContextImpl.scala | 11 +- .../spark/internal/config/package.scala | 10 +- .../spark/scheduler/TaskSchedulerImpl.scala | 5 +- .../scheduler/BarrierTaskContextSuite.scala | 26 ++-- 5 files changed, 95 insertions(+), 90 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 7664a514079d9..7b5e22179489b 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -18,25 +18,33 @@ package org.apache.spark import java.util.{Timer, TimerTask} +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} -class BarrierCoordinator( - timeout: Long, +/** + * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync + * request is generated by `BarrierTaskContext.barrier()`, and identified by + * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon + * received all the requests for a group of `barrier()` calls. If the coordinator doesn't collect + * enough global sync requests within a configured time, fail all the requests due to timeout. + */ +private[spark] class BarrierCoordinator( + timeout: Int, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val timer = new Timer("BarrierCoordinator barrier epoch increment timer") - // Barrier epoch for each stage attempt, fail a sync request if the barrier epoch in the request - // mismatches the barrier epoch in the coordinator. - private val barrierEpochByStageIdAndAttempt = new HashMap[Int, HashMap[Int, Int]] + // Epoch counter for each barrier (stage, attempt). + private val barrierEpochByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), AtomicInteger] - // Any access to this should be synchronized. + // Remember all the blocking global sync requests for each barrier (stage, attempt). private val syncRequestsByStageIdAndAttempt = - new HashMap[Int, HashMap[Int, ArrayBuffer[RpcCallContext]]] + new ConcurrentHashMap[(Int, Int), ArrayBuffer[RpcCallContext]] /** * Get the array of [[RpcCallContext]]s that correspond to a barrier sync request from a stage @@ -45,53 +53,40 @@ class BarrierCoordinator( private def getOrInitSyncRequests( stageId: Int, stageAttemptId: Int, - numTasks: Int = 0): ArrayBuffer[RpcCallContext] = synchronized { - val syncRequestsByStage = syncRequestsByStageIdAndAttempt - .getOrElseUpdate(stageId, new HashMap[Int, ArrayBuffer[RpcCallContext]]) - syncRequestsByStage.getOrElseUpdate(stageAttemptId, new ArrayBuffer[RpcCallContext](numTasks)) + numTasks: Int = 0): ArrayBuffer[RpcCallContext] = { + val requests = syncRequestsByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), + new ArrayBuffer[RpcCallContext](numTasks)) + if (requests == null) { + syncRequestsByStageIdAndAttempt.get((stageId, stageAttemptId)) + } else { + requests + } } /** * Clean up the array of [[RpcCallContext]]s that correspond to a barrier sync request from a * stage attempt. */ - private def cleanupSyncRequests(stageId: Int, stageAttemptId: Int): Unit = synchronized { - syncRequestsByStageIdAndAttempt.get(stageId).foreach { syncRequestByStage => - syncRequestByStage.get(stageAttemptId).foreach { syncRequests => - syncRequests.clear() - } - syncRequestByStage -= stageAttemptId - if (syncRequestByStage.isEmpty) { - syncRequestsByStageIdAndAttempt -= stageId - } - logInfo(s"Removed all the pending barrier sync requests from Stage $stageId(Attempt " + - s"$stageAttemptId).") + private def cleanupSyncRequests(stageId: Int, stageAttemptId: Int): Unit = { + val requests = syncRequestsByStageIdAndAttempt.remove((stageId, stageAttemptId)) + if (requests != null) { + requests.clear() } + logInfo(s"Removed all the pending barrier sync requests from Stage $stageId (Attempt " + + s"$stageAttemptId).") } /** * Get the barrier epoch that correspond to a barrier sync request from a stage attempt. */ - private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): Int = synchronized { - val barrierEpochByStage = barrierEpochByStageIdAndAttempt - .getOrElseUpdate(stageId, new HashMap[Int, Int]) - val barrierEpoch = barrierEpochByStage.getOrElseUpdate(stageAttemptId, 0) - logInfo(s"Current barrier epoch for Stage $stageId(Attempt $stageAttemptId) is $barrierEpoch.") - barrierEpoch - } - - /** - * Update the barrier epoch that correspond to a barrier sync request from a stage attempt. - */ - private def updateBarrierEpoch( - stageId: Int, - stageAttemptId: Int, - newBarrierEpoch: Int): Unit = synchronized { - val barrierEpochByStage = barrierEpochByStageIdAndAttempt - .getOrElseUpdate(stageId, new HashMap[Int, Int]) - barrierEpochByStage.put(stageAttemptId, newBarrierEpoch) - logInfo(s"Current barrier epoch for Stage $stageId(Attempt $stageAttemptId) is " + - s"$newBarrierEpoch.") + private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): AtomicInteger = { + val barrierEpoch = barrierEpochByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), + new AtomicInteger(0)) + if (barrierEpoch == null) { + barrierEpochByStageIdAndAttempt.get((stageId, stageAttemptId)) + } else { + barrierEpoch + } } /** @@ -122,51 +117,45 @@ class BarrierCoordinator( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestToSync(numTasks, stageId, stageAttemptId, taskAttemptId, barrierEpoch) => - // Check the barrier epoch, fail the sync request if barrier epoch mismatches. + // Check the barrier epoch, to see which barrier() call we are processing. val currentBarrierEpoch = getOrInitBarrierEpoch(stageId, stageAttemptId) - val syncRequests = getOrInitSyncRequests(stageId, stageAttemptId) - if (barrierEpoch != currentBarrierEpoch) { - syncRequests += context - failAllSyncRequests(syncRequests, - "The request to sync fails due to mismatched barrier epoch, the barrier epoch from " + - s"task $taskAttemptId is $barrierEpoch, while the barrier epoch from the " + - s"coordinator is $currentBarrierEpoch.") - cleanupSyncRequests(stageId, stageAttemptId) - // The global sync fails so the stage is expected to retry another attempt, all sync - // messages come from current stage attempt shall fail. - updateBarrierEpoch(stageId, stageAttemptId, -1) + logInfo(s"Current barrier epoch for Stage $stageId (Attempt $stageAttemptId) is" + + s"$currentBarrierEpoch.") + if (currentBarrierEpoch.get() != barrierEpoch) { + context.sendFailure(new SparkException(s"The request to sync of Stage $stageId (Attempt " + + s"$stageAttemptId) with barrier epoch $barrierEpoch has already finished. Maybe task " + + s"$taskAttemptId is not properly killed.")) } else { + val syncRequests = getOrInitSyncRequests(stageId, stageAttemptId) // If this is the first sync message received for a barrier() call, init a timer to ensure // we may timeout for the sync. if (syncRequests.isEmpty) { timer.schedule(new TimerTask { override def run(): Unit = { - // Timeout for current barrier() call, fail all the sync requests and reset the - // barrier epoch. + // Timeout for current barrier() call, fail all the sync requests. val requests = getOrInitSyncRequests(stageId, stageAttemptId) - failAllSyncRequests(requests, - "The coordinator didn't get all barrier sync requests for barrier epoch " + - s"$barrierEpoch from Stage $stageId(Attempt $stageAttemptId) within $timeout " + - "ms.") + failAllSyncRequests(requests, "The coordinator didn't get all barrier sync " + + s"requests for barrier epoch $barrierEpoch from Stage $stageId (Attempt " + + s"$stageAttemptId) within ${timeout}s.") cleanupSyncRequests(stageId, stageAttemptId) // The global sync fails so the stage is expected to retry another attempt, all sync // messages come from current stage attempt shall fail. - updateBarrierEpoch(stageId, stageAttemptId, -1) + currentBarrierEpoch.set(-1) } - }, timeout) + }, timeout * 1000) } syncRequests += context - logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId(Attempt $stageAttemptId) " + - s"received update from Task $taskAttemptId, current progress: " + + logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId (Attempt " + + s"$stageAttemptId) received update from Task $taskAttemptId, current progress: " + s"${syncRequests.size}/$numTasks.") if (maybeFinishAllSyncRequests(syncRequests, numTasks)) { // Finished current barrier() call successfully, clean up internal data and increase the // barrier epoch. - logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId(Attempt " + + logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId (Attempt " + s"$stageAttemptId) received all updates from tasks, finished successfully.") cleanupSyncRequests(stageId, stageAttemptId) - updateBarrierEpoch(stageId, stageAttemptId, currentBarrierEpoch + 1) + currentBarrierEpoch.incrementAndGet() } } } @@ -176,6 +165,16 @@ class BarrierCoordinator( private[spark] sealed trait BarrierCoordinatorMessage extends Serializable +/** + * A global sync request message from BarrierTaskContext, by `barrier()` call. Each request is + * identified by stageId + stageAttemptId + barrierEpoch. + * + * @param numTasks The number of global sync requests the BarrierCoordinator shall receive + * @param stageId ID of current stage + * @param stageAttemptId ID of current stage attempt + * @param taskAttemptId Unique ID of current task + * @param barrierEpoch ID of the `barrier()` call, a task may consists multiple `barrier()` calls. + */ private[spark] case class RequestToSync( numTasks: Int, stageId: Int, diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala index a8aeba5c47622..6e8d559bba9b0 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala @@ -26,7 +26,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} -import org.apache.spark.util.RpcUtils +import org.apache.spark.util.{RpcUtils, Utils} /** A [[BarrierTaskContext]] implementation. */ private[spark] class BarrierTaskContextImpl( @@ -53,11 +53,13 @@ private[spark] class BarrierTaskContextImpl( private var barrierEpoch = 0 - private lazy val numTasks = localProperties.getProperty("numTasks", "0").toInt + private lazy val numTasks = getTaskInfos().size override def barrier(): Unit = { + val callSite = Utils.getCallSite() logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace(s"Current callSite: $callSite") val startTime = System.currentTimeMillis() val timerTask = new TimerTask { @@ -73,7 +75,10 @@ private[spark] class BarrierTaskContextImpl( try { barrierCoordinator.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, barrierEpoch), + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. timeout = new RpcTimeout(31536000 /** = 3600 * 24 * 365 */ seconds, "barrierTimeout")) barrierEpoch += 1 logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7adab1a6c4734..22f98873844ea 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -570,11 +570,11 @@ package object config { private[spark] val BARRIER_SYNC_TIMEOUT = ConfigBuilder("spark.barrier.sync.timeout") - .doc("The timeout in milliseconds for each barrier() call from a barrier task. If the " + + .doc("The timeout in seconds for each barrier() call from a barrier task. If the " + "coordinator didn't receive all the sync messages from barrier tasks within the " + "configed time, throw a SparkException to fail all the tasks. The default value is set " + - "to 31536000000(3600 * 24 * 365 * 1000) so the barrier() call shall wait for one year.") - .longConf - .checkValue(v => v > 0, "The value should be a positive long value.") - .createWithDefault(31536000000L) + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") + .intConf + .checkValue(v => v > 0, "The value should be a positive int value.") + .createWithDefault(31536000) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index ab4a691f88014..7d130ef1c3f32 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -432,10 +432,7 @@ private[spark] class TaskSchedulerImpl( .sortBy(_._2.partitionId) .map(_._1) .mkString(",") - addressesWithDescs.foreach { case (_, taskDesc) => - taskDesc.properties.setProperty("addresses", addressesStr) - taskDesc.properties.setProperty("numTasks", taskSet.numTasks.toString) - } + addressesWithDescs.foreach(_._2.properties.setProperty("addresses", addressesStr)) logInfo(s"Successfully scheduled all the ${addressesWithDescs.size} tasks for barrier " + s"stage ${taskSet.stageId}.") diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index a9e613568b5c5..191cc1067af70 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -25,6 +25,8 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { test("global sync by barrier() call") { val conf = new SparkConf() + // Init local cluster here so each barrier task runs in a separated process, thus `barrier()` + // call is actually useful. .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") sc = new SparkContext(conf) @@ -38,7 +40,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val times = rdd2.collect() // All the tasks shall finish global sync within a short time slot. - assert(times.max - times.min <= 5) + assert(times.max - times.min <= 1000) } test("support multiple barrier() call within a single task") { @@ -62,25 +64,25 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val times = rdd2.collect() // All the tasks shall finish the first round of global sync within a short time slot. val times1 = times.map(_._1) - assert(times1.max - times1.min <= 5) + assert(times1.max - times1.min <= 1000) // All the tasks shall finish the second round of global sync within a short time slot. val times2 = times.map(_._2) - assert(times2.max - times2.min <= 5) + assert(times2.max - times2.min <= 1000) } test("throw exception on barrier() call timeout") { val conf = new SparkConf() - .set("spark.barrier.sync.timeout", "100") + .set("spark.barrier.sync.timeout", "1") .set("spark.test.noStageRetry", "true") .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { (it, context) => - // Task 3 shall sleep 200ms to ensure barrier() call timeout + // Task 3 shall sleep 2000ms to ensure barrier() call timeout if (context.taskAttemptId() == 3) { - Thread.sleep(200) + Thread.sleep(2000) } context.barrier() it @@ -90,12 +92,12 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { rdd2.collect() }.getMessage assert(error.contains("The coordinator didn't get all barrier sync requests")) - assert(error.contains("within 100 ms")) + assert(error.contains("within 1s")) } test("throw exception if barrier() call doesn't happen on every task") { val conf = new SparkConf() - .set("spark.barrier.sync.timeout", "100") + .set("spark.barrier.sync.timeout", "1") .set("spark.test.noStageRetry", "true") .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") @@ -112,11 +114,12 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { rdd2.collect() }.getMessage assert(error.contains("The coordinator didn't get all barrier sync requests")) - assert(error.contains("within 100 ms")) + assert(error.contains("within 1s")) } - ignore("throw exception if barrier() call mismatched") { + test("throw exception if the number of barrier() calls are not the same on every task") { val conf = new SparkConf() + .set("spark.barrier.sync.timeout", "1") .set("spark.test.noStageRetry", "true") .setMaster("local-cluster[4, 1, 1024]") .setAppName("test-cluster") @@ -139,6 +142,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val error = intercept[SparkException] { rdd2.collect() }.getMessage - assert(error.contains("fails due to mismatched barrier epoch")) + assert(error.contains("The coordinator didn't get all barrier sync requests")) + assert(error.contains("within 1s")) } } From 7e413b4daaecdf36e4d96c0de2e8209266fba8c2 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Thu, 2 Aug 2018 01:44:31 +0800 Subject: [PATCH 05/19] cleanup internal data on stage completed --- .../org/apache/spark/BarrierCoordinator.scala | 16 ++++++++++++++++ .../spark/scheduler/TaskSchedulerImpl.scala | 2 +- 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 7b5e22179489b..8178b855b867d 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} /** * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync @@ -35,10 +36,20 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} */ private[spark] class BarrierCoordinator( timeout: Int, + listenerBus: LiveListenerBus, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { private val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + private val listener = new SparkListener { + override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { + val stageInfo = stageCompleted.stageInfo + // Remove internal data from a finished stage attempt. + cleanupSyncRequests(stageInfo.stageId, stageInfo.attemptNumber) + barrierEpochByStageIdAndAttempt.remove((stageInfo.stageId, stageInfo.attemptNumber)) + } + } + // Epoch counter for each barrier (stage, attempt). private val barrierEpochByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), AtomicInteger] @@ -46,6 +57,11 @@ private[spark] class BarrierCoordinator( private val syncRequestsByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), ArrayBuffer[RpcCallContext]] + override def onStart(): Unit = { + super.onStart() + listenerBus.addToStatusQueue(listener) + } + /** * Get the array of [[RpcCallContext]]s that correspond to a barrier sync request from a stage * attempt. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 7d130ef1c3f32..89618e11b3da6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -142,7 +142,7 @@ private[spark] class TaskSchedulerImpl( private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) private[scheduler] lazy val barrierCoordinator: RpcEndpoint = { - val coordinator = new BarrierCoordinator(barrierSyncTimeout, sc.env.rpcEnv) + val coordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, sc.env.rpcEnv) sc.env.rpcEnv.setupEndpoint("barrierSync", coordinator) logInfo("Registered BarrierCoordinator endpoint") coordinator From da447905ef7b947b0080616357dfaf032f85fedc Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 01:42:01 +0800 Subject: [PATCH 06/19] address comments --- .../org/apache/spark/BarrierCoordinator.scala | 70 +++++++++++++++---- 1 file changed, 55 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 8178b855b867d..65fe360f1101a 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -47,16 +47,23 @@ private[spark] class BarrierCoordinator( // Remove internal data from a finished stage attempt. cleanupSyncRequests(stageInfo.stageId, stageInfo.attemptNumber) barrierEpochByStageIdAndAttempt.remove((stageInfo.stageId, stageInfo.attemptNumber)) + cancelTimerTask(stageInfo.stageId, stageInfo.attemptNumber) } } // Epoch counter for each barrier (stage, attempt). - private val barrierEpochByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), AtomicInteger] + private val barrierEpochByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), Int] // Remember all the blocking global sync requests for each barrier (stage, attempt). private val syncRequestsByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), ArrayBuffer[RpcCallContext]] + // Remember all the TimerTasks for each barrier (stage, attempt). + private val timerTaskByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), TimerTask] + + // Number of tasks for each stage. + private val numTasksByStage = new ConcurrentHashMap[Int, Int] + override def onStart(): Unit = { super.onStart() listenerBus.addToStatusQueue(listener) @@ -69,7 +76,7 @@ private[spark] class BarrierCoordinator( private def getOrInitSyncRequests( stageId: Int, stageAttemptId: Int, - numTasks: Int = 0): ArrayBuffer[RpcCallContext] = { + numTasks: Int): ArrayBuffer[RpcCallContext] = { val requests = syncRequestsByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), new ArrayBuffer[RpcCallContext](numTasks)) if (requests == null) { @@ -95,16 +102,40 @@ private[spark] class BarrierCoordinator( /** * Get the barrier epoch that correspond to a barrier sync request from a stage attempt. */ - private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): AtomicInteger = { + private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): Int = { + val defaultBarrierEpoch = 0 val barrierEpoch = barrierEpochByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), - new AtomicInteger(0)) + defaultBarrierEpoch) if (barrierEpoch == null) { - barrierEpochByStageIdAndAttempt.get((stageId, stageAttemptId)) + defaultBarrierEpoch } else { barrierEpoch } } + /** + * Increase the barrier epoch that correspond to a barrier sync request from a stage attempt. + */ + private def increaseBarrierEpoch(stageId: Int, stageAttemptId: Int): Unit = { + val barrierEpoch = barrierEpochByStageIdAndAttempt.get((stageId, stageAttemptId)) + if (barrierEpoch != null) { + barrierEpochByStageIdAndAttempt.put((stageId, stageAttemptId), barrierEpoch + 1) + } else { + // The barrier epoch have been removed because the stage attempt already completed. + } + } + + /** + * Cancel TimerTask for a stage attempt. + */ + private def cancelTimerTask(stageId: Int, stageAttemptId: Int): Unit = { + val timerTask = timerTaskByStageIdAndAttempt.get((stageId, stageAttemptId)) + if (timerTask != null) { + timerTask.cancel() + timerTaskByStageIdAndAttempt.remove((stageId, stageAttemptId)) + } + } + /** * Send failure to all the blocking barrier sync requests from a stage attempt with proper * failure message. @@ -124,41 +155,49 @@ private[spark] class BarrierCoordinator( numTasks: Int): Boolean = { if (syncRequests.size == numTasks) { syncRequests.foreach(_.reply(())) - return true + true + } else { + false } - - false } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestToSync(numTasks, stageId, stageAttemptId, taskAttemptId, barrierEpoch) => + // Require the number of tasks is correctly set from the BarrierTaskContext. + val currentNumTasks: Any = numTasksByStage.putIfAbsent(stageId, numTasks) + require(currentNumTasks == null || currentNumTasks == numTasks, "Number of tasks of " + + s"Stage $stageId is $numTasks from Task $taskAttemptId, previously it was " + + s"$currentNumTasks.") + // Check the barrier epoch, to see which barrier() call we are processing. val currentBarrierEpoch = getOrInitBarrierEpoch(stageId, stageAttemptId) logInfo(s"Current barrier epoch for Stage $stageId (Attempt $stageAttemptId) is" + s"$currentBarrierEpoch.") - if (currentBarrierEpoch.get() != barrierEpoch) { + if (currentBarrierEpoch != barrierEpoch) { context.sendFailure(new SparkException(s"The request to sync of Stage $stageId (Attempt " + s"$stageAttemptId) with barrier epoch $barrierEpoch has already finished. Maybe task " + s"$taskAttemptId is not properly killed.")) } else { - val syncRequests = getOrInitSyncRequests(stageId, stageAttemptId) + val syncRequests = getOrInitSyncRequests(stageId, stageAttemptId, numTasks) // If this is the first sync message received for a barrier() call, init a timer to ensure // we may timeout for the sync. if (syncRequests.isEmpty) { - timer.schedule(new TimerTask { + val timerTask = new TimerTask { override def run(): Unit = { // Timeout for current barrier() call, fail all the sync requests. - val requests = getOrInitSyncRequests(stageId, stageAttemptId) + val requests = getOrInitSyncRequests(stageId, stageAttemptId, numTasks) failAllSyncRequests(requests, "The coordinator didn't get all barrier sync " + s"requests for barrier epoch $barrierEpoch from Stage $stageId (Attempt " + s"$stageAttemptId) within ${timeout}s.") cleanupSyncRequests(stageId, stageAttemptId) // The global sync fails so the stage is expected to retry another attempt, all sync // messages come from current stage attempt shall fail. - currentBarrierEpoch.set(-1) + barrierEpochByStageIdAndAttempt.put((stageId, stageAttemptId), -1) } - }, timeout * 1000) + } + timer.schedule(timerTask, timeout * 1000) + timerTaskByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), timerTask) } syncRequests += context @@ -171,7 +210,8 @@ private[spark] class BarrierCoordinator( logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId (Attempt " + s"$stageAttemptId) received all updates from tasks, finished successfully.") cleanupSyncRequests(stageId, stageAttemptId) - currentBarrierEpoch.incrementAndGet() + increaseBarrierEpoch(stageId, stageAttemptId) + cancelTimerTask(stageId, stageAttemptId) } } } From 3ced829308c7c30cc29f3b8c363f1405d0ae08e1 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 09:08:23 +0800 Subject: [PATCH 07/19] update --- .../org/apache/spark/BarrierCoordinator.scala | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 65fe360f1101a..5f2a88461a711 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -77,13 +77,9 @@ private[spark] class BarrierCoordinator( stageId: Int, stageAttemptId: Int, numTasks: Int): ArrayBuffer[RpcCallContext] = { - val requests = syncRequestsByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), + syncRequestsByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), new ArrayBuffer[RpcCallContext](numTasks)) - if (requests == null) { - syncRequestsByStageIdAndAttempt.get((stageId, stageAttemptId)) - } else { - requests - } + syncRequestsByStageIdAndAttempt.get((stageId, stageAttemptId)) } /** @@ -103,25 +99,19 @@ private[spark] class BarrierCoordinator( * Get the barrier epoch that correspond to a barrier sync request from a stage attempt. */ private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): Int = { - val defaultBarrierEpoch = 0 - val barrierEpoch = barrierEpochByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), - defaultBarrierEpoch) - if (barrierEpoch == null) { - defaultBarrierEpoch - } else { - barrierEpoch - } + barrierEpochByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), 0) + barrierEpochByStageIdAndAttempt.get((stageId, stageAttemptId)) } /** * Increase the barrier epoch that correspond to a barrier sync request from a stage attempt. */ private def increaseBarrierEpoch(stageId: Int, stageAttemptId: Int): Unit = { - val barrierEpoch = barrierEpochByStageIdAndAttempt.get((stageId, stageAttemptId)) - if (barrierEpoch != null) { + val barrierEpoch = barrierEpochByStageIdAndAttempt.getOrDefault((stageId, stageAttemptId), -1) + if (barrierEpoch >= 0) { barrierEpochByStageIdAndAttempt.put((stageId, stageAttemptId), barrierEpoch + 1) } else { - // The barrier epoch have been removed because the stage attempt already completed. + // The stage attempt already finished, don't update barrier epoch. } } From 2f23e44d34f6ff4429d018cc601e654efb7031ad Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 09:20:07 +0800 Subject: [PATCH 08/19] update --- .../org/apache/spark/BarrierTaskContext.scala | 58 +++++++++- .../apache/spark/BarrierTaskContextImpl.scala | 104 ------------------ 2 files changed, 56 insertions(+), 106 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index ba303680d1a0f..6dcf811e8aed2 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -17,12 +17,17 @@ package org.apache.spark -import java.util.Properties +import java.util.{Properties, Timer, TimerTask} + +import scala.concurrent.duration._ +import scala.language.postfixOps import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} +import org.apache.spark.util.{RpcUtils, Utils} /** A [[TaskContext]] with extra info and tooling for a barrier stage. */ class BarrierTaskContext( @@ -39,6 +44,17 @@ class BarrierTaskContext( extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, taskMetrics) { + private val barrierCoordinator: RpcEndpointRef = { + val env = SparkEnv.get + RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) + } + + private val timer = new Timer("Barrier task timer for barrier() calls.") + + private var barrierEpoch = 0 + + private lazy val numTasks = getTaskInfos().size + /** * :: Experimental :: * Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to @@ -80,7 +96,45 @@ class BarrierTaskContext( @Experimental @Since("2.4.0") def barrier(): Unit = { - // TODO SPARK-24817 implement global barrier. + val callSite = Utils.getCallSite() + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + + s"the global sync, current barrier epoch is $barrierEpoch.") + logTrace(s"Current callSite: $callSite") + + val startTime = System.currentTimeMillis() + val timerTask = new TimerTask { + override def run(): Unit = { + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + + s"under the global sync since $startTime, has been waiting for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + } + } + // Log the update of global sync every 60 seconds. + timer.schedule(timerTask, 60000, 60000) + + try { + barrierCoordinator.askSync[Unit]( + message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, + barrierEpoch), + // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by + // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. + timeout = new RpcTimeout(31536000 /** = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + barrierEpoch += 1 + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + + "global sync successfully, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + + s"$barrierEpoch.") + } catch { + case e: SparkException => + logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + + "to perform global sync, waited for " + + s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + + s"is $barrierEpoch.") + throw e + } finally { + timerTask.cancel() + } } /** diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala deleted file mode 100644 index 6e8d559bba9b0..0000000000000 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContextImpl.scala +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.util.{Properties, Timer, TimerTask} - -import scala.concurrent.duration._ -import scala.language.postfixOps - -import org.apache.spark.executor.TaskMetrics -import org.apache.spark.memory.TaskMemoryManager -import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.rpc.{RpcEndpointRef, RpcTimeout} -import org.apache.spark.util.{RpcUtils, Utils} - -/** A [[BarrierTaskContext]] implementation. */ -private[spark] class BarrierTaskContextImpl( - override val stageId: Int, - override val stageAttemptNumber: Int, - override val partitionId: Int, - override val taskAttemptId: Long, - override val attemptNumber: Int, - override val taskMemoryManager: TaskMemoryManager, - localProperties: Properties, - @transient private val metricsSystem: MetricsSystem, - // The default value is only used in tests. - override val taskMetrics: TaskMetrics = TaskMetrics.empty) - extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, - taskMemoryManager, localProperties, metricsSystem, taskMetrics) - with BarrierTaskContext { - - private val barrierCoordinator: RpcEndpointRef = { - val env = SparkEnv.get - RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) - } - - private val timer = new Timer("Barrier task timer for barrier() calls.") - - private var barrierEpoch = 0 - - private lazy val numTasks = getTaskInfos().size - - override def barrier(): Unit = { - val callSite = Utils.getCallSite() - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + - s"the global sync, current barrier epoch is $barrierEpoch.") - logTrace(s"Current callSite: $callSite") - - val startTime = System.currentTimeMillis() - val timerTask = new TimerTask { - override def run(): Unit = { - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) waiting " + - s"under the global sync since $startTime, has been waiting for " + - s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + - s"is $barrierEpoch.") - } - } - // Log the update of global sync every 60 seconds. - timer.schedule(timerTask, 60000, 60000) - - try { - barrierCoordinator.askSync[Unit]( - message = RequestToSync(numTasks, stageId, stageAttemptNumber, taskAttemptId, - barrierEpoch), - // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by - // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. - timeout = new RpcTimeout(31536000 /** = 3600 * 24 * 365 */ seconds, "barrierTimeout")) - barrierEpoch += 1 - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + - "global sync successfully, waited for " + - s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch is " + - s"$barrierEpoch.") - } catch { - case e: SparkException => - logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) failed " + - "to perform global sync, waited for " + - s"${(System.currentTimeMillis() - startTime) / 1000} seconds, current barrier epoch " + - s"is $barrierEpoch.") - throw e - } finally { - timerTask.cancel() - } - } - - override def getTaskInfos(): Array[BarrierTaskInfo] = { - val addressesStr = localProperties.getProperty("addresses", "") - addressesStr.split(",").map(_.trim()).map(new BarrierTaskInfo(_)) - } -} From 67dcf17a47333c030a877a4fade463747c7bcf38 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 11:47:27 +0800 Subject: [PATCH 09/19] update --- .../apache/spark/scheduler/BarrierTaskContextSuite.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 191cc1067af70..380521d1ce305 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -81,7 +81,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { (it, context) => // Task 3 shall sleep 2000ms to ensure barrier() call timeout - if (context.taskAttemptId() == 3) { + if (context.taskAttemptId == 3) { Thread.sleep(2000) } context.barrier() @@ -104,7 +104,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext(conf) val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { (it, context) => - if (context.taskAttemptId() != 0) { + if (context.taskAttemptId != 0) { context.barrier() } it @@ -127,7 +127,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val rdd = sc.makeRDD(1 to 10, 4) val rdd2 = rdd.barrier().mapPartitions { (it, context) => try { - if (context.taskAttemptId() == 0) { + if (context.taskAttemptId == 0) { // Task 0 skip the first barrier() call. throw new SparkException("test") } From e29e3b6b087594c1a52b45705724fe4766b1b236 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 21:55:48 +0800 Subject: [PATCH 10/19] refactor --- .../org/apache/spark/BarrierCoordinator.scala | 231 +++++++++--------- 1 file changed, 114 insertions(+), 117 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 5f2a88461a711..75936bfd90a66 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap -import java.util.concurrent.atomic.AtomicInteger +import java.util.function.Consumer import scala.collection.mutable.ArrayBuffer @@ -27,6 +27,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} +/** + * Only one barrier() call shall happen on a barrier stage attempt at each time, we can use + * (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is from. + */ +private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { + override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" +} + /** * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync * request is generated by `BarrierTaskContext.barrier()`, and identified by @@ -44,25 +52,16 @@ private[spark] class BarrierCoordinator( private val listener = new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { val stageInfo = stageCompleted.stageInfo - // Remove internal data from a finished stage attempt. - cleanupSyncRequests(stageInfo.stageId, stageInfo.attemptNumber) - barrierEpochByStageIdAndAttempt.remove((stageInfo.stageId, stageInfo.attemptNumber)) - cancelTimerTask(stageInfo.stageId, stageInfo.attemptNumber) + val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) + // Clear ContextBarrierState from a finished stage attempt. + val barrierState = states.remove(barrierId) + if (barrierState != null) { + barrierState.clear() + } } } - // Epoch counter for each barrier (stage, attempt). - private val barrierEpochByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), Int] - - // Remember all the blocking global sync requests for each barrier (stage, attempt). - private val syncRequestsByStageIdAndAttempt = - new ConcurrentHashMap[(Int, Int), ArrayBuffer[RpcCallContext]] - - // Remember all the TimerTasks for each barrier (stage, attempt). - private val timerTaskByStageIdAndAttempt = new ConcurrentHashMap[(Int, Int), TimerTask] - - // Number of tasks for each stage. - private val numTasksByStage = new ConcurrentHashMap[Int, Int] + private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] override def onStart(): Unit = { super.onStart() @@ -70,59 +69,91 @@ private[spark] class BarrierCoordinator( } /** - * Get the array of [[RpcCallContext]]s that correspond to a barrier sync request from a stage - * attempt. + * Provide current state of a barrier() call, the state is created when a new stage attempt send + * out a barrier() call, and recycled on stage completed. + * + * @param barrierId Identifier of the barrier stage that make a barrier() call. + * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall + * collect `numTasks` requests to succeed. */ - private def getOrInitSyncRequests( - stageId: Int, - stageAttemptId: Int, - numTasks: Int): ArrayBuffer[RpcCallContext] = { - syncRequestsByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), - new ArrayBuffer[RpcCallContext](numTasks)) - syncRequestsByStageIdAndAttempt.get((stageId, stageAttemptId)) - } + private class ContextBarrierState( + val barrierId: ContextBarrierId, + val numTasks: Int) { + + // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used + // to identify each barrier() call. It shall get increased when a barrier() call succeed, or + // reset when a barrier() call fail due to timeout. + private var barrierEpoch: Int = 0 + + // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() + // call. + private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks) + + // A timer task that ensures we may timeout for a barrier() call. + private var timerTask: TimerTask = null + + private def initTimerTask(): TimerTask = new TimerTask { + override def run(): Unit = { + // Timeout current barrier() call, fail all the sync requests. + failAllRequesters(requesters, "The coordinator didn't get all barrier sync " + + s"requests for barrier epoch $barrierEpoch from $barrierId within ${timeout}s.") + cleanupBarrierStage(barrierId) + } + } - /** - * Clean up the array of [[RpcCallContext]]s that correspond to a barrier sync request from a - * stage attempt. - */ - private def cleanupSyncRequests(stageId: Int, stageAttemptId: Int): Unit = { - val requests = syncRequestsByStageIdAndAttempt.remove((stageId, stageAttemptId)) - if (requests != null) { - requests.clear() + private def cancelTimerTask(): Unit = { + if (timerTask != null) { + timerTask.cancel() + timerTask = null + } } - logInfo(s"Removed all the pending barrier sync requests from Stage $stageId (Attempt " + - s"$stageAttemptId).") - } - /** - * Get the barrier epoch that correspond to a barrier sync request from a stage attempt. - */ - private def getOrInitBarrierEpoch(stageId: Int, stageAttemptId: Int): Int = { - barrierEpochByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), 0) - barrierEpochByStageIdAndAttempt.get((stageId, stageAttemptId)) - } + def handleRequest(requester: RpcCallContext, epoch: Int, taskId: Long): Unit = synchronized { + // Check whether the epoch from the barrier tasks matches current barrierEpoch. + logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") + if (epoch != barrierEpoch) { + requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " + + s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " + + "properly killed.")) + } else { + // If this is the first sync message received for a barrier() call, start timer to ensure + // we may timeout for the sync. + if (requesters.isEmpty) { + timerTask = initTimerTask() + timer.schedule(timerTask, timeout * 1000L) + } + // Add the requester to array of RPCCallContexts pending for reply. + requesters += requester + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received update from Task " + + s"$taskId, current progress: ${requesters.size}/$numTasks.") + if (maybeFinishAllRequesters(requesters, numTasks)) { + // Finished current barrier() call successfully, clean up ContextBarrierState and + // increase the barrier epoch. + logInfo(s"Barrier sync epoch $barrierEpoch from $barrierId received all updates from " + + s"tasks, finished successfully.") + barrierEpoch += 1 + requesters.clear() + cancelTimerTask() + } + } + } - /** - * Increase the barrier epoch that correspond to a barrier sync request from a stage attempt. - */ - private def increaseBarrierEpoch(stageId: Int, stageAttemptId: Int): Unit = { - val barrierEpoch = barrierEpochByStageIdAndAttempt.getOrDefault((stageId, stageAttemptId), -1) - if (barrierEpoch >= 0) { - barrierEpochByStageIdAndAttempt.put((stageId, stageAttemptId), barrierEpoch + 1) - } else { - // The stage attempt already finished, don't update barrier epoch. + def clear(): Unit = synchronized { + // The global sync fails so the stage is expected to retry another attempt, all sync + // messages come from current stage attempt shall fail. + barrierEpoch = -1 + requesters.clear() + cancelTimerTask() } } /** - * Cancel TimerTask for a stage attempt. + * Clean up the [[ContextBarrierState]] that correspond to a stage attempt. */ - private def cancelTimerTask(stageId: Int, stageAttemptId: Int): Unit = { - val timerTask = timerTaskByStageIdAndAttempt.get((stageId, stageAttemptId)) - if (timerTask != null) { - timerTask.cancel() - timerTaskByStageIdAndAttempt.remove((stageId, stageAttemptId)) + private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { + val barrierState = states.remove(barrierId) + if (barrierState != null) { + barrierState.clear() } } @@ -130,83 +161,49 @@ private[spark] class BarrierCoordinator( * Send failure to all the blocking barrier sync requests from a stage attempt with proper * failure message. */ - private def failAllSyncRequests( - syncRequests: ArrayBuffer[RpcCallContext], + private def failAllRequesters( + requesters: ArrayBuffer[RpcCallContext], message: String): Unit = { - syncRequests.foreach(_.sendFailure(new SparkException(message))) + requesters.foreach(_.sendFailure(new SparkException(message))) } /** * Finish all the blocking barrier sync requests from a stage attempt successfully if we * have received all the sync requests. */ - private def maybeFinishAllSyncRequests( - syncRequests: ArrayBuffer[RpcCallContext], + private def maybeFinishAllRequesters( + requesters: ArrayBuffer[RpcCallContext], numTasks: Int): Boolean = { - if (syncRequests.size == numTasks) { - syncRequests.foreach(_.reply(())) + if (requesters.size == numTasks) { + requesters.foreach(_.reply(())) true } else { false } } - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RequestToSync(numTasks, stageId, stageAttemptId, taskAttemptId, barrierEpoch) => + // Get or init the ContextBarrierState correspond to the stage attempt. + val barrierId = ContextBarrierId(stageId, stageAttemptId) + states.putIfAbsent(barrierId, new ContextBarrierState(barrierId, numTasks)) + val barrierState = states.get(barrierId) + // Require the number of tasks is correctly set from the BarrierTaskContext. - val currentNumTasks: Any = numTasksByStage.putIfAbsent(stageId, numTasks) - require(currentNumTasks == null || currentNumTasks == numTasks, "Number of tasks of " + - s"Stage $stageId is $numTasks from Task $taskAttemptId, previously it was " + - s"$currentNumTasks.") - - // Check the barrier epoch, to see which barrier() call we are processing. - val currentBarrierEpoch = getOrInitBarrierEpoch(stageId, stageAttemptId) - logInfo(s"Current barrier epoch for Stage $stageId (Attempt $stageAttemptId) is" + - s"$currentBarrierEpoch.") - if (currentBarrierEpoch != barrierEpoch) { - context.sendFailure(new SparkException(s"The request to sync of Stage $stageId (Attempt " + - s"$stageAttemptId) with barrier epoch $barrierEpoch has already finished. Maybe task " + - s"$taskAttemptId is not properly killed.")) - } else { - val syncRequests = getOrInitSyncRequests(stageId, stageAttemptId, numTasks) - // If this is the first sync message received for a barrier() call, init a timer to ensure - // we may timeout for the sync. - if (syncRequests.isEmpty) { - val timerTask = new TimerTask { - override def run(): Unit = { - // Timeout for current barrier() call, fail all the sync requests. - val requests = getOrInitSyncRequests(stageId, stageAttemptId, numTasks) - failAllSyncRequests(requests, "The coordinator didn't get all barrier sync " + - s"requests for barrier epoch $barrierEpoch from Stage $stageId (Attempt " + - s"$stageAttemptId) within ${timeout}s.") - cleanupSyncRequests(stageId, stageAttemptId) - // The global sync fails so the stage is expected to retry another attempt, all sync - // messages come from current stage attempt shall fail. - barrierEpochByStageIdAndAttempt.put((stageId, stageAttemptId), -1) - } - } - timer.schedule(timerTask, timeout * 1000) - timerTaskByStageIdAndAttempt.putIfAbsent((stageId, stageAttemptId), timerTask) - } + require(barrierState.numTasks == numTasks, s"Number of tasks of $barrierId is $numTasks " + + s"from Task $taskAttemptId, previously it was ${barrierState.numTasks}.") - syncRequests += context - logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId (Attempt " + - s"$stageAttemptId) received update from Task $taskAttemptId, current progress: " + - s"${syncRequests.size}/$numTasks.") - if (maybeFinishAllSyncRequests(syncRequests, numTasks)) { - // Finished current barrier() call successfully, clean up internal data and increase the - // barrier epoch. - logInfo(s"Barrier sync epoch $barrierEpoch from Stage $stageId (Attempt " + - s"$stageAttemptId) received all updates from tasks, finished successfully.") - cleanupSyncRequests(stageId, stageAttemptId) - increaseBarrierEpoch(stageId, stageAttemptId) - cancelTimerTask(stageId, stageAttemptId) - } - } + barrierState.handleRequest(context, barrierEpoch, taskAttemptId) + } + + private val stateConsumer = new Consumer[ContextBarrierState] { + override def accept(state: ContextBarrierState) = state.clear() } - override def onStop(): Unit = timer.cancel() + override def onStop(): Unit = { + states.forEachValue(1, stateConsumer) + states.clear() + } } private[spark] sealed trait BarrierCoordinatorMessage extends Serializable From 33a89269906e94780e1b992e9400c9d8f98698f5 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Fri, 3 Aug 2018 22:21:39 +0800 Subject: [PATCH 11/19] update comment --- .../main/scala/org/apache/spark/BarrierCoordinator.scala | 8 ++++++++ .../main/scala/org/apache/spark/BarrierTaskContext.scala | 5 +++++ 2 files changed, 13 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 75936bfd90a66..c3e114c6b2e6d 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -49,6 +49,7 @@ private[spark] class BarrierCoordinator( private val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + // Listen to StageCompleted event, clear corresponding ContextBarrierState. private val listener = new SparkListener { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = { val stageInfo = stageCompleted.stageInfo @@ -61,6 +62,8 @@ private[spark] class BarrierCoordinator( } } + // Remember all active stage attempts that make barrier() call(s), and the corresponding + // internal state. private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] override def onStart(): Unit = { @@ -92,6 +95,7 @@ private[spark] class BarrierCoordinator( // A timer task that ensures we may timeout for a barrier() call. private var timerTask: TimerTask = null + // Init a TimerTask for a barrier() call. private def initTimerTask(): TimerTask = new TimerTask { override def run(): Unit = { // Timeout current barrier() call, fail all the sync requests. @@ -101,6 +105,7 @@ private[spark] class BarrierCoordinator( } } + // Cancel the current active TimerTask and release resources. private def cancelTimerTask(): Unit = { if (timerTask != null) { timerTask.cancel() @@ -108,6 +113,8 @@ private[spark] class BarrierCoordinator( } } + // Process the global sync request. The barrier() call succeed if collected enough requests + // within a configured time, otherwise fail all the pending requests. def handleRequest(requester: RpcCallContext, epoch: Int, taskId: Long): Unit = synchronized { // Check whether the epoch from the barrier tasks matches current barrierEpoch. logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") @@ -138,6 +145,7 @@ private[spark] class BarrierCoordinator( } } + // Cleanup the internal state of a barrier stage attempt. def clear(): Unit = synchronized { // The global sync fails so the stage is expected to retry another attempt, all sync // messages come from current stage attempt shall fail. diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 6dcf811e8aed2..e78ad0393dfea 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -44,6 +44,7 @@ class BarrierTaskContext( extends TaskContextImpl(stageId, stageAttemptNumber, partitionId, taskAttemptId, attemptNumber, taskMemoryManager, localProperties, metricsSystem, taskMetrics) { + // Find the driver side RPCEndpointRef of the coordinator that handles all the barrier() calls. private val barrierCoordinator: RpcEndpointRef = { val env = SparkEnv.get RpcUtils.makeDriverRef("barrierSync", env.conf, env.rpcEnv) @@ -51,8 +52,12 @@ class BarrierTaskContext( private val timer = new Timer("Barrier task timer for barrier() calls.") + // Local barrierEpoch that identify a barrier() call from current task, it shall be identical + // with the driver side epoch. private var barrierEpoch = 0 + // Number of tasks of the current barrier stage, a barrier() call must collect enough requests + // from different tasks within the same barrier stage attempt to succeed. private lazy val numTasks = getTaskInfos().size /** From 16ee90e5bbecbb94b54d61703e7d7d17a58b3bcb Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 4 Aug 2018 00:57:31 +0800 Subject: [PATCH 12/19] update --- .../org/apache/spark/BarrierCoordinator.scala | 87 +++++++++---------- .../org/apache/spark/BarrierTaskContext.scala | 2 +- 2 files changed, 44 insertions(+), 45 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index c3e114c6b2e6d..b5141348af18a 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -96,12 +96,14 @@ private[spark] class BarrierCoordinator( private var timerTask: TimerTask = null // Init a TimerTask for a barrier() call. - private def initTimerTask(): TimerTask = new TimerTask { - override def run(): Unit = { - // Timeout current barrier() call, fail all the sync requests. - failAllRequesters(requesters, "The coordinator didn't get all barrier sync " + - s"requests for barrier epoch $barrierEpoch from $barrierId within ${timeout}s.") - cleanupBarrierStage(barrierId) + private def initTimerTask(): Unit = { + timerTask = new TimerTask { + override def run(): Unit = { + // Timeout current barrier() call, fail all the sync requests. + failAllRequesters(requesters, "The coordinator didn't get all barrier sync " + + s"requests for barrier epoch $barrierEpoch from $barrierId within ${timeout}s.") + cleanupBarrierStage(barrierId) + } } } @@ -115,7 +117,14 @@ private[spark] class BarrierCoordinator( // Process the global sync request. The barrier() call succeed if collected enough requests // within a configured time, otherwise fail all the pending requests. - def handleRequest(requester: RpcCallContext, epoch: Int, taskId: Long): Unit = synchronized { + def handleRequest(requester: RpcCallContext, request: RequestToSync): Unit = synchronized { + val taskId = request.taskAttemptId + val epoch = request.barrierEpoch + + // Require the number of tasks is correctly set from the BarrierTaskContext. + require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " + + s"${request.numTasks} from Task $taskId, previously it was $numTasks.") + // Check whether the epoch from the barrier tasks matches current barrierEpoch. logInfo(s"Current barrier epoch for $barrierId is $barrierEpoch.") if (epoch != barrierEpoch) { @@ -126,7 +135,7 @@ private[spark] class BarrierCoordinator( // If this is the first sync message received for a barrier() call, start timer to ensure // we may timeout for the sync. if (requesters.isEmpty) { - timerTask = initTimerTask() + initTimerTask() timer.schedule(timerTask, timeout * 1000L) } // Add the requester to array of RPCCallContexts pending for reply. @@ -145,6 +154,27 @@ private[spark] class BarrierCoordinator( } } + // Send failure to all the blocking barrier sync requests from a stage attempt with proper + // failure message. + private def failAllRequesters( + requesters: ArrayBuffer[RpcCallContext], + message: String): Unit = { + requesters.foreach(_.sendFailure(new SparkException(message))) + } + + // Finish all the blocking barrier sync requests from a stage attempt successfully if we + // have received all the sync requests. + private def maybeFinishAllRequesters( + requesters: ArrayBuffer[RpcCallContext], + numTasks: Int): Boolean = { + if (requesters.size == numTasks) { + requesters.foreach(_.reply(())) + true + } else { + false + } + } + // Cleanup the internal state of a barrier stage attempt. def clear(): Unit = synchronized { // The global sync fails so the stage is expected to retry another attempt, all sync @@ -155,9 +185,7 @@ private[spark] class BarrierCoordinator( } } - /** - * Clean up the [[ContextBarrierState]] that correspond to a stage attempt. - */ + // Clean up the [[ContextBarrierState]] that correspond to a stage attempt. private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { val barrierState = states.remove(barrierId) if (barrierState != null) { @@ -165,51 +193,22 @@ private[spark] class BarrierCoordinator( } } - /** - * Send failure to all the blocking barrier sync requests from a stage attempt with proper - * failure message. - */ - private def failAllRequesters( - requesters: ArrayBuffer[RpcCallContext], - message: String): Unit = { - requesters.foreach(_.sendFailure(new SparkException(message))) - } - - /** - * Finish all the blocking barrier sync requests from a stage attempt successfully if we - * have received all the sync requests. - */ - private def maybeFinishAllRequesters( - requesters: ArrayBuffer[RpcCallContext], - numTasks: Int): Boolean = { - if (requesters.size == numTasks) { - requesters.foreach(_.reply(())) - true - } else { - false - } - } - override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RequestToSync(numTasks, stageId, stageAttemptId, taskAttemptId, barrierEpoch) => + case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => // Get or init the ContextBarrierState correspond to the stage attempt. val barrierId = ContextBarrierId(stageId, stageAttemptId) states.putIfAbsent(barrierId, new ContextBarrierState(barrierId, numTasks)) val barrierState = states.get(barrierId) - // Require the number of tasks is correctly set from the BarrierTaskContext. - require(barrierState.numTasks == numTasks, s"Number of tasks of $barrierId is $numTasks " + - s"from Task $taskAttemptId, previously it was ${barrierState.numTasks}.") - - barrierState.handleRequest(context, barrierEpoch, taskAttemptId) + barrierState.handleRequest(context, request) } - private val stateConsumer = new Consumer[ContextBarrierState] { + private val clearStateConsumer = new Consumer[ContextBarrierState] { override def accept(state: ContextBarrierState) = state.clear() } override def onStop(): Unit = { - states.forEachValue(1, stateConsumer) + states.forEachValue(1, clearStateConsumer) states.clear() } } diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index e78ad0393dfea..343e218325438 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -124,7 +124,7 @@ class BarrierTaskContext( barrierEpoch), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. - timeout = new RpcTimeout(31536000 /** = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) barrierEpoch += 1 logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + "global sync successfully, waited for " + From 53aa316bb10344fdec3ed4378f9386a3400fa8cb Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sat, 4 Aug 2018 22:53:31 +0800 Subject: [PATCH 13/19] update --- .../org/apache/spark/BarrierCoordinator.scala | 20 +++++-------------- .../scheduler/BarrierTaskContextSuite.scala | 4 +++- 2 files changed, 8 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index b5141348af18a..55c73299d355e 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -55,10 +55,7 @@ private[spark] class BarrierCoordinator( val stageInfo = stageCompleted.stageInfo val barrierId = ContextBarrierId(stageInfo.stageId, stageInfo.attemptNumber) // Clear ContextBarrierState from a finished stage attempt. - val barrierState = states.remove(barrierId) - if (barrierState != null) { - barrierState.clear() - } + cleanupBarrierStage(barrierId) } } @@ -98,10 +95,11 @@ private[spark] class BarrierCoordinator( // Init a TimerTask for a barrier() call. private def initTimerTask(): Unit = { timerTask = new TimerTask { - override def run(): Unit = { + override def run(): Unit = synchronized { // Timeout current barrier() call, fail all the sync requests. - failAllRequesters(requesters, "The coordinator didn't get all barrier sync " + - s"requests for barrier epoch $barrierEpoch from $barrierId within ${timeout}s.") + requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + + s"$timeout seconds."))) cleanupBarrierStage(barrierId) } } @@ -154,14 +152,6 @@ private[spark] class BarrierCoordinator( } } - // Send failure to all the blocking barrier sync requests from a stage attempt with proper - // failure message. - private def failAllRequesters( - requesters: ArrayBuffer[RpcCallContext], - message: String): Unit = { - requesters.foreach(_.sendFailure(new SparkException(message))) - } - // Finish all the blocking barrier sync requests from a stage attempt successfully if we // have received all the sync requests. private def maybeFinishAllRequesters( diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index 380521d1ce305..eada9db7045bd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -128,7 +128,9 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { val rdd2 = rdd.barrier().mapPartitions { (it, context) => try { if (context.taskAttemptId == 0) { - // Task 0 skip the first barrier() call. + // Due to some non-obvious reason, the code can trigger an Exception and skip the + // following statements within the try ... catch block, including the first barrier() + // call. throw new SparkException("test") } context.barrier() From da52db276e36b5fc7f06b752891f180d9d768645 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 5 Aug 2018 16:42:01 +0800 Subject: [PATCH 14/19] fix test failure --- .../main/scala/org/apache/spark/BarrierCoordinator.scala | 4 +++- .../apache/spark/scheduler/BarrierTaskContextSuite.scala | 6 +++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 55c73299d355e..ffb39a2f935d6 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -99,7 +99,7 @@ private[spark] class BarrierCoordinator( // Timeout current barrier() call, fail all the sync requests. requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + - s"$timeout seconds."))) + s"$timeout second(s)."))) cleanupBarrierStage(barrierId) } } @@ -200,6 +200,8 @@ private[spark] class BarrierCoordinator( override def onStop(): Unit = { states.forEachValue(1, clearStateConsumer) states.clear() + listenerBus.removeListener(listener) + super.onStop() } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala index eada9db7045bd..5f96d6fb0cdb6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BarrierTaskContextSuite.scala @@ -92,7 +92,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { rdd2.collect() }.getMessage assert(error.contains("The coordinator didn't get all barrier sync requests")) - assert(error.contains("within 1s")) + assert(error.contains("within 1 second(s)")) } test("throw exception if barrier() call doesn't happen on every task") { @@ -114,7 +114,7 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { rdd2.collect() }.getMessage assert(error.contains("The coordinator didn't get all barrier sync requests")) - assert(error.contains("within 1s")) + assert(error.contains("within 1 second(s)")) } test("throw exception if the number of barrier() calls are not the same on every task") { @@ -145,6 +145,6 @@ class BarrierTaskContextSuite extends SparkFunSuite with LocalSparkContext { rdd2.collect() }.getMessage assert(error.contains("The coordinator didn't get all barrier sync requests")) - assert(error.contains("within 1s")) + assert(error.contains("within 1 second(s)")) } } From 8e888b516f4b6a5782ecc1fe8d5f108a1335549b Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 5 Aug 2018 19:28:08 +0800 Subject: [PATCH 15/19] update --- .../main/scala/org/apache/spark/SparkContext.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 78ba0b31fc6bb..ba13567459e1d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1930,6 +1930,12 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _executorAllocationManager.foreach(_.stop()) } + if (_dagScheduler != null) { + Utils.tryLogNonFatalError { + _dagScheduler.stop() + } + _dagScheduler = null + } if (_listenerBusStarted) { Utils.tryLogNonFatalError { listenerBus.stop() @@ -1939,12 +1945,6 @@ class SparkContext(config: SparkConf) extends Logging { Utils.tryLogNonFatalError { _eventLogger.foreach(_.stop()) } - if (_dagScheduler != null) { - Utils.tryLogNonFatalError { - _dagScheduler.stop() - } - _dagScheduler = null - } if (env != null && _heartbeatReceiver != null) { Utils.tryLogNonFatalError { env.rpcEnv.stop(_heartbeatReceiver) From a8fa8db6a478009c9be6b9063fbddf30dcdaa316 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Sun, 5 Aug 2018 21:45:54 +0800 Subject: [PATCH 16/19] update --- .../org/apache/spark/BarrierCoordinator.scala | 43 +++++++++++-------- .../spark/internal/config/package.scala | 4 +- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index ffb39a2f935d6..abd2b3c612276 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -28,8 +28,9 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler.{LiveListenerBus, SparkListener, SparkListenerStageCompleted} /** - * Only one barrier() call shall happen on a barrier stage attempt at each time, we can use - * (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is from. + * For each barrier stage attempt, only at most one barrier() call can be active at any time, thus + * we can use (stageId, stageAttemptId) to identify the stage attempt where the barrier() call is + * from. */ private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { override def toString: String = s"Stage $stageId (Attempt $stageAttemptId)" @@ -39,11 +40,12 @@ private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) { * A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync * request is generated by `BarrierTaskContext.barrier()`, and identified by * stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon - * received all the requests for a group of `barrier()` calls. If the coordinator doesn't collect - * enough global sync requests within a configured time, fail all the requests due to timeout. + * all the requests for a group of `barrier()` calls are received. If the coordinator is unable to + * collect enough global sync requests within a configured time, fail all the requests and return + * an Exception with timeout message. */ private[spark] class BarrierCoordinator( - timeout: Int, + timeoutInSecs: Long, listenerBus: LiveListenerBus, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { @@ -59,8 +61,8 @@ private[spark] class BarrierCoordinator( } } - // Remember all active stage attempts that make barrier() call(s), and the corresponding - // internal state. + // Record all active stage attempts that make barrier() call(s), and the corresponding internal + // state. private val states = new ConcurrentHashMap[ContextBarrierId, ContextBarrierState] override def onStart(): Unit = { @@ -68,9 +70,19 @@ private[spark] class BarrierCoordinator( listenerBus.addToStatusQueue(listener) } + override def onStop(): Unit = { + try { + states.forEachValue(1, clearStateConsumer) + states.clear() + listenerBus.removeListener(listener) + } finally { + super.onStop() + } + } + /** - * Provide current state of a barrier() call, the state is created when a new stage attempt send - * out a barrier() call, and recycled on stage completed. + * Provide the current state of a barrier() call. A state is created when a new stage attempt + * sends out a barrier() call, and recycled on stage completed. * * @param barrierId Identifier of the barrier stage that make a barrier() call. * @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall @@ -99,7 +111,7 @@ private[spark] class BarrierCoordinator( // Timeout current barrier() call, fail all the sync requests. requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " + s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " + - s"$timeout second(s)."))) + s"$timeoutInSecs second(s)."))) cleanupBarrierStage(barrierId) } } @@ -134,7 +146,7 @@ private[spark] class BarrierCoordinator( // we may timeout for the sync. if (requesters.isEmpty) { initTimerTask() - timer.schedule(timerTask, timeout * 1000L) + timer.schedule(timerTask, timeoutInSecs * 1000) } // Add the requester to array of RPCCallContexts pending for reply. requesters += requester @@ -175,7 +187,7 @@ private[spark] class BarrierCoordinator( } } - // Clean up the [[ContextBarrierState]] that correspond to a stage attempt. + // Clean up the [[ContextBarrierState]] that correspond to a specific stage attempt. private def cleanupBarrierStage(barrierId: ContextBarrierId): Unit = { val barrierState = states.remove(barrierId) if (barrierState != null) { @@ -196,13 +208,6 @@ private[spark] class BarrierCoordinator( private val clearStateConsumer = new Consumer[ContextBarrierState] { override def accept(state: ContextBarrierState) = state.clear() } - - override def onStop(): Unit = { - states.forEachValue(1, clearStateConsumer) - states.clear() - listenerBus.removeListener(listener) - super.onStop() - } } private[spark] sealed trait BarrierCoordinatorMessage extends Serializable diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 22f98873844ea..5760a897b567d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -574,7 +574,7 @@ package object config { "coordinator didn't receive all the sync messages from barrier tasks within the " + "configed time, throw a SparkException to fail all the tasks. The default value is set " + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") - .intConf + .timeConf(TimeUnit.SECONDS) .checkValue(v => v > 0, "The value should be a positive int value.") - .createWithDefault(31536000) + .createWithDefaultString("365d") } From ab49fedb59ff79fc4d04135859a41500f07fa734 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 6 Aug 2018 02:12:18 +0800 Subject: [PATCH 17/19] fix --- core/src/main/scala/org/apache/spark/BarrierCoordinator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index abd2b3c612276..30119d6d3b45d 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -49,7 +49,7 @@ private[spark] class BarrierCoordinator( listenerBus: LiveListenerBus, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { - private val timer = new Timer("BarrierCoordinator barrier epoch increment timer") + private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") // Listen to StageCompleted event, clear corresponding ContextBarrierState. private val listener = new SparkListener { From 027ca717af9ea38ce2e7a8d2b19a7a0496cf4bb4 Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 6 Aug 2018 17:15:04 +0800 Subject: [PATCH 18/19] update --- .../org/apache/spark/BarrierCoordinator.scala | 13 ++++++++----- .../org/apache/spark/BarrierTaskContext.scala | 3 +-- .../apache/spark/internal/config/package.scala | 2 +- .../spark/scheduler/TaskSchedulerImpl.scala | 16 ++++++++++------ 4 files changed, 20 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 30119d6d3b45d..3bc4576bf1a76 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -19,7 +19,7 @@ package org.apache.spark import java.util.{Timer, TimerTask} import java.util.concurrent.ConcurrentHashMap -import java.util.function.Consumer +import java.util.function.{Consumer, Function} import scala.collection.mutable.ArrayBuffer @@ -93,8 +93,8 @@ private[spark] class BarrierCoordinator( val numTasks: Int) { // There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used - // to identify each barrier() call. It shall get increased when a barrier() call succeed, or - // reset when a barrier() call fail due to timeout. + // to identify each barrier() call. It shall get increased when a barrier() call succeeds, or + // reset when a barrier() call fails due to timeout. private var barrierEpoch: Int = 0 // An array of RPCCallContexts for barrier tasks that are waiting for reply of a barrier() @@ -199,7 +199,10 @@ private[spark] class BarrierCoordinator( case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _) => // Get or init the ContextBarrierState correspond to the stage attempt. val barrierId = ContextBarrierId(stageId, stageAttemptId) - states.putIfAbsent(barrierId, new ContextBarrierState(barrierId, numTasks)) + states.computeIfAbsent(barrierId, new Function[ContextBarrierId, ContextBarrierState] { + override def apply(key: ContextBarrierId): ContextBarrierState = + new ContextBarrierState(key, numTasks) + }) val barrierState = states.get(barrierId) barrierState.handleRequest(context, request) @@ -220,7 +223,7 @@ private[spark] sealed trait BarrierCoordinatorMessage extends Serializable * @param stageId ID of current stage * @param stageAttemptId ID of current stage attempt * @param taskAttemptId Unique ID of current task - * @param barrierEpoch ID of the `barrier()` call, a task may consists multiple `barrier()` calls. + * @param barrierEpoch ID of the `barrier()` call, a task may consist multiple `barrier()` calls. */ private[spark] case class RequestToSync( numTasks: Int, diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 343e218325438..8e2b15599b674 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -101,10 +101,9 @@ class BarrierTaskContext( @Experimental @Since("2.4.0") def barrier(): Unit = { - val callSite = Utils.getCallSite() logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) has entered " + s"the global sync, current barrier epoch is $barrierEpoch.") - logTrace(s"Current callSite: $callSite") + logTrace("Current callSite: " + Utils.getCallSite()) val startTime = System.currentTimeMillis() val timerTask = new TimerTask { diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 5760a897b567d..eb08628ce1112 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -575,6 +575,6 @@ package object config { "configed time, throw a SparkException to fail all the tasks. The default value is set " + "to 31536000(3600 * 24 * 365) so the barrier() call shall wait for one year.") .timeConf(TimeUnit.SECONDS) - .checkValue(v => v > 0, "The value should be a positive int value.") + .checkValue(v => v > 0, "The value should be a positive time value.") .createWithDefaultString("365d") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 89618e11b3da6..8992d7e2284a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -141,11 +141,15 @@ private[spark] class TaskSchedulerImpl( private lazy val barrierSyncTimeout = conf.get(config.BARRIER_SYNC_TIMEOUT) - private[scheduler] lazy val barrierCoordinator: RpcEndpoint = { - val coordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, sc.env.rpcEnv) - sc.env.rpcEnv.setupEndpoint("barrierSync", coordinator) - logInfo("Registered BarrierCoordinator endpoint") - coordinator + private[scheduler] var barrierCoordinator: RpcEndpoint = null + + private def maybeInitBarrierCoordinator(): Unit = { + if (barrierCoordinator == null) { + barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, + sc.env.rpcEnv) + sc.env.rpcEnv.setupEndpoint("barrierSync", barrierCoordinator) + logInfo("Registered BarrierCoordinator endpoint") + } } override def setDAGScheduler(dagScheduler: DAGScheduler) { @@ -424,7 +428,7 @@ private[spark] class TaskSchedulerImpl( "been blacklisted or cannot fulfill task locality requirements.") // materialize the barrier coordinator. - barrierCoordinator + maybeInitBarrierCoordinator() // Update the taskInfos into all the barrier task properties. val addressesStr = addressesWithDescs From 1f71e6583f9f9f270d07323f15c731717e13518d Mon Sep 17 00:00:00 2001 From: Xingbo Jiang Date: Mon, 6 Aug 2018 21:31:59 +0800 Subject: [PATCH 19/19] add comment --- core/src/main/scala/org/apache/spark/BarrierCoordinator.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala index 3bc4576bf1a76..5e546c694e8d9 100644 --- a/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/BarrierCoordinator.scala @@ -49,6 +49,8 @@ private[spark] class BarrierCoordinator( listenerBus: LiveListenerBus, override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + // TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to + // fetch result, we shall fix the issue. private lazy val timer = new Timer("BarrierCoordinator barrier epoch increment timer") // Listen to StageCompleted event, clear corresponding ContextBarrierState.