From 6ea5bd4b642fd666ce5ed1fa1d723cdd6387fdf8 Mon Sep 17 00:00:00 2001 From: Justine Date: Tue, 18 Apr 2023 12:14:06 -0700 Subject: [PATCH 1/8] to validate records and only verify the necessary partitions --- .../kafka/common/requests/ProduceRequest.java | 28 +++++++++++-- .../common/requests/ProduceRequestTest.java | 41 ++++++++++++++++--- .../main/scala/kafka/server/KafkaApis.scala | 8 +++- .../scala/kafka/server/ReplicaManager.scala | 10 ++++- .../unit/kafka/server/KafkaApisTest.scala | 36 ++++++++++++++++ .../kafka/server/ReplicaManagerTest.scala | 31 ++++++++++++-- 6 files changed, 140 insertions(+), 14 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java index 758631a1d87aa..6269091da0502 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java @@ -24,10 +24,7 @@ import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.Errors; -import org.apache.kafka.common.record.BaseRecords; -import org.apache.kafka.common.record.CompressionType; -import org.apache.kafka.common.record.RecordBatch; -import org.apache.kafka.common.record.Records; +import org.apache.kafka.common.record.*; import org.apache.kafka.common.utils.Utils; import java.nio.ByteBuffer; @@ -88,6 +85,7 @@ private ProduceRequest build(short version, boolean validate) { data.topicData().forEach(tpd -> tpd.partitionData().forEach(partitionProduceData -> ProduceRequest.validateRecords(version, partitionProduceData.records()))); + validateProducerIds(version, data); } return new ProduceRequest(data, version); } @@ -222,9 +220,31 @@ public void clearPartitionRecords() { partitionSizes(); data = null; } + + public static void validateProducerIds(short version, ProduceRequestData data) { + if (version >= 3) { + long producerId = -1; + for (ProduceRequestData.TopicProduceData topicData : data.topicData()) { + for (ProduceRequestData.PartitionProduceData partitionData : topicData.partitionData()) { + BaseRecords baseRecords = partitionData.records(); + if (baseRecords instanceof Records) { + Records records = (Records) baseRecords; + for (RecordBatch batch : records.batches()) { + if (producerId == -1 && batch.hasProducerId()) + producerId = batch.producerId(); + else if (batch.hasProducerId()) + if (batch.producerId() != producerId) + throw new InvalidRecordException("Produce requests with producer IDs can not have differing producer IDs"); + } + } + } + } + } + } public static void validateRecords(short version, BaseRecords baseRecords) { if (version >= 3) { + long producerId = -1L; if (baseRecords instanceof Records) { Records records = (Records) baseRecords; Iterator iterator = records.batches().iterator(); diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java index 9a01600012577..d784d5e83ab9c 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java @@ -73,7 +73,7 @@ public void shouldNotBeFlaggedAsTransactionalWhenNoRecords() { @Test public void shouldNotBeFlaggedAsIdempotentWhenRecordsNotIdempotent() { final ProduceRequest request = createNonIdempotentNonTransactionalRecords(); - assertFalse(RequestUtils.hasTransactionalRecords(request)); + assertFalse(RequestTestUtils.hasIdempotentRecords(request)); } @Test @@ -238,6 +238,37 @@ public void testV6AndBelowCannotUseZStdCompression() { // Works fine with current version (>= 7) ProduceRequest.forCurrentMagic(produceData); } + + @Test + public void testNoMixedProducerIds() { + final long producerId1 = 15L; + final long producerId2 = 16L; + final short producerEpoch = 5; + final int sequence = 10; + + final MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, + new SimpleRecord("foo".getBytes())); + final MemoryRecords txnRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId1, + producerEpoch, sequence, new SimpleRecord("bar".getBytes())); + final MemoryRecords idempotentRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId2, + producerEpoch, sequence, new SimpleRecord("bee".getBytes())); + + + ProduceRequest.Builder requestBuilder = ProduceRequest.forCurrentMagic( + new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( + new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(records))), + new ProduceRequestData.TopicProduceData().setName("bar").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(txnRecords))), + new ProduceRequestData.TopicProduceData().setName("bee").setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(idempotentRecords)))) + .iterator())) + .setAcks((short) 1) + .setTimeoutMs(5000)); + IntStream.range(3, ApiKeys.PRODUCE.latestVersion()) + .forEach(version -> assertThrows(InvalidRecordException.class, () -> requestBuilder.build((short) version).serialize())); + } @Test public void testMixedTransactionalData() { @@ -271,18 +302,18 @@ public void testMixedIdempotentData() { final short producerEpoch = 5; final int sequence = 10; - final MemoryRecords nonTxnRecords = MemoryRecords.withRecords(CompressionType.NONE, + final MemoryRecords nonIdempotentRecords = MemoryRecords.withRecords(CompressionType.NONE, new SimpleRecord("foo".getBytes())); - final MemoryRecords txnRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, + final MemoryRecords idempotentRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence, new SimpleRecord("bar".getBytes())); ProduceRequest.Builder builder = ProduceRequest.forMagic(RecordVersion.current().value, new ProduceRequestData() .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(txnRecords))), + new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(idempotentRecords))), new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(nonTxnRecords)))) + new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(nonIdempotentRecords)))) .iterator())) .setAcks((short) -1) .setTimeoutMs(5000)); diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 6a4971c44d353..003cdb85dcda3 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -65,7 +65,7 @@ import org.apache.kafka.common.resource.{Resource, ResourceType} import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} import org.apache.kafka.common.utils.{ProducerIdAndEpoch, Time} -import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.{InvalidRecordException, Node, TopicIdPartition, TopicPartition, Uuid} import org.apache.kafka.coordinator.group.GroupCoordinator import org.apache.kafka.server.authorizer._ import org.apache.kafka.server.common.MetadataVersion @@ -565,6 +565,12 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.sendErrorResponseMaybeThrottle(request, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception) return } + try { + ProduceRequest.validateProducerIds(request.header.apiVersion, produceRequest.data) + } catch { + case e: InvalidRecordException => + requestHelper.sendErrorResponseMaybeThrottle(request, Errors.INVALID_RECORD.exception) + } } val unauthorizedTopicResponses = mutable.Map[TopicPartition, PartitionResponse]() diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index e8bb496436f06..56a9d5157b76f 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -642,7 +642,14 @@ class ReplicaManager(val config: KafkaConfig, (entriesPerPartition, Map.empty) else entriesPerPartition.partition { case (topicPartition, records) => - getPartitionOrException(topicPartition).hasOngoingTransaction(records.firstBatch().producerId()) + // Produce requests (only requests that require verification) should only have one batch in "batches" but check all just to be safe. + val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional) + if (!transactionalBatches.isEmpty) { + getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId) + } else { + // If there is no producer ID in the batches, no need to verify. + true + } } def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = { @@ -744,6 +751,7 @@ class ReplicaManager(val config: KafkaConfig, } // map not yet verified partitions to a request object + // verification only occurs on produce requests and those will always have one batch for versions that support transactions. val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch() val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction() .setTransactionalId(transactionalId) diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index 69052491acc1c..d72b972d573be 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -2381,6 +2381,42 @@ class KafkaApisTest { ArgumentMatchers.eq(Some(transactionCoordinatorPartition))) } } + + @Test + def testDifferingProducerIdsThrowError(): Unit = { + val topic = "topic" + val transactionalId = "txn1" + + addTopicToMetadataCache(topic, numPartitions = 2) + + val tp = new TopicPartition("topic", 0) + val tp1 = new TopicPartition("topic", 1) + + val produceRequest = new ProduceRequest(new ProduceRequestData() + .setTopicData(new ProduceRequestData.TopicProduceDataCollection( + util.Arrays.asList(new ProduceRequestData.TopicProduceData() + .setName(tp.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp.partition) + .setRecords(MemoryRecords.withTransactionalRecords(CompressionType.NONE, 0, 0, 0, new SimpleRecord("test".getBytes))))), + new ProduceRequestData.TopicProduceData() + .setName(tp1.topic).setPartitionData(Collections.singletonList( + new ProduceRequestData.PartitionProduceData() + .setIndex(tp1.partition) + .setRecords(MemoryRecords.withIdempotentRecords(CompressionType.NONE, 2, 0, 0, new SimpleRecord("test".getBytes)))))) + .iterator)) + .setAcks(1.toShort) + .setTransactionalId(transactionalId) + .setTimeoutMs(5000), ApiKeys.PRODUCE.latestVersion) + val request = buildRequest(produceRequest) + + val kafkaApis = createKafkaApis() + + kafkaApis.handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) + + val response = verifyNoThrottling[ProduceResponse](request) + assertEquals(2, response.errorCounts().get(Errors.INVALID_RECORD)) + } @Test def testAddPartitionsToTxnWithInvalidPartition(): Unit = { diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index 1b7e99356779b..24936415e14a9 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -2124,11 +2124,15 @@ class ReplicaManagerTest { callback(Map(tp -> Errors.INVALID_RECORD).toMap) assertEquals(Errors.INVALID_RECORD, result.assertFired.error) - // If we don't supply a transaction coordinator partition, we do not verify, so counter stays the same. - val transactionalRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1, + // If we supply no transactional ID and idempotent records, we do not verify, so counter stays the same. + val idempotentRecords2 = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId, producerEpoch, sequence + 1, new SimpleRecord(s"message $sequence".getBytes)) - appendRecords(replicaManager, tp, transactionalRecords2) + appendRecords(replicaManager, tp, idempotentRecords2) verify(addPartitionsToTxnManager, times(1)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]()) + + // If we supply a transactional ID and some transactional and some idempotent records, we should only verify the topic partition with transactional records. + appendRecordsToMultipleTopics(replicaManager, Map(tp -> transactionalRecords, new TopicPartition(topic, 1) -> idempotentRecords2), transactionalId, Some(0)) + verify(addPartitionsToTxnManager, times(2)).addTxnData(ArgumentMatchers.eq(node), ArgumentMatchers.eq(transactionToAdd), any[AddPartitionsToTxnManager.AppendCallback]()) } finally { replicaManager.shutdown() } @@ -2506,6 +2510,27 @@ class ReplicaManagerTest { result } + private def appendRecordsToMultipleTopics(replicaManager: ReplicaManager, + entriesToAppend: Map[TopicPartition, MemoryRecords], + transactionalId: String, + transactionStatePartition: Option[Int], + origin: AppendOrigin = AppendOrigin.CLIENT, + requiredAcks: Short = -1): Unit = { + def appendCallback(responses: Map[TopicPartition, PartitionResponse]): Unit = { + responses.foreach( response => responses.get(response._1).isDefined) + } + + replicaManager.appendRecords( + timeout = 1000, + requiredAcks = requiredAcks, + internalTopicsAllowed = false, + origin = origin, + entriesPerPartition = entriesToAppend, + responseCallback = appendCallback, + transactionalId = transactionalId, + transactionStatePartition = transactionStatePartition) + } + private def fetchPartitionAsConsumer( replicaManager: ReplicaManager, partition: TopicIdPartition, From 72ce37854e33d3c75ba1b6e3139951b559a0debd Mon Sep 17 00:00:00 2001 From: Justine Date: Tue, 18 Apr 2023 12:23:46 -0700 Subject: [PATCH 2/8] checkstyle --- .../org/apache/kafka/common/requests/ProduceRequest.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java index 6269091da0502..35a53a7b7daf3 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java @@ -24,7 +24,10 @@ import org.apache.kafka.common.protocol.ApiKeys; import org.apache.kafka.common.protocol.ByteBufferAccessor; import org.apache.kafka.common.protocol.Errors; -import org.apache.kafka.common.record.*; +import org.apache.kafka.common.record.BaseRecords; +import org.apache.kafka.common.record.CompressionType; +import org.apache.kafka.common.record.RecordBatch; +import org.apache.kafka.common.record.Records; import org.apache.kafka.common.utils.Utils; import java.nio.ByteBuffer; @@ -244,7 +247,6 @@ else if (batch.hasProducerId()) public static void validateRecords(short version, BaseRecords baseRecords) { if (version >= 3) { - long producerId = -1L; if (baseRecords instanceof Records) { Records records = (Records) baseRecords; Iterator iterator = records.batches().iterator(); From 40ddfa9b411add9fa91768559dc3bf7998ef5f63 Mon Sep 17 00:00:00 2001 From: Justine Date: Mon, 1 May 2023 11:04:11 -0700 Subject: [PATCH 3/8] add return --- core/src/main/scala/kafka/server/KafkaApis.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 003cdb85dcda3..4a5ff8056c7be 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -568,8 +568,9 @@ class KafkaApis(val requestChannel: RequestChannel, try { ProduceRequest.validateProducerIds(request.header.apiVersion, produceRequest.data) } catch { - case e: InvalidRecordException => + case _: InvalidRecordException => requestHelper.sendErrorResponseMaybeThrottle(request, Errors.INVALID_RECORD.exception) + return } } From 26a1b678f5f0446c848ad514d0d2956956d06cb6 Mon Sep 17 00:00:00 2001 From: Justine Date: Mon, 1 May 2023 14:05:55 -0700 Subject: [PATCH 4/8] Fix comment --- core/src/main/scala/kafka/server/ReplicaManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 56a9d5157b76f..e80c36725cbb1 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -751,7 +751,7 @@ class ReplicaManager(val config: KafkaConfig, } // map not yet verified partitions to a request object - // verification only occurs on produce requests and those will always have one batch for versions that support transactions. + // Since verification occurs on produce requests only, and each produce request has one batch, we can just grab the first one. val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch() val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction() .setTransactionalId(transactionalId) From 0ccd46e38e2d2e00cb638450a71d8e8100fdc0cb Mon Sep 17 00:00:00 2001 From: Justine Date: Tue, 2 May 2023 13:08:10 -0700 Subject: [PATCH 5/8] Remove the enforcement of a single producer ID in batches in the request --- .../kafka/common/requests/ProduceRequest.java | 22 ------------ .../common/requests/ProduceRequestTest.java | 31 ---------------- .../main/scala/kafka/server/KafkaApis.scala | 9 +---- .../unit/kafka/server/KafkaApisTest.scala | 36 ------------------- 4 files changed, 1 insertion(+), 97 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java index 35a53a7b7daf3..758631a1d87aa 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java @@ -88,7 +88,6 @@ private ProduceRequest build(short version, boolean validate) { data.topicData().forEach(tpd -> tpd.partitionData().forEach(partitionProduceData -> ProduceRequest.validateRecords(version, partitionProduceData.records()))); - validateProducerIds(version, data); } return new ProduceRequest(data, version); } @@ -223,27 +222,6 @@ public void clearPartitionRecords() { partitionSizes(); data = null; } - - public static void validateProducerIds(short version, ProduceRequestData data) { - if (version >= 3) { - long producerId = -1; - for (ProduceRequestData.TopicProduceData topicData : data.topicData()) { - for (ProduceRequestData.PartitionProduceData partitionData : topicData.partitionData()) { - BaseRecords baseRecords = partitionData.records(); - if (baseRecords instanceof Records) { - Records records = (Records) baseRecords; - for (RecordBatch batch : records.batches()) { - if (producerId == -1 && batch.hasProducerId()) - producerId = batch.producerId(); - else if (batch.hasProducerId()) - if (batch.producerId() != producerId) - throw new InvalidRecordException("Produce requests with producer IDs can not have differing producer IDs"); - } - } - } - } - } - } public static void validateRecords(short version, BaseRecords baseRecords) { if (version >= 3) { diff --git a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java index d784d5e83ab9c..4ac61dbf5b3a3 100644 --- a/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java +++ b/clients/src/test/java/org/apache/kafka/common/requests/ProduceRequestTest.java @@ -238,37 +238,6 @@ public void testV6AndBelowCannotUseZStdCompression() { // Works fine with current version (>= 7) ProduceRequest.forCurrentMagic(produceData); } - - @Test - public void testNoMixedProducerIds() { - final long producerId1 = 15L; - final long producerId2 = 16L; - final short producerEpoch = 5; - final int sequence = 10; - - final MemoryRecords records = MemoryRecords.withRecords(CompressionType.NONE, - new SimpleRecord("foo".getBytes())); - final MemoryRecords txnRecords = MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId1, - producerEpoch, sequence, new SimpleRecord("bar".getBytes())); - final MemoryRecords idempotentRecords = MemoryRecords.withIdempotentRecords(CompressionType.NONE, producerId2, - producerEpoch, sequence, new SimpleRecord("bee".getBytes())); - - - ProduceRequest.Builder requestBuilder = ProduceRequest.forCurrentMagic( - new ProduceRequestData() - .setTopicData(new ProduceRequestData.TopicProduceDataCollection(Arrays.asList( - new ProduceRequestData.TopicProduceData().setName("foo").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(records))), - new ProduceRequestData.TopicProduceData().setName("bar").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(1).setRecords(txnRecords))), - new ProduceRequestData.TopicProduceData().setName("bee").setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData().setIndex(0).setRecords(idempotentRecords)))) - .iterator())) - .setAcks((short) 1) - .setTimeoutMs(5000)); - IntStream.range(3, ApiKeys.PRODUCE.latestVersion()) - .forEach(version -> assertThrows(InvalidRecordException.class, () -> requestBuilder.build((short) version).serialize())); - } @Test public void testMixedTransactionalData() { diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala b/core/src/main/scala/kafka/server/KafkaApis.scala index 4a5ff8056c7be..6a4971c44d353 100644 --- a/core/src/main/scala/kafka/server/KafkaApis.scala +++ b/core/src/main/scala/kafka/server/KafkaApis.scala @@ -65,7 +65,7 @@ import org.apache.kafka.common.resource.{Resource, ResourceType} import org.apache.kafka.common.security.auth.{KafkaPrincipal, SecurityProtocol} import org.apache.kafka.common.security.token.delegation.{DelegationToken, TokenInformation} import org.apache.kafka.common.utils.{ProducerIdAndEpoch, Time} -import org.apache.kafka.common.{InvalidRecordException, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.{Node, TopicIdPartition, TopicPartition, Uuid} import org.apache.kafka.coordinator.group.GroupCoordinator import org.apache.kafka.server.authorizer._ import org.apache.kafka.server.common.MetadataVersion @@ -565,13 +565,6 @@ class KafkaApis(val requestChannel: RequestChannel, requestHelper.sendErrorResponseMaybeThrottle(request, Errors.TRANSACTIONAL_ID_AUTHORIZATION_FAILED.exception) return } - try { - ProduceRequest.validateProducerIds(request.header.apiVersion, produceRequest.data) - } catch { - case _: InvalidRecordException => - requestHelper.sendErrorResponseMaybeThrottle(request, Errors.INVALID_RECORD.exception) - return - } } val unauthorizedTopicResponses = mutable.Map[TopicPartition, PartitionResponse]() diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala index d72b972d573be..69052491acc1c 100644 --- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala +++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala @@ -2381,42 +2381,6 @@ class KafkaApisTest { ArgumentMatchers.eq(Some(transactionCoordinatorPartition))) } } - - @Test - def testDifferingProducerIdsThrowError(): Unit = { - val topic = "topic" - val transactionalId = "txn1" - - addTopicToMetadataCache(topic, numPartitions = 2) - - val tp = new TopicPartition("topic", 0) - val tp1 = new TopicPartition("topic", 1) - - val produceRequest = new ProduceRequest(new ProduceRequestData() - .setTopicData(new ProduceRequestData.TopicProduceDataCollection( - util.Arrays.asList(new ProduceRequestData.TopicProduceData() - .setName(tp.topic).setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData() - .setIndex(tp.partition) - .setRecords(MemoryRecords.withTransactionalRecords(CompressionType.NONE, 0, 0, 0, new SimpleRecord("test".getBytes))))), - new ProduceRequestData.TopicProduceData() - .setName(tp1.topic).setPartitionData(Collections.singletonList( - new ProduceRequestData.PartitionProduceData() - .setIndex(tp1.partition) - .setRecords(MemoryRecords.withIdempotentRecords(CompressionType.NONE, 2, 0, 0, new SimpleRecord("test".getBytes)))))) - .iterator)) - .setAcks(1.toShort) - .setTransactionalId(transactionalId) - .setTimeoutMs(5000), ApiKeys.PRODUCE.latestVersion) - val request = buildRequest(produceRequest) - - val kafkaApis = createKafkaApis() - - kafkaApis.handleProduceRequest(request, RequestLocal.withThreadConfinedCaching) - - val response = verifyNoThrottling[ProduceResponse](request) - assertEquals(2, response.errorCounts().get(Errors.INVALID_RECORD)) - } @Test def testAddPartitionsToTxnWithInvalidPartition(): Unit = { From f9697e22017e0fc86ad889538a946854c4ba31f4 Mon Sep 17 00:00:00 2001 From: Justine Date: Tue, 2 May 2023 15:55:08 -0700 Subject: [PATCH 6/8] Update comments --- .../org/apache/kafka/common/requests/ProduceRequest.java | 4 ++-- core/src/main/scala/kafka/server/ReplicaManager.scala | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java index 758631a1d87aa..24a84de438d54 100644 --- a/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java +++ b/clients/src/main/java/org/apache/kafka/common/requests/ProduceRequest.java @@ -230,7 +230,7 @@ public static void validateRecords(short version, BaseRecords baseRecords) { Iterator iterator = records.batches().iterator(); if (!iterator.hasNext()) throw new InvalidRecordException("Produce requests with version " + version + " must have at least " + - "one record batch"); + "one record batch per partition"); RecordBatch entry = iterator.next(); if (entry.magic() != RecordBatch.MAGIC_VALUE_V2) @@ -243,7 +243,7 @@ public static void validateRecords(short version, BaseRecords baseRecords) { if (iterator.hasNext()) throw new InvalidRecordException("Produce requests with version " + version + " are only allowed to " + - "contain exactly one record batch"); + "contain exactly one record batch per partition"); } } diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index e80c36725cbb1..f37d9cf54497e 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -642,7 +642,7 @@ class ReplicaManager(val config: KafkaConfig, (entriesPerPartition, Map.empty) else entriesPerPartition.partition { case (topicPartition, records) => - // Produce requests (only requests that require verification) should only have one batch in "batches" but check all just to be safe. + // Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe. val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional) if (!transactionalBatches.isEmpty) { getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId) @@ -751,7 +751,8 @@ class ReplicaManager(val config: KafkaConfig, } // map not yet verified partitions to a request object - // Since verification occurs on produce requests only, and each produce request has one batch, we can just grab the first one. + // Since verification occurs on produce requests only, and each produce request has one batch per partition, we know the producer ID is transactional + // from the checks above val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch() val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction() .setTransactionalId(transactionalId) From b4586601de5dcaf38608db32c85dd994baad33c2 Mon Sep 17 00:00:00 2001 From: Justine Date: Wed, 3 May 2023 11:27:18 -0700 Subject: [PATCH 7/8] Updated to check producer ID is the same on verifying partitions --- .../scala/kafka/server/ReplicaManager.scala | 20 ++++--- .../kafka/server/ReplicaManagerTest.scala | 52 ++++++++++++++++++- 2 files changed, 64 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index f37d9cf54497e..9bb360e97ea6e 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -51,7 +51,7 @@ import org.apache.kafka.common.requests.FetchRequest.PartitionData import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests._ import org.apache.kafka.common.utils.Time -import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.{ElectionType, InvalidRecordException, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} import org.apache.kafka.image.{LocalReplicaChanges, MetadataImage, TopicsDelta} import org.apache.kafka.metadata.LeaderConstants.NO_LEADER import org.apache.kafka.server.common.MetadataVersion._ @@ -637,24 +637,31 @@ class ReplicaManager(val config: KafkaConfig, if (isValidRequiredAcks(requiredAcks)) { val sTime = time.milliseconds + val transactionalProducerIds = mutable.HashSet[Long]() val (verifiedEntriesPerPartition, notYetVerifiedEntriesPerPartition) = if (transactionStatePartition.isEmpty || !config.transactionPartitionVerificationEnable) (entriesPerPartition, Map.empty) - else + else { entriesPerPartition.partition { case (topicPartition, records) => // Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe. val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional) - if (!transactionalBatches.isEmpty) { + transactionalBatches.map(_.producerId()).toSet.foreach(transactionalProducerIds.add(_)) + if (transactionalBatches.nonEmpty) { getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId) } else { // If there is no producer ID in the batches, no need to verify. true } } + } + // We should have exactly one producer ID for transactional records + if (transactionalProducerIds.size > 1) { + throw new InvalidRecordException("Transactional records contained more than one producer ID") + } def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = { val verifiedEntries = - if (unverifiedEntries.isEmpty) + if (unverifiedEntries.isEmpty) allEntries else allEntries.filter { case (tp, _) => @@ -750,9 +757,8 @@ class ReplicaManager(val config: KafkaConfig, .setPartitions(tps.map(tp => Integer.valueOf(tp.partition())).toList.asJava)) } - // map not yet verified partitions to a request object - // Since verification occurs on produce requests only, and each produce request has one batch per partition, we know the producer ID is transactional - // from the checks above + // Map not yet verified partitions to a request object. + // We verify above that all partitions use the same producer ID. val batchInfo = notYetVerifiedEntriesPerPartition.head._2.firstBatch() val notYetVerifiedTransaction = new AddPartitionsToTxnTransaction() .setTransactionalId(transactionalId) diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index 24936415e14a9..6674580deba53 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -53,7 +53,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests._ import org.apache.kafka.common.security.auth.KafkaPrincipal import org.apache.kafka.common.utils.{LogContext, Time, Utils} -import org.apache.kafka.common.{IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.{InvalidRecordException, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} import org.apache.kafka.image._ import org.apache.kafka.metadata.LeaderConstants.NO_LEADER import org.apache.kafka.metadata.LeaderRecoveryState @@ -2139,6 +2139,56 @@ class ReplicaManagerTest { TestUtils.assertNoNonDaemonThreads(this.getClass.getName) } + + @Test + def testExceptionWhenUnverifiedTransactionHasMultipleProducerIds(): Unit = { + val tp0 = new TopicPartition(topic, 0) + val tp1 = new TopicPartition(topic, 1) + val transactionalId = "txn1" + val producerId = 24L + val producerEpoch = 0.toShort + val sequence = 0 + + val mockLogMgr = TestUtils.createLogManager(config.logDirs.map(new File(_))) + val metadataCache = mock(classOf[MetadataCache]) + val addPartitionsToTxnManager = mock(classOf[AddPartitionsToTxnManager]) + + val replicaManager = new ReplicaManager( + metrics = metrics, + config = config, + time = time, + scheduler = new MockScheduler(time), + logManager = mockLogMgr, + quotaManagers = quotaManager, + metadataCache = metadataCache, + logDirFailureChannel = new LogDirFailureChannel(config.logDirs.size), + alterPartitionManager = alterPartitionManager, + addPartitionsToTxnManager = Some(addPartitionsToTxnManager)) + + try { + replicaManager.becomeLeaderOrFollower(1, + makeLeaderAndIsrRequest(topicIds(tp0.topic), tp0, Seq(0, 1), LeaderAndIsr(1, List(0, 1))), + (_, _) => ()) + + replicaManager.becomeLeaderOrFollower(1, + makeLeaderAndIsrRequest(topicIds(tp1.topic), tp1, Seq(0, 1), LeaderAndIsr(1, List(0, 1))), + (_, _) => ()) + + // Append some transactional records with different producer IDs + val transactionalRecords = mutable.Map[TopicPartition, MemoryRecords]() + transactionalRecords.put(tp0, MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId, producerEpoch, sequence, + new SimpleRecord(s"message $sequence".getBytes))) + transactionalRecords.put(tp1, MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId + 1, producerEpoch, sequence, + new SimpleRecord(s"message $sequence".getBytes))) + + assertThrows(classOf[InvalidRecordException], + () => appendRecordsToMultipleTopics(replicaManager, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))) + } finally { + replicaManager.shutdown() + } + + TestUtils.assertNoNonDaemonThreads(this.getClass.getName) + } @Test def testDisabledVerification(): Unit = { From 5774739b399c8e78004cfc27ffd677af078f56b3 Mon Sep 17 00:00:00 2001 From: Justine Date: Wed, 3 May 2023 12:10:11 -0700 Subject: [PATCH 8/8] Changed error and simplified --- core/src/main/scala/kafka/server/ReplicaManager.scala | 6 +++--- .../test/scala/unit/kafka/server/ReplicaManagerTest.scala | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/kafka/server/ReplicaManager.scala b/core/src/main/scala/kafka/server/ReplicaManager.scala index 9bb360e97ea6e..314be5b3f4252 100644 --- a/core/src/main/scala/kafka/server/ReplicaManager.scala +++ b/core/src/main/scala/kafka/server/ReplicaManager.scala @@ -51,7 +51,7 @@ import org.apache.kafka.common.requests.FetchRequest.PartitionData import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests._ import org.apache.kafka.common.utils.Time -import org.apache.kafka.common.{ElectionType, InvalidRecordException, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.{ElectionType, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} import org.apache.kafka.image.{LocalReplicaChanges, MetadataImage, TopicsDelta} import org.apache.kafka.metadata.LeaderConstants.NO_LEADER import org.apache.kafka.server.common.MetadataVersion._ @@ -645,7 +645,7 @@ class ReplicaManager(val config: KafkaConfig, entriesPerPartition.partition { case (topicPartition, records) => // Produce requests (only requests that require verification) should only have one batch per partition in "batches" but check all just to be safe. val transactionalBatches = records.batches.asScala.filter(batch => batch.hasProducerId && batch.isTransactional) - transactionalBatches.map(_.producerId()).toSet.foreach(transactionalProducerIds.add(_)) + transactionalBatches.foreach(batch => transactionalProducerIds.add(batch.producerId)) if (transactionalBatches.nonEmpty) { getPartitionOrException(topicPartition).hasOngoingTransaction(transactionalBatches.head.producerId) } else { @@ -656,7 +656,7 @@ class ReplicaManager(val config: KafkaConfig, } // We should have exactly one producer ID for transactional records if (transactionalProducerIds.size > 1) { - throw new InvalidRecordException("Transactional records contained more than one producer ID") + throw new InvalidPidMappingException("Transactional records contained more than one producer ID") } def appendEntries(allEntries: Map[TopicPartition, MemoryRecords])(unverifiedEntries: Map[TopicPartition, Errors]): Unit = { diff --git a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala index 6674580deba53..83b5fed422df2 100644 --- a/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala +++ b/core/src/test/scala/unit/kafka/server/ReplicaManagerTest.scala @@ -35,7 +35,7 @@ import kafka.server.epoch.util.MockBlockingSender import kafka.utils.timer.MockTimer import kafka.utils.{MockTime, Pool, TestUtils} import org.apache.kafka.clients.FetchSessionHandler -import org.apache.kafka.common.errors.KafkaStorageException +import org.apache.kafka.common.errors.{InvalidPidMappingException, KafkaStorageException} import org.apache.kafka.common.message.LeaderAndIsrRequestData import org.apache.kafka.common.message.LeaderAndIsrRequestData.LeaderAndIsrPartitionState import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset @@ -53,7 +53,7 @@ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests._ import org.apache.kafka.common.security.auth.KafkaPrincipal import org.apache.kafka.common.utils.{LogContext, Time, Utils} -import org.apache.kafka.common.{InvalidRecordException, IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} +import org.apache.kafka.common.{IsolationLevel, Node, TopicIdPartition, TopicPartition, Uuid} import org.apache.kafka.image._ import org.apache.kafka.metadata.LeaderConstants.NO_LEADER import org.apache.kafka.metadata.LeaderRecoveryState @@ -2181,7 +2181,7 @@ class ReplicaManagerTest { transactionalRecords.put(tp1, MemoryRecords.withTransactionalRecords(CompressionType.NONE, producerId + 1, producerEpoch, sequence, new SimpleRecord(s"message $sequence".getBytes))) - assertThrows(classOf[InvalidRecordException], + assertThrows(classOf[InvalidPidMappingException], () => appendRecordsToMultipleTopics(replicaManager, transactionalRecords, transactionalId = transactionalId, transactionStatePartition = Some(0))) } finally { replicaManager.shutdown()