diff --git a/feature/dashboard/src/test/java/com/simprints/feature/dashboard/tools/di/FakeCoreModule.kt b/feature/dashboard/src/test/java/com/simprints/feature/dashboard/tools/di/FakeCoreModule.kt index f080af11e0..6e1578218c 100644 --- a/feature/dashboard/src/test/java/com/simprints/feature/dashboard/tools/di/FakeCoreModule.kt +++ b/feature/dashboard/src/test/java/com/simprints/feature/dashboard/tools/di/FakeCoreModule.kt @@ -1,6 +1,7 @@ package com.simprints.feature.dashboard.tools.di import com.simprints.core.AppScope +import com.simprints.core.AvailableProcessors import com.simprints.core.CoreModule import com.simprints.core.DeviceID import com.simprints.core.DispatcherBG @@ -52,6 +53,10 @@ object FakeCoreModule { @Provides fun providePackageVersionName(): String = PACKAGE_VERSION_NAME + @AvailableProcessors + @Provides + fun provideAvailableProcessors(): Int = 4 + @DispatcherIO @Provides fun provideCoroutineDispatcherIo(): CoroutineDispatcher = StandardTestDispatcher() diff --git a/feature/matcher/src/main/java/com/simprints/matcher/usecases/CreateRangesUseCase.kt b/feature/matcher/src/main/java/com/simprints/matcher/usecases/CreateRangesUseCase.kt index e85412449b..dd1ed5f129 100644 --- a/feature/matcher/src/main/java/com/simprints/matcher/usecases/CreateRangesUseCase.kt +++ b/feature/matcher/src/main/java/com/simprints/matcher/usecases/CreateRangesUseCase.kt @@ -1,45 +1,53 @@ package com.simprints.matcher.usecases +import com.simprints.core.AvailableProcessors import javax.inject.Inject +import kotlin.math.ceil -internal class CreateRangesUseCase @Inject constructor() { +internal class CreateRangesUseCase @Inject constructor( + @AvailableProcessors private val availableProcessors: Int, +) { /** * Creates a list of ranges to be used for batch processing. - * Range size is increased dynamically to ensure that first couple of batches are small - * to speed up initial reads, then it increases to ensure that the last batches are not too small. - * - * For example with minBatchSize = 10, returned batches will be 10, 10, 20, 30, 40, 50, 50 in size (if the total allows). + * The number of ranges will be a multiple of the available processors to ensure + * efficient parallel processing. + * Range sizes are adjusted to not exceed MAX_BATCH_SIZE. */ operator fun invoke( totalCount: Int, - minBatchSize: Int = DEFAULT_BATCH_SIZE, ): List { - val ranges = mutableListOf() - var index = 1 + if (totalCount <= 0) return emptyList() + + // Calculate how many multiples of processors we need to keep batches under MAX_BATCH_SIZE + val batchesPerProcessor = ceil(totalCount.toDouble() / (availableProcessors * MAX_BATCH_SIZE)).toInt().coerceAtLeast(1) + val totalBatches = availableProcessors * batchesPerProcessor + + // Calculate the base batch size and remainder for even distribution + val baseBatchSize = totalCount / totalBatches + val remainder = totalCount % totalBatches - var nextBatchSize = minBatchSize + val ranges = mutableListOf() var start = 0 - var end = nextBatchSize - while (start < totalCount) { - if (end > totalCount) { - end = totalCount - } - ranges.add(start..end) + // Create ranges with sizes distributed as evenly as possible + for (i in 0 until totalBatches) { + // Add 1 to batch size for the first 'remainder' batches to distribute the remainder evenly + val batchSize = baseBatchSize + if (i < remainder) 1 else 0 + val end = (start + batchSize).coerceAtMost(totalCount) + + ranges.add(start until end) start = end - end += nextBatchSize - // Make sure next batch is increased - nextBatchSize = minBatchSize + (minBatchSize * index.coerceIn(1, 4)) - index++ + if (start >= totalCount) break } + return ranges } companion object { /** - * Experimentally determined batch size that works well for most cases. + * Maximum size for a batch to avoid huge memory consumption */ - private const val DEFAULT_BATCH_SIZE = 1000 + private const val MAX_BATCH_SIZE = 2000 } } diff --git a/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt b/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt index 5077c81e0e..6395c5ec20 100644 --- a/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt +++ b/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt @@ -1,6 +1,8 @@ package com.simprints.matcher.usecases +import com.simprints.core.AvailableProcessors import com.simprints.core.DispatcherBG +import com.simprints.core.DispatcherIO import com.simprints.face.infra.basebiosdk.matching.FaceIdentity import com.simprints.face.infra.basebiosdk.matching.FaceMatcher import com.simprints.face.infra.basebiosdk.matching.FaceSample @@ -17,7 +19,9 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.launch +import java.util.concurrent.atomic.AtomicInteger import javax.inject.Inject import kotlin.math.min import com.simprints.infra.enrolment.records.repository.domain.models.FaceIdentity as DomainFaceIdentity @@ -26,16 +30,12 @@ internal class FaceMatcherUseCase @Inject constructor( private val enrolmentRecordRepository: EnrolmentRecordRepository, private val resolveFaceBioSdk: ResolveFaceBioSdkUseCase, private val createRanges: CreateRangesUseCase, + @AvailableProcessors private val availableProcessors: Int, @DispatcherBG private val dispatcherBG: CoroutineDispatcher, + @DispatcherIO private val dispatcherIO: CoroutineDispatcher, ) : MatcherUseCase { override val crashReportTag = LoggingConstants.CrashReportTag.FACE_MATCHING - // When using local DB loadedCandidates = expectedCandidates - // However, when using CommCare as data source, loadedCandidates < expectedCandidates - // as it's count function does not take into account filtering criteria - // This var is not thread safe - var loadedCandidates = 0 - override suspend operator fun invoke( matchParams: MatchParams, project: Project, @@ -59,12 +59,17 @@ internal class FaceMatcherUseCase @Inject constructor( send(MatcherState.Success(emptyList(), 0, bioSdk.matcherName)) return@channelFlow } - loadedCandidates = 0 + Simber.i("Matching candidates", tag = crashReportTag) send(MatcherState.LoadingStarted(expectedCandidates)) + + // When using local DB loadedCandidates = expectedCandidates + // However, when using CommCare as data source, loadedCandidates < expectedCandidates + // as it's count function does not take into account filtering criteria + val loadedCandidates = AtomicInteger(0) val ranges = createRanges(expectedCandidates) // if number of ranges less than the number of cores then use the number of ranges - val numConsumers = min(Runtime.getRuntime().availableProcessors(), ranges.size) + val numConsumers = min(availableProcessors, ranges.size) val resultSet = MatchResultSet() val candidatesChannel = enrolmentRecordRepository @@ -75,7 +80,7 @@ internal class FaceMatcherUseCase @Inject constructor( project = project, scope = this, onCandidateLoaded = { - loadedCandidates++ + loadedCandidates.incrementAndGet() this.trySend(MatcherState.CandidateLoaded) }, ) @@ -88,8 +93,8 @@ internal class FaceMatcherUseCase @Inject constructor( } // Wait for all to complete consumerJobs.forEach { it.join() } - send(MatcherState.Success(resultSet.toList(), loadedCandidates, bioSdk.matcherName)) - } + send(MatcherState.Success(resultSet.toList(), loadedCandidates.get(), bioSdk.matcherName)) + }.flowOn(dispatcherIO) suspend fun consumeAndMatch( candidatesChannel: ReceiveChannel>, diff --git a/feature/matcher/src/main/java/com/simprints/matcher/usecases/FingerprintMatcherUseCase.kt b/feature/matcher/src/main/java/com/simprints/matcher/usecases/FingerprintMatcherUseCase.kt index 5d6317487e..b92f4e6162 100644 --- a/feature/matcher/src/main/java/com/simprints/matcher/usecases/FingerprintMatcherUseCase.kt +++ b/feature/matcher/src/main/java/com/simprints/matcher/usecases/FingerprintMatcherUseCase.kt @@ -1,6 +1,8 @@ package com.simprints.matcher.usecases +import com.simprints.core.AvailableProcessors import com.simprints.core.DispatcherBG +import com.simprints.core.DispatcherIO import com.simprints.core.domain.common.FlowType import com.simprints.core.domain.fingerprint.IFingerIdentifier import com.simprints.fingerprint.infra.basebiosdk.matching.domain.FingerIdentifier @@ -22,7 +24,9 @@ import kotlinx.coroutines.CoroutineDispatcher import kotlinx.coroutines.channels.ReceiveChannel import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.channelFlow +import kotlinx.coroutines.flow.flowOn import kotlinx.coroutines.launch +import java.util.concurrent.atomic.AtomicInteger import javax.inject.Inject import kotlin.math.min import com.simprints.infra.enrolment.records.repository.domain.models.FingerprintIdentity as DomainFingerprintIdentity @@ -32,16 +36,12 @@ internal class FingerprintMatcherUseCase @Inject constructor( private val resolveBioSdkWrapper: ResolveBioSdkWrapperUseCase, private val configManager: ConfigManager, private val createRanges: CreateRangesUseCase, + @AvailableProcessors private val availableProcessors: Int, @DispatcherBG private val dispatcherBG: CoroutineDispatcher, + @DispatcherIO private val dispatcherIO: CoroutineDispatcher, ) : MatcherUseCase { override val crashReportTag = LoggingConstants.CrashReportTag.FINGER_MATCHING - // When using local DB loadedCandidates = expectedCandidates - // However, when using CommCare as data source, loadedCandidates < expectedCandidates - // as it's count function does not take into account filtering criteria - // This var is not thread safe - var loadedCandidates = 0 - override suspend operator fun invoke( matchParams: MatchParams, project: Project, @@ -67,10 +67,14 @@ internal class FingerprintMatcherUseCase @Inject constructor( Simber.i("Matching candidates", tag = crashReportTag) send(MatcherState.LoadingStarted(expectedCandidates)) - loadedCandidates = 0 + + // When using local DB loadedCandidates = expectedCandidates + // However, when using CommCare as data source, loadedCandidates < expectedCandidates + // as it's count function does not take into account filtering criteria + val loadedCandidates = AtomicInteger(0) val ranges = createRanges(expectedCandidates) // if number of ranges less than the number of cores then use the number of ranges - val numConsumers = min(Runtime.getRuntime().availableProcessors(), ranges.size) + val numConsumers = min(availableProcessors, ranges.size) val channel = enrolmentRecordRepository.loadFingerprintIdentities( query = queryWithSupportedFormat, ranges = ranges, @@ -78,7 +82,7 @@ internal class FingerprintMatcherUseCase @Inject constructor( scope = this, project = project, ) { - loadedCandidates++ + loadedCandidates.incrementAndGet() trySend(MatcherState.CandidateLoaded) } @@ -94,8 +98,8 @@ internal class FingerprintMatcherUseCase @Inject constructor( consumerJobs.forEach { it.join() } Simber.i("Matched $loadedCandidates candidates", tag = crashReportTag) - send(MatcherState.Success(resultSet.toList(), loadedCandidates, bioSdkWrapper.matcherName)) - } + send(MatcherState.Success(resultSet.toList(), loadedCandidates.get(), bioSdkWrapper.matcherName)) + }.flowOn(dispatcherIO) private suspend fun consumeAndMatch( channel: ReceiveChannel>, diff --git a/feature/matcher/src/main/java/com/simprints/matcher/usecases/MatchResultSet.kt b/feature/matcher/src/main/java/com/simprints/matcher/usecases/MatchResultSet.kt index d91c869eb4..928d028397 100644 --- a/feature/matcher/src/main/java/com/simprints/matcher/usecases/MatchResultSet.kt +++ b/feature/matcher/src/main/java/com/simprints/matcher/usecases/MatchResultSet.kt @@ -1,40 +1,48 @@ package com.simprints.matcher.usecases import com.simprints.matcher.MatchResultItem -import java.util.TreeSet +import java.util.concurrent.ConcurrentSkipListSet +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.locks.ReentrantLock +import kotlin.concurrent.withLock internal class MatchResultSet( private val maxSize: Int = MAX_RESULTS, ) { - private var lowestConfidence: Float = 0f + private val lowestConfidence = AtomicReference(0f) + private val lock = ReentrantLock() - private val treeSet = TreeSet( + private val skipListSet = ConcurrentSkipListSet( compareByDescending { it.confidence }.thenByDescending { it.subjectId }, ) fun add(element: T): MatchResultSet { - if (lowestConfidence > element.confidence) { - // skip adding if the last element is greater than the current element + // Use a lock to ensure thread safety during the entire add operation + lock.withLock { + // Only perform this optimization when we know the set is at max capacity + if (skipListSet.size >= maxSize && lowestConfidence.get() > element.confidence) { + // skip adding if the set is full and the last element has higher confidence than the current element + return this + } + + skipListSet.add(element) + if (skipListSet.size > maxSize) { + skipListSet.pollLast() + + // Now that the set is full, we can skip adding elements + // with confidence lower than the current lowest + lowestConfidence.set(skipListSet.last().confidence) + } return this } - - treeSet.add(element) - if (treeSet.size > maxSize) { - treeSet.pollLast() - - // Not that the set is full, we can skip adding elements - // with confidence lower than the current lowest - lowestConfidence = treeSet.last().confidence - } - return this } fun addAll(elements: MatchResultSet): MatchResultSet { - elements.treeSet.forEach { add(it) } + elements.skipListSet.forEach { add(it) } return this } - fun toList(): List = treeSet.toList() + fun toList(): List = skipListSet.toList() companion object { /** diff --git a/feature/matcher/src/test/java/com/simprints/matcher/usecases/CreateRangesUseCaseTest.kt b/feature/matcher/src/test/java/com/simprints/matcher/usecases/CreateRangesUseCaseTest.kt index f0e70e982b..e0b73d2a98 100644 --- a/feature/matcher/src/test/java/com/simprints/matcher/usecases/CreateRangesUseCaseTest.kt +++ b/feature/matcher/src/test/java/com/simprints/matcher/usecases/CreateRangesUseCaseTest.kt @@ -7,45 +7,382 @@ import org.junit.runners.JUnit4 @RunWith(JUnit4::class) class CreateRangesUseCaseTest { - private val useCase = CreateRangesUseCase() @Test - fun `Returns empty list if no total`() { - assertThat(useCase.invoke(0, 5)).isEqualTo(emptyList()) + fun `should create correct ranges when numCandidates equals numConsumers`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 5) + + // When + val result = useCase(5) + + // Then + assertThat(result).containsExactly( + 0 until 1, + 1 until 2, + 2 until 3, + 3 until 4, + 4 until 5 + ).inOrder() + } + + @Test + fun `should create correct ranges when numCandidates is greater than numConsumers`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 3) + + // When + val result = useCase(10) + + // Then + assertThat(result).containsExactly( + 0 until 4, + 4 until 7, + 7 until 10 + ).inOrder() + } + + @Test + fun `should handle single item`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 4) + + // When + val result = useCase(1) + + // Then + assertThat(result).containsExactly(0 until 1).inOrder() + } + + @Test + fun `should handle totalCount equal to MAX_BATCH_SIZE`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 1) + + // When + val result = useCase(2000) + + // Then + assertThat(result).containsExactly(0 until 2000) + } + + @Test + fun `should handle batch sizes that are exactly MAX_BATCH_SIZE`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 2) + + // When + val result = useCase(4000) + + // Then + assertThat(result).containsExactly( + 0 until 2000, + 2000 until 4000 + ).inOrder() + } + + @Test + fun `should create correct ranges when numCandidates is less than numConsumers`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 5) + + // When + val result = useCase(3) + + // Then + assertThat(result).containsExactly( + 0 until 1, + 1 until 2, + 2 until 3 + ).inOrder() + } + + @Test + fun `should create correct ranges with uneven distribution`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 3) + + // When + val result = useCase(11) + + // Then + assertThat(result).containsExactly( + 0 until 4, + 4 until 8, + 8 until 11 + ).inOrder() + } + + @Test + fun `should create empty list when numCandidates is zero`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 5) + + // When + val result = useCase(0) + + // Then + assertThat(result).isEmpty() + } + + @Test + fun `should create single range when numConsumers is one`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 1) + + // When + val result = useCase(10) + + // Then + assertThat(result).containsExactly(0 until 10) + } + + @Test + fun `should handle large numbers correctly`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 4) + + // When + val result = useCase(1000) + + // Then + assertThat(result).containsExactly( + 0 until 250, + 250 until 500, + 500 until 750, + 750 until 1000 + ).inOrder() } @Test - fun `Returns list if single item`() { - assertThat(useCase.invoke(1, 5)).isEqualTo(listOf(0..1)) + fun `should handle 2500 candidates with 4 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 4) + + // When + val result = useCase(2500) + + // Then + // 4 processors, batches under 2000 each, so 4 total batches + // Base size = 2500/4 = 625, remainder = 0 + assertThat(result).containsExactly( + 0 until 625, + 625 until 1250, + 1250 until 1875, + 1875 until 2500 + ).inOrder() } @Test - fun `Returns single item if max withing single batch`() { - assertThat(useCase.invoke(20, 25)).isEqualTo(listOf(0..20)) + fun `should handle 5000 candidates with 4 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 4) + + // When + val result = useCase(5000) + + // Then + // 4 processors, base size approaching MAX_BATCH_SIZE, so still 4 batches + // Base size = 5000/4 = 1250, remainder = 0 + assertThat(result).containsExactly( + 0 until 1250, + 1250 until 2500, + 2500 until 3750, + 3750 until 5000 + ).inOrder() } @Test - fun `Correctly calculates last batch reminder`() { - assertThat(useCase.invoke(17, 10)).isEqualTo( - listOf( - 0..10, - 10..17, - ), - ) + fun `should handle 10000 candidates with 8 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 8) + + // When + val result = useCase(10000) + + // Then + // 8 processors, base size = 10000/8 = 1250, remainder = 0 + assertThat(result).containsExactly( + 0 until 1250, + 1250 until 2500, + 2500 until 3750, + 3750 until 5000, + 5000 until 6250, + 6250 until 7500, + 7500 until 8750, + 8750 until 10000 + ).inOrder() } @Test - fun `Correctly calculates ranges for exact batches`() { - assertThat(useCase.invoke(210, 10)).isEqualTo( - listOf( - 0..10, // size=10 - 10..20, // size=10 - 20..40, // size=20 - 40..70, // size=30 - 70..110, // size=40 - 110..160, // size=50 - 160..210, // size=50 - ), - ) + fun `should limit batch size to 2000 for 15000 candidates with 4 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 4) + + // When + val result = useCase(15000) + + // Then + // Each processor would get 15000/4 = 3750 items, exceeding MAX_BATCH_SIZE + // Need ceiling(15000/(4*2000)) = 2 batches per processor = 8 total batches + // Base size = 15000/8 = 1875, remainder = 0 + assertThat(result).containsExactly( + 0 until 1875, + 1875 until 3750, + 3750 until 5625, + 5625 until 7500, + 7500 until 9375, + 9375 until 11250, + 11250 until 13125, + 13125 until 15000 + ).inOrder() + } + + @Test + fun `should limit batch size to 2000 for 20000 candidates with 8 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 8) + + // When + val result = useCase(20000) + + // Then + // Each processor would get 20000/8 = 2500, exceeding MAX_BATCH_SIZE + // Need ceiling(20000/(8*2000)) = 2 batches per processor = 16 total batches + // Base size = 20000/16 = 1250, remainder = 0 + assertThat(result).containsExactly( + 0 until 1250, + 1250 until 2500, + 2500 until 3750, + 3750 until 5000, + 5000 until 6250, + 6250 until 7500, + 7500 until 8750, + 8750 until 10000, + 10000 until 11250, + 11250 until 12500, + 12500 until 13750, + 13750 until 15000, + 15000 until 16250, + 16250 until 17500, + 17500 until 18750, + 18750 until 20000 + ).inOrder() + } + + @Test + fun `should limit batch size to 2000 for 50000 candidates with 4 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 4) + + // When + val result = useCase(50000) + + // Then + // Need ceiling(50000/(4*2000)) = 7 batches per processor = 28 total batches + // Base size = 50000/28 = 1785, remainder = 20 + // First 20 batches get size 1786, remaining 8 batches get size 1785 + assertThat(result).containsExactly( + 0 until 1786, + 1786 until 3572, + 3572 until 5358, + 5358 until 7144, + 7144 until 8930, + 8930 until 10716, + 10716 until 12502, + 12502 until 14288, + 14288 until 16074, + 16074 until 17860, + 17860 until 19646, + 19646 until 21432, + 21432 until 23218, + 23218 until 25004, + 25004 until 26790, + 26790 until 28576, + 28576 until 30362, + 30362 until 32148, + 32148 until 33934, + 33934 until 35720, + 35720 until 37505, + 37505 until 39290, + 39290 until 41075, + 41075 until 42860, + 42860 until 44645, + 44645 until 46430, + 46430 until 48215, + 48215 until 50000 + ).inOrder() + } + + @Test + fun `should limit batch size to 2000 for 100000 candidates with 8 processors`() { + // Given + val useCase = CreateRangesUseCase(availableProcessors = 8) + + // When + val result = useCase(100000) + + // Then + // Need ceiling(100000/(8*2000)) = 7 batches per processor = 56 total batches + // Base size = 100000/56 = 1785, remainder = 40 + // First 40 batches get size 1786, remaining 16 batches get size 1785 + assertThat(result).containsExactly( + 0 until 1786, + 1786 until 3572, + 3572 until 5358, + 5358 until 7144, + 7144 until 8930, + 8930 until 10716, + 10716 until 12502, + 12502 until 14288, + 14288 until 16074, + 16074 until 17860, + 17860 until 19646, + 19646 until 21432, + 21432 until 23218, + 23218 until 25004, + 25004 until 26790, + 26790 until 28576, + 28576 until 30362, + 30362 until 32148, + 32148 until 33934, + 33934 until 35720, + 35720 until 37506, + 37506 until 39292, + 39292 until 41078, + 41078 until 42864, + 42864 until 44650, + 44650 until 46436, + 46436 until 48222, + 48222 until 50008, + 50008 until 51794, + 51794 until 53580, + 53580 until 55366, + 55366 until 57152, + 57152 until 58938, + 58938 until 60724, + 60724 until 62510, + 62510 until 64296, + 64296 until 66082, + 66082 until 67868, + 67868 until 69654, + 69654 until 71440, + 71440 until 73225, + 73225 until 75010, + 75010 until 76795, + 76795 until 78580, + 78580 until 80365, + 80365 until 82150, + 82150 until 83935, + 83935 until 85720, + 85720 until 87505, + 87505 until 89290, + 89290 until 91075, + 91075 until 92860, + 92860 until 94645, + 94645 until 96430, + 96430 until 98215, + 98215 until 100000 + ).inOrder() } } diff --git a/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt b/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt index d2f06f5fcd..49fcc96fc4 100644 --- a/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt +++ b/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt @@ -53,6 +53,8 @@ internal class FaceMatcherUseCaseTest { enrolmentRecordRepository, resolveFaceBioSdk, createRangesUseCase, + 4, + testCoroutineRule.testCoroutineDispatcher, testCoroutineRule.testCoroutineDispatcher, ) } diff --git a/feature/matcher/src/test/java/com/simprints/matcher/usecases/FingerprintMatcherUseCaseTest.kt b/feature/matcher/src/test/java/com/simprints/matcher/usecases/FingerprintMatcherUseCaseTest.kt index 8fb25993a1..465241b33f 100644 --- a/feature/matcher/src/test/java/com/simprints/matcher/usecases/FingerprintMatcherUseCaseTest.kt +++ b/feature/matcher/src/test/java/com/simprints/matcher/usecases/FingerprintMatcherUseCaseTest.kt @@ -66,6 +66,8 @@ internal class FingerprintMatcherUseCaseTest { resolveBioSdkWrapperUseCase, configManager, createRangesUseCase, + 4, + testCoroutineRule.testCoroutineDispatcher, testCoroutineRule.testCoroutineDispatcher, ) } diff --git a/feature/matcher/src/test/java/com/simprints/matcher/usecases/MatchResultSetTest.kt b/feature/matcher/src/test/java/com/simprints/matcher/usecases/MatchResultSetTest.kt index 96ed8951ce..ef37f711e7 100644 --- a/feature/matcher/src/test/java/com/simprints/matcher/usecases/MatchResultSetTest.kt +++ b/feature/matcher/src/test/java/com/simprints/matcher/usecases/MatchResultSetTest.kt @@ -3,6 +3,9 @@ package com.simprints.matcher.usecases import com.google.common.truth.Truth.assertThat import com.simprints.matcher.FingerprintMatchResult import org.junit.Test +import java.util.concurrent.CountDownLatch +import java.util.concurrent.Executors +import java.util.concurrent.TimeUnit class MatchResultSetTest { @Test @@ -65,4 +68,134 @@ class MatchResultSetTest { ), ) } + + @Test + fun `Concurrent add operations maintain thread safety`() { + val set = MatchResultSet(5) + val threadCount = 10 + val elementsPerThread = 20 + val latch = CountDownLatch(1) + + val executor = Executors.newFixedThreadPool(threadCount) + + // Submit tasks to add items concurrently from multiple threads + repeat(threadCount) { threadIndex -> + executor.submit { + try { + // Wait for all threads to be ready + latch.await() + + // Each thread adds its own batch of elements + repeat(elementsPerThread) { i -> + val confidence = (threadIndex * elementsPerThread + i) / 100f + set.add(FingerprintMatchResult.Item("T$threadIndex-$i", confidence)) + } + } catch (e: Exception) { + e.printStackTrace() + } + } + } + + // Release all threads simultaneously + latch.countDown() + + // Shutdown executor and wait for completion + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) + + // Verify results + val results = set.toList() + + // Should have exactly 5 items (maxSize) + assertThat(results.size).isEqualTo(5) + + // Should be sorted by confidence descending + for (i in 0 until results.size - 1) { + assertThat(results[i].confidence).isAtLeast(results[i + 1].confidence) + } + + // Verify the highest confidence item is at the top + assertThat(results[0].confidence).isEqualTo(1.99f) + } + + @Test + fun `Concurrent addAll operations maintain thread safety`() { + val targetSet = MatchResultSet(5) + val threadCount = 5 + val latch = CountDownLatch(1) + + // Create source sets with different items + val sourceSets = List(threadCount) { threadIndex -> + MatchResultSet(3).apply { + repeat(5) { i -> + val confidence = 0.5f + (threadIndex * 5 + i) / 100f + add(FingerprintMatchResult.Item("S$threadIndex-$i", confidence)) + } + } + } + + val executor = Executors.newFixedThreadPool(threadCount) + + // Submit tasks to merge sets concurrently + repeat(threadCount) { threadIndex -> + executor.submit { + try { + latch.await() + targetSet.addAll(sourceSets[threadIndex]) + } catch (e: Exception) { + e.printStackTrace() + } + } + } + + // Release all threads simultaneously + latch.countDown() + + // Shutdown executor and wait for completion + executor.shutdown() + executor.awaitTermination(10, TimeUnit.SECONDS) + + // Verify results + val results = targetSet.toList() + + // Should have exactly 5 items (maxSize) + assertThat(results.size).isEqualTo(5) + + // Should be sorted by confidence descending + for (i in 0 until results.size - 1) { + assertThat(results[i].confidence).isAtLeast(results[i + 1].confidence) + } + + // Verify the highest confidence item is at the top + assertThat(results[0].confidence).isEqualTo(0.74f) + } + + @Test + fun `addAll correctly filters elements with lower confidence than current minimum`() { + val set = MatchResultSet(3) + + // Add higher confidence items first to fill the set + set.add(FingerprintMatchResult.Item("A", 0.8f)) + set.add(FingerprintMatchResult.Item("B", 0.7f)) + set.add(FingerprintMatchResult.Item("C", 0.6f)) + + // Try to add a new set with lower confidence items + val lowerSet = MatchResultSet(3) + lowerSet.add(FingerprintMatchResult.Item("D", 0.5f)) + lowerSet.add(FingerprintMatchResult.Item("E", 0.4f)) + lowerSet.add(FingerprintMatchResult.Item("F", 0.3f)) + + // Add one higher item to verify it still gets added + lowerSet.add(FingerprintMatchResult.Item("G", 0.9f)) + + set.addAll(lowerSet) + + // Verify results + val results = set.toList() + + assertThat(results).hasSize(3) + assertThat(results[0].confidence).isEqualTo(0.9f) + assertThat(results[1].confidence).isEqualTo(0.8f) + assertThat(results[2].confidence).isEqualTo(0.7f) + } } diff --git a/infra/core/src/main/java/com/simprints/core/CoreModule.kt b/infra/core/src/main/java/com/simprints/core/CoreModule.kt index 5327583052..ffbae4b1c7 100644 --- a/infra/core/src/main/java/com/simprints/core/CoreModule.kt +++ b/infra/core/src/main/java/com/simprints/core/CoreModule.kt @@ -78,6 +78,10 @@ object CoreModule { @Provides fun provideLibSimprintsVersionName(): String = com.simprints.libsimprints.BuildConfig.LIBRARY_PACKAGE_VERSION + @AvailableProcessors + @Provides + fun provideAvailableProcessors(): Int = Runtime.getRuntime().availableProcessors() + @DispatcherIO @Provides fun provideDispatcherIo(): CoroutineDispatcher = Dispatchers.IO @@ -152,6 +156,10 @@ annotation class DispatcherIO @Retention(AnnotationRetention.BINARY) annotation class DispatcherBG +@Qualifier +@Retention(AnnotationRetention.BINARY) +annotation class AvailableProcessors + @Qualifier @Retention(AnnotationRetention.BINARY) annotation class DispatcherMain diff --git a/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/EnrolmentRecordsStoreModule.kt b/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/EnrolmentRecordsStoreModule.kt index 7074fd5388..5aca8df5ad 100644 --- a/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/EnrolmentRecordsStoreModule.kt +++ b/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/EnrolmentRecordsStoreModule.kt @@ -1,6 +1,7 @@ package com.simprints.infra.enrolment.records.repository import android.content.Context +import com.simprints.core.AvailableProcessors import com.simprints.core.DispatcherIO import com.simprints.core.tools.json.JsonHelper import com.simprints.core.tools.utils.EncodingUtils @@ -45,12 +46,14 @@ class IdentityDataSourceModule { encoder: EncodingUtils, jsonHelper: JsonHelper, compareImplicitTokenizedStringsUseCase: CompareImplicitTokenizedStringsUseCase, + @AvailableProcessors availableProcessors: Int, @ApplicationContext context: Context, @DispatcherIO dispatcher: CoroutineDispatcher, ): IdentityDataSource = CommCareIdentityDataSource( encoder = encoder, jsonHelper = jsonHelper, compareImplicitTokenizedStringsUseCase = compareImplicitTokenizedStringsUseCase, + availableProcessors = availableProcessors, context = context, dispatcher = dispatcher, ) diff --git a/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSource.kt b/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSource.kt index 1a2b99cefc..7a3f58a4a3 100644 --- a/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSource.kt +++ b/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSource.kt @@ -7,6 +7,7 @@ import androidx.core.net.toUri import com.fasterxml.jackson.core.type.TypeReference import com.fasterxml.jackson.databind.module.SimpleModule import com.simprints.core.DispatcherBG +import com.simprints.core.AvailableProcessors import com.simprints.core.domain.face.FaceSample import com.simprints.core.domain.fingerprint.FingerprintSample import com.simprints.core.domain.tokenization.TokenizableString @@ -41,6 +42,7 @@ internal class CommCareIdentityDataSource @Inject constructor( private val encoder: EncodingUtils, private val jsonHelper: JsonHelper, private val compareImplicitTokenizedStringsUseCase: CompareImplicitTokenizedStringsUseCase, + @AvailableProcessors private val availableProcessors: Int, @ApplicationContext private val context: Context, @DispatcherBG private val dispatcher: CoroutineDispatcher, ) : IdentityDataSource { @@ -54,11 +56,10 @@ internal class CommCareIdentityDataSource @Inject constructor( dataSource: BiometricDataSource, project: Project, onCandidateLoaded: () -> Unit, - ): List = loadEnrolmentRecordCreationEvents(range, dataSource.callerPackageName(), query, project) + ): List = loadEnrolmentRecordCreationEvents(range, dataSource.callerPackageName(), query, project, onCandidateLoaded) .filter { erce -> erce.payload.biometricReferences.any { it is FingerprintReference && it.format == query.fingerprintSampleFormat } }.map { - onCandidateLoaded() FingerprintIdentity( it.payload.subjectId, it.payload.biometricReferences.filterIsInstance().flatMap { fingerprintReference -> @@ -79,6 +80,7 @@ internal class CommCareIdentityDataSource @Inject constructor( callerPackageName: String, query: SubjectQuery, project: Project, + onCandidateLoaded: () -> Unit, ): List { val enrolmentRecordCreationEvents: MutableList = mutableListOf() try { @@ -100,9 +102,9 @@ internal class CommCareIdentityDataSource @Inject constructor( caseMetadataCursor.getString(caseMetadataCursor.getColumnIndexOrThrow(COLUMN_CASE_ID))?.let { caseId -> enrolmentRecordCreationEvents.addAll( loadEnrolmentRecordCreationEvents(caseId, callerPackageName, query, project), - ) + ).also { onCandidateLoaded() } } - } while (caseMetadataCursor.moveToNext() && caseMetadataCursor.position < range.last) + } while (caseMetadataCursor.moveToNext() && caseMetadataCursor.position <= range.last) } } } catch (e: Exception) { @@ -126,11 +128,10 @@ internal class CommCareIdentityDataSource @Inject constructor( dataSource: BiometricDataSource, project: Project, onCandidateLoaded: () -> Unit, - ): List = loadEnrolmentRecordCreationEvents(range, dataSource.callerPackageName(), query, project) + ): List = loadEnrolmentRecordCreationEvents(range, dataSource.callerPackageName(), query, project, onCandidateLoaded) .filter { erce -> erce.payload.biometricReferences.any { it is FaceReference && it.format == query.faceSampleFormat } }.map { - onCandidateLoaded() FaceIdentity( it.payload.subjectId, it.payload.biometricReferences.filterIsInstance().flatMap { faceReference -> @@ -157,7 +158,7 @@ internal class CommCareIdentityDataSource @Inject constructor( return context.contentResolver .query(caseDataUri, null, null, null, null) ?.use { caseDataCursor -> - var subjectActions = getSubjectActionsValue(caseDataCursor) + val subjectActions = getSubjectActionsValue(caseDataCursor) Simber.d(subjectActions) val coSyncEnrolmentRecordEvents = parseRecordEvents(subjectActions) @@ -258,8 +259,6 @@ internal class CommCareIdentityDataSource @Inject constructor( count } - private val parallelism = Runtime.getRuntime().availableProcessors() - override fun loadFaceIdentities( query: SubjectQuery, ranges: List, @@ -270,7 +269,7 @@ internal class CommCareIdentityDataSource @Inject constructor( ): ReceiveChannel> = loadIdentitiesConcurrently( ranges = ranges, dispatcher = dispatcher, - parallelism = parallelism, + parallelism = availableProcessors, scope = scope, ) { range -> loadFaceIdentities( @@ -292,7 +291,7 @@ internal class CommCareIdentityDataSource @Inject constructor( ): ReceiveChannel> = loadIdentitiesConcurrently( ranges = ranges, dispatcher = dispatcher, - parallelism = parallelism, + parallelism = availableProcessors, scope = scope, ) { range -> loadFingerprintIdentities( diff --git a/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/local/EnrolmentRecordLocalDataSourceImpl.kt b/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/local/EnrolmentRecordLocalDataSourceImpl.kt index 5e1261ff67..5f9347936a 100644 --- a/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/local/EnrolmentRecordLocalDataSourceImpl.kt +++ b/infra/enrolment-records/repository/src/main/java/com/simprints/infra/enrolment/records/repository/local/EnrolmentRecordLocalDataSourceImpl.kt @@ -105,7 +105,8 @@ internal class EnrolmentRecordLocalDataSourceImpl @Inject constructor( realm .query(DbSubject::class) .buildRealmQueryForSubject(query) - .find { it.subList(range.first, range.last) } + // subList's second parameter is exclusive, so we need to add 1 to the last index + .find { it.subList(range.first, range.last+1) } .map { subject -> onCandidateLoaded() FingerprintIdentity( @@ -123,7 +124,8 @@ internal class EnrolmentRecordLocalDataSourceImpl @Inject constructor( realm .query(DbSubject::class) .buildRealmQueryForSubject(query) - .find { it.subList(range.first, range.last) } + // subList's second parameter is exclusive, so we need to add 1 to the last index + .find { it.subList(range.first, range.last+1) } .map { subject -> onCandidateLoaded() FaceIdentity( diff --git a/infra/enrolment-records/repository/src/test/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSourceTest.kt b/infra/enrolment-records/repository/src/test/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSourceTest.kt index db0274fe9a..a08e18d7ca 100644 --- a/infra/enrolment-records/repository/src/test/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSourceTest.kt +++ b/infra/enrolment-records/repository/src/test/java/com/simprints/infra/enrolment/records/repository/commcare/CommCareIdentityDataSourceTest.kt @@ -207,6 +207,7 @@ class CommCareIdentityDataSourceTest { encoder, JsonHelper, useCase, + 4, context, testCoroutineRule.testCoroutineDispatcher, ) @@ -647,7 +648,7 @@ class CommCareIdentityDataSourceTest { val templateFormat = "ISO_19794_2" val query = SubjectQuery(fingerprintSampleFormat = templateFormat) - val range = 0..expectedFingerprintIdentities.size + val range = expectedFingerprintIdentities.indices val actualIdentities = mutableListOf() dataSource .loadFingerprintIdentities(