diff --git a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt new file mode 100644 index 0000000000..4f0b926b37 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt @@ -0,0 +1,87 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader +import com.swmansion.rnexecutorch.models.ocr.Detector +import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import org.opencv.imgproc.Imgproc + +class OCR(reactContext: ReactApplicationContext) : + NativeOCRSpec(reactContext) { + + private lateinit var detector: Detector + private lateinit var recognitionHandler: RecognitionHandler + + companion object { + const val NAME = "OCR" + } + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + + override fun loadModule( + detectorSource: String, + recognizerSourceLarge: String, + recognizerSourceMedium: String, + recognizerSourceSmall: String, + symbols: String, + promise: Promise + ) { + try { + detector = Detector(reactApplicationContext) + detector.loadModel(detectorSource) + + recognitionHandler = RecognitionHandler( + symbols, + reactApplicationContext + ) + + recognitionHandler.loadRecognizers( + recognizerSourceLarge, + recognizerSourceMedium, + recognizerSourceSmall + ) { _, errorRecognizer -> + if (errorRecognizer != null) { + throw Error(errorRecognizer.message!!) + } + + promise.resolve(0) + } + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward(input: String, promise: Promise) { + try { + val inputImage = ImageProcessor.readImage(input) + val bBoxesList = detector.runModel(inputImage) + val detectorSize = detector.getModelImageSize() + Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY) + val result = recognitionHandler.recognize( + bBoxesList, + inputImage, + (detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(), + (detectorSize.height * Constants.RECOGNIZER_RATIO).toInt() + ) + promise.resolve(result) + } catch (e: Exception) { + Log.d("rn_executorch", "Error running model: ${e.message}") + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String { + return NAME + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index fb7fe1f63b..0ec2a51c4f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -25,6 +25,8 @@ class RnExecutorchPackage : TurboReactPackage() { ObjectDetection(reactContext) } else if (name == SpeechToText.NAME) { SpeechToText(reactContext) + } else if (name == OCR.NAME){ + OCR(reactContext) } else { null @@ -85,6 +87,15 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true ) + + moduleInfos[OCR.NAME] = ReactModuleInfo( + OCR.NAME, + OCR.NAME, + false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt new file mode 100644 index 0000000000..85976e2281 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -0,0 +1,79 @@ +package com.swmansion.rnexecutorch.models.ocr + +import android.util.Log +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class Detector(reactApplicationContext: ReactApplicationContext) : + BaseModel>(reactApplicationContext) { + private lateinit var originalSize: Size + + fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex] + val height = inputShape[inputShape.lastIndex - 1] + + val modelImageSize = Size(height.toDouble(), width.toDouble()) + + return modelImageSize + } + + override fun preprocess(input: Mat): EValue { + originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) + val resizedImage = ImageProcessor.resizeWithPadding( + input, + getModelImageSize().width.toInt(), + getModelImageSize().height.toInt() + ) + + return ImageProcessor.matToEValue( + resizedImage, + module.getInputShape(0), + Constants.MEAN, + Constants.VARIANCE + ) + } + + override fun postprocess(output: Array): List { + val outputTensor = output[0].toTensor() + val outputArray = outputTensor.dataAsFloatArray + val modelImageSize = getModelImageSize() + + val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( + outputArray, + Size(modelImageSize.width / 2, modelImageSize.height / 2) + ) + var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( + scoreText, + scoreLink, + Constants.TEXT_THRESHOLD, + Constants.LINK_THRESHOLD, + Constants.LOW_TEXT_THRESHOLD + ) + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) + bBoxesList = DetectorUtils.groupTextBoxes( + bBoxesList, + Constants.CENTER_THRESHOLD, + Constants.DISTANCE_THRESHOLD, + Constants.HEIGHT_THRESHOLD, + Constants.MIN_SIDE_THRESHOLD, + Constants.MAX_SIDE_THRESHOLD, + Constants.MAX_WIDTH + ) + + return bBoxesList.toList() + } + + override fun runModel(input: Mat): List { + return postprocess(forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt new file mode 100644 index 0000000000..90fd61280e --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt @@ -0,0 +1,114 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.WritableArray +import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Core +import org.opencv.core.Mat + +class RecognitionHandler( + symbols: String, + reactApplicationContext: ReactApplicationContext +) { + private val recognizerLarge = Recognizer(reactApplicationContext) + private val recognizerMedium = Recognizer(reactApplicationContext) + private val recognizerSmall = Recognizer(reactApplicationContext) + private val converter = CTCLabelConverter(symbols) + + private fun runModel(croppedImage: Mat): Pair, Double> { + val result: Pair, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) { + recognizerLarge.runModel(croppedImage) + } else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) { + recognizerMedium.runModel(croppedImage) + } else { + recognizerSmall.runModel(croppedImage) + } + + return result + } + + fun loadRecognizers( + largeRecognizerPath: String, + mediumRecognizerPath: String, + smallRecognizerPath: String, + onComplete: (Int, Exception?) -> Unit + ) { + try { + recognizerLarge.loadModel(largeRecognizerPath) + recognizerMedium.loadModel(mediumRecognizerPath) + recognizerSmall.loadModel(smallRecognizerPath) + onComplete(0, null) + } catch (e: Exception) { + onComplete(1, e) + } + } + + fun recognize( + bBoxesList: List, + imgGray: Mat, + desiredWidth: Int, + desiredHeight: Int + ): WritableArray { + val res: WritableArray = Arguments.createArray() + val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings( + imgGray.width(), + imgGray.height(), + desiredWidth, + desiredHeight + ) + + val left = ratioAndPadding["left"] as Int + val top = ratioAndPadding["top"] as Int + val resizeRatio = ratioAndPadding["resizeRatio"] as Float + val resizedImg = ImageProcessor.resizeWithPadding( + imgGray, + desiredWidth, + desiredHeight + ) + + for (box in bBoxesList) { + var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT) + if (croppedImage.empty()) { + continue + } + + croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST) + + var result = runModel(croppedImage) + var confidenceScore = result.second + + if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) { + Core.rotate(croppedImage, croppedImage, Core.ROTATE_180) + val rotatedResult = runModel(croppedImage) + val rotatedConfidenceScore = rotatedResult.second + if (rotatedConfidenceScore > confidenceScore) { + result = rotatedResult + confidenceScore = rotatedConfidenceScore + } + } + + val predIndex = result.first + val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size) + + for (bBox in box.bBox) { + bBox.x = (bBox.x - left) * resizeRatio + bBox.y = (bBox.y - top) * resizeRatio + } + + val resMap = Arguments.createMap() + + resMap.putString("text", decodedTexts[0]) + resMap.putArray("bbox", box.toWritableArray()) + resMap.putDouble("confidence", confidenceScore) + + res.pushMap(resMap) + } + + return res + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt new file mode 100644 index 0000000000..2772cc4a98 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt @@ -0,0 +1,56 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class Recognizer(reactApplicationContext: ReactApplicationContext) : + BaseModel, Double>>(reactApplicationContext) { + + private fun getModelOutputSize(): Size { + val outputShape = module.getOutputShape(0) + val width = outputShape[outputShape.lastIndex] + val height = outputShape[outputShape.lastIndex - 1] + + return Size(height.toDouble(), width.toDouble()) + } + + override fun preprocess(input: Mat): EValue { + return ImageProcessor.matToEValueGray(input) + } + + override fun postprocess(output: Array): Pair, Double> { + val modelOutputHeight = getModelOutputSize().height.toInt() + val tensor = output[0].toTensor().dataAsFloatArray + val numElements = tensor.size + val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight + val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F) + var counter = 0 + var currentRow = 0 + for (num in tensor) { + resultMat.put(currentRow, counter, floatArrayOf(num)) + counter++ + if (counter >= modelOutputHeight) { + counter = 0 + currentRow++ + } + } + + var probabilities = RecognizerUtils.softmax(resultMat) + val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight) + probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm) + val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities) + + val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices) + return Pair(indices, confidenceScore) + } + + + override fun runModel(input: Mat): Pair, Double> { + return postprocess(module.forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt new file mode 100644 index 0000000000..007e7e7c29 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt @@ -0,0 +1,57 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import java.io.File + +class CTCLabelConverter( + characters: String, +) { + private val dict = mutableMapOf() + private val character: List + private val ignoreIdx: List + + init { + val mutableCharacters = mutableListOf("[blank]") + characters.forEachIndexed { index, char -> + mutableCharacters.add(char.toString()) + dict[char.toString()] = index + 1 + } + character = mutableCharacters.toList() + + val ignoreIndexes = mutableListOf(0) + + ignoreIdx = ignoreIndexes.toList() + } + + fun decodeGreedy(textIndex: List, length: Int): List { + val texts = mutableListOf() + var index = 0 + while (index < textIndex.size) { + val segmentLength = minOf(length, textIndex.size - index) + val subArray = textIndex.subList(index, index + segmentLength) + + val text = StringBuilder() + var lastChar: Int? = null + val isNotRepeated = mutableListOf(true) + val isNotIgnored = mutableListOf() + + subArray.forEachIndexed { i, currentChar -> + if (i > 0) { + isNotRepeated.add(lastChar != currentChar) + } + isNotIgnored.add(!ignoreIdx.contains(currentChar)) + lastChar = currentChar + } + + subArray.forEachIndexed { j, charIndex -> + if (isNotRepeated[j] && isNotIgnored[j]) { + text.append(character[charIndex]) + } + } + + texts.add(text.toString()) + index += segmentLength + if (segmentLength < length) break + } + return texts.toList() + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt new file mode 100644 index 0000000000..b49232f41a --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt @@ -0,0 +1,27 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import org.opencv.core.Scalar + +class Constants { + companion object { + const val RECOGNIZER_RATIO = 1.6 + const val MODEL_HEIGHT = 64 + const val LARGE_MODEL_WIDTH = 512 + const val MEDIUM_MODEL_WIDTH = 256 + const val SMALL_MODEL_WIDTH = 128 + const val LOW_CONFIDENCE_THRESHOLD = 0.3 + const val ADJUST_CONTRAST = 0.2 + const val TEXT_THRESHOLD = 0.4 + const val LINK_THRESHOLD = 0.4 + const val LOW_TEXT_THRESHOLD = 0.7 + const val CENTER_THRESHOLD = 0.5 + const val DISTANCE_THRESHOLD = 2.0 + const val HEIGHT_THRESHOLD = 2.0 + const val MIN_SIDE_THRESHOLD = 15 + const val MAX_SIDE_THRESHOLD = 30 + const val MAX_WIDTH = (LARGE_MODEL_WIDTH + (LARGE_MODEL_WIDTH * 0.15)).toInt() + const val MIN_SIZE = 20 + val MEAN = Scalar(0.485, 0.456, 0.406) + val VARIANCE = Scalar(0.229, 0.224, 0.225) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt new file mode 100644 index 0000000000..4beb7ecf45 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -0,0 +1,468 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.WritableArray +import org.opencv.core.Core +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.MatOfFloat4 +import org.opencv.core.MatOfInt +import org.opencv.core.MatOfPoint +import org.opencv.core.MatOfPoint2f +import org.opencv.core.Point +import org.opencv.core.Rect +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import kotlin.math.abs +import kotlin.math.atan +import kotlin.math.cos +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow +import kotlin.math.sin +import kotlin.math.sqrt + +class DetectorUtils { + companion object { + private fun normalizeAngle(angle: Double): Double { + if (angle > 45.0) { + return angle - 90.0 + } + + return angle + } + + private fun midpoint(p1: BBoxPoint, p2: BBoxPoint): BBoxPoint { + val midpoint = BBoxPoint((p1.x + p2.x) / 2, (p1.y + p2.y) / 2) + return midpoint + } + + private fun distanceBetweenPoints(p1: BBoxPoint, p2: BBoxPoint): Double { + return sqrt((p1.x - p2.x).pow(2.0) + (p1.y - p2.y).pow(2.0)) + } + + private fun centerOfBox(box: OCRbBox): BBoxPoint { + val p1 = box.bBox[0] + val p2 = box.bBox[2] + return midpoint(p1, p2) + } + + private fun maxSideLength(box: OCRbBox): Double { + var maxSideLength = 0.0 + val numOfPoints = box.bBox.size + for (i in 0 until numOfPoints) { + val currentPoint = box.bBox[i] + val nextPoint = box.bBox[(i + 1) % numOfPoints] + val sideLength = distanceBetweenPoints(currentPoint, nextPoint) + if (sideLength > maxSideLength) { + maxSideLength = sideLength + } + } + return maxSideLength + } + + private fun minSideLength(box: OCRbBox): Double { + var minSideLength = Double.MAX_VALUE + val numOfPoints = box.bBox.size + for (i in 0 until numOfPoints) { + val currentPoint = box.bBox[i] + val nextPoint = box.bBox[(i + 1) % numOfPoints] + val sideLength = distanceBetweenPoints(currentPoint, nextPoint) + if (sideLength < minSideLength) { + minSideLength = sideLength + } + } + return minSideLength + } + + + private fun calculateMinimalDistanceBetweenBoxes(box1: OCRbBox, box2: OCRbBox): Double { + var minDistance = Double.MAX_VALUE + for (i in 0 until 4) { + for (j in 0 until 4) { + val distance = distanceBetweenPoints(box1.bBox[i], box2.bBox[j]) + if (distance < minDistance) { + minDistance = distance + } + } + } + + return minDistance + } + + private fun rotateBox(box: OCRbBox, angle: Double): OCRbBox { + val center = centerOfBox(box) + val radians = angle * Math.PI / 180 + val newBBox = box.bBox.map { point -> + val translatedX = point.x - center.x + val translatedY = point.y - center.y + val rotatedX = translatedX * cos(radians) - translatedY * sin(radians) + val rotatedY = translatedX * sin(radians) + translatedY * cos(radians) + BBoxPoint(rotatedX + center.x, rotatedY + center.y) + } + + return OCRbBox(newBBox, box.angle) + } + + private fun orderPointsClockwise(box: OCRbBox): OCRbBox { + var topLeft = box.bBox[0] + var topRight = box.bBox[1] + var bottomRight = box.bBox[2] + var bottomLeft = box.bBox[3] + var minSum = Double.MAX_VALUE + var maxSum = -Double.MAX_VALUE + var minDiff = Double.MAX_VALUE + var maxDiff = -Double.MAX_VALUE + + for (point in box.bBox) { + val sum = point.x + point.y + val diff = point.x - point.y + if (sum < minSum) { + minSum = sum + topLeft = point + } + if (sum > maxSum) { + maxSum = sum + bottomRight = point + } + if (diff < minDiff) { + minDiff = diff + bottomLeft = point + } + if (diff > maxDiff) { + maxDiff = diff + topRight = point + } + } + + return OCRbBox(listOf(topLeft, topRight, bottomRight, bottomLeft), box.angle) + } + + private fun mergeRotatedBoxes(box1: OCRbBox, box2: OCRbBox): OCRbBox { + val orderedBox1 = orderPointsClockwise(box1) + val orderedBox2 = orderPointsClockwise(box2) + + val allPoints = arrayListOf() + allPoints.addAll(orderedBox1.bBox.map { Point(it.x, it.y) }) + allPoints.addAll(orderedBox2.bBox.map { Point(it.x, it.y) }) + + val matOfAllPoints = MatOfPoint() + matOfAllPoints.fromList(allPoints) + + val hullIndices = MatOfInt() + Imgproc.convexHull(matOfAllPoints, hullIndices, false) + + val hullPoints = hullIndices.toArray().map { allPoints[it] } + + val matOfHullPoints = MatOfPoint2f() + matOfHullPoints.fromList(hullPoints) + + val minAreaRect = Imgproc.minAreaRect(matOfHullPoints) + val rectPoints = arrayOfNulls(4) + minAreaRect.points(rectPoints) + + val bBoxPoints = rectPoints.filterNotNull().map { BBoxPoint(it.x, it.y) } + + return OCRbBox(bBoxPoints, minAreaRect.angle) + } + + private fun removeSmallBoxes( + boxes: MutableList, + minSideThreshold: Int, + maxSideThreshold: Int + ): MutableList { + return boxes.filter { minSideLength(it) > minSideThreshold && maxSideLength(it) > maxSideThreshold } + .toMutableList() + } + + private fun minimumYFromBox(box: List): Double = box.minOf { it.y } + + private fun fitLineToShortestSides(box: OCRbBox): LineInfo { + val sides = mutableListOf>() + val midpoints = mutableListOf() + + for (i in box.bBox.indices) { + val p1 = box.bBox[i] + val p2 = box.bBox[(i + 1) % 4] + val sideLength = distanceBetweenPoints(p1, p2) + sides.add(sideLength to i) + midpoints.add(midpoint(p1, p2)) + } + + sides.sortBy { it.first } + + val midpoint1 = midpoints[sides[0].second] + val midpoint2 = midpoints[sides[1].second] + + val dx = abs(midpoint2.x - midpoint1.x) + val line = MatOfFloat4() + + val isVertical = if (dx < 20) { + for (point in arrayOf(midpoint1, midpoint2)) { + val temp = point.x + point.x = point.y + point.y = temp + } + Imgproc.fitLine( + MatOfPoint2f( + Point(midpoint1.x, midpoint1.y), + Point(midpoint2.x, midpoint2.y) + ), line, Imgproc.DIST_L2, 0.0, 0.01, 0.01 + ) + true + } else { + Imgproc.fitLine( + MatOfPoint2f( + Point(midpoint1.x, midpoint1.y), + Point(midpoint2.x, midpoint2.y) + ), line, Imgproc.DIST_L2, 0.0, 0.01, 0.01 + ) + false + } + + val m = line.get(1, 0)[0] / line.get(0, 0)[0] // slope + val c = line.get(3, 0)[0] - m * line.get(2, 0)[0] // intercept + return LineInfo(m, c, isVertical) + } + + private fun findClosestBox( + boxes: MutableList, + ignoredIds: Set, + currentBox: OCRbBox, + isVertical: Boolean, + m: Double, + c: Double, + centerThreshold: Double + ): Pair? { + var smallestDistance = Double.MAX_VALUE + var idx = -1 + var boxHeight = 0.0 + val centerOfCurrentBox = centerOfBox(currentBox) + boxes.forEachIndexed { i, box -> + if (ignoredIds.contains(i)) { + return@forEachIndexed + } + val centerOfProcessedBox = centerOfBox(box) + val distanceBetweenCenters = distanceBetweenPoints(centerOfCurrentBox, centerOfProcessedBox) + if (distanceBetweenCenters >= smallestDistance) { + return@forEachIndexed + } + boxHeight = minSideLength(box) + val lineDistance = if (isVertical) + abs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) + else + abs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c)) + + if (lineDistance < boxHeight * centerThreshold) { + idx = i + smallestDistance = distanceBetweenCenters + } + } + + return if (idx == -1) null else Pair(idx, boxHeight) + } + + private fun createMaskFromLabels(labels: Mat, labelValue: Int): Mat { + val mask = Mat.zeros(labels.size(), CvType.CV_8U) + + Core.compare(labels, Scalar(labelValue.toDouble()), mask, Core.CMP_EQ) + + return mask + } + + fun interleavedArrayToMats(array: FloatArray, size: Size): Pair { + val mat1 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) + val mat2 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) + + array.forEachIndexed { index, value -> + val x = (index / 2) % (size.width.toInt()) + val y = (index / 2) / size.width.toInt() + if (index % 2 == 0) { + mat1.put(y, x, value.toDouble()) + } else { + mat2.put(y, x, value.toDouble()) + } + } + + return Pair(mat1, mat2) + } + + fun getDetBoxesFromTextMap( + textMap: Mat, + affinityMap: Mat, + textThreshold: Double, + linkThreshold: Double, + lowTextThreshold: Double + ): MutableList { + val imgH = textMap.rows() + val imgW = textMap.cols() + + val textScore = Mat() + val affinityScore = Mat() + Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) + Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) + val textScoreComb = Mat() + Core.add(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 1.0, Imgproc.THRESH_BINARY) + + val binaryMat = Mat() + textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) + + val detectedBoxes = mutableListOf() + for (i in 1 until nLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + if (area < 10) continue + val mask = createMaskFromLabels(labels, i) + val maxValResult = Core.minMaxLoc(textMap, mask) + val maxVal = maxValResult.maxVal + if (maxVal < lowTextThreshold) continue + val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) + segMap.setTo(Scalar(255.0), mask) + + val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() + val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() + val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() + val sx = max(x - dilationRadius, 0) + val ex = min(x + w + dilationRadius + 1, imgW) + val sy = max(y - dilationRadius, 0) + val ey = min(y + h + dilationRadius + 1, imgH) + val roi = Rect(sx, sy, ex - sx, ey - sy) + val kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()) + ) + val roiSegMap = Mat(segMap, roi) + Imgproc.dilate(roiSegMap, roiSegMap, kernel) + + val contours: List = ArrayList() + Imgproc.findContours( + segMap, + contours, + Mat(), + Imgproc.RETR_EXTERNAL, + Imgproc.CHAIN_APPROX_SIMPLE + ) + if (contours.isNotEmpty()) { + val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) + val points = Array(4) { Point() } + minRect.points(points) + val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } + val boxInfo = OCRbBox(pointsList, minRect.angle) + detectedBoxes.add(boxInfo) + } + } + + return detectedBoxes + } + + fun restoreBoxRatio(boxes: MutableList, restoreRatio: Float): MutableList { + for (box in boxes) { + for (b in box.bBox) { + b.x *= restoreRatio + b.y *= restoreRatio + } + } + + return boxes + } + + fun groupTextBoxes( + boxes: MutableList, + centerThreshold: Double, + distanceThreshold: Double, + heightThreshold: Double, + minSideThreshold: Int, + maxSideThreshold: Int, + maxWidth: Int + ): MutableList { + boxes.sortByDescending { maxSideLength(it) } + var mergedArray = mutableListOf() + + while (boxes.isNotEmpty()) { + var currentBox = boxes.removeAt(0) + val normalizedAngle = normalizeAngle(currentBox.angle) + val ignoredIds = mutableSetOf() + var lineAngle: Double + while (true) { + val fittedLine = + fitLineToShortestSides(currentBox) + val slope = fittedLine.slope + val intercept = fittedLine.intercept + val isVertical = fittedLine.isVertical + + lineAngle = atan(slope) * 180 / Math.PI + if (isVertical) { + lineAngle = -90.0 + } + + val closestBoxInfo = findClosestBox( + boxes, ignoredIds, currentBox, + isVertical, slope, intercept, centerThreshold + ) ?: break + + val candidateIdx = closestBoxInfo.first + var candidateBox = boxes[candidateIdx] + val candidateHeight = closestBoxInfo.second + if ((candidateBox.angle == 90.0 && !isVertical) || (candidateBox.angle == 0.0 && isVertical)) { + candidateBox = + rotateBox(candidateBox, normalizedAngle) + } + val minDistance = + calculateMinimalDistanceBetweenBoxes(candidateBox, currentBox) + val mergedHeight = minSideLength(currentBox) + if (minDistance < distanceThreshold * candidateHeight && abs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { + currentBox = mergeRotatedBoxes(currentBox, candidateBox) + boxes.removeAt(candidateIdx) + ignoredIds.clear() + if (maxSideLength(currentBox) > maxWidth) { + break + } + } else { + ignoredIds.add(candidateIdx) + } + } + mergedArray.add(currentBox.copy(angle = lineAngle)) + } + + mergedArray = removeSmallBoxes(mergedArray, minSideThreshold, maxSideThreshold) + mergedArray = mergedArray.sortedWith(compareBy { minimumYFromBox(it.bBox) }).toMutableList() + + return mergedArray + } + } +} + +data class BBoxPoint( + var x: Double, + var y: Double, +) + +data class OCRbBox( + val bBox: List, + val angle: Double, +) { + fun toWritableArray(): WritableArray { + val array = Arguments.createArray() + bBox.forEach { point -> + val pointMap = Arguments.createMap() + pointMap.putDouble("x", point.x) + pointMap.putDouble("y", point.y) + array.pushMap(pointMap) + } + return array + } +} + +data class LineInfo( + val slope: Double, + val intercept: Double, + val isVertical: Boolean +) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt new file mode 100644 index 0000000000..99adcad9f0 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -0,0 +1,269 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Core +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.MatOfPoint2f +import org.opencv.core.Point +import org.opencv.core.Rect +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow +import kotlin.math.sqrt + +class RecognizerUtils { + companion object { + private fun calculateRatio(width: Int, height: Int): Double { + var ratio = width.toDouble() / height.toDouble() + if (ratio < 1.0) { + ratio = 1.0 / ratio + } + + return ratio + } + + private fun findIntersection(r1: Rect, r2: Rect): Rect { + val aLeft = r1.x + val aTop = r1.y + val aRight = r1.x + r1.width + val aBottom = r1.y + r1.height + + val bLeft = r2.x + val bTop = r2.y + val bRight = r2.x + r2.width + val bBottom = r2.y + r2.height + + val iLeft = max(aLeft, bLeft) + val iTop = max(aTop, bTop) + val iRight = min(aRight, bRight) + val iBottom = min(aBottom, bBottom) + + return if (iRight > iLeft && iBottom > iTop) { + Rect(iLeft, iTop, iRight - iLeft, iBottom - iTop) + } else { + Rect() + } + } + + private fun adjustContrastGrey(img: Mat, target: Double): Mat { + var high = 0 + var low = 255 + + for (i in 0 until img.rows()) { + for (j in 0 until img.cols()) { + val pixel = img.get(i, j)[0].toInt() + high = maxOf(high, pixel) + low = minOf(low, pixel) + } + } + + val contrast = (high - low) / 255.0 + + if (contrast < target) { + val ratio = 200.0 / maxOf(10, high - low) + val tempImg = Mat() + img.convertTo(tempImg, CvType.CV_32F) + Core.subtract(tempImg, Scalar(low.toDouble() - 25), tempImg) + Core.multiply(tempImg, Scalar(ratio), tempImg) + Imgproc.threshold(tempImg, tempImg, 255.0, 255.0, Imgproc.THRESH_TRUNC) + Imgproc.threshold(tempImg, tempImg, 0.0, 255.0, Imgproc.THRESH_TOZERO) + tempImg.convertTo(tempImg, CvType.CV_8U) + + return tempImg + } + + return img + } + + private fun computeRatioAndResize(img: Mat, width: Int, height: Int, modelHeight: Int): Mat { + var ratio = width.toDouble() / height.toDouble() + + if (ratio < 1.0) { + ratio = + calculateRatio(width, height) + Imgproc.resize( + img, img, Size(modelHeight.toDouble(), (modelHeight * ratio)), + 0.0, 0.0, Imgproc.INTER_LANCZOS4 + ) + } else { + Imgproc.resize( + img, img, Size((modelHeight * ratio), modelHeight.toDouble()), + 0.0, 0.0, Imgproc.INTER_LANCZOS4 + ) + } + + return img + } + + fun softmax(inputs: Mat): Mat { + val maxVal = Mat() + Core.reduce(inputs, maxVal, 1, Core.REDUCE_MAX, CvType.CV_32F) + + val tiledMaxVal = Mat() + Core.repeat(maxVal, 1, inputs.width(), tiledMaxVal) + val expInputs = Mat() + Core.subtract(inputs, tiledMaxVal, expInputs) + Core.exp(expInputs, expInputs) + + val sumExp = Mat() + Core.reduce(expInputs, sumExp, 1, Core.REDUCE_SUM, CvType.CV_32F) + + val tiledSumExp = Mat() + Core.repeat(sumExp, 1, inputs.width(), tiledSumExp) + val softmaxOutput = Mat() + Core.divide(expInputs, tiledSumExp, softmaxOutput) + + return softmaxOutput + } + + fun sumProbabilityRows(probabilities: Mat, modelOutputHeight: Int): FloatArray { + val predsNorm = FloatArray(probabilities.rows()) + + for (i in 0 until probabilities.rows()) { + var sum = 0.0 + for (j in 0 until modelOutputHeight) { + sum += probabilities.get(i, j)[0] + } + predsNorm[i] = sum.toFloat() + } + + return predsNorm + } + + fun divideMatrixByVector(matrix: Mat, vector: FloatArray): Mat { + for (i in 0 until matrix.rows()) { + for (j in 0 until matrix.cols()) { + val value = matrix.get(i, j)[0] / vector[i] + matrix.put(i, j, value) + } + } + + return matrix + } + + fun findMaxValuesAndIndices(probabilities: Mat): Pair> { + val values = DoubleArray(probabilities.rows()) + val indices = mutableListOf() + + for (i in 0 until probabilities.rows()) { + val row = probabilities.row(i) + val minMaxLocResult = Core.minMaxLoc(row) + + values[i] = minMaxLocResult.maxVal + indices.add(minMaxLocResult.maxLoc.x.toInt()) + } + + return Pair(values, indices) + } + + fun computeConfidenceScore(valuesArray: DoubleArray, indicesArray: List): Double { + val predsMaxProb = mutableListOf() + for ((index, value) in indicesArray.withIndex()) { + if (value != 0) predsMaxProb.add(valuesArray[index]) + } + + val nonZeroValues = + if (predsMaxProb.isEmpty()) doubleArrayOf(0.0) else predsMaxProb.toDoubleArray() + val product = nonZeroValues.reduce { acc, d -> acc * d } + val score = product.pow(2.0 / sqrt(nonZeroValues.size.toDouble())) + + return score + } + + fun calculateResizeRatioAndPaddings( + width: Int, + height: Int, + desiredWidth: Int, + desiredHeight: Int + ): Map { + val newRatioH = desiredHeight.toFloat() / height + val newRatioW = desiredWidth.toFloat() / width + var resizeRatio = minOf(newRatioH, newRatioW) + + val newWidth = (width * resizeRatio).toInt() + val newHeight = (height * resizeRatio).toInt() + + val deltaW = desiredWidth - newWidth + val deltaH = desiredHeight - newHeight + + val top = deltaH / 2 + val left = deltaW / 2 + + val heightRatio = height.toFloat() / desiredHeight + val widthRatio = width.toFloat() / desiredWidth + + resizeRatio = maxOf(heightRatio, widthRatio) + + return mapOf( + "resizeRatio" to resizeRatio, + "top" to top, + "left" to left + ) + } + + fun getCroppedImage(box: OCRbBox, image: Mat, modelHeight: Int): Mat { + val cords = box.bBox + val angle = box.angle + val points = ArrayList() + + cords.forEach { point -> + points.add(Point(point.x, point.y)) + } + + val rotatedRect = Imgproc.minAreaRect(MatOfPoint2f(*points.toTypedArray())) + val imageCenter = Point((image.cols() / 2.0), (image.rows() / 2.0)) + val rotationMatrix = Imgproc.getRotationMatrix2D(imageCenter, angle, 1.0) + val rotatedImage = Mat() + Imgproc.warpAffine(image, rotatedImage, rotationMatrix, image.size(), Imgproc.INTER_LINEAR) + + val rectPoints = Array(4) { Point() } + rotatedRect.points(rectPoints) + val transformedPoints = arrayOfNulls(4) + val rectMat = Mat(4, 2, CvType.CV_32FC2) + for (i in 0 until 4) { + rectMat.put(i, 0, *doubleArrayOf(rectPoints[i].x, rectPoints[i].y)) + } + Core.transform(rectMat, rectMat, rotationMatrix) + + for (i in 0 until 4) { + transformedPoints[i] = Point(rectMat.get(i, 0)[0], rectMat.get(i, 0)[1]) + } + + var boundingBox = + Imgproc.boundingRect(MatOfPoint2f(*transformedPoints.filterNotNull().toTypedArray())) + val validRegion = Rect(0, 0, rotatedImage.cols(), rotatedImage.rows()) + boundingBox = findIntersection(boundingBox, validRegion) + val croppedImage = Mat(rotatedImage, boundingBox) + if (croppedImage.empty()) { + return croppedImage + } + + return computeRatioAndResize(croppedImage, boundingBox.width, boundingBox.height, modelHeight) + } + + fun normalizeForRecognizer(image: Mat, adjustContrast: Double): Mat { + var img = image.clone() + + if (adjustContrast > 0) { + img = adjustContrastGrey(img, adjustContrast) + } + + val desiredWidth = when { + img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH + img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH + else -> Constants.SMALL_MODEL_WIDTH + } + + img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) + img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) + Core.subtract(img, Scalar(0.5), img) + Core.multiply(img, Scalar(2.0), img) + + return img + } + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt index 36a0a1a101..3dcbbed45d 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt @@ -19,6 +19,7 @@ class ArrayUtils { fun createCharArray(input: ReadableArray): CharArray { return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toChar() }.toCharArray() } + fun createIntArray(input: ReadableArray): IntArray { return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray() } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt index 5488ecb476..1e00aa4807 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt @@ -3,20 +3,30 @@ package com.swmansion.rnexecutorch.utils import android.content.Context import android.net.Uri import android.util.Base64 +import android.util.Log +import org.opencv.core.Core import org.opencv.core.CvType import org.opencv.core.Mat +import org.opencv.core.Scalar +import org.opencv.core.Size import org.opencv.imgcodecs.Imgcodecs +import org.opencv.imgproc.Imgproc import org.pytorch.executorch.EValue import org.pytorch.executorch.Tensor import java.io.File import java.io.InputStream import java.net.URL import java.util.UUID +import kotlin.math.floor class ImageProcessor { companion object { fun matToEValue(mat: Mat, shape: LongArray): EValue { + return matToEValue(mat, shape, Scalar(0.0, 0.0, 0.0), Scalar(1.0, 1.0, 1.0)) + } + + fun matToEValue(mat: Mat, shape: LongArray, mean: Scalar, variance: Scalar): EValue { val pixelCount = mat.cols() * mat.rows() val floatArray = FloatArray(pixelCount * 3) @@ -26,19 +36,38 @@ class ImageProcessor { val pixel = mat.get(row, col) if (mat.type() == CvType.CV_8UC3 || mat.type() == CvType.CV_8UC4) { - val b = pixel[0] / 255.0f - val g = pixel[1] / 255.0f - val r = pixel[2] / 255.0f + val b = (pixel[0] - mean.`val`[0] * 255.0f) / (variance.`val`[0] * 255.0f) + val g = (pixel[1] - mean.`val`[1] * 255.0f) / (variance.`val`[1] * 255.0f) + val r = (pixel[2] - mean.`val`[2] * 255.0f) / (variance.`val`[2] * 255.0f) - floatArray[i] = r.toFloat() - floatArray[pixelCount + i] = g.toFloat() - floatArray[2 * pixelCount + i] = b.toFloat() + floatArray[0 * pixelCount + i] = b.toFloat() + floatArray[1 * pixelCount + i] = g.toFloat() + floatArray[2 * pixelCount + i] = r.toFloat() } } return EValue.from(Tensor.fromBlob(floatArray, shape)) } + fun matToEValueGray(mat: Mat): EValue { + val pixelCount = mat.cols() * mat.rows() + val floatArray = FloatArray(pixelCount) + + for (i in 0 until pixelCount) { + val row = i / mat.cols() + val col = i % mat.cols() + val pixel = mat.get(row, col) + floatArray[i] = pixel[0].toFloat() + } + + return EValue.from( + Tensor.fromBlob( + floatArray, + longArrayOf(1, 1, mat.rows().toLong(), mat.cols().toLong()) + ) + ) + } + fun EValueToMat(array: FloatArray, width: Int, height: Int): Mat { val mat = Mat(height, width, CvType.CV_8UC3) @@ -64,7 +93,7 @@ class ImageProcessor { Imgcodecs.imwrite(tempFile.absolutePath, mat) return "file://${tempFile.absolutePath}" - }catch (e: Exception) { + } catch (e: Exception) { throw Exception(ETError.FileWriteFailed.toString()) } } @@ -89,11 +118,13 @@ class ImageProcessor { } inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR) } + scheme.equals("file", ignoreCase = true) -> { //device storage val path = uri.path inputImage = Imgcodecs.imread(path, Imgcodecs.IMREAD_COLOR) } + else -> { //external source val url = URL(source) @@ -117,5 +148,70 @@ class ImageProcessor { return inputImage } + + fun resizeWithPadding(img: Mat, desiredWidth: Int, desiredHeight: Int): Mat { + val height = img.rows() + val width = img.cols() + val heightRatio = desiredHeight.toFloat() / height + val widthRatio = desiredWidth.toFloat() / width + val resizeRatio = minOf(heightRatio, widthRatio) + val newWidth = (width * resizeRatio).toInt() + val newHeight = (height * resizeRatio).toInt() + + val resizedImg = Mat() + Imgproc.resize( + img, + resizedImg, + Size(newWidth.toDouble(), newHeight.toDouble()), + 0.0, + 0.0, + Imgproc.INTER_AREA + ) + + val cornerPatchSize = maxOf(1, minOf(width, height) / 30) + val corners = listOf( + img.submat(0, cornerPatchSize, 0, cornerPatchSize), + img.submat(0, cornerPatchSize, width - cornerPatchSize, width), + img.submat(height - cornerPatchSize, height, 0, cornerPatchSize), + img.submat(height - cornerPatchSize, height, width - cornerPatchSize, width) + ) + + var backgroundScalar = Core.mean(corners[0]) + for (i in 1 until corners.size) { + val mean = Core.mean(corners[i]) + backgroundScalar = Scalar( + backgroundScalar.`val`[0] + mean.`val`[0], + backgroundScalar.`val`[1] + mean.`val`[1], + backgroundScalar.`val`[2] + mean.`val`[2] + ) + } + + backgroundScalar = Scalar( + floor(backgroundScalar.`val`[0] / corners.size), + floor(backgroundScalar.`val`[1] / corners.size), + floor(backgroundScalar.`val`[2] / corners.size) + ) + + val deltaW = desiredWidth - newWidth + val deltaH = desiredHeight - newHeight + val top = deltaH / 2 + val bottom = deltaH - top + val left = deltaW / 2 + val right = deltaW - left + + val centeredImg = Mat() + Core.copyMakeBorder( + resizedImg, + centeredImg, + top, + bottom, + left, + right, + Core.BORDER_CONSTANT, + backgroundScalar + ) + + return centeredImg + } } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt deleted file mode 100644 index 7a2dda79f5..0000000000 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt +++ /dev/null @@ -1,7 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -import okhttp3.OkHttpClient - -object OkHttpClientSingleton { - val instance = OkHttpClient() -} \ No newline at end of file diff --git a/examples/computer-vision/App.tsx b/examples/computer-vision/App.tsx index 8d01269fd0..488c61cd56 100644 --- a/examples/computer-vision/App.tsx +++ b/examples/computer-vision/App.tsx @@ -8,11 +8,13 @@ import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context'; import { View, StyleSheet } from 'react-native'; import { ClassificationScreen } from './screens/ClassificationScreen'; import { ObjectDetectionScreen } from './screens/ObjectDetectionScreen'; +import { OCRScreen } from './screens/OCRScreen'; enum ModelType { STYLE_TRANSFER, OBJECT_DETECTION, CLASSIFICATION, + OCR, } export default function App() { @@ -46,6 +48,8 @@ export default function App() { return ( ); + case ModelType.OCR: + return ; default: return ( @@ -64,6 +68,7 @@ export default function App() { 'Style Transfer', 'Object Detection', 'Classification', + 'OCR', ]} onValueChange={(_, selectedIndex) => { handleModeChange(selectedIndex); diff --git a/examples/computer-vision/components/ImageWithOCRBboxes.tsx b/examples/computer-vision/components/ImageWithOCRBboxes.tsx new file mode 100644 index 0000000000..1c8fe616af --- /dev/null +++ b/examples/computer-vision/components/ImageWithOCRBboxes.tsx @@ -0,0 +1,103 @@ +// Import necessary components +import React from 'react'; +import { Image, StyleSheet, View } from 'react-native'; +import Svg, { Polygon } from 'react-native-svg'; +import { OCRDetection } from 'react-native-executorch'; + +interface Props { + imageUri: string; + detections: OCRDetection[]; + imageWidth: number; + imageHeight: number; +} + +export default function ImageWithOCRBboxes({ + imageUri, + detections, + imageWidth, + imageHeight, +}: Props) { + const [layout, setLayout] = React.useState({ width: 0, height: 0 }); + + const calculateAdjustedDimensions = () => { + const imageRatio = imageWidth / imageHeight; + const layoutRatio = layout.width / layout.height; + let sx, sy; + if (imageRatio > layoutRatio) { + sx = layout.width / imageWidth; + sy = layout.width / imageRatio / imageHeight; + } else { + sy = layout.height / imageHeight; + sx = (layout.height * imageRatio) / imageWidth; + } + return { + scaleX: sx, + scaleY: sy, + offsetX: (layout.width - imageWidth * sx) / 2, + offsetY: (layout.height - imageHeight * sy) / 2, + }; + }; + + return ( + { + const { width, height } = event.nativeEvent.layout; + setLayout({ width, height }); + }} + > + + + {detections.map((detection, index) => { + const { scaleX, scaleY, offsetX, offsetY } = + calculateAdjustedDimensions(); + const points = detection.bbox.map((point) => ({ + x: point.x * scaleX + offsetX, + y: point.y * scaleY + offsetY, + })); + + const pointsString = points + .map((point) => `${point.x},${point.y}`) + .join(' '); + + return ( + + ); + })} + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + position: 'relative', + }, + image: { + flex: 1, + width: '100%', + height: '100%', + }, + svgContainer: { + position: 'absolute', + top: 0, + left: 0, + right: 0, + bottom: 0, + }, +}); diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx new file mode 100644 index 0000000000..9d17118afb --- /dev/null +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -0,0 +1,112 @@ +import Spinner from 'react-native-loading-spinner-overlay'; +import { BottomBar } from '../components/BottomBar'; +import { getImage } from '../utils'; +import { useOCR } from 'react-native-executorch'; +import { View, StyleSheet, Image, Text } from 'react-native'; +import { useState } from 'react'; +import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; + +export const OCRScreen = ({ + imageUri, + setImageUri, +}: { + imageUri: string; + setImageUri: (imageUri: string) => void; +}) => { + const [results, setResults] = useState([]); + const [imageDimensions, setImageDimensions] = useState<{ + width: number; + height: number; + }>(); + const [detectedText, setDetectedText] = useState(''); + const model = useOCR({ + detectorSource: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_800.pte', + recognizerSources: { + recognizerLarge: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', + recognizerMedium: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_256.pte', + recognizerSmall: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_128.pte', + }, + language: 'en', + }); + + const handleCameraPress = async (isCamera: boolean) => { + const image = await getImage(isCamera); + const width = image?.width; + const height = image?.height; + setImageDimensions({ width: width as number, height: height as number }); + const uri = image?.uri; + if (typeof uri === 'string') { + setImageUri(uri as string); + setResults([]); + setDetectedText(''); + } + }; + + const runForward = async () => { + try { + const output = await model.forward(imageUri); + setResults(output); + console.log(output); + let txt = ''; + output.forEach((detection: any) => { + txt += detection.text + ' '; + }); + setDetectedText(txt); + } catch (e) { + console.error(e); + } + }; + + if (!model.isReady) { + return ( + + ); + } + + return ( + <> + + + {imageUri && imageDimensions?.width && imageDimensions?.height ? ( + + ) : ( + + )} + + {detectedText} + + + + ); +}; + +const styles = StyleSheet.create({ + image: { + flex: 2, + borderRadius: 8, + width: '100%', + }, + imageContainer: { + flex: 6, + width: '100%', + padding: 16, + }, +}); diff --git a/ios/RnExecutorch.xcodeproj/project.pbxproj b/ios/RnExecutorch.xcodeproj/project.pbxproj index 3fad88ed1c..68e367a8e8 100644 --- a/ios/RnExecutorch.xcodeproj/project.pbxproj +++ b/ios/RnExecutorch.xcodeproj/project.pbxproj @@ -35,12 +35,20 @@ LLM.h, ); }; + 552754CC2D394AC9006B38A2 /* Exceptions for "RnExecutorch" folder in "Compile Sources" phase from "RnExecutorch" target */ = { + isa = PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet; + buildPhase = 550986852CEF541900FECBB8 /* Sources */; + membershipExceptions = ( + models/ocr/utils/DetectorUtils.h, + ); + }; /* End PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */ /* Begin PBXFileSystemSynchronizedRootGroup section */ 5509868B2CEF541900FECBB8 /* RnExecutorch */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( + 552754CC2D394AC9006B38A2 /* Exceptions for "RnExecutorch" folder in "Compile Sources" phase from "RnExecutorch" target */, 550986902CEF541900FECBB8 /* Exceptions for "RnExecutorch" folder in "Copy Files" phase from "RnExecutorch" target */, ); path = RnExecutorch; @@ -119,6 +127,7 @@ TargetAttributes = { 550986882CEF541900FECBB8 = { CreatedOnToolsVersion = 16.1; + LastSwiftMigration = 1610; }; }; }; @@ -271,6 +280,7 @@ 550986942CEF541900FECBB8 /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { + CLANG_ENABLE_MODULES = YES; CODE_SIGN_STYLE = Automatic; OTHER_LDFLAGS = "-ObjC"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -279,6 +289,8 @@ SUPPORTS_MACCATALYST = NO; SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 6.0; TARGETED_DEVICE_FAMILY = "1,2"; }; name = Debug; @@ -286,6 +298,7 @@ 550986952CEF541900FECBB8 /* Release */ = { isa = XCBuildConfiguration; buildSettings = { + CLANG_ENABLE_MODULES = YES; CODE_SIGN_STYLE = Automatic; OTHER_LDFLAGS = "-ObjC"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -294,6 +307,7 @@ SUPPORTS_MACCATALYST = NO; SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; + SWIFT_VERSION = 6.0; TARGETED_DEVICE_FAMILY = "1,2"; }; name = Release; diff --git a/ios/RnExecutorch/OCR.h b/ios/RnExecutorch/OCR.h new file mode 100644 index 0000000000..68c0878598 --- /dev/null +++ b/ios/RnExecutorch/OCR.h @@ -0,0 +1,7 @@ +#import + +constexpr CGFloat recognizerRatio = 1.6; + +@interface OCR : NSObject + +@end diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm new file mode 100644 index 0000000000..59740c90bb --- /dev/null +++ b/ios/RnExecutorch/OCR.mm @@ -0,0 +1,101 @@ +#import "OCR.h" +#import "models/ocr/Detector.h" +#import "models/ocr/RecognitionHandler.h" +#import "utils/ImageProcessor.h" +#import +#import + +@implementation OCR { + Detector *detector; + RecognitionHandler *recognitionHandler; +} + +RCT_EXPORT_MODULE() + +- (void)loadModule:(NSString *)detectorSource + recognizerSourceLarge:(NSString *)recognizerSourceLarge + recognizerSourceMedium:(NSString *)recognizerSourceMedium + recognizerSourceSmall:(NSString *)recognizerSourceSmall + symbols:(NSString *)symbols + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + detector = [[Detector alloc] init]; + [detector + loadModel:[NSURL URLWithString:detectorSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + NSError *error = [NSError + errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{ + NSLocalizedDescriptionKey : [NSString + stringWithFormat:@"%ld", (long)[errorCode longValue]] + }]; + reject(@"init_module_error", @"Failed to initialize detector module", + error); + return; + } + self->recognitionHandler = + [[RecognitionHandler alloc] initWithSymbols:symbols]; + [self->recognitionHandler + loadRecognizers:recognizerSourceLarge + mediumRecognizerPath:recognizerSourceMedium + smallRecognizerPath:recognizerSourceSmall + completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { + if (allModelsLoaded) { + resolve(@(YES)); + } else { + NSError *error = [NSError + errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{ + NSLocalizedDescriptionKey : [NSString + stringWithFormat:@"%ld", + (long)[errorCode + longValue]] + }]; + reject(@"init_recognizer_error", + @"Failed to initialize one or more " + @"recognizer models", + error); + } + }]; + }]; +} + +- (void)forward:(NSString *)input + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + /* + The OCR consists of two phases: + 1. Detection - detecting text regions in the image, the result of this phase + is a list of bounding boxes. + 2. Recognition - recognizing the text in the bounding boxes, the result is a + list of strings and corresponding confidence scores. + + Recognition uses three models, each model is resposible for recognizing text + of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). + */ + @try { + cv::Mat image = [ImageProcessor readImage:input]; + NSArray *result = [detector runModel:image]; + cv::Size detectorSize = [detector getModelImageSize]; + cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); + result = [self->recognitionHandler + recognize:result + imgGray:image + desiredWidth:detectorSize.width * recognizerRatio + desiredHeight:detectorSize.height * recognizerRatio]; + resolve(result); + } @catch (NSException *exception) { + reject(@"forward_error", + [NSString stringWithFormat:@"%@", exception.reason], nil); + } +} + +- (std::shared_ptr)getTurboModule: + (const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared(params); +} + +@end diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h new file mode 100644 index 0000000000..0f67e93b84 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -0,0 +1,25 @@ +#import "BaseModel.h" +#import "RecognitionHandler.h" +#import "opencv2/opencv.hpp" + +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); +constexpr int minSize = 20; + +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); + +@interface Detector : BaseModel + +- (cv::Size)getModelImageSize; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm new file mode 100644 index 0000000000..20b82b5ee7 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -0,0 +1,100 @@ +#import "Detector.h" +#import "../../utils/ImageProcessor.h" +#import "utils/DetectorUtils.h" +#import "utils/OCRUtils.h" + +/* + The model used as detector is based on CRAFT (Character Region Awareness for + Text Detection) paper. https://arxiv.org/pdf/1904.01941 + */ + +@implementation Detector { + cv::Size originalSize; + cv::Size modelSize; +} + +- (cv::Size)getModelImageSize { + if (!modelSize.empty()) { + return modelSize; + } + + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape.lastObject; + NSNumber *heightNumber = inputShape[inputShape.count - 2]; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + modelSize = cv::Size(height, width); + + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + /* + Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. + */ + self->originalSize = cv::Size(input.cols, input.rows); + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat resizedImage; + resizedImage = [OCRUtils resizeWithPadding:input + desiredWidth:modelImageSize.width + desiredHeight:modelImageSize.height]; + NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage + mean:mean + variance:variance]; + return modelInput; +} + +- (NSArray *)postprocess:(NSArray *)output { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + 400x400 + + The result of this step is a list of bounding boxes that contain text. + */ + NSArray *predictions = [output objectAtIndex:0]; + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat scoreTextCV, scoreAffinityCV; + /* + The output of the model is a matrix in size of input image containing two + matrices representing heatmap. Those two matrices are in the size of half of + the input image, that's why the width and height is divided by 2. + */ + [DetectorUtils interleavedArrayToMats:predictions + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, + modelImageSize.height / 2)]; + NSArray *bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:textThreshold + linkThreshold:linkThreshold + lowTextThreshold:lowTextThreshold]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatio]; + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList + centerThreshold:centerThreshold + distanceThreshold:distanceThreshold + heightThreshold:heightThreshold + minSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold + maxWidth:maxWidth]; + + return bBoxesList; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + return result; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h new file mode 100644 index 0000000000..412504370e --- /dev/null +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -0,0 +1,22 @@ +#import "opencv2/opencv.hpp" + +constexpr int modelHeight = 64; +constexpr int largeModelWidth = 512; +constexpr int mediumModelWidth = 256; +constexpr int smallModelWidth = 128; +constexpr CGFloat lowConfidenceThreshold = 0.3; +constexpr CGFloat adjustContrast = 0.2; + +@interface RecognitionHandler : NSObject + +- (instancetype)initWithSymbols:(NSString *)symbols; +- (void)loadRecognizers:(NSString *)largeRecognizerPath + mediumRecognizerPath:(NSString *)mediumRecognizerPath + smallRecognizerPath:(NSString *)smallRecognizerPath + completion:(void (^)(BOOL, NSNumber *))completion; +- (NSArray *)recognize:(NSArray *)bBoxesList + imgGray:(cv::Mat)imgGray + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight; + +@end diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm new file mode 100644 index 0000000000..60616b9099 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -0,0 +1,151 @@ +#import "RecognitionHandler.h" +#import "../../utils/ImageProcessor.h" +#import "./utils/CTCLabelConverter.h" +#import "./utils/OCRUtils.h" +#import "./utils/RecognizerUtils.h" +#import "ExecutorchLib/ETModel.h" +#import "Recognizer.h" +#import + +/* + RecognitionHandler class is responsible for loading and choosing the + appropriate recognizer model based on the input image size, it also handles + converting the model output to text. + */ + +@implementation RecognitionHandler { + Recognizer *recognizerLarge; + Recognizer *recognizerMedium; + Recognizer *recognizerSmall; + CTCLabelConverter *converter; +} + +- (instancetype)initWithSymbols:(NSString *)symbols { + self = [super init]; + if (self) { + recognizerLarge = [[Recognizer alloc] init]; + recognizerMedium = [[Recognizer alloc] init]; + recognizerSmall = [[Recognizer alloc] init]; + + converter = [[CTCLabelConverter alloc] + initWithCharacters:symbols + separatorList:@{}]; + } + return self; +} + +- (void)loadRecognizers:(NSString *)largeRecognizerPath + mediumRecognizerPath:(NSString *)mediumRecognizerPath + smallRecognizerPath:(NSString *)smallRecognizerPath + completion:(void (^)(BOOL, NSNumber *))completion { + dispatch_group_t group = dispatch_group_create(); + __block BOOL allSuccessful = YES; + + NSArray *recognizers = + @[ recognizerLarge, recognizerMedium, recognizerSmall ]; + NSArray *paths = + @[ largeRecognizerPath, mediumRecognizerPath, smallRecognizerPath ]; + + for (NSInteger i = 0; i < recognizers.count; i++) { + Recognizer *recognizer = recognizers[i]; + NSString *path = paths[i]; + + dispatch_group_enter(group); + [recognizer loadModel:[NSURL URLWithString:path] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + allSuccessful = NO; + dispatch_group_leave(group); + completion(NO, errorCode); + return; + } + dispatch_group_leave(group); + }]; + } + + dispatch_group_notify(group, dispatch_get_main_queue(), ^{ + if (allSuccessful) { + completion(YES, @(0)); + } + }); +} + +- (NSArray *)runModel:(cv::Mat)croppedImage { + NSArray *result; + if (croppedImage.cols >= largeModelWidth) { + result = [recognizerLarge runModel:croppedImage]; + } else if (croppedImage.cols >= mediumModelWidth) { + result = [recognizerMedium runModel:croppedImage]; + } else { + result = [recognizerSmall runModel:croppedImage]; + } + + return result; +} + +- (NSArray *)recognize:(NSArray *)bBoxesList + imgGray:(cv::Mat)imgGray + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight { + NSDictionary *ratioAndPadding = + [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols + height:imgGray.rows + desiredWidth:desiredWidth + desiredHeight:desiredHeight]; + const int left = [ratioAndPadding[@"left"] intValue]; + const int top = [ratioAndPadding[@"top"] intValue]; + const CGFloat resizeRatio = [ratioAndPadding[@"resizeRatio"] floatValue]; + imgGray = [OCRUtils resizeWithPadding:imgGray + desiredWidth:desiredWidth + desiredHeight:desiredHeight]; + + NSMutableArray *predictions = [NSMutableArray array]; + for (NSDictionary *box in bBoxesList) { + cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box + image:imgGray + modelHeight:modelHeight]; + if (croppedImage.empty()) { + continue; + } + croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage + adjustContrast:adjustContrast]; + NSArray *result = [self runModel:croppedImage]; + + NSNumber *confidenceScore = [result objectAtIndex:1]; + if ([confidenceScore floatValue] < lowConfidenceThreshold) { + cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); + + NSArray *rotatedResult = [self runModel:croppedImage]; + NSNumber *rotatedConfidenceScore = [rotatedResult objectAtIndex:1]; + + if ([rotatedConfidenceScore floatValue] > [confidenceScore floatValue]) { + result = rotatedResult; + confidenceScore = rotatedConfidenceScore; + } + } + + NSArray *predIndex = [result objectAtIndex:0]; + NSArray *decodedTexts = [converter decodeGreedy:predIndex + length:(int)(predIndex.count)]; + + NSMutableArray *bbox = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *coords in box[@"bbox"]) { + const CGPoint point = [coords CGPointValue]; + [bbox addObject:@{ + @"x" : @((point.x - left) * resizeRatio), + @"y" : @((point.y - top) * resizeRatio) + }]; + } + + NSDictionary *res = @{ + @"text" : decodedTexts[0], + @"bbox" : bbox, + @"score" : confidenceScore + }; + [predictions addObject:res]; + } + + return predictions; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/Recognizer.h b/ios/RnExecutorch/models/ocr/Recognizer.h new file mode 100644 index 0000000000..4b301dbef7 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Recognizer.h @@ -0,0 +1,8 @@ +#import "BaseModel.h" +#import "opencv2/opencv.hpp" + +@interface Recognizer : BaseModel + +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm new file mode 100644 index 0000000000..8b339bc238 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -0,0 +1,78 @@ +#import "Recognizer.h" +#import "../../utils/ImageProcessor.h" +#import "RecognizerUtils.h" +#import "utils/OCRUtils.h" + +/* + The model used as detector is based on CRNN paper. + https://arxiv.org/pdf/1507.05717 + */ + +@implementation Recognizer { + cv::Size originalSize; +} + +- (cv::Size)getModelImageSize { + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape.lastObject; + NSNumber *heightNumber = inputShape[inputShape.count - 2]; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + return cv::Size(height, width); +} + +- (cv::Size)getModelOutputSize { + NSArray *outputShape = [module getOutputShape:@0]; + NSNumber *widthNumber = outputShape.lastObject; + NSNumber *heightNumber = outputShape[outputShape.count - 2]; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + return [ImageProcessor matToNSArrayGray:input]; +} + +- (NSArray *)postprocess:(NSArray *)output { + const int modelOutputHeight = [self getModelOutputSize].height; + NSInteger numElements = [output.firstObject count]; + NSInteger numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight; + cv::Mat resultMat = cv::Mat::zeros(numRows, modelOutputHeight, CV_32F); + NSInteger counter = 0; + NSInteger currentRow = 0; + for (NSNumber *num in output.firstObject) { + resultMat.at(currentRow, counter) = [num floatValue]; + counter++; + if (counter >= modelOutputHeight) { + counter = 0; + currentRow++; + } + } + + cv::Mat probabilities = [RecognizerUtils softmax:resultMat]; + NSMutableArray *predsNorm = + [RecognizerUtils sumProbabilityRows:probabilities + modelOutputHeight:modelOutputHeight]; + probabilities = [RecognizerUtils divideMatrix:probabilities + byVector:predsNorm]; + NSArray *maxValuesIndices = + [RecognizerUtils findMaxValuesAndIndices:probabilities]; + const CGFloat confidenceScore = + [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] + indicesArray:maxValuesIndices[1]]; + + return @[ maxValuesIndices[1], @(confidenceScore) ]; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + + return result; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h new file mode 100644 index 0000000000..498710dd03 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h @@ -0,0 +1,16 @@ +#import + +@interface CTCLabelConverter : NSObject + +@property(strong, nonatomic) NSMutableDictionary *dict; +@property(strong, nonatomic) NSArray *character; +@property(strong, nonatomic) NSDictionary *separatorList; +@property(strong, nonatomic) NSArray *ignoreIdx; +@property(strong, nonatomic) NSDictionary *dictList; + +- (instancetype)initWithCharacters:(NSString *)characters + separatorList:(NSDictionary *)separatorList; +- (NSArray *)decodeGreedy:(NSArray *)textIndex + length:(NSInteger)length; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm new file mode 100644 index 0000000000..7d50e3813f --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm @@ -0,0 +1,80 @@ +#import "CTCLabelConverter.h" + +@implementation CTCLabelConverter + +- (instancetype)initWithCharacters:(NSString *)characters + separatorList:(NSDictionary *)separatorList{ + self = [super init]; + if (self) { + _dict = [NSMutableDictionary dictionary]; + NSMutableArray *mutableCharacters = + [NSMutableArray arrayWithObject:@"[blank]"]; + + for (NSUInteger i = 0; i < [characters length]; i++) { + NSString *charStr = + [NSString stringWithFormat:@"%C", [characters characterAtIndex:i]]; + [mutableCharacters addObject:charStr]; + self.dict[charStr] = @(i + 1); + } + + _character = [mutableCharacters copy]; + _separatorList = separatorList; + + NSMutableArray *ignoreIndexes = [NSMutableArray arrayWithObject:@(0)]; + for (NSString *sep in separatorList.allValues) { + NSUInteger index = [characters rangeOfString:sep].location; + if (index != NSNotFound) { + [ignoreIndexes addObject:@(index)]; + } + } + _ignoreIdx = [ignoreIndexes copy]; + } + return self; +} + +- (NSArray *)decodeGreedy:(NSArray *)textIndex + length:(NSInteger)length { + NSMutableArray *texts = [NSMutableArray array]; + NSUInteger index = 0; + + while (index < textIndex.count) { + NSUInteger segmentLength = MIN(length, textIndex.count - index); + NSRange range = NSMakeRange(index, segmentLength); + NSArray *subArray = [textIndex subarrayWithRange:range]; + + NSMutableString *text = [NSMutableString string]; + NSNumber *lastChar = nil; + + NSMutableArray *isNotRepeated = + [NSMutableArray arrayWithObject:@YES]; + NSMutableArray *isNotIgnored = [NSMutableArray array]; + + for (NSUInteger i = 0; i < subArray.count; i++) { + NSNumber *currentChar = subArray[i]; + if (i > 0) { + [isNotRepeated addObject:@(![lastChar isEqualToNumber:currentChar])]; + } + [isNotIgnored addObject:@(![self.ignoreIdx containsObject:currentChar])]; + + lastChar = currentChar; + } + + for (NSUInteger j = 0; j < subArray.count; j++) { + if ([isNotRepeated[j] boolValue] && [isNotIgnored[j] boolValue]) { + NSUInteger charIndex = [subArray[j] unsignedIntegerValue]; + [text appendString:self.character[charIndex]]; + } + } + + [texts addObject:text.copy]; + index += segmentLength; + + if (segmentLength < length) { + break; + } + } + + return texts.copy; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h new file mode 100644 index 0000000000..3f205b8ebd --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -0,0 +1,26 @@ +#import + +constexpr int verticalLineThreshold = 20; + +@interface DetectorUtils : NSObject + ++ (void)interleavedArrayToMats:(NSArray *)array + outputMat1:(cv::Mat &)mat1 + outputMat2:(cv::Mat &)mat2 + withSize:(cv::Size)size; ++ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + lowTextThreshold:(CGFloat)lowTextThreshold; ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes + usingRestoreRatio:(CGFloat)restoreRatio; ++ (NSArray *)groupTextBoxes:(NSArray *)polys + centerThreshold:(CGFloat)centerThreshold + distanceThreshold:(CGFloat)distanceThreshold + heightThreshold:(CGFloat)heightThreshold + minSideThreshold:(int)minSideThreshold + maxSideThreshold:(int)maxSideThreshold + maxWidth:(int)maxWidth; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm new file mode 100644 index 0000000000..8ee7424d00 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -0,0 +1,653 @@ +#import "DetectorUtils.h" + +@implementation DetectorUtils + ++ (void)interleavedArrayToMats:(NSArray *)array + outputMat1:(cv::Mat &)mat1 + outputMat2:(cv::Mat &)mat2 + withSize:(cv::Size)size { + mat1 = cv::Mat(size.height, size.width, CV_32F); + mat2 = cv::Mat(size.height, size.width, CV_32F); + + for (NSUInteger idx = 0; idx < array.count; idx++) { + const CGFloat value = [array[idx] doubleValue]; + const int x = (idx / 2) % size.width; + const int y = (idx / 2) / size.width; + + if (idx % 2 == 0) { + mat1.at(y, x) = value; + } else { + mat2.at(y, x) = value; + } + } +} + +/** + * This method applies a series of image processing operations to identify + * likely areas of text in the textMap and return the bounding boxes for single + * words. + * + * @param textMap A cv::Mat representing a heat map of the characters of text + * being present in an image. + * @param affinityMap A cv::Mat representing a heat map of the affinity between + * characters. + * @param textThreshold A CGFloat representing the threshold for the text map. + * @param linkThreshold A CGFloat representing the threshold for the affinity + * map. + * @param lowTextThreshold A CGFloat representing the low text. + * + * @return An NSArray containing NSDictionary objects. Each dictionary includes: + * - "bbox": an NSArray of CGPoint values representing the vertices of the + * detected text box. + * - "angle": an NSNumber representing the rotation angle of the box. + */ ++ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + lowTextThreshold:(CGFloat)lowTextThreshold { + const int imgH = textMap.rows; + const int imgW = textMap.cols; + cv::Mat textScore; + cv::Mat affinityScore; + cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, 1, + cv::THRESH_BINARY); + cv::Mat textScoreComb = textScore + affinityScore; + cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + cv::Mat labels, stats, centroids; + const int nLabels = + cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); + + NSMutableArray *detectedBoxes = [NSMutableArray array]; + for (int i = 1; i < nLabels; i++) { + const int area = stats.at(i, cv::CC_STAT_AREA); + if (area < 10) + continue; + + cv::Mat mask = (labels == i); + CGFloat maxVal; + cv::minMaxLoc(textMap, NULL, &maxVal, NULL, NULL, mask); + if (maxVal < lowTextThreshold) + continue; + + cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); + segMap.setTo(255, mask); + + const int x = stats.at(i, cv::CC_STAT_LEFT); + const int y = stats.at(i, cv::CC_STAT_TOP); + const int w = stats.at(i, cv::CC_STAT_WIDTH); + const int h = stats.at(i, cv::CC_STAT_HEIGHT); + const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h))) * 2.0); + const int sx = MAX(x - dilationRadius, 0); + const int ex = MIN(x + w + dilationRadius + 1, imgW); + const int sy = MAX(y - dilationRadius, 0); + const int ey = MIN(y + h + dilationRadius + 1, imgH); + + cv::Rect roi(sx, sy, ex - sx, ey - sy); + cv::Mat kernel = cv::getStructuringElement( + cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); + cv::Mat roiSegMap = segMap(roi); + cv::dilate(roiSegMap, roiSegMap, kernel); + + std::vector> contours; + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, + cv::CHAIN_APPROX_SIMPLE); + if (!contours.empty()) { + cv::RotatedRect minRect = cv::minAreaRect(contours[0]); + cv::Point2f vertices[4]; + minRect.points(vertices); + NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; + for (int j = 0; j < 4; j++) { + const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); + [pointsArray addObject:[NSValue valueWithCGPoint:point]]; + } + NSDictionary *dict = + @{@"bbox" : pointsArray, @"angle" : @(minRect.angle)}; + [detectedBoxes addObject:dict]; + } + } + + return detectedBoxes; +} + ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes + usingRestoreRatio:(CGFloat)restoreRatio { + NSMutableArray *result = [NSMutableArray array]; + for (NSUInteger i = 0; i < [boxes count]; i++) { + NSDictionary *box = boxes[i]; + NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *value in box[@"bbox"]) { + CGPoint point = [value CGPointValue]; + point.x *= restoreRatio; + point.y *= restoreRatio; + [boxArray addObject:[NSValue valueWithCGPoint:point]]; + } + NSDictionary *dict = @{@"bbox" : boxArray, @"angle" : box[@"angle"]}; + [result addObject:dict]; + } + + return result; +} + +/** + * This method normalizes angle returned from cv::minAreaRect function which + *ranges from 0 to 90 degrees. + **/ ++ (CGFloat)normalizeAngle:(CGFloat)angle { + if (angle > 45) { + return angle - 90; + } + return angle; +} + ++ (CGPoint)midpointBetweenPoint:(CGPoint)p1 andPoint:(CGPoint)p2 { + return CGPointMake((p1.x + p2.x) / 2, (p1.y + p2.y) / 2); +} + ++ (CGFloat)distanceFromPoint:(CGPoint)p1 toPoint:(CGPoint)p2 { + const CGFloat xDist = (p2.x - p1.x); + const CGFloat yDist = (p2.y - p1.y); + return sqrt(xDist * xDist + yDist * yDist); +} + ++ (CGPoint)centerOfBox:(NSArray *)box { + return [self midpointBetweenPoint:[box[0] CGPointValue] + andPoint:[box[2] CGPointValue]]; +} + ++ (CGFloat)maxSideLength:(NSArray *)points { + CGFloat maxSideLength = 0; + NSInteger numOfPoints = points.count; + for (NSInteger i = 0; i < numOfPoints; i++) { + const CGPoint currentPoint = [points[i] CGPointValue]; + const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; + + const CGFloat sideLength = [self distanceFromPoint:currentPoint + toPoint:nextPoint]; + if (sideLength > maxSideLength) { + maxSideLength = sideLength; + } + } + return maxSideLength; +} + ++ (CGFloat)minSideLength:(NSArray *)points { + CGFloat minSideLength = CGFLOAT_MAX; + NSInteger numOfPoints = points.count; + + for (NSInteger i = 0; i < numOfPoints; i++) { + const CGPoint currentPoint = [points[i] CGPointValue]; + const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; + + const CGFloat sideLength = [self distanceFromPoint:currentPoint + toPoint:nextPoint]; + if (sideLength < minSideLength) { + minSideLength = sideLength; + } + } + + return minSideLength; +} + ++ (CGFloat)calculateMinimalDistanceBetweenBox:(NSArray *)box1 + andBox:(NSArray *)box2 { + CGFloat minDistance = CGFLOAT_MAX; + for (NSValue *value1 in box1) { + const CGPoint corner1 = [value1 CGPointValue]; + for (NSValue *value2 in box2) { + const CGPoint corner2 = [value2 CGPointValue]; + const CGFloat distance = [self distanceFromPoint:corner1 toPoint:corner2]; + if (distance < minDistance) { + minDistance = distance; + } + } + } + return minDistance; +} + ++ (NSArray *)rotateBox:(NSArray *)box + withAngle:(CGFloat)angle { + const CGPoint center = [self centerOfBox:box]; + + const CGFloat radians = angle * M_PI / 180.0; + + NSMutableArray *rotatedPoints = + [NSMutableArray arrayWithCapacity:4]; + for (NSValue *value in box) { + const CGPoint point = [value CGPointValue]; + + const CGFloat translatedX = point.x - center.x; + const CGFloat translatedY = point.y - center.y; + + const CGFloat rotatedX = + translatedX * cos(radians) - translatedY * sin(radians); + const CGFloat rotatedY = + translatedX * sin(radians) + translatedY * cos(radians); + + const CGPoint rotatedPoint = + CGPointMake(rotatedX + center.x, rotatedY + center.y); + [rotatedPoints addObject:[NSValue valueWithCGPoint:rotatedPoint]]; + } + + return rotatedPoints; +} + +/** + * Orders a set of points in a clockwise direction starting with the top-left + * point. + * + * Process: + * 1. It iterates through each CGPoint extracted from the NSValues. + * 2. For each point, it calculates the sum (x + y) and difference (y - x) of + * the coordinates. + * 3. Points are classified into: + * - Top-left: Minimum sum. + * - Bottom-right: Maximum sum. + * - Top-right: Minimum difference. + * - Bottom-left: Maximum difference. + * 4. The points are ordered starting from the top-left in a clockwise manner: + * top-left, top-right, bottom-right, bottom-left. + */ ++ (NSArray *)orderPointsClockwise:(NSArray *)points { + CGPoint topLeft, topRight, bottomRight, bottomLeft; + CGFloat minSum = FLT_MAX; + CGFloat maxSum = -FLT_MAX; + CGFloat minDiff = FLT_MAX; + CGFloat maxDiff = -FLT_MAX; + + for (NSValue *value in points) { + const CGPoint pt = [value CGPointValue]; + const CGFloat sum = pt.x + pt.y; + const CGFloat diff = pt.y - pt.x; + + if (sum < minSum) { + minSum = sum; + topLeft = pt; + } + if (sum > maxSum) { + maxSum = sum; + bottomRight = pt; + } + if (diff < minDiff) { + minDiff = diff; + topRight = pt; + } + if (diff > maxDiff) { + maxDiff = diff; + bottomLeft = pt; + } + } + + NSArray *rect = @[ + [NSValue valueWithCGPoint:topLeft], [NSValue valueWithCGPoint:topRight], + [NSValue valueWithCGPoint:bottomRight], + [NSValue valueWithCGPoint:bottomLeft] + ]; + + return rect; +} + ++ (std::vector)pointsFromNSValues:(NSArray *)nsValues { + std::vector points; + for (NSValue *value in nsValues) { + const CGPoint point = [value CGPointValue]; + points.emplace_back(point.x, point.y); + } + return points; +} + ++ (NSArray *)nsValuesFromPoints:(cv::Point2f *)points + count:(int)count { + NSMutableArray *nsValues = + [[NSMutableArray alloc] initWithCapacity:count]; + for (int i = 0; i < count; i++) { + [nsValues addObject:[NSValue valueWithCGPoint:CGPointMake(points[i].x, + points[i].y)]]; + } + return nsValues; +} + ++ (NSArray *)mergeRotatedBoxes:(NSArray *)box1 + withBox:(NSArray *)box2 { + box1 = [self orderPointsClockwise:box1]; + box2 = [self orderPointsClockwise:box2]; + + std::vector points1 = [self pointsFromNSValues:box1]; + std::vector points2 = [self pointsFromNSValues:box2]; + + std::vector allPoints; + allPoints.insert(allPoints.end(), points1.begin(), points1.end()); + allPoints.insert(allPoints.end(), points2.begin(), points2.end()); + + std::vector hullIndices; + cv::convexHull(allPoints, hullIndices, false); + + std::vector hullPoints; + for (int idx : hullIndices) { + hullPoints.push_back(allPoints[idx]); + } + + cv::RotatedRect minAreaRect = cv::minAreaRect(hullPoints); + + cv::Point2f rectPoints[4]; + minAreaRect.points(rectPoints); + + return [self nsValuesFromPoints:rectPoints count:4]; +} + ++ (NSMutableArray *) + removeSmallBoxesFromArray:(NSArray *)boxes + usingMinSideThreshold:(CGFloat)minSideThreshold + maxSideThreshold:(CGFloat)maxSideThreshold { + NSMutableArray *filteredBoxes = [NSMutableArray array]; + + for (NSDictionary *box in boxes) { + const CGFloat maxSideLength = [self maxSideLength:box[@"bbox"]]; + const CGFloat minSideLength = [self minSideLength:box[@"bbox"]]; + if (minSideLength > minSideThreshold && maxSideLength > maxSideThreshold) { + [filteredBoxes addObject:box]; + } + } + + return filteredBoxes; +} + ++ (CGFloat)minimumYFromBox:(NSArray *)box { + __block CGFloat minY = CGFLOAT_MAX; + [box enumerateObjectsUsingBlock:^(NSValue *_Nonnull obj, NSUInteger idx, + BOOL *_Nonnull stop) { + const CGPoint pt = [obj CGPointValue]; + if (pt.y < minY) { + minY = pt.y; + } + }]; + return minY; +} + +/** + * This method calculates the distances between each sequential pair of points + * in a presumed quadrilateral, identifies the two shortest sides, and fits a + * linear model to the midpoints of these sides. It also evaluates whether the + * resulting line should be considered vertical based on a predefined threshold + * for the x-coordinate differences. + * + * If the line is vertical it is fitted as a function of x = my + c, otherwise + * as y = mx + c. + * + * @return A NSDictionary containing: + * - "slope": NSNumber representing the slope (m) of the line. + * - "intercept": NSNumber representing the line's intercept (c) with y-axis. + * - "isVertical": NSNumber (boolean) indicating whether the line is + * considered vertical. + */ ++ (NSDictionary *)fitLineToShortestSides:(NSArray *)points { + NSMutableArray *sides = [NSMutableArray array]; + NSMutableArray *midpoints = [NSMutableArray array]; + + for (int i = 0; i < 4; i++) { + const CGPoint p1 = [points[i] CGPointValue]; + const CGPoint p2 = [points[(i + 1) % 4] CGPointValue]; + + const CGFloat sideLength = [self distanceFromPoint:p1 toPoint:p2]; + [sides addObject:@{@"length" : @(sideLength), @"index" : @(i)}]; + [midpoints + addObject:[NSValue valueWithCGPoint:[self midpointBetweenPoint:p1 + andPoint:p2]]]; + } + + [sides + sortUsingDescriptors:@[ [NSSortDescriptor sortDescriptorWithKey:@"length" + ascending:YES] ]]; + + const CGPoint midpoint1 = + [midpoints [[sides [0] [@"index"] intValue]] CGPointValue]; + const CGPoint midpoint2 = + [midpoints [[sides [1] [@"index"] intValue]] CGPointValue]; + const CGFloat dx = fabs(midpoint2.x - midpoint1.x); + + CGFloat m, c; + BOOL isVertical; + + std::vector cvMidPoints = { + cv::Point2f(midpoint1.x, midpoint1.y), + cv::Point2f(midpoint2.x, midpoint2.y)}; + cv::Vec4f line; + + if (dx < verticalLineThreshold) { + for (auto &pt : cvMidPoints) + std::swap(pt.x, pt.y); + cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); + m = line[1] / line[0]; + c = line[3] - m * line[2]; + isVertical = YES; + } else { + cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); + m = line[1] / line[0]; + c = line[3] - m * line[2]; + isVertical = NO; + } + + return @{@"slope" : @(m), @"intercept" : @(c), @"isVertical" : @(isVertical)}; +} + +/** + * This method assesses each box from a provided array, checks its center + * against the center of a "current box", and evaluates its alignment with a + * specified line equation. The function specifically searches for the box whose + * center is closest to the current box, that has not been ignored, and fits + * within a defined distance from the line. + * + * @param boxes An NSArray of NSDictionary objects where each dictionary + * represents a box with keys "bbox" and "angle". "bbox" is an NSArray of + * NSValue objects each encapsulating CGPoint that define the box vertices. + * "angle" is a NSNumber representing the box's rotation angle. + * @param ignoredIdxs An NSSet of NSNumber objects representing indices of boxes + * to ignore in the evaluation. + * @param currentBox An NSArray of NSValue objects encapsulating CGPoints + * representing the current box to compare against. + * @param isVertical A pointer to a BOOL indicating if the line to compare + * distance to is vertical. + * @param m The slope (gradient) of the line against which the box's alignment + * is checked. + * @param c The y-intercept of the line equation y = mx + c. + * @param centerThreshold A multiplier to determine the threshold for the + * distance between the box's center and the line. + * + * @return A NSDictionary containing: + * - "idx" : NSNumber indicating the index of the found box in the + * original NSArray. + * - "boxHeight" : NSNumber representing the shortest side length of the + * found box. Returns nil if no suitable box is found. + */ ++ (NSDictionary *)findClosestBox:(NSArray *)boxes + ignoredIdxs:(NSSet *)ignoredIdxs + currentBox:(NSArray *)currentBox + isVertical:(BOOL)isVertical + m:(CGFloat)m + c:(CGFloat)c + centerThreshold:(CGFloat)centerThreshold { + CGFloat smallestDistance = CGFLOAT_MAX; + NSInteger idx = -1; + CGFloat boxHeight = 0; + const CGPoint centerOfCurrentBox = [self centerOfBox:currentBox]; + + for (NSUInteger i = 0; i < boxes.count; i++) { + if ([ignoredIdxs containsObject:@(i)]) { + continue; + } + NSArray *bbox = boxes[i][@"bbox"]; + const CGPoint centerOfProcessedBox = [self centerOfBox:bbox]; + const CGFloat distanceBetweenCenters = + [self distanceFromPoint:centerOfCurrentBox + toPoint:centerOfProcessedBox]; + + if (distanceBetweenCenters >= smallestDistance) { + continue; + } + + boxHeight = [self minSideLength:bbox]; + + const CGFloat lineDistance = + (isVertical + ? fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) + : fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); + + if (lineDistance < boxHeight * centerThreshold) { + idx = i; + smallestDistance = distanceBetweenCenters; + } + } + + return idx != -1 ? @{@"idx" : @(idx), @"boxHeight" : @(boxHeight)} : nil; +} + +/** + * This method processes an array of text box dictionaries, each containing + * details about individual text boxes, and attempts to group and merge these + * boxes based on specified criteria including proximity, alignment, and size + * thresholds. It prioritizes merging of boxes that are aligned closely in + * angle, are near each other, and whose sizes are compatible based on the given + * thresholds. + * + * @param boxes An array of NSDictionary objects where each dictionary + * represents a text box. Each dictionary must have at least a "bbox" key with + * an NSArray of NSValue wrapping CGPoints defining the box vertices, and an + * "angle" key indicating the orientation of the box. + * @param centerThreshold A CGFloat representing the threshold for considering + * the distance between center and fitted line. + * @param distanceThreshold A CGFloat that defines the maximum allowed distance + * between boxes for them to be considered for merging. + * @param heightThreshold A CGFloat representing the maximum allowed difference + * in height between boxes for merging. + * @param minSideThreshold An int that defines the minimum dimension threshold + * to filter out small boxes after grouping. + * @param maxSideThreshold An int that specifies the maximum dimension threshold + * for filtering boxes post-grouping. + * @param maxWidth An int that represents the maximum width allowable for a + * merged box. + * + * @return An NSArray of NSDictionary objects representing the merged boxes. + * Each dictionary contains: + * - "bbox": An NSArray of NSValue each containing a CGPoint that + * defines the vertices of the merged box. + * - "angle": NSNumber representing the computed orientation of the + * merged box. + * + * Processing Steps: + * 1. Sort initial boxes based on their maximum side length. + * 2. Sequentially merge boxes considering alignment, proximity, and size + * compatibility. + * 3. Post-processing to remove any boxes that are too small or exceed max side + * criteria. + * 4. Sort the final array of boxes by their vertical positions. + */ ++ (NSArray *)groupTextBoxes: + (NSMutableArray *)boxes + centerThreshold:(CGFloat)centerThreshold + distanceThreshold:(CGFloat)distanceThreshold + heightThreshold:(CGFloat)heightThreshold + minSideThreshold:(int)minSideThreshold + maxSideThreshold:(int)maxSideThreshold + maxWidth:(int)maxWidth { + // Sort boxes based on their maximum side length + boxes = [boxes sortedArrayUsingComparator:^NSComparisonResult( + NSDictionary *obj1, NSDictionary *obj2) { + const CGFloat maxLen1 = [self maxSideLength:obj1[@"bbox"]]; + const CGFloat maxLen2 = [self maxSideLength:obj2[@"bbox"]]; + return (maxLen1 < maxLen2) ? NSOrderedDescending + : (maxLen1 > maxLen2) ? NSOrderedAscending + : NSOrderedSame; + }].mutableCopy; + + NSMutableArray *mergedArray = [NSMutableArray array]; + CGFloat lineAngle; + while (boxes.count > 0) { + NSMutableDictionary *currentBox = [boxes[0] mutableCopy]; + CGFloat normalizedAngle = + [self normalizeAngle:[currentBox[@"angle"] floatValue]]; + [boxes removeObjectAtIndex:0]; + NSMutableArray *ignoredIdxs = [NSMutableArray array]; + + while (YES) { + // Find all aligned boxes and merge them until max_size is reached or no + // more boxes can be merged + NSDictionary *fittedLine = + [self fitLineToShortestSides:currentBox[@"bbox"]]; + const CGFloat slope = [fittedLine[@"slope"] floatValue]; + const CGFloat intercept = [fittedLine[@"intercept"] floatValue]; + const BOOL isVertical = [fittedLine[@"isVertical"] boolValue]; + + lineAngle = atan(slope) * 180 / M_PI; + if (isVertical) { + lineAngle = -90; + } + + NSDictionary *closestBoxInfo = + [self findClosestBox:boxes + ignoredIdxs:[NSSet setWithArray:ignoredIdxs] + currentBox:currentBox[@"bbox"] + isVertical:isVertical + m:slope + c:intercept + centerThreshold:centerThreshold]; + if (closestBoxInfo == nil) + break; + + NSInteger candidateIdx = [closestBoxInfo[@"idx"] integerValue]; + NSMutableDictionary *candidateBox = [boxes[candidateIdx] mutableCopy]; + const CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; + + if (([candidateBox[@"angle"] isEqual:@90] && !isVertical) || + ([candidateBox[@"angle"] isEqual:@0] && isVertical)) { + candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] + withAngle:normalizedAngle]; + } + + const CGFloat minDistance = + [self calculateMinimalDistanceBetweenBox:candidateBox[@"bbox"] + andBox:currentBox[@"bbox"]]; + const CGFloat mergedHeight = [self minSideLength:currentBox[@"bbox"]]; + if (minDistance < distanceThreshold * candidateHeight && + fabs(mergedHeight - candidateHeight) < + candidateHeight * heightThreshold) { + currentBox[@"bbox"] = [self mergeRotatedBoxes:currentBox[@"bbox"] + withBox:candidateBox[@"bbox"]]; + [boxes removeObjectAtIndex:candidateIdx]; + [ignoredIdxs removeAllObjects]; + if ([self maxSideLength:currentBox[@"bbox"]] > maxWidth) { + break; + } + } else { + [ignoredIdxs addObject:@(candidateIdx)]; + } + } + + [mergedArray + addObject:@{@"bbox" : currentBox[@"bbox"], @"angle" : @(lineAngle)}]; + } + + // Remove small boxes and sort by vertical + mergedArray = [self removeSmallBoxesFromArray:mergedArray + usingMinSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold]; + + NSArray *sortedBoxes = [mergedArray + sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, + NSDictionary *obj2) { + NSArray *coords1 = obj1[@"bbox"]; + NSArray *coords2 = obj2[@"bbox"]; + const CGFloat minY1 = [self minimumYFromBox:coords1]; + const CGFloat minY2 = [self minimumYFromBox:coords2]; + return (minY1 < minY2) ? NSOrderedAscending + : (minY1 > minY2) ? NSOrderedDescending + : NSOrderedSame; + }]; + + return sortedBoxes; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h new file mode 100644 index 0000000000..dca8b9bba5 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -0,0 +1,9 @@ +#import + +@interface OCRUtils : NSObject + ++ (cv::Mat)resizeWithPadding:(cv::Mat)img + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm new file mode 100644 index 0000000000..f530dac2da --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -0,0 +1,55 @@ +#import "OCRUtils.h" + +@implementation OCRUtils + ++ (cv::Mat)resizeWithPadding:(cv::Mat)img + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight { + const int height = img.rows; + const int width = img.cols; + const float heightRatio = (float)desiredHeight / height; + const float widthRatio = (float)desiredWidth / width; + const float resizeRatio = MIN(heightRatio, widthRatio); + + const int newWidth = width * resizeRatio; + const int newHeight = height * resizeRatio; + + cv::Mat resizedImg; + cv::resize(img, resizedImg, cv::Size(newWidth, newHeight), 0, 0, + cv::INTER_AREA); + + const int cornerPatchSize = MAX(1, MIN(height, width) / 30); + std::vector corners = { + img(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), + img(cv::Rect(width - cornerPatchSize, 0, cornerPatchSize, + cornerPatchSize)), + img(cv::Rect(0, height - cornerPatchSize, cornerPatchSize, + cornerPatchSize)), + img(cv::Rect(width - cornerPatchSize, height - cornerPatchSize, + cornerPatchSize, cornerPatchSize))}; + + cv::Scalar backgroundScalar = cv::mean(corners[0]); + for (int i = 1; i < corners.size(); i++) { + backgroundScalar += cv::mean(corners[i]); + } + backgroundScalar /= (double)corners.size(); + + backgroundScalar[0] = cvFloor(backgroundScalar[0]); + backgroundScalar[1] = cvFloor(backgroundScalar[1]); + backgroundScalar[2] = cvFloor(backgroundScalar[2]); + + const int deltaW = desiredWidth - newWidth; + const int deltaH = desiredHeight - newHeight; + const int top = deltaH / 2; + const int bottom = deltaH - top; + const int left = deltaW / 2; + const int right = deltaW - left; + + cv::Mat centeredImg; + cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, + cv::BORDER_CONSTANT, backgroundScalar); + + return centeredImg; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h new file mode 100644 index 0000000000..7af748f58c --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -0,0 +1,28 @@ +#import + +@interface RecognizerUtils : NSObject + ++ (CGFloat)calculateRatio:(int)width height:(int)height; ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img + width:(int)width + height:(int)height + modelHeight:(int)modelHeight; ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image + adjustContrast:(double)adjustContrast; ++ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; ++ (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; ++ (cv::Mat)softmax:(cv::Mat)inputs; ++ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width + height:(int)height + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight; ++ (cv::Mat)getCroppedImage:(NSDictionary *)box + image:(cv::Mat)image + modelHeight:(int)modelHeight; ++ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities + modelOutputHeight:(int)modelOutputHeight; ++ (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; ++ (double)computeConfidenceScore:(NSArray *)valuesArray + indicesArray:(NSArray *)indicesArray; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm new file mode 100644 index 0000000000..65c088b361 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -0,0 +1,223 @@ +#import "RecognizerUtils.h" +#import "OCRUtils.h" + +@implementation RecognizerUtils + ++ (CGFloat)calculateRatio:(int)width height:(int)height { + CGFloat ratio = (CGFloat)width / (CGFloat)height; + if (ratio < 1.0) { + ratio = 1.0 / ratio; + } + return ratio; +} + ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img + width:(int)width + height:(int)height + modelHeight:(int)modelHeight { + CGFloat ratio = (CGFloat)width / (CGFloat)height; + if (ratio < 1.0) { + ratio = [self calculateRatio:width height:height]; + cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, + 0, cv::INTER_LANCZOS4); + } else { + cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, + 0, cv::INTER_LANCZOS4); + } + return img; +} + ++ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target { + double contrast = 0.0; + int high = 0; + int low = 255; + + for (int i = 0; i < img.rows; ++i) { + for (int j = 0; j < img.cols; ++j) { + uchar pixel = img.at(i, j); + high = MAX(high, pixel); + low = MIN(low, pixel); + } + } + contrast = (high - low) / 255.0; + + if (contrast < target) { + const double ratio = 200.0 / MAX(10, high - low); + img.convertTo(img, CV_32F); + img = ((img - low + 25) * ratio); + + cv::threshold(img, img, 255, 255, cv::THRESH_TRUNC); + cv::threshold(img, img, 0, 0, cv::THRESH_TOZERO); + + img.convertTo(img, CV_8U); + } + + return img; +} + ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image + adjustContrast:(double)adjustContrast { + if (adjustContrast > 0) { + image = [self adjustContrastGrey:image target:adjustContrast]; + } + + int desiredWidth = 128; + if (image.cols >= 512) { + desiredWidth = 512; + } else if (image.cols >= 256) { + desiredWidth = 256; + } + + image = [OCRUtils resizeWithPadding:image + desiredWidth:desiredWidth + desiredHeight:64]; + + image.convertTo(image, CV_32F, 1.0 / 255.0); + image = (image - 0.5) * 2.0; + + return image; +} + ++ (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { + cv::Mat result = matrix.clone(); + + for (int i = 0; i < matrix.rows; i++) { + const float divisor = [vector[i] floatValue]; + for (int j = 0; j < matrix.cols; j++) { + result.at(i, j) /= divisor; + } + } + + return result; +} + ++ (cv::Mat)softmax:(cv::Mat)inputs { + cv::Mat maxVal; + cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); + cv::Mat expInputs; + cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); + cv::Mat sumExp; + cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); + cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); + return softmaxOutput; +} + ++ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width + height:(int)height + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight { + const float newRatioH = (float)desiredHeight / height; + const float newRatioW = (float)desiredWidth / width; + float resizeRatio = MIN(newRatioH, newRatioW); + const int newWidth = width * resizeRatio; + const int newHeight = height * resizeRatio; + const int deltaW = desiredWidth - newWidth; + const int deltaH = desiredHeight - newHeight; + const int top = deltaH / 2; + const int left = deltaW / 2; + const float heightRatio = (float)height / desiredHeight; + const float widthRatio = (float)width / desiredWidth; + + resizeRatio = MAX(heightRatio, widthRatio); + + return @{ + @"resizeRatio" : @(resizeRatio), + @"top" : @(top), + @"left" : @(left), + }; +} + ++ (cv::Mat)getCroppedImage:(NSDictionary *)box + image:(cv::Mat)image + modelHeight:(int)modelHeight { + NSArray *coords = box[@"bbox"]; + const CGFloat angle = [box[@"angle"] floatValue]; + + std::vector points; + for (NSValue *value in coords) { + const CGPoint point = [value CGPointValue]; + points.emplace_back(static_cast(point.x), + static_cast(point.y)); + } + + cv::RotatedRect rotatedRect = cv::minAreaRect(points); + + cv::Point2f imageCenter = cv::Point2f(image.cols / 2.0, image.rows / 2.0); + cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, angle, 1.0); + cv::Mat rotatedImage; + cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), + cv::INTER_LINEAR); + cv::Point2f rectPoints[4]; + rotatedRect.points(rectPoints); + std::vector transformedPoints(4); + cv::Mat rectMat(4, 2, CV_32FC2, rectPoints); + cv::transform(rectMat, rectMat, rotationMatrix); + + for (int i = 0; i < 4; ++i) { + transformedPoints[i] = rectPoints[i]; + } + + cv::Rect boundingBox = cv::boundingRect(transformedPoints); + boundingBox &= cv::Rect(0, 0, rotatedImage.cols, rotatedImage.rows); + cv::Mat croppedImage = rotatedImage(boundingBox); + if (boundingBox.width == 0 || boundingBox.height == 0) { + croppedImage = cv::Mat().empty(); + + return croppedImage; + } + + croppedImage = [self computeRatioAndResize:croppedImage + width:boundingBox.width + height:boundingBox.height + modelHeight:modelHeight]; + + return croppedImage; +} + ++ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities + modelOutputHeight:(int)modelOutputHeight { + NSMutableArray *predsNorm = + [NSMutableArray arrayWithCapacity:probabilities.rows]; + for (int i = 0; i < probabilities.rows; i++) { + float sum = 0.0; + for (int j = 0; j < modelOutputHeight; j++) { + sum += probabilities.at(i, j); + } + [predsNorm addObject:@(sum)]; + } + return predsNorm; +} + ++ (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities { + NSMutableArray *valuesArray = [NSMutableArray array]; + NSMutableArray *indicesArray = [NSMutableArray array]; + for (int i = 0; i < probabilities.rows; i++) { + double maxVal = 0; + cv::Point maxLoc; + cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); + [valuesArray addObject:@(maxVal)]; + [indicesArray addObject:@(maxLoc.x)]; + } + return @[ valuesArray, indicesArray ]; +} + ++ (double)computeConfidenceScore:(NSArray *)valuesArray + indicesArray:(NSArray *)indicesArray { + NSMutableArray *predsMaxProb = [NSMutableArray array]; + for (NSUInteger index = 0; index < indicesArray.count; index++) { + NSNumber *indicator = indicesArray[index]; + if ([indicator intValue] != 0) { + [predsMaxProb addObject:valuesArray[index]]; + } + } + if (predsMaxProb.count == 0) { + [predsMaxProb addObject:@(0)]; + } + double product = 1.0; + for (NSNumber *prob in predsMaxProb) { + product *= [prob doubleValue]; + } + return pow(product, 2.0 / sqrt(predsMaxProb.count)); +} + +@end diff --git a/ios/RnExecutorch/utils/ImageProcessor.h b/ios/RnExecutorch/utils/ImageProcessor.h index 4bb7034e87..c65182d0a6 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.h +++ b/ios/RnExecutorch/utils/ImageProcessor.h @@ -3,8 +3,13 @@ @interface ImageProcessor : NSObject ++ (NSArray *)matToNSArray:(const cv::Mat &)mat + mean:(cv::Scalar)mean + variance:(cv::Scalar)variance; + (NSArray *)matToNSArray:(const cv::Mat &)mat; + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height; ++ (cv::Mat)arrayToMatGray:(NSArray *)array width:(int)width height:(int)height; ++ (NSArray *)matToNSArrayGray:(const cv::Mat &)mat; + (NSString *)saveToTempFile:(const cv::Mat &)image; + (cv::Mat)readImage:(NSString *)source; diff --git a/ios/RnExecutorch/utils/ImageProcessor.mm b/ios/RnExecutorch/utils/ImageProcessor.mm index feab17f608..a8617c262f 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.mm +++ b/ios/RnExecutorch/utils/ImageProcessor.mm @@ -4,6 +4,12 @@ @implementation ImageProcessor + (NSArray *)matToNSArray:(const cv::Mat &)mat { + return [ImageProcessor matToNSArray:mat mean:cv::Scalar(0.0, 0.0, 0.0) variance:cv::Scalar(1.0, 1.0, 1.0)]; +} + ++ (NSArray *)matToNSArray:(const cv::Mat &)mat + mean:(cv::Scalar)mean + variance:(cv::Scalar)variance { int pixelCount = mat.cols * mat.rows; NSMutableArray *floatArray = [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; for (NSUInteger k = 0; k < pixelCount * 3; k++) { @@ -14,14 +20,27 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat { int row = i / mat.cols; int col = i % mat.cols; cv::Vec3b pixel = mat.at(row, col); - floatArray[0 * pixelCount + i] = @(pixel[2] / 255.0f); - floatArray[1 * pixelCount + i] = @(pixel[1] / 255.0f); - floatArray[2 * pixelCount + i] = @(pixel[0] / 255.0f); + floatArray[0 * pixelCount + i] = @((pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0)); + floatArray[1 * pixelCount + i] = @((pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0)); + floatArray[2 * pixelCount + i] = @((pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0)); } return floatArray; } ++ (NSArray *)matToNSArrayGray:(const cv::Mat &)mat { + NSMutableArray *pixelArray = [[NSMutableArray alloc] initWithCapacity:mat.cols * mat.rows]; + + for (int row = 0; row < mat.rows; row++) { + for (int col = 0; col < mat.cols; col++) { + float pixelValue = mat.at(row, col); + [pixelArray addObject:@(pixelValue)]; + } + } + + return pixelArray; +} + + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { cv::Mat mat(height, width, CV_8UC3); @@ -42,6 +61,20 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat { return mat; } ++ (cv::Mat)arrayToMatGray:(NSArray *)array width:(int)width height:(int)height { + cv::Mat mat(height, width, CV_32F); + + int pixelCount = width * height; + for (int i = 0; i < pixelCount; i++) { + int row = i / width; + int col = i % width; + float value = [array[i] floatValue]; + mat.at(row, col) = value; + } + + return mat; +} + + (NSString *)saveToTempFile:(const cv::Mat&)image { NSString *uniqueID = [[NSUUID UUID] UUIDString]; NSString *filename = [NSString stringWithFormat:@"rn_executorch_%@.png", uniqueID]; @@ -65,9 +98,9 @@ + (NSString *)saveToTempFile:(const cv::Mat&)image { //base64 NSArray *parts = [source componentsSeparatedByString:@","]; if ([parts count] < 2) { - @throw [NSException exceptionWithName:@"readImage_error" - reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] - userInfo:nil]; + @throw [NSException exceptionWithName:@"readImage_error" + reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] + userInfo:nil]; } NSString *encodedString = parts[1]; NSData *data = [[NSData alloc] initWithBase64EncodedString:encodedString options:NSDataBase64DecodingIgnoreUnknownCharacters]; diff --git a/src/Error.ts b/src/Error.ts index 767856393c..955b62a95e 100644 --- a/src/Error.ts +++ b/src/Error.ts @@ -4,6 +4,7 @@ export enum ETError { ModuleNotLoaded = 0x66, FileWriteFailed = 0x67, ModelGenerating = 0x68, + LanguageNotSupported = 0x69, InvalidModelSource = 0xff, // ExecuTorch mapped errors diff --git a/src/constants/ocr/languageDicts.ts b/src/constants/ocr/languageDicts.ts new file mode 100644 index 0000000000..fcd189b53c --- /dev/null +++ b/src/constants/ocr/languageDicts.ts @@ -0,0 +1,4 @@ +export const languageDicts: { [key: string]: string } = { + en: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/en.txt', + pl: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/pl.txt', +}; diff --git a/src/constants/ocr/symbols.ts b/src/constants/ocr/symbols.ts new file mode 100644 index 0000000000..229c0613d1 --- /dev/null +++ b/src/constants/ocr/symbols.ts @@ -0,0 +1,4 @@ +export const symbols: { [key: string]: string } = { + en: '0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', + pl: ' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~ªÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćČčĎďĐđĒēĖėĘęĚěĞğĨĩĪīĮįİıĶķĹĺĻļĽľŁłŃńŅņŇňŒœŔŕŘřŚśŞşŠšŤťŨũŪūŮůŲųŸŹźŻżŽžƏƠơƯưȘșȚțə̇ḌḍḶḷṀṁṂṃṄṅṆṇṬṭẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ€', +}; diff --git a/src/hooks/computer_vision/useOCR.ts b/src/hooks/computer_vision/useOCR.ts new file mode 100644 index 0000000000..56ee04e412 --- /dev/null +++ b/src/hooks/computer_vision/useOCR.ts @@ -0,0 +1,109 @@ +import { useEffect, useState } from 'react'; +import { fetchResource } from '../../utils/fetchResource'; +import { languageDicts } from '../../constants/ocr/languageDicts'; +import { symbols } from '../../constants/ocr/symbols'; +import { getError, ETError } from '../../Error'; +import { OCR } from '../../native/RnExecutorchModules'; +import { ResourceSource } from '../../types/common'; +import { OCRDetection } from '../../types/ocr'; + +interface OCRModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: string) => Promise; + downloadProgress: number; +} + +export const useOCR = ({ + detectorSource, + recognizerSources, + language = 'en', +}: { + detectorSource: ResourceSource; + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerMedium: ResourceSource; + recognizerSmall: ResourceSource; + }; + language?: string; +}): OCRModule => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + + useEffect(() => { + const loadModel = async () => { + try { + if (!detectorSource || Object.keys(recognizerSources).length === 0) + return; + + const recognizerPaths = {} as { + recognizerLarge: string; + recognizerMedium: string; + recognizerSmall: string; + }; + + if (!symbols[language] || !languageDicts[language]) { + setError(getError(ETError.LanguageNotSupported)); + return; + } + + const detectorPath = await fetchResource(detectorSource); + + await Promise.all([ + fetchResource(recognizerSources.recognizerLarge, setDownloadProgress), + fetchResource(recognizerSources.recognizerMedium), + fetchResource(recognizerSources.recognizerSmall), + ]).then((values) => { + recognizerPaths.recognizerLarge = values[0]; + recognizerPaths.recognizerMedium = values[1]; + recognizerPaths.recognizerSmall = values[2]; + }); + + setIsReady(false); + await OCR.loadModule( + detectorPath, + recognizerPaths.recognizerLarge, + recognizerPaths.recognizerMedium, + recognizerPaths.recognizerSmall, + symbols[language] + ); + setIsReady(true); + } catch (e) { + setError(getError(e)); + } + }; + + loadModel(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [detectorSource, language, JSON.stringify(recognizerSources)]); + + const forward = async (input: string) => { + if (!isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + setIsGenerating(true); + const output = await OCR.forward(input); + return output; + } catch (e) { + throw new Error(getError(e)); + } finally { + setIsGenerating(false); + } + }; + + return { + error, + isReady, + isGenerating, + forward, + downloadProgress, + }; +}; diff --git a/src/index.tsx b/src/index.tsx index ec7dd8c128..f5bfa1854d 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -2,6 +2,7 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; +export * from './hooks/computer_vision/useOCR'; export * from './hooks/natural_language_processing/useLLM'; @@ -11,6 +12,7 @@ export * from './hooks/general/useExecutorchModule'; export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; +export * from './modules/computer_vision/OCRModule'; export * from './modules/natural_language_processing/LLMModule'; @@ -21,6 +23,7 @@ export * from './utils/listDownloadedResources'; // types export * from './types/object_detection'; +export * from './types/ocr'; // constants export * from './constants/modelUrls'; diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts new file mode 100644 index 0000000000..26ea6f4e89 --- /dev/null +++ b/src/modules/computer_vision/OCRModule.ts @@ -0,0 +1,72 @@ +import { languageDicts } from '../../constants/ocr/languageDicts'; +import { symbols } from '../../constants/ocr/symbols'; +import { getError, ETError } from '../../Error'; +import { OCR } from '../../native/RnExecutorchModules'; +import { ResourceSource } from '../../types/common'; +import { fetchResource } from '../../utils/fetchResource'; + +export class OCRModule { + static onDownloadProgressCallback = (_downloadProgress: number) => {}; + + static async load( + detectorSource: ResourceSource, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerMedium: ResourceSource; + recognizerSmall: ResourceSource; + }, + language = 'en' + ) { + try { + if (!detectorSource || Object.keys(recognizerSources).length === 0) + return; + + const recognizerPaths = {} as { + recognizerLarge: string; + recognizerMedium: string; + recognizerSmall: string; + }; + + if (!symbols[language] || !languageDicts[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + + const detectorPath = await fetchResource(detectorSource); + + await Promise.all([ + fetchResource( + recognizerSources.recognizerLarge, + this.onDownloadProgressCallback + ), + fetchResource(recognizerSources.recognizerMedium), + fetchResource(recognizerSources.recognizerSmall), + ]).then((values) => { + recognizerPaths.recognizerLarge = values[0]; + recognizerPaths.recognizerMedium = values[1]; + recognizerPaths.recognizerSmall = values[2]; + }); + + await OCR.loadModule( + detectorPath, + recognizerPaths.recognizerLarge, + recognizerPaths.recognizerMedium, + recognizerPaths.recognizerSmall, + symbols[language] + ); + } catch (e) { + throw new Error(getError(e)); + } + } + + static async forward(input: string) { + try { + return await OCR.forward(input); + } catch (e) { + throw new Error(getError(e)); + } + } + + static onDownloadProgress(callback: (downloadProgress: number) => void) { + this.onDownloadProgressCallback = callback; + } +} diff --git a/src/native/NativeOCR.ts b/src/native/NativeOCR.ts new file mode 100644 index 0000000000..2c14c6ac0d --- /dev/null +++ b/src/native/NativeOCR.ts @@ -0,0 +1,16 @@ +import type { TurboModule } from 'react-native'; +import { TurboModuleRegistry } from 'react-native'; +import { OCRDetection } from '../types/ocr'; + +export interface Spec extends TurboModule { + loadModule( + detectorSource: string, + recognizerSourceLarge: string, + recognizerSourceMedium: string, + recognizerSourceSmall: string, + symbols: string + ): Promise; + forward(input: string): Promise; +} + +export default TurboModuleRegistry.get('OCR'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index fc1cfe28db..c8044aa473 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -88,6 +88,19 @@ const SpeechToText = SpeechToTextSpec } ); +const OCRSpec = require('./NativeOCR').default; + +const OCR = OCRSpec + ? OCRSpec + : new Proxy( + {}, + { + get() { + throw new Error(LINKING_ERROR); + }, + } + ); + class _ObjectDetectionModule { async forward( input: string @@ -168,6 +181,7 @@ export { ObjectDetection, StyleTransfer, SpeechToText, + OCR, _ETModule, _ClassificationModule, _StyleTransferModule, diff --git a/src/types/ocr.ts b/src/types/ocr.ts new file mode 100644 index 0000000000..f5f2e6d35e --- /dev/null +++ b/src/types/ocr.ts @@ -0,0 +1,10 @@ +export interface OCRDetection { + bbox: OCRBbox[]; + text: string; + score: number; +} + +export interface OCRBbox { + x: number; + y: number; +}