From 74107fa9be8f2f4008dcca20e2db91c53519e2fa Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Thu, 21 May 2020 11:34:47 +0100 Subject: [PATCH 1/8] KAFKA-10029; Don't update completedReceives when channels are closed to avoid ConcurrentModificationException --- .../apache/kafka/common/network/Selector.java | 1 - .../kafka/common/network/SelectorTest.java | 55 +++++++++++++++++++ .../scala/kafka/network/SocketServer.scala | 4 +- .../unit/kafka/network/SocketServerTest.scala | 25 +++++++-- 4 files changed, 76 insertions(+), 9 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index cb91cad92575e..8fb20e02b8409 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -935,7 +935,6 @@ private void doClose(KafkaChannel channel, boolean notifyDisconnect) { } this.sensors.connectionClosed.record(); - this.completedReceives.remove(channel); this.explicitlyMutedChannels.remove(channel); if (notifyDisconnect) this.disconnected.put(channel.id(), channel.state()); diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java index 57b0153f4dd83..ef1076a36ca96 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java @@ -48,6 +48,7 @@ import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -904,6 +905,60 @@ public void testWriteCompletesSendWithNoBytesWritten() throws IOException { assertEquals(asList(send), selector.completedSends()); } + /** + * Ensure that no errors are thrown if channels are closed while processing multiple completed receives + */ + @Test + public void testChannelCloseWhileProcessingReceives() throws Exception { + int numChannels = 4; + Map channels = TestUtils.fieldValue(selector, Selector.class, "channels"); + Set selectionKeys = new HashSet<>(); + for (int i = 0; i < numChannels; i++) { + String id = String.valueOf(i); + KafkaChannel channel = mock(KafkaChannel.class); + channels.put(id, channel); + when(channel.id()).thenReturn(id); + when(channel.state()).thenReturn(ChannelState.READY); + when(channel.isConnected()).thenReturn(true); + when(channel.ready()).thenReturn(true); + when(channel.read()).thenReturn(1L); + + SelectionKey selectionKey = mock(SelectionKey.class); + when(channel.selectionKey()).thenReturn(selectionKey); + when(selectionKey.isValid()).thenReturn(true); + when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ); + selectionKey.attach(channel); + selectionKeys.add(selectionKey); + + NetworkReceive receive = mock(NetworkReceive.class); + when(receive.source()).thenReturn(id); + when(receive.size()).thenReturn(10); + when(receive.bytesRead()).thenReturn(1); + when(receive.payload()).thenReturn(ByteBuffer.allocate(10)); + when(channel.maybeCompleteReceive()).thenReturn(receive); + } + + selector.pollSelectionKeys(selectionKeys, false, System.nanoTime()); + assertEquals(numChannels, selector.completedReceives().size()); + Set closed = new HashSet<>(); + Set notClosed = new HashSet<>(); + for (NetworkReceive receive : selector.completedReceives()) { + KafkaChannel channel = selector.channel(receive.source()); + assertNotNull(channel); + if (closed.size() < 2) { + selector.close(channel.id()); + closed.add(channel); + } else + notClosed.add(channel); + } + assertEquals(notClosed, new HashSet<>(selector.channels())); + closed.forEach(channel -> assertNull(selector.channel(channel.id()))); + + selector.poll(0); + assertEquals(0, selector.completedReceives().size()); + } + + private String blockingRequest(String node, String s) throws IOException { selector.send(createSend(node, s)); selector.poll(1000L); diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index 16d9322b546a1..ca88c54fe9909 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -825,7 +825,7 @@ private[kafka] class Processor(val id: Int, // be either associated with a specific socket channel or a bad request. These exceptions are caught and // processed by the individual methods above which close the failing channel and continue processing other // channels. So this catch block should only ever see ControlThrowables. - case e: Throwable => processException("Processor got uncaught exception.", e) + case e: Throwable => processException("Processor got uncaught exception.", e, isUncaught = true) } } } finally { @@ -835,7 +835,7 @@ private[kafka] class Processor(val id: Int, } } - private def processException(errorMessage: String, throwable: Throwable): Unit = { + private[network] def processException(errorMessage: String, throwable: Throwable, isUncaught: Boolean = false): Unit = { throwable match { case e: ControlThrowable => throw e case e => error(errorMessage, e) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 6743dc8bab6bd..e307f8b2f335d 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1597,6 +1597,8 @@ class SocketServerTest { testableSelector.waitForOperations(SelectorOperation.Poll, 1) testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1) + assertEquals(1, testableServer.uncaughtExceptions) + testableServer.uncaughtExceptions = 0 }) } @@ -1675,6 +1677,7 @@ class SocketServerTest { testWithServer(testableServer) } finally { shutdownServerAndMetrics(testableServer) + assertEquals(0, testableServer.uncaughtExceptions) } } @@ -1730,6 +1733,7 @@ class SocketServerTest { new Metrics, time, credentialProvider) { @volatile var selector: Option[TestableSelector] = None + @volatile var uncaughtExceptions = 0 override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName, protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = { @@ -1742,6 +1746,12 @@ class SocketServerTest { selector = Some(testableSelector) testableSelector } + + override private[network] def processException(errorMessage: String, throwable: Throwable, isUncaught: Boolean): Unit = { + if (isUncaught) + uncaughtExceptions += 1 + super.processException(errorMessage, throwable, isUncaught) + } } } @@ -1807,6 +1817,7 @@ class SocketServerTest { currentPollValues ++= newValues } else deferredValues ++= newValues + newValues.clear() } def reset(): Unit = { currentPollValues.clear() @@ -1875,6 +1886,14 @@ class SocketServerTest { cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer) cachedCompletedSends.update(super.completedSends.asScala) cachedDisconnected.update(super.disconnected.asScala.toBuffer) + + val map: util.Map[KafkaChannel, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") + cachedCompletedReceives.currentPollValues.foreach { receive => + val channel = Option(super.channel(receive.source)).orElse(Option(super.closingChannel(receive.source))) + channel.foreach(map.put(_, receive)) + } + cachedCompletedSends.currentPollValues.foreach(super.completedSends.add) + cachedDisconnected.currentPollValues.foreach { case (id, state) => super.disconnected.put(id, state) } } } @@ -1899,12 +1918,6 @@ class SocketServerTest { } } - override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava - - override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava - - override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava - override def close(id: String): Unit = { runOp(SelectorOperation.Close, Some(id)) { super.close(id) From 09589be517c2e9ecc912c1e2fd8d17eeda433365 Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Tue, 26 May 2020 10:57:09 +0100 Subject: [PATCH 2/8] Address review comments --- .../org/apache/kafka/common/network/Selector.java | 6 +++--- core/src/main/scala/kafka/network/SocketServer.scala | 4 ++-- .../scala/unit/kafka/network/SocketServerTest.scala | 12 ++++++------ 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index 8fb20e02b8409..f47ec39783281 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -107,7 +107,7 @@ private enum CloseMode { private final Set explicitlyMutedChannels; private boolean outOfMemory; private final List completedSends; - private final LinkedHashMap completedReceives; + private final LinkedHashMap completedReceives; private final Set immediatelyConnectedKeys; private final Map closingChannels; private Set keysWithBufferedRead; @@ -1014,7 +1014,7 @@ private KafkaChannel channel(SelectionKey key) { * Check if given channel has a completed receive */ private boolean hasCompletedReceive(KafkaChannel channel) { - return completedReceives.containsKey(channel); + return completedReceives.containsKey(channel.id()); } /** @@ -1024,7 +1024,7 @@ private void addToCompletedReceives(KafkaChannel channel, NetworkReceive network if (hasCompletedReceive(channel)) throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id()); - this.completedReceives.put(channel, networkReceive); + this.completedReceives.put(channel.id(), networkReceive); sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs); } diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index ca88c54fe9909..d31c645fbf8bb 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -825,7 +825,7 @@ private[kafka] class Processor(val id: Int, // be either associated with a specific socket channel or a bad request. These exceptions are caught and // processed by the individual methods above which close the failing channel and continue processing other // channels. So this catch block should only ever see ControlThrowables. - case e: Throwable => processException("Processor got uncaught exception.", e, isUncaught = true) + case e: Throwable => processException("Processor got uncaught exception.", e) } } } finally { @@ -835,7 +835,7 @@ private[kafka] class Processor(val id: Int, } } - private[network] def processException(errorMessage: String, throwable: Throwable, isUncaught: Boolean = false): Unit = { + private[network] def processException(errorMessage: String, throwable: Throwable): Unit = { throwable match { case e: ControlThrowable => throw e case e => error(errorMessage, e) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index e307f8b2f335d..240940bb6b8a0 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1747,10 +1747,10 @@ class SocketServerTest { testableSelector } - override private[network] def processException(errorMessage: String, throwable: Throwable, isUncaught: Boolean): Unit = { - if (isUncaught) + override private[network] def processException(errorMessage: String, throwable: Throwable): Unit = { + if (errorMessage.contains("uncaught exception")) uncaughtExceptions += 1 - super.processException(errorMessage, throwable, isUncaught) + super.processException(errorMessage, throwable) } } } @@ -1887,10 +1887,10 @@ class SocketServerTest { cachedCompletedSends.update(super.completedSends.asScala) cachedDisconnected.update(super.disconnected.asScala.toBuffer) - val map: util.Map[KafkaChannel, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") + val map: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") cachedCompletedReceives.currentPollValues.foreach { receive => - val channel = Option(super.channel(receive.source)).orElse(Option(super.closingChannel(receive.source))) - channel.foreach(map.put(_, receive)) + val channelOpt = Option(super.channel(receive.source)).orElse(Option(super.closingChannel(receive.source))) + channelOpt.foreach { channel => map.put(channel.id, receive) } } cachedCompletedSends.currentPollValues.foreach(super.completedSends.add) cachedDisconnected.currentPollValues.foreach { case (id, state) => super.disconnected.put(id, state) } From 64f84b3e7b72c2cb2c21987cecabb04d6d22ddee Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Wed, 27 May 2020 16:56:49 +0100 Subject: [PATCH 3/8] Address review comments --- .../apache/kafka/common/network/Selector.java | 28 ++++++++++++++++- .../kafka/common/network/SelectorTest.java | 30 +++++++++++++++++++ .../scala/kafka/network/SocketServer.scala | 2 ++ .../unit/kafka/network/SocketServerTest.scala | 3 ++ 4 files changed, 62 insertions(+), 1 deletion(-) diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index f47ec39783281..be040d9e5c17a 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -804,7 +804,33 @@ private void maybeCloseOldestConnection(long currentTimeNanos) { } /** - * Clear the results from the prior poll + * Clears completed receives. This is used by SocketServer to remove references to + * receive buffers after processing completed receives, without waiting for the next + * poll() after all results have been processed. + */ + public void clearCompletedReceives() { + this.completedReceives.clear(); + } + + /** + * Clears completed sends. This is used by SocketServer to remove references to + * send buffers after processing completed sends, without waiting for the next + * poll() after all results have been processed. + */ + public void clearCompletedSends() { + this.completedSends.clear(); + } + + /** + * Clears all the results from the previous poll. This is invoked by Selector at the start of + * a poll() when all the results from the previous poll are expected to have been handled. + *

+ * SocketServer uses {@link #clearCompletedSends()} and {@link #clearCompletedSends()} to + * clear `completedSends` and `completedReceives` as soon as they are processed to avoid + * holding onto large request/response buffers from multiple connections longer than necessary. + * Clients rely on Selector invoking {@link #clear()} at the start of each poll() since memory usage + * is less critical and clearing once-per-poll provides the flexibility to process these results in + * any order before the next poll. */ private void clear() { this.completedSends.clear(); diff --git a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java index ef1076a36ca96..ac773eed3dcb4 100644 --- a/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java +++ b/clients/src/test/java/org/apache/kafka/common/network/SelectorTest.java @@ -342,6 +342,36 @@ public void testEmptyRequest() throws Exception { assertEquals("", blockingRequest(node, "")); } + @Test + public void testClearCompletedSendsAndReceives() throws Exception { + int bufferSize = 1024; + String node = "0"; + InetSocketAddress addr = new InetSocketAddress("localhost", server.port); + connect(node, addr); + String request = TestUtils.randomString(bufferSize); + selector.send(createSend(node, request)); + boolean sent = false; + boolean received = false; + while (!sent || !received) { + selector.poll(1000L); + assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size()); + if (!selector.completedSends().isEmpty()) { + assertEquals(1, selector.completedSends().size()); + selector.clearCompletedSends(); + assertEquals(0, selector.completedSends().size()); + sent = true; + } + + if (!selector.completedReceives().isEmpty()) { + assertEquals(1, selector.completedReceives().size()); + assertEquals(request, asString(selector.completedReceives().iterator().next())); + selector.clearCompletedReceives(); + assertEquals(0, selector.completedReceives().size()); + received = true; + } + } + } + @Test(expected = IllegalStateException.class) public void testExistingConnectionId() throws IOException { blockingConnect("0"); diff --git a/core/src/main/scala/kafka/network/SocketServer.scala b/core/src/main/scala/kafka/network/SocketServer.scala index d31c645fbf8bb..bff185d3e0385 100644 --- a/core/src/main/scala/kafka/network/SocketServer.scala +++ b/core/src/main/scala/kafka/network/SocketServer.scala @@ -968,6 +968,7 @@ private[kafka] class Processor(val id: Int, processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e) } } + selector.clearCompletedReceives() } private def processCompletedSends(): Unit = { @@ -991,6 +992,7 @@ private[kafka] class Processor(val id: Int, s"Exception while processing completed send to ${send.destination}", e) } } + selector.clearCompletedSends() } private def updateRequestMetrics(response: RequestChannel.Response): Unit = { diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 240940bb6b8a0..8239dd4e15c7d 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1872,6 +1872,9 @@ class SocketServerTest { override def poll(timeout: Long): Unit = { try { + assertEquals(0, super.completedReceives().size) + assertEquals(0, super.completedSends().size) + pollCallback.apply() while (!pendingClosingChannels.isEmpty) { makeClosing(pendingClosingChannels.poll()) From 8cbe30e0cb4d17d8f4ccb2e1721e1f6527e7be87 Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Wed, 27 May 2020 17:04:19 +0100 Subject: [PATCH 4/8] Address review comment --- .../java/org/apache/kafka/common/network/KafkaChannel.java | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java index 4e4edd47adb3c..0ed9ee0f96150 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java +++ b/clients/src/main/java/org/apache/kafka/common/network/KafkaChannel.java @@ -28,7 +28,6 @@ import java.net.SocketAddress; import java.nio.channels.SelectionKey; import java.nio.channels.SocketChannel; -import java.util.Objects; import java.util.Optional; import java.util.function.Supplier; @@ -471,12 +470,12 @@ public boolean equals(Object o) { return false; } KafkaChannel that = (KafkaChannel) o; - return Objects.equals(id, that.id); + return id.equals(that.id); } @Override public int hashCode() { - return Objects.hash(id); + return id.hashCode(); } @Override From a44db9437a46560e1f1f56cec036d69a2357bbe1 Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Thu, 28 May 2020 10:16:09 +0100 Subject: [PATCH 5/8] Address review comments --- .../apache/kafka/common/network/Selector.java | 4 +-- .../unit/kafka/network/SocketServerTest.scala | 28 ++++++++++--------- 2 files changed, 17 insertions(+), 15 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/network/Selector.java b/clients/src/main/java/org/apache/kafka/common/network/Selector.java index be040d9e5c17a..06f7048793fc5 100644 --- a/clients/src/main/java/org/apache/kafka/common/network/Selector.java +++ b/clients/src/main/java/org/apache/kafka/common/network/Selector.java @@ -806,7 +806,7 @@ private void maybeCloseOldestConnection(long currentTimeNanos) { /** * Clears completed receives. This is used by SocketServer to remove references to * receive buffers after processing completed receives, without waiting for the next - * poll() after all results have been processed. + * poll(). */ public void clearCompletedReceives() { this.completedReceives.clear(); @@ -815,7 +815,7 @@ public void clearCompletedReceives() { /** * Clears completed sends. This is used by SocketServer to remove references to * send buffers after processing completed sends, without waiting for the next - * poll() after all results have been processed. + * poll(). */ public void clearCompletedSends() { this.completedSends.clear(); diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 8239dd4e15c7d..7cd80eb8999d7 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1807,9 +1807,9 @@ class SocketServerTest { class PollData[T] { var minPerPoll = 1 val deferredValues = mutable.Buffer[T]() - val currentPollValues = mutable.Buffer[T]() - def update(newValues: mutable.Buffer[T]): Unit = { - if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) { + def update(newValues: mutable.Buffer[T], addToCurrentResult: T => Unit): Unit = { + val currentPollValues = mutable.Buffer[T]() + if (deferredValues.size + newValues.size >= minPerPoll) { if (deferredValues.nonEmpty) { currentPollValues ++= deferredValues deferredValues.clear() @@ -1817,10 +1817,10 @@ class SocketServerTest { currentPollValues ++= newValues } else deferredValues ++= newValues + + // Update results for the current poll newValues.clear() - } - def reset(): Unit = { - currentPollValues.clear() + currentPollValues.foreach(addToCurrentResult) } } val cachedCompletedReceives = new PollData[NetworkReceive]() @@ -1879,24 +1879,26 @@ class SocketServerTest { while (!pendingClosingChannels.isEmpty) { makeClosing(pendingClosingChannels.poll()) } - allCachedPollData.foreach(_.reset) runOp(SelectorOperation.Poll, None) { super.poll(pollTimeoutOverride.getOrElse(timeout)) } } finally { super.channels.forEach(allChannels += _.id) allDisconnectedChannels ++= super.disconnected.asScala.keys - cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer) - cachedCompletedSends.update(super.completedSends.asScala) - cachedDisconnected.update(super.disconnected.asScala.toBuffer) val map: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") - cachedCompletedReceives.currentPollValues.foreach { receive => + def addToCompletedReceives(receive: NetworkReceive): Unit = { val channelOpt = Option(super.channel(receive.source)).orElse(Option(super.closingChannel(receive.source))) channelOpt.foreach { channel => map.put(channel.id, receive) } } - cachedCompletedSends.currentPollValues.foreach(super.completedSends.add) - cachedDisconnected.currentPollValues.foreach { case (id, state) => super.disconnected.put(id, state) } + + // For each result type (completedReceives/completedSends/disconnected), defer the result to a subsequent poll() + // if `minPerPoll` results are not yet available. When sufficient results are available, all available results + // including previously deferred results are returned. This allows tests to process `minPerPoll` elements as the + // results of a single poll iteration. + cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer, addToCompletedReceives) + cachedCompletedSends.update(super.completedSends.asScala, s => super.completedSends.add(s)) + cachedDisconnected.update(super.disconnected.asScala.toBuffer, d => super.disconnected.put(d._1, d._2)) } } From ceace47a433b930570018f256d69a9348653c5ed Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Thu, 28 May 2020 14:16:37 +0100 Subject: [PATCH 6/8] Address review comment --- core/src/test/scala/unit/kafka/network/SocketServerTest.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 7cd80eb8999d7..cc352883a611b 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1886,10 +1886,10 @@ class SocketServerTest { super.channels.forEach(allChannels += _.id) allDisconnectedChannels ++= super.disconnected.asScala.keys - val map: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") + val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") def addToCompletedReceives(receive: NetworkReceive): Unit = { val channelOpt = Option(super.channel(receive.source)).orElse(Option(super.closingChannel(receive.source))) - channelOpt.foreach { channel => map.put(channel.id, receive) } + channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) } } // For each result type (completedReceives/completedSends/disconnected), defer the result to a subsequent poll() From d572d8e2cb7cf0c3d7fb40e923b6e9579845555a Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Thu, 28 May 2020 15:04:41 +0100 Subject: [PATCH 7/8] Address review comments --- .../unit/kafka/network/SocketServerTest.scala | 68 ++++++++++++++----- 1 file changed, 51 insertions(+), 17 deletions(-) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index cc352883a611b..0afc44acabe35 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1804,10 +1804,16 @@ class SocketServerTest { // Enable data from `Selector.poll()` to be deferred to a subsequent poll() until // the number of elements of that type reaches `minPerPoll`. This enables tests to verify // that failed processing doesn't impact subsequent processing within the same iteration. - class PollData[T] { + abstract class PollData[T] { var minPerPoll = 1 val deferredValues = mutable.Buffer[T]() - def update(newValues: mutable.Buffer[T], addToCurrentResult: T => Unit): Unit = { + + /** + * Process new results and return the results for the current poll if at least + * `minPerPoll` results are available including any deferred results. Otherwise + * add the provided values to the deferred set and return an empty buffer. + */ + protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = { val currentPollValues = mutable.Buffer[T]() if (deferredValues.size + newValues.size >= minPerPoll) { if (deferredValues.nonEmpty) { @@ -1818,14 +1824,48 @@ class SocketServerTest { } else deferredValues ++= newValues - // Update results for the current poll - newValues.clear() - currentPollValues.foreach(addToCurrentResult) + currentPollValues + } + + /** + * Process results from the appropriate buffer in Selector and update the buffer to either + * defer and return nothing or return all results including previously deferred values. + */ + def updateResults(): Unit + } + + class CompletedReceivesPollData(selector: TestableSelector) extends PollData[NetworkReceive] { + val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(selector, classOf[Selector], "completedReceives") + + override def updateResults(): Unit = { + val currentReceives = update(selector.completedReceives.asScala.toBuffer) + completedReceivesMap.clear() + currentReceives.foreach { receive => + val channelOpt = Option(selector.channel(receive.source)).orElse(Option(selector.closingChannel(receive.source))) + channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) } + } + } + } + + class CompletedSendsPollData(selector: TestableSelector) extends PollData[Send] { + override def updateResults(): Unit = { + val currentSends = update(selector.completedSends.asScala) + selector.completedSends.clear() + currentSends.foreach { selector.completedSends.add } + } + } + + class DisconnectedPollData(selector: TestableSelector) extends PollData[(String, ChannelState)] { + override def updateResults(): Unit = { + val currentDisconnected = update(selector.disconnected.asScala.toBuffer) + selector.disconnected.clear() + currentDisconnected.foreach { case (channelId, state) => selector.disconnected.put(channelId, state) } } } - val cachedCompletedReceives = new PollData[NetworkReceive]() - val cachedCompletedSends = new PollData[Send]() - val cachedDisconnected = new PollData[(String, ChannelState)]() + + val cachedCompletedReceives = new CompletedReceivesPollData(this) + val cachedCompletedSends = new CompletedSendsPollData(this) + val cachedDisconnected = new DisconnectedPollData(this) val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected) val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]() @volatile var minWakeupCount = 0 @@ -1886,19 +1926,13 @@ class SocketServerTest { super.channels.forEach(allChannels += _.id) allDisconnectedChannels ++= super.disconnected.asScala.keys - val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(this, classOf[Selector], "completedReceives") - def addToCompletedReceives(receive: NetworkReceive): Unit = { - val channelOpt = Option(super.channel(receive.source)).orElse(Option(super.closingChannel(receive.source))) - channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) } - } - // For each result type (completedReceives/completedSends/disconnected), defer the result to a subsequent poll() // if `minPerPoll` results are not yet available. When sufficient results are available, all available results // including previously deferred results are returned. This allows tests to process `minPerPoll` elements as the // results of a single poll iteration. - cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer, addToCompletedReceives) - cachedCompletedSends.update(super.completedSends.asScala, s => super.completedSends.add(s)) - cachedDisconnected.update(super.disconnected.asScala.toBuffer, d => super.disconnected.put(d._1, d._2)) + cachedCompletedReceives.updateResults() + cachedCompletedSends.updateResults() + cachedDisconnected.updateResults() } } From e605e29f4452aa6bfe1eb077e9af93ac3e57d7a8 Mon Sep 17 00:00:00 2001 From: Rajini Sivaram Date: Fri, 29 May 2020 09:08:32 +0100 Subject: [PATCH 8/8] Address review comment --- .../test/scala/unit/kafka/network/SocketServerTest.scala | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala index 0afc44acabe35..8f1091ed13303 100644 --- a/core/src/test/scala/unit/kafka/network/SocketServerTest.scala +++ b/core/src/test/scala/unit/kafka/network/SocketServerTest.scala @@ -1811,7 +1811,8 @@ class SocketServerTest { /** * Process new results and return the results for the current poll if at least * `minPerPoll` results are available including any deferred results. Otherwise - * add the provided values to the deferred set and return an empty buffer. + * add the provided values to the deferred set and return an empty buffer. This allows + * tests to process `minPerPoll` elements as the results of a single poll iteration. */ protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = { val currentPollValues = mutable.Buffer[T]() @@ -1926,10 +1927,6 @@ class SocketServerTest { super.channels.forEach(allChannels += _.id) allDisconnectedChannels ++= super.disconnected.asScala.keys - // For each result type (completedReceives/completedSends/disconnected), defer the result to a subsequent poll() - // if `minPerPoll` results are not yet available. When sufficient results are available, all available results - // including previously deferred results are returned. This allows tests to process `minPerPoll` elements as the - // results of a single poll iteration. cachedCompletedReceives.updateResults() cachedCompletedSends.updateResults() cachedDisconnected.updateResults()