Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/OCR.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.swmansion.rnexecutorch.models.ocr.Detector
import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.utils.Fetcher
import com.swmansion.rnexecutorch.utils.ResourceType
import org.opencv.imgproc.Imgproc

class OCR(reactContext: ReactApplicationContext) :
NativeOCRSpec(reactContext) {

private lateinit var detector: Detector
private lateinit var recognitionHandler: RecognitionHandler

companion object {
const val NAME = "OCR"
}

init {
if (!OpenCVLoader.initLocal()) {
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}

override fun loadModule(
detectorSource: String,
recognizerSourceLarge: String,
recognizerSourceMedium: String,
recognizerSourceSmall: String,
symbols: String,
languageDictPath: String,
promise: Promise
) {
try {
detector = Detector(reactApplicationContext)
detector.loadModel(detectorSource)
Fetcher.downloadResource(
reactApplicationContext,
languageDictPath,
ResourceType.TXT,
false,
{ path, error ->
if (error != null) {
throw Error(error.message!!)
}

recognitionHandler = RecognitionHandler(
symbols,
path!!,
reactApplicationContext
)

recognitionHandler.loadRecognizers(
recognizerSourceLarge,
recognizerSourceMedium,
recognizerSourceSmall
) { _, errorRecognizer ->
if (errorRecognizer != null) {
throw Error(errorRecognizer.message!!)
}

promise.resolve(0)
}
})
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val inputImage = ImageProcessor.readImage(input)
val bBoxesList = detector.runModel(inputImage)
val detectorSize = detector.getModelImageSize()
Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY)
val result = recognitionHandler.recognize(
bBoxesList,
inputImage,
(detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(),
(detectorSize.height * Constants.RECOGNIZER_RATIO).toInt()
)
promise.resolve(result)
} catch (e: Exception) {
Log.d("rn_executorch", "Error running model: ${e.message}")
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class RnExecutorchPackage : TurboReactPackage() {
Classification(reactContext)
} else if (name == ObjectDetection.NAME) {
ObjectDetection(reactContext)
} else if (name == OCR.NAME){
OCR(reactContext)
}
else {
null
Expand Down Expand Up @@ -74,6 +76,14 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)
moduleInfos[OCR.NAME] = ReactModuleInfo(
OCR.NAME,
OCR.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.swmansion.rnexecutorch.models.ocr

import android.util.Log
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Scalar
import org.opencv.core.Size
import org.pytorch.executorch.EValue

class Detector(reactApplicationContext: ReactApplicationContext) :
BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
private lateinit var originalSize: Size

fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

val modelImageSize = Size(height.toDouble(), width.toDouble())

return modelImageSize
}

override fun preprocess(input: Mat): EValue {
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
val resizedImage = ImageProcessor.resizeWithPadding(
input,
getModelImageSize().width.toInt(),
getModelImageSize().height.toInt()
)

return ImageProcessor.matToEValue(
resizedImage,
module.getInputShape(0),
Constants.MEAN,
Constants.VARIANCE
)
}

override fun postprocess(output: Array<EValue>): List<OCRbBox> {
val outputTensor = output[0].toTensor()
val outputArray = outputTensor.dataAsFloatArray
val modelImageSize = getModelImageSize()

val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats(
outputArray,
Size(modelImageSize.width / 2, modelImageSize.height / 2)
)
var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(
scoreText,
scoreLink,
Constants.TEXT_THRESHOLD,
Constants.LINK_THRESHOLD,
Constants.LOW_TEXT_THRESHOLD
)
bBoxesList =
DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat())
bBoxesList = DetectorUtils.groupTextBoxes(
bBoxesList,
Constants.CENTER_THRESHOLD,
Constants.DISTANCE_THRESHOLD,
Constants.HEIGHT_THRESHOLD,
Constants.MIN_SIDE_THRESHOLD,
Constants.MAX_SIDE_THRESHOLD,
Constants.MAX_WIDTH
)

return bBoxesList.toList()
}

override fun runModel(input: Mat): List<OCRbBox> {
return postprocess(forward(preprocess(input)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package com.swmansion.rnexecutorch.models.ocr

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Core
import org.opencv.core.Mat

class RecognitionHandler(
symbols: String,
languageDictPath: String,
reactApplicationContext: ReactApplicationContext
) {
private val recognizerLarge = Recognizer(reactApplicationContext)
private val recognizerMedium = Recognizer(reactApplicationContext)
private val recognizerSmall = Recognizer(reactApplicationContext)
private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key"))

private fun runModel(croppedImage: Mat): Pair<List<Int>, Double> {
val result: Pair<List<Int>, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) {
recognizerLarge.runModel(croppedImage)
} else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) {
recognizerMedium.runModel(croppedImage)
} else {
recognizerSmall.runModel(croppedImage)
}

return result
}

fun loadRecognizers(
largeRecognizerPath: String,
mediumRecognizerPath: String,
smallRecognizerPath: String,
onComplete: (Int, Exception?) -> Unit
) {
try {
recognizerLarge.loadModel(largeRecognizerPath)
recognizerMedium.loadModel(mediumRecognizerPath)
recognizerSmall.loadModel(smallRecognizerPath)
onComplete(0, null)
} catch (e: Exception) {
onComplete(1, e)
}
Comment on lines +47 to +49
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we tell from error which model failed?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without separate try catches for every model I don't think so

}

fun recognize(
bBoxesList: List<OCRbBox>,
imgGray: Mat,
desiredWidth: Int,
desiredHeight: Int
): WritableArray {
val res: WritableArray = Arguments.createArray()
val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings(
imgGray.width(),
imgGray.height(),
desiredWidth,
desiredHeight
)

val left = ratioAndPadding["left"] as Int
val top = ratioAndPadding["top"] as Int
val resizeRatio = ratioAndPadding["resizeRatio"] as Float
val resizedImg = ImageProcessor.resizeWithPadding(
imgGray,
desiredWidth,
desiredHeight
)

for (box in bBoxesList) {
var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT)
if (croppedImage.empty()) {
continue
}

croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST)

var result = runModel(croppedImage)
var confidenceScore = result.second

if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) {
Core.rotate(croppedImage, croppedImage, Core.ROTATE_180)
val rotatedResult = runModel(croppedImage)
val rotatedConfidenceScore = rotatedResult.second
if (rotatedConfidenceScore > confidenceScore) {
result = rotatedResult
confidenceScore = rotatedConfidenceScore
}
Comment on lines +90 to +93
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we skip processing if confidence is still below Constants.LOW_CONFIDENCE_THRESHOLD?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sometimes the confidence score is still low and the result is correct, we are returning a confidence score to user so I think we should leave handling those cases for him

}

val predIndex = result.first
val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size)

for (bBox in box.bBox) {
bBox.x = (bBox.x - left) * resizeRatio
bBox.y = (bBox.y - top) * resizeRatio
}

val resMap = Arguments.createMap()

resMap.putString("text", decodedTexts[0])
resMap.putArray("bbox", box.toWritableArray())
resMap.putDouble("confidence", confidenceScore)

res.pushMap(resMap)
}

return res
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.swmansion.rnexecutorch.models.ocr

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.pytorch.executorch.EValue

class Recognizer(reactApplicationContext: ReactApplicationContext) :
BaseModel<Mat, Pair<List<Int>, Double>>(reactApplicationContext) {

private fun getModelOutputSize(): Size {
val outputShape = module.getOutputShape(0)
val width = outputShape[outputShape.lastIndex]
val height = outputShape[outputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
return ImageProcessor.matToEValueGray(input)
}

override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
val modelOutputHeight = getModelOutputSize().height.toInt()
val tensor = output[0].toTensor().dataAsFloatArray
val numElements = tensor.size
val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight
val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F)
var counter = 0
var currentRow = 0
for (num in tensor) {
resultMat.put(currentRow, counter, floatArrayOf(num))
counter++
if (counter >= modelOutputHeight) {
counter = 0
currentRow++
}
}

var probabilities = RecognizerUtils.softmax(resultMat)
val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight)
probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm)
val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities)

val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices)
return Pair(indices, confidenceScore)
}


override fun runModel(input: Mat): Pair<List<Int>, Double> {
return postprocess(module.forward(preprocess(input)))
}
}
Loading