diff --git a/core/src/main/scala/kafka/cluster/Partition.scala b/core/src/main/scala/kafka/cluster/Partition.scala index 9e8edaa5ea4e2..fb0576ee5648c 100755 --- a/core/src/main/scala/kafka/cluster/Partition.scala +++ b/core/src/main/scala/kafka/cluster/Partition.scala @@ -678,10 +678,13 @@ class Partition(val topicPartition: TopicPartition, isr: Set[Int], addingReplicas: Seq[Int], removingReplicas: Seq[Int]): Unit = { - remoteReplicasMap.clear() - assignment - .filter(_ != localBrokerId) - .foreach(id => remoteReplicasMap.getAndMaybePut(id, new Replica(id, topicPartition))) + val newRemoteReplicas = assignment.filter(_ != localBrokerId) + val removedReplicas = remoteReplicasMap.keys.filter(!newRemoteReplicas.contains(_)) + + // due to code paths accessing remoteReplicasMap without a lock, + // first add the new replicas and then remove the old ones + newRemoteReplicas.foreach(id => remoteReplicasMap.getAndMaybePut(id, new Replica(id, topicPartition))) + remoteReplicasMap.removeAll(removedReplicas) if (addingReplicas.nonEmpty || removingReplicas.nonEmpty) assignmentState = OngoingReassignmentState(addingReplicas, removingReplicas, assignment) diff --git a/core/src/main/scala/kafka/utils/Pool.scala b/core/src/main/scala/kafka/utils/Pool.scala index 0a1531ba4ad04..d64ff5d6538ff 100644 --- a/core/src/main/scala/kafka/utils/Pool.scala +++ b/core/src/main/scala/kafka/utils/Pool.scala @@ -69,6 +69,8 @@ class Pool[K,V](valueFactory: Option[K => V] = None) extends Iterable[(K, V)] { def remove(key: K, value: V): Boolean = pool.remove(key, value) + def removeAll(keys: Iterable[K]): Unit = pool.keySet.removeAll(keys.asJavaCollection) + def keys: Set[K] = pool.keySet.asScala def values: Iterable[V] = pool.values.asScala diff --git a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala index 403eebec30dfb..8dd3b53f81fac 100644 --- a/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala +++ b/core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala @@ -36,6 +36,7 @@ import org.mockito.ArgumentMatchers import org.mockito.Mockito.{mock, when} import scala.jdk.CollectionConverters._ +import scala.concurrent.duration._ /** * Verifies that slow appends to log don't block request threads processing replica fetch requests. @@ -116,6 +117,56 @@ class PartitionLockTest extends Logging { future.get(15, TimeUnit.SECONDS) } + /** + * Concurrently calling updateAssignmentAndIsr should always ensure that non-lock access + * to the inner remoteReplicaMap (accessed by getReplica) cannot see an intermediate state + * where replicas present both in the old and new assignment are missing + */ + @Test + def testGetReplicaWithUpdateAssignmentAndIsr(): Unit = { + val active = new AtomicBoolean(true) + val replicaToCheck = 3 + val firstReplicaSet = Seq[Integer](3, 4, 5).asJava + val secondReplicaSet = Seq[Integer](1, 2, 3).asJava + def partitionState(replicas: java.util.List[Integer]) = new LeaderAndIsrPartitionState() + .setControllerEpoch(1) + .setLeader(replicas.get(0)) + .setLeaderEpoch(1) + .setIsr(replicas) + .setZkVersion(1) + .setReplicas(replicas) + .setIsNew(true) + val offsetCheckpoints: OffsetCheckpoints = mock(classOf[OffsetCheckpoints]) + // Update replica set synchronously first to avoid race conditions + partition.makeLeader(partitionState(secondReplicaSet), offsetCheckpoints) + assertTrue(s"Expected replica $replicaToCheck to be defined", partition.getReplica(replicaToCheck).isDefined) + + val future = executorService.submit((() => { + var i = 0 + // Flip assignment between two replica sets + while (active.get) { + val replicas = if (i % 2 == 0) { + firstReplicaSet + } else { + secondReplicaSet + } + + partition.makeLeader(partitionState(replicas), offsetCheckpoints) + + i += 1 + Thread.sleep(1) // just to avoid tight loop + } + }): Runnable) + + val deadline = 1.seconds.fromNow + while (deadline.hasTimeLeft()) { + assertTrue(s"Expected replica $replicaToCheck to be defined", partition.getReplica(replicaToCheck).isDefined) + } + active.set(false) + future.get(5, TimeUnit.SECONDS) + assertTrue(s"Expected replica $replicaToCheck to be defined", partition.getReplica(replicaToCheck).isDefined) + } + /** * Perform concurrent appends and replica fetch requests that don't require write lock to * update follower state. Release sufficient append permits to complete all except one append. diff --git a/core/src/test/scala/unit/kafka/utils/PoolTest.scala b/core/src/test/scala/unit/kafka/utils/PoolTest.scala new file mode 100644 index 0000000000000..74751806a11ef --- /dev/null +++ b/core/src/test/scala/unit/kafka/utils/PoolTest.scala @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package kafka.utils + +import org.junit.Assert.assertEquals +import org.junit.Test + + +class PoolTest { + @Test + def testRemoveAll(): Unit = { + val pool = new Pool[Int, String] + pool.put(1, "1") + pool.put(2, "2") + pool.put(3, "3") + + assertEquals(3, pool.size) + + pool.removeAll(Seq(1, 2)) + assertEquals(1, pool.size) + assertEquals("3", pool.get(3)) + pool.removeAll(Seq(3)) + assertEquals(0, pool.size) + } +}