From 9ac6bb0650d4570ac052215297ab90a08bf8e7ea Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Thu, 10 Dec 2020 18:36:31 -0800 Subject: [PATCH 01/10] KAFKA-10842; Use `InterBrokerSendThread` for raft's outbound network channel --- checkstyle/suppressions.xml | 2 +- .../common/requests/AbstractRequest.java | 3 - .../requests/AbstractRequestResponse.java | 7 +- .../common/requests/AbstractResponse.java | 3 - .../clients/admin/KafkaAdminClientTest.java | 12 +- .../kafka/common/InterBrokerSendThread.scala | 40 ++-- .../TransactionMarkerChannelManager.scala | 20 +- .../kafka/raft/KafkaNetworkChannel.scala | 221 +++++------------- ...BrokerToControllerChannelManagerImpl.scala | 7 +- .../kafka/tools/TestRaftRequestHandler.scala | 60 ++++- .../scala/kafka/tools/TestRaftServer.scala | 8 +- .../common/InterBrokerSendThreadTest.scala | 39 ++-- ...ransactionCoordinatorConcurrencyTest.scala | 2 +- .../TransactionMarkerChannelManagerTest.scala | 18 +- .../kafka/raft/KafkaNetworkChannelTest.scala | 57 ++--- .../apache/kafka/raft/KafkaRaftClient.java | 99 ++++++-- .../org/apache/kafka/raft/LeaderState.java | 11 +- .../org/apache/kafka/raft/NetworkChannel.java | 19 +- .../apache/kafka/raft/RaftMessageQueue.java | 57 +++++ .../org/apache/kafka/raft/RaftRequest.java | 5 + .../org/apache/kafka/raft/RequestManager.java | 18 +- .../raft/internals/BlockingMessageQueue.java | 78 +++++++ .../kafka/raft/KafkaRaftClientTest.java | 193 ++++++++------- .../java/org/apache/kafka/raft/MockLog.java | 7 + .../apache/kafka/raft/MockMessageQueue.java | 67 ++++++ .../apache/kafka/raft/MockNetworkChannel.java | 101 +++----- .../kafka/raft/RaftClientTestContext.java | 96 +++++--- .../kafka/raft/RaftEventSimulationTest.java | 81 ++++--- .../apache/kafka/raft/RequestManagerTest.java | 5 +- .../internals/BlockingMessageQueueTest.java | 59 +++++ 30 files changed, 822 insertions(+), 573 deletions(-) create mode 100644 raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java create mode 100644 raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java create mode 100644 raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java create mode 100644 raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml index 853903487ce46..0e348d7b71b26 100644 --- a/checkstyle/suppressions.xml +++ b/checkstyle/suppressions.xml @@ -28,7 +28,7 @@ + files="(Fetcher|Sender|SenderTest|ConsumerCoordinator|KafkaConsumer|KafkaProducer|Utils|TransactionManager|TransactionManagerTest|KafkaAdminClient|NetworkClient|Admin|KafkaRaftClient|KafkaRaftClientTest|RaftClientTestContext).java"/> errorCounts, Errors error) errorCounts.put(error, count + 1); } - protected abstract Message data(); - /** * Parse a response from the provided buffer. The buffer is expected to hold both * the {@link ResponseHeader} as well as the response payload. diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java index d6e9f04e87eed..74041f0bade6a 100644 --- a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java +++ b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java @@ -55,12 +55,12 @@ import org.apache.kafka.common.errors.InvalidTopicException; import org.apache.kafka.common.errors.LeaderNotAvailableException; import org.apache.kafka.common.errors.LogDirNotFoundException; -import org.apache.kafka.common.errors.ThrottlingQuotaExceededException; -import org.apache.kafka.common.errors.TimeoutException; import org.apache.kafka.common.errors.NotLeaderOrFollowerException; import org.apache.kafka.common.errors.OffsetOutOfRangeException; import org.apache.kafka.common.errors.SaslAuthenticationException; import org.apache.kafka.common.errors.SecurityDisabledException; +import org.apache.kafka.common.errors.ThrottlingQuotaExceededException; +import org.apache.kafka.common.errors.TimeoutException; import org.apache.kafka.common.errors.TopicAuthorizationException; import org.apache.kafka.common.errors.TopicDeletionDisabledException; import org.apache.kafka.common.errors.TopicExistsException; @@ -75,9 +75,9 @@ import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult; import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData; import org.apache.kafka.common.message.ApiVersionsResponseData; +import org.apache.kafka.common.message.CreateAclsResponseData; import org.apache.kafka.common.message.CreatePartitionsResponseData; import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult; -import org.apache.kafka.common.message.CreateAclsResponseData; import org.apache.kafka.common.message.CreateTopicsResponseData; import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult; import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResultCollection; @@ -108,16 +108,16 @@ import org.apache.kafka.common.message.ListOffsetsResponseData; import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse; import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData; -import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition; +import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic; import org.apache.kafka.common.message.OffsetDeleteResponseData; import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartition; import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartitionCollection; import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopic; import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopicCollection; import org.apache.kafka.common.protocol.ApiKeys; +import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.common.protocol.Errors; -import org.apache.kafka.common.protocol.Message; import org.apache.kafka.common.quota.ClientQuotaAlteration; import org.apache.kafka.common.quota.ClientQuotaEntity; import org.apache.kafka.common.quota.ClientQuotaFilter; @@ -4937,7 +4937,7 @@ public Map errorCounts() { } @Override - protected Message data() { + public ApiMessage data() { return null; } diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala index 11e1aa8de3920..7327695af680f 100644 --- a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala +++ b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala @@ -16,8 +16,9 @@ */ package kafka.common -import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator} import java.util.Map.Entry +import java.util.concurrent.ConcurrentLinkedQueue +import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator} import kafka.utils.ShutdownableThread import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient, RequestCompletionHandler} @@ -32,17 +33,18 @@ import scala.jdk.CollectionConverters._ /** * Class for inter-broker send thread that utilize a non-blocking network client. */ -abstract class InterBrokerSendThread(name: String, - networkClient: KafkaClient, - time: Time, - isInterruptible: Boolean = true) - extends ShutdownableThread(name, isInterruptible) { - - def generateRequests(): Iterable[RequestAndCompletionHandler] - def requestTimeoutMs: Int +class InterBrokerSendThread( + name: String, + networkClient: KafkaClient, + requestTimeoutMs: Int, + time: Time, + isInterruptible: Boolean = true +) extends ShutdownableThread(name, isInterruptible) { + + private val inboundQueue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() private val unsentRequests = new UnsentRequests - def hasUnsentRequests = unsentRequests.iterator().hasNext + def hasUnsentRequests: Boolean = unsentRequests.iterator().hasNext override def shutdown(): Unit = { initiateShutdown() @@ -51,22 +53,30 @@ abstract class InterBrokerSendThread(name: String, awaitShutdown() } - override def doWork(): Unit = { - var now = time.milliseconds() + def sendRequest(request: RequestAndCompletionHandler): Unit = { + inboundQueue.offer(request) + wakeup() + } - generateRequests().foreach { request => + private def drainInboundQueue(): Unit = { + while (!inboundQueue.isEmpty) { + val request = inboundQueue.poll() val completionHandler = request.handler unsentRequests.put(request.destination, networkClient.newClientRequest( request.destination.idString, request.request, - now, + time.milliseconds(), true, requestTimeoutMs, completionHandler)) } + } + override def doWork(): Unit = { try { + var now = time.milliseconds() + drainInboundQueue() val timeout = sendRequests(now) networkClient.poll(timeout, now) now = time.milliseconds() @@ -198,5 +208,5 @@ private class UnsentRequests { requests.iterator } - def nodes = unsent.keySet + def nodes: java.util.Set[Node] = unsent.keySet } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala index 029ded837a87d..b000a890d32c7 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -127,11 +127,14 @@ class TxnMarkerQueue(@volatile var destination: Node) { def totalNumMarkers(txnTopicPartition: Int): Int = markersPerTxnTopicPartition.get(txnTopicPartition).fold(0)(_.size) } -class TransactionMarkerChannelManager(config: KafkaConfig, - metadataCache: MetadataCache, - networkClient: NetworkClient, - txnStateManager: TransactionStateManager, - time: Time) extends InterBrokerSendThread("TxnMarkerSenderThread-" + config.brokerId, networkClient, time) with Logging with KafkaMetricsGroup { +class TransactionMarkerChannelManager( + config: KafkaConfig, + metadataCache: MetadataCache, + networkClient: NetworkClient, + txnStateManager: TransactionStateManager, + time: Time +) extends InterBrokerSendThread("TxnMarkerSenderThread-" + config.brokerId, networkClient, config.requestTimeoutMs, time) + with Logging with KafkaMetricsGroup { this.logIdent = "[Transaction Marker Channel Manager " + config.brokerId + "]: " @@ -145,8 +148,6 @@ class TransactionMarkerChannelManager(config: KafkaConfig, private val transactionsWithPendingMarkers = new ConcurrentHashMap[String, PendingCompleteTxn] - override val requestTimeoutMs: Int = config.requestTimeoutMs - val writeTxnMarkersRequestVersion: Short = if (config.interBrokerProtocolVersion >= KAFKA_2_8_IV0) 1 else 0 @@ -154,7 +155,10 @@ class TransactionMarkerChannelManager(config: KafkaConfig, newGauge("UnknownDestinationQueueSize", () => markersQueueForUnknownBroker.totalNumMarkers) newGauge("LogAppendRetryQueueSize", () => txnLogAppendRetryQueue.size) - override def generateRequests() = drainQueuedTransactionMarkers() + override def doWork(): Unit = { + drainQueuedTransactionMarkers().foreach(super.sendRequest) + super.doWork() + } override def shutdown(): Unit = { super.shutdown() diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala index 7f769c8463e09..53ed7dcceab19 100644 --- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala +++ b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala @@ -17,39 +17,22 @@ package kafka.raft import java.net.InetSocketAddress -import java.util -import java.util.concurrent.ArrayBlockingQueue import java.util.concurrent.atomic.AtomicInteger +import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} import kafka.utils.Logging -import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient} +import org.apache.kafka.clients.{ClientResponse, KafkaClient} +import org.apache.kafka.common.Node import org.apache.kafka.common.message._ import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} import org.apache.kafka.common.requests._ import org.apache.kafka.common.utils.Time -import org.apache.kafka.common.{KafkaException, Node} -import org.apache.kafka.raft.{NetworkChannel, RaftMessage, RaftRequest, RaftResponse, RaftUtil} +import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil} import scala.collection.mutable -import scala.jdk.CollectionConverters._ object KafkaNetworkChannel { - private[raft] def buildResponse(responseData: ApiMessage): AbstractResponse = { - responseData match { - case voteResponse: VoteResponseData => - new VoteResponse(voteResponse) - case beginEpochResponse: BeginQuorumEpochResponseData => - new BeginQuorumEpochResponse(beginEpochResponse) - case endEpochResponse: EndQuorumEpochResponseData => - new EndQuorumEpochResponse(endEpochResponse) - case fetchResponse: FetchResponseData => - new FetchResponse(fetchResponse) - case _ => - throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData") - } - } - private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = { requestData match { case voteRequest: VoteRequestData => @@ -68,161 +51,76 @@ object KafkaNetworkChannel { } } - private[raft] def responseData(response: AbstractResponse): ApiMessage = { - response match { - case voteResponse: VoteResponse => voteResponse.data - case beginEpochResponse: BeginQuorumEpochResponse => beginEpochResponse.data - case endEpochResponse: EndQuorumEpochResponse => endEpochResponse.data - case fetchResponse: FetchResponse[_] => fetchResponse.data - case _ => throw new IllegalArgumentException(s"Unexpected type for response: $response") - } - } - - private[raft] def requestData(request: AbstractRequest): ApiMessage = { - request match { - case voteRequest: VoteRequest => voteRequest.data - case beginEpochRequest: BeginQuorumEpochRequest => beginEpochRequest.data - case endEpochRequest: EndQuorumEpochRequest => endEpochRequest.data - case fetchRequest: FetchRequest => fetchRequest.data - case _ => throw new IllegalArgumentException(s"Unexpected type for request: $request") - } - } - } -class KafkaNetworkChannel(time: Time, - client: KafkaClient, - clientId: String, - retryBackoffMs: Int, - requestTimeoutMs: Int) extends NetworkChannel with Logging { +class KafkaNetworkChannel( + time: Time, + client: KafkaClient, + requestTimeoutMs: Int +) extends NetworkChannel with Logging { import KafkaNetworkChannel._ type ResponseHandler = AbstractResponse => Unit private val correlationIdCounter = new AtomicInteger(0) - private val pendingInbound = mutable.Map.empty[Long, ResponseHandler] - private val undelivered = new ArrayBlockingQueue[RaftMessage](10) - private val pendingOutbound = new ArrayBlockingQueue[RaftRequest.Outbound](10) private val endpoints = mutable.HashMap.empty[Int, Node] - override def newCorrelationId(): Int = correlationIdCounter.getAndIncrement() - - private def buildClientRequest(req: RaftRequest.Outbound): ClientRequest = { - val destination = req.destinationId.toString - val request = buildRequest(req.data) - val correlationId = req.correlationId - val createdTimeMs = req.createdTimeMs - new ClientRequest(destination, request, correlationId, clientId, createdTimeMs, true, - requestTimeoutMs, null) - } - - override def send(message: RaftMessage): Unit = { - message match { - case request: RaftRequest.Outbound => - if (!pendingOutbound.offer(request)) - throw new KafkaException("Pending outbound queue is full") - - case response: RaftResponse.Outbound => - pendingInbound.remove(response.correlationId).foreach { onResponseReceived: ResponseHandler => - onResponseReceived(buildResponse(response.data)) - } - case _ => - throw new IllegalArgumentException("Unhandled message type " + message) + private val requestThread = new InterBrokerSendThread( + name = "raft-outbound-request-thread", + networkClient = client, + requestTimeoutMs = requestTimeoutMs, + time = time, + isInterruptible = false + ) + + override def send(request: RaftRequest.Outbound): Unit = { + def completeFuture(message: ApiMessage): Unit = { + val response = new RaftResponse.Inbound( + request.correlationId, + message, + request.destinationId + ) + request.completion.complete(response) } - } - private def sendOutboundRequests(currentTimeMs: Long): Unit = { - while (!pendingOutbound.isEmpty) { - val request = pendingOutbound.peek() - endpoints.get(request.destinationId) match { - case Some(node) => - if (client.connectionFailed(node)) { - pendingOutbound.poll() - val apiKey = ApiKeys.forId(request.data.apiKey) - val disconnectResponse = RaftUtil.errorResponse(apiKey, Errors.BROKER_NOT_AVAILABLE) - val success = undelivered.offer(new RaftResponse.Inbound( - request.correlationId, disconnectResponse, request.destinationId)) - if (!success) { - throw new KafkaException("Undelivered queue is full") - } - - // Make sure to reset the connection state - client.ready(node, currentTimeMs) - } else if (client.ready(node, currentTimeMs)) { - pendingOutbound.poll() - val clientRequest = buildClientRequest(request) - client.send(clientRequest, currentTimeMs) - } else { - // We will retry this request on the next poll - return - } - - case None => - pendingOutbound.poll() - val apiKey = ApiKeys.forId(request.data.apiKey) - val responseData = RaftUtil.errorResponse(apiKey, Errors.BROKER_NOT_AVAILABLE) - val response = new RaftResponse.Inbound(request.correlationId, responseData, request.destinationId) - if (!undelivered.offer(response)) - throw new KafkaException("Undelivered queue is full") + def onComplete(clientResponse: ClientResponse): Unit = { + val response = if (clientResponse.authenticationException != null) { + errorResponse(request.data, Errors.CLUSTER_AUTHORIZATION_FAILED) + } else if (clientResponse.wasDisconnected()) { + errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE) + } else { + clientResponse.responseBody.data } + completeFuture(response) } - } - - def getConnectionInfo(nodeId: Int): Node = { - if (!endpoints.contains(nodeId)) - null - else - endpoints(nodeId) - } - - def allConnections(): Set[Node] = { - endpoints.values.toSet - } - private def buildInboundRaftResponse(response: ClientResponse): RaftResponse.Inbound = { - val header = response.requestHeader() - val data = if (response.authenticationException != null) { - RaftUtil.errorResponse(header.apiKey, Errors.CLUSTER_AUTHORIZATION_FAILED) - } else if (response.wasDisconnected) { - RaftUtil.errorResponse(header.apiKey, Errors.BROKER_NOT_AVAILABLE) - } else { - responseData(response.responseBody) - } - new RaftResponse.Inbound(header.correlationId, data, response.destination.toInt) - } + endpoints.get(request.destinationId) match { + case Some(node) => + requestThread.sendRequest(RequestAndCompletionHandler( + destination = node, + request = buildRequest(request.data), + handler = onComplete + )) - private def pollInboundResponses(timeoutMs: Long, inboundMessages: util.List[RaftMessage]): Unit = { - val responses = client.poll(timeoutMs, time.milliseconds()) - for (response <- responses.asScala) { - inboundMessages.add(buildInboundRaftResponse(response)) + case None => + completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)) } } - private def drainInboundRequests(inboundMessages: util.List[RaftMessage]): Unit = { - undelivered.drainTo(inboundMessages) + def pollOnce(): Unit = { + requestThread.doWork() } - private def pollInboundMessages(timeoutMs: Long): util.List[RaftMessage] = { - val pollTimeoutMs = if (!undelivered.isEmpty) { - 0L - } else if (!pendingOutbound.isEmpty) { - retryBackoffMs - } else { - timeoutMs - } - val messages = new util.ArrayList[RaftMessage] - pollInboundResponses(pollTimeoutMs, messages) - drainInboundRequests(messages) - messages + override def newCorrelationId(): Int = { + correlationIdCounter.getAndIncrement() } - override def receive(timeoutMs: Long): util.List[RaftMessage] = { - sendOutboundRequests(time.milliseconds()) - pollInboundMessages(timeoutMs) - } - - override def wakeup(): Unit = { - client.wakeup() + private def errorResponse( + request: ApiMessage, + error: Errors + ): ApiMessage = { + val apiKey = ApiKeys.forId(request.apiKey) + RaftUtil.errorResponse(apiKey, error) } override def updateEndpoint(id: Int, address: InetSocketAddress): Unit = { @@ -230,17 +128,16 @@ class KafkaNetworkChannel(time: Time, endpoints.put(id, node) } - def postInboundRequest(request: AbstractRequest, onResponseReceived: ResponseHandler): Unit = { - val data = requestData(request) - val correlationId = newCorrelationId() - val req = new RaftRequest.Inbound(correlationId, data, time.milliseconds()) - pendingInbound.put(correlationId, onResponseReceived) - if (!undelivered.offer(req)) - throw new KafkaException("Undelivered queue is full") - wakeup() + def start(): Unit = { + requestThread.start() + } + + def initiateShutdown(): Unit = { + requestThread.initiateShutdown() } override def close(): Unit = { + requestThread.shutdown() client.close() } diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala index c0918ad7d03e8..aa212f1c0d60b 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala @@ -164,13 +164,11 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, listenerName: ListenerName, time: Time, threadName: String) - extends InterBrokerSendThread(threadName, networkClient, time, isInterruptible = false) { + extends InterBrokerSendThread(threadName, networkClient, config.controllerSocketTimeoutMs, time, isInterruptible = false) { private var activeController: Option[Node] = None - override def requestTimeoutMs: Int = config.controllerSocketTimeoutMs - - override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + def generateRequests(): Iterable[RequestAndCompletionHandler] = { val requestsToSend = new mutable.Queue[RequestAndCompletionHandler] val topRequest = requestQueue.poll() if (topRequest != null) { @@ -209,6 +207,7 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, override def doWork(): Unit = { if (activeController.isDefined) { + generateRequests().foreach(sendRequest) super.doWork() } else { debug("Controller isn't cached, looking for local metadata changes") diff --git a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala index 4fad0db4f1e83..f8ab30cac8dc2 100644 --- a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala +++ b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala @@ -19,13 +19,15 @@ package kafka.tools import kafka.network.RequestChannel import kafka.network.RequestConvertToJson -import kafka.raft.KafkaNetworkChannel import kafka.server.ApiRequestHandler import kafka.utils.Logging import org.apache.kafka.common.internals.FatalExitError -import org.apache.kafka.common.protocol.{ApiKeys, Errors} -import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse} +import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} +import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} +import org.apache.kafka.common.record.BaseRecords +import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, BeginQuorumEpochResponse, EndQuorumEpochResponse, FetchResponse, VoteResponse} import org.apache.kafka.common.utils.Time +import org.apache.kafka.raft.{KafkaRaftClient, RaftRequest} import scala.jdk.CollectionConverters._ @@ -33,7 +35,7 @@ import scala.jdk.CollectionConverters._ * Simple request handler implementation for use by [[TestRaftServer]]. */ class TestRaftRequestHandler( - networkChannel: KafkaNetworkChannel, + raftClient: KafkaRaftClient[_], requestChannel: RequestChannel, time: Time, ) extends ApiRequestHandler with Logging { @@ -43,13 +45,10 @@ class TestRaftRequestHandler( trace(s"Handling request:${request.requestDesc(true)} from connection ${request.context.connectionId};" + s"securityProtocol:${request.context.securityProtocol},principal:${request.context.principal}") request.header.apiKey match { - case ApiKeys.VOTE - | ApiKeys.BEGIN_QUORUM_EPOCH - | ApiKeys.END_QUORUM_EPOCH - | ApiKeys.FETCH => - val requestBody = request.body[AbstractRequest] - networkChannel.postInboundRequest(requestBody, response => - sendResponse(request, Some(response))) + case ApiKeys.VOTE => handleVote(request) + case ApiKeys.BEGIN_QUORUM_EPOCH => handleBeginQuorumEpoch(request) + case ApiKeys.END_QUORUM_EPOCH => handleEndQuorumEpoch(request) + case ApiKeys.FETCH => handleFetch(request) case _ => throw new IllegalArgumentException(s"Unsupported api key: ${request.header.apiKey}") } @@ -63,6 +62,45 @@ class TestRaftRequestHandler( } } + private def handleVote(request: RequestChannel.Request): Unit = { + handle(request, response => new VoteResponse(response.asInstanceOf[VoteResponseData])) + } + + private def handleBeginQuorumEpoch(request: RequestChannel.Request): Unit = { + handle(request, response => new BeginQuorumEpochResponse(response.asInstanceOf[BeginQuorumEpochResponseData])) + } + + private def handleEndQuorumEpoch(request: RequestChannel.Request): Unit = { + handle(request, response => new EndQuorumEpochResponse(response.asInstanceOf[EndQuorumEpochResponseData])) + } + + private def handleFetch(request: RequestChannel.Request): Unit = { + handle(request, response => new FetchResponse[BaseRecords](response.asInstanceOf[FetchResponseData])) + } + + private def handle( + request: RequestChannel.Request, + buildResponse: ApiMessage => AbstractResponse + ): Unit = { + val requestBody = request.body[AbstractRequest] + val inboundRequest = new RaftRequest.Inbound( + request.header.correlationId, + requestBody.data, + time.milliseconds() + ) + + inboundRequest.completion.whenComplete((response, exception) => { + val res = if (exception != null) { + requestBody.getErrorResponse(exception) + } else { + buildResponse(response.data) + } + sendResponse(request, Some(res)) + }) + + raftClient.handle(inboundRequest) + } + private def handleError(request: RequestChannel.Request, err: Throwable): Unit = { error("Error when handling request: " + s"clientId=${request.header.clientId}, " + diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala index 8d179dccacd02..fb1f7c5b04686 100644 --- a/core/src/main/scala/kafka/tools/TestRaftServer.scala +++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala @@ -95,6 +95,8 @@ class TestRaftServer( val metadataLog = buildMetadataLog(logDir) val networkChannel = buildNetworkChannel(raftConfig, logContext) + networkChannel.start() + val raftClient = buildRaftClient( raftConfig, metadataLog, @@ -114,7 +116,7 @@ class TestRaftServer( raftClient.initialize() val requestHandler = new TestRaftRequestHandler( - networkChannel, + raftClient, socketServer.dataPlaneRequestChannel, time ) @@ -163,9 +165,7 @@ class TestRaftServer( private def buildNetworkChannel(raftConfig: RaftConfig, logContext: LogContext): KafkaNetworkChannel = { val netClient = buildNetworkClient(raftConfig, logContext) - val clientId = s"Raft-${config.brokerId}" - new KafkaNetworkChannel(time, netClient, clientId, - raftConfig.retryBackoffMs, raftConfig.requestTimeoutMs) + new KafkaNetworkChannel(time, netClient, raftConfig.requestTimeoutMs) } private def buildMetadataLog(logDir: File): KafkaMetadataLog = { diff --git a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala index f5110bf0b03e0..0b621d7332a6a 100644 --- a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala +++ b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala @@ -27,8 +27,6 @@ import org.apache.kafka.common.requests.AbstractRequest import org.easymock.EasyMock import org.junit.{Assert, Test} -import scala.collection.mutable - class InterBrokerSendThreadTest { private val time = new MockTime() private val networkClient: NetworkClient = EasyMock.createMock(classOf[NetworkClient]) @@ -37,10 +35,7 @@ class InterBrokerSendThreadTest { @Test def shouldNotSendAnythingWhenNoRequests(): Unit = { - val sendThread = new InterBrokerSendThread("name", networkClient, time) { - override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs - override def generateRequests() = mutable.Iterable.empty - } + val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) // poll is always called but there should be no further invocations on NetworkClient EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong())) @@ -59,13 +54,12 @@ class InterBrokerSendThreadTest { val request = new StubRequestBuilder() val node = new Node(1, "", 8080) val handler = RequestAndCompletionHandler(node, request, completionHandler) - val sendThread = new InterBrokerSendThread("name", networkClient, time) { - override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs - override def generateRequests() = List[RequestAndCompletionHandler](handler) - } + val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) + EasyMock.expect(networkClient.wakeup()) + EasyMock.expect(networkClient.newClientRequest( EasyMock.eq("1"), EasyMock.same(handler.request), @@ -85,6 +79,7 @@ class InterBrokerSendThreadTest { EasyMock.replay(networkClient) + sendThread.sendRequest(handler) sendThread.doWork() EasyMock.verify(networkClient) @@ -95,21 +90,20 @@ class InterBrokerSendThreadTest { def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = { val request = new StubRequestBuilder val node = new Node(1, "", 8080) - val requestAndCompletionHandler = RequestAndCompletionHandler(node, request, completionHandler) - val sendThread = new InterBrokerSendThread("name", networkClient, time) { - override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs - override def generateRequests() = List[RequestAndCompletionHandler](requestAndCompletionHandler) - } + val handler = RequestAndCompletionHandler(node, request, completionHandler) + val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) - val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, requestAndCompletionHandler.handler) + val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) + + EasyMock.expect(networkClient.wakeup()) EasyMock.expect(networkClient.newClientRequest( EasyMock.eq("1"), - EasyMock.same(requestAndCompletionHandler.request), + EasyMock.same(handler.request), EasyMock.anyLong(), EasyMock.eq(true), EasyMock.eq(requestTimeoutMs), - EasyMock.same(requestAndCompletionHandler.handler))) + EasyMock.same(handler.handler))) .andReturn(clientRequest) EasyMock.expect(networkClient.ready(node, time.milliseconds())) @@ -129,6 +123,7 @@ class InterBrokerSendThreadTest { EasyMock.replay(networkClient) + sendThread.sendRequest(handler) sendThread.doWork() EasyMock.verify(networkClient) @@ -140,10 +135,7 @@ class InterBrokerSendThreadTest { val request = new StubRequestBuilder() val node = new Node(1, "", 8080) val handler = RequestAndCompletionHandler(node, request, completionHandler) - val sendThread = new InterBrokerSendThread("name", networkClient, time) { - override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs - override def generateRequests() = List[RequestAndCompletionHandler](handler) - } + val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) val clientRequest = new ClientRequest("dest", request, @@ -155,6 +147,8 @@ class InterBrokerSendThreadTest { handler.handler) time.sleep(1500) + EasyMock.expect(networkClient.wakeup()) + EasyMock.expect(networkClient.newClientRequest( EasyMock.eq("1"), EasyMock.same(handler.request), @@ -180,6 +174,7 @@ class InterBrokerSendThreadTest { EasyMock.replay(networkClient) + sendThread.sendRequest(handler) sendThread.doWork() EasyMock.verify(networkClient) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 3788cb1d65325..9043893a4d49e 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -385,7 +385,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren new WriteTxnMarkersResponse(pidErrorMap) } synchronized { - txnMarkerChannelManager.generateRequests().foreach { requestAndHandler => + txnMarkerChannelManager.drainQueuedTransactionMarkers().foreach { requestAndHandler => val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build() val response = createResponse(request) requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index 441b4e07ee100..5c82a5613007a 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -156,7 +156,7 @@ class TransactionMarkerChannelManagerTest { @Test def shouldGenerateEmptyMapWhenNoRequestsOutstanding(): Unit = { - assertTrue(channelManager.generateRequests().isEmpty) + assertTrue(channelManager.drainQueuedTransactionMarkers().isEmpty) } @Test @@ -194,12 +194,12 @@ class TransactionMarkerChannelManagerTest { val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() - val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.drainQueuedTransactionMarkers().map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap assertEquals(Map(broker1 -> expectedBroker1Request, broker2 -> expectedBroker2Request), requests) - assertTrue(channelManager.generateRequests().isEmpty) + assertTrue(channelManager.drainQueuedTransactionMarkers().isEmpty) } @Test @@ -270,13 +270,13 @@ class TransactionMarkerChannelManagerTest { val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() - val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.drainQueuedTransactionMarkers().map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap assertEquals(Map(broker2 -> expectedBroker2Request), firstDrainedRequests) - val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => + val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.drainQueuedTransactionMarkers().map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap @@ -354,7 +354,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.drainQueuedTransactionMarkers() val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -401,7 +401,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.drainQueuedTransactionMarkers() val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -450,7 +450,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.drainQueuedTransactionMarkers() val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -459,7 +459,7 @@ class TransactionMarkerChannelManagerTest { } // call this again so that append log will be retried - channelManager.generateRequests() + channelManager.drainQueuedTransactionMarkers() EasyMock.verify(txnStateManager) diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala index 699e8faa711bb..bfa1675ff2e56 100644 --- a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala +++ b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala @@ -19,16 +19,15 @@ package kafka.raft import java.net.InetSocketAddress import java.util import java.util.Collections -import java.util.concurrent.atomic.AtomicReference import org.apache.kafka.clients.MockClient.MockMetadataUpdater import org.apache.kafka.clients.{ApiVersion, MockClient, NodeApiVersions} import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData} import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors} -import org.apache.kafka.common.requests.{AbstractResponse, BeginQuorumEpochRequest, EndQuorumEpochRequest, VoteRequest, VoteResponse} +import org.apache.kafka.common.requests.{AbstractResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchResponse, VoteRequest, VoteResponse} import org.apache.kafka.common.utils.{MockTime, Time} import org.apache.kafka.common.{Node, TopicPartition} -import org.apache.kafka.raft.{RaftRequest, RaftResponse, RaftUtil} +import org.apache.kafka.raft.{RaftRequest, RaftUtil} import org.junit.Assert._ import org.junit.{Before, Test} @@ -38,13 +37,11 @@ class KafkaNetworkChannelTest { import KafkaNetworkChannelTest._ private val clusterId = "clusterId" - private val clientId = "clientId" - private val retryBackoffMs = 100 private val requestTimeoutMs = 30000 private val time = new MockTime() private val client = new MockClient(time, new StubMetadataUpdater) private val topicPartition = new TopicPartition("topic", 0) - private val channel = new KafkaNetworkChannel(time, client, clientId, retryBackoffMs, requestTimeoutMs) + private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs) @Before def setupSupportedApis(): Unit = { @@ -74,7 +71,7 @@ class KafkaNetworkChannelTest { channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host, destinationNode.port)) for (apiKey <- RaftApis) { - val response = KafkaNetworkChannel.buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)) + val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)) client.prepareResponseFrom(response, destinationNode, true) sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE) } @@ -109,33 +106,12 @@ class KafkaNetworkChannelTest { for (apiKey <- RaftApis) { val expectedError = Errors.INVALID_REQUEST - val response = KafkaNetworkChannel.buildResponse(buildTestErrorResponse(apiKey, expectedError)) + val response = buildResponse(buildTestErrorResponse(apiKey, expectedError)) client.prepareResponseFrom(response, destinationNode) sendAndAssertErrorResponse(apiKey, destinationId, expectedError) } } - @Test - def testReceiveAndSendInboundRequest(): Unit = { - for (apiKey <- RaftApis) { - val request = KafkaNetworkChannel.buildRequest(buildTestRequest(apiKey)).build() - val responseRef = new AtomicReference[AbstractResponse]() - - channel.postInboundRequest(request, responseRef.set) - val inbound = channel.receive(1000).asScala - assertEquals(1, inbound.size) - - val inboundRequest = inbound.head.asInstanceOf[RaftRequest.Inbound] - val errorResponse = buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST) - val outboundResponse = new RaftResponse.Outbound(inboundRequest.correlationId, errorResponse) - channel.send(outboundResponse) - channel.receive(1000) - - assertNotNull(responseRef.get) - assertEquals(Errors.INVALID_REQUEST, extractError(KafkaNetworkChannel.responseData(responseRef.get))) - } - } - private def sendAndAssertErrorResponse(apiKey: ApiKeys, destinationId: Int, error: Errors): Unit = { @@ -145,10 +121,11 @@ class KafkaNetworkChannelTest { val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs) channel.send(request) - val responses = channel.receive(1000).asScala - assertEquals(1, responses.size) + channel.pollOnce() + + assertTrue(request.completion.isDone) - val response = responses.head.asInstanceOf[RaftResponse.Inbound] + val response = request.completion.get() assertEquals(destinationId, response.sourceId) assertEquals(correlationId, response.correlationId) assertEquals(apiKey, ApiKeys.forId(response.data.apiKey)) @@ -216,6 +193,22 @@ class KafkaNetworkChannelTest { Errors.forCode(code) } + + def buildResponse(responseData: ApiMessage): AbstractResponse = { + responseData match { + case voteResponse: VoteResponseData => + new VoteResponse(voteResponse) + case beginEpochResponse: BeginQuorumEpochResponseData => + new BeginQuorumEpochResponse(beginEpochResponse) + case endEpochResponse: EndQuorumEpochResponseData => + new EndQuorumEpochResponse(endEpochResponse) + case fetchResponse: FetchResponseData => + new FetchResponse(fetchResponse) + case _ => + throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData") + } + } + } object KafkaNetworkChannelTest { diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java index 17bd9584b15f9..7c63401bb75c0 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -23,14 +23,14 @@ import org.apache.kafka.common.message.BeginQuorumEpochRequestData; import org.apache.kafka.common.message.BeginQuorumEpochResponseData; import org.apache.kafka.common.message.DescribeQuorumRequestData; -import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; import org.apache.kafka.common.message.DescribeQuorumResponseData; +import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState; import org.apache.kafka.common.message.EndQuorumEpochRequestData; import org.apache.kafka.common.message.EndQuorumEpochResponseData; import org.apache.kafka.common.message.FetchRequestData; import org.apache.kafka.common.message.FetchResponseData; -import org.apache.kafka.common.message.LeaderChangeMessage.Voter; import org.apache.kafka.common.message.LeaderChangeMessage; +import org.apache.kafka.common.message.LeaderChangeMessage.Voter; import org.apache.kafka.common.message.VoteRequestData; import org.apache.kafka.common.message.VoteResponseData; import org.apache.kafka.common.metrics.Metrics; @@ -55,6 +55,7 @@ import org.apache.kafka.raft.RequestManager.ConnectionState; import org.apache.kafka.raft.internals.BatchAccumulator; import org.apache.kafka.raft.internals.BatchMemoryPool; +import org.apache.kafka.raft.internals.BlockingMessageQueue; import org.apache.kafka.raft.internals.CloseListener; import org.apache.kafka.raft.internals.FuturePurgatory; import org.apache.kafka.raft.internals.KafkaRaftMetrics; @@ -71,6 +72,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.OptionalInt; import java.util.OptionalLong; @@ -141,6 +143,8 @@ public class KafkaRaftClient implements RaftClient { private final FuturePurgatory fetchPurgatory; private final RecordSerde serde; private final MemoryPool memoryPool; + private final RaftMessageQueue messageQueue; + private final List listenerContexts = new ArrayList<>(); private final ConcurrentLinkedQueue> pendingListeners = new ConcurrentLinkedQueue<>(); @@ -158,6 +162,7 @@ public KafkaRaftClient( ) { this(serde, channel, + new BlockingMessageQueue(), log, quorum, new BatchMemoryPool(5, MAX_BATCH_SIZE), @@ -177,6 +182,7 @@ public KafkaRaftClient( public KafkaRaftClient( RecordSerde serde, NetworkChannel channel, + RaftMessageQueue messageQueue, ReplicatedLog log, QuorumState quorum, MemoryPool memoryPool, @@ -194,6 +200,7 @@ public KafkaRaftClient( ) { this.serde = serde; this.channel = channel; + this.messageQueue = messageQueue; this.log = log; this.quorum = quorum; this.memoryPool = memoryPool; @@ -333,7 +340,7 @@ public void initialize() throws IOException { @Override public void register(Listener listener) { pendingListeners.add(listener); - channel.wakeup(); + wakeup(); } private OffsetAndEpoch endOffset() { @@ -779,9 +786,6 @@ private EndQuorumEpochResponseData handleEndQuorumEpochRequest( FollowerState state = quorum.followerStateOrThrow(); if (state.leaderId() == requestLeaderId) { List preferredSuccessors = partitionRequest.preferredSuccessors(); - if (!preferredSuccessors.contains(quorum.localId)) { - return buildEndQuorumEpochResponse(Errors.INCONSISTENT_VOTER_SET); - } long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors); logger.debug("Overriding follower fetch timeout to {} after receiving " + "EndQuorumEpoch request from leader {} in epoch {}", electionBackoffMs, @@ -798,7 +802,7 @@ private long endEpochElectionBackoff(List preferredSuccessors) { // voter has a higher chance to be elected. If the node's priority is highest, become // candidate immediately instead of waiting for next poll. int position = preferredSuccessors.indexOf(quorum.localId); - if (position == 0) { + if (position <= 0) { return 0; } else { return strictExponentialElectionBackoffMs(position, preferredSuccessors.size()); @@ -932,6 +936,7 @@ private CompletableFuture handleFetchRequest( } } + // FIXME: `completionTimeMs`, which can be null logger.trace("Completing delayed fetch from {} starting at offset {} at {}", request.replicaId(), fetchPartition.fetchOffset(), completionTimeMs); @@ -967,6 +972,7 @@ private FetchResponseData tryCompleteFetchRequest( return buildFetchResponse(Errors.NONE, MemoryRecords.EMPTY, divergingEpoch, state.highWatermark()); } else { LogFetchInfo info = log.read(fetchOffset, Isolation.UNCOMMITTED); + if (state.updateReplicaState(replicaId, currentTimeMs, info.startOffsetMetadata)) { onUpdateLeaderHighWatermark(state, currentTimeMs); } @@ -1231,7 +1237,7 @@ private boolean handleTopLevelError(Errors error, RaftResponse.Inbound response) private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) { logger.error("Unexpected error {} in {} response: {}", - error, response.data.apiKey(), response); + error, ApiKeys.forId(response.data.apiKey()), response); return false; } @@ -1263,7 +1269,7 @@ private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) t ConnectionState connection = requestManager.getOrCreate(response.sourceId()); if (handledSuccessfully) { - connection.onResponseReceived(response.correlationId, currentTimeMs); + connection.onResponseReceived(response.correlationId); } else { connection.onResponseError(response.correlationId, currentTimeMs); } @@ -1343,7 +1349,10 @@ private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) thro } else { message = RaftUtil.errorResponse(apiKey, Errors.forException(exception)); } - sendOutboundMessage(new RaftResponse.Outbound(request.correlationId(), message)); + + RaftResponse.Outbound responseMessage = new RaftResponse.Outbound(request.correlationId(), message); + request.completion.complete(responseMessage); + logger.trace("Sent response {} to inbound request {}", responseMessage, request); }); } @@ -1355,17 +1364,17 @@ private void handleInboundMessage(RaftMessage message, long currentTimeMs) throw handleRequest(request, currentTimeMs); } else if (message instanceof RaftResponse.Inbound) { RaftResponse.Inbound response = (RaftResponse.Inbound) message; - handleResponse(response, currentTimeMs); + ConnectionState connection = requestManager.getOrCreate(response.sourceId()); + if (connection.isResponseExpected(response.correlationId)) { + handleResponse(response, currentTimeMs); + } else { + logger.debug("Ignoring response {} since it is no longer needed", response); + } } else { throw new IllegalArgumentException("Unexpected message " + message); } } - private void sendOutboundMessage(RaftMessage message) { - channel.send(message); - logger.trace("Sent outbound message: {}", message); - } - /** * Attempt to send a request. Return the time to wait before the request can be retried. */ @@ -1383,7 +1392,32 @@ private long maybeSendRequest( if (connection.isReady(currentTimeMs)) { int correlationId = channel.newCorrelationId(); ApiMessage request = requestSupplier.get(); - sendOutboundMessage(new RaftRequest.Outbound(correlationId, request, destinationId, currentTimeMs)); + + RaftRequest.Outbound requestMessage = new RaftRequest.Outbound( + correlationId, + request, + destinationId, + currentTimeMs + ); + + requestMessage.completion.whenComplete((response, exception) -> { + if (exception != null) { + ApiKeys api = ApiKeys.forId(request.apiKey()); + Errors error = Errors.forException(exception); + ApiMessage errorResponse = RaftUtil.errorResponse(api, error); + + response = new RaftResponse.Inbound( + correlationId, + errorResponse, + destinationId + ); + } + + messageQueue.offer(response); + }); + + channel.send(requestMessage); + logger.trace("Sent outbound request: {}", requestMessage); connection.onRequestSent(correlationId, currentTimeMs); return Long.MAX_VALUE; } @@ -1781,6 +1815,26 @@ private boolean maybeCompleteShutdown(long currentTimeMs) { return false; } + private void wakeup() { + messageQueue.wakeup(); + } + + /** + * Handle an inbound request. The response will be returned through + * {@link RaftRequest.Inbound#completion}. + * + * @param request The inbound request + */ + public void handle(RaftRequest.Inbound request) { + messageQueue.offer(Objects.requireNonNull(request)); + } + + /** + * Poll for new events. This allows the client to handle inbound + * requests and send any needed outbound requests. + * + * @throws IOException for any IO errors encountered + */ public void poll() throws IOException { pollListeners(); @@ -1792,14 +1846,13 @@ public void poll() throws IOException { long pollTimeoutMs = pollCurrentState(currentTimeMs); kafkaRaftMetrics.updatePollStart(currentTimeMs); - List inboundMessages = channel.receive(pollTimeoutMs); + RaftMessage message = messageQueue.poll(pollTimeoutMs); currentTimeMs = time.milliseconds(); kafkaRaftMetrics.updatePollEnd(currentTimeMs); - for (RaftMessage message : inboundMessages) { + if (message != null) { handleInboundMessage(message, currentTimeMs); - currentTimeMs = time.milliseconds(); } } @@ -1819,7 +1872,7 @@ public Long scheduleAppend(int epoch, List records) { // the linger timeout so that it can schedule its own wakeup in case // there are no additional appends. if (isFirstAppend || accumulator.needsDrain(time.milliseconds())) { - channel.wakeup(); + wakeup(); } return offset; } @@ -1829,7 +1882,7 @@ public CompletableFuture shutdown(int timeoutMs) { logger.info("Beginning graceful shutdown"); CompletableFuture shutdownComplete = new CompletableFuture<>(); shutdown.set(new GracefulShutdown(timeoutMs, shutdownComplete)); - channel.wakeup(); + wakeup(); return shutdownComplete; } @@ -1991,7 +2044,7 @@ public synchronized void onClose(BatchReader reader) { if (lastSent == reader) { lastSent = null; - channel.wakeup(); + wakeup(); } } diff --git a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java index 1d810a9c22aad..c44f0f030cf23 100644 --- a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java +++ b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java @@ -159,14 +159,15 @@ public boolean updateReplicaState(int replicaId, public List nonLeaderVotersByDescendingFetchOffset() { return followersByDescendingFetchOffset().stream() - .filter(state -> state.nodeId != localId) - .map(state -> state.nodeId) - .collect(Collectors.toList()); + .filter(state -> state.nodeId != localId) + .map(state -> state.nodeId) + .collect(Collectors.toList()); } private List followersByDescendingFetchOffset() { - return new ArrayList<>(this.voterReplicaStates.values()) - .stream().sorted().collect(Collectors.toList()); + return new ArrayList<>(this.voterReplicaStates.values()).stream() + .sorted() + .collect(Collectors.toList()); } private boolean updateEndOffset(ReplicaState state, diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java index 421dae2b46fbb..c097f7cc40f79 100644 --- a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java +++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java @@ -18,11 +18,10 @@ import java.io.Closeable; import java.net.InetSocketAddress; -import java.util.List; /** * A simple network interface with few assumptions. We do not assume ordering - * of requests or even that every request will receive a response. + * of requests or even that every outbound request will receive a response. */ public interface NetworkChannel extends Closeable { @@ -37,21 +36,7 @@ public interface NetworkChannel extends Closeable { * or a response to a request that was received through {@link #receive(long)} * (i.e. an instance of {@link org.apache.kafka.raft.RaftResponse.Outbound}). */ - void send(RaftMessage message); - - /** - * Receive inbound messages. These could contain either inbound requests - * (i.e. instances of {@link org.apache.kafka.raft.RaftRequest.Inbound}) - * or responses to outbound requests sent through {@link #send(RaftMessage)} - * (i.e. instances of {@link org.apache.kafka.raft.RaftResponse.Inbound}). - */ - List receive(long timeoutMs); - - /** - * Wakeup the channel if it is blocking in {@link #receive(long)}. This will cause - * the call to immediately return with whatever messages are available. - */ - void wakeup(); + void send(RaftRequest.Outbound request); /** * Update connection information for the given id. diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java new file mode 100644 index 0000000000000..ea46f5ed0cf17 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +/** + * This class is used to serialize inbound requests or responses to outbound requests. + * It basically just allows us to wrap a blocking queue so that we can have a mocked + * implementation which does not depend on system time. + * + * See {@link org.apache.kafka.raft.internals.BlockingMessageQueue}. + */ +public interface RaftMessageQueue { + + /** + * Block for the arrival of a new message. + * + * @param timeoutMs timeout in milliseconds to wait for a new event + * @return the event or null if the timeout was reached + */ + RaftMessage poll(long timeoutMs); + + /** + * Offer a new message to the queue. + * + * @param message the message to deliver + * @throws IllegalStateException if the queue cannot accept the message + */ + void offer(RaftMessage message); + + /** + * Check whether there are pending messages awaiting delivery. + * + * @return if there are no pending messages to deliver + */ + boolean isEmpty(); + + /** + * Wakeup the thread blocking in {@link #poll(long)}. This will cause + * {@link #poll(long)} to return null if no messages are available. + */ + void wakeup(); + +} diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java index f607027dced5b..28e63c14ce69f 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java @@ -18,6 +18,8 @@ import org.apache.kafka.common.protocol.ApiMessage; +import java.util.concurrent.CompletableFuture; + public abstract class RaftRequest implements RaftMessage { protected final int correlationId; protected final ApiMessage data; @@ -44,6 +46,8 @@ public long createdTimeMs() { } public static class Inbound extends RaftRequest { + public final CompletableFuture completion = new CompletableFuture<>(); + public Inbound(int correlationId, ApiMessage data, long createdTimeMs) { super(correlationId, data, createdTimeMs); } @@ -60,6 +64,7 @@ public String toString() { public static class Outbound extends RaftRequest { private final int destinationId; + public final CompletableFuture completion = new CompletableFuture<>(); public Outbound(int correlationId, ApiMessage data, int destinationId, long createdTimeMs) { super(correlationId, data, createdTimeMs); diff --git a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java index aec05c7835710..5a5cb003c25af 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java +++ b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java @@ -20,8 +20,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.OptionalInt; +import java.util.OptionalLong; import java.util.Random; import java.util.Set; @@ -103,7 +103,7 @@ public class ConnectionState { private State state = State.READY; private long lastSendTimeMs = 0L; private long lastFailTimeMs = 0L; - private Optional inFlightCorrelationId = Optional.empty(); + private OptionalLong inFlightCorrelationId = OptionalLong.empty(); public ConnectionState(long id) { this.id = id; @@ -160,28 +160,32 @@ long remainingBackoffMs(long timeMs) { } } + boolean isResponseExpected(long correlationId) { + return inFlightCorrelationId.isPresent() && inFlightCorrelationId.getAsLong() == correlationId; + } + void onResponseError(long correlationId, long timeMs) { inFlightCorrelationId.ifPresent(inflightRequestId -> { if (inflightRequestId == correlationId) { lastFailTimeMs = timeMs; state = State.BACKING_OFF; - inFlightCorrelationId = Optional.empty(); + inFlightCorrelationId = OptionalLong.empty(); } }); } - void onResponseReceived(long correlationId, long timeMs) { + void onResponseReceived(long correlationId) { inFlightCorrelationId.ifPresent(inflightRequestId -> { if (inflightRequestId == correlationId) { state = State.READY; - inFlightCorrelationId = Optional.empty(); + inFlightCorrelationId = OptionalLong.empty(); } }); } void onRequestSent(long correlationId, long timeMs) { lastSendTimeMs = timeMs; - inFlightCorrelationId = Optional.of(correlationId); + inFlightCorrelationId = OptionalLong.of(correlationId); state = State.AWAITING_REQUEST; } @@ -192,7 +196,7 @@ void onRequestSent(long correlationId, long timeMs) { */ void reset() { state = State.READY; - inFlightCorrelationId = Optional.empty(); + inFlightCorrelationId = OptionalLong.empty(); } @Override diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java new file mode 100644 index 0000000000000..9fe99f82080b1 --- /dev/null +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.raft.RaftMessage; +import org.apache.kafka.raft.RaftMessageQueue; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +public class BlockingMessageQueue implements RaftMessageQueue { + private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private final AtomicInteger size = new AtomicInteger(0); + + @Override + public RaftMessage poll(long timeoutMs) { + try { + RaftEvent event = queue.poll(timeoutMs, TimeUnit.MILLISECONDS); + if (event instanceof MessageReceived) { + size.decrementAndGet(); + return ((MessageReceived) event).message; + } else { + return null; + } + } catch (InterruptedException e) { + throw new InterruptException(e); + } + + } + + @Override + public void offer(RaftMessage message) { + queue.add(new MessageReceived(message)); + size.incrementAndGet(); + } + + @Override + public boolean isEmpty() { + return size.get() == 0; + } + + @Override + public void wakeup() { + queue.add(Wakeup.INSTANCE); + } + + public interface RaftEvent { + } + + static final class MessageReceived implements RaftEvent { + private final RaftMessage message; + private MessageReceived(RaftMessage message) { + this.message = message; + } + } + + static final class Wakeup implements RaftEvent { + public static final Wakeup INSTANCE = new Wakeup(); + } + +} diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java index 5a0e902e15bd5..43fb2301414f9 100644 --- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java @@ -109,7 +109,7 @@ public void testRejectVotesFromSameEpochAfterResigningLeadership() throws Except // other voter in the same epoch, even if it has caught up to the same position. context.deliverRequest(context.voteRequest(epoch, remoteId, context.log.lastFetchedEpoch(), context.log.endOffset().offset)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(localId), false); } @@ -134,7 +134,7 @@ public void testRejectVotesFromSameEpochAfterResigningCandidacy() throws Excepti // other voter in the same epoch, even if it has caught up to the same position. context.deliverRequest(context.voteRequest(epoch, remoteId, context.log.lastFetchedEpoch(), context.log.endOffset().offset)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), false); } @@ -158,13 +158,13 @@ public void testInitializeAsResignedLeaderFromStateStore() throws Exception { context.client.poll(); assertEquals(Long.MAX_VALUE, context.client.scheduleAppend(epoch, Arrays.asList("a", "b"))); - context.pollUntilSend(); + context.pollUntilRequest(); int correlationId = context.assertSentEndQuorumEpochRequest(epoch, 1); context.deliverResponse(correlationId, 1, context.endEpochResponse(epoch, OptionalInt.of(localId))); context.client.poll(); context.time.sleep(context.electionTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(epoch + 1, localId); context.assertSentVoteRequest(epoch + 1, 0, 0L, 1); } @@ -187,7 +187,7 @@ public void testEndQuorumEpochRetriesWhileResigned() throws Exception { .withElectedLeader(epoch, localId) .build(); - context.pollUntilSend(); + context.pollUntilRequest(); List requests = context.collectEndQuorumRequests(epoch, Utils.mkSet(voter1, voter2)); assertEquals(2, requests.size()); @@ -203,7 +203,7 @@ public void testEndQuorumEpochRetriesWhileResigned() throws Exception { // retried request from the voter that hasn't responded yet. int nonRespondedId = requests.get(1).destinationId(); context.time.sleep(6000); - context.pollUntilSend(); + context.pollUntilRequest(); List retries = context.collectEndQuorumRequests(epoch, Utils.mkSet(nonRespondedId)); assertEquals(1, retries.size()); } @@ -254,7 +254,7 @@ public void testInitializeAsCandidateFromStateStore() throws Exception { assertEquals(0L, context.log.endOffset().offset); // The candidate will resume the election after reinitialization - context.pollUntilSend(); + context.pollUntilRequest(); List voteRequests = context.collectVoteRequests(2, 0, 0); assertEquals(2, voteRequests.size()); } @@ -269,7 +269,7 @@ public void testInitializeAsCandidateAndBecomeLeader() throws Exception { context.assertUnknownLeader(0); context.time.sleep(2 * context.electionTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(1, context.localId); int correlationId = context.assertSentVoteRequest(1, 0, 0L, 1); @@ -309,7 +309,7 @@ public void testInitializeAsCandidateAndBecomeLeaderQuorumOfThree() throws Excep context.assertUnknownLeader(0); context.time.sleep(2 * context.electionTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(1, context.localId); int correlationId = context.assertSentVoteRequest(1, 0, 0L, 2); @@ -350,8 +350,7 @@ public void testHandleBeginQuorumRequest() throws Exception { .build(); context.deliverRequest(context.beginEpochRequest(votedCandidateEpoch, otherNodeId)); - - context.client.poll(); + context.pollUntilResponse(); context.assertElectedLeader(votedCandidateEpoch, otherNodeId); @@ -370,8 +369,7 @@ public void testHandleBeginQuorumResponse() throws Exception { .build(); context.deliverRequest(context.beginEpochRequest(leaderEpoch + 1, otherNodeId)); - - context.client.poll(); + context.pollUntilResponse(); context.assertElectedLeader(leaderEpoch + 1, otherNodeId); } @@ -431,7 +429,7 @@ public void testEndQuorumIgnoredAsLeaderIfOlderEpoch() throws Exception { // One of the voters may have sent EndQuorumEpoch from an earlier epoch context.deliverRequest(context.endEpochRequest(epoch - 2, voter2, Arrays.asList(context.localId, voter3))); - context.client.poll(); + context.pollUntilResponse(); context.assertSentEndQuorumEpochResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.of(context.localId)); // We should still be leader as long as fetch timeout has not expired @@ -455,7 +453,7 @@ public void testEndQuorumStartsNewElectionImmediatelyIfFollowerUnattached() thro context.deliverRequest(context.endEpochRequest(epoch, voter2, Arrays.asList(context.localId, voter3))); - context.client.poll(); + context.pollUntilResponse(); context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(voter2)); // Should become a candidate immediately @@ -486,7 +484,7 @@ public void testAccumulatorClearedAfterBecomingFollower() throws Exception { assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); context.deliverRequest(context.beginEpochRequest(epoch + 1, otherNodeId)); - context.client.poll(); + context.pollUntilResponse(); context.assertElectedLeader(epoch + 1, otherNodeId); Mockito.verify(memoryPool).release(buffer); @@ -516,7 +514,7 @@ public void testAccumulatorClearedAfterBecomingVoted() throws Exception { assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, context.log.endOffset().offset)); - context.client.poll(); + context.pollUntilResponse(); context.assertVotedCandidate(epoch + 1, otherNodeId); Mockito.verify(memoryPool).release(buffer); @@ -545,7 +543,7 @@ public void testAccumulatorClearedAfterBecomingUnattached() throws Exception { assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 0L)); - context.client.poll(); + context.pollUntilResponse(); context.assertUnknownLeader(epoch + 1); Mockito.verify(memoryPool).release(buffer); @@ -571,14 +569,14 @@ public void testChannelWokenUpIfLingerTimeoutReachedWithoutAppend() throws Excep int epoch = context.currentEpoch(); assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); - assertTrue(context.channel.wakeupRequested()); + assertTrue(context.messageQueue.wakeupRequested()); context.client.poll(); - assertEquals(OptionalLong.of(lingerMs), context.channel.lastReceiveTimeout()); + assertEquals(OptionalLong.of(lingerMs), context.messageQueue.lastPollTimeoutMs()); context.time.sleep(20); context.client.poll(); - assertEquals(OptionalLong.of(30), context.channel.lastReceiveTimeout()); + assertEquals(OptionalLong.of(30), context.messageQueue.lastPollTimeoutMs()); context.time.sleep(30); context.client.poll(); @@ -605,15 +603,15 @@ public void testChannelWokenUpIfLingerTimeoutReachedDuringAppend() throws Except int epoch = context.currentEpoch(); assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a"))); - assertTrue(context.channel.wakeupRequested()); + assertTrue(context.messageQueue.wakeupRequested()); context.client.poll(); - assertFalse(context.channel.wakeupRequested()); - assertEquals(OptionalLong.of(lingerMs), context.channel.lastReceiveTimeout()); + assertFalse(context.messageQueue.wakeupRequested()); + assertEquals(OptionalLong.of(lingerMs), context.messageQueue.lastPollTimeoutMs()); context.time.sleep(lingerMs); assertEquals(2L, context.client.scheduleAppend(epoch, singletonList("b"))); - assertTrue(context.channel.wakeupRequested()); + assertTrue(context.messageQueue.wakeupRequested()); context.client.poll(); assertEquals(3L, context.log.endOffset().offset); @@ -633,7 +631,7 @@ public void testHandleEndQuorumRequest() throws Exception { context.deliverRequest(context.endEpochRequest(leaderEpoch, oldLeaderId, Collections.singletonList(context.localId))); - context.client.poll(); + context.pollUntilResponse(); context.assertSentEndQuorumEpochResponse(Errors.NONE, leaderEpoch, OptionalInt.of(oldLeaderId)); context.client.poll(); @@ -655,19 +653,19 @@ public void testHandleEndQuorumRequestWithLowerPriorityToBecomeLeader() throws E context.deliverRequest(context.endEpochRequest(leaderEpoch, oldLeaderId, Arrays.asList(preferredNextLeader, context.localId))); - context.pollUntilSend(); + context.pollUntilResponse(); context.assertSentEndQuorumEpochResponse(Errors.NONE, leaderEpoch, OptionalInt.of(oldLeaderId)); // The election won't trigger by one round retry backoff context.time.sleep(1); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertSentFetchRequest(leaderEpoch, 0, 0); context.time.sleep(context.retryBackoffMs); - context.pollUntilSend(); + context.pollUntilRequest(); List voteRequests = context.collectVoteRequests(leaderEpoch + 1, 0, 0); assertEquals(2, voteRequests.size()); @@ -687,7 +685,7 @@ public void testVoteRequestTimeout() throws Exception { context.assertUnknownLeader(0); context.time.sleep(2 * context.electionTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(epoch, context.localId); int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); @@ -696,13 +694,12 @@ public void testVoteRequestTimeout() throws Exception { context.client.poll(); int retryCorrelationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); - // Even though we have resent the request, we should still accept the response to - // the first request if it arrives late. + // We will ignore the timed out response if it arrives late context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); context.client.poll(); - context.assertElectedLeader(epoch, context.localId); + context.assertVotedCandidate(epoch, context.localId); - // If the second request arrives later, it should have no effect + // Become leader after receiving the retry response context.deliverResponse(retryCorrelationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1)); context.client.poll(); context.assertElectedLeader(epoch, context.localId); @@ -720,8 +717,7 @@ public void testHandleValidVoteRequestAsFollower() throws Exception { .build(); context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1)); - - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), true); @@ -741,8 +737,7 @@ public void testHandleVoteRequestAsFollowerWithElectedLeader() throws Exception .build(); context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1)); - - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(electedLeaderId), false); @@ -762,8 +757,7 @@ public void testHandleVoteRequestAsFollowerWithVotedCandidate() throws Exception .build(); context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1)); - - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), false); context.assertVotedCandidate(epoch, votedCandidateId); @@ -781,8 +775,7 @@ public void testHandleInvalidVoteRequestWithOlderEpoch() throws Exception { .build(); context.deliverRequest(context.voteRequest(epoch - 1, otherNodeId, epoch - 2, 1)); - - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.empty(), false); context.assertUnknownLeader(epoch); @@ -801,8 +794,7 @@ public void testHandleInvalidVoteRequestAsObserver() throws Exception { .build(); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 1)); - - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.empty(), false); context.assertUnknownLeader(epoch); @@ -841,7 +833,8 @@ public void testListenerCommitCallbackAfterLeaderWrite() throws Exception { // Let follower send a fetch to initialize the high watermark, // note the offset 0 would be a control message for becoming the leader context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0L, epoch, 500)); - context.pollUntilSend(); + context.pollUntilResponse(); + context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId)); assertEquals(OptionalLong.of(0L), context.client.highWatermark()); List records = Arrays.asList("a", "b", "c"); @@ -851,14 +844,14 @@ public void testListenerCommitCallbackAfterLeaderWrite() throws Exception { // Let the follower send a fetch, it should advance the high watermark context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 500)); - context.pollUntilSend(); + context.pollUntilResponse(); + context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId)); assertEquals(OptionalLong.of(1L), context.client.highWatermark()); assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset()); // Let the follower send another fetch from offset 4 context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 4L, epoch, 500)); - context.client.poll(); - assertEquals(OptionalLong.of(4L), context.client.highWatermark()); + context.pollUntil(() -> context.client.highWatermark().equals(OptionalLong.of(4L))); assertEquals(records, context.listener.commitWithLastOffset(offset)); } @@ -873,7 +866,7 @@ public void testCandidateIgnoreVoteRequestOnSameEpoch() throws Exception { .withVotedCandidate(leaderEpoch, localId) .build(); - context.pollUntilSend(); + context.pollUntilRequest(); context.deliverRequest(context.voteRequest(leaderEpoch, otherNodeId, leaderEpoch - 1, 1)); context.client.poll(); @@ -898,7 +891,7 @@ public void testRetryElection() throws Exception { context.assertUnknownLeader(0); context.time.sleep(2 * context.electionTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(epoch, context.localId); // Quorum size is two. If the other member rejects, then we need to schedule a revote. @@ -919,7 +912,7 @@ public void testRetryElection() throws Exception { // After jitter expires, we become a candidate again context.time.sleep(1); context.client.poll(); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(epoch + 1, context.localId); context.assertSentVoteRequest(epoch + 1, 0, 0L, 1); } @@ -937,7 +930,7 @@ public void testInitializeAsFollowerEmptyLog() throws Exception { context.assertElectedLeader(epoch, otherNodeId); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertSentFetchRequest(epoch, 0L, 0); } @@ -957,7 +950,7 @@ public void testInitializeAsFollowerNonEmptyLog() throws Exception { context.assertElectedLeader(epoch, otherNodeId); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertSentFetchRequest(epoch, 1L, lastEpoch); } @@ -975,12 +968,12 @@ public void testVoterBecomeCandidateAfterFetchTimeout() throws Exception { .build(); context.assertElectedLeader(epoch, otherNodeId); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertSentFetchRequest(epoch, 1L, lastEpoch); context.time.sleep(context.fetchTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertSentVoteRequest(epoch + 1, lastEpoch, 1L, 1); context.assertVotedCandidate(epoch + 1, context.localId); @@ -996,7 +989,7 @@ public void testInitializeObserverNoPreviousState() throws Exception { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); - context.pollUntilSend(); + context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); @@ -1017,7 +1010,7 @@ public void testObserverQuorumDiscoveryFailure() throws Exception { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); - context.pollUntilSend(); + context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); @@ -1027,7 +1020,7 @@ public void testObserverQuorumDiscoveryFailure() throws Exception { context.client.poll(); context.time.sleep(context.retryBackoffMs); - context.pollUntilSend(); + context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); @@ -1050,7 +1043,7 @@ public void testObserverSendDiscoveryFetchAfterFetchTimeout() throws Exception { RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build(); - context.pollUntilSend(); + context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); context.assertFetchRequestData(fetchRequest, 0, 0L, 0); @@ -1062,7 +1055,7 @@ public void testObserverSendDiscoveryFetchAfterFetchTimeout() throws Exception { context.assertElectedLeader(epoch, leaderId); context.time.sleep(context.fetchTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); @@ -1079,27 +1072,27 @@ public void testInvalidFetchRequest() throws Exception { context.deliverRequest(context.fetchRequest( epoch, otherNodeId, -5L, 0, 0)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId)); context.deliverRequest(context.fetchRequest( epoch, otherNodeId, 0L, -1, 0)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId)); context.deliverRequest(context.fetchRequest( epoch, otherNodeId, 0L, epoch + 1, 0)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId)); context.deliverRequest(context.fetchRequest( epoch + 1, otherNodeId, 0L, 0, 0)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentFetchResponse(Errors.UNKNOWN_LEADER_EPOCH, epoch, OptionalInt.of(context.localId)); context.deliverRequest(context.fetchRequest( epoch, otherNodeId, 0L, 0, -1)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId)); } @@ -1141,17 +1134,17 @@ public void testInvalidVoteRequest() throws Exception { context.assertElectedLeader(epoch, otherNodeId); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, 0, -5L)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false); context.assertElectedLeader(epoch, otherNodeId); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, -1, 0L)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false); context.assertElectedLeader(epoch, otherNodeId); context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch + 1, 0L)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false); context.assertElectedLeader(epoch, otherNodeId); } @@ -1220,7 +1213,7 @@ public void testPurgatoryFetchCompletedByFollowerTransition() throws Exception { // Now we get a BeginEpoch from the other voter and become a follower context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3)); - context.client.poll(); + context.pollUntilResponse(); context.assertElectedLeader(epoch + 1, voter3); // We expect the BeginQuorumEpoch response and a failed Fetch response @@ -1246,7 +1239,7 @@ public void testFetchResponseIgnoredAfterBecomingCandidate() throws Exception { context.assertElectedLeader(epoch, otherNodeId); // Wait until we have a Fetch inflight to the leader - context.pollUntilSend(); + context.pollUntilRequest(); int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); // Now await the fetch timeout and become a candidate @@ -1280,7 +1273,7 @@ public void testFetchResponseIgnoredAfterBecomingFollowerOfDifferentLeader() thr context.assertElectedLeader(epoch, voter2); // Wait until we have a Fetch inflight to the leader - context.pollUntilSend(); + context.pollUntilRequest(); int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); // Now receive a BeginEpoch from `voter3` @@ -1316,7 +1309,7 @@ public void testVoteResponseIgnoredAfterBecomingFollower() throws Exception { context.time.sleep(context.electionTimeoutMs * 2); // Wait until the vote requests are inflight - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(epoch, context.localId); List voteRequests = context.collectVoteRequests(epoch, 0, 0); assertEquals(2, voteRequests.size()); @@ -1349,14 +1342,14 @@ public void testObserverLeaderRediscoveryAfterBrokerNotAvailableError() throws E context.discoverLeaderAsObserver(leaderId, epoch); - context.pollUntilSend(); + context.pollUntilRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); assertEquals(leaderId, fetchRequest1.destinationId()); context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(), context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE)); - context.pollUntilSend(); + context.pollUntilRequest(); // We should retry the Fetch against the other voter since the original // voter connection will be backing off. @@ -1386,13 +1379,13 @@ public void testObserverLeaderRediscoveryAfterRequestTimeout() throws Exception context.discoverLeaderAsObserver(leaderId, epoch); - context.pollUntilSend(); + context.pollUntilRequest(); RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest(); assertEquals(leaderId, fetchRequest1.destinationId()); context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0); context.time.sleep(context.requestTimeoutMs); - context.pollUntilSend(); + context.pollUntilRequest(); // We should retry the Fetch against the other voter since the original // voter connection will be backing off. @@ -1427,7 +1420,7 @@ public void testLeaderGracefulShutdown() throws Exception { assertFalse(shutdownFuture.isDone()); // Send EndQuorumEpoch request to the other voter - context.pollUntilSend(); + context.pollUntilRequest(); assertTrue(context.client.isShuttingDown()); assertTrue(context.client.isRunning()); context.assertSentEndQuorumEpochRequest(1, otherNodeId); @@ -1469,7 +1462,7 @@ public void testEndQuorumEpochSentBasedOnFetchOffset() throws Exception { assertTrue(context.client.isRunning()); // Send EndQuorumEpoch request to the close follower - context.pollUntilSend(); + context.pollUntilRequest(); assertTrue(context.client.isRunning()); List endQuorumRequests = context.collectEndQuorumRequests( @@ -1494,14 +1487,14 @@ public void testDescribeQuorum() throws Exception { int observerId = 3; context.deliverRequest(context.fetchRequest(epoch, observerId, 0L, 0, 0)); - context.client.poll(); + context.pollUntilResponse(); long highWatermark = 1L; context.assertSentFetchResponse(highWatermark, epoch); context.deliverRequest(DescribeQuorumRequest.singletonRequest(context.metadataPartition)); - context.client.poll(); + context.pollUntilResponse(); context.assertSentDescribeQuorumResponse(context.localId, epoch, highWatermark, Arrays.asList( @@ -1540,7 +1533,7 @@ public void testLeaderGracefulShutdownTimeout() throws Exception { assertFalse(shutdownFuture.isDone()); // Send EndQuorumEpoch request to the other vote - context.pollUntilSend(); + context.pollUntilRequest(); assertTrue(context.client.isRunning()); context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); @@ -1631,7 +1624,7 @@ public void testFollowerReplication() throws Exception { .build(); context.assertElectedLeader(epoch, otherNodeId); - context.pollUntilSend(); + context.pollUntilRequest(); int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b")); @@ -1656,7 +1649,7 @@ public void testEmptyRecordSetInFetchResponse() throws Exception { context.assertElectedLeader(epoch, otherNodeId); // Receive an empty fetch response - context.pollUntilSend(); + context.pollUntilRequest(); int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); FetchResponseData fetchResponse = context.fetchResponse(epoch, otherNodeId, MemoryRecords.EMPTY, 0L, Errors.NONE); @@ -1666,7 +1659,7 @@ public void testEmptyRecordSetInFetchResponse() throws Exception { assertEquals(OptionalLong.of(0L), context.client.highWatermark()); // Receive some records in the next poll, but do not advance high watermark - context.pollUntilSend(); + context.pollUntilRequest(); Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b")); fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0); fetchResponse = context.fetchResponse(epoch, otherNodeId, @@ -1677,7 +1670,7 @@ public void testEmptyRecordSetInFetchResponse() throws Exception { assertEquals(OptionalLong.of(0L), context.client.highWatermark()); // The next fetch response is empty, but should still advance the high watermark - context.pollUntilSend(); + context.pollUntilRequest(); fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 2L, epoch); fetchResponse = context.fetchResponse(epoch, otherNodeId, MemoryRecords.EMPTY, 2L, Errors.NONE); @@ -1704,7 +1697,7 @@ public void testFetchShouldBeTreatedAsLeaderEndorsement() throws Exception { context.time.sleep(context.electionTimeoutMs); context.expectAndGrantVotes(epoch); - context.pollUntilSend(); + context.pollUntilRequest(); // We send BeginEpoch, but it gets lost and the destination finds the leader through the Fetch API context.assertSentBeginQuorumEpochRequest(epoch, 1); @@ -1720,12 +1713,12 @@ public void testFetchShouldBeTreatedAsLeaderEndorsement() throws Exception { context.client.poll(); - List sentMessages = context.channel.drainSendQueue(); + List sentMessages = context.channel.drainSendQueue(); assertEquals(0, sentMessages.size()); } @Test - public void testLeaderAppendSingleMemberQuorum() throws IOException { + public void testLeaderAppendSingleMemberQuorum() throws Exception { int localId = 0; Set voters = Collections.singleton(localId); @@ -1752,8 +1745,7 @@ public void testLeaderAppendSingleMemberQuorum() throws IOException { // Now try reading it int otherNodeId = 1; context.deliverRequest(context.fetchRequest(1, otherNodeId, 0L, 0, 500)); - - context.client.poll(); + context.pollUntilResponse(); MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, 1, OptionalInt.of(context.localId)); List batches = Utils.toList(fetchedRecords.batchIterator()); @@ -1796,7 +1788,7 @@ public void testFollowerLogReconciliation() throws Exception { context.assertElectedLeader(epoch, otherNodeId); assertEquals(3L, context.log.endOffset().offset); - context.pollUntilSend(); + context.pollUntilRequest(); int correlationId = context.assertSentFetchRequest(epoch, 3L, lastEpoch); @@ -1873,7 +1865,7 @@ public void testClusterAuthorizationFailedInFetch() throws Exception { context.assertElectedLeader(epoch, otherNodeId); - context.pollUntilSend(); + context.pollUntilRequest(); int correlationId = context.assertSentFetchRequest(epoch, 0, 0); FetchResponseData response = new FetchResponseData() @@ -1899,7 +1891,7 @@ public void testClusterAuthorizationFailedInBeginQuorumEpoch() throws Exception context.time.sleep(context.electionTimeoutMs); context.expectAndGrantVotes(epoch); - context.pollUntilSend(); + context.pollUntilRequest(); int correlationId = context.assertSentBeginQuorumEpochRequest(epoch, 1); BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData() .setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code()); @@ -1921,7 +1913,7 @@ public void testClusterAuthorizationFailedInVote() throws Exception { // Sleep a little to ensure that we become a candidate context.time.sleep(context.electionTimeoutMs * 2); - context.pollUntilSend(); + context.pollUntilRequest(); context.assertVotedCandidate(epoch, context.localId); int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1); @@ -1942,7 +1934,7 @@ public void testClusterAuthorizationFailedInEndQuorumEpoch() throws Exception { RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch); context.client.shutdown(5000); - context.pollUntilSend(); + context.pollUntilRequest(); int correlationId = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId); EndQuorumEpochResponseData response = new EndQuorumEpochResponseData() @@ -2067,7 +2059,7 @@ public void testHandleCommitCallbackFiresAfterFollowerHighWatermarkAdvances() th assertEquals(OptionalLong.empty(), context.client.highWatermark()); // Poll for our first fetch request - context.pollUntilSend(); + context.pollUntilRequest(); RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); context.assertFetchRequestData(fetchRequest, epoch, 0L, 0); @@ -2085,7 +2077,7 @@ public void testHandleCommitCallbackFiresAfterFollowerHighWatermarkAdvances() th assertEquals(OptionalInt.empty(), context.listener.currentClaimedEpoch()); // Now look for the next fetch request - context.pollUntilSend(); + context.pollUntilRequest(); fetchRequest = context.assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); context.assertFetchRequestData(fetchRequest, epoch, 3L, 3); @@ -2131,7 +2123,7 @@ public void testHandleCommitCallbackFiresInVotedState() throws Exception { // Now we receive a vote request which transitions us to the 'voted' state int candidateEpoch = epoch + 1; context.deliverRequest(context.voteRequest(candidateEpoch, otherNodeId, epoch, 10L)); - context.client.poll(); + context.pollUntilResponse(); context.assertVotedCandidate(candidateEpoch, otherNodeId); assertEquals(OptionalLong.of(10L), context.client.highWatermark()); @@ -2166,12 +2158,13 @@ public void testHandleCommitCallbackFiresInCandidateState() throws Exception { // Start off as the leader and receive a fetch to initialize the high watermark context.becomeLeader(); context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 9L, epoch, 500)); - context.client.poll(); + context.pollUntilResponse(); assertEquals(OptionalLong.of(9L), context.client.highWatermark()); + context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(context.localId)); // Now we receive a vote request which transitions us to the 'unattached' state context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 9L)); - context.client.poll(); + context.pollUntilResponse(); context.assertUnknownLeader(epoch + 1); assertEquals(OptionalLong.of(9L), context.client.highWatermark()); diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java index 1e8c1b53fcb72..f0be8dd80d5a4 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java @@ -255,6 +255,13 @@ public LogAppendInfo appendAsFollower(Records records) { long baseOffset = endOffset().offset; long lastOffset = baseOffset; for (RecordBatch batch : records.batches()) { + Optional lastEntry = lastEntry(); + + if (lastEntry.isPresent() && batch.baseOffset() != lastEntry.get().offset + 1) { + throw new IllegalArgumentException("Illegal append at offset " + batch.baseOffset() + + " with current end offset of " + endOffset().offset); + } + List entries = buildEntries(batch, Record::offset); appendBatch(new LogBatch(batch.partitionLeaderEpoch(), batch.isControlBatch(), entries)); lastOffset = entries.get(entries.size() - 1).offset; diff --git a/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java new file mode 100644 index 0000000000000..d1c73e0f9b065 --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft; + +import java.util.ArrayDeque; +import java.util.OptionalLong; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Mocked implementation which does not block in {@link #poll(long)}.. + */ +public class MockMessageQueue implements RaftMessageQueue { + private final Queue messages = new ArrayDeque<>(); + private final AtomicBoolean wakeupRequested = new AtomicBoolean(false); + private final AtomicLong lastPollTimeout = new AtomicLong(-1); + + @Override + public RaftMessage poll(long timeoutMs) { + wakeupRequested.set(false); + lastPollTimeout.set(timeoutMs); + return messages.poll(); + } + + @Override + public void offer(RaftMessage message) { + messages.offer(message); + } + + public OptionalLong lastPollTimeoutMs() { + long lastTimeoutMs = lastPollTimeout.get(); + if (lastTimeoutMs < 0) { + return OptionalLong.empty(); + } else { + return OptionalLong.of(lastTimeoutMs); + } + } + + public boolean wakeupRequested() { + return wakeupRequested.get(); + } + + @Override + public boolean isEmpty() { + return messages.isEmpty(); + } + + @Override + public void wakeup() { + wakeupRequested.set(true); + } +} diff --git a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java index da14df307336a..3f08ff58256d3 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java @@ -24,22 +24,17 @@ import java.util.Iterator; import java.util.List; import java.util.Map; -import java.util.OptionalLong; -import java.util.concurrent.atomic.AtomicBoolean; +import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicLong; public class MockNetworkChannel implements NetworkChannel { - private final AtomicInteger requestIdCounter; - private final AtomicBoolean wakeupRequested = new AtomicBoolean(false); - private final AtomicLong lastReceiveTimeout = new AtomicLong(-1); + private final AtomicInteger correlationIdCounter; + private final List sendQueue = new ArrayList<>(); + private final Map awaitingResponse = new HashMap<>(); + private final Map addressCache = new HashMap<>(); - private List sendQueue = new ArrayList<>(); - private List receiveQueue = new ArrayList<>(); - private Map addressCache = new HashMap<>(); - - public MockNetworkChannel(AtomicInteger requestIdCounter) { - this.requestIdCounter = requestIdCounter; + public MockNetworkChannel(AtomicInteger correlationIdCounter) { + this.correlationIdCounter = correlationIdCounter; } public MockNetworkChannel() { @@ -48,46 +43,16 @@ public MockNetworkChannel() { @Override public int newCorrelationId() { - return requestIdCounter.getAndIncrement(); - } - - @Override - public void send(RaftMessage message) { - if (message instanceof RaftRequest.Outbound) { - RaftRequest.Outbound request = (RaftRequest.Outbound) message; - if (!addressCache.containsKey(request.destinationId())) { - throw new IllegalArgumentException("Attempted to send to destination " + - request.destinationId() + ", but its address is not yet known"); - } - } - sendQueue.add(message); + return correlationIdCounter.getAndIncrement(); } @Override - public List receive(long timeoutMs) { - wakeupRequested.set(false); - lastReceiveTimeout.set(timeoutMs); - List messages = receiveQueue; - receiveQueue = new ArrayList<>(); - return messages; - } - - OptionalLong lastReceiveTimeout() { - long timeout = lastReceiveTimeout.get(); - if (timeout < 0) { - return OptionalLong.empty(); - } else { - return OptionalLong.of(timeout); + public void send(RaftRequest.Outbound request) { + if (!addressCache.containsKey(request.destinationId())) { + throw new IllegalArgumentException("Attempted to send to destination " + + request.destinationId() + ", but its address is not yet known"); } - } - - boolean wakeupRequested() { - return wakeupRequested.get(); - } - - @Override - public void wakeup() { - wakeupRequested.set(true); + sendQueue.add(request); } @Override @@ -95,19 +60,17 @@ public void updateEndpoint(int id, InetSocketAddress address) { addressCache.put(id, address); } - public List drainSendQueue() { - List messages = sendQueue; - sendQueue = new ArrayList<>(); - return messages; + public List drainSendQueue() { + return drainSentRequests(Optional.empty()); } - public List drainSentRequests(ApiKeys apiKey) { + public List drainSentRequests(Optional apiKeyFilter) { List requests = new ArrayList<>(); - Iterator iterator = sendQueue.iterator(); + Iterator iterator = sendQueue.iterator(); while (iterator.hasNext()) { - RaftMessage message = iterator.next(); - if (message instanceof RaftRequest.Outbound && message.data().apiKey() == apiKey.id) { - RaftRequest.Outbound request = (RaftRequest.Outbound) message; + RaftRequest.Outbound request = iterator.next(); + if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) { + awaitingResponse.put(request.correlationId, request); requests.add(request); iterator.remove(); } @@ -115,27 +78,17 @@ public List drainSentRequests(ApiKeys apiKey) { return requests; } - public List drainSentResponses(ApiKeys apiKey) { - List responses = new ArrayList<>(); - Iterator iterator = sendQueue.iterator(); - while (iterator.hasNext()) { - RaftMessage message = iterator.next(); - if (message instanceof RaftResponse.Outbound && message.data().apiKey() == apiKey.id) { - RaftResponse.Outbound response = (RaftResponse.Outbound) message; - responses.add(response); - iterator.remove(); - } - } - return responses; - } - - public boolean hasSentMessages() { + public boolean hasSentRequests() { return !sendQueue.isEmpty(); } - public void mockReceive(RaftMessage message) { - receiveQueue.add(message); + public void mockReceive(RaftResponse.Inbound response) { + RaftRequest.Outbound request = awaitingResponse.get(response.correlationId); + if (request == null) { + throw new IllegalStateException("Received response for a request which is not being awaited"); + } + request.completion.complete(response); } } diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java index 5ed8061a41b89..2b69ac8049f1f 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java +++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java @@ -52,6 +52,7 @@ import org.apache.kafka.common.utils.Utils; import org.apache.kafka.raft.internals.BatchBuilder; import org.apache.kafka.raft.internals.StringSerde; +import org.apache.kafka.test.TestCondition; import org.apache.kafka.test.TestUtils; import org.mockito.Mockito; @@ -63,6 +64,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Optional; @@ -96,11 +98,13 @@ public final class RaftClientTestContext { final Metrics metrics; public final MockLog log; final MockNetworkChannel channel; + final MockMessageQueue messageQueue; final MockTime time; final MockListener listener; - final Set voters; + private final List sentResponses = new ArrayList<>(); + public static final class Builder { static final int DEFAULT_ELECTION_TIMEOUT_MS = 10000; @@ -113,6 +117,7 @@ public static final class Builder { private static final int RETRY_BACKOFF_MS = 50; private static final int DEFAULT_APPEND_LINGER_MS = 0; + private final MockMessageQueue messageQueue = new MockMessageQueue(); private final MockTime time = new MockTime(); private final QuorumStateStore quorumStateStore = new MockQuorumStateStore(); private final Random random = Mockito.spy(new Random(1)); @@ -194,6 +199,7 @@ public RaftClientTestContext build() throws IOException { KafkaRaftClient client = new KafkaRaftClient<>( STRING_SERDE, channel, + messageQueue, log, quorum, memoryPool, @@ -218,6 +224,7 @@ public RaftClientTestContext build() throws IOException { client, log, channel, + messageQueue, time, quorumStateStore, quorum, @@ -233,6 +240,7 @@ private RaftClientTestContext( KafkaRaftClient client, MockLog log, MockNetworkChannel channel, + MockMessageQueue messageQueue, MockTime time, QuorumStateStore quorumStateStore, QuorumState quorum, @@ -244,6 +252,7 @@ private RaftClientTestContext( this.client = client; this.log = log; this.channel = channel; + this.messageQueue = messageQueue; this.time = time; this.quorumStateStore = quorumStateStore; this.quorum = quorum; @@ -326,7 +335,7 @@ LeaderAndEpoch currentLeaderAndEpoch() { void expectAndGrantVotes( int epoch ) throws Exception { - pollUntilSend(); + pollUntilRequest(); List voteRequests = collectVoteRequests(epoch, log.lastFetchedEpoch(), log.endOffset().offset); @@ -343,7 +352,7 @@ void expectAndGrantVotes( void expectBeginEpoch( int epoch ) throws Exception { - pollUntilSend(); + pollUntilRequest(); for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) { BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localId); deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse); @@ -351,11 +360,19 @@ void expectBeginEpoch( client.poll(); } - void pollUntilSend() throws InterruptedException { + void pollUntil(TestCondition condition) throws InterruptedException { TestUtils.waitForCondition(() -> { client.poll(); - return channel.hasSentMessages(); - }, 5000, "Condition failed to be satisfied before timeout"); + return condition.conditionMet(); + }, 500000000, "Condition failed to be satisfied before timeout"); + } + + void pollUntilResponse() throws InterruptedException { + pollUntil(() -> !sentResponses.isEmpty()); + } + + void pollUntilRequest() throws InterruptedException { + pollUntil(channel::hasSentRequests); } void assertVotedCandidate(int epoch, int leaderId) throws IOException { @@ -375,14 +392,16 @@ void assertResignedLeader(int epoch, int leaderId) throws IOException { assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), quorumStateStore.readElectionState()); } - int assertSentDescribeQuorumResponse(int leaderId, - int leaderEpoch, - long highWatermark, - List voterStates, - List observerStates) { - List sentMessages = channel.drainSendQueue(); + int assertSentDescribeQuorumResponse( + int leaderId, + int leaderEpoch, + long highWatermark, + List voterStates, + List observerStates + ) { + List sentMessages = drainSentResponses(ApiKeys.DESCRIBE_QUORUM); assertEquals(1, sentMessages.size()); - RaftMessage raftMessage = sentMessages.get(0); + RaftResponse.Outbound raftMessage = sentMessages.get(0); assertTrue( raftMessage.data() instanceof DescribeQuorumResponseData, "Unexpected request type " + raftMessage.data()); @@ -412,7 +431,7 @@ void assertSentVoteResponse( OptionalInt leaderId, boolean voteGranted ) { - List sentMessages = channel.drainSentResponses(ApiKeys.VOTE); + List sentMessages = drainSentResponses(ApiKeys.VOTE); assertEquals(1, sentMessages.size()); RaftMessage raftMessage = sentMessages.get(0); assertTrue(raftMessage.data() instanceof VoteResponseData); @@ -449,8 +468,16 @@ List collectVoteRequests( } void deliverRequest(ApiMessage request) { - RaftRequest.Inbound message = new RaftRequest.Inbound(channel.newCorrelationId(), request, time.milliseconds()); - channel.mockReceive(message); + RaftRequest.Inbound inboundRequest = new RaftRequest.Inbound( + channel.newCorrelationId(), request, time.milliseconds()); + inboundRequest.completion.whenComplete((response, exception) -> { + if (exception != null) { + throw new RuntimeException(exception); + } else { + sentResponses.add(response); + } + }); + client.handle(inboundRequest); } void deliverResponse(int correlationId, int sourceId, ApiMessage response) { @@ -463,12 +490,27 @@ int assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) { return requests.get(0).correlationId; } + private List drainSentResponses( + ApiKeys apiKey + ) { + List res = new ArrayList<>(); + Iterator iterator = sentResponses.iterator(); + while (iterator.hasNext()) { + RaftResponse.Outbound response = iterator.next(); + if (response.data.apiKey() == apiKey.id) { + res.add(response); + iterator.remove(); + } + } + return res; + } + void assertSentBeginQuorumEpochResponse( Errors partitionError, int epoch, OptionalInt leaderId ) { - List sentMessages = channel.drainSentResponses(ApiKeys.BEGIN_QUORUM_EPOCH); + List sentMessages = drainSentResponses(ApiKeys.BEGIN_QUORUM_EPOCH); assertEquals(1, sentMessages.size()); RaftMessage raftMessage = sentMessages.get(0); assertTrue(raftMessage.data() instanceof BeginQuorumEpochResponseData); @@ -495,7 +537,7 @@ void assertSentEndQuorumEpochResponse( int epoch, OptionalInt leaderId ) { - List sentMessages = channel.drainSentResponses(ApiKeys.END_QUORUM_EPOCH); + List sentMessages = drainSentResponses(ApiKeys.END_QUORUM_EPOCH); assertEquals(1, sentMessages.size()); RaftMessage raftMessage = sentMessages.get(0); assertTrue(raftMessage.data() instanceof EndQuorumEpochResponseData); @@ -511,7 +553,7 @@ void assertSentEndQuorumEpochResponse( } RaftRequest.Outbound assertSentFetchRequest() { - List sentRequests = channel.drainSentRequests(ApiKeys.FETCH); + List sentRequests = channel.drainSentRequests(Optional.of(ApiKeys.FETCH)); assertEquals(1, sentRequests.size()); return sentRequests.get(0); } @@ -521,8 +563,10 @@ int assertSentFetchRequest( long fetchOffset, int lastFetchedEpoch ) { - List sentMessages = channel.drainSendQueue(); + List sentMessages = channel.drainSendQueue(); assertEquals(1, sentMessages.size()); + + // TODO: Use more specific type RaftMessage raftMessage = sentMessages.get(0); assertFetchRequestData(raftMessage, epoch, fetchOffset, lastFetchedEpoch); return raftMessage.correlationId(); @@ -559,8 +603,7 @@ void buildFollowerSet( // The lagging follower fetches first deliverRequest(fetchRequest(1, laggingFollower, 0L, 0, 0)); - client.poll(); - + pollUntilResponse(); assertSentFetchResponse(0L, epoch); // Append some records, so that the close follower will be able to advance further. @@ -569,8 +612,7 @@ void buildFollowerSet( deliverRequest(fetchRequest(epoch, closeFollower, 1L, epoch, 0)); - client.poll(); - + pollUntilResponse(); assertSentFetchResponse(1L, epoch); } @@ -600,7 +642,7 @@ void discoverLeaderAsObserver( int leaderId, int epoch ) throws Exception { - pollUntilSend(); + pollUntilRequest(); RaftRequest.Outbound fetchRequest = assertSentFetchRequest(); assertTrue(voters.contains(fetchRequest.destinationId())); assertFetchRequestData(fetchRequest, 0, 0L, 0); @@ -613,7 +655,7 @@ void discoverLeaderAsObserver( private List collectBeginEpochRequests(int epoch) { List requests = new ArrayList<>(); - for (RaftRequest.Outbound raftRequest : channel.drainSentRequests(ApiKeys.BEGIN_QUORUM_EPOCH)) { + for (RaftRequest.Outbound raftRequest : channel.drainSentRequests(Optional.of(ApiKeys.BEGIN_QUORUM_EPOCH))) { assertTrue(raftRequest.data() instanceof BeginQuorumEpochRequestData); BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) raftRequest.data(); @@ -628,7 +670,7 @@ private List collectBeginEpochRequests(int epoch) { } private FetchResponseData.FetchablePartitionResponse assertSentPartitionResponse() { - List sentMessages = channel.drainSentResponses(ApiKeys.FETCH); + List sentMessages = drainSentResponses(ApiKeys.FETCH); assertEquals( 1, sentMessages.size(), "Found unexpected sent messages " + sentMessages); RaftResponse.Outbound raftMessage = sentMessages.get(0); diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java index cd4c6a0d7adf8..78f65dbfe0c0e 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java @@ -19,8 +19,8 @@ import org.apache.kafka.common.TopicPartition; import org.apache.kafka.common.memory.MemoryPool; import org.apache.kafka.common.metrics.Metrics; -import org.apache.kafka.common.protocol.Writable; import org.apache.kafka.common.protocol.Readable; +import org.apache.kafka.common.protocol.Writable; import org.apache.kafka.common.protocol.types.Type; import org.apache.kafka.common.utils.LogContext; import org.apache.kafka.common.utils.MockTime; @@ -62,9 +62,9 @@ public class RaftEventSimulationTest { private static final TopicPartition METADATA_PARTITION = new TopicPartition("__cluster_metadata", 0); private static final int ELECTION_TIMEOUT_MS = 1000; private static final int ELECTION_JITTER_MS = 100; - private static final int FETCH_TIMEOUT_MS = 5000; + private static final int FETCH_TIMEOUT_MS = 3000; private static final int RETRY_BACKOFF_MS = 50; - private static final int REQUEST_TIMEOUT_MS = 500; + private static final int REQUEST_TIMEOUT_MS = 3000; private static final int FETCH_MAX_WAIT_MS = 100; private static final int LINGER_MS = 0; @@ -115,7 +115,7 @@ public void testElectionAfterLeaderFailureQuorumSizeThree() { @Test public void testElectionAfterLeaderFailureQuorumSizeThreeAndTwoObservers() { - testElectionAfterLeaderFailure(new QuorumConfig(3, 2)); + testElectionAfterLeaderFailure(new QuorumConfig(3, 1)); } @Test @@ -748,6 +748,7 @@ void start(int nodeId) { LogContext logContext = new LogContext("[Node " + nodeId + "] "); PersistentState persistentState = nodes.get(nodeId); MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter); + MockMessageQueue messageQueue = new MockMessageQueue(); QuorumState quorum = new QuorumState(nodeId, voters(), ELECTION_TIMEOUT_MS, FETCH_TIMEOUT_MS, persistentState.store, time, logContext, random); Metrics metrics = new Metrics(time); @@ -766,6 +767,7 @@ void start(int nodeId) { KafkaRaftClient client = new KafkaRaftClient<>( serde, channel, + messageQueue, persistentState.log, quorum, memoryPool, @@ -781,8 +783,16 @@ void start(int nodeId) { logContext, random ); - RaftNode node = new RaftNode(nodeId, client, persistentState.log, channel, - persistentState.store, quorum, logContext); + RaftNode node = new RaftNode( + nodeId, + client, + persistentState.log, + channel, + messageQueue, + persistentState.store, + quorum, + logContext + ); node.initialize(); running.put(nodeId, node); } @@ -793,22 +803,27 @@ private static class RaftNode { final KafkaRaftClient client; final MockLog log; final MockNetworkChannel channel; + final MockMessageQueue messageQueue; final MockQuorumStateStore store; final QuorumState quorum; final LogContext logContext; final ReplicatedCounter counter; - private RaftNode(int nodeId, - KafkaRaftClient client, - MockLog log, - MockNetworkChannel channel, - MockQuorumStateStore store, - QuorumState quorum, - LogContext logContext) { + private RaftNode( + int nodeId, + KafkaRaftClient client, + MockLog log, + MockNetworkChannel channel, + MockMessageQueue messageQueue, + MockQuorumStateStore store, + QuorumState quorum, + LogContext logContext + ) { this.nodeId = nodeId; this.client = client; this.log = log; this.channel = channel; + this.messageQueue = messageQueue; this.store = store; this.quorum = quorum; this.logContext = logContext; @@ -826,9 +841,11 @@ void initialize() { void poll() { try { - client.poll(); - } catch (IOException e) { - throw new RuntimeException(e); + do { + client.poll(); + } while (client.isRunning() && !messageQueue.isEmpty()); + } catch (Exception e) { + throw new RuntimeException("Uncaught exception during poll of node " + nodeId, e); } } } @@ -1025,7 +1042,7 @@ private int parseSequenceNumber(ByteBuffer value) { return (int) Type.INT32.read(value); } - private void assertCommittedData(int nodeId, KafkaRaftClient manager, MockLog log) { + private void assertCommittedData(int nodeId, KafkaRaftClient manager, MockLog log) { OptionalLong highWatermark = manager.highWatermark(); if (!highWatermark.isPresent()) { // We cannot do validation if the current high watermark is unknown @@ -1070,6 +1087,9 @@ private MessageRouter(Cluster cluster) { } void deliver(int senderId, RaftRequest.Outbound outbound) { + if (!filters.get(senderId).acceptOutbound(outbound)) + return; + int correlationId = outbound.correlationId(); int destinationId = outbound.destinationId(); RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(), @@ -1079,9 +1099,15 @@ void deliver(int senderId, RaftRequest.Outbound outbound) { return; cluster.nodeIfRunning(destinationId).ifPresent(node -> { - MockNetworkChannel destChannel = node.channel; inflight.put(correlationId, new InflightRequest(correlationId, senderId, destinationId)); - destChannel.mockReceive(inbound); + + inbound.completion.whenComplete((response, exception) -> { + if (response != null && filters.get(destinationId).acceptOutbound(response)) { + deliver(destinationId, response); + } + }); + + node.client.handle(inbound); }); } @@ -1089,6 +1115,7 @@ void deliver(int senderId, RaftResponse.Outbound outbound) { int correlationId = outbound.correlationId(); RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId); InflightRequest inflightRequest = inflight.remove(correlationId); + if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound)) return; @@ -1097,26 +1124,10 @@ void deliver(int senderId, RaftResponse.Outbound outbound) { }); } - void deliver(int senderId, RaftMessage message) { - if (!filters.get(senderId).acceptOutbound(message)) { - return; - } else if (message instanceof RaftRequest.Outbound) { - deliver(senderId, (RaftRequest.Outbound) message); - } else if (message instanceof RaftResponse.Outbound) { - deliver(senderId, (RaftResponse.Outbound) message); - } else { - throw new AssertionError("Illegal message type sent by node " + message); - } - } - void filter(int nodeId, NetworkFilter filter) { filters.put(nodeId, filter); } - void deliverRandom() { - cluster.forRandomRunning(this::deliverTo); - } - void deliverTo(RaftNode node) { node.channel.drainSendQueue().forEach(msg -> deliver(node.nodeId, msg)); } diff --git a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java index b516ce57ecc2a..e6e2f7cf0a6a5 100644 --- a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java @@ -92,7 +92,7 @@ public void testSuccessfulResponse() { long correlationId = 1; connectionState.onRequestSent(correlationId, time.milliseconds()); assertFalse(connectionState.isReady(time.milliseconds())); - connectionState.onResponseReceived(correlationId, time.milliseconds()); + connectionState.onResponseReceived(correlationId); assertTrue(connectionState.isReady(time.milliseconds())); } @@ -109,7 +109,7 @@ public void testIgnoreUnexpectedResponse() { long correlationId = 1; connectionState.onRequestSent(correlationId, time.milliseconds()); assertFalse(connectionState.isReady(time.milliseconds())); - connectionState.onResponseReceived(correlationId + 1, time.milliseconds()); + connectionState.onResponseReceived(correlationId + 1); assertFalse(connectionState.isReady(time.milliseconds())); } @@ -121,7 +121,6 @@ public void testRequestTimeout() { requestTimeoutMs, random); - RequestManager.ConnectionState connectionState = cache.getOrCreate(1); long correlationId = 1; diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java new file mode 100644 index 0000000000000..4aafe5b12866a --- /dev/null +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.kafka.raft.internals; + +import org.apache.kafka.raft.RaftMessage; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class BlockingMessageQueueTest { + + @Test + public void testOfferAndPoll() { + BlockingMessageQueue queue = new BlockingMessageQueue(); + assertTrue(queue.isEmpty()); + assertNull(queue.poll(0)); + + RaftMessage message1 = Mockito.mock(RaftMessage.class); + queue.offer(message1); + assertFalse(queue.isEmpty()); + assertEquals(message1, queue.poll(0)); + assertTrue(queue.isEmpty()); + + RaftMessage message2 = Mockito.mock(RaftMessage.class); + RaftMessage message3 = Mockito.mock(RaftMessage.class); + queue.offer(message2); + queue.offer(message3); + assertFalse(queue.isEmpty()); + assertEquals(message2, queue.poll(0)); + assertEquals(message3, queue.poll(0)); + + } + + @Test + public void testWakeupFromPoll() { + BlockingMessageQueue queue = new BlockingMessageQueue(); + queue.wakeup(); + assertNull(queue.poll(Long.MAX_VALUE)); + } + +} \ No newline at end of file From 065514e8e0caa9a5203a8364be11fde766b6ff0b Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Fri, 11 Dec 2020 11:38:25 -0800 Subject: [PATCH 02/10] Add batch send api to `InterBrokerSendThread` --- .../kafka/common/InterBrokerSendThread.scala | 6 ++- .../TransactionMarkerChannelManager.scala | 2 +- ...=> BrokerToControllerChannelManager.scala} | 51 ++++++++----------- .../main/scala/kafka/server/KafkaServer.scala | 4 +- 4 files changed, 29 insertions(+), 34 deletions(-) rename core/src/main/scala/kafka/server/{BrokerToControllerChannelManagerImpl.scala => BrokerToControllerChannelManager.scala} (90%) diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala index 7327695af680f..ff1cd1fb44188 100644 --- a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala +++ b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala @@ -54,7 +54,11 @@ class InterBrokerSendThread( } def sendRequest(request: RequestAndCompletionHandler): Unit = { - inboundQueue.offer(request) + sendRequests(Seq(request)) + } + + def sendRequests(requests: Iterable[RequestAndCompletionHandler]): Unit = { + inboundQueue.addAll(requests.asJavaCollection) wakeup() } diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala index b000a890d32c7..ed0d2390c0b3d 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -156,7 +156,7 @@ class TransactionMarkerChannelManager( newGauge("LogAppendRetryQueueSize", () => txnLogAppendRetryQueue.size) override def doWork(): Unit = { - drainQueuedTransactionMarkers().foreach(super.sendRequest) + super.sendRequests(drainQueuedTransactionMarkers()) super.doWork() } diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala similarity index 90% rename from core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala rename to core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index aa212f1c0d60b..1546658db3e67 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -33,26 +33,6 @@ import org.apache.kafka.common.utils.{LogContext, Time} import scala.collection.mutable import scala.jdk.CollectionConverters._ -trait BrokerToControllerChannelManager { - - /** - * Send request to the controller. - * - * @param request The request to be sent. - * @param callback Request completion callback. - * @param retryDeadlineMs The retry deadline which will only be checked after receiving a response. - * This means that in the worst case, the total timeout would be twice of - * the configured timeout. - */ - def sendRequest(request: AbstractRequest.Builder[_ <: AbstractRequest], - callback: ControllerRequestCompletionHandler, - retryDeadlineMs: Long): Unit - - def start(): Unit - - def shutdown(): Unit -} - /** * This class manages the connection between a broker and the controller. It runs a single * {@link BrokerToControllerRequestThread} which uses the broker's metadata cache as its own metadata to find @@ -60,22 +40,24 @@ trait BrokerToControllerChannelManager { * The maximum number of in-flight requests are set to one to ensure orderly response from the controller, therefore * care must be taken to not block on outstanding requests for too long. */ -class BrokerToControllerChannelManagerImpl(metadataCache: kafka.server.MetadataCache, - time: Time, - metrics: Metrics, - config: KafkaConfig, - channelName: String, - threadNamePrefix: Option[String] = None) extends BrokerToControllerChannelManager with Logging { +class BrokerToControllerChannelManager( + metadataCache: kafka.server.MetadataCache, + time: Time, + metrics: Metrics, + config: KafkaConfig, + channelName: String, + threadNamePrefix: Option[String] = None +) extends Logging { private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem] private val logContext = new LogContext(s"[broker-${config.brokerId}-to-controller] ") private val manualMetadataUpdater = new ManualMetadataUpdater() private val requestThread = newRequestThread - override def start(): Unit = { + def start(): Unit = { requestThread.start() } - override def shutdown(): Unit = { + def shutdown(): Unit = { requestThread.shutdown() requestThread.awaitShutdown() info(s"Broker to controller channel manager for $channelName shutdown") @@ -135,7 +117,16 @@ class BrokerToControllerChannelManagerImpl(metadataCache: kafka.server.MetadataC brokerToControllerListenerName, time, threadName) } - override def sendRequest(request: AbstractRequest.Builder[_ <: AbstractRequest], + /** + * Send request to the controller. + * + * @param request The request to be sent. + * @param callback Request completion callback. + * @param retryDeadlineMs The retry deadline which will only be checked after receiving a response. + * This means that in the worst case, the total timeout would be twice of + * the configured timeout. + */ + def sendRequest(request: AbstractRequest.Builder[_ <: AbstractRequest], callback: ControllerRequestCompletionHandler, retryDeadlineMs: Long): Unit = { requestQueue.put(BrokerToControllerQueueItem(request, callback, retryDeadlineMs)) @@ -207,7 +198,7 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, override def doWork(): Unit = { if (activeController.isDefined) { - generateRequests().foreach(sendRequest) + super.sendRequests(generateRequests()) super.doWork() } else { debug("Controller isn't cached, looking for local metadata changes") diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index 1e7bb19c6dc4e..a270e803a5b43 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -309,7 +309,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP socketServer.startup(startProcessingRequests = false) /* start replica manager */ - alterIsrChannelManager = new BrokerToControllerChannelManagerImpl( + alterIsrChannelManager = new BrokerToControllerChannelManager( metadataCache, time, metrics, config, "alterIsrChannel", threadNamePrefix) replicaManager = createReplicaManager(isShuttingDown) replicaManager.startup() @@ -332,7 +332,7 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP var forwardingManager: ForwardingManager = null if (config.metadataQuorumEnabled) { /* start forwarding manager */ - forwardingChannelManager = new BrokerToControllerChannelManagerImpl(metadataCache, time, metrics, + forwardingChannelManager = new BrokerToControllerChannelManager(metadataCache, time, metrics, config, "forwardingChannel", threadNamePrefix) forwardingChannelManager.start() forwardingManager = new ForwardingManager(forwardingChannelManager, time, config.requestTimeoutMs.longValue()) From 6eb4ef252cb271b89563f78b4de64bb649094d02 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Fri, 11 Dec 2020 14:31:26 -0800 Subject: [PATCH 03/10] Walk back addition of queue in `InterBrokerSendThread` --- .../kafka/common/InterBrokerSendThread.scala | 49 +++++----- .../TransactionMarkerChannelManager.scala | 21 ++-- .../kafka/raft/KafkaNetworkChannel.scala | 42 +++++++- .../BrokerToControllerChannelManager.scala | 49 +++++----- .../common/InterBrokerSendThreadTest.scala | 47 +++++---- .../BrokerToControllerRequestThreadTest.scala | 96 +++++++++---------- ...ransactionCoordinatorConcurrencyTest.scala | 2 +- .../TransactionMarkerChannelManagerTest.scala | 20 ++-- .../apache/kafka/raft/RaftMessageQueue.java | 3 +- 9 files changed, 192 insertions(+), 137 deletions(-) diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala index ff1cd1fb44188..0ef84265b5201 100644 --- a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala +++ b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala @@ -17,7 +17,6 @@ package kafka.common import java.util.Map.Entry -import java.util.concurrent.ConcurrentLinkedQueue import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator} import kafka.utils.ShutdownableThread @@ -33,7 +32,7 @@ import scala.jdk.CollectionConverters._ /** * Class for inter-broker send thread that utilize a non-blocking network client. */ -class InterBrokerSendThread( +abstract class InterBrokerSendThread( name: String, networkClient: KafkaClient, requestTimeoutMs: Int, @@ -41,9 +40,10 @@ class InterBrokerSendThread( isInterruptible: Boolean = true ) extends ShutdownableThread(name, isInterruptible) { - private val inboundQueue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() private val unsentRequests = new UnsentRequests + def generateRequests(): Iterable[RequestAndCompletionHandler] + def hasUnsentRequests: Boolean = unsentRequests.iterator().hasNext override def shutdown(): Unit = { @@ -53,35 +53,25 @@ class InterBrokerSendThread( awaitShutdown() } - def sendRequest(request: RequestAndCompletionHandler): Unit = { - sendRequests(Seq(request)) - } - - def sendRequests(requests: Iterable[RequestAndCompletionHandler]): Unit = { - inboundQueue.addAll(requests.asJavaCollection) - wakeup() - } - - private def drainInboundQueue(): Unit = { - while (!inboundQueue.isEmpty) { - val request = inboundQueue.poll() - val completionHandler = request.handler + private def drainGeneratedRequests(): Unit = { + generateRequests().foreach { request => unsentRequests.put(request.destination, networkClient.newClientRequest( request.destination.idString, request.request, - time.milliseconds(), + request.creationTimeMs, true, requestTimeoutMs, - completionHandler)) + request.handler + )) } } - override def doWork(): Unit = { + protected def pollOnce(maxTimeoutMs: Long): Unit = { try { + drainGeneratedRequests() var now = time.milliseconds() - drainInboundQueue() - val timeout = sendRequests(now) + val timeout = sendRequests(now, maxTimeoutMs) networkClient.poll(timeout, now) now = time.milliseconds() checkDisconnects(now) @@ -99,8 +89,12 @@ class InterBrokerSendThread( } } - private def sendRequests(now: Long): Long = { - var pollTimeout = Long.MaxValue + override def doWork(): Unit = { + pollOnce(Long.MaxValue) + } + + private def sendRequests(now: Long, maxTimeoutMs: Long): Long = { + var pollTimeout = maxTimeoutMs for (node <- unsentRequests.nodes.asScala) { val requestIterator = unsentRequests.requestIterator(node) while (requestIterator.hasNext) { @@ -157,9 +151,12 @@ class InterBrokerSendThread( def wakeup(): Unit = networkClient.wakeup() } -case class RequestAndCompletionHandler(destination: Node, - request: AbstractRequest.Builder[_ <: AbstractRequest], - handler: RequestCompletionHandler) +case class RequestAndCompletionHandler( + creationTimeMs: Long, + destination: Node, + request: AbstractRequest.Builder[_ <: AbstractRequest], + handler: RequestCompletionHandler +) private class UnsentRequests { private val unsent = new HashMap[Node, ArrayDeque[ClientRequest]] diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala index ed0d2390c0b3d..ac582f37cf9d1 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -24,8 +24,8 @@ import kafka.api.KAFKA_2_8_IV0 import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} import kafka.metrics.KafkaMetricsGroup import kafka.server.{KafkaConfig, MetadataCache} -import kafka.utils.{CoreUtils, Logging} import kafka.utils.Implicits._ +import kafka.utils.{CoreUtils, Logging} import org.apache.kafka.clients._ import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.network._ @@ -36,8 +36,8 @@ import org.apache.kafka.common.security.JaasContext import org.apache.kafka.common.utils.{LogContext, Time} import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition} -import scala.jdk.CollectionConverters._ import scala.collection.{concurrent, immutable} +import scala.jdk.CollectionConverters._ object TransactionMarkerChannelManager { def apply(config: KafkaConfig, @@ -155,11 +155,6 @@ class TransactionMarkerChannelManager( newGauge("UnknownDestinationQueueSize", () => markersQueueForUnknownBroker.totalNumMarkers) newGauge("LogAppendRetryQueueSize", () => txnLogAppendRetryQueue.size) - override def doWork(): Unit = { - super.sendRequests(drainQueuedTransactionMarkers()) - super.doWork() - } - override def shutdown(): Unit = { super.shutdown() markersQueuePerBroker.clear() @@ -195,7 +190,7 @@ class TransactionMarkerChannelManager( } } - private[transaction] def drainQueuedTransactionMarkers(): Iterable[RequestAndCompletionHandler] = { + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { retryLogAppends() val txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]() markersQueueForUnknownBroker.forEachTxnTopicPartition { case (_, queue) => @@ -213,6 +208,7 @@ class TransactionMarkerChannelManager( addTxnMarkersToBrokerQueue(transactionalId, producerId, producerEpoch, txnResult, coordinatorEpoch, topicPartitions) } + val currentTimeMs = time.milliseconds() markersQueuePerBroker.values.map { brokerRequestQueue => val txnIdAndMarkerEntries = new util.ArrayList[TxnIdAndMarkerEntry]() brokerRequestQueue.forEachTxnTopicPartition { case (_, queue) => @@ -222,7 +218,14 @@ class TransactionMarkerChannelManager( }.filter { case (_, entries) => !entries.isEmpty }.map { case (node, entries) => val markersToSend = entries.asScala.map(_.txnMarkerEntry).asJava val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries) - RequestAndCompletionHandler(node, new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend), requestCompletionHandler) + val request = new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend) + + RequestAndCompletionHandler( + currentTimeMs, + node, + request, + requestCompletionHandler + ) } } diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala index 53ed7dcceab19..435167208f70a 100644 --- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala +++ b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala @@ -17,6 +17,7 @@ package kafka.raft import java.net.InetSocketAddress +import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.atomic.AtomicInteger import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} @@ -53,6 +54,41 @@ object KafkaNetworkChannel { } +private[raft] class RaftSendThread( + name: String, + networkClient: KafkaClient, + requestTimeoutMs: Int, + time: Time, + isInterruptible: Boolean = true +) extends InterBrokerSendThread( + name, + networkClient, + requestTimeoutMs, + time, + isInterruptible +) { + private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]() + + def generateRequests(): Iterable[RequestAndCompletionHandler] = { + val buffer = mutable.Buffer[RequestAndCompletionHandler]() + while (true) { + val request = queue.poll() + if (request == null) { + return buffer + } else { + buffer += request + } + } + buffer + } + + def sendRequest(request: RequestAndCompletionHandler): Unit = { + queue.add(request) + } + +} + + class KafkaNetworkChannel( time: Time, client: KafkaClient, @@ -65,7 +101,7 @@ class KafkaNetworkChannel( private val correlationIdCounter = new AtomicInteger(0) private val endpoints = mutable.HashMap.empty[Int, Node] - private val requestThread = new InterBrokerSendThread( + private val requestThread = new RaftSendThread( name = "raft-outbound-request-thread", networkClient = client, requestTimeoutMs = requestTimeoutMs, @@ -97,6 +133,7 @@ class KafkaNetworkChannel( endpoints.get(request.destinationId) match { case Some(node) => requestThread.sendRequest(RequestAndCompletionHandler( + request.createdTimeMs, destination = node, request = buildRequest(request.data), handler = onComplete @@ -107,7 +144,8 @@ class KafkaNetworkChannel( } } - def pollOnce(): Unit = { + // Visible for testing + private[raft] def pollOnce(): Unit = { requestThread.doWork() } diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index 1546658db3e67..976f85c20bac2 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -30,7 +30,6 @@ import org.apache.kafka.common.requests.AbstractRequest import org.apache.kafka.common.security.JaasContext import org.apache.kafka.common.utils.{LogContext, Time} -import scala.collection.mutable import scala.jdk.CollectionConverters._ /** @@ -48,7 +47,6 @@ class BrokerToControllerChannelManager( channelName: String, threadNamePrefix: Option[String] = None ) extends Logging { - private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem] private val logContext = new LogContext(s"[broker-${config.brokerId}-to-controller] ") private val manualMetadataUpdater = new ManualMetadataUpdater() private val requestThread = newRequestThread @@ -113,7 +111,7 @@ class BrokerToControllerChannelManager( case Some(name) => s"$name:broker-${config.brokerId}-to-controller-send-thread" } - new BrokerToControllerRequestThread(networkClient, manualMetadataUpdater, requestQueue, metadataCache, config, + new BrokerToControllerRequestThread(networkClient, manualMetadataUpdater, metadataCache, config, brokerToControllerListenerName, time, threadName) } @@ -126,11 +124,16 @@ class BrokerToControllerChannelManager( * This means that in the worst case, the total timeout would be twice of * the configured timeout. */ - def sendRequest(request: AbstractRequest.Builder[_ <: AbstractRequest], - callback: ControllerRequestCompletionHandler, - retryDeadlineMs: Long): Unit = { - requestQueue.put(BrokerToControllerQueueItem(request, callback, retryDeadlineMs)) - requestThread.wakeup() + def sendRequest( + request: AbstractRequest.Builder[_ <: AbstractRequest], + callback: ControllerRequestCompletionHandler, + retryDeadlineMs: Long + ): Unit = { + requestThread.enqueue(BrokerToControllerQueueItem( + request, + callback, + retryDeadlineMs + )) } } @@ -149,7 +152,6 @@ case class BrokerToControllerQueueItem(request: AbstractRequest.Builder[_ <: Abs class BrokerToControllerRequestThread(networkClient: KafkaClient, metadataUpdater: ManualMetadataUpdater, - requestQueue: LinkedBlockingDeque[BrokerToControllerQueueItem], metadataCache: kafka.server.MetadataCache, config: KafkaConfig, listenerName: ListenerName, @@ -157,21 +159,25 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, threadName: String) extends InterBrokerSendThread(threadName, networkClient, config.controllerSocketTimeoutMs, time, isInterruptible = false) { + private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() private var activeController: Option[Node] = None - def generateRequests(): Iterable[RequestAndCompletionHandler] = { - val requestsToSend = new mutable.Queue[RequestAndCompletionHandler] - val topRequest = requestQueue.poll() - if (topRequest != null) { - val request = RequestAndCompletionHandler( + def enqueue(request: BrokerToControllerQueueItem): Unit = { + requestQueue.add(request) + if (activeController.isDefined) { + wakeup() + } + } + + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + Option(requestQueue.poll()).map { queueItem => + RequestAndCompletionHandler( + time.milliseconds(), activeController.get, - topRequest.request, - handleResponse(topRequest) + queueItem.request, + handleResponse(queueItem) ) - - requestsToSend.enqueue(request) } - requestsToSend } private[server] def handleResponse(request: BrokerToControllerQueueItem)(response: ClientResponse): Unit = { @@ -198,8 +204,7 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, override def doWork(): Unit = { if (activeController.isDefined) { - super.sendRequests(generateRequests()) - super.doWork() + super.pollOnce(Long.MaxValue) } else { debug("Controller isn't cached, looking for local metadata changes") val controllerOpt = metadataCache.getControllerId.flatMap(metadataCache.getAliveBroker) @@ -211,7 +216,7 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, } else { // need to backoff to avoid tight loops debug("No controller defined in metadata cache, retrying after backoff") - backoff() + super.pollOnce(100) } } } diff --git a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala index 0b621d7332a6a..d716cbd00f280 100644 --- a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala +++ b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala @@ -27,15 +27,34 @@ import org.apache.kafka.common.requests.AbstractRequest import org.easymock.EasyMock import org.junit.{Assert, Test} +import scala.collection.mutable + class InterBrokerSendThreadTest { private val time = new MockTime() private val networkClient: NetworkClient = EasyMock.createMock(classOf[NetworkClient]) private val completionHandler = new StubCompletionHandler private val requestTimeoutMs = 1000 + class TestInterBrokerSendThread( + ) extends InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) { + private val queue = mutable.Queue[RequestAndCompletionHandler]() + + def enqueue(request: RequestAndCompletionHandler): Unit = { + queue += request + } + + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { + if (queue.isEmpty) { + None + } else { + Some(queue.dequeue()) + } + } + } + @Test def shouldNotSendAnythingWhenNoRequests(): Unit = { - val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) + val sendThread = new TestInterBrokerSendThread() // poll is always called but there should be no further invocations on NetworkClient EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong())) @@ -53,13 +72,11 @@ class InterBrokerSendThreadTest { def shouldCreateClientRequestAndSendWhenNodeIsReady(): Unit = { val request = new StubRequestBuilder() val node = new Node(1, "", 8080) - val handler = RequestAndCompletionHandler(node, request, completionHandler) - val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) + val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) + val sendThread = new TestInterBrokerSendThread() val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) - EasyMock.expect(networkClient.wakeup()) - EasyMock.expect(networkClient.newClientRequest( EasyMock.eq("1"), EasyMock.same(handler.request), @@ -79,7 +96,7 @@ class InterBrokerSendThreadTest { EasyMock.replay(networkClient) - sendThread.sendRequest(handler) + sendThread.enqueue(handler) sendThread.doWork() EasyMock.verify(networkClient) @@ -90,13 +107,11 @@ class InterBrokerSendThreadTest { def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = { val request = new StubRequestBuilder val node = new Node(1, "", 8080) - val handler = RequestAndCompletionHandler(node, request, completionHandler) - val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) + val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) + val sendThread = new TestInterBrokerSendThread() val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler) - EasyMock.expect(networkClient.wakeup()) - EasyMock.expect(networkClient.newClientRequest( EasyMock.eq("1"), EasyMock.same(handler.request), @@ -123,7 +138,7 @@ class InterBrokerSendThreadTest { EasyMock.replay(networkClient) - sendThread.sendRequest(handler) + sendThread.enqueue(handler) sendThread.doWork() EasyMock.verify(networkClient) @@ -134,8 +149,8 @@ class InterBrokerSendThreadTest { def testFailingExpiredRequests(): Unit = { val request = new StubRequestBuilder() val node = new Node(1, "", 8080) - val handler = RequestAndCompletionHandler(node, request, completionHandler) - val sendThread = new InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) + val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler) + val sendThread = new TestInterBrokerSendThread() val clientRequest = new ClientRequest("dest", request, @@ -147,12 +162,10 @@ class InterBrokerSendThreadTest { handler.handler) time.sleep(1500) - EasyMock.expect(networkClient.wakeup()) - EasyMock.expect(networkClient.newClientRequest( EasyMock.eq("1"), EasyMock.same(handler.request), - EasyMock.eq(time.milliseconds()), + EasyMock.eq(handler.creationTimeMs), EasyMock.eq(true), EasyMock.eq(requestTimeoutMs), EasyMock.same(handler.handler))) @@ -174,7 +187,7 @@ class InterBrokerSendThreadTest { EasyMock.replay(networkClient) - sendThread.sendRequest(handler) + sendThread.enqueue(handler) sendThread.doWork() EasyMock.verify(networkClient) diff --git a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala index 9ef9d7c33bb8e..ee2326605f7db 100644 --- a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala +++ b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala @@ -17,19 +17,20 @@ package kafka.server -import java.util.concurrent.{CountDownLatch, LinkedBlockingDeque, TimeUnit} import java.util.Collections +import java.util.concurrent.atomic.AtomicBoolean + import kafka.cluster.{Broker, EndPoint} import kafka.utils.TestUtils import org.apache.kafka.clients.{ClientResponse, ManualMetadataUpdater, Metadata, MockClient} import org.apache.kafka.common.feature.Features import org.apache.kafka.common.feature.Features.emptySupportedFeatures -import org.apache.kafka.common.utils.{MockTime, SystemTime} import org.apache.kafka.common.message.MetadataRequestData import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests.{AbstractRequest, MetadataRequest, MetadataResponse, RequestTestUtils} import org.apache.kafka.common.security.auth.SecurityProtocol +import org.apache.kafka.common.utils.{MockTime, SystemTime} import org.junit.Assert.{assertEquals, assertFalse, assertTrue} import org.junit.Test import org.mockito.Mockito._ @@ -46,7 +47,6 @@ class BrokerToControllerRequestThreadTest { val metadata = mock(classOf[Metadata]) val mockClient = new MockClient(time, metadata) - val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() val metadataCache = mock(classOf[MetadataCache]) val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) val activeController = new Broker(controllerId, @@ -57,22 +57,24 @@ class BrokerToControllerRequestThreadTest { when(metadataCache.getAliveBroker(controllerId)).thenReturn(Some(activeController)) val expectedResponse = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap("a", 2)) - val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), requestQueue, metadataCache, + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, config, listenerName, time, "") mockClient.prepareResponse(expectedResponse) - val responseLatch = new CountDownLatch(1) + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) val queueItem = BrokerToControllerQueueItem( new MetadataRequest.Builder(new MetadataRequestData()), - new TestRequestCompletionHandler(expectedResponse, responseLatch), - Long.MaxValue) - requestQueue.put(queueItem) + completionHandler, + Long.MaxValue + ) + + testRequestThread.enqueue(queueItem) // initialize to the controller testRequestThread.doWork() // send and process the request testRequestThread.doWork() - assertTrue(responseLatch.await(10, TimeUnit.SECONDS)) + assertTrue(completionHandler.completed.get()) } @Test @@ -86,7 +88,6 @@ class BrokerToControllerRequestThreadTest { val metadata = mock(classOf[Metadata]) val mockClient = new MockClient(time, metadata) - val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() val metadataCache = mock(classOf[MetadataCache]) val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) val oldController = new Broker(oldControllerId, @@ -102,34 +103,31 @@ class BrokerToControllerRequestThreadTest { val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), - requestQueue, metadataCache, config, listenerName, time, "") - - val responseLatch = new CountDownLatch(1) + metadataCache, config, listenerName, time, "") + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) val queueItem = BrokerToControllerQueueItem( new MetadataRequest.Builder(new MetadataRequestData()), - new TestRequestCompletionHandler(expectedResponse, responseLatch), - Long.MaxValue) - requestQueue.put(queueItem) + completionHandler, + Long.MaxValue + ) + + testRequestThread.enqueue(queueItem) mockClient.prepareResponse(expectedResponse) // initialize the thread with oldController testRequestThread.doWork() - // assert queue correctness - assertFalse(requestQueue.isEmpty) - assertEquals(1, requestQueue.size()) - assertEquals(queueItem, requestQueue.peek()) + assertFalse(completionHandler.completed.get()) + // disconnect the node mockClient.setUnreachable(oldControllerNode, time.milliseconds() + 5000) // verify that the client closed the connection to the faulty controller testRequestThread.doWork() - assertFalse(requestQueue.isEmpty) - assertEquals(1, requestQueue.size()) // should connect to the new controller testRequestThread.doWork() // should send the request and process the response testRequestThread.doWork() - assertTrue(responseLatch.await(10, TimeUnit.SECONDS)) + assertTrue(completionHandler.completed.get()) } @Test @@ -142,7 +140,6 @@ class BrokerToControllerRequestThreadTest { val metadata = mock(classOf[Metadata]) val mockClient = new MockClient(time, metadata) - val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() val metadataCache = mock(classOf[MetadataCache]) val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) val oldController = new Broker(oldControllerId, @@ -159,16 +156,17 @@ class BrokerToControllerRequestThreadTest { Collections.singletonMap("a", Errors.NOT_CONTROLLER), Collections.singletonMap("a", 2)) val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) - val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), requestQueue, metadataCache, + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, config, listenerName, time, "") - val responseLatch = new CountDownLatch(1) + val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) val queueItem = BrokerToControllerQueueItem( new MetadataRequest.Builder(new MetadataRequestData() .setAllowAutoTopicCreation(true)), - new TestRequestCompletionHandler(expectedResponse, responseLatch), - Long.MaxValue) - requestQueue.put(queueItem) + completionHandler, + Long.MaxValue + ) + testRequestThread.enqueue(queueItem) // initialize to the controller testRequestThread.doWork() // send and process the request @@ -183,7 +181,7 @@ class BrokerToControllerRequestThreadTest { mockClient.prepareResponse(expectedResponse) testRequestThread.doWork() - assertTrue(responseLatch.await(10, TimeUnit.SECONDS)) + assertTrue(completionHandler.completed.get()) } @Test @@ -195,7 +193,6 @@ class BrokerToControllerRequestThreadTest { val metadata = mock(classOf[Metadata]) val mockClient = new MockClient(time, metadata) - val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() val metadataCache = mock(classOf[MetadataCache]) val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) val controller = new Broker(controllerId, @@ -208,21 +205,19 @@ class BrokerToControllerRequestThreadTest { val responseWithNotControllerError = RequestTestUtils.metadataUpdateWith("cluster1", 2, Collections.singletonMap("a", Errors.NOT_CONTROLLER), Collections.singletonMap("a", 2)) - val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), requestQueue, metadataCache, + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, config, listenerName, time, "") - val responseLatch = new CountDownLatch(1) val requestTimeout = config.requestTimeoutMs.longValue() + val completionHandler = new TestRequestCompletionHandler() val queueItem = BrokerToControllerQueueItem( new MetadataRequest.Builder(new MetadataRequestData() - .setAllowAutoTopicCreation(true)), new ControllerRequestCompletionHandler { - override def onComplete(response: ClientResponse): Unit = {} + .setAllowAutoTopicCreation(true)), + completionHandler, + requestTimeout + time.milliseconds() + ) - override def onTimeout(): Unit = { - responseLatch.countDown() - } - }, requestTimeout + time.milliseconds()) - requestQueue.put(queueItem) + testRequestThread.enqueue(queueItem) // initialize to the controller testRequestThread.doWork() @@ -237,21 +232,24 @@ class BrokerToControllerRequestThreadTest { testRequestThread.doWork() - // The queued item should be timed out, instead of - // re-enqueue by NOT_CONTROLLER error. - assertEquals(0, requestQueue.size()) - - assertTrue(responseLatch.await(10, TimeUnit.SECONDS)) + assertTrue(completionHandler.timedOut.get()) } - class TestRequestCompletionHandler(expectedResponse: MetadataResponse, - responseLatch: CountDownLatch) extends ControllerRequestCompletionHandler { + class TestRequestCompletionHandler( + expectedResponse: Option[MetadataResponse] = None + ) extends ControllerRequestCompletionHandler { + val completed: AtomicBoolean = new AtomicBoolean(false) + val timedOut: AtomicBoolean = new AtomicBoolean(false) + override def onComplete(response: ClientResponse): Unit = { - assertEquals(expectedResponse, response.responseBody()) - responseLatch.countDown() + expectedResponse.foreach { expected => + assertEquals(expected, response.responseBody()) + } + completed.set(true) } override def onTimeout(): Unit = { + timedOut.set(true) } } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala index 9043893a4d49e..3788cb1d65325 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionCoordinatorConcurrencyTest.scala @@ -385,7 +385,7 @@ class TransactionCoordinatorConcurrencyTest extends AbstractCoordinatorConcurren new WriteTxnMarkersResponse(pidErrorMap) } synchronized { - txnMarkerChannelManager.drainQueuedTransactionMarkers().foreach { requestAndHandler => + txnMarkerChannelManager.generateRequests().foreach { requestAndHandler => val request = requestAndHandler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build() val response = createResponse(request) requestAndHandler.handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index 5c82a5613007a..1746147457fc2 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -132,7 +132,7 @@ class TransactionMarkerChannelManagerTest { response) TestUtils.waitUntilTrue(() => { - val requests = channelManager.drainQueuedTransactionMarkers() + val requests = channelManager.generateRequests() if (requests.nonEmpty) { assertEquals(1, requests.size) val request = requests.head @@ -156,7 +156,7 @@ class TransactionMarkerChannelManagerTest { @Test def shouldGenerateEmptyMapWhenNoRequestsOutstanding(): Unit = { - assertTrue(channelManager.drainQueuedTransactionMarkers().isEmpty) + assertTrue(channelManager.generateRequests().isEmpty) } @Test @@ -194,12 +194,12 @@ class TransactionMarkerChannelManagerTest { val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() - val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.drainQueuedTransactionMarkers().map { handler => + val requests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap assertEquals(Map(broker1 -> expectedBroker1Request, broker2 -> expectedBroker2Request), requests) - assertTrue(channelManager.drainQueuedTransactionMarkers().isEmpty) + assertTrue(channelManager.generateRequests().isEmpty) } @Test @@ -270,13 +270,13 @@ class TransactionMarkerChannelManagerTest { val expectedBroker2Request = new WriteTxnMarkersRequest.Builder(ApiKeys.WRITE_TXN_MARKERS.latestVersion(), asList(new WriteTxnMarkersRequest.TxnMarkerEntry(producerId1, producerEpoch, coordinatorEpoch, txnResult, asList(partition2)))).build() - val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.drainQueuedTransactionMarkers().map { handler => + val firstDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap assertEquals(Map(broker2 -> expectedBroker2Request), firstDrainedRequests) - val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.drainQueuedTransactionMarkers().map { handler => + val secondDrainedRequests: Map[Node, WriteTxnMarkersRequest] = channelManager.generateRequests().map { handler => (handler.destination, handler.request.asInstanceOf[WriteTxnMarkersRequest.Builder].build()) }.toMap @@ -354,7 +354,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.drainQueuedTransactionMarkers() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -401,7 +401,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.drainQueuedTransactionMarkers() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -450,7 +450,7 @@ class TransactionMarkerChannelManagerTest { channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, txnMetadata2, txnTransitionMetadata2) - val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.drainQueuedTransactionMarkers() + val requestAndHandlers: Iterable[RequestAndCompletionHandler] = channelManager.generateRequests() val response = new WriteTxnMarkersResponse(createPidErrorMap(Errors.NONE)) for (requestAndHandler <- requestAndHandlers) { @@ -459,7 +459,7 @@ class TransactionMarkerChannelManagerTest { } // call this again so that append log will be retried - channelManager.drainQueuedTransactionMarkers() + channelManager.generateRequests() EasyMock.verify(txnStateManager) diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java index ea46f5ed0cf17..dee765570d901 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java @@ -29,7 +29,8 @@ public interface RaftMessageQueue { * Block for the arrival of a new message. * * @param timeoutMs timeout in milliseconds to wait for a new event - * @return the event or null if the timeout was reached + * @return the event or null if either the timeout was reached or there was + * a call to {@link #wakeup()} before any events became available */ RaftMessage poll(long timeoutMs); From d0fd77f98411e1e28e113b04155c16ab42847879 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Mon, 14 Dec 2020 10:21:47 -0800 Subject: [PATCH 04/10] Use sentinel `RaftMessage` for wakeup --- .../raft/internals/BlockingMessageQueue.java | 44 +++++++++---------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java index 9fe99f82080b1..5d8f384a10f39 100644 --- a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java @@ -17,6 +17,7 @@ package org.apache.kafka.raft.internals; import org.apache.kafka.common.errors.InterruptException; +import org.apache.kafka.common.protocol.ApiMessage; import org.apache.kafka.raft.RaftMessage; import org.apache.kafka.raft.RaftMessageQueue; @@ -26,28 +27,39 @@ import java.util.concurrent.atomic.AtomicInteger; public class BlockingMessageQueue implements RaftMessageQueue { - private final BlockingQueue queue = new LinkedBlockingQueue<>(); + private static final RaftMessage WAKEUP_MESSAGE = new RaftMessage() { + @Override + public int correlationId() { + return 0; + } + + @Override + public ApiMessage data() { + return null; + } + }; + + private final BlockingQueue queue = new LinkedBlockingQueue<>(); private final AtomicInteger size = new AtomicInteger(0); @Override public RaftMessage poll(long timeoutMs) { try { - RaftEvent event = queue.poll(timeoutMs, TimeUnit.MILLISECONDS); - if (event instanceof MessageReceived) { - size.decrementAndGet(); - return ((MessageReceived) event).message; - } else { + RaftMessage message = queue.poll(timeoutMs, TimeUnit.MILLISECONDS); + if (message == null || message == WAKEUP_MESSAGE) { return null; + } else { + size.decrementAndGet(); + return message; } } catch (InterruptedException e) { throw new InterruptException(e); } - } @Override public void offer(RaftMessage message) { - queue.add(new MessageReceived(message)); + queue.add(message); size.incrementAndGet(); } @@ -58,21 +70,7 @@ public boolean isEmpty() { @Override public void wakeup() { - queue.add(Wakeup.INSTANCE); - } - - public interface RaftEvent { - } - - static final class MessageReceived implements RaftEvent { - private final RaftMessage message; - private MessageReceived(RaftMessage message) { - this.message = message; - } - } - - static final class Wakeup implements RaftEvent { - public static final Wakeup INSTANCE = new Wakeup(); + queue.add(WAKEUP_MESSAGE); } } From 53cf918db43b1178751fed1eaae20378c89dd289 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Tue, 15 Dec 2020 10:29:45 -0800 Subject: [PATCH 05/10] Fix timeout logic in `BrokerToControllerChannelManager` --- .../BrokerToControllerChannelManager.scala | 37 +++++++------- .../BrokerToControllerRequestThreadTest.scala | 51 +++++++++++++++---- .../apache/kafka/raft/KafkaRaftClient.java | 4 +- .../apache/kafka/raft/RaftMessageQueue.java | 4 +- .../raft/internals/BlockingMessageQueue.java | 2 +- .../apache/kafka/raft/MockMessageQueue.java | 2 +- .../internals/BlockingMessageQueueTest.java | 6 +-- 7 files changed, 69 insertions(+), 37 deletions(-) diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index 976f85c20bac2..13c24c2ba25e2 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -17,7 +17,7 @@ package kafka.server -import java.util.concurrent.{LinkedBlockingDeque, TimeUnit} +import java.util.concurrent.LinkedBlockingDeque import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler} import kafka.utils.Logging @@ -34,7 +34,7 @@ import scala.jdk.CollectionConverters._ /** * This class manages the connection between a broker and the controller. It runs a single - * {@link BrokerToControllerRequestThread} which uses the broker's metadata cache as its own metadata to find + * [[BrokerToControllerRequestThread]] which uses the broker's metadata cache as its own metadata to find * and connect to the controller. The channel is async and runs the network connection in the background. * The maximum number of in-flight requests are set to one to ensure orderly response from the controller, therefore * care must be taken to not block on outstanding requests for too long. @@ -170,20 +170,27 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, } override def generateRequests(): Iterable[RequestAndCompletionHandler] = { - Option(requestQueue.poll()).map { queueItem => - RequestAndCompletionHandler( - time.milliseconds(), - activeController.get, - queueItem.request, - handleResponse(queueItem) - ) + val currentTimeMs = time.milliseconds() + val requestIter = requestQueue.iterator() + while (requestIter.hasNext) { + val request = requestIter.next + if (currentTimeMs >= request.deadlineMs) { + request.callback.onTimeout() + requestIter.remove() + } else if (activeController.isDefined) { + return Some(RequestAndCompletionHandler( + time.milliseconds(), + activeController.get, + request.request, + handleResponse(request) + )) + } } + None } private[server] def handleResponse(request: BrokerToControllerQueueItem)(response: ClientResponse): Unit = { - if (hasTimedOut(request, response)) { - request.callback.onTimeout() - } else if (response.wasDisconnected()) { + if (response.wasDisconnected()) { activeController = None requestQueue.putFirst(request) } else if (response.responseBody().errorCounts().containsKey(Errors.NOT_CONTROLLER)) { @@ -196,12 +203,6 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, } } - private def hasTimedOut(request: BrokerToControllerQueueItem, response: ClientResponse): Boolean = { - response.receivedTimeMs() > request.deadlineMs - } - - private[server] def backoff(): Unit = pause(100, TimeUnit.MILLISECONDS) - override def doWork(): Unit = { if (activeController.isDefined) { super.pollOnce(Long.MaxValue) diff --git a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala index ee2326605f7db..d9beadd42870b 100644 --- a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala +++ b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala @@ -30,17 +30,48 @@ import org.apache.kafka.common.network.ListenerName import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests.{AbstractRequest, MetadataRequest, MetadataResponse, RequestTestUtils} import org.apache.kafka.common.security.auth.SecurityProtocol -import org.apache.kafka.common.utils.{MockTime, SystemTime} +import org.apache.kafka.common.utils.MockTime import org.junit.Assert.{assertEquals, assertFalse, assertTrue} import org.junit.Test import org.mockito.Mockito._ class BrokerToControllerRequestThreadTest { + @Test + def testRetryTimeoutWhileControllerNotAvailable(): Unit = { + val time = new MockTime() + val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) + val metadata = mock(classOf[Metadata]) + val mockClient = new MockClient(time, metadata) + val metadataCache = mock(classOf[MetadataCache]) + val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT) + + when(metadataCache.getControllerId).thenReturn(None) + + val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, + config, listenerName, time, "") + + val retryTimeout = 30000 + val completionHandler = new TestRequestCompletionHandler(None) + val queueItem = BrokerToControllerQueueItem( + new MetadataRequest.Builder(new MetadataRequestData()), + completionHandler, + time.milliseconds() + retryTimeout + ) + + testRequestThread.enqueue(queueItem) + testRequestThread.doWork() + + time.sleep(retryTimeout) + testRequestThread.doWork() + + assertTrue(completionHandler.timedOut.get) + } + @Test def testRequestsSent(): Unit = { // just a simple test that tests whether the request from 1 -> 2 is sent and the response callback is called - val time = new SystemTime + val time = new MockTime() val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) val controllerId = 2 @@ -80,7 +111,7 @@ class BrokerToControllerRequestThreadTest { @Test def testControllerChanged(): Unit = { // in this test the current broker is 1, and the controller changes from 2 -> 3 then back: 3 -> 2 - val time = new SystemTime + val time = new MockTime() val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) val oldControllerId = 1 val newControllerId = 2 @@ -132,7 +163,7 @@ class BrokerToControllerRequestThreadTest { @Test def testNotController(): Unit = { - val time = new SystemTime + val time = new MockTime() val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) val oldControllerId = 1 val newControllerId = 2 @@ -185,7 +216,7 @@ class BrokerToControllerRequestThreadTest { } @Test - def testRequestTimeout(): Unit = { + def testRetryTimeout(): Unit = { val time = new MockTime() val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181")) val controllerId = 1 @@ -208,24 +239,24 @@ class BrokerToControllerRequestThreadTest { val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, config, listenerName, time, "") - val requestTimeout = config.requestTimeoutMs.longValue() + val retryTimeout = 30000 val completionHandler = new TestRequestCompletionHandler() val queueItem = BrokerToControllerQueueItem( new MetadataRequest.Builder(new MetadataRequestData() .setAllowAutoTopicCreation(true)), completionHandler, - requestTimeout + time.milliseconds() + retryTimeout + time.milliseconds() ) testRequestThread.enqueue(queueItem) // initialize to the controller testRequestThread.doWork() + + time.sleep(retryTimeout) + // send and process the request mockClient.prepareResponse((body: AbstractRequest) => { - // Advance time to timeout the response - time.sleep(requestTimeout + 1) - body.isInstanceOf[MetadataRequest] && body.asInstanceOf[MetadataRequest].allowAutoTopicCreation() }, responseWithNotControllerError) diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java index 7c63401bb75c0..37eee73f2ce83 100644 --- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java +++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java @@ -1413,7 +1413,7 @@ private long maybeSendRequest( ); } - messageQueue.offer(response); + messageQueue.add(response); }); channel.send(requestMessage); @@ -1826,7 +1826,7 @@ private void wakeup() { * @param request The inbound request */ public void handle(RaftRequest.Inbound request) { - messageQueue.offer(Objects.requireNonNull(request)); + messageQueue.add(Objects.requireNonNull(request)); } /** diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java index dee765570d901..7d1e4b7598a52 100644 --- a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java +++ b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java @@ -35,12 +35,12 @@ public interface RaftMessageQueue { RaftMessage poll(long timeoutMs); /** - * Offer a new message to the queue. + * Add a new message to the queue. * * @param message the message to deliver * @throws IllegalStateException if the queue cannot accept the message */ - void offer(RaftMessage message); + void add(RaftMessage message); /** * Check whether there are pending messages awaiting delivery. diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java index 5d8f384a10f39..9343cca8d47c5 100644 --- a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java +++ b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java @@ -58,7 +58,7 @@ public RaftMessage poll(long timeoutMs) { } @Override - public void offer(RaftMessage message) { + public void add(RaftMessage message) { queue.add(message); size.incrementAndGet(); } diff --git a/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java index d1c73e0f9b065..5fcd599b95f1e 100644 --- a/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java +++ b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java @@ -38,7 +38,7 @@ public RaftMessage poll(long timeoutMs) { } @Override - public void offer(RaftMessage message) { + public void add(RaftMessage message) { messages.offer(message); } diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java index 4aafe5b12866a..e752fbd6dd37d 100644 --- a/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java +++ b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java @@ -34,15 +34,15 @@ public void testOfferAndPoll() { assertNull(queue.poll(0)); RaftMessage message1 = Mockito.mock(RaftMessage.class); - queue.offer(message1); + queue.add(message1); assertFalse(queue.isEmpty()); assertEquals(message1, queue.poll(0)); assertTrue(queue.isEmpty()); RaftMessage message2 = Mockito.mock(RaftMessage.class); RaftMessage message3 = Mockito.mock(RaftMessage.class); - queue.offer(message2); - queue.offer(message3); + queue.add(message2); + queue.add(message3); assertFalse(queue.isEmpty()); assertEquals(message2, queue.poll(0)); assertEquals(message3, queue.poll(0)); From 2999fc7d174d24df9041ce86f1d599c2333a04b1 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Tue, 15 Dec 2020 14:20:18 -0800 Subject: [PATCH 06/10] Factor retry deadline out of `sendRequest` --- .../scala/kafka/server/AlterIsrManager.scala | 66 +++++++++++++---- .../BrokerToControllerChannelManager.scala | 73 +++++++++++-------- .../kafka/server/ForwardingManager.scala | 63 ++++++++++++---- .../main/scala/kafka/server/KafkaServer.scala | 51 +++++++------ .../BrokerToControllerRequestThreadTest.scala | 36 ++++----- .../kafka/server/AlterIsrManagerTest.scala | 27 ++++--- .../kafka/server/ForwardingManagerTest.scala | 8 +- .../scala/unit/kafka/utils/TestUtils.scala | 2 - 8 files changed, 210 insertions(+), 116 deletions(-) diff --git a/core/src/main/scala/kafka/server/AlterIsrManager.scala b/core/src/main/scala/kafka/server/AlterIsrManager.scala index fa4616bd202b0..7e29e03c8d7cc 100644 --- a/core/src/main/scala/kafka/server/AlterIsrManager.scala +++ b/core/src/main/scala/kafka/server/AlterIsrManager.scala @@ -22,11 +22,12 @@ import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import kafka.api.LeaderAndIsr import kafka.metrics.KafkaMetricsGroup -import kafka.utils.{Logging, Scheduler} +import kafka.utils.{KafkaScheduler, Logging, Scheduler} import kafka.zk.KafkaZkClient import org.apache.kafka.clients.ClientResponse import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.message.{AlterIsrRequestData, AlterIsrResponseData} +import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests.{AlterIsrRequest, AlterIsrResponse} import org.apache.kafka.common.utils.Time @@ -44,7 +45,9 @@ import scala.jdk.CollectionConverters._ * requests. */ trait AlterIsrManager { - def start(): Unit + def start(): Unit = {} + + def shutdown(): Unit = {} def submit(alterIsrItem: AlterIsrItem): Boolean @@ -57,30 +60,57 @@ case class AlterIsrItem(topicPartition: TopicPartition, controllerEpoch: Int) // controllerEpoch needed for Zk impl object AlterIsrManager { + /** * Factory to AlterIsr based implementation, used when IBP >= 2.7-IV2 */ - def apply(controllerChannelManager: BrokerToControllerChannelManager, - scheduler: Scheduler, - time: Time, - brokerId: Int, - brokerEpochSupplier: () => Long): AlterIsrManager = { - new DefaultAlterIsrManager(controllerChannelManager, scheduler, time, brokerId, brokerEpochSupplier) + def apply( + config: KafkaConfig, + metadataCache: MetadataCache, + scheduler: KafkaScheduler, + time: Time, + metrics: Metrics, + threadNamePrefix: Option[String], + brokerEpochSupplier: () => Long + ): AlterIsrManager = { + val channelManager = new BrokerToControllerChannelManager( + metadataCache = metadataCache, + time = time, + metrics = metrics, + config = config, + channelName = "forwardingChannel", + threadNamePrefix = threadNamePrefix, + retryTimeoutMs = Long.MaxValue + ) + new DefaultAlterIsrManager( + controllerChannelManager = channelManager, + scheduler = scheduler, + time = time, + brokerId = config.brokerId, + brokerEpochSupplier = brokerEpochSupplier + ) } /** * Factory for ZK based implementation, used when IBP < 2.7-IV2 */ - def apply(scheduler: Scheduler, time: Time, zkClient: KafkaZkClient): AlterIsrManager = { + def apply( + scheduler: Scheduler, + time: Time, + zkClient: KafkaZkClient + ): AlterIsrManager = { new ZkIsrManager(scheduler, time, zkClient) } + } -class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerChannelManager, - val scheduler: Scheduler, - val time: Time, - val brokerId: Int, - val brokerEpochSupplier: () => Long) extends AlterIsrManager with Logging with KafkaMetricsGroup { +class DefaultAlterIsrManager( + val controllerChannelManager: BrokerToControllerChannelManager, + val scheduler: Scheduler, + val time: Time, + val brokerId: Int, + val brokerEpochSupplier: () => Long +) extends AlterIsrManager with Logging with KafkaMetricsGroup { // Used to allow only one pending ISR update per partition private val unsentIsrUpdates: util.Map[TopicPartition, AlterIsrItem] = new ConcurrentHashMap[TopicPartition, AlterIsrItem]() @@ -91,9 +121,14 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha private val lastIsrPropagationMs = new AtomicLong(0) override def start(): Unit = { + controllerChannelManager.start() scheduler.schedule("send-alter-isr", propagateIsrChanges, 50, 50, TimeUnit.MILLISECONDS) } + override def shutdown(): Unit = { + controllerChannelManager.shutdown() + } + override def submit(alterIsrItem: AlterIsrItem): Boolean = { unsentIsrUpdates.putIfAbsent(alterIsrItem.topicPartition, alterIsrItem) == null } @@ -125,6 +160,7 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha } debug(s"Sending AlterIsr to controller $message") + // We will not timeout AlterISR request, instead letting it retry indefinitely // until a response is received, or a new LeaderAndIsr overwrites the existing isrState // which causes the inflight requests to be ignored. @@ -142,7 +178,7 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha override def onTimeout(): Unit = { throw new IllegalStateException("Encountered unexpected timeout when sending AlterIsr to the controller") } - }, Long.MaxValue) + }) } private def buildRequest(inflightAlterIsrItems: Seq[AlterIsrItem]): AlterIsrRequestData = { diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index 13c24c2ba25e2..5427cf2ac16e3 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -45,7 +45,8 @@ class BrokerToControllerChannelManager( metrics: Metrics, config: KafkaConfig, channelName: String, - threadNamePrefix: Option[String] = None + threadNamePrefix: Option[String], + retryTimeoutMs: Long ) extends Logging { private val logContext = new LogContext(s"[broker-${config.brokerId}-to-controller] ") private val manualMetadataUpdater = new ManualMetadataUpdater() @@ -111,8 +112,16 @@ class BrokerToControllerChannelManager( case Some(name) => s"$name:broker-${config.brokerId}-to-controller-send-thread" } - new BrokerToControllerRequestThread(networkClient, manualMetadataUpdater, metadataCache, config, - brokerToControllerListenerName, time, threadName) + new BrokerToControllerRequestThread( + networkClient, + manualMetadataUpdater, + metadataCache, + config, + brokerToControllerListenerName, + time, + threadName, + retryTimeoutMs + ) } /** @@ -120,19 +129,15 @@ class BrokerToControllerChannelManager( * * @param request The request to be sent. * @param callback Request completion callback. - * @param retryDeadlineMs The retry deadline which will only be checked after receiving a response. - * This means that in the worst case, the total timeout would be twice of - * the configured timeout. */ def sendRequest( request: AbstractRequest.Builder[_ <: AbstractRequest], - callback: ControllerRequestCompletionHandler, - retryDeadlineMs: Long + callback: ControllerRequestCompletionHandler ): Unit = { requestThread.enqueue(BrokerToControllerQueueItem( + time.milliseconds(), request, - callback, - retryDeadlineMs + callback )) } } @@ -146,18 +151,22 @@ abstract class ControllerRequestCompletionHandler extends RequestCompletionHandl def onTimeout(): Unit } -case class BrokerToControllerQueueItem(request: AbstractRequest.Builder[_ <: AbstractRequest], - callback: ControllerRequestCompletionHandler, - deadlineMs: Long) +case class BrokerToControllerQueueItem( + createdTimeMs: Long, + request: AbstractRequest.Builder[_ <: AbstractRequest], + callback: ControllerRequestCompletionHandler +) -class BrokerToControllerRequestThread(networkClient: KafkaClient, - metadataUpdater: ManualMetadataUpdater, - metadataCache: kafka.server.MetadataCache, - config: KafkaConfig, - listenerName: ListenerName, - time: Time, - threadName: String) - extends InterBrokerSendThread(threadName, networkClient, config.controllerSocketTimeoutMs, time, isInterruptible = false) { +class BrokerToControllerRequestThread( + networkClient: KafkaClient, + metadataUpdater: ManualMetadataUpdater, + metadataCache: kafka.server.MetadataCache, + config: KafkaConfig, + listenerName: ListenerName, + time: Time, + threadName: String, + retryTimeoutMs: Long +) extends InterBrokerSendThread(threadName, networkClient, config.controllerSocketTimeoutMs, time, isInterruptible = false) { private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]() private var activeController: Option[Node] = None @@ -174,7 +183,7 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, val requestIter = requestQueue.iterator() while (requestIter.hasNext) { val request = requestIter.next - if (currentTimeMs >= request.deadlineMs) { + if (currentTimeMs - request.createdTimeMs >= retryTimeoutMs) { request.callback.onTimeout() requestIter.remove() } else if (activeController.isDefined) { @@ -209,15 +218,17 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient, } else { debug("Controller isn't cached, looking for local metadata changes") val controllerOpt = metadataCache.getControllerId.flatMap(metadataCache.getAliveBroker) - if (controllerOpt.isDefined) { - if (activeController.isEmpty || activeController.exists(_.id != controllerOpt.get.id)) - info(s"Recorded new controller, from now on will use broker ${controllerOpt.get.id}") - activeController = Option(controllerOpt.get.node(listenerName)) - metadataUpdater.setNodes(metadataCache.getAliveBrokers.map(_.node(listenerName)).asJava) - } else { - // need to backoff to avoid tight loops - debug("No controller defined in metadata cache, retrying after backoff") - super.pollOnce(100) + controllerOpt match { + case Some(controller) => + info(s"Recorded new controller, from now on will use broker $controller") + val controllerNode = controller.node(listenerName) + activeController = Some(controllerNode) + metadataUpdater.setNodes(Seq(controllerNode).asJava) + + case None => + // need to backoff to avoid tight loops + debug("No controller defined in metadata cache, retrying after backoff") + super.pollOnce(maxTimeoutMs = 100) } } } diff --git a/core/src/main/scala/kafka/server/ForwardingManager.scala b/core/src/main/scala/kafka/server/ForwardingManager.scala index 54bef3eefc8bf..9ee24ea859faa 100644 --- a/core/src/main/scala/kafka/server/ForwardingManager.scala +++ b/core/src/main/scala/kafka/server/ForwardingManager.scala @@ -22,6 +22,7 @@ import java.nio.ByteBuffer import kafka.network.RequestChannel import kafka.utils.Logging import org.apache.kafka.clients.ClientResponse +import org.apache.kafka.common.metrics.Metrics import org.apache.kafka.common.protocol.Errors import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, EnvelopeRequest, EnvelopeResponse, RequestHeader} import org.apache.kafka.common.utils.Time @@ -29,12 +30,55 @@ import org.apache.kafka.common.utils.Time import scala.compat.java8.OptionConverters._ import scala.concurrent.TimeoutException -class ForwardingManager(channelManager: BrokerToControllerChannelManager, - time: Time, - retryTimeoutMs: Long) extends Logging { +trait ForwardingManager { + def forwardRequest( + request: RequestChannel.Request, + responseCallback: AbstractResponse => Unit + ): Unit - def forwardRequest(request: RequestChannel.Request, - responseCallback: AbstractResponse => Unit): Unit = { + def start(): Unit = {} + + def shutdown(): Unit = {} +} + +object ForwardingManager { + private val ThreadNamePrefix = "controller-forwarder" + + def apply( + config: KafkaConfig, + metadataCache: MetadataCache, + time: Time, + metrics: Metrics + ): ForwardingManager = { + val channelManager = new BrokerToControllerChannelManager( + metadataCache = metadataCache, + time = time, + metrics = metrics, + config = config, + channelName = "forwardingChannel", + threadNamePrefix = Some(ThreadNamePrefix), + retryTimeoutMs = config.requestTimeoutMs.longValue + ) + new ForwardingManagerImpl(channelManager) + } +} + +class ForwardingManagerImpl( + channelManager: BrokerToControllerChannelManager +) extends ForwardingManager with Logging { + + override def start(): Unit = { + channelManager.start() + } + + override def shutdown(): Unit = { + channelManager.shutdown() + } + + override def forwardRequest( + request: RequestChannel.Request, + responseCallback: AbstractResponse => Unit + ): Unit = { val principalSerde = request.context.principalSerde.asScala.getOrElse( throw new IllegalArgumentException(s"Cannot deserialize principal from request $request " + "since there is no serde defined") @@ -75,14 +119,7 @@ class ForwardingManager(channelManager: BrokerToControllerChannelManager, } } - val currentTime = time.milliseconds() - val deadlineMs = - if (Long.MaxValue - currentTime < retryTimeoutMs) - Long.MaxValue - else - currentTime + retryTimeoutMs - - channelManager.sendRequest(envelopeRequest, new ForwardingResponseHandler, deadlineMs) + channelManager.sendRequest(envelopeRequest, new ForwardingResponseHandler) } private def parseResponse( diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index a270e803a5b43..60a66303c0d0a 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -50,8 +50,8 @@ import org.apache.kafka.common.{ClusterResource, Endpoint, Node} import org.apache.kafka.server.authorizer.Authorizer import org.apache.zookeeper.client.ZKClientConfig -import scala.jdk.CollectionConverters._ import scala.collection.{Map, Seq, mutable} +import scala.jdk.CollectionConverters._ object KafkaServer { // Copy the subset of properties that are relevant to Logs @@ -168,9 +168,9 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP var kafkaController: KafkaController = null - var forwardingChannelManager: BrokerToControllerChannelManager = null + var forwardingManager: ForwardingManager = null - var alterIsrChannelManager: BrokerToControllerChannelManager = null + var alterIsrManager: AlterIsrManager = null var kafkaScheduler: KafkaScheduler = null @@ -309,11 +309,23 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP socketServer.startup(startProcessingRequests = false) /* start replica manager */ - alterIsrChannelManager = new BrokerToControllerChannelManager( - metadataCache, time, metrics, config, "alterIsrChannel", threadNamePrefix) + alterIsrManager = if (config.interBrokerProtocolVersion.isAlterIsrSupported) { + AlterIsrManager( + config = config, + metadataCache = metadataCache, + scheduler = kafkaScheduler, + time = time, + metrics = metrics, + threadNamePrefix = threadNamePrefix, + brokerEpochSupplier = () => kafkaController.brokerEpoch + ) + } else { + AlterIsrManager(kafkaScheduler, time, zkClient) + } + alterIsrManager.start() + replicaManager = createReplicaManager(isShuttingDown) replicaManager.startup() - alterIsrChannelManager.start() val brokerInfo = createBrokerInfo val brokerEpoch = zkClient.registerBroker(brokerInfo) @@ -329,13 +341,14 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP kafkaController = new KafkaController(config, zkClient, time, metrics, brokerInfo, brokerEpoch, tokenManager, brokerFeatures, featureCache, threadNamePrefix) kafkaController.startup() - var forwardingManager: ForwardingManager = null if (config.metadataQuorumEnabled) { - /* start forwarding manager */ - forwardingChannelManager = new BrokerToControllerChannelManager(metadataCache, time, metrics, - config, "forwardingChannel", threadNamePrefix) - forwardingChannelManager.start() - forwardingManager = new ForwardingManager(forwardingChannelManager, time, config.requestTimeoutMs.longValue()) + forwardingManager = ForwardingManager( + config, + metadataCache, + time, + metrics + ) + forwardingManager.start() } adminManager = new AdminManager(config, metrics, metadataCache, zkClient) @@ -444,12 +457,6 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP } protected def createReplicaManager(isShuttingDown: AtomicBoolean): ReplicaManager = { - val alterIsrManager: AlterIsrManager = if (config.interBrokerProtocolVersion.isAlterIsrSupported) { - AlterIsrManager(alterIsrChannelManager, kafkaScheduler, - time, config.brokerId, () => kafkaController.brokerEpoch) - } else { - AlterIsrManager(kafkaScheduler, time, zkClient) - } new ReplicaManager(config, metrics, time, zkClient, kafkaScheduler, logManager, isShuttingDown, quotaManagers, brokerTopicStats, metadataCache, logDirFailureChannel, alterIsrManager) } @@ -730,11 +737,11 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP if (replicaManager != null) CoreUtils.swallow(replicaManager.shutdown(), this) - if (alterIsrChannelManager != null) - CoreUtils.swallow(alterIsrChannelManager.shutdown(), this) + if (alterIsrManager != null) + CoreUtils.swallow(alterIsrManager.shutdown(), this) - if (forwardingChannelManager != null) - CoreUtils.swallow(forwardingChannelManager.shutdown(), this) + if (forwardingManager != null) + CoreUtils.swallow(forwardingManager.shutdown(), this) if (logManager != null) CoreUtils.swallow(logManager.shutdown(), this) diff --git a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala index d9beadd42870b..e44a38c2825ec 100644 --- a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala +++ b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala @@ -48,21 +48,21 @@ class BrokerToControllerRequestThreadTest { when(metadataCache.getControllerId).thenReturn(None) + val retryTimeoutMs = 30000 val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, - config, listenerName, time, "") + config, listenerName, time, "", retryTimeoutMs) - val retryTimeout = 30000 val completionHandler = new TestRequestCompletionHandler(None) val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), new MetadataRequest.Builder(new MetadataRequestData()), - completionHandler, - time.milliseconds() + retryTimeout + completionHandler ) testRequestThread.enqueue(queueItem) testRequestThread.doWork() - time.sleep(retryTimeout) + time.sleep(retryTimeoutMs) testRequestThread.doWork() assertTrue(completionHandler.timedOut.get) @@ -89,14 +89,14 @@ class BrokerToControllerRequestThreadTest { val expectedResponse = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap("a", 2)) val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, - config, listenerName, time, "") + config, listenerName, time, "", retryTimeoutMs = Long.MaxValue) mockClient.prepareResponse(expectedResponse) val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), new MetadataRequest.Builder(new MetadataRequestData()), - completionHandler, - Long.MaxValue + completionHandler ) testRequestThread.enqueue(queueItem) @@ -134,13 +134,13 @@ class BrokerToControllerRequestThreadTest { val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), - metadataCache, config, listenerName, time, "") + metadataCache, config, listenerName, time, "", retryTimeoutMs = Long.MaxValue) val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), new MetadataRequest.Builder(new MetadataRequestData()), completionHandler, - Long.MaxValue ) testRequestThread.enqueue(queueItem) @@ -188,14 +188,14 @@ class BrokerToControllerRequestThreadTest { Collections.singletonMap("a", 2)) val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2)) val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, - config, listenerName, time, "") + config, listenerName, time, "", retryTimeoutMs = Long.MaxValue) val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse)) val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), new MetadataRequest.Builder(new MetadataRequestData() .setAllowAutoTopicCreation(true)), - completionHandler, - Long.MaxValue + completionHandler ) testRequestThread.enqueue(queueItem) // initialize to the controller @@ -233,19 +233,19 @@ class BrokerToControllerRequestThreadTest { when(metadataCache.getAliveBrokers).thenReturn(Seq(controller)) when(metadataCache.getAliveBroker(controllerId)).thenReturn(Some(controller)) + val retryTimeoutMs = 30000 val responseWithNotControllerError = RequestTestUtils.metadataUpdateWith("cluster1", 2, Collections.singletonMap("a", Errors.NOT_CONTROLLER), Collections.singletonMap("a", 2)) val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache, - config, listenerName, time, "") + config, listenerName, time, "", retryTimeoutMs) - val retryTimeout = 30000 val completionHandler = new TestRequestCompletionHandler() val queueItem = BrokerToControllerQueueItem( + time.milliseconds(), new MetadataRequest.Builder(new MetadataRequestData() .setAllowAutoTopicCreation(true)), - completionHandler, - retryTimeout + time.milliseconds() + completionHandler ) testRequestThread.enqueue(queueItem) @@ -253,7 +253,7 @@ class BrokerToControllerRequestThreadTest { // initialize to the controller testRequestThread.doWork() - time.sleep(retryTimeout) + time.sleep(retryTimeoutMs) // send and process the request mockClient.prepareResponse((body: AbstractRequest) => { diff --git a/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala b/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala index a2e269b3217d9..8c5657e4152c6 100644 --- a/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala @@ -15,12 +15,11 @@ * limitations under the License. */ -package unit.kafka.server +package kafka.server import java.util.Collections import java.util.concurrent.atomic.AtomicInteger import kafka.api.LeaderAndIsr -import kafka.server.{AlterIsrItem, AlterIsrManager, BrokerToControllerChannelManager, ControllerRequestCompletionHandler, DefaultAlterIsrManager, ZkIsrManager} import kafka.utils.{MockScheduler, MockTime} import kafka.zk.KafkaZkClient import org.apache.kafka.clients.ClientResponse @@ -42,7 +41,6 @@ class AlterIsrManagerTest { val time = new MockTime val metrics = new Metrics val brokerId = 1 - val requestTimeout = Long.MaxValue var brokerToController: BrokerToControllerChannelManager = _ @@ -57,7 +55,8 @@ class AlterIsrManagerTest { @Test def testBasic(): Unit = { - EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.anyObject())).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) @@ -73,7 +72,8 @@ class AlterIsrManagerTest { @Test def testOverwriteWithinBatch(): Unit = { val capture = EasyMock.newCapture[AbstractRequest.Builder[AlterIsrRequest]]() - EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject(), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject())).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) @@ -97,7 +97,8 @@ class AlterIsrManagerTest { @Test def testSingleBatch(): Unit = { val capture = EasyMock.newCapture[AbstractRequest.Builder[AlterIsrRequest]]() - EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject(), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject())).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) @@ -151,7 +152,8 @@ class AlterIsrManagerTest { def testTopLevelError(isrs: Seq[AlterIsrItem], error: Errors): AlterIsrManager = { val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() - EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) @@ -184,7 +186,8 @@ class AlterIsrManagerTest { def testPartitionError(tp: TopicPartition, error: Errors): AlterIsrManager = { val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() EasyMock.reset(brokerToController) - EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) @@ -226,7 +229,8 @@ class AlterIsrManagerTest { def testOneInFlight(): Unit = { val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() EasyMock.reset(brokerToController) - EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) @@ -253,7 +257,7 @@ class AlterIsrManagerTest { callbackCapture.getValue.onComplete(resp) EasyMock.reset(brokerToController) - EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() EasyMock.replay(brokerToController) time.sleep(100) @@ -265,7 +269,8 @@ class AlterIsrManagerTest { def testPartitionMissingInResponse(): Unit = { val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]() EasyMock.reset(brokerToController) - EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once() + EasyMock.expect(brokerToController.start()) + EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once() EasyMock.replay(brokerToController) val scheduler = new MockScheduler(time) diff --git a/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala b/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala index 5001f671c02f8..ca03359f90758 100644 --- a/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala @@ -19,6 +19,7 @@ package kafka.server import java.net.InetAddress import java.nio.ByteBuffer import java.util.Optional + import kafka.network import kafka.network.RequestChannel import kafka.utils.MockTime @@ -34,7 +35,7 @@ import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuild import org.junit.Assert._ import org.junit.Test import org.mockito.ArgumentMatchers._ -import org.mockito.{ArgumentMatchers, Mockito} +import org.mockito.Mockito import scala.jdk.CollectionConverters._ @@ -45,7 +46,7 @@ class ForwardingManagerTest { @Test def testResponseCorrelationIdMismatch(): Unit = { - val forwardingManager = new ForwardingManager(brokerToController, time, Long.MaxValue) + val forwardingManager = new ForwardingManagerImpl(brokerToController) val requestCorrelationId = 27 val envelopeCorrelationId = 39 val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client") @@ -64,8 +65,7 @@ class ForwardingManagerTest { Mockito.when(brokerToController.sendRequest( any(classOf[EnvelopeRequest.Builder]), - any(classOf[ControllerRequestCompletionHandler]), - ArgumentMatchers.eq(Long.MaxValue) + any(classOf[ControllerRequestCompletionHandler]) )).thenAnswer(invocation => { val completionHandler = invocation.getArgument[RequestCompletionHandler](1) val response = buildEnvelopeResponse(responseBuffer, envelopeCorrelationId, completionHandler) diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala index 462e802707c25..e58e65b2a30b1 100755 --- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala +++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala @@ -1082,8 +1082,6 @@ object TestUtils extends Logging { inFlight.set(false); } - override def start(): Unit = { } - def completeIsrUpdate(newZkVersion: Int): Unit = { if (inFlight.compareAndSet(true, false)) { val item = isrUpdates.head From 67633d5b4a9b5aaa99e390e242f3139d27220d25 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Tue, 15 Dec 2020 15:48:27 -0800 Subject: [PATCH 07/10] We should use `disconnect` so that we get responses --- .../scala/kafka/server/BrokerToControllerChannelManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index 5427cf2ac16e3..a0818aa3dbbe9 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -204,7 +204,7 @@ class BrokerToControllerRequestThread( requestQueue.putFirst(request) } else if (response.responseBody().errorCounts().containsKey(Errors.NOT_CONTROLLER)) { // just close the controller connection and wait for metadata cache update in doWork - networkClient.close(activeController.get.idString) + networkClient.disconnect(activeController.get.idString) activeController = None requestQueue.putFirst(request) } else { From e113993296e77b12fe0fcbd17ec73ba4ddff52ba Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Wed, 16 Dec 2020 09:43:34 -0800 Subject: [PATCH 08/10] Fix startup/send bugs --- core/src/main/scala/kafka/server/AlterIsrManager.scala | 1 + .../kafka/server/BrokerToControllerChannelManager.scala | 7 ++++++- .../kafka/server/BrokerToControllerRequestThreadTest.scala | 6 +++++- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/kafka/server/AlterIsrManager.scala b/core/src/main/scala/kafka/server/AlterIsrManager.scala index 7e29e03c8d7cc..746c53f73d42c 100644 --- a/core/src/main/scala/kafka/server/AlterIsrManager.scala +++ b/core/src/main/scala/kafka/server/AlterIsrManager.scala @@ -168,6 +168,7 @@ class DefaultAlterIsrManager( new ControllerRequestCompletionHandler { override def onComplete(response: ClientResponse): Unit = { try { + debug(s"Received AlterIsr response $response") val body = response.responseBody().asInstanceOf[AlterIsrResponse] handleAlterIsrResponse(body, message.brokerEpoch, inflightAlterIsrItems) } finally { diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala index a0818aa3dbbe9..dac882ef58a5a 100644 --- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala +++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala @@ -178,15 +178,20 @@ class BrokerToControllerRequestThread( } } + def queueSize: Int = { + requestQueue.size + } + override def generateRequests(): Iterable[RequestAndCompletionHandler] = { val currentTimeMs = time.milliseconds() val requestIter = requestQueue.iterator() while (requestIter.hasNext) { val request = requestIter.next if (currentTimeMs - request.createdTimeMs >= retryTimeoutMs) { - request.callback.onTimeout() requestIter.remove() + request.callback.onTimeout() } else if (activeController.isDefined) { + requestIter.remove() return Some(RequestAndCompletionHandler( time.milliseconds(), activeController.get, diff --git a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala index e44a38c2825ec..7b65bea556151 100644 --- a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala +++ b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala @@ -61,10 +61,11 @@ class BrokerToControllerRequestThreadTest { testRequestThread.enqueue(queueItem) testRequestThread.doWork() + assertEquals(1, testRequestThread.queueSize) time.sleep(retryTimeoutMs) testRequestThread.doWork() - + assertEquals(0, testRequestThread.queueSize) assertTrue(completionHandler.timedOut.get) } @@ -100,11 +101,14 @@ class BrokerToControllerRequestThreadTest { ) testRequestThread.enqueue(queueItem) + assertEquals(1, testRequestThread.queueSize) + // initialize to the controller testRequestThread.doWork() // send and process the request testRequestThread.doWork() + assertEquals(0, testRequestThread.queueSize) assertTrue(completionHandler.completed.get()) } From e9574f3b1170653348a3d8f403d3b6a996ae91a3 Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Wed, 16 Dec 2020 11:39:00 -0800 Subject: [PATCH 09/10] Fix broken channel name --- core/src/main/scala/kafka/server/AlterIsrManager.scala | 2 +- core/src/main/scala/kafka/server/ForwardingManager.scala | 6 +++--- core/src/main/scala/kafka/server/KafkaServer.scala | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/kafka/server/AlterIsrManager.scala b/core/src/main/scala/kafka/server/AlterIsrManager.scala index 746c53f73d42c..e463487167df6 100644 --- a/core/src/main/scala/kafka/server/AlterIsrManager.scala +++ b/core/src/main/scala/kafka/server/AlterIsrManager.scala @@ -78,7 +78,7 @@ object AlterIsrManager { time = time, metrics = metrics, config = config, - channelName = "forwardingChannel", + channelName = "alterIsrChannel", threadNamePrefix = threadNamePrefix, retryTimeoutMs = Long.MaxValue ) diff --git a/core/src/main/scala/kafka/server/ForwardingManager.scala b/core/src/main/scala/kafka/server/ForwardingManager.scala index 9ee24ea859faa..e4261a73a3aa1 100644 --- a/core/src/main/scala/kafka/server/ForwardingManager.scala +++ b/core/src/main/scala/kafka/server/ForwardingManager.scala @@ -42,13 +42,13 @@ trait ForwardingManager { } object ForwardingManager { - private val ThreadNamePrefix = "controller-forwarder" def apply( config: KafkaConfig, metadataCache: MetadataCache, time: Time, - metrics: Metrics + metrics: Metrics, + threadNamePrefix: Option[String] ): ForwardingManager = { val channelManager = new BrokerToControllerChannelManager( metadataCache = metadataCache, @@ -56,7 +56,7 @@ object ForwardingManager { metrics = metrics, config = config, channelName = "forwardingChannel", - threadNamePrefix = Some(ThreadNamePrefix), + threadNamePrefix = threadNamePrefix, retryTimeoutMs = config.requestTimeoutMs.longValue ) new ForwardingManagerImpl(channelManager) diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala index 60a66303c0d0a..220484eabda99 100755 --- a/core/src/main/scala/kafka/server/KafkaServer.scala +++ b/core/src/main/scala/kafka/server/KafkaServer.scala @@ -346,7 +346,8 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP config, metadataCache, time, - metrics + metrics, + threadNamePrefix ) forwardingManager.start() } From c31ba33f2ffb1144e0acbe975ddf45cc5d4065bc Mon Sep 17 00:00:00 2001 From: Jason Gustafson Date: Mon, 21 Dec 2020 12:43:34 -0800 Subject: [PATCH 10/10] Remove start() call in ReplicaManager --- core/src/main/scala/kafka/server/ReplicaManager.scala | 1 - 1 file changed, 1 deletion(-) diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 5da3b692e05fd..012c0ea916bfc 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -293,7 +293,6 @@ class ReplicaManager(val config: KafkaConfig, // A follower can lag behind leader for up to config.replicaLagTimeMaxMs x 1.5 before it is removed from ISR scheduler.schedule("isr-expiration", maybeShrinkIsr _, period = config.replicaLagTimeMaxMs / 2, unit = TimeUnit.MILLISECONDS) scheduler.schedule("shutdown-idle-replica-alter-log-dirs-thread", shutdownIdleReplicaAlterLogDirsThread _, period = 10000L, unit = TimeUnit.MILLISECONDS) - alterIsrManager.start() // If inter-broker protocol (IBP) < 1.0, the controller will send LeaderAndIsrRequest V0 which does not include isNew field. // In this case, the broker receiving the request cannot determine whether it is safe to create a partition if a log directory has failed.