diff --git a/OneSignalSDK/onesignal/core/src/main/java/com/onesignal/common/services/ServiceProvider.kt b/OneSignalSDK/onesignal/core/src/main/java/com/onesignal/common/services/ServiceProvider.kt index c2630baa39..aaa767c3ab 100644 --- a/OneSignalSDK/onesignal/core/src/main/java/com/onesignal/common/services/ServiceProvider.kt +++ b/OneSignalSDK/onesignal/core/src/main/java/com/onesignal/common/services/ServiceProvider.kt @@ -8,11 +8,9 @@ import com.onesignal.debug.internal.logging.Logging class ServiceProvider( registrations: List>, ) : IServiceProvider { - private var serviceMap: Map, List>> + private val serviceMap = mutableMapOf, MutableList>>() init { - val serviceMap = mutableMapOf, MutableList>>() - // go through the registrations to create the service map for easier lookup post-build for (reg in registrations) { for (service in reg.services) { @@ -23,8 +21,6 @@ class ServiceProvider( } } } - - this.serviceMap = serviceMap } internal inline fun hasService(): Boolean { @@ -44,23 +40,27 @@ class ServiceProvider( } override fun hasService(c: Class): Boolean { - return serviceMap.containsKey(c) + synchronized(serviceMap) { + return serviceMap.containsKey(c) + } } override fun getAllServices(c: Class): List { - val listOfServices: MutableList = mutableListOf() + synchronized(serviceMap) { + val listOfServices: MutableList = mutableListOf() - if (serviceMap.containsKey(c)) { - for (serviceReg in serviceMap!![c]!!) { - val service = - serviceReg.resolve(this) as T? - ?: throw Exception("Could not instantiate service: $serviceReg") + if (serviceMap.containsKey(c)) { + for (serviceReg in serviceMap!![c]!!) { + val service = + serviceReg.resolve(this) as T? + ?: throw Exception("Could not instantiate service: $serviceReg") - listOfServices.add(service) + listOfServices.add(service) + } } - } - return listOfServices + return listOfServices + } } override fun getService(c: Class): T { @@ -74,11 +74,10 @@ class ServiceProvider( } override fun getServiceOrNull(c: Class): T? { - Logging.debug("${indent}Retrieving service $c") -// indent += " " - val service = serviceMap[c]?.last()?.resolve(this) as T? -// indent = indent.substring(0, indent.length-2) - return service + synchronized(serviceMap) { + Logging.debug("${indent}Retrieving service $c") + return serviceMap[c]?.last()?.resolve(this) as T? + } } companion object { diff --git a/OneSignalSDK/onesignal/core/src/test/java/com/onesignal/common/ServiceProviderTest.kt b/OneSignalSDK/onesignal/core/src/test/java/com/onesignal/common/ServiceProviderTest.kt new file mode 100644 index 0000000000..4738d513ae --- /dev/null +++ b/OneSignalSDK/onesignal/core/src/test/java/com/onesignal/common/ServiceProviderTest.kt @@ -0,0 +1,44 @@ +package com.onesignal.common + +import com.onesignal.common.services.ServiceBuilder +import com.onesignal.common.services.ServiceProvider +import io.kotest.core.spec.style.FunSpec +import io.kotest.matchers.types.shouldBeSameInstanceAs +import java.util.concurrent.LinkedBlockingQueue + +internal interface IMyTestInterface + +internal class MySlowConstructorClass : IMyTestInterface { + init { + // NOTE: Keep these println calls, otherwise Kotlin optimizes + // something which cases the test not fail when it should. + println("MySlowConstructorClass BEFORE") + Thread.sleep(10) + println("MySlowConstructorClass AFTER") + } +} + +class ServiceProviderTest : FunSpec({ + + fun setupServiceProviderWithSlowInitClass(): ServiceProvider { + val serviceBuilder = ServiceBuilder() + serviceBuilder.register().provides() + return serviceBuilder.build() + } + + test("getService is thread safe") { + val services = setupServiceProviderWithSlowInitClass() + + val queue = LinkedBlockingQueue() + Thread { + queue.add(services.getService()) + }.start() + Thread { + queue.add(services.getService()) + }.start() + + val firstReference = queue.take() + val secondReference = queue.take() + firstReference shouldBeSameInstanceAs secondReference + } +})