Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6220832
feat: implementation of detector pre and post processing(ios)
NorbertKlockiewicz Jan 16, 2025
a0e65a4
fix: fixes to groupTextBox and getDetBox function to make it return s…
NorbertKlockiewicz Jan 17, 2025
4190ced
feat: implemented recognition (ios)
NorbertKlockiewicz Jan 22, 2025
d6d2e1d
fix: add missing function to ImageProcessor
NorbertKlockiewicz Jan 22, 2025
a2efb61
feat: finished recognition, added confidence score and bounding boxes…
NorbertKlockiewicz Jan 23, 2025
9d2c9a1
refactor: first part of native ocr code refactor
NorbertKlockiewicz Jan 23, 2025
22c7aa2
reformat: reformat of detector and recognizer code
NorbertKlockiewicz Jan 26, 2025
3db3181
refactor: split groupTextBoxes into smaller functions, add functions …
NorbertKlockiewicz Jan 27, 2025
b7cab0a
fix: add missing argument in header file
NorbertKlockiewicz Jan 27, 2025
ad81393
feat: automatically load list of words and symbols for converter
NorbertKlockiewicz Jan 27, 2025
7637be2
feat: add polish language support
NorbertKlockiewicz Jan 27, 2025
dfdc811
feat: implemented upgraded mid processing pipeline
NorbertKlockiewicz Feb 5, 2025
b7e0635
refactor: refactored code for mid processing pipeline, adjusted code …
NorbertKlockiewicz Feb 6, 2025
70ced9b
fix: add missing angle normalization
NorbertKlockiewicz Feb 6, 2025
659295f
reformat: fix formatting of long lines
NorbertKlockiewicz Feb 10, 2025
83580df
format: format with clang format
NorbertKlockiewicz Feb 13, 2025
9538fbf
feat: ocr(android) (#96)
NorbertKlockiewicz Feb 21, 2025
d009071
feat: implemented hookless api, also added fetching with expo file sy…
NorbertKlockiewicz Feb 24, 2025
ad57615
Merge branch 'main' into @norbertklockiewicz/ocr-implementation
NorbertKlockiewicz Feb 24, 2025
6850891
refactor: remove unnecessary _OCRModule
NorbertKlockiewicz Feb 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions android/src/main/java/com/swmansion/rnexecutorch/OCR.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
package com.swmansion.rnexecutorch

import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.utils.ETError
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.android.OpenCVLoader
import com.swmansion.rnexecutorch.models.ocr.Detector
import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import org.opencv.imgproc.Imgproc

class OCR(reactContext: ReactApplicationContext) :
NativeOCRSpec(reactContext) {

private lateinit var detector: Detector
private lateinit var recognitionHandler: RecognitionHandler

companion object {
const val NAME = "OCR"
}

init {
if (!OpenCVLoader.initLocal()) {
Log.d("rn_executorch", "OpenCV not loaded")
} else {
Log.d("rn_executorch", "OpenCV loaded")
}
}

override fun loadModule(
detectorSource: String,
recognizerSourceLarge: String,
recognizerSourceMedium: String,
recognizerSourceSmall: String,
symbols: String,
promise: Promise
) {
try {
detector = Detector(reactApplicationContext)
detector.loadModel(detectorSource)

recognitionHandler = RecognitionHandler(
symbols,
reactApplicationContext
)

recognitionHandler.loadRecognizers(
recognizerSourceLarge,
recognizerSourceMedium,
recognizerSourceSmall
) { _, errorRecognizer ->
if (errorRecognizer != null) {
throw Error(errorRecognizer.message!!)
}

promise.resolve(0)
}
} catch (e: Exception) {
promise.reject(e.message!!, ETError.InvalidModelSource.toString())
}
}

override fun forward(input: String, promise: Promise) {
try {
val inputImage = ImageProcessor.readImage(input)
val bBoxesList = detector.runModel(inputImage)
val detectorSize = detector.getModelImageSize()
Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY)
val result = recognitionHandler.recognize(
bBoxesList,
inputImage,
(detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(),
(detectorSize.height * Constants.RECOGNIZER_RATIO).toInt()
)
promise.resolve(result)
} catch (e: Exception) {
Log.d("rn_executorch", "Error running model: ${e.message}")
promise.reject(e.message!!, e.message)
}
}

override fun getName(): String {
return NAME
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class RnExecutorchPackage : TurboReactPackage() {
ObjectDetection(reactContext)
} else if (name == SpeechToText.NAME) {
SpeechToText(reactContext)
} else if (name == OCR.NAME){
OCR(reactContext)
}
else {
null
Expand Down Expand Up @@ -85,6 +87,15 @@ class RnExecutorchPackage : TurboReactPackage() {
false, // isCxxModule
true
)

moduleInfos[OCR.NAME] = ReactModuleInfo(
OCR.NAME,
OCR.NAME,
false, // canOverrideExistingModule
false, // needsEagerInit
false, // isCxxModule
true
)
moduleInfos
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.swmansion.rnexecutorch.models.ocr

import android.util.Log
import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Scalar
import org.opencv.core.Size
import org.pytorch.executorch.EValue

class Detector(reactApplicationContext: ReactApplicationContext) :
BaseModel<Mat, List<OCRbBox>>(reactApplicationContext) {
private lateinit var originalSize: Size

fun getModelImageSize(): Size {
val inputShape = module.getInputShape(0)
val width = inputShape[inputShape.lastIndex]
val height = inputShape[inputShape.lastIndex - 1]

val modelImageSize = Size(height.toDouble(), width.toDouble())

return modelImageSize
}

override fun preprocess(input: Mat): EValue {
originalSize = Size(input.cols().toDouble(), input.rows().toDouble())
val resizedImage = ImageProcessor.resizeWithPadding(
input,
getModelImageSize().width.toInt(),
getModelImageSize().height.toInt()
)

return ImageProcessor.matToEValue(
resizedImage,
module.getInputShape(0),
Constants.MEAN,
Constants.VARIANCE
)
}

override fun postprocess(output: Array<EValue>): List<OCRbBox> {
val outputTensor = output[0].toTensor()
val outputArray = outputTensor.dataAsFloatArray
val modelImageSize = getModelImageSize()

val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats(
outputArray,
Size(modelImageSize.width / 2, modelImageSize.height / 2)
)
var bBoxesList = DetectorUtils.getDetBoxesFromTextMap(
scoreText,
scoreLink,
Constants.TEXT_THRESHOLD,
Constants.LINK_THRESHOLD,
Constants.LOW_TEXT_THRESHOLD
)
bBoxesList =
DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat())
bBoxesList = DetectorUtils.groupTextBoxes(
bBoxesList,
Constants.CENTER_THRESHOLD,
Constants.DISTANCE_THRESHOLD,
Constants.HEIGHT_THRESHOLD,
Constants.MIN_SIDE_THRESHOLD,
Constants.MAX_SIDE_THRESHOLD,
Constants.MAX_WIDTH
)

return bBoxesList.toList()
}

override fun runModel(input: Mat): List<OCRbBox> {
return postprocess(forward(preprocess(input)))
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package com.swmansion.rnexecutorch.models.ocr

import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.WritableArray
import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter
import com.swmansion.rnexecutorch.models.ocr.utils.Constants
import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Core
import org.opencv.core.Mat

class RecognitionHandler(
symbols: String,
reactApplicationContext: ReactApplicationContext
) {
private val recognizerLarge = Recognizer(reactApplicationContext)
private val recognizerMedium = Recognizer(reactApplicationContext)
private val recognizerSmall = Recognizer(reactApplicationContext)
private val converter = CTCLabelConverter(symbols)

private fun runModel(croppedImage: Mat): Pair<List<Int>, Double> {
val result: Pair<List<Int>, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) {
recognizerLarge.runModel(croppedImage)
} else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) {
recognizerMedium.runModel(croppedImage)
} else {
recognizerSmall.runModel(croppedImage)
}

return result
}

fun loadRecognizers(
largeRecognizerPath: String,
mediumRecognizerPath: String,
smallRecognizerPath: String,
onComplete: (Int, Exception?) -> Unit
) {
try {
recognizerLarge.loadModel(largeRecognizerPath)
recognizerMedium.loadModel(mediumRecognizerPath)
recognizerSmall.loadModel(smallRecognizerPath)
onComplete(0, null)
} catch (e: Exception) {
onComplete(1, e)
}
}

fun recognize(
bBoxesList: List<OCRbBox>,
imgGray: Mat,
desiredWidth: Int,
desiredHeight: Int
): WritableArray {
val res: WritableArray = Arguments.createArray()
val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings(
imgGray.width(),
imgGray.height(),
desiredWidth,
desiredHeight
)

val left = ratioAndPadding["left"] as Int
val top = ratioAndPadding["top"] as Int
val resizeRatio = ratioAndPadding["resizeRatio"] as Float
val resizedImg = ImageProcessor.resizeWithPadding(
imgGray,
desiredWidth,
desiredHeight
)

for (box in bBoxesList) {
var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT)
if (croppedImage.empty()) {
continue
}

croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST)

var result = runModel(croppedImage)
var confidenceScore = result.second

if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) {
Core.rotate(croppedImage, croppedImage, Core.ROTATE_180)
val rotatedResult = runModel(croppedImage)
val rotatedConfidenceScore = rotatedResult.second
if (rotatedConfidenceScore > confidenceScore) {
result = rotatedResult
confidenceScore = rotatedConfidenceScore
}
}

val predIndex = result.first
val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size)

for (bBox in box.bBox) {
bBox.x = (bBox.x - left) * resizeRatio
bBox.y = (bBox.y - top) * resizeRatio
}

val resMap = Arguments.createMap()

resMap.putString("text", decodedTexts[0])
resMap.putArray("bbox", box.toWritableArray())
resMap.putDouble("confidence", confidenceScore)

res.pushMap(resMap)
}

return res
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.swmansion.rnexecutorch.models.ocr

import com.facebook.react.bridge.ReactApplicationContext
import com.swmansion.rnexecutorch.models.BaseModel
import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils
import com.swmansion.rnexecutorch.utils.ImageProcessor
import org.opencv.core.Mat
import org.opencv.core.Size
import org.pytorch.executorch.EValue

class Recognizer(reactApplicationContext: ReactApplicationContext) :
BaseModel<Mat, Pair<List<Int>, Double>>(reactApplicationContext) {

private fun getModelOutputSize(): Size {
val outputShape = module.getOutputShape(0)
val width = outputShape[outputShape.lastIndex]
val height = outputShape[outputShape.lastIndex - 1]

return Size(height.toDouble(), width.toDouble())
}

override fun preprocess(input: Mat): EValue {
return ImageProcessor.matToEValueGray(input)
}

override fun postprocess(output: Array<EValue>): Pair<List<Int>, Double> {
val modelOutputHeight = getModelOutputSize().height.toInt()
val tensor = output[0].toTensor().dataAsFloatArray
val numElements = tensor.size
val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight
val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F)
var counter = 0
var currentRow = 0
for (num in tensor) {
resultMat.put(currentRow, counter, floatArrayOf(num))
counter++
if (counter >= modelOutputHeight) {
counter = 0
currentRow++
}
}

var probabilities = RecognizerUtils.softmax(resultMat)
val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight)
probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm)
val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities)

val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices)
return Pair(indices, confidenceScore)
}


override fun runModel(input: Mat): Pair<List<Int>, Double> {
return postprocess(module.forward(preprocess(input)))
}
}
Loading