From 64f88967698a1816557dfe1816d85a029cc55ad6 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 5 Nov 2015 02:58:56 +0900 Subject: [PATCH 1/6] Bypass unnecessary network access if block managers share an identical host --- .../spark/network/BlockDataManager.scala | 9 +- .../shuffle/FileShuffleBlockResolver.scala | 9 +- .../shuffle/IndexShuffleBlockResolver.scala | 39 +++++++-- .../spark/shuffle/ShuffleBlockResolver.scala | 5 +- .../apache/spark/storage/BlockManager.scala | 17 +++- .../apache/spark/storage/BlockManagerId.scala | 4 + .../spark/storage/BlockManagerMaster.scala | 25 ++++-- .../storage/BlockManagerMasterEndpoint.scala | 38 ++++++--- .../spark/storage/BlockManagerMessages.scala | 3 + .../spark/storage/DiskBlockManager.scala | 55 ++++++++++-- .../storage/ShuffleBlockFetcherIterator.scala | 58 ++++++++----- .../org/apache/spark/DistributedSuite.scala | 12 +++ .../BlockStoreShuffleReaderSuite.scala | 3 +- .../spark/storage/DiskBlockManagerSuite.scala | 78 ++++++++++++++--- .../ShuffleBlockFetcherIteratorSuite.scala | 84 +++++++++++++++++-- docs/configuration.md | 8 ++ 16 files changed, 365 insertions(+), 82 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 1745d52c81923..94e33d156dc4b 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -18,7 +18,7 @@ package org.apache.spark.network import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.storage.{ShuffleBlockId, BlockManagerId, BlockId, StorageLevel} private[spark] trait BlockDataManager { @@ -29,6 +29,13 @@ trait BlockDataManager { */ def getBlockData(blockId: BlockId): ManagedBuffer + /** + * Interface to get the shuffle block data that block manager with given blockManagerId + * holds in a local host. Throws an exception if the block cannot be found or + * cannot be read successfully. + */ + def getShuffleBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer + /** * Put the block locally, using the given storage level. */ diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index cc5f933393adf..4528e982c4cbd 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -98,8 +98,13 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - val file = blockManager.diskBlockManager.getFile(blockId) + override def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId) + : ManagedBuffer = { + val file = if (blockManager.blockManagerId != blockManagerId) { + blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess(blockId, blockManagerId) + } else { + blockManager.diskBlockManager.getFile(blockId) + } new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index fadb8fe7ed0ab..d53a2666db0ce 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -49,12 +49,35 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - def getDataFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + private def getDataFile( + shuffleId: Int, + mapId: Int, + blockManagerId: BlockManagerId = blockManager.blockManagerId) + : File = { + if (blockManager.blockManagerId != blockManagerId) { + blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess( + ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) + } else { + blockManager.diskBlockManager.getFile( + ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + } } - private def getIndexFile(shuffleId: Int, mapId: Int): File = { - blockManager.diskBlockManager.getFile(ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + def getDataFile(shuffleId: Int, mapId: Int): File = + getDataFile(shuffleId, mapId, blockManager.blockManagerId) + + private def getIndexFile( + shuffleId: Int, + mapId: Int, + blockManagerId: BlockManagerId = blockManager.blockManagerId) + : File = { + if (blockManager.blockManagerId != blockManagerId) { + blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess( + ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) + } else { + blockManager.diskBlockManager.getFile( + ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) + } } /** @@ -183,10 +206,12 @@ private[spark] class IndexShuffleBlockResolver( } } - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { + override def getBlockData( + blockId: ShuffleBlockId, blockManagerId: BlockManagerId) + : ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index - val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId) + val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId, blockManagerId) val in = new DataInputStream(new FileInputStream(indexFile)) try { @@ -195,7 +220,7 @@ private[spark] class IndexShuffleBlockResolver( val nextOffset = in.readLong() new FileSegmentManagedBuffer( transportConf, - getDataFile(blockId.shuffleId, blockId.mapId), + getDataFile(blockId.shuffleId, blockId.mapId, blockManagerId), offset, nextOffset - offset) } finally { diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 4342b0d598b16..907ef68ecf889 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -17,9 +17,8 @@ package org.apache.spark.shuffle -import java.nio.ByteBuffer import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} private[spark] /** @@ -35,7 +34,7 @@ trait ShuffleBlockResolver { * Retrieve the data for the specified block. If the data for that block is not available, * throws an unspecified exception. */ - def getBlockData(blockId: ShuffleBlockId): ManagedBuffer + def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer def stop(): Unit } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ab0007fb78993..bd6783b0d9bea 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -187,7 +187,8 @@ private[spark] class BlockManager( blockManagerId } - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager( + blockManagerId, maxMemory, diskBlockManager.getLocalDirsPath(), slaveEndpoint) // Register Executors' configuration with the local shuffle service, if one should exist. if (externalShuffleServiceEnabled && !blockManagerId.isDriver) { @@ -250,7 +251,8 @@ private[spark] class BlockManager( def reregister(): Unit = { // TODO: We might need to rate limit re-registering. logInfo("BlockManager re-registering with master") - master.registerBlockManager(blockManagerId, maxMemory, slaveEndpoint) + master.registerBlockManager( + blockManagerId, maxMemory, diskBlockManager.getLocalDirsPath(), slaveEndpoint) reportAllBlocks() } @@ -288,7 +290,7 @@ private[spark] class BlockManager( */ override def getBlockData(blockId: BlockId): ManagedBuffer = { if (blockId.isShuffle) { - shuffleManager.shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + getShuffleBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -301,6 +303,12 @@ private[spark] class BlockManager( } } + override def getShuffleBlockData( + blockId: ShuffleBlockId, blockManagerId: BlockManagerId) + : ManagedBuffer = { + shuffleManager.shuffleBlockResolver.getBlockData(blockId, blockManagerId) + } + /** * Put the block locally, using the given storage level. */ @@ -432,7 +440,8 @@ private[spark] class BlockManager( // TODO: This should gracefully handle case where local block is not available. Currently // downstream code will throw an exception. Option( - shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId]).nioByteBuffer()) + shuffleBlockResolver.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + .nioByteBuffer()) } else { doGetLocal(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index 69ac37511e730..caf50bf0ff2cb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -65,6 +65,10 @@ class BlockManagerId private ( executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER } + def shareHost(other: BlockManagerId): Boolean = { + host == other.host + } + override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) out.writeUTF(host_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index f45bff34d4dbc..43f07e6f563c4 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,14 +17,14 @@ package org.apache.spark.storage -import scala.collection.Iterable -import scala.collection.generic.CanBuildFrom -import scala.concurrent.{Await, Future} - import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ThreadUtils, RpcUtils} +import org.apache.spark.util.{RpcUtils, ThreadUtils} +import org.apache.spark.{Logging, SparkConf, SparkException} + +import scala.collection.Iterable +import scala.collection.generic.CanBuildFrom +import scala.concurrent.Future private[spark] class BlockManagerMaster( @@ -43,9 +43,12 @@ class BlockManagerMaster( /** Register the BlockManager's id with the driver. */ def registerBlockManager( - blockManagerId: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef): Unit = { + blockManagerId: BlockManagerId, + maxMemSize: Long, + localDirsPath: Array[String], + slaveEndpoint: RpcEndpointRef): Unit = { logInfo("Trying to register BlockManager") - tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint)) + tell(RegisterBlockManager(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint)) logInfo("Registered BlockManager") } @@ -74,6 +77,12 @@ class BlockManagerMaster( GetLocationsMultipleBlockIds(blockIds)) } + /** Return other blockmanager's local dirs with the given blockManagerId */ + def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = { + driverEndpoint.askWithRetry[Map[BlockManagerId, Array[String]]]( + GetLocalDirsPath(blockManagerId)) + } + /** * Check if block manager master has a block. Note that this can be used to check for only * those blocks that are reported to block manager master. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 7db6035553ae6..84f71340c8eee 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,17 +19,16 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} -import scala.collection.immutable.HashSet -import scala.collection.mutable -import scala.collection.JavaConverters._ -import scala.concurrent.{ExecutionContext, Future} - -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} -import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} +import org.apache.spark.{Logging, SparkConf} + +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} /** * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses @@ -56,8 +55,8 @@ class BlockManagerMasterEndpoint( private implicit val askExecutionContext = ExecutionContext.fromExecutorService(askThreadPool) override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterBlockManager(blockManagerId, maxMemSize, slaveEndpoint) => - register(blockManagerId, maxMemSize, slaveEndpoint) + case RegisterBlockManager(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint) => + register(blockManagerId, maxMemSize, localDirsPath, slaveEndpoint) context.reply(true) case _updateBlockInfo @ UpdateBlockInfo( @@ -81,6 +80,9 @@ class BlockManagerMasterEndpoint( case GetMemoryStatus => context.reply(memoryStatus) + case GetLocalDirsPath(blockManagerId) => + context.reply(getLocalDirsPath(blockManagerId)) + case GetStorageStatus => context.reply(storageStatus) @@ -240,6 +242,15 @@ class BlockManagerMasterEndpoint( }.toMap } + // Return the local dirs of a block manager with the given blockManagerId + private def getLocalDirsPath(blockManagerId: BlockManagerId) + : Map[BlockManagerId, Array[String]] = { + blockManagerInfo + .filter { case(id, _) => id != blockManagerId && id.host == blockManagerId.host } + .mapValues { info => info.localDirsPath } + .toMap + } + private def storageStatus: Array[StorageStatus] = { blockManagerInfo.map { case (blockManagerId, info) => new StorageStatus(blockManagerId, info.maxMem, info.blocks.asScala) @@ -299,7 +310,11 @@ class BlockManagerMasterEndpoint( ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveEndpoint: RpcEndpointRef) { + private def register( + id: BlockManagerId, + maxMemSize: Long, + localDirsPath: Array[String], + slaveEndpoint: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -316,7 +331,7 @@ class BlockManagerMasterEndpoint( blockManagerIdByExecutor(id.executorId) = id blockManagerInfo(id) = new BlockManagerInfo( - id, System.currentTimeMillis(), maxMemSize, slaveEndpoint) + id, System.currentTimeMillis(), maxMemSize, localDirsPath, slaveEndpoint) } listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize)) } @@ -423,6 +438,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, + val localDirsPath: Array[String], val slaveEndpoint: RpcEndpointRef) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 376e9eb48843d..2c32ca174c1cc 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -50,6 +50,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, + localDirsPath: Array[String], sender: RpcEndpointRef) extends ToBlockManagerMaster @@ -109,4 +110,6 @@ private[spark] object BlockManagerMessages { case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster case class HasCachedBlocks(executorId: String) extends ToBlockManagerMaster + + case class GetLocalDirsPath(blockManagerId: BlockManagerId) extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f7e84a2c2e14c..a336baf893d6f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -20,6 +20,8 @@ package org.apache.spark.storage import java.util.UUID import java.io.{IOException, File} +import scala.collection.mutable + import org.apache.spark.{SparkConf, Logging} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -51,16 +53,26 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon // of subDirs(i) is protected by the lock of subDirs(i) private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir)) + // Cache local directories for other block managers + private val localDirsByOtherBlkMgr = new mutable.HashMap[BlockManagerId, Array[String]] + private val shutdownHook = addShutdownHook() + def blockManagerId: BlockManagerId = blockManager.blockManagerId + + def getLocalDirsPath(): Array[String] = { + localDirs.map(_.getAbsolutePath) + } + + def getLocalDirsPath(blockManagerId: BlockManagerId): Map[BlockManagerId, Array[String]] = { + blockManager.master.getLocalDirsPath(blockManagerId) + } + /** Looks up a file by hashing it into one of our local subdirectories. */ // This method should be kept in sync with // org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getFile(). def getFile(filename: String): File = { - // Figure out which local directory it hashes to, and which subdirectory in that - val hash = Utils.nonNegativeHash(filename) - val dirId = hash % localDirs.length - val subDirId = (hash / localDirs.length) % subDirsPerLocalDir + val (dirId, subDirId) = getDirInfo(filename, localDirs.length) // Create the subdirectory if it doesn't already exist val subDir = subDirs(dirId).synchronized { @@ -82,6 +94,39 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon def getFile(blockId: BlockId): File = getFile(blockId.name) + def getShuffleFileBypassNetworkAccess(blockId: BlockId, blockManagerId: BlockManagerId): File = { + if (this.blockManagerId == blockManagerId) { + getFile(blockId) + } else { + // Get a file from another block manager with given blockManagerId + val dirs = localDirsByOtherBlkMgr.synchronized { + localDirsByOtherBlkMgr.getOrElse(blockManagerId, { + localDirsByOtherBlkMgr ++= getLocalDirsPath(this.blockManagerId) + localDirsByOtherBlkMgr.getOrElse(blockManagerId, { + throw new IOException(s"Block manager (${blockManagerId}) not found " + + s"in host '${this.blockManagerId.host}'") + }) + }) + } + val (dirId, subDirId) = getDirInfo(blockId.name, dirs.length) + val file = new File(new File(dirs(dirId), "%02x".format(subDirId)), blockId.name) + if (!file.exists()) { + throw new IOException(s"File '${file}' not found in local dir") + } + logInfo(s"${this.blockManagerId} bypasses network access and " + + s"directly reads file '${file}' in local dir") + file + } + } + + def getDirInfo(filename: String, numDirs: Int): (Int, Int) = { + // Figure out which local directory it hashes to, and which subdirectory in that + val hash = Utils.nonNegativeHash(filename) + val dirId = hash % numDirs + val subDirName = (hash / numDirs) % subDirsPerLocalDir + (dirId, subDirName) + } + /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { getFile(blockId.name).exists() @@ -166,7 +211,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon // Only perform cleanup if an external service is not serving our shuffle files. // Also blockManagerId could be null if block manager is not initialized properly. if (!blockManager.externalShuffleServiceEnabled || - (blockManager.blockManagerId != null && blockManager.blockManagerId.isDriver)) { + (this.blockManagerId != null && blockManager.blockManagerId.isDriver)) { localDirs.foreach { localDir => if (localDir.isDirectory() && localDir.exists()) { try { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d0448feb5b06..e3ff0ce72fd8e 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -20,7 +20,7 @@ package org.apache.spark.storage import java.io.InputStream import java.util.concurrent.LinkedBlockingQueue -import scala.collection.mutable.{ArrayBuffer, HashSet, Queue} +import scala.collection.mutable.{HashMap, ArrayBuffer, HashSet, Queue} import scala.util.control.NonFatal import org.apache.spark.{Logging, SparkException, TaskContext} @@ -58,6 +58,13 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ + /** + * If this option enabled, bypass unnecessary network interaction + * if multiple block managers work in a single host. + */ + private[this] val enableBypassNetworkAccess = + blockManager.conf.getBoolean("spark.shuffle.bypassNetworkAccess", false) + /** * Total number of blocks to fetch. This can be smaller than the total number of blocks * in [[blocksByAddress]] because we filter out zero-sized blocks in [[initialize]]. @@ -74,8 +81,12 @@ final class ShuffleBlockFetcherIterator( private[this] val startTime = System.currentTimeMillis - /** Local blocks to fetch, excluding zero-sized blocks. */ - private[this] val localBlocks = new ArrayBuffer[BlockId]() + /** + * Local blocks to fetch, excluding zero-sized blocks. + * This iterator bypasses remote access to fetch the blocks that + * other block managers holds in an identical host. + */ + private[this] val localBlocks = new HashMap[BlockManagerId, ArrayBuffer[BlockId]] /** Remote blocks to fetch, excluding zero-sized blocks. */ private[this] val remoteBlocks = new HashSet[BlockId]() @@ -188,10 +199,13 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - if (address.executorId == blockManager.blockManagerId.executorId) { + // if (blockManager.blockManagerId.shareHost(address)) { + if (blockManager.blockManagerId == address || + (enableBypassNetworkAccess && blockManager.blockManagerId.shareHost(address))) { // Filter out zero-sized blocks - localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1) - numBlocksToFetch += localBlocks.size + val blocks = blockInfos.filter(_._2 != 0).map(_._1) + localBlocks.getOrElseUpdate(address, ArrayBuffer()) ++= blocks + numBlocksToFetch += blocks.size } else { val iterator = blockInfos.iterator var curRequestSize = 0L @@ -233,19 +247,25 @@ final class ShuffleBlockFetcherIterator( private[this] def fetchLocalBlocks() { val iter = localBlocks.iterator while (iter.hasNext) { - val blockId = iter.next() - try { - val buf = blockManager.getBlockData(blockId) - shuffleMetrics.incLocalBlocksFetched(1) - shuffleMetrics.incLocalBytesRead(buf.size) - buf.retain() - results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) - } catch { - case e: Exception => - // If we see an exception, stop immediately. - logError(s"Error occurred while fetching local blocks", e) - results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) - return + val (blockManagerId, blockIds) = iter.next() + val blockIter = blockIds.iterator + while (blockIter.hasNext) { + val blockId = blockIter.next() + assert(blockId.isShuffle) + try { + val buf = blockManager.getShuffleBlockData( + blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + shuffleMetrics.incLocalBlocksFetched(1) + shuffleMetrics.incLocalBytesRead(buf.size) + buf.retain() + results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) + } catch { + case e: Exception => + // If we see an exception, stop immediately. + logError(s"Error occurred while fetching local blocks", e) + results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) + return + } } } } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 1c3f2bc315ddc..676948856a0a3 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -86,6 +86,18 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex assert(groups.map(_._2).sum === 2000) } + test("bypass remote access") { + val conf = new SparkConf().set("spark.shuffle.bypassNetworkAccess", "true") + Seq("hash", "sort", "tungsten-sort").map { shuffle => + sc = new SparkContext(clusterUrl, "test", conf.clone.set("spark.shuffle.manager", shuffle)) + val rdd = sc.parallelize((0 until 1000).map(x => (x % 4, 1)), 5) + val groups = rdd.reduceByKey(_ + _).collect + assert(groups.size === 4) + assert(groups.forall(_._2 == 250)) + resetSparkContext() + } + } + test("accumulators") { sc = new SparkContext(clusterUrl, "test") val accum = sc.accumulator(0) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 26a372d6a905d..22b4f50860390 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -104,7 +104,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer) + when(blockManager.getShuffleBlockData(shuffleBlockId, localBlockManagerId)) + .thenReturn(managedBuffer) when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) .thenAnswer(dummyCompressionFunction) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 688f56f4665f3..ee877478ed367 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -17,21 +17,21 @@ package org.apache.spark.storage -import java.io.{File, FileWriter} +import java.io.{File, FileWriter, IOException} -import scala.language.reflectiveCalls - -import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} - -import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.mockito.Matchers.{eq => meq} +import org.mockito.Mockito.{mock, times, verify, when} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester} -class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll { +import scala.language.reflectiveCalls + +class DiskBlockManagerSuite extends SparkFunSuite + with BeforeAndAfterEach with BeforeAndAfterAll with PrivateMethodTester { private val testConf = new SparkConf(false) private var rootDir0: File = _ private var rootDir1: File = _ - private var rootDirs: String = _ val blockManager = mock(classOf[BlockManager]) when(blockManager.conf).thenReturn(testConf) @@ -41,7 +41,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B super.beforeAll() rootDir0 = Utils.createTempDir() rootDir1 = Utils.createTempDir() - rootDirs = rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath + testConf.set("spark.local.dir", rootDir0.getAbsolutePath + "," + rootDir1.getAbsolutePath) } override def afterAll() { @@ -51,9 +51,7 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B } override def beforeEach() { - val conf = testConf.clone - conf.set("spark.local.dir", rootDirs) - diskBlockManager = new DiskBlockManager(blockManager, conf) + diskBlockManager = new DiskBlockManager(blockManager, testConf.clone) } override def afterEach() { @@ -81,4 +79,58 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B for (i <- 0 until numBytes) writer.write(i) writer.close() } + + test("bypassing network access") { + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockManager = mock(classOf[BlockManager]) + + // Assume two executors in an identical host + val localBmId1 = BlockManagerId("test-exec1", "test-client1", 1) + val localBmId2 = BlockManagerId("test-exec2", "test-client1", 2) + + // Assume that localBmId2 holds 'shuffle_1_0_0' + val blockIdInLocalBmId2 = ShuffleBlockId(1, 0, 0) + val tempDir = Utils.createTempDir() + try { + // Create mock classes for testing + when(mockBlockManagerMaster.getLocalDirsPath(meq(localBmId1))) + .thenReturn(Map(localBmId2 -> Array(tempDir.getAbsolutePath))) + when(mockBlockManager.conf).thenReturn(testConf) + when(mockBlockManager.master).thenReturn(mockBlockManagerMaster) + when(mockBlockManager.blockManagerId).thenReturn(localBmId1) + + val testDiskBlockManager = new DiskBlockManager(mockBlockManager, testConf.clone) + + val getBlockDir: String => File = (s: String) => { + val (_, subDirId) = testDiskBlockManager.getDirInfo(s, 1) + new File(tempDir, "%02x".format(subDirId)) + } + + // Create a dummy file for a shuffle block + val blockDir = getBlockDir(blockIdInLocalBmId2.name) + assert(blockDir.mkdir()) + val dummyBlockFile = new File(blockDir, blockIdInLocalBmId2.name) + assert(dummyBlockFile.createNewFile()) + + val file = testDiskBlockManager.getShuffleFileBypassNetworkAccess( + blockIdInLocalBmId2, localBmId2) + assert(dummyBlockFile.getName === file.getName) + assert(dummyBlockFile.toString.contains(tempDir.toString)) + + verify(mockBlockManagerMaster, times(1)).getLocalDirsPath(meq(localBmId1)) + verify(mockBlockManager, times(1)).conf + verify(mockBlockManager, times(1)).master + verify(mockBlockManager, times(3)).blockManagerId + + // Throw an IOException if given shuffle file not found + val blockIdNotInLocalBmId2 = ShuffleBlockId(2, 0, 0) + val errMsg = intercept[IOException] { + testDiskBlockManager.getShuffleFileBypassNetworkAccess(blockIdNotInLocalBmId2, localBmId2) + } + assert(errMsg.getMessage contains s"File '${getBlockDir(blockIdNotInLocalBmId2.name)}/" + + s"${blockIdNotInLocalBmId2}' not found in local dir") + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 828153bdbfc44..4c13fb772b955 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -27,16 +27,24 @@ import org.mockito.Matchers.{any, eq => meq} import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer -import org.scalatest.PrivateMethodTester +import org.scalatest.{BeforeAndAfterAll, PrivateMethodTester} -import org.apache.spark.{SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.shuffle.FetchFailedException -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { +class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite + with BeforeAndAfterAll with PrivateMethodTester { + private val testConf = new SparkConf(false) + + override def beforeAll() { + super.beforeAll() + testConf.set("spark.shuffle.bypassNetworkAccess", "false") + } + // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. @@ -70,15 +78,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("successful 3 local reads + 2 remote reads") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(testConf).when(blockManager).conf doReturn(localBmId).when(blockManager).blockManagerId // Make sure blockManager.getBlockData would return the blocks - val localBlocks = Map[BlockId, ManagedBuffer]( + val localBlocks = Map[ShuffleBlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getBlockData(meq(blockId)) + doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId)) } // Make sure remote blocks would return @@ -102,14 +111,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT 48 * 1024 * 1024) // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getShuffleBlockData(any(), any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. - val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId)) + val mockBuf = localBlocks.getOrElse( + blockId.asInstanceOf[ShuffleBlockId], remoteBlocks(blockId)) // Note: ShuffleBlockFetcherIterator wraps input streams in a BufferReleasingInputStream val wrappedInputStream = inputStream.asInstanceOf[BufferReleasingInputStream] verify(mockBuf, times(0)).release() @@ -126,13 +136,70 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(3)).getBlockData(any()) + verify(blockManager, times(3)).getShuffleBlockData(any(), any()) verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) } + test("bypass unnecessary network access if block managers share an identical host") { + val blockManager = mock(classOf[BlockManager]) + + // Assume two executors in an identical host + val localBmId1 = BlockManagerId("test-exec1", "test-client1", 1) + val localBmId2 = BlockManagerId("test-exec2", "test-client1", 2) + + // Enable an option to bypass network access + doReturn(testConf.clone.set("spark.shuffle.bypassNetworkAccess", "true")) + .when(blockManager).conf + doReturn(localBmId1).when(blockManager).blockManagerId + + // Make sure blockManager.getBlockData would return the blocks + val localBlocksInBmId1 = Map[ShuffleBlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + localBlocksInBmId1.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId1)) + } + val localBlocksInBmId2 = Map[ShuffleBlockId, ManagedBuffer]( + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) + localBlocksInBmId2.foreach { case (blockId, buf) => + doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId2)) + } + + // Create mock transfer + val transfer = mock(classOf[BlockTransferService]) + when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = {} + }) + + val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (localBmId1, localBlocksInBmId1.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq), + (localBmId2, localBlocksInBmId2.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq) + ) + + val iterator = new ShuffleBlockFetcherIterator( + TaskContext.empty(), + transfer, + blockManager, + blocksByAddress, + 48 * 1024 * 1024) + + // Skip unnecessary remote reads + verify(blockManager, times(3)).getShuffleBlockData(any(), any()) + + for (i <- 0 until 3) { + assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") + iterator.next() + } + + // As a result, only 3 local reads (2 remote access skipped) + verify(blockManager, times(3)).getShuffleBlockData(any(), any()) + verify(transfer, times(0)).fetchBlocks(any(), any(), any(), any(), any()) + } + test("release current unexhausted buffer in case the task completes early") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(testConf).when(blockManager).conf doReturn(localBmId).when(blockManager).blockManagerId // Make sure remote blocks would return @@ -194,6 +261,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("fail all blocks if any of the remote request fails") { val blockManager = mock(classOf[BlockManager]) val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(testConf).when(blockManager).conf doReturn(localBmId).when(blockManager).blockManagerId // Make sure remote blocks would return diff --git a/docs/configuration.md b/docs/configuration.md index 741d6b2b37a87..b267919ebfaa7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -477,6 +477,14 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec. + + spark.shuffle.bypassNetworkAccess + false + + Whether to bypass network interaction if block managers share an identical host + (e.g., multiple block managers work in a single host). + + #### Spark UI From a93ec369c0a4274a3991c909f0c7539b7d512b20 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 5 Nov 2015 16:57:44 +0900 Subject: [PATCH 2/6] Disable the bypassing optimization if external shuffle service enabled --- .../storage/ShuffleBlockFetcherIterator.scala | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index e3ff0ce72fd8e..755cc9f56a0b0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -58,12 +58,16 @@ final class ShuffleBlockFetcherIterator( import ShuffleBlockFetcherIterator._ + private[this] val enableExternalShuffleService = + blockManager.conf.getBoolean("spark.shuffle.service.enabled", false) + /** * If this option enabled, bypass unnecessary network interaction * if multiple block managers work in a single host. */ private[this] val enableBypassNetworkAccess = - blockManager.conf.getBoolean("spark.shuffle.bypassNetworkAccess", false) + blockManager.conf.getBoolean("spark.shuffle.bypassNetworkAccess", false) && + !enableExternalShuffleService /** * Total number of blocks to fetch. This can be smaller than the total number of blocks @@ -199,8 +203,7 @@ final class ShuffleBlockFetcherIterator( var totalBlocks = 0 for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size - // if (blockManager.blockManagerId.shareHost(address)) { - if (blockManager.blockManagerId == address || + if (address.executorId == blockManager.blockManagerId.executorId || (enableBypassNetworkAccess && blockManager.blockManagerId.shareHost(address))) { // Filter out zero-sized blocks val blocks = blockInfos.filter(_._2 != 0).map(_._1) @@ -253,8 +256,11 @@ final class ShuffleBlockFetcherIterator( val blockId = blockIter.next() assert(blockId.isShuffle) try { - val buf = blockManager.getShuffleBlockData( - blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + val buf = if (!enableExternalShuffleService) { + blockManager.getShuffleBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + } else { + blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) + } shuffleMetrics.incLocalBlocksFetched(1) shuffleMetrics.incLocalBytesRead(buf.size) buf.retain() From bee23aae0c17000cf9c72fa9a4f0913a8d76beb9 Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 6 Nov 2015 08:45:40 +0900 Subject: [PATCH 3/6] Fix bugs in tests --- .../org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 22b4f50860390..90583f4a7edde 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -77,6 +77,8 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // can ensure retain() and release() are properly called. val blockManager = mock(classOf[BlockManager]) + when(blockManager.conf).thenReturn(testConf) + // Create a return function to use for the mocked wrapForCompression method that just returns // the original input stream. val dummyCompressionFunction = new Answer[InputStream] { From b1126aab64d810a8f165f3f2f46b22ec7bb597ba Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 6 Nov 2015 09:00:13 +0900 Subject: [PATCH 4/6] Apply comments --- .../spark/shuffle/IndexShuffleBlockResolver.scala | 4 ++-- .../org/apache/spark/storage/BlockManager.scala | 2 +- .../apache/spark/storage/BlockManagerMaster.scala | 12 ++++++------ 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index d53a2666db0ce..51cbc58d0140c 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -207,8 +207,8 @@ private[spark] class IndexShuffleBlockResolver( } override def getBlockData( - blockId: ShuffleBlockId, blockManagerId: BlockManagerId) - : ManagedBuffer = { + blockId: ShuffleBlockId, + blockManagerId: BlockManagerId): ManagedBuffer = { // The block is actually going to be a range of a single map output file for this map, so // find out the consolidated file, then the offset within that from our index val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId, blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index bd6783b0d9bea..a77d8f6fbc80f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -38,7 +38,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{Serializer, SerializerInstance} +import org.apache.spark.serializer.{SerializerInstance, Serializer} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 43f07e6f563c4..3c3f777933236 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,14 +17,14 @@ package org.apache.spark.storage -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{RpcUtils, ThreadUtils} -import org.apache.spark.{Logging, SparkConf, SparkException} - import scala.collection.Iterable import scala.collection.generic.CanBuildFrom -import scala.concurrent.Future +import scala.concurrent.{Await, Future} + +import org.apache.spark.rpc.RpcEndpointRef +import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.storage.BlockManagerMessages._ +import org.apache.spark.util.{ThreadUtils, RpcUtils} private[spark] class BlockManagerMaster( From ba94687b48d08fc6a4c863fbafeb5d39181cc53c Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Thu, 26 Nov 2015 12:32:03 +0900 Subject: [PATCH 5/6] Apply style comments --- .../spark/shuffle/IndexShuffleBlockResolver.scala | 3 +-- .../org/apache/spark/storage/BlockManager.scala | 4 ++-- .../org/apache/spark/storage/BlockManagerId.scala | 4 ---- .../spark/storage/BlockManagerMasterEndpoint.scala | 14 +++++++------- .../storage/ShuffleBlockFetcherIterator.scala | 4 ++-- .../spark/storage/DiskBlockManagerSuite.scala | 7 ++++--- .../storage/ShuffleBlockFetcherIteratorSuite.scala | 5 +++-- 7 files changed, 19 insertions(+), 22 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 51cbc58d0140c..53fa49f2c9bcf 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -69,8 +69,7 @@ private[spark] class IndexShuffleBlockResolver( private def getIndexFile( shuffleId: Int, mapId: Int, - blockManagerId: BlockManagerId = blockManager.blockManagerId) - : File = { + blockManagerId: BlockManagerId = blockManager.blockManagerId): File = { if (blockManager.blockManagerId != blockManagerId) { blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess( ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a77d8f6fbc80f..1f7ff8d3729d8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -304,8 +304,8 @@ private[spark] class BlockManager( } override def getShuffleBlockData( - blockId: ShuffleBlockId, blockManagerId: BlockManagerId) - : ManagedBuffer = { + blockId: ShuffleBlockId, + blockManagerId: BlockManagerId): ManagedBuffer = { shuffleManager.shuffleBlockResolver.getBlockData(blockId, blockManagerId) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala index caf50bf0ff2cb..69ac37511e730 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala @@ -65,10 +65,6 @@ class BlockManagerId private ( executorId == SparkContext.LEGACY_DRIVER_IDENTIFIER } - def shareHost(other: BlockManagerId): Boolean = { - host == other.host - } - override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeUTF(executorId_) out.writeUTF(host_) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 84f71340c8eee..e6450b097e3fb 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -19,6 +19,10 @@ package org.apache.spark.storage import java.util.{HashMap => JHashMap} +import scala.collection.JavaConverters._ +import scala.collection.mutable +import scala.concurrent.{ExecutionContext, Future} + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ @@ -26,10 +30,6 @@ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} import org.apache.spark.{Logging, SparkConf} -import scala.collection.JavaConverters._ -import scala.collection.mutable -import scala.concurrent.{ExecutionContext, Future} - /** * BlockManagerMasterEndpoint is an [[ThreadSafeRpcEndpoint]] on the master node to track statuses * of all slaves' block managers. @@ -237,7 +237,7 @@ class BlockManagerMasterEndpoint( // Return a map from the block manager id to max memory and remaining memory. private def memoryStatus: Map[BlockManagerId, (Long, Long)] = { - blockManagerInfo.map { case(blockManagerId, info) => + blockManagerInfo.map { case (blockManagerId, info) => (blockManagerId, (info.maxMem, info.remainingMem)) }.toMap } @@ -246,7 +246,7 @@ class BlockManagerMasterEndpoint( private def getLocalDirsPath(blockManagerId: BlockManagerId) : Map[BlockManagerId, Array[String]] = { blockManagerInfo - .filter { case(id, _) => id != blockManagerId && id.host == blockManagerId.host } + .filter { case (id, _) => id != blockManagerId && id.host == blockManagerId.host } .mapValues { info => info.localDirsPath } .toMap } @@ -314,7 +314,7 @@ class BlockManagerMasterEndpoint( id: BlockManagerId, maxMemSize: Long, localDirsPath: Array[String], - slaveEndpoint: RpcEndpointRef) { + slaveEndpoint: RpcEndpointRef): Unit = { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 755cc9f56a0b0..a4eb063c541e1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -204,7 +204,7 @@ final class ShuffleBlockFetcherIterator( for ((address, blockInfos) <- blocksByAddress) { totalBlocks += blockInfos.size if (address.executorId == blockManager.blockManagerId.executorId || - (enableBypassNetworkAccess && blockManager.blockManagerId.shareHost(address))) { + (enableBypassNetworkAccess && blockManager.blockManagerId.host == address.host)) { // Filter out zero-sized blocks val blocks = blockInfos.filter(_._2 != 0).map(_._1) localBlocks.getOrElseUpdate(address, ArrayBuffer()) ++= blocks @@ -266,7 +266,7 @@ final class ShuffleBlockFetcherIterator( buf.retain() results.put(new SuccessFetchResult(blockId, blockManager.blockManagerId, 0, buf)) } catch { - case e: Exception => + case NonFatal(e) => // If we see an exception, stop immediately. logError(s"Error occurred while fetching local blocks", e) results.put(new FailureFetchResult(blockId, blockManager.blockManagerId, e)) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index ee877478ed367..8a30d6541a5b0 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.storage import java.io.{File, FileWriter, IOException} -import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkFunSuite} +import scala.language.reflectiveCalls + import org.mockito.Matchers.{eq => meq} import org.mockito.Mockito.{mock, times, verify, when} import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester} -import scala.language.reflectiveCalls +import org.apache.spark.util.Utils +import org.apache.spark.{SparkConf, SparkFunSuite} class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with BeforeAndAfterAll with PrivateMethodTester { diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 4c13fb772b955..d6bee52da3914 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -36,8 +36,9 @@ import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.shuffle.FetchFailedException -class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite - with BeforeAndAfterAll with PrivateMethodTester { +class ShuffleBlockFetcherIteratorSuite + extends SparkFunSuite with BeforeAndAfterAll with PrivateMethodTester +{ private val testConf = new SparkConf(false) override def beforeAll() { From 303abcd0f37d137c6a8ce4a0147466bb8feb9d9e Mon Sep 17 00:00:00 2001 From: Takeshi YAMAMURO Date: Fri, 27 Nov 2015 15:52:08 +0900 Subject: [PATCH 6/6] Fix the interface issues --- .../spark/network/BlockDataManager.scala | 11 +++--- .../shuffle/FileShuffleBlockResolver.scala | 6 +--- .../shuffle/IndexShuffleBlockResolver.scala | 36 ++++++------------- .../apache/spark/storage/BlockManager.scala | 13 ++++--- .../spark/storage/DiskBlockManager.scala | 2 +- .../storage/ShuffleBlockFetcherIterator.scala | 2 +- .../BlockStoreShuffleReaderSuite.scala | 2 +- .../spark/storage/DiskBlockManagerSuite.scala | 4 +-- .../ShuffleBlockFetcherIteratorSuite.scala | 14 ++++---- 9 files changed, 35 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala index 94e33d156dc4b..e34c796b60c12 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala @@ -18,7 +18,9 @@ package org.apache.spark.network import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.storage.{ShuffleBlockId, BlockManagerId, BlockId, StorageLevel} +import org.apache.spark.storage.BlockId +import org.apache.spark.storage.BlockManagerId +import org.apache.spark.storage.StorageLevel private[spark] trait BlockDataManager { @@ -30,11 +32,10 @@ trait BlockDataManager { def getBlockData(blockId: BlockId): ManagedBuffer /** - * Interface to get the shuffle block data that block manager with given blockManagerId - * holds in a local host. Throws an exception if the block cannot be found or - * cannot be read successfully. + * Interface to get local block data managed by given BlockManagerId. + * Throws an exception if the block cannot be found or cannot be read successfully. */ - def getShuffleBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId): ManagedBuffer + def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer /** * Put the block locally, using the given storage level. diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 4528e982c4cbd..d8b96aa27fc0e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -100,11 +100,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) override def getBlockData(blockId: ShuffleBlockId, blockManagerId: BlockManagerId) : ManagedBuffer = { - val file = if (blockManager.blockManagerId != blockManagerId) { - blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess(blockId, blockManagerId) - } else { - blockManager.diskBlockManager.getFile(blockId) - } + val file = blockManager.diskBlockManager.getFile(blockId, blockManagerId) new FileSegmentManagedBuffer(transportConf, file, 0, file.length) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index 53fa49f2c9bcf..6c5134411ceb2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -49,35 +49,19 @@ private[spark] class IndexShuffleBlockResolver( private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - private def getDataFile( - shuffleId: Int, - mapId: Int, - blockManagerId: BlockManagerId = blockManager.blockManagerId) - : File = { - if (blockManager.blockManagerId != blockManagerId) { - blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess( - ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) - } else { - blockManager.diskBlockManager.getFile( - ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) - } - } - def getDataFile(shuffleId: Int, mapId: Int): File = getDataFile(shuffleId, mapId, blockManager.blockManagerId) - private def getIndexFile( - shuffleId: Int, - mapId: Int, - blockManagerId: BlockManagerId = blockManager.blockManagerId): File = { - if (blockManager.blockManagerId != blockManagerId) { - blockManager.diskBlockManager.getShuffleFileBypassNetworkAccess( - ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) - } else { - blockManager.diskBlockManager.getFile( - ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID)) - } - } + private def getDataFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File = + blockManager.diskBlockManager.getFile( + ShuffleDataBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) + + private def getIndexFile(shuffleId: Int, mapId: Int): File = + getIndexFile(shuffleId, mapId, blockManager.blockManagerId) + + private def getIndexFile(shuffleId: Int, mapId: Int, blockManagerId: BlockManagerId): File = + blockManager.diskBlockManager.getFile( + ShuffleIndexBlockId(shuffleId, mapId, NOOP_REDUCE_ID), blockManagerId) /** * Remove data file and index file that contain the output data from one map. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 1f7ff8d3729d8..03bf8ac7fda28 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -38,7 +38,7 @@ import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv -import org.apache.spark.serializer.{SerializerInstance, Serializer} +import org.apache.spark.serializer.{Serializer, SerializerInstance} import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.util._ @@ -288,9 +288,10 @@ private[spark] class BlockManager( * Interface to get local block data. Throws an exception if the block cannot be found or * cannot be read successfully. */ - override def getBlockData(blockId: BlockId): ManagedBuffer = { + override def getBlockData(blockId: BlockId, blockManagerId: BlockManagerId): ManagedBuffer = { if (blockId.isShuffle) { - getShuffleBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + shuffleManager.shuffleBlockResolver.getBlockData( + blockId.asInstanceOf[ShuffleBlockId], blockManagerId) } else { val blockBytesOpt = doGetLocal(blockId, asBlockResult = false) .asInstanceOf[Option[ByteBuffer]] @@ -303,10 +304,8 @@ private[spark] class BlockManager( } } - override def getShuffleBlockData( - blockId: ShuffleBlockId, - blockManagerId: BlockManagerId): ManagedBuffer = { - shuffleManager.shuffleBlockResolver.getBlockData(blockId, blockManagerId) + override def getBlockData(blockId: BlockId): ManagedBuffer = { + getBlockData(blockId, this.blockManagerId) } /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a336baf893d6f..6bb504f657f45 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -94,7 +94,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon def getFile(blockId: BlockId): File = getFile(blockId.name) - def getShuffleFileBypassNetworkAccess(blockId: BlockId, blockManagerId: BlockManagerId): File = { + def getFile(blockId: BlockId, blockManagerId: BlockManagerId): File = { if (this.blockManagerId == blockManagerId) { getFile(blockId) } else { diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index a4eb063c541e1..54383ec8d6487 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -257,7 +257,7 @@ final class ShuffleBlockFetcherIterator( assert(blockId.isShuffle) try { val buf = if (!enableExternalShuffleService) { - blockManager.getShuffleBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) + blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId], blockManagerId) } else { blockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId]) } diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala index 90583f4a7edde..6bd5135c19849 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala @@ -106,7 +106,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext // Setup the blockManager mock so the buffer gets returned when the shuffle code tries to // fetch shuffle data. val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId) - when(blockManager.getShuffleBlockData(shuffleBlockId, localBlockManagerId)) + when(blockManager.getBlockData(shuffleBlockId, localBlockManagerId)) .thenReturn(managedBuffer) when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream]))) .thenAnswer(dummyCompressionFunction) diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 8a30d6541a5b0..f566491e380a8 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -113,7 +113,7 @@ class DiskBlockManagerSuite extends SparkFunSuite val dummyBlockFile = new File(blockDir, blockIdInLocalBmId2.name) assert(dummyBlockFile.createNewFile()) - val file = testDiskBlockManager.getShuffleFileBypassNetworkAccess( + val file = testDiskBlockManager.getFile( blockIdInLocalBmId2, localBmId2) assert(dummyBlockFile.getName === file.getName) assert(dummyBlockFile.toString.contains(tempDir.toString)) @@ -126,7 +126,7 @@ class DiskBlockManagerSuite extends SparkFunSuite // Throw an IOException if given shuffle file not found val blockIdNotInLocalBmId2 = ShuffleBlockId(2, 0, 0) val errMsg = intercept[IOException] { - testDiskBlockManager.getShuffleFileBypassNetworkAccess(blockIdNotInLocalBmId2, localBmId2) + testDiskBlockManager.getFile(blockIdNotInLocalBmId2, localBmId2) } assert(errMsg.getMessage contains s"File '${getBlockDir(blockIdNotInLocalBmId2.name)}/" + s"${blockIdNotInLocalBmId2}' not found in local dir") diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index d6bee52da3914..d7b3b42474205 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -88,7 +88,7 @@ class ShuffleBlockFetcherIteratorSuite ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocks.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId)) } // Make sure remote blocks would return @@ -112,7 +112,7 @@ class ShuffleBlockFetcherIteratorSuite 48 * 1024 * 1024) // 3 local blocks fetched in initialization - verify(blockManager, times(3)).getShuffleBlockData(any(), any()) + verify(blockManager, times(3)).getBlockData(any(), any()) for (i <- 0 until 5) { assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements") @@ -137,7 +137,7 @@ class ShuffleBlockFetcherIteratorSuite // 3 local blocks, and 2 remote blocks // (but from the same block manager so one call to fetchBlocks) - verify(blockManager, times(3)).getShuffleBlockData(any(), any()) + verify(blockManager, times(3)).getBlockData(any(), any()) verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any()) } @@ -157,13 +157,13 @@ class ShuffleBlockFetcherIteratorSuite val localBlocksInBmId1 = Map[ShuffleBlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) localBlocksInBmId1.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId1)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId1)) } val localBlocksInBmId2 = Map[ShuffleBlockId, ManagedBuffer]( ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) localBlocksInBmId2.foreach { case (blockId, buf) => - doReturn(buf).when(blockManager).getShuffleBlockData(meq(blockId), meq(localBmId2)) + doReturn(buf).when(blockManager).getBlockData(meq(blockId), meq(localBmId2)) } // Create mock transfer @@ -185,7 +185,7 @@ class ShuffleBlockFetcherIteratorSuite 48 * 1024 * 1024) // Skip unnecessary remote reads - verify(blockManager, times(3)).getShuffleBlockData(any(), any()) + verify(blockManager, times(3)).getBlockData(any(), any()) for (i <- 0 until 3) { assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") @@ -193,7 +193,7 @@ class ShuffleBlockFetcherIteratorSuite } // As a result, only 3 local reads (2 remote access skipped) - verify(blockManager, times(3)).getShuffleBlockData(any(), any()) + verify(blockManager, times(3)).getBlockData(any(), any()) verify(transfer, times(0)).fetchBlocks(any(), any(), any(), any(), any()) }