diff --git a/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt b/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt index 78338b88..11a42a3c 100644 --- a/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt +++ b/app/src/main/java/com/urik/keyboard/UrikInputMethodService.kt @@ -37,6 +37,7 @@ import com.urik.keyboard.service.AutofillStateTracker import com.urik.keyboard.service.CharacterVariationService import com.urik.keyboard.service.ClipboardMonitorService import com.urik.keyboard.service.EmojiSearchManager +import com.urik.keyboard.service.EnglishPronounI import com.urik.keyboard.service.InputMethod import com.urik.keyboard.service.InputStateManager import com.urik.keyboard.service.LanguageManager @@ -47,7 +48,6 @@ import com.urik.keyboard.service.SpellConfirmationState import com.urik.keyboard.service.SuggestionPipeline import com.urik.keyboard.service.TextInputProcessor import com.urik.keyboard.service.ViewCallback -import com.urik.keyboard.service.EnglishPronounI import com.urik.keyboard.service.WordLearningEngine import com.urik.keyboard.service.WordState import com.urik.keyboard.settings.KeyboardSettings @@ -95,6 +95,9 @@ class UrikInputMethodService : @Inject lateinit var swipeDetector: SwipeDetector + @Inject + lateinit var streamingScoringEngine: com.urik.keyboard.ui.keyboard.components.StreamingScoringEngine + @Inject lateinit var languageManager: LanguageManager @@ -229,6 +232,7 @@ class UrikInputMethodService : private fun isSentenceEndingPunctuation(char: Char): Boolean = UCharacter.hasBinaryProperty(char.code, UProperty.S_TERM) private fun coordinateStateClear() { + streamingScoringEngine.cancelActiveGesture() outputBridge.coordinateStateClear() } @@ -1159,7 +1163,11 @@ class UrikInputMethodService : val cursorPosInWord = if (inputState.composingRegionStart != -1 && inputState.displayBuffer.isNotEmpty()) { val absoluteCursorPos = outputBridge.safeGetCursorPosition() - CursorEditingUtils.calculateCursorPositionInWord(absoluteCursorPos, inputState.composingRegionStart, inputState.displayBuffer.length) + CursorEditingUtils.calculateCursorPositionInWord( + absoluteCursorPos, + inputState.composingRegionStart, + inputState.displayBuffer.length, + ) } else { inputState.displayBuffer.length } @@ -1313,7 +1321,11 @@ class UrikInputMethodService : outputBridge.highlightCurrentWord() val suggestions = textInputProcessor.getSuggestions(inputState.displayBuffer) - val displaySuggestions = suggestionPipeline.storeAndCapitalizeSuggestions(suggestions, inputState.isCurrentWordAtSentenceStart) + val displaySuggestions = + suggestionPipeline.storeAndCapitalizeSuggestions( + suggestions, + inputState.isCurrentWordAtSentenceStart, + ) inputState.pendingSuggestions = displaySuggestions if (displaySuggestions.isNotEmpty()) { swipeKeyboardView?.updateSuggestions(displaySuggestions) @@ -1520,8 +1532,7 @@ class UrikInputMethodService : return caseTransformer.applyCasing(suggestion, keyboardState, isSentenceStart) } - private fun getEnglishPronounIForm(normalizedWord: String): String? = - EnglishPronounI.capitalize(normalizedWord) + private fun getEnglishPronounIForm(normalizedWord: String): String? = EnglishPronounI.capitalize(normalizedWord) private fun handleSuggestionSelected(suggestion: String) { serviceScope.launch { @@ -1701,7 +1712,8 @@ class UrikInputMethodService : val actualCursorPos = outputBridge.safeGetCursorPosition() if (inputState.displayBuffer.isNotEmpty() && inputState.composingRegionStart != -1) { - val expectedCursorRange = inputState.composingRegionStart..(inputState.composingRegionStart + inputState.displayBuffer.length) + val expectedCursorRange = + inputState.composingRegionStart..(inputState.composingRegionStart + inputState.displayBuffer.length) if (actualCursorPos !in expectedCursorRange) { invalidateComposingStateOnCursorJump() } @@ -1811,12 +1823,20 @@ class UrikInputMethodService : val cursorPosInWord = if (inputState.composingRegionStart != -1) { - CursorEditingUtils.calculateCursorPositionInWord(absoluteCursorPos, inputState.composingRegionStart, inputState.displayBuffer.length) + CursorEditingUtils.calculateCursorPositionInWord( + absoluteCursorPos, + inputState.composingRegionStart, + inputState.displayBuffer.length, + ) } else { val potentialStart = absoluteCursorPos - inputState.displayBuffer.length if (potentialStart >= 0) { inputState.composingRegionStart = potentialStart - CursorEditingUtils.calculateCursorPositionInWord(absoluteCursorPos, inputState.composingRegionStart, inputState.displayBuffer.length) + CursorEditingUtils.calculateCursorPositionInWord( + absoluteCursorPos, + inputState.composingRegionStart, + inputState.displayBuffer.length, + ) } else { inputState.displayBuffer.length } @@ -2099,7 +2119,11 @@ class UrikInputMethodService : inputState.pendingWordForLearning = inputState.displayBuffer outputBridge.highlightCurrentWord() - val displaySuggestions = suggestionPipeline.storeAndCapitalizeSuggestions(suggestions, inputState.isCurrentWordAtSentenceStart) + val displaySuggestions = + suggestionPipeline.storeAndCapitalizeSuggestions( + suggestions, + inputState.isCurrentWordAtSentenceStart, + ) inputState.pendingSuggestions = displaySuggestions if (displaySuggestions.isNotEmpty()) { swipeKeyboardView?.updateSuggestions(displaySuggestions) @@ -2483,6 +2507,7 @@ class UrikInputMethodService : override fun onConfigurationChanged(newConfig: android.content.res.Configuration) { super.onConfigurationChanged(newConfig) + streamingScoringEngine.cancelActiveGesture() val currentDensity = resources.displayMetrics.density @@ -2645,6 +2670,7 @@ class UrikInputMethodService : } override fun onDestroy() { + streamingScoringEngine.shutdown() wordFrequencyRepository.clearCache() autofillStateTracker.cleanup() diff --git a/app/src/main/java/com/urik/keyboard/di/KeyboardModule.kt b/app/src/main/java/com/urik/keyboard/di/KeyboardModule.kt index 2316bea0..cfef30fd 100644 --- a/app/src/main/java/com/urik/keyboard/di/KeyboardModule.kt +++ b/app/src/main/java/com/urik/keyboard/di/KeyboardModule.kt @@ -19,6 +19,7 @@ import com.urik.keyboard.service.WordNormalizer import com.urik.keyboard.settings.SettingsRepository import com.urik.keyboard.ui.keyboard.components.PathGeometryAnalyzer import com.urik.keyboard.ui.keyboard.components.ResidualScorer +import com.urik.keyboard.ui.keyboard.components.StreamingScoringEngine import com.urik.keyboard.ui.keyboard.components.SwipeDetector import com.urik.keyboard.ui.keyboard.components.ZipfCheck import com.urik.keyboard.utils.CacheMemoryManager @@ -104,7 +105,7 @@ object KeyboardModule { @Provides @Singleton - fun provideSwipeDetector( + fun provideStreamingScoringEngine( spellCheckManager: SpellCheckManager, wordLearningEngine: WordLearningEngine, pathGeometryAnalyzer: PathGeometryAnalyzer, @@ -112,8 +113,8 @@ object KeyboardModule { residualScorer: ResidualScorer, zipfCheck: ZipfCheck, wordNormalizer: WordNormalizer, - ): SwipeDetector = - SwipeDetector( + ): StreamingScoringEngine = + StreamingScoringEngine( spellCheckManager, wordLearningEngine, pathGeometryAnalyzer, @@ -123,6 +124,12 @@ object KeyboardModule { wordNormalizer, ) + @Provides + @Singleton + fun provideSwipeDetector( + streamingScoringEngine: StreamingScoringEngine, + ): SwipeDetector = SwipeDetector(streamingScoringEngine) + @Provides @Singleton fun provideSpellCheckManager( diff --git a/app/src/main/java/com/urik/keyboard/service/WordLearningEngine.kt b/app/src/main/java/com/urik/keyboard/service/WordLearningEngine.kt index dee742ea..2f3b7166 100644 --- a/app/src/main/java/com/urik/keyboard/service/WordLearningEngine.kt +++ b/app/src/main/java/com/urik/keyboard/service/WordLearningEngine.kt @@ -4,7 +4,6 @@ import android.database.sqlite.SQLiteDatabaseCorruptException import android.database.sqlite.SQLiteDatabaseLockedException import android.database.sqlite.SQLiteException import android.database.sqlite.SQLiteFullException -import com.urik.keyboard.KeyboardConstants.MemoryConstants import com.urik.keyboard.data.database.LearnedWord import com.urik.keyboard.data.database.LearnedWordDao import com.urik.keyboard.data.database.WordSource diff --git a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/GestureInterpolator.kt b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/GestureInterpolator.kt new file mode 100644 index 00000000..5a65378f --- /dev/null +++ b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/GestureInterpolator.kt @@ -0,0 +1,129 @@ +package com.urik.keyboard.ui.keyboard.components + +import kotlin.math.sqrt + +/** Catmull-Rom spline interpolation for raw touch input. */ +class GestureInterpolator(private val ringBuffer: SwipePointRingBuffer) { + + private val windowX = FloatArray(WINDOW_SIZE) + private val windowY = FloatArray(WINDOW_SIZE) + private val windowTimestamp = LongArray(WINDOW_SIZE) + private val windowPressure = FloatArray(WINDOW_SIZE) + private val windowVelocity = FloatArray(WINDOW_SIZE) + private var windowCount = 0 + private var rawPointIndex = 0 + val rawPointCount: Int get() = rawPointIndex + + fun onRawPoint(x: Float, y: Float, timestamp: Long, pressure: Float, velocity: Float) { + rawPointIndex++ + if (windowCount < WINDOW_SIZE) { + val i = windowCount + windowX[i] = x + windowY[i] = y + windowTimestamp[i] = timestamp + windowPressure[i] = pressure + windowVelocity[i] = velocity + windowCount++ + + if (windowCount == WINDOW_SIZE) { + interpolateSegment() + } + + ringBuffer.write(x, y, timestamp, pressure, velocity) + return + } + + windowX[0] = windowX[1] + windowY[0] = windowY[1] + windowTimestamp[0] = windowTimestamp[1] + windowPressure[0] = windowPressure[1] + windowVelocity[0] = windowVelocity[1] + + windowX[1] = windowX[2] + windowY[1] = windowY[2] + windowTimestamp[1] = windowTimestamp[2] + windowPressure[1] = windowPressure[2] + windowVelocity[1] = windowVelocity[2] + + windowX[2] = windowX[3] + windowY[2] = windowY[3] + windowTimestamp[2] = windowTimestamp[3] + windowPressure[2] = windowPressure[3] + windowVelocity[2] = windowVelocity[3] + + windowX[3] = x + windowY[3] = y + windowTimestamp[3] = timestamp + windowPressure[3] = pressure + windowVelocity[3] = velocity + + interpolateSegment() + + ringBuffer.write(x, y, timestamp, pressure, velocity) + } + + private fun interpolateSegment() { + val dx = windowX[3] - windowX[2] + val dy = windowY[3] - windowY[2] + val segmentLength = sqrt(dx * dx + dy * dy) + + if (segmentLength < MIN_SEGMENT_FOR_INTERPOLATION) { + return + } + + val pointCount = ((segmentLength / TARGET_DENSITY_PX).toInt() - 1) + .coerceIn(0, MAX_INTERPOLATED_PER_SEGMENT) + + if (pointCount <= 0) return + + val p1x = windowX[1]; val p1y = windowY[1] + val p2x = windowX[2]; val p2y = windowY[2] + val p3x = windowX[3]; val p3y = windowY[3] + + val t1 = windowTimestamp[2] + val t2 = windowTimestamp[3] + val pr1 = windowPressure[2] + val pr2 = windowPressure[3] + val v1 = windowVelocity[2] + val v2 = windowVelocity[3] + + for (i in 1..pointCount) { + val t = i.toFloat() / (pointCount + 1) + val t2f = t * t + val t3f = t2f * t + + val interpX = ALPHA * ( + (-p1x + 2f * p2x - p3x) * t3f + + (2f * p1x - 4f * p2x + 2f * p3x) * t2f + + (-p1x + p3x) * t + + 2f * p2x + ) + + val interpY = ALPHA * ( + (-p1y + 2f * p2y - p3y) * t3f + + (2f * p1y - 4f * p2y + 2f * p3y) * t2f + + (-p1y + p3y) * t + + 2f * p2y + ) + + val interpTimestamp = t1 + ((t2 - t1) * t).toLong() + val interpPressure = pr1 + (pr2 - pr1) * t + val interpVelocity = v1 + (v2 - v1) * t + + ringBuffer.write(interpX, interpY, interpTimestamp, interpPressure, interpVelocity) + } + } + + fun reset() { + windowCount = 0 + rawPointIndex = 0 + } + + companion object { + private const val WINDOW_SIZE = 4 + private const val TARGET_DENSITY_PX = 6f + private const val MIN_SEGMENT_FOR_INTERPOLATION = 6f + private const val MAX_INTERPOLATED_PER_SEGMENT = 10 + private const val ALPHA = 0.5f + } +} diff --git a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/PathGeometryAnalyzer.kt b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/PathGeometryAnalyzer.kt index 461e937c..c91b2a6e 100644 --- a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/PathGeometryAnalyzer.kt +++ b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/PathGeometryAnalyzer.kt @@ -355,12 +355,13 @@ class PathGeometryAnalyzer flags.fill(false, 0, minOf(size, flags.size)) val coverageRadius = PATH_COVERAGE_RADIUS val radiusSq = coverageRadius * coverageRadius + val windowRadius = (size / 15).coerceIn(3, 20) letterPathIndices.forEachIndexed { letterIdx, pathIdx -> if (pathIdx in 0.. + Thread(r, "urik-swipe-scorer").apply { + priority = Thread.NORM_PRIORITY + 1 + isDaemon = true + } + } + private val scoringDispatcher = scoringExecutor.asCoroutineDispatcher() + private val scoringScope = CoroutineScope(scoringDispatcher + SupervisorJob()) + + private var tickerJob: Job? = null + + @Volatile private var keyPositions = emptyMap() + @Volatile private var liveCandidates = ArrayList(LIVE_SET_CAPACITY) + @Volatile private var gestureActive = false + @Volatile private var tickCount = 0 + @Volatile private var gestureStartTimeNanos = 0L + + @Volatile private var cachedDictionary = emptyMap() + @Volatile private var cachedLanguageCombination = emptyList() + @Volatile private var cachedAdaptiveSigmas = emptyMap() + @Volatile private var cachedKeyNeighborhoods = emptyMap() + @Volatile private var lastKeyPositionsHash = 0 + + @Volatile var lastCommittedWord: String = "" + @Volatile var currentLanguageTag: String = "en" + + private var fullDictionary = ArrayList() + + lateinit var ringBuffer: SwipePointRingBuffer + private set + + fun bindRingBuffer(buffer: SwipePointRingBuffer) { + ringBuffer = buffer + } + + fun startGesture( + currentKeyPositions: Map, + activeLanguages: List, + languageTag: String, + ) { + cancelActiveGesture() + + keyPositions = currentKeyPositions + currentLanguageTag = languageTag + gestureActive = true + tickCount = 0 + gestureStartTimeNanos = System.nanoTime() + liveCandidates.clear() + + scoringScope.launch { + try { + val dictionary = loadOrCacheDictionary(activeLanguages) + if (dictionary.isEmpty()) return@launch + + val indexed = buildDictionaryIndex(dictionary) + fullDictionary = ArrayList(indexed) + liveCandidates = ArrayList(indexed) + + updateAdaptiveSigmaCache(currentKeyPositions) + startTicker() + } catch (_: Exception) { } + } + } + + private fun startTicker() { + tickerJob?.cancel() + tickerJob = scoringScope.launch { + var nextTickNanos = System.nanoTime() + TICK_INTERVAL_NANOS + + while (isActive && gestureActive) { + val now = System.nanoTime() + val sleepMs = ((nextTickNanos - now) / 1_000_000L).coerceAtLeast(1L) + delay(sleepMs) + + if (!gestureActive) break + + onTick() + tickCount++ + nextTickNanos += TICK_INTERVAL_NANOS + } + } + } + + private fun onTick() { + if (!::ringBuffer.isInitialized) return + val path = ringBuffer.snapshot() + if (path.size < 3) return + + val currentKeyPositions = keyPositions + if (currentKeyPositions.isEmpty()) return + + val elapsedMs = (System.nanoTime() - gestureStartTimeNanos) / 1_000_000L + + when { + elapsedMs >= TRAVERSAL_PRUNE_MS && tickCount >= 3 -> { + val charsInBounds = computeCharsInBounds(path, currentKeyPositions) + val traversedKeys = computeTraversedKeys(path, currentKeyPositions) + liveCandidates = ArrayList(pruneByTraversal(liveCandidates, traversedKeys)) + liveCandidates = ArrayList(pruneByBounds(liveCandidates, charsInBounds)) + } + elapsedMs >= BOUNDS_PRUNE_MS && tickCount >= 2 -> { + val charsInBounds = computeCharsInBounds(path, currentKeyPositions) + liveCandidates = ArrayList(pruneByBounds(liveCandidates, charsInBounds)) + } + elapsedMs >= ANCHOR_PRUNE_MS && tickCount >= 1 -> { + val startAnchorKeys = computeStartAnchorKeys(path, currentKeyPositions) + liveCandidates = ArrayList(pruneByStartAnchor(liveCandidates, startAnchorKeys)) + } + } + } + + fun cancelActiveGesture() { + gestureActive = false + tickerJob?.cancel() + tickerJob = null + liveCandidates.clear() + tickCount = 0 + } + + suspend fun finalize( + swipePath: List, + rawPointCount: Int, + ): List = withContext(scoringDispatcher) { + gestureActive = false + tickerJob?.cancel() + + if (swipePath.isEmpty()) return@withContext emptyList() + + val currentKeyPositions = keyPositions + if (currentKeyPositions.isEmpty()) return@withContext emptyList() + + val maxLength = (rawPointCount / 5).coerceIn(5, 20) + val unfilteredCandidates = if (liveCandidates.isNotEmpty()) { + liveCandidates + } else { + fullDictionary + } + val candidates = unfilteredCandidates.filter { it.word.length <= maxLength } + + if (candidates.isEmpty()) return@withContext emptyList() + + val sigmaCache = cachedAdaptiveSigmas + val neighborhoodCache = cachedKeyNeighborhoods + + val signal = SwipeSignal.extract( + swipePath, + currentKeyPositions, + pathGeometryAnalyzer, + sigmaCache, + rawPointCount, + ) + + val bigramPredictions: Set = + if (lastCommittedWord.isNotBlank()) { + wordFrequencyRepository.getBigramPredictions( + lastCommittedWord, currentLanguageTag, + ).toSet() + } else { + emptySet() + } + + var maxFrequencySeen = 0L + val results = ArrayList(candidates.size / 4) + + for (i in candidates.indices) { + if (i % 50 == 0) yield() + + val entry = candidates[i] + if (entry.rawFrequency > maxFrequencySeen) { + maxFrequencySeen = entry.rawFrequency + } + + val result = residualScorer.scoreCandidate( + entry, signal, currentKeyPositions, + sigmaCache, neighborhoodCache, maxFrequencySeen, + ) ?: continue + + results.add(result) + + if (result.combinedScore > EXCELLENT_CANDIDATE_THRESHOLD) { + var excellentCount = 0 + for (candidate in results) { + if (candidate.combinedScore > 0.90f) excellentCount++ + } + if (excellentCount >= MIN_EXCELLENT_CANDIDATES) break + } + } + + val wordFrequencyMap = cachedDictionary + + val arbitration = zipfCheck.arbitrate( + results, + signal.geometricAnalysis, + currentKeyPositions, + bigramPredictions, + wordFrequencyMap, + rawPointCount, + ) + + return@withContext arbitration.candidates + } + + fun pruneByStartAnchor( + candidates: List, + startKeys: Set, + ): List { + if (startKeys.isEmpty()) return candidates + return candidates.filter { it.firstChar in startKeys } + } + + fun pruneByBounds( + candidates: List, + charsInBounds: Set, + ): List { + if (charsInBounds.isEmpty()) return candidates + return candidates.filter { entry -> + val uniqueChars = entry.word.lowercase().toSet() + val outOfBoundsCount = uniqueChars.count { it !in charsInBounds } + outOfBoundsCount <= BOUNDS_SAFETY_MARGIN + } + } + + fun pruneByTraversal( + candidates: List, + traversedKeys: Set, + ): List { + if (traversedKeys.size < 2) return candidates + return candidates.filter { entry -> + val wordChars = entry.word.lowercase().toSet() + val traversedOverlap = wordChars.count { it in traversedKeys } + traversedOverlap.toFloat() / wordChars.size >= TRAVERSAL_MIN_OVERLAP + } + } + + private fun computeStartAnchorKeys( + path: List, + positions: Map, + ): Set { + if (path.isEmpty()) return emptySet() + + val sampleCount = minOf(5, path.size) + var cx = 0f + var cy = 0f + for (i in 0 until sampleCount) { + cx += path[i].x + cy += path[i].y + } + cx /= sampleCount + cy /= sampleCount + + val thresholdSq = START_ANCHOR_RADIUS * START_ANCHOR_RADIUS + val result = mutableSetOf() + for ((char, pos) in positions) { + val dx = pos.x - cx + val dy = pos.y - cy + if (dx * dx + dy * dy < thresholdSq) { + result.add(char) + } + } + return result + } + + private fun computeCharsInBounds( + path: List, + positions: Map, + ): Set { + if (path.isEmpty()) return emptySet() + + var minX = Float.MAX_VALUE + var maxX = Float.MIN_VALUE + var minY = Float.MAX_VALUE + var maxY = Float.MIN_VALUE + for (point in path) { + if (point.x < minX) minX = point.x + if (point.x > maxX) maxX = point.x + if (point.y < minY) minY = point.y + if (point.y > maxY) maxY = point.y + } + + minX -= BOUNDS_MARGIN + maxX += BOUNDS_MARGIN + minY -= BOUNDS_MARGIN + maxY += BOUNDS_MARGIN + + val result = mutableSetOf() + for ((char, pos) in positions) { + if (pos.x in minX..maxX && pos.y in minY..maxY) { + result.add(char) + } + } + return result + } + + private fun computeTraversedKeys( + path: List, + positions: Map, + ): Set { + val result = mutableSetOf() + val traversalRadiusSq = TRAVERSAL_RADIUS * TRAVERSAL_RADIUS + + for (point in path) { + for ((char, pos) in positions) { + if (char in result) continue + val dx = point.x - pos.x + val dy = point.y - pos.y + if (dx * dx + dy * dy < traversalRadiusSq) { + result.add(char) + } + } + } + return result + } + + private suspend fun loadOrCacheDictionary( + compatibleLanguages: List, + ): Map { + if (compatibleLanguages == cachedLanguageCombination && cachedDictionary.isNotEmpty()) { + return cachedDictionary + } + + val dictionaryWordsMap = spellCheckManager.getCommonWordsForLanguages(compatibleLanguages) + val learnedWordsMap = wordLearningEngine.getLearnedWordsForSwipeAllLanguages( + compatibleLanguages, 2, 20, + ) + + val mergedMap = HashMap(dictionaryWordsMap.size + learnedWordsMap.size) + dictionaryWordsMap.forEach { (word, freq) -> mergedMap[word] = freq } + learnedWordsMap.forEach { (word, freq) -> + mergedMap[word] = maxOf(mergedMap[word] ?: 0, freq) + } + + cachedDictionary = mergedMap + cachedLanguageCombination = compatibleLanguages + return mergedMap + } + + private fun buildDictionaryIndex( + wordFrequencyMap: Map, + ): List { + val sorted = wordFrequencyMap.entries + .filter { (word, _) -> word.length in 2..20 } + .sortedByDescending { it.value } + + return sorted.mapIndexed { rank, (word, frequency) -> + SwipeDetector.DictionaryEntry( + word = word, + frequencyScore = ln(frequency.toFloat() + 1f) / 20f, + rawFrequency = frequency.toLong(), + firstChar = wordNormalizer.stripDiacritics( + word.first().toString(), + ).first().lowercaseChar(), + uniqueLetterCount = word.toSet().size, + frequencyTier = SwipeDetector.FrequencyTier.fromRank(rank), + ) + } + } + + private fun updateAdaptiveSigmaCache(positions: Map) { + val hash = positions.hashCode() + if (hash != lastKeyPositionsHash) { + val newSigmas = mutableMapOf() + positions.keys.forEach { char -> + newSigmas[char] = pathGeometryAnalyzer.calculateAdaptiveSigma(char, positions) + } + cachedAdaptiveSigmas = newSigmas + cachedKeyNeighborhoods = pathGeometryAnalyzer.computeKeyNeighborhoods(positions) + lastKeyPositionsHash = hash + } + } + + fun shutdown() { + cancelActiveGesture() + scoringScope.cancel() + scoringDispatcher.close() + scoringExecutor.shutdown() + } + + companion object { + private const val TICK_INTERVAL_NANOS = 50_000_000L + private const val ANCHOR_PRUNE_MS = 100L + private const val BOUNDS_PRUNE_MS = 200L + private const val TRAVERSAL_PRUNE_MS = 300L + private const val START_ANCHOR_RADIUS = 85f + private const val BOUNDS_MARGIN = 60f + private const val TRAVERSAL_RADIUS = 40f + private const val BOUNDS_SAFETY_MARGIN = 1 + private const val TRAVERSAL_MIN_OVERLAP = 0.30f + private const val LIVE_SET_CAPACITY = 50 + private const val EXCELLENT_CANDIDATE_THRESHOLD = 0.95f + private const val MIN_EXCELLENT_CANDIDATES = 3 + } +} diff --git a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeDetector.kt b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeDetector.kt index 4a591ef0..5d547c54 100644 --- a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeDetector.kt +++ b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeDetector.kt @@ -3,26 +3,18 @@ package com.urik.keyboard.ui.keyboard.components import android.graphics.PointF -import android.util.Log import android.view.MotionEvent import com.ibm.icu.lang.UScript import com.ibm.icu.util.ULocale -import com.urik.keyboard.KeyboardConstants.GeometricScoringConstants -import com.urik.keyboard.data.WordFrequencyRepository import com.urik.keyboard.model.KeyboardKey -import com.urik.keyboard.service.SpellCheckManager -import com.urik.keyboard.service.WordLearningEngine -import com.urik.keyboard.service.WordNormalizer import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.Job import kotlinx.coroutines.SupervisorJob import kotlinx.coroutines.launch import kotlinx.coroutines.withContext -import kotlinx.coroutines.yield import javax.inject.Inject import javax.inject.Singleton -import kotlin.math.ln import kotlin.math.sqrt /** @@ -46,13 +38,7 @@ data class WordCandidate( class SwipeDetector @Inject constructor( - private val spellCheckManager: SpellCheckManager, - private val wordLearningEngine: WordLearningEngine, - private val pathGeometryAnalyzer: PathGeometryAnalyzer, - private val wordFrequencyRepository: WordFrequencyRepository, - private val residualScorer: ResidualScorer, - private val zipfCheck: ZipfCheck, - private val wordNormalizer: WordNormalizer, + private val streamingScoringEngine: StreamingScoringEngine, ) { /** * Captured swipe point with metadata. @@ -108,12 +94,19 @@ class SwipeDetector @Suppress("ktlint:standard:backing-property-naming") private var _swipeListener: SwipeListener? = null + private val ringBuffer = SwipePointRingBuffer() + private val interpolator = GestureInterpolator(ringBuffer) + + init { + streamingScoringEngine.bindRingBuffer(ringBuffer) + } + private var lastUpdateTime = 0L private var isSwiping = false - private var swipePoints = ArrayList(MAX_SWIPE_POINTS) private var startTime = 0L private var pointCounter = 0 - private var firstPoint: SwipePoint? = null + private var firstPointX = 0f + private var firstPointY = 0f private var startingKey: KeyboardKey? = null private var lastDeltaX = 0f private var directionReversals = 0 @@ -145,30 +138,6 @@ class SwipeDetector @Volatile private var keyCharacterPositions = emptyMap() - @Volatile - private var cachedSwipeDictionary = emptyMap() - - @Volatile - private var cachedLanguageCombination = emptyList() - - @Volatile - private var cachedScriptCode = UScript.LATIN - - @Volatile - private var cachedAdaptiveSigmas = emptyMap() - - @Volatile - private var cachedKeyNeighborhoods = emptyMap() - - @Volatile - private var lastKeyPositionsHash = 0 - - @Volatile - private var lastCommittedWord: String = "" - - @Volatile - private var currentLanguageTag: String = "en" - private val scopeJob = SupervisorJob() private val scope = CoroutineScope(Dispatchers.Default + scopeJob) private var scoringJob: Job? = null @@ -280,11 +249,11 @@ class SwipeDetector } fun updateLastCommittedWord(word: String) { - lastCommittedWord = word + streamingScoringEngine.lastCommittedWord = word } fun updateCurrentLanguage(tag: String) { - currentLanguageTag = tag + streamingScoringEngine.currentLanguageTag = tag } /** @@ -375,16 +344,13 @@ class SwipeDetector startingKey = key val transformed = transformTouchCoordinate(event.x, event.y) - val point = - SwipePoint( - x = transformed.x, - y = transformed.y, - timestamp = startTime, - pressure = event.pressure, - velocity = 0.0f, - ) - firstPoint = point - swipePoints.add(point) + firstPointX = transformed.x + firstPointY = transformed.y + + interpolator.onRawPoint( + transformed.x, transformed.y, event.eventTime, + event.pressure, 0f, + ) lastCheckX = transformed.x } @@ -392,153 +358,152 @@ class SwipeDetector event: MotionEvent, keyAt: (Float, Float) -> KeyboardKey?, ) { - firstPoint?.let { start -> - val now = System.currentTimeMillis() - val timeSinceDown = now - startTime - val distance = calculateDistance(start.x, start.y, event.x, event.y) - - if (lastCheckX != 0f) { - val deltaX = event.x - lastCheckX - if (lastDeltaX != 0f && deltaX != 0f) { - if ((lastDeltaX > 0) != (deltaX > 0)) { - directionReversals++ - } - } - lastDeltaX = deltaX - } - lastCheckX = event.x - - for (h in 0 until event.historySize) { - val histX = event.getHistoricalX(h) - val histY = event.getHistoricalY(h) - val histTime = event.getHistoricalEventTime(h) - val histTransformed = transformTouchCoordinate(histX, histY) - val histLastPoint = swipePoints.lastOrNull() - - val histVelocity = - if (histLastPoint != null) { - val dx = histTransformed.x - histLastPoint.x - val dy = histTransformed.y - histLastPoint.y - val dt = (histTime - histLastPoint.timestamp).coerceAtLeast(1L).toFloat() - sqrt(dx * dx + dy * dy) / dt - } else { - 0f - } - - val histDist = - if (histLastPoint != null) { - calculateDistance(histLastPoint.x, histLastPoint.y, histTransformed.x, histTransformed.y) - } else { - Float.MAX_VALUE - } + if (firstPointX == 0f && firstPointY == 0f) return - if (histDist > 4f) { - swipePoints.add( - SwipePoint( - x = histTransformed.x, - y = histTransformed.y, - timestamp = histTime, - pressure = event.getHistoricalPressure(h), - velocity = histVelocity, - ), - ) + val now = System.currentTimeMillis() + val timeSinceDown = now - startTime + val distance = calculateDistance(firstPointX, firstPointY, event.x, event.y) + + if (lastCheckX != 0f) { + val deltaX = event.x - lastCheckX + if (lastDeltaX != 0f && deltaX != 0f) { + if ((lastDeltaX > 0) != (deltaX > 0)) { + directionReversals++ } } + lastDeltaX = deltaX + } + lastCheckX = event.x - val transformed = transformTouchCoordinate(event.x, event.y) - val lastPoint = swipePoints.lastOrNull() - val velocityFromLast = - if (lastPoint != null && timeSinceDown > 0) { - val dx = transformed.x - lastPoint.x - val dy = transformed.y - lastPoint.y - sqrt(dx * dx + dy * dy) / (now - lastPoint.timestamp).coerceAtLeast(1L).toFloat() + for (h in 0 until event.historySize) { + val histX = event.getHistoricalX(h) + val histY = event.getHistoricalY(h) + val histTime = event.getHistoricalEventTime(h) + val histTransformed = transformTouchCoordinate(histX, histY) + val histLastPoint = ringBuffer.peekLast() + + val histVelocity = + if (histLastPoint != null) { + val dx = histTransformed.x - histLastPoint.x + val dy = histTransformed.y - histLastPoint.y + val dt = (histTime - histLastPoint.timestamp).coerceAtLeast(1L).toFloat() + sqrt(dx * dx + dy * dy) / dt } else { 0f } - val distFromLast = - if (lastPoint != null) { - calculateDistance(lastPoint.x, lastPoint.y, transformed.x, transformed.y) + val histDist = + if (histLastPoint != null) { + calculateDistance(histLastPoint.x, histLastPoint.y, histTransformed.x, histTransformed.y) } else { Float.MAX_VALUE } - if (distFromLast > 4f) { - swipePoints.add( - SwipePoint( - x = transformed.x, - y = transformed.y, - timestamp = now, - pressure = event.pressure, - velocity = velocityFromLast, - ), + if (histDist > 4f) { + interpolator.onRawPoint( + histTransformed.x, histTransformed.y, histTime, + event.getHistoricalPressure(h), histVelocity, ) } + } + + val transformed = transformTouchCoordinate(event.x, event.y) + val lastPoint = ringBuffer.peekLast() + val velocityFromLast = + if (lastPoint != null && timeSinceDown > 0) { + val dx = transformed.x - lastPoint.x + val dy = transformed.y - lastPoint.y + sqrt(dx * dx + dy * dy) / (event.eventTime - lastPoint.timestamp).coerceAtLeast(1L).toFloat() + } else { + 0f + } - if (swipePoints.size >= 3 && timeSinceDown >= SWIPE_TIME_THRESHOLD_MS) { - var largeGapCount = 0 - for (i in 0 until swipePoints.size - 1) { - val prev = swipePoints[i] - val curr = swipePoints[i + 1] - if (calculateDistance(prev.x, prev.y, curr.x, curr.y) > MAX_CONSECUTIVE_GAP_PX) { - largeGapCount++ - } - } + val distFromLast = + if (lastPoint != null) { + calculateDistance(lastPoint.x, lastPoint.y, transformed.x, transformed.y) + } else { + Float.MAX_VALUE + } - val gapRatio = largeGapCount.toFloat() / (swipePoints.size - 1) - if (gapRatio > 0.5f) { - reset() - return + if (distFromLast > 4f) { + interpolator.onRawPoint( + transformed.x, transformed.y, event.eventTime, + event.pressure, velocityFromLast, + ) + } + + if (ringBuffer.size >= 3 && timeSinceDown >= SWIPE_TIME_THRESHOLD_MS) { + val snapshot = ringBuffer.snapshot() + var largeGapCount = 0 + for (i in 0 until snapshot.size - 1) { + val prev = snapshot[i] + val curr = snapshot[i + 1] + if (calculateDistance(prev.x, prev.y, curr.x, curr.y) > MAX_CONSECUTIVE_GAP_PX) { + largeGapCount++ } } - if (directionReversals >= 3) { + val gapRatio = largeGapCount.toFloat() / (snapshot.size - 1) + if (gapRatio > 0.5f) { reset() return } + } - val isHighVelocity = timeSinceDown < SWIPE_TIME_THRESHOLD_MS - val effectiveDistance = - if (isHighVelocity) { - swipeStartDistancePx * HIGH_VELOCITY_DISTANCE_MULTIPLIER - } else { - swipeStartDistancePx - } + if (directionReversals >= 3) { + reset() + return + } - if (distance > effectiveDistance) { - val currentKey = keyAt(event.x, event.y) - if (currentKey == startingKey) { - return - } + val isHighVelocity = timeSinceDown < SWIPE_TIME_THRESHOLD_MS + val effectiveDistance = + if (isHighVelocity) { + swipeStartDistancePx * HIGH_VELOCITY_DISTANCE_MULTIPLIER + } else { + swipeStartDistancePx + } - val avgVelocity = distance / timeSinceDown.coerceAtLeast(1L).toFloat() - if (avgVelocity > MAX_SWIPE_VELOCITY_PX_PER_MS) { - return - } + if (distance > effectiveDistance) { + val currentKey = keyAt(event.x, event.y) + if (currentKey == startingKey) { + return + } - if (isPeckLikeMotion()) { - return - } + val avgVelocity = distance / timeSinceDown.coerceAtLeast(1L).toFloat() + if (avgVelocity > MAX_SWIPE_VELOCITY_PX_PER_MS) { + return + } - if (isGhostPath(distance, avgVelocity)) { - return - } + if (isPeckLikeMotion()) { + return + } - isSwiping = true - pointCounter = swipePoints.size - cachedTransformPoint.set(start.x, start.y) - _swipeListener?.onSwipeStart(cachedTransformPoint) - updateSwipePath(event) + if (isGhostPath(distance, avgVelocity)) { + return } + + isSwiping = true + pointCounter = ringBuffer.size + + val compatibleLanguages = getCompatibleLanguagesForSwipe(activeLanguages, currentScriptCode) + streamingScoringEngine.startGesture( + keyCharacterPositions, + compatibleLanguages, + streamingScoringEngine.currentLanguageTag, + ) + + cachedTransformPoint.set(firstPointX, firstPointY) + _swipeListener?.onSwipeStart(cachedTransformPoint) } } private fun isPeckLikeMotion(): Boolean { - val pointCount = swipePoints.size + val snapshot = ringBuffer.snapshot() + val pointCount = snapshot.size if (pointCount < 3) return false - val first = swipePoints[0] - val last = swipePoints[pointCount - 1] + val first = snapshot[0] + val last = snapshot[pointCount - 1] val totalDuration = last.timestamp - first.timestamp if (totalDuration <= 0) return false @@ -546,7 +511,7 @@ class SwipeDetector var midPointIndex = 0 var minTimeDiff = Long.MAX_VALUE for (i in 1 until pointCount - 1) { - val diff = kotlin.math.abs(swipePoints[i].timestamp - midTimestamp) + val diff = kotlin.math.abs(snapshot[i].timestamp - midTimestamp) if (diff < minTimeDiff) { minTimeDiff = diff midPointIndex = i @@ -555,7 +520,7 @@ class SwipeDetector if (midPointIndex == 0) return false - val midPoint = swipePoints[midPointIndex] + val midPoint = snapshot[midPointIndex] val earlyDisplacement = calculateDistance(first.x, first.y, midPoint.x, midPoint.y) val lateDisplacement = calculateDistance(midPoint.x, midPoint.y, last.x, last.y) val totalDisplacement = earlyDisplacement + lateDisplacement @@ -575,11 +540,12 @@ class SwipeDetector } private fun hasImpossibleGap(): Boolean { + val snapshot = ringBuffer.snapshot() val threshold = GHOST_IMPOSSIBLE_GAP_PX val thresholdSq = threshold * threshold - for (i in 0 until swipePoints.size - 1) { - val p1 = swipePoints[i] - val p2 = swipePoints[i + 1] + for (i in 0 until snapshot.size - 1) { + val p1 = snapshot[i] + val p2 = snapshot[i + 1] val dx = p2.x - p1.x val dy = p2.y - p1.y val dt = (p2.timestamp - p1.timestamp).coerceAtLeast(1L) @@ -596,16 +562,17 @@ class SwipeDetector ): Boolean { if (avgVelocity < GHOST_DENSITY_VELOCITY_GATE) return false if (totalDistance < 1f) return false - val density = swipePoints.size.toFloat() / totalDistance + val density = ringBuffer.size.toFloat() / totalDistance return density < GHOST_MIN_PATH_DENSITY } private fun isSlideOffStart(avgVelocity: Float): Boolean { - if (swipePoints.size < 3) return false + val snapshot = ringBuffer.snapshot() + if (snapshot.size < 3) return false if (avgVelocity < GHOST_DENSITY_VELOCITY_GATE) return false - val p0 = swipePoints[0] - val p1 = swipePoints[1] + val p0 = snapshot[0] + val p1 = snapshot[1] val dt01 = (p1.timestamp - p0.timestamp).coerceAtLeast(1L).toFloat() val dx01 = p1.x - p0.x val dy01 = p1.y - p0.y @@ -613,10 +580,10 @@ class SwipeDetector if (initialVelocity < GHOST_START_MOMENTUM_VELOCITY) return false - val checkEnd = minOf(swipePoints.size, GHOST_START_INTENT_POINTS) + val checkEnd = minOf(snapshot.size, GHOST_START_INTENT_POINTS) for (i in 2 until checkEnd) { - val prev = swipePoints[i - 1] - val curr = swipePoints[i] + val prev = snapshot[i - 1] + val curr = snapshot[i] val dt = (curr.timestamp - prev.timestamp).coerceAtLeast(1L).toFloat() val dx = curr.x - prev.x val dy = curr.y - prev.y @@ -626,8 +593,8 @@ class SwipeDetector return false } - val prevDx = prev.x - swipePoints[i - 2].x - val prevDy = prev.y - swipePoints[i - 2].y + val prevDx = prev.x - snapshot[i - 2].x + val prevDy = prev.y - snapshot[i - 2].y val dot = prevDx * dx + prevDy * dy val prevLen = sqrt(prevDx * prevDx + prevDy * prevDy) val currLen = sqrt(dx * dx + dy * dy) @@ -643,27 +610,28 @@ class SwipeDetector } private fun shouldSamplePoint( - newPoint: SwipePoint, + newX: Float, + newY: Float, counter: Int, velocity: Float, ): Boolean { - if (swipePoints.size < MIN_SWIPE_POINTS_FOR_SAMPLING) return true + if (ringBuffer.size < MIN_SWIPE_POINTS_FOR_SAMPLING) return true - val lastPoint = swipePoints.lastOrNull() ?: return true + val lastPoint = ringBuffer.peekLast() ?: return true - val dx = newPoint.x - lastPoint.x - val dy = newPoint.y - lastPoint.y + val dx = newX - lastPoint.x + val dy = newY - lastPoint.y val distanceSquared = dx * dx + dy * dy if (distanceSquared < MIN_POINT_DISTANCE * MIN_POINT_DISTANCE) return false val samplingInterval = when { - swipePoints.size < ADAPTIVE_THRESHOLD -> { + ringBuffer.size < ADAPTIVE_THRESHOLD -> { MIN_SAMPLING_INTERVAL } - swipePoints.size < MAX_SWIPE_POINTS * ADAPTIVE_THRESHOLD_RATIO -> { + ringBuffer.size < SwipePointRingBuffer.CAPACITY * ADAPTIVE_THRESHOLD_RATIO -> { MIN_SAMPLING_INTERVAL + 2 } @@ -675,7 +643,7 @@ class SwipeDetector if (counter % samplingInterval != 0) return false - val isSlowPreciseMovement = velocity < SLOW_MOVEMENT_VELOCITY_THRESHOLD && swipePoints.size > 10 + val isSlowPreciseMovement = velocity < SLOW_MOVEMENT_VELOCITY_THRESHOLD && ringBuffer.size > 10 if (isSlowPreciseMovement) { return counter % MIN_SAMPLING_INTERVAL == 0 } @@ -690,7 +658,7 @@ class SwipeDetector val histY = event.getHistoricalY(h) val histTime = event.getHistoricalEventTime(h) val histTransformed = transformTouchCoordinate(histX, histY) - val histLastPoint = swipePoints.lastOrNull() + val histLastPoint = ringBuffer.peekLast() val histVelocity = if (histLastPoint != null) { @@ -702,36 +670,23 @@ class SwipeDetector 0f } - val histPoint = - SwipePoint( - x = histTransformed.x, - y = histTransformed.y, - timestamp = histTime, - pressure = event.getHistoricalPressure(h), - velocity = histVelocity, + if (shouldSamplePoint(histTransformed.x, histTransformed.y, pointCounter, histVelocity)) { + interpolator.onRawPoint( + histTransformed.x, histTransformed.y, histTime, + event.getHistoricalPressure(h), histVelocity, ) - - if (shouldSamplePoint(histPoint, pointCounter, histVelocity)) { - swipePoints.add(histPoint) } } pointCounter++ val transformed = transformTouchCoordinate(event.x, event.y) val velocity = calculateVelocity(event) - val newPoint = - SwipePoint( - x = transformed.x, - y = transformed.y, - timestamp = System.currentTimeMillis(), - pressure = event.pressure, - velocity = velocity, - ) - val shouldAddPoint = shouldSamplePoint(newPoint, pointCounter, velocity) - - if (shouldAddPoint) { - swipePoints.add(newPoint) + if (shouldSamplePoint(transformed.x, transformed.y, pointCounter, velocity)) { + interpolator.onRawPoint( + transformed.x, transformed.y, event.eventTime, + event.pressure, velocity, + ) } val now = System.currentTimeMillis() @@ -747,24 +702,21 @@ class SwipeDetector ): Boolean { if (isSwiping) { val transformed = transformTouchCoordinate(event.x, event.y) - val finalPoint = - SwipePoint( - x = transformed.x, - y = transformed.y, - timestamp = System.currentTimeMillis(), - pressure = event.pressure, - velocity = calculateVelocity(event), - ) - swipePoints.add(finalPoint) + interpolator.onRawPoint( + transformed.x, transformed.y, + event.eventTime, event.pressure, + calculateVelocity(event), + ) - val pathSnapshot = ArrayList(swipePoints) + val pathSnapshot = ringBuffer.snapshot().toList() + val rawPointCount = interpolator.rawPointCount _swipeListener?.onSwipeEnd() scoringJob = scope.launch { try { - val topCandidates = performSpatialScoringAsync(pathSnapshot) + val topCandidates = streamingScoringEngine.finalize(pathSnapshot, rawPointCount) withContext(Dispatchers.Main) { _swipeListener?.onSwipeResults(topCandidates) @@ -793,247 +745,13 @@ class SwipeDetector } } - private suspend fun performSpatialScoringAsync(swipePath: List): List = - withContext(Dispatchers.Default) { - try { - if (swipePath.isEmpty()) return@withContext emptyList() - - val keyPositionsSnapshot = keyCharacterPositions - if (keyPositionsSnapshot.isEmpty()) return@withContext emptyList() - - val interpolatedPath = interpolatePathForFastSegments(swipePath, keyPositionsSnapshot) - - val minLength = 2 - val maxLength = (interpolatedPath.size / 5).coerceIn(5, 20) - - val compatibleLanguages = getCompatibleLanguagesForSwipe(activeLanguages, currentScriptCode) - val wordFrequencyMap = loadOrCacheDictionary(compatibleLanguages, minLength, maxLength) - if (wordFrequencyMap.isEmpty()) return@withContext emptyList() - - val bigramPredictions: Set = - if (lastCommittedWord.isNotBlank()) { - wordFrequencyRepository.getBigramPredictions(lastCommittedWord, currentLanguageTag).toSet() - } else { - emptySet() - } - - updateAdaptiveSigmaCache(keyPositionsSnapshot) - val sigmaCache = cachedAdaptiveSigmas - val neighborhoodCache = cachedKeyNeighborhoods - - val signal = - SwipeSignal.extract( - interpolatedPath, - keyPositionsSnapshot, - pathGeometryAnalyzer, - sigmaCache, - ) - - val dictionaryByFirstChar = buildDictionaryIndex(wordFrequencyMap, minLength, maxLength) - val relevantChars = signal.startAnchor.candidateKeys.ifEmpty { dictionaryByFirstChar.keys } - val dictionarySnapshot = relevantChars.flatMap { dictionaryByFirstChar[it] ?: emptyList() } - - var maxFrequencySeen = 0L - val results = ArrayList(dictionarySnapshot.size / 4) - - for (i in dictionarySnapshot.indices) { - if (i % 50 == 0) yield() - - val entry = dictionarySnapshot[i] - if (entry.rawFrequency > maxFrequencySeen) { - maxFrequencySeen = entry.rawFrequency - } - - val result = - residualScorer.scoreCandidate( - entry, - signal, - keyPositionsSnapshot, - sigmaCache, - neighborhoodCache, - maxFrequencySeen, - ) ?: continue - - results.add(result) - - if (result.combinedScore > EXCELLENT_CANDIDATE_THRESHOLD) { - var excellentCount = 0 - for (candidate in results) { - if (candidate.combinedScore > 0.90f) excellentCount++ - } - if (excellentCount >= MIN_EXCELLENT_CANDIDATES) break - } - } - - val arbitration = - zipfCheck.arbitrate( - results, - signal.geometricAnalysis, - keyPositionsSnapshot, - bigramPredictions, - wordFrequencyMap, - interpolatedPath.size, - ) - return@withContext arbitration.candidates - } catch (_: Exception) { - return@withContext emptyList() - } - } - - private fun interpolatePathForFastSegments( - path: List, - keyPositions: Map, - ): List { - if (path.size < 2) return path - - val result = ArrayList(path.size + 20) - result.add(path[0]) - - for (i in 1 until path.size) { - val prev = path[i - 1] - val curr = path[i] - - val dx = curr.x - prev.x - val dy = curr.y - prev.y - val distance = sqrt(dx * dx + dy * dy) - - val avgVelocity = (prev.velocity + curr.velocity) / 2f - - val shouldInterpolate = - ( - avgVelocity > VELOCITY_INTERPOLATION_THRESHOLD && - distance > INTERPOLATION_MIN_GAP_PX - ) || - distance > LARGE_GAP_INTERPOLATION_THRESHOLD_PX - - if (shouldInterpolate) { - val keysOnSegment = findKeysOnSegment(prev, curr, keyPositions) - - keysOnSegment.take(MAX_INTERPOLATED_POINTS).forEach { (_, _, t) -> - val interpX = prev.x + dx * t - val interpY = prev.y + dy * t - val interpTime = prev.timestamp + ((curr.timestamp - prev.timestamp) * t).toLong() - val interpVelocity = prev.velocity + (curr.velocity - prev.velocity) * t - - result.add( - SwipePoint( - x = interpX, - y = interpY, - timestamp = interpTime, - pressure = (prev.pressure + curr.pressure) / 2f, - velocity = interpVelocity, - ), - ) - } - } - - result.add(curr) - } - - return result - } - - private fun findKeysOnSegment( - p1: SwipePoint, - p2: SwipePoint, - keyPositions: Map, - ): List> { - val keysOnPath = mutableListOf>() - - val dx = p2.x - p1.x - val dy = p2.y - p1.y - val segmentLengthSq = dx * dx + dy * dy - - if (segmentLengthSq < 1f) return keysOnPath - - keyPositions.forEach { (char, keyPos) -> - val t = ((keyPos.x - p1.x) * dx + (keyPos.y - p1.y) * dy) / segmentLengthSq - - if (t in 0.1f..0.9f) { - val projX = p1.x + t * dx - val projY = p1.y + t * dy - val distToKey = sqrt((projX - keyPos.x).let { it * it } + (projY - keyPos.y).let { it * it }) - - if (distToKey < GeometricScoringConstants.KEY_TRAVERSAL_RADIUS) { - keysOnPath.add(Triple(char, keyPos, t)) - } - } - } - - return keysOnPath.sortedBy { it.third } - } - - private suspend fun loadOrCacheDictionary( - compatibleLanguages: List, - minLength: Int, - maxLength: Int, - ): Map = - if (compatibleLanguages == cachedLanguageCombination && - currentScriptCode == cachedScriptCode && - cachedSwipeDictionary.isNotEmpty() - ) { - cachedSwipeDictionary - } else { - val dictionaryWordsMap = spellCheckManager.getCommonWordsForLanguages(compatibleLanguages) - val learnedWordsMap = - wordLearningEngine.getLearnedWordsForSwipeAllLanguages( - compatibleLanguages, - minLength, - maxLength, - ) - - val mergedMap = HashMap(dictionaryWordsMap.size + learnedWordsMap.size) - dictionaryWordsMap.forEach { (word, freq) -> mergedMap[word] = freq } - learnedWordsMap.forEach { (word, freq) -> - mergedMap[word] = maxOf(mergedMap[word] ?: 0, freq) - } - - cachedSwipeDictionary = mergedMap - cachedLanguageCombination = compatibleLanguages - cachedScriptCode = currentScriptCode - - mergedMap - } - - private fun updateAdaptiveSigmaCache(keyPositions: Map) { - val positionsHash = keyPositions.hashCode() - if (positionsHash != lastKeyPositionsHash) { - val newSigmas = mutableMapOf() - keyPositions.keys.forEach { char -> - newSigmas[char] = pathGeometryAnalyzer.calculateAdaptiveSigma(char, keyPositions) - } - cachedAdaptiveSigmas = newSigmas - cachedKeyNeighborhoods = pathGeometryAnalyzer.computeKeyNeighborhoods(keyPositions) - lastKeyPositionsHash = positionsHash - } - } - - private fun buildDictionaryIndex( - wordFrequencyMap: Map, - minLength: Int, - maxLength: Int, - ): Map> { - val filtered = wordFrequencyMap.entries - .filter { (word, _) -> word.length in minLength..maxLength } - .sortedByDescending { it.value } - return filtered.mapIndexed { rank, (word, frequency) -> - DictionaryEntry( - word = word, - frequencyScore = ln(frequency.toFloat() + 1f) / 20f, - rawFrequency = frequency.toLong(), - firstChar = wordNormalizer.stripDiacritics(word.first().toString()).first().lowercaseChar(), - uniqueLetterCount = word.toSet().size, - frequencyTier = FrequencyTier.fromRank(rank), - ) - }.groupBy { it.firstChar } - } private fun calculateVelocity(event: MotionEvent): Float { - if (swipePoints.size < 2) return 0.0f + if (ringBuffer.size < 2) return 0.0f - val lastPoint = swipePoints.lastOrNull() ?: return 0.0f + val lastPoint = ringBuffer.peekLast() ?: return 0.0f val distance = calculateDistance(lastPoint.x, lastPoint.y, event.x, event.y) - val timeDelta = System.currentTimeMillis() - lastPoint.timestamp + val timeDelta = event.eventTime - lastPoint.timestamp return if (timeDelta > 0) distance / timeDelta else 0.0f } @@ -1051,10 +769,12 @@ class SwipeDetector private fun reset() { isSwiping = false - swipePoints.clear() + ringBuffer.reset() + interpolator.reset() startTime = 0L pointCounter = 0 - firstPoint = null + firstPointX = 0f + firstPointY = 0f startingKey = null lastDeltaX = 0f directionReversals = 0 @@ -1068,24 +788,19 @@ class SwipeDetector fun cleanup() { scoringJob?.cancel() scopeJob.cancel() + streamingScoringEngine.cancelActiveGesture() _swipeListener = null keyCharacterPositions = emptyMap() - cachedAdaptiveSigmas = emptyMap() - cachedKeyNeighborhoods = emptyMap() reset() } private companion object { - private const val TAG = "SwipeEngine" - - private const val MAX_SWIPE_POINTS = 500 private const val MIN_SAMPLING_INTERVAL = 2 private const val MAX_SAMPLING_INTERVAL = 8 private const val ADAPTIVE_THRESHOLD = 40 private const val ADAPTIVE_THRESHOLD_RATIO = 0.75 private const val MIN_POINT_DISTANCE = 8f private const val MAX_CONSECUTIVE_GAP_PX = 45f - private const val MIN_EXCELLENT_CANDIDATES = 3 private const val SWIPE_TIME_THRESHOLD_MS = 100L private const val SWIPE_START_DISTANCE_DP = 35f private const val MIN_SWIPE_POINTS_FOR_SAMPLING = 3 @@ -1093,7 +808,6 @@ class SwipeDetector private const val UI_UPDATE_INTERVAL_MS = 16 private const val TAP_DURATION_THRESHOLD_MS = 350L private const val MAX_SWIPE_VELOCITY_PX_PER_MS = 5f - private const val EXCELLENT_CANDIDATE_THRESHOLD = 0.95f private const val PECK_LATE_DISPLACEMENT_RATIO = 0.95f private const val HIGH_VELOCITY_DISTANCE_MULTIPLIER = 1.5f private const val GHOST_DENSITY_VELOCITY_GATE = 2.0f @@ -1102,10 +816,5 @@ class SwipeDetector private const val GHOST_START_INTENT_POINTS = 4 private const val GHOST_START_SLOWDOWN_RATIO = 0.7f private const val GHOST_IMPOSSIBLE_GAP_PX = 200f - - private const val VELOCITY_INTERPOLATION_THRESHOLD = 1.1f - private const val MAX_INTERPOLATED_POINTS = 3 - private const val INTERPOLATION_MIN_GAP_PX = 25f - private const val LARGE_GAP_INTERPOLATION_THRESHOLD_PX = 60f } } diff --git a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipePointRingBuffer.kt b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipePointRingBuffer.kt new file mode 100644 index 00000000..2c9e81bc --- /dev/null +++ b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipePointRingBuffer.kt @@ -0,0 +1,80 @@ +package com.urik.keyboard.ui.keyboard.components + +/** Fixed-capacity ring buffer for swipe touch coordinates. */ +class SwipePointRingBuffer { + + class Slot { + var x: Float = 0f + var y: Float = 0f + var timestamp: Long = 0L + var pressure: Float = 0f + var velocity: Float = 0f + + fun reset() { + x = 0f + y = 0f + timestamp = 0L + pressure = 0f + velocity = 0f + } + + fun toSwipePoint(): SwipeDetector.SwipePoint = + SwipeDetector.SwipePoint( + x = x, + y = y, + timestamp = timestamp, + pressure = pressure, + velocity = velocity, + ) + } + + private val slots = Array(CAPACITY) { Slot() } + private var head = 0 + private var count = 0 + + val size: Int get() = count + + fun write(x: Float, y: Float, timestamp: Long, pressure: Float, velocity: Float) { + val slot = slots[head] + slot.x = x + slot.y = y + slot.timestamp = timestamp + slot.pressure = pressure + slot.velocity = velocity + + head = (head + 1) and MASK + if (count < CAPACITY) count++ + } + + fun peekLast(): SwipeDetector.SwipePoint? { + if (count == 0) return null + val index = (head - 1 + CAPACITY) and MASK + return slots[index].toSwipePoint() + } + + fun snapshot(): List { + if (count == 0) return emptyList() + + val result = ArrayList(count) + val tail = (head - count + CAPACITY) and MASK + for (i in 0 until count) { + val index = (tail + i) and MASK + result.add(slots[index].toSwipePoint()) + } + + return result + } + + fun reset() { + for (slot in slots) { + slot.reset() + } + head = 0 + count = 0 + } + + companion object { + const val CAPACITY = 512 + private const val MASK = CAPACITY - 1 + } +} diff --git a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeSignal.kt b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeSignal.kt index d9f39f24..b4127e37 100644 --- a/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeSignal.kt +++ b/app/src/main/java/com/urik/keyboard/ui/keyboard/components/SwipeSignal.kt @@ -5,8 +5,8 @@ package com.urik.keyboard.ui.keyboard.components import android.graphics.PointF import com.urik.keyboard.KeyboardConstants.GeometricScoringConstants import kotlin.math.ln -import kotlin.math.min import kotlin.math.max +import kotlin.math.min import kotlin.math.sqrt /** @@ -32,6 +32,7 @@ class SwipeSignal private constructor( val spatialWeight: Float, val frequencyWeight: Float, val pointZeroDominant: Boolean, + val rawPointCount: Int, ) { data class PathBounds( val minX: Float, @@ -86,6 +87,7 @@ class SwipeSignal private constructor( keyPositions: Map, pathGeometryAnalyzer: PathGeometryAnalyzer, cachedAdaptiveSigmas: Map, + rawPointCount: Int = interpolatedPath.size, ): SwipeSignal { val geometricAnalysis = pathGeometryAnalyzer.analyze(interpolatedPath, keyPositions) @@ -111,9 +113,10 @@ class SwipeSignal private constructor( for ((key, traversal) in geometricAnalysis.traversedKeys) { val lc = key.lowercaseChar() if (traversal.velocityAtKey > PASSTHROUGH_VELOCITY_THRESHOLD) { - val hasIntentionalInflection = geometricAnalysis.inflectionPoints.any { inflection -> - inflection.isIntentional && inflection.nearestKey?.lowercaseChar() == lc - } + val hasIntentionalInflection = + geometricAnalysis.inflectionPoints.any { inflection -> + inflection.isIntentional && inflection.nearestKey?.lowercaseChar() == lc + } if (!hasIntentionalInflection) { passthroughKeys.add(lc) } @@ -125,13 +128,18 @@ class SwipeSignal private constructor( if (inflection.isIntentional) intentionalInflectionCount++ } - val expectedWordLength = calculateExpectedWordLength( - intentionalInflectionCount, interpolatedPath.size, - ) - - val startAnchor = buildStartAnchor( - interpolatedPath, keyPositions, pointZeroDominant, - ) + val expectedWordLength = + calculateExpectedWordLength( + intentionalInflectionCount, + rawPointCount, + ) + + val startAnchor = + buildStartAnchor( + interpolatedPath, + keyPositions, + pointZeroDominant, + ) val endAnchor = buildEndAnchor(interpolatedPath, keyPositions) return SwipeSignal( @@ -151,6 +159,7 @@ class SwipeSignal private constructor( spatialWeight = baselineSpatialWeight, frequencyWeight = baselineFreqWeight, pointZeroDominant = pointZeroDominant, + rawPointCount = rawPointCount, ) } @@ -221,15 +230,16 @@ class SwipeSignal private constructor( intentionalInflectionCount: Int, pathSize: Int, ): Int { - val maxInflectionLength = when { - pathSize < 35 -> 3 - pathSize < 50 -> 4 - pathSize < 70 -> 6 - pathSize < 100 -> 8 - pathSize < 150 -> 12 - pathSize < 200 -> 16 - else -> 20 - } + val maxInflectionLength = + when { + pathSize < 35 -> 3 + pathSize < 50 -> 4 + pathSize < 70 -> 6 + pathSize < 100 -> 8 + pathSize < 150 -> 12 + pathSize < 200 -> 16 + else -> 20 + } val inflectionBasedLength = (intentionalInflectionCount + 2).coerceIn(2, maxInflectionLength) val pathPointBasedLength = (pathSize / 14).coerceIn(2, 20) return maxOf(inflectionBasedLength, pathPointBasedLength) @@ -245,27 +255,29 @@ class SwipeSignal private constructor( val candidateKeys = findCandidateStartKeys(centroid, path, keyPositions, backprojected) val pointZero = path[0] - val keyDistances = candidateKeys.associateWith { char -> - val keyPos = keyPositions[char] ?: return@associateWith Float.MAX_VALUE - val dxC = keyPos.x - centroid.x - val dyC = keyPos.y - centroid.y - val distCentroid = sqrt(dxC * dxC + dyC * dyC) - val dxP = keyPos.x - pointZero.x - val dyP = keyPos.y - pointZero.y - val distPointZero = sqrt(dxP * dxP + dyP * dyP) - val weightedPointZero = if (pointZeroDominant) { - distPointZero * POINT_ZERO_DISTANCE_WEIGHT - } else { - distPointZero - } - var best = minOf(distCentroid, weightedPointZero) - if (backprojected != null) { - val dxB = keyPos.x - backprojected.x - val dyB = keyPos.y - backprojected.y - best = minOf(best, sqrt(dxB * dxB + dyB * dyB)) + val keyDistances = + candidateKeys.associateWith { char -> + val keyPos = keyPositions[char] ?: return@associateWith Float.MAX_VALUE + val dxC = keyPos.x - centroid.x + val dyC = keyPos.y - centroid.y + val distCentroid = sqrt(dxC * dxC + dyC * dyC) + val dxP = keyPos.x - pointZero.x + val dyP = keyPos.y - pointZero.y + val distPointZero = sqrt(dxP * dxP + dyP * dyP) + val weightedPointZero = + if (pointZeroDominant) { + distPointZero * POINT_ZERO_DISTANCE_WEIGHT + } else { + distPointZero + } + var best = minOf(distCentroid, weightedPointZero) + if (backprojected != null) { + val dxB = keyPos.x - backprojected.x + val dyB = keyPos.y - backprojected.y + best = minOf(best, sqrt(dxB * dxB + dyB * dyB)) + } + best } - best - } val closestKey = keyDistances.minByOrNull { it.value }?.key var pointZeroNearest: Char? = null @@ -287,10 +299,12 @@ class SwipeSignal private constructor( } } - val anchorThresholdSq = GeometricScoringConstants.VERTEX_MIN_SEGMENT_LENGTH_PX * - GeometricScoringConstants.VERTEX_MIN_SEGMENT_LENGTH_PX - val isAmbiguous = pointZeroSecond != null && - (pointZeroSecondDistSq - pointZeroMinDistSq) < anchorThresholdSq + val anchorThresholdSq = + GeometricScoringConstants.VERTEX_MIN_SEGMENT_LENGTH_PX * + GeometricScoringConstants.VERTEX_MIN_SEGMENT_LENGTH_PX + val isAmbiguous = + pointZeroSecond != null && + (pointZeroSecondDistSq - pointZeroMinDistSq) < anchorThresholdSq val isAnchorLocked = pointZeroMinDistSq < anchorThresholdSq return StartAnchor( @@ -308,21 +322,23 @@ class SwipeSignal private constructor( private fun computeStartCentroid(path: List): PointF { if (path.isEmpty()) return PointF(0f, 0f) - val startVelocity = if (path.size >= 2) { - val p0 = path[0] - val p1 = path[1] - val dt = (p1.timestamp - p0.timestamp).coerceAtLeast(1L).toFloat() - val dx = p1.x - p0.x - val dy = p1.y - p0.y - sqrt(dx * dx + dy * dy) / dt - } else { - 0f - } - val sampleCount = if (startVelocity > HIGH_VELOCITY_START_THRESHOLD) { - START_CENTROID_POINTS_FAST - } else { - START_CENTROID_POINTS_NORMAL - } + val startVelocity = + if (path.size >= 2) { + val p0 = path[0] + val p1 = path[1] + val dt = (p1.timestamp - p0.timestamp).coerceAtLeast(1L).toFloat() + val dx = p1.x - p0.x + val dy = p1.y - p0.y + sqrt(dx * dx + dy * dy) / dt + } else { + 0f + } + val sampleCount = + if (startVelocity > HIGH_VELOCITY_START_THRESHOLD) { + START_CENTROID_POINTS_FAST + } else { + START_CENTROID_POINTS_NORMAL + } val n = minOf(sampleCount, path.size) var sumX = 0f var sumY = 0f @@ -353,11 +369,12 @@ class SwipeSignal private constructor( val normX = vecX / vecLen val normY = vecY / vecLen - val projectionDist = minOf( - BACKPROJECTION_BASE_PX + - BACKPROJECTION_LOG_SCALE * ln(startVelocity), - BACKPROJECTION_MAX_PX, - ) + val projectionDist = + minOf( + BACKPROJECTION_BASE_PX + + BACKPROJECTION_LOG_SCALE * ln(startVelocity), + BACKPROJECTION_MAX_PX, + ) return PointF(p0.x - normX * projectionDist, p0.y - normY * projectionDist) } @@ -369,65 +386,74 @@ class SwipeSignal private constructor( ): Set { if (path.isEmpty()) return emptySet() - val startVelocity = if (path.size >= 2) { - val p0 = path[0] - val p1 = path[1] - val dt = (p1.timestamp - p0.timestamp).coerceAtLeast(1L).toFloat() - val dx = p1.x - p0.x - val dy = p1.y - p0.y - sqrt(dx * dx + dy * dy) / dt - } else { - 0f - } + val startVelocity = + if (path.size >= 2) { + val p0 = path[0] + val p1 = path[1] + val dt = (p1.timestamp - p0.timestamp).coerceAtLeast(1L).toFloat() + val dx = p1.x - p0.x + val dy = p1.y - p0.y + sqrt(dx * dx + dy * dy) / dt + } else { + 0f + } val baseThresholdSq = CLOSE_KEY_DISTANCE_THRESHOLD_SQ - val effectiveThresholdSq = when { - startVelocity > EXTREME_VELOCITY_START_THRESHOLD -> { - val m = EXTREME_VELOCITY_RADIUS_MULTIPLIER - baseThresholdSq * m * m - } - startVelocity > HIGH_VELOCITY_START_THRESHOLD -> { - val m = VELOCITY_EXPANDED_RADIUS_MULTIPLIER - baseThresholdSq * m * m + val effectiveThresholdSq = + when { + startVelocity > EXTREME_VELOCITY_START_THRESHOLD -> { + val m = EXTREME_VELOCITY_RADIUS_MULTIPLIER + baseThresholdSq * m * m + } + + startVelocity > HIGH_VELOCITY_START_THRESHOLD -> { + val m = VELOCITY_EXPANDED_RADIUS_MULTIPLIER + baseThresholdSq * m * m + } + + else -> { + baseThresholdSq + } } - else -> baseThresholdSq - } - val centroidKeys = keyPositions.entries - .map { (char, pos) -> - val dx = pos.x - centroid.x - val dy = pos.y - centroid.y - char to (dx * dx + dy * dy) - }.sortedBy { it.second } - .take(8) - .filter { it.second < effectiveThresholdSq } - .map { it.first } - .toSet() + val centroidKeys = + keyPositions.entries + .map { (char, pos) -> + val dx = pos.x - centroid.x + val dy = pos.y - centroid.y + char to (dx * dx + dy * dy) + }.sortedBy { it.second } + .take(8) + .filter { it.second < effectiveThresholdSq } + .map { it.first } + .toSet() val firstPoint = path[0] - val pointZeroKeys = keyPositions.entries - .map { (char, pos) -> - val dx = pos.x - firstPoint.x - val dy = pos.y - firstPoint.y - char to (dx * dx + dy * dy) - }.sortedBy { it.second } - .take(POINT_ZERO_PROXIMITY_COUNT) - .map { it.first } - .toSet() - - val backprojKeys = if (backprojectedStart != null) { + val pointZeroKeys = keyPositions.entries .map { (char, pos) -> - val dx = pos.x - backprojectedStart.x - val dy = pos.y - backprojectedStart.y + val dx = pos.x - firstPoint.x + val dy = pos.y - firstPoint.y char to (dx * dx + dy * dy) }.sortedBy { it.second } .take(POINT_ZERO_PROXIMITY_COUNT) .map { it.first } .toSet() - } else { - emptySet() - } + + val backprojKeys = + if (backprojectedStart != null) { + keyPositions.entries + .map { (char, pos) -> + val dx = pos.x - backprojectedStart.x + val dy = pos.y - backprojectedStart.y + char to (dx * dx + dy * dy) + }.sortedBy { it.second } + .take(POINT_ZERO_PROXIMITY_COUNT) + .map { it.first } + .toSet() + } else { + emptySet() + } return centroidKeys + pointZeroKeys + backprojKeys } @@ -447,11 +473,12 @@ class SwipeSignal private constructor( endCentroidY /= endN val centroid = PointF(endCentroidX, endCentroidY) - val keyDistances = keyPositions.mapValues { (_, keyPos) -> - val dx = keyPos.x - endCentroidX - val dy = keyPos.y - endCentroidY - sqrt(dx * dx + dy * dy) - } + val keyDistances = + keyPositions.mapValues { (_, keyPos) -> + val dx = keyPos.x - endCentroidX + val dy = keyPos.y - endCentroidY + sqrt(dx * dx + dy * dy) + } val closestKey = keyDistances.minByOrNull { it.value }?.key return EndAnchor( diff --git a/app/src/test/java/com/urik/keyboard/integration/SwipeInputIntegrationTest.kt b/app/src/test/java/com/urik/keyboard/integration/SwipeInputIntegrationTest.kt index 8232560f..3c8c8784 100644 --- a/app/src/test/java/com/urik/keyboard/integration/SwipeInputIntegrationTest.kt +++ b/app/src/test/java/com/urik/keyboard/integration/SwipeInputIntegrationTest.kt @@ -18,6 +18,7 @@ import com.urik.keyboard.settings.KeyboardSettings import com.urik.keyboard.settings.SettingsRepository import com.urik.keyboard.ui.keyboard.components.PathGeometryAnalyzer import com.urik.keyboard.ui.keyboard.components.ResidualScorer +import com.urik.keyboard.ui.keyboard.components.StreamingScoringEngine import com.urik.keyboard.ui.keyboard.components.SwipeDetector import com.urik.keyboard.ui.keyboard.components.ZipfCheck import com.urik.keyboard.utils.CacheMemoryManager @@ -155,16 +156,16 @@ class SwipeInputIntegrationTest { val pathGeometryAnalyzer = PathGeometryAnalyzer() val residualScorer = ResidualScorer(pathGeometryAnalyzer) val zipfCheck = ZipfCheck(spellCheckManager) - swipeDetector = - SwipeDetector( - spellCheckManager, - wordLearningEngine, - pathGeometryAnalyzer, - wordFrequencyRepository, - residualScorer, - zipfCheck, - wordNormalizer, - ) + val streamingScoringEngine = StreamingScoringEngine( + spellCheckManager, + wordLearningEngine, + pathGeometryAnalyzer, + wordFrequencyRepository, + residualScorer, + zipfCheck, + wordNormalizer, + ) + swipeDetector = SwipeDetector(streamingScoringEngine) } @After diff --git a/app/src/test/java/com/urik/keyboard/ui/keyboard/components/GestureInterpolatorTest.kt b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/GestureInterpolatorTest.kt new file mode 100644 index 00000000..1d996c4b --- /dev/null +++ b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/GestureInterpolatorTest.kt @@ -0,0 +1,207 @@ +package com.urik.keyboard.ui.keyboard.components + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import kotlin.math.sqrt + +class GestureInterpolatorTest { + + private lateinit var interpolator: GestureInterpolator + private lateinit var ringBuffer: SwipePointRingBuffer + + @Before + fun setup() { + ringBuffer = SwipePointRingBuffer() + interpolator = GestureInterpolator(ringBuffer) + } + + @Test + fun `first point passes through directly`() { + interpolator.onRawPoint(100f, 200f, 1000L, 1f, 0f) + assertEquals(1, ringBuffer.size) + } + + @Test + fun `second point passes through directly`() { + interpolator.onRawPoint(100f, 200f, 1000L, 1f, 0f) + interpolator.onRawPoint(110f, 200f, 1010L, 1f, 1f) + assertEquals(2, ringBuffer.size) + } + + @Test + fun `third point passes through directly`() { + interpolator.onRawPoint(100f, 200f, 1000L, 1f, 0f) + interpolator.onRawPoint(110f, 200f, 1010L, 1f, 1f) + interpolator.onRawPoint(120f, 200f, 1020L, 1f, 1f) + assertEquals(3, ringBuffer.size) + } + + @Test + fun `fourth point triggers spline interpolation with intermediate points`() { + interpolator.onRawPoint(0f, 0f, 0L, 1f, 0f) + interpolator.onRawPoint(50f, 0f, 10L, 1f, 5f) + interpolator.onRawPoint(100f, 50f, 20L, 1f, 5f) + interpolator.onRawPoint(150f, 50f, 30L, 1f, 5f) + + assertTrue( + "Spline should produce more points than raw input", + ringBuffer.size > 4, + ) + } + + @Test + fun `slow movement below 6px gap does not interpolate`() { + interpolator.onRawPoint(100f, 200f, 1000L, 1f, 0f) + interpolator.onRawPoint(102f, 200f, 1010L, 1f, 0.2f) + interpolator.onRawPoint(104f, 200f, 1020L, 1f, 0.2f) + interpolator.onRawPoint(106f, 200f, 1030L, 1f, 0.2f) + + assertEquals( + "Close points should pass through without interpolation", + 4, + ringBuffer.size, + ) + } + + @Test + fun `fast movement with 60px plus gap caps at 10 interpolated points`() { + interpolator.onRawPoint(0f, 0f, 0L, 1f, 0f) + interpolator.onRawPoint(100f, 0f, 5L, 1f, 20f) + interpolator.onRawPoint(200f, 100f, 10L, 1f, 20f) + interpolator.onRawPoint(400f, 100f, 15L, 1f, 40f) + + assertTrue( + "Large gap should interpolate but cap at max density", + ringBuffer.size <= 4 + 10 + 10, + ) + } + + @Test + fun `interpolated points lie between control points spatially`() { + interpolator.onRawPoint(0f, 0f, 0L, 1f, 0f) + interpolator.onRawPoint(100f, 0f, 10L, 1f, 10f) + interpolator.onRawPoint(200f, 100f, 20L, 1f, 10f) + interpolator.onRawPoint(300f, 100f, 30L, 1f, 10f) + + val snapshot = ringBuffer.snapshot() + + for (point in snapshot) { + assertTrue("X should be in range", point.x >= -50f && point.x <= 350f) + assertTrue("Y should be in range", point.y >= -50f && point.y <= 150f) + } + } + + @Test + fun `timestamps are monotonically increasing`() { + interpolator.onRawPoint(0f, 0f, 100L, 1f, 0f) + interpolator.onRawPoint(50f, 10f, 200L, 1f, 0.5f) + interpolator.onRawPoint(100f, 20f, 300L, 1f, 0.5f) + interpolator.onRawPoint(200f, 30f, 400L, 1f, 1f) + + val snapshot = ringBuffer.snapshot() + for (i in 1 until snapshot.size) { + assertTrue( + "Timestamps must be monotonically increasing", + snapshot[i].timestamp >= snapshot[i - 1].timestamp, + ) + } + } + + @Test + fun `reset clears sliding window`() { + interpolator.onRawPoint(0f, 0f, 0L, 1f, 0f) + interpolator.onRawPoint(50f, 0f, 10L, 1f, 5f) + interpolator.onRawPoint(100f, 0f, 20L, 1f, 5f) + + interpolator.reset() + ringBuffer.reset() + + interpolator.onRawPoint(200f, 200f, 100L, 1f, 0f) + val snapshot = ringBuffer.snapshot() + + assertEquals(1, snapshot.size) + assertEquals(200f, snapshot[0].x, 0.001f) + } + + @Test + fun `sharp 90 degree turn does not produce spline overshoot`() { + interpolator.onRawPoint(0f, 100f, 0L, 1f, 0f) + interpolator.onRawPoint(50f, 100f, 10L, 1f, 5f) + interpolator.onRawPoint(100f, 100f, 20L, 1f, 5f) + interpolator.onRawPoint(100f, 50f, 30L, 1f, 5f) + + val snapshot = ringBuffer.snapshot() + + val marginPx = 80f + for ((i, point) in snapshot.withIndex()) { + assertTrue( + "Point $i overshoots X: ${point.x}", + point.x >= -marginPx && point.x <= 100f + marginPx, + ) + assertTrue( + "Point $i overshoots Y: ${point.y}", + point.y >= 50f - marginPx && point.y <= 100f + marginPx, + ) + } + } + + @Test + fun `U-shaped reversal does not create loop artifacts`() { + interpolator.onRawPoint(0f, 0f, 0L, 1f, 0f) + interpolator.onRawPoint(80f, 0f, 10L, 1f, 8f) + interpolator.onRawPoint(80f, 80f, 20L, 1f, 8f) + interpolator.onRawPoint(0f, 80f, 30L, 1f, 8f) + + val snapshot = ringBuffer.snapshot() + + var hasBacktrack = false + for (i in 1 until snapshot.size) { + val dx = snapshot[i].x - snapshot[i - 1].x + val dy = snapshot[i].y - snapshot[i - 1].y + val dist = sqrt(dx * dx + dy * dy) + if (dist > 100f) { + hasBacktrack = true + } + } + + assertFalse( + "U-turn should not produce large inter-point jumps (loop artifact)", + hasBacktrack, + ) + } + + @Test + fun `continuous multi-segment gesture produces monotonic X progression for straight swipe`() { + interpolator.onRawPoint(0f, 100f, 0L, 1f, 0f) + interpolator.onRawPoint(40f, 100f, 10L, 1f, 4f) + interpolator.onRawPoint(80f, 100f, 20L, 1f, 4f) + interpolator.onRawPoint(120f, 100f, 30L, 1f, 4f) + interpolator.onRawPoint(160f, 100f, 40L, 1f, 4f) + interpolator.onRawPoint(200f, 100f, 50L, 1f, 4f) + + val snapshot = ringBuffer.snapshot() + + for (i in 1 until snapshot.size) { + assertTrue( + "X should be monotonically increasing for a rightward swipe, but point $i: ${snapshot[i].x} < ${snapshot[i - 1].x}", + snapshot[i].x >= snapshot[i - 1].x - 1f, + ) + } + } + + @Test + fun `high velocity segment does not exceed 10 interpolated points per segment`() { + interpolator.onRawPoint(0f, 0f, 0L, 1f, 0f) + interpolator.onRawPoint(20f, 0f, 5L, 1f, 4f) + interpolator.onRawPoint(40f, 0f, 10L, 1f, 4f) + interpolator.onRawPoint(300f, 0f, 15L, 1f, 52f) + + assertTrue( + "Even with 260px gap, should not exceed 4 raw + 10 + 10 interpolated", + ringBuffer.size <= 24, + ) + } +} diff --git a/app/src/test/java/com/urik/keyboard/ui/keyboard/components/StreamingScoringEngineTest.kt b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/StreamingScoringEngineTest.kt new file mode 100644 index 00000000..91a044c9 --- /dev/null +++ b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/StreamingScoringEngineTest.kt @@ -0,0 +1,309 @@ +package com.urik.keyboard.ui.keyboard.components + +import android.graphics.PointF +import com.urik.keyboard.data.WordFrequencyRepository +import com.urik.keyboard.service.SpellCheckManager +import com.urik.keyboard.service.WordLearningEngine +import com.urik.keyboard.service.WordNormalizer +import kotlinx.coroutines.ExperimentalCoroutinesApi +import kotlinx.coroutines.test.runTest +import org.junit.After +import org.junit.Assert.assertEquals +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test +import org.junit.runner.RunWith +import org.mockito.Mock +import org.mockito.MockitoAnnotations +import org.robolectric.RobolectricTestRunner + +@OptIn(ExperimentalCoroutinesApi::class) +@RunWith(RobolectricTestRunner::class) +class StreamingScoringEngineTest { + + @Mock private lateinit var spellCheckManager: SpellCheckManager + @Mock private lateinit var wordLearningEngine: WordLearningEngine + @Mock private lateinit var pathGeometryAnalyzer: PathGeometryAnalyzer + @Mock private lateinit var wordFrequencyRepository: WordFrequencyRepository + @Mock private lateinit var residualScorer: ResidualScorer + @Mock private lateinit var zipfCheck: ZipfCheck + @Mock private lateinit var wordNormalizer: WordNormalizer + + private lateinit var engine: StreamingScoringEngine + private lateinit var closeable: AutoCloseable + + private val qwertyKeyPositions = mapOf( + 'q' to PointF(30f, 50f), 'w' to PointF(80f, 50f), 'e' to PointF(130f, 50f), + 'r' to PointF(180f, 50f), 't' to PointF(230f, 50f), 'y' to PointF(280f, 50f), + 'u' to PointF(330f, 50f), 'i' to PointF(380f, 50f), 'o' to PointF(430f, 50f), + 'p' to PointF(480f, 50f), + 'a' to PointF(40f, 130f), 's' to PointF(90f, 130f), 'd' to PointF(140f, 130f), + 'f' to PointF(190f, 130f), 'g' to PointF(240f, 130f), 'h' to PointF(290f, 130f), + 'j' to PointF(340f, 130f), 'k' to PointF(390f, 130f), 'l' to PointF(440f, 130f), + 'z' to PointF(90f, 210f), 'x' to PointF(140f, 210f), 'c' to PointF(190f, 210f), + 'v' to PointF(240f, 210f), 'b' to PointF(290f, 210f), 'n' to PointF(340f, 210f), + 'm' to PointF(390f, 210f), + ) + + @Before + fun setup() { + closeable = MockitoAnnotations.openMocks(this) + engine = StreamingScoringEngine( + spellCheckManager = spellCheckManager, + wordLearningEngine = wordLearningEngine, + pathGeometryAnalyzer = pathGeometryAnalyzer, + wordFrequencyRepository = wordFrequencyRepository, + residualScorer = residualScorer, + zipfCheck = zipfCheck, + wordNormalizer = wordNormalizer, + ) + } + + @After + fun teardown() { + engine.shutdown() + closeable.close() + } + + @Test + fun `cancelActiveGesture clears live candidate set`() = runTest { + engine.startGesture(qwertyKeyPositions, listOf("en"), "en") + engine.cancelActiveGesture() + + val results = engine.finalize(emptyList(), 0) + assertTrue("Cancelled gesture should produce no results", results.isEmpty()) + } + + @Test + fun `startGesture resets state from previous gesture`() = runTest { + engine.startGesture(qwertyKeyPositions, listOf("en"), "en") + engine.cancelActiveGesture() + + engine.startGesture(qwertyKeyPositions, listOf("en"), "en") + engine.cancelActiveGesture() + + val results = engine.finalize(emptyList(), 0) + assertTrue(results.isEmpty()) + } + + @Test + fun `pruneByStartAnchor removes candidates with wrong first letter`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + makeDictionaryEntry("world", 'w', 900), + makeDictionaryEntry("zebra", 'z', 800), + ) + + val startKeys = setOf('h', 'j') + val pruned = engine.pruneByStartAnchor(candidates, startKeys) + + assertEquals(1, pruned.size) + assertEquals("hello", pruned[0].word) + } + + @Test + fun `pruneByBounds removes candidates requiring keys outside bounds`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + makeDictionaryEntry("zebra", 'z', 800), + ) + + val charsInBounds = setOf('h', 'e', 'l', 'o', 'w', 'r', 't') + val pruned = engine.pruneByBounds(candidates, charsInBounds) + + assertEquals(1, pruned.size) + assertEquals("hello", pruned[0].word) + } + + @Test + fun `pruneByBounds applies fingertip safety margin`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + ) + + val charsInBounds = setOf('h', 'e', 'l', 'o') + val pruned = engine.pruneByBounds(candidates, charsInBounds) + + assertEquals( + "All letters within bounds should survive", + 1, + pruned.size, + ) + } + + @Test + fun `pruneByBounds allows one out-of-bounds letter as safety margin`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + ) + + val charsInBounds = setOf('h', 'e', 'l') + val pruned = engine.pruneByBounds(candidates, charsInBounds) + + assertEquals( + "One out-of-bounds letter should be tolerated", + 1, + pruned.size, + ) + } + + @Test + fun `monotonic pruning does not re-add eliminated candidates`() { + val full = listOf( + makeDictionaryEntry("hello", 'h', 1000), + makeDictionaryEntry("world", 'w', 900), + makeDictionaryEntry("zebra", 'z', 800), + ) + + val afterStart = engine.pruneByStartAnchor(full, setOf('h', 'w')) + assertEquals(2, afterStart.size) + + val afterBounds = engine.pruneByBounds(afterStart, setOf('h', 'e', 'l', 'o')) + assertEquals(1, afterBounds.size) + assertEquals("hello", afterBounds[0].word) + } + + @Test + fun `traversal pruning preserves hello with 30 percent overlap threshold`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + makeDictionaryEntry("world", 'w', 900), + makeDictionaryEntry("help", 'h', 800), + ) + + val traversedKeys = setOf('h', 'e', 'l', 'o', 'w', 'r') + val pruned = engine.pruneByTraversal(candidates, traversedKeys) + + val prunedWords = pruned.map { it.word }.toSet() + assertTrue("'hello' should survive traversal (100% overlap)", "hello" in prunedWords) + assertTrue("'world' should survive traversal (75% overlap)", "world" in prunedWords) + assertTrue("'help' should survive traversal (100% overlap)", "help" in prunedWords) + } + + @Test + fun `traversal pruning rejects word with low overlap`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + makeDictionaryEntry("pizza", 'p', 500), + ) + + val traversedKeys = setOf('h', 'e', 'l', 'o') + val pruned = engine.pruneByTraversal(candidates, traversedKeys) + + assertEquals("Only hello should survive", 1, pruned.size) + assertEquals("hello", pruned[0].word) + } + + @Test + fun `bounds pruning tolerates one out-of-bounds char for long words`() { + val candidates = listOf( + makeDictionaryEntry("together", 't', 5000), + ) + + val charsInBounds = setOf('t', 'o', 'g', 'e', 'h', 'r') + val pruned = engine.pruneByBounds(candidates, charsInBounds) + + assertEquals( + "together should survive with 1 out-of-bounds char", + 1, + pruned.size, + ) + } + + @Test + fun `bounds pruning rejects word with 2 plus out-of-bounds chars`() { + val candidates = listOf( + makeDictionaryEntry("together", 't', 5000), + ) + + val charsInBounds = setOf('t', 'o', 'g') + val pruned = engine.pruneByBounds(candidates, charsInBounds) + + assertEquals( + "together should be rejected with many out-of-bounds chars", + 0, + pruned.size, + ) + } + + @Test + fun `start anchor pruning preserves words matching any start key`() { + val candidates = listOf( + makeDictionaryEntry("another", 'a', 2000), + makeDictionaryEntry("seven", 's', 1500), + makeDictionaryEntry("together", 't', 1000), + ) + + val startKeys = setOf('a', 's') + val pruned = engine.pruneByStartAnchor(candidates, startKeys) + + assertEquals(2, pruned.size) + val prunedWords = pruned.map { it.word }.toSet() + assertTrue("another" in prunedWords) + assertTrue("seven" in prunedWords) + } + + @Test + fun `cascaded pruning does not eliminate common words prematurely`() { + val commonWords = listOf( + makeDictionaryEntry("the", 't', 100_000_000), + makeDictionaryEntry("hello", 'h', 50_000_000), + makeDictionaryEntry("world", 'w', 30_000_000), + makeDictionaryEntry("picture", 'p', 10_000_000), + makeDictionaryEntry("together", 't', 8_000_000), + makeDictionaryEntry("another", 'a', 5_000_000), + makeDictionaryEntry("proctor", 'p', 100), + makeDictionaryEntry("zebra", 'z', 50), + ) + + val startKeys = setOf('h', 'p', 't', 'w', 'a') + val afterAnchor = engine.pruneByStartAnchor(commonWords, startKeys) + assertTrue( + "Anchor prune should keep most common words", + afterAnchor.size >= 6, + ) + + val charsInBounds = setOf('h', 'e', 'l', 'o', 'p', 'i', 'c', 't', 'u', 'r') + val afterBounds = engine.pruneByBounds(afterAnchor, charsInBounds) + val survivingWords = afterBounds.map { it.word }.toSet() + + assertTrue( + "hello should survive full cascade", + "hello" in survivingWords, + ) + assertTrue( + "picture should survive full cascade", + "picture" in survivingWords, + ) + } + + @Test + fun `traversal with sparse traversed keys does not prune`() { + val candidates = listOf( + makeDictionaryEntry("hello", 'h', 1000), + makeDictionaryEntry("world", 'w', 900), + ) + + val sparseKeys = setOf('h') + val pruned = engine.pruneByTraversal(candidates, sparseKeys) + + assertEquals( + "Sparse traversed keys (size < 2) should skip pruning entirely", + 2, + pruned.size, + ) + } + + private fun makeDictionaryEntry( + word: String, + firstChar: Char, + frequency: Long, + ): SwipeDetector.DictionaryEntry = + SwipeDetector.DictionaryEntry( + word = word, + frequencyScore = 0.5f, + rawFrequency = frequency, + firstChar = firstChar, + uniqueLetterCount = word.toSet().size, + ) +} diff --git a/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipeDetectorTest.kt b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipeDetectorTest.kt index ab67c25c..e1d6adef 100644 --- a/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipeDetectorTest.kt +++ b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipeDetectorTest.kt @@ -2,11 +2,7 @@ package com.urik.keyboard.ui.keyboard.components import android.graphics.PointF import android.view.MotionEvent -import com.urik.keyboard.data.WordFrequencyRepository import com.urik.keyboard.model.KeyboardKey -import com.urik.keyboard.service.SpellCheckManager -import com.urik.keyboard.service.WordLearningEngine -import com.urik.keyboard.service.WordNormalizer import kotlinx.coroutines.ExperimentalCoroutinesApi import org.junit.After import org.junit.Assert.assertFalse @@ -26,25 +22,7 @@ import org.robolectric.RobolectricTestRunner @RunWith(RobolectricTestRunner::class) class SwipeDetectorTest { @Mock - private lateinit var spellCheckManager: SpellCheckManager - - @Mock - private lateinit var wordLearningEngine: WordLearningEngine - - @Mock - private lateinit var pathGeometryAnalyzer: PathGeometryAnalyzer - - @Mock - private lateinit var wordNormalizer: WordNormalizer - - @Mock - private lateinit var wordFrequencyRepository: WordFrequencyRepository - - @Mock - private lateinit var residualScorer: ResidualScorer - - @Mock - private lateinit var zipfCheck: ZipfCheck + private lateinit var streamingScoringEngine: StreamingScoringEngine @Mock private lateinit var swipeListener: SwipeDetector.SwipeListener @@ -55,16 +33,7 @@ class SwipeDetectorTest { @Before fun setup() { closeable = MockitoAnnotations.openMocks(this) - swipeDetector = - SwipeDetector( - spellCheckManager, - wordLearningEngine, - pathGeometryAnalyzer, - wordFrequencyRepository, - residualScorer, - zipfCheck, - wordNormalizer, - ) + swipeDetector = SwipeDetector(streamingScoringEngine) swipeDetector.setSwipeListener(swipeListener) } diff --git a/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipePointRingBufferTest.kt b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipePointRingBufferTest.kt new file mode 100644 index 00000000..133be353 --- /dev/null +++ b/app/src/test/java/com/urik/keyboard/ui/keyboard/components/SwipePointRingBufferTest.kt @@ -0,0 +1,165 @@ +package com.urik.keyboard.ui.keyboard.components + +import org.junit.Assert.assertEquals +import org.junit.Assert.assertFalse +import org.junit.Assert.assertTrue +import org.junit.Before +import org.junit.Test + +class SwipePointRingBufferTest { + + private lateinit var buffer: SwipePointRingBuffer + + @Before + fun setup() { + buffer = SwipePointRingBuffer() + } + + @Test + fun `new buffer has zero size`() { + assertEquals(0, buffer.size) + } + + @Test + fun `write and read single point`() { + buffer.write(10f, 20f, 1000L, 0.8f, 1.5f) + assertEquals(1, buffer.size) + + val snapshot = buffer.snapshot() + assertEquals(1, snapshot.size) + assertEquals(10f, snapshot[0].x, 0.001f) + assertEquals(20f, snapshot[0].y, 0.001f) + assertEquals(1000L, snapshot[0].timestamp) + assertEquals(0.8f, snapshot[0].pressure, 0.001f) + assertEquals(1.5f, snapshot[0].velocity, 0.001f) + } + + @Test + fun `write fills to capacity without crash`() { + for (i in 0 until 512) { + buffer.write(i.toFloat(), i.toFloat(), i.toLong(), 1f, 0f) + } + assertEquals(512, buffer.size) + } + + @Test + fun `write beyond capacity overwrites oldest`() { + for (i in 0 until 513) { + buffer.write(i.toFloat(), i.toFloat(), i.toLong(), 1f, 0f) + } + assertEquals(512, buffer.size) + + val snapshot = buffer.snapshot() + assertEquals(1f, snapshot[0].x, 0.001f) + assertEquals(512f, snapshot[511].x, 0.001f) + } + + @Test + fun `snapshot returns points in chronological order`() { + for (i in 0 until 600) { + buffer.write(i.toFloat(), 0f, i.toLong(), 1f, 0f) + } + + val snapshot = buffer.snapshot() + for (i in 1 until snapshot.size) { + assertTrue( + "Points must be chronological", + snapshot[i].timestamp > snapshot[i - 1].timestamp, + ) + } + } + + @Test + fun `reset clears size and zeroes all slots`() { + for (i in 0 until 100) { + buffer.write(i.toFloat(), i.toFloat(), i.toLong(), 1f, 1f) + } + + buffer.reset() + assertEquals(0, buffer.size) + + val snapshot = buffer.snapshot() + assertTrue(snapshot.isEmpty()) + } + + @Test + fun `reset prevents data contamination between gestures`() { + buffer.write(99f, 99f, 99L, 0.5f, 3f) + buffer.reset() + + buffer.write(1f, 2f, 100L, 1f, 0f) + val snapshot = buffer.snapshot() + + assertEquals(1, snapshot.size) + assertEquals(1f, snapshot[0].x, 0.001f) + assertEquals(2f, snapshot[0].y, 0.001f) + } + + @Test + fun `snapshot returns independent copy each call`() { + buffer.write(1f, 2f, 100L, 1f, 0f) + val first = buffer.snapshot() + buffer.write(3f, 4f, 200L, 1f, 0f) + val second = buffer.snapshot() + assertFalse("Snapshots must be independent to avoid ConcurrentModificationException", first === second) + assertEquals("First snapshot unchanged after second write", 1, first.size) + assertEquals("Second snapshot includes both writes", 2, second.size) + } + + @Test + fun `peek returns last written point`() { + buffer.write(10f, 20f, 1000L, 0.8f, 1.5f) + buffer.write(30f, 40f, 2000L, 0.9f, 2.0f) + + val last = buffer.peekLast() + assertEquals(30f, last!!.x, 0.001f) + assertEquals(40f, last.y, 0.001f) + } + + @Test + fun `peek on empty buffer returns null`() { + assertEquals(null, buffer.peekLast()) + } + + @Test + fun `snapshot after wrap-around preserves data integrity`() { + for (i in 0 until 600) { + buffer.write( + i.toFloat(), + (i * 2).toFloat(), + (i * 10).toLong(), + 0.5f + (i % 10) * 0.05f, + i.toFloat() * 0.1f, + ) + } + + val snapshot = buffer.snapshot() + assertEquals(512, snapshot.size) + + assertEquals(88f, snapshot[0].x, 0.001f) + assertEquals(176f, snapshot[0].y, 0.001f) + assertEquals(880L, snapshot[0].timestamp) + + assertEquals(599f, snapshot[511].x, 0.001f) + } + + @Test + fun `multiple wrap-arounds maintain correct ordering`() { + for (i in 0 until 1500) { + buffer.write(i.toFloat(), 0f, i.toLong(), 1f, 0f) + } + + val snapshot = buffer.snapshot() + assertEquals(512, snapshot.size) + + for (i in 1 until snapshot.size) { + assertTrue( + "Points must be chronological after multiple wraps", + snapshot[i].x > snapshot[i - 1].x, + ) + } + + assertEquals(988f, snapshot[0].x, 0.001f) + assertEquals(1499f, snapshot[511].x, 0.001f) + } +}