diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index 0ec2a51c4f..64f746613c 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -25,10 +25,11 @@ class RnExecutorchPackage : TurboReactPackage() { ObjectDetection(reactContext) } else if (name == SpeechToText.NAME) { SpeechToText(reactContext) - } else if (name == OCR.NAME){ + } else if (name == OCR.NAME) { OCR(reactContext) - } - else { + } else if (name == VerticalOCR.NAME) { + VerticalOCR(reactContext) + } else { null } @@ -44,54 +45,49 @@ class RnExecutorchPackage : TurboReactPackage() { true, ) moduleInfos[ETModule.NAME] = ReactModuleInfo( - ETModule.NAME, - ETModule.NAME, - false, // canOverrideExistingModule + ETModule.NAME, ETModule.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[StyleTransfer.NAME] = ReactModuleInfo( - StyleTransfer.NAME, - StyleTransfer.NAME, - false, // canOverrideExistingModule + StyleTransfer.NAME, StyleTransfer.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[Classification.NAME] = ReactModuleInfo( - Classification.NAME, - Classification.NAME, - false, // canOverrideExistingModule + Classification.NAME, Classification.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[ObjectDetection.NAME] = ReactModuleInfo( - ObjectDetection.NAME, - ObjectDetection.NAME, - false, // canOverrideExistingModule + ObjectDetection.NAME, ObjectDetection.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[SpeechToText.NAME] = ReactModuleInfo( - SpeechToText.NAME, - SpeechToText.NAME, - false, // canOverrideExistingModule + SpeechToText.NAME, SpeechToText.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true ) moduleInfos[OCR.NAME] = ReactModuleInfo( - OCR.NAME, - OCR.NAME, - false, // canOverrideExistingModule + OCR.NAME, OCR.NAME, false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) + + moduleInfos[VerticalOCR.NAME] = ReactModuleInfo( + VerticalOCR.NAME, VerticalOCR.NAME, false, // canOverrideExistingModule false, // needsEagerInit false, // isCxxModule true diff --git a/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt new file mode 100644 index 0000000000..2d8006774c --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/VerticalOCR.kt @@ -0,0 +1,173 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader +import com.swmansion.rnexecutorch.models.ocr.Recognizer +import com.swmansion.rnexecutorch.models.ocr.VerticalDetector +import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import org.opencv.core.Core +import org.opencv.core.Mat + +class VerticalOCR(reactContext: ReactApplicationContext) : + NativeVerticalOCRSpec(reactContext) { + + private lateinit var detectorLarge: VerticalDetector + private lateinit var detectorNarrow: VerticalDetector + private lateinit var recognizer: Recognizer + private lateinit var converter: CTCLabelConverter + private var independentCharacters = true + + companion object { + const val NAME = "VerticalOCR" + } + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + + override fun loadModule( + detectorLargeSource: String, + detectorNarrowSource: String, + recognizerSource: String, + symbols: String, + independentCharacters: Boolean, + promise: Promise + ) { + try { + this.independentCharacters = independentCharacters + detectorLarge = VerticalDetector(false, reactApplicationContext) + detectorLarge.loadModel(detectorLargeSource) + detectorNarrow = VerticalDetector(true, reactApplicationContext) + detectorNarrow.loadModel(detectorNarrowSource) + recognizer = Recognizer(reactApplicationContext) + recognizer.loadModel(recognizerSource) + + converter = CTCLabelConverter(symbols) + + promise.resolve(0) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward(input: String, promise: Promise) { + try { + val inputImage = ImageProcessor.readImage(input) + val result = detectorLarge.runModel(inputImage) + val largeDetectorSize = detectorLarge.getModelImageSize() + val resizedImage = ImageProcessor.resizeWithPadding( + inputImage, + largeDetectorSize.width.toInt(), + largeDetectorSize.height.toInt() + ) + val predictions = Arguments.createArray() + for (box in result) { + val cords = box.bBox + val boxWidth = cords[2].x - cords[0].x + val boxHeight = cords[2].y - cords[0].y + + val boundingBox = RecognizerUtils.extractBoundingBox(cords) + val croppedImage = Mat(resizedImage, boundingBox) + + val paddings = RecognizerUtils.calculateResizeRatioAndPaddings( + inputImage.width(), + inputImage.height(), + largeDetectorSize.width.toInt(), + largeDetectorSize.height.toInt() + ) + + var text = "" + var confidenceScore = 0.0 + val boxResult = detectorNarrow.runModel(croppedImage) + val narrowDetectorSize = detectorNarrow.getModelImageSize() + + val croppedCharacters = mutableListOf() + + for (characterBox in boxResult) { + val boxCords = characterBox.bBox + val paddingsBox = RecognizerUtils.calculateResizeRatioAndPaddings( + boxWidth.toInt(), + boxHeight.toInt(), + narrowDetectorSize.width.toInt(), + narrowDetectorSize.height.toInt() + ) + + var croppedCharacter = RecognizerUtils.cropImageWithBoundingBox( + inputImage, + boxCords, + cords, + paddingsBox, + paddings + ) + + if (this.independentCharacters) { + croppedCharacter = RecognizerUtils.cropSingleCharacter(croppedCharacter) + croppedCharacter = RecognizerUtils.normalizeForRecognizer(croppedCharacter, 0.0, true) + val recognitionResult = recognizer.runModel(croppedCharacter) + val predIndex = recognitionResult.first + val decodedText = converter.decodeGreedy(predIndex, predIndex.size) + text += decodedText[0] + confidenceScore += recognitionResult.second + } else { + croppedCharacters.add(croppedCharacter) + } + } + + if (this.independentCharacters) { + confidenceScore /= boxResult.size + } else { + var mergedCharacters = Mat() + Core.hconcat(croppedCharacters, mergedCharacters) + mergedCharacters = ImageProcessor.resizeWithPadding( + mergedCharacters, + Constants.LARGE_MODEL_WIDTH, + Constants.MODEL_HEIGHT + ) + mergedCharacters = RecognizerUtils.normalizeForRecognizer(mergedCharacters, 0.0) + + val recognitionResult = recognizer.runModel(mergedCharacters) + val predIndex = recognitionResult.first + val decodedText = converter.decodeGreedy(predIndex, predIndex.size) + + text = decodedText[0] + confidenceScore = recognitionResult.second + } + + for (bBox in box.bBox) { + bBox.x = + (bBox.x - paddings["left"] as Int) * paddings["resizeRatio"] as Float + bBox.y = + (bBox.y - paddings["top"] as Int) * paddings["resizeRatio"] as Float + } + + val resMap = Arguments.createMap() + + resMap.putString("text", text) + resMap.putArray("bbox", box.toWritableArray()) + resMap.putDouble("confidence", confidenceScore) + + predictions.pushMap(resMap) + } + + promise.resolve(predictions) + } catch (e: Exception) { + Log.d("rn_executorch", "Error running model: ${e.message}") + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String { + return NAME + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt index 85976e2281..fb8e4329fe 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -1,6 +1,5 @@ package com.swmansion.rnexecutorch.models.ocr -import android.util.Log import com.facebook.react.bridge.ReactApplicationContext import com.swmansion.rnexecutorch.models.BaseModel import com.swmansion.rnexecutorch.models.ocr.utils.Constants @@ -8,18 +7,18 @@ import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.core.Mat -import org.opencv.core.Scalar import org.opencv.core.Size import org.pytorch.executorch.EValue -class Detector(reactApplicationContext: ReactApplicationContext) : - BaseModel>(reactApplicationContext) { +class Detector( + reactApplicationContext: ReactApplicationContext +) : BaseModel>(reactApplicationContext) { private lateinit var originalSize: Size fun getModelImageSize(): Size { val inputShape = module.getInputShape(0) - val width = inputShape[inputShape.lastIndex] - val height = inputShape[inputShape.lastIndex - 1] + val width = inputShape[inputShape.lastIndex - 1] + val height = inputShape[inputShape.lastIndex] val modelImageSize = Size(height.toDouble(), width.toDouble()) @@ -29,16 +28,11 @@ class Detector(reactApplicationContext: ReactApplicationContext) : override fun preprocess(input: Mat): EValue { originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) val resizedImage = ImageProcessor.resizeWithPadding( - input, - getModelImageSize().width.toInt(), - getModelImageSize().height.toInt() + input, getModelImageSize().width.toInt(), getModelImageSize().height.toInt() ) return ImageProcessor.matToEValue( - resizedImage, - module.getInputShape(0), - Constants.MEAN, - Constants.VARIANCE + resizedImage, module.getInputShape(0), Constants.MEAN, Constants.VARIANCE ) } @@ -48,8 +42,7 @@ class Detector(reactApplicationContext: ReactApplicationContext) : val modelImageSize = getModelImageSize() val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( - outputArray, - Size(modelImageSize.width / 2, modelImageSize.height / 2) + outputArray, Size(modelImageSize.width / 2, modelImageSize.height / 2) ) var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( scoreText, @@ -58,8 +51,10 @@ class Detector(reactApplicationContext: ReactApplicationContext) : Constants.LINK_THRESHOLD, Constants.LOW_TEXT_THRESHOLD ) + bBoxesList = DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) + bBoxesList = DetectorUtils.groupTextBoxes( bBoxesList, Constants.CENTER_THRESHOLD, diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt new file mode 100644 index 0000000000..1566525673 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/VerticalDetector.kt @@ -0,0 +1,89 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class VerticalDetector( + private val detectSingleCharacter: Boolean, + reactApplicationContext: ReactApplicationContext +) : + BaseModel>(reactApplicationContext) { + private lateinit var originalSize: Size + + fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex - 1] + val height = inputShape[inputShape.lastIndex] + + val modelImageSize = Size(height.toDouble(), width.toDouble()) + + return modelImageSize + } + + override fun preprocess(input: Mat): EValue { + originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) + val resizedImage = ImageProcessor.resizeWithPadding( + input, + getModelImageSize().width.toInt(), + getModelImageSize().height.toInt() + ) + + return ImageProcessor.matToEValue( + resizedImage, + module.getInputShape(0), + Constants.MEAN, + Constants.VARIANCE + ) + } + + override fun postprocess(output: Array): List { + val outputTensor = output[0].toTensor() + val outputArray = outputTensor.dataAsFloatArray + val modelImageSize = getModelImageSize() + + val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( + outputArray, + Size(modelImageSize.width / 2, modelImageSize.height / 2) + ) + + val txtThreshold = if (detectSingleCharacter) Constants.TEXT_THRESHOLD else Constants.TEXT_THRESHOLD_VERTICAL + var bBoxesList = DetectorUtils.getDetBoxesFromTextMapVertical( + scoreText, + scoreLink, + txtThreshold, + Constants.LINK_THRESHOLD, + detectSingleCharacter + ) + + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RESTORE_RATIO_VERTICAL).toFloat()) + + if (detectSingleCharacter) { + return bBoxesList + } + + + bBoxesList = DetectorUtils.groupTextBoxes( + bBoxesList, + Constants.CENTER_THRESHOLD, + Constants.DISTANCE_THRESHOLD, + Constants.HEIGHT_THRESHOLD, + Constants.MIN_SIDE_THRESHOLD, + Constants.MAX_SIDE_THRESHOLD, + Constants.MAX_WIDTH + ) + + return bBoxesList.toList() + } + + override fun runModel(input: Mat): List { + return postprocess(forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt index b49232f41a..5dc25cd796 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt @@ -5,13 +5,16 @@ import org.opencv.core.Scalar class Constants { companion object { const val RECOGNIZER_RATIO = 1.6 + const val RESTORE_RATIO_VERTICAL = 2.0 const val MODEL_HEIGHT = 64 const val LARGE_MODEL_WIDTH = 512 const val MEDIUM_MODEL_WIDTH = 256 const val SMALL_MODEL_WIDTH = 128 + const val VERTICAL_SMALL_MODEL_WIDTH = 64 const val LOW_CONFIDENCE_THRESHOLD = 0.3 const val ADJUST_CONTRAST = 0.2 const val TEXT_THRESHOLD = 0.4 + const val TEXT_THRESHOLD_VERTICAL = 0.3 const val LINK_THRESHOLD = 0.4 const val LOW_TEXT_THRESHOLD = 0.7 const val CENTER_THRESHOLD = 0.5 @@ -21,6 +24,7 @@ class Constants { const val MAX_SIDE_THRESHOLD = 30 const val MAX_WIDTH = (LARGE_MODEL_WIDTH + (LARGE_MODEL_WIDTH * 0.15)).toInt() const val MIN_SIZE = 20 + const val SINGLE_CHARACTER_MIN_SIZE = 70 val MEAN = Scalar(0.485, 0.456, 0.406) val VARIANCE = Scalar(0.229, 0.224, 0.225) } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt index 4beb7ecf45..c6b5789ae7 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -1,5 +1,6 @@ package com.swmansion.rnexecutorch.models.ocr.utils +import android.util.Log import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.WritableArray import org.opencv.core.Core @@ -288,6 +289,98 @@ class DetectorUtils { return Pair(mat1, mat2) } + fun getDetBoxesFromTextMapVertical( + textMap: Mat, + affinityMap: Mat, + textThreshold: Double, + linkThreshold: Double, + independentCharacters: Boolean + ): MutableList { + val imgH = textMap.rows() + val imgW = textMap.cols() + + val textScore = Mat() + val affinityScore = Mat() + Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) + Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) + val textScoreComb = Mat() + var kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size(3.0, 3.0) + ) + if (independentCharacters) { + Core.subtract(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 0.0, Imgproc.THRESH_TOZERO) + Imgproc.threshold(textScoreComb, textScoreComb, 1.0, 1.0, Imgproc.THRESH_TRUNC) + Imgproc.erode(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 1) + Imgproc.dilate(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 4) + } else { + Core.add(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 0.0, Imgproc.THRESH_TOZERO) + Imgproc.threshold(textScoreComb, textScoreComb, 1.0, 1.0, Imgproc.THRESH_TRUNC) + Imgproc.dilate(textScoreComb, textScoreComb, kernel, Point(-1.0, -1.0), 2) + } + + val binaryMat = Mat() + textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) + + val detectedBoxes = mutableListOf() + for (i in 1 until nLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + if (area < Constants.MIN_SIZE) continue + + val height = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val width = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + + if (!independentCharacters && height < width) continue + val mask = createMaskFromLabels(labels, i) + + val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) + segMap.setTo(Scalar(255.0), mask) + + val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() + val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() + val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() + val sx = max(x - dilationRadius, 0) + val ex = min(x + w + dilationRadius + 1, imgW) + val sy = max(y - dilationRadius, 0) + val ey = min(y + h + dilationRadius + 1, imgH) + val roi = Rect(sx, sy, ex - sx, ey - sy) + kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()) + ) + val roiSegMap = Mat(segMap, roi) + Imgproc.dilate(roiSegMap, roiSegMap, kernel, Point(-1.0, -1.0), 2) + + val contours: List = ArrayList() + Imgproc.findContours( + segMap, + contours, + Mat(), + Imgproc.RETR_EXTERNAL, + Imgproc.CHAIN_APPROX_SIMPLE + ) + if (contours.isNotEmpty()) { + val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) + val points = Array(4) { Point() } + minRect.points(points) + val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } + val boxInfo = OCRbBox(pointsList, minRect.angle) + detectedBoxes.add(boxInfo) + } + } + + return detectedBoxes + } + fun getDetBoxesFromTextMap( textMap: Mat, affinityMap: Mat, @@ -435,6 +528,8 @@ class DetectorUtils { mergedArray = removeSmallBoxes(mergedArray, minSideThreshold, maxSideThreshold) mergedArray = mergedArray.sortedWith(compareBy { minimumYFromBox(it.bBox) }).toMutableList() + mergedArray = mergedArray.map { box -> orderPointsClockwise(box) }.toMutableList() + return mergedArray } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt index 99adcad9f0..1847e8ee5c 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -4,6 +4,8 @@ import com.swmansion.rnexecutorch.utils.ImageProcessor import org.opencv.core.Core import org.opencv.core.CvType import org.opencv.core.Mat +import org.opencv.core.MatOfFloat +import org.opencv.core.MatOfInt import org.opencv.core.MatOfPoint2f import org.opencv.core.Point import org.opencv.core.Rect @@ -245,17 +247,21 @@ class RecognizerUtils { return computeRatioAndResize(croppedImage, boundingBox.width, boundingBox.height, modelHeight) } - fun normalizeForRecognizer(image: Mat, adjustContrast: Double): Mat { + fun normalizeForRecognizer( + image: Mat, + adjustContrast: Double, + isVertical: Boolean = false + ): Mat { var img = image.clone() if (adjustContrast > 0) { img = adjustContrastGrey(img, adjustContrast) } - val desiredWidth = when { + val desiredWidth =when { img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH - else -> Constants.SMALL_MODEL_WIDTH + else -> if (isVertical) Constants.VERTICAL_SMALL_MODEL_WIDTH else Constants.SMALL_MODEL_WIDTH } img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) @@ -265,5 +271,121 @@ class RecognizerUtils { return img } + + fun cropImageWithBoundingBox( + image: Mat, + box: List, + originalBox: List, + paddings: Map, + originalPaddings: Map + ): Mat { + val topLeft = originalBox[0] + val points = arrayOfNulls(4) + + for (i in 0 until 4) { + val cords = box[i] + cords.x -= paddings["left"]!! as Int + cords.y -= paddings["top"]!! as Int + + cords.x *= paddings["resizeRatio"]!! as Float + cords.y *= paddings["resizeRatio"]!! as Float + + cords.x += topLeft.x + cords.y += topLeft.y + + cords.x -= originalPaddings["left"]!! as Int + cords.y -= (originalPaddings["top"]!! as Int) + + cords.x *= originalPaddings["resizeRatio"]!! as Float + cords.y *= originalPaddings["resizeRatio"]!! as Float + + points[i] = Point(cords.x, cords.y) + } + + val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) + val croppedImage = Mat(image, boundingBox) + Imgproc.cvtColor(croppedImage, croppedImage, Imgproc.COLOR_BGR2GRAY) + Imgproc.resize(croppedImage, croppedImage, Size(64.0, 64.0), 0.0, 0.0, Imgproc.INTER_LANCZOS4) + Imgproc.medianBlur(croppedImage, croppedImage, 1) + + return croppedImage + } + + fun extractBoundingBox(cords: List): Rect { + val points = arrayOfNulls(4) + + for (i in 0 until 4) { + points[i] = Point(cords[i].x, cords[i].y) + } + + val boundingBox = Imgproc.boundingRect(MatOfPoint2f(*points)) + + return boundingBox + } + + fun cropSingleCharacter(img: Mat): Mat { + val histogram = Mat() + val histSize = MatOfInt(256) + val range = MatOfFloat(0f, 256f) + Imgproc.calcHist( + listOf(img), + MatOfInt(0), + Mat(), + histogram, + histSize, + range + ) + + val midPoint = 256 / 2 + var sumLeft = 0.0 + var sumRight = 0.0 + for (i in 0 until midPoint) { + sumLeft += histogram.get(i, 0)[0] + } + for (i in midPoint until 256) { + sumRight += histogram.get(i, 0)[0] + } + + val thresholdType = if (sumLeft < sumRight) Imgproc.THRESH_BINARY_INV else Imgproc.THRESH_BINARY + + val thresh = Mat() + Imgproc.threshold(img, thresh, 0.0, 255.0, thresholdType + Imgproc.THRESH_OTSU) + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val numLabels = Imgproc.connectedComponentsWithStats(thresh, labels, stats, centroids, 8) + + val centralThreshold = 0.3 + val height = thresh.rows() + val width = thresh.cols() + val minX = centralThreshold * width + val maxX = (1 - centralThreshold) * width + val minY = centralThreshold * height + val maxY = (1 - centralThreshold) * height + + var selectedComponent = -1 + for (i in 1 until numLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + val cx = centroids.get(i, 0)[0] + val cy = centroids.get(i, 1)[0] + if (cx > minX && cx < maxX && cy > minY && cy < maxY && area > Constants.SINGLE_CHARACTER_MIN_SIZE) { + if (selectedComponent == -1 || area > stats.get(selectedComponent, Imgproc.CC_STAT_AREA)[0]) { + selectedComponent = i + } + } + } + + val mask = Mat.zeros(img.size(), CvType.CV_8UC1) + if (selectedComponent != -1) { + Core.compare(labels, Scalar(selectedComponent.toDouble()), mask, Core.CMP_EQ) + } + + val resultImage = Mat.zeros(img.size(), img.type()) + img.copyTo(resultImage, mask) + + Core.bitwise_not(resultImage, resultImage) + return resultImage + } } } diff --git a/examples/computer-vision/App.tsx b/examples/computer-vision/App.tsx index 488c61cd56..c79519ca89 100644 --- a/examples/computer-vision/App.tsx +++ b/examples/computer-vision/App.tsx @@ -9,12 +9,14 @@ import { View, StyleSheet } from 'react-native'; import { ClassificationScreen } from './screens/ClassificationScreen'; import { ObjectDetectionScreen } from './screens/ObjectDetectionScreen'; import { OCRScreen } from './screens/OCRScreen'; +import { VerticalOCRScreen } from './screens/VerticalOCRScreen'; enum ModelType { STYLE_TRANSFER, OBJECT_DETECTION, CLASSIFICATION, OCR, + VERTICAL_OCR, } export default function App() { @@ -50,6 +52,10 @@ export default function App() { ); case ModelType.OCR: return ; + case ModelType.VERTICAL_OCR: + return ( + + ); default: return ( @@ -69,6 +75,7 @@ export default function App() { 'Object Detection', 'Classification', 'OCR', + 'Vertical OCR', ]} onValueChange={(_, selectedIndex) => { handleModeChange(selectedIndex); diff --git a/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h new file mode 100644 index 0000000000..8263ddd4e8 --- /dev/null +++ b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.h @@ -0,0 +1,28 @@ +#import "BaseModel.h" +#import "RecognitionHandler.h" +#import "opencv2/opencv.hpp" + +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat textThresholdVertical = 0.3; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr CGFloat restoreRatioVertical = 2.0; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); +constexpr int minSize = 20; + +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); + +@interface VerticalDetector : BaseModel + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; +- (cv::Size)getModelImageSize; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm new file mode 100644 index 0000000000..a2657a00cf --- /dev/null +++ b/examples/computer-vision/node_modules/react-native-executorch/ios/RnExecutorch/models/ocr/VerticalDetector.mm @@ -0,0 +1,118 @@ +#import "VerticalDetector.h" +#import "../../utils/ImageProcessor.h" +#import "utils/DetectorUtils.h" +#import "utils/OCRUtils.h" + +/* + The model used as detector is based on CRAFT (Character Region Awareness for + Text Detection) paper. https://arxiv.org/pdf/1904.01941 + */ + +@implementation VerticalDetector { + cv::Size originalSize; + cv::Size modelSize; + BOOL detectSingleCharacters; +} + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters { + self = [super init]; + if (self) { + self->detectSingleCharacters = detectSingleCharacters; + } + return self; +} + +- (cv::Size)getModelImageSize { + if (!modelSize.empty()) { + return modelSize; + } + + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + modelSize = cv::Size(height, width); + + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + /* + Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. + */ + self->originalSize = cv::Size(input.cols, input.rows); + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat resizedImage; + resizedImage = [OCRUtils resizeWithPadding:input + desiredWidth:modelImageSize.width + desiredHeight:modelImageSize.height]; + NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage + mean:mean + variance:variance]; + return modelInput; +} + +- (NSArray *)postprocess:(NSArray *)output { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + 400x400 + + The result of this step is a list of bounding boxes that contain text. + */ + NSArray *predictions = [output objectAtIndex:0]; + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat scoreTextCV, scoreAffinityCV; + /* + The output of the model is a matrix in size of input image containing two + matrices representing heatmap. Those two matrices are in the size of half of + the input image, that's why the width and height is divided by 2. + */ + [DetectorUtils interleavedArrayToMats:predictions + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, + modelImageSize.height / 2)]; + CGFloat txtThreshold = textThreshold; + if (!self->detectSingleCharacters) { + txtThreshold = textThresholdVertical; + } + NSArray *bBoxesList = [DetectorUtils + getDetBoxesFromTextMapVertical:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:txtThreshold + linkThreshold:linkThreshold + independentCharacters:self->detectSingleCharacters]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatioVertical]; + + if (self->detectSingleCharacters) { + return bBoxesList; + } + + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList + centerThreshold:centerThreshold + distanceThreshold:distanceThreshold + heightThreshold:heightThreshold + minSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold + maxWidth:maxWidth]; + + return bBoxesList; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + return result; +} + +@end diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 9d17118afb..4bfc0bee2f 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -1,7 +1,13 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../components/BottomBar'; import { getImage } from '../utils'; -import { useOCR } from 'react-native-executorch'; +import { + DETECTOR_CRAFT_800, + RECOGNIZER_EN_CRNN_128, + RECOGNIZER_EN_CRNN_256, + RECOGNIZER_EN_CRNN_512, + useOCR, +} from 'react-native-executorch'; import { View, StyleSheet, Image, Text } from 'react-native'; import { useState } from 'react'; import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; @@ -19,16 +25,13 @@ export const OCRScreen = ({ height: number; }>(); const [detectedText, setDetectedText] = useState(''); + const model = useOCR({ - detectorSource: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_800.pte', + detectorSource: DETECTOR_CRAFT_800, recognizerSources: { - recognizerLarge: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', - recognizerMedium: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_256.pte', - recognizerSmall: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_128.pte', + recognizerLarge: RECOGNIZER_EN_CRNN_512, + recognizerMedium: RECOGNIZER_EN_CRNN_256, + recognizerSmall: RECOGNIZER_EN_CRNN_128, }, language: 'en', }); @@ -63,7 +66,10 @@ export const OCRScreen = ({ if (!model.isReady) { return ( - + ); } diff --git a/examples/computer-vision/screens/VerticalOCRScreen.tsx b/examples/computer-vision/screens/VerticalOCRScreen.tsx new file mode 100644 index 0000000000..e242fb1138 --- /dev/null +++ b/examples/computer-vision/screens/VerticalOCRScreen.tsx @@ -0,0 +1,120 @@ +import Spinner from 'react-native-loading-spinner-overlay'; +import { BottomBar } from '../components/BottomBar'; +import { getImage } from '../utils'; +import { + DETECTOR_CRAFT_1280, + DETECTOR_CRAFT_320, + RECOGNIZER_EN_CRNN_512, + RECOGNIZER_EN_CRNN_64, + useVerticalOCR, +} from 'react-native-executorch'; +import { View, StyleSheet, Image, Text } from 'react-native'; +import { useState } from 'react'; +import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; + +export const VerticalOCRScreen = ({ + imageUri, + setImageUri, +}: { + imageUri: string; + setImageUri: (imageUri: string) => void; +}) => { + const [results, setResults] = useState([]); + const [imageDimensions, setImageDimensions] = useState<{ + width: number; + height: number; + }>(); + const [detectedText, setDetectedText] = useState(''); + const model = useVerticalOCR({ + detectorSources: { + detectorLarge: DETECTOR_CRAFT_1280, + detectorNarrow: DETECTOR_CRAFT_320, + }, + recognizerSources: { + recognizerLarge: RECOGNIZER_EN_CRNN_512, + recognizerSmall: RECOGNIZER_EN_CRNN_64, + }, + language: 'en', + independentCharacters: true, + }); + + const handleCameraPress = async (isCamera: boolean) => { + const image = await getImage(isCamera); + const width = image?.width; + const height = image?.height; + setImageDimensions({ width: width as number, height: height as number }); + const uri = image?.uri; + if (typeof uri === 'string') { + setImageUri(uri as string); + setResults([]); + setDetectedText(''); + } + }; + + const runForward = async () => { + try { + const output = await model.forward(imageUri); + setResults(output); + console.log(output); + let txt = ''; + output.forEach((detection: any) => { + txt += detection.text + ' '; + }); + setDetectedText(txt); + } catch (e) { + console.error(e); + } + }; + + if (!model.isReady) { + return ( + + ); + } + + return ( + <> + + + {imageUri && imageDimensions?.width && imageDimensions?.height ? ( + + ) : ( + + )} + + {detectedText} + + + + ); +}; + +const styles = StyleSheet.create({ + image: { + flex: 2, + borderRadius: 8, + width: '100%', + }, + imageContainer: { + flex: 6, + width: '100%', + padding: 16, + }, +}); diff --git a/ios/RnExecutorch/OCR.h b/ios/RnExecutorch/OCR.h index 68c0878598..4994108bce 100644 --- a/ios/RnExecutorch/OCR.h +++ b/ios/RnExecutorch/OCR.h @@ -1,7 +1,5 @@ #import -constexpr CGFloat recognizerRatio = 1.6; - @interface OCR : NSObject @end diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index 59740c90bb..509e38765a 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -2,6 +2,7 @@ #import "models/ocr/Detector.h" #import "models/ocr/RecognitionHandler.h" #import "utils/ImageProcessor.h" +#import "models/ocr/utils/Constants.h" #import #import @@ -80,6 +81,7 @@ of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). cv::Mat image = [ImageProcessor readImage:input]; NSArray *result = [detector runModel:image]; cv::Size detectorSize = [detector getModelImageSize]; + const CGFloat recognizerRatio = recognizerImageSize / detectorSize.width; cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); result = [self->recognitionHandler recognize:result diff --git a/ios/RnExecutorch/VerticalOCR.h b/ios/RnExecutorch/VerticalOCR.h new file mode 100644 index 0000000000..5692d37897 --- /dev/null +++ b/ios/RnExecutorch/VerticalOCR.h @@ -0,0 +1,5 @@ +#import + +@interface VerticalOCR : NSObject + +@end diff --git a/ios/RnExecutorch/VerticalOCR.mm b/ios/RnExecutorch/VerticalOCR.mm new file mode 100644 index 0000000000..ef5e58a2bc --- /dev/null +++ b/ios/RnExecutorch/VerticalOCR.mm @@ -0,0 +1,156 @@ +#import "VerticalOCR.h" +#import "models/ocr/VerticalDetector.h" +#import "models/ocr/RecognitionHandler.h" +#import "models/ocr/Recognizer.h" +#import "models/ocr/utils/RecognizerUtils.h" +#import "utils/ImageProcessor.h" +#import +#import +#import "models/ocr/utils/OCRUtils.h" +#import "models/ocr/utils/CTCLabelConverter.h" + +@implementation VerticalOCR { + VerticalDetector *detectorLarge; + VerticalDetector *detectorNarrow; + Recognizer *recognizer; + CTCLabelConverter *converter; + BOOL independentCharacters; +} + +RCT_EXPORT_MODULE() + +- (void)loadModule:(NSString *)detectorLargeSource +detectorNarrowSource:(NSString *)detectorNarrowSource + recognizerSource:(NSString *)recognizerSource + symbols:(NSString *)symbols +independentCharacters:(BOOL)independentCharacters + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + detectorLarge = [[VerticalDetector alloc] initWithDetectSingleCharacters:NO]; + converter = [[CTCLabelConverter alloc] initWithCharacters:symbols separatorList:@{}]; + self->independentCharacters = independentCharacters; + [detectorLarge + loadModel:[NSURL URLWithString:detectorLargeSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + reject(@"init_module_error", @"Failed to initialize detector module", + nil); + return; + } + self->detectorNarrow = [[VerticalDetector alloc] initWithDetectSingleCharacters:YES]; + [self->detectorNarrow + loadModel:[NSURL URLWithString:detectorNarrowSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + reject(@"init_module_error", + @"Failed to initialize detector module", nil); + return; + } + + self->recognizer = [[Recognizer alloc] init]; + [self->recognizer + loadModel:[NSURL URLWithString:recognizerSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + reject(@"init_module_error", + @"Failed to initialize recognizer module", nil); + } + + resolve(@(YES)); + }]; + }]; + }]; + +} + +- (void)forward:(NSString *)input + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + @try { + cv::Mat image = [ImageProcessor readImage:input]; + NSArray *result = [detectorLarge runModel:image]; + cv::Size largeDetectorSize = [detectorLarge getModelImageSize]; + cv::Mat resizedImage = [OCRUtils resizeWithPadding:image desiredWidth:largeDetectorSize.width desiredHeight:largeDetectorSize.height]; + NSMutableArray *predictions = [NSMutableArray array]; + + for (NSDictionary *box in result){ + NSArray *cords = box[@"bbox"]; + const int boxWidth = [[cords objectAtIndex:2] CGPointValue].x - [[cords objectAtIndex:0] CGPointValue].x; + const int boxHeight = [[cords objectAtIndex:2] CGPointValue].y - [[cords objectAtIndex:0] CGPointValue].y; + + cv::Rect boundingBox = [OCRUtils extractBoundingBox:cords]; + cv::Mat croppedImage = resizedImage(boundingBox); + NSDictionary *paddings = + [RecognizerUtils calculateResizeRatioAndPaddings:image.cols + height:image.rows + desiredWidth:largeDetectorSize.width + desiredHeight:largeDetectorSize.height]; + + NSString *text = @""; + NSNumber *confidenceScore = @0.0; + NSArray *boxResult = [detectorNarrow runModel:croppedImage]; + std::vector croppedCharacters; + cv::Size narrowRecognizerSize = [detectorNarrow getModelImageSize]; + for(NSDictionary *characterBox in boxResult){ + NSArray *boxCords = characterBox[@"bbox"]; + NSDictionary *paddingsBox = [RecognizerUtils calculateResizeRatioAndPaddings:boxWidth height:boxHeight desiredWidth:narrowRecognizerSize.width desiredHeight:narrowRecognizerSize.height]; + cv::Mat croppedCharacter = [RecognizerUtils cropImageWithBoundingBox:image bbox:boxCords originalBbox:cords paddings:paddingsBox originalPaddings:paddings]; + if(self->independentCharacters){ + croppedCharacter = [RecognizerUtils cropSingleCharacter:croppedCharacter]; + croppedCharacter = [RecognizerUtils normalizeForRecognizer:croppedCharacter adjustContrast:0.0 isVertical: YES]; + NSArray *recognitionResult = [recognizer runModel:croppedCharacter]; + NSArray *predIndex = [recognitionResult objectAtIndex:0]; + NSArray *decodedText = [converter decodeGreedy: predIndex length:(int)(predIndex.count)]; + text = [text stringByAppendingString:decodedText[0]]; + confidenceScore = @([confidenceScore floatValue] + [[recognitionResult objectAtIndex:1] floatValue]); + }else{ + croppedCharacters.push_back(croppedCharacter); + } + } + + if(self->independentCharacters){ + confidenceScore = @([confidenceScore floatValue] / boxResult.count); + }else{ + cv::Mat mergedCharacters; + cv::hconcat(croppedCharacters.data(), (int)croppedCharacters.size(), mergedCharacters); + mergedCharacters = [OCRUtils resizeWithPadding:mergedCharacters desiredWidth:largeRecognizerWidth desiredHeight:recognizerHeight]; + mergedCharacters = [RecognizerUtils normalizeForRecognizer:mergedCharacters adjustContrast:0.0 isVertical: NO]; + NSArray *recognitionResult = [recognizer runModel:mergedCharacters]; + NSArray *predIndex = [recognitionResult objectAtIndex:0]; + NSArray *decodedText = [converter decodeGreedy: predIndex length:(int)(predIndex.count)]; + text = [text stringByAppendingString:decodedText[0]]; + confidenceScore = @([confidenceScore floatValue] + [[recognitionResult objectAtIndex:1] floatValue]); + } + + NSMutableArray *newCoords = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *cord in cords){ + const CGPoint point = [cord CGPointValue]; + + [newCoords addObject:@{ + @"x" : @((point.x - [paddings[@"left"] intValue]) * [paddings[@"resizeRatio"] floatValue]), + @"y" : @((point.y - [paddings[@"top"] intValue]) * [paddings[@"resizeRatio"] floatValue]) + }]; + } + + NSDictionary *res = @{ + @"text" : text, + @"bbox" : newCoords, + @"score" : confidenceScore + }; + [predictions addObject:res]; + } + + + resolve(predictions); + } @catch (NSException *exception) { + reject(@"forward_error", + [NSString stringWithFormat:@"%@", exception.reason], nil); + } +} + +- (std::shared_ptr)getTurboModule: +(const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared(params); +} + +@end diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 0f67e93b84..1644135931 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,21 +1,7 @@ #import "BaseModel.h" #import "RecognitionHandler.h" #import "opencv2/opencv.hpp" - -constexpr CGFloat textThreshold = 0.4; -constexpr CGFloat linkThreshold = 0.4; -constexpr CGFloat lowTextThreshold = 0.7; -constexpr CGFloat centerThreshold = 0.5; -constexpr CGFloat distanceThreshold = 2.0; -constexpr CGFloat heightThreshold = 2.0; -constexpr CGFloat restoreRatio = 3.2; -constexpr int minSideThreshold = 15; -constexpr int maxSideThreshold = 30; -constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); -constexpr int minSize = 20; - -const cv::Scalar mean(0.485, 0.456, 0.406); -const cv::Scalar variance(0.229, 0.224, 0.225); +#import "utils/Constants.h" @interface Detector : BaseModel diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 20b82b5ee7..5bec88369b 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -19,8 +19,8 @@ @implementation Detector { } NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape.lastObject; - NSNumber *heightNumber = inputShape[inputShape.count - 2]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; @@ -36,7 +36,6 @@ - (NSArray *)preprocess:(cv::Mat &)input { original aspect ratio and the missing parts are filled with padding. */ self->originalSize = cv::Size(input.cols, input.rows); - cv::Size modelImageSize = [self getModelImageSize]; cv::Mat resizedImage; resizedImage = [OCRUtils resizeWithPadding:input @@ -79,6 +78,7 @@ group each character into a single instance (sequence) Both matrices are lowTextThreshold:lowTextThreshold]; bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList usingRestoreRatio:restoreRatio]; + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList centerThreshold:centerThreshold distanceThreshold:distanceThreshold diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index 412504370e..7f674d9892 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -1,12 +1,5 @@ #import "opencv2/opencv.hpp" -constexpr int modelHeight = 64; -constexpr int largeModelWidth = 512; -constexpr int mediumModelWidth = 256; -constexpr int smallModelWidth = 128; -constexpr CGFloat lowConfidenceThreshold = 0.3; -constexpr CGFloat adjustContrast = 0.2; - @interface RecognitionHandler : NSObject - (instancetype)initWithSymbols:(NSString *)symbols; diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 60616b9099..a263297720 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -3,6 +3,7 @@ #import "./utils/CTCLabelConverter.h" #import "./utils/OCRUtils.h" #import "./utils/RecognizerUtils.h" +#import "./utils/Constants.h" #import "ExecutorchLib/ETModel.h" #import "Recognizer.h" #import @@ -72,9 +73,9 @@ - (void)loadRecognizers:(NSString *)largeRecognizerPath - (NSArray *)runModel:(cv::Mat)croppedImage { NSArray *result; - if (croppedImage.cols >= largeModelWidth) { + if (croppedImage.cols >= largeRecognizerWidth) { result = [recognizerLarge runModel:croppedImage]; - } else if (croppedImage.cols >= mediumModelWidth) { + } else if (croppedImage.cols >= mediumRecognizerWidth) { result = [recognizerMedium runModel:croppedImage]; } else { result = [recognizerSmall runModel:croppedImage]; @@ -103,12 +104,12 @@ - (NSArray *)recognize:(NSArray *)bBoxesList for (NSDictionary *box in bBoxesList) { cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box image:imgGray - modelHeight:modelHeight]; + modelHeight:recognizerHeight]; if (croppedImage.empty()) { continue; } croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage - adjustContrast:adjustContrast]; + adjustContrast:adjustContrast isVertical:NO]; NSArray *result = [self runModel:croppedImage]; NSNumber *confidenceScore = [result objectAtIndex:1]; diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm index 8b339bc238..e3ee908984 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.mm +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -14,8 +14,8 @@ @implementation Recognizer { - (cv::Size)getModelImageSize { NSArray *inputShape = [module getInputShape:@0]; - NSNumber *widthNumber = inputShape.lastObject; - NSNumber *heightNumber = inputShape[inputShape.count - 2]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; @@ -24,8 +24,8 @@ @implementation Recognizer { - (cv::Size)getModelOutputSize { NSArray *outputShape = [module getOutputShape:@0]; - NSNumber *widthNumber = outputShape.lastObject; - NSNumber *heightNumber = outputShape[outputShape.count - 2]; + NSNumber *widthNumber = outputShape[outputShape.count - 2]; + NSNumber *heightNumber = outputShape.lastObject; const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; diff --git a/ios/RnExecutorch/models/ocr/VerticalDetector.h b/ios/RnExecutorch/models/ocr/VerticalDetector.h new file mode 100644 index 0000000000..1c1fcd2ec1 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/VerticalDetector.h @@ -0,0 +1,12 @@ +#import "BaseModel.h" +#import "RecognitionHandler.h" +#import "opencv2/opencv.hpp" +#import "utils/Constants.h" + +@interface VerticalDetector : BaseModel + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters; +- (cv::Size)getModelImageSize; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/ios/RnExecutorch/models/ocr/VerticalDetector.mm b/ios/RnExecutorch/models/ocr/VerticalDetector.mm new file mode 100644 index 0000000000..087604dd5d --- /dev/null +++ b/ios/RnExecutorch/models/ocr/VerticalDetector.mm @@ -0,0 +1,117 @@ +#import "VerticalDetector.h" +#import "../../utils/ImageProcessor.h" +#import "utils/DetectorUtils.h" +#import "utils/OCRUtils.h" + +/* + The model used as detector is based on CRAFT (Character Region Awareness for + Text Detection) paper. https://arxiv.org/pdf/1904.01941 + */ + +@implementation VerticalDetector { + cv::Size originalSize; + cv::Size modelSize; + BOOL detectSingleCharacters; +} + +- (instancetype)initWithDetectSingleCharacters:(BOOL)detectSingleCharacters { + self = [super init]; + if (self) { + self->detectSingleCharacters = detectSingleCharacters; + } + return self; +} + +- (cv::Size)getModelImageSize { + if (!modelSize.empty()) { + return modelSize; + } + + NSArray *inputShape = [module getInputShape:@0]; + NSNumber *widthNumber = inputShape[inputShape.count - 2]; + NSNumber *heightNumber = inputShape.lastObject; + + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; + modelSize = cv::Size(height, width); + + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + /* + Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. + */ + self->originalSize = cv::Size(input.cols, input.rows); + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat resizedImage; + resizedImage = [OCRUtils resizeWithPadding:input + desiredWidth:modelImageSize.width + desiredHeight:modelImageSize.height]; + NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage + mean:mean + variance:variance]; + return modelInput; +} + +- (NSArray *)postprocess:(NSArray *)output { + /* + The output of the model consists of two matrices (heat maps): + 1. ScoreText(Score map) - The probability of a region containing character + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + 400x400 + + The result of this step is a list of bounding boxes that contain text. + */ + NSArray *predictions = [output objectAtIndex:0]; + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat scoreTextCV, scoreAffinityCV; + /* + The output of the model is a matrix in size of input image containing two + matrices representing heatmap. Those two matrices are in the size of half of + the input image, that's why the width and height is divided by 2. + */ + [DetectorUtils interleavedArrayToMats:predictions + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, + modelImageSize.height / 2)]; + CGFloat txtThreshold = (self->detectSingleCharacters) ? textThreshold + : textThresholdVertical; + + NSArray *bBoxesList = [DetectorUtils + getDetBoxesFromTextMapVertical:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:txtThreshold + linkThreshold:linkThreshold + independentCharacters:self->detectSingleCharacters]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatioVertical]; + + if (self->detectSingleCharacters) { + return bBoxesList; + } + + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList + centerThreshold:centerThreshold + distanceThreshold:distanceThreshold + heightThreshold:heightThreshold + minSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold + maxWidth:maxWidth]; + + return bBoxesList; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + return result; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/Constants.h b/ios/RnExecutorch/models/ocr/utils/Constants.h new file mode 100644 index 0000000000..ba1e162227 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/Constants.h @@ -0,0 +1,26 @@ +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat textThresholdVertical = 0.3; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr CGFloat restoreRatioVertical = 2.0; +constexpr CGFloat singleCharacterCenterThreshold = 0.3; +constexpr CGFloat lowConfidenceThreshold = 0.3; +constexpr CGFloat adjustContrast = 0.2; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int recognizerHeight = 64; +constexpr int largeRecognizerWidth = 512; +constexpr int mediumRecognizerWidth = 256; +constexpr int smallRecognizerWidth = 128; +constexpr int smallVerticalRecognizerWidth = 64; +constexpr int maxWidth = largeRecognizerWidth + (largeRecognizerWidth * 0.15); +constexpr int minSize = 20; +constexpr int singleCharacterMinSize = 70; +constexpr int recognizerImageSize = 1280; + +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); \ No newline at end of file diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index 3f205b8ebd..704671695f 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -13,9 +13,14 @@ constexpr int verticalLineThreshold = 20; usingTextThreshold:(CGFloat)textThreshold linkThreshold:(CGFloat)linkThreshold lowTextThreshold:(CGFloat)lowTextThreshold; ++ (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + independentCharacters:(BOOL)independentCharacters; + (NSArray *)restoreBboxRatio:(NSArray *)boxes usingRestoreRatio:(CGFloat)restoreRatio; -+ (NSArray *)groupTextBoxes:(NSArray *)polys ++ (NSArray *)groupTextBoxes:(NSArray *)polys centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold heightThreshold:(CGFloat)heightThreshold diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 8ee7424d00..0bdd6a766a 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -22,6 +22,98 @@ + (void)interleavedArrayToMats:(NSArray *)array } } ++ (NSArray *)getDetBoxesFromTextMapVertical:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + independentCharacters:(BOOL)independentCharacters { + const int imgH = textMap.rows; + const int imgW = textMap.cols; + cv::Mat textScore; + cv::Mat affinityScore; + cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, 1, + cv::THRESH_BINARY); + cv::Mat textScoreComb; + if (independentCharacters) { + textScoreComb = textScore - affinityScore; + cv::threshold(textScoreComb, textScoreComb, 0.0, 0, cv::THRESH_TOZERO); + cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); + cv::erode(textScoreComb, textScoreComb, + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), + cv::Point(-1, -1), 1); + cv::dilate(textScoreComb, textScoreComb, + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), + cv::Point(-1, -1), 4); + } else { + textScoreComb = textScore + affinityScore; + cv::threshold(textScoreComb, textScoreComb, 0.0, 0, cv::THRESH_TOZERO); + cv::threshold(textScoreComb, textScoreComb, 1.0, 1.0, cv::THRESH_TRUNC); + cv::dilate(textScoreComb, textScoreComb, + cv::getStructuringElement(cv::MORPH_RECT, cv::Size(3, 3)), + cv::Point(-1, -1), 2); + } + + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + cv::Mat labels, stats, centroids; + const int nLabels = + cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); + + NSMutableArray *detectedBoxes = [NSMutableArray array]; + for (int i = 1; i < nLabels; i++) { + const int area = stats.at(i, cv::CC_STAT_AREA); + if (area < 20) + continue; + const int width = stats.at(i, cv::CC_STAT_WIDTH); + const int height = stats.at(i, cv::CC_STAT_HEIGHT); + + if (!independentCharacters && height < width) + continue; + + cv::Mat mask = (labels == i); + + cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); + segMap.setTo(255, mask); + + const int x = stats.at(i, cv::CC_STAT_LEFT); + const int y = stats.at(i, cv::CC_STAT_TOP); + const int w = stats.at(i, cv::CC_STAT_WIDTH); + const int h = stats.at(i, cv::CC_STAT_HEIGHT); + const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h))) * 2.0); + const int sx = MAX(x - dilationRadius, 0); + const int ex = MIN(x + w + dilationRadius + 1, imgW); + const int sy = MAX(y - dilationRadius, 0); + const int ey = MIN(y + h + dilationRadius + 1, imgH); + + cv::Rect roi(sx, sy, ex - sx, ey - sy); + cv::Mat kernel = cv::getStructuringElement( + cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); + cv::Mat roiSegMap = segMap(roi); + cv::dilate(roiSegMap, roiSegMap, kernel, cv::Point(-1, -1), 2); + + std::vector> contours; + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, + cv::CHAIN_APPROX_SIMPLE); + if (!contours.empty()) { + cv::RotatedRect minRect = cv::minAreaRect(contours[0]); + cv::Point2f vertices[4]; + minRect.points(vertices); + NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; + for (int j = 0; j < 4; j++) { + const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); + [pointsArray addObject:[NSValue valueWithCGPoint:point]]; + } + NSDictionary *dict = + @{@"bbox" : pointsArray, @"angle" : @(minRect.angle)}; + [detectedBoxes addObject:dict]; + } + } + + return detectedBoxes; +} + /** * This method applies a series of image processing operations to identify * likely areas of text in the textMap and return the bounding boxes for single @@ -545,7 +637,7 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes * criteria. * 4. Sort the final array of boxes by their vertical positions. */ -+ (NSArray *)groupTextBoxes: ++ (NSArray *)groupTextBoxes: (NSMutableArray *)boxes centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold @@ -635,7 +727,7 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes usingMinSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold]; - NSArray *sortedBoxes = [mergedArray + NSArray *sortedBoxes = [mergedArray sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { NSArray *coords1 = obj1[@"bbox"]; @@ -646,8 +738,17 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes : (minY1 > minY2) ? NSOrderedDescending : NSOrderedSame; }]; - - return sortedBoxes; + + NSMutableArray *orderedSortedBoxes = [[NSMutableArray alloc] initWithCapacity:[sortedBoxes count]]; + for (NSDictionary *dict in sortedBoxes) { + NSMutableDictionary *mutableDict = [dict mutableCopy]; + NSArray *originalBBox = mutableDict[@"bbox"]; + NSArray *orderedBBox = [self orderPointsClockwise:originalBBox]; + mutableDict[@"bbox"] = orderedBBox; + [orderedSortedBoxes addObject:mutableDict]; + } + + return orderedSortedBoxes; } @end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h index dca8b9bba5..90a8fa7a43 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -5,5 +5,6 @@ + (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; ++ (cv::Rect)extractBoundingBox:(NSArray *)coords; @end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index f530dac2da..eed17a1520 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -1,4 +1,5 @@ #import "OCRUtils.h" +#import "RecognizerUtils.h" @implementation OCRUtils @@ -52,4 +53,16 @@ @implementation OCRUtils return centeredImg; } ++ (cv::Rect)extractBoundingBox:(NSArray *)coords { + std::vector points; + points.reserve(coords.count); + for (NSValue *value in coords) { + const CGPoint point = [value CGPointValue]; + + points.emplace_back(point.x, point.y); + } + + return cv::boundingRect(points); +} + @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index 7af748f58c..51d93638a3 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -8,7 +8,8 @@ height:(int)height modelHeight:(int)modelHeight; + (cv::Mat)normalizeForRecognizer:(cv::Mat)image - adjustContrast:(double)adjustContrast; + adjustContrast:(double)adjustContrast + isVertical:(BOOL)isVertical; + (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; + (cv::Mat)softmax:(cv::Mat)inputs; @@ -24,5 +25,11 @@ + (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; + (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray; ++ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img + bbox:(NSArray *)bbox + originalBbox:(NSArray *)originalBbox + paddings:(NSDictionary *)paddings + originalPaddings:(NSDictionary *)originalPaddings; ++ (cv::Mat)cropSingleCharacter:(cv::Mat)img; @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 65c088b361..1908ad6f99 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -1,5 +1,6 @@ #import "RecognizerUtils.h" #import "OCRUtils.h" +#import "Constants.h" @implementation RecognizerUtils @@ -56,21 +57,23 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { } + (cv::Mat)normalizeForRecognizer:(cv::Mat)image - adjustContrast:(double)adjustContrast { + adjustContrast:(double)adjustContrast + isVertical:(BOOL)isVertical { if (adjustContrast > 0) { image = [self adjustContrastGrey:image target:adjustContrast]; } - int desiredWidth = 128; - if (image.cols >= 512) { - desiredWidth = 512; - } else if (image.cols >= 256) { - desiredWidth = 256; + int desiredWidth = (isVertical) ? smallVerticalRecognizerWidth : smallRecognizerWidth; + + if (image.cols >= largeRecognizerWidth) { + desiredWidth = largeRecognizerWidth; + } else if (image.cols >= mediumRecognizerWidth) { + desiredWidth = mediumRecognizerWidth; } image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth - desiredHeight:64]; + desiredHeight:recognizerHeight]; image.convertTo(image, CV_32F, 1.0 / 255.0); image = (image - 0.5) * 2.0; @@ -220,4 +223,105 @@ + (double)computeConfidenceScore:(NSArray *)valuesArray return pow(product, 2.0 / sqrt(predsMaxProb.count)); } ++ (cv::Mat)cropImageWithBoundingBox:(cv::Mat &)img + bbox:(NSArray *)bbox + originalBbox:(NSArray *)originalBbox + paddings:(NSDictionary *)paddings + originalPaddings:(NSDictionary *)originalPaddings { + CGPoint topLeft = [originalBbox[0] CGPointValue]; + std::vector points; + points.reserve(bbox.count); + for (NSValue *coords in bbox) { + CGPoint point = [coords CGPointValue]; + + point.x = point.x - [paddings[@"left"] intValue]; + point.y = point.y - [paddings[@"top"] intValue]; + + point.x = point.x * [paddings[@"resizeRatio"] floatValue]; + point.y = point.y * [paddings[@"resizeRatio"] floatValue]; + + point.x = point.x + topLeft.x; + point.y = point.y + topLeft.y; + + point.x = point.x - [originalPaddings[@"left"] intValue]; + point.y = point.y - [originalPaddings[@"top"] intValue]; + + point.x = point.x * [originalPaddings[@"resizeRatio"] floatValue]; + point.y = point.y * [originalPaddings[@"resizeRatio"] floatValue]; + + points.emplace_back(cv::Point2f(point.x, point.y)); + } + + cv::Rect rect = cv::boundingRect(points); + cv::Mat croppedImage = img(rect); + return croppedImage; +} + ++ (cv::Mat)cropSingleCharacter:(cv::Mat)img { + cv::cvtColor(img, img, cv::COLOR_BGR2GRAY); + cv::resize(img, img, cv::Size(smallVerticalRecognizerWidth, recognizerHeight), 0, 0, + cv::INTER_AREA); + cv::medianBlur(img, img, 1); + + cv::Mat histogram; + + int histSize = 256; + float range[] = {0, 256}; + const float *histRange = {range}; + bool uniform = true, accumulate = false; + + cv::calcHist(&img, 1, 0, cv::Mat(), histogram, 1, &histSize, &histRange, uniform, + accumulate); + + int midPoint = histSize / 2; + + double sumLeft = 0.0, sumRight = 0.0; + for (int i = 0; i < midPoint; i++) { + sumLeft += histogram.at(i); + } + for (int i = midPoint; i < histSize; i++) { + sumRight += histogram.at(i); + } + + const int thresholdType = (sumLeft < sumRight) ? cv::THRESH_BINARY_INV : cv::THRESH_BINARY; + + cv::Mat thresh; + cv::threshold(img, thresh, 0, 255, thresholdType + cv::THRESH_OTSU); + + cv::Mat labels, stats, centroids; + const int numLabels = + connectedComponentsWithStats(thresh, labels, stats, centroids, 8); + const CGFloat centralThreshold = singleCharacterCenterThreshold; + const int height = thresh.rows; + const int width = thresh.cols; + + const int minX = centralThreshold * width; + const int maxX = (1 - centralThreshold) * width; + const int minY = centralThreshold * height; + const int maxY = (1 - centralThreshold) * height; + + int selectedComponent = -1; + + for (int i = 1; i < numLabels; i++) { + const int area = stats.at(i, cv::CC_STAT_AREA); + const double cx = centroids.at(i, 0); + const double cy = centroids.at(i, 1); + + if (minX < cx && cx < maxX && minY < cy && cy < maxY && area > singleCharacterMinSize) { + if (selectedComponent == -1 || + area > stats.at(selectedComponent, cv::CC_STAT_AREA)) { + selectedComponent = i; + } + } + } + cv::Mat mask = cv::Mat::zeros(img.size(), CV_8UC1); + if (selectedComponent != -1) { + mask = (labels == selectedComponent) / 255; + } + cv::Mat resultImage = cv::Mat::zeros(img.size(), img.type()); + img.copyTo(resultImage, mask); + cv::bitwise_not(resultImage, resultImage); + return resultImage; +} + @end diff --git a/src/constants/modelUrls.ts b/src/constants/modelUrls.ts index 2f57331c44..30e3847909 100644 --- a/src/constants/modelUrls.ts +++ b/src/constants/modelUrls.ts @@ -46,6 +46,24 @@ export const STYLE_TRANSFER_UDNIE = ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/coreml/style_transfer_udnie_coreml.pte' : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.2.0/xnnpack/style_transfer_udnie_xnnpack.pte'; +// OCR + +export const DETECTOR_CRAFT_1280 = + 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.3.0/xnnpack/xnnpack_craft_1280.pte'; +export const DETECTOR_CRAFT_800 = + 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.3.0/xnnpack/xnnpack_craft_800.pte'; +export const DETECTOR_CRAFT_320 = + 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.3.0/xnnpack/xnnpack_craft_320.pte'; + +export const RECOGNIZER_EN_CRNN_512 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_512.pte'; +export const RECOGNIZER_EN_CRNN_256 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_256.pte'; +export const RECOGNIZER_EN_CRNN_128 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_128.pte'; +export const RECOGNIZER_EN_CRNN_64 = + 'https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.3.0/xnnpack/xnnpack_crnn_en_64.pte'; + // Backward compatibility export const LLAMA3_2_3B_URL = LLAMA3_2_3B; export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA; diff --git a/src/constants/ocr/languageDicts.ts b/src/constants/ocr/languageDicts.ts deleted file mode 100644 index fcd189b53c..0000000000 --- a/src/constants/ocr/languageDicts.ts +++ /dev/null @@ -1,4 +0,0 @@ -export const languageDicts: { [key: string]: string } = { - en: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/en.txt', - pl: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/pl.txt', -}; diff --git a/src/controllers/OCRController.ts b/src/controllers/OCRController.ts new file mode 100644 index 0000000000..a6cf1a5da1 --- /dev/null +++ b/src/controllers/OCRController.ts @@ -0,0 +1,111 @@ +import { symbols } from '../constants/ocr/symbols'; +import { ETError, getError } from '../Error'; +import { _OCRModule } from '../native/RnExecutorchModules'; +import { ResourceSource } from '../types/common'; +import { OCRLanguage } from '../types/ocr'; +import { + fetchResource, + calculateDownloadProgres, +} from '../utils/fetchResource'; + +export class OCRController { + private nativeModule: _OCRModule; + public isReady: boolean = false; + public isGenerating: boolean = false; + public error: string | null = null; + private modelDownloadProgressCallback: (downloadProgress: number) => void; + private isReadyCallback: (isReady: boolean) => void; + private isGeneratingCallback: (isGenerating: boolean) => void; + private errorCallback: (error: string) => void; + + constructor({ + modelDownloadProgressCallback = (_downloadProgress: number) => {}, + isReadyCallback = (_isReady: boolean) => {}, + isGeneratingCallback = (_isGenerating: boolean) => {}, + errorCallback = (_error: string) => {}, + }) { + this.nativeModule = new _OCRModule(); + this.modelDownloadProgressCallback = modelDownloadProgressCallback; + this.isReadyCallback = isReadyCallback; + this.isGeneratingCallback = isGeneratingCallback; + this.errorCallback = errorCallback; + } + + public loadModel = async ( + detectorSource: ResourceSource, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerMedium: ResourceSource; + recognizerSmall: ResourceSource; + }, + language: OCRLanguage + ) => { + try { + if (!detectorSource || Object.keys(recognizerSources).length !== 3) + return; + + if (!symbols[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + this.isReady = false; + this.isReadyCallback(false); + + const detectorPath = await fetchResource( + detectorSource, + calculateDownloadProgres(4, 0, this.modelDownloadProgressCallback) + ); + + const recognizerPaths = { + recognizerLarge: await fetchResource( + recognizerSources.recognizerLarge, + calculateDownloadProgres(4, 1, this.modelDownloadProgressCallback) + ), + recognizerMedium: await fetchResource( + recognizerSources.recognizerMedium, + calculateDownloadProgres(4, 2, this.modelDownloadProgressCallback) + ), + recognizerSmall: await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(4, 3, this.modelDownloadProgressCallback) + ), + }; + + await this.nativeModule.loadModule( + detectorPath, + recognizerPaths.recognizerLarge, + recognizerPaths.recognizerMedium, + recognizerPaths.recognizerSmall, + symbols[language] + ); + + this.isReady = true; + this.isReadyCallback(this.isReady); + } catch (e) { + if (this.errorCallback) { + this.errorCallback(getError(e)); + } else { + throw new Error(getError(e)); + } + } + }; + + public forward = async (input: string) => { + if (!this.isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (this.isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + this.isGenerating = true; + this.isGeneratingCallback(this.isGenerating); + return await this.nativeModule.forward(input); + } catch (e) { + throw new Error(getError(e)); + } finally { + this.isGenerating = false; + this.isGeneratingCallback(this.isGenerating); + } + }; +} diff --git a/src/controllers/VerticalOCRController.ts b/src/controllers/VerticalOCRController.ts new file mode 100644 index 0000000000..f09e70a7e8 --- /dev/null +++ b/src/controllers/VerticalOCRController.ts @@ -0,0 +1,119 @@ +import { symbols } from '../constants/ocr/symbols'; +import { ETError, getError } from '../Error'; +import { _VerticalOCRModule } from '../native/RnExecutorchModules'; +import { ResourceSource } from '../types/common'; +import { OCRLanguage } from '../types/ocr'; +import { + fetchResource, + calculateDownloadProgres, +} from '../utils/fetchResource'; + +export class VerticalOCRController { + private nativeModule: _VerticalOCRModule; + public isReady: boolean = false; + public isGenerating: boolean = false; + public error: string | null = null; + private modelDownloadProgressCallback: (downloadProgress: number) => void; + private isReadyCallback: (isReady: boolean) => void; + private isGeneratingCallback: (isGenerating: boolean) => void; + private errorCallback: (error: string) => void; + + constructor({ + modelDownloadProgressCallback = (_downloadProgress: number) => {}, + isReadyCallback = (_isReady: boolean) => {}, + isGeneratingCallback = (_isGenerating: boolean) => {}, + errorCallback = (_error: string) => {}, + }) { + this.nativeModule = new _VerticalOCRModule(); + this.modelDownloadProgressCallback = modelDownloadProgressCallback; + this.isReadyCallback = isReadyCallback; + this.isGeneratingCallback = isGeneratingCallback; + this.errorCallback = errorCallback; + } + + public loadModel = async ( + detectorSources: { + detectorLarge: ResourceSource; + detectorNarrow: ResourceSource; + }, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerSmall: ResourceSource; + }, + language: OCRLanguage, + independentCharacters: boolean + ) => { + try { + if ( + Object.keys(detectorSources).length !== 2 || + Object.keys(recognizerSources).length !== 2 + ) + return; + + if (!symbols[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + + this.isReady = false; + this.isReadyCallback(this.isReady); + + const recognizerPath = independentCharacters + ? await fetchResource( + recognizerSources.recognizerSmall, + calculateDownloadProgres(3, 0, this.modelDownloadProgressCallback) + ) + : await fetchResource( + recognizerSources.recognizerLarge, + calculateDownloadProgres(3, 0, this.modelDownloadProgressCallback) + ); + + const detectorPaths = { + detectorLarge: await fetchResource( + detectorSources.detectorLarge, + calculateDownloadProgres(3, 1, this.modelDownloadProgressCallback) + ), + detectorNarrow: await fetchResource( + detectorSources.detectorNarrow, + calculateDownloadProgres(3, 2, this.modelDownloadProgressCallback) + ), + }; + + await this.nativeModule.loadModule( + detectorPaths.detectorLarge, + detectorPaths.detectorNarrow, + recognizerPath, + symbols[language], + independentCharacters + ); + + this.isReady = true; + this.isReadyCallback(this.isReady); + } catch (e) { + if (this.errorCallback) { + this.errorCallback(getError(e)); + } else { + throw new Error(getError(e)); + } + } + }; + + public forward = async (input: string) => { + if (!this.isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (this.isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + this.isGenerating = true; + this.isGeneratingCallback(this.isGenerating); + return await this.nativeModule.forward(input); + } catch (e) { + throw new Error(getError(e)); + } finally { + this.isGenerating = false; + this.isGeneratingCallback(this.isGenerating); + } + }; +} diff --git a/src/hooks/computer_vision/useOCR.ts b/src/hooks/computer_vision/useOCR.ts index 56ee04e412..faa52a9cb5 100644 --- a/src/hooks/computer_vision/useOCR.ts +++ b/src/hooks/computer_vision/useOCR.ts @@ -1,11 +1,7 @@ import { useEffect, useState } from 'react'; -import { fetchResource } from '../../utils/fetchResource'; -import { languageDicts } from '../../constants/ocr/languageDicts'; -import { symbols } from '../../constants/ocr/symbols'; -import { getError, ETError } from '../../Error'; -import { OCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; -import { OCRDetection } from '../../types/ocr'; +import { OCRDetection, OCRLanguage } from '../../types/ocr'; +import { OCRController } from '../../controllers/OCRController'; interface OCRModule { error: string | null; @@ -26,84 +22,37 @@ export const useOCR = ({ recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; }; - language?: string; + language?: OCRLanguage; }): OCRModule => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); + const [model, _] = useState( + () => + new OCRController({ + modelDownloadProgressCallback: setDownloadProgress, + isReadyCallback: setIsReady, + isGeneratingCallback: setIsGenerating, + errorCallback: setError, + }) + ); + useEffect(() => { const loadModel = async () => { - try { - if (!detectorSource || Object.keys(recognizerSources).length === 0) - return; - - const recognizerPaths = {} as { - recognizerLarge: string; - recognizerMedium: string; - recognizerSmall: string; - }; - - if (!symbols[language] || !languageDicts[language]) { - setError(getError(ETError.LanguageNotSupported)); - return; - } - - const detectorPath = await fetchResource(detectorSource); - - await Promise.all([ - fetchResource(recognizerSources.recognizerLarge, setDownloadProgress), - fetchResource(recognizerSources.recognizerMedium), - fetchResource(recognizerSources.recognizerSmall), - ]).then((values) => { - recognizerPaths.recognizerLarge = values[0]; - recognizerPaths.recognizerMedium = values[1]; - recognizerPaths.recognizerSmall = values[2]; - }); - - setIsReady(false); - await OCR.loadModule( - detectorPath, - recognizerPaths.recognizerLarge, - recognizerPaths.recognizerMedium, - recognizerPaths.recognizerSmall, - symbols[language] - ); - setIsReady(true); - } catch (e) { - setError(getError(e)); - } + await model.loadModel(detectorSource, recognizerSources, language); }; loadModel(); // eslint-disable-next-line react-hooks/exhaustive-deps }, [detectorSource, language, JSON.stringify(recognizerSources)]); - const forward = async (input: string) => { - if (!isReady) { - throw new Error(getError(ETError.ModuleNotLoaded)); - } - if (isGenerating) { - throw new Error(getError(ETError.ModelGenerating)); - } - - try { - setIsGenerating(true); - const output = await OCR.forward(input); - return output; - } catch (e) { - throw new Error(getError(e)); - } finally { - setIsGenerating(false); - } - }; - return { error, isReady, isGenerating, - forward, + forward: model.forward, downloadProgress, }; }; diff --git a/src/hooks/computer_vision/useVerticalOCR.ts b/src/hooks/computer_vision/useVerticalOCR.ts new file mode 100644 index 0000000000..65e9ed7305 --- /dev/null +++ b/src/hooks/computer_vision/useVerticalOCR.ts @@ -0,0 +1,74 @@ +import { useEffect, useState } from 'react'; +import { ResourceSource } from '../../types/common'; +import { OCRDetection, OCRLanguage } from '../../types/ocr'; +import { VerticalOCRController } from '../../controllers/VerticalOCRController'; + +interface OCRModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: string) => Promise; + downloadProgress: number; +} + +export const useVerticalOCR = ({ + detectorSources, + recognizerSources, + language = 'en', + independentCharacters = false, +}: { + detectorSources: { + detectorLarge: ResourceSource; + detectorNarrow: ResourceSource; + }; + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerSmall: ResourceSource; + }; + language?: OCRLanguage; + independentCharacters?: boolean; +}): OCRModule => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); + + const [model, _] = useState( + () => + new VerticalOCRController({ + modelDownloadProgressCallback: setDownloadProgress, + isReadyCallback: setIsReady, + isGeneratingCallback: setIsGenerating, + errorCallback: setError, + }) + ); + + useEffect(() => { + const loadModel = async () => { + await model.loadModel( + detectorSources, + recognizerSources, + language, + independentCharacters + ); + }; + + loadModel(); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + // eslint-disable-next-line react-hooks/exhaustive-deps + JSON.stringify(detectorSources), + language, + independentCharacters, + // eslint-disable-next-line react-hooks/exhaustive-deps + JSON.stringify(recognizerSources), + ]); + + return { + error, + isReady, + isGenerating, + forward: model.forward, + downloadProgress, + }; +}; diff --git a/src/index.tsx b/src/index.tsx index f5bfa1854d..9d50e77610 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -3,6 +3,7 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; export * from './hooks/computer_vision/useOCR'; +export * from './hooks/computer_vision/useVerticalOCR'; export * from './hooks/natural_language_processing/useLLM'; @@ -13,6 +14,7 @@ export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; export * from './modules/computer_vision/OCRModule'; +export * from './modules/computer_vision/VerticalOCRModule'; export * from './modules/natural_language_processing/LLMModule'; diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts index 26ea6f4e89..c7a28ef622 100644 --- a/src/modules/computer_vision/OCRModule.ts +++ b/src/modules/computer_vision/OCRModule.ts @@ -1,11 +1,10 @@ -import { languageDicts } from '../../constants/ocr/languageDicts'; -import { symbols } from '../../constants/ocr/symbols'; -import { getError, ETError } from '../../Error'; -import { OCR } from '../../native/RnExecutorchModules'; +import { OCRController } from '../../controllers/OCRController'; import { ResourceSource } from '../../types/common'; -import { fetchResource } from '../../utils/fetchResource'; +import { OCRLanguage } from '../../types/ocr'; export class OCRModule { + static module: OCRController; + static onDownloadProgressCallback = (_downloadProgress: number) => {}; static async load( @@ -15,55 +14,17 @@ export class OCRModule { recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; }, - language = 'en' + language: OCRLanguage = 'en' ) { - try { - if (!detectorSource || Object.keys(recognizerSources).length === 0) - return; - - const recognizerPaths = {} as { - recognizerLarge: string; - recognizerMedium: string; - recognizerSmall: string; - }; - - if (!symbols[language] || !languageDicts[language]) { - throw new Error(getError(ETError.LanguageNotSupported)); - } - - const detectorPath = await fetchResource(detectorSource); - - await Promise.all([ - fetchResource( - recognizerSources.recognizerLarge, - this.onDownloadProgressCallback - ), - fetchResource(recognizerSources.recognizerMedium), - fetchResource(recognizerSources.recognizerSmall), - ]).then((values) => { - recognizerPaths.recognizerLarge = values[0]; - recognizerPaths.recognizerMedium = values[1]; - recognizerPaths.recognizerSmall = values[2]; - }); + this.module = new OCRController({ + modelDownloadProgressCallback: this.onDownloadProgressCallback, + }); - await OCR.loadModule( - detectorPath, - recognizerPaths.recognizerLarge, - recognizerPaths.recognizerMedium, - recognizerPaths.recognizerSmall, - symbols[language] - ); - } catch (e) { - throw new Error(getError(e)); - } + await this.module.loadModel(detectorSource, recognizerSources, language); } static async forward(input: string) { - try { - return await OCR.forward(input); - } catch (e) { - throw new Error(getError(e)); - } + return await this.module.forward(input); } static onDownloadProgress(callback: (downloadProgress: number) => void) { diff --git a/src/modules/computer_vision/VerticalOCRModule.ts b/src/modules/computer_vision/VerticalOCRModule.ts new file mode 100644 index 0000000000..4c8b1120de --- /dev/null +++ b/src/modules/computer_vision/VerticalOCRModule.ts @@ -0,0 +1,41 @@ +import { VerticalOCRController } from '../../controllers/VerticalOCRController'; +import { ResourceSource } from '../../types/common'; +import { OCRLanguage } from '../../types/ocr'; + +export class VerticalOCRModule { + static module: VerticalOCRController; + + static onDownloadProgressCallback = (_downloadProgress: number) => {}; + + static async load( + detectorSources: { + detectorLarge: ResourceSource; + detectorNarrow: ResourceSource; + }, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerSmall: ResourceSource; + }, + language: OCRLanguage = 'en', + independentCharacters: boolean = false + ) { + this.module = new VerticalOCRController({ + modelDownloadProgressCallback: this.onDownloadProgressCallback, + }); + + await this.module.loadModel( + detectorSources, + recognizerSources, + language, + independentCharacters + ); + } + + static async forward(input: string) { + return await this.module.forward(input); + } + + static onDownloadProgress(callback: (downloadProgress: number) => void) { + this.onDownloadProgressCallback = callback; + } +} diff --git a/src/native/NativeVerticalOCR.ts b/src/native/NativeVerticalOCR.ts new file mode 100644 index 0000000000..2aca8cbebc --- /dev/null +++ b/src/native/NativeVerticalOCR.ts @@ -0,0 +1,16 @@ +import type { TurboModule } from 'react-native'; +import { TurboModuleRegistry } from 'react-native'; +import { OCRDetection } from '../types/ocr'; + +export interface Spec extends TurboModule { + loadModule( + detectorLargeSource: string, + detectorNarrowSource: string, + recognizerSource: string, + symbols: string, + independentCharacters: boolean + ): Promise; + forward(input: string): Promise; +} + +export default TurboModuleRegistry.get('VerticalOCR'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index c8044aa473..49ac1e52a3 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -3,6 +3,8 @@ import { Spec as ClassificationInterface } from './NativeClassification'; import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; import { Spec as StyleTransferInterface } from './NativeStyleTransfer'; import { Spec as ETModuleInterface } from './NativeETModule'; +import { Spec as OCRInterface } from './NativeOCR'; +import { Spec as VerticalOCRInterface } from './NativeVerticalOCR'; const LINKING_ERROR = `The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` + @@ -101,6 +103,19 @@ const OCR = OCRSpec } ); +const VerticalOCRSpec = require('./NativeVerticalOCR').default; + +const VerticalOCR = VerticalOCRSpec + ? VerticalOCRSpec + : new Proxy( + {}, + { + get() { + throw new Error(LINKING_ERROR); + }, + } + ); + class _ObjectDetectionModule { async forward( input: string @@ -154,6 +169,50 @@ class _ClassificationModule { } } +class _OCRModule { + async forward(input: string): ReturnType { + return await OCR.forward(input); + } + + async loadModule( + detectorSource: string, + recognizerSourceLarge: string, + recognizerSourceMedium: string, + recognizerSourceSmall: string, + symbols: string + ) { + return await OCR.loadModule( + detectorSource, + recognizerSourceLarge, + recognizerSourceMedium, + recognizerSourceSmall, + symbols + ); + } +} + +class _VerticalOCRModule { + async forward(input: string): ReturnType { + return await VerticalOCR.forward(input); + } + + async loadModule( + detectorLargeSource: string, + detectorMediumSource: string, + recognizerSource: string, + symbols: string, + independentCharacters: boolean + ): ReturnType { + return await VerticalOCR.loadModule( + detectorLargeSource, + detectorMediumSource, + recognizerSource, + symbols, + independentCharacters + ); + } +} + class _ETModule { async forward( inputs: number[][], @@ -182,9 +241,12 @@ export { StyleTransfer, SpeechToText, OCR, + VerticalOCR, _ETModule, _ClassificationModule, _StyleTransferModule, _ObjectDetectionModule, _SpeechToTextModule, + _OCRModule, + _VerticalOCRModule, }; diff --git a/src/types/ocr.ts b/src/types/ocr.ts index f5f2e6d35e..f633265fc3 100644 --- a/src/types/ocr.ts +++ b/src/types/ocr.ts @@ -8,3 +8,5 @@ export interface OCRBbox { x: number; y: number; } + +export type OCRLanguage = 'en'; diff --git a/src/utils/fetchResource.ts b/src/utils/fetchResource.ts index 9885758ee5..ecaec03468 100644 --- a/src/utils/fetchResource.ts +++ b/src/utils/fetchResource.ts @@ -80,3 +80,21 @@ export const fetchResource = async ( return fileUri; }; + +export const calculateDownloadProgres = + ( + numberOfFiles: number, + currentFileIndex: number, + setProgress: (downloadProgress: number) => void + ) => + (progress: number) => { + if (progress === 1 && currentFileIndex === numberOfFiles - 1) { + setProgress(1); + return; + } + const contributionPerFile = 1 / numberOfFiles; + const baseProgress = contributionPerFile * currentFileIndex; + const scaledProgress = progress * contributionPerFile; + const updatedProgress = baseProgress + scaledProgress; + setProgress(updatedProgress); + };