diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java index 758631a1d87aa..24a84de438d54 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java @@ -230,7 +230,7 @@ public static void validateRecords(short version, BaseRecords baseRecords) { Iterator iterator = records.batches().iterator(); if (!iterator.hasNext()) throw new InvalidRecordException("Produce requests with version " + version + " must have at least " + - "one record batch"); + "one record batch per partition"); RecordBatch entry = iterator.next(); if (entry.magic() != RecordBatch.MAGIC_VALUE_V2) @@ -243,7 +243,7 @@ public static void validateRecords(short version, BaseRecords baseRecords) { if (iterator.hasNext()) throw new InvalidRecordException("Produce requests with version " + version + " are only allowed to " + - "contain exactly one record batch"); + "contain exactly one record batch per partition"); } } diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java index 9a01600012577..4ac61dbf5b3a3 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java @@ -73,7 +73,7 @@ public void shouldNotBeFlaggedAsTransactionalWhenNoRecords() { @Test public void shouldNotBeFlaggedAsIdempotentWhenRecordsNotIdempotent() { final ProduceRequest request = createNonIdempotentNonTransactionalRecords(); - assertFalse(RequestUtils.hasTransactionalRecords(request)); + assertFalse(RequestTestUtils.hasIdempotentRecords(request)); } @Test @@ -271,18 +271,18 @@ public void testMixedIdempotentData() { final short producerEpoch = 5; final int sequence = 10; - final MemoryRecords nonTxnRecords = MemoryRecords.withRecords(CompressionType.NONE, + final MemoryRecords nonIdempotentRecords = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("foo".getBytes())); - final MemoryRecords txnRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, + final MemoryRecords idempotentRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence, new SimpleRecord("bar".getBytes())); ProduceRequest.Builder builder = ProduceRequest.forMagic(RecordVersion.current().value, new ProduceRequestData() .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(txnRecords))), + new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(idempotentRecords))), new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(nonTxnRecords)))) + new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(nonIdempotentRecords)))) .iterator())) .setAcks((short) -1) .setTimeoutMs(5000)); diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 5980ebef2219d..92ddb59ed5786 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -637,17 +637,31 @@ class ReplicaManager(val config: KafkaConfig, if (isValidRequiredAcks(requiredAcks)) { val sTime = time.milliseconds + val transactionalProducerIds = mutable.HashSet[Long]() val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition) = if (transactionStatePartition.isEmpty || !config.transactionPartitionVerificationEnable) (entriesPerPartition, Map.empty) - else + else { entriesPerPartition.partition { case (topicPartition, records) => - getPartitionOrException(topicPartition).hasOngoingTransaction(records.firstBatch().producerId()) + // Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe. + val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional) + transactionalBatches.foreach(batch => transactionalProducerIds.add(batch.producerId)) + if (transactionalBatches.nonEmpty) { + getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId) + } else { + // If there is no producer ID in the batches, no need to verify. + true + } } + } + // We should have exactly one producer ID for transactional records + if (transactionalProducerIds.size > 1) { + throw new InvalidPidMappingException("Transactional records contained more than one producer ID") + } def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = { val verifiedEntries = - if (unverifiedEntries.isEmpty) + if (unverifiedEntries.isEmpty) allEntries else allEntries.filter { case (tp, _) => @@ -743,7 +757,8 @@ class ReplicaManager(val config: KafkaConfig, .setPartitions(tps.map(tp => Integer.valueOf(tp.partition())).toList.asJava)) } - // map not yet verified partitions to a request object + // Map not yet verified partitions to a request object. + // We verify above that all partitions use the same producer ID. val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch() val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction() .setTransactionalId(transactionalId) diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index 8319a4836efd9..36e6abfe588f4 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -35,7 +35,7 @@ import kafka.server.epoch.util.MockBlockingSender import kafka.utils.timer.MockTimer import kafka.utils.{MockTime, Pool, TestUtils} import org.apache.kafka.clients.FetchSessionHandler -import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.errors.{InvalidPidMappingException, KafkaStorageException} import org.apache.kafka.common.message.LeaderAndIsrRequestData import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset @@ -2162,11 +2162,65 @@ class ReplicaManagerTest { callback(Map(tp -> Errors.INVALID_RECORD).toMap) assertEquals(Errors.INVALID_RECORD, result.assertFired.error) - // If we don't supply a transaction coordinator partition, we do not verify, so counter stays the same. - val transactionalRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1, + // If we supply no transactional ID and idempotent records, we do not verify, so counter stays the same. + val idempotentRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1, new SimpleRecord(s"message $sequence".getBytes)) - appendRecords(replicaManager, tp, transactionalRecords2) + appendRecords(replicaManager, tp, idempotentRecords2) verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]()) + + // If we supply a transactional ID and some transactional and some idempotent records, we should only verify the topic partition with transactional records. + appendRecordsToMultipleTopics(replicaManager, Map(tp -> transactionalRecords, new TopicPartition(topic, 1) -> idempotentRecords2), transactionalId, Some(0)) + verify(addPartitionsToTxnManager, times(2)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]()) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } + + @Test + def testExceptionWhenUnverifiedTransactionHasMultipleProducerIds(): Unit = { + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val transactionalId = "txn1" + val producerId = 24L + val producerEpoch = 0.toShort + val sequence = 0 + + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val metadataCache = mock(classOf[MetadataCache]) + val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager]) + + val replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterPartitionManager = alterPartitionManager, + addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) + + try { + replicaManager.becomeLeaderOrFollower(1, + makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(1, List(0, 1))), + (_, _) => ()) + + replicaManager.becomeLeaderOrFollower(1, + makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), LeaderAndIsr(1, List(0, 1))), + (_, _) => ()) + + // Append some transactional records with different producer IDs + val transactionalRecords = mutable.Map[TopicPartition, MemoryRecords]() + transactionalRecords.put(tp0, MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord(s"message $sequence".getBytes))) + transactionalRecords.put(tp1, MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId + 1, producerEpoch, sequence, + new SimpleRecord(s"message $sequence".getBytes))) + + assertThrows(classOf[InvalidPidMappingException], + () => appendRecordsToMultipleTopics(replicaManager, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))) } finally { replicaManager.shutdown() } @@ -2544,6 +2598,27 @@ class ReplicaManagerTest { result } + private def appendRecordsToMultipleTopics(replicaManager: ReplicaManager, + entriesToAppend: Map[TopicPartition, MemoryRecords], + transactionalId: String, + transactionStatePartition: Option[Int], + origin: AppendOrigin = AppendOrigin.CLIENT, + requiredAcks: Short = -1): Unit = { + def appendCallback(responses: Map[TopicPartition, PartitionResponse]): Unit = { + responses.foreach( response => responses.get(response._1).isDefined) + } + + replicaManager.appendRecords( + timeout = 1000, + requiredAcks = requiredAcks, + internalTopicsAllowed = false, + origin = origin, + entriesPerPartition = entriesToAppend, + responseCallback = appendCallback, + transactionalId = transactionalId, + transactionStatePartition = transactionStatePartition) + } + private def fetchPartitionAsConsumer( replicaManager: ReplicaManager, partition: TopicIdPartition,