diff --git a/infra/events/src/main/java/com/simprints/infra/events/event/local/EventDatabaseFactory.kt b/infra/events/src/main/java/com/simprints/infra/events/event/local/EventDatabaseFactory.kt index dd2126bb4a..447111fe4d 100644 --- a/infra/events/src/main/java/com/simprints/infra/events/event/local/EventDatabaseFactory.kt +++ b/infra/events/src/main/java/com/simprints/infra/events/event/local/EventDatabaseFactory.kt @@ -7,17 +7,28 @@ import dagger.hilt.android.qualifiers.ApplicationContext import net.sqlcipher.database.SQLiteDatabase.getBytes import net.sqlcipher.database.SupportFactory import javax.inject.Inject +import javax.inject.Singleton +@Singleton internal class EventDatabaseFactory @Inject constructor( @ApplicationContext val ctx: Context, private val securityManager: SecurityManager, ) { - fun build(): EventRoomDatabase { + private lateinit var eventDatabase: EventRoomDatabase + + fun get(): EventRoomDatabase { + if (!::eventDatabase.isInitialized) { + build() + } + return eventDatabase + } + + private fun build() { try { val key = getOrCreateKey(DB_NAME) val passphrase: ByteArray = getBytes(key) val factory = SupportFactory(passphrase) - return EventRoomDatabase.getDatabase( + eventDatabase = EventRoomDatabase.getDatabase( ctx, factory, DB_NAME, @@ -38,12 +49,14 @@ internal class EventDatabaseFactory @Inject constructor( securityManager.getLocalDbKeyOrThrow(dbName) }.value.decodeToString().toCharArray() - fun deleteDatabase() { + fun recreateDatabase() { + // DB corruption detected; either DB file or key is corrupt + // 1. Delete DB file in order to create a new one at next init ctx.deleteDatabase(DB_NAME) - } - - fun recreateDatabaseKey() { + // 2. Recreate the DB key securityManager.recreateLocalDatabaseKey(DB_NAME) + // 3. Rebuild the DB + build() } companion object { diff --git a/infra/events/src/main/java/com/simprints/infra/events/event/local/EventLocalDataSource.kt b/infra/events/src/main/java/com/simprints/infra/events/event/local/EventLocalDataSource.kt index 618d8b829c..1c7f53efbe 100644 --- a/infra/events/src/main/java/com/simprints/infra/events/event/local/EventLocalDataSource.kt +++ b/infra/events/src/main/java/com/simprints/infra/events/event/local/EventLocalDataSource.kt @@ -22,17 +22,19 @@ import kotlinx.coroutines.sync.Mutex import kotlinx.coroutines.sync.withLock import kotlinx.coroutines.withContext import javax.inject.Inject +import javax.inject.Singleton import kotlin.coroutines.CoroutineContext +@Singleton internal open class EventLocalDataSource @Inject constructor( private val eventDatabaseFactory: EventDatabaseFactory, private val jsonHelper: JsonHelper, @DispatcherIO private val readingDispatcher: CoroutineDispatcher, @NonCancellableIO private val writingContext: CoroutineContext, ) { - private var eventDao: EventRoomDao = eventDatabaseFactory.build().eventDao + private var eventDao: EventRoomDao = eventDatabaseFactory.get().eventDao - private var scopeDao: SessionScopeRoomDao = eventDatabaseFactory.build().scopeDao + private var scopeDao: SessionScopeRoomDao = eventDatabaseFactory.get().scopeDao private val mutex = Mutex() @@ -44,7 +46,7 @@ internal open class EventLocalDataSource @Inject constructor( block() } catch (ex: SQLiteException) { if (isFileCorruption(ex)) { - rebuildDatabase(ex) + recreateDatabase(ex) // Retry operation with new file and key block() } else { @@ -60,7 +62,7 @@ internal open class EventLocalDataSource @Inject constructor( try { block().catch { cause -> if (isFileCorruption(cause)) { - rebuildDatabase(cause) + recreateDatabase(cause) // Recreate flow and re-emit values with the new file and key emitAll(block()) } else { @@ -69,7 +71,7 @@ internal open class EventLocalDataSource @Inject constructor( } } catch (ex: SQLiteException) { if (isFileCorruption(ex)) { - rebuildDatabase(ex) + recreateDatabase(ex) // Recreate flow with the new file and key block() } else { @@ -81,17 +83,11 @@ internal open class EventLocalDataSource @Inject constructor( private fun isFileCorruption(ex: Throwable) = ex is SQLiteDatabaseCorruptException || ex.let { it as? SQLiteException }?.message?.contains("file is not a database") == true - private suspend fun rebuildDatabase(ex: Throwable) = mutex.withLock { - // DB corruption detected; either DB file or key is corrupt - // 1. Delete DB file in order to create a new one at next init - eventDatabaseFactory.deleteDatabase() - // 2. Recreate the DB key - eventDatabaseFactory.recreateDatabaseKey() - // 3. Log exception after recreating the key so we get extra info - Simber.e("Rebuilt event DB due to error", ex, tag = DB_CORRUPTION) - // 4. Rebuild database - eventDao = eventDatabaseFactory.build().eventDao - scopeDao = eventDatabaseFactory.build().scopeDao + private suspend fun recreateDatabase(ex: Throwable) = mutex.withLock { + eventDatabaseFactory.recreateDatabase() + Simber.e("Recreated event DB due to error", ex, tag = DB_CORRUPTION) + eventDao = eventDatabaseFactory.get().eventDao + scopeDao = eventDatabaseFactory.get().scopeDao } suspend fun saveEventScope(scope: EventScope) = useRoom(writingContext) { diff --git a/infra/events/src/test/java/com/simprints/infra/events/event/local/EventDatabaseFactoryTest.kt b/infra/events/src/test/java/com/simprints/infra/events/event/local/EventDatabaseFactoryTest.kt index e61272898d..d838cb0994 100644 --- a/infra/events/src/test/java/com/simprints/infra/events/event/local/EventDatabaseFactoryTest.kt +++ b/infra/events/src/test/java/com/simprints/infra/events/event/local/EventDatabaseFactoryTest.kt @@ -31,19 +31,36 @@ internal class EventDatabaseFactoryTest { } @Test - fun `test build db success`() = runTest { + fun `test get db success`() = runTest { // Given coEvery { securityManager.getLocalDbKeyOrThrow(dbName) } returns localDbKey mockkObject(EventRoomDatabase) val db: EventRoomDatabase = mockk() every { EventRoomDatabase.getDatabase(context, any(), dbName) } returns db // When - Truth.assertThat(dbEventDatabaseFactory.build()).isEqualTo(db) + Truth.assertThat(dbEventDatabaseFactory.get()).isEqualTo(db) verify { securityManager.getLocalDbKeyOrThrow(dbName) } } @Test - fun `test build db creates key if not exist`() = runTest { + fun `get should return the same db instance on multiple calls`() = runTest { + // Given + coEvery { securityManager.getLocalDbKeyOrThrow(dbName) } returns localDbKey + mockkObject(EventRoomDatabase) + every { EventRoomDatabase.getDatabase(context, any(), dbName) } returns mockk() + // When and Then + val db1 = dbEventDatabaseFactory.get() + val db2 = dbEventDatabaseFactory.get() + Truth.assertThat(db1).isSameInstanceAs(db2) + // Verify that getLocalDbKeyOrThrow is called only once + verify(exactly = 1) { + securityManager.getLocalDbKeyOrThrow(dbName) + EventRoomDatabase.getDatabase(context, any(), dbName) + } + } + + @Test + fun `test get db creates key if not exist`() = runTest { // Given coEvery { securityManager.getLocalDbKeyOrThrow(dbName) } throws Exception() andThen localDbKey justRun { securityManager.createLocalDatabaseKeyIfMissing(dbName) } @@ -51,12 +68,12 @@ internal class EventDatabaseFactoryTest { val db: EventRoomDatabase = mockk() every { EventRoomDatabase.getDatabase(context, any(), dbName) } returns db // When and Then - Truth.assertThat(dbEventDatabaseFactory.build()).isEqualTo(db) + Truth.assertThat(dbEventDatabaseFactory.get()).isEqualTo(db) verify(exactly = 2) { securityManager.getLocalDbKeyOrThrow(dbName) } } @Test(expected = Exception::class) - fun `test build db falure`() = runTest { + fun `test get db falure`() = runTest { // Given coEvery { securityManager.getLocalDbKeyOrThrow(dbName) } throws Exception() justRun { securityManager.createLocalDatabaseKeyIfMissing(dbName) } @@ -64,25 +81,22 @@ internal class EventDatabaseFactoryTest { val db: EventRoomDatabase = mockk() every { EventRoomDatabase.getDatabase(context, any(), dbName) } returns db // When calling build it should throw exception - dbEventDatabaseFactory.build() - } - - @Test - fun deleteDatabase() { - // When - dbEventDatabaseFactory.deleteDatabase() - - // Then - verify { context.deleteDatabase(dbName) } + dbEventDatabaseFactory.get() } @Test - fun recreateDatabaseKey() { + fun recreateDatabase() { // Given justRun { securityManager.recreateLocalDatabaseKey(dbName) } + coEvery { securityManager.getLocalDbKeyOrThrow(dbName) } returns localDbKey + justRun { securityManager.createLocalDatabaseKeyIfMissing(dbName) } + mockkObject(EventRoomDatabase) + val db: EventRoomDatabase = mockk() + every { EventRoomDatabase.getDatabase(context, any(), dbName) } returns db // When - dbEventDatabaseFactory.recreateDatabaseKey() + dbEventDatabaseFactory.recreateDatabase() // Then + verify { context.deleteDatabase(dbName) } verify { securityManager.recreateLocalDatabaseKey(dbName) } } } diff --git a/infra/events/src/test/java/com/simprints/infra/events/event/local/EventLocalDataSourceTest.kt b/infra/events/src/test/java/com/simprints/infra/events/event/local/EventLocalDataSourceTest.kt index 18c098b11f..1f4f8f3c96 100644 --- a/infra/events/src/test/java/com/simprints/infra/events/event/local/EventLocalDataSourceTest.kt +++ b/infra/events/src/test/java/com/simprints/infra/events/event/local/EventLocalDataSourceTest.kt @@ -4,9 +4,9 @@ import android.content.Context import android.database.sqlite.SQLiteDatabaseCorruptException import android.database.sqlite.SQLiteException import androidx.room.Room -import androidx.test.core.app.ApplicationProvider -import androidx.test.ext.junit.runners.AndroidJUnit4 -import com.google.common.truth.Truth.assertThat +import androidx.test.core.app.* +import androidx.test.ext.junit.runners.* +import com.google.common.truth.Truth.* import com.simprints.core.tools.json.JsonHelper import com.simprints.infra.events.event.domain.models.Event import com.simprints.infra.events.event.domain.models.EventType.CALLBACK_ENROLMENT @@ -20,7 +20,8 @@ import com.simprints.testtools.common.syntax.assertThrows import com.simprints.testtools.unit.robolectric.ShadowAndroidXMultiDex import dagger.hilt.android.testing.HiltTestApplication import io.mockk.* -import io.mockk.impl.annotations.RelaxedMockK +import io.mockk.impl.annotations.* +import kotlinx.coroutines.ExperimentalCoroutinesApi import kotlinx.coroutines.flow.flow import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.toList @@ -44,6 +45,7 @@ internal class EventLocalDataSourceTest { @RelaxedMockK lateinit var eventDatabaseFactory: EventDatabaseFactory + @OptIn(ExperimentalCoroutinesApi::class) @Before fun setup() { MockKAnnotations.init(this) @@ -57,7 +59,7 @@ internal class EventLocalDataSourceTest { eventDao = db.eventDao scopeDao = db.scopeDao - every { eventDatabaseFactory.build() } returns db + every { eventDatabaseFactory.get() } returns db mockDaoLoadToMakeNothing() eventLocalDataSource = EventLocalDataSource( @@ -78,9 +80,8 @@ internal class EventLocalDataSourceTest { eventLocalDataSource.loadAllEvents() // Then verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() + eventDatabaseFactory.recreateDatabase() + eventDatabaseFactory.get() } coVerify(exactly = 2) { eventDao.loadAll() } } @@ -95,9 +96,7 @@ internal class EventLocalDataSourceTest { eventLocalDataSource.loadAllEvents() // Then verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.loadAll() } } @@ -110,7 +109,7 @@ internal class EventLocalDataSourceTest { assertThrows { eventLocalDataSource.loadAllEvents() } // Then verify(exactly = 0) { - eventDatabaseFactory.deleteDatabase() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 1) { eventDao.loadAll() } } @@ -123,7 +122,7 @@ internal class EventLocalDataSourceTest { assertThrows { eventLocalDataSource.loadAllEvents() } // Then verify(exactly = 0) { - eventDatabaseFactory.deleteDatabase() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 1) { eventDao.loadAll() } } @@ -137,11 +136,7 @@ internal class EventLocalDataSourceTest { // When eventLocalDataSource.observeEventCount().toList() // Then - verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() - } + verify { eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } } @@ -155,7 +150,7 @@ internal class EventLocalDataSourceTest { } // Then verify(exactly = 0) { - eventDatabaseFactory.deleteDatabase() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 1) { eventDao.observeCount() } } @@ -169,11 +164,7 @@ internal class EventLocalDataSourceTest { // When eventLocalDataSource.observeEventCount().toList() // Then - verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() - } + verify { eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } } @@ -187,7 +178,7 @@ internal class EventLocalDataSourceTest { } // Then verify(exactly = 0) { - eventDatabaseFactory.deleteDatabase() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 1) { eventDao.observeCount() } } @@ -202,9 +193,7 @@ internal class EventLocalDataSourceTest { val count = eventLocalDataSource.observeEventCount().toList() // Then verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } assertThat(count).isEqualTo(listOf(1, 2, 3)) @@ -219,11 +208,7 @@ internal class EventLocalDataSourceTest { // When eventLocalDataSource.observeEventCount().toList() // Then - verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() - } + verify { eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } } @@ -237,7 +222,7 @@ internal class EventLocalDataSourceTest { eventLocalDataSource.observeEventCount().toList() } // Then - verify(exactly = 0) { eventDatabaseFactory.deleteDatabase() } + verify(exactly = 0) { eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 1) { eventDao.observeCount() } } @@ -252,9 +237,7 @@ internal class EventLocalDataSourceTest { eventLocalDataSource.observeEventCount().toList() // Then verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } } @@ -268,7 +251,7 @@ internal class EventLocalDataSourceTest { eventLocalDataSource.observeEventCount().toList() } // Then - verify(exactly = 0) { eventDatabaseFactory.deleteDatabase() } + verify(exactly = 0) { eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 1) { eventDao.observeCount() } } @@ -282,9 +265,7 @@ internal class EventLocalDataSourceTest { val count = eventLocalDataSource.observeEventCount().toList() // Then verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } assertThat(count).isEqualTo(listOf(1, 2, 3)) @@ -301,9 +282,7 @@ internal class EventLocalDataSourceTest { } // Then verify { - eventDatabaseFactory.deleteDatabase() - eventDatabaseFactory.recreateDatabaseKey() - eventDatabaseFactory.build() + eventDatabaseFactory.recreateDatabase() } coVerify(exactly = 2) { eventDao.observeCount() } } @@ -477,7 +456,7 @@ internal class EventLocalDataSourceTest { coEvery { scopeDao.count(any()) } returns 0 every { db.eventDao } returns eventDao every { db.scopeDao } returns scopeDao - every { eventDatabaseFactory.build() } returns db + every { eventDatabaseFactory.get() } returns db } @After