From ba9c27e1097d971becea7e9a47963094928298ef Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 11 Jan 2019 23:14:40 +0800 Subject: [PATCH 1/4] Register channel for stream request. --- .../spark/network/server/OneForOneStreamManager.java | 5 +++++ .../org/apache/spark/network/server/StreamManager.java | 7 +++++++ .../spark/network/server/TransportRequestHandler.java | 1 + 3 files changed, 13 insertions(+) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 0f6a8824d95e5..1a9e7f8a738c2 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -78,6 +78,11 @@ public void registerChannel(Channel channel, long streamId) { } } + @Override + public void registerChannel(Channel channel, String streamChunkId) { + registerChannel(channel, parseStreamChunkId(streamChunkId).getLeft()); + } + @Override public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index c535295831606..351b51bb3dc27 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -70,6 +70,13 @@ public ManagedBuffer openStream(String streamId) { */ public void registerChannel(Channel channel, long streamId) { } + /** + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. This is similar to {@link #registerChannel(Channel, long)} method, but the + * streamId argument is for the stream in response to a stream() request. + */ + public void registerChannel(Channel channel, String streamId) { } + /** * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not * to read from the associated streams again, so any state can be cleaned up. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index 3e089b4cae273..d66aa17e4131f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -127,6 +127,7 @@ private void processStreamRequest(final StreamRequest req) { ManagedBuffer buf; try { buf = streamManager.openStream(req.streamId); + streamManager.registerChannel(channel, req.streamId); } catch (Exception e) { logger.error(String.format( "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e); From 6b028a9a17e0c60fed11354d49c9dbe325202e36 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Mon, 14 Jan 2019 22:47:28 +0800 Subject: [PATCH 2/4] Remove registerChannel from StreamManager. --- .../server/ChunkFetchRequestHandler.java | 3 +-- .../server/OneForOneStreamManager.java | 18 ++++++++------- .../spark/network/server/StreamManager.java | 23 ++++--------------- .../server/TransportRequestHandler.java | 3 +-- .../network/ChunkFetchIntegrationSuite.java | 3 ++- .../ChunkFetchRequestHandlerSuite.java | 1 - .../RequestTimeoutIntegrationSuite.java | 5 ++-- .../org/apache/spark/network/StreamSuite.java | 5 ++-- .../network/TransportRequestHandlerSuite.java | 1 - .../spark/network/sasl/SparkSaslSuite.java | 2 +- .../spark/rpc/netty/NettyStreamManager.scala | 6 +++-- 11 files changed, 29 insertions(+), 41 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index f08d8b0f984cf..069dc61f2c27e 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -90,8 +90,7 @@ protected void channelRead0( ManagedBuffer buf; try { streamManager.checkAuthorization(client, msg.streamChunkId.streamId); - streamManager.registerChannel(channel, msg.streamChunkId.streamId); - buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); + buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex, channel); } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", msg.streamChunkId, getRemoteAddress(channel)), e); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 1a9e7f8a738c2..2b45e7906ffff 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -71,7 +71,12 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap<>(); } - @Override + + /** + * Associates a stream with a single client connection, which is guaranteed to be the only reader + * of the stream. Once the connection is closed, the stream will never be used again, enabling + * cleanup by `connectionTerminated`. + */ public void registerChannel(Channel channel, long streamId) { if (streams.containsKey(streamId)) { streams.get(streamId).associatedChannel = channel; @@ -79,12 +84,9 @@ public void registerChannel(Channel channel, long streamId) { } @Override - public void registerChannel(Channel channel, String streamChunkId) { - registerChannel(channel, parseStreamChunkId(streamChunkId).getLeft()); - } + public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { + registerChannel(channel, streamId); - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); if (chunkIndex != state.curChunk) { throw new IllegalStateException(String.format( @@ -105,9 +107,9 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex) { } @Override - public ManagedBuffer openStream(String streamChunkId) { + public ManagedBuffer openStream(String streamChunkId, Channel channel) { Pair streamChunkIdPair = parseStreamChunkId(streamChunkId); - return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight()); + return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight(), channel); } public static String genStreamChunkId(long streamId, int chunkId) { diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index 351b51bb3dc27..393cc72c00e9f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -42,9 +42,10 @@ public abstract class StreamManager { * The returned ManagedBuffer will be release()'d after being written to the network. * * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param channel The connection used to serve chunk request. * @param chunkIndex 0-indexed chunk of the stream that's requested */ - public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel); /** * Called in response to a stream() request. The returned data is streamed to the client @@ -54,29 +55,13 @@ public abstract class StreamManager { * {@link #getChunk(long, int)} method. * * @param streamId id of a stream that has been previously registered with the StreamManager. + * @param channel The connection used to serve stream request. * @return A managed buffer for the stream, or null if the stream was not found. */ - public ManagedBuffer openStream(String streamId) { + public ManagedBuffer openStream(String streamId, Channel channel) { throw new UnsupportedOperationException(); } - /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. The getChunk() method will be called serially on this connection and once the - * connection is closed, the stream will never be used again, enabling cleanup. - * - * This must be called before the first getChunk() on the stream, but it may be invoked multiple - * times with the same channel and stream id. - */ - public void registerChannel(Channel channel, long streamId) { } - - /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. This is similar to {@link #registerChannel(Channel, long)} method, but the - * streamId argument is for the stream in response to a stream() request. - */ - public void registerChannel(Channel channel, String streamId) { } - /** * Indicates that the given channel has been terminated. After this occurs, we are guaranteed not * to read from the associated streams again, so any state can be cleaned up. diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index d66aa17e4131f..fb8ab7ba39c91 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -126,8 +126,7 @@ private void processStreamRequest(final StreamRequest req) { } ManagedBuffer buf; try { - buf = streamManager.openStream(req.streamId); - streamManager.registerChannel(channel, req.streamId); + buf = streamManager.openStream(req.streamId, channel); } catch (Exception e) { logger.error(String.format( "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index 37a8664a52661..b58cbaaadf086 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -32,6 +32,7 @@ import com.google.common.collect.Sets; import com.google.common.io.Closeables; +import io.netty.channel.Channel; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -92,7 +93,7 @@ public static void setUp() throws Exception { streamManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { assertEquals(STREAM_ID, streamId); if (chunkIndex == BUFFER_CHUNK_INDEX) { return new NioManagedBuffer(buf); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java index 2c72c53a33ae8..f8ca37054e5d9 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java @@ -65,7 +65,6 @@ public void handleChunkFetchRequest() throws Exception { managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); - streamManager.registerChannel(channel, streamId); TransportClient reverseClient = mock(TransportClient.class); ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient, rpcHandler.getStreamManager(), 2L); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index c0724e018263f..55826b2840d7a 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.network; import com.google.common.util.concurrent.Uninterruptibles; +import io.netty.channel.Channel; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; @@ -65,7 +66,7 @@ public void setUp() throws Exception { defaultManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { throw new UnsupportedOperationException(); } }; @@ -184,7 +185,7 @@ public void furtherRequestsDelay() throws Exception { final byte[] response = new byte[16]; final StreamManager manager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); return new NioManagedBuffer(ByteBuffer.wrap(response)); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index f3050cb79cdfd..4fb1f313fec61 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -31,6 +31,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.io.Files; +import io.netty.channel.Channel; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -70,12 +71,12 @@ public static void setUp() throws Exception { final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex) { + public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { throw new UnsupportedOperationException(); } @Override - public ManagedBuffer openStream(String streamId) { + public ManagedBuffer openStream(String streamId, Channel channel) { return testData.openStream(conf, streamId); } }; diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index ad640415a8e6d..316baebde4453 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -59,7 +59,6 @@ public void handleStreamRequest() throws Exception { managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); - streamManager.registerChannel(channel, streamId); TransportClient reverseClient = mock(TransportClient.class); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L); diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 59adf9704cbf6..5fee5a4953d78 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -259,7 +259,7 @@ public void testFileRegionEncryption() throws Exception { try { TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt())).thenAnswer(invocation -> + when(sm.getChunk(anyLong(), anyInt(), any(Channel.class))).thenAnswer(invocation -> new FileSegmentManagedBuffer(conf, file, 0, file.length())); RpcHandler rpcHandler = mock(RpcHandler.class); diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index 780fadd5bda8e..edc6f2736b3f4 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.rpc.netty import java.io.File import java.util.concurrent.ConcurrentHashMap +import io.netty.channel.Channel + import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc.RpcEnvFileServer @@ -43,11 +45,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) private val jars = new ConcurrentHashMap[String, File]() private val dirs = new ConcurrentHashMap[String, File]() - override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { + override def getChunk(streamId: Long, chunkIndex: Int, channel: Channel): ManagedBuffer = { throw new UnsupportedOperationException() } - override def openStream(streamId: String): ManagedBuffer = { + override def openStream(streamId: String, channel: Channel): ManagedBuffer = { val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) val file = ftype match { case "files" => files.get(fname) From 53e9c6e123ede351eb1e5ab330cb6d486bb1cab4 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Jan 2019 08:40:50 +0800 Subject: [PATCH 3/4] Fix java style. --- .../apache/spark/network/server/ChunkFetchRequestHandler.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index 069dc61f2c27e..36e82e95bd534 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -90,7 +90,8 @@ protected void channelRead0( ManagedBuffer buf; try { streamManager.checkAuthorization(client, msg.streamChunkId.streamId); - buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex, channel); + buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex, + channel); } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", msg.streamChunkId, getRemoteAddress(channel)), e); From 7eed77988d5330b054b4a0af60639e680bc116e5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Tue, 15 Jan 2019 17:08:42 +0800 Subject: [PATCH 4/4] Register channel when registering stream. --- .../server/ChunkFetchRequestHandler.java | 3 +- .../server/OneForOneStreamManager.java | 38 +++++++++---------- .../spark/network/server/StreamManager.java | 6 +-- .../server/TransportRequestHandler.java | 2 +- .../network/ChunkFetchIntegrationSuite.java | 3 +- .../ChunkFetchRequestHandlerSuite.java | 2 +- .../RequestTimeoutIntegrationSuite.java | 5 +-- .../org/apache/spark/network/StreamSuite.java | 5 +-- .../network/TransportRequestHandlerSuite.java | 8 +++- .../spark/network/sasl/SparkSaslSuite.java | 2 +- .../server/OneForOneStreamManagerSuite.java | 5 ++- .../shuffle/ExternalShuffleBlockHandler.java | 2 +- .../ExternalShuffleBlockHandlerSuite.java | 3 +- .../network/netty/NettyBlockRpcServer.scala | 3 +- .../spark/rpc/netty/NettyStreamManager.scala | 6 +-- 15 files changed, 45 insertions(+), 48 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java index 36e82e95bd534..43c3d23b6304d 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/ChunkFetchRequestHandler.java @@ -90,8 +90,7 @@ protected void channelRead0( ManagedBuffer buf; try { streamManager.checkAuthorization(client, msg.streamChunkId.streamId); - buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex, - channel); + buf = streamManager.getChunk(msg.streamChunkId.streamId, msg.streamChunkId.chunkIndex); } catch (Exception e) { logger.error(String.format("Error opening block %s for request from %s", msg.streamChunkId, getRemoteAddress(channel)), e); diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index 2b45e7906ffff..6fafcc131fa24 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -23,6 +23,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicLong; +import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.netty.channel.Channel; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -49,7 +50,7 @@ private static class StreamState { final Iterator buffers; // The channel associated to the stream - Channel associatedChannel = null; + final Channel associatedChannel; // Used to keep track of the index of the buffer that the user has retrieved, just to ensure // that the caller only requests each chunk one at a time, in order. @@ -58,9 +59,10 @@ private static class StreamState { // Used to keep track of the number of chunks being transferred and not finished yet. volatile long chunksBeingTransferred = 0L; - StreamState(String appId, Iterator buffers) { + StreamState(String appId, Iterator buffers, Channel channel) { this.appId = appId; this.buffers = Preconditions.checkNotNull(buffers); + this.associatedChannel = channel; } } @@ -71,22 +73,8 @@ public OneForOneStreamManager() { streams = new ConcurrentHashMap<>(); } - - /** - * Associates a stream with a single client connection, which is guaranteed to be the only reader - * of the stream. Once the connection is closed, the stream will never be used again, enabling - * cleanup by `connectionTerminated`. - */ - public void registerChannel(Channel channel, long streamId) { - if (streams.containsKey(streamId)) { - streams.get(streamId).associatedChannel = channel; - } - } - @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { - registerChannel(channel, streamId); - + public ManagedBuffer getChunk(long streamId, int chunkIndex) { StreamState state = streams.get(streamId); if (chunkIndex != state.curChunk) { throw new IllegalStateException(String.format( @@ -107,9 +95,9 @@ public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { } @Override - public ManagedBuffer openStream(String streamChunkId, Channel channel) { + public ManagedBuffer openStream(String streamChunkId) { Pair streamChunkIdPair = parseStreamChunkId(streamChunkId); - return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight(), channel); + return getChunk(streamChunkIdPair.getLeft(), streamChunkIdPair.getRight()); } public static String genStreamChunkId(long streamId, int chunkId) { @@ -202,11 +190,19 @@ public long chunksBeingTransferred() { * * If an app ID is provided, only callers who've authenticated with the given app ID will be * allowed to fetch from this stream. + * + * This method also associates the stream with a single client connection, which is guaranteed + * to be the only reader of the stream. Once the connection is closed, the stream will never + * be used again, enabling cleanup by `connectionTerminated`. */ - public long registerStream(String appId, Iterator buffers) { + public long registerStream(String appId, Iterator buffers, Channel channel) { long myStreamId = nextStreamId.getAndIncrement(); - streams.put(myStreamId, new StreamState(appId, buffers)); + streams.put(myStreamId, new StreamState(appId, buffers, channel)); return myStreamId; } + @VisibleForTesting + public int numStreamStates() { + return streams.size(); + } } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java index 393cc72c00e9f..e48d27be1126a 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -42,10 +42,9 @@ public abstract class StreamManager { * The returned ManagedBuffer will be release()'d after being written to the network. * * @param streamId id of a stream that has been previously registered with the StreamManager. - * @param channel The connection used to serve chunk request. * @param chunkIndex 0-indexed chunk of the stream that's requested */ - public abstract ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel); + public abstract ManagedBuffer getChunk(long streamId, int chunkIndex); /** * Called in response to a stream() request. The returned data is streamed to the client @@ -55,10 +54,9 @@ public abstract class StreamManager { * {@link #getChunk(long, int)} method. * * @param streamId id of a stream that has been previously registered with the StreamManager. - * @param channel The connection used to serve stream request. * @return A managed buffer for the stream, or null if the stream was not found. */ - public ManagedBuffer openStream(String streamId, Channel channel) { + public ManagedBuffer openStream(String streamId) { throw new UnsupportedOperationException(); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index fb8ab7ba39c91..3e089b4cae273 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -126,7 +126,7 @@ private void processStreamRequest(final StreamRequest req) { } ManagedBuffer buf; try { - buf = streamManager.openStream(req.streamId, channel); + buf = streamManager.openStream(req.streamId); } catch (Exception e) { logger.error(String.format( "Error opening stream %s for request from %s", req.streamId, getRemoteAddress(channel)), e); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java index b58cbaaadf086..37a8664a52661 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchIntegrationSuite.java @@ -32,7 +32,6 @@ import com.google.common.collect.Sets; import com.google.common.io.Closeables; -import io.netty.channel.Channel; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -93,7 +92,7 @@ public static void setUp() throws Exception { streamManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { + public ManagedBuffer getChunk(long streamId, int chunkIndex) { assertEquals(STREAM_ID, streamId); if (chunkIndex == BUFFER_CHUNK_INDEX) { return new NioManagedBuffer(buf); diff --git a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java index f8ca37054e5d9..6c9239606bb8c 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/ChunkFetchRequestHandlerSuite.java @@ -64,7 +64,7 @@ public void handleChunkFetchRequest() throws Exception { managedBuffers.add(new TestManagedBuffer(20)); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); - long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); TransportClient reverseClient = mock(TransportClient.class); ChunkFetchRequestHandler requestHandler = new ChunkFetchRequestHandler(reverseClient, rpcHandler.getStreamManager(), 2L); diff --git a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java index 55826b2840d7a..c0724e018263f 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/RequestTimeoutIntegrationSuite.java @@ -18,7 +18,6 @@ package org.apache.spark.network; import com.google.common.util.concurrent.Uninterruptibles; -import io.netty.channel.Channel; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.buffer.NioManagedBuffer; import org.apache.spark.network.client.ChunkReceivedCallback; @@ -66,7 +65,7 @@ public void setUp() throws Exception { defaultManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { + public ManagedBuffer getChunk(long streamId, int chunkIndex) { throw new UnsupportedOperationException(); } }; @@ -185,7 +184,7 @@ public void furtherRequestsDelay() throws Exception { final byte[] response = new byte[16]; final StreamManager manager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { + public ManagedBuffer getChunk(long streamId, int chunkIndex) { Uninterruptibles.sleepUninterruptibly(FOREVER, TimeUnit.MILLISECONDS); return new NioManagedBuffer(ByteBuffer.wrap(response)); } diff --git a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java index 4fb1f313fec61..f3050cb79cdfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/StreamSuite.java @@ -31,7 +31,6 @@ import java.util.concurrent.TimeUnit; import com.google.common.io.Files; -import io.netty.channel.Channel; import org.junit.AfterClass; import org.junit.BeforeClass; import org.junit.Test; @@ -71,12 +70,12 @@ public static void setUp() throws Exception { final TransportConf conf = new TransportConf("shuffle", MapConfigProvider.EMPTY); final StreamManager streamManager = new StreamManager() { @Override - public ManagedBuffer getChunk(long streamId, int chunkIndex, Channel channel) { + public ManagedBuffer getChunk(long streamId, int chunkIndex) { throw new UnsupportedOperationException(); } @Override - public ManagedBuffer openStream(String streamId, Channel channel) { + public ManagedBuffer openStream(String streamId) { return testData.openStream(conf, streamId); } }; diff --git a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java index 316baebde4453..a87f6c11a2bfd 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/TransportRequestHandlerSuite.java @@ -58,7 +58,10 @@ public void handleStreamRequest() throws Exception { managedBuffers.add(new TestManagedBuffer(20)); managedBuffers.add(new TestManagedBuffer(30)); managedBuffers.add(new TestManagedBuffer(40)); - long streamId = streamManager.registerStream("test-app", managedBuffers.iterator()); + long streamId = streamManager.registerStream("test-app", managedBuffers.iterator(), channel); + + assert streamManager.numStreamStates() == 1; + TransportClient reverseClient = mock(TransportClient.class); TransportRequestHandler requestHandler = new TransportRequestHandler(channel, reverseClient, rpcHandler, 2L); @@ -93,5 +96,8 @@ public void handleStreamRequest() throws Exception { requestHandler.handle(request3); verify(channel, times(1)).close(); assert responseAndPromisePairs.size() == 3; + + streamManager.connectionTerminated(channel); + assert streamManager.numStreamStates() == 0; } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 5fee5a4953d78..59adf9704cbf6 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -259,7 +259,7 @@ public void testFileRegionEncryption() throws Exception { try { TransportConf conf = new TransportConf("shuffle", new MapConfigProvider(testConf)); StreamManager sm = mock(StreamManager.class); - when(sm.getChunk(anyLong(), anyInt(), any(Channel.class))).thenAnswer(invocation -> + when(sm.getChunk(anyLong(), anyInt())).thenAnswer(invocation -> new FileSegmentManagedBuffer(conf, file, 0, file.length())); RpcHandler rpcHandler = mock(RpcHandler.class); diff --git a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java index c647525d8f1bd..4248762c32389 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/server/OneForOneStreamManagerSuite.java @@ -37,14 +37,15 @@ public void managedBuffersAreFeedWhenConnectionIsClosed() throws Exception { TestManagedBuffer buffer2 = Mockito.spy(new TestManagedBuffer(20)); buffers.add(buffer1); buffers.add(buffer2); - long streamId = manager.registerStream("appId", buffers.iterator()); Channel dummyChannel = Mockito.mock(Channel.class, Mockito.RETURNS_SMART_NULLS); - manager.registerChannel(dummyChannel, streamId); + manager.registerStream("appId", buffers.iterator(), dummyChannel); + assert manager.numStreamStates() == 1; manager.connectionTerminated(dummyChannel); Mockito.verify(buffer1, Mockito.times(1)).release(); Mockito.verify(buffer2, Mockito.times(1)).release(); + assert manager.numStreamStates() == 0; } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 788a845c57755..b25e48a164e6b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -92,7 +92,7 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); long streamId = streamManager.registerStream(client.getClientId(), - new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds)); + new ManagedBufferIterator(msg.appId, msg.execId, msg.blockIds), client.getChannel()); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index 4cc9a16e1449f..537c277cd26b5 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -103,7 +103,8 @@ public void testOpenShuffleBlocks() { @SuppressWarnings("unchecked") ArgumentCaptor> stream = (ArgumentCaptor>) (ArgumentCaptor) ArgumentCaptor.forClass(Iterator.class); - verify(streamManager, times(1)).registerStream(anyString(), stream.capture()); + verify(streamManager, times(1)).registerStream(anyString(), stream.capture(), + any()); Iterator buffers = stream.getValue(); assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 7076701421e2e..27f4f94ea55f8 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -59,7 +59,8 @@ class NettyBlockRpcServer( val blocksNum = openBlocks.blockIds.length val blocks = for (i <- (0 until blocksNum).view) yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) - val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) + val streamId = streamManager.registerStream(appId, blocks.iterator.asJava, + client.getChannel) logTrace(s"Registered streamId $streamId with $blocksNum buffers") responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index edc6f2736b3f4..780fadd5bda8e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -19,8 +19,6 @@ package org.apache.spark.rpc.netty import java.io.File import java.util.concurrent.ConcurrentHashMap -import io.netty.channel.Channel - import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc.RpcEnvFileServer @@ -45,11 +43,11 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) private val jars = new ConcurrentHashMap[String, File]() private val dirs = new ConcurrentHashMap[String, File]() - override def getChunk(streamId: Long, chunkIndex: Int, channel: Channel): ManagedBuffer = { + override def getChunk(streamId: Long, chunkIndex: Int): ManagedBuffer = { throw new UnsupportedOperationException() } - override def openStream(streamId: String, channel: Channel): ManagedBuffer = { + override def openStream(streamId: String): ManagedBuffer = { val Array(ftype, fname) = streamId.stripPrefix("/").split("/", 2) val file = ftype match { case "files" => files.get(fname)