Skip to content
11 changes: 7 additions & 4 deletions core/src/main/scala/kafka/cluster/Partition.scala
Original file line number Diff line number Diff line change
Expand Up @@ -678,10 +678,13 @@ class Partition(val topicPartition: TopicPartition,
isr: Set[Int],
addingReplicas: Seq[Int],
removingReplicas: Seq[Int]): Unit = {
remoteReplicasMap.clear()
assignment
.filter(_ != localBrokerId)
.foreach(id => remoteReplicasMap.getAndMaybePut(id, new Replica(id, topicPartition)))
val newRemoteReplicas = assignment.filter(_ != localBrokerId)
val removedReplicas = remoteReplicasMap.keys.filter(!newRemoteReplicas.contains(_))

// due to code paths accessing remoteReplicasMap without a lock,
// first add the new replicas and then remove the old ones
newRemoteReplicas.foreach(id => remoteReplicasMap.getAndMaybePut(id, new Replica(id, topicPartition)))
remoteReplicasMap.removeAll(removedReplicas)

if (addingReplicas.nonEmpty || removingReplicas.nonEmpty)
assignmentState = OngoingReassignmentState(addingReplicas, removingReplicas, assignment)
Expand Down
2 changes: 2 additions & 0 deletions core/src/main/scala/kafka/utils/Pool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ class Pool[K,V](valueFactory: Option[K => V] = None) extends Iterable[(K, V)] {

def remove(key: K, value: V): Boolean = pool.remove(key, value)

def removeAll(keys: Iterable[K]): Unit = pool.keySet.removeAll(keys.asJavaCollection)

def keys: Set[K] = pool.keySet.asScala

def values: Iterable[V] = pool.values.asScala
Expand Down
51 changes: 51 additions & 0 deletions core/src/test/scala/unit/kafka/cluster/PartitionLockTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.mockito.ArgumentMatchers
import org.mockito.Mockito.{mock, when}

import scala.jdk.CollectionConverters._
import scala.concurrent.duration._

/**
* Verifies that slow appends to log don't block request threads processing replica fetch requests.
Expand Down Expand Up @@ -116,6 +117,56 @@ class PartitionLockTest extends Logging {
future.get(15, TimeUnit.SECONDS)
}

/**
* Concurrently calling updateAssignmentAndIsr should always ensure that non-lock access
* to the inner remoteReplicaMap (accessed by getReplica) cannot see an intermediate state
* where replicas present both in the old and new assignment are missing
*/
@Test
def testGetReplicaWithUpdateAssignmentAndIsr(): Unit = {
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails incredibly quickly 100/100 times without the Partition.scala changes.

val active = new AtomicBoolean(true)
val replicaToCheck = 3
val firstReplicaSet = Seq[Integer](3, 4, 5).asJava
val secondReplicaSet = Seq[Integer](1, 2, 3).asJava
def partitionState(replicas: java.util.List[Integer]) = new LeaderAndIsrPartitionState()
.setControllerEpoch(1)
.setLeader(replicas.get(0))
.setLeaderEpoch(1)
.setIsr(replicas)
.setZkVersion(1)
.setReplicas(replicas)
.setIsNew(true)
val offsetCheckpoints: OffsetCheckpoints = mock(classOf[OffsetCheckpoints])
// Update replica set synchronously first to avoid race conditions
partition.makeLeader(partitionState(secondReplicaSet), offsetCheckpoints)
assertTrue(s"Expected replica $replicaToCheck to be defined", partition.getReplica(replicaToCheck).isDefined)

val future = executorService.submit((() => {
var i = 0
// Flip assignment between two replica sets
while (active.get) {
val replicas = if (i % 2 == 0) {
firstReplicaSet
} else {
secondReplicaSet
}

partition.makeLeader(partitionState(replicas), offsetCheckpoints)

i += 1
Thread.sleep(1) // just to avoid tight loop
}
}): Runnable)

val deadline = 1.seconds.fromNow
while (deadline.hasTimeLeft()) {
assertTrue(s"Expected replica $replicaToCheck to be defined", partition.getReplica(replicaToCheck).isDefined)
}
active.set(false)
future.get(5, TimeUnit.SECONDS)
assertTrue(s"Expected replica $replicaToCheck to be defined", partition.getReplica(replicaToCheck).isDefined)
}

/**
* Perform concurrent appends and replica fetch requests that don't require write lock to
* update follower state. Release sufficient append permits to complete all except one append.
Expand Down
40 changes: 40 additions & 0 deletions core/src/test/scala/unit/kafka/utils/PoolTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package kafka.utils

import org.junit.Assert.assertEquals
import org.junit.Test


class PoolTest {
@Test
def testRemoveAll(): Unit = {
val pool = new Pool[Int, String]
pool.put(1, "1")
pool.put(2, "2")
pool.put(3, "3")

assertEquals(3, pool.size)

pool.removeAll(Seq(1, 2))
assertEquals(1, pool.size)
assertEquals("3", pool.get(3))
pool.removeAll(Seq(3))
assertEquals(0, pool.size)
}
}