From 3316d5e927cf6977ddf5826cd0c0a1f65448e5dd Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Tue, 8 Oct 2019 13:07:23 -0700 Subject: [PATCH 1/2] [SPARK-29398][core] Support dedicated thread pools for RPC endpoints. The current RPC backend in Spark supports single- and multi-threaded message delivery to endpoints, but the all share the same underlying thread pool. So an RPC endpoint that blocks a dispatcher thread can negatively affect other endpoints. This can be more pronounced with configurations that limit the number of RPC dispatch threads based on configuration and / or running environment. And exposing the RPC layer to other code (for example with something like SPARK-29396) could make it easy to affect normal Spark operation with a badly written RPC handler. This change adds a new RPC endpoint type that tells the RPC env to create dedicated dispatch threads, so that those effects are minimised. Other endpoints will still need CPU to process their messages, but they won't be able to actively block the dispatch thread of these isolated endpoints. As part of the change, I've changed the most important Spark endpoints (the driver, executor and block manager endpoints) to be isolated from others. This means a couple of extra threads are created on the driver and executor for these endpoints. Tested with existing unit tests, which hammer the RPC system extensively, and also by running applications on a cluster (with a prototype of SPARK-29396). --- .../CoarseGrainedExecutorBackend.scala | 2 +- .../org/apache/spark/rpc/RpcEndpoint.scala | 10 + .../apache/spark/rpc/netty/Dispatcher.scala | 130 ++++-------- .../org/apache/spark/rpc/netty/Inbox.scala | 6 +- .../apache/spark/rpc/netty/MessageLoop.scala | 194 ++++++++++++++++++ .../CoarseGrainedSchedulerBackend.scala | 2 +- .../storage/BlockManagerMasterEndpoint.scala | 4 +- .../storage/BlockManagerSlaveEndpoint.scala | 4 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 35 +++- .../apache/spark/rpc/netty/InboxSuite.scala | 23 +-- 10 files changed, 290 insertions(+), 120 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index fbf2dc73ea075..b4bca1e9401e2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend( userClassPath: Seq[URL], env: SparkEnv, resourcesFileOpt: Option[String]) - extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + extends IsolatedRpcEndpoint with ExecutorBackend with Logging { private implicit val formats = DefaultFormats diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 97eed540b8f59..c7f56f7749a44 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -146,3 +146,13 @@ private[spark] trait RpcEndpoint { * [[ThreadSafeRpcEndpoint]] for different messages. */ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + +/** + * An endpoint that uses a dedicated thread pool for delivering messages. + */ +private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint { + + /** How many threads to use for delivering messages. By default, use a single thread. */ + def threadCount(): Int = 1 + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 2f923d7902b05..27c943da88105 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,20 +17,16 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.concurrent.Promise -import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.EXECUTOR_ID -import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ -import org.apache.spark.util.ThreadUtils /** * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). @@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils */ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging { - private class EndpointData( - val name: String, - val endpoint: RpcEndpoint, - val ref: NettyRpcEndpointRef) { - val inbox = new Inbox(ref, endpoint) - } - - private val endpoints: ConcurrentMap[String, EndpointData] = - new ConcurrentHashMap[String, EndpointData] + private val endpoints: ConcurrentMap[String, MessageLoop] = + new ConcurrentHashMap[String, MessageLoop] private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] - // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[EndpointData] + private val shutdownLatch = new CountDownLatch(1) + private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores) + + private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop = { + endpoint match { + case e: IsolatedRpcEndpoint => + new DedicatedMessageLoop(name, e, this) + case _ => + sharedLoop.register(name, endpoint) + sharedLoop + } + } /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced @@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte if (stopped) { throw new IllegalStateException("RpcEnv has been stopped") } - if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { + if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) { throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") } - val data = endpoints.get(name) - endpointRefs.put(data.endpoint, data.ref) - receivers.offer(data) // for the OnStart message } + endpointRefs.put(endpoint, endpointRef) endpointRef } @@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte // Should be idempotent private def unregisterRpcEndpoint(name: String): Unit = { - val data = endpoints.remove(name) - if (data != null) { - data.inbox.stop() - receivers.offer(data) // for the OnStop message + val loop = endpoints.remove(name) + if (loop != null) { + loop.unregister(name) } // Don't clean `endpointRefs` here because it's possible that some messages are being processed // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via @@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte message: InboxMessage, callbackIfStopped: (Exception) => Unit): Unit = { val error = synchronized { - val data = endpoints.get(endpointName) + val loop = endpoints.get(endpointName) if (stopped) { Some(new RpcEnvStoppedException()) - } else if (data == null) { + } else if (loop == null) { Some(new SparkException(s"Could not find $endpointName.")) } else { - data.inbox.post(message) - receivers.offer(data) + loop.post(endpointName, message) None } } @@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte } stopped = true } - // Stop all endpoints. This will queue all endpoints for processing by the message loops. - endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) - // Enqueue a message that tells the message loops to stop. - receivers.offer(PoisonPill) - threadpool.shutdown() + var stopSharedLoop = false + endpoints.asScala.foreach { case (name, loop) => + unregisterRpcEndpoint(name) + if (!loop.isInstanceOf[SharedMessageLoop]) { + loop.stop() + } else { + stopSharedLoop = true + } + } + if (stopSharedLoop) { + sharedLoop.stop() + } + shutdownLatch.countDown() } def awaitTermination(): Unit = { - threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + shutdownLatch.await() } /** @@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte def verify(name: String): Boolean = { endpoints.containsKey(name) } - - private def getNumOfThreads(conf: SparkConf): Int = { - val availableCores = - if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() - - val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS) - .getOrElse(math.max(2, availableCores)) - - conf.get(EXECUTOR_ID).map { id => - val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" - conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads) - }.getOrElse(modNumThreads) - } - - /** Thread pool used for dispatching messages. */ - private val threadpool: ThreadPoolExecutor = { - val numThreads = getNumOfThreads(nettyEnv.conf) - val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") - for (i <- 0 until numThreads) { - pool.execute(new MessageLoop) - } - pool - } - - /** Message loop used for dispatching messages. */ - private class MessageLoop extends Runnable { - override def run(): Unit = { - try { - while (true) { - try { - val data = receivers.take() - if (data == PoisonPill) { - // Put PoisonPill back so that other MessageLoops can see it. - receivers.offer(PoisonPill) - return - } - data.inbox.process(Dispatcher.this) - } catch { - case NonFatal(e) => logError(e.getMessage, e) - } - } - } catch { - case _: InterruptedException => // exit - case t: Throwable => - try { - // Re-submit a MessageLoop so that Dispatcher will still work if - // UncaughtExceptionHandler decides to not kill JVM. - threadpool.execute(new MessageLoop) - } finally { - throw t - } - } - } - } - - /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = new EndpointData(null, null, null) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 44d2622a42f58..2ed03f7430c32 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA /** * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. */ -private[netty] class Inbox( - val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) +private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint) extends Logging { inbox => // Give this an alias so we can use it more clearly in closures. @@ -195,7 +193,7 @@ private[netty] class Inbox( * Exposed for testing. */ protected def onDrop(message: InboxMessage): Unit = { - logWarning(s"Drop $message because $endpointRef is stopped") + logWarning(s"Drop $message because endpoint $endpointName is stopped") } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala new file mode 100644 index 0000000000000..c985c72f2adce --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_ID +import org.apache.spark.internal.config.Network._ +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcEndpoint} +import org.apache.spark.util.ThreadUtils + +/** + * A message loop used by [[Dispatcher]] to deliver messages to endpoints. + */ +private sealed abstract class MessageLoop(dispatcher: Dispatcher) extends Logging { + + // List of inboxes with pending messages, to be processed by the message loop. + private val active = new LinkedBlockingQueue[Inbox]() + + // Message loop task; should be run in all threads of the message loop's pool. + protected val receiveLoopRunnable = new Runnable() { + override def run(): Unit = receiveLoop() + } + + protected val threadpool: ExecutorService + + private var stopped = false + + def post(endpointName: String, message: InboxMessage): Unit + + def unregister(name: String): Unit + + def stop(): Unit = { + synchronized { + if (!stopped) { + setActive(MessageLoop.PoisonPill) + threadpool.shutdown() + stopped = true + } + } + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + protected final def setActive(inbox: Inbox): Unit = active.offer(inbox) + + private def receiveLoop(): Unit = { + try { + while (true) { + try { + val inbox = active.take() + if (inbox == MessageLoop.PoisonPill) { + // Put PoisonPill back so that other threads can see it. + setActive(MessageLoop.PoisonPill) + return + } + inbox.process(dispatcher) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case _: InterruptedException => // exit + case t: Throwable => + try { + // Re-submit a receive task so that message delivery will still work if + // UncaughtExceptionHandler decides to not kill JVM. + threadpool.execute(receiveLoopRunnable) + } finally { + throw t + } + } + } +} + +private object MessageLoop { + /** A poison inbox that indicates the message loop should stop processing messages. */ + val PoisonPill = new Inbox(null, null) +} + +/** + * A message loop that serves multiple RPC endpoints, using a shared thread pool. + */ +private class SharedMessageLoop( + conf: SparkConf, + dispatcher: Dispatcher, + numUsableCores: Int) + extends MessageLoop(dispatcher) { + + private val endpoints = new ConcurrentHashMap[String, Inbox]() + + private def getNumOfThreads(conf: SparkConf): Int = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() + + val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS) + .getOrElse(math.max(2, availableCores)) + + conf.get(EXECUTOR_ID).map { id => + val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" + conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads) + }.getOrElse(modNumThreads) + } + + /** Thread pool used for dispatching messages. */ + override protected val threadpool: ThreadPoolExecutor = { + val numThreads = getNumOfThreads(conf) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(receiveLoopRunnable) + } + pool + } + + override def post(endpointName: String, message: InboxMessage): Unit = { + val inbox = endpoints.get(endpointName) + inbox.post(message) + setActive(inbox) + } + + override def unregister(name: String): Unit = { + val inbox = endpoints.remove(name) + if (inbox != null) { + inbox.stop() + // Mark active to handle the OnStop message. + setActive(inbox) + } + } + + def register(name: String, endpoint: RpcEndpoint): Unit = { + val inbox = new Inbox(name, endpoint) + endpoints.put(name, inbox) + // Mark active to handle the OnStart message. + setActive(inbox) + } +} + +/** + * A message loop that is dedicated to a single RPC endpoint. + */ +private class DedicatedMessageLoop( + name: String, + endpoint: IsolatedRpcEndpoint, + dispatcher: Dispatcher) + extends MessageLoop(dispatcher) { + + private val inbox = new Inbox(name, endpoint) + + override protected val threadpool = if (endpoint.threadCount() > 1) { + ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", endpoint.threadCount()) + } else { + ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name") + } + + (1 to endpoint.threadCount()).foreach { _ => + threadpool.submit(receiveLoopRunnable) + } + + // Mark active to handle the OnStart message. + setActive(inbox) + + override def post(endpointName: String, message: InboxMessage): Unit = { + require(endpointName == name) + inbox.post(message) + setActive(inbox) + } + + override def unregister(endpointName: String): Unit = synchronized { + require(endpointName == name) + inbox.stop() + // Mark active to handle the OnStop message. + setActive(inbox) + setActive(MessageLoop.PoisonPill) + threadpool.shutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4958389ae4257..6e990d1335897 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -111,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") - class DriverEndpoint extends ThreadSafeRpcEndpoint with Logging { + class DriverEndpoint extends IsolatedRpcEndpoint with Logging { override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index faf6f713c838f..02d0e1a834909 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.shuffle.ExternalBlockStoreClient -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} @@ -46,7 +46,7 @@ class BlockManagerMasterEndpoint( conf: SparkConf, listenerBus: LiveListenerBus, externalBlockStoreClient: Option[ExternalBlockStoreClient]) - extends ThreadSafeRpcEndpoint with Logging { + extends IsolatedRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index f90595ab924b4..29e21142ce449 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,7 +34,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends ThreadSafeRpcEndpoint with Logging { + extends IsolatedRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 5929fbf85a1f4..c10f2c244e133 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -36,7 +36,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.config._ -import org.apache.spark.internal.config.Network import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -954,6 +953,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { verify(endpoint, never()).onDisconnected(any()) verify(endpoint, never()).onNetworkError(any(), any()) } + + test("isolated endpoints") { + val latch = new CountDownLatch(1) + val singleThreadedEnv = createRpcEnv( + new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) + try { + val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new IsolatedRpcEndpoint { + override val rpcEnv: RpcEnv = singleThreadedEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => + latch.await() + context.reply(m) + } + }) + + val nonBlockingEndpoint = singleThreadedEnv.setupEndpoint("non-blocking", new RpcEndpoint { + override val rpcEnv: RpcEnv = singleThreadedEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply(m) + } + }) + + val to = new RpcTimeout(5.seconds, "test-timeout") + val blockingFuture = blockingEndpoint.ask[String]("hi", to) + assert(nonBlockingEndpoint.askSync[String]("hello", to) === "hello") + latch.countDown() + assert(ThreadUtils.awaitResult(blockingFuture, 5.seconds) === "hi") + } finally { + latch.countDown() + singleThreadedEnv.shutdown() + } + } } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index e5539566e4b6f..c74c728b3e3f3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -29,12 +29,9 @@ class InboxSuite extends SparkFunSuite { test("post") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - when(endpointRef.name).thenReturn("hello") - val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) @@ -51,10 +48,9 @@ class InboxSuite extends SparkFunSuite { test("post: with reply") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) @@ -65,13 +61,10 @@ class InboxSuite extends SparkFunSuite { test("post: multiple threads") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - when(endpointRef.name).thenReturn("hello") - val dispatcher = mock(classOf[Dispatcher]) val numDroppedMessages = new AtomicInteger(0) - val inbox = new Inbox(endpointRef, endpoint) { + val inbox = new Inbox("name", endpoint) { override def onDrop(message: InboxMessage): Unit = { numDroppedMessages.incrementAndGet() } @@ -107,12 +100,10 @@ class InboxSuite extends SparkFunSuite { test("post: Associated") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) inbox.post(RemoteProcessConnected(remoteAddress)) inbox.process(dispatcher) @@ -121,12 +112,11 @@ class InboxSuite extends SparkFunSuite { test("post: Disassociated") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) inbox.post(RemoteProcessDisconnected(remoteAddress)) inbox.process(dispatcher) @@ -135,13 +125,12 @@ class InboxSuite extends SparkFunSuite { test("post: AssociationError") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) val cause = new RuntimeException("Oops") - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) inbox.process(dispatcher) From b674b4c251ecc897e50504bf999257dc0e6d8354 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 10 Oct 2019 15:10:11 -0700 Subject: [PATCH 2/2] Add comment about thread-safety. --- .../src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index c7f56f7749a44..4728759e7fb0d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -152,7 +152,13 @@ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint */ private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint { - /** How many threads to use for delivering messages. By default, use a single thread. */ + /** + * How many threads to use for delivering messages. By default, use a single thread. + * + * Note that requesting more than one thread means that the endpoint should be able to handle + * messages arriving from many threads at once, and all the things that entails (including + * messages being delivered to the endpoint out of order). + */ def threadCount(): Int = 1 }