From 7f2e4c467c9f1c16fe55f512e9e9fd6ddaaaf6fe Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 13 Feb 2025 12:28:48 +0100 Subject: [PATCH 1/2] feat: implementation of ocr for android --- .../java/com/swmansion/rnexecutorch/OCR.kt | 102 ++++ .../rnexecutorch/RnExecutorchPackage.kt | 10 + .../rnexecutorch/models/ocr/Detector.kt | 65 +++ .../models/ocr/RecognitionHandler.kt | 133 +++++ .../rnexecutorch/models/ocr/Recognizer.kt | 56 +++ .../models/ocr/utils/CTCLabelConverter.kt | 75 +++ .../models/ocr/utils/DetectorUtils.kt | 467 ++++++++++++++++++ .../models/ocr/utils/RecognizerUtils.kt | 269 ++++++++++ .../swmansion/rnexecutorch/utils/Fetcher.kt | 5 + .../rnexecutorch/utils/ImageProcessor.kt | 110 ++++- 10 files changed, 1285 insertions(+), 7 deletions(-) create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/OCR.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt diff --git a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt new file mode 100644 index 0000000000..eba3d90b16 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt @@ -0,0 +1,102 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader +import com.swmansion.rnexecutorch.models.ocr.Detector +import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler +import com.swmansion.rnexecutorch.utils.Fetcher +import com.swmansion.rnexecutorch.utils.ResourceType +import org.opencv.imgproc.Imgproc + +const val recognizerRatio = 1.6 + +class OCR(reactContext: ReactApplicationContext) : + NativeOCRSpec(reactContext) { + + private lateinit var detector: Detector + private lateinit var recognitionHandler: RecognitionHandler + + companion object { + const val NAME = "OCR" + } + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + + override fun loadModule( + detectorSource: String, + recognizerSourceLarge: String, + recognizerSourceMedium: String, + recognizerSourceSmall: String, + symbols: String, + languageDictPath: String, + promise: Promise + ) { + try { + detector = Detector(reactApplicationContext) + detector.loadModel(detectorSource) + Fetcher.downloadResource( + reactApplicationContext, + languageDictPath, + ResourceType.TXT, + false, + { path, error -> + if (error != null) { + throw Error(error.message!!) + } + + recognitionHandler = RecognitionHandler( + symbols, + path!!, + reactApplicationContext + ) + + recognitionHandler.loadRecognizers( + recognizerSourceLarge, + recognizerSourceMedium, + recognizerSourceSmall + ) { _, errorRecognizer -> + if (errorRecognizer != null) { + throw Error(errorRecognizer.message!!) + } + + promise.resolve(0) + } + }) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward(input: String, promise: Promise) { + try { + val inputImage = ImageProcessor.readImage(input) + val bBoxesList = detector.runModel(inputImage) + val detectorSize = detector.getModelImageSize() + Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY) + val result = recognitionHandler.recognize( + bBoxesList, + inputImage, + (detectorSize.width * recognizerRatio).toInt(), + (detectorSize.height * recognizerRatio).toInt() + ) + promise.resolve(result) + } catch (e: Exception) { + Log.d("rn_executorch", "Error running model: ${e.message}") + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String { + return NAME + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index 6a4fdc2df9..7ce07b9849 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -23,6 +23,8 @@ class RnExecutorchPackage : TurboReactPackage() { Classification(reactContext) } else if (name == ObjectDetection.NAME) { ObjectDetection(reactContext) + } else if (name == OCR.NAME){ + OCR(reactContext) } else { null @@ -74,6 +76,14 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true ) + moduleInfos[OCR.NAME] = ReactModuleInfo( + OCR.NAME, + OCR.NAME, + false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt new file mode 100644 index 0000000000..025555b5b2 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -0,0 +1,65 @@ +package com.swmansion.rnexecutorch.models.ocr + +import android.util.Log +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +val mean: Scalar = Scalar(0.485, 0.456, 0.406) +val variance: Scalar = Scalar(0.229, 0.224, 0.225) + +class Detector(reactApplicationContext: ReactApplicationContext) : + BaseModel>(reactApplicationContext) { + private lateinit var originalSize: Size + + fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex] + val height = inputShape[inputShape.lastIndex - 1] + + val modelImageSize = Size(height.toDouble(), width.toDouble()) + + return modelImageSize + } + + override fun preprocess(input: Mat): EValue { + originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) + val resizedImage = ImageProcessor.resizeWithPadding( + input, + getModelImageSize().width.toInt(), + getModelImageSize().height.toInt() + ) + + return ImageProcessor.matToEValue(resizedImage, module.getInputShape(0), mean, variance) + } + + override fun postprocess(output: Array): List { + val outputTensor = output[0].toTensor() + val outputArray = outputTensor.dataAsFloatArray + val modelImageSize = getModelImageSize() + + val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( + outputArray, + Size(modelImageSize.width / 2, modelImageSize.height / 2) + ) + var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(scoreText, scoreLink, 0.4, 0.4, 0.7) + bBoxesList = DetectorUtils.restoreBoxRatio(bBoxesList, 3.2f) + bBoxesList = DetectorUtils.groupTextBoxes(bBoxesList, 0.5, 2.0, 2.0, 15, 30, 678) + + return bBoxesList.toList() + } + + override fun runModel(input: Mat): List { + val modelInput = preprocess(input) + val modelOutput = forward(modelInput) + Log.d("rn_executorch", "modelOutput: $modelOutput") + val output = postprocess(modelOutput) + return output + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt new file mode 100644 index 0000000000..81f8e57cad --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt @@ -0,0 +1,133 @@ +package com.swmansion.rnexecutorch.models.ocr + +import android.util.Log +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.WritableArray +import com.swmansion.rnexecutorch.models.ocr.utils.BBoxPoint +import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Core +import org.opencv.core.Mat + +const val modelHeight = 64 +const val largeModelWidth = 512 +const val mediumModelWidth = 256 +const val smallModelWidth = 128 +const val lowConfidenceThreshold = 0.3 +const val adjustContrast = 0.2 + +class RecognitionHandler( + symbols: String, + languageDictPath: String, + reactApplicationContext: ReactApplicationContext +) { + private val recognizerLarge = Recognizer(reactApplicationContext) + private val recognizerMedium = Recognizer(reactApplicationContext) + private val recognizerSmall = Recognizer(reactApplicationContext) + private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key")) + + private fun runModel(croppedImage: Mat): Pair, Double> { + val result: Pair, Double> = if (croppedImage.cols() >= largeModelWidth) { + recognizerLarge.runModel(croppedImage) + } else if (croppedImage.cols() >= mediumModelWidth) { + recognizerMedium.runModel(croppedImage) + } else { + recognizerSmall.runModel(croppedImage) + } + + return result + } + + fun loadRecognizers( + largeRecognizerPath: String, + mediumRecognizerPath: String, + smallRecognizerPath: String, + onComplete: (Int, Exception?) -> Unit + ) { + try { + recognizerLarge.loadModel(largeRecognizerPath) + recognizerMedium.loadModel(mediumRecognizerPath) + recognizerSmall.loadModel(smallRecognizerPath) + onComplete(0, null) + } catch (e: Exception) { + onComplete(1, e) + } + } + + fun recognize( + bBoxesList: List, + imgGray: Mat, + desiredWidth: Int, + desiredHeight: Int + ): WritableArray { + val res: WritableArray = Arguments.createArray() + val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings( + imgGray.width(), + imgGray.height(), + desiredWidth, + desiredHeight + ) + + val left = ratioAndPadding["left"] as Int + val top = ratioAndPadding["top"] as Int + val resizeRatio = ratioAndPadding["resizeRatio"] as Float + val resizedImg = ImageProcessor.resizeWithPadding( + imgGray, + desiredWidth, + desiredHeight + ) + + for (box in bBoxesList) { + var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, modelHeight) + if (croppedImage.empty()) { + continue + } + + croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, adjustContrast) + + var result = runModel(croppedImage) + var confidenceScore = result.second + + if (confidenceScore < lowConfidenceThreshold) { + Core.rotate(croppedImage, croppedImage, Core.ROTATE_180) + val rotatedResult = runModel(croppedImage) + val rotatedConfidenceScore = rotatedResult.second + if (rotatedConfidenceScore > confidenceScore) { + result = rotatedResult + confidenceScore = rotatedConfidenceScore + } + } + + val predIndex = result.first + val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size) + + val bbox = Array(4) { BBoxPoint(0.0, 0.0) } + for (i in 0 until 4) { + bbox[i] = BBoxPoint( + ((box.bBox[i].x - left) * resizeRatio), + ((box.bBox[i].y - top) * resizeRatio) + ) + } + + Log.d("rn_executorch", "confidenceScore: $confidenceScore") + val resMap = Arguments.createMap() + val bboxArray = Arguments.createArray() + bbox.forEach { point -> + val pointMap = Arguments.createMap() + pointMap.putDouble("x", point.x) + pointMap.putDouble("y", point.y) + bboxArray.pushMap(pointMap) + } + resMap.putString("text", decodedTexts[0]) + resMap.putArray("bbox", bboxArray) + resMap.putDouble("confidence", confidenceScore) + + res.pushMap(resMap) + } + + return res + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt new file mode 100644 index 0000000000..2772cc4a98 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt @@ -0,0 +1,56 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class Recognizer(reactApplicationContext: ReactApplicationContext) : + BaseModel, Double>>(reactApplicationContext) { + + private fun getModelOutputSize(): Size { + val outputShape = module.getOutputShape(0) + val width = outputShape[outputShape.lastIndex] + val height = outputShape[outputShape.lastIndex - 1] + + return Size(height.toDouble(), width.toDouble()) + } + + override fun preprocess(input: Mat): EValue { + return ImageProcessor.matToEValueGray(input) + } + + override fun postprocess(output: Array): Pair, Double> { + val modelOutputHeight = getModelOutputSize().height.toInt() + val tensor = output[0].toTensor().dataAsFloatArray + val numElements = tensor.size + val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight + val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F) + var counter = 0 + var currentRow = 0 + for (num in tensor) { + resultMat.put(currentRow, counter, floatArrayOf(num)) + counter++ + if (counter >= modelOutputHeight) { + counter = 0 + currentRow++ + } + } + + var probabilities = RecognizerUtils.softmax(resultMat) + val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight) + probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm) + val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities) + + val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices) + return Pair(indices, confidenceScore) + } + + + override fun runModel(input: Mat): Pair, Double> { + return postprocess(module.forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt new file mode 100644 index 0000000000..336d2f600f --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt @@ -0,0 +1,75 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import java.io.File + +class CTCLabelConverter( + characters: String, + dictPathList: Map +) { + private val dict = mutableMapOf() + val character: List + private val ignoreIdx: List + private val dictList: Map> + + init { + val mutableCharacters = mutableListOf("[blank]") + characters.forEachIndexed { index, char -> + mutableCharacters.add(char.toString()) + dict[char.toString()] = index + 1 + } + character = mutableCharacters.toList() + + val ignoreIndexes = mutableListOf(0) + + ignoreIdx = ignoreIndexes.toList() + + dictList = loadDictionariesWithDictPathList(dictPathList) + } + + private fun loadDictionariesWithDictPathList(dictPathList: Map): Map> { + val tempDictList = mutableMapOf>() + dictPathList.forEach { (lang, dictPath) -> + runCatching { + File(dictPath).readLines() + }.onSuccess { lines -> + tempDictList[lang] = lines + }.onFailure { error -> + println("Error reading file: ${error.localizedMessage}") + } + } + return tempDictList.toMap() + } + + fun decodeGreedy(textIndex: List, length: Int): List { + val texts = mutableListOf() + var index = 0 + while (index < textIndex.size) { + val segmentLength = minOf(length, textIndex.size - index) + val subArray = textIndex.subList(index, index + segmentLength) + + val text = StringBuilder() + var lastChar: Int? = null + val isNotRepeated = mutableListOf(true) + val isNotIgnored = mutableListOf() + + subArray.forEachIndexed { i, currentChar -> + if (i > 0) { + isNotRepeated.add(lastChar != currentChar) + } + isNotIgnored.add(!ignoreIdx.contains(currentChar)) + lastChar = currentChar + } + + subArray.forEachIndexed { j, charIndex -> + if (isNotRepeated[j] && isNotIgnored[j]) { + text.append(character[charIndex]) + } + } + + texts.add(text.toString()) + index += segmentLength + if (segmentLength < length) break + } + return texts.toList() + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt new file mode 100644 index 0000000000..fea965b05e --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -0,0 +1,467 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import android.util.Log +import org.opencv.core.Core +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.MatOfFloat4 +import org.opencv.core.MatOfInt +import org.opencv.core.MatOfPoint +import org.opencv.core.MatOfPoint2f +import org.opencv.core.Point +import org.opencv.core.Rect +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import kotlin.math.abs +import kotlin.math.atan +import kotlin.math.cos +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow +import kotlin.math.sin +import kotlin.math.sqrt + +data class BBoxPoint( + var x: Double, + var y: Double, +) + +data class OCRbBox( + val bBox: List, + val angle: Double, +) + +data class LineInfo( + val slope: Double, + val intercept: Double, + val isVertical: Boolean +) + +class DetectorUtils { + companion object { + private fun normalizeAngle(angle: Double): Double { + if (angle > 45.0) { + return angle - 90.0 + } + + return angle + } + + private fun midpoint(p1: BBoxPoint, p2: BBoxPoint): BBoxPoint { + val midpoint = BBoxPoint((p1.x + p2.x) / 2, (p1.y + p2.y) / 2) + return midpoint + } + + private fun distanceBetweenPoints(p1: BBoxPoint, p2: BBoxPoint): Double { + return sqrt((p1.x - p2.x).pow(2.0) + (p1.y - p2.y).pow(2.0)) + } + + private fun centerOfBox(box: OCRbBox): BBoxPoint { + val p1 = box.bBox[0] + val p2 = box.bBox[2] + return midpoint(p1, p2) + } + + private fun maxSideLength(box: OCRbBox): Double { + var maxSideLength = 0.0 + val numOfPoints = box.bBox.size + for (i in 0 until numOfPoints) { + val currentPoint = box.bBox[i] + val nextPoint = box.bBox[(i + 1) % numOfPoints] + val sideLength = distanceBetweenPoints(currentPoint, nextPoint) + if (sideLength > maxSideLength) { + maxSideLength = sideLength + } + } + return maxSideLength + } + + private fun minSideLength(box: OCRbBox): Double { + var minSideLength = Double.MAX_VALUE + val numOfPoints = box.bBox.size + for (i in 0 until numOfPoints) { + val currentPoint = box.bBox[i] + val nextPoint = box.bBox[(i + 1) % numOfPoints] + val sideLength = distanceBetweenPoints(currentPoint, nextPoint) + if (sideLength < minSideLength) { + minSideLength = sideLength + } + } + return minSideLength + } + + + private fun calculateMinimalDistanceBetweenBoxes(box1: OCRbBox, box2: OCRbBox): Double { + var minDistance = Double.MAX_VALUE + for (i in 0 until 4) { + for (j in 0 until 4) { + val distance = distanceBetweenPoints(box1.bBox[i], box2.bBox[j]) + if (distance < minDistance) { + minDistance = distance + } + } + } + + return minDistance + } + + private fun rotateBox(box: OCRbBox, angle: Double): OCRbBox { + val center = centerOfBox(box) + val radians = angle * Math.PI / 180 + val newBBox = box.bBox.map { point -> + val translatedX = point.x - center.x + val translatedY = point.y - center.y + val rotatedX = translatedX * cos(radians) - translatedY * sin(radians) + val rotatedY = translatedX * sin(radians) + translatedY * cos(radians) + BBoxPoint(rotatedX + center.x, rotatedY + center.y) + } + + return OCRbBox(newBBox, box.angle) + } + + private fun orderPointsClockwise(box: OCRbBox): OCRbBox { + var topLeft = box.bBox[0] + var topRight = box.bBox[1] + var bottomRight = box.bBox[2] + var bottomLeft = box.bBox[3] + var minSum = Double.MAX_VALUE + var maxSum = -Double.MAX_VALUE + var minDiff = Double.MAX_VALUE + var maxDiff = -Double.MAX_VALUE + + for (point in box.bBox) { + val sum = point.x + point.y + val diff = point.x - point.y + if (sum < minSum) { + minSum = sum + topLeft = point + } + if (sum > maxSum) { + maxSum = sum + bottomRight = point + } + if (diff < minDiff) { + minDiff = diff + bottomLeft = point + } + if (diff > maxDiff) { + maxDiff = diff + topRight = point + } + } + + return OCRbBox(listOf(topLeft, topRight, bottomRight, bottomLeft), box.angle) + } + + private fun mergeRotatedBoxes(box1: OCRbBox, box2: OCRbBox): OCRbBox { + val orderedBox1 = orderPointsClockwise(box1) + val orderedBox2 = orderPointsClockwise(box2) + + val allPoints = arrayListOf() + allPoints.addAll(orderedBox1.bBox.map { Point(it.x, it.y) }) + allPoints.addAll(orderedBox2.bBox.map { Point(it.x, it.y) }) + + val matOfAllPoints = MatOfPoint() + matOfAllPoints.fromList(allPoints) + + // Finding the convex hull + val hullIndices = MatOfInt() + Imgproc.convexHull(matOfAllPoints, hullIndices, false) + + // Mapping the hull indices back to points + val hullPoints = hullIndices.toArray().map { allPoints[it] } + + val matOfHullPoints = MatOfPoint2f() + matOfHullPoints.fromList(hullPoints) + + // Create the minimal area rectangle from the hull points + val minAreaRect = Imgproc.minAreaRect(matOfHullPoints) + val rectPoints = arrayOfNulls(4) + minAreaRect.points(rectPoints) + + // Convert points back to BBoxPoint + val bBoxPoints = rectPoints.filterNotNull().map { BBoxPoint(it.x, it.y) } + + return OCRbBox(bBoxPoints, minAreaRect.angle) + } + + private fun removeSmallBoxes( + boxes: MutableList, + minSideThreshold: Int, + maxSideThreshold: Int + ): MutableList { + return boxes.filter { minSideLength(it) > minSideThreshold && maxSideLength(it) > maxSideThreshold } + .toMutableList() + } + + private fun minimumYFromBox(box: List): Double = box.minOf { it.y } + + private fun fitLineToShortestSides(box: OCRbBox): LineInfo { + // Convert the BBoxPoints to OpenCV Points + val sides = mutableListOf>() // Store side length and index + val midpoints = mutableListOf() + + // Calculate side lengths and midpoints + for (i in box.bBox.indices) { + val p1 = box.bBox[i] + val p2 = box.bBox[(i + 1) % 4] + val sideLength = distanceBetweenPoints(p1, p2) + sides.add(sideLength to i) + midpoints.add(midpoint(p1, p2)) + } + + // Sort sides by length + sides.sortBy { it.first } + + val midpoint1 = midpoints[sides[0].second] + val midpoint2 = midpoints[sides[1].second] + + val dx = abs(midpoint2.x - midpoint1.x) + val line = MatOfFloat4() + + val isVertical = if (dx < 20) { + for (point in arrayOf(midpoint1, midpoint2)) { + val temp = point.x + point.x = point.y + point.y = temp + } + Imgproc.fitLine( + MatOfPoint2f( + Point(midpoint1.x, midpoint1.y), + Point(midpoint2.x, midpoint2.y) + ), line, Imgproc.DIST_L2, 0.0, 0.01, 0.01 + ) + true + } else { + Imgproc.fitLine( + MatOfPoint2f( + Point(midpoint1.x, midpoint1.y), + Point(midpoint2.x, midpoint2.y) + ), line, Imgproc.DIST_L2, 0.0, 0.01, 0.01 + ) + false + } + + val m = line.get(1, 0)[0] / line.get(0, 0)[0] // slope + val c = line.get(3, 0)[0] - m * line.get(2, 0)[0] // intercept + return LineInfo(m, c, isVertical) + } + + private fun findClosestBox( + boxes: MutableList, + ignoredIds: Set, + currentBox: OCRbBox, + isVertical: Boolean, + m: Double, + c: Double, + centerThreshold: Double + ): Map? { + var smallestDistance = Double.MAX_VALUE + var idx = -1 + var boxHeight = 0.0 + val centerOfCurrentBox = centerOfBox(currentBox) + boxes.forEachIndexed { i, box -> + if (ignoredIds.contains(i)) { + return@forEachIndexed // continue in forEachIndexed is achieved by return@forEachIndexed + } + val centerOfProcessedBox = centerOfBox(box) + val distanceBetweenCenters = distanceBetweenPoints(centerOfCurrentBox, centerOfProcessedBox) + if (distanceBetweenCenters >= smallestDistance) { + return@forEachIndexed + } + boxHeight = minSideLength(box) + val lineDistance = if (isVertical) + abs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) + else + abs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c)) + + if (lineDistance < boxHeight * centerThreshold) { + idx = i + smallestDistance = distanceBetweenCenters + } + } + + return idx.takeIf { it != -1 }?.let { + mapOf("idx" to it, "boxHeight" to boxHeight) + } + } + + private fun createMaskFromLabels(labels: Mat, labelValue: Int): Mat { + val mask = Mat.zeros(labels.size(), CvType.CV_8U) + + Core.compare(labels, Scalar(labelValue.toDouble()), mask, Core.CMP_EQ) + + return mask + } + + fun interleavedArrayToMats(array: FloatArray, size: Size): Pair { + val mat1 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) + val mat2 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) + + array.forEachIndexed { index, value -> + val x = (index / 2) % (size.width.toInt()) + val y = (index / 2) / size.width.toInt() + if (index % 2 == 0) { + mat1.put(y, x, value.toDouble()) + } else { + mat2.put(y, x, value.toDouble()) + } + } + + return Pair(mat1, mat2) + } + + fun getDetBoxesFromTextMap( + textMap: Mat, + affinityMap: Mat, + textThreshold: Double, + linkThreshold: Double, + lowTextThreshold: Double + ): MutableList { + val imgH = textMap.rows() + val imgW = textMap.cols() + + val textScore = Mat() + val affinityScore = Mat() + Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) + Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) + val textScoreComb = Mat() + Core.add(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 1.0, Imgproc.THRESH_BINARY) + + val binaryMat = Mat() + textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) + + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) + + val detectedBoxes = mutableListOf() + Log.d("rn_executorch", "nLabels: $nLabels") + for (i in 1 until nLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + if (area < 10) continue + val mask = createMaskFromLabels(labels, i) + val maxValResult = Core.minMaxLoc(textMap, mask) + val maxVal = maxValResult.maxVal + if (maxVal < lowTextThreshold) continue + val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) + segMap.setTo(Scalar(255.0), mask) + + val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() + val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() + val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() + val sx = max(x - dilationRadius, 0) + val ex = min(x + w + dilationRadius + 1, imgW) + val sy = max(y - dilationRadius, 0) + val ey = min(y + h + dilationRadius + 1, imgH) + val roi = Rect(sx, sy, ex - sx, ey - sy) + val kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()) + ) + val roiSegMap = Mat(segMap, roi) + Imgproc.dilate(roiSegMap, roiSegMap, kernel) + + val contours: List = ArrayList() + Imgproc.findContours( + segMap, + contours, + Mat(), + Imgproc.RETR_EXTERNAL, + Imgproc.CHAIN_APPROX_SIMPLE + ) + if (contours.isNotEmpty()) { + val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) + val points = Array(4) { Point() } + minRect.points(points) + val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } + val boxInfo = OCRbBox(pointsList, minRect.angle) + detectedBoxes.add(boxInfo) + } + } + + return detectedBoxes + } + + fun restoreBoxRatio(boxes: MutableList, restoreRatio: Float): MutableList { + for (box in boxes) { + for (b in box.bBox) { + b.x *= restoreRatio + b.y *= restoreRatio + } + } + + return boxes + } + + fun groupTextBoxes( + boxes: MutableList, + centerThreshold: Double, + distanceThreshold: Double, + heightThreshold: Double, + minSideThreshold: Int, + maxSideThreshold: Int, + maxWidth: Int + ): MutableList { + boxes.sortByDescending { maxSideLength(it) } + var mergedArray = mutableListOf() + + while (boxes.isNotEmpty()) { + var currentBox = boxes.removeAt(0) + val normalizedAngle = normalizeAngle(currentBox.angle) + val ignoredIdxs = mutableSetOf() + var lineAngle: Double + while (true) { + val fittedLine = + fitLineToShortestSides(currentBox) // Placeholder for actual implementation + val slope = fittedLine.slope + val intercept = fittedLine.intercept + val isVertical = fittedLine.isVertical + + lineAngle = atan(slope) * 180 / Math.PI + if (isVertical) { + lineAngle = -90.0 + } + + val closestBoxInfo = findClosestBox( + boxes, ignoredIdxs, currentBox, + isVertical, slope, intercept, centerThreshold + ) ?: break + + val candidateIdx = closestBoxInfo["idx"] as Int + var candidateBox = boxes[candidateIdx] + val candidateHeight = closestBoxInfo["boxHeight"] as Double + if ((candidateBox.angle == 90.0 && !isVertical) || (candidateBox.angle == 0.0 && isVertical)) { + candidateBox = + rotateBox(candidateBox, normalizedAngle) // Placeholder for actual implementation + } + val minDistance = + calculateMinimalDistanceBetweenBoxes(candidateBox, currentBox) // Placeholder + val mergedHeight = minSideLength(currentBox) + if (minDistance < distanceThreshold * candidateHeight && abs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { + currentBox = mergeRotatedBoxes(currentBox, candidateBox) + boxes.removeAt(candidateIdx) + ignoredIdxs.clear() + if (maxSideLength(currentBox) > maxWidth) { + break + } + } else { + ignoredIdxs.add(candidateIdx) + } + } + mergedArray.add(currentBox.copy(angle = lineAngle)) + } + + mergedArray = removeSmallBoxes(mergedArray, minSideThreshold, maxSideThreshold) + mergedArray = mergedArray.sortedWith(compareBy { minimumYFromBox(it.bBox) }).toMutableList() + + return mergedArray + } + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt new file mode 100644 index 0000000000..4530ca8d11 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -0,0 +1,269 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Core +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.MatOfPoint2f +import org.opencv.core.Point +import org.opencv.core.Rect +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow +import kotlin.math.sqrt + +class RecognizerUtils { + companion object { + fun softmax(inputs: Mat): Mat { + val maxVal = Mat() + Core.reduce(inputs, maxVal, 1, Core.REDUCE_MAX, CvType.CV_32F) + + val tiledMaxVal = Mat() + Core.repeat(maxVal, 1, inputs.width(), tiledMaxVal) + val expInputs = Mat() + Core.subtract(inputs, tiledMaxVal, expInputs) + Core.exp(expInputs, expInputs) + + val sumExp = Mat() + Core.reduce(expInputs, sumExp, 1, Core.REDUCE_SUM, CvType.CV_32F) + + val tiledSumExp = Mat() + Core.repeat(sumExp, 1, inputs.width(), tiledSumExp) + val softmaxOutput = Mat() + Core.divide(expInputs, tiledSumExp, softmaxOutput) + + return softmaxOutput + } + + fun sumProbabilityRows(probabilities: Mat, modelOutputHeight: Int): FloatArray { + val predsNorm = FloatArray(probabilities.rows()) + + for (i in 0 until probabilities.rows()) { + var sum = 0.0 + for (j in 0 until modelOutputHeight) { + sum += probabilities.get(i, j)[0] + } + predsNorm[i] = sum.toFloat() + } + + return predsNorm + } + + fun divideMatrixByVector(matrix: Mat, vector: FloatArray): Mat { + for (i in 0 until matrix.rows()) { + for (j in 0 until matrix.cols()) { + val value = matrix.get(i, j)[0] / vector[i] + matrix.put(i, j, value) + } + } + + return matrix + } + + fun findMaxValuesAndIndices(probabilities: Mat): Pair> { + val values = DoubleArray(probabilities.rows()) + val indices = mutableListOf() + + for (i in 0 until probabilities.rows()) { + val row = probabilities.row(i) + val minMaxLocResult = Core.minMaxLoc(row) + + values[i] = minMaxLocResult.maxVal + indices.add(minMaxLocResult.maxLoc.x.toInt()) + } + + return Pair(values, indices) + } + + fun computeConfidenceScore(valuesArray: DoubleArray, indicesArray: List): Double { + val predsMaxProb = mutableListOf() + for ((index, value) in indicesArray.withIndex()) { + if (value != 0) predsMaxProb.add(valuesArray[index]) + } + + val nonZeroValues = + if (predsMaxProb.isEmpty()) doubleArrayOf(0.0) else predsMaxProb.toDoubleArray() + val product = nonZeroValues.reduce { acc, d -> acc * d } + val score = product.pow(2.0 / sqrt(nonZeroValues.size.toDouble())) + + return score + } + + private fun calculateRatio(width: Int, height: Int): Double { + var ratio = width.toDouble() / height.toDouble() + if (ratio < 1.0) { + ratio = 1.0 / ratio + } + + return ratio + } + + private fun findIntersection(r1: Rect, r2: Rect): Rect { + val aLeft = r1.x + val aTop = r1.y + val aRight = r1.x + r1.width + val aBottom = r1.y + r1.height + + val bLeft = r2.x + val bTop = r2.y + val bRight = r2.x + r2.width + val bBottom = r2.y + r2.height + + val iLeft = max(aLeft, bLeft) + val iTop = max(aTop, bTop) + val iRight = min(aRight, bRight) + val iBottom = min(aBottom, bBottom) + + return if (iRight > iLeft && iBottom > iTop) { + Rect(iLeft, iTop, iRight - iLeft, iBottom - iTop) + } else { + Rect() + } + } + + private fun adjustContrastGrey(img: Mat, target: Double): Mat { + var high = 0 + var low = 255 + + for (i in 0 until img.rows()) { + for (j in 0 until img.cols()) { + val pixel = img.get(i, j)[0].toInt() + high = maxOf(high, pixel) + low = minOf(low, pixel) + } + } + + val contrast = (high - low) / 255.0 + + if (contrast < target) { + val ratio = 200.0 / maxOf(10, high - low) + val tempImg = Mat() + img.convertTo(tempImg, CvType.CV_32F) + Core.subtract(tempImg, Scalar(low.toDouble() - 25), tempImg) + Core.multiply(tempImg, Scalar(ratio), tempImg) + Imgproc.threshold(tempImg, tempImg, 255.0, 255.0, Imgproc.THRESH_TRUNC) + Imgproc.threshold(tempImg, tempImg, 0.0, 255.0, Imgproc.THRESH_TOZERO) + tempImg.convertTo(tempImg, CvType.CV_8U) + + return tempImg + } + + return img + } + + private fun computeRatioAndResize(img: Mat, width: Int, height: Int, modelHeight: Int): Mat { + var ratio = width.toDouble() / height.toDouble() + + if (ratio < 1.0) { + ratio = + calculateRatio(width, height) + Imgproc.resize( + img, img, Size(modelHeight.toDouble(), (modelHeight * ratio)), + 0.0, 0.0, Imgproc.INTER_LANCZOS4 + ) + } else { + Imgproc.resize( + img, img, Size((modelHeight * ratio), modelHeight.toDouble()), + 0.0, 0.0, Imgproc.INTER_LANCZOS4 + ) + } + + return img + } + + fun calculateResizeRatioAndPaddings( + width: Int, + height: Int, + desiredWidth: Int, + desiredHeight: Int + ): Map { + val newRatioH = desiredHeight.toFloat() / height + val newRatioW = desiredWidth.toFloat() / width + var resizeRatio = minOf(newRatioH, newRatioW) + + val newWidth = (width * resizeRatio).toInt() + val newHeight = (height * resizeRatio).toInt() + + val deltaW = desiredWidth - newWidth + val deltaH = desiredHeight - newHeight + + val top = deltaH / 2 + val left = deltaW / 2 + + val heightRatio = height.toFloat() / desiredHeight + val widthRatio = width.toFloat() / desiredWidth + + resizeRatio = maxOf(heightRatio, widthRatio) + + return mapOf( + "resizeRatio" to resizeRatio, + "top" to top, + "left" to left + ) + } + + fun getCroppedImage(box: OCRbBox, image: Mat, modelHeight: Int): Mat { + val cords = box.bBox + val angle = box.angle + val points = ArrayList() + + cords.forEach { point -> + points.add(Point(point.x, point.y)) + } + + val rotatedRect = Imgproc.minAreaRect(MatOfPoint2f(*points.toTypedArray())) + val imageCenter = Point((image.cols() / 2.0), (image.rows() / 2.0)) + val rotationMatrix = Imgproc.getRotationMatrix2D(imageCenter, angle, 1.0) + val rotatedImage = Mat() + Imgproc.warpAffine(image, rotatedImage, rotationMatrix, image.size(), Imgproc.INTER_LINEAR) + + val rectPoints = Array(4) { Point() } + rotatedRect.points(rectPoints) + val transformedPoints = arrayOfNulls(4) + val rectMat = Mat(4, 2, CvType.CV_32FC2) + for (i in 0 until 4) { + rectMat.put(i, 0, *doubleArrayOf(rectPoints[i].x, rectPoints[i].y)) + } + Core.transform(rectMat, rectMat, rotationMatrix) + + for (i in 0 until 4) { + transformedPoints[i] = Point(rectMat.get(i, 0)[0], rectMat.get(i, 0)[1]) + } + + var boundingBox = + Imgproc.boundingRect(MatOfPoint2f(*transformedPoints.filterNotNull().toTypedArray())) + val validRegion = Rect(0, 0, rotatedImage.cols(), rotatedImage.rows()) + boundingBox = findIntersection(boundingBox, validRegion) + val croppedImage = Mat(rotatedImage, boundingBox) + if (croppedImage.empty()) { + return croppedImage + } + + return computeRatioAndResize(croppedImage, boundingBox.width, boundingBox.height, modelHeight) + } + + fun normalizeForRecognizer(image: Mat, adjustContrast: Double): Mat { + var img = image.clone() + + if (adjustContrast > 0) { + img = adjustContrastGrey(img, adjustContrast) + } + + val desiredWidth = when { + img.width() >= 512 -> 512 + img.width() >= 256 -> 256 + else -> 128 + } + + img = ImageProcessor.resizeWithPadding(img, desiredWidth, 64) + img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) + Core.subtract(img, Scalar(0.5), img) + Core.multiply(img, Scalar(2.0), img) + + return img + } + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/Fetcher.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/Fetcher.kt index deaf787e36..e79ef97b1f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/Fetcher.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/Fetcher.kt @@ -15,6 +15,7 @@ import java.net.URL enum class ResourceType { TOKENIZER, MODEL, + TXT } class Fetcher { @@ -40,6 +41,10 @@ class Fetcher { ResourceType.MODEL -> { "pte" } + + ResourceType.TXT -> { + "txt" + } } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt index 5488ecb476..1e00aa4807 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt @@ -3,20 +3,30 @@ package com.swmansion.rnexecutorch.utils import android.content.Context import android.net.Uri import android.util.Base64 +import android.util.Log +import org.opencv.core.Core import org.opencv.core.CvType import org.opencv.core.Mat +import org.opencv.core.Scalar +import org.opencv.core.Size import org.opencv.imgcodecs.Imgcodecs +import org.opencv.imgproc.Imgproc import org.pytorch.executorch.EValue import org.pytorch.executorch.Tensor import java.io.File import java.io.InputStream import java.net.URL import java.util.UUID +import kotlin.math.floor class ImageProcessor { companion object { fun matToEValue(mat: Mat, shape: LongArray): EValue { + return matToEValue(mat, shape, Scalar(0.0, 0.0, 0.0), Scalar(1.0, 1.0, 1.0)) + } + + fun matToEValue(mat: Mat, shape: LongArray, mean: Scalar, variance: Scalar): EValue { val pixelCount = mat.cols() * mat.rows() val floatArray = FloatArray(pixelCount * 3) @@ -26,19 +36,38 @@ class ImageProcessor { val pixel = mat.get(row, col) if (mat.type() == CvType.CV_8UC3 || mat.type() == CvType.CV_8UC4) { - val b = pixel[0] / 255.0f - val g = pixel[1] / 255.0f - val r = pixel[2] / 255.0f + val b = (pixel[0] - mean.`val`[0] * 255.0f) / (variance.`val`[0] * 255.0f) + val g = (pixel[1] - mean.`val`[1] * 255.0f) / (variance.`val`[1] * 255.0f) + val r = (pixel[2] - mean.`val`[2] * 255.0f) / (variance.`val`[2] * 255.0f) - floatArray[i] = r.toFloat() - floatArray[pixelCount + i] = g.toFloat() - floatArray[2 * pixelCount + i] = b.toFloat() + floatArray[0 * pixelCount + i] = b.toFloat() + floatArray[1 * pixelCount + i] = g.toFloat() + floatArray[2 * pixelCount + i] = r.toFloat() } } return EValue.from(Tensor.fromBlob(floatArray, shape)) } + fun matToEValueGray(mat: Mat): EValue { + val pixelCount = mat.cols() * mat.rows() + val floatArray = FloatArray(pixelCount) + + for (i in 0 until pixelCount) { + val row = i / mat.cols() + val col = i % mat.cols() + val pixel = mat.get(row, col) + floatArray[i] = pixel[0].toFloat() + } + + return EValue.from( + Tensor.fromBlob( + floatArray, + longArrayOf(1, 1, mat.rows().toLong(), mat.cols().toLong()) + ) + ) + } + fun EValueToMat(array: FloatArray, width: Int, height: Int): Mat { val mat = Mat(height, width, CvType.CV_8UC3) @@ -64,7 +93,7 @@ class ImageProcessor { Imgcodecs.imwrite(tempFile.absolutePath, mat) return "file://${tempFile.absolutePath}" - }catch (e: Exception) { + } catch (e: Exception) { throw Exception(ETError.FileWriteFailed.toString()) } } @@ -89,11 +118,13 @@ class ImageProcessor { } inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR) } + scheme.equals("file", ignoreCase = true) -> { //device storage val path = uri.path inputImage = Imgcodecs.imread(path, Imgcodecs.IMREAD_COLOR) } + else -> { //external source val url = URL(source) @@ -117,5 +148,70 @@ class ImageProcessor { return inputImage } + + fun resizeWithPadding(img: Mat, desiredWidth: Int, desiredHeight: Int): Mat { + val height = img.rows() + val width = img.cols() + val heightRatio = desiredHeight.toFloat() / height + val widthRatio = desiredWidth.toFloat() / width + val resizeRatio = minOf(heightRatio, widthRatio) + val newWidth = (width * resizeRatio).toInt() + val newHeight = (height * resizeRatio).toInt() + + val resizedImg = Mat() + Imgproc.resize( + img, + resizedImg, + Size(newWidth.toDouble(), newHeight.toDouble()), + 0.0, + 0.0, + Imgproc.INTER_AREA + ) + + val cornerPatchSize = maxOf(1, minOf(width, height) / 30) + val corners = listOf( + img.submat(0, cornerPatchSize, 0, cornerPatchSize), + img.submat(0, cornerPatchSize, width - cornerPatchSize, width), + img.submat(height - cornerPatchSize, height, 0, cornerPatchSize), + img.submat(height - cornerPatchSize, height, width - cornerPatchSize, width) + ) + + var backgroundScalar = Core.mean(corners[0]) + for (i in 1 until corners.size) { + val mean = Core.mean(corners[i]) + backgroundScalar = Scalar( + backgroundScalar.`val`[0] + mean.`val`[0], + backgroundScalar.`val`[1] + mean.`val`[1], + backgroundScalar.`val`[2] + mean.`val`[2] + ) + } + + backgroundScalar = Scalar( + floor(backgroundScalar.`val`[0] / corners.size), + floor(backgroundScalar.`val`[1] / corners.size), + floor(backgroundScalar.`val`[2] / corners.size) + ) + + val deltaW = desiredWidth - newWidth + val deltaH = desiredHeight - newHeight + val top = deltaH / 2 + val bottom = deltaH - top + val left = deltaW / 2 + val right = deltaW - left + + val centeredImg = Mat() + Core.copyMakeBorder( + resizedImg, + centeredImg, + top, + bottom, + left, + right, + Core.BORDER_CONSTANT, + backgroundScalar + ) + + return centeredImg + } } } From c0621e8b19cd24380d85bc2d63b3a612bdd8a5f5 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 13 Feb 2025 15:36:45 +0100 Subject: [PATCH 2/2] refactor: refactor ocr code --- .../java/com/swmansion/rnexecutorch/OCR.kt | 7 +- .../rnexecutorch/models/ocr/Detector.kt | 38 +++-- .../models/ocr/RecognitionHandler.kt | 40 ++--- .../models/ocr/utils/Constants.kt | 27 +++ .../models/ocr/utils/DetectorUtils.kt | 83 ++++----- .../models/ocr/utils/RecognizerUtils.kt | 158 +++++++++--------- 6 files changed, 188 insertions(+), 165 deletions(-) create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt diff --git a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt index eba3d90b16..85acf06260 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt @@ -8,12 +8,11 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.android.OpenCVLoader import com.swmansion.rnexecutorch.models.ocr.Detector import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler +import com.swmansion.rnexecutorch.models.ocr.utils.Constants import com.swmansion.rnexecutorch.utils.Fetcher import com.swmansion.rnexecutorch.utils.ResourceType import org.opencv.imgproc.Imgproc -const val recognizerRatio = 1.6 - class OCR(reactContext: ReactApplicationContext) : NativeOCRSpec(reactContext) { @@ -86,8 +85,8 @@ class OCR(reactContext: ReactApplicationContext) : val result = recognitionHandler.recognize( bBoxesList, inputImage, - (detectorSize.width * recognizerRatio).toInt(), - (detectorSize.height * recognizerRatio).toInt() + (detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(), + (detectorSize.height * Constants.RECOGNIZER_RATIO).toInt() ) promise.resolve(result) } catch (e: Exception) { diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt index 025555b5b2..85976e2281 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.models.ocr import android.util.Log import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.Constants import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox import com.swmansion.rnexecutorch.utils.ImageProcessor @@ -11,9 +12,6 @@ import org.opencv.core.Scalar import org.opencv.core.Size import org.pytorch.executorch.EValue -val mean: Scalar = Scalar(0.485, 0.456, 0.406) -val variance: Scalar = Scalar(0.229, 0.224, 0.225) - class Detector(reactApplicationContext: ReactApplicationContext) : BaseModel>(reactApplicationContext) { private lateinit var originalSize: Size @@ -36,7 +34,12 @@ class Detector(reactApplicationContext: ReactApplicationContext) : getModelImageSize().height.toInt() ) - return ImageProcessor.matToEValue(resizedImage, module.getInputShape(0), mean, variance) + return ImageProcessor.matToEValue( + resizedImage, + module.getInputShape(0), + Constants.MEAN, + Constants.VARIANCE + ) } override fun postprocess(output: Array): List { @@ -48,18 +51,29 @@ class Detector(reactApplicationContext: ReactApplicationContext) : outputArray, Size(modelImageSize.width / 2, modelImageSize.height / 2) ) - var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(scoreText, scoreLink, 0.4, 0.4, 0.7) - bBoxesList = DetectorUtils.restoreBoxRatio(bBoxesList, 3.2f) - bBoxesList = DetectorUtils.groupTextBoxes(bBoxesList, 0.5, 2.0, 2.0, 15, 30, 678) + var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( + scoreText, + scoreLink, + Constants.TEXT_THRESHOLD, + Constants.LINK_THRESHOLD, + Constants.LOW_TEXT_THRESHOLD + ) + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) + bBoxesList = DetectorUtils.groupTextBoxes( + bBoxesList, + Constants.CENTER_THRESHOLD, + Constants.DISTANCE_THRESHOLD, + Constants.HEIGHT_THRESHOLD, + Constants.MIN_SIDE_THRESHOLD, + Constants.MAX_SIDE_THRESHOLD, + Constants.MAX_WIDTH + ) return bBoxesList.toList() } override fun runModel(input: Mat): List { - val modelInput = preprocess(input) - val modelOutput = forward(modelInput) - Log.d("rn_executorch", "modelOutput: $modelOutput") - val output = postprocess(modelOutput) - return output + return postprocess(forward(preprocess(input))) } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt index 81f8e57cad..1aeae02e22 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt @@ -1,24 +1,16 @@ package com.swmansion.rnexecutorch.models.ocr -import android.util.Log import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.ReactApplicationContext import com.facebook.react.bridge.WritableArray -import com.swmansion.rnexecutorch.models.ocr.utils.BBoxPoint import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter +import com.swmansion.rnexecutorch.models.ocr.utils.Constants import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.core.Core import org.opencv.core.Mat -const val modelHeight = 64 -const val largeModelWidth = 512 -const val mediumModelWidth = 256 -const val smallModelWidth = 128 -const val lowConfidenceThreshold = 0.3 -const val adjustContrast = 0.2 - class RecognitionHandler( symbols: String, languageDictPath: String, @@ -30,9 +22,9 @@ class RecognitionHandler( private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key")) private fun runModel(croppedImage: Mat): Pair, Double> { - val result: Pair, Double> = if (croppedImage.cols() >= largeModelWidth) { + val result: Pair, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) { recognizerLarge.runModel(croppedImage) - } else if (croppedImage.cols() >= mediumModelWidth) { + } else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) { recognizerMedium.runModel(croppedImage) } else { recognizerSmall.runModel(croppedImage) @@ -81,17 +73,17 @@ class RecognitionHandler( ) for (box in bBoxesList) { - var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, modelHeight) + var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT) if (croppedImage.empty()) { continue } - croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, adjustContrast) + croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST) var result = runModel(croppedImage) var confidenceScore = result.second - if (confidenceScore < lowConfidenceThreshold) { + if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) { Core.rotate(croppedImage, croppedImage, Core.ROTATE_180) val rotatedResult = runModel(croppedImage) val rotatedConfidenceScore = rotatedResult.second @@ -104,25 +96,15 @@ class RecognitionHandler( val predIndex = result.first val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size) - val bbox = Array(4) { BBoxPoint(0.0, 0.0) } - for (i in 0 until 4) { - bbox[i] = BBoxPoint( - ((box.bBox[i].x - left) * resizeRatio), - ((box.bBox[i].y - top) * resizeRatio) - ) + for (bBox in box.bBox) { + bBox.x = (bBox.x - left) * resizeRatio + bBox.y = (bBox.y - top) * resizeRatio } - Log.d("rn_executorch", "confidenceScore: $confidenceScore") val resMap = Arguments.createMap() - val bboxArray = Arguments.createArray() - bbox.forEach { point -> - val pointMap = Arguments.createMap() - pointMap.putDouble("x", point.x) - pointMap.putDouble("y", point.y) - bboxArray.pushMap(pointMap) - } + resMap.putString("text", decodedTexts[0]) - resMap.putArray("bbox", bboxArray) + resMap.putArray("bbox", box.toWritableArray()) resMap.putDouble("confidence", confidenceScore) res.pushMap(resMap) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt new file mode 100644 index 0000000000..b49232f41a --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt @@ -0,0 +1,27 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import org.opencv.core.Scalar + +class Constants { + companion object { + const val RECOGNIZER_RATIO = 1.6 + const val MODEL_HEIGHT = 64 + const val LARGE_MODEL_WIDTH = 512 + const val MEDIUM_MODEL_WIDTH = 256 + const val SMALL_MODEL_WIDTH = 128 + const val LOW_CONFIDENCE_THRESHOLD = 0.3 + const val ADJUST_CONTRAST = 0.2 + const val TEXT_THRESHOLD = 0.4 + const val LINK_THRESHOLD = 0.4 + const val LOW_TEXT_THRESHOLD = 0.7 + const val CENTER_THRESHOLD = 0.5 + const val DISTANCE_THRESHOLD = 2.0 + const val HEIGHT_THRESHOLD = 2.0 + const val MIN_SIDE_THRESHOLD = 15 + const val MAX_SIDE_THRESHOLD = 30 + const val MAX_WIDTH = (LARGE_MODEL_WIDTH + (LARGE_MODEL_WIDTH * 0.15)).toInt() + const val MIN_SIZE = 20 + val MEAN = Scalar(0.485, 0.456, 0.406) + val VARIANCE = Scalar(0.229, 0.224, 0.225) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt index fea965b05e..4beb7ecf45 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -1,6 +1,7 @@ package com.swmansion.rnexecutorch.models.ocr.utils -import android.util.Log +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.WritableArray import org.opencv.core.Core import org.opencv.core.CvType import org.opencv.core.Mat @@ -22,22 +23,6 @@ import kotlin.math.pow import kotlin.math.sin import kotlin.math.sqrt -data class BBoxPoint( - var x: Double, - var y: Double, -) - -data class OCRbBox( - val bBox: List, - val angle: Double, -) - -data class LineInfo( - val slope: Double, - val intercept: Double, - val isVertical: Boolean -) - class DetectorUtils { companion object { private fun normalizeAngle(angle: Double): Double { @@ -165,22 +150,18 @@ class DetectorUtils { val matOfAllPoints = MatOfPoint() matOfAllPoints.fromList(allPoints) - // Finding the convex hull val hullIndices = MatOfInt() Imgproc.convexHull(matOfAllPoints, hullIndices, false) - // Mapping the hull indices back to points val hullPoints = hullIndices.toArray().map { allPoints[it] } val matOfHullPoints = MatOfPoint2f() matOfHullPoints.fromList(hullPoints) - // Create the minimal area rectangle from the hull points val minAreaRect = Imgproc.minAreaRect(matOfHullPoints) val rectPoints = arrayOfNulls(4) minAreaRect.points(rectPoints) - // Convert points back to BBoxPoint val bBoxPoints = rectPoints.filterNotNull().map { BBoxPoint(it.x, it.y) } return OCRbBox(bBoxPoints, minAreaRect.angle) @@ -198,11 +179,9 @@ class DetectorUtils { private fun minimumYFromBox(box: List): Double = box.minOf { it.y } private fun fitLineToShortestSides(box: OCRbBox): LineInfo { - // Convert the BBoxPoints to OpenCV Points - val sides = mutableListOf>() // Store side length and index + val sides = mutableListOf>() val midpoints = mutableListOf() - // Calculate side lengths and midpoints for (i in box.bBox.indices) { val p1 = box.bBox[i] val p2 = box.bBox[(i + 1) % 4] @@ -211,7 +190,6 @@ class DetectorUtils { midpoints.add(midpoint(p1, p2)) } - // Sort sides by length sides.sortBy { it.first } val midpoint1 = midpoints[sides[0].second] @@ -256,14 +234,14 @@ class DetectorUtils { m: Double, c: Double, centerThreshold: Double - ): Map? { + ): Pair? { var smallestDistance = Double.MAX_VALUE var idx = -1 var boxHeight = 0.0 val centerOfCurrentBox = centerOfBox(currentBox) boxes.forEachIndexed { i, box -> if (ignoredIds.contains(i)) { - return@forEachIndexed // continue in forEachIndexed is achieved by return@forEachIndexed + return@forEachIndexed } val centerOfProcessedBox = centerOfBox(box) val distanceBetweenCenters = distanceBetweenPoints(centerOfCurrentBox, centerOfProcessedBox) @@ -282,9 +260,7 @@ class DetectorUtils { } } - return idx.takeIf { it != -1 }?.let { - mapOf("idx" to it, "boxHeight" to boxHeight) - } + return if (idx == -1) null else Pair(idx, boxHeight) } private fun createMaskFromLabels(labels: Mat, labelValue: Int): Mat { @@ -333,14 +309,12 @@ class DetectorUtils { val binaryMat = Mat() textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) - val labels = Mat() val stats = Mat() val centroids = Mat() val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) val detectedBoxes = mutableListOf() - Log.d("rn_executorch", "nLabels: $nLabels") for (i in 1 until nLabels) { val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() if (area < 10) continue @@ -415,11 +389,11 @@ class DetectorUtils { while (boxes.isNotEmpty()) { var currentBox = boxes.removeAt(0) val normalizedAngle = normalizeAngle(currentBox.angle) - val ignoredIdxs = mutableSetOf() + val ignoredIds = mutableSetOf() var lineAngle: Double while (true) { val fittedLine = - fitLineToShortestSides(currentBox) // Placeholder for actual implementation + fitLineToShortestSides(currentBox) val slope = fittedLine.slope val intercept = fittedLine.intercept val isVertical = fittedLine.isVertical @@ -430,29 +404,29 @@ class DetectorUtils { } val closestBoxInfo = findClosestBox( - boxes, ignoredIdxs, currentBox, + boxes, ignoredIds, currentBox, isVertical, slope, intercept, centerThreshold ) ?: break - val candidateIdx = closestBoxInfo["idx"] as Int + val candidateIdx = closestBoxInfo.first var candidateBox = boxes[candidateIdx] - val candidateHeight = closestBoxInfo["boxHeight"] as Double + val candidateHeight = closestBoxInfo.second if ((candidateBox.angle == 90.0 && !isVertical) || (candidateBox.angle == 0.0 && isVertical)) { candidateBox = - rotateBox(candidateBox, normalizedAngle) // Placeholder for actual implementation + rotateBox(candidateBox, normalizedAngle) } val minDistance = - calculateMinimalDistanceBetweenBoxes(candidateBox, currentBox) // Placeholder + calculateMinimalDistanceBetweenBoxes(candidateBox, currentBox) val mergedHeight = minSideLength(currentBox) if (minDistance < distanceThreshold * candidateHeight && abs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { currentBox = mergeRotatedBoxes(currentBox, candidateBox) boxes.removeAt(candidateIdx) - ignoredIdxs.clear() + ignoredIds.clear() if (maxSideLength(currentBox) > maxWidth) { break } } else { - ignoredIdxs.add(candidateIdx) + ignoredIds.add(candidateIdx) } } mergedArray.add(currentBox.copy(angle = lineAngle)) @@ -465,3 +439,30 @@ class DetectorUtils { } } } + +data class BBoxPoint( + var x: Double, + var y: Double, +) + +data class OCRbBox( + val bBox: List, + val angle: Double, +) { + fun toWritableArray(): WritableArray { + val array = Arguments.createArray() + bBox.forEach { point -> + val pointMap = Arguments.createMap() + pointMap.putDouble("x", point.x) + pointMap.putDouble("y", point.y) + array.pushMap(pointMap) + } + return array + } +} + +data class LineInfo( + val slope: Double, + val intercept: Double, + val isVertical: Boolean +) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index 4530ca8d11..99adcad9f0 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -17,81 +17,6 @@ import kotlin.math.sqrt class RecognizerUtils { companion object { - fun softmax(inputs: Mat): Mat { - val maxVal = Mat() - Core.reduce(inputs, maxVal, 1, Core.REDUCE_MAX, CvType.CV_32F) - - val tiledMaxVal = Mat() - Core.repeat(maxVal, 1, inputs.width(), tiledMaxVal) - val expInputs = Mat() - Core.subtract(inputs, tiledMaxVal, expInputs) - Core.exp(expInputs, expInputs) - - val sumExp = Mat() - Core.reduce(expInputs, sumExp, 1, Core.REDUCE_SUM, CvType.CV_32F) - - val tiledSumExp = Mat() - Core.repeat(sumExp, 1, inputs.width(), tiledSumExp) - val softmaxOutput = Mat() - Core.divide(expInputs, tiledSumExp, softmaxOutput) - - return softmaxOutput - } - - fun sumProbabilityRows(probabilities: Mat, modelOutputHeight: Int): FloatArray { - val predsNorm = FloatArray(probabilities.rows()) - - for (i in 0 until probabilities.rows()) { - var sum = 0.0 - for (j in 0 until modelOutputHeight) { - sum += probabilities.get(i, j)[0] - } - predsNorm[i] = sum.toFloat() - } - - return predsNorm - } - - fun divideMatrixByVector(matrix: Mat, vector: FloatArray): Mat { - for (i in 0 until matrix.rows()) { - for (j in 0 until matrix.cols()) { - val value = matrix.get(i, j)[0] / vector[i] - matrix.put(i, j, value) - } - } - - return matrix - } - - fun findMaxValuesAndIndices(probabilities: Mat): Pair> { - val values = DoubleArray(probabilities.rows()) - val indices = mutableListOf() - - for (i in 0 until probabilities.rows()) { - val row = probabilities.row(i) - val minMaxLocResult = Core.minMaxLoc(row) - - values[i] = minMaxLocResult.maxVal - indices.add(minMaxLocResult.maxLoc.x.toInt()) - } - - return Pair(values, indices) - } - - fun computeConfidenceScore(valuesArray: DoubleArray, indicesArray: List): Double { - val predsMaxProb = mutableListOf() - for ((index, value) in indicesArray.withIndex()) { - if (value != 0) predsMaxProb.add(valuesArray[index]) - } - - val nonZeroValues = - if (predsMaxProb.isEmpty()) doubleArrayOf(0.0) else predsMaxProb.toDoubleArray() - val product = nonZeroValues.reduce { acc, d -> acc * d } - val score = product.pow(2.0 / sqrt(nonZeroValues.size.toDouble())) - - return score - } - private fun calculateRatio(width: Int, height: Int): Double { var ratio = width.toDouble() / height.toDouble() if (ratio < 1.0) { @@ -174,6 +99,81 @@ class RecognizerUtils { return img } + fun softmax(inputs: Mat): Mat { + val maxVal = Mat() + Core.reduce(inputs, maxVal, 1, Core.REDUCE_MAX, CvType.CV_32F) + + val tiledMaxVal = Mat() + Core.repeat(maxVal, 1, inputs.width(), tiledMaxVal) + val expInputs = Mat() + Core.subtract(inputs, tiledMaxVal, expInputs) + Core.exp(expInputs, expInputs) + + val sumExp = Mat() + Core.reduce(expInputs, sumExp, 1, Core.REDUCE_SUM, CvType.CV_32F) + + val tiledSumExp = Mat() + Core.repeat(sumExp, 1, inputs.width(), tiledSumExp) + val softmaxOutput = Mat() + Core.divide(expInputs, tiledSumExp, softmaxOutput) + + return softmaxOutput + } + + fun sumProbabilityRows(probabilities: Mat, modelOutputHeight: Int): FloatArray { + val predsNorm = FloatArray(probabilities.rows()) + + for (i in 0 until probabilities.rows()) { + var sum = 0.0 + for (j in 0 until modelOutputHeight) { + sum += probabilities.get(i, j)[0] + } + predsNorm[i] = sum.toFloat() + } + + return predsNorm + } + + fun divideMatrixByVector(matrix: Mat, vector: FloatArray): Mat { + for (i in 0 until matrix.rows()) { + for (j in 0 until matrix.cols()) { + val value = matrix.get(i, j)[0] / vector[i] + matrix.put(i, j, value) + } + } + + return matrix + } + + fun findMaxValuesAndIndices(probabilities: Mat): Pair> { + val values = DoubleArray(probabilities.rows()) + val indices = mutableListOf() + + for (i in 0 until probabilities.rows()) { + val row = probabilities.row(i) + val minMaxLocResult = Core.minMaxLoc(row) + + values[i] = minMaxLocResult.maxVal + indices.add(minMaxLocResult.maxLoc.x.toInt()) + } + + return Pair(values, indices) + } + + fun computeConfidenceScore(valuesArray: DoubleArray, indicesArray: List): Double { + val predsMaxProb = mutableListOf() + for ((index, value) in indicesArray.withIndex()) { + if (value != 0) predsMaxProb.add(valuesArray[index]) + } + + val nonZeroValues = + if (predsMaxProb.isEmpty()) doubleArrayOf(0.0) else predsMaxProb.toDoubleArray() + val product = nonZeroValues.reduce { acc, d -> acc * d } + val score = product.pow(2.0 / sqrt(nonZeroValues.size.toDouble())) + + return score + } + fun calculateResizeRatioAndPaddings( width: Int, height: Int, @@ -253,12 +253,12 @@ class RecognizerUtils { } val desiredWidth = when { - img.width() >= 512 -> 512 - img.width() >= 256 -> 256 - else -> 128 + img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH + img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH + else -> Constants.SMALL_MODEL_WIDTH } - img = ImageProcessor.resizeWithPadding(img, desiredWidth, 64) + img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) Core.subtract(img, Scalar(0.5), img) Core.multiply(img, Scalar(2.0), img)