-
Notifications
You must be signed in to change notification settings - Fork 74
feat: ocr(android) #96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,101 @@ | ||
| package com.swmansion.rnexecutorch | ||
|
|
||
| import android.util.Log | ||
| import com.facebook.react.bridge.Promise | ||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.swmansion.rnexecutorch.utils.ETError | ||
| import com.swmansion.rnexecutorch.utils.ImageProcessor | ||
| import org.opencv.android.OpenCVLoader | ||
| import com.swmansion.rnexecutorch.models.ocr.Detector | ||
| import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.Constants | ||
| import com.swmansion.rnexecutorch.utils.Fetcher | ||
| import com.swmansion.rnexecutorch.utils.ResourceType | ||
| import org.opencv.imgproc.Imgproc | ||
|
|
||
| class OCR(reactContext: ReactApplicationContext) : | ||
| NativeOCRSpec(reactContext) { | ||
|
|
||
| private lateinit var detector: Detector | ||
| private lateinit var recognitionHandler: RecognitionHandler | ||
|
|
||
| companion object { | ||
| const val NAME = "OCR" | ||
| } | ||
|
|
||
| init { | ||
| if (!OpenCVLoader.initLocal()) { | ||
| Log.d("rn_executorch", "OpenCV not loaded") | ||
| } else { | ||
| Log.d("rn_executorch", "OpenCV loaded") | ||
| } | ||
| } | ||
|
|
||
| override fun loadModule( | ||
| detectorSource: String, | ||
| recognizerSourceLarge: String, | ||
| recognizerSourceMedium: String, | ||
| recognizerSourceSmall: String, | ||
| symbols: String, | ||
| languageDictPath: String, | ||
| promise: Promise | ||
| ) { | ||
| try { | ||
| detector = Detector(reactApplicationContext) | ||
| detector.loadModel(detectorSource) | ||
| Fetcher.downloadResource( | ||
| reactApplicationContext, | ||
| languageDictPath, | ||
| ResourceType.TXT, | ||
| false, | ||
| { path, error -> | ||
| if (error != null) { | ||
| throw Error(error.message!!) | ||
| } | ||
|
|
||
| recognitionHandler = RecognitionHandler( | ||
| symbols, | ||
| path!!, | ||
| reactApplicationContext | ||
| ) | ||
|
|
||
| recognitionHandler.loadRecognizers( | ||
| recognizerSourceLarge, | ||
| recognizerSourceMedium, | ||
| recognizerSourceSmall | ||
| ) { _, errorRecognizer -> | ||
| if (errorRecognizer != null) { | ||
| throw Error(errorRecognizer.message!!) | ||
| } | ||
|
|
||
| promise.resolve(0) | ||
| } | ||
| }) | ||
| } catch (e: Exception) { | ||
| promise.reject(e.message!!, ETError.InvalidModelSource.toString()) | ||
| } | ||
| } | ||
|
|
||
| override fun forward(input: String, promise: Promise) { | ||
| try { | ||
| val inputImage = ImageProcessor.readImage(input) | ||
| val bBoxesList = detector.runModel(inputImage) | ||
| val detectorSize = detector.getModelImageSize() | ||
| Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY) | ||
| val result = recognitionHandler.recognize( | ||
| bBoxesList, | ||
| inputImage, | ||
| (detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(), | ||
| (detectorSize.height * Constants.RECOGNIZER_RATIO).toInt() | ||
| ) | ||
| promise.resolve(result) | ||
| } catch (e: Exception) { | ||
| Log.d("rn_executorch", "Error running model: ${e.message}") | ||
| promise.reject(e.message!!, e.message) | ||
| } | ||
| } | ||
|
|
||
| override fun getName(): String { | ||
| return NAME | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| package com.swmansion.rnexecutorch.models.ocr | ||
|
|
||
| import android.util.Log | ||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.swmansion.rnexecutorch.models.BaseModel | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.Constants | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox | ||
| import com.swmansion.rnexecutorch.utils.ImageProcessor | ||
| import org.opencv.core.Mat | ||
| import org.opencv.core.Scalar | ||
| import org.opencv.core.Size | ||
| import org.pytorch.executorch.EValue | ||
|
|
||
| class Detector(reactApplicationContext: ReactApplicationContext) : | ||
| BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) { | ||
| private lateinit var originalSize: Size | ||
|
|
||
| fun getModelImageSize(): Size { | ||
| val inputShape = module.getInputShape(0) | ||
| val width = inputShape[inputShape.lastIndex] | ||
| val height = inputShape[inputShape.lastIndex - 1] | ||
|
|
||
| val modelImageSize = Size(height.toDouble(), width.toDouble()) | ||
|
|
||
| return modelImageSize | ||
| } | ||
|
|
||
| override fun preprocess(input: Mat): EValue { | ||
| originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) | ||
| val resizedImage = ImageProcessor.resizeWithPadding( | ||
| input, | ||
| getModelImageSize().width.toInt(), | ||
| getModelImageSize().height.toInt() | ||
| ) | ||
|
|
||
| return ImageProcessor.matToEValue( | ||
| resizedImage, | ||
| module.getInputShape(0), | ||
| Constants.MEAN, | ||
| Constants.VARIANCE | ||
| ) | ||
| } | ||
|
|
||
| override fun postprocess(output: Array<EValue>): List<OCRbBox> { | ||
| val outputTensor = output[0].toTensor() | ||
| val outputArray = outputTensor.dataAsFloatArray | ||
| val modelImageSize = getModelImageSize() | ||
|
|
||
| val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( | ||
| outputArray, | ||
| Size(modelImageSize.width / 2, modelImageSize.height / 2) | ||
| ) | ||
| var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( | ||
| scoreText, | ||
| scoreLink, | ||
| Constants.TEXT_THRESHOLD, | ||
| Constants.LINK_THRESHOLD, | ||
| Constants.LOW_TEXT_THRESHOLD | ||
| ) | ||
| bBoxesList = | ||
| DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) | ||
| bBoxesList = DetectorUtils.groupTextBoxes( | ||
| bBoxesList, | ||
| Constants.CENTER_THRESHOLD, | ||
| Constants.DISTANCE_THRESHOLD, | ||
| Constants.HEIGHT_THRESHOLD, | ||
| Constants.MIN_SIDE_THRESHOLD, | ||
| Constants.MAX_SIDE_THRESHOLD, | ||
| Constants.MAX_WIDTH | ||
| ) | ||
|
|
||
| return bBoxesList.toList() | ||
| } | ||
|
|
||
| override fun runModel(input: Mat): List<OCRbBox> { | ||
| return postprocess(forward(preprocess(input))) | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| package com.swmansion.rnexecutorch.models.ocr | ||
|
|
||
| import com.facebook.react.bridge.Arguments | ||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.facebook.react.bridge.WritableArray | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.Constants | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils | ||
| import com.swmansion.rnexecutorch.utils.ImageProcessor | ||
| import org.opencv.core.Core | ||
| import org.opencv.core.Mat | ||
|
|
||
| class RecognitionHandler( | ||
| symbols: String, | ||
| languageDictPath: String, | ||
| reactApplicationContext: ReactApplicationContext | ||
| ) { | ||
| private val recognizerLarge = Recognizer(reactApplicationContext) | ||
| private val recognizerMedium = Recognizer(reactApplicationContext) | ||
| private val recognizerSmall = Recognizer(reactApplicationContext) | ||
| private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key")) | ||
|
|
||
| private fun runModel(croppedImage: Mat): Pair<List<Int>, Double> { | ||
| val result: Pair<List<Int>, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) { | ||
| recognizerLarge.runModel(croppedImage) | ||
| } else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) { | ||
| recognizerMedium.runModel(croppedImage) | ||
| } else { | ||
| recognizerSmall.runModel(croppedImage) | ||
| } | ||
|
|
||
| return result | ||
| } | ||
|
|
||
| fun loadRecognizers( | ||
| largeRecognizerPath: String, | ||
| mediumRecognizerPath: String, | ||
| smallRecognizerPath: String, | ||
| onComplete: (Int, Exception?) -> Unit | ||
| ) { | ||
| try { | ||
| recognizerLarge.loadModel(largeRecognizerPath) | ||
| recognizerMedium.loadModel(mediumRecognizerPath) | ||
| recognizerSmall.loadModel(smallRecognizerPath) | ||
| onComplete(0, null) | ||
| } catch (e: Exception) { | ||
| onComplete(1, e) | ||
| } | ||
| } | ||
|
|
||
| fun recognize( | ||
| bBoxesList: List<OCRbBox>, | ||
| imgGray: Mat, | ||
| desiredWidth: Int, | ||
| desiredHeight: Int | ||
| ): WritableArray { | ||
| val res: WritableArray = Arguments.createArray() | ||
| val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings( | ||
| imgGray.width(), | ||
| imgGray.height(), | ||
| desiredWidth, | ||
| desiredHeight | ||
| ) | ||
|
|
||
| val left = ratioAndPadding["left"] as Int | ||
| val top = ratioAndPadding["top"] as Int | ||
| val resizeRatio = ratioAndPadding["resizeRatio"] as Float | ||
| val resizedImg = ImageProcessor.resizeWithPadding( | ||
| imgGray, | ||
| desiredWidth, | ||
| desiredHeight | ||
| ) | ||
|
|
||
| for (box in bBoxesList) { | ||
| var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT) | ||
| if (croppedImage.empty()) { | ||
| continue | ||
| } | ||
|
|
||
| croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST) | ||
|
|
||
| var result = runModel(croppedImage) | ||
| var confidenceScore = result.second | ||
|
|
||
| if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) { | ||
| Core.rotate(croppedImage, croppedImage, Core.ROTATE_180) | ||
| val rotatedResult = runModel(croppedImage) | ||
| val rotatedConfidenceScore = rotatedResult.second | ||
| if (rotatedConfidenceScore > confidenceScore) { | ||
| result = rotatedResult | ||
| confidenceScore = rotatedConfidenceScore | ||
| } | ||
|
Comment on lines
+90
to
+93
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we skip processing if confidence is still below
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sometimes the confidence score is still low and the result is correct, we are returning a confidence score to user so I think we should leave handling those cases for him |
||
| } | ||
|
|
||
| val predIndex = result.first | ||
| val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size) | ||
|
|
||
| for (bBox in box.bBox) { | ||
| bBox.x = (bBox.x - left) * resizeRatio | ||
| bBox.y = (bBox.y - top) * resizeRatio | ||
| } | ||
|
|
||
| val resMap = Arguments.createMap() | ||
|
|
||
| resMap.putString("text", decodedTexts[0]) | ||
| resMap.putArray("bbox", box.toWritableArray()) | ||
| resMap.putDouble("confidence", confidenceScore) | ||
|
|
||
| res.pushMap(resMap) | ||
| } | ||
|
|
||
| return res | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| package com.swmansion.rnexecutorch.models.ocr | ||
|
|
||
| import com.facebook.react.bridge.ReactApplicationContext | ||
| import com.swmansion.rnexecutorch.models.BaseModel | ||
| import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils | ||
| import com.swmansion.rnexecutorch.utils.ImageProcessor | ||
| import org.opencv.core.Mat | ||
| import org.opencv.core.Size | ||
| import org.pytorch.executorch.EValue | ||
|
|
||
| class Recognizer(reactApplicationContext: ReactApplicationContext) : | ||
| BaseModel<Mat, Pair<List<Int>, Double>>(reactApplicationContext) { | ||
|
|
||
| private fun getModelOutputSize(): Size { | ||
| val outputShape = module.getOutputShape(0) | ||
| val width = outputShape[outputShape.lastIndex] | ||
| val height = outputShape[outputShape.lastIndex - 1] | ||
|
|
||
| return Size(height.toDouble(), width.toDouble()) | ||
| } | ||
|
|
||
| override fun preprocess(input: Mat): EValue { | ||
| return ImageProcessor.matToEValueGray(input) | ||
| } | ||
|
|
||
| override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> { | ||
| val modelOutputHeight = getModelOutputSize().height.toInt() | ||
| val tensor = output[0].toTensor().dataAsFloatArray | ||
| val numElements = tensor.size | ||
| val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight | ||
| val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F) | ||
| var counter = 0 | ||
| var currentRow = 0 | ||
| for (num in tensor) { | ||
| resultMat.put(currentRow, counter, floatArrayOf(num)) | ||
| counter++ | ||
| if (counter >= modelOutputHeight) { | ||
| counter = 0 | ||
| currentRow++ | ||
| } | ||
| } | ||
|
|
||
| var probabilities = RecognizerUtils.softmax(resultMat) | ||
| val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight) | ||
| probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm) | ||
| val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities) | ||
|
|
||
| val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices) | ||
| return Pair(indices, confidenceScore) | ||
| } | ||
|
|
||
|
|
||
| override fun runModel(input: Mat): Pair<List<Int>, Double> { | ||
| return postprocess(module.forward(preprocess(input))) | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we tell from error which model failed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without separate try catches for every model I don't think so