-
Notifications
You must be signed in to change notification settings - Fork 2
[MS-949] Improve concurrency in matching #1194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
088d400
deadb70
3b95419
841a2c4
457264f
ef741aa
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,45 +1,53 @@ | ||
| package com.simprints.matcher.usecases | ||
|
|
||
| import com.simprints.core.AvailableProcessors | ||
| import javax.inject.Inject | ||
| import kotlin.math.ceil | ||
|
|
||
| internal class CreateRangesUseCase @Inject constructor() { | ||
| internal class CreateRangesUseCase @Inject constructor( | ||
| @AvailableProcessors private val availableProcessors: Int, | ||
| ) { | ||
| /** | ||
| * Creates a list of ranges to be used for batch processing. | ||
| * Range size is increased dynamically to ensure that first couple of batches are small | ||
| * to speed up initial reads, then it increases to ensure that the last batches are not too small. | ||
| * | ||
| * For example with minBatchSize = 10, returned batches will be 10, 10, 20, 30, 40, 50, 50 in size (if the total allows). | ||
| * The number of ranges will be a multiple of the available processors to ensure | ||
| * efficient parallel processing. | ||
| * Range sizes are adjusted to not exceed MAX_BATCH_SIZE. | ||
| */ | ||
| operator fun invoke( | ||
| totalCount: Int, | ||
| minBatchSize: Int = DEFAULT_BATCH_SIZE, | ||
| ): List<IntRange> { | ||
| val ranges = mutableListOf<IntRange>() | ||
| var index = 1 | ||
| if (totalCount <= 0) return emptyList() | ||
|
|
||
| // Calculate how many multiples of processors we need to keep batches under MAX_BATCH_SIZE | ||
| val batchesPerProcessor = ceil(totalCount.toDouble() / (availableProcessors * MAX_BATCH_SIZE)).toInt().coerceAtLeast(1) | ||
| val totalBatches = availableProcessors * batchesPerProcessor | ||
|
|
||
| // Calculate the base batch size and remainder for even distribution | ||
| val baseBatchSize = totalCount / totalBatches | ||
| val remainder = totalCount % totalBatches | ||
|
|
||
| var nextBatchSize = minBatchSize | ||
| val ranges = mutableListOf<IntRange>() | ||
| var start = 0 | ||
| var end = nextBatchSize | ||
|
|
||
| while (start < totalCount) { | ||
| if (end > totalCount) { | ||
| end = totalCount | ||
| } | ||
| ranges.add(start..end) | ||
| // Create ranges with sizes distributed as evenly as possible | ||
| for (i in 0 until totalBatches) { | ||
| // Add 1 to batch size for the first 'remainder' batches to distribute the remainder evenly | ||
| val batchSize = baseBatchSize + if (i < remainder) 1 else 0 | ||
| val end = (start + batchSize).coerceAtMost(totalCount) | ||
|
|
||
| ranges.add(start until end) | ||
| start = end | ||
| end += nextBatchSize | ||
|
|
||
| // Make sure next batch is increased | ||
| nextBatchSize = minBatchSize + (minBatchSize * index.coerceIn(1, 4)) | ||
| index++ | ||
| if (start >= totalCount) break | ||
| } | ||
|
|
||
| return ranges | ||
| } | ||
|
|
||
| companion object { | ||
| /** | ||
| * Experimentally determined batch size that works well for most cases. | ||
| * Maximum size for a batch to avoid huge memory consumption | ||
| */ | ||
| private const val DEFAULT_BATCH_SIZE = 1000 | ||
| private const val MAX_BATCH_SIZE = 2000 | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,8 @@ | ||
| package com.simprints.matcher.usecases | ||
|
|
||
| import com.simprints.core.AvailableProcessors | ||
| import com.simprints.core.DispatcherBG | ||
| import com.simprints.core.DispatcherIO | ||
| import com.simprints.face.infra.basebiosdk.matching.FaceIdentity | ||
| import com.simprints.face.infra.basebiosdk.matching.FaceMatcher | ||
| import com.simprints.face.infra.basebiosdk.matching.FaceSample | ||
|
|
@@ -17,7 +19,9 @@ import kotlinx.coroutines.CoroutineDispatcher | |
| import kotlinx.coroutines.channels.ReceiveChannel | ||
| import kotlinx.coroutines.flow.Flow | ||
| import kotlinx.coroutines.flow.channelFlow | ||
| import kotlinx.coroutines.flow.flowOn | ||
| import kotlinx.coroutines.launch | ||
| import java.util.concurrent.atomic.AtomicInteger | ||
| import javax.inject.Inject | ||
| import kotlin.math.min | ||
| import com.simprints.infra.enrolment.records.repository.domain.models.FaceIdentity as DomainFaceIdentity | ||
|
|
@@ -26,16 +30,12 @@ internal class FaceMatcherUseCase @Inject constructor( | |
| private val enrolmentRecordRepository: EnrolmentRecordRepository, | ||
| private val resolveFaceBioSdk: ResolveFaceBioSdkUseCase, | ||
| private val createRanges: CreateRangesUseCase, | ||
| @AvailableProcessors private val availableProcessors: Int, | ||
| @DispatcherBG private val dispatcherBG: CoroutineDispatcher, | ||
| @DispatcherIO private val dispatcherIO: CoroutineDispatcher, | ||
| ) : MatcherUseCase { | ||
| override val crashReportTag = LoggingConstants.CrashReportTag.FACE_MATCHING | ||
|
|
||
| // When using local DB loadedCandidates = expectedCandidates | ||
| // However, when using CommCare as data source, loadedCandidates < expectedCandidates | ||
| // as it's count function does not take into account filtering criteria | ||
| // This var is not thread safe | ||
| var loadedCandidates = 0 | ||
|
|
||
| override suspend operator fun invoke( | ||
| matchParams: MatchParams, | ||
| project: Project, | ||
|
|
@@ -59,12 +59,17 @@ internal class FaceMatcherUseCase @Inject constructor( | |
| send(MatcherState.Success(emptyList(), 0, bioSdk.matcherName)) | ||
| return@channelFlow | ||
| } | ||
| loadedCandidates = 0 | ||
|
|
||
| Simber.i("Matching candidates", tag = crashReportTag) | ||
| send(MatcherState.LoadingStarted(expectedCandidates)) | ||
|
|
||
| // When using local DB loadedCandidates = expectedCandidates | ||
| // However, when using CommCare as data source, loadedCandidates < expectedCandidates | ||
| // as it's count function does not take into account filtering criteria | ||
| val loadedCandidates = AtomicInteger(0) | ||
| val ranges = createRanges(expectedCandidates) | ||
| // if number of ranges less than the number of cores then use the number of ranges | ||
| val numConsumers = min(Runtime.getRuntime().availableProcessors(), ranges.size) | ||
| val numConsumers = min(availableProcessors, ranges.size) | ||
|
|
||
| val resultSet = MatchResultSet<FaceMatchResult.Item>() | ||
| val candidatesChannel = enrolmentRecordRepository | ||
|
|
@@ -75,7 +80,7 @@ internal class FaceMatcherUseCase @Inject constructor( | |
| project = project, | ||
| scope = this, | ||
| onCandidateLoaded = { | ||
| loadedCandidates++ | ||
| loadedCandidates.incrementAndGet() | ||
| this.trySend(MatcherState.CandidateLoaded) | ||
| }, | ||
| ) | ||
|
|
@@ -88,8 +93,8 @@ internal class FaceMatcherUseCase @Inject constructor( | |
| } | ||
| // Wait for all to complete | ||
| consumerJobs.forEach { it.join() } | ||
| send(MatcherState.Success(resultSet.toList(), loadedCandidates, bioSdk.matcherName)) | ||
| } | ||
| send(MatcherState.Success(resultSet.toList(), loadedCandidates.get(), bioSdk.matcherName)) | ||
| }.flowOn(dispatcherIO) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need .flowOn(dispatcherIO), or at most we should use .flowOn(dispatcherBG). All read operations already run on the IO dispatcher using withContext(dispatcherIO), and all matching logic is executed within launch(dispatcherBG) { ... }.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
That's the problem - they don't! Reading was happening on the main thread because it used the scope passed from the ViewModel! This is why the UI wasn't updating counts until all reading was done.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the flow should run on a background dispatcher, with only the database read operations running on the I/O dispatcher.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That would be the official recommendation. However, in my testing IO provided better performance. I have no explanation why but I don't see any harm, either 🤷 |
||
|
|
||
| suspend fun consumeAndMatch( | ||
| candidatesChannel: ReceiveChannel<List<DomainFaceIdentity>>, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,40 +1,48 @@ | ||
| package com.simprints.matcher.usecases | ||
|
|
||
| import com.simprints.matcher.MatchResultItem | ||
| import java.util.TreeSet | ||
| import java.util.concurrent.ConcurrentSkipListSet | ||
| import java.util.concurrent.atomic.AtomicReference | ||
| import java.util.concurrent.locks.ReentrantLock | ||
| import kotlin.concurrent.withLock | ||
|
|
||
| internal class MatchResultSet<T : MatchResultItem>( | ||
| private val maxSize: Int = MAX_RESULTS, | ||
| ) { | ||
| private var lowestConfidence: Float = 0f | ||
| private val lowestConfidence = AtomicReference(0f) | ||
| private val lock = ReentrantLock() | ||
|
|
||
| private val treeSet = TreeSet( | ||
| private val skipListSet = ConcurrentSkipListSet( | ||
| compareByDescending<T> { it.confidence }.thenByDescending { it.subjectId }, | ||
| ) | ||
|
|
||
| fun add(element: T): MatchResultSet<T> { | ||
| if (lowestConfidence > element.confidence) { | ||
| // skip adding if the last element is greater than the current element | ||
| // Use a lock to ensure thread safety during the entire add operation | ||
| lock.withLock { | ||
| // Only perform this optimization when we know the set is at max capacity | ||
| if (skipListSet.size >= maxSize && lowestConfidence.get() > element.confidence) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are atomic reference for "lowestConfidence" and concurrent collection required if the whole block is synchronised?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For how we use it currently - no. But they should not lead to any further slowdown, either.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Likely splitting hairs, but concurrent data structures are typically much slower.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm tempted to dismiss this concern for structures with max 10 results. However, if you feel strongly about this, we can do a quick benchmark and verify TreeSet indeed works fine in real concurrent situations (CoSync reading with N threads)!? |
||
| // skip adding if the set is full and the last element has higher confidence than the current element | ||
| return this | ||
| } | ||
|
|
||
| skipListSet.add(element) | ||
| if (skipListSet.size > maxSize) { | ||
| skipListSet.pollLast() | ||
|
|
||
| // Now that the set is full, we can skip adding elements | ||
| // with confidence lower than the current lowest | ||
| lowestConfidence.set(skipListSet.last().confidence) | ||
| } | ||
| return this | ||
| } | ||
|
|
||
| treeSet.add(element) | ||
| if (treeSet.size > maxSize) { | ||
| treeSet.pollLast() | ||
|
|
||
| // Not that the set is full, we can skip adding elements | ||
| // with confidence lower than the current lowest | ||
| lowestConfidence = treeSet.last().confidence | ||
| } | ||
| return this | ||
| } | ||
|
|
||
| fun addAll(elements: MatchResultSet<T>): MatchResultSet<T> { | ||
| elements.treeSet.forEach { add(it) } | ||
| elements.skipListSet.forEach { add(it) } | ||
| return this | ||
| } | ||
|
|
||
| fun toList(): List<T> = treeSet.toList() | ||
| fun toList(): List<T> = skipListSet.toList() | ||
|
|
||
| companion object { | ||
| /** | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.