Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import java.net.SocketAddress;
import java.nio.channels.SelectionKey;
import java.nio.channels.SocketChannel;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;

Expand Down Expand Up @@ -471,12 +470,12 @@ public boolean equals(Object o) {
return false;
}
KafkaChannel that = (KafkaChannel) o;
return Objects.equals(id, that.id);
return id.equals(that.id);
}

@Override
public int hashCode() {
return Objects.hash(id);
return id.hashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ private enum CloseMode {
private final Set<KafkaChannel> explicitlyMutedChannels;
private boolean outOfMemory;
private final List<Send> completedSends;
private final LinkedHashMap<KafkaChannel, NetworkReceive> completedReceives;
private final LinkedHashMap<String, NetworkReceive> completedReceives;
private final Set<SelectionKey> immediatelyConnectedKeys;
private final Map<String, KafkaChannel> closingChannels;
private Set<SelectionKey> keysWithBufferedRead;
Expand Down Expand Up @@ -804,7 +804,33 @@ private void maybeCloseOldestConnection(long currentTimeNanos) {
}

/**
* Clear the results from the prior poll
* Clears completed receives. This is used by SocketServer to remove references to
* receive buffers after processing completed receives, without waiting for the next
* poll().
*/
public void clearCompletedReceives() {
this.completedReceives.clear();
}

/**
* Clears completed sends. This is used by SocketServer to remove references to
* send buffers after processing completed sends, without waiting for the next
* poll().
*/
public void clearCompletedSends() {
this.completedSends.clear();
}

/**
* Clears all the results from the previous poll. This is invoked by Selector at the start of
* a poll() when all the results from the previous poll are expected to have been handled.
* <p>
* SocketServer uses {@link #clearCompletedSends()} and {@link #clearCompletedSends()} to
* clear `completedSends` and `completedReceives` as soon as they are processed to avoid
* holding onto large request/response buffers from multiple connections longer than necessary.
* Clients rely on Selector invoking {@link #clear()} at the start of each poll() since memory usage
* is less critical and clearing once-per-poll provides the flexibility to process these results in
* any order before the next poll.
*/
private void clear() {
this.completedSends.clear();
Expand Down Expand Up @@ -935,7 +961,6 @@ private void doClose(KafkaChannel channel, boolean notifyDisconnect) {
}

this.sensors.connectionClosed.record();
this.completedReceives.remove(channel);
this.explicitlyMutedChannels.remove(channel);
if (notifyDisconnect)
this.disconnected.put(channel.id(), channel.state());
Expand Down Expand Up @@ -1015,7 +1040,7 @@ private KafkaChannel channel(SelectionKey key) {
* Check if given channel has a completed receive
*/
private boolean hasCompletedReceive(KafkaChannel channel) {
return completedReceives.containsKey(channel);
return completedReceives.containsKey(channel.id());
}

/**
Expand All @@ -1025,7 +1050,7 @@ private void addToCompletedReceives(KafkaChannel channel, NetworkReceive network
if (hasCompletedReceive(channel))
throw new IllegalStateException("Attempting to add second completed receive to channel " + channel.id());

this.completedReceives.put(channel, networkReceive);
this.completedReceives.put(channel.id(), networkReceive);
sensors.recordCompletedReceive(channel.id(), networkReceive.size(), currentTimeMs);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -341,6 +342,36 @@ public void testEmptyRequest() throws Exception {
assertEquals("", blockingRequest(node, ""));
}

@Test
public void testClearCompletedSendsAndReceives() throws Exception {
int bufferSize = 1024;
String node = "0";
InetSocketAddress addr = new InetSocketAddress("localhost", server.port);
connect(node, addr);
String request = TestUtils.randomString(bufferSize);
selector.send(createSend(node, request));
boolean sent = false;
boolean received = false;
while (!sent || !received) {
selector.poll(1000L);
assertEquals("No disconnects should have occurred.", 0, selector.disconnected().size());
if (!selector.completedSends().isEmpty()) {
assertEquals(1, selector.completedSends().size());
selector.clearCompletedSends();
assertEquals(0, selector.completedSends().size());
sent = true;
}

if (!selector.completedReceives().isEmpty()) {
assertEquals(1, selector.completedReceives().size());
assertEquals(request, asString(selector.completedReceives().iterator().next()));
selector.clearCompletedReceives();
assertEquals(0, selector.completedReceives().size());
received = true;
}
}
}

@Test(expected = IllegalStateException.class)
public void testExistingConnectionId() throws IOException {
blockingConnect("0");
Expand Down Expand Up @@ -904,6 +935,60 @@ public void testWriteCompletesSendWithNoBytesWritten() throws IOException {
assertEquals(asList(send), selector.completedSends());
}

/**
* Ensure that no errors are thrown if channels are closed while processing multiple completed receives
*/
@Test
public void testChannelCloseWhileProcessingReceives() throws Exception {
int numChannels = 4;
Map<String, KafkaChannel> channels = TestUtils.fieldValue(selector, Selector.class, "channels");
Set<SelectionKey> selectionKeys = new HashSet<>();
for (int i = 0; i < numChannels; i++) {
String id = String.valueOf(i);
KafkaChannel channel = mock(KafkaChannel.class);
channels.put(id, channel);
when(channel.id()).thenReturn(id);
when(channel.state()).thenReturn(ChannelState.READY);
when(channel.isConnected()).thenReturn(true);
when(channel.ready()).thenReturn(true);
when(channel.read()).thenReturn(1L);

SelectionKey selectionKey = mock(SelectionKey.class);
when(channel.selectionKey()).thenReturn(selectionKey);
when(selectionKey.isValid()).thenReturn(true);
when(selectionKey.readyOps()).thenReturn(SelectionKey.OP_READ);
selectionKey.attach(channel);
selectionKeys.add(selectionKey);

NetworkReceive receive = mock(NetworkReceive.class);
when(receive.source()).thenReturn(id);
when(receive.size()).thenReturn(10);
when(receive.bytesRead()).thenReturn(1);
when(receive.payload()).thenReturn(ByteBuffer.allocate(10));
when(channel.maybeCompleteReceive()).thenReturn(receive);
}

selector.pollSelectionKeys(selectionKeys, false, System.nanoTime());
assertEquals(numChannels, selector.completedReceives().size());
Set<KafkaChannel> closed = new HashSet<>();
Set<KafkaChannel> notClosed = new HashSet<>();
for (NetworkReceive receive : selector.completedReceives()) {
KafkaChannel channel = selector.channel(receive.source());
assertNotNull(channel);
if (closed.size() < 2) {
selector.close(channel.id());
closed.add(channel);
} else
notClosed.add(channel);
}
assertEquals(notClosed, new HashSet<>(selector.channels()));
closed.forEach(channel -> assertNull(selector.channel(channel.id())));

selector.poll(0);
assertEquals(0, selector.completedReceives().size());
}


private String blockingRequest(String node, String s) throws IOException {
selector.send(createSend(node, s));
selector.poll(1000L);
Expand Down
4 changes: 3 additions & 1 deletion core/src/main/scala/kafka/network/SocketServer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ private[kafka] class Processor(val id: Int,
}
}

private def processException(errorMessage: String, throwable: Throwable): Unit = {
private[network] def processException(errorMessage: String, throwable: Throwable): Unit = {
throwable match {
case e: ControlThrowable => throw e
case e => error(errorMessage, e)
Expand Down Expand Up @@ -968,6 +968,7 @@ private[kafka] class Processor(val id: Int,
processChannelException(receive.source, s"Exception while processing request from ${receive.source}", e)
}
}
selector.clearCompletedReceives()
}

private def processCompletedSends(): Unit = {
Expand All @@ -991,6 +992,7 @@ private[kafka] class Processor(val id: Int,
s"Exception while processing completed send to ${send.destination}", e)
}
}
selector.clearCompletedSends()
}

private def updateRequestMetrics(response: RequestChannel.Response): Unit = {
Expand Down
87 changes: 68 additions & 19 deletions core/src/test/scala/unit/kafka/network/SocketServerTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1597,6 +1597,8 @@ class SocketServerTest {
testableSelector.waitForOperations(SelectorOperation.Poll, 1)

testableSelector.waitForOperations(SelectorOperation.CloseSelector, 1)
assertEquals(1, testableServer.uncaughtExceptions)
testableServer.uncaughtExceptions = 0
})
}

Expand Down Expand Up @@ -1675,6 +1677,7 @@ class SocketServerTest {
testWithServer(testableServer)
} finally {
shutdownServerAndMetrics(testableServer)
assertEquals(0, testableServer.uncaughtExceptions)
}
}

Expand Down Expand Up @@ -1730,6 +1733,7 @@ class SocketServerTest {
new Metrics, time, credentialProvider) {

@volatile var selector: Option[TestableSelector] = None
@volatile var uncaughtExceptions = 0

override def newProcessor(id: Int, requestChannel: RequestChannel, connectionQuotas: ConnectionQuotas, listenerName: ListenerName,
protocol: SecurityProtocol, memoryPool: MemoryPool): Processor = {
Expand All @@ -1742,6 +1746,12 @@ class SocketServerTest {
selector = Some(testableSelector)
testableSelector
}

override private[network] def processException(errorMessage: String, throwable: Throwable): Unit = {
if (errorMessage.contains("uncaught exception"))
uncaughtExceptions += 1
super.processException(errorMessage, throwable)
}
}
}

Expand Down Expand Up @@ -1794,27 +1804,69 @@ class SocketServerTest {
// Enable data from `Selector.poll()` to be deferred to a subsequent poll() until
// the number of elements of that type reaches `minPerPoll`. This enables tests to verify
// that failed processing doesn't impact subsequent processing within the same iteration.
class PollData[T] {
abstract class PollData[T] {
var minPerPoll = 1
val deferredValues = mutable.Buffer[T]()
val currentPollValues = mutable.Buffer[T]()
def update(newValues: mutable.Buffer[T]): Unit = {
if (currentPollValues.nonEmpty || deferredValues.size + newValues.size >= minPerPoll) {

/**
* Process new results and return the results for the current poll if at least
* `minPerPoll` results are available including any deferred results. Otherwise
* add the provided values to the deferred set and return an empty buffer. This allows
* tests to process `minPerPoll` elements as the results of a single poll iteration.
*/
protected def update(newValues: mutable.Buffer[T]): mutable.Buffer[T] = {
val currentPollValues = mutable.Buffer[T]()
if (deferredValues.size + newValues.size >= minPerPoll) {
if (deferredValues.nonEmpty) {
currentPollValues ++= deferredValues
deferredValues.clear()
}
currentPollValues ++= newValues
} else
deferredValues ++= newValues

currentPollValues
}
def reset(): Unit = {
currentPollValues.clear()

/**
* Process results from the appropriate buffer in Selector and update the buffer to either
* defer and return nothing or return all results including previously deferred values.
*/
def updateResults(): Unit
}

class CompletedReceivesPollData(selector: TestableSelector) extends PollData[NetworkReceive] {
val completedReceivesMap: util.Map[String, NetworkReceive] = JTestUtils.fieldValue(selector, classOf[Selector], "completedReceives")

override def updateResults(): Unit = {
val currentReceives = update(selector.completedReceives.asScala.toBuffer)
completedReceivesMap.clear()
currentReceives.foreach { receive =>
val channelOpt = Option(selector.channel(receive.source)).orElse(Option(selector.closingChannel(receive.source)))
channelOpt.foreach { channel => completedReceivesMap.put(channel.id, receive) }
}
}
}
val cachedCompletedReceives = new PollData[NetworkReceive]()
val cachedCompletedSends = new PollData[Send]()
val cachedDisconnected = new PollData[(String, ChannelState)]()

class CompletedSendsPollData(selector: TestableSelector) extends PollData[Send] {
override def updateResults(): Unit = {
val currentSends = update(selector.completedSends.asScala)
selector.completedSends.clear()
currentSends.foreach { selector.completedSends.add }
}
}

class DisconnectedPollData(selector: TestableSelector) extends PollData[(String, ChannelState)] {
override def updateResults(): Unit = {
val currentDisconnected = update(selector.disconnected.asScala.toBuffer)
selector.disconnected.clear()
currentDisconnected.foreach { case (channelId, state) => selector.disconnected.put(channelId, state) }
}
}

val cachedCompletedReceives = new CompletedReceivesPollData(this)
val cachedCompletedSends = new CompletedSendsPollData(this)
val cachedDisconnected = new DisconnectedPollData(this)
val allCachedPollData = Seq(cachedCompletedReceives, cachedCompletedSends, cachedDisconnected)
val pendingClosingChannels = new ConcurrentLinkedQueue[KafkaChannel]()
@volatile var minWakeupCount = 0
Expand Down Expand Up @@ -1861,20 +1913,23 @@ class SocketServerTest {

override def poll(timeout: Long): Unit = {
try {
assertEquals(0, super.completedReceives().size)
assertEquals(0, super.completedSends().size)

pollCallback.apply()
while (!pendingClosingChannels.isEmpty) {
makeClosing(pendingClosingChannels.poll())
}
allCachedPollData.foreach(_.reset)
runOp(SelectorOperation.Poll, None) {
super.poll(pollTimeoutOverride.getOrElse(timeout))
}
} finally {
super.channels.forEach(allChannels += _.id)
allDisconnectedChannels ++= super.disconnected.asScala.keys
cachedCompletedReceives.update(super.completedReceives.asScala.toBuffer)
cachedCompletedSends.update(super.completedSends.asScala)
cachedDisconnected.update(super.disconnected.asScala.toBuffer)

cachedCompletedReceives.updateResults()
cachedCompletedSends.updateResults()
cachedDisconnected.updateResults()
}
}

Expand All @@ -1899,12 +1954,6 @@ class SocketServerTest {
}
}

override def disconnected: java.util.Map[String, ChannelState] = cachedDisconnected.currentPollValues.toMap.asJava

override def completedSends: java.util.List[Send] = cachedCompletedSends.currentPollValues.asJava

override def completedReceives: java.util.List[NetworkReceive] = cachedCompletedReceives.currentPollValues.asJava

override def close(id: String): Unit = {
runOp(SelectorOperation.Close, Some(id)) {
super.close(id)
Expand Down