diff --git a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala index 9e3769b6239a7..2e22dee91d5cd 100644 --- a/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala +++ b/core/src/main/scala/kafka/coordinator/group/GroupMetadataManager.scala @@ -649,7 +649,6 @@ class GroupMetadataManager(brokerId: Int, if (batchBaseOffset.isEmpty) batchBaseOffset = Some(record.offset) GroupMetadataManager.readMessageKey(record.key) match { - case offsetKey: OffsetKey => if (isTxnOffsetCommit && !pendingOffsets.contains(batch.producerId)) pendingOffsets.put(batch.producerId, mutable.Map[GroupTopicPartition, CommitRecordMetadataAndOffset]()) @@ -681,8 +680,10 @@ class GroupMetadataManager(brokerId: Int, removedGroups.add(groupId) } - case unknownKey => - throw new IllegalStateException(s"Unexpected message key $unknownKey while loading offsets and group metadata") + case unknownKey: UnknownKey => + warn(s"Unknown message key with version ${unknownKey.version}" + + s" while loading offsets and group metadata from $topicPartition. Ignoring it. " + + "It could be a left over from an aborted upgrade.") } } } @@ -1146,7 +1147,9 @@ object GroupMetadataManager { // version 2 refers to group metadata val key = new GroupMetadataKeyData(new ByteBufferAccessor(buffer), version) GroupMetadataKey(version, key.group) - } else throw new IllegalStateException(s"Unknown group metadata message version: $version") + } else { + UnknownKey(version) + } } /** @@ -1266,7 +1269,7 @@ object GroupMetadataManager { GroupMetadataManager.readMessageKey(record.key) match { case offsetKey: OffsetKey => parseOffsets(offsetKey, record.value) case groupMetadataKey: GroupMetadataKey => parseGroupMetadata(groupMetadataKey, record.value) - case _ => throw new KafkaException("Failed to decode message using offset topic decoder (message had an invalid key)") + case unknownKey: UnknownKey => (Some(s"unknown::version=${unknownKey.version}"), None) } } } @@ -1344,18 +1347,20 @@ case class GroupTopicPartition(group: String, topicPartition: TopicPartition) { "[%s,%s,%d]".format(group, topicPartition.topic, topicPartition.partition) } -trait BaseKey{ +sealed trait BaseKey{ def version: Short def key: Any } case class OffsetKey(version: Short, key: GroupTopicPartition) extends BaseKey { - override def toString: String = key.toString } case class GroupMetadataKey(version: Short, key: String) extends BaseKey { - override def toString: String = key } +case class UnknownKey(version: Short) extends BaseKey { + override def key: String = null + override def toString: String = key +} diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala index cb501f774fd9d..30bd517c0093e 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionLog.scala @@ -19,7 +19,6 @@ package kafka.coordinator.transaction import java.io.PrintStream import java.nio.ByteBuffer import java.nio.charset.StandardCharsets - import kafka.internals.generated.{TransactionLogKey, TransactionLogValue} import org.apache.kafka.clients.consumer.ConsumerRecord import org.apache.kafka.common.protocol.{ByteBufferAccessor, MessageUtil} @@ -98,7 +97,7 @@ object TransactionLog { * * @return the key */ - def readTxnRecordKey(buffer: ByteBuffer): TxnKey = { + def readTxnRecordKey(buffer: ByteBuffer): BaseKey = { val version = buffer.getShort if (version >= TransactionLogKey.LOWEST_SUPPORTED_VERSION && version <= TransactionLogKey.HIGHEST_SUPPORTED_VERSION) { val value = new TransactionLogKey(new ByteBufferAccessor(buffer), version) @@ -106,7 +105,9 @@ object TransactionLog { version = version, transactionalId = value.transactionalId ) - } else throw new IllegalStateException(s"Unknown version $version from the transaction log message") + } else { + UnknownKey(version) + } } /** @@ -148,17 +149,21 @@ object TransactionLog { // Formatter for use with tools to read transaction log messages class TransactionLogMessageFormatter extends MessageFormatter { def writeTo(consumerRecord: ConsumerRecord[Array[Byte], Array[Byte]], output: PrintStream): Unit = { - Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { txnKey => - val transactionalId = txnKey.transactionalId - val value = consumerRecord.value - val producerIdMetadata = if (value == null) - None - else - readTxnRecordValue(transactionalId, ByteBuffer.wrap(value)) - output.write(transactionalId.getBytes(StandardCharsets.UTF_8)) - output.write("::".getBytes(StandardCharsets.UTF_8)) - output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8)) - output.write("\n".getBytes(StandardCharsets.UTF_8)) + Option(consumerRecord.key).map(key => readTxnRecordKey(ByteBuffer.wrap(key))).foreach { + case txnKey: TxnKey => + val transactionalId = txnKey.transactionalId + val value = consumerRecord.value + val producerIdMetadata = if (value == null) + None + else + readTxnRecordValue(transactionalId, ByteBuffer.wrap(value)) + output.write(transactionalId.getBytes(StandardCharsets.UTF_8)) + output.write("::".getBytes(StandardCharsets.UTF_8)) + output.write(producerIdMetadata.getOrElse("NULL").toString.getBytes(StandardCharsets.UTF_8)) + output.write("\n".getBytes(StandardCharsets.UTF_8)) + + case unknownKey: UnknownKey => + output.write(s"unknown::version=${unknownKey.version}\n".getBytes(StandardCharsets.UTF_8)) } } } @@ -167,25 +172,41 @@ object TransactionLog { * Exposed for printing records using [[kafka.tools.DumpLogSegments]] */ def formatRecordKeyAndValue(record: Record): (Option[String], Option[String]) = { - val txnKey = TransactionLog.readTxnRecordKey(record.key) - val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}" - - val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match { - case None => "" - - case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," + - s"producerEpoch:${txnMetadata.producerEpoch}," + - s"state=${txnMetadata.state}," + - s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," + - s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," + - s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}" - } + TransactionLog.readTxnRecordKey(record.key) match { + case txnKey: TxnKey => + val keyString = s"transaction_metadata::transactionalId=${txnKey.transactionalId}" + + val valueString = TransactionLog.readTxnRecordValue(txnKey.transactionalId, record.value) match { + case None => "" - (Some(keyString), Some(valueString)) + case Some(txnMetadata) => s"producerId:${txnMetadata.producerId}," + + s"producerEpoch:${txnMetadata.producerEpoch}," + + s"state=${txnMetadata.state}," + + s"partitions=${txnMetadata.topicPartitions.mkString("[", ",", "]")}," + + s"txnLastUpdateTimestamp=${txnMetadata.txnLastUpdateTimestamp}," + + s"txnTimeoutMs=${txnMetadata.txnTimeoutMs}" + } + + (Some(keyString), Some(valueString)) + + case unknownKey: UnknownKey => + (Some(s"unknown::version=${unknownKey.version}"), None) + } } } -case class TxnKey(version: Short, transactionalId: String) { +sealed trait BaseKey{ + def version: Short + def transactionalId: String +} + +case class TxnKey(version: Short, transactionalId: String) extends BaseKey { + override def toString: String = transactionalId +} + +case class UnknownKey(version: Short) extends BaseKey { + override def transactionalId: String = null override def toString: String = transactionalId } + diff --git a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala index 217b38382a7a1..4a19b57a6ec56 100644 --- a/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala +++ b/core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala @@ -466,16 +466,23 @@ class TransactionStateManager(brokerId: Int, memRecords.batches.forEach { batch => for (record <- batch.asScala) { require(record.hasKey, "Transaction state log's key should not be null") - val txnKey = TransactionLog.readTxnRecordKey(record.key) - // load transaction metadata along with transaction state - val transactionalId = txnKey.transactionalId - TransactionLog.readTxnRecordValue(transactionalId, record.value) match { - case None => - loadedTransactions.remove(transactionalId) - case Some(txnMetadata) => - loadedTransactions.put(transactionalId, txnMetadata) + TransactionLog.readTxnRecordKey(record.key) match { + case txnKey: TxnKey => + // load transaction metadata along with transaction state + val transactionalId = txnKey.transactionalId + TransactionLog.readTxnRecordValue(transactionalId, record.value) match { + case None => + loadedTransactions.remove(transactionalId) + case Some(txnMetadata) => + loadedTransactions.put(transactionalId, txnMetadata) + } + currOffset = batch.nextOffset + + case unknownKey: UnknownKey => + warn(s"Unknown message key with version ${unknownKey.version}" + + s" while loading transaction state from $topicPartition. Ignoring it. " + + "It could be a left over from an aborted upgrade.") } - currOffset = batch.nextOffset } } } diff --git a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala index bf475cc75899a..c4cb53f1d01de 100644 --- a/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/group/GroupMetadataManagerTest.scala @@ -37,7 +37,7 @@ import org.apache.kafka.clients.consumer.internals.ConsumerProtocol import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.internals.Topic import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics => kMetrics} -import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.protocol.{Errors, MessageUtil} import org.apache.kafka.common.record._ import org.apache.kafka.common.requests.OffsetFetchResponse import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse @@ -657,6 +657,7 @@ class GroupMetadataManagerTest { val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) val memberId = "98098230493" val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, (offsetCommitRecords ++ Seq(groupMetadataRecord)).toArray: _*) @@ -2648,4 +2649,61 @@ class GroupMetadataManagerTest { assertTrue(partitionLoadTime("partition-load-time-max") >= diff) assertTrue(partitionLoadTime("partition-load-time-avg") >= diff) } + + @Test + def testReadMessageKeyCanReadUnknownMessage(): Unit = { + val record = new kafka.internals.generated.GroupMetadataKey() + val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record) + val key = GroupMetadataManager.readMessageKey(ByteBuffer.wrap(unknownRecord)) + assertEquals(UnknownKey(Short.MaxValue), key) + } + + @Test + def testLoadGroupsAndOffsetsWillIgnoreUnknownMessage(): Unit = { + val generation = 935 + val protocolType = "consumer" + val protocol = "range" + val startOffset = 15L + val committedOffsets = Map( + new TopicPartition("foo", 0) -> 23L, + new TopicPartition("foo", 1) -> 455L, + new TopicPartition("bar", 0) -> 8992L + ) + + val offsetCommitRecords = createCommittedOffsetRecords(committedOffsets) + val memberId = "98098230493" + val groupMetadataRecord = buildStableGroupRecordWithMember(generation, protocolType, protocol, memberId) + + // Should ignore unknown record + val unknownKey = new kafka.internals.generated.GroupMetadataKey() + val lowestUnsupportedVersion = (kafka.internals.generated.GroupMetadataKey + .HIGHEST_SUPPORTED_VERSION + 1).toShort + + val unknownMessage1 = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey) + val unknownMessage2 = MessageUtil.toVersionPrefixedBytes(lowestUnsupportedVersion, unknownKey) + val unknownRecord1 = new SimpleRecord(unknownMessage1, unknownMessage1) + val unknownRecord2 = new SimpleRecord(unknownMessage2, unknownMessage2) + + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (offsetCommitRecords ++ Seq(unknownRecord1, unknownRecord2) ++ Seq(groupMetadataRecord)).toArray: _*) + + expectGroupMetadataLoad(groupTopicPartition, startOffset, records) + EasyMock.replay(replicaManager) + + groupMetadataManager.loadGroupsAndOffsets(groupTopicPartition, 1, _ => (), 0L) + + val group = groupMetadataManager.getGroup(groupId).getOrElse(throw new AssertionError("Group was not loaded into the cache")) + assertEquals(groupId, group.groupId) + assertEquals(Stable, group.currentState) + assertEquals(memberId, group.leaderOrNull) + assertEquals(generation, group.generationId) + assertEquals(Some(protocolType), group.protocolType) + assertEquals(protocol, group.protocolName.orNull) + assertEquals(Set(memberId), group.allMembers) + assertEquals(committedOffsets.size, group.allOffsets.size) + committedOffsets.foreach { case (topicPartition, offset) => + assertEquals(Some(offset), group.offset(topicPartition).map(_.offset)) + assertTrue(group.offset(topicPartition).map(_.expireTimestamp).contains(None)) + } + } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala index 32e17d88a7b1e..eb1284278d29b 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionLogTest.scala @@ -17,12 +17,15 @@ package kafka.coordinator.transaction +import kafka.internals.generated.TransactionLogKey import kafka.utils.TestUtils import org.apache.kafka.common.TopicPartition +import org.apache.kafka.common.protocol.MessageUtil import org.apache.kafka.common.record.{CompressionType, MemoryRecords, SimpleRecord} import org.junit.jupiter.api.Assertions.{assertEquals, assertThrows} import org.junit.jupiter.api.Test +import java.nio.ByteBuffer import scala.jdk.CollectionConverters._ class TransactionLogTest { @@ -135,4 +138,11 @@ class TransactionLogTest { assertEquals(Some(""), valueStringOpt) } + @Test + def testReadTxnRecordKeyCanReadUnknownMessage(): Unit = { + val record = new TransactionLogKey() + val unknownRecord = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, record) + val key = TransactionLog.readTxnRecordKey(ByteBuffer.wrap(unknownRecord)) + assertEquals(UnknownKey(Short.MaxValue), key) + } } diff --git a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala index 9cff29ba858a3..f648378ce0874 100644 --- a/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala +++ b/core/src/test/scala/unit/kafka/coordinator/transaction/TransactionStateManagerTest.scala @@ -16,6 +16,8 @@ */ package kafka.coordinator.transaction +import kafka.internals.generated.TransactionLogKey + import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.concurrent.CountDownLatch @@ -29,7 +31,7 @@ import kafka.zk.KafkaZkClient import org.apache.kafka.common.TopicPartition import org.apache.kafka.common.internals.Topic.TRANSACTION_STATE_TOPIC_NAME import org.apache.kafka.common.metrics.{JmxReporter, KafkaMetricsContext, Metrics} -import org.apache.kafka.common.protocol.Errors +import org.apache.kafka.common.protocol.{Errors, MessageUtil} import org.apache.kafka.common.record._ import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse import org.apache.kafka.common.requests.TransactionResult @@ -1112,4 +1114,40 @@ class TransactionStateManagerTest { assertTrue(partitionLoadTime("partition-load-time-max") >= 0) assertTrue(partitionLoadTime( "partition-load-time-avg") >= 0) } + + @Test + def testIgnoreUnknownRecordType(): Unit = { + txnMetadata1.state = PrepareCommit + txnMetadata1.addPartitions(Set[TopicPartition](new TopicPartition("topic1", 0), + new TopicPartition("topic1", 1))) + + txnRecords += new SimpleRecord(txnMessageKeyBytes1, TransactionLog.valueToBytes(txnMetadata1.prepareNoTransit())) + val startOffset = 0L + + val unknownKey = new TransactionLogKey() + val unknownMessage = MessageUtil.toVersionPrefixedBytes(Short.MaxValue, unknownKey) + val unknownRecord = new SimpleRecord(unknownMessage, unknownMessage) + + val records = MemoryRecords.withRecords(startOffset, CompressionType.NONE, + (Seq(unknownRecord) ++ txnRecords).toArray: _*) + + prepareTxnLog(topicPartition, 0, records) + + transactionManager.loadTransactionsForTxnTopicPartition(partitionId, coordinatorEpoch = 1, (_, _, _, _) => ()) + assertEquals(0, transactionManager.loadingPartitions.size) + assertTrue(transactionManager.transactionMetadataCache.contains(partitionId)) + val txnMetadataPool = transactionManager.transactionMetadataCache(partitionId).metadataPerTransactionalId + assertFalse(txnMetadataPool.isEmpty) + assertTrue(txnMetadataPool.contains(transactionalId1)) + val txnMetadata = txnMetadataPool.get(transactionalId1) + assertEquals(txnMetadata1.transactionalId, txnMetadata.transactionalId) + assertEquals(txnMetadata1.producerId, txnMetadata.producerId) + assertEquals(txnMetadata1.lastProducerId, txnMetadata.lastProducerId) + assertEquals(txnMetadata1.producerEpoch, txnMetadata.producerEpoch) + assertEquals(txnMetadata1.lastProducerEpoch, txnMetadata.lastProducerEpoch) + assertEquals(txnMetadata1.txnTimeoutMs, txnMetadata.txnTimeoutMs) + assertEquals(txnMetadata1.state, txnMetadata.state) + assertEquals(txnMetadata1.topicPartitions, txnMetadata.topicPartitions) + assertEquals(1, transactionManager.transactionMetadataCache(partitionId).coordinatorEpoch) + } }