diff --git a/checkstyle/suppressions.xml b/checkstyle/suppressions.xml
index 853903487ce46..0e348d7b71b26 100644
--- a/checkstyle/suppressions.xml
+++ b/checkstyle/suppressions.xml
@@ -28,7 +28,7 @@
+ files="(Fetcher|Sender|SenderTest|ConsumerCoordinator|KafkaConsumer|KafkaProducer|Utils|TransactionManager|TransactionManagerTest|KafkaAdminClient|NetworkClient|Admin|KafkaRaftClient|KafkaRaftClientTest|RaftClientTestContext).java"/>
errorCounts, Errors error)
errorCounts.put(error, count + 1);
}
- protected abstract Message data();
-
/**
* Parse a response from the provided buffer. The buffer is expected to hold both
* the {@link ResponseHeader} as well as the response payload.
diff --git a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java
index d6e9f04e87eed..74041f0bade6a 100644
--- a/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java
+++ b/clients/src/test/java/org/apache/kafka/clients/admin/KafkaAdminClientTest.java
@@ -55,12 +55,12 @@
import org.apache.kafka.common.errors.InvalidTopicException;
import org.apache.kafka.common.errors.LeaderNotAvailableException;
import org.apache.kafka.common.errors.LogDirNotFoundException;
-import org.apache.kafka.common.errors.ThrottlingQuotaExceededException;
-import org.apache.kafka.common.errors.TimeoutException;
import org.apache.kafka.common.errors.NotLeaderOrFollowerException;
import org.apache.kafka.common.errors.OffsetOutOfRangeException;
import org.apache.kafka.common.errors.SaslAuthenticationException;
import org.apache.kafka.common.errors.SecurityDisabledException;
+import org.apache.kafka.common.errors.ThrottlingQuotaExceededException;
+import org.apache.kafka.common.errors.TimeoutException;
import org.apache.kafka.common.errors.TopicAuthorizationException;
import org.apache.kafka.common.errors.TopicDeletionDisabledException;
import org.apache.kafka.common.errors.TopicExistsException;
@@ -75,9 +75,9 @@
import org.apache.kafka.common.message.AlterReplicaLogDirsResponseData.AlterReplicaLogDirTopicResult;
import org.apache.kafka.common.message.AlterUserScramCredentialsResponseData;
import org.apache.kafka.common.message.ApiVersionsResponseData;
+import org.apache.kafka.common.message.CreateAclsResponseData;
import org.apache.kafka.common.message.CreatePartitionsResponseData;
import org.apache.kafka.common.message.CreatePartitionsResponseData.CreatePartitionsTopicResult;
-import org.apache.kafka.common.message.CreateAclsResponseData;
import org.apache.kafka.common.message.CreateTopicsResponseData;
import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResult;
import org.apache.kafka.common.message.CreateTopicsResponseData.CreatableTopicResultCollection;
@@ -108,16 +108,16 @@
import org.apache.kafka.common.message.ListOffsetsResponseData;
import org.apache.kafka.common.message.ListOffsetsResponseData.ListOffsetsTopicResponse;
import org.apache.kafka.common.message.ListPartitionReassignmentsResponseData;
-import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic;
import org.apache.kafka.common.message.MetadataResponseData.MetadataResponsePartition;
+import org.apache.kafka.common.message.MetadataResponseData.MetadataResponseTopic;
import org.apache.kafka.common.message.OffsetDeleteResponseData;
import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartition;
import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponsePartitionCollection;
import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopic;
import org.apache.kafka.common.message.OffsetDeleteResponseData.OffsetDeleteResponseTopicCollection;
import org.apache.kafka.common.protocol.ApiKeys;
+import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.protocol.Errors;
-import org.apache.kafka.common.protocol.Message;
import org.apache.kafka.common.quota.ClientQuotaAlteration;
import org.apache.kafka.common.quota.ClientQuotaEntity;
import org.apache.kafka.common.quota.ClientQuotaFilter;
@@ -4937,7 +4937,7 @@ public Map errorCounts() {
}
@Override
- protected Message data() {
+ public ApiMessage data() {
return null;
}
diff --git a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
index 11e1aa8de3920..0ef84265b5201 100644
--- a/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
+++ b/core/src/main/scala/kafka/common/InterBrokerSendThread.scala
@@ -16,8 +16,8 @@
*/
package kafka.common
-import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator}
import java.util.Map.Entry
+import java.util.{ArrayDeque, ArrayList, Collection, Collections, HashMap, Iterator}
import kafka.utils.ShutdownableThread
import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient, RequestCompletionHandler}
@@ -32,17 +32,19 @@ import scala.jdk.CollectionConverters._
/**
* Class for inter-broker send thread that utilize a non-blocking network client.
*/
-abstract class InterBrokerSendThread(name: String,
- networkClient: KafkaClient,
- time: Time,
- isInterruptible: Boolean = true)
- extends ShutdownableThread(name, isInterruptible) {
+abstract class InterBrokerSendThread(
+ name: String,
+ networkClient: KafkaClient,
+ requestTimeoutMs: Int,
+ time: Time,
+ isInterruptible: Boolean = true
+) extends ShutdownableThread(name, isInterruptible) {
- def generateRequests(): Iterable[RequestAndCompletionHandler]
- def requestTimeoutMs: Int
private val unsentRequests = new UnsentRequests
- def hasUnsentRequests = unsentRequests.iterator().hasNext
+ def generateRequests(): Iterable[RequestAndCompletionHandler]
+
+ def hasUnsentRequests: Boolean = unsentRequests.iterator().hasNext
override def shutdown(): Unit = {
initiateShutdown()
@@ -51,23 +53,25 @@ abstract class InterBrokerSendThread(name: String,
awaitShutdown()
}
- override def doWork(): Unit = {
- var now = time.milliseconds()
-
+ private def drainGeneratedRequests(): Unit = {
generateRequests().foreach { request =>
- val completionHandler = request.handler
unsentRequests.put(request.destination,
networkClient.newClientRequest(
request.destination.idString,
request.request,
- now,
+ request.creationTimeMs,
true,
requestTimeoutMs,
- completionHandler))
+ request.handler
+ ))
}
+ }
+ protected def pollOnce(maxTimeoutMs: Long): Unit = {
try {
- val timeout = sendRequests(now)
+ drainGeneratedRequests()
+ var now = time.milliseconds()
+ val timeout = sendRequests(now, maxTimeoutMs)
networkClient.poll(timeout, now)
now = time.milliseconds()
checkDisconnects(now)
@@ -85,8 +89,12 @@ abstract class InterBrokerSendThread(name: String,
}
}
- private def sendRequests(now: Long): Long = {
- var pollTimeout = Long.MaxValue
+ override def doWork(): Unit = {
+ pollOnce(Long.MaxValue)
+ }
+
+ private def sendRequests(now: Long, maxTimeoutMs: Long): Long = {
+ var pollTimeout = maxTimeoutMs
for (node <- unsentRequests.nodes.asScala) {
val requestIterator = unsentRequests.requestIterator(node)
while (requestIterator.hasNext) {
@@ -143,9 +151,12 @@ abstract class InterBrokerSendThread(name: String,
def wakeup(): Unit = networkClient.wakeup()
}
-case class RequestAndCompletionHandler(destination: Node,
- request: AbstractRequest.Builder[_ <: AbstractRequest],
- handler: RequestCompletionHandler)
+case class RequestAndCompletionHandler(
+ creationTimeMs: Long,
+ destination: Node,
+ request: AbstractRequest.Builder[_ <: AbstractRequest],
+ handler: RequestCompletionHandler
+)
private class UnsentRequests {
private val unsent = new HashMap[Node, ArrayDeque[ClientRequest]]
@@ -198,5 +209,5 @@ private class UnsentRequests {
requests.iterator
}
- def nodes = unsent.keySet
+ def nodes: java.util.Set[Node] = unsent.keySet
}
diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
index 029ded837a87d..ac582f37cf9d1 100644
--- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
+++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala
@@ -24,8 +24,8 @@ import kafka.api.KAFKA_2_8_IV0
import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.metrics.KafkaMetricsGroup
import kafka.server.{KafkaConfig, MetadataCache}
-import kafka.utils.{CoreUtils, Logging}
import kafka.utils.Implicits._
+import kafka.utils.{CoreUtils, Logging}
import org.apache.kafka.clients._
import org.apache.kafka.common.metrics.Metrics
import org.apache.kafka.common.network._
@@ -36,8 +36,8 @@ import org.apache.kafka.common.security.JaasContext
import org.apache.kafka.common.utils.{LogContext, Time}
import org.apache.kafka.common.{Node, Reconfigurable, TopicPartition}
-import scala.jdk.CollectionConverters._
import scala.collection.{concurrent, immutable}
+import scala.jdk.CollectionConverters._
object TransactionMarkerChannelManager {
def apply(config: KafkaConfig,
@@ -127,11 +127,14 @@ class TxnMarkerQueue(@volatile var destination: Node) {
def totalNumMarkers(txnTopicPartition: Int): Int = markersPerTxnTopicPartition.get(txnTopicPartition).fold(0)(_.size)
}
-class TransactionMarkerChannelManager(config: KafkaConfig,
- metadataCache: MetadataCache,
- networkClient: NetworkClient,
- txnStateManager: TransactionStateManager,
- time: Time) extends InterBrokerSendThread("TxnMarkerSenderThread-" + config.brokerId, networkClient, time) with Logging with KafkaMetricsGroup {
+class TransactionMarkerChannelManager(
+ config: KafkaConfig,
+ metadataCache: MetadataCache,
+ networkClient: NetworkClient,
+ txnStateManager: TransactionStateManager,
+ time: Time
+) extends InterBrokerSendThread("TxnMarkerSenderThread-" + config.brokerId, networkClient, config.requestTimeoutMs, time)
+ with Logging with KafkaMetricsGroup {
this.logIdent = "[Transaction Marker Channel Manager " + config.brokerId + "]: "
@@ -145,8 +148,6 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
private val transactionsWithPendingMarkers = new ConcurrentHashMap[String, PendingCompleteTxn]
- override val requestTimeoutMs: Int = config.requestTimeoutMs
-
val writeTxnMarkersRequestVersion: Short =
if (config.interBrokerProtocolVersion >= KAFKA_2_8_IV0) 1
else 0
@@ -154,8 +155,6 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
newGauge("UnknownDestinationQueueSize", () => markersQueueForUnknownBroker.totalNumMarkers)
newGauge("LogAppendRetryQueueSize", () => txnLogAppendRetryQueue.size)
- override def generateRequests() = drainQueuedTransactionMarkers()
-
override def shutdown(): Unit = {
super.shutdown()
markersQueuePerBroker.clear()
@@ -191,7 +190,7 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
}
}
- private[transaction] def drainQueuedTransactionMarkers(): Iterable[RequestAndCompletionHandler] = {
+ override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
retryLogAppends()
val txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]()
markersQueueForUnknownBroker.forEachTxnTopicPartition { case (_, queue) =>
@@ -209,6 +208,7 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
addTxnMarkersToBrokerQueue(transactionalId, producerId, producerEpoch, txnResult, coordinatorEpoch, topicPartitions)
}
+ val currentTimeMs = time.milliseconds()
markersQueuePerBroker.values.map { brokerRequestQueue =>
val txnIdAndMarkerEntries = new util.ArrayList[TxnIdAndMarkerEntry]()
brokerRequestQueue.forEachTxnTopicPartition { case (_, queue) =>
@@ -218,7 +218,14 @@ class TransactionMarkerChannelManager(config: KafkaConfig,
}.filter { case (_, entries) => !entries.isEmpty }.map { case (node, entries) =>
val markersToSend = entries.asScala.map(_.txnMarkerEntry).asJava
val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries)
- RequestAndCompletionHandler(node, new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend), requestCompletionHandler)
+ val request = new WriteTxnMarkersRequest.Builder(writeTxnMarkersRequestVersion, markersToSend)
+
+ RequestAndCompletionHandler(
+ currentTimeMs,
+ node,
+ request,
+ requestCompletionHandler
+ )
}
}
diff --git a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
index 7f769c8463e09..435167208f70a 100644
--- a/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
+++ b/core/src/main/scala/kafka/raft/KafkaNetworkChannel.scala
@@ -17,39 +17,23 @@
package kafka.raft
import java.net.InetSocketAddress
-import java.util
-import java.util.concurrent.ArrayBlockingQueue
+import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.atomic.AtomicInteger
+import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.utils.Logging
-import org.apache.kafka.clients.{ClientRequest, ClientResponse, KafkaClient}
+import org.apache.kafka.clients.{ClientResponse, KafkaClient}
+import org.apache.kafka.common.Node
import org.apache.kafka.common.message._
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
import org.apache.kafka.common.requests._
import org.apache.kafka.common.utils.Time
-import org.apache.kafka.common.{KafkaException, Node}
-import org.apache.kafka.raft.{NetworkChannel, RaftMessage, RaftRequest, RaftResponse, RaftUtil}
+import org.apache.kafka.raft.{NetworkChannel, RaftRequest, RaftResponse, RaftUtil}
import scala.collection.mutable
-import scala.jdk.CollectionConverters._
object KafkaNetworkChannel {
- private[raft] def buildResponse(responseData: ApiMessage): AbstractResponse = {
- responseData match {
- case voteResponse: VoteResponseData =>
- new VoteResponse(voteResponse)
- case beginEpochResponse: BeginQuorumEpochResponseData =>
- new BeginQuorumEpochResponse(beginEpochResponse)
- case endEpochResponse: EndQuorumEpochResponseData =>
- new EndQuorumEpochResponse(endEpochResponse)
- case fetchResponse: FetchResponseData =>
- new FetchResponse(fetchResponse)
- case _ =>
- throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData")
- }
- }
-
private[raft] def buildRequest(requestData: ApiMessage): AbstractRequest.Builder[_ <: AbstractRequest] = {
requestData match {
case voteRequest: VoteRequestData =>
@@ -68,161 +52,113 @@ object KafkaNetworkChannel {
}
}
- private[raft] def responseData(response: AbstractResponse): ApiMessage = {
- response match {
- case voteResponse: VoteResponse => voteResponse.data
- case beginEpochResponse: BeginQuorumEpochResponse => beginEpochResponse.data
- case endEpochResponse: EndQuorumEpochResponse => endEpochResponse.data
- case fetchResponse: FetchResponse[_] => fetchResponse.data
- case _ => throw new IllegalArgumentException(s"Unexpected type for response: $response")
+}
+
+private[raft] class RaftSendThread(
+ name: String,
+ networkClient: KafkaClient,
+ requestTimeoutMs: Int,
+ time: Time,
+ isInterruptible: Boolean = true
+) extends InterBrokerSendThread(
+ name,
+ networkClient,
+ requestTimeoutMs,
+ time,
+ isInterruptible
+) {
+ private val queue = new ConcurrentLinkedQueue[RequestAndCompletionHandler]()
+
+ def generateRequests(): Iterable[RequestAndCompletionHandler] = {
+ val buffer = mutable.Buffer[RequestAndCompletionHandler]()
+ while (true) {
+ val request = queue.poll()
+ if (request == null) {
+ return buffer
+ } else {
+ buffer += request
+ }
}
+ buffer
}
- private[raft] def requestData(request: AbstractRequest): ApiMessage = {
- request match {
- case voteRequest: VoteRequest => voteRequest.data
- case beginEpochRequest: BeginQuorumEpochRequest => beginEpochRequest.data
- case endEpochRequest: EndQuorumEpochRequest => endEpochRequest.data
- case fetchRequest: FetchRequest => fetchRequest.data
- case _ => throw new IllegalArgumentException(s"Unexpected type for request: $request")
- }
+ def sendRequest(request: RequestAndCompletionHandler): Unit = {
+ queue.add(request)
}
}
-class KafkaNetworkChannel(time: Time,
- client: KafkaClient,
- clientId: String,
- retryBackoffMs: Int,
- requestTimeoutMs: Int) extends NetworkChannel with Logging {
+
+class KafkaNetworkChannel(
+ time: Time,
+ client: KafkaClient,
+ requestTimeoutMs: Int
+) extends NetworkChannel with Logging {
import KafkaNetworkChannel._
type ResponseHandler = AbstractResponse => Unit
private val correlationIdCounter = new AtomicInteger(0)
- private val pendingInbound = mutable.Map.empty[Long, ResponseHandler]
- private val undelivered = new ArrayBlockingQueue[RaftMessage](10)
- private val pendingOutbound = new ArrayBlockingQueue[RaftRequest.Outbound](10)
private val endpoints = mutable.HashMap.empty[Int, Node]
- override def newCorrelationId(): Int = correlationIdCounter.getAndIncrement()
-
- private def buildClientRequest(req: RaftRequest.Outbound): ClientRequest = {
- val destination = req.destinationId.toString
- val request = buildRequest(req.data)
- val correlationId = req.correlationId
- val createdTimeMs = req.createdTimeMs
- new ClientRequest(destination, request, correlationId, clientId, createdTimeMs, true,
- requestTimeoutMs, null)
- }
-
- override def send(message: RaftMessage): Unit = {
- message match {
- case request: RaftRequest.Outbound =>
- if (!pendingOutbound.offer(request))
- throw new KafkaException("Pending outbound queue is full")
-
- case response: RaftResponse.Outbound =>
- pendingInbound.remove(response.correlationId).foreach { onResponseReceived: ResponseHandler =>
- onResponseReceived(buildResponse(response.data))
- }
- case _ =>
- throw new IllegalArgumentException("Unhandled message type " + message)
+ private val requestThread = new RaftSendThread(
+ name = "raft-outbound-request-thread",
+ networkClient = client,
+ requestTimeoutMs = requestTimeoutMs,
+ time = time,
+ isInterruptible = false
+ )
+
+ override def send(request: RaftRequest.Outbound): Unit = {
+ def completeFuture(message: ApiMessage): Unit = {
+ val response = new RaftResponse.Inbound(
+ request.correlationId,
+ message,
+ request.destinationId
+ )
+ request.completion.complete(response)
}
- }
- private def sendOutboundRequests(currentTimeMs: Long): Unit = {
- while (!pendingOutbound.isEmpty) {
- val request = pendingOutbound.peek()
- endpoints.get(request.destinationId) match {
- case Some(node) =>
- if (client.connectionFailed(node)) {
- pendingOutbound.poll()
- val apiKey = ApiKeys.forId(request.data.apiKey)
- val disconnectResponse = RaftUtil.errorResponse(apiKey, Errors.BROKER_NOT_AVAILABLE)
- val success = undelivered.offer(new RaftResponse.Inbound(
- request.correlationId, disconnectResponse, request.destinationId))
- if (!success) {
- throw new KafkaException("Undelivered queue is full")
- }
-
- // Make sure to reset the connection state
- client.ready(node, currentTimeMs)
- } else if (client.ready(node, currentTimeMs)) {
- pendingOutbound.poll()
- val clientRequest = buildClientRequest(request)
- client.send(clientRequest, currentTimeMs)
- } else {
- // We will retry this request on the next poll
- return
- }
-
- case None =>
- pendingOutbound.poll()
- val apiKey = ApiKeys.forId(request.data.apiKey)
- val responseData = RaftUtil.errorResponse(apiKey, Errors.BROKER_NOT_AVAILABLE)
- val response = new RaftResponse.Inbound(request.correlationId, responseData, request.destinationId)
- if (!undelivered.offer(response))
- throw new KafkaException("Undelivered queue is full")
+ def onComplete(clientResponse: ClientResponse): Unit = {
+ val response = if (clientResponse.authenticationException != null) {
+ errorResponse(request.data, Errors.CLUSTER_AUTHORIZATION_FAILED)
+ } else if (clientResponse.wasDisconnected()) {
+ errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE)
+ } else {
+ clientResponse.responseBody.data
}
+ completeFuture(response)
}
- }
-
- def getConnectionInfo(nodeId: Int): Node = {
- if (!endpoints.contains(nodeId))
- null
- else
- endpoints(nodeId)
- }
-
- def allConnections(): Set[Node] = {
- endpoints.values.toSet
- }
-
- private def buildInboundRaftResponse(response: ClientResponse): RaftResponse.Inbound = {
- val header = response.requestHeader()
- val data = if (response.authenticationException != null) {
- RaftUtil.errorResponse(header.apiKey, Errors.CLUSTER_AUTHORIZATION_FAILED)
- } else if (response.wasDisconnected) {
- RaftUtil.errorResponse(header.apiKey, Errors.BROKER_NOT_AVAILABLE)
- } else {
- responseData(response.responseBody)
- }
- new RaftResponse.Inbound(header.correlationId, data, response.destination.toInt)
- }
- private def pollInboundResponses(timeoutMs: Long, inboundMessages: util.List[RaftMessage]): Unit = {
- val responses = client.poll(timeoutMs, time.milliseconds())
- for (response <- responses.asScala) {
- inboundMessages.add(buildInboundRaftResponse(response))
+ endpoints.get(request.destinationId) match {
+ case Some(node) =>
+ requestThread.sendRequest(RequestAndCompletionHandler(
+ request.createdTimeMs,
+ destination = node,
+ request = buildRequest(request.data),
+ handler = onComplete
+ ))
+
+ case None =>
+ completeFuture(errorResponse(request.data, Errors.BROKER_NOT_AVAILABLE))
}
}
- private def drainInboundRequests(inboundMessages: util.List[RaftMessage]): Unit = {
- undelivered.drainTo(inboundMessages)
+ // Visible for testing
+ private[raft] def pollOnce(): Unit = {
+ requestThread.doWork()
}
- private def pollInboundMessages(timeoutMs: Long): util.List[RaftMessage] = {
- val pollTimeoutMs = if (!undelivered.isEmpty) {
- 0L
- } else if (!pendingOutbound.isEmpty) {
- retryBackoffMs
- } else {
- timeoutMs
- }
- val messages = new util.ArrayList[RaftMessage]
- pollInboundResponses(pollTimeoutMs, messages)
- drainInboundRequests(messages)
- messages
+ override def newCorrelationId(): Int = {
+ correlationIdCounter.getAndIncrement()
}
- override def receive(timeoutMs: Long): util.List[RaftMessage] = {
- sendOutboundRequests(time.milliseconds())
- pollInboundMessages(timeoutMs)
- }
-
- override def wakeup(): Unit = {
- client.wakeup()
+ private def errorResponse(
+ request: ApiMessage,
+ error: Errors
+ ): ApiMessage = {
+ val apiKey = ApiKeys.forId(request.apiKey)
+ RaftUtil.errorResponse(apiKey, error)
}
override def updateEndpoint(id: Int, address: InetSocketAddress): Unit = {
@@ -230,17 +166,16 @@ class KafkaNetworkChannel(time: Time,
endpoints.put(id, node)
}
- def postInboundRequest(request: AbstractRequest, onResponseReceived: ResponseHandler): Unit = {
- val data = requestData(request)
- val correlationId = newCorrelationId()
- val req = new RaftRequest.Inbound(correlationId, data, time.milliseconds())
- pendingInbound.put(correlationId, onResponseReceived)
- if (!undelivered.offer(req))
- throw new KafkaException("Undelivered queue is full")
- wakeup()
+ def start(): Unit = {
+ requestThread.start()
+ }
+
+ def initiateShutdown(): Unit = {
+ requestThread.initiateShutdown()
}
override def close(): Unit = {
+ requestThread.shutdown()
client.close()
}
diff --git a/core/src/main/scala/kafka/server/AlterIsrManager.scala b/core/src/main/scala/kafka/server/AlterIsrManager.scala
index fa4616bd202b0..e463487167df6 100644
--- a/core/src/main/scala/kafka/server/AlterIsrManager.scala
+++ b/core/src/main/scala/kafka/server/AlterIsrManager.scala
@@ -22,11 +22,12 @@ import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import kafka.api.LeaderAndIsr
import kafka.metrics.KafkaMetricsGroup
-import kafka.utils.{Logging, Scheduler}
+import kafka.utils.{KafkaScheduler, Logging, Scheduler}
import kafka.zk.KafkaZkClient
import org.apache.kafka.clients.ClientResponse
import org.apache.kafka.common.TopicPartition
import org.apache.kafka.common.message.{AlterIsrRequestData, AlterIsrResponseData}
+import org.apache.kafka.common.metrics.Metrics
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.requests.{AlterIsrRequest, AlterIsrResponse}
import org.apache.kafka.common.utils.Time
@@ -44,7 +45,9 @@ import scala.jdk.CollectionConverters._
* requests.
*/
trait AlterIsrManager {
- def start(): Unit
+ def start(): Unit = {}
+
+ def shutdown(): Unit = {}
def submit(alterIsrItem: AlterIsrItem): Boolean
@@ -57,30 +60,57 @@ case class AlterIsrItem(topicPartition: TopicPartition,
controllerEpoch: Int) // controllerEpoch needed for Zk impl
object AlterIsrManager {
+
/**
* Factory to AlterIsr based implementation, used when IBP >= 2.7-IV2
*/
- def apply(controllerChannelManager: BrokerToControllerChannelManager,
- scheduler: Scheduler,
- time: Time,
- brokerId: Int,
- brokerEpochSupplier: () => Long): AlterIsrManager = {
- new DefaultAlterIsrManager(controllerChannelManager, scheduler, time, brokerId, brokerEpochSupplier)
+ def apply(
+ config: KafkaConfig,
+ metadataCache: MetadataCache,
+ scheduler: KafkaScheduler,
+ time: Time,
+ metrics: Metrics,
+ threadNamePrefix: Option[String],
+ brokerEpochSupplier: () => Long
+ ): AlterIsrManager = {
+ val channelManager = new BrokerToControllerChannelManager(
+ metadataCache = metadataCache,
+ time = time,
+ metrics = metrics,
+ config = config,
+ channelName = "alterIsrChannel",
+ threadNamePrefix = threadNamePrefix,
+ retryTimeoutMs = Long.MaxValue
+ )
+ new DefaultAlterIsrManager(
+ controllerChannelManager = channelManager,
+ scheduler = scheduler,
+ time = time,
+ brokerId = config.brokerId,
+ brokerEpochSupplier = brokerEpochSupplier
+ )
}
/**
* Factory for ZK based implementation, used when IBP < 2.7-IV2
*/
- def apply(scheduler: Scheduler, time: Time, zkClient: KafkaZkClient): AlterIsrManager = {
+ def apply(
+ scheduler: Scheduler,
+ time: Time,
+ zkClient: KafkaZkClient
+ ): AlterIsrManager = {
new ZkIsrManager(scheduler, time, zkClient)
}
+
}
-class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerChannelManager,
- val scheduler: Scheduler,
- val time: Time,
- val brokerId: Int,
- val brokerEpochSupplier: () => Long) extends AlterIsrManager with Logging with KafkaMetricsGroup {
+class DefaultAlterIsrManager(
+ val controllerChannelManager: BrokerToControllerChannelManager,
+ val scheduler: Scheduler,
+ val time: Time,
+ val brokerId: Int,
+ val brokerEpochSupplier: () => Long
+) extends AlterIsrManager with Logging with KafkaMetricsGroup {
// Used to allow only one pending ISR update per partition
private val unsentIsrUpdates: util.Map[TopicPartition, AlterIsrItem] = new ConcurrentHashMap[TopicPartition, AlterIsrItem]()
@@ -91,9 +121,14 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha
private val lastIsrPropagationMs = new AtomicLong(0)
override def start(): Unit = {
+ controllerChannelManager.start()
scheduler.schedule("send-alter-isr", propagateIsrChanges, 50, 50, TimeUnit.MILLISECONDS)
}
+ override def shutdown(): Unit = {
+ controllerChannelManager.shutdown()
+ }
+
override def submit(alterIsrItem: AlterIsrItem): Boolean = {
unsentIsrUpdates.putIfAbsent(alterIsrItem.topicPartition, alterIsrItem) == null
}
@@ -125,6 +160,7 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha
}
debug(s"Sending AlterIsr to controller $message")
+
// We will not timeout AlterISR request, instead letting it retry indefinitely
// until a response is received, or a new LeaderAndIsr overwrites the existing isrState
// which causes the inflight requests to be ignored.
@@ -132,6 +168,7 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha
new ControllerRequestCompletionHandler {
override def onComplete(response: ClientResponse): Unit = {
try {
+ debug(s"Received AlterIsr response $response")
val body = response.responseBody().asInstanceOf[AlterIsrResponse]
handleAlterIsrResponse(body, message.brokerEpoch, inflightAlterIsrItems)
} finally {
@@ -142,7 +179,7 @@ class DefaultAlterIsrManager(val controllerChannelManager: BrokerToControllerCha
override def onTimeout(): Unit = {
throw new IllegalStateException("Encountered unexpected timeout when sending AlterIsr to the controller")
}
- }, Long.MaxValue)
+ })
}
private def buildRequest(inflightAlterIsrItems: Seq[AlterIsrItem]): AlterIsrRequestData = {
diff --git a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala
similarity index 56%
rename from core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala
rename to core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala
index c0918ad7d03e8..dac882ef58a5a 100644
--- a/core/src/main/scala/kafka/server/BrokerToControllerChannelManagerImpl.scala
+++ b/core/src/main/scala/kafka/server/BrokerToControllerChannelManager.scala
@@ -17,7 +17,7 @@
package kafka.server
-import java.util.concurrent.{LinkedBlockingDeque, TimeUnit}
+import java.util.concurrent.LinkedBlockingDeque
import kafka.common.{InterBrokerSendThread, RequestAndCompletionHandler}
import kafka.utils.Logging
@@ -30,52 +30,33 @@ import org.apache.kafka.common.requests.AbstractRequest
import org.apache.kafka.common.security.JaasContext
import org.apache.kafka.common.utils.{LogContext, Time}
-import scala.collection.mutable
import scala.jdk.CollectionConverters._
-trait BrokerToControllerChannelManager {
-
- /**
- * Send request to the controller.
- *
- * @param request The request to be sent.
- * @param callback Request completion callback.
- * @param retryDeadlineMs The retry deadline which will only be checked after receiving a response.
- * This means that in the worst case, the total timeout would be twice of
- * the configured timeout.
- */
- def sendRequest(request: AbstractRequest.Builder[_ <: AbstractRequest],
- callback: ControllerRequestCompletionHandler,
- retryDeadlineMs: Long): Unit
-
- def start(): Unit
-
- def shutdown(): Unit
-}
-
/**
* This class manages the connection between a broker and the controller. It runs a single
- * {@link BrokerToControllerRequestThread} which uses the broker's metadata cache as its own metadata to find
+ * [[BrokerToControllerRequestThread]] which uses the broker's metadata cache as its own metadata to find
* and connect to the controller. The channel is async and runs the network connection in the background.
* The maximum number of in-flight requests are set to one to ensure orderly response from the controller, therefore
* care must be taken to not block on outstanding requests for too long.
*/
-class BrokerToControllerChannelManagerImpl(metadataCache: kafka.server.MetadataCache,
- time: Time,
- metrics: Metrics,
- config: KafkaConfig,
- channelName: String,
- threadNamePrefix: Option[String] = None) extends BrokerToControllerChannelManager with Logging {
- private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]
+class BrokerToControllerChannelManager(
+ metadataCache: kafka.server.MetadataCache,
+ time: Time,
+ metrics: Metrics,
+ config: KafkaConfig,
+ channelName: String,
+ threadNamePrefix: Option[String],
+ retryTimeoutMs: Long
+) extends Logging {
private val logContext = new LogContext(s"[broker-${config.brokerId}-to-controller] ")
private val manualMetadataUpdater = new ManualMetadataUpdater()
private val requestThread = newRequestThread
- override def start(): Unit = {
+ def start(): Unit = {
requestThread.start()
}
- override def shutdown(): Unit = {
+ def shutdown(): Unit = {
requestThread.shutdown()
requestThread.awaitShutdown()
info(s"Broker to controller channel manager for $channelName shutdown")
@@ -131,15 +112,33 @@ class BrokerToControllerChannelManagerImpl(metadataCache: kafka.server.MetadataC
case Some(name) => s"$name:broker-${config.brokerId}-to-controller-send-thread"
}
- new BrokerToControllerRequestThread(networkClient, manualMetadataUpdater, requestQueue, metadataCache, config,
- brokerToControllerListenerName, time, threadName)
+ new BrokerToControllerRequestThread(
+ networkClient,
+ manualMetadataUpdater,
+ metadataCache,
+ config,
+ brokerToControllerListenerName,
+ time,
+ threadName,
+ retryTimeoutMs
+ )
}
- override def sendRequest(request: AbstractRequest.Builder[_ <: AbstractRequest],
- callback: ControllerRequestCompletionHandler,
- retryDeadlineMs: Long): Unit = {
- requestQueue.put(BrokerToControllerQueueItem(request, callback, retryDeadlineMs))
- requestThread.wakeup()
+ /**
+ * Send request to the controller.
+ *
+ * @param request The request to be sent.
+ * @param callback Request completion callback.
+ */
+ def sendRequest(
+ request: AbstractRequest.Builder[_ <: AbstractRequest],
+ callback: ControllerRequestCompletionHandler
+ ): Unit = {
+ requestThread.enqueue(BrokerToControllerQueueItem(
+ time.milliseconds(),
+ request,
+ callback
+ ))
}
}
@@ -152,48 +151,65 @@ abstract class ControllerRequestCompletionHandler extends RequestCompletionHandl
def onTimeout(): Unit
}
-case class BrokerToControllerQueueItem(request: AbstractRequest.Builder[_ <: AbstractRequest],
- callback: ControllerRequestCompletionHandler,
- deadlineMs: Long)
-
-class BrokerToControllerRequestThread(networkClient: KafkaClient,
- metadataUpdater: ManualMetadataUpdater,
- requestQueue: LinkedBlockingDeque[BrokerToControllerQueueItem],
- metadataCache: kafka.server.MetadataCache,
- config: KafkaConfig,
- listenerName: ListenerName,
- time: Time,
- threadName: String)
- extends InterBrokerSendThread(threadName, networkClient, time, isInterruptible = false) {
-
+case class BrokerToControllerQueueItem(
+ createdTimeMs: Long,
+ request: AbstractRequest.Builder[_ <: AbstractRequest],
+ callback: ControllerRequestCompletionHandler
+)
+
+class BrokerToControllerRequestThread(
+ networkClient: KafkaClient,
+ metadataUpdater: ManualMetadataUpdater,
+ metadataCache: kafka.server.MetadataCache,
+ config: KafkaConfig,
+ listenerName: ListenerName,
+ time: Time,
+ threadName: String,
+ retryTimeoutMs: Long
+) extends InterBrokerSendThread(threadName, networkClient, config.controllerSocketTimeoutMs, time, isInterruptible = false) {
+
+ private val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]()
private var activeController: Option[Node] = None
- override def requestTimeoutMs: Int = config.controllerSocketTimeoutMs
+ def enqueue(request: BrokerToControllerQueueItem): Unit = {
+ requestQueue.add(request)
+ if (activeController.isDefined) {
+ wakeup()
+ }
+ }
- override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
- val requestsToSend = new mutable.Queue[RequestAndCompletionHandler]
- val topRequest = requestQueue.poll()
- if (topRequest != null) {
- val request = RequestAndCompletionHandler(
- activeController.get,
- topRequest.request,
- handleResponse(topRequest)
- )
+ def queueSize: Int = {
+ requestQueue.size
+ }
- requestsToSend.enqueue(request)
+ override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
+ val currentTimeMs = time.milliseconds()
+ val requestIter = requestQueue.iterator()
+ while (requestIter.hasNext) {
+ val request = requestIter.next
+ if (currentTimeMs - request.createdTimeMs >= retryTimeoutMs) {
+ requestIter.remove()
+ request.callback.onTimeout()
+ } else if (activeController.isDefined) {
+ requestIter.remove()
+ return Some(RequestAndCompletionHandler(
+ time.milliseconds(),
+ activeController.get,
+ request.request,
+ handleResponse(request)
+ ))
+ }
}
- requestsToSend
+ None
}
private[server] def handleResponse(request: BrokerToControllerQueueItem)(response: ClientResponse): Unit = {
- if (hasTimedOut(request, response)) {
- request.callback.onTimeout()
- } else if (response.wasDisconnected()) {
+ if (response.wasDisconnected()) {
activeController = None
requestQueue.putFirst(request)
} else if (response.responseBody().errorCounts().containsKey(Errors.NOT_CONTROLLER)) {
// just close the controller connection and wait for metadata cache update in doWork
- networkClient.close(activeController.get.idString)
+ networkClient.disconnect(activeController.get.idString)
activeController = None
requestQueue.putFirst(request)
} else {
@@ -201,27 +217,23 @@ class BrokerToControllerRequestThread(networkClient: KafkaClient,
}
}
- private def hasTimedOut(request: BrokerToControllerQueueItem, response: ClientResponse): Boolean = {
- response.receivedTimeMs() > request.deadlineMs
- }
-
- private[server] def backoff(): Unit = pause(100, TimeUnit.MILLISECONDS)
-
override def doWork(): Unit = {
if (activeController.isDefined) {
- super.doWork()
+ super.pollOnce(Long.MaxValue)
} else {
debug("Controller isn't cached, looking for local metadata changes")
val controllerOpt = metadataCache.getControllerId.flatMap(metadataCache.getAliveBroker)
- if (controllerOpt.isDefined) {
- if (activeController.isEmpty || activeController.exists(_.id != controllerOpt.get.id))
- info(s"Recorded new controller, from now on will use broker ${controllerOpt.get.id}")
- activeController = Option(controllerOpt.get.node(listenerName))
- metadataUpdater.setNodes(metadataCache.getAliveBrokers.map(_.node(listenerName)).asJava)
- } else {
- // need to backoff to avoid tight loops
- debug("No controller defined in metadata cache, retrying after backoff")
- backoff()
+ controllerOpt match {
+ case Some(controller) =>
+ info(s"Recorded new controller, from now on will use broker $controller")
+ val controllerNode = controller.node(listenerName)
+ activeController = Some(controllerNode)
+ metadataUpdater.setNodes(Seq(controllerNode).asJava)
+
+ case None =>
+ // need to backoff to avoid tight loops
+ debug("No controller defined in metadata cache, retrying after backoff")
+ super.pollOnce(maxTimeoutMs = 100)
}
}
}
diff --git a/core/src/main/scala/kafka/server/ForwardingManager.scala b/core/src/main/scala/kafka/server/ForwardingManager.scala
index 54bef3eefc8bf..e4261a73a3aa1 100644
--- a/core/src/main/scala/kafka/server/ForwardingManager.scala
+++ b/core/src/main/scala/kafka/server/ForwardingManager.scala
@@ -22,6 +22,7 @@ import java.nio.ByteBuffer
import kafka.network.RequestChannel
import kafka.utils.Logging
import org.apache.kafka.clients.ClientResponse
+import org.apache.kafka.common.metrics.Metrics
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, EnvelopeRequest, EnvelopeResponse, RequestHeader}
import org.apache.kafka.common.utils.Time
@@ -29,12 +30,55 @@ import org.apache.kafka.common.utils.Time
import scala.compat.java8.OptionConverters._
import scala.concurrent.TimeoutException
-class ForwardingManager(channelManager: BrokerToControllerChannelManager,
- time: Time,
- retryTimeoutMs: Long) extends Logging {
+trait ForwardingManager {
+ def forwardRequest(
+ request: RequestChannel.Request,
+ responseCallback: AbstractResponse => Unit
+ ): Unit
- def forwardRequest(request: RequestChannel.Request,
- responseCallback: AbstractResponse => Unit): Unit = {
+ def start(): Unit = {}
+
+ def shutdown(): Unit = {}
+}
+
+object ForwardingManager {
+
+ def apply(
+ config: KafkaConfig,
+ metadataCache: MetadataCache,
+ time: Time,
+ metrics: Metrics,
+ threadNamePrefix: Option[String]
+ ): ForwardingManager = {
+ val channelManager = new BrokerToControllerChannelManager(
+ metadataCache = metadataCache,
+ time = time,
+ metrics = metrics,
+ config = config,
+ channelName = "forwardingChannel",
+ threadNamePrefix = threadNamePrefix,
+ retryTimeoutMs = config.requestTimeoutMs.longValue
+ )
+ new ForwardingManagerImpl(channelManager)
+ }
+}
+
+class ForwardingManagerImpl(
+ channelManager: BrokerToControllerChannelManager
+) extends ForwardingManager with Logging {
+
+ override def start(): Unit = {
+ channelManager.start()
+ }
+
+ override def shutdown(): Unit = {
+ channelManager.shutdown()
+ }
+
+ override def forwardRequest(
+ request: RequestChannel.Request,
+ responseCallback: AbstractResponse => Unit
+ ): Unit = {
val principalSerde = request.context.principalSerde.asScala.getOrElse(
throw new IllegalArgumentException(s"Cannot deserialize principal from request $request " +
"since there is no serde defined")
@@ -75,14 +119,7 @@ class ForwardingManager(channelManager: BrokerToControllerChannelManager,
}
}
- val currentTime = time.milliseconds()
- val deadlineMs =
- if (Long.MaxValue - currentTime < retryTimeoutMs)
- Long.MaxValue
- else
- currentTime + retryTimeoutMs
-
- channelManager.sendRequest(envelopeRequest, new ForwardingResponseHandler, deadlineMs)
+ channelManager.sendRequest(envelopeRequest, new ForwardingResponseHandler)
}
private def parseResponse(
diff --git a/core/src/main/scala/kafka/server/KafkaServer.scala b/core/src/main/scala/kafka/server/KafkaServer.scala
index 1e7bb19c6dc4e..220484eabda99 100755
--- a/core/src/main/scala/kafka/server/KafkaServer.scala
+++ b/core/src/main/scala/kafka/server/KafkaServer.scala
@@ -50,8 +50,8 @@ import org.apache.kafka.common.{ClusterResource, Endpoint, Node}
import org.apache.kafka.server.authorizer.Authorizer
import org.apache.zookeeper.client.ZKClientConfig
-import scala.jdk.CollectionConverters._
import scala.collection.{Map, Seq, mutable}
+import scala.jdk.CollectionConverters._
object KafkaServer {
// Copy the subset of properties that are relevant to Logs
@@ -168,9 +168,9 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
var kafkaController: KafkaController = null
- var forwardingChannelManager: BrokerToControllerChannelManager = null
+ var forwardingManager: ForwardingManager = null
- var alterIsrChannelManager: BrokerToControllerChannelManager = null
+ var alterIsrManager: AlterIsrManager = null
var kafkaScheduler: KafkaScheduler = null
@@ -309,11 +309,23 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
socketServer.startup(startProcessingRequests = false)
/* start replica manager */
- alterIsrChannelManager = new BrokerToControllerChannelManagerImpl(
- metadataCache, time, metrics, config, "alterIsrChannel", threadNamePrefix)
+ alterIsrManager = if (config.interBrokerProtocolVersion.isAlterIsrSupported) {
+ AlterIsrManager(
+ config = config,
+ metadataCache = metadataCache,
+ scheduler = kafkaScheduler,
+ time = time,
+ metrics = metrics,
+ threadNamePrefix = threadNamePrefix,
+ brokerEpochSupplier = () => kafkaController.brokerEpoch
+ )
+ } else {
+ AlterIsrManager(kafkaScheduler, time, zkClient)
+ }
+ alterIsrManager.start()
+
replicaManager = createReplicaManager(isShuttingDown)
replicaManager.startup()
- alterIsrChannelManager.start()
val brokerInfo = createBrokerInfo
val brokerEpoch = zkClient.registerBroker(brokerInfo)
@@ -329,13 +341,15 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
kafkaController = new KafkaController(config, zkClient, time, metrics, brokerInfo, brokerEpoch, tokenManager, brokerFeatures, featureCache, threadNamePrefix)
kafkaController.startup()
- var forwardingManager: ForwardingManager = null
if (config.metadataQuorumEnabled) {
- /* start forwarding manager */
- forwardingChannelManager = new BrokerToControllerChannelManagerImpl(metadataCache, time, metrics,
- config, "forwardingChannel", threadNamePrefix)
- forwardingChannelManager.start()
- forwardingManager = new ForwardingManager(forwardingChannelManager, time, config.requestTimeoutMs.longValue())
+ forwardingManager = ForwardingManager(
+ config,
+ metadataCache,
+ time,
+ metrics,
+ threadNamePrefix
+ )
+ forwardingManager.start()
}
adminManager = new AdminManager(config, metrics, metadataCache, zkClient)
@@ -444,12 +458,6 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
}
protected def createReplicaManager(isShuttingDown: AtomicBoolean): ReplicaManager = {
- val alterIsrManager: AlterIsrManager = if (config.interBrokerProtocolVersion.isAlterIsrSupported) {
- AlterIsrManager(alterIsrChannelManager, kafkaScheduler,
- time, config.brokerId, () => kafkaController.brokerEpoch)
- } else {
- AlterIsrManager(kafkaScheduler, time, zkClient)
- }
new ReplicaManager(config, metrics, time, zkClient, kafkaScheduler, logManager, isShuttingDown, quotaManagers,
brokerTopicStats, metadataCache, logDirFailureChannel, alterIsrManager)
}
@@ -730,11 +738,11 @@ class KafkaServer(val config: KafkaConfig, time: Time = Time.SYSTEM, threadNameP
if (replicaManager != null)
CoreUtils.swallow(replicaManager.shutdown(), this)
- if (alterIsrChannelManager != null)
- CoreUtils.swallow(alterIsrChannelManager.shutdown(), this)
+ if (alterIsrManager != null)
+ CoreUtils.swallow(alterIsrManager.shutdown(), this)
- if (forwardingChannelManager != null)
- CoreUtils.swallow(forwardingChannelManager.shutdown(), this)
+ if (forwardingManager != null)
+ CoreUtils.swallow(forwardingManager.shutdown(), this)
if (logManager != null)
CoreUtils.swallow(logManager.shutdown(), this)
diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala
index 5da3b692e05fd..012c0ea916bfc 100644
--- a/core/src/main/scala/kafka/server/ReplicaManager.scala
+++ b/core/src/main/scala/kafka/server/ReplicaManager.scala
@@ -293,7 +293,6 @@ class ReplicaManager(val config: KafkaConfig,
// A follower can lag behind leader for up to config.replicaLagTimeMaxMs x 1.5 before it is removed from ISR
scheduler.schedule("isr-expiration", maybeShrinkIsr _, period = config.replicaLagTimeMaxMs / 2, unit = TimeUnit.MILLISECONDS)
scheduler.schedule("shutdown-idle-replica-alter-log-dirs-thread", shutdownIdleReplicaAlterLogDirsThread _, period = 10000L, unit = TimeUnit.MILLISECONDS)
- alterIsrManager.start()
// If inter-broker protocol (IBP) < 1.0, the controller will send LeaderAndIsrRequest V0 which does not include isNew field.
// In this case, the broker receiving the request cannot determine whether it is safe to create a partition if a log directory has failed.
diff --git a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala
index 4fad0db4f1e83..f8ab30cac8dc2 100644
--- a/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala
+++ b/core/src/main/scala/kafka/tools/TestRaftRequestHandler.scala
@@ -19,13 +19,15 @@ package kafka.tools
import kafka.network.RequestChannel
import kafka.network.RequestConvertToJson
-import kafka.raft.KafkaNetworkChannel
import kafka.server.ApiRequestHandler
import kafka.utils.Logging
import org.apache.kafka.common.internals.FatalExitError
-import org.apache.kafka.common.protocol.{ApiKeys, Errors}
-import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse}
+import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData}
+import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
+import org.apache.kafka.common.record.BaseRecords
+import org.apache.kafka.common.requests.{AbstractRequest, AbstractResponse, BeginQuorumEpochResponse, EndQuorumEpochResponse, FetchResponse, VoteResponse}
import org.apache.kafka.common.utils.Time
+import org.apache.kafka.raft.{KafkaRaftClient, RaftRequest}
import scala.jdk.CollectionConverters._
@@ -33,7 +35,7 @@ import scala.jdk.CollectionConverters._
* Simple request handler implementation for use by [[TestRaftServer]].
*/
class TestRaftRequestHandler(
- networkChannel: KafkaNetworkChannel,
+ raftClient: KafkaRaftClient[_],
requestChannel: RequestChannel,
time: Time,
) extends ApiRequestHandler with Logging {
@@ -43,13 +45,10 @@ class TestRaftRequestHandler(
trace(s"Handling request:${request.requestDesc(true)} from connection ${request.context.connectionId};" +
s"securityProtocol:${request.context.securityProtocol},principal:${request.context.principal}")
request.header.apiKey match {
- case ApiKeys.VOTE
- | ApiKeys.BEGIN_QUORUM_EPOCH
- | ApiKeys.END_QUORUM_EPOCH
- | ApiKeys.FETCH =>
- val requestBody = request.body[AbstractRequest]
- networkChannel.postInboundRequest(requestBody, response =>
- sendResponse(request, Some(response)))
+ case ApiKeys.VOTE => handleVote(request)
+ case ApiKeys.BEGIN_QUORUM_EPOCH => handleBeginQuorumEpoch(request)
+ case ApiKeys.END_QUORUM_EPOCH => handleEndQuorumEpoch(request)
+ case ApiKeys.FETCH => handleFetch(request)
case _ => throw new IllegalArgumentException(s"Unsupported api key: ${request.header.apiKey}")
}
@@ -63,6 +62,45 @@ class TestRaftRequestHandler(
}
}
+ private def handleVote(request: RequestChannel.Request): Unit = {
+ handle(request, response => new VoteResponse(response.asInstanceOf[VoteResponseData]))
+ }
+
+ private def handleBeginQuorumEpoch(request: RequestChannel.Request): Unit = {
+ handle(request, response => new BeginQuorumEpochResponse(response.asInstanceOf[BeginQuorumEpochResponseData]))
+ }
+
+ private def handleEndQuorumEpoch(request: RequestChannel.Request): Unit = {
+ handle(request, response => new EndQuorumEpochResponse(response.asInstanceOf[EndQuorumEpochResponseData]))
+ }
+
+ private def handleFetch(request: RequestChannel.Request): Unit = {
+ handle(request, response => new FetchResponse[BaseRecords](response.asInstanceOf[FetchResponseData]))
+ }
+
+ private def handle(
+ request: RequestChannel.Request,
+ buildResponse: ApiMessage => AbstractResponse
+ ): Unit = {
+ val requestBody = request.body[AbstractRequest]
+ val inboundRequest = new RaftRequest.Inbound(
+ request.header.correlationId,
+ requestBody.data,
+ time.milliseconds()
+ )
+
+ inboundRequest.completion.whenComplete((response, exception) => {
+ val res = if (exception != null) {
+ requestBody.getErrorResponse(exception)
+ } else {
+ buildResponse(response.data)
+ }
+ sendResponse(request, Some(res))
+ })
+
+ raftClient.handle(inboundRequest)
+ }
+
private def handleError(request: RequestChannel.Request, err: Throwable): Unit = {
error("Error when handling request: " +
s"clientId=${request.header.clientId}, " +
diff --git a/core/src/main/scala/kafka/tools/TestRaftServer.scala b/core/src/main/scala/kafka/tools/TestRaftServer.scala
index 8d179dccacd02..fb1f7c5b04686 100644
--- a/core/src/main/scala/kafka/tools/TestRaftServer.scala
+++ b/core/src/main/scala/kafka/tools/TestRaftServer.scala
@@ -95,6 +95,8 @@ class TestRaftServer(
val metadataLog = buildMetadataLog(logDir)
val networkChannel = buildNetworkChannel(raftConfig, logContext)
+ networkChannel.start()
+
val raftClient = buildRaftClient(
raftConfig,
metadataLog,
@@ -114,7 +116,7 @@ class TestRaftServer(
raftClient.initialize()
val requestHandler = new TestRaftRequestHandler(
- networkChannel,
+ raftClient,
socketServer.dataPlaneRequestChannel,
time
)
@@ -163,9 +165,7 @@ class TestRaftServer(
private def buildNetworkChannel(raftConfig: RaftConfig,
logContext: LogContext): KafkaNetworkChannel = {
val netClient = buildNetworkClient(raftConfig, logContext)
- val clientId = s"Raft-${config.brokerId}"
- new KafkaNetworkChannel(time, netClient, clientId,
- raftConfig.retryBackoffMs, raftConfig.requestTimeoutMs)
+ new KafkaNetworkChannel(time, netClient, raftConfig.requestTimeoutMs)
}
private def buildMetadataLog(logDir: File): KafkaMetadataLog = {
diff --git a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala
index f5110bf0b03e0..d716cbd00f280 100644
--- a/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala
+++ b/core/src/test/scala/kafka/common/InterBrokerSendThreadTest.scala
@@ -35,12 +35,26 @@ class InterBrokerSendThreadTest {
private val completionHandler = new StubCompletionHandler
private val requestTimeoutMs = 1000
+ class TestInterBrokerSendThread(
+ ) extends InterBrokerSendThread("name", networkClient, requestTimeoutMs, time) {
+ private val queue = mutable.Queue[RequestAndCompletionHandler]()
+
+ def enqueue(request: RequestAndCompletionHandler): Unit = {
+ queue += request
+ }
+
+ override def generateRequests(): Iterable[RequestAndCompletionHandler] = {
+ if (queue.isEmpty) {
+ None
+ } else {
+ Some(queue.dequeue())
+ }
+ }
+ }
+
@Test
def shouldNotSendAnythingWhenNoRequests(): Unit = {
- val sendThread = new InterBrokerSendThread("name", networkClient, time) {
- override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs
- override def generateRequests() = mutable.Iterable.empty
- }
+ val sendThread = new TestInterBrokerSendThread()
// poll is always called but there should be no further invocations on NetworkClient
EasyMock.expect(networkClient.poll(EasyMock.anyLong(), EasyMock.anyLong()))
@@ -58,11 +72,8 @@ class InterBrokerSendThreadTest {
def shouldCreateClientRequestAndSendWhenNodeIsReady(): Unit = {
val request = new StubRequestBuilder()
val node = new Node(1, "", 8080)
- val handler = RequestAndCompletionHandler(node, request, completionHandler)
- val sendThread = new InterBrokerSendThread("name", networkClient, time) {
- override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs
- override def generateRequests() = List[RequestAndCompletionHandler](handler)
- }
+ val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
+ val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler)
@@ -85,6 +96,7 @@ class InterBrokerSendThreadTest {
EasyMock.replay(networkClient)
+ sendThread.enqueue(handler)
sendThread.doWork()
EasyMock.verify(networkClient)
@@ -95,21 +107,18 @@ class InterBrokerSendThreadTest {
def shouldCallCompletionHandlerWithDisconnectedResponseWhenNodeNotReady(): Unit = {
val request = new StubRequestBuilder
val node = new Node(1, "", 8080)
- val requestAndCompletionHandler = RequestAndCompletionHandler(node, request, completionHandler)
- val sendThread = new InterBrokerSendThread("name", networkClient, time) {
- override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs
- override def generateRequests() = List[RequestAndCompletionHandler](requestAndCompletionHandler)
- }
+ val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
+ val sendThread = new TestInterBrokerSendThread()
- val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, requestAndCompletionHandler.handler)
+ val clientRequest = new ClientRequest("dest", request, 0, "1", 0, true, requestTimeoutMs, handler.handler)
EasyMock.expect(networkClient.newClientRequest(
EasyMock.eq("1"),
- EasyMock.same(requestAndCompletionHandler.request),
+ EasyMock.same(handler.request),
EasyMock.anyLong(),
EasyMock.eq(true),
EasyMock.eq(requestTimeoutMs),
- EasyMock.same(requestAndCompletionHandler.handler)))
+ EasyMock.same(handler.handler)))
.andReturn(clientRequest)
EasyMock.expect(networkClient.ready(node, time.milliseconds()))
@@ -129,6 +138,7 @@ class InterBrokerSendThreadTest {
EasyMock.replay(networkClient)
+ sendThread.enqueue(handler)
sendThread.doWork()
EasyMock.verify(networkClient)
@@ -139,11 +149,8 @@ class InterBrokerSendThreadTest {
def testFailingExpiredRequests(): Unit = {
val request = new StubRequestBuilder()
val node = new Node(1, "", 8080)
- val handler = RequestAndCompletionHandler(node, request, completionHandler)
- val sendThread = new InterBrokerSendThread("name", networkClient, time) {
- override val requestTimeoutMs: Int = InterBrokerSendThreadTest.this.requestTimeoutMs
- override def generateRequests() = List[RequestAndCompletionHandler](handler)
- }
+ val handler = RequestAndCompletionHandler(time.milliseconds(), node, request, completionHandler)
+ val sendThread = new TestInterBrokerSendThread()
val clientRequest = new ClientRequest("dest",
request,
@@ -158,7 +165,7 @@ class InterBrokerSendThreadTest {
EasyMock.expect(networkClient.newClientRequest(
EasyMock.eq("1"),
EasyMock.same(handler.request),
- EasyMock.eq(time.milliseconds()),
+ EasyMock.eq(handler.creationTimeMs),
EasyMock.eq(true),
EasyMock.eq(requestTimeoutMs),
EasyMock.same(handler.handler)))
@@ -180,6 +187,7 @@ class InterBrokerSendThreadTest {
EasyMock.replay(networkClient)
+ sendThread.enqueue(handler)
sendThread.doWork()
EasyMock.verify(networkClient)
diff --git a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala
index 9ef9d7c33bb8e..7b65bea556151 100644
--- a/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala
+++ b/core/src/test/scala/kafka/server/BrokerToControllerRequestThreadTest.scala
@@ -17,36 +17,68 @@
package kafka.server
-import java.util.concurrent.{CountDownLatch, LinkedBlockingDeque, TimeUnit}
import java.util.Collections
+import java.util.concurrent.atomic.AtomicBoolean
+
import kafka.cluster.{Broker, EndPoint}
import kafka.utils.TestUtils
import org.apache.kafka.clients.{ClientResponse, ManualMetadataUpdater, Metadata, MockClient}
import org.apache.kafka.common.feature.Features
import org.apache.kafka.common.feature.Features.emptySupportedFeatures
-import org.apache.kafka.common.utils.{MockTime, SystemTime}
import org.apache.kafka.common.message.MetadataRequestData
import org.apache.kafka.common.network.ListenerName
import org.apache.kafka.common.protocol.Errors
import org.apache.kafka.common.requests.{AbstractRequest, MetadataRequest, MetadataResponse, RequestTestUtils}
import org.apache.kafka.common.security.auth.SecurityProtocol
+import org.apache.kafka.common.utils.MockTime
import org.junit.Assert.{assertEquals, assertFalse, assertTrue}
import org.junit.Test
import org.mockito.Mockito._
class BrokerToControllerRequestThreadTest {
+ @Test
+ def testRetryTimeoutWhileControllerNotAvailable(): Unit = {
+ val time = new MockTime()
+ val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181"))
+ val metadata = mock(classOf[Metadata])
+ val mockClient = new MockClient(time, metadata)
+ val metadataCache = mock(classOf[MetadataCache])
+ val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
+
+ when(metadataCache.getControllerId).thenReturn(None)
+
+ val retryTimeoutMs = 30000
+ val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache,
+ config, listenerName, time, "", retryTimeoutMs)
+
+ val completionHandler = new TestRequestCompletionHandler(None)
+ val queueItem = BrokerToControllerQueueItem(
+ time.milliseconds(),
+ new MetadataRequest.Builder(new MetadataRequestData()),
+ completionHandler
+ )
+
+ testRequestThread.enqueue(queueItem)
+ testRequestThread.doWork()
+ assertEquals(1, testRequestThread.queueSize)
+
+ time.sleep(retryTimeoutMs)
+ testRequestThread.doWork()
+ assertEquals(0, testRequestThread.queueSize)
+ assertTrue(completionHandler.timedOut.get)
+ }
+
@Test
def testRequestsSent(): Unit = {
// just a simple test that tests whether the request from 1 -> 2 is sent and the response callback is called
- val time = new SystemTime
+ val time = new MockTime()
val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181"))
val controllerId = 2
val metadata = mock(classOf[Metadata])
val mockClient = new MockClient(time, metadata)
- val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]()
val metadataCache = mock(classOf[MetadataCache])
val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
val activeController = new Broker(controllerId,
@@ -57,28 +89,33 @@ class BrokerToControllerRequestThreadTest {
when(metadataCache.getAliveBroker(controllerId)).thenReturn(Some(activeController))
val expectedResponse = RequestTestUtils.metadataUpdateWith(2, Collections.singletonMap("a", 2))
- val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), requestQueue, metadataCache,
- config, listenerName, time, "")
+ val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache,
+ config, listenerName, time, "", retryTimeoutMs = Long.MaxValue)
mockClient.prepareResponse(expectedResponse)
- val responseLatch = new CountDownLatch(1)
+ val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse))
val queueItem = BrokerToControllerQueueItem(
+ time.milliseconds(),
new MetadataRequest.Builder(new MetadataRequestData()),
- new TestRequestCompletionHandler(expectedResponse, responseLatch),
- Long.MaxValue)
- requestQueue.put(queueItem)
+ completionHandler
+ )
+
+ testRequestThread.enqueue(queueItem)
+ assertEquals(1, testRequestThread.queueSize)
+
// initialize to the controller
testRequestThread.doWork()
// send and process the request
testRequestThread.doWork()
- assertTrue(responseLatch.await(10, TimeUnit.SECONDS))
+ assertEquals(0, testRequestThread.queueSize)
+ assertTrue(completionHandler.completed.get())
}
@Test
def testControllerChanged(): Unit = {
// in this test the current broker is 1, and the controller changes from 2 -> 3 then back: 3 -> 2
- val time = new SystemTime
+ val time = new MockTime()
val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181"))
val oldControllerId = 1
val newControllerId = 2
@@ -86,7 +123,6 @@ class BrokerToControllerRequestThreadTest {
val metadata = mock(classOf[Metadata])
val mockClient = new MockClient(time, metadata)
- val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]()
val metadataCache = mock(classOf[MetadataCache])
val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
val oldController = new Broker(oldControllerId,
@@ -102,39 +138,36 @@ class BrokerToControllerRequestThreadTest {
val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2))
val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(),
- requestQueue, metadataCache, config, listenerName, time, "")
-
- val responseLatch = new CountDownLatch(1)
+ metadataCache, config, listenerName, time, "", retryTimeoutMs = Long.MaxValue)
+ val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse))
val queueItem = BrokerToControllerQueueItem(
+ time.milliseconds(),
new MetadataRequest.Builder(new MetadataRequestData()),
- new TestRequestCompletionHandler(expectedResponse, responseLatch),
- Long.MaxValue)
- requestQueue.put(queueItem)
+ completionHandler,
+ )
+
+ testRequestThread.enqueue(queueItem)
mockClient.prepareResponse(expectedResponse)
// initialize the thread with oldController
testRequestThread.doWork()
- // assert queue correctness
- assertFalse(requestQueue.isEmpty)
- assertEquals(1, requestQueue.size())
- assertEquals(queueItem, requestQueue.peek())
+ assertFalse(completionHandler.completed.get())
+
// disconnect the node
mockClient.setUnreachable(oldControllerNode, time.milliseconds() + 5000)
// verify that the client closed the connection to the faulty controller
testRequestThread.doWork()
- assertFalse(requestQueue.isEmpty)
- assertEquals(1, requestQueue.size())
// should connect to the new controller
testRequestThread.doWork()
// should send the request and process the response
testRequestThread.doWork()
- assertTrue(responseLatch.await(10, TimeUnit.SECONDS))
+ assertTrue(completionHandler.completed.get())
}
@Test
def testNotController(): Unit = {
- val time = new SystemTime
+ val time = new MockTime()
val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181"))
val oldControllerId = 1
val newControllerId = 2
@@ -142,7 +175,6 @@ class BrokerToControllerRequestThreadTest {
val metadata = mock(classOf[Metadata])
val mockClient = new MockClient(time, metadata)
- val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]()
val metadataCache = mock(classOf[MetadataCache])
val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
val oldController = new Broker(oldControllerId,
@@ -159,16 +191,17 @@ class BrokerToControllerRequestThreadTest {
Collections.singletonMap("a", Errors.NOT_CONTROLLER),
Collections.singletonMap("a", 2))
val expectedResponse = RequestTestUtils.metadataUpdateWith(3, Collections.singletonMap("a", 2))
- val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), requestQueue, metadataCache,
- config, listenerName, time, "")
+ val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache,
+ config, listenerName, time, "", retryTimeoutMs = Long.MaxValue)
- val responseLatch = new CountDownLatch(1)
+ val completionHandler = new TestRequestCompletionHandler(Some(expectedResponse))
val queueItem = BrokerToControllerQueueItem(
+ time.milliseconds(),
new MetadataRequest.Builder(new MetadataRequestData()
.setAllowAutoTopicCreation(true)),
- new TestRequestCompletionHandler(expectedResponse, responseLatch),
- Long.MaxValue)
- requestQueue.put(queueItem)
+ completionHandler
+ )
+ testRequestThread.enqueue(queueItem)
// initialize to the controller
testRequestThread.doWork()
// send and process the request
@@ -183,11 +216,11 @@ class BrokerToControllerRequestThreadTest {
mockClient.prepareResponse(expectedResponse)
testRequestThread.doWork()
- assertTrue(responseLatch.await(10, TimeUnit.SECONDS))
+ assertTrue(completionHandler.completed.get())
}
@Test
- def testRequestTimeout(): Unit = {
+ def testRetryTimeout(): Unit = {
val time = new MockTime()
val config = new KafkaConfig(TestUtils.createBrokerConfig(1, "localhost:2181"))
val controllerId = 1
@@ -195,7 +228,6 @@ class BrokerToControllerRequestThreadTest {
val metadata = mock(classOf[Metadata])
val mockClient = new MockClient(time, metadata)
- val requestQueue = new LinkedBlockingDeque[BrokerToControllerQueueItem]()
val metadataCache = mock(classOf[MetadataCache])
val listenerName = ListenerName.forSecurityProtocol(SecurityProtocol.PLAINTEXT)
val controller = new Broker(controllerId,
@@ -205,53 +237,54 @@ class BrokerToControllerRequestThreadTest {
when(metadataCache.getAliveBrokers).thenReturn(Seq(controller))
when(metadataCache.getAliveBroker(controllerId)).thenReturn(Some(controller))
+ val retryTimeoutMs = 30000
val responseWithNotControllerError = RequestTestUtils.metadataUpdateWith("cluster1", 2,
Collections.singletonMap("a", Errors.NOT_CONTROLLER),
Collections.singletonMap("a", 2))
- val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), requestQueue, metadataCache,
- config, listenerName, time, "")
+ val testRequestThread = new BrokerToControllerRequestThread(mockClient, new ManualMetadataUpdater(), metadataCache,
+ config, listenerName, time, "", retryTimeoutMs)
- val responseLatch = new CountDownLatch(1)
- val requestTimeout = config.requestTimeoutMs.longValue()
+ val completionHandler = new TestRequestCompletionHandler()
val queueItem = BrokerToControllerQueueItem(
+ time.milliseconds(),
new MetadataRequest.Builder(new MetadataRequestData()
- .setAllowAutoTopicCreation(true)), new ControllerRequestCompletionHandler {
- override def onComplete(response: ClientResponse): Unit = {}
+ .setAllowAutoTopicCreation(true)),
+ completionHandler
+ )
- override def onTimeout(): Unit = {
- responseLatch.countDown()
- }
- }, requestTimeout + time.milliseconds())
- requestQueue.put(queueItem)
+ testRequestThread.enqueue(queueItem)
// initialize to the controller
testRequestThread.doWork()
+
+ time.sleep(retryTimeoutMs)
+
// send and process the request
mockClient.prepareResponse((body: AbstractRequest) => {
- // Advance time to timeout the response
- time.sleep(requestTimeout + 1)
-
body.isInstanceOf[MetadataRequest] &&
body.asInstanceOf[MetadataRequest].allowAutoTopicCreation()
}, responseWithNotControllerError)
testRequestThread.doWork()
- // The queued item should be timed out, instead of
- // re-enqueue by NOT_CONTROLLER error.
- assertEquals(0, requestQueue.size())
-
- assertTrue(responseLatch.await(10, TimeUnit.SECONDS))
+ assertTrue(completionHandler.timedOut.get())
}
- class TestRequestCompletionHandler(expectedResponse: MetadataResponse,
- responseLatch: CountDownLatch) extends ControllerRequestCompletionHandler {
+ class TestRequestCompletionHandler(
+ expectedResponse: Option[MetadataResponse] = None
+ ) extends ControllerRequestCompletionHandler {
+ val completed: AtomicBoolean = new AtomicBoolean(false)
+ val timedOut: AtomicBoolean = new AtomicBoolean(false)
+
override def onComplete(response: ClientResponse): Unit = {
- assertEquals(expectedResponse, response.responseBody())
- responseLatch.countDown()
+ expectedResponse.foreach { expected =>
+ assertEquals(expected, response.responseBody())
+ }
+ completed.set(true)
}
override def onTimeout(): Unit = {
+ timedOut.set(true)
}
}
}
diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
index 441b4e07ee100..1746147457fc2 100644
--- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala
@@ -132,7 +132,7 @@ class TransactionMarkerChannelManagerTest {
response)
TestUtils.waitUntilTrue(() => {
- val requests = channelManager.drainQueuedTransactionMarkers()
+ val requests = channelManager.generateRequests()
if (requests.nonEmpty) {
assertEquals(1, requests.size)
val request = requests.head
diff --git a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
index 699e8faa711bb..bfa1675ff2e56 100644
--- a/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
+++ b/core/src/test/scala/unit/kafka/raft/KafkaNetworkChannelTest.scala
@@ -19,16 +19,15 @@ package kafka.raft
import java.net.InetSocketAddress
import java.util
import java.util.Collections
-import java.util.concurrent.atomic.AtomicReference
import org.apache.kafka.clients.MockClient.MockMetadataUpdater
import org.apache.kafka.clients.{ApiVersion, MockClient, NodeApiVersions}
import org.apache.kafka.common.message.{BeginQuorumEpochResponseData, EndQuorumEpochResponseData, FetchResponseData, VoteResponseData}
import org.apache.kafka.common.protocol.{ApiKeys, ApiMessage, Errors}
-import org.apache.kafka.common.requests.{AbstractResponse, BeginQuorumEpochRequest, EndQuorumEpochRequest, VoteRequest, VoteResponse}
+import org.apache.kafka.common.requests.{AbstractResponse, BeginQuorumEpochRequest, BeginQuorumEpochResponse, EndQuorumEpochRequest, EndQuorumEpochResponse, FetchResponse, VoteRequest, VoteResponse}
import org.apache.kafka.common.utils.{MockTime, Time}
import org.apache.kafka.common.{Node, TopicPartition}
-import org.apache.kafka.raft.{RaftRequest, RaftResponse, RaftUtil}
+import org.apache.kafka.raft.{RaftRequest, RaftUtil}
import org.junit.Assert._
import org.junit.{Before, Test}
@@ -38,13 +37,11 @@ class KafkaNetworkChannelTest {
import KafkaNetworkChannelTest._
private val clusterId = "clusterId"
- private val clientId = "clientId"
- private val retryBackoffMs = 100
private val requestTimeoutMs = 30000
private val time = new MockTime()
private val client = new MockClient(time, new StubMetadataUpdater)
private val topicPartition = new TopicPartition("topic", 0)
- private val channel = new KafkaNetworkChannel(time, client, clientId, retryBackoffMs, requestTimeoutMs)
+ private val channel = new KafkaNetworkChannel(time, client, requestTimeoutMs)
@Before
def setupSupportedApis(): Unit = {
@@ -74,7 +71,7 @@ class KafkaNetworkChannelTest {
channel.updateEndpoint(destinationId, new InetSocketAddress(destinationNode.host, destinationNode.port))
for (apiKey <- RaftApis) {
- val response = KafkaNetworkChannel.buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST))
+ val response = buildResponse(buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST))
client.prepareResponseFrom(response, destinationNode, true)
sendAndAssertErrorResponse(apiKey, destinationId, Errors.BROKER_NOT_AVAILABLE)
}
@@ -109,33 +106,12 @@ class KafkaNetworkChannelTest {
for (apiKey <- RaftApis) {
val expectedError = Errors.INVALID_REQUEST
- val response = KafkaNetworkChannel.buildResponse(buildTestErrorResponse(apiKey, expectedError))
+ val response = buildResponse(buildTestErrorResponse(apiKey, expectedError))
client.prepareResponseFrom(response, destinationNode)
sendAndAssertErrorResponse(apiKey, destinationId, expectedError)
}
}
- @Test
- def testReceiveAndSendInboundRequest(): Unit = {
- for (apiKey <- RaftApis) {
- val request = KafkaNetworkChannel.buildRequest(buildTestRequest(apiKey)).build()
- val responseRef = new AtomicReference[AbstractResponse]()
-
- channel.postInboundRequest(request, responseRef.set)
- val inbound = channel.receive(1000).asScala
- assertEquals(1, inbound.size)
-
- val inboundRequest = inbound.head.asInstanceOf[RaftRequest.Inbound]
- val errorResponse = buildTestErrorResponse(apiKey, Errors.INVALID_REQUEST)
- val outboundResponse = new RaftResponse.Outbound(inboundRequest.correlationId, errorResponse)
- channel.send(outboundResponse)
- channel.receive(1000)
-
- assertNotNull(responseRef.get)
- assertEquals(Errors.INVALID_REQUEST, extractError(KafkaNetworkChannel.responseData(responseRef.get)))
- }
- }
-
private def sendAndAssertErrorResponse(apiKey: ApiKeys,
destinationId: Int,
error: Errors): Unit = {
@@ -145,10 +121,11 @@ class KafkaNetworkChannelTest {
val request = new RaftRequest.Outbound(correlationId, apiRequest, destinationId, createdTimeMs)
channel.send(request)
- val responses = channel.receive(1000).asScala
- assertEquals(1, responses.size)
+ channel.pollOnce()
+
+ assertTrue(request.completion.isDone)
- val response = responses.head.asInstanceOf[RaftResponse.Inbound]
+ val response = request.completion.get()
assertEquals(destinationId, response.sourceId)
assertEquals(correlationId, response.correlationId)
assertEquals(apiKey, ApiKeys.forId(response.data.apiKey))
@@ -216,6 +193,22 @@ class KafkaNetworkChannelTest {
Errors.forCode(code)
}
+
+ def buildResponse(responseData: ApiMessage): AbstractResponse = {
+ responseData match {
+ case voteResponse: VoteResponseData =>
+ new VoteResponse(voteResponse)
+ case beginEpochResponse: BeginQuorumEpochResponseData =>
+ new BeginQuorumEpochResponse(beginEpochResponse)
+ case endEpochResponse: EndQuorumEpochResponseData =>
+ new EndQuorumEpochResponse(endEpochResponse)
+ case fetchResponse: FetchResponseData =>
+ new FetchResponse(fetchResponse)
+ case _ =>
+ throw new IllegalArgumentException(s"Unexpected type for responseData: $responseData")
+ }
+ }
+
}
object KafkaNetworkChannelTest {
diff --git a/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala b/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala
index a2e269b3217d9..8c5657e4152c6 100644
--- a/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/AlterIsrManagerTest.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package unit.kafka.server
+package kafka.server
import java.util.Collections
import java.util.concurrent.atomic.AtomicInteger
import kafka.api.LeaderAndIsr
-import kafka.server.{AlterIsrItem, AlterIsrManager, BrokerToControllerChannelManager, ControllerRequestCompletionHandler, DefaultAlterIsrManager, ZkIsrManager}
import kafka.utils.{MockScheduler, MockTime}
import kafka.zk.KafkaZkClient
import org.apache.kafka.clients.ClientResponse
@@ -42,7 +41,6 @@ class AlterIsrManagerTest {
val time = new MockTime
val metrics = new Metrics
val brokerId = 1
- val requestTimeout = Long.MaxValue
var brokerToController: BrokerToControllerChannelManager = _
@@ -57,7 +55,8 @@ class AlterIsrManagerTest {
@Test
def testBasic(): Unit = {
- EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.anyObject(), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.anyObject())).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
@@ -73,7 +72,8 @@ class AlterIsrManagerTest {
@Test
def testOverwriteWithinBatch(): Unit = {
val capture = EasyMock.newCapture[AbstractRequest.Builder[AlterIsrRequest]]()
- EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject(), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject())).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
@@ -97,7 +97,8 @@ class AlterIsrManagerTest {
@Test
def testSingleBatch(): Unit = {
val capture = EasyMock.newCapture[AbstractRequest.Builder[AlterIsrRequest]]()
- EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject(), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.capture(capture), EasyMock.anyObject())).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
@@ -151,7 +152,8 @@ class AlterIsrManagerTest {
def testTopLevelError(isrs: Seq[AlterIsrItem], error: Errors): AlterIsrManager = {
val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]()
- EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
@@ -184,7 +186,8 @@ class AlterIsrManagerTest {
def testPartitionError(tp: TopicPartition, error: Errors): AlterIsrManager = {
val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]()
EasyMock.reset(brokerToController)
- EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
@@ -226,7 +229,8 @@ class AlterIsrManagerTest {
def testOneInFlight(): Unit = {
val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]()
EasyMock.reset(brokerToController)
- EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
@@ -253,7 +257,7 @@ class AlterIsrManagerTest {
callbackCapture.getValue.onComplete(resp)
EasyMock.reset(brokerToController)
- EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once()
EasyMock.replay(brokerToController)
time.sleep(100)
@@ -265,7 +269,8 @@ class AlterIsrManagerTest {
def testPartitionMissingInResponse(): Unit = {
val callbackCapture = EasyMock.newCapture[ControllerRequestCompletionHandler]()
EasyMock.reset(brokerToController)
- EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture), EasyMock.eq(requestTimeout))).once()
+ EasyMock.expect(brokerToController.start())
+ EasyMock.expect(brokerToController.sendRequest(EasyMock.anyObject(), EasyMock.capture(callbackCapture))).once()
EasyMock.replay(brokerToController)
val scheduler = new MockScheduler(time)
diff --git a/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala b/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala
index 5001f671c02f8..ca03359f90758 100644
--- a/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala
+++ b/core/src/test/scala/unit/kafka/server/ForwardingManagerTest.scala
@@ -19,6 +19,7 @@ package kafka.server
import java.net.InetAddress
import java.nio.ByteBuffer
import java.util.Optional
+
import kafka.network
import kafka.network.RequestChannel
import kafka.utils.MockTime
@@ -34,7 +35,7 @@ import org.apache.kafka.common.security.authenticator.DefaultKafkaPrincipalBuild
import org.junit.Assert._
import org.junit.Test
import org.mockito.ArgumentMatchers._
-import org.mockito.{ArgumentMatchers, Mockito}
+import org.mockito.Mockito
import scala.jdk.CollectionConverters._
@@ -45,7 +46,7 @@ class ForwardingManagerTest {
@Test
def testResponseCorrelationIdMismatch(): Unit = {
- val forwardingManager = new ForwardingManager(brokerToController, time, Long.MaxValue)
+ val forwardingManager = new ForwardingManagerImpl(brokerToController)
val requestCorrelationId = 27
val envelopeCorrelationId = 39
val clientPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, "client")
@@ -64,8 +65,7 @@ class ForwardingManagerTest {
Mockito.when(brokerToController.sendRequest(
any(classOf[EnvelopeRequest.Builder]),
- any(classOf[ControllerRequestCompletionHandler]),
- ArgumentMatchers.eq(Long.MaxValue)
+ any(classOf[ControllerRequestCompletionHandler])
)).thenAnswer(invocation => {
val completionHandler = invocation.getArgument[RequestCompletionHandler](1)
val response = buildEnvelopeResponse(responseBuffer, envelopeCorrelationId, completionHandler)
diff --git a/core/src/test/scala/unit/kafka/utils/TestUtils.scala b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
index 462e802707c25..e58e65b2a30b1 100755
--- a/core/src/test/scala/unit/kafka/utils/TestUtils.scala
+++ b/core/src/test/scala/unit/kafka/utils/TestUtils.scala
@@ -1082,8 +1082,6 @@ object TestUtils extends Logging {
inFlight.set(false);
}
- override def start(): Unit = { }
-
def completeIsrUpdate(newZkVersion: Int): Unit = {
if (inFlight.compareAndSet(true, false)) {
val item = isrUpdates.head
diff --git a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
index 17bd9584b15f9..37eee73f2ce83 100644
--- a/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
+++ b/raft/src/main/java/org/apache/kafka/raft/KafkaRaftClient.java
@@ -23,14 +23,14 @@
import org.apache.kafka.common.message.BeginQuorumEpochRequestData;
import org.apache.kafka.common.message.BeginQuorumEpochResponseData;
import org.apache.kafka.common.message.DescribeQuorumRequestData;
-import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState;
import org.apache.kafka.common.message.DescribeQuorumResponseData;
+import org.apache.kafka.common.message.DescribeQuorumResponseData.ReplicaState;
import org.apache.kafka.common.message.EndQuorumEpochRequestData;
import org.apache.kafka.common.message.EndQuorumEpochResponseData;
import org.apache.kafka.common.message.FetchRequestData;
import org.apache.kafka.common.message.FetchResponseData;
-import org.apache.kafka.common.message.LeaderChangeMessage.Voter;
import org.apache.kafka.common.message.LeaderChangeMessage;
+import org.apache.kafka.common.message.LeaderChangeMessage.Voter;
import org.apache.kafka.common.message.VoteRequestData;
import org.apache.kafka.common.message.VoteResponseData;
import org.apache.kafka.common.metrics.Metrics;
@@ -55,6 +55,7 @@
import org.apache.kafka.raft.RequestManager.ConnectionState;
import org.apache.kafka.raft.internals.BatchAccumulator;
import org.apache.kafka.raft.internals.BatchMemoryPool;
+import org.apache.kafka.raft.internals.BlockingMessageQueue;
import org.apache.kafka.raft.internals.CloseListener;
import org.apache.kafka.raft.internals.FuturePurgatory;
import org.apache.kafka.raft.internals.KafkaRaftMetrics;
@@ -71,6 +72,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.OptionalLong;
@@ -141,6 +143,8 @@ public class KafkaRaftClient implements RaftClient {
private final FuturePurgatory fetchPurgatory;
private final RecordSerde serde;
private final MemoryPool memoryPool;
+ private final RaftMessageQueue messageQueue;
+
private final List listenerContexts = new ArrayList<>();
private final ConcurrentLinkedQueue> pendingListeners = new ConcurrentLinkedQueue<>();
@@ -158,6 +162,7 @@ public KafkaRaftClient(
) {
this(serde,
channel,
+ new BlockingMessageQueue(),
log,
quorum,
new BatchMemoryPool(5, MAX_BATCH_SIZE),
@@ -177,6 +182,7 @@ public KafkaRaftClient(
public KafkaRaftClient(
RecordSerde serde,
NetworkChannel channel,
+ RaftMessageQueue messageQueue,
ReplicatedLog log,
QuorumState quorum,
MemoryPool memoryPool,
@@ -194,6 +200,7 @@ public KafkaRaftClient(
) {
this.serde = serde;
this.channel = channel;
+ this.messageQueue = messageQueue;
this.log = log;
this.quorum = quorum;
this.memoryPool = memoryPool;
@@ -333,7 +340,7 @@ public void initialize() throws IOException {
@Override
public void register(Listener listener) {
pendingListeners.add(listener);
- channel.wakeup();
+ wakeup();
}
private OffsetAndEpoch endOffset() {
@@ -779,9 +786,6 @@ private EndQuorumEpochResponseData handleEndQuorumEpochRequest(
FollowerState state = quorum.followerStateOrThrow();
if (state.leaderId() == requestLeaderId) {
List preferredSuccessors = partitionRequest.preferredSuccessors();
- if (!preferredSuccessors.contains(quorum.localId)) {
- return buildEndQuorumEpochResponse(Errors.INCONSISTENT_VOTER_SET);
- }
long electionBackoffMs = endEpochElectionBackoff(preferredSuccessors);
logger.debug("Overriding follower fetch timeout to {} after receiving " +
"EndQuorumEpoch request from leader {} in epoch {}", electionBackoffMs,
@@ -798,7 +802,7 @@ private long endEpochElectionBackoff(List preferredSuccessors) {
// voter has a higher chance to be elected. If the node's priority is highest, become
// candidate immediately instead of waiting for next poll.
int position = preferredSuccessors.indexOf(quorum.localId);
- if (position == 0) {
+ if (position <= 0) {
return 0;
} else {
return strictExponentialElectionBackoffMs(position, preferredSuccessors.size());
@@ -932,6 +936,7 @@ private CompletableFuture handleFetchRequest(
}
}
+ // FIXME: `completionTimeMs`, which can be null
logger.trace("Completing delayed fetch from {} starting at offset {} at {}",
request.replicaId(), fetchPartition.fetchOffset(), completionTimeMs);
@@ -967,6 +972,7 @@ private FetchResponseData tryCompleteFetchRequest(
return buildFetchResponse(Errors.NONE, MemoryRecords.EMPTY, divergingEpoch, state.highWatermark());
} else {
LogFetchInfo info = log.read(fetchOffset, Isolation.UNCOMMITTED);
+
if (state.updateReplicaState(replicaId, currentTimeMs, info.startOffsetMetadata)) {
onUpdateLeaderHighWatermark(state, currentTimeMs);
}
@@ -1231,7 +1237,7 @@ private boolean handleTopLevelError(Errors error, RaftResponse.Inbound response)
private boolean handleUnexpectedError(Errors error, RaftResponse.Inbound response) {
logger.error("Unexpected error {} in {} response: {}",
- error, response.data.apiKey(), response);
+ error, ApiKeys.forId(response.data.apiKey()), response);
return false;
}
@@ -1263,7 +1269,7 @@ private void handleResponse(RaftResponse.Inbound response, long currentTimeMs) t
ConnectionState connection = requestManager.getOrCreate(response.sourceId());
if (handledSuccessfully) {
- connection.onResponseReceived(response.correlationId, currentTimeMs);
+ connection.onResponseReceived(response.correlationId);
} else {
connection.onResponseError(response.correlationId, currentTimeMs);
}
@@ -1343,7 +1349,10 @@ private void handleRequest(RaftRequest.Inbound request, long currentTimeMs) thro
} else {
message = RaftUtil.errorResponse(apiKey, Errors.forException(exception));
}
- sendOutboundMessage(new RaftResponse.Outbound(request.correlationId(), message));
+
+ RaftResponse.Outbound responseMessage = new RaftResponse.Outbound(request.correlationId(), message);
+ request.completion.complete(responseMessage);
+ logger.trace("Sent response {} to inbound request {}", responseMessage, request);
});
}
@@ -1355,17 +1364,17 @@ private void handleInboundMessage(RaftMessage message, long currentTimeMs) throw
handleRequest(request, currentTimeMs);
} else if (message instanceof RaftResponse.Inbound) {
RaftResponse.Inbound response = (RaftResponse.Inbound) message;
- handleResponse(response, currentTimeMs);
+ ConnectionState connection = requestManager.getOrCreate(response.sourceId());
+ if (connection.isResponseExpected(response.correlationId)) {
+ handleResponse(response, currentTimeMs);
+ } else {
+ logger.debug("Ignoring response {} since it is no longer needed", response);
+ }
} else {
throw new IllegalArgumentException("Unexpected message " + message);
}
}
- private void sendOutboundMessage(RaftMessage message) {
- channel.send(message);
- logger.trace("Sent outbound message: {}", message);
- }
-
/**
* Attempt to send a request. Return the time to wait before the request can be retried.
*/
@@ -1383,7 +1392,32 @@ private long maybeSendRequest(
if (connection.isReady(currentTimeMs)) {
int correlationId = channel.newCorrelationId();
ApiMessage request = requestSupplier.get();
- sendOutboundMessage(new RaftRequest.Outbound(correlationId, request, destinationId, currentTimeMs));
+
+ RaftRequest.Outbound requestMessage = new RaftRequest.Outbound(
+ correlationId,
+ request,
+ destinationId,
+ currentTimeMs
+ );
+
+ requestMessage.completion.whenComplete((response, exception) -> {
+ if (exception != null) {
+ ApiKeys api = ApiKeys.forId(request.apiKey());
+ Errors error = Errors.forException(exception);
+ ApiMessage errorResponse = RaftUtil.errorResponse(api, error);
+
+ response = new RaftResponse.Inbound(
+ correlationId,
+ errorResponse,
+ destinationId
+ );
+ }
+
+ messageQueue.add(response);
+ });
+
+ channel.send(requestMessage);
+ logger.trace("Sent outbound request: {}", requestMessage);
connection.onRequestSent(correlationId, currentTimeMs);
return Long.MAX_VALUE;
}
@@ -1781,6 +1815,26 @@ private boolean maybeCompleteShutdown(long currentTimeMs) {
return false;
}
+ private void wakeup() {
+ messageQueue.wakeup();
+ }
+
+ /**
+ * Handle an inbound request. The response will be returned through
+ * {@link RaftRequest.Inbound#completion}.
+ *
+ * @param request The inbound request
+ */
+ public void handle(RaftRequest.Inbound request) {
+ messageQueue.add(Objects.requireNonNull(request));
+ }
+
+ /**
+ * Poll for new events. This allows the client to handle inbound
+ * requests and send any needed outbound requests.
+ *
+ * @throws IOException for any IO errors encountered
+ */
public void poll() throws IOException {
pollListeners();
@@ -1792,14 +1846,13 @@ public void poll() throws IOException {
long pollTimeoutMs = pollCurrentState(currentTimeMs);
kafkaRaftMetrics.updatePollStart(currentTimeMs);
- List inboundMessages = channel.receive(pollTimeoutMs);
+ RaftMessage message = messageQueue.poll(pollTimeoutMs);
currentTimeMs = time.milliseconds();
kafkaRaftMetrics.updatePollEnd(currentTimeMs);
- for (RaftMessage message : inboundMessages) {
+ if (message != null) {
handleInboundMessage(message, currentTimeMs);
- currentTimeMs = time.milliseconds();
}
}
@@ -1819,7 +1872,7 @@ public Long scheduleAppend(int epoch, List records) {
// the linger timeout so that it can schedule its own wakeup in case
// there are no additional appends.
if (isFirstAppend || accumulator.needsDrain(time.milliseconds())) {
- channel.wakeup();
+ wakeup();
}
return offset;
}
@@ -1829,7 +1882,7 @@ public CompletableFuture shutdown(int timeoutMs) {
logger.info("Beginning graceful shutdown");
CompletableFuture shutdownComplete = new CompletableFuture<>();
shutdown.set(new GracefulShutdown(timeoutMs, shutdownComplete));
- channel.wakeup();
+ wakeup();
return shutdownComplete;
}
@@ -1991,7 +2044,7 @@ public synchronized void onClose(BatchReader reader) {
if (lastSent == reader) {
lastSent = null;
- channel.wakeup();
+ wakeup();
}
}
diff --git a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
index 1d810a9c22aad..c44f0f030cf23 100644
--- a/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
+++ b/raft/src/main/java/org/apache/kafka/raft/LeaderState.java
@@ -159,14 +159,15 @@ public boolean updateReplicaState(int replicaId,
public List nonLeaderVotersByDescendingFetchOffset() {
return followersByDescendingFetchOffset().stream()
- .filter(state -> state.nodeId != localId)
- .map(state -> state.nodeId)
- .collect(Collectors.toList());
+ .filter(state -> state.nodeId != localId)
+ .map(state -> state.nodeId)
+ .collect(Collectors.toList());
}
private List followersByDescendingFetchOffset() {
- return new ArrayList<>(this.voterReplicaStates.values())
- .stream().sorted().collect(Collectors.toList());
+ return new ArrayList<>(this.voterReplicaStates.values()).stream()
+ .sorted()
+ .collect(Collectors.toList());
}
private boolean updateEndOffset(ReplicaState state,
diff --git a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
index 421dae2b46fbb..c097f7cc40f79 100644
--- a/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
+++ b/raft/src/main/java/org/apache/kafka/raft/NetworkChannel.java
@@ -18,11 +18,10 @@
import java.io.Closeable;
import java.net.InetSocketAddress;
-import java.util.List;
/**
* A simple network interface with few assumptions. We do not assume ordering
- * of requests or even that every request will receive a response.
+ * of requests or even that every outbound request will receive a response.
*/
public interface NetworkChannel extends Closeable {
@@ -37,21 +36,7 @@ public interface NetworkChannel extends Closeable {
* or a response to a request that was received through {@link #receive(long)}
* (i.e. an instance of {@link org.apache.kafka.raft.RaftResponse.Outbound}).
*/
- void send(RaftMessage message);
-
- /**
- * Receive inbound messages. These could contain either inbound requests
- * (i.e. instances of {@link org.apache.kafka.raft.RaftRequest.Inbound})
- * or responses to outbound requests sent through {@link #send(RaftMessage)}
- * (i.e. instances of {@link org.apache.kafka.raft.RaftResponse.Inbound}).
- */
- List receive(long timeoutMs);
-
- /**
- * Wakeup the channel if it is blocking in {@link #receive(long)}. This will cause
- * the call to immediately return with whatever messages are available.
- */
- void wakeup();
+ void send(RaftRequest.Outbound request);
/**
* Update connection information for the given id.
diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java
new file mode 100644
index 0000000000000..7d1e4b7598a52
--- /dev/null
+++ b/raft/src/main/java/org/apache/kafka/raft/RaftMessageQueue.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+/**
+ * This class is used to serialize inbound requests or responses to outbound requests.
+ * It basically just allows us to wrap a blocking queue so that we can have a mocked
+ * implementation which does not depend on system time.
+ *
+ * See {@link org.apache.kafka.raft.internals.BlockingMessageQueue}.
+ */
+public interface RaftMessageQueue {
+
+ /**
+ * Block for the arrival of a new message.
+ *
+ * @param timeoutMs timeout in milliseconds to wait for a new event
+ * @return the event or null if either the timeout was reached or there was
+ * a call to {@link #wakeup()} before any events became available
+ */
+ RaftMessage poll(long timeoutMs);
+
+ /**
+ * Add a new message to the queue.
+ *
+ * @param message the message to deliver
+ * @throws IllegalStateException if the queue cannot accept the message
+ */
+ void add(RaftMessage message);
+
+ /**
+ * Check whether there are pending messages awaiting delivery.
+ *
+ * @return if there are no pending messages to deliver
+ */
+ boolean isEmpty();
+
+ /**
+ * Wakeup the thread blocking in {@link #poll(long)}. This will cause
+ * {@link #poll(long)} to return null if no messages are available.
+ */
+ void wakeup();
+
+}
diff --git a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java
index f607027dced5b..28e63c14ce69f 100644
--- a/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java
+++ b/raft/src/main/java/org/apache/kafka/raft/RaftRequest.java
@@ -18,6 +18,8 @@
import org.apache.kafka.common.protocol.ApiMessage;
+import java.util.concurrent.CompletableFuture;
+
public abstract class RaftRequest implements RaftMessage {
protected final int correlationId;
protected final ApiMessage data;
@@ -44,6 +46,8 @@ public long createdTimeMs() {
}
public static class Inbound extends RaftRequest {
+ public final CompletableFuture completion = new CompletableFuture<>();
+
public Inbound(int correlationId, ApiMessage data, long createdTimeMs) {
super(correlationId, data, createdTimeMs);
}
@@ -60,6 +64,7 @@ public String toString() {
public static class Outbound extends RaftRequest {
private final int destinationId;
+ public final CompletableFuture completion = new CompletableFuture<>();
public Outbound(int correlationId, ApiMessage data, int destinationId, long createdTimeMs) {
super(correlationId, data, createdTimeMs);
diff --git a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java
index aec05c7835710..5a5cb003c25af 100644
--- a/raft/src/main/java/org/apache/kafka/raft/RequestManager.java
+++ b/raft/src/main/java/org/apache/kafka/raft/RequestManager.java
@@ -20,8 +20,8 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.Optional;
import java.util.OptionalInt;
+import java.util.OptionalLong;
import java.util.Random;
import java.util.Set;
@@ -103,7 +103,7 @@ public class ConnectionState {
private State state = State.READY;
private long lastSendTimeMs = 0L;
private long lastFailTimeMs = 0L;
- private Optional inFlightCorrelationId = Optional.empty();
+ private OptionalLong inFlightCorrelationId = OptionalLong.empty();
public ConnectionState(long id) {
this.id = id;
@@ -160,28 +160,32 @@ long remainingBackoffMs(long timeMs) {
}
}
+ boolean isResponseExpected(long correlationId) {
+ return inFlightCorrelationId.isPresent() && inFlightCorrelationId.getAsLong() == correlationId;
+ }
+
void onResponseError(long correlationId, long timeMs) {
inFlightCorrelationId.ifPresent(inflightRequestId -> {
if (inflightRequestId == correlationId) {
lastFailTimeMs = timeMs;
state = State.BACKING_OFF;
- inFlightCorrelationId = Optional.empty();
+ inFlightCorrelationId = OptionalLong.empty();
}
});
}
- void onResponseReceived(long correlationId, long timeMs) {
+ void onResponseReceived(long correlationId) {
inFlightCorrelationId.ifPresent(inflightRequestId -> {
if (inflightRequestId == correlationId) {
state = State.READY;
- inFlightCorrelationId = Optional.empty();
+ inFlightCorrelationId = OptionalLong.empty();
}
});
}
void onRequestSent(long correlationId, long timeMs) {
lastSendTimeMs = timeMs;
- inFlightCorrelationId = Optional.of(correlationId);
+ inFlightCorrelationId = OptionalLong.of(correlationId);
state = State.AWAITING_REQUEST;
}
@@ -192,7 +196,7 @@ void onRequestSent(long correlationId, long timeMs) {
*/
void reset() {
state = State.READY;
- inFlightCorrelationId = Optional.empty();
+ inFlightCorrelationId = OptionalLong.empty();
}
@Override
diff --git a/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java
new file mode 100644
index 0000000000000..9343cca8d47c5
--- /dev/null
+++ b/raft/src/main/java/org/apache/kafka/raft/internals/BlockingMessageQueue.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft.internals;
+
+import org.apache.kafka.common.errors.InterruptException;
+import org.apache.kafka.common.protocol.ApiMessage;
+import org.apache.kafka.raft.RaftMessage;
+import org.apache.kafka.raft.RaftMessageQueue;
+
+import java.util.concurrent.BlockingQueue;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
+public class BlockingMessageQueue implements RaftMessageQueue {
+ private static final RaftMessage WAKEUP_MESSAGE = new RaftMessage() {
+ @Override
+ public int correlationId() {
+ return 0;
+ }
+
+ @Override
+ public ApiMessage data() {
+ return null;
+ }
+ };
+
+ private final BlockingQueue queue = new LinkedBlockingQueue<>();
+ private final AtomicInteger size = new AtomicInteger(0);
+
+ @Override
+ public RaftMessage poll(long timeoutMs) {
+ try {
+ RaftMessage message = queue.poll(timeoutMs, TimeUnit.MILLISECONDS);
+ if (message == null || message == WAKEUP_MESSAGE) {
+ return null;
+ } else {
+ size.decrementAndGet();
+ return message;
+ }
+ } catch (InterruptedException e) {
+ throw new InterruptException(e);
+ }
+ }
+
+ @Override
+ public void add(RaftMessage message) {
+ queue.add(message);
+ size.incrementAndGet();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return size.get() == 0;
+ }
+
+ @Override
+ public void wakeup() {
+ queue.add(WAKEUP_MESSAGE);
+ }
+
+}
diff --git a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
index 5a0e902e15bd5..43fb2301414f9 100644
--- a/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/KafkaRaftClientTest.java
@@ -109,7 +109,7 @@ public void testRejectVotesFromSameEpochAfterResigningLeadership() throws Except
// other voter in the same epoch, even if it has caught up to the same position.
context.deliverRequest(context.voteRequest(epoch, remoteId,
context.log.lastFetchedEpoch(), context.log.endOffset().offset));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(localId), false);
}
@@ -134,7 +134,7 @@ public void testRejectVotesFromSameEpochAfterResigningCandidacy() throws Excepti
// other voter in the same epoch, even if it has caught up to the same position.
context.deliverRequest(context.voteRequest(epoch, remoteId,
context.log.lastFetchedEpoch(), context.log.endOffset().offset));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), false);
}
@@ -158,13 +158,13 @@ public void testInitializeAsResignedLeaderFromStateStore() throws Exception {
context.client.poll();
assertEquals(Long.MAX_VALUE, context.client.scheduleAppend(epoch, Arrays.asList("a", "b")));
- context.pollUntilSend();
+ context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(epoch, 1);
context.deliverResponse(correlationId, 1, context.endEpochResponse(epoch, OptionalInt.of(localId)));
context.client.poll();
context.time.sleep(context.electionTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(epoch + 1, localId);
context.assertSentVoteRequest(epoch + 1, 0, 0L, 1);
}
@@ -187,7 +187,7 @@ public void testEndQuorumEpochRetriesWhileResigned() throws Exception {
.withElectedLeader(epoch, localId)
.build();
- context.pollUntilSend();
+ context.pollUntilRequest();
List requests = context.collectEndQuorumRequests(epoch, Utils.mkSet(voter1, voter2));
assertEquals(2, requests.size());
@@ -203,7 +203,7 @@ public void testEndQuorumEpochRetriesWhileResigned() throws Exception {
// retried request from the voter that hasn't responded yet.
int nonRespondedId = requests.get(1).destinationId();
context.time.sleep(6000);
- context.pollUntilSend();
+ context.pollUntilRequest();
List retries = context.collectEndQuorumRequests(epoch, Utils.mkSet(nonRespondedId));
assertEquals(1, retries.size());
}
@@ -254,7 +254,7 @@ public void testInitializeAsCandidateFromStateStore() throws Exception {
assertEquals(0L, context.log.endOffset().offset);
// The candidate will resume the election after reinitialization
- context.pollUntilSend();
+ context.pollUntilRequest();
List voteRequests = context.collectVoteRequests(2, 0, 0);
assertEquals(2, voteRequests.size());
}
@@ -269,7 +269,7 @@ public void testInitializeAsCandidateAndBecomeLeader() throws Exception {
context.assertUnknownLeader(0);
context.time.sleep(2 * context.electionTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(1, context.localId);
int correlationId = context.assertSentVoteRequest(1, 0, 0L, 1);
@@ -309,7 +309,7 @@ public void testInitializeAsCandidateAndBecomeLeaderQuorumOfThree() throws Excep
context.assertUnknownLeader(0);
context.time.sleep(2 * context.electionTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(1, context.localId);
int correlationId = context.assertSentVoteRequest(1, 0, 0L, 2);
@@ -350,8 +350,7 @@ public void testHandleBeginQuorumRequest() throws Exception {
.build();
context.deliverRequest(context.beginEpochRequest(votedCandidateEpoch, otherNodeId));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertElectedLeader(votedCandidateEpoch, otherNodeId);
@@ -370,8 +369,7 @@ public void testHandleBeginQuorumResponse() throws Exception {
.build();
context.deliverRequest(context.beginEpochRequest(leaderEpoch + 1, otherNodeId));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertElectedLeader(leaderEpoch + 1, otherNodeId);
}
@@ -431,7 +429,7 @@ public void testEndQuorumIgnoredAsLeaderIfOlderEpoch() throws Exception {
// One of the voters may have sent EndQuorumEpoch from an earlier epoch
context.deliverRequest(context.endEpochRequest(epoch - 2, voter2, Arrays.asList(context.localId, voter3)));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentEndQuorumEpochResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.of(context.localId));
// We should still be leader as long as fetch timeout has not expired
@@ -455,7 +453,7 @@ public void testEndQuorumStartsNewElectionImmediatelyIfFollowerUnattached() thro
context.deliverRequest(context.endEpochRequest(epoch, voter2,
Arrays.asList(context.localId, voter3)));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentEndQuorumEpochResponse(Errors.NONE, epoch, OptionalInt.of(voter2));
// Should become a candidate immediately
@@ -486,7 +484,7 @@ public void testAccumulatorClearedAfterBecomingFollower() throws Exception {
assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a")));
context.deliverRequest(context.beginEpochRequest(epoch + 1, otherNodeId));
- context.client.poll();
+ context.pollUntilResponse();
context.assertElectedLeader(epoch + 1, otherNodeId);
Mockito.verify(memoryPool).release(buffer);
@@ -516,7 +514,7 @@ public void testAccumulatorClearedAfterBecomingVoted() throws Exception {
assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a")));
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch,
context.log.endOffset().offset));
- context.client.poll();
+ context.pollUntilResponse();
context.assertVotedCandidate(epoch + 1, otherNodeId);
Mockito.verify(memoryPool).release(buffer);
@@ -545,7 +543,7 @@ public void testAccumulatorClearedAfterBecomingUnattached() throws Exception {
assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a")));
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 0L));
- context.client.poll();
+ context.pollUntilResponse();
context.assertUnknownLeader(epoch + 1);
Mockito.verify(memoryPool).release(buffer);
@@ -571,14 +569,14 @@ public void testChannelWokenUpIfLingerTimeoutReachedWithoutAppend() throws Excep
int epoch = context.currentEpoch();
assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a")));
- assertTrue(context.channel.wakeupRequested());
+ assertTrue(context.messageQueue.wakeupRequested());
context.client.poll();
- assertEquals(OptionalLong.of(lingerMs), context.channel.lastReceiveTimeout());
+ assertEquals(OptionalLong.of(lingerMs), context.messageQueue.lastPollTimeoutMs());
context.time.sleep(20);
context.client.poll();
- assertEquals(OptionalLong.of(30), context.channel.lastReceiveTimeout());
+ assertEquals(OptionalLong.of(30), context.messageQueue.lastPollTimeoutMs());
context.time.sleep(30);
context.client.poll();
@@ -605,15 +603,15 @@ public void testChannelWokenUpIfLingerTimeoutReachedDuringAppend() throws Except
int epoch = context.currentEpoch();
assertEquals(1L, context.client.scheduleAppend(epoch, singletonList("a")));
- assertTrue(context.channel.wakeupRequested());
+ assertTrue(context.messageQueue.wakeupRequested());
context.client.poll();
- assertFalse(context.channel.wakeupRequested());
- assertEquals(OptionalLong.of(lingerMs), context.channel.lastReceiveTimeout());
+ assertFalse(context.messageQueue.wakeupRequested());
+ assertEquals(OptionalLong.of(lingerMs), context.messageQueue.lastPollTimeoutMs());
context.time.sleep(lingerMs);
assertEquals(2L, context.client.scheduleAppend(epoch, singletonList("b")));
- assertTrue(context.channel.wakeupRequested());
+ assertTrue(context.messageQueue.wakeupRequested());
context.client.poll();
assertEquals(3L, context.log.endOffset().offset);
@@ -633,7 +631,7 @@ public void testHandleEndQuorumRequest() throws Exception {
context.deliverRequest(context.endEpochRequest(leaderEpoch, oldLeaderId,
Collections.singletonList(context.localId)));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentEndQuorumEpochResponse(Errors.NONE, leaderEpoch, OptionalInt.of(oldLeaderId));
context.client.poll();
@@ -655,19 +653,19 @@ public void testHandleEndQuorumRequestWithLowerPriorityToBecomeLeader() throws E
context.deliverRequest(context.endEpochRequest(leaderEpoch, oldLeaderId,
Arrays.asList(preferredNextLeader, context.localId)));
- context.pollUntilSend();
+ context.pollUntilResponse();
context.assertSentEndQuorumEpochResponse(Errors.NONE, leaderEpoch, OptionalInt.of(oldLeaderId));
// The election won't trigger by one round retry backoff
context.time.sleep(1);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertSentFetchRequest(leaderEpoch, 0, 0);
context.time.sleep(context.retryBackoffMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
List voteRequests = context.collectVoteRequests(leaderEpoch + 1, 0, 0);
assertEquals(2, voteRequests.size());
@@ -687,7 +685,7 @@ public void testVoteRequestTimeout() throws Exception {
context.assertUnknownLeader(0);
context.time.sleep(2 * context.electionTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(epoch, context.localId);
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
@@ -696,13 +694,12 @@ public void testVoteRequestTimeout() throws Exception {
context.client.poll();
int retryCorrelationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
- // Even though we have resent the request, we should still accept the response to
- // the first request if it arrives late.
+ // We will ignore the timed out response if it arrives late
context.deliverResponse(correlationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1));
context.client.poll();
- context.assertElectedLeader(epoch, context.localId);
+ context.assertVotedCandidate(epoch, context.localId);
- // If the second request arrives later, it should have no effect
+ // Become leader after receiving the retry response
context.deliverResponse(retryCorrelationId, otherNodeId, context.voteResponse(true, Optional.empty(), 1));
context.client.poll();
context.assertElectedLeader(epoch, context.localId);
@@ -720,8 +717,7 @@ public void testHandleValidVoteRequestAsFollower() throws Exception {
.build();
context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), true);
@@ -741,8 +737,7 @@ public void testHandleVoteRequestAsFollowerWithElectedLeader() throws Exception
.build();
context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.of(electedLeaderId), false);
@@ -762,8 +757,7 @@ public void testHandleVoteRequestAsFollowerWithVotedCandidate() throws Exception
.build();
context.deliverRequest(context.voteRequest(epoch, otherNodeId, epoch - 1, 1));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.NONE, epoch, OptionalInt.empty(), false);
context.assertVotedCandidate(epoch, votedCandidateId);
@@ -781,8 +775,7 @@ public void testHandleInvalidVoteRequestWithOlderEpoch() throws Exception {
.build();
context.deliverRequest(context.voteRequest(epoch - 1, otherNodeId, epoch - 2, 1));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.FENCED_LEADER_EPOCH, epoch, OptionalInt.empty(), false);
context.assertUnknownLeader(epoch);
@@ -801,8 +794,7 @@ public void testHandleInvalidVoteRequestAsObserver() throws Exception {
.build();
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 1));
-
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.INCONSISTENT_VOTER_SET, epoch, OptionalInt.empty(), false);
context.assertUnknownLeader(epoch);
@@ -841,7 +833,8 @@ public void testListenerCommitCallbackAfterLeaderWrite() throws Exception {
// Let follower send a fetch to initialize the high watermark,
// note the offset 0 would be a control message for becoming the leader
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 0L, epoch, 500));
- context.pollUntilSend();
+ context.pollUntilResponse();
+ context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
assertEquals(OptionalLong.of(0L), context.client.highWatermark());
List records = Arrays.asList("a", "b", "c");
@@ -851,14 +844,14 @@ public void testListenerCommitCallbackAfterLeaderWrite() throws Exception {
// Let the follower send a fetch, it should advance the high watermark
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 1L, epoch, 500));
- context.pollUntilSend();
+ context.pollUntilResponse();
+ context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(localId));
assertEquals(OptionalLong.of(1L), context.client.highWatermark());
assertEquals(OptionalLong.empty(), context.listener.lastCommitOffset());
// Let the follower send another fetch from offset 4
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 4L, epoch, 500));
- context.client.poll();
- assertEquals(OptionalLong.of(4L), context.client.highWatermark());
+ context.pollUntil(() -> context.client.highWatermark().equals(OptionalLong.of(4L)));
assertEquals(records, context.listener.commitWithLastOffset(offset));
}
@@ -873,7 +866,7 @@ public void testCandidateIgnoreVoteRequestOnSameEpoch() throws Exception {
.withVotedCandidate(leaderEpoch, localId)
.build();
- context.pollUntilSend();
+ context.pollUntilRequest();
context.deliverRequest(context.voteRequest(leaderEpoch, otherNodeId, leaderEpoch - 1, 1));
context.client.poll();
@@ -898,7 +891,7 @@ public void testRetryElection() throws Exception {
context.assertUnknownLeader(0);
context.time.sleep(2 * context.electionTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(epoch, context.localId);
// Quorum size is two. If the other member rejects, then we need to schedule a revote.
@@ -919,7 +912,7 @@ public void testRetryElection() throws Exception {
// After jitter expires, we become a candidate again
context.time.sleep(1);
context.client.poll();
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(epoch + 1, context.localId);
context.assertSentVoteRequest(epoch + 1, 0, 0L, 1);
}
@@ -937,7 +930,7 @@ public void testInitializeAsFollowerEmptyLog() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertSentFetchRequest(epoch, 0L, 0);
}
@@ -957,7 +950,7 @@ public void testInitializeAsFollowerNonEmptyLog() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertSentFetchRequest(epoch, 1L, lastEpoch);
}
@@ -975,12 +968,12 @@ public void testVoterBecomeCandidateAfterFetchTimeout() throws Exception {
.build();
context.assertElectedLeader(epoch, otherNodeId);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertSentFetchRequest(epoch, 1L, lastEpoch);
context.time.sleep(context.fetchTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertSentVoteRequest(epoch + 1, lastEpoch, 1L, 1);
context.assertVotedCandidate(epoch + 1, context.localId);
@@ -996,7 +989,7 @@ public void testInitializeObserverNoPreviousState() throws Exception {
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
- context.pollUntilSend();
+ context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
@@ -1017,7 +1010,7 @@ public void testObserverQuorumDiscoveryFailure() throws Exception {
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
- context.pollUntilSend();
+ context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
@@ -1027,7 +1020,7 @@ public void testObserverQuorumDiscoveryFailure() throws Exception {
context.client.poll();
context.time.sleep(context.retryBackoffMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
@@ -1050,7 +1043,7 @@ public void testObserverSendDiscoveryFetchAfterFetchTimeout() throws Exception {
RaftClientTestContext context = new RaftClientTestContext.Builder(localId, voters).build();
- context.pollUntilSend();
+ context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
context.assertFetchRequestData(fetchRequest, 0, 0L, 0);
@@ -1062,7 +1055,7 @@ public void testObserverSendDiscoveryFetchAfterFetchTimeout() throws Exception {
context.assertElectedLeader(epoch, leaderId);
context.time.sleep(context.fetchTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
@@ -1079,27 +1072,27 @@ public void testInvalidFetchRequest() throws Exception {
context.deliverRequest(context.fetchRequest(
epoch, otherNodeId, -5L, 0, 0));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId));
context.deliverRequest(context.fetchRequest(
epoch, otherNodeId, 0L, -1, 0));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId));
context.deliverRequest(context.fetchRequest(
epoch, otherNodeId, 0L, epoch + 1, 0));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId));
context.deliverRequest(context.fetchRequest(
epoch + 1, otherNodeId, 0L, 0, 0));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentFetchResponse(Errors.UNKNOWN_LEADER_EPOCH, epoch, OptionalInt.of(context.localId));
context.deliverRequest(context.fetchRequest(
epoch, otherNodeId, 0L, 0, -1));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentFetchResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(context.localId));
}
@@ -1141,17 +1134,17 @@ public void testInvalidVoteRequest() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, 0, -5L));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false);
context.assertElectedLeader(epoch, otherNodeId);
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, -1, 0L));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false);
context.assertElectedLeader(epoch, otherNodeId);
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch + 1, 0L));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentVoteResponse(Errors.INVALID_REQUEST, epoch, OptionalInt.of(otherNodeId), false);
context.assertElectedLeader(epoch, otherNodeId);
}
@@ -1220,7 +1213,7 @@ public void testPurgatoryFetchCompletedByFollowerTransition() throws Exception {
// Now we get a BeginEpoch from the other voter and become a follower
context.deliverRequest(context.beginEpochRequest(epoch + 1, voter3));
- context.client.poll();
+ context.pollUntilResponse();
context.assertElectedLeader(epoch + 1, voter3);
// We expect the BeginQuorumEpoch response and a failed Fetch response
@@ -1246,7 +1239,7 @@ public void testFetchResponseIgnoredAfterBecomingCandidate() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
// Wait until we have a Fetch inflight to the leader
- context.pollUntilSend();
+ context.pollUntilRequest();
int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
// Now await the fetch timeout and become a candidate
@@ -1280,7 +1273,7 @@ public void testFetchResponseIgnoredAfterBecomingFollowerOfDifferentLeader() thr
context.assertElectedLeader(epoch, voter2);
// Wait until we have a Fetch inflight to the leader
- context.pollUntilSend();
+ context.pollUntilRequest();
int fetchCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
// Now receive a BeginEpoch from `voter3`
@@ -1316,7 +1309,7 @@ public void testVoteResponseIgnoredAfterBecomingFollower() throws Exception {
context.time.sleep(context.electionTimeoutMs * 2);
// Wait until the vote requests are inflight
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(epoch, context.localId);
List voteRequests = context.collectVoteRequests(epoch, 0, 0);
assertEquals(2, voteRequests.size());
@@ -1349,14 +1342,14 @@ public void testObserverLeaderRediscoveryAfterBrokerNotAvailableError() throws E
context.discoverLeaderAsObserver(leaderId, epoch);
- context.pollUntilSend();
+ context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest1.destinationId());
context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0);
context.deliverResponse(fetchRequest1.correlationId, fetchRequest1.destinationId(),
context.fetchResponse(epoch, -1, MemoryRecords.EMPTY, -1, Errors.BROKER_NOT_AVAILABLE));
- context.pollUntilSend();
+ context.pollUntilRequest();
// We should retry the Fetch against the other voter since the original
// voter connection will be backing off.
@@ -1386,13 +1379,13 @@ public void testObserverLeaderRediscoveryAfterRequestTimeout() throws Exception
context.discoverLeaderAsObserver(leaderId, epoch);
- context.pollUntilSend();
+ context.pollUntilRequest();
RaftRequest.Outbound fetchRequest1 = context.assertSentFetchRequest();
assertEquals(leaderId, fetchRequest1.destinationId());
context.assertFetchRequestData(fetchRequest1, epoch, 0L, 0);
context.time.sleep(context.requestTimeoutMs);
- context.pollUntilSend();
+ context.pollUntilRequest();
// We should retry the Fetch against the other voter since the original
// voter connection will be backing off.
@@ -1427,7 +1420,7 @@ public void testLeaderGracefulShutdown() throws Exception {
assertFalse(shutdownFuture.isDone());
// Send EndQuorumEpoch request to the other voter
- context.pollUntilSend();
+ context.pollUntilRequest();
assertTrue(context.client.isShuttingDown());
assertTrue(context.client.isRunning());
context.assertSentEndQuorumEpochRequest(1, otherNodeId);
@@ -1469,7 +1462,7 @@ public void testEndQuorumEpochSentBasedOnFetchOffset() throws Exception {
assertTrue(context.client.isRunning());
// Send EndQuorumEpoch request to the close follower
- context.pollUntilSend();
+ context.pollUntilRequest();
assertTrue(context.client.isRunning());
List endQuorumRequests = context.collectEndQuorumRequests(
@@ -1494,14 +1487,14 @@ public void testDescribeQuorum() throws Exception {
int observerId = 3;
context.deliverRequest(context.fetchRequest(epoch, observerId, 0L, 0, 0));
- context.client.poll();
+ context.pollUntilResponse();
long highWatermark = 1L;
context.assertSentFetchResponse(highWatermark, epoch);
context.deliverRequest(DescribeQuorumRequest.singletonRequest(context.metadataPartition));
- context.client.poll();
+ context.pollUntilResponse();
context.assertSentDescribeQuorumResponse(context.localId, epoch, highWatermark,
Arrays.asList(
@@ -1540,7 +1533,7 @@ public void testLeaderGracefulShutdownTimeout() throws Exception {
assertFalse(shutdownFuture.isDone());
// Send EndQuorumEpoch request to the other vote
- context.pollUntilSend();
+ context.pollUntilRequest();
assertTrue(context.client.isRunning());
context.assertSentEndQuorumEpochRequest(epoch, otherNodeId);
@@ -1631,7 +1624,7 @@ public void testFollowerReplication() throws Exception {
.build();
context.assertElectedLeader(epoch, otherNodeId);
- context.pollUntilSend();
+ context.pollUntilRequest();
int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
Records records = context.buildBatch(0L, 3, Arrays.asList("a", "b"));
@@ -1656,7 +1649,7 @@ public void testEmptyRecordSetInFetchResponse() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
// Receive an empty fetch response
- context.pollUntilSend();
+ context.pollUntilRequest();
int fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
FetchResponseData fetchResponse = context.fetchResponse(epoch, otherNodeId,
MemoryRecords.EMPTY, 0L, Errors.NONE);
@@ -1666,7 +1659,7 @@ public void testEmptyRecordSetInFetchResponse() throws Exception {
assertEquals(OptionalLong.of(0L), context.client.highWatermark());
// Receive some records in the next poll, but do not advance high watermark
- context.pollUntilSend();
+ context.pollUntilRequest();
Records records = context.buildBatch(0L, epoch, Arrays.asList("a", "b"));
fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 0L, 0);
fetchResponse = context.fetchResponse(epoch, otherNodeId,
@@ -1677,7 +1670,7 @@ public void testEmptyRecordSetInFetchResponse() throws Exception {
assertEquals(OptionalLong.of(0L), context.client.highWatermark());
// The next fetch response is empty, but should still advance the high watermark
- context.pollUntilSend();
+ context.pollUntilRequest();
fetchQuorumCorrelationId = context.assertSentFetchRequest(epoch, 2L, epoch);
fetchResponse = context.fetchResponse(epoch, otherNodeId,
MemoryRecords.EMPTY, 2L, Errors.NONE);
@@ -1704,7 +1697,7 @@ public void testFetchShouldBeTreatedAsLeaderEndorsement() throws Exception {
context.time.sleep(context.electionTimeoutMs);
context.expectAndGrantVotes(epoch);
- context.pollUntilSend();
+ context.pollUntilRequest();
// We send BeginEpoch, but it gets lost and the destination finds the leader through the Fetch API
context.assertSentBeginQuorumEpochRequest(epoch, 1);
@@ -1720,12 +1713,12 @@ public void testFetchShouldBeTreatedAsLeaderEndorsement() throws Exception {
context.client.poll();
- List sentMessages = context.channel.drainSendQueue();
+ List sentMessages = context.channel.drainSendQueue();
assertEquals(0, sentMessages.size());
}
@Test
- public void testLeaderAppendSingleMemberQuorum() throws IOException {
+ public void testLeaderAppendSingleMemberQuorum() throws Exception {
int localId = 0;
Set voters = Collections.singleton(localId);
@@ -1752,8 +1745,7 @@ public void testLeaderAppendSingleMemberQuorum() throws IOException {
// Now try reading it
int otherNodeId = 1;
context.deliverRequest(context.fetchRequest(1, otherNodeId, 0L, 0, 500));
-
- context.client.poll();
+ context.pollUntilResponse();
MemoryRecords fetchedRecords = context.assertSentFetchResponse(Errors.NONE, 1, OptionalInt.of(context.localId));
List batches = Utils.toList(fetchedRecords.batchIterator());
@@ -1796,7 +1788,7 @@ public void testFollowerLogReconciliation() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
assertEquals(3L, context.log.endOffset().offset);
- context.pollUntilSend();
+ context.pollUntilRequest();
int correlationId = context.assertSentFetchRequest(epoch, 3L, lastEpoch);
@@ -1873,7 +1865,7 @@ public void testClusterAuthorizationFailedInFetch() throws Exception {
context.assertElectedLeader(epoch, otherNodeId);
- context.pollUntilSend();
+ context.pollUntilRequest();
int correlationId = context.assertSentFetchRequest(epoch, 0, 0);
FetchResponseData response = new FetchResponseData()
@@ -1899,7 +1891,7 @@ public void testClusterAuthorizationFailedInBeginQuorumEpoch() throws Exception
context.time.sleep(context.electionTimeoutMs);
context.expectAndGrantVotes(epoch);
- context.pollUntilSend();
+ context.pollUntilRequest();
int correlationId = context.assertSentBeginQuorumEpochRequest(epoch, 1);
BeginQuorumEpochResponseData response = new BeginQuorumEpochResponseData()
.setErrorCode(Errors.CLUSTER_AUTHORIZATION_FAILED.code());
@@ -1921,7 +1913,7 @@ public void testClusterAuthorizationFailedInVote() throws Exception {
// Sleep a little to ensure that we become a candidate
context.time.sleep(context.electionTimeoutMs * 2);
- context.pollUntilSend();
+ context.pollUntilRequest();
context.assertVotedCandidate(epoch, context.localId);
int correlationId = context.assertSentVoteRequest(epoch, 0, 0L, 1);
@@ -1942,7 +1934,7 @@ public void testClusterAuthorizationFailedInEndQuorumEpoch() throws Exception {
RaftClientTestContext context = RaftClientTestContext.initializeAsLeader(localId, voters, epoch);
context.client.shutdown(5000);
- context.pollUntilSend();
+ context.pollUntilRequest();
int correlationId = context.assertSentEndQuorumEpochRequest(epoch, otherNodeId);
EndQuorumEpochResponseData response = new EndQuorumEpochResponseData()
@@ -2067,7 +2059,7 @@ public void testHandleCommitCallbackFiresAfterFollowerHighWatermarkAdvances() th
assertEquals(OptionalLong.empty(), context.client.highWatermark());
// Poll for our first fetch request
- context.pollUntilSend();
+ context.pollUntilRequest();
RaftRequest.Outbound fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
context.assertFetchRequestData(fetchRequest, epoch, 0L, 0);
@@ -2085,7 +2077,7 @@ public void testHandleCommitCallbackFiresAfterFollowerHighWatermarkAdvances() th
assertEquals(OptionalInt.empty(), context.listener.currentClaimedEpoch());
// Now look for the next fetch request
- context.pollUntilSend();
+ context.pollUntilRequest();
fetchRequest = context.assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
context.assertFetchRequestData(fetchRequest, epoch, 3L, 3);
@@ -2131,7 +2123,7 @@ public void testHandleCommitCallbackFiresInVotedState() throws Exception {
// Now we receive a vote request which transitions us to the 'voted' state
int candidateEpoch = epoch + 1;
context.deliverRequest(context.voteRequest(candidateEpoch, otherNodeId, epoch, 10L));
- context.client.poll();
+ context.pollUntilResponse();
context.assertVotedCandidate(candidateEpoch, otherNodeId);
assertEquals(OptionalLong.of(10L), context.client.highWatermark());
@@ -2166,12 +2158,13 @@ public void testHandleCommitCallbackFiresInCandidateState() throws Exception {
// Start off as the leader and receive a fetch to initialize the high watermark
context.becomeLeader();
context.deliverRequest(context.fetchRequest(epoch, otherNodeId, 9L, epoch, 500));
- context.client.poll();
+ context.pollUntilResponse();
assertEquals(OptionalLong.of(9L), context.client.highWatermark());
+ context.assertSentFetchResponse(Errors.NONE, epoch, OptionalInt.of(context.localId));
// Now we receive a vote request which transitions us to the 'unattached' state
context.deliverRequest(context.voteRequest(epoch + 1, otherNodeId, epoch, 9L));
- context.client.poll();
+ context.pollUntilResponse();
context.assertUnknownLeader(epoch + 1);
assertEquals(OptionalLong.of(9L), context.client.highWatermark());
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockLog.java b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
index 1e8c1b53fcb72..f0be8dd80d5a4 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockLog.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockLog.java
@@ -255,6 +255,13 @@ public LogAppendInfo appendAsFollower(Records records) {
long baseOffset = endOffset().offset;
long lastOffset = baseOffset;
for (RecordBatch batch : records.batches()) {
+ Optional lastEntry = lastEntry();
+
+ if (lastEntry.isPresent() && batch.baseOffset() != lastEntry.get().offset + 1) {
+ throw new IllegalArgumentException("Illegal append at offset " + batch.baseOffset() +
+ " with current end offset of " + endOffset().offset);
+ }
+
List entries = buildEntries(batch, Record::offset);
appendBatch(new LogBatch(batch.partitionLeaderEpoch(), batch.isControlBatch(), entries));
lastOffset = entries.get(entries.size() - 1).offset;
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java
new file mode 100644
index 0000000000000..5fcd599b95f1e
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/MockMessageQueue.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft;
+
+import java.util.ArrayDeque;
+import java.util.OptionalLong;
+import java.util.Queue;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicLong;
+
+/**
+ * Mocked implementation which does not block in {@link #poll(long)}..
+ */
+public class MockMessageQueue implements RaftMessageQueue {
+ private final Queue messages = new ArrayDeque<>();
+ private final AtomicBoolean wakeupRequested = new AtomicBoolean(false);
+ private final AtomicLong lastPollTimeout = new AtomicLong(-1);
+
+ @Override
+ public RaftMessage poll(long timeoutMs) {
+ wakeupRequested.set(false);
+ lastPollTimeout.set(timeoutMs);
+ return messages.poll();
+ }
+
+ @Override
+ public void add(RaftMessage message) {
+ messages.offer(message);
+ }
+
+ public OptionalLong lastPollTimeoutMs() {
+ long lastTimeoutMs = lastPollTimeout.get();
+ if (lastTimeoutMs < 0) {
+ return OptionalLong.empty();
+ } else {
+ return OptionalLong.of(lastTimeoutMs);
+ }
+ }
+
+ public boolean wakeupRequested() {
+ return wakeupRequested.get();
+ }
+
+ @Override
+ public boolean isEmpty() {
+ return messages.isEmpty();
+ }
+
+ @Override
+ public void wakeup() {
+ wakeupRequested.set(true);
+ }
+}
diff --git a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java
index da14df307336a..3f08ff58256d3 100644
--- a/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java
+++ b/raft/src/test/java/org/apache/kafka/raft/MockNetworkChannel.java
@@ -24,22 +24,17 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
-import java.util.OptionalLong;
-import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.Optional;
import java.util.concurrent.atomic.AtomicInteger;
-import java.util.concurrent.atomic.AtomicLong;
public class MockNetworkChannel implements NetworkChannel {
- private final AtomicInteger requestIdCounter;
- private final AtomicBoolean wakeupRequested = new AtomicBoolean(false);
- private final AtomicLong lastReceiveTimeout = new AtomicLong(-1);
+ private final AtomicInteger correlationIdCounter;
+ private final List sendQueue = new ArrayList<>();
+ private final Map awaitingResponse = new HashMap<>();
+ private final Map addressCache = new HashMap<>();
- private List sendQueue = new ArrayList<>();
- private List receiveQueue = new ArrayList<>();
- private Map addressCache = new HashMap<>();
-
- public MockNetworkChannel(AtomicInteger requestIdCounter) {
- this.requestIdCounter = requestIdCounter;
+ public MockNetworkChannel(AtomicInteger correlationIdCounter) {
+ this.correlationIdCounter = correlationIdCounter;
}
public MockNetworkChannel() {
@@ -48,46 +43,16 @@ public MockNetworkChannel() {
@Override
public int newCorrelationId() {
- return requestIdCounter.getAndIncrement();
- }
-
- @Override
- public void send(RaftMessage message) {
- if (message instanceof RaftRequest.Outbound) {
- RaftRequest.Outbound request = (RaftRequest.Outbound) message;
- if (!addressCache.containsKey(request.destinationId())) {
- throw new IllegalArgumentException("Attempted to send to destination " +
- request.destinationId() + ", but its address is not yet known");
- }
- }
- sendQueue.add(message);
+ return correlationIdCounter.getAndIncrement();
}
@Override
- public List receive(long timeoutMs) {
- wakeupRequested.set(false);
- lastReceiveTimeout.set(timeoutMs);
- List messages = receiveQueue;
- receiveQueue = new ArrayList<>();
- return messages;
- }
-
- OptionalLong lastReceiveTimeout() {
- long timeout = lastReceiveTimeout.get();
- if (timeout < 0) {
- return OptionalLong.empty();
- } else {
- return OptionalLong.of(timeout);
+ public void send(RaftRequest.Outbound request) {
+ if (!addressCache.containsKey(request.destinationId())) {
+ throw new IllegalArgumentException("Attempted to send to destination " +
+ request.destinationId() + ", but its address is not yet known");
}
- }
-
- boolean wakeupRequested() {
- return wakeupRequested.get();
- }
-
- @Override
- public void wakeup() {
- wakeupRequested.set(true);
+ sendQueue.add(request);
}
@Override
@@ -95,19 +60,17 @@ public void updateEndpoint(int id, InetSocketAddress address) {
addressCache.put(id, address);
}
- public List drainSendQueue() {
- List messages = sendQueue;
- sendQueue = new ArrayList<>();
- return messages;
+ public List drainSendQueue() {
+ return drainSentRequests(Optional.empty());
}
- public List drainSentRequests(ApiKeys apiKey) {
+ public List drainSentRequests(Optional apiKeyFilter) {
List requests = new ArrayList<>();
- Iterator iterator = sendQueue.iterator();
+ Iterator iterator = sendQueue.iterator();
while (iterator.hasNext()) {
- RaftMessage message = iterator.next();
- if (message instanceof RaftRequest.Outbound && message.data().apiKey() == apiKey.id) {
- RaftRequest.Outbound request = (RaftRequest.Outbound) message;
+ RaftRequest.Outbound request = iterator.next();
+ if (!apiKeyFilter.isPresent() || request.data().apiKey() == apiKeyFilter.get().id) {
+ awaitingResponse.put(request.correlationId, request);
requests.add(request);
iterator.remove();
}
@@ -115,27 +78,17 @@ public List drainSentRequests(ApiKeys apiKey) {
return requests;
}
- public List drainSentResponses(ApiKeys apiKey) {
- List responses = new ArrayList<>();
- Iterator iterator = sendQueue.iterator();
- while (iterator.hasNext()) {
- RaftMessage message = iterator.next();
- if (message instanceof RaftResponse.Outbound && message.data().apiKey() == apiKey.id) {
- RaftResponse.Outbound response = (RaftResponse.Outbound) message;
- responses.add(response);
- iterator.remove();
- }
- }
- return responses;
- }
-
- public boolean hasSentMessages() {
+ public boolean hasSentRequests() {
return !sendQueue.isEmpty();
}
- public void mockReceive(RaftMessage message) {
- receiveQueue.add(message);
+ public void mockReceive(RaftResponse.Inbound response) {
+ RaftRequest.Outbound request = awaitingResponse.get(response.correlationId);
+ if (request == null) {
+ throw new IllegalStateException("Received response for a request which is not being awaited");
+ }
+ request.completion.complete(response);
}
}
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
index 5ed8061a41b89..2b69ac8049f1f 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftClientTestContext.java
@@ -52,6 +52,7 @@
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.raft.internals.BatchBuilder;
import org.apache.kafka.raft.internals.StringSerde;
+import org.apache.kafka.test.TestCondition;
import org.apache.kafka.test.TestUtils;
import org.mockito.Mockito;
@@ -63,6 +64,7 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -96,11 +98,13 @@ public final class RaftClientTestContext {
final Metrics metrics;
public final MockLog log;
final MockNetworkChannel channel;
+ final MockMessageQueue messageQueue;
final MockTime time;
final MockListener listener;
-
final Set voters;
+ private final List sentResponses = new ArrayList<>();
+
public static final class Builder {
static final int DEFAULT_ELECTION_TIMEOUT_MS = 10000;
@@ -113,6 +117,7 @@ public static final class Builder {
private static final int RETRY_BACKOFF_MS = 50;
private static final int DEFAULT_APPEND_LINGER_MS = 0;
+ private final MockMessageQueue messageQueue = new MockMessageQueue();
private final MockTime time = new MockTime();
private final QuorumStateStore quorumStateStore = new MockQuorumStateStore();
private final Random random = Mockito.spy(new Random(1));
@@ -194,6 +199,7 @@ public RaftClientTestContext build() throws IOException {
KafkaRaftClient client = new KafkaRaftClient<>(
STRING_SERDE,
channel,
+ messageQueue,
log,
quorum,
memoryPool,
@@ -218,6 +224,7 @@ public RaftClientTestContext build() throws IOException {
client,
log,
channel,
+ messageQueue,
time,
quorumStateStore,
quorum,
@@ -233,6 +240,7 @@ private RaftClientTestContext(
KafkaRaftClient client,
MockLog log,
MockNetworkChannel channel,
+ MockMessageQueue messageQueue,
MockTime time,
QuorumStateStore quorumStateStore,
QuorumState quorum,
@@ -244,6 +252,7 @@ private RaftClientTestContext(
this.client = client;
this.log = log;
this.channel = channel;
+ this.messageQueue = messageQueue;
this.time = time;
this.quorumStateStore = quorumStateStore;
this.quorum = quorum;
@@ -326,7 +335,7 @@ LeaderAndEpoch currentLeaderAndEpoch() {
void expectAndGrantVotes(
int epoch
) throws Exception {
- pollUntilSend();
+ pollUntilRequest();
List voteRequests = collectVoteRequests(epoch,
log.lastFetchedEpoch(), log.endOffset().offset);
@@ -343,7 +352,7 @@ void expectAndGrantVotes(
void expectBeginEpoch(
int epoch
) throws Exception {
- pollUntilSend();
+ pollUntilRequest();
for (RaftRequest.Outbound request : collectBeginEpochRequests(epoch)) {
BeginQuorumEpochResponseData beginEpochResponse = beginEpochResponse(epoch, localId);
deliverResponse(request.correlationId, request.destinationId(), beginEpochResponse);
@@ -351,11 +360,19 @@ void expectBeginEpoch(
client.poll();
}
- void pollUntilSend() throws InterruptedException {
+ void pollUntil(TestCondition condition) throws InterruptedException {
TestUtils.waitForCondition(() -> {
client.poll();
- return channel.hasSentMessages();
- }, 5000, "Condition failed to be satisfied before timeout");
+ return condition.conditionMet();
+ }, 500000000, "Condition failed to be satisfied before timeout");
+ }
+
+ void pollUntilResponse() throws InterruptedException {
+ pollUntil(() -> !sentResponses.isEmpty());
+ }
+
+ void pollUntilRequest() throws InterruptedException {
+ pollUntil(channel::hasSentRequests);
}
void assertVotedCandidate(int epoch, int leaderId) throws IOException {
@@ -375,14 +392,16 @@ void assertResignedLeader(int epoch, int leaderId) throws IOException {
assertEquals(ElectionState.withElectedLeader(epoch, leaderId, voters), quorumStateStore.readElectionState());
}
- int assertSentDescribeQuorumResponse(int leaderId,
- int leaderEpoch,
- long highWatermark,
- List voterStates,
- List observerStates) {
- List sentMessages = channel.drainSendQueue();
+ int assertSentDescribeQuorumResponse(
+ int leaderId,
+ int leaderEpoch,
+ long highWatermark,
+ List voterStates,
+ List observerStates
+ ) {
+ List sentMessages = drainSentResponses(ApiKeys.DESCRIBE_QUORUM);
assertEquals(1, sentMessages.size());
- RaftMessage raftMessage = sentMessages.get(0);
+ RaftResponse.Outbound raftMessage = sentMessages.get(0);
assertTrue(
raftMessage.data() instanceof DescribeQuorumResponseData,
"Unexpected request type " + raftMessage.data());
@@ -412,7 +431,7 @@ void assertSentVoteResponse(
OptionalInt leaderId,
boolean voteGranted
) {
- List sentMessages = channel.drainSentResponses(ApiKeys.VOTE);
+ List sentMessages = drainSentResponses(ApiKeys.VOTE);
assertEquals(1, sentMessages.size());
RaftMessage raftMessage = sentMessages.get(0);
assertTrue(raftMessage.data() instanceof VoteResponseData);
@@ -449,8 +468,16 @@ List collectVoteRequests(
}
void deliverRequest(ApiMessage request) {
- RaftRequest.Inbound message = new RaftRequest.Inbound(channel.newCorrelationId(), request, time.milliseconds());
- channel.mockReceive(message);
+ RaftRequest.Inbound inboundRequest = new RaftRequest.Inbound(
+ channel.newCorrelationId(), request, time.milliseconds());
+ inboundRequest.completion.whenComplete((response, exception) -> {
+ if (exception != null) {
+ throw new RuntimeException(exception);
+ } else {
+ sentResponses.add(response);
+ }
+ });
+ client.handle(inboundRequest);
}
void deliverResponse(int correlationId, int sourceId, ApiMessage response) {
@@ -463,12 +490,27 @@ int assertSentBeginQuorumEpochRequest(int epoch, int numBeginEpochRequests) {
return requests.get(0).correlationId;
}
+ private List drainSentResponses(
+ ApiKeys apiKey
+ ) {
+ List res = new ArrayList<>();
+ Iterator iterator = sentResponses.iterator();
+ while (iterator.hasNext()) {
+ RaftResponse.Outbound response = iterator.next();
+ if (response.data.apiKey() == apiKey.id) {
+ res.add(response);
+ iterator.remove();
+ }
+ }
+ return res;
+ }
+
void assertSentBeginQuorumEpochResponse(
Errors partitionError,
int epoch,
OptionalInt leaderId
) {
- List sentMessages = channel.drainSentResponses(ApiKeys.BEGIN_QUORUM_EPOCH);
+ List sentMessages = drainSentResponses(ApiKeys.BEGIN_QUORUM_EPOCH);
assertEquals(1, sentMessages.size());
RaftMessage raftMessage = sentMessages.get(0);
assertTrue(raftMessage.data() instanceof BeginQuorumEpochResponseData);
@@ -495,7 +537,7 @@ void assertSentEndQuorumEpochResponse(
int epoch,
OptionalInt leaderId
) {
- List sentMessages = channel.drainSentResponses(ApiKeys.END_QUORUM_EPOCH);
+ List sentMessages = drainSentResponses(ApiKeys.END_QUORUM_EPOCH);
assertEquals(1, sentMessages.size());
RaftMessage raftMessage = sentMessages.get(0);
assertTrue(raftMessage.data() instanceof EndQuorumEpochResponseData);
@@ -511,7 +553,7 @@ void assertSentEndQuorumEpochResponse(
}
RaftRequest.Outbound assertSentFetchRequest() {
- List sentRequests = channel.drainSentRequests(ApiKeys.FETCH);
+ List sentRequests = channel.drainSentRequests(Optional.of(ApiKeys.FETCH));
assertEquals(1, sentRequests.size());
return sentRequests.get(0);
}
@@ -521,8 +563,10 @@ int assertSentFetchRequest(
long fetchOffset,
int lastFetchedEpoch
) {
- List sentMessages = channel.drainSendQueue();
+ List sentMessages = channel.drainSendQueue();
assertEquals(1, sentMessages.size());
+
+ // TODO: Use more specific type
RaftMessage raftMessage = sentMessages.get(0);
assertFetchRequestData(raftMessage, epoch, fetchOffset, lastFetchedEpoch);
return raftMessage.correlationId();
@@ -559,8 +603,7 @@ void buildFollowerSet(
// The lagging follower fetches first
deliverRequest(fetchRequest(1, laggingFollower, 0L, 0, 0));
- client.poll();
-
+ pollUntilResponse();
assertSentFetchResponse(0L, epoch);
// Append some records, so that the close follower will be able to advance further.
@@ -569,8 +612,7 @@ void buildFollowerSet(
deliverRequest(fetchRequest(epoch, closeFollower, 1L, epoch, 0));
- client.poll();
-
+ pollUntilResponse();
assertSentFetchResponse(1L, epoch);
}
@@ -600,7 +642,7 @@ void discoverLeaderAsObserver(
int leaderId,
int epoch
) throws Exception {
- pollUntilSend();
+ pollUntilRequest();
RaftRequest.Outbound fetchRequest = assertSentFetchRequest();
assertTrue(voters.contains(fetchRequest.destinationId()));
assertFetchRequestData(fetchRequest, 0, 0L, 0);
@@ -613,7 +655,7 @@ void discoverLeaderAsObserver(
private List collectBeginEpochRequests(int epoch) {
List requests = new ArrayList<>();
- for (RaftRequest.Outbound raftRequest : channel.drainSentRequests(ApiKeys.BEGIN_QUORUM_EPOCH)) {
+ for (RaftRequest.Outbound raftRequest : channel.drainSentRequests(Optional.of(ApiKeys.BEGIN_QUORUM_EPOCH))) {
assertTrue(raftRequest.data() instanceof BeginQuorumEpochRequestData);
BeginQuorumEpochRequestData request = (BeginQuorumEpochRequestData) raftRequest.data();
@@ -628,7 +670,7 @@ private List collectBeginEpochRequests(int epoch) {
}
private FetchResponseData.FetchablePartitionResponse assertSentPartitionResponse() {
- List sentMessages = channel.drainSentResponses(ApiKeys.FETCH);
+ List sentMessages = drainSentResponses(ApiKeys.FETCH);
assertEquals(
1, sentMessages.size(), "Found unexpected sent messages " + sentMessages);
RaftResponse.Outbound raftMessage = sentMessages.get(0);
diff --git a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
index cd4c6a0d7adf8..78f65dbfe0c0e 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RaftEventSimulationTest.java
@@ -19,8 +19,8 @@
import org.apache.kafka.common.TopicPartition;
import org.apache.kafka.common.memory.MemoryPool;
import org.apache.kafka.common.metrics.Metrics;
-import org.apache.kafka.common.protocol.Writable;
import org.apache.kafka.common.protocol.Readable;
+import org.apache.kafka.common.protocol.Writable;
import org.apache.kafka.common.protocol.types.Type;
import org.apache.kafka.common.utils.LogContext;
import org.apache.kafka.common.utils.MockTime;
@@ -62,9 +62,9 @@ public class RaftEventSimulationTest {
private static final TopicPartition METADATA_PARTITION = new TopicPartition("__cluster_metadata", 0);
private static final int ELECTION_TIMEOUT_MS = 1000;
private static final int ELECTION_JITTER_MS = 100;
- private static final int FETCH_TIMEOUT_MS = 5000;
+ private static final int FETCH_TIMEOUT_MS = 3000;
private static final int RETRY_BACKOFF_MS = 50;
- private static final int REQUEST_TIMEOUT_MS = 500;
+ private static final int REQUEST_TIMEOUT_MS = 3000;
private static final int FETCH_MAX_WAIT_MS = 100;
private static final int LINGER_MS = 0;
@@ -115,7 +115,7 @@ public void testElectionAfterLeaderFailureQuorumSizeThree() {
@Test
public void testElectionAfterLeaderFailureQuorumSizeThreeAndTwoObservers() {
- testElectionAfterLeaderFailure(new QuorumConfig(3, 2));
+ testElectionAfterLeaderFailure(new QuorumConfig(3, 1));
}
@Test
@@ -748,6 +748,7 @@ void start(int nodeId) {
LogContext logContext = new LogContext("[Node " + nodeId + "] ");
PersistentState persistentState = nodes.get(nodeId);
MockNetworkChannel channel = new MockNetworkChannel(correlationIdCounter);
+ MockMessageQueue messageQueue = new MockMessageQueue();
QuorumState quorum = new QuorumState(nodeId, voters(), ELECTION_TIMEOUT_MS,
FETCH_TIMEOUT_MS, persistentState.store, time, logContext, random);
Metrics metrics = new Metrics(time);
@@ -766,6 +767,7 @@ void start(int nodeId) {
KafkaRaftClient client = new KafkaRaftClient<>(
serde,
channel,
+ messageQueue,
persistentState.log,
quorum,
memoryPool,
@@ -781,8 +783,16 @@ void start(int nodeId) {
logContext,
random
);
- RaftNode node = new RaftNode(nodeId, client, persistentState.log, channel,
- persistentState.store, quorum, logContext);
+ RaftNode node = new RaftNode(
+ nodeId,
+ client,
+ persistentState.log,
+ channel,
+ messageQueue,
+ persistentState.store,
+ quorum,
+ logContext
+ );
node.initialize();
running.put(nodeId, node);
}
@@ -793,22 +803,27 @@ private static class RaftNode {
final KafkaRaftClient client;
final MockLog log;
final MockNetworkChannel channel;
+ final MockMessageQueue messageQueue;
final MockQuorumStateStore store;
final QuorumState quorum;
final LogContext logContext;
final ReplicatedCounter counter;
- private RaftNode(int nodeId,
- KafkaRaftClient client,
- MockLog log,
- MockNetworkChannel channel,
- MockQuorumStateStore store,
- QuorumState quorum,
- LogContext logContext) {
+ private RaftNode(
+ int nodeId,
+ KafkaRaftClient client,
+ MockLog log,
+ MockNetworkChannel channel,
+ MockMessageQueue messageQueue,
+ MockQuorumStateStore store,
+ QuorumState quorum,
+ LogContext logContext
+ ) {
this.nodeId = nodeId;
this.client = client;
this.log = log;
this.channel = channel;
+ this.messageQueue = messageQueue;
this.store = store;
this.quorum = quorum;
this.logContext = logContext;
@@ -826,9 +841,11 @@ void initialize() {
void poll() {
try {
- client.poll();
- } catch (IOException e) {
- throw new RuntimeException(e);
+ do {
+ client.poll();
+ } while (client.isRunning() && !messageQueue.isEmpty());
+ } catch (Exception e) {
+ throw new RuntimeException("Uncaught exception during poll of node " + nodeId, e);
}
}
}
@@ -1025,7 +1042,7 @@ private int parseSequenceNumber(ByteBuffer value) {
return (int) Type.INT32.read(value);
}
- private void assertCommittedData(int nodeId, KafkaRaftClient manager, MockLog log) {
+ private void assertCommittedData(int nodeId, KafkaRaftClient manager, MockLog log) {
OptionalLong highWatermark = manager.highWatermark();
if (!highWatermark.isPresent()) {
// We cannot do validation if the current high watermark is unknown
@@ -1070,6 +1087,9 @@ private MessageRouter(Cluster cluster) {
}
void deliver(int senderId, RaftRequest.Outbound outbound) {
+ if (!filters.get(senderId).acceptOutbound(outbound))
+ return;
+
int correlationId = outbound.correlationId();
int destinationId = outbound.destinationId();
RaftRequest.Inbound inbound = new RaftRequest.Inbound(correlationId, outbound.data(),
@@ -1079,9 +1099,15 @@ void deliver(int senderId, RaftRequest.Outbound outbound) {
return;
cluster.nodeIfRunning(destinationId).ifPresent(node -> {
- MockNetworkChannel destChannel = node.channel;
inflight.put(correlationId, new InflightRequest(correlationId, senderId, destinationId));
- destChannel.mockReceive(inbound);
+
+ inbound.completion.whenComplete((response, exception) -> {
+ if (response != null && filters.get(destinationId).acceptOutbound(response)) {
+ deliver(destinationId, response);
+ }
+ });
+
+ node.client.handle(inbound);
});
}
@@ -1089,6 +1115,7 @@ void deliver(int senderId, RaftResponse.Outbound outbound) {
int correlationId = outbound.correlationId();
RaftResponse.Inbound inbound = new RaftResponse.Inbound(correlationId, outbound.data(), senderId);
InflightRequest inflightRequest = inflight.remove(correlationId);
+
if (!filters.get(inflightRequest.sourceId).acceptInbound(inbound))
return;
@@ -1097,26 +1124,10 @@ void deliver(int senderId, RaftResponse.Outbound outbound) {
});
}
- void deliver(int senderId, RaftMessage message) {
- if (!filters.get(senderId).acceptOutbound(message)) {
- return;
- } else if (message instanceof RaftRequest.Outbound) {
- deliver(senderId, (RaftRequest.Outbound) message);
- } else if (message instanceof RaftResponse.Outbound) {
- deliver(senderId, (RaftResponse.Outbound) message);
- } else {
- throw new AssertionError("Illegal message type sent by node " + message);
- }
- }
-
void filter(int nodeId, NetworkFilter filter) {
filters.put(nodeId, filter);
}
- void deliverRandom() {
- cluster.forRandomRunning(this::deliverTo);
- }
-
void deliverTo(RaftNode node) {
node.channel.drainSendQueue().forEach(msg -> deliver(node.nodeId, msg));
}
diff --git a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java
index b516ce57ecc2a..e6e2f7cf0a6a5 100644
--- a/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java
+++ b/raft/src/test/java/org/apache/kafka/raft/RequestManagerTest.java
@@ -92,7 +92,7 @@ public void testSuccessfulResponse() {
long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
- connectionState.onResponseReceived(correlationId, time.milliseconds());
+ connectionState.onResponseReceived(correlationId);
assertTrue(connectionState.isReady(time.milliseconds()));
}
@@ -109,7 +109,7 @@ public void testIgnoreUnexpectedResponse() {
long correlationId = 1;
connectionState.onRequestSent(correlationId, time.milliseconds());
assertFalse(connectionState.isReady(time.milliseconds()));
- connectionState.onResponseReceived(correlationId + 1, time.milliseconds());
+ connectionState.onResponseReceived(correlationId + 1);
assertFalse(connectionState.isReady(time.milliseconds()));
}
@@ -121,7 +121,6 @@ public void testRequestTimeout() {
requestTimeoutMs,
random);
-
RequestManager.ConnectionState connectionState = cache.getOrCreate(1);
long correlationId = 1;
diff --git a/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java
new file mode 100644
index 0000000000000..e752fbd6dd37d
--- /dev/null
+++ b/raft/src/test/java/org/apache/kafka/raft/internals/BlockingMessageQueueTest.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.raft.internals;
+
+import org.apache.kafka.raft.RaftMessage;
+import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertNull;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+class BlockingMessageQueueTest {
+
+ @Test
+ public void testOfferAndPoll() {
+ BlockingMessageQueue queue = new BlockingMessageQueue();
+ assertTrue(queue.isEmpty());
+ assertNull(queue.poll(0));
+
+ RaftMessage message1 = Mockito.mock(RaftMessage.class);
+ queue.add(message1);
+ assertFalse(queue.isEmpty());
+ assertEquals(message1, queue.poll(0));
+ assertTrue(queue.isEmpty());
+
+ RaftMessage message2 = Mockito.mock(RaftMessage.class);
+ RaftMessage message3 = Mockito.mock(RaftMessage.class);
+ queue.add(message2);
+ queue.add(message3);
+ assertFalse(queue.isEmpty());
+ assertEquals(message2, queue.poll(0));
+ assertEquals(message3, queue.poll(0));
+
+ }
+
+ @Test
+ public void testWakeupFromPoll() {
+ BlockingMessageQueue queue = new BlockingMessageQueue();
+ queue.wakeup();
+ assertNull(queue.poll(Long.MAX_VALUE));
+ }
+
+}
\ No newline at end of file