diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala index 5a72554c74d60..068dff4cca609 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerChannelManager.scala @@ -109,23 +109,43 @@ object TransactionMarkerChannelManager { } -class TxnMarkerQueue(@volatile var destination: Node) { +class TxnMarkerQueue(@volatile var destination: Node) extends Logging { // keep track of the requests per txn topic partition so we can easily clear the queue // during partition emigration - private val markersPerTxnTopicPartition = new ConcurrentHashMap[Int, BlockingQueue[TxnIdAndMarkerEntry]]().asScala + private val markersPerTxnTopicPartition = new ConcurrentHashMap[Int, BlockingQueue[PendingCompleteTxnAndMarkerEntry]]().asScala - def removeMarkersForTxnTopicPartition(partition: Int): Option[BlockingQueue[TxnIdAndMarkerEntry]] = { + def removeMarkersForTxnTopicPartition(partition: Int): Option[BlockingQueue[PendingCompleteTxnAndMarkerEntry]] = { markersPerTxnTopicPartition.remove(partition) } - def addMarkers(txnTopicPartition: Int, txnIdAndMarker: TxnIdAndMarkerEntry): Unit = { - val queue = CoreUtils.atomicGetOrUpdate(markersPerTxnTopicPartition, txnTopicPartition, - new LinkedBlockingQueue[TxnIdAndMarkerEntry]()) - queue.add(txnIdAndMarker) + def addMarkers(txnTopicPartition: Int, pendingCompleteTxnAndMarker: PendingCompleteTxnAndMarkerEntry): Unit = { + val queue = CoreUtils.atomicGetOrUpdate(markersPerTxnTopicPartition, txnTopicPartition, { + // Note that this may get called more than once if threads have a close race while adding new queue. + info(s"Creating new marker queue for txn partition $txnTopicPartition to destination broker ${destination.id}") + new LinkedBlockingQueue[PendingCompleteTxnAndMarkerEntry]() + }) + queue.add(pendingCompleteTxnAndMarker) + + if (markersPerTxnTopicPartition.get(txnTopicPartition).orNull != queue) { + // This could happen if the queue got removed concurrently. + // Note that it could create an unexpected state when the queue is removed from + // removeMarkersForTxnTopicPartition, we could have: + // + // 1. [addMarkers] Retrieve queue. + // 2. [removeMarkersForTxnTopicPartition] Remove queue. + // 3. [removeMarkersForTxnTopicPartition] Iterate over queue, but not removeMarkersForTxn because queue is empty. + // 4. [addMarkers] Add markers to the queue. + // + // Now we've effectively removed the markers while transactionsWithPendingMarkers has an entry. + // + // While this could lead to an orphan entry in transactionsWithPendingMarkers, sending new markers + // will fix the state, so it shouldn't impact the state machine operation. + warn(s"Added $pendingCompleteTxnAndMarker to dead queue for txn partition $txnTopicPartition to destination broker ${destination.id}") + } } - def forEachTxnTopicPartition[B](f:(Int, BlockingQueue[TxnIdAndMarkerEntry]) => B): Unit = + def forEachTxnTopicPartition[B](f:(Int, BlockingQueue[PendingCompleteTxnAndMarkerEntry]) => B): Unit = markersPerTxnTopicPartition.forKeyValue { (partition, queue) => if (!queue.isEmpty) f(partition, queue) } @@ -187,17 +207,21 @@ class TransactionMarkerChannelManager( // visible for testing private[transaction] def queueForUnknownBroker = markersQueueForUnknownBroker - private[transaction] def addMarkersForBroker(broker: Node, txnTopicPartition: Int, txnIdAndMarker: TxnIdAndMarkerEntry): Unit = { + private[transaction] def addMarkersForBroker(broker: Node, txnTopicPartition: Int, pendingCompleteTxnAndMarker: PendingCompleteTxnAndMarkerEntry): Unit = { val brokerId = broker.id // we do not synchronize on the update of the broker node with the enqueuing, // since even if there is a race condition we will just retry - val brokerRequestQueue = CoreUtils.atomicGetOrUpdate(markersQueuePerBroker, brokerId, - new TxnMarkerQueue(broker)) + val brokerRequestQueue = CoreUtils.atomicGetOrUpdate(markersQueuePerBroker, brokerId, { + // Note that this may get called more than once if threads have a close race while adding new queue. + info(s"Creating new marker queue map to destination broker $brokerId") + new TxnMarkerQueue(broker) + }) brokerRequestQueue.destination = broker - brokerRequestQueue.addMarkers(txnTopicPartition, txnIdAndMarker) + brokerRequestQueue.addMarkers(txnTopicPartition, pendingCompleteTxnAndMarker) - trace(s"Added marker ${txnIdAndMarker.txnMarkerEntry} for transactional id ${txnIdAndMarker.txnId} to destination broker $brokerId") + trace(s"Added marker ${pendingCompleteTxnAndMarker.txnMarkerEntry} for transactional id" + + s" ${pendingCompleteTxnAndMarker.pendingCompleteTxn.transactionalId} to destination broker $brokerId") } private def retryLogAppends(): Unit = { @@ -211,29 +235,28 @@ class TransactionMarkerChannelManager( override def generateRequests(): util.Collection[RequestAndCompletionHandler] = { retryLogAppends() - val txnIdAndMarkerEntries: util.List[TxnIdAndMarkerEntry] = new util.ArrayList[TxnIdAndMarkerEntry]() + val pendingCompleteTxnAndMarkerEntries = new util.ArrayList[PendingCompleteTxnAndMarkerEntry]() markersQueueForUnknownBroker.forEachTxnTopicPartition { case (_, queue) => - queue.drainTo(txnIdAndMarkerEntries) + queue.drainTo(pendingCompleteTxnAndMarkerEntries) } - for (txnIdAndMarker: TxnIdAndMarkerEntry <- txnIdAndMarkerEntries.asScala) { - val transactionalId = txnIdAndMarker.txnId - val producerId = txnIdAndMarker.txnMarkerEntry.producerId - val producerEpoch = txnIdAndMarker.txnMarkerEntry.producerEpoch - val txnResult = txnIdAndMarker.txnMarkerEntry.transactionResult - val coordinatorEpoch = txnIdAndMarker.txnMarkerEntry.coordinatorEpoch - val topicPartitions = txnIdAndMarker.txnMarkerEntry.partitions.asScala.toSet + for (pendingCompleteTxnAndMarker: PendingCompleteTxnAndMarkerEntry <- pendingCompleteTxnAndMarkerEntries.asScala) { + val producerId = pendingCompleteTxnAndMarker.txnMarkerEntry.producerId + val producerEpoch = pendingCompleteTxnAndMarker.txnMarkerEntry.producerEpoch + val txnResult = pendingCompleteTxnAndMarker.txnMarkerEntry.transactionResult + val pendingCompleteTxn = pendingCompleteTxnAndMarker.pendingCompleteTxn + val topicPartitions = pendingCompleteTxnAndMarker.txnMarkerEntry.partitions.asScala.toSet - addTxnMarkersToBrokerQueue(transactionalId, producerId, producerEpoch, txnResult, coordinatorEpoch, topicPartitions) + addTxnMarkersToBrokerQueue(producerId, producerEpoch, txnResult, pendingCompleteTxn, topicPartitions) } val currentTimeMs = time.milliseconds() markersQueuePerBroker.values.map { brokerRequestQueue => - val txnIdAndMarkerEntries = new util.ArrayList[TxnIdAndMarkerEntry]() + val pendingCompleteTxnAndMarkerEntries = new util.ArrayList[PendingCompleteTxnAndMarkerEntry]() brokerRequestQueue.forEachTxnTopicPartition { case (_, queue) => - queue.drainTo(txnIdAndMarkerEntries) + queue.drainTo(pendingCompleteTxnAndMarkerEntries) } - (brokerRequestQueue.destination, txnIdAndMarkerEntries) + (brokerRequestQueue.destination, pendingCompleteTxnAndMarkerEntries) }.filter { case (_, entries) => !entries.isEmpty }.map { case (node, entries) => val markersToSend = entries.asScala.map(_.txnMarkerEntry).asJava val requestCompletionHandler = new TransactionMarkerRequestCompletionHandler(node.id, txnStateManager, this, entries) @@ -300,9 +323,12 @@ class TransactionMarkerChannelManager( txnMetadata, newMetadata) - transactionsWithPendingMarkers.put(transactionalId, pendingCompleteTxn) - addTxnMarkersToBrokerQueue(transactionalId, txnMetadata.producerId, - txnMetadata.producerEpoch, txnResult, coordinatorEpoch, txnMetadata.topicPartitions.toSet) + val prev = transactionsWithPendingMarkers.put(transactionalId, pendingCompleteTxn) + if (prev != null) { + info(s"Replaced an existing pending complete txn $prev with $pendingCompleteTxn while adding markers to send.") + } + addTxnMarkersToBrokerQueue(txnMetadata.producerId, + txnMetadata.producerEpoch, txnResult, pendingCompleteTxn, txnMetadata.topicPartitions.toSet) maybeWriteTxnCompletion(transactionalId) } @@ -354,41 +380,42 @@ class TransactionMarkerChannelManager( txnLogAppend.newMetadata, appendCallback, _ == Errors.COORDINATOR_NOT_AVAILABLE, RequestLocal.NoCaching) } - def addTxnMarkersToBrokerQueue(transactionalId: String, - producerId: Long, + def addTxnMarkersToBrokerQueue(producerId: Long, producerEpoch: Short, result: TransactionResult, - coordinatorEpoch: Int, + pendingCompleteTxn: PendingCompleteTxn, topicPartitions: immutable.Set[TopicPartition]): Unit = { - val txnTopicPartition = txnStateManager.partitionFor(transactionalId) + val txnTopicPartition = txnStateManager.partitionFor(pendingCompleteTxn.transactionalId) val partitionsByDestination: immutable.Map[Option[Node], immutable.Set[TopicPartition]] = topicPartitions.groupBy { topicPartition: TopicPartition => metadataCache.getPartitionLeaderEndpoint(topicPartition.topic, topicPartition.partition, interBrokerListenerName) } + val coordinatorEpoch = pendingCompleteTxn.coordinatorEpoch for ((broker: Option[Node], topicPartitions: immutable.Set[TopicPartition]) <- partitionsByDestination) { broker match { case Some(brokerNode) => val marker = new TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, result, topicPartitions.toList.asJava) - val txnIdAndMarker = TxnIdAndMarkerEntry(transactionalId, marker) + val pendingCompleteTxnAndMarker = PendingCompleteTxnAndMarkerEntry(pendingCompleteTxn, marker) if (brokerNode == Node.noNode) { // if the leader of the partition is known but node not available, put it into an unknown broker queue // and let the sender thread to look for its broker and migrate them later - markersQueueForUnknownBroker.addMarkers(txnTopicPartition, txnIdAndMarker) + markersQueueForUnknownBroker.addMarkers(txnTopicPartition, pendingCompleteTxnAndMarker) } else { - addMarkersForBroker(brokerNode, txnTopicPartition, txnIdAndMarker) + addMarkersForBroker(brokerNode, txnTopicPartition, pendingCompleteTxnAndMarker) } case None => + val transactionalId = pendingCompleteTxn.transactionalId txnStateManager.getTransactionState(transactionalId) match { case Left(error) => info(s"Encountered $error trying to fetch transaction metadata for $transactionalId with coordinator epoch $coordinatorEpoch; cancel sending markers to its partition leaders") - transactionsWithPendingMarkers.remove(transactionalId) + transactionsWithPendingMarkers.remove(transactionalId, pendingCompleteTxn) case Right(Some(epochAndMetadata)) => if (epochAndMetadata.coordinatorEpoch != coordinatorEpoch) { info(s"The cached metadata has changed to $epochAndMetadata (old coordinator epoch is $coordinatorEpoch) since preparing to send markers; cancel sending markers to its partition leaders") - transactionsWithPendingMarkers.remove(transactionalId) + transactionsWithPendingMarkers.remove(transactionalId, pendingCompleteTxn) } else { // if the leader of the partition is unknown, skip sending the txn marker since // the partition is likely to be deleted already @@ -419,25 +446,34 @@ class TransactionMarkerChannelManager( def removeMarkersForTxnTopicPartition(txnTopicPartitionId: Int): Unit = { markersQueueForUnknownBroker.removeMarkersForTxnTopicPartition(txnTopicPartitionId).foreach { queue => - for (entry: TxnIdAndMarkerEntry <- queue.asScala) - removeMarkersForTxnId(entry.txnId) + for (entry <- queue.asScala) { + info(s"Removing $entry for txn partition $txnTopicPartitionId to destination broker -1") + removeMarkersForTxn(entry.pendingCompleteTxn) + } } - markersQueuePerBroker.foreach { case(_, brokerQueue) => + markersQueuePerBroker.foreach { case(brokerId, brokerQueue) => brokerQueue.removeMarkersForTxnTopicPartition(txnTopicPartitionId).foreach { queue => - for (entry: TxnIdAndMarkerEntry <- queue.asScala) - removeMarkersForTxnId(entry.txnId) + for (entry <- queue.asScala) { + info(s"Removing $entry for txn partition $txnTopicPartitionId to destination broker $brokerId") + removeMarkersForTxn(entry.pendingCompleteTxn) + } } } } - def removeMarkersForTxnId(transactionalId: String): Unit = { - transactionsWithPendingMarkers.remove(transactionalId) + def removeMarkersForTxn(pendingCompleteTxn: PendingCompleteTxn): Unit = { + val transactionalId = pendingCompleteTxn.transactionalId + val removed = transactionsWithPendingMarkers.remove(transactionalId, pendingCompleteTxn) + if (!removed) { + val current = transactionsWithPendingMarkers.get(transactionalId) + if (current != null) { + info(s"Failed to remove pending marker entry $current trying to remove $pendingCompleteTxn") + } + } } } -case class TxnIdAndMarkerEntry(txnId: String, txnMarkerEntry: TxnMarkerEntry) - case class PendingCompleteTxn(transactionalId: String, coordinatorEpoch: Int, txnMetadata: TransactionMetadata, @@ -451,3 +487,5 @@ case class PendingCompleteTxn(transactionalId: String, s"newMetadata=$newMetadata)" } } + +case class PendingCompleteTxnAndMarkerEntry(pendingCompleteTxn: PendingCompleteTxn, txnMarkerEntry: TxnMarkerEntry) diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala index 7a59139b17c76..d95dabab6c356 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandler.scala @@ -29,7 +29,7 @@ import scala.jdk.CollectionConverters._ class TransactionMarkerRequestCompletionHandler(brokerId: Int, txnStateManager: TransactionStateManager, txnMarkerChannelManager: TransactionMarkerChannelManager, - txnIdAndMarkerEntries: java.util.List[TxnIdAndMarkerEntry]) extends RequestCompletionHandler with Logging { + pendingCompleteTxnAndMarkerEntries: java.util.List[PendingCompleteTxnAndMarkerEntry]) extends RequestCompletionHandler with Logging { this.logIdent = "[Transaction Marker Request Completion Handler " + brokerId + "]: " @@ -39,22 +39,23 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, if (response.wasDisconnected) { trace(s"Cancelled request with header $requestHeader due to node ${response.destination} being disconnected") - for (txnIdAndMarker <- txnIdAndMarkerEntries.asScala) { - val transactionalId = txnIdAndMarker.txnId - val txnMarker = txnIdAndMarker.txnMarkerEntry + for (pendingCompleteTxnAndMarker <- pendingCompleteTxnAndMarkerEntries.asScala) { + val pendingCompleteTxn = pendingCompleteTxnAndMarker.pendingCompleteTxn + val transactionalId = pendingCompleteTxn.transactionalId + val txnMarker = pendingCompleteTxnAndMarker.txnMarkerEntry txnStateManager.getTransactionState(transactionalId) match { case Left(Errors.NOT_COORDINATOR) => info(s"I am no longer the coordinator for $transactionalId; cancel sending transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) case Left(Errors.COORDINATOR_LOAD_IN_PROGRESS) => info(s"I am loading the transaction partition that contains $transactionalId which means the current markers have to be obsoleted; " + s"cancel sending transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) case Left(unexpectedError) => throw new IllegalStateException(s"Unhandled error $unexpectedError when fetching current transaction state") @@ -69,17 +70,16 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, info(s"Transaction coordinator epoch for $transactionalId has changed from ${txnMarker.coordinatorEpoch} to " + s"${epochAndMetadata.coordinatorEpoch}; cancel sending transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) } else { // re-enqueue the markers with possibly new destination brokers trace(s"Re-enqueuing ${txnMarker.transactionResult} transaction markers for transactional id $transactionalId " + s"under coordinator epoch ${txnMarker.coordinatorEpoch}") - txnMarkerChannelManager.addTxnMarkersToBrokerQueue(transactionalId, - txnMarker.producerId, + txnMarkerChannelManager.addTxnMarkersToBrokerQueue(txnMarker.producerId, txnMarker.producerEpoch, txnMarker.transactionResult, - txnMarker.coordinatorEpoch, + pendingCompleteTxn, txnMarker.partitions.asScala.toSet) } } @@ -90,9 +90,10 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, val writeTxnMarkerResponse = response.responseBody.asInstanceOf[WriteTxnMarkersResponse] val responseErrors = writeTxnMarkerResponse.errorsByProducerId - for (txnIdAndMarker <- txnIdAndMarkerEntries.asScala) { - val transactionalId = txnIdAndMarker.txnId - val txnMarker = txnIdAndMarker.txnMarkerEntry + for (pendingCompleteTxnAndMarker <- pendingCompleteTxnAndMarkerEntries.asScala) { + val pendingCompleteTxn = pendingCompleteTxnAndMarker.pendingCompleteTxn + val transactionalId = pendingCompleteTxn.transactionalId + val txnMarker = pendingCompleteTxnAndMarker.txnMarkerEntry val errors = responseErrors.get(txnMarker.producerId) if (errors == null) @@ -102,13 +103,13 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, case Left(Errors.NOT_COORDINATOR) => info(s"I am no longer the coordinator for $transactionalId; cancel sending transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) case Left(Errors.COORDINATOR_LOAD_IN_PROGRESS) => info(s"I am loading the transaction partition that contains $transactionalId which means the current markers have to be obsoleted; " + s"cancel sending transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) case Left(unexpectedError) => throw new IllegalStateException(s"Unhandled error $unexpectedError when fetching current transaction state") @@ -127,7 +128,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, info(s"Transaction coordinator epoch for $transactionalId has changed from ${txnMarker.coordinatorEpoch} to " + s"${epochAndMetadata.coordinatorEpoch}; cancel sending transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) abortSending = true } else { txnMetadata.inLock { @@ -161,7 +162,7 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, info(s"Sending $transactionalId's transaction marker for partition $topicPartition has permanently failed with error ${error.exceptionName} " + s"with the current coordinator epoch ${epochAndMetadata.coordinatorEpoch}; cancel sending any more transaction markers $txnMarker to the brokers") - txnMarkerChannelManager.removeMarkersForTxnId(transactionalId) + txnMarkerChannelManager.removeMarkersForTxn(pendingCompleteTxn) abortSending = true case Errors.UNSUPPORTED_FOR_MESSAGE_FORMAT | @@ -187,11 +188,10 @@ class TransactionMarkerRequestCompletionHandler(brokerId: Int, // re-enqueue with possible new leaders of the partitions txnMarkerChannelManager.addTxnMarkersToBrokerQueue( - transactionalId, txnMarker.producerId, txnMarker.producerEpoch, txnMarker.transactionResult, - txnMarker.coordinatorEpoch, + pendingCompleteTxn, retryPartitions.toSet) } else { txnMarkerChannelManager.maybeWriteTxnCompletion(transactionalId) diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index 355e848980746..813fd75be36b6 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -557,6 +557,7 @@ class TransactionStateManager(brokerId: Int, loadingPartitions.remove(partitionAndLeaderEpoch) transactionsPendingForCompletion.foreach { txnTransitMetadata => + info(s"Sending txn markers for $txnTransitMetadata after loading partition $partitionId") sendTxnMarkers(txnTransitMetadata.coordinatorEpoch, txnTransitMetadata.result, txnTransitMetadata.txnMetadata, txnTransitMetadata.transitMetadata) } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala index de58f8ed7fa86..3356c4f9e372c 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerChannelManagerTest.scala @@ -34,7 +34,7 @@ import org.junit.jupiter.api.Assertions._ import org.junit.jupiter.api.Test import org.mockito.ArgumentMatchers.any import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.Mockito.{mock, mockConstruction, times, verify, verifyNoMoreInteractions, when} +import org.mockito.Mockito.{clearInvocations, mock, mockConstruction, times, verify, verifyNoMoreInteractions, when} import scala.jdk.CollectionConverters._ import scala.collection.mutable @@ -59,6 +59,7 @@ class TransactionMarkerChannelManagerTest { private val txnTopicPartition1 = 0 private val txnTopicPartition2 = 1 private val coordinatorEpoch = 0 + private val coordinatorEpoch2 = 1 private val txnTimeoutMs = 0 private val txnResult = TransactionResult.COMMIT private val txnMetadata1 = new TransactionMetadata(transactionalId1, producerId1, producerId1, producerEpoch, lastProducerEpoch, @@ -177,6 +178,86 @@ class TransactionMarkerChannelManagerTest { any()) } + @Test + def shouldNotLoseTxnCompletionAfterLoad(): Unit = { + mockCache() + + val expectedTransition = txnMetadata2.prepareComplete(time.milliseconds()) + + when(metadataCache.getPartitionLeaderEndpoint( + ArgumentMatchers.eq(partition1.topic), + ArgumentMatchers.eq(partition1.partition), + any()) + ).thenReturn(Some(broker1)) + + // Build a successful client response. + val header = new RequestHeader(ApiKeys.WRITE_TXN_MARKERS, 0, "client", 1) + val successfulResponse = new WriteTxnMarkersResponse( + Collections.singletonMap(producerId2: java.lang.Long, Collections.singletonMap(partition1, Errors.NONE))) + val successfulClientResponse = new ClientResponse(header, null, null, + time.milliseconds(), time.milliseconds(), false, null, null, + successfulResponse) + + // Build a disconnected client response. + val disconnectedClientResponse = new ClientResponse(header, null, null, + time.milliseconds(), time.milliseconds(), true, null, null, + null) + + // Test matrix to cover various scenarios: + val clientResponses = Seq(successfulClientResponse, disconnectedClientResponse) + val getTransactionStateResponses = Seq( + // NOT_COORDINATOR error case + Left(Errors.NOT_COORDINATOR), + // COORDINATOR_LOAD_IN_PROGRESS + Left(Errors.COORDINATOR_LOAD_IN_PROGRESS), + // "Newly loaded" transaction state with the new epoch. + Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch2, txnMetadata2))) + ) + + clientResponses.foreach { clientResponse => + getTransactionStateResponses.foreach { getTransactionStateResponse => + // Reset data from previous iteration. + txnMetadata2.topicPartitions.add(partition1) + clearInvocations(txnStateManager) + // Send out markers for a transaction before load. + channelManager.addTxnMarkersToSend(coordinatorEpoch, txnResult, + txnMetadata2, expectedTransition) + + // Drain the marker to make it "in-flight". + val requests1 = channelManager.generateRequests().asScala + assertEquals(1, requests1.size) + + // Simulate a partition load: + // 1. Remove the markers from the channel manager. + // 2. Simulate the corresponding test case scenario. + // 3. Add the markers back to the channel manager. + channelManager.removeMarkersForTxnTopicPartition(txnTopicPartition2) + when(txnStateManager.getTransactionState(ArgumentMatchers.eq(transactionalId2))) + .thenReturn(getTransactionStateResponse) + channelManager.addTxnMarkersToSend(coordinatorEpoch2, txnResult, + txnMetadata2, expectedTransition) + + // Complete the marker from the previous epoch. + requests1.head.handler.onComplete(clientResponse) + + // Now drain and complete the marker from the new epoch. + when(txnStateManager.getTransactionState(ArgumentMatchers.eq(transactionalId2))) + .thenReturn(Right(Some(CoordinatorEpochAndTxnMetadata(coordinatorEpoch2, txnMetadata2)))) + val requests2 = channelManager.generateRequests().asScala + assertEquals(1, requests2.size) + requests2.head.handler.onComplete(successfulClientResponse) + + verify(txnStateManager).appendTransactionToLog( + ArgumentMatchers.eq(transactionalId2), + ArgumentMatchers.eq(coordinatorEpoch2), + ArgumentMatchers.eq(expectedTransition), + capturedErrorsCallback.capture(), + any(), + any()) + } + } + } + @Test def shouldGenerateEmptyMapWhenNoRequestsOutstanding(): Unit = { assertTrue(channelManager.generateRequests().isEmpty) diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala index aecf6542f7d7f..1004915f46cb9 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionMarkerRequestCompletionHandlerTest.scala @@ -18,7 +18,6 @@ package kafka.coordinator.transaction import java.{lang, util} import java.util.Arrays.asList - import org.apache.kafka.clients.ClientResponse import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.protocol.{ApiKeys, Errors} @@ -43,18 +42,19 @@ class TransactionMarkerRequestCompletionHandlerTest { private val coordinatorEpoch = 0 private val txnResult = TransactionResult.COMMIT private val topicPartition = new TopicPartition("topic1", 0) - private val txnIdAndMarkers = asList( - TxnIdAndMarkerEntry(transactionalId, new WriteTxnMarkersRequest.TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, txnResult, asList(topicPartition)))) - private val txnMetadata = new TransactionMetadata(transactionalId, producerId, producerId, producerEpoch, lastProducerEpoch, txnTimeoutMs, PrepareCommit, mutable.Set[TopicPartition](topicPartition), 0L, 0L) + private val pendingCompleteTxnAndMarkers = asList( + PendingCompleteTxnAndMarkerEntry( + PendingCompleteTxn(transactionalId, coordinatorEpoch, txnMetadata, txnMetadata.prepareComplete(42)), + new WriteTxnMarkersRequest.TxnMarkerEntry(producerId, producerEpoch, coordinatorEpoch, txnResult, asList(topicPartition)))) private val markerChannelManager: TransactionMarkerChannelManager = mock(classOf[TransactionMarkerChannelManager]) private val txnStateManager: TransactionStateManager = mock(classOf[TransactionStateManager]) - private val handler = new TransactionMarkerRequestCompletionHandler(brokerId, txnStateManager, markerChannelManager, txnIdAndMarkers) + private val handler = new TransactionMarkerRequestCompletionHandler(brokerId, txnStateManager, markerChannelManager, pendingCompleteTxnAndMarkers) private def mockCache(): Unit = { when(txnStateManager.partitionFor(transactionalId)) @@ -70,8 +70,9 @@ class TransactionMarkerRequestCompletionHandlerTest { handler.onComplete(new ClientResponse(new RequestHeader(ApiKeys.PRODUCE, 0, "client", 1), null, null, 0, 0, true, null, null, null)) - verify(markerChannelManager).addTxnMarkersToBrokerQueue(transactionalId, - producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition)) + verify(markerChannelManager).addTxnMarkersToBrokerQueue(producerId, + producerEpoch, txnResult, pendingCompleteTxnAndMarkers.get(0).pendingCompleteTxn, + Set[TopicPartition](topicPartition)) } @Test @@ -193,8 +194,9 @@ class TransactionMarkerRequestCompletionHandlerTest { null, null, 0, 0, false, null, null, response)) assertEquals(txnMetadata.topicPartitions, mutable.Set[TopicPartition](topicPartition)) - verify(markerChannelManager).addTxnMarkersToBrokerQueue(transactionalId, - producerId, producerEpoch, txnResult, coordinatorEpoch, Set[TopicPartition](topicPartition)) + verify(markerChannelManager).addTxnMarkersToBrokerQueue(producerId, + producerEpoch, txnResult, pendingCompleteTxnAndMarkers.get(0).pendingCompleteTxn, + Set[TopicPartition](topicPartition)) } private def verifyThrowIllegalStateExceptionOnError(error: Errors) = { @@ -222,7 +224,8 @@ class TransactionMarkerRequestCompletionHandlerTest { private def verifyRemoveDelayedOperationOnError(error: Errors): Unit = { var removed = false - when(markerChannelManager.removeMarkersForTxnId(transactionalId)) + val pendingCompleteTxn = pendingCompleteTxnAndMarkers.get(0).pendingCompleteTxn + when(markerChannelManager.removeMarkersForTxn(pendingCompleteTxn)) .thenAnswer(_ => removed = true) val response = new WriteTxnMarkersResponse(createProducerIdErrorMap(error))