diff --git a/face/infra/base-bio-sdk/src/main/java/com/simprints/face/infra/basebiosdk/matching/FaceMatcher.kt b/face/infra/base-bio-sdk/src/main/java/com/simprints/face/infra/basebiosdk/matching/FaceMatcher.kt index 26094588f0..80daab7261 100644 --- a/face/infra/base-bio-sdk/src/main/java/com/simprints/face/infra/basebiosdk/matching/FaceMatcher.kt +++ b/face/infra/base-bio-sdk/src/main/java/com/simprints/face/infra/basebiosdk/matching/FaceMatcher.kt @@ -1,38 +1,13 @@ package com.simprints.face.infra.basebiosdk.matching -abstract class FaceMatcher { +abstract class FaceMatcher( + open val probeSamples: List +) : AutoCloseable { /** - * The matching SDK name - */ - abstract val matcherName: String - abstract val supportedTemplateFormat: String - - /** - * Returns a comparison score of two templates from 0.0 - 100.0 - */ - abstract suspend fun getComparisonScore( - probe: ByteArray, - matchAgainst: ByteArray, - ): Float - - /** - * Get highest comparison score for matching candidate template against all probes + * Get highest comparison score for matching candidate template against samples * - * @param probes * @param candidate * @return the highest comparison score */ - suspend fun getHighestComparisonScoreForCandidate( - probes: List, - candidate: FaceIdentity, - ): Float { - var highestScore = 0f - probes.forEach { probe -> - candidate.faces.forEach { face -> - val score = getComparisonScore(probe.template, face.template) - if (score > highestScore) highestScore = score - } - } - return highestScore - } + abstract suspend fun getHighestComparisonScoreForCandidate(candidate: FaceIdentity): Float } diff --git a/face/infra/base-bio-sdk/src/test/java/com/simprints/infra/facebiosdk/matching/FaceMatcherTest.kt b/face/infra/base-bio-sdk/src/test/java/com/simprints/infra/facebiosdk/matching/FaceMatcherTest.kt deleted file mode 100644 index ab3f93c2ce..0000000000 --- a/face/infra/base-bio-sdk/src/test/java/com/simprints/infra/facebiosdk/matching/FaceMatcherTest.kt +++ /dev/null @@ -1,44 +0,0 @@ -package com.simprints.infra.facebiosdk.matching - -import com.google.common.truth.Truth.assertThat -import com.simprints.face.infra.basebiosdk.matching.FaceIdentity -import com.simprints.face.infra.basebiosdk.matching.FaceMatcher -import com.simprints.face.infra.basebiosdk.matching.FaceSample -import io.mockk.coEvery -import io.mockk.spyk -import kotlinx.coroutines.test.runTest -import org.junit.Test -import java.util.UUID -import kotlin.random.Random - -class FaceMatcherTest { - private val faceMatcher = spyk() - private val candidate1 = getFaceIdentity(2) - private val probes = generateSequenceN(2) { getFaceSample() }.toList() - - @Test - fun `Get highest score for a candidate`() = runTest { - coEvery { faceMatcher.getComparisonScore(any(), any()) } returnsMany listOf( - 0.1f, - 0.2f, - 0.3f, - 0.4f, - ) - - val score = faceMatcher.getHighestComparisonScoreForCandidate(probes, candidate1) - - assertThat(score).isEqualTo(0.4f) - } - - private fun getFaceIdentity(numFaces: Int): FaceIdentity = FaceIdentity( - UUID.randomUUID().toString(), - generateSequenceN(numFaces) { getFaceSample() }.toList(), - ) - - private fun getFaceSample(): FaceSample = FaceSample(UUID.randomUUID().toString(), Random.nextBytes(20)) - - private fun generateSequenceN( - n: Int, - f: () -> T, - ) = generateSequence(f).take(n) -} diff --git a/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/FaceBioSDK.kt b/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/FaceBioSDK.kt index 671695321e..2918f4ec3a 100644 --- a/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/FaceBioSDK.kt +++ b/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/FaceBioSDK.kt @@ -3,10 +3,15 @@ package com.simprints.face.infra.biosdkresolver import com.simprints.face.infra.basebiosdk.detection.FaceDetector import com.simprints.face.infra.basebiosdk.initialization.FaceBioSdkInitializer import com.simprints.face.infra.basebiosdk.matching.FaceMatcher +import com.simprints.face.infra.basebiosdk.matching.FaceSample interface FaceBioSDK { val initializer: FaceBioSdkInitializer val detector: FaceDetector - val matcher: FaceMatcher + val version: String + val templateFormat: String + val matcherName: String + + fun createMatcher(probeSamples: List): FaceMatcher } diff --git a/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdk.kt b/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdk.kt index d88a9fd091..d515368562 100644 --- a/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdk.kt +++ b/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdk.kt @@ -1,6 +1,9 @@ package com.simprints.face.infra.biosdkresolver +import com.simprints.face.infra.basebiosdk.matching.FaceMatcher +import com.simprints.face.infra.basebiosdk.matching.FaceSample import com.simprints.face.infra.rocv1.detection.RocV1Detector +import com.simprints.face.infra.rocv1.detection.RocV1Detector.Companion.RANK_ONE_TEMPLATE_FORMAT_1_23 import com.simprints.face.infra.rocv1.initialization.RocV1Initializer import com.simprints.face.infra.rocv1.matching.RocV1Matcher import javax.inject.Inject @@ -10,7 +13,10 @@ import javax.inject.Singleton class RocV1BioSdk @Inject constructor( override val initializer: RocV1Initializer, override val detector: RocV1Detector, - override val matcher: RocV1Matcher, ) : FaceBioSDK { override val version: String = "1.23" + override val templateFormat: String = RANK_ONE_TEMPLATE_FORMAT_1_23 + override val matcherName: String = "RANK_ONE" + + override fun createMatcher(probeSamples: List): FaceMatcher = RocV1Matcher(probeSamples) } diff --git a/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdk.kt b/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdk.kt index 38f7c515e2..f382030133 100644 --- a/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdk.kt +++ b/face/infra/bio-sdk-resolver/src/main/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdk.kt @@ -1,6 +1,9 @@ package com.simprints.face.infra.biosdkresolver +import com.simprints.face.infra.basebiosdk.matching.FaceMatcher +import com.simprints.face.infra.basebiosdk.matching.FaceSample import com.simprints.face.infra.rocv3.detection.RocV3Detector +import com.simprints.face.infra.rocv3.detection.RocV3Detector.Companion.RANK_ONE_TEMPLATE_FORMAT_3_1 import com.simprints.face.infra.rocv3.initialization.RocV3Initializer import com.simprints.face.infra.rocv3.matching.RocV3Matcher import javax.inject.Inject @@ -10,7 +13,10 @@ import javax.inject.Singleton class RocV3BioSdk @Inject constructor( override val initializer: RocV3Initializer, override val detector: RocV3Detector, - override val matcher: RocV3Matcher, ) : FaceBioSDK { override val version: String = "3.1" + override val templateFormat: String = RANK_ONE_TEMPLATE_FORMAT_3_1 + override val matcherName: String = "RANK_ONE" + + override fun createMatcher(probeSamples: List): FaceMatcher = RocV3Matcher(probeSamples) } diff --git a/face/infra/bio-sdk-resolver/src/test/java/ResolveFaceBioSdkUseCaseTest.kt b/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/ResolveFaceBioSdkUseCaseTest.kt similarity index 87% rename from face/infra/bio-sdk-resolver/src/test/java/ResolveFaceBioSdkUseCaseTest.kt rename to face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/ResolveFaceBioSdkUseCaseTest.kt index bc1c092497..98e84280a6 100644 --- a/face/infra/bio-sdk-resolver/src/test/java/ResolveFaceBioSdkUseCaseTest.kt +++ b/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/ResolveFaceBioSdkUseCaseTest.kt @@ -1,13 +1,11 @@ package com.simprints.face.infra.biosdkresolver -import com.google.common.truth.Truth.assertThat +import com.google.common.truth.Truth.* import com.simprints.infra.config.store.ConfigRepository -import io.mockk.coEvery -import io.mockk.mockk +import io.mockk.* import kotlinx.coroutines.test.runTest import org.junit.Before import org.junit.Test -import java.lang.IllegalArgumentException class ResolveFaceBioSdkUseCaseTest { private lateinit var resolveFaceBioSdkUseCase: ResolveFaceBioSdkUseCase @@ -17,9 +15,10 @@ class ResolveFaceBioSdkUseCaseTest { @Before fun setUp() { - rocV1BioSdk = RocV1BioSdk(mockk(), mockk(), mockk()) - rocV3BioSdk = RocV3BioSdk(mockk(), mockk(), mockk()) - resolveFaceBioSdkUseCase = ResolveFaceBioSdkUseCase(configRepository, rocV1BioSdk, rocV3BioSdk) + rocV1BioSdk = RocV1BioSdk(mockk(), mockk()) + rocV3BioSdk = RocV3BioSdk(mockk(), mockk()) + resolveFaceBioSdkUseCase = + ResolveFaceBioSdkUseCase(configRepository, rocV1BioSdk, rocV3BioSdk) } @Test(expected = IllegalArgumentException::class) diff --git a/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdkTest.kt b/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdkTest.kt new file mode 100644 index 0000000000..8a42d5549f --- /dev/null +++ b/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/RocV1BioSdkTest.kt @@ -0,0 +1,22 @@ +package com.simprints.face.infra.biosdkresolver + + +import com.google.common.truth.Truth.* +import io.mockk.* +import org.junit.Test + +class RocV1BioSdkTest { + + + private lateinit var rocV1BioSdk: RocV1BioSdk + + @Test + fun createMatcher() { + rocV1BioSdk = RocV1BioSdk(mockk(), mockk()) + + val matcher = rocV1BioSdk.createMatcher(emptyList()) + + assertThat(matcher).isNotNull() + } + +} diff --git a/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdkTest.kt b/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdkTest.kt new file mode 100644 index 0000000000..080431f071 --- /dev/null +++ b/face/infra/bio-sdk-resolver/src/test/java/com/simprints/face/infra/biosdkresolver/RocV3BioSdkTest.kt @@ -0,0 +1,21 @@ +package com.simprints.face.infra.biosdkresolver + + +import com.google.common.truth.Truth.* +import io.mockk.* +import org.junit.Test + +class RocV3BioSdkTest { + + private lateinit var rocV3BioSdk: RocV3BioSdk + + @Test + fun createMatcher() { + rocV3BioSdk = RocV3BioSdk(mockk(), mockk()) + + val matcher = rocV3BioSdk.createMatcher(emptyList()) + + assertThat(matcher).isNotNull() + } + +} diff --git a/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/RocV1WrapperModule.kt b/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/RocV1WrapperModule.kt index 0d455a63ad..39f786c602 100644 --- a/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/RocV1WrapperModule.kt +++ b/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/RocV1WrapperModule.kt @@ -5,7 +5,6 @@ import com.simprints.face.infra.basebiosdk.initialization.FaceBioSdkInitializer import com.simprints.face.infra.basebiosdk.matching.FaceMatcher import com.simprints.face.infra.rocv1.detection.RocV1Detector import com.simprints.face.infra.rocv1.initialization.RocV1Initializer -import com.simprints.face.infra.rocv1.matching.RocV1Matcher import dagger.Binds import dagger.Module import dagger.hilt.InstallIn @@ -19,7 +18,4 @@ abstract class RocV1WrapperModule { @Binds abstract fun provideFaceDetector(impl: RocV1Detector): FaceDetector - - @Binds - abstract fun provideFaceMatcher(impl: RocV1Matcher): FaceMatcher } diff --git a/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/matching/RocV1Matcher.kt b/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/matching/RocV1Matcher.kt index 06922e15ca..5e8a3bfc38 100644 --- a/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/matching/RocV1Matcher.kt +++ b/face/infra/roc-v1/src/main/java/com/simprints/face/infra/rocv1/matching/RocV1Matcher.kt @@ -1,33 +1,40 @@ package com.simprints.face.infra.rocv1.matching import com.simprints.core.ExcludedFromGeneratedTestCoverageReports +import com.simprints.face.infra.basebiosdk.matching.FaceIdentity import com.simprints.face.infra.basebiosdk.matching.FaceMatcher -import com.simprints.face.infra.rocv1.detection.RocV1Detector.Companion.RANK_ONE_TEMPLATE_FORMAT_1_23 +import com.simprints.face.infra.basebiosdk.matching.FaceSample +import io.rankone.rocsdk.embedded.SWIGTYPE_p_unsigned_char import io.rankone.rocsdk.embedded.roc import io.rankone.rocsdk.embedded.rocConstants.ROC_FAST_FV_SIZE -import javax.inject.Inject - -class RocV1Matcher @Inject constructor() : FaceMatcher() { - override val matcherName - get() = "RANK_ONE" - - override val supportedTemplateFormat - get() = RANK_ONE_TEMPLATE_FORMAT_1_23 - - // Ignore this method from test coverage calculations - // because it uses jni native code which is hard to test - @ExcludedFromGeneratedTestCoverageReports( - reason = "This function uses roc class that has native functions and can't be mocked", - ) - override suspend fun getComparisonScore( - probe: ByteArray, - matchAgainst: ByteArray, - ): Float { - val probeTemplate = roc.new_uint8_t_array(ROC_FAST_FV_SIZE.toInt()) - roc.memmove(roc.roc_cast(probeTemplate), probe) +@ExcludedFromGeneratedTestCoverageReports( + reason = "This function uses roc class that has native functions and can't be mocked", +) +class RocV1Matcher( + override val probeSamples: List +) : FaceMatcher(probeSamples) { + + var probeTemplates: List = probeSamples.mapIndexed { i, probe -> + val probeTemplate: SWIGTYPE_p_unsigned_char = + roc.new_uint8_t_array(ROC_FAST_FV_SIZE.toInt()) + roc.memmove(roc.roc_cast(probeTemplate), probe.template) + probeTemplate + } + + override suspend fun getHighestComparisonScoreForCandidate(candidate: FaceIdentity): Float = + probeTemplates.flatMap { probeTemplate -> + candidate.faces.map { face -> + getSimilarityScoreForCandidate(probeTemplate, face.template) + } + }.max() + + private fun getSimilarityScoreForCandidate( + probeTemplate: SWIGTYPE_p_unsigned_char, + candidateTemplate: ByteArray + ): Float { val matchTemplate = roc.new_uint8_t_array(ROC_FAST_FV_SIZE.toInt()) - roc.memmove(roc.roc_cast(matchTemplate), matchAgainst) + roc.memmove(roc.roc_cast(matchTemplate), candidateTemplate) val similarity = roc.roc_embedded_compare_templates( probeTemplate, @@ -35,10 +42,12 @@ class RocV1Matcher @Inject constructor() : FaceMatcher() { matchTemplate, ROC_FAST_FV_SIZE, ) - - roc.delete_uint8_t_array(probeTemplate) roc.delete_uint8_t_array(matchTemplate) - return (similarity * 100) + return similarity * 100f + } + + override fun close() { + probeTemplates.forEach { roc.delete_uint8_t_array(it) } } } diff --git a/face/infra/roc-v1/src/test/java/com/simprints/infra/rocwrapper/matching/RocV1MatcherTest.kt b/face/infra/roc-v1/src/test/java/com/simprints/infra/rocwrapper/matching/RocV1MatcherTest.kt index dc53eadf04..4fe85a6425 100644 --- a/face/infra/roc-v1/src/test/java/com/simprints/infra/rocwrapper/matching/RocV1MatcherTest.kt +++ b/face/infra/roc-v1/src/test/java/com/simprints/infra/rocwrapper/matching/RocV1MatcherTest.kt @@ -1,6 +1,6 @@ package com.simprints.infra.rocwrapper.matching -import com.google.common.truth.Truth +import com.google.common.truth.* import com.simprints.face.infra.rocv1.matching.RocV1Matcher import org.junit.Test @@ -8,7 +8,6 @@ class RocV1MatcherTest { // Dummy test to generate jacoco reports. @Test fun getMatcherName() { - RocV1Matcher().matcherName - Truth.assertThat(RocV1Matcher().matcherName).isEqualTo("RANK_ONE") + Truth.assertThat(RocV1Matcher(emptyList())).isNotNull() } } diff --git a/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/RocV3WrapperModule.kt b/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/RocV3WrapperModule.kt index 3aff32a7cc..e267f1869e 100644 --- a/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/RocV3WrapperModule.kt +++ b/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/RocV3WrapperModule.kt @@ -2,10 +2,8 @@ package com.simprints.face.infra.rocv3 import com.simprints.face.infra.basebiosdk.detection.FaceDetector import com.simprints.face.infra.basebiosdk.initialization.FaceBioSdkInitializer -import com.simprints.face.infra.basebiosdk.matching.FaceMatcher import com.simprints.face.infra.rocv3.detection.RocV3Detector import com.simprints.face.infra.rocv3.initialization.RocV3Initializer -import com.simprints.face.infra.rocv3.matching.RocV3Matcher import dagger.Binds import dagger.Module import dagger.hilt.InstallIn @@ -19,7 +17,4 @@ abstract class RocV3WrapperModule { @Binds abstract fun provideFaceDetector(impl: RocV3Detector): FaceDetector - - @Binds - abstract fun provideFaceMatcher(impl: RocV3Matcher): FaceMatcher } diff --git a/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/matching/RocV3Matcher.kt b/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/matching/RocV3Matcher.kt index 7f218bbceb..db2c4c52da 100644 --- a/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/matching/RocV3Matcher.kt +++ b/face/infra/roc-v3/src/main/java/com/simprints/face/infra/rocv3/matching/RocV3Matcher.kt @@ -1,35 +1,40 @@ package com.simprints.face.infra.rocv3.matching +import ai.roc.rocsdk.embedded.SWIGTYPE_p_unsigned_char import ai.roc.rocsdk.embedded.roc import ai.roc.rocsdk.embedded.rocConstants.ROC_FACE_FAST_FV_SIZE import com.simprints.core.ExcludedFromGeneratedTestCoverageReports +import com.simprints.face.infra.basebiosdk.matching.FaceIdentity import com.simprints.face.infra.basebiosdk.matching.FaceMatcher -import com.simprints.face.infra.rocv3.detection.RocV3Detector.Companion.RANK_ONE_TEMPLATE_FORMAT_3_1 -import javax.inject.Inject -import javax.inject.Singleton - -@Singleton -class RocV3Matcher @Inject constructor() : FaceMatcher() { - override val matcherName - get() = "RANK_ONE" - - override val supportedTemplateFormat - get() = RANK_ONE_TEMPLATE_FORMAT_3_1 - - // Ignore this method from test coverage calculations - // because it uses jni native code which is hard to test - @ExcludedFromGeneratedTestCoverageReports( - reason = "This function uses roc class that has native functions and can't be mocked", - ) - override suspend fun getComparisonScore( - probe: ByteArray, - matchAgainst: ByteArray, - ): Float { - val probeTemplate = roc.new_uint8_t_array(ROC_FACE_FAST_FV_SIZE.toInt()) - roc.memmove(roc.roc_cast(probeTemplate), probe) +import com.simprints.face.infra.basebiosdk.matching.FaceSample + +@ExcludedFromGeneratedTestCoverageReports( + reason = "This function uses roc class that has native functions and can't be mocked", +) +class RocV3Matcher( + override val probeSamples: List +) : FaceMatcher(probeSamples) { + + var probeTemplates: List = probeSamples.mapIndexed { i, probe -> + val probeTemplate: SWIGTYPE_p_unsigned_char = + roc.new_uint8_t_array(ROC_FACE_FAST_FV_SIZE.toInt()) + roc.memmove(roc.roc_cast(probeTemplate), probe.template) + probeTemplate + } + override suspend fun getHighestComparisonScoreForCandidate(candidate: FaceIdentity): Float = + probeTemplates.flatMap { probeTemplate -> + candidate.faces.map { face -> + getSimilarityScoreForCandidate(probeTemplate, face.template) + } + }.max() + + private fun getSimilarityScoreForCandidate( + probeTemplate: SWIGTYPE_p_unsigned_char, + candidateTemplate: ByteArray, + ): Float { val matchTemplate = roc.new_uint8_t_array(ROC_FACE_FAST_FV_SIZE.toInt()) - roc.memmove(roc.roc_cast(matchTemplate), matchAgainst) + roc.memmove(roc.roc_cast(matchTemplate), candidateTemplate) val similarity = roc.roc_embedded_compare_templates( probeTemplate, @@ -37,10 +42,12 @@ class RocV3Matcher @Inject constructor() : FaceMatcher() { matchTemplate, ROC_FACE_FAST_FV_SIZE, ) - - roc.delete_uint8_t_array(probeTemplate) roc.delete_uint8_t_array(matchTemplate) - return (similarity * 100) + return similarity * 100f + } + + override fun close() { + probeTemplates.forEach { roc.delete_uint8_t_array(it) } } } diff --git a/face/infra/roc-v3/src/test/java/com/simprints/infra/rocwrapper/matching/RocV3MatcherTest.kt b/face/infra/roc-v3/src/test/java/com/simprints/infra/rocwrapper/matching/RocV3MatcherTest.kt index 6e27bb8aca..bc2424efb2 100644 --- a/face/infra/roc-v3/src/test/java/com/simprints/infra/rocwrapper/matching/RocV3MatcherTest.kt +++ b/face/infra/roc-v3/src/test/java/com/simprints/infra/rocwrapper/matching/RocV3MatcherTest.kt @@ -4,10 +4,10 @@ import com.google.common.truth.Truth import com.simprints.face.infra.rocv3.matching.RocV3Matcher import org.junit.Test +// Dummy test to generate jacoco reports. class RocV3MatcherTest { @Test fun getMatcherName() { - RocV3Matcher().matcherName - Truth.assertThat(RocV3Matcher().matcherName).isEqualTo("RANK_ONE") + Truth.assertThat(RocV3Matcher(emptyList())).isNotNull() } } diff --git a/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt b/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt index 82d70f9a32..054a85ff92 100644 --- a/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt +++ b/feature/matcher/src/main/java/com/simprints/matcher/usecases/FaceMatcherUseCase.kt @@ -28,7 +28,6 @@ internal class FaceMatcherUseCase @Inject constructor( private val createRanges: CreateRangesUseCase, @DispatcherBG private val dispatcher: CoroutineDispatcher, ) : MatcherUseCase { - private lateinit var faceMatcher: FaceMatcher override val crashReportTag = LoggingConstants.CrashReportTag.FACE_MATCHING override suspend operator fun invoke( @@ -36,21 +35,22 @@ internal class FaceMatcherUseCase @Inject constructor( project: Project, ): Flow = channelFlow { Simber.i("Initialising matcher", tag = crashReportTag) - faceMatcher = resolveFaceBioSdk().matcher + val bioSdk = resolveFaceBioSdk() + if (matchParams.probeFaceSamples.isEmpty()) { - send(MatcherState.Success(emptyList(), 0, faceMatcher.matcherName)) + send(MatcherState.Success(emptyList(), 0, bioSdk.matcherName)) return@channelFlow } val samples = mapSamples(matchParams.probeFaceSamples) val queryWithSupportedFormat = matchParams.queryForCandidates.copy( - faceSampleFormat = faceMatcher.supportedTemplateFormat, + faceSampleFormat = bioSdk.templateFormat, ) val expectedCandidates = enrolmentRecordRepository.count( queryWithSupportedFormat, dataSource = matchParams.biometricDataSource, ) if (expectedCandidates == 0) { - send(MatcherState.Success(emptyList(), 0, faceMatcher.matcherName)) + send(MatcherState.Success(emptyList(), 0, bioSdk.matcherName)) return@channelFlow } @@ -61,6 +61,7 @@ internal class FaceMatcherUseCase @Inject constructor( // as it's count function does not take into account filtering criteria var loadedCandidates = 0 val resultItems = coroutineScope { + createRanges(expectedCandidates) .map { range -> async(dispatcher) { @@ -74,7 +75,7 @@ internal class FaceMatcherUseCase @Inject constructor( loadedCandidates++ trySend(MatcherState.CandidateLoaded) } - match(batchCandidates, samples) + bioSdk.createMatcher(samples).use { match(it, batchCandidates) } } }.awaitAll() .reduce { acc, subSet -> acc.addAll(subSet) } @@ -83,7 +84,7 @@ internal class FaceMatcherUseCase @Inject constructor( Simber.i("Matched $loadedCandidates candidates", tag = crashReportTag) - send(MatcherState.Success(resultItems, loadedCandidates, faceMatcher.matcherName)) + send(MatcherState.Success(resultItems, loadedCandidates, bioSdk.matcherName)) } private fun mapSamples(probes: List) = probes.map { FaceSample(it.faceId, it.template) } @@ -104,13 +105,13 @@ internal class FaceMatcherUseCase @Inject constructor( } private suspend fun match( - batchCandidates: List, - samples: List, - ) = batchCandidates.fold(MatchResultSet()) { acc, item -> + matcher: FaceMatcher, + batchCandidates: List + ) = batchCandidates.fold(MatchResultSet()) { acc, candidate -> acc.add( FaceMatchResult.Item( - item.subjectId, - faceMatcher.getHighestComparisonScoreForCandidate(samples, item), + candidate.subjectId, + matcher.getHighestComparisonScoreForCandidate(candidate), ), ) } diff --git a/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt b/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt index d9021cf83e..255e25cf9d 100644 --- a/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt +++ b/feature/matcher/src/test/java/com/simprints/matcher/usecases/FaceMatcherUseCaseTest.kt @@ -1,7 +1,7 @@ package com.simprints.matcher.usecases import androidx.arch.core.executor.testing.InstantTaskExecutorRule -import com.google.common.truth.Truth.assertThat +import com.google.common.truth.Truth.* import com.simprints.core.domain.common.FlowType import com.simprints.core.domain.face.FaceSample import com.simprints.face.infra.basebiosdk.matching.FaceMatcher @@ -14,9 +14,7 @@ import com.simprints.infra.enrolment.records.repository.domain.models.SubjectQue import com.simprints.matcher.FaceMatchResult import com.simprints.matcher.MatchParams import com.simprints.testtools.common.coroutines.TestCoroutineRule -import io.mockk.MockKAnnotations -import io.mockk.coEvery -import io.mockk.coVerify +import io.mockk.* import io.mockk.impl.annotations.MockK import kotlinx.coroutines.flow.toList import kotlinx.coroutines.test.runTest @@ -50,7 +48,7 @@ internal class FaceMatcherUseCaseTest { @Before fun setUp() { MockKAnnotations.init(this, relaxed = true) - coEvery { resolveFaceBioSdk().matcher } returns faceMatcher + coEvery { resolveFaceBioSdk().createMatcher(any()) } returns faceMatcher useCase = FaceMatcherUseCase( enrolmentRecordRepository, resolveFaceBioSdk, @@ -61,7 +59,7 @@ internal class FaceMatcherUseCaseTest { @Test fun `Skips matching if there are no probes`() = runTest { - coEvery { faceMatcher.getHighestComparisonScoreForCandidate(any(), any()) } returns 1f + coEvery { faceMatcher.getHighestComparisonScoreForCandidate(any()) } returns 1f val results = useCase .invoke( @@ -74,7 +72,7 @@ internal class FaceMatcherUseCaseTest { project, ).toList() - coVerify(exactly = 0) { faceMatcher.getHighestComparisonScoreForCandidate(any(), any()) } + coVerify(exactly = 0) { faceMatcher.getHighestComparisonScoreForCandidate(any()) } assertThat(results).containsExactly( MatcherUseCase.MatcherState.Success( @@ -103,7 +101,7 @@ internal class FaceMatcherUseCaseTest { project, ).toList() - coVerify(exactly = 0) { faceMatcher.getHighestComparisonScoreForCandidate(any(), any()) } + coVerify(exactly = 0) { faceMatcher.getHighestComparisonScoreForCandidate(any()) } assertThat(results).containsExactly( MatcherUseCase.MatcherState.Success( @@ -133,7 +131,7 @@ internal class FaceMatcherUseCaseTest { // Return the face identities faceIdentities } - coEvery { faceMatcher.getHighestComparisonScoreForCandidate(any(), any()) } returns 42f + coEvery { faceMatcher.getHighestComparisonScoreForCandidate(any()) } returns 42f val results = useCase .invoke( @@ -149,7 +147,7 @@ internal class FaceMatcherUseCaseTest { project, ).toList() - coVerify { faceMatcher.getHighestComparisonScoreForCandidate(any(), any()) } + coVerify { faceMatcher.getHighestComparisonScoreForCandidate(any()) } assertThat(results).containsExactly( MatcherUseCase.MatcherState.LoadingStarted(totalCandidates),