From ff3f31678813130c4e5651d41dfd5cec593d674d Mon Sep 17 00:00:00 2001 From: Jeff Kim Date: Tue, 18 Apr 2023 04:41:54 -0400 Subject: [PATCH 1/3] KAFKA-14869: Ignore unknown record types for coordinators (KIP-915, Part-1) (#13511) This patch implemented the first part of KIP-915. It updates the group coordinator and the transaction coordinator to ignores unknown record types while loading their respective state from the partitions. This allows downgrades from future versions that will include new record types. Reviewers: Alexandre Dupriez , David Jacot --- .../group/GroupMetadataManager.scala | 21 +++-- .../transaction/TransactionLog.scala | 79 ++++++++++++------- .../transaction/TransactionStateManager.scala | 25 +++--- .../group/GroupMetadataManagerTest.scala | 59 +++++++++++++- .../transaction/TransactionLogTest.scala | 10 +++ .../TransactionStateManagerTest.scala | 40 +++++++++- 6 files changed, 186 insertions(+), 48 deletions(-) diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala index 9e3769b6239a7..2e22dee91d5cd 100644 --- a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala +++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala @@ -649,7 +649,6 @@ class GroupMetadataManager(brokerId: Int, if (batchBaseOffset.isEmpty) batchBaseOffset = Some(record.offset) GroupMetadataManager.readMessageKey(record.key) match { - case offsetKey: OffsetKey => if (isTxnOffsetCommit && !pendingOffsets.contains(batch.producerId)) pendingOffsets.put(batch.producerId, mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]()) @@ -681,8 +680,10 @@ class GroupMetadataManager(brokerId: Int, removedGroups.add(groupId) } - case unknownKey => - throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata") + case unknownKey: UnknownKey => + warn(s"Unknown message key with version ${unknownKey.version}" + + s" while loading offsets and group metadata from $topicPartition. Ignoring it. " + + "It could be a left over from an aborted upgrade.") } } } @@ -1146,7 +1147,9 @@ object GroupMetadataManager { // version 2 refers to group metadata val key = new GroupMetadataKeyData(new ByteBufferAccessor(buffer), version) GroupMetadataKey(version, key.group) - } else throw new IllegalStateException(s"Unknown group metadata message version: $version") + } else { + UnknownKey(version) + } } /** @@ -1266,7 +1269,7 @@ object GroupMetadataManager { GroupMetadataManager.readMessageKey(record.key) match { case offsetKey: OffsetKey => parseOffsets(offsetKey, record.value) case groupMetadataKey: GroupMetadataKey => parseGroupMetadata(groupMetadataKey, record.value) - case _ => throw new KafkaException("Failed to decode message using offset topic decoder (message had an invalid key)") + case unknownKey: UnknownKey => (Some(s"unknown::version=${unknownKey.version}"), None) } } } @@ -1344,18 +1347,20 @@ case class GroupTopicPartition(group: String, topicPartition: TopicPartition) { "[%s,%s,%d]".format(group, topicPartition.topic, topicPartition.partition) } -trait BaseKey{ +sealed trait BaseKey{ def version: Short def key: Any } case class OffsetKey(version: Short, key: GroupTopicPartition) extends BaseKey { - override def toString: String = key.toString } case class GroupMetadataKey(version: Short, key: String) extends BaseKey { - override def toString: String = key } +case class UnknownKey(version: Short) extends BaseKey { + override def key: String = null + override def toString: String = key +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index cb501f774fd9d..30bd517c0093e 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -19,7 +19,6 @@ package kafka.coordinator.transaction import java.io.PrintStream import java.nio.ByteBuffer import java.nio.charset.StandardCharsets - import kafka.internals.generated.{TransactionLogKey, TransactionLogValue} import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} @@ -98,7 +97,7 @@ object TransactionLog { * * @return the key */ - def readTxnRecordKey(buffer: ByteBuffer): TxnKey = { + def readTxnRecordKey(buffer: ByteBuffer): BaseKey = { val version = buffer.getShort if (version >= TransactionLogKey.LOWEST_SUPPORTED_VERSION && version <= TransactionLogKey.HIGHEST_SUPPORTED_VERSION) { val value = new TransactionLogKey(new ByteBufferAccessor(buffer), version) @@ -106,7 +105,9 @@ object TransactionLog { version = version, transactionalId = value.transactionalId ) - } else throw new IllegalStateException(s"Unknown version $version from the transaction log message") + } else { + UnknownKey(version) + } } /** @@ -148,17 +149,21 @@ object TransactionLog { // Formatter for use with tools to read transaction log messages class TransactionLogMessageFormatter extends MessageFormatter { def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { - Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { txnKey => - val transactionalId = txnKey.transactionalId - val value = consumerRecord.value - val producerIdMetadata = if (value == null) - None - else - readTxnRecordValue(transactionalId, ByteBuffer.wrap(value)) - output.write(transactionalId.getBytes(StandardCharsets.UTF_8)) - output.write("::".getBytes(StandardCharsets.UTF_8)) - output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8)) - output.write("\n".getBytes(StandardCharsets.UTF_8)) + Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { + case txnKey: TxnKey => + val transactionalId = txnKey.transactionalId + val value = consumerRecord.value + val producerIdMetadata = if (value == null) + None + else + readTxnRecordValue(transactionalId, ByteBuffer.wrap(value)) + output.write(transactionalId.getBytes(StandardCharsets.UTF_8)) + output.write("::".getBytes(StandardCharsets.UTF_8)) + output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8)) + output.write("\n".getBytes(StandardCharsets.UTF_8)) + + case unknownKey: UnknownKey => + output.write(s"unknown::version=${unknownKey.version}\n".getBytes(StandardCharsets.UTF_8)) } } } @@ -167,25 +172,41 @@ object TransactionLog { * Exposed for printing records using [[kafka.tools.DumpLogSegments]] */ def formatRecordKeyAndValue(record: Record): (Option[String], Option[String]) = { - val txnKey = TransactionLog.readTxnRecordKey(record.key) - val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}" - - val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match { - case None => "" - - case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," + - s"producerEpoch:${txnMetadata.producerEpoch}," + - s"state=${txnMetadata.state}," + - s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," + - s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," + - s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}" - } + TransactionLog.readTxnRecordKey(record.key) match { + case txnKey: TxnKey => + val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}" + + val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match { + case None => "" - (Some(keyString), Some(valueString)) + case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," + + s"producerEpoch:${txnMetadata.producerEpoch}," + + s"state=${txnMetadata.state}," + + s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," + + s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," + + s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}" + } + + (Some(keyString), Some(valueString)) + + case unknownKey: UnknownKey => + (Some(s"unknown::version=${unknownKey.version}"), None) + } } } -case class TxnKey(version: Short, transactionalId: String) { +sealed trait BaseKey{ + def version: Short + def transactionalId: String +} + +case class TxnKey(version: Short, transactionalId: String) extends BaseKey { + override def toString: String = transactionalId +} + +case class UnknownKey(version: Short) extends BaseKey { + override def transactionalId: String = null override def toString: String = transactionalId } + diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index 217b38382a7a1..4a19b57a6ec56 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -466,16 +466,23 @@ class TransactionStateManager(brokerId: Int, memRecords.batches.forEach { batch => for (record <- batch.asScala) { require(record.hasKey, "Transaction state log's key should not be null") - val txnKey = TransactionLog.readTxnRecordKey(record.key) - // load transaction metadata along with transaction state - val transactionalId = txnKey.transactionalId - TransactionLog.readTxnRecordValue(transactionalId, record.value) match { - case None => - loadedTransactions.remove(transactionalId) - case Some(txnMetadata) => - loadedTransactions.put(transactionalId, txnMetadata) + TransactionLog.readTxnRecordKey(record.key) match { + case txnKey: TxnKey => + // load transaction metadata along with transaction state + val transactionalId = txnKey.transactionalId + TransactionLog.readTxnRecordValue(transactionalId, record.value) match { + case None => + loadedTransactions.remove(transactionalId) + case Some(txnMetadata) => + loadedTransactions.put(transactionalId, txnMetadata) + } + currOffset = batch.nextOffset + + case unknownKey: UnknownKey => + warn(s"Unknown message key with version ${unknownKey.version}" + + s" while loading transaction state from $topicPartition. Ignoring it. " + + "It could be a left over from an aborted upgrade.") } - currOffset = batch.nextOffset } } } diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala index bf475cc75899a..9365c76dd9995 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala @@ -37,7 +37,7 @@ import org.apache.kafka.clients.consumer.internals.ConsumerProtocol import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.internals.Topic import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics => kMetrics} -import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.protocol.{Errors, MessageUtil} import org.apache.kafka.common.record._ import org.apache.kafka.common.requests.OffsetFetchResponse import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse @@ -657,6 +657,7 @@ class GroupMetadataManagerTest { val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) val memberId = "98098230493" val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) @@ -2648,4 +2649,60 @@ class GroupMetadataManagerTest { assertTrue(partitionLoadTime("partition-load-time-max") >= diff) assertTrue(partitionLoadTime("partition-load-time-avg") >= diff) } + + @Test + def testReadMessageKeyCanReadUnknownMessage(): Unit = { + val record = new org.apache.kafka.coordinator.group.generated.GroupMetadataKey() + val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record) + val key = GroupMetadataManager.readMessageKey(ByteBuffer.wrap(unknownRecord)) + assertEquals(UnknownKey(Short.MaxValue), key) + } + + @Test + def testLoadGroupsAndOffsetsWillIgnoreUnknownMessage(): Unit = { + val generation = 935 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + + // Should ignore unknown record + val unknownKey = new org.apache.kafka.coordinator.group.generated.GroupMetadataKey() + val lowestUnsupportedVersion = (org.apache.kafka.coordinator.group.generated.GroupMetadataKey + .HIGHEST_SUPPORTED_VERSION + 1).toShort + + val unknownMessage1 = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey) + val unknownMessage2 = MessageUtil.toVersionPrefixedBytes(lowestUnsupportedVersion, unknownKey) + val unknownRecord1 = new SimpleRecord(unknownMessage1, unknownMessage1) + val unknownRecord2 = new SimpleRecord(unknownMessage2, unknownMessage2) + + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(unknownRecord1, unknownRecord2) ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupTopicPartition, startOffset, records) + + groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, 1, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Stable, group.currentState) + assertEquals(memberId, group.leaderOrNull) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(Set(memberId), group.allMembers) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).contains(None)) + } + } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala index 32e17d88a7b1e..eb1284278d29b 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -17,12 +17,15 @@ package kafka.coordinator.transaction +import kafka.internals.generated.TransactionLogKey import kafka.utils.TestUtils import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.MessageUtil import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test +import java.nio.ByteBuffer import scala.jdk.CollectionConverters._ class TransactionLogTest { @@ -135,4 +138,11 @@ class TransactionLogTest { assertEquals(Some(""), valueStringOpt) } + @Test + def testReadTxnRecordKeyCanReadUnknownMessage(): Unit = { + val record = new TransactionLogKey() + val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record) + val key = TransactionLog.readTxnRecordKey(ByteBuffer.wrap(unknownRecord)) + assertEquals(UnknownKey(Short.MaxValue), key) + } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index 9cff29ba858a3..f648378ce0874 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -16,6 +16,8 @@ */ package kafka.coordinator.transaction +import kafka.internals.generated.TransactionLogKey + import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.concurrent.CountDownLatch @@ -29,7 +31,7 @@ import kafka.zk.KafkaZkClient import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics} -import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.protocol.{Errors, MessageUtil} import org.apache.kafka.common.record._ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult @@ -1112,4 +1114,40 @@ class TransactionStateManagerTest { assertTrue(partitionLoadTime("partition-load-time-max") >= 0) assertTrue(partitionLoadTime( "partition-load-time-avg") >= 0) } + + @Test + def testIgnoreUnknownRecordType(): Unit = { + txnMetadata1.state = PrepareCommit + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + val startOffset = 0L + + val unknownKey = new TransactionLogKey() + val unknownMessage = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey) + val unknownRecord = new SimpleRecord(unknownMessage, unknownMessage) + + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (Seq(unknownRecord) ++ txnRecords).toArray: _*) + + prepareTxnLog(topicPartition, 0, records) + + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1, (_, _, _, _) => ()) + assertEquals(0, transactionManager.loadingPartitions.size) + assertTrue(transactionManager.transactionMetadataCache.contains(partitionId)) + val txnMetadataPool = transactionManager.transactionMetadataCache(partitionId).metadataPerTransactionalId + assertFalse(txnMetadataPool.isEmpty) + assertTrue(txnMetadataPool.contains(transactionalId1)) + val txnMetadata = txnMetadataPool.get(transactionalId1) + assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId) + assertEquals(txnMetadata1.producerId, txnMetadata.producerId) + assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId) + assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch) + assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch) + assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs) + assertEquals(txnMetadata1.state, txnMetadata.state) + assertEquals(txnMetadata1.topicPartitions, txnMetadata.topicPartitions) + assertEquals(1, transactionManager.transactionMetadataCache(partitionId).coordinatorEpoch) + } } From f916d4916632554c0722a6089799ff38b6ce061d Mon Sep 17 00:00:00 2001 From: Jeff Kim Date: Wed, 19 Apr 2023 12:02:02 -0400 Subject: [PATCH 2/3] fix build --- .../kafka/coordinator/group/GroupMetadataManagerTest.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala index 9365c76dd9995..5cad046056a2b 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala @@ -2652,7 +2652,7 @@ class GroupMetadataManagerTest { @Test def testReadMessageKeyCanReadUnknownMessage(): Unit = { - val record = new org.apache.kafka.coordinator.group.generated.GroupMetadataKey() + val record = new kafka.internals.generated.GroupMetadataKey() val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record) val key = GroupMetadataManager.readMessageKey(ByteBuffer.wrap(unknownRecord)) assertEquals(UnknownKey(Short.MaxValue), key) @@ -2675,8 +2675,8 @@ class GroupMetadataManagerTest { val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) // Should ignore unknown record - val unknownKey = new org.apache.kafka.coordinator.group.generated.GroupMetadataKey() - val lowestUnsupportedVersion = (org.apache.kafka.coordinator.group.generated.GroupMetadataKey + val unknownKey = new kafka.internals.generated.GroupMetadataKey() + val lowestUnsupportedVersion = (kafka.internals.generated.GroupMetadataKey .HIGHEST_SUPPORTED_VERSION + 1).toShort val unknownMessage1 = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey) From 034507136c6fba036e78ec944576eb2d1cbb2238 Mon Sep 17 00:00:00 2001 From: Jeff Kim Date: Thu, 20 Apr 2023 09:32:16 -0400 Subject: [PATCH 3/3] fix test --- .../unit/kafka/coordinator/group/GroupMetadataManagerTest.scala | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala index 5cad046056a2b..c4cb53f1d01de 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala @@ -2688,6 +2688,7 @@ class GroupMetadataManagerTest { (offsetCommitRecords ++ Seq(unknownRecord1, unknownRecord2) ++ Seq(groupMetadataRecord)).toArray: _*) expectGroupMetadataLoad(groupTopicPartition, startOffset, records) + EasyMock.replay(replicaManager) groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, 1, _ => (), 0L)