From 6220832021d515718930d88a1e63e51c77e0b65f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 16 Jan 2025 15:20:58 +0100 Subject: [PATCH 01/19] feat: implementation of detector pre and post processing(ios) --- examples/computer-vision/App.tsx | 5 + .../computer-vision/screens/OCRScreen.tsx | 82 +++++ ios/RnExecutorch.xcodeproj/project.pbxproj | 14 + ios/RnExecutorch/OCR.h | 5 + ios/RnExecutorch/OCR.mm | 59 ++++ ios/RnExecutorch/models/ocr/Detector.h | 11 + ios/RnExecutorch/models/ocr/Detector.mm | 75 +++++ .../models/ocr/utils/DetectorUtils.h | 16 + .../models/ocr/utils/DetectorUtils.mm | 279 ++++++++++++++++++ ios/RnExecutorch/models/ocr/utils/OCRUtils.h | 7 + ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 50 ++++ src/OCR.ts | 77 +++++ src/native/NativeOCR.ts | 13 + src/native/RnExecutorchModules.ts | 14 + 14 files changed, 707 insertions(+) create mode 100644 examples/computer-vision/screens/OCRScreen.tsx create mode 100644 ios/RnExecutorch/OCR.h create mode 100644 ios/RnExecutorch/OCR.mm create mode 100644 ios/RnExecutorch/models/ocr/Detector.h create mode 100644 ios/RnExecutorch/models/ocr/Detector.mm create mode 100644 ios/RnExecutorch/models/ocr/utils/DetectorUtils.h create mode 100644 ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm create mode 100644 ios/RnExecutorch/models/ocr/utils/OCRUtils.h create mode 100644 ios/RnExecutorch/models/ocr/utils/OCRUtils.mm create mode 100644 src/OCR.ts create mode 100644 src/native/NativeOCR.ts diff --git a/examples/computer-vision/App.tsx b/examples/computer-vision/App.tsx index 8d01269fd0..488c61cd56 100644 --- a/examples/computer-vision/App.tsx +++ b/examples/computer-vision/App.tsx @@ -8,11 +8,13 @@ import { SafeAreaProvider, SafeAreaView } from 'react-native-safe-area-context'; import { View, StyleSheet } from 'react-native'; import { ClassificationScreen } from './screens/ClassificationScreen'; import { ObjectDetectionScreen } from './screens/ObjectDetectionScreen'; +import { OCRScreen } from './screens/OCRScreen'; enum ModelType { STYLE_TRANSFER, OBJECT_DETECTION, CLASSIFICATION, + OCR, } export default function App() { @@ -46,6 +48,8 @@ export default function App() { return ( ); + case ModelType.OCR: + return ; default: return ( @@ -64,6 +68,7 @@ export default function App() { 'Style Transfer', 'Object Detection', 'Classification', + 'OCR', ]} onValueChange={(_, selectedIndex) => { handleModeChange(selectedIndex); diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx new file mode 100644 index 0000000000..1493c30360 --- /dev/null +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -0,0 +1,82 @@ +import Spinner from 'react-native-loading-spinner-overlay'; +import { BottomBar } from '../components/BottomBar'; +import { getImage } from '../utils'; +import { useOCR } from 'react-native-executorch'; +import { View, StyleSheet, Image } from 'react-native'; + +export const OCRScreen = ({ + imageUri, + setImageUri, +}: { + imageUri: string; + setImageUri: (imageUri: string) => void; +}) => { + const model = useOCR({ + detectorSource: require('../assets/models/xnnpack_craft.pte'), + recognizerSources: [require('../assets/models/xnnpack_crnn_128.pte')], + }); + + const handleCameraPress = async (isCamera: boolean) => { + const image = await getImage(isCamera); + const uri = image?.uri; + if (typeof uri === 'string') { + setImageUri(uri as string); + } + }; + + const shape = [1, 1, 64, 128]; + const input = new Float32Array(shape[1] * shape[2] * shape[3]); + + for (let i = 0; i < shape[1] * shape[2] * shape[3]; i++) { + input[i] = Math.random() * 255; + } + + const runForward = async () => { + try { + const output = await model.forward(imageUri); + console.log(output[0]); + console.log(output[1]); + } catch (e) { + console.error(e); + } + }; + + if (!model.isReady) { + return ( + + ); + } + + return ( + <> + + + + + + ); +}; + +const styles = StyleSheet.create({ + imageContainer: { + flex: 6, + width: '100%', + padding: 16, + }, + image: { + flex: 1, + borderRadius: 8, + width: '100%', + }, +}); diff --git a/ios/RnExecutorch.xcodeproj/project.pbxproj b/ios/RnExecutorch.xcodeproj/project.pbxproj index 3fad88ed1c..68e367a8e8 100644 --- a/ios/RnExecutorch.xcodeproj/project.pbxproj +++ b/ios/RnExecutorch.xcodeproj/project.pbxproj @@ -35,12 +35,20 @@ LLM.h, ); }; + 552754CC2D394AC9006B38A2 /* Exceptions for "RnExecutorch" folder in "Compile Sources" phase from "RnExecutorch" target */ = { + isa = PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet; + buildPhase = 550986852CEF541900FECBB8 /* Sources */; + membershipExceptions = ( + models/ocr/utils/DetectorUtils.h, + ); + }; /* End PBXFileSystemSynchronizedGroupBuildPhaseMembershipExceptionSet section */ /* Begin PBXFileSystemSynchronizedRootGroup section */ 5509868B2CEF541900FECBB8 /* RnExecutorch */ = { isa = PBXFileSystemSynchronizedRootGroup; exceptions = ( + 552754CC2D394AC9006B38A2 /* Exceptions for "RnExecutorch" folder in "Compile Sources" phase from "RnExecutorch" target */, 550986902CEF541900FECBB8 /* Exceptions for "RnExecutorch" folder in "Copy Files" phase from "RnExecutorch" target */, ); path = RnExecutorch; @@ -119,6 +127,7 @@ TargetAttributes = { 550986882CEF541900FECBB8 = { CreatedOnToolsVersion = 16.1; + LastSwiftMigration = 1610; }; }; }; @@ -271,6 +280,7 @@ 550986942CEF541900FECBB8 /* Debug */ = { isa = XCBuildConfiguration; buildSettings = { + CLANG_ENABLE_MODULES = YES; CODE_SIGN_STYLE = Automatic; OTHER_LDFLAGS = "-ObjC"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -279,6 +289,8 @@ SUPPORTS_MACCATALYST = NO; SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; + SWIFT_OPTIMIZATION_LEVEL = "-Onone"; + SWIFT_VERSION = 6.0; TARGETED_DEVICE_FAMILY = "1,2"; }; name = Debug; @@ -286,6 +298,7 @@ 550986952CEF541900FECBB8 /* Release */ = { isa = XCBuildConfiguration; buildSettings = { + CLANG_ENABLE_MODULES = YES; CODE_SIGN_STYLE = Automatic; OTHER_LDFLAGS = "-ObjC"; PRODUCT_NAME = "$(TARGET_NAME)"; @@ -294,6 +307,7 @@ SUPPORTS_MACCATALYST = NO; SUPPORTS_MAC_DESIGNED_FOR_IPHONE_IPAD = NO; SUPPORTS_XR_DESIGNED_FOR_IPHONE_IPAD = NO; + SWIFT_VERSION = 6.0; TARGETED_DEVICE_FAMILY = "1,2"; }; name = Release; diff --git a/ios/RnExecutorch/OCR.h b/ios/RnExecutorch/OCR.h new file mode 100644 index 0000000000..4994108bce --- /dev/null +++ b/ios/RnExecutorch/OCR.h @@ -0,0 +1,5 @@ +#import + +@interface OCR : NSObject + +@end diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm new file mode 100644 index 0000000000..a35db266d4 --- /dev/null +++ b/ios/RnExecutorch/OCR.mm @@ -0,0 +1,59 @@ +#import "OCR.h" +#import "models/object_detection/SSDLiteLargeModel.hpp" +#import +#import +#import "utils/ImageProcessor.h" +#import "models/ocr/Detector.h" + +@implementation OCR { + Detector *detector; +} + +RCT_EXPORT_MODULE() + +- (void)loadModule:(NSString *)detectorSource + recognizerSources:(NSArray *)recognizerSources + language:(NSString *)language + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + NSLog(@"TEST"); + detector = [[Detector alloc] init]; + [detector loadModel:[NSURL URLWithString:detectorSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (success) { + resolve(errorCode); + return; + } + + NSError *error = [NSError + errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{ + NSLocalizedDescriptionKey : [NSString + stringWithFormat:@"%ld", (long)[errorCode longValue]] + }]; + reject(@"init_module_error", error.localizedDescription, error); + return; + }]; +} + +- (void)forward:(NSString *)input + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { + @try { + cv::Mat image = [ImageProcessor readImage:input]; + NSArray* result = [detector runModel:image]; + resolve(result); + } @catch (NSException *exception) { + reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason], + nil); + } +} + +- (std::shared_ptr)getTurboModule: +(const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared( + params); +} + +@end diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h new file mode 100644 index 0000000000..1c8619bd38 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -0,0 +1,11 @@ +#import "BaseModel.h" +#import "opencv2/opencv.hpp" + +@interface Detector : BaseModel + +- (cv::Size)getModelImageSize; +- (NSArray *)preprocess:(cv::Mat &)input; +- (NSArray *)postprocess:(NSArray *)output; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm new file mode 100644 index 0000000000..2c91da8301 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -0,0 +1,75 @@ +#import "opencv2/opencv.hpp" +#import "Detector.h" +#import "../../utils/ImageProcessor.h" +#import "utils/DetectorUtils.h" +#import "utils/OCRUtils.h" + +@implementation Detector { + cv::Size originalSize; +} + +- (cv::Size)getModelImageSize{ + NSArray * inputShape = [module getInputShape: @0]; + NSNumber *widthNumber = inputShape.lastObject; + NSNumber *heightNumber = inputShape[inputShape.count - 2]; + + int height = [heightNumber intValue]; + int width = [widthNumber intValue]; + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + self->originalSize = cv::Size(input.cols, input.rows); + + cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat resizedImage; + resizedImage = [OCRUtils resizeWithPadding:input desiredWidth:modelImageSize.width desiredHeight:modelImageSize.height]; + + NSArray *modelInput = [DetectorUtils matToNSArray: resizedImage]; + return modelInput; +} + +- (NSArray *)postprocess:(NSArray *)output { + NSArray *predictions = [output objectAtIndex:0]; + + NSDictionary *splittedData = [DetectorUtils splitInterleavedNSArray:predictions]; + NSArray *scoreText = splittedData[@"ScoreText"]; + NSArray *scoreLink = splittedData[@"ScoreLink"]; + + cv::Mat scoreTextCV; + cv::Mat scoreLinkCV; + + scoreTextCV = [DetectorUtils arrayToMat:scoreText width:640 height:640]; + scoreLinkCV = [DetectorUtils arrayToMat:scoreLink width:640 height:640]; + + NSArray* boxes = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:0.7 linkThreshold:0.4 lowText:0.4]; + NSMutableArray *single_img_result = [NSMutableArray array]; + for (NSUInteger i = 0; i < [boxes count]; i++) { + NSArray *box = boxes[i]; + NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; + // Iterate over each point in the box + for (NSValue *value in box) { + CGPoint point = [value CGPointValue]; + point.x *= 2; + point.y *= 2; + [boxArray addObject:@((int)point.x)]; + [boxArray addObject:@((int)point.y)]; + } + [single_img_result addObject:boxArray]; + } + + NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result slopeThs:0.1 ycenterThs:0.5 heightThs:0.5 widthThs:1.0 addMargin:0.1]; + + return horizontalList; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray *modelInput = [self preprocess:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + NSLog(@"Running Inference with detector model"); + + return result; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h new file mode 100644 index 0000000000..a6b25969d1 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -0,0 +1,16 @@ +#import + +@interface DetectorUtils : NSObject + ++ (NSArray *)matToNSArray:(const cv::Mat &)mat; ++ (NSDictionary *)splitInterleavedNSArray:(NSArray *)array; ++ (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height; ++ (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; ++ (NSArray *> *)groupTextBox:(NSArray *> *)polys + slopeThs:(CGFloat)slopeThs + ycenterThs:(CGFloat)ycenterThs + heightThs:(CGFloat)heightThs + widthThs:(CGFloat)widthThs + addMargin:(CGFloat)addMargin; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm new file mode 100644 index 0000000000..8b6edb12c1 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -0,0 +1,279 @@ +#import "DetectorUtils.h" + +@implementation DetectorUtils + ++ (NSArray *)matToNSArray:(const cv::Mat &)mat { + cv::Scalar mean(0.485, 0.456, 0.406); + cv::Scalar variance(0.229, 0.224, 0.225); + + int pixelCount = mat.cols * mat.rows; + NSMutableArray *floatArray = [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; + for (NSUInteger k = 0; k < pixelCount * 3; k++) { + [floatArray addObject:@0.0]; + } + + for (int i = 0; i < pixelCount; i++) { + int row = i / mat.cols; + int col = i % mat.cols; + cv::Vec3b pixel = mat.at(row, col); + floatArray[0 * pixelCount + i] = @((pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0)); + floatArray[1 * pixelCount + i] = @((pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0)); + floatArray[2 * pixelCount + i] = @((pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0)); + } + + return floatArray; +} + ++ (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { + NSMutableArray *scoreText = [[NSMutableArray alloc] init]; + NSMutableArray *scoreLink = [[NSMutableArray alloc] init]; + + // Iterate through the array and distribute elements to scoreText or scoreLink + [array enumerateObjectsUsingBlock:^(id element, NSUInteger idx, BOOL *stop) { + if (idx % 2 == 0) { // Even index, belongs to scoreText + [scoreText addObject:element]; + } else { // Odd index, belongs to scoreLink + [scoreLink addObject:element]; + } + }]; + + return @{@"ScoreText": scoreText, @"ScoreLink": scoreLink}; +} + ++ (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { + cv::Mat mat(height, width, CV_32F); + + int pixelCount = width * height; + for (int i = 0; i < pixelCount; i++) { + int row = i / width; + int col = i % width; + float value = [array[i] floatValue]; + mat.at(row, col) = value; + } + + return mat; +} + ++ (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText { + cv::Mat textmapCopy = textmap.clone(); + cv::Mat linkmapCopy = linkmap.clone(); + int img_h = textmap.rows; + int img_w = textmap.cols; + cv::Mat textScore, linkScore; + cv::threshold(textmapCopy, textScore, lowText, 1, 0); + cv::threshold(linkmapCopy, linkScore, linkThreshold, 1, 0); + cv::Mat textScoreComb = textScore + linkScore; + cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + cv::Mat labels, stats, centroids; + int nLabels = cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); + + NSMutableArray *detectedBoxes = [NSMutableArray array]; + for (int i = 1; i < nLabels; i++) { + int area = stats.at(i, cv::CC_STAT_AREA); + if (area < 10) continue; + + cv::Mat mask = (labels == i); + double maxVal; + cv::minMaxLoc(textmapCopy, NULL, &maxVal, NULL, NULL, mask); + if (maxVal < textThreshold) continue; + + // Create mask for segmented area + cv::Mat segMap = cv::Mat::zeros(textmap.size(), CV_8U); + segMap.setTo(255, (labels == i)); + + // Dilate the segmented area + int x = stats.at(i, cv::CC_STAT_LEFT); + int y = stats.at(i, cv::CC_STAT_TOP); + int w = stats.at(i, cv::CC_STAT_WIDTH); + int h = stats.at(i, cv::CC_STAT_HEIGHT); + + int niter = sqrt(area * MIN(w, h) / (w * h)) * 2; + int sx = x - niter; + int ex = x + w + niter + 1; + int sy = y - niter; + int ey = y + h + niter + 1; + + if (sx < 0) sx = 0; + if (sy < 0) sy = 0; + if (ex >= img_w) ex = img_w; + if (ey >= img_h) ey = img_h; + + cv::Rect roi(sx, sy, ex - sx, ey - sy); // (x, y, width, height) of the ROI + + // Generate the kernel for dilation + cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + niter, 1 + niter)); + cv::Mat roiSegMap = segMap(roi); // Reference a sub-region of segMap for dilation + cv::dilate(roiSegMap, roiSegMap, kernel); + + // Find contours and fit rotated rectangle + std::vector> contours; + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + if (!contours.empty()) { + cv::RotatedRect minRect = cv::minAreaRect(contours[0]); + cv::Point2f vertices[4]; + minRect.points(vertices); + NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; + for (int j = 0; j < 4; j++) { + CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); + [pointsArray addObject:[NSValue valueWithCGPoint:point]]; + } + [detectedBoxes addObject:pointsArray]; + } + } + + return detectedBoxes; +} + ++ (NSArray *> *)groupTextBox:(NSArray *> *)polys + slopeThs:(CGFloat)slopeThs + ycenterThs:(CGFloat)ycenterThs + heightThs:(CGFloat)heightThs + widthThs:(CGFloat)widthThs + addMargin:(CGFloat)addMargin +{ + NSMutableArray *> *horizontalList = [NSMutableArray array]; + NSMutableArray *> *> *combinedList = [NSMutableArray array]; + NSMutableArray *> *mergedList = [NSMutableArray array]; + + for (NSArray *poly in polys) { + NSArray *xCoords = @[poly[0], poly[2], poly[4], poly[6]]; + + // Array of y coordinates + NSArray *yCoords = @[poly[1], poly[3], poly[5], poly[7]]; + + // Calculating max and min values for x coordinates + NSNumber *xMaxNumber = [xCoords valueForKeyPath:@"@max.self"]; + NSNumber *xMinNumber = [xCoords valueForKeyPath:@"@min.self"]; + int xMax = [xMaxNumber intValue]; // Convert max float value to int + int xMin = [xMinNumber intValue]; // Convert min float value to int + + // Calculating max and min values for y coordinates + NSNumber *yMaxNumber = [yCoords valueForKeyPath:@"@max.self"]; + NSNumber *yMinNumber = [yCoords valueForKeyPath:@"@min.self"]; + int yMax = [yMaxNumber intValue]; // Convert max float value to int + int yMin = [yMinNumber intValue]; // Convert min float value to int + + [horizontalList addObject:[@[@(xMin), @(xMax), @(yMin), @(yMax), @((yMin + yMax) / 2.0), @(yMax - yMin)] mutableCopy]]; + } + + [horizontalList sortUsingComparator:^NSComparisonResult(NSMutableArray *obj1, NSMutableArray *obj2) { + return [obj1[4] compare:obj2[4]]; // Sorting by y_center + }]; + + NSMutableArray *newBox = [NSMutableArray array]; + NSMutableArray *bHeight = [NSMutableArray array]; + NSMutableArray *bYcenter = [NSMutableArray array]; + + for (NSArray *box in horizontalList) { + if (newBox.count == 0) { + [bHeight addObject:box[5]]; + [bYcenter addObject:box[4]]; + [newBox addObject:box]; + } else { + if (fabs([[bYcenter valueForKeyPath:@"@avg.self"] floatValue] - [box[4] floatValue]) < ycenterThs * [[bHeight valueForKeyPath:@"@avg.self"] floatValue]) { + [bHeight addObject:box[5]]; + [bYcenter addObject:box[4]]; + [newBox addObject:box]; + } else { + [combinedList addObject:[newBox copy]]; + [newBox removeAllObjects]; + [newBox addObject:box]; + bHeight = [@[box[5]] mutableCopy]; + bYcenter = [@[box[4]] mutableCopy]; + } + } + } + + [combinedList addObject:[newBox copy]]; + + for (NSArray *boxes in combinedList) { + if ([boxes count] == 1) { // If there is only one box in the line + NSArray *box = boxes[0]; + int margin = (int)(addMargin * MIN([box[1] floatValue] - [box[0] floatValue], [box[5] floatValue])); + [mergedList addObject:@[@([box[0] intValue] - margin), + @([box[1] intValue] + margin), + @([box[2] intValue] - margin), + @([box[3] intValue] + margin)]]; + } else { // There are multiple boxes to be merged + NSArray *sortedBoxes = [boxes sortedArrayUsingComparator:^NSComparisonResult(NSArray *obj1, NSArray *obj2) { + return [@([obj1[0] intValue]) compare:@([obj2[0] intValue])]; // Sort boxes by x_min + }]; + + NSMutableArray *mergedBox = [NSMutableArray array]; + NSMutableArray *newBox = [NSMutableArray array]; + int xMax = 0; + NSMutableArray *bHeight = [NSMutableArray array]; + + for (NSArray *box in sortedBoxes) { + if ([newBox count] == 0) { + [bHeight addObject:box[5]]; + xMax = [box[1] intValue]; + [newBox addObject:box]; + } else { + int currHeight = [box[5] intValue]; + float meanHeight = [[bHeight valueForKeyPath:@"@avg.self"] floatValue]; + if (fabs(meanHeight - currHeight) < heightThs * meanHeight && + ([box[0] intValue] - xMax) < widthThs * ([box[3] intValue] - [box[2] intValue])) { + // merge condition is met + [bHeight addObject:box[5]]; + xMax = [box[1] intValue]; + [newBox addObject:box]; + } else { + [mergedBox addObject:[newBox copy]]; + newBox = [@[box] mutableCopy]; + bHeight = [@[box[5]] mutableCopy]; + xMax = [box[1] intValue]; + } + } + } + if ([newBox count] > 0) { + [mergedBox addObject:newBox]; + } + + // Create merged boxes from merged box array + for (NSArray *mbox in mergedBox) { + if ([mbox count] != 1) { + NSNumber *xMin = [mbox[0] objectAtIndex:0]; // minX + NSNumber *xMax = [mbox[0] objectAtIndex:1]; // maxX + NSNumber *yMin = [mbox[0] objectAtIndex:2]; // minY + NSNumber *yMax = [mbox[0] objectAtIndex:3]; // maxY + // Iterate over each box in the mbox array to find min and max + for (NSArray *box in mbox) { + if ([box[0] intValue] < [xMin intValue]) { + xMin = box[0]; + } + if([box[1] intValue] > [xMax intValue]) { + xMax = box[1]; + } + if ([box[2] intValue] < [yMin intValue]) { + yMin = box[2]; + } + if ([box[3] intValue] > [yMax intValue]) { + yMax = box[3]; + } + } + + int margin = (int)(addMargin * MIN([xMax floatValue] - [xMin floatValue], [yMax floatValue] - [yMin floatValue])); + [mergedList addObject:@[@([xMin intValue] - margin), + @([xMax intValue] + margin), + @([yMin intValue] - margin), + @([yMax intValue] + margin)]]; + } else { + NSArray *box = mbox[0]; + int margin = (int)(addMargin * MIN([box[1] floatValue] - [box[0] floatValue], [box[3] floatValue] - [box[2] floatValue])); + [mergedList addObject:@[@([box[0] intValue] - margin), + @([box[1] intValue] + margin), + @([box[2] intValue] - margin), + @([box[3] intValue] + margin)]]; + } + } + } + } + + return mergedList; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h new file mode 100644 index 0000000000..0304ad37e3 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -0,0 +1,7 @@ +#import + +@interface OCRUtils : NSObject + ++ (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm new file mode 100644 index 0000000000..071ea2bd9c --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -0,0 +1,50 @@ +#import "OCRUtils.h" + +@implementation OCRUtils + ++ (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { + const int height = img.rows; + const int width = img.cols; + const float heightRatio = (float)desiredHeight / height; + const float widthRatio = (float)desiredWidth / width; + const float resizeRatio = MIN(heightRatio, widthRatio); + + const int newWidth = width * resizeRatio; + const int newHeight = height * resizeRatio; + + cv::Mat resizedImg; + cv::resize(img, resizedImg, cv::Size(newWidth, newHeight), 0, 0, cv::INTER_AREA); + + // Estimating the background color by sampling from the corners of the image + const int cornerPatchSize = MAX(1, MIN(height, width) / 30); + std::vector corners = { + img(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), + img(cv::Rect(width - cornerPatchSize, 0, cornerPatchSize, cornerPatchSize)), + img(cv::Rect(0, height - cornerPatchSize, cornerPatchSize, cornerPatchSize)), + img(cv::Rect(width - cornerPatchSize, height - cornerPatchSize, cornerPatchSize, cornerPatchSize)) + }; + + cv::Scalar backgroundScalar = cv::mean(corners[0]); + for (int i = 1; i < corners.size(); i++) { + backgroundScalar += cv::mean(corners[i]); + } + backgroundScalar /= (double)corners.size(); + + backgroundScalar[0] = cvFloor(backgroundScalar[0]); + backgroundScalar[1] = cvFloor(backgroundScalar[1]); + backgroundScalar[2] = cvFloor(backgroundScalar[2]); + + const int deltaW = desiredWidth - newWidth; + const int deltaH = desiredHeight - newHeight; + const int top = deltaH / 2; + const int bottom = deltaH - top; + const int left = deltaW / 2; + const int right = deltaW - left; + + cv::Mat centeredImg; + cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, cv::BORDER_CONSTANT, backgroundScalar); + + return centeredImg; +} + +@end diff --git a/src/OCR.ts b/src/OCR.ts new file mode 100644 index 0000000000..43c36ecbff --- /dev/null +++ b/src/OCR.ts @@ -0,0 +1,77 @@ +import { useEffect, useState } from 'react'; +import { ResourceSource } from './types/common'; +import { OCR } from './native/RnExecutorchModules'; +import { ETError, getError } from './Error'; +import { Image } from 'react-native'; + +interface OCRModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: string) => Promise; +} + +const getModelPath = (source: ResourceSource) => { + if (typeof source === 'number') { + return Image.resolveAssetSource(source).uri; + } + return source; +}; + +export const useOCR = ({ + detectorSource, + recognizerSources, + language = 'en', +}: { + detectorSource: ResourceSource; + recognizerSources: ResourceSource[]; + language?: string; +}): OCRModule => { + const [error, setError] = useState(null); + const [isReady, setIsReady] = useState(false); + const [isGenerating, setIsGenerating] = useState(false); + + useEffect(() => { + const loadModel = async () => { + if (!detectorSource || recognizerSources.length === 0) return; + + const detectorPath = getModelPath(detectorSource); + const recognizerPaths = recognizerSources.map(getModelPath); + try { + setIsReady(false); + await OCR.loadModule(detectorPath, recognizerPaths, language); + setIsReady(true); + } catch (e) { + setError(getError(e)); + } + }; + + loadModel(); + }, [detectorSource, language, recognizerSources.length]); + + const forward = async (input: string) => { + if (!isReady) { + throw new Error(getError(ETError.ModuleNotLoaded)); + } + if (isGenerating) { + throw new Error(getError(ETError.ModelGenerating)); + } + + try { + setIsGenerating(true); + const output = await OCR.forward(input); + return output; + } catch (e) { + throw new Error(getError(e)); + } finally { + setIsGenerating(false); + } + }; + + return { + error, + isReady, + isGenerating, + forward, + }; +}; diff --git a/src/native/NativeOCR.ts b/src/native/NativeOCR.ts new file mode 100644 index 0000000000..f819ca8975 --- /dev/null +++ b/src/native/NativeOCR.ts @@ -0,0 +1,13 @@ +import type { TurboModule } from 'react-native'; +import { TurboModuleRegistry } from 'react-native'; + +export interface Spec extends TurboModule { + loadModule( + detectorSource: string, + recognizerSources: string[], + language: string + ): Promise; + forward(input: string): Promise; +} + +export default TurboModuleRegistry.get('OCR'); diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index fc1cfe28db..c8044aa473 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -88,6 +88,19 @@ const SpeechToText = SpeechToTextSpec } ); +const OCRSpec = require('./NativeOCR').default; + +const OCR = OCRSpec + ? OCRSpec + : new Proxy( + {}, + { + get() { + throw new Error(LINKING_ERROR); + }, + } + ); + class _ObjectDetectionModule { async forward( input: string @@ -168,6 +181,7 @@ export { ObjectDetection, StyleTransfer, SpeechToText, + OCR, _ETModule, _ClassificationModule, _StyleTransferModule, From a0e65a4b7f15da0e61e5ef49cf008b8233932394 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 17 Jan 2025 17:29:06 +0100 Subject: [PATCH 02/19] fix: fixes to groupTextBox and getDetBox function to make it return similar result to original version --- ios/RnExecutorch/models/ocr/Detector.mm | 20 ++++++++++++++++-- .../models/ocr/utils/DetectorUtils.h | 1 - .../models/ocr/utils/DetectorUtils.mm | 21 +++++++------------ 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 2c91da8301..23f3b2f9aa 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -52,14 +52,30 @@ - (NSArray *)postprocess:(NSArray *)output { CGPoint point = [value CGPointValue]; point.x *= 2; point.y *= 2; +// NSLog(@"%d %d", (int)point.x, (int)point.y); [boxArray addObject:@((int)point.x)]; [boxArray addObject:@((int)point.y)]; } [single_img_result addObject:boxArray]; } - NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result slopeThs:0.1 ycenterThs:0.5 heightThs:0.5 widthThs:1.0 addMargin:0.1]; - + NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result ycenterThs:0.5 heightThs:0.5 widthThs:0.5 addMargin:0.1]; + NSLog(@"%lu", (unsigned long)[horizontalList count]); + + NSMutableArray *boxesToKeep = [NSMutableArray array]; // Create a new array to keep the boxes that fit the condition + + for (NSArray *box in horizontalList) { + if (MAX([box[1] intValue] - [box[0] intValue], [box[3] intValue] - [box[2] intValue]) >= 20) { + + [boxesToKeep addObject:box]; // Add the box to the new array only if it meets the condition + } + } + + horizontalList = [NSMutableArray arrayWithArray:boxesToKeep]; + NSLog(@"%lu", (unsigned long)[horizontalList count]); + for(NSArray *box in horizontalList){ + NSLog(@"%d %d %d %d", [box[0] intValue], [box[1] intValue], [box[2] intValue], [box[3] intValue]); + } return horizontalList; } diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index a6b25969d1..716d0e402e 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -7,7 +7,6 @@ + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height; + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; + (NSArray *> *)groupTextBox:(NSArray *> *)polys - slopeThs:(CGFloat)slopeThs ycenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs widthThs:(CGFloat)widthThs diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 8b6edb12c1..21aa76a9cb 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -89,18 +89,15 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold int y = stats.at(i, cv::CC_STAT_TOP); int w = stats.at(i, cv::CC_STAT_WIDTH); int h = stats.at(i, cv::CC_STAT_HEIGHT); - - int niter = sqrt(area * MIN(w, h) / (w * h)) * 2; + int niter = (int)(sqrt((double)(area * MIN(w, h)) / (double)(w * h)) * 2.0); int sx = x - niter; int ex = x + w + niter + 1; int sy = y - niter; int ey = y + h + niter + 1; - if (sx < 0) sx = 0; if (sy < 0) sy = 0; if (ex >= img_w) ex = img_w; if (ey >= img_h) ey = img_h; - cv::Rect roi(sx, sy, ex - sx, ey - sy); // (x, y, width, height) of the ROI // Generate the kernel for dilation @@ -128,7 +125,6 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold } + (NSArray *> *)groupTextBox:(NSArray *> *)polys - slopeThs:(CGFloat)slopeThs ycenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs widthThs:(CGFloat)widthThs @@ -147,14 +143,14 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold // Calculating max and min values for x coordinates NSNumber *xMaxNumber = [xCoords valueForKeyPath:@"@max.self"]; NSNumber *xMinNumber = [xCoords valueForKeyPath:@"@min.self"]; - int xMax = [xMaxNumber intValue]; // Convert max float value to int - int xMin = [xMinNumber intValue]; // Convert min float value to int + float xMax = [xMaxNumber floatValue]; // Convert max float value to int + float xMin = [xMinNumber floatValue]; // Convert min float value to int // Calculating max and min values for y coordinates NSNumber *yMaxNumber = [yCoords valueForKeyPath:@"@max.self"]; NSNumber *yMinNumber = [yCoords valueForKeyPath:@"@min.self"]; - int yMax = [yMaxNumber intValue]; // Convert max float value to int - int yMin = [yMinNumber intValue]; // Convert min float value to int + float yMax = [yMaxNumber floatValue]; // Convert max float value to int + float yMin = [yMinNumber floatValue]; // Convert min float value to int [horizontalList addObject:[@[@(xMin), @(xMax), @(yMin), @(yMax), @((yMin + yMax) / 2.0), @(yMax - yMin)] mutableCopy]]; } @@ -162,11 +158,10 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold [horizontalList sortUsingComparator:^NSComparisonResult(NSMutableArray *obj1, NSMutableArray *obj2) { return [obj1[4] compare:obj2[4]]; // Sorting by y_center }]; - + NSMutableArray *newBox = [NSMutableArray array]; NSMutableArray *bHeight = [NSMutableArray array]; NSMutableArray *bYcenter = [NSMutableArray array]; - for (NSArray *box in horizontalList) { if (newBox.count == 0) { [bHeight addObject:box[5]]; @@ -188,9 +183,9 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold } [combinedList addObject:[newBox copy]]; - for (NSArray *boxes in combinedList) { if ([boxes count] == 1) { // If there is only one box in the line + NSLog(@"One in line"); NSArray *box = boxes[0]; int margin = (int)(addMargin * MIN([box[1] floatValue] - [box[0] floatValue], [box[5] floatValue])); [mergedList addObject:@[@([box[0] intValue] - margin), @@ -272,7 +267,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold } } } - + NSLog(@"Merged List Count: %lu", (unsigned long)[mergedList count]); return mergedList; } From 4190ced895f9ecc51bb72be6d8fb0380d82a56a3 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 22 Jan 2025 15:22:58 +0100 Subject: [PATCH 03/19] feat: implemented recognition (ios) --- ios/RnExecutorch/OCR.mm | 7 +- ios/RnExecutorch/models/ocr/Detector.mm | 24 ++-- .../models/ocr/RecognitionHandler.h | 7 + .../models/ocr/RecognitionHandler.mm | 134 ++++++++++++++++++ .../models/ocr/utils/CTCLabelConverter.h | 15 ++ .../models/ocr/utils/CTCLabelConverter.mm | 98 +++++++++++++ ios/RnExecutorch/models/ocr/utils/OCRUtils.h | 5 + ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 90 ++++++++++++ 8 files changed, 363 insertions(+), 17 deletions(-) create mode 100644 ios/RnExecutorch/models/ocr/RecognitionHandler.h create mode 100644 ios/RnExecutorch/models/ocr/RecognitionHandler.mm create mode 100644 ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h create mode 100644 ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index a35db266d4..96a6175016 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -4,9 +4,11 @@ #import #import "utils/ImageProcessor.h" #import "models/ocr/Detector.h" +#import "models/ocr/RecognitionHandler.h" @implementation OCR { Detector *detector; + RecognitionHandler *handler; } RCT_EXPORT_MODULE() @@ -16,7 +18,6 @@ - (void)loadModule:(NSString *)detectorSource language:(NSString *)language resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { - NSLog(@"TEST"); detector = [[Detector alloc] init]; [detector loadModel:[NSURL URLWithString:detectorSource] completion:^(BOOL success, NSNumber *errorCode) { @@ -43,6 +44,10 @@ - (void)forward:(NSString *)input @try { cv::Mat image = [ImageProcessor readImage:input]; NSArray* result = [detector runModel:image]; + cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); + handler = [[RecognitionHandler alloc] init]; + NSLog(@"TEST"); + [handler recognize:result imgGray:image desiredWidth:1280 desiredHeight:1280]; resolve(result); } @catch (NSException *exception) { reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason], diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 23f3b2f9aa..6038cc101f 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -47,12 +47,10 @@ - (NSArray *)postprocess:(NSArray *)output { for (NSUInteger i = 0; i < [boxes count]; i++) { NSArray *box = boxes[i]; NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; - // Iterate over each point in the box for (NSValue *value in box) { CGPoint point = [value CGPointValue]; point.x *= 2; point.y *= 2; -// NSLog(@"%d %d", (int)point.x, (int)point.y); [boxArray addObject:@((int)point.x)]; [boxArray addObject:@((int)point.y)]; } @@ -60,22 +58,17 @@ - (NSArray *)postprocess:(NSArray *)output { } NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result ycenterThs:0.5 heightThs:0.5 widthThs:0.5 addMargin:0.1]; - NSLog(@"%lu", (unsigned long)[horizontalList count]); - - NSMutableArray *boxesToKeep = [NSMutableArray array]; // Create a new array to keep the boxes that fit the condition - + + NSMutableArray *boxesToKeep = [NSMutableArray array]; + for (NSArray *box in horizontalList) { - if (MAX([box[1] intValue] - [box[0] intValue], [box[3] intValue] - [box[2] intValue]) >= 20) { - - [boxesToKeep addObject:box]; // Add the box to the new array only if it meets the condition - } + if (MAX([box[1] intValue] - [box[0] intValue], [box[3] intValue] - [box[2] intValue]) >= 20) { + + [boxesToKeep addObject:box]; + } } - + horizontalList = [NSMutableArray arrayWithArray:boxesToKeep]; - NSLog(@"%lu", (unsigned long)[horizontalList count]); - for(NSArray *box in horizontalList){ - NSLog(@"%d %d %d %d", [box[0] intValue], [box[1] intValue], [box[2] intValue], [box[3] intValue]); - } return horizontalList; } @@ -83,7 +76,6 @@ - (NSArray *)runModel:(cv::Mat &)input { NSArray *modelInput = [self preprocess:input]; NSArray *modelResult = [self forward:modelInput]; NSArray *result = [self postprocess:modelResult]; - NSLog(@"Running Inference with detector model"); return result; } diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h new file mode 100644 index 0000000000..707700a97a --- /dev/null +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -0,0 +1,7 @@ +#import "opencv2/opencv.hpp" + +@interface RecognitionHandler: NSObject + +- (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; + +@end diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm new file mode 100644 index 0000000000..0f5e891e7e --- /dev/null +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -0,0 +1,134 @@ +#import "RecognitionHandler.h" +#import +#import "./utils/OCRUtils.h" +#import "../../utils/ImageProcessor.h" +#import "./utils/CTCLabelConverter.h" +#import "ExecutorchLib/ETModel.h" + +@implementation RecognitionHandler + +- (NSArray *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix { + // Ensure the matrix is 2D and has more than one column to avoid trivial results. + NSAssert(matrix.dims == 2 && matrix.cols > 1, @"Matrix must be 2D with more than one column."); + + NSMutableArray *maxIndices = [NSMutableArray array]; + + // Iterating over each row to find the index of the max element + for (int i = 0; i < matrix.rows; i++) { + double maxVal; // Variable to store the maximum value (not used) + cv::Point maxLoc; // This will store the location of the maximum value + cv::minMaxLoc(matrix.row(i), NULL, &maxVal, NULL, &maxLoc); + [maxIndices addObject:@(maxLoc.x)]; // Add the index of the max value to the array + } + + return [maxIndices copy]; // Return an NSArray copy of the mutable array +} + + +- (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { + // Ensure the vector's length matches the number of rows in the matrix + NSAssert(matrix.rows == vector.count, @"Vector length must match number of matrix rows."); + + cv::Mat result = matrix.clone(); // Clone the matrix to keep the original unchanged + + // Iterate through each element in the matrix and divide by the corresponding vector element + for (int i = 0; i < matrix.rows; i++) { + float divisor = [vector[i] floatValue]; // Get the CGFloat value from NSArray + for (int j = 0; j < matrix.cols; j++) { + result.at(i, j) /= divisor; + } + } + + return result; +} + +- (cv::Mat)softmax:(cv::Mat) inputs { + cv::Mat maxVal; + cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); // Find max per row for numerical stability + cv::Mat expInputs; + cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); // Compute exp(values - max) + cv::Mat sumExp; + cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); // Sum of exp per row + cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); // Divide by sum(exp) + return softmaxOutput; +} + +- (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { + NSLog(@"Before padding"); + NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_512" ofType:@"pte"]; + ETModel *recognizer_512 = [[ETModel alloc] init]; + [recognizer_512 loadModel:modelPath]; + ETModel *recognizer_256 = [[ETModel alloc] init]; + modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_256" ofType:@"pte"]; + [recognizer_256 loadModel:modelPath]; + ETModel *recognizer_128 = [[ETModel alloc] init]; + modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_128" ofType:@"pte"]; + [recognizer_128 loadModel:modelPath]; + + imgGray = [OCRUtils resizeWithPadding:imgGray desiredWidth:desiredWidth desiredHeight:desiredHeight]; + for (NSArray *box in horizontalList) { + int maximum_y = imgGray.rows; + int maximum_x = imgGray.cols; + + int x_min = MAX(0, [box[0] intValue]); + int x_max = MIN([box[1] intValue], maximum_x); + int y_min = MAX(0, [box[2] intValue]); + int y_max = MIN([box[3] intValue], maximum_y); + cv::Mat croppedImage = [OCRUtils getCroppedImage:x_max x_min:x_min y_max:y_max y_min:y_min image:imgGray modelHeight:64]; + + + croppedImage = [OCRUtils normalizeForRecognizer:croppedImage adjustContrast:0.0]; + NSArray* modelInput = [ImageProcessor matToNSArrayForGrayscale:croppedImage]; + NSArray *result; + if(croppedImage.cols >= 512) { + result = [recognizer_512 forward:modelInput shape:[recognizer_512 getInputShape:0] inputType:[recognizer_512 getInputType:0]]; + } else if (croppedImage.cols >= 256) { + result = [recognizer_256 forward:modelInput shape:[recognizer_256 getInputShape:0] inputType:[recognizer_256 getInputType:0]]; + } else { + result = [recognizer_128 forward:modelInput shape:[recognizer_128 getInputShape:0] inputType:[recognizer_128 getInputType:0]]; + } + + NSInteger totalNumbers = [result.firstObject count]; + NSInteger numRows = (totalNumbers + 96) / 97; // Each row has 97 columns, round up if needed + + // Initialize the matrix with appropriate size + cv::Mat resultMat = cv::Mat::zeros(numRows, 97, CV_32F); // 97 columns, floating point values + + // Counter for columns and row tracker + NSInteger counter = 0; + NSInteger currentRow = 0; + + for (NSNumber *num in result.firstObject) { + // Set the value in the matrix + resultMat.at(currentRow, counter) = [num floatValue]; + + counter++; + if (counter >= 97) { + counter = 0; // Reset counter if 97 columns are filled + currentRow++; // Move to the next row + } + } + + cv::Mat probabilities = [self softmax:resultMat]; + NSMutableArray* pred_norm = [NSMutableArray arrayWithCapacity:probabilities.rows]; + for(int i = 0; i < probabilities.rows; i++) { + float sum = 0.0; + for(int j = 0; j < 97; j++) { + sum += probabilities.at(i, j); + } + [pred_norm addObject:@(sum)]; + } + + probabilities = [self divideMatrix:probabilities byVector:pred_norm]; + NSString *dictPath = [[NSBundle mainBundle] pathForResource:@"en" ofType:@"txt"]; + CTCLabelConverter *converter = [[CTCLabelConverter alloc] initWithCharacters:@"0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" separatorList:@{} dictPathList:@{@"en": dictPath}]; + NSArray* preds_index = [self indicesOfMaxValuesInMatrix:probabilities]; + NSArray* decodedTexts = [converter decodeGreedyWithTextIndex:preds_index length:(int)(preds_index.count)]; + NSLog(@"%@", decodedTexts[0]); + } + + return [NSArray init]; +} + +@end + diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h new file mode 100644 index 0000000000..7b879166a6 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h @@ -0,0 +1,15 @@ +#import + +@interface CTCLabelConverter : NSObject + +@property (strong, nonatomic) NSMutableDictionary *dict; +@property (strong, nonatomic) NSArray *character; +@property (strong, nonatomic) NSDictionary *separatorList; +@property (strong, nonatomic) NSArray *ignoreIdx; +@property (strong, nonatomic) NSDictionary *dictList; + +- (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDictionary *)separatorList dictPathList:(NSDictionary *)dictPathList; +- (void)loadDictionariesWithDictPathList:(NSDictionary *)dictPathList; +- (NSArray *)decodeGreedyWithTextIndex:(NSArray *)textIndex length:(NSInteger)length; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm new file mode 100644 index 0000000000..e8b8d0fbc5 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm @@ -0,0 +1,98 @@ +#import "CTCLabelConverter.h" + +@implementation CTCLabelConverter + +- (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDictionary *)separatorList dictPathList:(NSDictionary *)dictPathList { + self = [super init]; + if (self) { + _dict = [NSMutableDictionary dictionary]; + NSMutableArray *mutableCharacters = [NSMutableArray arrayWithObject:@"[blank]"]; + + for (NSUInteger i = 0; i < [characters length]; i++) { + NSString *charStr = [NSString stringWithFormat:@"%C", [characters characterAtIndex:i]]; + [mutableCharacters addObject:charStr]; + self.dict[charStr] = @(i + 1); + } + + _character = [mutableCharacters copy]; + _separatorList = separatorList; + + NSMutableArray *ignoreIndexes = [NSMutableArray arrayWithObject:@(0)]; + for (NSString *sep in separatorList.allValues) { + NSUInteger index = [characters rangeOfString:sep].location; + if (index != NSNotFound) { + [ignoreIndexes addObject:@(index)]; + } + } + _ignoreIdx = [ignoreIndexes copy]; + _dictList = [NSDictionary dictionary]; + [self loadDictionariesWithDictPathList:dictPathList]; + } + return self; +} + +- (void)loadDictionariesWithDictPathList:(NSDictionary *)dictPathList { + NSMutableDictionary *tempDictList = [NSMutableDictionary dictionary]; + for (NSString *lang in dictPathList.allKeys) { + NSString *dictPath = dictPathList[lang]; + NSError *error; + NSString *fileContents = [NSString stringWithContentsOfFile:dictPath encoding:NSUTF8StringEncoding error:&error]; + if (error) { + NSLog(@"Error reading file: %@", error.localizedDescription); + continue; + } + NSArray *lines = [fileContents componentsSeparatedByCharactersInSet:[NSCharacterSet newlineCharacterSet]]; + [tempDictList setObject:lines forKey:lang]; + } + _dictList = [tempDictList copy]; +} + +- (NSArray *)decodeGreedyWithTextIndex:(NSArray *)textIndex length:(NSInteger)length { + NSMutableArray *texts = [NSMutableArray array]; + NSUInteger index = 0; + + // Loop until you've processed all characters + while (index < textIndex.count) { + NSUInteger segmentLength = MIN(length, textIndex.count - index); // Calculate size of the current segment + NSRange range = NSMakeRange(index, segmentLength); + NSArray *subArray = [textIndex subarrayWithRange:range]; + + NSMutableString *text = [NSMutableString string]; + NSNumber *lastChar = nil; + + // Creating mutable arrays to store states like in Python with `a` and `b` + NSMutableArray *isNotRepeated = [NSMutableArray arrayWithObject:@YES]; // First character is always not repeated + NSMutableArray *isNotIgnored = [NSMutableArray array]; + + for (NSUInteger i = 0; i < subArray.count; i++) { + NSNumber *currentChar = subArray[i]; + // Check if character is repeated + if (i > 0) { // From second character onward + [isNotRepeated addObject:@(![lastChar isEqualToNumber:currentChar])]; + } + // Check if the current character is in the ignore list + [isNotIgnored addObject:@(![self.ignoreIdx containsObject:currentChar])]; + + lastChar = currentChar; // Update lastChar to current character + } + + // Combine `isNotRepeated` and `isNotIgnored` conditions just like combining 'a' and 'b' in Python + for (NSUInteger j = 0; j < subArray.count; j++) { + if ([isNotRepeated[j] boolValue] && [isNotIgnored[j] boolValue]) { + NSUInteger charIndex = [subArray[j] unsignedIntegerValue]; + [text appendString:self.character[charIndex]]; + } + } + + [texts addObject:text.copy]; + index += segmentLength; // Move index forward + + if (segmentLength < length) { // If reached the end of textIndex + break; + } + } + + return texts.copy; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h index 0304ad37e3..ab9c6859ef 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -3,5 +3,10 @@ @interface OCRUtils : NSObject + (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; ++ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight; ++ (CGFloat)calculateRatioWithWidth:(int)width height:(int)height; ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight; ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast; ++ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; @end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index 071ea2bd9c..24110f8e31 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -47,4 +47,94 @@ @implementation OCRUtils return centeredImg; } ++ (CGFloat)calculateRatioWithWidth:(int)width height:(int)height { + CGFloat ratio = (CGFloat)width / (CGFloat)height; + if (ratio < 1.0) { + ratio = 1.0 / ratio; + } + return ratio; +} + ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight { + CGFloat ratio = (CGFloat)width / (CGFloat)height; + if (ratio < 1.0) { + ratio = [self calculateRatioWithWidth:width height:height]; + cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, 0, cv::INTER_LANCZOS4); + } else { + cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, 0, cv::INTER_LANCZOS4); + } + return img; +} + ++ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight { + cv::Rect region(x_min, y_min, x_max - x_min, y_max - y_min); + cv::Mat crop_img = image(region); + + int width = x_max - x_min; + int height = y_max - y_min; + + CGFloat ratio = [OCRUtils calculateRatioWithWidth:width height:height]; + int new_width = (int)(modelHeight * ratio); + + if (new_width == 0) { + return crop_img; // Return nil if calculated new_width is zero to avoid further processing + } + + crop_img = [OCRUtils computeRatioAndResize:crop_img width:width height:height modelHeight:modelHeight]; + + return crop_img; +} + ++ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target { + double contrast = 0.0; + int high = 0; + int low = 255; + + // Calculate existing contrast, high, and low + for (int i = 0; i < img.rows; ++i) { + for (int j = 0; j < img.cols; ++j) { + uchar pixel = img.at(i, j); + high = MAX(high, pixel); + low = MIN(low, pixel); + } + } + contrast = (high - low) / 255.0; + + // Adjust contrast if below the target + if (contrast < target) { + double ratio = 200.0 / MAX(10, high - low); + img.convertTo(img, CV_32F); // Convert to float for scaling operations + img = ((img - low + 25) * ratio); + + // Clipping values to ensure they remain within valid range + cv::threshold(img, img, 255, 255, cv::THRESH_TRUNC); // Cap values at 255 + cv::threshold(img, img, 0, 0, cv::THRESH_TOZERO); // Ensure no negative values + + img.convertTo(img, CV_8U); // Convert back to 8-bit pixel values + } + + return img; +} + ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast { + if (adjustContrast > 0) { + image = [OCRUtils adjustContrastGrey:image target:adjustContrast]; // Make sure this method exists and works as expected + } + + int desiredWidth = 128; + if (image.cols >= 512) { + desiredWidth = 512; + } else if (image.cols >= 256) { + desiredWidth = 256; + } + + image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth desiredHeight:64]; + + // Normalization: (image / 255.0 - 0.5) * 2.0 + image.convertTo(image, CV_32F, 1.0 / 255.0); // Scale pixel values to [0,1] + image = (image - 0.5) * 2.0; // Shift to [-1,1] + + return image; +} + @end From d6d2e1d2c4086aa1e8b69202d18eb7caac421688 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 22 Jan 2025 16:35:20 +0100 Subject: [PATCH 04/19] fix: add missing function to ImageProcessor --- ios/RnExecutorch/utils/ImageProcessor.h | 1 + ios/RnExecutorch/utils/ImageProcessor.mm | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/ios/RnExecutorch/utils/ImageProcessor.h b/ios/RnExecutorch/utils/ImageProcessor.h index 4bb7034e87..e2c6d34651 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.h +++ b/ios/RnExecutorch/utils/ImageProcessor.h @@ -5,6 +5,7 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat; + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height; ++ (NSArray *)matToArrayForGrayscale:(const cv::Mat &)mat; + (NSString *)saveToTempFile:(const cv::Mat &)image; + (cv::Mat)readImage:(NSString *)source; diff --git a/ios/RnExecutorch/utils/ImageProcessor.mm b/ios/RnExecutorch/utils/ImageProcessor.mm index feab17f608..4b932663a5 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.mm +++ b/ios/RnExecutorch/utils/ImageProcessor.mm @@ -22,6 +22,26 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat { return floatArray; } ++ (NSArray *)matToArrayForGrayscale:(const cv::Mat &)mat { + if (mat.empty() || mat.type() != CV_32F) { + NSLog(@"Invalid or empty matrix or matrix not of type CV_32F."); + return @[]; + } + + NSMutableArray *pixelArray = [[NSMutableArray alloc] initWithCapacity:mat.cols * mat.rows]; + + // Iterate through every pixel in the matrix + for (int row = 0; row < mat.rows; row++) { + for (int col = 0; col < mat.cols; col++) { + // Access and add the pixel value directly as a float, store as NSNumber + float pixelValue = mat.at(row, col); + [pixelArray addObject:@(pixelValue)]; + } + } + + return pixelArray; +} + + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { cv::Mat mat(height, width, CV_8UC3); From a2efb615caaaa55664eeb975d656a470b88728c7 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 23 Jan 2025 12:51:12 +0100 Subject: [PATCH 05/19] feat: finished recognition, added confidence score and bounding boxes in returned object --- .../computer-vision/screens/OCRScreen.tsx | 71 +++++++++----- ios/RnExecutorch/OCR.mm | 3 +- .../models/ocr/RecognitionHandler.h | 4 +- .../models/ocr/RecognitionHandler.mm | 92 +++++++++++++------ .../models/ocr/utils/CTCLabelConverter.h | 10 +- 5 files changed, 117 insertions(+), 63 deletions(-) diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 1493c30360..2006ffaa75 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -2,7 +2,9 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../components/BottomBar'; import { getImage } from '../utils'; import { useOCR } from 'react-native-executorch'; -import { View, StyleSheet, Image } from 'react-native'; +import { View, StyleSheet, Image, Text } from 'react-native'; +import { useState } from 'react'; +import ImageWithBboxes from '../components/ImageWithBboxes'; export const OCRScreen = ({ imageUri, @@ -11,6 +13,12 @@ export const OCRScreen = ({ imageUri: string; setImageUri: (imageUri: string) => void; }) => { + const [results, setResults] = useState([]); + const [imageDimensions, setImageDimensions] = useState<{ + width: number; + height: number; + }>(); + const [detectedText, setDetectedText] = useState(''); const model = useOCR({ detectorSource: require('../assets/models/xnnpack_craft.pte'), recognizerSources: [require('../assets/models/xnnpack_crnn_128.pte')], @@ -18,24 +26,27 @@ export const OCRScreen = ({ const handleCameraPress = async (isCamera: boolean) => { const image = await getImage(isCamera); + const width = image?.width; + const height = image?.height; + setImageDimensions({ width: width as number, height: height as number }); const uri = image?.uri; if (typeof uri === 'string') { setImageUri(uri as string); + setResults([]); } }; - const shape = [1, 1, 64, 128]; - const input = new Float32Array(shape[1] * shape[2] * shape[3]); - - for (let i = 0; i < shape[1] * shape[2] * shape[3]; i++) { - input[i] = Math.random() * 255; - } - const runForward = async () => { try { const output = await model.forward(imageUri); - console.log(output[0]); - console.log(output[1]); + setResults(output); + console.log(output); + let txt = ''; + output.forEach((detection: any) => { + txt += detection.text + ' '; + }); + console.log(txt); + setDetectedText(txt); } catch (e) { console.error(e); } @@ -46,19 +57,29 @@ export const OCRScreen = ({ ); } - + console.log(imageDimensions?.width, imageDimensions?.height); return ( <> - + + {imageUri && imageDimensions?.width && imageDimensions?.height ? ( + + ) : ( + + )} + + {detectedText} *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix { - // Ensure the matrix is 2D and has more than one column to avoid trivial results. - NSAssert(matrix.dims == 2 && matrix.cols > 1, @"Matrix must be 2D with more than one column."); - NSMutableArray *maxIndices = [NSMutableArray array]; - // Iterating over each row to find the index of the max element for (int i = 0; i < matrix.rows; i++) { - double maxVal; // Variable to store the maximum value (not used) - cv::Point maxLoc; // This will store the location of the maximum value + double maxVal; + cv::Point maxLoc; cv::minMaxLoc(matrix.row(i), NULL, &maxVal, NULL, &maxLoc); - [maxIndices addObject:@(maxLoc.x)]; // Add the index of the max value to the array + [maxIndices addObject:@(maxLoc.x)]; } - return [maxIndices copy]; // Return an NSArray copy of the mutable array + return [maxIndices copy]; } - (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { - // Ensure the vector's length matches the number of rows in the matrix - NSAssert(matrix.rows == vector.count, @"Vector length must match number of matrix rows."); - - cv::Mat result = matrix.clone(); // Clone the matrix to keep the original unchanged + cv::Mat result = matrix.clone(); - // Iterate through each element in the matrix and divide by the corresponding vector element for (int i = 0; i < matrix.rows; i++) { - float divisor = [vector[i] floatValue]; // Get the CGFloat value from NSArray + float divisor = [vector[i] floatValue]; for (int j = 0; j < matrix.cols; j++) { result.at(i, j) /= divisor; } @@ -44,17 +36,29 @@ @implementation RecognitionHandler - (cv::Mat)softmax:(cv::Mat) inputs { cv::Mat maxVal; - cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); // Find max per row for numerical stability + cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); cv::Mat expInputs; - cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); // Compute exp(values - max) + cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); cv::Mat sumExp; - cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); // Sum of exp per row - cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); // Divide by sum(exp) + cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); + cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); return softmaxOutput; } - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { - NSLog(@"Before padding"); + const float newRatioH = (float)desiredHeight / imgGray.rows; + const float newRatioW = (float)desiredWidth / imgGray.cols; + float resizeRatio = MIN(newRatioH, newRatioW); + const int newWidth = imgGray.cols * resizeRatio; + const int newHeight = imgGray.rows * resizeRatio; + const int deltaW = desiredWidth - newWidth; + const int deltaH = desiredHeight - newHeight; + const int top = deltaH / 2; + const int left= deltaW / 2; + float heightRatio = (float)imgGray.rows / desiredHeight; + float widthRatio = (float)imgGray.cols / desiredWidth; + resizeRatio = MAX(heightRatio, widthRatio); + NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_512" ofType:@"pte"]; ETModel *recognizer_512 = [[ETModel alloc] init]; [recognizer_512 loadModel:modelPath]; @@ -66,6 +70,7 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir [recognizer_128 loadModel:modelPath]; imgGray = [OCRUtils resizeWithPadding:imgGray desiredWidth:desiredWidth desiredHeight:desiredHeight]; + NSMutableArray *predictions = [NSMutableArray array]; for (NSArray *box in horizontalList) { int maximum_y = imgGray.rows; int maximum_x = imgGray.cols; @@ -78,7 +83,7 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir croppedImage = [OCRUtils normalizeForRecognizer:croppedImage adjustContrast:0.0]; - NSArray* modelInput = [ImageProcessor matToNSArrayForGrayscale:croppedImage]; + NSArray* modelInput = [ImageProcessor matToArrayForGrayscale:croppedImage]; NSArray *result; if(croppedImage.cols >= 512) { result = [recognizer_512 forward:modelInput shape:[recognizer_512 getInputShape:0] inputType:[recognizer_512 getInputType:0]]; @@ -89,23 +94,20 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir } NSInteger totalNumbers = [result.firstObject count]; - NSInteger numRows = (totalNumbers + 96) / 97; // Each row has 97 columns, round up if needed + NSInteger numRows = (totalNumbers + 96) / 97; - // Initialize the matrix with appropriate size - cv::Mat resultMat = cv::Mat::zeros(numRows, 97, CV_32F); // 97 columns, floating point values + cv::Mat resultMat = cv::Mat::zeros(numRows, 97, CV_32F); - // Counter for columns and row tracker NSInteger counter = 0; NSInteger currentRow = 0; for (NSNumber *num in result.firstObject) { - // Set the value in the matrix resultMat.at(currentRow, counter) = [num floatValue]; counter++; if (counter >= 97) { - counter = 0; // Reset counter if 97 columns are filled - currentRow++; // Move to the next row + counter = 0; + currentRow++; } } @@ -124,10 +126,42 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir CTCLabelConverter *converter = [[CTCLabelConverter alloc] initWithCharacters:@"0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" separatorList:@{} dictPathList:@{@"en": dictPath}]; NSArray* preds_index = [self indicesOfMaxValuesInMatrix:probabilities]; NSArray* decodedTexts = [converter decodeGreedyWithTextIndex:preds_index length:(int)(preds_index.count)]; - NSLog(@"%@", decodedTexts[0]); + NSMutableArray *valuesArray = [NSMutableArray array]; + NSMutableArray *indicesArray = [NSMutableArray array]; + for (int i = 0; i < probabilities.rows; i++) { + double maxVal = 0; + cv::Point maxLoc; + cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); + + [valuesArray addObject:@(maxVal)]; + [indicesArray addObject:@(maxLoc.x)]; + } + + NSMutableArray *predsMaxProb = [NSMutableArray array]; + + for (NSUInteger index = 0; index < indicesArray.count; index++) { + NSNumber *indicator = indicesArray[index]; + if ([indicator intValue] != 0) { + [predsMaxProb addObject:valuesArray[index]]; + } + } + + + if (predsMaxProb.count == 0) { + [predsMaxProb addObject:@(0)]; + } + + double product = 1.0; + for (NSNumber *prob in predsMaxProb) { + product *= [prob doubleValue]; + } + + double confidenceScore = pow(product, 2.0 / sqrt(predsMaxProb.count)); + NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": @{@"x1": @((int)((x_min - left) * resizeRatio)), @"x2": @((int)((x_max - left) * resizeRatio)), @"y1": @((int)((y_min - top) * resizeRatio)), @"y2":@((int)((y_max - top) * resizeRatio))}, @"score": @(confidenceScore)}; + [predictions addObject:res]; } - return [NSArray init]; + return predictions; } @end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h index 7b879166a6..7cc167f2b1 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h @@ -2,11 +2,11 @@ @interface CTCLabelConverter : NSObject -@property (strong, nonatomic) NSMutableDictionary *dict; -@property (strong, nonatomic) NSArray *character; -@property (strong, nonatomic) NSDictionary *separatorList; -@property (strong, nonatomic) NSArray *ignoreIdx; -@property (strong, nonatomic) NSDictionary *dictList; +@property(strong, nonatomic) NSMutableDictionary *dict; +@property(strong, nonatomic) NSArray *character; +@property(strong, nonatomic) NSDictionary *separatorList; +@property(strong, nonatomic) NSArray *ignoreIdx; +@property(strong, nonatomic) NSDictionary *dictList; - (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDictionary *)separatorList dictPathList:(NSDictionary *)dictPathList; - (void)loadDictionariesWithDictPathList:(NSDictionary *)dictPathList; From 9d2c9a1ab635ab646cb48e4b701578f40701dab4 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 23 Jan 2025 15:47:13 +0100 Subject: [PATCH 06/19] refactor: first part of native ocr code refactor --- .../computer-vision/screens/OCRScreen.tsx | 10 +- ios/RnExecutorch/OCR.mm | 39 ++-- ios/RnExecutorch/models/ocr/Detector.h | 10 +- ios/RnExecutorch/models/ocr/Detector.mm | 23 ++- .../models/ocr/RecognitionHandler.h | 6 + .../models/ocr/RecognitionHandler.mm | 189 ++++++------------ ios/RnExecutorch/models/ocr/Recognizer.h | 11 + ios/RnExecutorch/models/ocr/Recognizer.mm | 105 ++++++++++ .../models/ocr/utils/DetectorUtils.mm | 32 ++- ios/RnExecutorch/models/ocr/utils/OCRUtils.h | 1 - ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 27 +-- .../models/ocr/utils/RecognizerUtils.h | 11 + .../models/ocr/utils/RecognizerUtils.mm | 84 ++++++++ src/OCR.ts | 30 ++- src/native/NativeOCR.ts | 7 +- src/types/ocr.ts | 7 + 16 files changed, 385 insertions(+), 207 deletions(-) create mode 100644 ios/RnExecutorch/models/ocr/Recognizer.h create mode 100644 ios/RnExecutorch/models/ocr/Recognizer.mm create mode 100644 ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h create mode 100644 ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm create mode 100644 src/types/ocr.ts diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 2006ffaa75..e6e4dbcdc1 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -21,7 +21,12 @@ export const OCRScreen = ({ const [detectedText, setDetectedText] = useState(''); const model = useOCR({ detectorSource: require('../assets/models/xnnpack_craft.pte'), - recognizerSources: [require('../assets/models/xnnpack_crnn_128.pte')], + recognizerSources: { + recognizer512: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', + recognizer256: require('../assets/models/xnnpack_crnn_256.pte'), + recognizer128: require('../assets/models/xnnpack_crnn_128.pte'), + }, }); const handleCameraPress = async (isCamera: boolean) => { @@ -45,7 +50,6 @@ export const OCRScreen = ({ output.forEach((detection: any) => { txt += detection.text + ' '; }); - console.log(txt); setDetectedText(txt); } catch (e) { console.error(e); @@ -57,7 +61,7 @@ export const OCRScreen = ({ ); } - console.log(imageDimensions?.width, imageDimensions?.height); + return ( <> diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index 5cb23fe084..66957e879f 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -8,33 +8,40 @@ @implementation OCR { Detector *detector; - RecognitionHandler *handler; + RecognitionHandler *recognitionHandler; } RCT_EXPORT_MODULE() - (void)loadModule:(NSString *)detectorSource - recognizerSources:(NSArray *)recognizerSources +recognizerSource512:(NSString *)recognizerSource512 +recognizerSource256:(NSString *)recognizerSource256 +recognizerSource128:(NSString *)recognizerSource128 language:(NSString *)language resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { detector = [[Detector alloc] init]; - [detector loadModel:[NSURL URLWithString:detectorSource] - completion:^(BOOL success, NSNumber *errorCode) { - if (success) { - resolve(errorCode); + recognitionHandler = [[RecognitionHandler alloc] init]; + + [detector loadModel:[NSURL URLWithString:detectorSource] completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]]}]; + reject(@"init_module_error", @"Failed to initialize detector module", error); return; } - NSError *error = [NSError - errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{ - NSLocalizedDescriptionKey : [NSString - stringWithFormat:@"%ld", (long)[errorCode longValue]] + [self->recognitionHandler loadRecognizers:recognizerSource512 mediumRecognizerPath:recognizerSource256 smallRecognizerPath:recognizerSource128 completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { + if (allModelsLoaded) { + resolve(@(YES)); + } else { + NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]]}]; + reject(@"init_recognizer_error", @"Failed to initialize one or more recognizer models", error); + } }]; - reject(@"init_module_error", error.localizedDescription, error); - return; }]; } @@ -44,9 +51,9 @@ - (void)forward:(NSString *)input @try { cv::Mat image = [ImageProcessor readImage:input]; NSArray* result = [detector runModel:image]; + cv::Size detectorSize = [detector getModelImageSize]; cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); - handler = [[RecognitionHandler alloc] init]; - result = [handler recognize:result imgGray:image desiredWidth:1280 desiredHeight:1280]; + result = [self->recognitionHandler recognize:result imgGray:image desiredWidth:detectorSize.width desiredHeight:detectorSize.height]; resolve(result); } @catch (NSException *exception) { reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason], diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 1c8619bd38..9c026e3289 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,8 +1,16 @@ #import "BaseModel.h" #import "opencv2/opencv.hpp" -@interface Detector : BaseModel +const float textThreshold = 0.7; +const float linkThreshold = 0.4; +const float lowText = 0.4; +const float yCenterThs = 0.5; +const float heightThs = 0.5; +const float widthThs = 0.5; +const float addMargin = 0.1; +const int minSize = 20; +@interface Detector : BaseModel - (cv::Size)getModelImageSize; - (NSArray *)preprocess:(cv::Mat &)input; - (NSArray *)postprocess:(NSArray *)output; diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 6038cc101f..4633cd1d17 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -6,15 +6,22 @@ @implementation Detector { cv::Size originalSize; + cv::Size modelSize; } - (cv::Size)getModelImageSize{ + if(!modelSize.empty()) { + return modelSize; + } + NSArray * inputShape = [module getInputShape: @0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; int height = [heightNumber intValue]; int width = [widthNumber intValue]; + modelSize = cv::Size(height, width); + return cv::Size(height, width); } @@ -38,11 +45,11 @@ - (NSArray *)postprocess:(NSArray *)output { cv::Mat scoreTextCV; cv::Mat scoreLinkCV; + cv::Size modelImageSize = [self getModelImageSize]; + scoreTextCV = [DetectorUtils arrayToMat:scoreText width:modelImageSize.width / 2 height:modelImageSize.height / 2]; + scoreLinkCV = [DetectorUtils arrayToMat:scoreLink width:modelImageSize.width / 2 height:modelImageSize.height / 2]; - scoreTextCV = [DetectorUtils arrayToMat:scoreText width:640 height:640]; - scoreLinkCV = [DetectorUtils arrayToMat:scoreLink width:640 height:640]; - - NSArray* boxes = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:0.7 linkThreshold:0.4 lowText:0.4]; + NSArray* boxes = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:textThreshold linkThreshold:linkThreshold lowText:lowText]; NSMutableArray *single_img_result = [NSMutableArray array]; for (NSUInteger i = 0; i < [boxes count]; i++) { NSArray *box = boxes[i]; @@ -57,19 +64,17 @@ - (NSArray *)postprocess:(NSArray *)output { [single_img_result addObject:boxArray]; } - NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result ycenterThs:0.5 heightThs:0.5 widthThs:0.5 addMargin:0.1]; + NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result ycenterThs:yCenterThs heightThs:heightThs widthThs:widthThs addMargin:addMargin]; NSMutableArray *boxesToKeep = [NSMutableArray array]; for (NSArray *box in horizontalList) { - if (MAX([box[1] intValue] - [box[0] intValue], [box[3] intValue] - [box[2] intValue]) >= 20) { - + if (MAX([box[1] intValue] - [box[0] intValue], [box[3] intValue] - [box[2] intValue]) >= minSize) { [boxesToKeep addObject:box]; } } - horizontalList = [NSMutableArray arrayWithArray:boxesToKeep]; - return horizontalList; + return boxesToKeep; } - (NSArray *)runModel:(cv::Mat &)input { diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index 84e55b9d43..b3d367c9f6 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -1,7 +1,13 @@ #import "opencv2/opencv.hpp" +const int modelHeight = 64; +const int largeModelWidth = 512; +const int mediumModelWidth = 256; +const int smallModelWidth = 128; + @interface RecognitionHandler : NSObject +- (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath completion:(void (^)(BOOL, NSNumber *))completion; - (NSArray *)recognize:(NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; @end diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 7a1903fcd3..4a5abc9187 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -1,75 +1,71 @@ #import "RecognitionHandler.h" #import -#import "./utils/OCRUtils.h" +#import "ExecutorchLib/ETModel.h" +#import "ExecutorchLib/ETModel.h" #import "../../utils/ImageProcessor.h" +#import "../../utils/Fetcher.h" +#import "./utils/RecognizerUtils.h" +#import "./utils/OCRUtils.h" #import "./utils/CTCLabelConverter.h" -#import "ExecutorchLib/ETModel.h" +#import "Recognizer.h" -@implementation RecognitionHandler +@implementation RecognitionHandler { + Recognizer *recognizerLarge; + Recognizer *recognizerMedium; + Recognizer *recognizerSmall; + CTCLabelConverter *converter; +} -- (NSArray *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix { - NSMutableArray *maxIndices = [NSMutableArray array]; - - for (int i = 0; i < matrix.rows; i++) { - double maxVal; - cv::Point maxLoc; - cv::minMaxLoc(matrix.row(i), NULL, &maxVal, NULL, &maxLoc); - [maxIndices addObject:@(maxLoc.x)]; +- (instancetype)init { + self = [super init]; + if (self) { + recognizerLarge = [[Recognizer alloc] init]; + recognizerMedium = [[Recognizer alloc] init]; + recognizerSmall = [[Recognizer alloc] init]; + NSString *dictPath = [[NSBundle mainBundle] pathForResource:@"en" ofType:@"txt"]; + converter = [[CTCLabelConverter alloc] initWithCharacters:@"0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" separatorList:@{} dictPathList:@{@"en": dictPath}]; } - - return [maxIndices copy]; + return self; } - -- (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { - cv::Mat result = matrix.clone(); +- (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath completion:(void (^)(BOOL, NSNumber *))completion { + dispatch_group_t group = dispatch_group_create(); + __block BOOL allSuccessful = YES; - for (int i = 0; i < matrix.rows; i++) { - float divisor = [vector[i] floatValue]; - for (int j = 0; j < matrix.cols; j++) { - result.at(i, j) /= divisor; - } + NSArray *recognizers = @[recognizerLarge, recognizerMedium, recognizerSmall]; + NSArray *paths = @[largeRecognizerPath, mediumRecognizerPath, smallRecognizerPath]; + + for (NSInteger i = 0; i < recognizers.count; i++) { + Recognizer *recognizer = recognizers[i]; + NSString *path = paths[i]; + + dispatch_group_enter(group); + [recognizer loadModel:[NSURL URLWithString: path] completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + allSuccessful = NO; + dispatch_group_leave(group); + completion(NO, errorCode); + return; + } + dispatch_group_leave(group); + }]; } - return result; -} - -- (cv::Mat)softmax:(cv::Mat) inputs { - cv::Mat maxVal; - cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); - cv::Mat expInputs; - cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); - cv::Mat sumExp; - cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); - cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); - return softmaxOutput; + dispatch_group_notify(group, dispatch_get_main_queue(), ^{ + if (allSuccessful) { + completion(YES, @(0)); + } + }); } - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { - const float newRatioH = (float)desiredHeight / imgGray.rows; - const float newRatioW = (float)desiredWidth / imgGray.cols; - float resizeRatio = MIN(newRatioH, newRatioW); - const int newWidth = imgGray.cols * resizeRatio; - const int newHeight = imgGray.rows * resizeRatio; - const int deltaW = desiredWidth - newWidth; - const int deltaH = desiredHeight - newHeight; - const int top = deltaH / 2; - const int left= deltaW / 2; - float heightRatio = (float)imgGray.rows / desiredHeight; - float widthRatio = (float)imgGray.cols / desiredWidth; - resizeRatio = MAX(heightRatio, widthRatio); - - NSString *modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_512" ofType:@"pte"]; - ETModel *recognizer_512 = [[ETModel alloc] init]; - [recognizer_512 loadModel:modelPath]; - ETModel *recognizer_256 = [[ETModel alloc] init]; - modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_256" ofType:@"pte"]; - [recognizer_256 loadModel:modelPath]; - ETModel *recognizer_128 = [[ETModel alloc] init]; - modelPath = [[NSBundle mainBundle] pathForResource:@"xnnpack_crnn_128" ofType:@"pte"]; - [recognizer_128 loadModel:modelPath]; + NSDictionary* ratioAndPadding = [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols height:imgGray.rows desiredWidth:desiredWidth desiredHeight:desiredHeight]; + int left = [ratioAndPadding[@"left"] intValue]; + int top = [ratioAndPadding[@"top"] intValue]; + float resizeRatio = [ratioAndPadding[@"resizeRatio"] floatValue]; imgGray = [OCRUtils resizeWithPadding:imgGray desiredWidth:desiredWidth desiredHeight:desiredHeight]; + NSMutableArray *predictions = [NSMutableArray array]; for (NSArray *box in horizontalList) { int maximum_y = imgGray.rows; @@ -79,85 +75,26 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir int x_max = MIN([box[1] intValue], maximum_x); int y_min = MAX(0, [box[2] intValue]); int y_max = MIN([box[3] intValue], maximum_y); - cv::Mat croppedImage = [OCRUtils getCroppedImage:x_max x_min:x_min y_max:y_max y_min:y_min image:imgGray modelHeight:64]; + + cv::Mat croppedImage = [RecognizerUtils getCroppedImage:x_max x_min:x_min y_max:y_max y_min:y_min image:imgGray modelHeight:modelHeight]; croppedImage = [OCRUtils normalizeForRecognizer:croppedImage adjustContrast:0.0]; - NSArray* modelInput = [ImageProcessor matToArrayForGrayscale:croppedImage]; - NSArray *result; - if(croppedImage.cols >= 512) { - result = [recognizer_512 forward:modelInput shape:[recognizer_512 getInputShape:0] inputType:[recognizer_512 getInputType:0]]; - } else if (croppedImage.cols >= 256) { - result = [recognizer_256 forward:modelInput shape:[recognizer_256 getInputShape:0] inputType:[recognizer_256 getInputType:0]]; + NSArray *result; + if(croppedImage.cols >= largeModelWidth) { + result = [recognizerLarge runModel:croppedImage]; + } else if (croppedImage.cols >= mediumModelWidth) { + result = [recognizerMedium runModel: croppedImage]; } else { - result = [recognizer_128 forward:modelInput shape:[recognizer_128 getInputShape:0] inputType:[recognizer_128 getInputType:0]]; + result = [recognizerSmall runModel: croppedImage]; } - NSInteger totalNumbers = [result.firstObject count]; - NSInteger numRows = (totalNumbers + 96) / 97; + NSNumber *confidenceScore = [result objectAtIndex:1]; + NSArray *pred_index = [result objectAtIndex:0]; - cv::Mat resultMat = cv::Mat::zeros(numRows, 97, CV_32F); - - NSInteger counter = 0; - NSInteger currentRow = 0; - - for (NSNumber *num in result.firstObject) { - resultMat.at(currentRow, counter) = [num floatValue]; - - counter++; - if (counter >= 97) { - counter = 0; - currentRow++; - } - } - - cv::Mat probabilities = [self softmax:resultMat]; - NSMutableArray* pred_norm = [NSMutableArray arrayWithCapacity:probabilities.rows]; - for(int i = 0; i < probabilities.rows; i++) { - float sum = 0.0; - for(int j = 0; j < 97; j++) { - sum += probabilities.at(i, j); - } - [pred_norm addObject:@(sum)]; - } - - probabilities = [self divideMatrix:probabilities byVector:pred_norm]; - NSString *dictPath = [[NSBundle mainBundle] pathForResource:@"en" ofType:@"txt"]; - CTCLabelConverter *converter = [[CTCLabelConverter alloc] initWithCharacters:@"0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" separatorList:@{} dictPathList:@{@"en": dictPath}]; - NSArray* preds_index = [self indicesOfMaxValuesInMatrix:probabilities]; - NSArray* decodedTexts = [converter decodeGreedyWithTextIndex:preds_index length:(int)(preds_index.count)]; - NSMutableArray *valuesArray = [NSMutableArray array]; - NSMutableArray *indicesArray = [NSMutableArray array]; - for (int i = 0; i < probabilities.rows; i++) { - double maxVal = 0; - cv::Point maxLoc; - cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); - - [valuesArray addObject:@(maxVal)]; - [indicesArray addObject:@(maxLoc.x)]; - } - - NSMutableArray *predsMaxProb = [NSMutableArray array]; - - for (NSUInteger index = 0; index < indicesArray.count; index++) { - NSNumber *indicator = indicesArray[index]; - if ([indicator intValue] != 0) { - [predsMaxProb addObject:valuesArray[index]]; - } - } - - - if (predsMaxProb.count == 0) { - [predsMaxProb addObject:@(0)]; - } - - double product = 1.0; - for (NSNumber *prob in predsMaxProb) { - product *= [prob doubleValue]; - } + NSArray* decodedTexts = [converter decodeGreedyWithTextIndex:pred_index length:(int)(pred_index.count)]; - double confidenceScore = pow(product, 2.0 / sqrt(predsMaxProb.count)); - NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": @{@"x1": @((int)((x_min - left) * resizeRatio)), @"x2": @((int)((x_max - left) * resizeRatio)), @"y1": @((int)((y_min - top) * resizeRatio)), @"y2":@((int)((y_max - top) * resizeRatio))}, @"score": @(confidenceScore)}; + NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": @{@"x1": @((int)((x_min - left) * resizeRatio)), @"x2": @((int)((x_max - left) * resizeRatio)), @"y1": @((int)((y_min - top) * resizeRatio)), @"y2":@((int)((y_max - top) * resizeRatio))}, @"score": confidenceScore}; [predictions addObject:res]; } diff --git a/ios/RnExecutorch/models/ocr/Recognizer.h b/ios/RnExecutorch/models/ocr/Recognizer.h new file mode 100644 index 0000000000..f47769996a --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Recognizer.h @@ -0,0 +1,11 @@ +#import "BaseModel.h" +#import "opencv2/opencv.hpp" + +@interface Recognizer : BaseModel + +- (cv::Size)getModelImageSize; +- (NSArray *)preprocess:(cv::Mat &)input; +- (NSArray *)postprocess:(NSArray *)output; +- (NSArray *)runModel:(cv::Mat &)input; + +@end diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm new file mode 100644 index 0000000000..f4e0fd7ca1 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -0,0 +1,105 @@ +#import "opencv2/opencv.hpp" +#import "Recognizer.h" +#import "../../utils/ImageProcessor.h" +#import "utils/OCRUtils.h" +#import "RecognizerUtils.h" + +@implementation Recognizer { + cv::Size originalSize; +} + +- (cv::Size)getModelImageSize{ + NSArray * inputShape = [module getInputShape: @0]; + NSNumber *widthNumber = inputShape.lastObject; + NSNumber *heightNumber = inputShape[inputShape.count - 2]; + + int height = [heightNumber intValue]; + int width = [widthNumber intValue]; + return cv::Size(height, width); +} + +- (NSArray *)preprocess:(cv::Mat &)input { + return [NSArray init]; +} + +- (NSArray *)postprocess:(NSArray *)output { + NSInteger totalNumbers = [output.firstObject count]; + NSInteger numRows = (totalNumbers + 96) / 97; + + cv::Mat resultMat = cv::Mat::zeros(numRows, 97, CV_32F); + + NSInteger counter = 0; + NSInteger currentRow = 0; + + for (NSNumber *num in output.firstObject) { + resultMat.at(currentRow, counter) = [num floatValue]; + + counter++; + if (counter >= 97) { + counter = 0; + currentRow++; + } + } + + cv::Mat probabilities = [RecognizerUtils softmax:resultMat]; + NSMutableArray* pred_norm = [NSMutableArray arrayWithCapacity:probabilities.rows]; + for(int i = 0; i < probabilities.rows; i++) { + float sum = 0.0; + for(int j = 0; j < 97; j++) { + sum += probabilities.at(i, j); + } + [pred_norm addObject:@(sum)]; + } + + probabilities = [RecognizerUtils divideMatrix:probabilities byVector:pred_norm]; + NSArray* preds_index = [RecognizerUtils indicesOfMaxValuesInMatrix:probabilities]; + + NSMutableArray *valuesArray = [NSMutableArray array]; + NSMutableArray *indicesArray = [NSMutableArray array]; + for (int i = 0; i < probabilities.rows; i++) { + double maxVal = 0; + cv::Point maxLoc; + cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); + + [valuesArray addObject:@(maxVal)]; + [indicesArray addObject:@(maxLoc.x)]; + } + + NSMutableArray *predsMaxProb = [NSMutableArray array]; + + for (NSUInteger index = 0; index < indicesArray.count; index++) { + NSNumber *indicator = indicesArray[index]; + if ([indicator intValue] != 0) { + [predsMaxProb addObject:valuesArray[index]]; + } + } + + + if (predsMaxProb.count == 0) { + [predsMaxProb addObject:@(0)]; + } + + double product = 1.0; + for (NSNumber *prob in predsMaxProb) { + product *= [prob doubleValue]; + } + + double confidenceScore = pow(product, 2.0 / sqrt(predsMaxProb.count)); + + NSMutableArray* result = [[NSMutableArray alloc] init]; + + [result addObject:preds_index]; + [result addObject: @(confidenceScore)]; + + return result; +} + +- (NSArray *)runModel:(cv::Mat &)input { + NSArray* modelInput = [ImageProcessor matToArrayForGrayscale:input]; + NSArray *modelResult = [self forward:modelInput]; + NSArray *result = [self postprocess:modelResult]; + + return result; +} + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 21aa76a9cb..82dec21566 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -28,11 +28,10 @@ + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { NSMutableArray *scoreText = [[NSMutableArray alloc] init]; NSMutableArray *scoreLink = [[NSMutableArray alloc] init]; - // Iterate through the array and distribute elements to scoreText or scoreLink [array enumerateObjectsUsingBlock:^(id element, NSUInteger idx, BOOL *stop) { - if (idx % 2 == 0) { // Even index, belongs to scoreText + if (idx % 2 == 0) { [scoreText addObject:element]; - } else { // Odd index, belongs to scoreLink + } else { [scoreLink addObject:element]; } }]; @@ -98,14 +97,12 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold if (sy < 0) sy = 0; if (ex >= img_w) ex = img_w; if (ey >= img_h) ey = img_h; - cv::Rect roi(sx, sy, ex - sx, ey - sy); // (x, y, width, height) of the ROI + cv::Rect roi(sx, sy, ex - sx, ey - sy); - // Generate the kernel for dilation cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + niter, 1 + niter)); - cv::Mat roiSegMap = segMap(roi); // Reference a sub-region of segMap for dilation + cv::Mat roiSegMap = segMap(roi); cv::dilate(roiSegMap, roiSegMap, kernel); - // Find contours and fit rotated rectangle std::vector> contours; cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); if (!contours.empty()) { @@ -158,7 +155,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold [horizontalList sortUsingComparator:^NSComparisonResult(NSMutableArray *obj1, NSMutableArray *obj2) { return [obj1[4] compare:obj2[4]]; // Sorting by y_center }]; - + NSMutableArray *newBox = [NSMutableArray array]; NSMutableArray *bHeight = [NSMutableArray array]; NSMutableArray *bYcenter = [NSMutableArray array]; @@ -184,17 +181,16 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold [combinedList addObject:[newBox copy]]; for (NSArray *boxes in combinedList) { - if ([boxes count] == 1) { // If there is only one box in the line - NSLog(@"One in line"); + if ([boxes count] == 1) { NSArray *box = boxes[0]; int margin = (int)(addMargin * MIN([box[1] floatValue] - [box[0] floatValue], [box[5] floatValue])); [mergedList addObject:@[@([box[0] intValue] - margin), @([box[1] intValue] + margin), @([box[2] intValue] - margin), @([box[3] intValue] + margin)]]; - } else { // There are multiple boxes to be merged + } else { NSArray *sortedBoxes = [boxes sortedArrayUsingComparator:^NSComparisonResult(NSArray *obj1, NSArray *obj2) { - return [@([obj1[0] intValue]) compare:@([obj2[0] intValue])]; // Sort boxes by x_min + return [@([obj1[0] intValue]) compare:@([obj2[0] intValue])]; }]; NSMutableArray *mergedBox = [NSMutableArray array]; @@ -228,14 +224,12 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold [mergedBox addObject:newBox]; } - // Create merged boxes from merged box array for (NSArray *mbox in mergedBox) { if ([mbox count] != 1) { - NSNumber *xMin = [mbox[0] objectAtIndex:0]; // minX - NSNumber *xMax = [mbox[0] objectAtIndex:1]; // maxX - NSNumber *yMin = [mbox[0] objectAtIndex:2]; // minY - NSNumber *yMax = [mbox[0] objectAtIndex:3]; // maxY - // Iterate over each box in the mbox array to find min and max + NSNumber *xMin = [mbox[0] objectAtIndex:0]; + NSNumber *xMax = [mbox[0] objectAtIndex:1]; + NSNumber *yMin = [mbox[0] objectAtIndex:2]; + NSNumber *yMax = [mbox[0] objectAtIndex:3]; for (NSArray *box in mbox) { if ([box[0] intValue] < [xMin intValue]) { xMin = box[0]; @@ -267,7 +261,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold } } } - NSLog(@"Merged List Count: %lu", (unsigned long)[mergedList count]); + return mergedList; } diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h index ab9c6859ef..2fefc167f8 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -3,7 +3,6 @@ @interface OCRUtils : NSObject + (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; -+ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight; + (CGFloat)calculateRatioWithWidth:(int)width height:(int)height; + (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight; + (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast; diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index 24110f8e31..1d0536a8b4 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -66,25 +66,6 @@ + (CGFloat)calculateRatioWithWidth:(int)width height:(int)height { return img; } -+ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight { - cv::Rect region(x_min, y_min, x_max - x_min, y_max - y_min); - cv::Mat crop_img = image(region); - - int width = x_max - x_min; - int height = y_max - y_min; - - CGFloat ratio = [OCRUtils calculateRatioWithWidth:width height:height]; - int new_width = (int)(modelHeight * ratio); - - if (new_width == 0) { - return crop_img; // Return nil if calculated new_width is zero to avoid further processing - } - - crop_img = [OCRUtils computeRatioAndResize:crop_img width:width height:height modelHeight:modelHeight]; - - return crop_img; -} - + (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target { double contrast = 0.0; int high = 0; @@ -118,8 +99,7 @@ + (CGFloat)calculateRatioWithWidth:(int)width height:(int)height { + (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast { if (adjustContrast > 0) { - image = [OCRUtils adjustContrastGrey:image target:adjustContrast]; // Make sure this method exists and works as expected - } + image = [OCRUtils adjustContrastGrey:image target:adjustContrast]; } int desiredWidth = 128; if (image.cols >= 512) { @@ -130,9 +110,8 @@ + (CGFloat)calculateRatioWithWidth:(int)width height:(int)height { image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth desiredHeight:64]; - // Normalization: (image / 255.0 - 0.5) * 2.0 - image.convertTo(image, CV_32F, 1.0 / 255.0); // Scale pixel values to [0,1] - image = (image - 0.5) * 2.0; // Shift to [-1,1] + image.convertTo(image, CV_32F, 1.0 / 255.0); + image = (image - 0.5) * 2.0; return image; } diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h new file mode 100644 index 0000000000..b9c1278282 --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -0,0 +1,11 @@ +#import + +@interface RecognizerUtils : NSObject + ++ (NSArray *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix; ++ (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; ++ (cv::Mat)softmax:(cv::Mat)inputs; ++ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; ++ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight; + +@end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm new file mode 100644 index 0000000000..75a3585cbb --- /dev/null +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -0,0 +1,84 @@ +#import "RecognizerUtils.h" +#import "OCRUtils.h" + +@implementation RecognizerUtils + ++ (NSArray *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix { + NSMutableArray *maxIndices = [NSMutableArray array]; + + for (int i = 0; i < matrix.rows; i++) { + double maxVal; + cv::Point maxLoc; + cv::minMaxLoc(matrix.row(i), NULL, &maxVal, NULL, &maxLoc); + [maxIndices addObject:@(maxLoc.x)]; + } + + return [maxIndices copy]; +} + ++ (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { + cv::Mat result = matrix.clone(); + + for (int i = 0; i < matrix.rows; i++) { + float divisor = [vector[i] floatValue]; + for (int j = 0; j < matrix.cols; j++) { + result.at(i, j) /= divisor; + } + } + + return result; +} + ++ (cv::Mat)softmax:(cv::Mat) inputs { + cv::Mat maxVal; + cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); + cv::Mat expInputs; + cv::exp(inputs - cv::repeat(maxVal, 1, inputs.cols), expInputs); + cv::Mat sumExp; + cv::reduce(expInputs, sumExp, 1, cv::REDUCE_SUM, CV_32F); + cv::Mat softmaxOutput = expInputs / cv::repeat(sumExp, 1, inputs.cols); + return softmaxOutput; +} + ++ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { + const float newRatioH = (float)desiredHeight / height; + const float newRatioW = (float)desiredWidth / width; + float resizeRatio = MIN(newRatioH, newRatioW); + const int newWidth = width * resizeRatio; + const int newHeight = height * resizeRatio; + const int deltaW = desiredWidth - newWidth; + const int deltaH = desiredHeight - newHeight; + const int top = deltaH / 2; + const int left = deltaW / 2; + float heightRatio = (float)height / desiredHeight; + float widthRatio = (float)width / desiredWidth; + + resizeRatio = MAX(heightRatio, widthRatio); + + return @{ + @"resizeRatio": @(resizeRatio), + @"top": @(top), + @"left": @(left), + }; +} + ++ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight { + cv::Rect region(x_min, y_min, x_max - x_min, y_max - y_min); + cv::Mat crop_img = image(region); + + int width = x_max - x_min; + int height = y_max - y_min; + + CGFloat ratio = [OCRUtils calculateRatioWithWidth:width height:height]; + int new_width = (int)(modelHeight * ratio); + + if (new_width == 0) { + return crop_img; + } + + crop_img = [OCRUtils computeRatioAndResize:crop_img width:width height:height modelHeight:modelHeight]; + + return crop_img; +} + +@end diff --git a/src/OCR.ts b/src/OCR.ts index 43c36ecbff..a7649c708a 100644 --- a/src/OCR.ts +++ b/src/OCR.ts @@ -3,12 +3,13 @@ import { ResourceSource } from './types/common'; import { OCR } from './native/RnExecutorchModules'; import { ETError, getError } from './Error'; import { Image } from 'react-native'; +import { OCRDetection } from './types/ocr'; interface OCRModule { error: string | null; isReady: boolean; isGenerating: boolean; - forward: (input: string) => Promise; + forward: (input: string) => Promise; } const getModelPath = (source: ResourceSource) => { @@ -24,7 +25,9 @@ export const useOCR = ({ language = 'en', }: { detectorSource: ResourceSource; - recognizerSources: ResourceSource[]; + recognizerSources: { + [key: string]: ResourceSource; + }; language?: string; }): OCRModule => { const [error, setError] = useState(null); @@ -33,13 +36,27 @@ export const useOCR = ({ useEffect(() => { const loadModel = async () => { - if (!detectorSource || recognizerSources.length === 0) return; + if (!detectorSource || Object.keys(recognizerSources).length === 0) + return; const detectorPath = getModelPath(detectorSource); - const recognizerPaths = recognizerSources.map(getModelPath); + const recognizerPaths: { + [key: string]: string; + } = {}; + Object.keys(recognizerSources).forEach((key: string) => { + recognizerPaths[key] = getModelPath( + recognizerSources[key] as ResourceSource + ); + }); try { setIsReady(false); - await OCR.loadModule(detectorPath, recognizerPaths, language); + await OCR.loadModule( + detectorPath, + recognizerPaths.recognizer512, + recognizerPaths.recognizer256, + recognizerPaths.recognizer128, + language + ); setIsReady(true); } catch (e) { setError(getError(e)); @@ -47,7 +64,8 @@ export const useOCR = ({ }; loadModel(); - }, [detectorSource, language, recognizerSources.length]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [detectorSource, language, JSON.stringify(recognizerSources)]); const forward = async (input: string) => { if (!isReady) { diff --git a/src/native/NativeOCR.ts b/src/native/NativeOCR.ts index f819ca8975..4ddbfe353e 100644 --- a/src/native/NativeOCR.ts +++ b/src/native/NativeOCR.ts @@ -1,13 +1,16 @@ import type { TurboModule } from 'react-native'; import { TurboModuleRegistry } from 'react-native'; +import { OCRDetection } from '../types/ocr'; export interface Spec extends TurboModule { loadModule( detectorSource: string, - recognizerSources: string[], + recognizerSource512: string, + recognizerSource256: string, + recognizerSource128: string, language: string ): Promise; - forward(input: string): Promise; + forward(input: string): Promise; } export default TurboModuleRegistry.get('OCR'); diff --git a/src/types/ocr.ts b/src/types/ocr.ts new file mode 100644 index 0000000000..697dbcaaa9 --- /dev/null +++ b/src/types/ocr.ts @@ -0,0 +1,7 @@ +import { Bbox } from './object_detection'; + +export interface OCRDetection { + bbox: Bbox; + text: string; + score: number; +} From 22c7aa2dbfedf0a8f46048bdb0b2935c6becb0d2 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Sun, 26 Jan 2025 16:40:03 +0100 Subject: [PATCH 07/19] reformat: reformat of detector and recognizer code --- .../computer-vision/screens/OCRScreen.tsx | 7 +- ios/RnExecutorch/OCR.mm | 18 ++- ios/RnExecutorch/models/ocr/Detector.h | 7 +- ios/RnExecutorch/models/ocr/Detector.mm | 43 ++++--- .../models/ocr/RecognitionHandler.mm | 18 +-- ios/RnExecutorch/models/ocr/Recognizer.h | 5 +- ios/RnExecutorch/models/ocr/Recognizer.mm | 94 +++++--------- .../models/ocr/utils/CTCLabelConverter.h | 2 +- .../models/ocr/utils/CTCLabelConverter.mm | 19 ++- .../models/ocr/utils/DetectorUtils.h | 3 +- .../models/ocr/utils/DetectorUtils.mm | 67 ++++------ ios/RnExecutorch/models/ocr/utils/OCRUtils.h | 4 - ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 70 ----------- .../models/ocr/utils/RecognizerUtils.h | 8 +- .../models/ocr/utils/RecognizerUtils.mm | 118 ++++++++++++++++-- src/OCR.ts | 31 +++-- src/native/NativeOCR.ts | 6 +- 17 files changed, 254 insertions(+), 266 deletions(-) diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index e6e4dbcdc1..0b097293fd 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -22,10 +22,10 @@ export const OCRScreen = ({ const model = useOCR({ detectorSource: require('../assets/models/xnnpack_craft.pte'), recognizerSources: { - recognizer512: + recognizerLarge: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', - recognizer256: require('../assets/models/xnnpack_crnn_256.pte'), - recognizer128: require('../assets/models/xnnpack_crnn_128.pte'), + recognizerMedium: require('../assets/models/xnnpack_crnn_256.pte'), + recognizerSmall: require('../assets/models/xnnpack_crnn_128.pte'), }, }); @@ -38,6 +38,7 @@ export const OCRScreen = ({ if (typeof uri === 'string') { setImageUri(uri as string); setResults([]); + setDetectedText(''); } }; diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index 66957e879f..56ee04c0e1 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -1,7 +1,6 @@ -#import "OCR.h" -#import "models/object_detection/SSDLiteLargeModel.hpp" #import #import +#import "OCR.h" #import "utils/ImageProcessor.h" #import "models/ocr/Detector.h" #import "models/ocr/RecognitionHandler.h" @@ -14,9 +13,9 @@ @implementation OCR { RCT_EXPORT_MODULE() - (void)loadModule:(NSString *)detectorSource -recognizerSource512:(NSString *)recognizerSource512 -recognizerSource256:(NSString *)recognizerSource256 -recognizerSource128:(NSString *)recognizerSource128 +recognizerSourceLarge:(NSString *)recognizerSourceLarge +recognizerSourceMedium:(NSString *)recognizerSourceMedium +recognizerSourceSmall:(NSString *)recognizerSourceSmall language:(NSString *)language resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { @@ -32,7 +31,7 @@ - (void)loadModule:(NSString *)detectorSource return; } - [self->recognitionHandler loadRecognizers:recognizerSource512 mediumRecognizerPath:recognizerSource256 smallRecognizerPath:recognizerSource128 completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { + [self->recognitionHandler loadRecognizers:recognizerSourceLarge mediumRecognizerPath:recognizerSourceMedium smallRecognizerPath:recognizerSourceSmall completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { if (allModelsLoaded) { resolve(@(YES)); } else { @@ -48,6 +47,13 @@ - (void)loadModule:(NSString *)detectorSource - (void)forward:(NSString *)input resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { + /* + The OCR consists of two phases: + 1. Detection - detecting text regions in the image, the result of this phase is a list of bounding boxes. + 2. Recognition - recognizing the text in the bounding boxes, the result is a list of strings and corresponding confidence scores. + + Recognition uses three models, each model is resposible for recognizing text of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). + */ @try { cv::Mat image = [ImageProcessor readImage:input]; NSArray* result = [detector runModel:image]; diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 9c026e3289..5d7df1b8e1 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,5 +1,5 @@ -#import "BaseModel.h" #import "opencv2/opencv.hpp" +#import "BaseModel.h" const float textThreshold = 0.7; const float linkThreshold = 0.4; @@ -9,11 +9,12 @@ const float heightThs = 0.5; const float widthThs = 0.5; const float addMargin = 0.1; const int minSize = 20; +const cv::Scalar mean(0.485, 0.456, 0.406); +const cv::Scalar variance(0.229, 0.224, 0.225); @interface Detector : BaseModel + - (cv::Size)getModelImageSize; -- (NSArray *)preprocess:(cv::Mat &)input; -- (NSArray *)postprocess:(NSArray *)output; - (NSArray *)runModel:(cv::Mat &)input; @end diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 4633cd1d17..e499f78461 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -1,9 +1,13 @@ -#import "opencv2/opencv.hpp" #import "Detector.h" #import "../../utils/ImageProcessor.h" #import "utils/DetectorUtils.h" #import "utils/OCRUtils.h" +/* + The model used as detector is based on CRAFT (Character Region Awareness for Text Detection) paper. + https://arxiv.org/pdf/1904.01941 + */ + @implementation Detector { cv::Size originalSize; cv::Size modelSize; @@ -26,17 +30,30 @@ @implementation Detector { } - (NSArray *)preprocess:(cv::Mat &)input { + /* + Detector as an input accepts tensor with a shape of [1, 3, 1280, 1280]. + Due to big influence of resize to quality of recognition the image preserves original + aspect ratio and the missing parts are filled with padding. + */ self->originalSize = cv::Size(input.cols, input.rows); cv::Size modelImageSize = [self getModelImageSize]; cv::Mat resizedImage; resizedImage = [OCRUtils resizeWithPadding:input desiredWidth:modelImageSize.width desiredHeight:modelImageSize.height]; - NSArray *modelInput = [DetectorUtils matToNSArray: resizedImage]; + NSArray *modelInput = [ImageProcessor matToNSArray: resizedImage mean:mean variance:variance]; return modelInput; } - (NSArray *)postprocess:(NSArray *)output { + /* + The output of the model consists of two matrices: + 1. ScoreText(Score map) - The probability of a region containing character + 2. ScoreLink(Affinity map) - The probability of a region being a part of a text line + Both matrices are 640x640 + + The result of this step is a list of bounding boxes that contain text. + */ NSArray *predictions = [output objectAtIndex:0]; NSDictionary *splittedData = [DetectorUtils splitInterleavedNSArray:predictions]; @@ -46,25 +63,13 @@ - (NSArray *)postprocess:(NSArray *)output { cv::Mat scoreTextCV; cv::Mat scoreLinkCV; cv::Size modelImageSize = [self getModelImageSize]; - scoreTextCV = [DetectorUtils arrayToMat:scoreText width:modelImageSize.width / 2 height:modelImageSize.height / 2]; - scoreLinkCV = [DetectorUtils arrayToMat:scoreLink width:modelImageSize.width / 2 height:modelImageSize.height / 2]; - NSArray* boxes = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:textThreshold linkThreshold:linkThreshold lowText:lowText]; - NSMutableArray *single_img_result = [NSMutableArray array]; - for (NSUInteger i = 0; i < [boxes count]; i++) { - NSArray *box = boxes[i]; - NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *value in box) { - CGPoint point = [value CGPointValue]; - point.x *= 2; - point.y *= 2; - [boxArray addObject:@((int)point.x)]; - [boxArray addObject:@((int)point.y)]; - } - [single_img_result addObject:boxArray]; - } + scoreTextCV = [ImageProcessor arrayToMatGray:scoreText width:modelImageSize.width / 2 height:modelImageSize.height / 2]; + scoreLinkCV = [ImageProcessor arrayToMatGray:scoreLink width:modelImageSize.width / 2 height:modelImageSize.height / 2]; - NSArray* horizontalList = [DetectorUtils groupTextBox:single_img_result ycenterThs:yCenterThs heightThs:heightThs widthThs:widthThs addMargin:addMargin]; + NSArray* horizontalList = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:textThreshold linkThreshold:linkThreshold lowText:lowText]; + horizontalList = [DetectorUtils restoreBboxRatio:horizontalList]; + horizontalList = [DetectorUtils groupTextBox:horizontalList ycenterThs:yCenterThs heightThs:heightThs widthThs:widthThs addMargin:addMargin]; NSMutableArray *boxesToKeep = [NSMutableArray array]; diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 4a5abc9187..0fdd17180a 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -1,13 +1,17 @@ -#import "RecognitionHandler.h" #import #import "ExecutorchLib/ETModel.h" -#import "ExecutorchLib/ETModel.h" -#import "../../utils/ImageProcessor.h" #import "../../utils/Fetcher.h" -#import "./utils/RecognizerUtils.h" -#import "./utils/OCRUtils.h" +#import "../../utils/ImageProcessor.h" #import "./utils/CTCLabelConverter.h" +#import "./utils/OCRUtils.h" +#import "./utils/RecognizerUtils.h" #import "Recognizer.h" +#import "RecognitionHandler.h" + +/* + RecognitionHandler class is responsible for loading and choosing the appropriate recognizer model based on the input image size, + it also handles converting the model output to text. + */ @implementation RecognitionHandler { Recognizer *recognizerLarge; @@ -79,7 +83,7 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir cv::Mat croppedImage = [RecognizerUtils getCroppedImage:x_max x_min:x_min y_max:y_max y_min:y_min image:imgGray modelHeight:modelHeight]; - croppedImage = [OCRUtils normalizeForRecognizer:croppedImage adjustContrast:0.0]; + croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:0.0]; NSArray *result; if(croppedImage.cols >= largeModelWidth) { result = [recognizerLarge runModel:croppedImage]; @@ -92,7 +96,7 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir NSNumber *confidenceScore = [result objectAtIndex:1]; NSArray *pred_index = [result objectAtIndex:0]; - NSArray* decodedTexts = [converter decodeGreedyWithTextIndex:pred_index length:(int)(pred_index.count)]; + NSArray* decodedTexts = [converter decodeGreedy:pred_index length:(int)(pred_index.count)]; NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": @{@"x1": @((int)((x_min - left) * resizeRatio)), @"x2": @((int)((x_max - left) * resizeRatio)), @"y1": @((int)((y_min - top) * resizeRatio)), @"y2":@((int)((y_max - top) * resizeRatio))}, @"score": confidenceScore}; [predictions addObject:res]; diff --git a/ios/RnExecutorch/models/ocr/Recognizer.h b/ios/RnExecutorch/models/ocr/Recognizer.h index f47769996a..63047ac00a 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.h +++ b/ios/RnExecutorch/models/ocr/Recognizer.h @@ -1,11 +1,8 @@ -#import "BaseModel.h" #import "opencv2/opencv.hpp" +#import "BaseModel.h" @interface Recognizer : BaseModel -- (cv::Size)getModelImageSize; -- (NSArray *)preprocess:(cv::Mat &)input; -- (NSArray *)postprocess:(NSArray *)output; - (NSArray *)runModel:(cv::Mat &)input; @end diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm index f4e0fd7ca1..10af525ce6 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.mm +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -1,15 +1,19 @@ -#import "opencv2/opencv.hpp" #import "Recognizer.h" +#import "RecognizerUtils.h" #import "../../utils/ImageProcessor.h" #import "utils/OCRUtils.h" -#import "RecognizerUtils.h" + +/* + The model used as detector is based on CRNN paper. + https://arxiv.org/pdf/1507.05717 + */ @implementation Recognizer { cv::Size originalSize; } - (cv::Size)getModelImageSize{ - NSArray * inputShape = [module getInputShape: @0]; + NSArray *inputShape = [module getInputShape: @0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; @@ -18,84 +22,46 @@ @implementation Recognizer { return cv::Size(height, width); } +- (cv::Size)getModelOutputSize{ + NSArray *outputShape = [module getOutputShape: @0]; + NSNumber *widthNumber = outputShape.lastObject; + NSNumber *heightNumber = outputShape[outputShape.count - 2]; + + int height = [heightNumber intValue]; + int width = [widthNumber intValue]; + return cv::Size(height, width); +} + - (NSArray *)preprocess:(cv::Mat &)input { - return [NSArray init]; + return [ImageProcessor matToNSArrayGray:input]; } - (NSArray *)postprocess:(NSArray *)output { - NSInteger totalNumbers = [output.firstObject count]; - NSInteger numRows = (totalNumbers + 96) / 97; - - cv::Mat resultMat = cv::Mat::zeros(numRows, 97, CV_32F); - + int modelOutputHeight = [self getModelOutputSize].height; + NSInteger numElements = [output.firstObject count]; + NSInteger numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight; + cv::Mat resultMat = cv::Mat::zeros(numRows, modelOutputHeight, CV_32F); NSInteger counter = 0; NSInteger currentRow = 0; - for (NSNumber *num in output.firstObject) { resultMat.at(currentRow, counter) = [num floatValue]; - counter++; - if (counter >= 97) { - counter = 0; - currentRow++; + if (counter >= modelOutputHeight) { + counter = 0; currentRow++; } } cv::Mat probabilities = [RecognizerUtils softmax:resultMat]; - NSMutableArray* pred_norm = [NSMutableArray arrayWithCapacity:probabilities.rows]; - for(int i = 0; i < probabilities.rows; i++) { - float sum = 0.0; - for(int j = 0; j < 97; j++) { - sum += probabilities.at(i, j); - } - [pred_norm addObject:@(sum)]; - } - - probabilities = [RecognizerUtils divideMatrix:probabilities byVector:pred_norm]; - NSArray* preds_index = [RecognizerUtils indicesOfMaxValuesInMatrix:probabilities]; + NSMutableArray *predsNorm = [RecognizerUtils sumProbabilityRows:probabilities modelOutputHeight:modelOutputHeight]; + probabilities = [RecognizerUtils divideMatrix:probabilities byVector:predsNorm]; + NSArray *maxValuesIndices = [RecognizerUtils findMaxValuesAndIndices:probabilities]; + double confidenceScore = [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] indicesArray:maxValuesIndices[1]]; - NSMutableArray *valuesArray = [NSMutableArray array]; - NSMutableArray *indicesArray = [NSMutableArray array]; - for (int i = 0; i < probabilities.rows; i++) { - double maxVal = 0; - cv::Point maxLoc; - cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); - - [valuesArray addObject:@(maxVal)]; - [indicesArray addObject:@(maxLoc.x)]; - } - - NSMutableArray *predsMaxProb = [NSMutableArray array]; - - for (NSUInteger index = 0; index < indicesArray.count; index++) { - NSNumber *indicator = indicesArray[index]; - if ([indicator intValue] != 0) { - [predsMaxProb addObject:valuesArray[index]]; - } - } - - - if (predsMaxProb.count == 0) { - [predsMaxProb addObject:@(0)]; - } - - double product = 1.0; - for (NSNumber *prob in predsMaxProb) { - product *= [prob doubleValue]; - } - - double confidenceScore = pow(product, 2.0 / sqrt(predsMaxProb.count)); - - NSMutableArray* result = [[NSMutableArray alloc] init]; - - [result addObject:preds_index]; - [result addObject: @(confidenceScore)]; - - return result; + return @[maxValuesIndices[1], @(confidenceScore)]; } - (NSArray *)runModel:(cv::Mat &)input { - NSArray* modelInput = [ImageProcessor matToArrayForGrayscale:input]; + NSArray* modelInput = [self preprocess:input]; NSArray *modelResult = [self forward:modelInput]; NSArray *result = [self postprocess:modelResult]; diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h index 7cc167f2b1..037782f4bb 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h @@ -10,6 +10,6 @@ - (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDictionary *)separatorList dictPathList:(NSDictionary *)dictPathList; - (void)loadDictionariesWithDictPathList:(NSDictionary *)dictPathList; -- (NSArray *)decodeGreedyWithTextIndex:(NSArray *)textIndex length:(NSInteger)length; +- (NSArray *)decodeGreedy:(NSArray *)textIndex length:(NSInteger)length; @end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm index e8b8d0fbc5..644a29e213 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm @@ -47,36 +47,31 @@ - (void)loadDictionariesWithDictPathList:(NSDictionary * _dictList = [tempDictList copy]; } -- (NSArray *)decodeGreedyWithTextIndex:(NSArray *)textIndex length:(NSInteger)length { +- (NSArray *)decodeGreedy:(NSArray *)textIndex length:(NSInteger)length { NSMutableArray *texts = [NSMutableArray array]; NSUInteger index = 0; - // Loop until you've processed all characters while (index < textIndex.count) { - NSUInteger segmentLength = MIN(length, textIndex.count - index); // Calculate size of the current segment + NSUInteger segmentLength = MIN(length, textIndex.count - index); NSRange range = NSMakeRange(index, segmentLength); NSArray *subArray = [textIndex subarrayWithRange:range]; NSMutableString *text = [NSMutableString string]; NSNumber *lastChar = nil; - // Creating mutable arrays to store states like in Python with `a` and `b` - NSMutableArray *isNotRepeated = [NSMutableArray arrayWithObject:@YES]; // First character is always not repeated + NSMutableArray *isNotRepeated = [NSMutableArray arrayWithObject:@YES]; NSMutableArray *isNotIgnored = [NSMutableArray array]; for (NSUInteger i = 0; i < subArray.count; i++) { NSNumber *currentChar = subArray[i]; - // Check if character is repeated - if (i > 0) { // From second character onward + if (i > 0) { [isNotRepeated addObject:@(![lastChar isEqualToNumber:currentChar])]; } - // Check if the current character is in the ignore list [isNotIgnored addObject:@(![self.ignoreIdx containsObject:currentChar])]; - lastChar = currentChar; // Update lastChar to current character + lastChar = currentChar; } - // Combine `isNotRepeated` and `isNotIgnored` conditions just like combining 'a' and 'b' in Python for (NSUInteger j = 0; j < subArray.count; j++) { if ([isNotRepeated[j] boolValue] && [isNotIgnored[j] boolValue]) { NSUInteger charIndex = [subArray[j] unsignedIntegerValue]; @@ -85,9 +80,9 @@ - (void)loadDictionariesWithDictPathList:(NSDictionary * } [texts addObject:text.copy]; - index += segmentLength; // Move index forward + index += segmentLength; - if (segmentLength < length) { // If reached the end of textIndex + if (segmentLength < length) { break; } } diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index 716d0e402e..3e4c78226e 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -2,9 +2,8 @@ @interface DetectorUtils : NSObject -+ (NSArray *)matToNSArray:(const cv::Mat &)mat; + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array; -+ (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height; ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes; + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; + (NSArray *> *)groupTextBox:(NSArray *> *)polys ycenterThs:(CGFloat)ycenterThs diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 82dec21566..2dbaa69677 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -2,28 +2,6 @@ @implementation DetectorUtils -+ (NSArray *)matToNSArray:(const cv::Mat &)mat { - cv::Scalar mean(0.485, 0.456, 0.406); - cv::Scalar variance(0.229, 0.224, 0.225); - - int pixelCount = mat.cols * mat.rows; - NSMutableArray *floatArray = [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; - for (NSUInteger k = 0; k < pixelCount * 3; k++) { - [floatArray addObject:@0.0]; - } - - for (int i = 0; i < pixelCount; i++) { - int row = i / mat.cols; - int col = i % mat.cols; - cv::Vec3b pixel = mat.at(row, col); - floatArray[0 * pixelCount + i] = @((pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0)); - floatArray[1 * pixelCount + i] = @((pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0)); - floatArray[2 * pixelCount + i] = @((pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0)); - } - - return floatArray; -} - + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { NSMutableArray *scoreText = [[NSMutableArray alloc] init]; NSMutableArray *scoreLink = [[NSMutableArray alloc] init]; @@ -39,21 +17,28 @@ + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { return @{@"ScoreText": scoreText, @"ScoreLink": scoreLink}; } -+ (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { - cv::Mat mat(height, width, CV_32F); - - int pixelCount = width * height; - for (int i = 0; i < pixelCount; i++) { - int row = i / width; - int col = i % width; - float value = [array[i] floatValue]; - mat.at(row, col) = value; ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes { + NSMutableArray *result = [NSMutableArray array]; + for (NSUInteger i = 0; i < [boxes count]; i++) { + NSArray *box = boxes[i]; + NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *value in box) { + CGPoint point = [value CGPointValue]; + point.x *= 2; + point.y *= 2; + [boxArray addObject:@((int)point.x)]; + [boxArray addObject:@((int)point.y)]; + } + [result addObject:boxArray]; } - return mat; + return result; } + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText { + /* + The getDetBoxes function uses scoreMap and affinityMap to generate bounding boxes which contain text. + */ cv::Mat textmapCopy = textmap.clone(); cv::Mat linkmapCopy = linkmap.clone(); int img_h = textmap.rows; @@ -79,11 +64,9 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold cv::minMaxLoc(textmapCopy, NULL, &maxVal, NULL, NULL, mask); if (maxVal < textThreshold) continue; - // Create mask for segmented area cv::Mat segMap = cv::Mat::zeros(textmap.size(), CV_8U); segMap.setTo(255, (labels == i)); - // Dilate the segmented area int x = stats.at(i, cv::CC_STAT_LEFT); int y = stats.at(i, cv::CC_STAT_TOP); int w = stats.at(i, cv::CC_STAT_WIDTH); @@ -134,26 +117,23 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold for (NSArray *poly in polys) { NSArray *xCoords = @[poly[0], poly[2], poly[4], poly[6]]; - // Array of y coordinates NSArray *yCoords = @[poly[1], poly[3], poly[5], poly[7]]; - // Calculating max and min values for x coordinates NSNumber *xMaxNumber = [xCoords valueForKeyPath:@"@max.self"]; NSNumber *xMinNumber = [xCoords valueForKeyPath:@"@min.self"]; - float xMax = [xMaxNumber floatValue]; // Convert max float value to int - float xMin = [xMinNumber floatValue]; // Convert min float value to int + float xMax = [xMaxNumber floatValue]; + float xMin = [xMinNumber floatValue]; - // Calculating max and min values for y coordinates NSNumber *yMaxNumber = [yCoords valueForKeyPath:@"@max.self"]; NSNumber *yMinNumber = [yCoords valueForKeyPath:@"@min.self"]; - float yMax = [yMaxNumber floatValue]; // Convert max float value to int - float yMin = [yMinNumber floatValue]; // Convert min float value to int + float yMax = [yMaxNumber floatValue]; + float yMin = [yMinNumber floatValue]; [horizontalList addObject:[@[@(xMin), @(xMax), @(yMin), @(yMax), @((yMin + yMax) / 2.0), @(yMax - yMin)] mutableCopy]]; } [horizontalList sortUsingComparator:^NSComparisonResult(NSMutableArray *obj1, NSMutableArray *obj2) { - return [obj1[4] compare:obj2[4]]; // Sorting by y_center + return [obj1[4] compare:obj2[4]]; }]; NSMutableArray *newBox = [NSMutableArray array]; @@ -208,7 +188,6 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold float meanHeight = [[bHeight valueForKeyPath:@"@avg.self"] floatValue]; if (fabs(meanHeight - currHeight) < heightThs * meanHeight && ([box[0] intValue] - xMax) < widthThs * ([box[3] intValue] - [box[2] intValue])) { - // merge condition is met [bHeight addObject:box[5]]; xMax = [box[1] intValue]; [newBox addObject:box]; @@ -261,7 +240,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold } } } - + return mergedList; } diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h index 2fefc167f8..0304ad37e3 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -3,9 +3,5 @@ @interface OCRUtils : NSObject + (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; -+ (CGFloat)calculateRatioWithWidth:(int)width height:(int)height; -+ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight; -+ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast; -+ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; @end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index 1d0536a8b4..3bec624450 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -15,7 +15,6 @@ @implementation OCRUtils cv::Mat resizedImg; cv::resize(img, resizedImg, cv::Size(newWidth, newHeight), 0, 0, cv::INTER_AREA); - // Estimating the background color by sampling from the corners of the image const int cornerPatchSize = MAX(1, MIN(height, width) / 30); std::vector corners = { img(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), @@ -47,73 +46,4 @@ @implementation OCRUtils return centeredImg; } -+ (CGFloat)calculateRatioWithWidth:(int)width height:(int)height { - CGFloat ratio = (CGFloat)width / (CGFloat)height; - if (ratio < 1.0) { - ratio = 1.0 / ratio; - } - return ratio; -} - -+ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight { - CGFloat ratio = (CGFloat)width / (CGFloat)height; - if (ratio < 1.0) { - ratio = [self calculateRatioWithWidth:width height:height]; - cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, 0, cv::INTER_LANCZOS4); - } else { - cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, 0, cv::INTER_LANCZOS4); - } - return img; -} - -+ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target { - double contrast = 0.0; - int high = 0; - int low = 255; - - // Calculate existing contrast, high, and low - for (int i = 0; i < img.rows; ++i) { - for (int j = 0; j < img.cols; ++j) { - uchar pixel = img.at(i, j); - high = MAX(high, pixel); - low = MIN(low, pixel); - } - } - contrast = (high - low) / 255.0; - - // Adjust contrast if below the target - if (contrast < target) { - double ratio = 200.0 / MAX(10, high - low); - img.convertTo(img, CV_32F); // Convert to float for scaling operations - img = ((img - low + 25) * ratio); - - // Clipping values to ensure they remain within valid range - cv::threshold(img, img, 255, 255, cv::THRESH_TRUNC); // Cap values at 255 - cv::threshold(img, img, 0, 0, cv::THRESH_TOZERO); // Ensure no negative values - - img.convertTo(img, CV_8U); // Convert back to 8-bit pixel values - } - - return img; -} - -+ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast { - if (adjustContrast > 0) { - image = [OCRUtils adjustContrastGrey:image target:adjustContrast]; } - - int desiredWidth = 128; - if (image.cols >= 512) { - desiredWidth = 512; - } else if (image.cols >= 256) { - desiredWidth = 256; - } - - image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth desiredHeight:64]; - - image.convertTo(image, CV_32F, 1.0 / 255.0); - image = (image - 0.5) * 2.0; - - return image; -} - @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index b9c1278282..d3a9965383 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -2,10 +2,16 @@ @interface RecognizerUtils : NSObject -+ (NSArray *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix; ++ (CGFloat)calculateRatio:(int)width height:(int)height; ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight; ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast; ++ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; + (cv::Mat)softmax:(cv::Mat)inputs; + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; + (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight; ++ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities modelOutputHeight:(int)modelOutputHeight; ++ (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; ++ (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray; @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 75a3585cbb..9199a828b4 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -1,19 +1,72 @@ -#import "RecognizerUtils.h" #import "OCRUtils.h" +#import "RecognizerUtils.h" @implementation RecognizerUtils -+ (NSArray *)indicesOfMaxValuesInMatrix:(cv::Mat)matrix { - NSMutableArray *maxIndices = [NSMutableArray array]; ++ (CGFloat)calculateRatio:(int)width height:(int)height { + CGFloat ratio = (CGFloat)width / (CGFloat)height; + if (ratio < 1.0) { + ratio = 1.0 / ratio; + } + return ratio; +} + ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight { + CGFloat ratio = (CGFloat)width / (CGFloat)height; + if (ratio < 1.0) { + ratio = [self calculateRatio:width height:height]; + cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, 0, cv::INTER_LANCZOS4); + } else { + cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, 0, cv::INTER_LANCZOS4); + } + return img; +} + ++ (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target { + double contrast = 0.0; + int high = 0; + int low = 255; - for (int i = 0; i < matrix.rows; i++) { - double maxVal; - cv::Point maxLoc; - cv::minMaxLoc(matrix.row(i), NULL, &maxVal, NULL, &maxLoc); - [maxIndices addObject:@(maxLoc.x)]; + for (int i = 0; i < img.rows; ++i) { + for (int j = 0; j < img.cols; ++j) { + uchar pixel = img.at(i, j); + high = MAX(high, pixel); + low = MIN(low, pixel); + } + } + contrast = (high - low) / 255.0; + + if (contrast < target) { + double ratio = 200.0 / MAX(10, high - low); + img.convertTo(img, CV_32F); + img = ((img - low + 25) * ratio); + + cv::threshold(img, img, 255, 255, cv::THRESH_TRUNC); + cv::threshold(img, img, 0, 0, cv::THRESH_TOZERO); + + img.convertTo(img, CV_8U); + } + + return img; +} + ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast { + if (adjustContrast > 0) { + image = [self adjustContrastGrey:image target:adjustContrast]; } + + int desiredWidth = 128; + if (image.cols >= 512) { + desiredWidth = 512; + } else if (image.cols >= 256) { + desiredWidth = 256; } - return [maxIndices copy]; + image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth desiredHeight:64]; + + image.convertTo(image, CV_32F, 1.0 / 255.0); + image = (image - 0.5) * 2.0; + + return image; } + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { @@ -69,16 +122,59 @@ + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height int width = x_max - x_min; int height = y_max - y_min; - CGFloat ratio = [OCRUtils calculateRatioWithWidth:width height:height]; + CGFloat ratio = [self calculateRatio:width height:height]; int new_width = (int)(modelHeight * ratio); if (new_width == 0) { return crop_img; } - crop_img = [OCRUtils computeRatioAndResize:crop_img width:width height:height modelHeight:modelHeight]; + crop_img = [self computeRatioAndResize:crop_img width:width height:height modelHeight:modelHeight]; return crop_img; } ++ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities modelOutputHeight:(int)modelOutputHeight { + NSMutableArray *predsNorm = [NSMutableArray arrayWithCapacity:probabilities.rows]; + for (int i = 0; i < probabilities.rows; i++) { + float sum = 0.0; + for (int j = 0; j < modelOutputHeight; j++) { + sum += probabilities.at(i, j); + } + [predsNorm addObject:@(sum)]; + } + return predsNorm; +} + ++ (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities { + NSMutableArray *valuesArray = [NSMutableArray array]; + NSMutableArray *indicesArray = [NSMutableArray array]; + for (int i = 0; i < probabilities.rows; i++) { + double maxVal = 0; + cv::Point maxLoc; + cv::minMaxLoc(probabilities.row(i), NULL, &maxVal, NULL, &maxLoc); + [valuesArray addObject:@(maxVal)]; + [indicesArray addObject:@(maxLoc.x)]; + } + return @[valuesArray, indicesArray]; +} + ++ (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray { + NSMutableArray *predsMaxProb = [NSMutableArray array]; + for (NSUInteger index = 0; index < indicesArray.count; index++) { + NSNumber *indicator = indicesArray[index]; + if ([indicator intValue] != 0) { + [predsMaxProb addObject:valuesArray[index]]; + } + } + if (predsMaxProb.count == 0) { + [predsMaxProb addObject:@(0)]; + } + double product = 1.0; + for (NSNumber *prob in predsMaxProb) { + product *= [prob doubleValue]; + } + return pow(product, 2.0 / sqrt(predsMaxProb.count)); +} + @end diff --git a/src/OCR.ts b/src/OCR.ts index a7649c708a..f76c015af9 100644 --- a/src/OCR.ts +++ b/src/OCR.ts @@ -26,7 +26,9 @@ export const useOCR = ({ }: { detectorSource: ResourceSource; recognizerSources: { - [key: string]: ResourceSource; + recognizerLarge: ResourceSource; + recognizerMedium: ResourceSource; + recognizerSmall: ResourceSource; }; language?: string; }): OCRModule => { @@ -40,21 +42,26 @@ export const useOCR = ({ return; const detectorPath = getModelPath(detectorSource); - const recognizerPaths: { - [key: string]: string; - } = {}; - Object.keys(recognizerSources).forEach((key: string) => { - recognizerPaths[key] = getModelPath( - recognizerSources[key] as ResourceSource - ); - }); + const recognizerPaths = {} as { + recognizerLarge: string; + recognizerMedium: string; + recognizerSmall: string; + }; + + for (const key in recognizerSources) { + if (recognizerSources.hasOwnProperty(key)) { + recognizerPaths[key as keyof typeof recognizerPaths] = getModelPath( + recognizerSources[key as keyof typeof recognizerSources] + ); + } + } try { setIsReady(false); await OCR.loadModule( detectorPath, - recognizerPaths.recognizer512, - recognizerPaths.recognizer256, - recognizerPaths.recognizer128, + recognizerPaths.recognizerLarge, + recognizerPaths.recognizerMedium, + recognizerPaths.recognizerSmall, language ); setIsReady(true); diff --git a/src/native/NativeOCR.ts b/src/native/NativeOCR.ts index 4ddbfe353e..16ca3c0b7f 100644 --- a/src/native/NativeOCR.ts +++ b/src/native/NativeOCR.ts @@ -5,9 +5,9 @@ import { OCRDetection } from '../types/ocr'; export interface Spec extends TurboModule { loadModule( detectorSource: string, - recognizerSource512: string, - recognizerSource256: string, - recognizerSource128: string, + recognizerSourceLarge: string, + recognizerSourceMedium: string, + recognizerSourceSmall: string, language: string ): Promise; forward(input: string): Promise; From 3db318195a6c14596ecab7f2d35d6c07f9b11c9d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 27 Jan 2025 09:19:07 +0100 Subject: [PATCH 08/19] refactor: split groupTextBoxes into smaller functions, add functions to ImageProcessor --- .../models/ocr/utils/DetectorUtils.h | 3 + .../models/ocr/utils/DetectorUtils.mm | 109 ++++++++---------- ios/RnExecutorch/utils/ImageProcessor.h | 6 +- ios/RnExecutorch/utils/ImageProcessor.mm | 59 ++++++---- 4 files changed, 94 insertions(+), 83 deletions(-) diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index 3e4c78226e..272a43f1a4 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -5,6 +5,9 @@ + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array; + (NSArray *)restoreBboxRatio:(NSArray *)boxes; + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; ++ (NSMutableArray *)prepareBoxesFromPolys:(NSArray *)polys; ++ (NSMutableArray *)combineBoxes:(NSMutableArray *)boxes withYCenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs; ++ (NSArray *)mergeBoxes:(NSMutableArray *)combinedBoxes withWidthThs:(CGFloat)widthThs addMargin:(CGFloat)addMargin; + (NSArray *> *)groupTextBox:(NSArray *> *)polys ycenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 2dbaa69677..51a5587419 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -36,9 +36,6 @@ + (NSArray *)restoreBboxRatio:(NSArray *)boxes { } + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText { - /* - The getDetBoxes function uses scoreMap and affinityMap to generate bounding boxes which contain text. - */ cv::Mat textmapCopy = textmap.clone(); cv::Mat linkmapCopy = linkmap.clone(); int img_h = textmap.rows; @@ -104,72 +101,54 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold return detectedBoxes; } -+ (NSArray *> *)groupTextBox:(NSArray *> *)polys - ycenterThs:(CGFloat)ycenterThs - heightThs:(CGFloat)heightThs - widthThs:(CGFloat)widthThs - addMargin:(CGFloat)addMargin -{ - NSMutableArray *> *horizontalList = [NSMutableArray array]; - NSMutableArray *> *> *combinedList = [NSMutableArray array]; - NSMutableArray *> *mergedList = [NSMutableArray array]; - - for (NSArray *poly in polys) { ++ (NSMutableArray *)prepareBoxesFromPolys:(NSArray *)polys { + NSMutableArray *boxes = [NSMutableArray array]; + for (NSArray *poly in polys) { NSArray *xCoords = @[poly[0], poly[2], poly[4], poly[6]]; - NSArray *yCoords = @[poly[1], poly[3], poly[5], poly[7]]; - - NSNumber *xMaxNumber = [xCoords valueForKeyPath:@"@max.self"]; - NSNumber *xMinNumber = [xCoords valueForKeyPath:@"@min.self"]; - float xMax = [xMaxNumber floatValue]; - float xMin = [xMinNumber floatValue]; - - NSNumber *yMaxNumber = [yCoords valueForKeyPath:@"@max.self"]; - NSNumber *yMinNumber = [yCoords valueForKeyPath:@"@min.self"]; - float yMax = [yMaxNumber floatValue]; - float yMin = [yMinNumber floatValue]; - - [horizontalList addObject:[@[@(xMin), @(xMax), @(yMin), @(yMax), @((yMin + yMax) / 2.0), @(yMax - yMin)] mutableCopy]]; + NSNumber *xMin = [xCoords valueForKeyPath:@"@min.self"]; + NSNumber *xMax = [xCoords valueForKeyPath:@"@max.self"]; + NSNumber *yMin = [yCoords valueForKeyPath:@"@min.self"]; + NSNumber *yMax = [yCoords valueForKeyPath:@"@max.self"]; + [boxes addObject:@[xMin, xMax, yMin, yMax, @(([yMin floatValue] + [yMax floatValue]) / 2.0), @([yMax floatValue] - [yMin floatValue])]]; } + return boxes; +} + ++ (NSMutableArray *)combineBoxes:(NSMutableArray *)boxes withYCenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs { + NSMutableArray *combinedBoxes = [NSMutableArray array]; + NSMutableArray *currentGroup = [NSMutableArray array]; - [horizontalList sortUsingComparator:^NSComparisonResult(NSMutableArray *obj1, NSMutableArray *obj2) { - return [obj1[4] compare:obj2[4]]; - }]; - - NSMutableArray *newBox = [NSMutableArray array]; - NSMutableArray *bHeight = [NSMutableArray array]; - NSMutableArray *bYcenter = [NSMutableArray array]; - for (NSArray *box in horizontalList) { - if (newBox.count == 0) { - [bHeight addObject:box[5]]; - [bYcenter addObject:box[4]]; - [newBox addObject:box]; + for (NSArray *box in boxes) { + if (currentGroup.count == 0) { + [currentGroup addObject:box]; } else { - if (fabs([[bYcenter valueForKeyPath:@"@avg.self"] floatValue] - [box[4] floatValue]) < ycenterThs * [[bHeight valueForKeyPath:@"@avg.self"] floatValue]) { - [bHeight addObject:box[5]]; - [bYcenter addObject:box[4]]; - [newBox addObject:box]; + NSArray *lastBox = [currentGroup lastObject]; + BOOL closeYCenter = fabs([[lastBox objectAtIndex:4] floatValue] - [[box objectAtIndex:4] floatValue]) < ycenterThs * [[lastBox objectAtIndex:5] floatValue]; + if (closeYCenter) { + [currentGroup addObject:box]; } else { - [combinedList addObject:[newBox copy]]; - [newBox removeAllObjects]; - [newBox addObject:box]; - bHeight = [@[box[5]] mutableCopy]; - bYcenter = [@[box[4]] mutableCopy]; + [combinedBoxes addObject:[currentGroup copy]]; + currentGroup = [@[box] mutableCopy]; } } } + if (currentGroup.count > 0) { + [combinedBoxes addObject:[currentGroup copy]]; + } + return combinedBoxes; +} + ++ (NSArray *)mergeBoxes:(NSMutableArray *)combinedBoxes withWidthThs:(CGFloat)widthThs heightThs:(CGFloat)heightThs addMargin:(CGFloat)addMargin { + NSMutableArray *mergedList = [NSMutableArray array]; - [combinedList addObject:[newBox copy]]; - for (NSArray *boxes in combinedList) { - if ([boxes count] == 1) { - NSArray *box = boxes[0]; - int margin = (int)(addMargin * MIN([box[1] floatValue] - [box[0] floatValue], [box[5] floatValue])); - [mergedList addObject:@[@([box[0] intValue] - margin), - @([box[1] intValue] + margin), - @([box[2] intValue] - margin), - @([box[3] intValue] + margin)]]; + for (NSArray *group in combinedBoxes) { + if (group.count == 1) { + NSArray *box = group[0]; + float margin = addMargin * MIN([[box objectAtIndex:1] floatValue] - [[box objectAtIndex:0] floatValue], [[box objectAtIndex:5] floatValue]); + [mergedList addObject:@[@([[box objectAtIndex:0] floatValue] - margin), @([[box objectAtIndex:1] floatValue] + margin), @([[box objectAtIndex:2] floatValue] - margin), @([[box objectAtIndex:3] floatValue] + margin)]]; } else { - NSArray *sortedBoxes = [boxes sortedArrayUsingComparator:^NSComparisonResult(NSArray *obj1, NSArray *obj2) { + NSArray *sortedBoxes = [group sortedArrayUsingComparator:^NSComparisonResult(NSArray *obj1, NSArray *obj2) { return [@([obj1[0] intValue]) compare:@([obj2[0] intValue])]; }]; @@ -238,10 +217,22 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold @([box[3] intValue] + margin)]]; } } + } } - return mergedList; + return [mergedList copy]; +} + ++ (NSArray *> *)groupTextBox:(NSArray *> *)polys + ycenterThs:(CGFloat)ycenterThs + heightThs:(CGFloat)heightThs + widthThs:(CGFloat)widthThs + addMargin:(CGFloat)addMargin +{ + NSMutableArray *horizontalList = [self prepareBoxesFromPolys:polys]; + NSMutableArray *combinedList = [self combineBoxes:horizontalList withYCenterThs:ycenterThs heightThs:heightThs]; + return [self mergeBoxes:combinedList withWidthThs:widthThs heightThs:heightThs addMargin:addMargin]; } @end diff --git a/ios/RnExecutorch/utils/ImageProcessor.h b/ios/RnExecutorch/utils/ImageProcessor.h index e2c6d34651..c65182d0a6 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.h +++ b/ios/RnExecutorch/utils/ImageProcessor.h @@ -3,9 +3,13 @@ @interface ImageProcessor : NSObject ++ (NSArray *)matToNSArray:(const cv::Mat &)mat + mean:(cv::Scalar)mean + variance:(cv::Scalar)variance; + (NSArray *)matToNSArray:(const cv::Mat &)mat; + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height; -+ (NSArray *)matToArrayForGrayscale:(const cv::Mat &)mat; ++ (cv::Mat)arrayToMatGray:(NSArray *)array width:(int)width height:(int)height; ++ (NSArray *)matToNSArrayGray:(const cv::Mat &)mat; + (NSString *)saveToTempFile:(const cv::Mat &)image; + (cv::Mat)readImage:(NSString *)source; diff --git a/ios/RnExecutorch/utils/ImageProcessor.mm b/ios/RnExecutorch/utils/ImageProcessor.mm index 4b932663a5..a8617c262f 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.mm +++ b/ios/RnExecutorch/utils/ImageProcessor.mm @@ -4,6 +4,12 @@ @implementation ImageProcessor + (NSArray *)matToNSArray:(const cv::Mat &)mat { + return [ImageProcessor matToNSArray:mat mean:cv::Scalar(0.0, 0.0, 0.0) variance:cv::Scalar(1.0, 1.0, 1.0)]; +} + ++ (NSArray *)matToNSArray:(const cv::Mat &)mat + mean:(cv::Scalar)mean + variance:(cv::Scalar)variance { int pixelCount = mat.cols * mat.rows; NSMutableArray *floatArray = [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; for (NSUInteger k = 0; k < pixelCount * 3; k++) { @@ -14,32 +20,25 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat { int row = i / mat.cols; int col = i % mat.cols; cv::Vec3b pixel = mat.at(row, col); - floatArray[0 * pixelCount + i] = @(pixel[2] / 255.0f); - floatArray[1 * pixelCount + i] = @(pixel[1] / 255.0f); - floatArray[2 * pixelCount + i] = @(pixel[0] / 255.0f); + floatArray[0 * pixelCount + i] = @((pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0)); + floatArray[1 * pixelCount + i] = @((pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0)); + floatArray[2 * pixelCount + i] = @((pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0)); } return floatArray; } -+ (NSArray *)matToArrayForGrayscale:(const cv::Mat &)mat { - if (mat.empty() || mat.type() != CV_32F) { - NSLog(@"Invalid or empty matrix or matrix not of type CV_32F."); - return @[]; - } - - NSMutableArray *pixelArray = [[NSMutableArray alloc] initWithCapacity:mat.cols * mat.rows]; - - // Iterate through every pixel in the matrix - for (int row = 0; row < mat.rows; row++) { - for (int col = 0; col < mat.cols; col++) { - // Access and add the pixel value directly as a float, store as NSNumber - float pixelValue = mat.at(row, col); - [pixelArray addObject:@(pixelValue)]; - } ++ (NSArray *)matToNSArrayGray:(const cv::Mat &)mat { + NSMutableArray *pixelArray = [[NSMutableArray alloc] initWithCapacity:mat.cols * mat.rows]; + + for (int row = 0; row < mat.rows; row++) { + for (int col = 0; col < mat.cols; col++) { + float pixelValue = mat.at(row, col); + [pixelArray addObject:@(pixelValue)]; } - - return pixelArray; + } + + return pixelArray; } + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { @@ -62,6 +61,20 @@ + (NSArray *)matToArrayForGrayscale:(const cv::Mat &)mat { return mat; } ++ (cv::Mat)arrayToMatGray:(NSArray *)array width:(int)width height:(int)height { + cv::Mat mat(height, width, CV_32F); + + int pixelCount = width * height; + for (int i = 0; i < pixelCount; i++) { + int row = i / width; + int col = i % width; + float value = [array[i] floatValue]; + mat.at(row, col) = value; + } + + return mat; +} + + (NSString *)saveToTempFile:(const cv::Mat&)image { NSString *uniqueID = [[NSUUID UUID] UUIDString]; NSString *filename = [NSString stringWithFormat:@"rn_executorch_%@.png", uniqueID]; @@ -85,9 +98,9 @@ + (NSString *)saveToTempFile:(const cv::Mat&)image { //base64 NSArray *parts = [source componentsSeparatedByString:@","]; if ([parts count] < 2) { - @throw [NSException exceptionWithName:@"readImage_error" - reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] - userInfo:nil]; + @throw [NSException exceptionWithName:@"readImage_error" + reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] + userInfo:nil]; } NSString *encodedString = parts[1]; NSData *data = [[NSData alloc] initWithBase64EncodedString:encodedString options:NSDataBase64DecodingIgnoreUnknownCharacters]; From b7cab0a932020e8df9453ad47deb3b676c6ab64d Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 27 Jan 2025 09:22:25 +0100 Subject: [PATCH 09/19] fix: add missing argument in header file --- ios/RnExecutorch/models/ocr/utils/DetectorUtils.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index 272a43f1a4..afddf2ffa9 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -7,7 +7,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; + (NSMutableArray *)prepareBoxesFromPolys:(NSArray *)polys; + (NSMutableArray *)combineBoxes:(NSMutableArray *)boxes withYCenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs; -+ (NSArray *)mergeBoxes:(NSMutableArray *)combinedBoxes withWidthThs:(CGFloat)widthThs addMargin:(CGFloat)addMargin; ++ (NSArray *)mergeBoxes:(NSMutableArray *)combinedBoxes withWidthThs:(CGFloat)widthThs heightThs:(CGFloat)heightThs addMargin:(CGFloat)addMargin; + (NSArray *> *)groupTextBox:(NSArray *> *)polys ycenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs From ad813938695e4b440b861ff1f49cd9339de13963 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 27 Jan 2025 11:32:34 +0100 Subject: [PATCH 10/19] feat: automatically load list of words and symbols for converter --- .../computer-vision/screens/OCRScreen.tsx | 1 + ios/RnExecutorch/OCR.mm | 31 ++++++++++++------- .../models/ocr/RecognitionHandler.h | 1 + .../models/ocr/RecognitionHandler.mm | 6 ++-- src/Error.ts | 1 + src/OCR.ts | 24 ++++++++++---- src/constants/ocr/languageDicts.ts | 3 ++ src/constants/ocr/symbols.ts | 4 +++ src/native/NativeOCR.ts | 3 +- 9 files changed, 52 insertions(+), 22 deletions(-) create mode 100644 src/constants/ocr/languageDicts.ts create mode 100644 src/constants/ocr/symbols.ts diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 0b097293fd..5562ddb9fa 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -27,6 +27,7 @@ export const OCRScreen = ({ recognizerMedium: require('../assets/models/xnnpack_crnn_256.pte'), recognizerSmall: require('../assets/models/xnnpack_crnn_128.pte'), }, + language: 'en', }); const handleCameraPress = async (isCamera: boolean) => { diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index 56ee04c0e1..d7237242f8 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -1,6 +1,7 @@ #import #import #import "OCR.h" +#import "utils/Fetcher.h" #import "utils/ImageProcessor.h" #import "models/ocr/Detector.h" #import "models/ocr/RecognitionHandler.h" @@ -16,12 +17,11 @@ - (void)loadModule:(NSString *)detectorSource recognizerSourceLarge:(NSString *)recognizerSourceLarge recognizerSourceMedium:(NSString *)recognizerSourceMedium recognizerSourceSmall:(NSString *)recognizerSourceSmall - language:(NSString *)language + symbols:(NSString *)symbols + languageDictPath:(NSString *)languageDictPath resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { detector = [[Detector alloc] init]; - recognitionHandler = [[RecognitionHandler alloc] init]; - [detector loadModel:[NSURL URLWithString:detectorSource] completion:^(BOOL success, NSNumber *errorCode) { if (!success) { NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" @@ -30,16 +30,23 @@ - (void)loadModule:(NSString *)detectorSource reject(@"init_module_error", @"Failed to initialize detector module", error); return; } - - [self->recognitionHandler loadRecognizers:recognizerSourceLarge mediumRecognizerPath:recognizerSourceMedium smallRecognizerPath:recognizerSourceSmall completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { - if (allModelsLoaded) { - resolve(@(YES)); - } else { - NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]]}]; - reject(@"init_recognizer_error", @"Failed to initialize one or more recognizer models", error); + [Fetcher fetchResource:[NSURL URLWithString:languageDictPath] resourceType:ResourceType::TXT completionHandler:^(NSString *filePath, NSError *error) { + if (error) { + reject(@"init_module_error", @"Failed to initialize converter module", error); + return; } + + self->recognitionHandler = [[RecognitionHandler alloc] initWithSymbols:symbols languageDictPath:filePath]; + [self->recognitionHandler loadRecognizers:recognizerSourceLarge mediumRecognizerPath:recognizerSourceMedium smallRecognizerPath:recognizerSourceSmall completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { + if (allModelsLoaded) { + resolve(@(YES)); + } else { + NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]]}]; + reject(@"init_recognizer_error", @"Failed to initialize one or more recognizer models", error); + } + }]; }]; }]; } diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index b3d367c9f6..93dbebd82e 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -7,6 +7,7 @@ const int smallModelWidth = 128; @interface RecognitionHandler : NSObject +- (instancetype)initWithSymbols:(NSString *)symbols languageDictPath:(NSString *)languageDictPath; - (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath completion:(void (^)(BOOL, NSNumber *))completion; - (NSArray *)recognize:(NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 0fdd17180a..aca54e91e8 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -20,14 +20,14 @@ @implementation RecognitionHandler { CTCLabelConverter *converter; } -- (instancetype)init { +- (instancetype)initWithSymbols:(NSString *)symbols languageDictPath:(NSString *)languageDictPath { self = [super init]; if (self) { recognizerLarge = [[Recognizer alloc] init]; recognizerMedium = [[Recognizer alloc] init]; recognizerSmall = [[Recognizer alloc] init]; - NSString *dictPath = [[NSBundle mainBundle] pathForResource:@"en" ofType:@"txt"]; - converter = [[CTCLabelConverter alloc] initWithCharacters:@"0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz" separatorList:@{} dictPathList:@{@"en": dictPath}]; + + converter = [[CTCLabelConverter alloc] initWithCharacters:symbols separatorList:@{} dictPathList:@{@"key": languageDictPath}]; } return self; } diff --git a/src/Error.ts b/src/Error.ts index 767856393c..955b62a95e 100644 --- a/src/Error.ts +++ b/src/Error.ts @@ -4,6 +4,7 @@ export enum ETError { ModuleNotLoaded = 0x66, FileWriteFailed = 0x67, ModelGenerating = 0x68, + LanguageNotSupported = 0x69, InvalidModelSource = 0xff, // ExecuTorch mapped errors diff --git a/src/OCR.ts b/src/OCR.ts index f76c015af9..ed18831f46 100644 --- a/src/OCR.ts +++ b/src/OCR.ts @@ -4,6 +4,8 @@ import { OCR } from './native/RnExecutorchModules'; import { ETError, getError } from './Error'; import { Image } from 'react-native'; import { OCRDetection } from './types/ocr'; +import { symbols } from './constants/ocr/symbols'; +import { languageDicts } from './constants/ocr/languageDicts'; interface OCRModule { error: string | null; @@ -12,7 +14,7 @@ interface OCRModule { forward: (input: string) => Promise; } -const getModelPath = (source: ResourceSource) => { +const getResourcePath = (source: ResourceSource) => { if (typeof source === 'number') { return Image.resolveAssetSource(source).uri; } @@ -41,20 +43,29 @@ export const useOCR = ({ if (!detectorSource || Object.keys(recognizerSources).length === 0) return; - const detectorPath = getModelPath(detectorSource); + const detectorPath = getResourcePath(detectorSource); const recognizerPaths = {} as { recognizerLarge: string; recognizerMedium: string; recognizerSmall: string; }; + if (!symbols[language] || !languageDicts[language]) { + setError(getError(ETError.LanguageNotSupported)); + return; + } + for (const key in recognizerSources) { if (recognizerSources.hasOwnProperty(key)) { - recognizerPaths[key as keyof typeof recognizerPaths] = getModelPath( - recognizerSources[key as keyof typeof recognizerSources] - ); + recognizerPaths[key as keyof typeof recognizerPaths] = + getResourcePath( + recognizerSources[key as keyof typeof recognizerSources] + ); } } + + const languageDictPath = getResourcePath(languageDicts[language]); + try { setIsReady(false); await OCR.loadModule( @@ -62,7 +73,8 @@ export const useOCR = ({ recognizerPaths.recognizerLarge, recognizerPaths.recognizerMedium, recognizerPaths.recognizerSmall, - language + symbols.default + symbols[language]!, + languageDictPath ); setIsReady(true); } catch (e) { diff --git a/src/constants/ocr/languageDicts.ts b/src/constants/ocr/languageDicts.ts new file mode 100644 index 0000000000..7eaa1f146b --- /dev/null +++ b/src/constants/ocr/languageDicts.ts @@ -0,0 +1,3 @@ +export const languageDicts: { [key: string]: string } = { + en: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/en.txt', +}; diff --git a/src/constants/ocr/symbols.ts b/src/constants/ocr/symbols.ts new file mode 100644 index 0000000000..4ead5265f6 --- /dev/null +++ b/src/constants/ocr/symbols.ts @@ -0,0 +1,4 @@ +export const symbols: { [key: string]: string } = { + default: '0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ €', + en: 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', +}; diff --git a/src/native/NativeOCR.ts b/src/native/NativeOCR.ts index 16ca3c0b7f..305bf01273 100644 --- a/src/native/NativeOCR.ts +++ b/src/native/NativeOCR.ts @@ -8,7 +8,8 @@ export interface Spec extends TurboModule { recognizerSourceLarge: string, recognizerSourceMedium: string, recognizerSourceSmall: string, - language: string + symbols: string, + languageDictPath: string ): Promise; forward(input: string): Promise; } From 7637be2e218f862c60b111c00fcddf0a267fb886 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 27 Jan 2025 15:28:13 +0100 Subject: [PATCH 11/19] feat: add polish language support --- examples/computer-vision/screens/OCRScreen.tsx | 9 +++++---- src/OCR.ts | 2 +- src/constants/ocr/languageDicts.ts | 1 + src/constants/ocr/symbols.ts | 4 ++-- 4 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index 5562ddb9fa..ac94fb7c6e 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -23,11 +23,11 @@ export const OCRScreen = ({ detectorSource: require('../assets/models/xnnpack_craft.pte'), recognizerSources: { recognizerLarge: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', - recognizerMedium: require('../assets/models/xnnpack_crnn_256.pte'), - recognizerSmall: require('../assets/models/xnnpack_crnn_128.pte'), + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_latin_512.pte', + recognizerMedium: require('../assets/models/xnnpack_latin_256.pte'), + recognizerSmall: require('../assets/models/xnnpack_latin_128.pte'), }, - language: 'en', + language: 'pl', }); const handleCameraPress = async (isCamera: boolean) => { @@ -44,6 +44,7 @@ export const OCRScreen = ({ }; const runForward = async () => { + console.log('RUnning forward'); try { const output = await model.forward(imageUri); setResults(output); diff --git a/src/OCR.ts b/src/OCR.ts index ed18831f46..17c4aafcdb 100644 --- a/src/OCR.ts +++ b/src/OCR.ts @@ -73,7 +73,7 @@ export const useOCR = ({ recognizerPaths.recognizerLarge, recognizerPaths.recognizerMedium, recognizerPaths.recognizerSmall, - symbols.default + symbols[language]!, + symbols[language], languageDictPath ); setIsReady(true); diff --git a/src/constants/ocr/languageDicts.ts b/src/constants/ocr/languageDicts.ts index 7eaa1f146b..fcd189b53c 100644 --- a/src/constants/ocr/languageDicts.ts +++ b/src/constants/ocr/languageDicts.ts @@ -1,3 +1,4 @@ export const languageDicts: { [key: string]: string } = { en: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/en.txt', + pl: 'https://huggingface.co/nklockiewicz/ocr/resolve/main/pl.txt', }; diff --git a/src/constants/ocr/symbols.ts b/src/constants/ocr/symbols.ts index 4ead5265f6..229c0613d1 100644 --- a/src/constants/ocr/symbols.ts +++ b/src/constants/ocr/symbols.ts @@ -1,4 +1,4 @@ export const symbols: { [key: string]: string } = { - default: '0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ €', - en: 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', + en: '0123456789!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ €ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz', + pl: ' !"#$%&\'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~ªÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿĀāĂ㥹ĆćČčĎďĐđĒēĖėĘęĚěĞğĨĩĪīĮįİıĶķĹĺĻļĽľŁłŃńŅņŇňŒœŔŕŘřŚśŞşŠšŤťŨũŪūŮůŲųŸŹźŻżŽžƏƠơƯưȘșȚțə̇ḌḍḶḷṀṁṂṃṄṅṆṇṬṭẠạẢảẤấẦầẨẩẪẫẬậẮắẰằẲẳẴẵẶặẸẹẺẻẼẽẾếỀềỂểỄễỆệỈỉỊịỌọỎỏỐốỒồỔổỖỗỘộỚớỜờỞởỠỡỢợỤụỦủỨứỪừỬửỮữỰựỲỳỴỵỶỷỸỹ€', }; From dfdc8110d035331bc268c26ac67bc57d6a62281f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Wed, 5 Feb 2025 16:21:10 +0100 Subject: [PATCH 12/19] feat: implemented upgraded mid processing pipeline --- ios/RnExecutorch/models/ocr/Detector.h | 15 +- ios/RnExecutorch/models/ocr/Detector.mm | 33 +- .../models/ocr/RecognitionHandler.mm | 44 +- .../models/ocr/utils/DetectorUtils.h | 12 +- .../models/ocr/utils/DetectorUtils.mm | 484 +++++++++++++----- .../models/ocr/utils/RecognizerUtils.h | 2 +- .../models/ocr/utils/RecognizerUtils.mm | 54 +- 7 files changed, 450 insertions(+), 194 deletions(-) diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 5d7df1b8e1..1e09073b94 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,14 +1,13 @@ #import "opencv2/opencv.hpp" #import "BaseModel.h" -const float textThreshold = 0.7; -const float linkThreshold = 0.4; -const float lowText = 0.4; -const float yCenterThs = 0.5; -const float heightThs = 0.5; -const float widthThs = 0.5; -const float addMargin = 0.1; -const int minSize = 20; +constexpr float textThreshold = 0.7; +constexpr float linkThreshold = 0.4; +constexpr float lowText = 0.4; +constexpr CGFloat centerThreshold = 0.5; +constexpr CGFloat distanceThreshold = 2.0; +constexpr CGFloat heightThreshold = 2.0; +constexpr int minSize = 20; const cv::Scalar mean(0.485, 0.456, 0.406); const cv::Scalar variance(0.229, 0.224, 0.225); diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index e499f78461..69621061b5 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -69,23 +69,36 @@ - (NSArray *)postprocess:(NSArray *)output { NSArray* horizontalList = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:textThreshold linkThreshold:linkThreshold lowText:lowText]; horizontalList = [DetectorUtils restoreBboxRatio:horizontalList]; - horizontalList = [DetectorUtils groupTextBox:horizontalList ycenterThs:yCenterThs heightThs:heightThs widthThs:widthThs addMargin:addMargin]; + horizontalList = [DetectorUtils groupTextBoxes:horizontalList centerThreshold:centerThreshold distanceThreshold:distanceThreshold heightThreshold:heightThreshold]; - NSMutableArray *boxesToKeep = [NSMutableArray array]; - - for (NSArray *box in horizontalList) { - if (MAX([box[1] intValue] - [box[0] intValue], [box[3] intValue] - [box[2] intValue]) >= minSize) { - [boxesToKeep addObject:box]; - } - } - - return boxesToKeep; + return horizontalList; } - (NSArray *)runModel:(cv::Mat &)input { + NSDate *startTime; + NSDate *endTime; + NSTimeInterval executionTime; + + // Preprocessing + startTime = [NSDate date]; NSArray *modelInput = [self preprocess:input]; + endTime = [NSDate date]; + executionTime = [endTime timeIntervalSinceDate:startTime]; + NSLog(@"Preprocessing time: %f seconds", executionTime); + + // Running the model + startTime = [NSDate date]; NSArray *modelResult = [self forward:modelInput]; + endTime = [NSDate date]; + executionTime = [endTime timeIntervalSinceDate:startTime]; + NSLog(@"Model forwarding time: %f seconds", executionTime); + + // Postprocessing + startTime = [NSDate date]; NSArray *result = [self postprocess:modelResult]; + endTime = [NSDate date]; + executionTime = [endTime timeIntervalSinceDate:startTime]; + NSLog(@"Postprocessing time: %f seconds", executionTime); return result; } diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index aca54e91e8..4979884f78 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -63,27 +63,12 @@ - (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NS } - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { - NSDictionary* ratioAndPadding = [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols height:imgGray.rows desiredWidth:desiredWidth desiredHeight:desiredHeight]; - - int left = [ratioAndPadding[@"left"] intValue]; - int top = [ratioAndPadding[@"top"] intValue]; - float resizeRatio = [ratioAndPadding[@"resizeRatio"] floatValue]; imgGray = [OCRUtils resizeWithPadding:imgGray desiredWidth:desiredWidth desiredHeight:desiredHeight]; - NSMutableArray *predictions = [NSMutableArray array]; - for (NSArray *box in horizontalList) { - int maximum_y = imgGray.rows; - int maximum_x = imgGray.cols; - - int x_min = MAX(0, [box[0] intValue]); - int x_max = MIN([box[1] intValue], maximum_x); - int y_min = MAX(0, [box[2] intValue]); - int y_max = MIN([box[3] intValue], maximum_y); - - cv::Mat croppedImage = [RecognizerUtils getCroppedImage:x_max x_min:x_min y_max:y_max y_min:y_min image:imgGray modelHeight:modelHeight]; - - - croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:0.0]; + NSLog(@"%@", horizontalList); + for (NSDictionary *box in horizontalList) { + cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box image:imgGray modelHeight:modelHeight]; + croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:0.2]; NSArray *result; if(croppedImage.cols >= largeModelWidth) { result = [recognizerLarge runModel:croppedImage]; @@ -94,11 +79,28 @@ - (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desir } NSNumber *confidenceScore = [result objectAtIndex:1]; + if([confidenceScore floatValue] < 0.3){ + cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); + } + NSArray *rotatedResult; + if(croppedImage.cols >= largeModelWidth) { + rotatedResult = [recognizerLarge runModel:croppedImage]; + } else if (croppedImage.cols >= mediumModelWidth) { + rotatedResult = [recognizerMedium runModel: croppedImage]; + } else { + rotatedResult = [recognizerSmall runModel: croppedImage]; + } + NSNumber *rotatedConfidenceScore = [rotatedResult objectAtIndex:1]; + + if ([rotatedConfidenceScore floatValue] > [confidenceScore floatValue]) { + result = rotatedResult; + } + NSArray *pred_index = [result objectAtIndex:0]; NSArray* decodedTexts = [converter decodeGreedy:pred_index length:(int)(pred_index.count)]; - - NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": @{@"x1": @((int)((x_min - left) * resizeRatio)), @"x2": @((int)((x_max - left) * resizeRatio)), @"y1": @((int)((y_min - top) * resizeRatio)), @"y2":@((int)((y_max - top) * resizeRatio))}, @"score": confidenceScore}; + + NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": box[@"box"], @"score": confidenceScore}; [predictions addObject:res]; } diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index afddf2ffa9..cc527d7ad2 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -5,13 +5,9 @@ + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array; + (NSArray *)restoreBboxRatio:(NSArray *)boxes; + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; -+ (NSMutableArray *)prepareBoxesFromPolys:(NSArray *)polys; -+ (NSMutableArray *)combineBoxes:(NSMutableArray *)boxes withYCenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs; -+ (NSArray *)mergeBoxes:(NSMutableArray *)combinedBoxes withWidthThs:(CGFloat)widthThs heightThs:(CGFloat)heightThs addMargin:(CGFloat)addMargin; -+ (NSArray *> *)groupTextBox:(NSArray *> *)polys - ycenterThs:(CGFloat)ycenterThs - heightThs:(CGFloat)heightThs - widthThs:(CGFloat)widthThs - addMargin:(CGFloat)addMargin; ++ (NSArray *)groupTextBoxes:(NSArray *)polys + centerThreshold:(CGFloat)centerThreshold + distanceThreshold:(CGFloat)distanceThreshold + heightThreshold:(CGFloat)heightThreshold; @end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 51a5587419..ec089190c6 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -2,6 +2,179 @@ @implementation DetectorUtils ++ (CGFloat)normalizeAngle:(CGFloat)angle { + /* + Normalize the angle returned by OpenCV's minAreaRect. + */ + if (angle > 45) { + return angle - 90; + } + return angle; +} + ++ (CGFloat)distance:(CGPoint)p1 p2:(CGPoint)p2 { + double xDist = (p2.x - p1.x); + double yDist = (p2.y - p1.y); + return sqrt(xDist * xDist + yDist * yDist); +} + ++ (CGPoint)midpoint:(CGPoint)p1 p2:(CGPoint)p2 { + return CGPointMake((p1.x + p2.x) / 2, (p1.y + p2.y) / 2); +} + ++ (CGPoint)centerOfBox:(NSArray *)box { + return CGPointMake(([box[0] CGPointValue].x + [box[2] CGPointValue].x) / 2, ([box[0] CGPointValue].y + [box[2] CGPointValue].y) / 2); +} + ++ (CGFloat)maxSideLength:(NSArray *)points { + CGFloat maxSideLength = 0; + NSInteger numOfPoints = points.count; + for (NSInteger i = 0; i < numOfPoints; i++) { + CGPoint currentPoint = [points[i] CGPointValue]; + CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; + + CGFloat sideLength = [self distance:currentPoint p2:nextPoint]; + if (sideLength > maxSideLength) { + maxSideLength = sideLength; + } + } + return maxSideLength; +} + ++ (CGFloat)minSideLength:(NSArray *)points { + CGFloat minSideLength = CGFLOAT_MAX; + NSInteger numOfPoints = points.count; + + for (NSInteger i = 0; i < numOfPoints; i++) { + CGPoint currentPoint = [points[i] CGPointValue]; + CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; + + CGFloat sideLength = [self distance:currentPoint p2:nextPoint]; + if (sideLength < minSideLength) { + minSideLength = sideLength; + } + } + + return minSideLength; +} + ++ (NSArray *)orderPointsClockwise:(NSArray *)points{ + CGPoint topLeft, topRight, bottomRight, bottomLeft; + float minSum = FLT_MAX; + float maxSum = -FLT_MAX; + float minDiff = FLT_MAX; + float maxDiff = -FLT_MAX; + + for (NSValue *value in points) { + CGPoint pt = [value CGPointValue]; + float sum = pt.x + pt.y; + float diff = pt.y - pt.x; + + // For top-left and bottom-right determination + if (sum < minSum) { + minSum = sum; + topLeft = pt; + } + if (sum > maxSum) { + maxSum = sum; + bottomRight = pt; + } + + // For top-right and bottom-left determination + if (diff < minDiff) { + minDiff = diff; + topRight = pt; + } + if (diff > maxDiff) { + maxDiff = diff; + bottomLeft = pt; + } + } + + NSArray *rect = @[[NSValue valueWithCGPoint:topLeft], + [NSValue valueWithCGPoint:topRight], + [NSValue valueWithCGPoint:bottomRight], + [NSValue valueWithCGPoint:bottomLeft]]; + + return rect; +} + ++ (NSArray *)rotateBox:(NSArray *)box withAngle:(CGFloat)angle { + // Calculate the center of the rectangle + CGPoint center = [self centerOfBox:box]; + + // Convert angle from degrees to radians + CGFloat radians = angle * M_PI / 180.0; + + // Prepare an array to hold the rotated points + NSMutableArray *rotatedPoints = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *value in box) { + CGPoint point = [value CGPointValue]; + + // Translate point to origin + CGFloat translatedX = point.x - center.x; + CGFloat translatedY = point.y - center.y; + + // Rotate point + CGFloat rotatedX = translatedX * cos(radians) - translatedY * sin(radians); + CGFloat rotatedY = translatedX * sin(radians) + translatedY * cos(radians); + + // Translate point back + CGPoint rotatedPoint = CGPointMake(rotatedX + center.x, rotatedY + center.y); + [rotatedPoints addObject:[NSValue valueWithCGPoint:rotatedPoint]]; + } + + return rotatedPoints; +} + ++ (std::vector)pointsFromNSValues:(NSArray *)nsValues { + std::vector points; + for (NSValue *value in nsValues) { + CGPoint point = [value CGPointValue]; + points.emplace_back(point.x, point.y); + } + return points; +} + ++ (NSArray *)nsValuesFromPoints:(cv::Point2f *)points count:(int)count { + NSMutableArray *nsValues = [[NSMutableArray alloc] initWithCapacity:count]; + for (int i = 0; i < count; i++) { + [nsValues addObject:[NSValue valueWithCGPoint:CGPointMake(points[i].x, points[i].y)]]; + } + return nsValues; +} + ++ (NSArray *)mergeRotatedBoxes:(NSArray *)box1 withBox:(NSArray *)box2 { + box1 = [self orderPointsClockwise:box1]; + box2 = [self orderPointsClockwise:box2]; + + std::vector points1 = [self pointsFromNSValues:box1]; + std::vector points2 = [self pointsFromNSValues:box2]; + + // Collect all points from both rectangles + std::vector allPoints; + allPoints.insert(allPoints.end(), points1.begin(), points1.end()); + allPoints.insert(allPoints.end(), points2.begin(), points2.end()); + + // Calculate the convex hull of all points + std::vector hullIndices; + cv::convexHull(allPoints, hullIndices, false); + + std::vector hullPoints; + for (int idx : hullIndices) { + hullPoints.push_back(allPoints[idx]); + } + + // Get the minimum area rectangle that bounds the convex hull + cv::RotatedRect minAreaRect = cv::minAreaRect(hullPoints); + + cv::Point2f rectPoints[4]; + minAreaRect.points(rectPoints); + + // Convert rotated rectangle points back to NSArray + return [self nsValuesFromPoints:rectPoints count:4]; +} + + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { NSMutableArray *scoreText = [[NSMutableArray alloc] init]; NSMutableArray *scoreLink = [[NSMutableArray alloc] init]; @@ -17,32 +190,133 @@ + (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { return @{@"ScoreText": scoreText, @"ScoreLink": scoreLink}; } ++ (CGFloat)calculateMinimalDistance:(NSArray *)corners1 corners2:(NSArray *)corners2 { + CGFloat minDistance = CGFLOAT_MAX; + for (NSValue *value1 in corners1) { + CGPoint corner1 = [value1 CGPointValue]; + for (NSValue *value2 in corners2) { + CGPoint corner2 = [value2 CGPointValue]; + CGFloat distance = [self distance:corner1 p2:corner2]; + if (distance < minDistance) { + minDistance = distance; + } + } + } + return minDistance; +} + ++ (NSDictionary *)fitLineToShortestSides:(NSArray *)points { + // Calculate distances and find midpoints + NSMutableArray *sides = [NSMutableArray array]; + NSMutableArray *midpoints = [NSMutableArray array]; + + for (int i = 0; i < 4; i++) { + CGPoint p1 = [points[i] CGPointValue]; + CGPoint p2 = [points[(i + 1) % 4] CGPointValue]; + + CGFloat sideLength = [self distance:p1 p2:p2]; + [sides addObject:@{@"length": @(sideLength), @"index": @(i)}]; + [midpoints addObject:[NSValue valueWithCGPoint:[self midpoint:p1 p2:p2]]]; + } + + // Sort indices by distances + [sides sortUsingDescriptors:@[[NSSortDescriptor sortDescriptorWithKey:@"length" ascending:YES]]]; + + CGPoint midpoint1 = [midpoints[[sides[0][@"index"] intValue]] CGPointValue]; + CGPoint midpoint2 = [midpoints[[sides[1][@"index"] intValue]] CGPointValue]; + CGFloat dx = fabs(midpoint2.x - midpoint1.x); + + float m, c; + BOOL isVertical; + + std::vector cvMidPoints = {cv::Point2f(midpoint1.x, midpoint1.y), cv::Point2f(midpoint2.x, midpoint2.y)}; + cv::Vec4f line; + + if (dx < 20) { + // If almost vertical, fit x = my + c + for (auto &pt : cvMidPoints) std::swap(pt.x, pt.y); + cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); + m = line[1] / line[0]; + c = line[3] - m * line[2]; + isVertical = YES; + } else { + // Fit y = mx + c + cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); + m = line[1] / line[0]; + c = line[3] - m * line[2]; + isVertical = NO; + } + + return @{@"slope": @(m), @"intercept": @(c), @"isVertical": @(isVertical)}; +} + ++ (NSDictionary *)findClosestBox:(NSArray *)polys + ignoredIdxs:(NSSet *)ignoredIdxs + currentBox:(NSArray *)currentBox + isVertical:(BOOL)isVertical + m:(CGFloat)m + c:(CGFloat)c + centerThreshold:(CGFloat)centerThreshold +{ + CGFloat smallestDistance = CGFLOAT_MAX; + NSDictionary *boxToMerge = nil; + NSInteger idx = -1; + CGFloat boxHeight = 0; + CGPoint centerOfCurrentBox = [self centerOfBox:currentBox]; + + for (NSUInteger i = 0; i < polys.count; i++) { + if ([ignoredIdxs containsObject:@(i)]) { + continue; + } + NSArray *coords = polys[i][@"box"]; + CGFloat angle = [polys[i][@"angle"] doubleValue]; + CGPoint centerOfProcessedBox = [self centerOfBox:coords]; + CGFloat distanceBetweenCenters = [self distance:centerOfCurrentBox p2:centerOfProcessedBox]; + + if (distanceBetweenCenters >= smallestDistance) { + continue; + } + + boxHeight = [self minSideLength:coords]; + + CGFloat lineDistance = (isVertical ? + fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) : + fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); + + if (lineDistance < boxHeight * centerThreshold) { + boxToMerge = @{@"coords": coords, @"angle": @(angle)}; + idx = i; + smallestDistance = distanceBetweenCenters; + } + } + + return boxToMerge ? @{@"boxToMerge": boxToMerge, @"idx": @(idx), @"boxHeight": @(boxHeight)} : nil; +} + + (NSArray *)restoreBboxRatio:(NSArray *)boxes { NSMutableArray *result = [NSMutableArray array]; for (NSUInteger i = 0; i < [boxes count]; i++) { - NSArray *box = boxes[i]; + NSDictionary *box = boxes[i]; NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *value in box) { + for (NSValue *value in box[@"box"]) { CGPoint point = [value CGPointValue]; - point.x *= 2; - point.y *= 2; - [boxArray addObject:@((int)point.x)]; - [boxArray addObject:@((int)point.y)]; + point.x *= 2 * 1.6; + point.y *= 2 * 1.6; + [boxArray addObject:[NSValue valueWithCGPoint:point]]; } - [result addObject:boxArray]; + NSDictionary *dict = @{@"box": boxArray, @"angle": box[@"angle"]}; + [result addObject:dict]; } return result; } + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText { - cv::Mat textmapCopy = textmap.clone(); - cv::Mat linkmapCopy = linkmap.clone(); int img_h = textmap.rows; int img_w = textmap.cols; cv::Mat textScore, linkScore; - cv::threshold(textmapCopy, textScore, lowText, 1, 0); - cv::threshold(linkmapCopy, linkScore, linkThreshold, 1, 0); + cv::threshold(textmap, textScore, lowText, 1, 0); + cv::threshold(linkmap, linkScore, linkThreshold, 1, 0); cv::Mat textScoreComb = textScore + linkScore; cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); cv::Mat binaryMat; @@ -58,7 +332,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold cv::Mat mask = (labels == i); double maxVal; - cv::minMaxLoc(textmapCopy, NULL, &maxVal, NULL, NULL, mask); + cv::minMaxLoc(textmap, NULL, &maxVal, NULL, NULL, mask); if (maxVal < textThreshold) continue; cv::Mat segMap = cv::Mat::zeros(textmap.size(), CV_8U); @@ -83,6 +357,7 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold cv::Mat roiSegMap = segMap(roi); cv::dilate(roiSegMap, roiSegMap, kernel); + // Find minimal area rect std::vector> contours; cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); if (!contours.empty()) { @@ -94,145 +369,84 @@ + (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); [pointsArray addObject:[NSValue valueWithCGPoint:point]]; } - [detectedBoxes addObject:pointsArray]; + NSDictionary *dict = @{@"box": pointsArray, @"angle": @(minRect.angle)}; + [detectedBoxes addObject:dict]; } } return detectedBoxes; } -+ (NSMutableArray *)prepareBoxesFromPolys:(NSArray *)polys { - NSMutableArray *boxes = [NSMutableArray array]; - for (NSArray *poly in polys) { - NSArray *xCoords = @[poly[0], poly[2], poly[4], poly[6]]; - NSArray *yCoords = @[poly[1], poly[3], poly[5], poly[7]]; - NSNumber *xMin = [xCoords valueForKeyPath:@"@min.self"]; - NSNumber *xMax = [xCoords valueForKeyPath:@"@max.self"]; - NSNumber *yMin = [yCoords valueForKeyPath:@"@min.self"]; - NSNumber *yMax = [yCoords valueForKeyPath:@"@max.self"]; - [boxes addObject:@[xMin, xMax, yMin, yMax, @(([yMin floatValue] + [yMax floatValue]) / 2.0), @([yMax floatValue] - [yMin floatValue])]]; - } - return boxes; -} - -+ (NSMutableArray *)combineBoxes:(NSMutableArray *)boxes withYCenterThs:(CGFloat)ycenterThs heightThs:(CGFloat)heightThs { - NSMutableArray *combinedBoxes = [NSMutableArray array]; - NSMutableArray *currentGroup = [NSMutableArray array]; - - for (NSArray *box in boxes) { - if (currentGroup.count == 0) { - [currentGroup addObject:box]; - } else { - NSArray *lastBox = [currentGroup lastObject]; - BOOL closeYCenter = fabs([[lastBox objectAtIndex:4] floatValue] - [[box objectAtIndex:4] floatValue]) < ycenterThs * [[lastBox objectAtIndex:5] floatValue]; - if (closeYCenter) { - [currentGroup addObject:box]; - } else { - [combinedBoxes addObject:[currentGroup copy]]; - currentGroup = [@[box] mutableCopy]; - } ++ (CGFloat)minimumYFromBox:(NSArray *)box { + __block CGFloat minY = CGFLOAT_MAX; + [box enumerateObjectsUsingBlock:^(NSValue * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { + CGPoint pt = [obj CGPointValue]; + if (pt.y < minY) { + minY = pt.y; } - } - if (currentGroup.count > 0) { - [combinedBoxes addObject:[currentGroup copy]]; - } - return combinedBoxes; + }]; + return minY; } -+ (NSArray *)mergeBoxes:(NSMutableArray *)combinedBoxes withWidthThs:(CGFloat)widthThs heightThs:(CGFloat)heightThs addMargin:(CGFloat)addMargin { - NSMutableArray *mergedList = [NSMutableArray array]; ++ (NSArray *)groupTextBoxes:(NSArray *)polys + centerThreshold:(CGFloat)centerThreshold + distanceThreshold:(CGFloat)distanceThreshold + heightThreshold:(CGFloat)heightThreshold +{ + // Sort polys by max side length in descending order + NSMutableArray *sortedPolys = [polys sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { + CGFloat maxLen1 = [self maxSideLength:obj1[@"box"]]; + CGFloat maxLen2 = [self maxSideLength:obj2[@"box"]]; + return (maxLen1 < maxLen2) ? NSOrderedDescending : (maxLen1 > maxLen2) ? NSOrderedAscending : NSOrderedSame; + }].mutableCopy; - for (NSArray *group in combinedBoxes) { - if (group.count == 1) { - NSArray *box = group[0]; - float margin = addMargin * MIN([[box objectAtIndex:1] floatValue] - [[box objectAtIndex:0] floatValue], [[box objectAtIndex:5] floatValue]); - [mergedList addObject:@[@([[box objectAtIndex:0] floatValue] - margin), @([[box objectAtIndex:1] floatValue] + margin), @([[box objectAtIndex:2] floatValue] - margin), @([[box objectAtIndex:3] floatValue] + margin)]]; - } else { - NSArray *sortedBoxes = [group sortedArrayUsingComparator:^NSComparisonResult(NSArray *obj1, NSArray *obj2) { - return [@([obj1[0] intValue]) compare:@([obj2[0] intValue])]; - }]; - - NSMutableArray *mergedBox = [NSMutableArray array]; - NSMutableArray *newBox = [NSMutableArray array]; - int xMax = 0; - NSMutableArray *bHeight = [NSMutableArray array]; - - for (NSArray *box in sortedBoxes) { - if ([newBox count] == 0) { - [bHeight addObject:box[5]]; - xMax = [box[1] intValue]; - [newBox addObject:box]; - } else { - int currHeight = [box[5] intValue]; - float meanHeight = [[bHeight valueForKeyPath:@"@avg.self"] floatValue]; - if (fabs(meanHeight - currHeight) < heightThs * meanHeight && - ([box[0] intValue] - xMax) < widthThs * ([box[3] intValue] - [box[2] intValue])) { - [bHeight addObject:box[5]]; - xMax = [box[1] intValue]; - [newBox addObject:box]; - } else { - [mergedBox addObject:[newBox copy]]; - newBox = [@[box] mutableCopy]; - bHeight = [@[box[5]] mutableCopy]; - xMax = [box[1] intValue]; - } - } - } - if ([newBox count] > 0) { - [mergedBox addObject:newBox]; + NSMutableArray *mergedList = [NSMutableArray array]; + CGFloat angleDegrees; + while (sortedPolys.count > 0) { + NSMutableDictionary *currentBox = [sortedPolys[0] mutableCopy]; + [sortedPolys removeObjectAtIndex:0]; + CGFloat currentAngle = [self normalizeAngle:[currentBox[@"angle"] floatValue]]; + NSMutableArray *ignoredIdxs = [NSMutableArray array]; + + while (YES) { + NSDictionary *lineFit = [self fitLineToShortestSides:currentBox[@"box"]]; + NSLog(@"lineFit: %@", lineFit); + angleDegrees = atan([lineFit[@"slope"] floatValue]) * 180 / M_PI; + if ([lineFit[@"isVertical"] boolValue]){ + angleDegrees = -90; } + CGFloat mergedHeight = [self minSideLength:currentBox[@"box"]]; + NSDictionary *closestBoxInfo = [self findClosestBox:sortedPolys ignoredIdxs:[NSSet setWithArray:ignoredIdxs] currentBox:currentBox[@"box"] isVertical:[lineFit[@"isVertical"] boolValue] m:[lineFit[@"slope"] floatValue] c:[lineFit[@"intercept"] floatValue] centerThreshold:centerThreshold]; + if (closestBoxInfo == nil) break; - for (NSArray *mbox in mergedBox) { - if ([mbox count] != 1) { - NSNumber *xMin = [mbox[0] objectAtIndex:0]; - NSNumber *xMax = [mbox[0] objectAtIndex:1]; - NSNumber *yMin = [mbox[0] objectAtIndex:2]; - NSNumber *yMax = [mbox[0] objectAtIndex:3]; - for (NSArray *box in mbox) { - if ([box[0] intValue] < [xMin intValue]) { - xMin = box[0]; - } - if([box[1] intValue] > [xMax intValue]) { - xMax = box[1]; - } - if ([box[2] intValue] < [yMin intValue]) { - yMin = box[2]; - } - if ([box[3] intValue] > [yMax intValue]) { - yMax = box[3]; - } - } - - int margin = (int)(addMargin * MIN([xMax floatValue] - [xMin floatValue], [yMax floatValue] - [yMin floatValue])); - [mergedList addObject:@[@([xMin intValue] - margin), - @([xMax intValue] + margin), - @([yMin intValue] - margin), - @([yMax intValue] + margin)]]; - } else { - NSArray *box = mbox[0]; - int margin = (int)(addMargin * MIN([box[1] floatValue] - [box[0] floatValue], [box[3] floatValue] - [box[2] floatValue])); - [mergedList addObject:@[@([box[0] intValue] - margin), - @([box[1] intValue] + margin), - @([box[2] intValue] - margin), - @([box[3] intValue] + margin)]]; - } + NSMutableDictionary *candidateBox = [closestBoxInfo[@"boxToMerge"] mutableCopy]; + NSInteger candidateIdx = [closestBoxInfo[@"idx"] integerValue]; + CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; + if (([candidateBox[@"angle"] isEqual: @90] && ![lineFit[@"isVertical"] boolValue]) || ([candidateBox[@"angle"] isEqual: @0] && [lineFit[@"isVertical"] boolValue])) { + candidateBox[@"coords"] = [self rotateBox:candidateBox[@"coords"] withAngle:currentAngle]; + } + CGFloat minDistance = [self calculateMinimalDistance:candidateBox[@"coords"] corners2:currentBox[@"box"]]; + if (minDistance < distanceThreshold * candidateHeight && fabs(mergedHeight - candidateHeight) < mergedHeight * heightThreshold) { + currentBox[@"box"] = [self mergeRotatedBoxes:currentBox[@"box"] withBox:candidateBox[@"coords"]]; + [sortedPolys removeObjectAtIndex:candidateIdx]; + [ignoredIdxs removeAllObjects]; // Restart with new merged box + } else { + [ignoredIdxs addObject:@(candidateIdx)]; } - } + [mergedList addObject:@{@"box" : currentBox[@"box"], @"angle" : @(angleDegrees)}]; } - return [mergedList copy]; -} - -+ (NSArray *> *)groupTextBox:(NSArray *> *)polys - ycenterThs:(CGFloat)ycenterThs - heightThs:(CGFloat)heightThs - widthThs:(CGFloat)widthThs - addMargin:(CGFloat)addMargin -{ - NSMutableArray *horizontalList = [self prepareBoxesFromPolys:polys]; - NSMutableArray *combinedList = [self combineBoxes:horizontalList withYCenterThs:ycenterThs heightThs:heightThs]; - return [self mergeBoxes:combinedList withWidthThs:widthThs heightThs:heightThs addMargin:addMargin]; + // Optionally sort by angle if needed + NSArray *sortedBoxes = [mergedList sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { + NSArray *coords1 = obj1[@"box"]; + NSArray *coords2 = obj2[@"box"]; + CGFloat minY1 = [self minimumYFromBox:coords1]; + CGFloat minY2 = [self minimumYFromBox:coords2]; + return (minY1 < minY2) ? NSOrderedAscending : (minY1 > minY2) ? NSOrderedDescending : NSOrderedSame; + }]; + + return sortedBoxes; } @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index d3a9965383..337cdc9f94 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -9,7 +9,7 @@ + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; + (cv::Mat)softmax:(cv::Mat)inputs; + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; -+ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight; ++ (cv::Mat)getCroppedImage:(NSDictionary *)box image:(cv::Mat)image modelHeight:(int)modelHeight; + (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities modelOutputHeight:(int)modelOutputHeight; + (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; + (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray; diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 9199a828b4..3e6adccde8 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -115,23 +115,55 @@ + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height }; } -+ (cv::Mat)getCroppedImage:(int)x_max x_min:(int)x_min y_max:(int)y_max y_min:(int)y_min image:(cv::Mat)image modelHeight:(int)modelHeight { - cv::Rect region(x_min, y_min, x_max - x_min, y_max - y_min); - cv::Mat crop_img = image(region); ++ (cv::Mat)getCroppedImage:(NSDictionary *)box image:(cv::Mat)image modelHeight:(int)modelHeight { + // Convert NSValue array to cv::Point2f vector + NSArray *coords = box[@"box"]; + CGFloat angle = [box[@"angle"] floatValue]; + + std::vector points; + for (NSValue *value in coords) { + CGPoint point = [value CGPointValue]; + points.emplace_back(static_cast(point.x), static_cast(point.y)); + } + + // Obtain the rotated rectangle from the points + cv::RotatedRect rotatedRect = cv::minAreaRect(points); + + // Compute the rotation matrix for the angle of the rotated rectangle + cv::Point2f imageCenter = cv::Point2f(image.cols / 2.0, image.rows / 2.0); + cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, angle, 1.0); + + // Rotate the entire image + cv::Mat rotatedImage; + cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), cv::INTER_LINEAR); - int width = x_max - x_min; - int height = y_max - y_min; + // Get vertices of the minimal rotated rectangle + cv::Point2f rectPoints[4]; + rotatedRect.points(rectPoints); - CGFloat ratio = [self calculateRatio:width height:height]; - int new_width = (int)(modelHeight * ratio); + // Transform points using the rotation matrix + std::vector transformedPoints(4); + cv::Mat rectMat(4, 2, CV_32FC2, rectPoints); + cv::transform(rectMat, rectMat, rotationMatrix); - if (new_width == 0) { - return crop_img; + // Convert points back to array of points to compute bounding box afterward + for (int i = 0; i < 4; ++i) { + transformedPoints[i] = rectPoints[i]; } - crop_img = [self computeRatioAndResize:crop_img width:width height:height modelHeight:modelHeight]; + // Compute the bounding box of transformed points + cv::Rect boundingBox = cv::boundingRect(transformedPoints); + + // Make sure the bounding box fits within the image + boundingBox &= cv::Rect(0, 0, rotatedImage.cols, rotatedImage.rows); + + // Optional: Crop to this bounding box if necessary + cv::Mat croppedImage = rotatedImage(boundingBox); + + // If specified to resize according to modelHeight, adjust the aspect ratio and resize + croppedImage = [self computeRatioAndResize:croppedImage width:boundingBox.width height:boundingBox.height modelHeight:modelHeight]; - return crop_img; + return croppedImage; } + (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities modelOutputHeight:(int)modelOutputHeight { From b7e0635445898026fed2eaa47259caf9d6c11598 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 6 Feb 2025 16:04:25 +0100 Subject: [PATCH 13/19] refactor: refactored code for mid processing pipeline, adjusted code to changes suggested in review --- .../components/ImageWithOCRBboxes.tsx | 103 ++++ .../computer-vision/screens/OCRScreen.tsx | 17 +- ios/RnExecutorch/OCR.h | 2 + ios/RnExecutorch/OCR.mm | 2 +- ios/RnExecutorch/models/ocr/Detector.h | 12 +- ios/RnExecutorch/models/ocr/Detector.mm | 63 +- .../models/ocr/RecognitionHandler.h | 12 +- .../models/ocr/RecognitionHandler.mm | 73 ++- ios/RnExecutorch/models/ocr/Recognizer.mm | 15 +- .../models/ocr/utils/DetectorUtils.h | 16 +- .../models/ocr/utils/DetectorUtils.mm | 572 ++++++++++-------- .../models/ocr/utils/RecognizerUtils.mm | 40 +- src/index.tsx | 1 + src/types/ocr.ts | 9 +- 14 files changed, 570 insertions(+), 367 deletions(-) create mode 100644 examples/computer-vision/components/ImageWithOCRBboxes.tsx diff --git a/examples/computer-vision/components/ImageWithOCRBboxes.tsx b/examples/computer-vision/components/ImageWithOCRBboxes.tsx new file mode 100644 index 0000000000..1c8fe616af --- /dev/null +++ b/examples/computer-vision/components/ImageWithOCRBboxes.tsx @@ -0,0 +1,103 @@ +// Import necessary components +import React from 'react'; +import { Image, StyleSheet, View } from 'react-native'; +import Svg, { Polygon } from 'react-native-svg'; +import { OCRDetection } from 'react-native-executorch'; + +interface Props { + imageUri: string; + detections: OCRDetection[]; + imageWidth: number; + imageHeight: number; +} + +export default function ImageWithOCRBboxes({ + imageUri, + detections, + imageWidth, + imageHeight, +}: Props) { + const [layout, setLayout] = React.useState({ width: 0, height: 0 }); + + const calculateAdjustedDimensions = () => { + const imageRatio = imageWidth / imageHeight; + const layoutRatio = layout.width / layout.height; + let sx, sy; + if (imageRatio > layoutRatio) { + sx = layout.width / imageWidth; + sy = layout.width / imageRatio / imageHeight; + } else { + sy = layout.height / imageHeight; + sx = (layout.height * imageRatio) / imageWidth; + } + return { + scaleX: sx, + scaleY: sy, + offsetX: (layout.width - imageWidth * sx) / 2, + offsetY: (layout.height - imageHeight * sy) / 2, + }; + }; + + return ( + { + const { width, height } = event.nativeEvent.layout; + setLayout({ width, height }); + }} + > + + + {detections.map((detection, index) => { + const { scaleX, scaleY, offsetX, offsetY } = + calculateAdjustedDimensions(); + const points = detection.bbox.map((point) => ({ + x: point.x * scaleX + offsetX, + y: point.y * scaleY + offsetY, + })); + + const pointsString = points + .map((point) => `${point.x},${point.y}`) + .join(' '); + + return ( + + ); + })} + + + ); +} + +const styles = StyleSheet.create({ + container: { + flex: 1, + position: 'relative', + }, + image: { + flex: 1, + width: '100%', + height: '100%', + }, + svgContainer: { + position: 'absolute', + top: 0, + left: 0, + right: 0, + bottom: 0, + }, +}); diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index ac94fb7c6e..ebcdf32eba 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -4,7 +4,7 @@ import { getImage } from '../utils'; import { useOCR } from 'react-native-executorch'; import { View, StyleSheet, Image, Text } from 'react-native'; import { useState } from 'react'; -import ImageWithBboxes from '../components/ImageWithBboxes'; +import ImageWithBboxes2 from '../components/ImageWithOCRBboxes'; export const OCRScreen = ({ imageUri, @@ -20,14 +20,17 @@ export const OCRScreen = ({ }>(); const [detectedText, setDetectedText] = useState(''); const model = useOCR({ - detectorSource: require('../assets/models/xnnpack_craft.pte'), + detectorSource: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_craft_800.pte', recognizerSources: { recognizerLarge: - 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_latin_512.pte', - recognizerMedium: require('../assets/models/xnnpack_latin_256.pte'), - recognizerSmall: require('../assets/models/xnnpack_latin_128.pte'), + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_512.pte', + recognizerMedium: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_256.pte', + recognizerSmall: + 'https://huggingface.co/nklockiewicz/ocr/resolve/main/xnnpack_crnn_128.pte', }, - language: 'pl', + language: 'en', }); const handleCameraPress = async (isCamera: boolean) => { @@ -70,7 +73,7 @@ export const OCRScreen = ({ {imageUri && imageDimensions?.width && imageDimensions?.height ? ( - +constexpr CGFloat recognizerRatio = 1.6; + @interface OCR : NSObject @end diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index d7237242f8..cd58d2c4b8 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -66,7 +66,7 @@ - (void)forward:(NSString *)input NSArray* result = [detector runModel:image]; cv::Size detectorSize = [detector getModelImageSize]; cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); - result = [self->recognitionHandler recognize:result imgGray:image desiredWidth:detectorSize.width desiredHeight:detectorSize.height]; + result = [self->recognitionHandler recognize:result imgGray:image desiredWidth:detectorSize.width * recognizerRatio desiredHeight:detectorSize.height * recognizerRatio]; resolve(result); } @catch (NSException *exception) { reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason], diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 1e09073b94..346069720a 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,13 +1,19 @@ #import "opencv2/opencv.hpp" #import "BaseModel.h" +#import "RecognitionHandler.h" -constexpr float textThreshold = 0.7; -constexpr float linkThreshold = 0.4; -constexpr float lowText = 0.4; +constexpr CGFloat textThreshold = 0.4; +constexpr CGFloat linkThreshold = 0.4; +constexpr CGFloat lowTextThreshold = 0.7; constexpr CGFloat centerThreshold = 0.5; constexpr CGFloat distanceThreshold = 2.0; constexpr CGFloat heightThreshold = 2.0; +constexpr CGFloat restoreRatio = 3.2; +constexpr int minSideThreshold = 15; +constexpr int maxSideThreshold = 30; +constexpr int maxWidth = largeModelWidth + (largeModelWidth * 0.15); constexpr int minSize = 20; + const cv::Scalar mean(0.485, 0.456, 0.406); const cv::Scalar variance(0.229, 0.224, 0.225); diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 69621061b5..355e0b337a 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -18,12 +18,12 @@ @implementation Detector { return modelSize; } - NSArray * inputShape = [module getInputShape: @0]; + NSArray *inputShape = [module getInputShape: @0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - int height = [heightNumber intValue]; - int width = [widthNumber intValue]; + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; modelSize = cv::Size(height, width); return cv::Size(height, width); @@ -31,7 +31,7 @@ @implementation Detector { - (NSArray *)preprocess:(cv::Mat &)input { /* - Detector as an input accepts tensor with a shape of [1, 3, 1280, 1280]. + Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. Due to big influence of resize to quality of recognition the image preserves original aspect ratio and the missing parts are filled with padding. */ @@ -47,59 +47,36 @@ - (NSArray *)preprocess:(cv::Mat &)input { - (NSArray *)postprocess:(NSArray *)output { /* - The output of the model consists of two matrices: + The output of the model consists of two matrices (heat maps): 1. ScoreText(Score map) - The probability of a region containing character - 2. ScoreLink(Affinity map) - The probability of a region being a part of a text line - Both matrices are 640x640 + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to group each character into a single instance (sequence) + Both matrices are 400x400 The result of this step is a list of bounding boxes that contain text. */ NSArray *predictions = [output objectAtIndex:0]; - NSDictionary *splittedData = [DetectorUtils splitInterleavedNSArray:predictions]; - NSArray *scoreText = splittedData[@"ScoreText"]; - NSArray *scoreLink = splittedData[@"ScoreLink"]; - - cv::Mat scoreTextCV; - cv::Mat scoreLinkCV; cv::Size modelImageSize = [self getModelImageSize]; + cv::Mat scoreTextCV, scoreAffinityCV; + /* + The output of the model is a matrix in size of input image containing two matrices representing heatmap. + Those two matrices are in the size of half of the input image, that's why the width and height is divided by 2. + */ + [DetectorUtils interleavedArrayToMats:predictions + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; + NSArray* bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV affinityMap:scoreAffinityCV usingTextThreshold:textThreshold linkThreshold:linkThreshold lowTextThreshold:lowTextThreshold]; + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList usingRestoreRatio: restoreRatio]; + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList centerThreshold:centerThreshold distanceThreshold:distanceThreshold heightThreshold:heightThreshold minSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold maxWidth:maxWidth]; - scoreTextCV = [ImageProcessor arrayToMatGray:scoreText width:modelImageSize.width / 2 height:modelImageSize.height / 2]; - scoreLinkCV = [ImageProcessor arrayToMatGray:scoreLink width:modelImageSize.width / 2 height:modelImageSize.height / 2]; - - NSArray* horizontalList = [DetectorUtils getDetBoxes:scoreTextCV linkMap:scoreLinkCV textThreshold:textThreshold linkThreshold:linkThreshold lowText:lowText]; - horizontalList = [DetectorUtils restoreBboxRatio:horizontalList]; - horizontalList = [DetectorUtils groupTextBoxes:horizontalList centerThreshold:centerThreshold distanceThreshold:distanceThreshold heightThreshold:heightThreshold]; - - return horizontalList; + return bBoxesList; } - (NSArray *)runModel:(cv::Mat &)input { - NSDate *startTime; - NSDate *endTime; - NSTimeInterval executionTime; - - // Preprocessing - startTime = [NSDate date]; NSArray *modelInput = [self preprocess:input]; - endTime = [NSDate date]; - executionTime = [endTime timeIntervalSinceDate:startTime]; - NSLog(@"Preprocessing time: %f seconds", executionTime); - - // Running the model - startTime = [NSDate date]; NSArray *modelResult = [self forward:modelInput]; - endTime = [NSDate date]; - executionTime = [endTime timeIntervalSinceDate:startTime]; - NSLog(@"Model forwarding time: %f seconds", executionTime); - - // Postprocessing - startTime = [NSDate date]; NSArray *result = [self postprocess:modelResult]; - endTime = [NSDate date]; - executionTime = [endTime timeIntervalSinceDate:startTime]; - NSLog(@"Postprocessing time: %f seconds", executionTime); - return result; } diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index 93dbebd82e..72ec004ff1 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -1,14 +1,16 @@ #import "opencv2/opencv.hpp" -const int modelHeight = 64; -const int largeModelWidth = 512; -const int mediumModelWidth = 256; -const int smallModelWidth = 128; +constexpr int modelHeight = 64; +constexpr int largeModelWidth = 512; +constexpr int mediumModelWidth = 256; +constexpr int smallModelWidth = 128; +constexpr CGFloat lowConfidenceThreshold = 0.3; +constexpr CGFloat adjustContrast = 0.2; @interface RecognitionHandler : NSObject - (instancetype)initWithSymbols:(NSString *)symbols languageDictPath:(NSString *)languageDictPath; - (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath completion:(void (^)(BOOL, NSNumber *))completion; -- (NSArray *)recognize:(NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; +- (NSArray *)recognize:(NSArray *)bBoxesList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; @end diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 4979884f78..b1d8559995 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -62,50 +62,63 @@ - (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NS }); } -- (NSArray *)recognize: (NSArray *)horizontalList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { +- (NSArray *)runModel:(cv::Mat)croppedImage { + NSArray *result; + if(croppedImage.cols >= largeModelWidth) { + result = [recognizerLarge runModel:croppedImage]; + } else if (croppedImage.cols >= mediumModelWidth) { + result = [recognizerMedium runModel: croppedImage]; + } else { + result = [recognizerSmall runModel: croppedImage]; + } + + return result; +} + +- (NSArray *)recognize: (NSArray *)bBoxesList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { + NSDictionary* ratioAndPadding = [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols height:imgGray.rows desiredWidth:desiredWidth desiredHeight:desiredHeight]; + const int left = [ratioAndPadding[@"left"] intValue]; + const int top = [ratioAndPadding[@"top"] intValue]; + const CGFloat resizeRatio = [ratioAndPadding[@"resizeRatio"] floatValue]; imgGray = [OCRUtils resizeWithPadding:imgGray desiredWidth:desiredWidth desiredHeight:desiredHeight]; + NSMutableArray *predictions = [NSMutableArray array]; - NSLog(@"%@", horizontalList); - for (NSDictionary *box in horizontalList) { + for (NSDictionary *box in bBoxesList) { cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box image:imgGray modelHeight:modelHeight]; - croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:0.2]; - NSArray *result; - if(croppedImage.cols >= largeModelWidth) { - result = [recognizerLarge runModel:croppedImage]; - } else if (croppedImage.cols >= mediumModelWidth) { - result = [recognizerMedium runModel: croppedImage]; - } else { - result = [recognizerSmall runModel: croppedImage]; + if (croppedImage.empty()) { + continue; } + croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:adjustContrast]; + NSArray *result = [self runModel: croppedImage]; + NSNumber *confidenceScore = [result objectAtIndex:1]; - if([confidenceScore floatValue] < 0.3){ + if([confidenceScore floatValue] < lowConfidenceThreshold){ cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); + + NSArray *rotatedResult = [self runModel: croppedImage]; + NSNumber *rotatedConfidenceScore = [rotatedResult objectAtIndex:1]; + + if ([rotatedConfidenceScore floatValue] > [confidenceScore floatValue]) { + result = rotatedResult; + confidenceScore = rotatedConfidenceScore; + } } - NSArray *rotatedResult; - if(croppedImage.cols >= largeModelWidth) { - rotatedResult = [recognizerLarge runModel:croppedImage]; - } else if (croppedImage.cols >= mediumModelWidth) { - rotatedResult = [recognizerMedium runModel: croppedImage]; - } else { - rotatedResult = [recognizerSmall runModel: croppedImage]; - } - NSNumber *rotatedConfidenceScore = [rotatedResult objectAtIndex:1]; - if ([rotatedConfidenceScore floatValue] > [confidenceScore floatValue]) { - result = rotatedResult; - } + NSArray *predIndex = [result objectAtIndex:0]; + NSArray* decodedTexts = [converter decodeGreedy:predIndex length:(int)(predIndex.count)]; - NSArray *pred_index = [result objectAtIndex:0]; + NSMutableArray *bbox = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *coords in box[@"bbox"]){ + const CGPoint point = [coords CGPointValue]; + [bbox addObject: @{@"x": @((point.x - left) * resizeRatio), @"y": @((point.y - top) * resizeRatio)}]; + } - NSArray* decodedTexts = [converter decodeGreedy:pred_index length:(int)(pred_index.count)]; - - NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": box[@"box"], @"score": confidenceScore}; + NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": bbox, @"score": confidenceScore}; [predictions addObject:res]; } return predictions; } -@end - +@end \ No newline at end of file diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm index 10af525ce6..a6d9f7137d 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.mm +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -17,8 +17,8 @@ @implementation Recognizer { NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - int height = [heightNumber intValue]; - int width = [widthNumber intValue]; + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; return cv::Size(height, width); } @@ -27,8 +27,8 @@ @implementation Recognizer { NSNumber *widthNumber = outputShape.lastObject; NSNumber *heightNumber = outputShape[outputShape.count - 2]; - int height = [heightNumber intValue]; - int width = [widthNumber intValue]; + const int height = [heightNumber intValue]; + const int width = [widthNumber intValue]; return cv::Size(height, width); } @@ -37,7 +37,7 @@ - (NSArray *)preprocess:(cv::Mat &)input { } - (NSArray *)postprocess:(NSArray *)output { - int modelOutputHeight = [self getModelOutputSize].height; + const int modelOutputHeight = [self getModelOutputSize].height; NSInteger numElements = [output.firstObject count]; NSInteger numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight; cv::Mat resultMat = cv::Mat::zeros(numRows, modelOutputHeight, CV_32F); @@ -55,17 +55,18 @@ - (NSArray *)postprocess:(NSArray *)output { NSMutableArray *predsNorm = [RecognizerUtils sumProbabilityRows:probabilities modelOutputHeight:modelOutputHeight]; probabilities = [RecognizerUtils divideMatrix:probabilities byVector:predsNorm]; NSArray *maxValuesIndices = [RecognizerUtils findMaxValuesAndIndices:probabilities]; - double confidenceScore = [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] indicesArray:maxValuesIndices[1]]; + const CGFloat confidenceScore = [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] indicesArray:maxValuesIndices[1]]; return @[maxValuesIndices[1], @(confidenceScore)]; } - (NSArray *)runModel:(cv::Mat &)input { - NSArray* modelInput = [self preprocess:input]; + NSArray *modelInput = [self preprocess:input]; NSArray *modelResult = [self forward:modelInput]; NSArray *result = [self postprocess:modelResult]; return result; } + @end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index cc527d7ad2..8330cf9891 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -1,13 +1,21 @@ #import +constexpr int verticalLineThreshold = 20; + @interface DetectorUtils : NSObject -+ (NSDictionary *)splitInterleavedNSArray:(NSArray *)array; -+ (NSArray *)restoreBboxRatio:(NSArray *)boxes; -+ (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText; ++ (void)interleavedArrayToMats:(NSArray *)array + outputMat1:(cv::Mat &)mat1 + outputMat2:(cv::Mat &)mat2 + withSize:(cv::Size)size; ++ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap affinityMap:(cv::Mat)affinityMap usingTextThreshold:(CGFloat)textThreshold linkThreshold:(CGFloat)linkThreshold lowTextThreshold:(CGFloat)lowTextThreshold; ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes usingRestoreRatio:(CGFloat)restoreRatio; + (NSArray *)groupTextBoxes:(NSArray *)polys centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold - heightThreshold:(CGFloat)heightThreshold; + heightThreshold:(CGFloat)heightThreshold + minSideThreshold:(int)minSideThreshold + maxSideThreshold:(int)maxSideThreshold + maxWidth:(int)maxWidth; @end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index ec089190c6..d79d654c35 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -2,38 +2,151 @@ @implementation DetectorUtils ++ (void)interleavedArrayToMats:(NSArray *)array + outputMat1:(cv::Mat &)mat1 + outputMat2:(cv::Mat &)mat2 + withSize:(cv::Size)size { + mat1 = cv::Mat(size.height, size.width, CV_32F); + mat2 = cv::Mat(size.height, size.width, CV_32F); + + for (NSUInteger idx = 0; idx < array.count; idx++) { + const CGFloat value = [array[idx] doubleValue]; + const int x = (idx / 2) % size.width; + const int y = (idx / 2) / size.width; + + if (idx % 2 == 0) { + mat1.at(y, x) = value; + } else { + mat2.at(y, x) = value; + } + } +} + +/** + * This method applies a series of image processing operations to identify likely areas of text in the textMap and return the bounding boxes for single words. + * + * @param textMap A cv::Mat representing a heat map of the characters of text being present in an image. + * @param affinityMap A cv::Mat representing a heat map of the affinity between characters. + * @param textThreshold A CGFloat representing the threshold for the text map. + * @param linkThreshold A CGFloat representing the threshold for the affinity map. + * @param lowTextThreshold A CGFloat representing the low text. + * + * @return An NSArray containing NSDictionary objects. Each dictionary includes: + * - "bbox": an NSArray of CGPoint values representing the vertices of the detected text box. + * - "angle": an NSNumber representing the rotation angle of the box. + */ ++ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap affinityMap:(cv::Mat)affinityMap usingTextThreshold:(CGFloat)textThreshold linkThreshold:(CGFloat)linkThreshold lowTextThreshold:(CGFloat)lowTextThreshold { + const int imgH = textMap.rows; + const int imgW = textMap.cols; + cv::Mat textScore; + cv::Mat affinityScore; + cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, 1, cv::THRESH_BINARY); + cv::Mat textScoreComb = textScore + affinityScore; + cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); + cv::Mat binaryMat; + textScoreComb.convertTo(binaryMat, CV_8UC1); + + cv::Mat labels, stats, centroids; + const int nLabels = cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); + + NSMutableArray *detectedBoxes = [NSMutableArray array]; + for (int i = 1; i < nLabels; i++) { + const int area = stats.at(i, cv::CC_STAT_AREA); + if (area < 10) continue; + + cv::Mat mask = (labels == i); + CGFloat maxVal; + cv::minMaxLoc(textMap, NULL, &maxVal, NULL, NULL, mask); + if (maxVal < lowTextThreshold) continue; + + cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); + segMap.setTo(255, mask); + + const int x = stats.at(i, cv::CC_STAT_LEFT); + const int y = stats.at(i, cv::CC_STAT_TOP); + const int w = stats.at(i, cv::CC_STAT_WIDTH); + const int h = stats.at(i, cv::CC_STAT_HEIGHT); + const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h)) ) * 2.0); + const int sx = MAX(x - dilationRadius, 0); + const int ex = MIN(x + w + dilationRadius + 1, imgW); + const int sy = MAX(y - dilationRadius, 0); + const int ey = MIN(y + h + dilationRadius + 1, imgH); + + cv::Rect roi(sx, sy, ex - sx, ey - sy); + cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); + cv::Mat roiSegMap = segMap(roi); + cv::dilate(roiSegMap, roiSegMap, kernel); + + std::vector> contours; + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + if (!contours.empty()) { + cv::RotatedRect minRect = cv::minAreaRect(contours[0]); + cv::Point2f vertices[4]; + minRect.points(vertices); + NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; + for (int j = 0; j < 4; j++) { + const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); + [pointsArray addObject:[NSValue valueWithCGPoint:point]]; + } + NSDictionary *dict = @{@"bbox": pointsArray, @"angle": @(minRect.angle)}; + [detectedBoxes addObject:dict]; + } + } + + return detectedBoxes; +} + ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes usingRestoreRatio:(CGFloat)restoreRatio { + NSMutableArray *result = [NSMutableArray array]; + for (NSUInteger i = 0; i < [boxes count]; i++) { + NSDictionary *box = boxes[i]; + NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *value in box[@"bbox"]) { + CGPoint point = [value CGPointValue]; + point.x *= restoreRatio; + point.y *= restoreRatio; + [boxArray addObject:[NSValue valueWithCGPoint:point]]; + } + NSDictionary *dict = @{@"bbox": boxArray, @"angle": box[@"angle"]}; + [result addObject:dict]; + } + + return result; +} + +/** + * This method normalizes angle returned from cv::minAreaRect function which ranges from 0 to 90 degrees. + **/ + (CGFloat)normalizeAngle:(CGFloat)angle { - /* - Normalize the angle returned by OpenCV's minAreaRect. - */ if (angle > 45) { return angle - 90; } return angle; } -+ (CGFloat)distance:(CGPoint)p1 p2:(CGPoint)p2 { - double xDist = (p2.x - p1.x); - double yDist = (p2.y - p1.y); - return sqrt(xDist * xDist + yDist * yDist); ++ (CGPoint)midpointBetweenPoint:(CGPoint)p1 andPoint:(CGPoint)p2 { + return CGPointMake((p1.x + p2.x) / 2, (p1.y + p2.y) / 2); } -+ (CGPoint)midpoint:(CGPoint)p1 p2:(CGPoint)p2 { - return CGPointMake((p1.x + p2.x) / 2, (p1.y + p2.y) / 2); ++ (CGFloat)distanceFromPoint:(CGPoint)p1 toPoint:(CGPoint)p2 { + const CGFloat xDist = (p2.x - p1.x); + const CGFloat yDist = (p2.y - p1.y); + return sqrt(xDist * xDist + yDist * yDist); } -+ (CGPoint)centerOfBox:(NSArray *)box { - return CGPointMake(([box[0] CGPointValue].x + [box[2] CGPointValue].x) / 2, ([box[0] CGPointValue].y + [box[2] CGPointValue].y) / 2); ++ (CGPoint)centerOfBox:(NSArray *)box { + return [self midpointBetweenPoint:[box[0] CGPointValue] andPoint:[box[2] CGPointValue]]; } + (CGFloat)maxSideLength:(NSArray *)points { CGFloat maxSideLength = 0; NSInteger numOfPoints = points.count; for (NSInteger i = 0; i < numOfPoints; i++) { - CGPoint currentPoint = [points[i] CGPointValue]; - CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; + const CGPoint currentPoint = [points[i] CGPointValue]; + const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; - CGFloat sideLength = [self distance:currentPoint p2:nextPoint]; + const CGFloat sideLength = [self distanceFromPoint:currentPoint toPoint:nextPoint]; if (sideLength > maxSideLength) { maxSideLength = sideLength; } @@ -46,10 +159,10 @@ + (CGFloat)minSideLength:(NSArray *)points { NSInteger numOfPoints = points.count; for (NSInteger i = 0; i < numOfPoints; i++) { - CGPoint currentPoint = [points[i] CGPointValue]; - CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; + const CGPoint currentPoint = [points[i] CGPointValue]; + const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; - CGFloat sideLength = [self distance:currentPoint p2:nextPoint]; + const CGFloat sideLength = [self distanceFromPoint:currentPoint toPoint:nextPoint]; if (sideLength < minSideLength) { minSideLength = sideLength; } @@ -58,19 +171,68 @@ + (CGFloat)minSideLength:(NSArray *)points { return minSideLength; } ++ (CGFloat)calculateMinimalDistanceBetweenBox:(NSArray *)box1 andBox:(NSArray *)box2 { + CGFloat minDistance = CGFLOAT_MAX; + for (NSValue *value1 in box1) { + const CGPoint corner1 = [value1 CGPointValue]; + for (NSValue *value2 in box2) { + const CGPoint corner2 = [value2 CGPointValue]; + const CGFloat distance = [self distanceFromPoint:corner1 toPoint:corner2]; + if (distance < minDistance) { + minDistance = distance; + } + } + } + return minDistance; +} + ++ (NSArray *)rotateBox:(NSArray *)box withAngle:(CGFloat)angle { + const CGPoint center = [self centerOfBox:box]; + + const CGFloat radians = angle * M_PI / 180.0; + + NSMutableArray *rotatedPoints = [NSMutableArray arrayWithCapacity:4]; + for (NSValue *value in box) { + const CGPoint point = [value CGPointValue]; + + const CGFloat translatedX = point.x - center.x; + const CGFloat translatedY = point.y - center.y; + + const CGFloat rotatedX = translatedX * cos(radians) - translatedY * sin(radians); + const CGFloat rotatedY = translatedX * sin(radians) + translatedY * cos(radians); + + const CGPoint rotatedPoint = CGPointMake(rotatedX + center.x, rotatedY + center.y); + [rotatedPoints addObject:[NSValue valueWithCGPoint:rotatedPoint]]; + } + + return rotatedPoints; +} + +/** + * Orders a set of points in a clockwise direction starting with the top-left point. + * + * Process: + * 1. It iterates through each CGPoint extracted from the NSValues. + * 2. For each point, it calculates the sum (x + y) and difference (y - x) of the coordinates. + * 3. Points are classified into: + * - Top-left: Minimum sum. + * - Bottom-right: Maximum sum. + * - Top-right: Minimum difference. + * - Bottom-left: Maximum difference. + * 4. The points are ordered starting from the top-left in a clockwise manner: top-left, top-right, bottom-right, bottom-left. + */ + (NSArray *)orderPointsClockwise:(NSArray *)points{ CGPoint topLeft, topRight, bottomRight, bottomLeft; - float minSum = FLT_MAX; - float maxSum = -FLT_MAX; - float minDiff = FLT_MAX; - float maxDiff = -FLT_MAX; + CGFloat minSum = FLT_MAX; + CGFloat maxSum = -FLT_MAX; + CGFloat minDiff = FLT_MAX; + CGFloat maxDiff = -FLT_MAX; for (NSValue *value in points) { - CGPoint pt = [value CGPointValue]; - float sum = pt.x + pt.y; - float diff = pt.y - pt.x; + const CGPoint pt = [value CGPointValue]; + const CGFloat sum = pt.x + pt.y; + const CGFloat diff = pt.y - pt.x; - // For top-left and bottom-right determination if (sum < minSum) { minSum = sum; topLeft = pt; @@ -79,8 +241,6 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ maxSum = sum; bottomRight = pt; } - - // For top-right and bottom-left determination if (diff < minDiff) { minDiff = diff; topRight = pt; @@ -99,38 +259,10 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ return rect; } -+ (NSArray *)rotateBox:(NSArray *)box withAngle:(CGFloat)angle { - // Calculate the center of the rectangle - CGPoint center = [self centerOfBox:box]; - - // Convert angle from degrees to radians - CGFloat radians = angle * M_PI / 180.0; - - // Prepare an array to hold the rotated points - NSMutableArray *rotatedPoints = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *value in box) { - CGPoint point = [value CGPointValue]; - - // Translate point to origin - CGFloat translatedX = point.x - center.x; - CGFloat translatedY = point.y - center.y; - - // Rotate point - CGFloat rotatedX = translatedX * cos(radians) - translatedY * sin(radians); - CGFloat rotatedY = translatedX * sin(radians) + translatedY * cos(radians); - - // Translate point back - CGPoint rotatedPoint = CGPointMake(rotatedX + center.x, rotatedY + center.y); - [rotatedPoints addObject:[NSValue valueWithCGPoint:rotatedPoint]]; - } - - return rotatedPoints; -} - + (std::vector)pointsFromNSValues:(NSArray *)nsValues { std::vector points; for (NSValue *value in nsValues) { - CGPoint point = [value CGPointValue]; + const CGPoint point = [value CGPointValue]; points.emplace_back(point.x, point.y); } return points; @@ -151,12 +283,10 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ std::vector points1 = [self pointsFromNSValues:box1]; std::vector points2 = [self pointsFromNSValues:box2]; - // Collect all points from both rectangles std::vector allPoints; allPoints.insert(allPoints.end(), points1.begin(), points1.end()); allPoints.insert(allPoints.end(), points2.begin(), points2.end()); - // Calculate the convex hull of all points std::vector hullIndices; cv::convexHull(allPoints, hullIndices, false); @@ -165,92 +295,113 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ hullPoints.push_back(allPoints[idx]); } - // Get the minimum area rectangle that bounds the convex hull cv::RotatedRect minAreaRect = cv::minAreaRect(hullPoints); cv::Point2f rectPoints[4]; minAreaRect.points(rectPoints); - // Convert rotated rectangle points back to NSArray return [self nsValuesFromPoints:rectPoints count:4]; } -+ (NSDictionary *)splitInterleavedNSArray:(NSArray *)array { - NSMutableArray *scoreText = [[NSMutableArray alloc] init]; - NSMutableArray *scoreLink = [[NSMutableArray alloc] init]; ++ (NSMutableArray *)removeSmallBoxesFromArray:(NSArray *)boxes usingMinSideThreshold:(CGFloat)minSideThreshold maxSideThreshold:(CGFloat)maxSideThreshold { + NSMutableArray *filteredBoxes = [NSMutableArray array]; - [array enumerateObjectsUsingBlock:^(id element, NSUInteger idx, BOOL *stop) { - if (idx % 2 == 0) { - [scoreText addObject:element]; - } else { - [scoreLink addObject:element]; + for (NSDictionary *box in boxes) { + const CGFloat maxSideLength = [self maxSideLength:box[@"bbox"]]; + const CGFloat minSideLength = [self minSideLength:box[@"bbox"]]; + if (minSideLength > minSideThreshold && maxSideLength > maxSideThreshold) { + [filteredBoxes addObject:box]; } - }]; + } - return @{@"ScoreText": scoreText, @"ScoreLink": scoreLink}; + return filteredBoxes; } -+ (CGFloat)calculateMinimalDistance:(NSArray *)corners1 corners2:(NSArray *)corners2 { - CGFloat minDistance = CGFLOAT_MAX; - for (NSValue *value1 in corners1) { - CGPoint corner1 = [value1 CGPointValue]; - for (NSValue *value2 in corners2) { - CGPoint corner2 = [value2 CGPointValue]; - CGFloat distance = [self distance:corner1 p2:corner2]; - if (distance < minDistance) { - minDistance = distance; - } ++ (CGFloat)minimumYFromBox:(NSArray *)box { + __block CGFloat minY = CGFLOAT_MAX; + [box enumerateObjectsUsingBlock:^(NSValue * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { + const CGPoint pt = [obj CGPointValue]; + if (pt.y < minY) { + minY = pt.y; } - } - return minDistance; + }]; + return minY; } +/** + * This method calculates the distances between each sequential pair of points in a presumed quadrilateral, + * identifies the two shortest sides, and fits a linear model to the midpoints of these sides. It also evaluates + * whether the resulting line should be considered vertical based on a predefined threshold for the x-coordinate differences. + * + * If the line is vertical it is fitted as a function of x = my + c, otherwise as y = mx + c. + * + * @return A NSDictionary containing: + * - "slope": NSNumber representing the slope (m) of the line. + * - "intercept": NSNumber representing the line's intercept (c) with y-axis. + * - "isVertical": NSNumber (boolean) indicating whether the line is considered vertical. + */ + (NSDictionary *)fitLineToShortestSides:(NSArray *)points { - // Calculate distances and find midpoints NSMutableArray *sides = [NSMutableArray array]; NSMutableArray *midpoints = [NSMutableArray array]; for (int i = 0; i < 4; i++) { - CGPoint p1 = [points[i] CGPointValue]; - CGPoint p2 = [points[(i + 1) % 4] CGPointValue]; - - CGFloat sideLength = [self distance:p1 p2:p2]; - [sides addObject:@{@"length": @(sideLength), @"index": @(i)}]; - [midpoints addObject:[NSValue valueWithCGPoint:[self midpoint:p1 p2:p2]]]; + const CGPoint p1 = [points[i] CGPointValue]; + const CGPoint p2 = [points[(i + 1) % 4] CGPointValue]; + + const CGFloat sideLength = [self distanceFromPoint:p1 toPoint:p2]; + [sides addObject:@{@"length": @(sideLength), @"index": @(i)}]; + [midpoints addObject:[NSValue valueWithCGPoint:[self midpointBetweenPoint:p1 andPoint:p2]]]; } - // Sort indices by distances [sides sortUsingDescriptors:@[[NSSortDescriptor sortDescriptorWithKey:@"length" ascending:YES]]]; - CGPoint midpoint1 = [midpoints[[sides[0][@"index"] intValue]] CGPointValue]; - CGPoint midpoint2 = [midpoints[[sides[1][@"index"] intValue]] CGPointValue]; - CGFloat dx = fabs(midpoint2.x - midpoint1.x); + const CGPoint midpoint1 = [midpoints[[sides[0][@"index"] intValue]] CGPointValue]; + const CGPoint midpoint2 = [midpoints[[sides[1][@"index"] intValue]] CGPointValue]; + const CGFloat dx = fabs(midpoint2.x - midpoint1.x); - float m, c; + CGFloat m, c; BOOL isVertical; std::vector cvMidPoints = {cv::Point2f(midpoint1.x, midpoint1.y), cv::Point2f(midpoint2.x, midpoint2.y)}; cv::Vec4f line; - if (dx < 20) { - // If almost vertical, fit x = my + c + if (dx < verticalLineThreshold) { for (auto &pt : cvMidPoints) std::swap(pt.x, pt.y); cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); m = line[1] / line[0]; c = line[3] - m * line[2]; isVertical = YES; } else { - // Fit y = mx + c cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); m = line[1] / line[0]; c = line[3] - m * line[2]; isVertical = NO; } - + return @{@"slope": @(m), @"intercept": @(c), @"isVertical": @(isVertical)}; } -+ (NSDictionary *)findClosestBox:(NSArray *)polys +/** + * This method assesses each box from a provided array, checks its center against the center of a "current box", + * and evaluates its alignment with a specified line equation. The function specifically searches for the box + * whose center is closest to the current box, that has not been ignored, and fits within a defined distance from the line. + * + * @param boxes An NSArray of NSDictionary objects where each dictionary represents a box with keys "bbox" and "angle". + * "bbox" is an NSArray of NSValue objects each encapsulating CGPoint that define the box vertices. + * "angle" is a NSNumber representing the box's rotation angle. + * @param ignoredIdxs An NSSet of NSNumber objects representing indices of boxes to ignore in the evaluation. + * @param currentBox An NSArray of NSValue objects encapsulating CGPoints representing the current box to compare against. + * @param isVertical A pointer to a BOOL indicating if the line to compare distance to is vertical. + * @param m The slope (gradient) of the line against which the box's alignment is checked. + * @param c The y-intercept of the line equation y = mx + c. + * @param centerThreshold A multiplier to determine the threshold for the distance between the box's center and the line. + * + * @return A NSDictionary containing: + * - "idx" : NSNumber indicating the index of the found box in the original NSArray. + * - "boxHeight" : NSNumber representing the shortest side length of the found box. + * Returns nil if no suitable box is found. + */ ++ (NSDictionary *)findClosestBox:(NSArray *)boxes ignoredIdxs:(NSSet *)ignoredIdxs currentBox:(NSArray *)currentBox isVertical:(BOOL)isVertical @@ -259,190 +410,133 @@ + (NSDictionary *)findClosestBox:(NSArray *)polys centerThreshold:(CGFloat)centerThreshold { CGFloat smallestDistance = CGFLOAT_MAX; - NSDictionary *boxToMerge = nil; NSInteger idx = -1; CGFloat boxHeight = 0; - CGPoint centerOfCurrentBox = [self centerOfBox:currentBox]; + const CGPoint centerOfCurrentBox = [self centerOfBox:currentBox]; - for (NSUInteger i = 0; i < polys.count; i++) { + for (NSUInteger i = 0; i < boxes.count; i++) { if ([ignoredIdxs containsObject:@(i)]) { continue; } - NSArray *coords = polys[i][@"box"]; - CGFloat angle = [polys[i][@"angle"] doubleValue]; - CGPoint centerOfProcessedBox = [self centerOfBox:coords]; - CGFloat distanceBetweenCenters = [self distance:centerOfCurrentBox p2:centerOfProcessedBox]; + NSArray *bbox = boxes[i][@"bbox"]; + const CGPoint centerOfProcessedBox = [self centerOfBox:bbox]; + const CGFloat distanceBetweenCenters = [self distanceFromPoint:centerOfCurrentBox toPoint:centerOfProcessedBox]; if (distanceBetweenCenters >= smallestDistance) { continue; } - boxHeight = [self minSideLength:coords]; + boxHeight = [self minSideLength:bbox]; - CGFloat lineDistance = (isVertical ? + const CGFloat lineDistance = (isVertical ? fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) : fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); if (lineDistance < boxHeight * centerThreshold) { - boxToMerge = @{@"coords": coords, @"angle": @(angle)}; idx = i; smallestDistance = distanceBetweenCenters; } } - return boxToMerge ? @{@"boxToMerge": boxToMerge, @"idx": @(idx), @"boxHeight": @(boxHeight)} : nil; + return idx != -1 ? @{@"idx": @(idx), @"boxHeight": @(boxHeight)} : nil; } -+ (NSArray *)restoreBboxRatio:(NSArray *)boxes { - NSMutableArray *result = [NSMutableArray array]; - for (NSUInteger i = 0; i < [boxes count]; i++) { - NSDictionary *box = boxes[i]; - NSMutableArray *boxArray = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *value in box[@"box"]) { - CGPoint point = [value CGPointValue]; - point.x *= 2 * 1.6; - point.y *= 2 * 1.6; - [boxArray addObject:[NSValue valueWithCGPoint:point]]; - } - NSDictionary *dict = @{@"box": boxArray, @"angle": box[@"angle"]}; - [result addObject:dict]; - } - - return result; -} - -+ (NSArray *)getDetBoxes:(cv::Mat)textmap linkMap:(cv::Mat)linkmap textThreshold:(double)textThreshold linkThreshold:(double)linkThreshold lowText:(double)lowText { - int img_h = textmap.rows; - int img_w = textmap.cols; - cv::Mat textScore, linkScore; - cv::threshold(textmap, textScore, lowText, 1, 0); - cv::threshold(linkmap, linkScore, linkThreshold, 1, 0); - cv::Mat textScoreComb = textScore + linkScore; - cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); - cv::Mat binaryMat; - textScoreComb.convertTo(binaryMat, CV_8UC1); - - cv::Mat labels, stats, centroids; - int nLabels = cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); - - NSMutableArray *detectedBoxes = [NSMutableArray array]; - for (int i = 1; i < nLabels; i++) { - int area = stats.at(i, cv::CC_STAT_AREA); - if (area < 10) continue; - - cv::Mat mask = (labels == i); - double maxVal; - cv::minMaxLoc(textmap, NULL, &maxVal, NULL, NULL, mask); - if (maxVal < textThreshold) continue; - - cv::Mat segMap = cv::Mat::zeros(textmap.size(), CV_8U); - segMap.setTo(255, (labels == i)); - - int x = stats.at(i, cv::CC_STAT_LEFT); - int y = stats.at(i, cv::CC_STAT_TOP); - int w = stats.at(i, cv::CC_STAT_WIDTH); - int h = stats.at(i, cv::CC_STAT_HEIGHT); - int niter = (int)(sqrt((double)(area * MIN(w, h)) / (double)(w * h)) * 2.0); - int sx = x - niter; - int ex = x + w + niter + 1; - int sy = y - niter; - int ey = y + h + niter + 1; - if (sx < 0) sx = 0; - if (sy < 0) sy = 0; - if (ex >= img_w) ex = img_w; - if (ey >= img_h) ey = img_h; - cv::Rect roi(sx, sy, ex - sx, ey - sy); - - cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + niter, 1 + niter)); - cv::Mat roiSegMap = segMap(roi); - cv::dilate(roiSegMap, roiSegMap, kernel); - - // Find minimal area rect - std::vector> contours; - cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); - if (!contours.empty()) { - cv::RotatedRect minRect = cv::minAreaRect(contours[0]); - cv::Point2f vertices[4]; - minRect.points(vertices); - NSMutableArray *pointsArray = [NSMutableArray arrayWithCapacity:4]; - for (int j = 0; j < 4; j++) { - CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); - [pointsArray addObject:[NSValue valueWithCGPoint:point]]; - } - NSDictionary *dict = @{@"box": pointsArray, @"angle": @(minRect.angle)}; - [detectedBoxes addObject:dict]; - } - } - - return detectedBoxes; -} - -+ (CGFloat)minimumYFromBox:(NSArray *)box { - __block CGFloat minY = CGFLOAT_MAX; - [box enumerateObjectsUsingBlock:^(NSValue * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { - CGPoint pt = [obj CGPointValue]; - if (pt.y < minY) { - minY = pt.y; - } - }]; - return minY; -} - -+ (NSArray *)groupTextBoxes:(NSArray *)polys +/** + * This method processes an array of text box dictionaries, each containing details about individual text boxes, + * and attempts to group and merge these boxes based on specified criteria including proximity, alignment, + * and size thresholds. It prioritizes merging of boxes that are aligned closely in angle, are near each other, + * and whose sizes are compatible based on the given thresholds. + * + * @param boxes An array of NSDictionary objects where each dictionary represents a text box. Each dictionary must have + * at least a "bbox" key with an NSArray of NSValue wrapping CGPoints defining the box vertices, + * and an "angle" key indicating the orientation of the box. + * @param centerThreshold A CGFloat representing the threshold for considering the distance between center and fitted line. + * @param distanceThreshold A CGFloat that defines the maximum allowed distance between boxes for them to be considered for merging. + * @param heightThreshold A CGFloat representing the maximum allowed difference in height between boxes for merging. + * @param minSideThreshold An int that defines the minimum dimension threshold to filter out small boxes after grouping. + * @param maxSideThreshold An int that specifies the maximum dimension threshold for filtering boxes post-grouping. + * @param maxWidth An int that represents the maximum width allowable for a merged box. + * + * @return An NSArray of NSDictionary objects representing the merged boxes. Each dictionary contains: + * - "bbox": An NSArray of NSValue each containing a CGPoint that defines the vertices of the merged box. + * - "angle": NSNumber representing the computed orientation of the merged box. + * + * Processing Steps: + * 1. Sort initial boxes based on their maximum side length. + * 2. Sequentially merge boxes considering alignment, proximity, and size compatibility. + * 3. Post-processing to remove any boxes that are too small or exceed max side criteria. + * 4. Sort the final array of boxes by their vertical positions. + */ ++ (NSArray *)groupTextBoxes:(NSMutableArray *)boxes centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold heightThreshold:(CGFloat)heightThreshold + minSideThreshold:(int)minSideThreshold + maxSideThreshold:(int)maxSideThreshold + maxWidth:(int)maxWidth { - // Sort polys by max side length in descending order - NSMutableArray *sortedPolys = [polys sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { - CGFloat maxLen1 = [self maxSideLength:obj1[@"box"]]; - CGFloat maxLen2 = [self maxSideLength:obj2[@"box"]]; + // Sort boxes based on their maximum side length + boxes = [boxes sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { + const CGFloat maxLen1 = [self maxSideLength:obj1[@"bbox"]]; + const CGFloat maxLen2 = [self maxSideLength:obj2[@"bbox"]]; return (maxLen1 < maxLen2) ? NSOrderedDescending : (maxLen1 > maxLen2) ? NSOrderedAscending : NSOrderedSame; }].mutableCopy; - NSMutableArray *mergedList = [NSMutableArray array]; - CGFloat angleDegrees; - while (sortedPolys.count > 0) { - NSMutableDictionary *currentBox = [sortedPolys[0] mutableCopy]; - [sortedPolys removeObjectAtIndex:0]; - CGFloat currentAngle = [self normalizeAngle:[currentBox[@"angle"] floatValue]]; + NSMutableArray *mergedArray = [NSMutableArray array]; + CGFloat lineAngle; + while (boxes.count > 0) { + NSMutableDictionary *currentBox = [boxes[0] mutableCopy]; + [boxes removeObjectAtIndex:0]; NSMutableArray *ignoredIdxs = [NSMutableArray array]; while (YES) { - NSDictionary *lineFit = [self fitLineToShortestSides:currentBox[@"box"]]; - NSLog(@"lineFit: %@", lineFit); - angleDegrees = atan([lineFit[@"slope"] floatValue]) * 180 / M_PI; - if ([lineFit[@"isVertical"] boolValue]){ - angleDegrees = -90; + //Find all aligned boxes and merge them until max_size is reached or no more boxes can be merged + NSDictionary *fittedLine = [self fitLineToShortestSides:currentBox[@"bbox"]]; + const CGFloat slope = [fittedLine[@"slope"] floatValue]; + const CGFloat intercept = [fittedLine[@"intercept"] floatValue]; + const BOOL isVertical = [fittedLine[@"isVertical"] boolValue]; + + lineAngle = atan(slope) * 180 / M_PI; + if (isVertical){ + lineAngle = -90; } - CGFloat mergedHeight = [self minSideLength:currentBox[@"box"]]; - NSDictionary *closestBoxInfo = [self findClosestBox:sortedPolys ignoredIdxs:[NSSet setWithArray:ignoredIdxs] currentBox:currentBox[@"box"] isVertical:[lineFit[@"isVertical"] boolValue] m:[lineFit[@"slope"] floatValue] c:[lineFit[@"intercept"] floatValue] centerThreshold:centerThreshold]; + + NSDictionary *closestBoxInfo = [self findClosestBox:boxes ignoredIdxs:[NSSet setWithArray:ignoredIdxs] currentBox:currentBox[@"bbox"] isVertical:isVertical m:slope c:intercept centerThreshold:centerThreshold]; if (closestBoxInfo == nil) break; - NSMutableDictionary *candidateBox = [closestBoxInfo[@"boxToMerge"] mutableCopy]; NSInteger candidateIdx = [closestBoxInfo[@"idx"] integerValue]; - CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; - if (([candidateBox[@"angle"] isEqual: @90] && ![lineFit[@"isVertical"] boolValue]) || ([candidateBox[@"angle"] isEqual: @0] && [lineFit[@"isVertical"] boolValue])) { - candidateBox[@"coords"] = [self rotateBox:candidateBox[@"coords"] withAngle:currentAngle]; + NSMutableDictionary *candidateBox = [boxes[candidateIdx] mutableCopy]; + const CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; + + if (([candidateBox[@"angle"] isEqual: @90] && !isVertical) || ([candidateBox[@"angle"] isEqual: @0] && isVertical)) { + candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] withAngle:[currentBox[@"angle"] floatValue]]; } - CGFloat minDistance = [self calculateMinimalDistance:candidateBox[@"coords"] corners2:currentBox[@"box"]]; - if (minDistance < distanceThreshold * candidateHeight && fabs(mergedHeight - candidateHeight) < mergedHeight * heightThreshold) { - currentBox[@"box"] = [self mergeRotatedBoxes:currentBox[@"box"] withBox:candidateBox[@"coords"]]; - [sortedPolys removeObjectAtIndex:candidateIdx]; - [ignoredIdxs removeAllObjects]; // Restart with new merged box + + const CGFloat minDistance = [self calculateMinimalDistanceBetweenBox:candidateBox[@"bbox"] andBox:currentBox[@"bbox"]]; + const CGFloat mergedHeight = [self minSideLength:currentBox[@"bbox"]]; + if (minDistance < distanceThreshold * candidateHeight && fabs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { + currentBox[@"bbox"] = [self mergeRotatedBoxes:currentBox[@"bbox"] withBox:candidateBox[@"bbox"]]; + [boxes removeObjectAtIndex:candidateIdx]; + [ignoredIdxs removeAllObjects]; + if ([self maxSideLength:currentBox[@"bbox"]] > maxWidth){ + break; + } } else { [ignoredIdxs addObject:@(candidateIdx)]; } } - [mergedList addObject:@{@"box" : currentBox[@"box"], @"angle" : @(angleDegrees)}]; + + [mergedArray addObject:@{@"bbox" : currentBox[@"bbox"], @"angle" : @(lineAngle)}]; } - // Optionally sort by angle if needed - NSArray *sortedBoxes = [mergedList sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { - NSArray *coords1 = obj1[@"box"]; - NSArray *coords2 = obj2[@"box"]; - CGFloat minY1 = [self minimumYFromBox:coords1]; - CGFloat minY2 = [self minimumYFromBox:coords2]; + // Remove small boxes and sort by vertical + mergedArray = [self removeSmallBoxesFromArray:mergedArray usingMinSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold]; + + NSArray *sortedBoxes = [mergedArray sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { + NSArray *coords1 = obj1[@"bbox"]; + NSArray *coords2 = obj2[@"bbox"]; + const CGFloat minY1 = [self minimumYFromBox:coords1]; + const CGFloat minY2 = [self minimumYFromBox:coords2]; return (minY1 < minY2) ? NSOrderedAscending : (minY1 > minY2) ? NSOrderedDescending : NSOrderedSame; }]; diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 3e6adccde8..2154a1923b 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -37,7 +37,7 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { contrast = (high - low) / 255.0; if (contrast < target) { - double ratio = 200.0 / MAX(10, high - low); + const double ratio = 200.0 / MAX(10, high - low); img.convertTo(img, CV_32F); img = ((img - low + 25) * ratio); @@ -52,7 +52,8 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { + (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast { if (adjustContrast > 0) { - image = [self adjustContrastGrey:image target:adjustContrast]; } + image = [self adjustContrastGrey:image target:adjustContrast]; + } int desiredWidth = 128; if (image.cols >= 512) { @@ -73,7 +74,7 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { cv::Mat result = matrix.clone(); for (int i = 0; i < matrix.rows; i++) { - float divisor = [vector[i] floatValue]; + const float divisor = [vector[i] floatValue]; for (int j = 0; j < matrix.cols; j++) { result.at(i, j) /= divisor; } @@ -103,8 +104,8 @@ + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height const int deltaH = desiredHeight - newHeight; const int top = deltaH / 2; const int left = deltaW / 2; - float heightRatio = (float)height / desiredHeight; - float widthRatio = (float)width / desiredWidth; + const float heightRatio = (float)height / desiredHeight; + const float widthRatio = (float)width / desiredWidth; resizeRatio = MAX(heightRatio, widthRatio); @@ -116,51 +117,40 @@ + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height } + (cv::Mat)getCroppedImage:(NSDictionary *)box image:(cv::Mat)image modelHeight:(int)modelHeight { - // Convert NSValue array to cv::Point2f vector - NSArray *coords = box[@"box"]; - CGFloat angle = [box[@"angle"] floatValue]; + NSArray *coords = box[@"bbox"]; + const CGFloat angle = [box[@"angle"] floatValue]; std::vector points; for (NSValue *value in coords) { - CGPoint point = [value CGPointValue]; + const CGPoint point = [value CGPointValue]; points.emplace_back(static_cast(point.x), static_cast(point.y)); } - // Obtain the rotated rectangle from the points cv::RotatedRect rotatedRect = cv::minAreaRect(points); - - // Compute the rotation matrix for the angle of the rotated rectangle + cv::Point2f imageCenter = cv::Point2f(image.cols / 2.0, image.rows / 2.0); cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, angle, 1.0); - - // Rotate the entire image cv::Mat rotatedImage; cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), cv::INTER_LINEAR); - - // Get vertices of the minimal rotated rectangle cv::Point2f rectPoints[4]; rotatedRect.points(rectPoints); - - // Transform points using the rotation matrix std::vector transformedPoints(4); cv::Mat rectMat(4, 2, CV_32FC2, rectPoints); cv::transform(rectMat, rectMat, rotationMatrix); - // Convert points back to array of points to compute bounding box afterward for (int i = 0; i < 4; ++i) { transformedPoints[i] = rectPoints[i]; } - // Compute the bounding box of transformed points cv::Rect boundingBox = cv::boundingRect(transformedPoints); - - // Make sure the bounding box fits within the image boundingBox &= cv::Rect(0, 0, rotatedImage.cols, rotatedImage.rows); - - // Optional: Crop to this bounding box if necessary cv::Mat croppedImage = rotatedImage(boundingBox); + if (boundingBox.width == 0 || boundingBox.height == 0){ + croppedImage = cv::Mat().empty(); + + return croppedImage; + } - // If specified to resize according to modelHeight, adjust the aspect ratio and resize croppedImage = [self computeRatioAndResize:croppedImage width:boundingBox.width height:boundingBox.height modelHeight:modelHeight]; return croppedImage; diff --git a/src/index.tsx b/src/index.tsx index ec7dd8c128..066e14df4f 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -21,6 +21,7 @@ export * from './utils/listDownloadedResources'; // types export * from './types/object_detection'; +export * from './types/ocr'; // constants export * from './constants/modelUrls'; diff --git a/src/types/ocr.ts b/src/types/ocr.ts index 697dbcaaa9..f5f2e6d35e 100644 --- a/src/types/ocr.ts +++ b/src/types/ocr.ts @@ -1,7 +1,10 @@ -import { Bbox } from './object_detection'; - export interface OCRDetection { - bbox: Bbox; + bbox: OCRBbox[]; text: string; score: number; } + +export interface OCRBbox { + x: number; + y: number; +} From 70ced9bb99263a7147f2ad88f7391964f33ad0bf Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 6 Feb 2025 16:42:27 +0100 Subject: [PATCH 14/19] fix: add missing angle normalization --- examples/computer-vision/screens/OCRScreen.tsx | 1 - ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/computer-vision/screens/OCRScreen.tsx b/examples/computer-vision/screens/OCRScreen.tsx index ebcdf32eba..9d17118afb 100644 --- a/examples/computer-vision/screens/OCRScreen.tsx +++ b/examples/computer-vision/screens/OCRScreen.tsx @@ -47,7 +47,6 @@ export const OCRScreen = ({ }; const runForward = async () => { - console.log('RUnning forward'); try { const output = await model.forward(imageUri); setResults(output); diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index d79d654c35..c2dd3fc20f 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -486,6 +486,7 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes CGFloat lineAngle; while (boxes.count > 0) { NSMutableDictionary *currentBox = [boxes[0] mutableCopy]; + CGFloat normalizedAngle = [self normalizeAngle:[currentBox[@"angle"] floatValue]]; [boxes removeObjectAtIndex:0]; NSMutableArray *ignoredIdxs = [NSMutableArray array]; @@ -509,7 +510,7 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes const CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; if (([candidateBox[@"angle"] isEqual: @90] && !isVertical) || ([candidateBox[@"angle"] isEqual: @0] && isVertical)) { - candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] withAngle:[currentBox[@"angle"] floatValue]]; + candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] withAngle:normalizedAngle]; } const CGFloat minDistance = [self calculateMinimalDistanceBetweenBox:candidateBox[@"bbox"] andBox:currentBox[@"bbox"]]; From 659295f134996e6b109c147420c032ec11945514 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 10 Feb 2025 16:12:44 +0100 Subject: [PATCH 15/19] reformat: fix formatting of long lines --- ios/RnExecutorch/models/ocr/Detector.mm | 10 +++++----- ios/RnExecutorch/models/ocr/RecognitionHandler.mm | 4 ++-- ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm | 10 +++++----- ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 355e0b337a..411c178d8b 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -40,7 +40,6 @@ - (NSArray *)preprocess:(cv::Mat &)input { cv::Size modelImageSize = [self getModelImageSize]; cv::Mat resizedImage; resizedImage = [OCRUtils resizeWithPadding:input desiredWidth:modelImageSize.width desiredHeight:modelImageSize.height]; - NSArray *modelInput = [ImageProcessor matToNSArray: resizedImage mean:mean variance:variance]; return modelInput; } @@ -60,13 +59,14 @@ The output of the model consists of two matrices (heat maps): cv::Mat scoreTextCV, scoreAffinityCV; /* The output of the model is a matrix in size of input image containing two matrices representing heatmap. - Those two matrices are in the size of half of the input image, that's why the width and height is divided by 2. + Those two matrices are in the size of half of the input image, that's why the width and height is divided by 2. */ [DetectorUtils interleavedArrayToMats:predictions - outputMat1:scoreTextCV - outputMat2:scoreAffinityCV - withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; + outputMat1:scoreTextCV + outputMat2:scoreAffinityCV + withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; NSArray* bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV affinityMap:scoreAffinityCV usingTextThreshold:textThreshold linkThreshold:linkThreshold lowTextThreshold:lowTextThreshold]; + NSLog(@"Detected boxes: %lu", (unsigned long)bBoxesList.count); bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList usingRestoreRatio: restoreRatio]; bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList centerThreshold:centerThreshold distanceThreshold:distanceThreshold heightThreshold:heightThreshold minSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold maxWidth:maxWidth]; diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index b1d8559995..50e303df0e 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -90,7 +90,7 @@ - (NSArray *)recognize: (NSArray *)bBoxesList imgGray:(cv::Mat)i } croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:adjustContrast]; NSArray *result = [self runModel: croppedImage]; - + NSNumber *confidenceScore = [result objectAtIndex:1]; if([confidenceScore floatValue] < lowConfidenceThreshold){ @@ -121,4 +121,4 @@ - (NSArray *)recognize: (NSArray *)bBoxesList imgGray:(cv::Mat)i return predictions; } -@end \ No newline at end of file +@end diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index c2dd3fc20f..5e49f1f0a9 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -429,8 +429,8 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes boxHeight = [self minSideLength:bbox]; const CGFloat lineDistance = (isVertical ? - fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) : - fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); + fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) : + fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); if (lineDistance < boxHeight * centerThreshold) { idx = i; @@ -471,9 +471,9 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold heightThreshold:(CGFloat)heightThreshold - minSideThreshold:(int)minSideThreshold - maxSideThreshold:(int)maxSideThreshold - maxWidth:(int)maxWidth + minSideThreshold:(int)minSideThreshold + maxSideThreshold:(int)maxSideThreshold + maxWidth:(int)maxWidth { // Sort boxes based on their maximum side length boxes = [boxes sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 2154a1923b..74048e208c 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -127,7 +127,7 @@ + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height } cv::RotatedRect rotatedRect = cv::minAreaRect(points); - + cv::Point2f imageCenter = cv::Point2f(image.cols / 2.0, image.rows / 2.0); cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, angle, 1.0); cv::Mat rotatedImage; From 83580dfdfa9c6170b0b19dad2931b92c5500c43f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Thu, 13 Feb 2025 15:57:19 +0100 Subject: [PATCH 16/19] format: format with clang format --- ios/RnExecutorch/OCR.mm | 127 +++-- ios/RnExecutorch/models/ocr/Detector.h | 2 +- ios/RnExecutorch/models/ocr/Detector.mm | 66 ++- .../models/ocr/RecognitionHandler.h | 13 +- .../models/ocr/RecognitionHandler.mm | 128 +++-- ios/RnExecutorch/models/ocr/Recognizer.h | 2 +- ios/RnExecutorch/models/ocr/Recognizer.mm | 40 +- .../models/ocr/utils/CTCLabelConverter.h | 10 +- .../models/ocr/utils/CTCLabelConverter.mm | 52 +- .../models/ocr/utils/DetectorUtils.h | 9 +- .../models/ocr/utils/DetectorUtils.mm | 466 +++++++++++------- ios/RnExecutorch/models/ocr/utils/OCRUtils.h | 4 +- ios/RnExecutorch/models/ocr/utils/OCRUtils.mm | 38 +- .../models/ocr/utils/RecognizerUtils.h | 23 +- .../models/ocr/utils/RecognizerUtils.mm | 103 ++-- 15 files changed, 672 insertions(+), 411 deletions(-) diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index cd58d2c4b8..975c82989b 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -1,10 +1,10 @@ -#import -#import #import "OCR.h" -#import "utils/Fetcher.h" -#import "utils/ImageProcessor.h" #import "models/ocr/Detector.h" #import "models/ocr/RecognitionHandler.h" +#import "utils/Fetcher.h" +#import "utils/ImageProcessor.h" +#import +#import @implementation OCR { Detector *detector; @@ -14,41 +14,68 @@ @implementation OCR { RCT_EXPORT_MODULE() - (void)loadModule:(NSString *)detectorSource -recognizerSourceLarge:(NSString *)recognizerSourceLarge -recognizerSourceMedium:(NSString *)recognizerSourceMedium -recognizerSourceSmall:(NSString *)recognizerSourceSmall - symbols:(NSString *)symbols - languageDictPath:(NSString *)languageDictPath - resolve:(RCTPromiseResolveBlock)resolve - reject:(RCTPromiseRejectBlock)reject { + recognizerSourceLarge:(NSString *)recognizerSourceLarge + recognizerSourceMedium:(NSString *)recognizerSourceMedium + recognizerSourceSmall:(NSString *)recognizerSourceSmall + symbols:(NSString *)symbols + languageDictPath:(NSString *)languageDictPath + resolve:(RCTPromiseResolveBlock)resolve + reject:(RCTPromiseRejectBlock)reject { detector = [[Detector alloc] init]; - [detector loadModel:[NSURL URLWithString:detectorSource] completion:^(BOOL success, NSNumber *errorCode) { - if (!success) { - NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]]}]; - reject(@"init_module_error", @"Failed to initialize detector module", error); - return; - } - [Fetcher fetchResource:[NSURL URLWithString:languageDictPath] resourceType:ResourceType::TXT completionHandler:^(NSString *filePath, NSError *error) { - if (error) { - reject(@"init_module_error", @"Failed to initialize converter module", error); - return; - } - - self->recognitionHandler = [[RecognitionHandler alloc] initWithSymbols:symbols languageDictPath:filePath]; - [self->recognitionHandler loadRecognizers:recognizerSourceLarge mediumRecognizerPath:recognizerSourceMedium smallRecognizerPath:recognizerSourceSmall completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { - if (allModelsLoaded) { - resolve(@(YES)); - } else { - NSError *error = [NSError errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{NSLocalizedDescriptionKey: [NSString stringWithFormat:@"%ld", (long)[errorCode longValue]]}]; - reject(@"init_recognizer_error", @"Failed to initialize one or more recognizer models", error); + [detector + loadModel:[NSURL URLWithString:detectorSource] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + NSError *error = [NSError + errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{ + NSLocalizedDescriptionKey : [NSString + stringWithFormat:@"%ld", (long)[errorCode longValue]] + }]; + reject(@"init_module_error", @"Failed to initialize detector module", + error); + return; } + [Fetcher fetchResource:[NSURL URLWithString:languageDictPath] + resourceType:ResourceType::TXT + completionHandler:^(NSString *filePath, NSError *error) { + if (error) { + reject(@"init_module_error", + @"Failed to initialize converter module", error); + return; + } + + self->recognitionHandler = + [[RecognitionHandler alloc] initWithSymbols:symbols + languageDictPath:filePath]; + [self->recognitionHandler + loadRecognizers:recognizerSourceLarge + mediumRecognizerPath:recognizerSourceMedium + smallRecognizerPath:recognizerSourceSmall + completion:^(BOOL allModelsLoaded, + NSNumber *errorCode) { + if (allModelsLoaded) { + resolve(@(YES)); + } else { + NSError *error = [NSError + errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{ + NSLocalizedDescriptionKey : + [NSString stringWithFormat: + @"%ld", + (long)[errorCode + longValue]] + }]; + reject(@"init_recognizer_error", + @"Failed to initialize one or more " + @"recognizer models", + error); + } + }]; + }]; }]; - }]; - }]; } - (void)forward:(NSString *)input @@ -56,28 +83,34 @@ - (void)forward:(NSString *)input reject:(RCTPromiseRejectBlock)reject { /* The OCR consists of two phases: - 1. Detection - detecting text regions in the image, the result of this phase is a list of bounding boxes. - 2. Recognition - recognizing the text in the bounding boxes, the result is a list of strings and corresponding confidence scores. - - Recognition uses three models, each model is resposible for recognizing text of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). + 1. Detection - detecting text regions in the image, the result of this phase + is a list of bounding boxes. + 2. Recognition - recognizing the text in the bounding boxes, the result is a + list of strings and corresponding confidence scores. + + Recognition uses three models, each model is resposible for recognizing text + of different sizes (e.g. large - 512x64, medium - 256x64, small - 128x64). */ @try { cv::Mat image = [ImageProcessor readImage:input]; - NSArray* result = [detector runModel:image]; + NSArray *result = [detector runModel:image]; cv::Size detectorSize = [detector getModelImageSize]; cv::cvtColor(image, image, cv::COLOR_BGR2GRAY); - result = [self->recognitionHandler recognize:result imgGray:image desiredWidth:detectorSize.width * recognizerRatio desiredHeight:detectorSize.height * recognizerRatio]; + result = [self->recognitionHandler + recognize:result + imgGray:image + desiredWidth:detectorSize.width * recognizerRatio + desiredHeight:detectorSize.height * recognizerRatio]; resolve(result); } @catch (NSException *exception) { - reject(@"forward_error", [NSString stringWithFormat:@"%@", exception.reason], - nil); + reject(@"forward_error", + [NSString stringWithFormat:@"%@", exception.reason], nil); } } - (std::shared_ptr)getTurboModule: -(const facebook::react::ObjCTurboModule::InitParams &)params { - return std::make_shared( - params); + (const facebook::react::ObjCTurboModule::InitParams &)params { + return std::make_shared(params); } @end diff --git a/ios/RnExecutorch/models/ocr/Detector.h b/ios/RnExecutorch/models/ocr/Detector.h index 346069720a..0f67e93b84 100644 --- a/ios/RnExecutorch/models/ocr/Detector.h +++ b/ios/RnExecutorch/models/ocr/Detector.h @@ -1,6 +1,6 @@ -#import "opencv2/opencv.hpp" #import "BaseModel.h" #import "RecognitionHandler.h" +#import "opencv2/opencv.hpp" constexpr CGFloat textThreshold = 0.4; constexpr CGFloat linkThreshold = 0.4; diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 411c178d8b..56604ac607 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -4,8 +4,8 @@ #import "utils/OCRUtils.h" /* - The model used as detector is based on CRAFT (Character Region Awareness for Text Detection) paper. - https://arxiv.org/pdf/1904.01941 + The model used as detector is based on CRAFT (Character Region Awareness for + Text Detection) paper. https://arxiv.org/pdf/1904.01941 */ @implementation Detector { @@ -13,34 +13,38 @@ @implementation Detector { cv::Size modelSize; } -- (cv::Size)getModelImageSize{ - if(!modelSize.empty()) { +- (cv::Size)getModelImageSize { + if (!modelSize.empty()) { return modelSize; } - - NSArray *inputShape = [module getInputShape: @0]; + + NSArray *inputShape = [module getInputShape:@0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - + const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; modelSize = cv::Size(height, width); - + return cv::Size(height, width); } - (NSArray *)preprocess:(cv::Mat &)input { /* Detector as an input accepts tensor with a shape of [1, 3, 800, 800]. - Due to big influence of resize to quality of recognition the image preserves original - aspect ratio and the missing parts are filled with padding. + Due to big influence of resize to quality of recognition the image preserves + original aspect ratio and the missing parts are filled with padding. */ self->originalSize = cv::Size(input.cols, input.rows); - + cv::Size modelImageSize = [self getModelImageSize]; cv::Mat resizedImage; - resizedImage = [OCRUtils resizeWithPadding:input desiredWidth:modelImageSize.width desiredHeight:modelImageSize.height]; - NSArray *modelInput = [ImageProcessor matToNSArray: resizedImage mean:mean variance:variance]; + resizedImage = [OCRUtils resizeWithPadding:input + desiredWidth:modelImageSize.width + desiredHeight:modelImageSize.height]; + NSArray *modelInput = [ImageProcessor matToNSArray:resizedImage + mean:mean + variance:variance]; return modelInput; } @@ -48,28 +52,42 @@ - (NSArray *)postprocess:(NSArray *)output { /* The output of the model consists of two matrices (heat maps): 1. ScoreText(Score map) - The probability of a region containing character - 2. ScoreAffinity(Affinity map) - affinity between characters, used to to group each character into a single instance (sequence) - Both matrices are 400x400 - + 2. ScoreAffinity(Affinity map) - affinity between characters, used to to + group each character into a single instance (sequence) Both matrices are + 400x400 + The result of this step is a list of bounding boxes that contain text. */ NSArray *predictions = [output objectAtIndex:0]; - + cv::Size modelImageSize = [self getModelImageSize]; cv::Mat scoreTextCV, scoreAffinityCV; /* - The output of the model is a matrix in size of input image containing two matrices representing heatmap. - Those two matrices are in the size of half of the input image, that's why the width and height is divided by 2. + The output of the model is a matrix in size of input image containing two + matrices representing heatmap. Those two matrices are in the size of half of + the input image, that's why the width and height is divided by 2. */ [DetectorUtils interleavedArrayToMats:predictions outputMat1:scoreTextCV outputMat2:scoreAffinityCV - withSize:cv::Size(modelImageSize.width / 2, modelImageSize.height / 2)]; - NSArray* bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV affinityMap:scoreAffinityCV usingTextThreshold:textThreshold linkThreshold:linkThreshold lowTextThreshold:lowTextThreshold]; + withSize:cv::Size(modelImageSize.width / 2, + modelImageSize.height / 2)]; + NSArray *bBoxesList = [DetectorUtils getDetBoxesFromTextMap:scoreTextCV + affinityMap:scoreAffinityCV + usingTextThreshold:textThreshold + linkThreshold:linkThreshold + lowTextThreshold:lowTextThreshold]; NSLog(@"Detected boxes: %lu", (unsigned long)bBoxesList.count); - bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList usingRestoreRatio: restoreRatio]; - bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList centerThreshold:centerThreshold distanceThreshold:distanceThreshold heightThreshold:heightThreshold minSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold maxWidth:maxWidth]; - + bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList + usingRestoreRatio:restoreRatio]; + bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList + centerThreshold:centerThreshold + distanceThreshold:distanceThreshold + heightThreshold:heightThreshold + minSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold + maxWidth:maxWidth]; + return bBoxesList; } diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index 72ec004ff1..b638eff0e6 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -9,8 +9,15 @@ constexpr CGFloat adjustContrast = 0.2; @interface RecognitionHandler : NSObject -- (instancetype)initWithSymbols:(NSString *)symbols languageDictPath:(NSString *)languageDictPath; -- (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath completion:(void (^)(BOOL, NSNumber *))completion; -- (NSArray *)recognize:(NSArray *)bBoxesList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; +- (instancetype)initWithSymbols:(NSString *)symbols + languageDictPath:(NSString *)languageDictPath; +- (void)loadRecognizers:(NSString *)largeRecognizerPath + mediumRecognizerPath:(NSString *)mediumRecognizerPath + smallRecognizerPath:(NSString *)smallRecognizerPath + completion:(void (^)(BOOL, NSNumber *))completion; +- (NSArray *)recognize:(NSArray *)bBoxesList + imgGray:(cv::Mat)imgGray + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight; @end diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 50e303df0e..57cc419a58 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -1,16 +1,17 @@ -#import -#import "ExecutorchLib/ETModel.h" +#import "RecognitionHandler.h" #import "../../utils/Fetcher.h" #import "../../utils/ImageProcessor.h" #import "./utils/CTCLabelConverter.h" #import "./utils/OCRUtils.h" #import "./utils/RecognizerUtils.h" +#import "ExecutorchLib/ETModel.h" #import "Recognizer.h" -#import "RecognitionHandler.h" +#import /* - RecognitionHandler class is responsible for loading and choosing the appropriate recognizer model based on the input image size, - it also handles converting the model output to text. + RecognitionHandler class is responsible for loading and choosing the + appropriate recognizer model based on the input image size, it also handles + converting the model output to text. */ @implementation RecognitionHandler { @@ -20,41 +21,51 @@ @implementation RecognitionHandler { CTCLabelConverter *converter; } -- (instancetype)initWithSymbols:(NSString *)symbols languageDictPath:(NSString *)languageDictPath { +- (instancetype)initWithSymbols:(NSString *)symbols + languageDictPath:(NSString *)languageDictPath { self = [super init]; if (self) { recognizerLarge = [[Recognizer alloc] init]; recognizerMedium = [[Recognizer alloc] init]; recognizerSmall = [[Recognizer alloc] init]; - - converter = [[CTCLabelConverter alloc] initWithCharacters:symbols separatorList:@{} dictPathList:@{@"key": languageDictPath}]; + + converter = [[CTCLabelConverter alloc] + initWithCharacters:symbols + separatorList:@{} + dictPathList:@{@"key" : languageDictPath}]; } return self; } -- (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath completion:(void (^)(BOOL, NSNumber *))completion { +- (void)loadRecognizers:(NSString *)largeRecognizerPath + mediumRecognizerPath:(NSString *)mediumRecognizerPath + smallRecognizerPath:(NSString *)smallRecognizerPath + completion:(void (^)(BOOL, NSNumber *))completion { dispatch_group_t group = dispatch_group_create(); __block BOOL allSuccessful = YES; - - NSArray *recognizers = @[recognizerLarge, recognizerMedium, recognizerSmall]; - NSArray *paths = @[largeRecognizerPath, mediumRecognizerPath, smallRecognizerPath]; - + + NSArray *recognizers = + @[ recognizerLarge, recognizerMedium, recognizerSmall ]; + NSArray *paths = + @[ largeRecognizerPath, mediumRecognizerPath, smallRecognizerPath ]; + for (NSInteger i = 0; i < recognizers.count; i++) { Recognizer *recognizer = recognizers[i]; NSString *path = paths[i]; - + dispatch_group_enter(group); - [recognizer loadModel:[NSURL URLWithString: path] completion:^(BOOL success, NSNumber *errorCode) { - if (!success) { - allSuccessful = NO; - dispatch_group_leave(group); - completion(NO, errorCode); - return; - } - dispatch_group_leave(group); - }]; + [recognizer loadModel:[NSURL URLWithString:path] + completion:^(BOOL success, NSNumber *errorCode) { + if (!success) { + allSuccessful = NO; + dispatch_group_leave(group); + completion(NO, errorCode); + return; + } + dispatch_group_leave(group); + }]; } - + dispatch_group_notify(group, dispatch_get_main_queue(), ^{ if (allSuccessful) { completion(YES, @(0)); @@ -64,60 +75,79 @@ - (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NS - (NSArray *)runModel:(cv::Mat)croppedImage { NSArray *result; - if(croppedImage.cols >= largeModelWidth) { + if (croppedImage.cols >= largeModelWidth) { result = [recognizerLarge runModel:croppedImage]; } else if (croppedImage.cols >= mediumModelWidth) { - result = [recognizerMedium runModel: croppedImage]; + result = [recognizerMedium runModel:croppedImage]; } else { - result = [recognizerSmall runModel: croppedImage]; + result = [recognizerSmall runModel:croppedImage]; } - + return result; } -- (NSArray *)recognize: (NSArray *)bBoxesList imgGray:(cv::Mat)imgGray desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { - NSDictionary* ratioAndPadding = [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols height:imgGray.rows desiredWidth:desiredWidth desiredHeight:desiredHeight]; +- (NSArray *)recognize:(NSArray *)bBoxesList + imgGray:(cv::Mat)imgGray + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight { + NSDictionary *ratioAndPadding = + [RecognizerUtils calculateResizeRatioAndPaddings:imgGray.cols + height:imgGray.rows + desiredWidth:desiredWidth + desiredHeight:desiredHeight]; const int left = [ratioAndPadding[@"left"] intValue]; const int top = [ratioAndPadding[@"top"] intValue]; const CGFloat resizeRatio = [ratioAndPadding[@"resizeRatio"] floatValue]; - imgGray = [OCRUtils resizeWithPadding:imgGray desiredWidth:desiredWidth desiredHeight:desiredHeight]; - + imgGray = [OCRUtils resizeWithPadding:imgGray + desiredWidth:desiredWidth + desiredHeight:desiredHeight]; + NSMutableArray *predictions = [NSMutableArray array]; for (NSDictionary *box in bBoxesList) { - cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box image:imgGray modelHeight:modelHeight]; + cv::Mat croppedImage = [RecognizerUtils getCroppedImage:box + image:imgGray + modelHeight:modelHeight]; if (croppedImage.empty()) { continue; } - croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage adjustContrast:adjustContrast]; - NSArray *result = [self runModel: croppedImage]; - - + croppedImage = [RecognizerUtils normalizeForRecognizer:croppedImage + adjustContrast:adjustContrast]; + NSArray *result = [self runModel:croppedImage]; + NSNumber *confidenceScore = [result objectAtIndex:1]; - if([confidenceScore floatValue] < lowConfidenceThreshold){ + if ([confidenceScore floatValue] < lowConfidenceThreshold) { cv::rotate(croppedImage, croppedImage, cv::ROTATE_180); - - NSArray *rotatedResult = [self runModel: croppedImage]; + + NSArray *rotatedResult = [self runModel:croppedImage]; NSNumber *rotatedConfidenceScore = [rotatedResult objectAtIndex:1]; - + if ([rotatedConfidenceScore floatValue] > [confidenceScore floatValue]) { result = rotatedResult; confidenceScore = rotatedConfidenceScore; } } - + NSArray *predIndex = [result objectAtIndex:0]; - NSArray* decodedTexts = [converter decodeGreedy:predIndex length:(int)(predIndex.count)]; - + NSArray *decodedTexts = [converter decodeGreedy:predIndex + length:(int)(predIndex.count)]; + NSMutableArray *bbox = [NSMutableArray arrayWithCapacity:4]; - for (NSValue *coords in box[@"bbox"]){ + for (NSValue *coords in box[@"bbox"]) { const CGPoint point = [coords CGPointValue]; - [bbox addObject: @{@"x": @((point.x - left) * resizeRatio), @"y": @((point.y - top) * resizeRatio)}]; + [bbox addObject:@{ + @"x" : @((point.x - left) * resizeRatio), + @"y" : @((point.y - top) * resizeRatio) + }]; } - - NSDictionary *res = @{@"text": decodedTexts[0], @"bbox": bbox, @"score": confidenceScore}; + + NSDictionary *res = @{ + @"text" : decodedTexts[0], + @"bbox" : bbox, + @"score" : confidenceScore + }; [predictions addObject:res]; } - + return predictions; } diff --git a/ios/RnExecutorch/models/ocr/Recognizer.h b/ios/RnExecutorch/models/ocr/Recognizer.h index 63047ac00a..4b301dbef7 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.h +++ b/ios/RnExecutorch/models/ocr/Recognizer.h @@ -1,5 +1,5 @@ -#import "opencv2/opencv.hpp" #import "BaseModel.h" +#import "opencv2/opencv.hpp" @interface Recognizer : BaseModel diff --git a/ios/RnExecutorch/models/ocr/Recognizer.mm b/ios/RnExecutorch/models/ocr/Recognizer.mm index a6d9f7137d..8b339bc238 100644 --- a/ios/RnExecutorch/models/ocr/Recognizer.mm +++ b/ios/RnExecutorch/models/ocr/Recognizer.mm @@ -1,6 +1,6 @@ #import "Recognizer.h" -#import "RecognizerUtils.h" #import "../../utils/ImageProcessor.h" +#import "RecognizerUtils.h" #import "utils/OCRUtils.h" /* @@ -12,21 +12,21 @@ @implementation Recognizer { cv::Size originalSize; } -- (cv::Size)getModelImageSize{ - NSArray *inputShape = [module getInputShape: @0]; +- (cv::Size)getModelImageSize { + NSArray *inputShape = [module getInputShape:@0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - + const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; return cv::Size(height, width); } -- (cv::Size)getModelOutputSize{ - NSArray *outputShape = [module getOutputShape: @0]; +- (cv::Size)getModelOutputSize { + NSArray *outputShape = [module getOutputShape:@0]; NSNumber *widthNumber = outputShape.lastObject; NSNumber *heightNumber = outputShape[outputShape.count - 2]; - + const int height = [heightNumber intValue]; const int width = [widthNumber intValue]; return cv::Size(height, width); @@ -47,26 +47,32 @@ - (NSArray *)postprocess:(NSArray *)output { resultMat.at(currentRow, counter) = [num floatValue]; counter++; if (counter >= modelOutputHeight) { - counter = 0; currentRow++; + counter = 0; + currentRow++; } } - + cv::Mat probabilities = [RecognizerUtils softmax:resultMat]; - NSMutableArray *predsNorm = [RecognizerUtils sumProbabilityRows:probabilities modelOutputHeight:modelOutputHeight]; - probabilities = [RecognizerUtils divideMatrix:probabilities byVector:predsNorm]; - NSArray *maxValuesIndices = [RecognizerUtils findMaxValuesAndIndices:probabilities]; - const CGFloat confidenceScore = [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] indicesArray:maxValuesIndices[1]]; - - return @[maxValuesIndices[1], @(confidenceScore)]; + NSMutableArray *predsNorm = + [RecognizerUtils sumProbabilityRows:probabilities + modelOutputHeight:modelOutputHeight]; + probabilities = [RecognizerUtils divideMatrix:probabilities + byVector:predsNorm]; + NSArray *maxValuesIndices = + [RecognizerUtils findMaxValuesAndIndices:probabilities]; + const CGFloat confidenceScore = + [RecognizerUtils computeConfidenceScore:maxValuesIndices[0] + indicesArray:maxValuesIndices[1]]; + + return @[ maxValuesIndices[1], @(confidenceScore) ]; } - (NSArray *)runModel:(cv::Mat &)input { NSArray *modelInput = [self preprocess:input]; NSArray *modelResult = [self forward:modelInput]; NSArray *result = [self postprocess:modelResult]; - + return result; } - @end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h index 037782f4bb..cae07be437 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h @@ -8,8 +8,12 @@ @property(strong, nonatomic) NSArray *ignoreIdx; @property(strong, nonatomic) NSDictionary *dictList; -- (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDictionary *)separatorList dictPathList:(NSDictionary *)dictPathList; -- (void)loadDictionariesWithDictPathList:(NSDictionary *)dictPathList; -- (NSArray *)decodeGreedy:(NSArray *)textIndex length:(NSInteger)length; +- (instancetype)initWithCharacters:(NSString *)characters + separatorList:(NSDictionary *)separatorList + dictPathList:(NSDictionary *)dictPathList; +- (void)loadDictionariesWithDictPathList: + (NSDictionary *)dictPathList; +- (NSArray *)decodeGreedy:(NSArray *)textIndex + length:(NSInteger)length; @end diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm index 644a29e213..ca0fd30da0 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm @@ -2,21 +2,25 @@ @implementation CTCLabelConverter -- (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDictionary *)separatorList dictPathList:(NSDictionary *)dictPathList { +- (instancetype)initWithCharacters:(NSString *)characters + separatorList:(NSDictionary *)separatorList + dictPathList:(NSDictionary *)dictPathList { self = [super init]; if (self) { _dict = [NSMutableDictionary dictionary]; - NSMutableArray *mutableCharacters = [NSMutableArray arrayWithObject:@"[blank]"]; - + NSMutableArray *mutableCharacters = + [NSMutableArray arrayWithObject:@"[blank]"]; + for (NSUInteger i = 0; i < [characters length]; i++) { - NSString *charStr = [NSString stringWithFormat:@"%C", [characters characterAtIndex:i]]; + NSString *charStr = + [NSString stringWithFormat:@"%C", [characters characterAtIndex:i]]; [mutableCharacters addObject:charStr]; self.dict[charStr] = @(i + 1); } - + _character = [mutableCharacters copy]; _separatorList = separatorList; - + NSMutableArray *ignoreIndexes = [NSMutableArray arrayWithObject:@(0)]; for (NSString *sep in separatorList.allValues) { NSUInteger index = [characters rangeOfString:sep].location; @@ -31,62 +35,70 @@ - (instancetype)initWithCharacters:(NSString *)characters separatorList:(NSDicti return self; } -- (void)loadDictionariesWithDictPathList:(NSDictionary *)dictPathList { +- (void)loadDictionariesWithDictPathList: + (NSDictionary *)dictPathList { NSMutableDictionary *tempDictList = [NSMutableDictionary dictionary]; for (NSString *lang in dictPathList.allKeys) { NSString *dictPath = dictPathList[lang]; NSError *error; - NSString *fileContents = [NSString stringWithContentsOfFile:dictPath encoding:NSUTF8StringEncoding error:&error]; + NSString *fileContents = + [NSString stringWithContentsOfFile:dictPath + encoding:NSUTF8StringEncoding + error:&error]; if (error) { NSLog(@"Error reading file: %@", error.localizedDescription); continue; } - NSArray *lines = [fileContents componentsSeparatedByCharactersInSet:[NSCharacterSet newlineCharacterSet]]; + NSArray *lines = [fileContents + componentsSeparatedByCharactersInSet:[NSCharacterSet + newlineCharacterSet]]; [tempDictList setObject:lines forKey:lang]; } _dictList = [tempDictList copy]; } -- (NSArray *)decodeGreedy:(NSArray *)textIndex length:(NSInteger)length { +- (NSArray *)decodeGreedy:(NSArray *)textIndex + length:(NSInteger)length { NSMutableArray *texts = [NSMutableArray array]; NSUInteger index = 0; - + while (index < textIndex.count) { NSUInteger segmentLength = MIN(length, textIndex.count - index); NSRange range = NSMakeRange(index, segmentLength); NSArray *subArray = [textIndex subarrayWithRange:range]; - + NSMutableString *text = [NSMutableString string]; NSNumber *lastChar = nil; - - NSMutableArray *isNotRepeated = [NSMutableArray arrayWithObject:@YES]; + + NSMutableArray *isNotRepeated = + [NSMutableArray arrayWithObject:@YES]; NSMutableArray *isNotIgnored = [NSMutableArray array]; - + for (NSUInteger i = 0; i < subArray.count; i++) { NSNumber *currentChar = subArray[i]; if (i > 0) { [isNotRepeated addObject:@(![lastChar isEqualToNumber:currentChar])]; } [isNotIgnored addObject:@(![self.ignoreIdx containsObject:currentChar])]; - + lastChar = currentChar; } - + for (NSUInteger j = 0; j < subArray.count; j++) { if ([isNotRepeated[j] boolValue] && [isNotIgnored[j] boolValue]) { NSUInteger charIndex = [subArray[j] unsignedIntegerValue]; [text appendString:self.character[charIndex]]; } } - + [texts addObject:text.copy]; index += segmentLength; - + if (segmentLength < length) { break; } } - + return texts.copy; } diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h index 8330cf9891..3f205b8ebd 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.h @@ -8,8 +8,13 @@ constexpr int verticalLineThreshold = 20; outputMat1:(cv::Mat &)mat1 outputMat2:(cv::Mat &)mat2 withSize:(cv::Size)size; -+ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap affinityMap:(cv::Mat)affinityMap usingTextThreshold:(CGFloat)textThreshold linkThreshold:(CGFloat)linkThreshold lowTextThreshold:(CGFloat)lowTextThreshold; -+ (NSArray *)restoreBboxRatio:(NSArray *)boxes usingRestoreRatio:(CGFloat)restoreRatio; ++ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + lowTextThreshold:(CGFloat)lowTextThreshold; ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes + usingRestoreRatio:(CGFloat)restoreRatio; + (NSArray *)groupTextBoxes:(NSArray *)polys centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold diff --git a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm index 5e49f1f0a9..8ee7424d00 100644 --- a/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/DetectorUtils.mm @@ -8,12 +8,12 @@ + (void)interleavedArrayToMats:(NSArray *)array withSize:(cv::Size)size { mat1 = cv::Mat(size.height, size.width, CV_32F); mat2 = cv::Mat(size.height, size.width, CV_32F); - + for (NSUInteger idx = 0; idx < array.count; idx++) { const CGFloat value = [array[idx] doubleValue]; const int x = (idx / 2) % size.width; const int y = (idx / 2) / size.width; - + if (idx % 2 == 0) { mat1.at(y, x) = value; } else { @@ -23,63 +23,79 @@ + (void)interleavedArrayToMats:(NSArray *)array } /** - * This method applies a series of image processing operations to identify likely areas of text in the textMap and return the bounding boxes for single words. + * This method applies a series of image processing operations to identify + * likely areas of text in the textMap and return the bounding boxes for single + * words. * - * @param textMap A cv::Mat representing a heat map of the characters of text being present in an image. - * @param affinityMap A cv::Mat representing a heat map of the affinity between characters. + * @param textMap A cv::Mat representing a heat map of the characters of text + * being present in an image. + * @param affinityMap A cv::Mat representing a heat map of the affinity between + * characters. * @param textThreshold A CGFloat representing the threshold for the text map. - * @param linkThreshold A CGFloat representing the threshold for the affinity map. + * @param linkThreshold A CGFloat representing the threshold for the affinity + * map. * @param lowTextThreshold A CGFloat representing the low text. * * @return An NSArray containing NSDictionary objects. Each dictionary includes: - * - "bbox": an NSArray of CGPoint values representing the vertices of the detected text box. + * - "bbox": an NSArray of CGPoint values representing the vertices of the + * detected text box. * - "angle": an NSNumber representing the rotation angle of the box. */ -+ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap affinityMap:(cv::Mat)affinityMap usingTextThreshold:(CGFloat)textThreshold linkThreshold:(CGFloat)linkThreshold lowTextThreshold:(CGFloat)lowTextThreshold { ++ (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap + affinityMap:(cv::Mat)affinityMap + usingTextThreshold:(CGFloat)textThreshold + linkThreshold:(CGFloat)linkThreshold + lowTextThreshold:(CGFloat)lowTextThreshold { const int imgH = textMap.rows; const int imgW = textMap.cols; cv::Mat textScore; cv::Mat affinityScore; cv::threshold(textMap, textScore, textThreshold, 1, cv::THRESH_BINARY); - cv::threshold(affinityMap, affinityScore, linkThreshold, 1, cv::THRESH_BINARY); + cv::threshold(affinityMap, affinityScore, linkThreshold, 1, + cv::THRESH_BINARY); cv::Mat textScoreComb = textScore + affinityScore; cv::threshold(textScoreComb, textScoreComb, 0, 1, cv::THRESH_BINARY); cv::Mat binaryMat; textScoreComb.convertTo(binaryMat, CV_8UC1); - + cv::Mat labels, stats, centroids; - const int nLabels = cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); - + const int nLabels = + cv::connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4); + NSMutableArray *detectedBoxes = [NSMutableArray array]; for (int i = 1; i < nLabels; i++) { const int area = stats.at(i, cv::CC_STAT_AREA); - if (area < 10) continue; - + if (area < 10) + continue; + cv::Mat mask = (labels == i); CGFloat maxVal; cv::minMaxLoc(textMap, NULL, &maxVal, NULL, NULL, mask); - if (maxVal < lowTextThreshold) continue; - + if (maxVal < lowTextThreshold) + continue; + cv::Mat segMap = cv::Mat::zeros(textMap.size(), CV_8U); segMap.setTo(255, mask); - + const int x = stats.at(i, cv::CC_STAT_LEFT); const int y = stats.at(i, cv::CC_STAT_TOP); const int w = stats.at(i, cv::CC_STAT_WIDTH); const int h = stats.at(i, cv::CC_STAT_HEIGHT); - const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h)) ) * 2.0); + const int dilationRadius = (int)(sqrt((double)(area / MAX(w, h))) * 2.0); const int sx = MAX(x - dilationRadius, 0); const int ex = MIN(x + w + dilationRadius + 1, imgW); const int sy = MAX(y - dilationRadius, 0); const int ey = MIN(y + h + dilationRadius + 1, imgH); - + cv::Rect roi(sx, sy, ex - sx, ey - sy); - cv::Mat kernel = cv::getStructuringElement(cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); + cv::Mat kernel = cv::getStructuringElement( + cv::MORPH_RECT, cv::Size(1 + dilationRadius, 1 + dilationRadius)); cv::Mat roiSegMap = segMap(roi); cv::dilate(roiSegMap, roiSegMap, kernel); - + std::vector> contours; - cv::findContours(segMap, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE); + cv::findContours(segMap, contours, cv::RETR_EXTERNAL, + cv::CHAIN_APPROX_SIMPLE); if (!contours.empty()) { cv::RotatedRect minRect = cv::minAreaRect(contours[0]); cv::Point2f vertices[4]; @@ -89,15 +105,17 @@ + (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap affinityMap:(cv::Mat)affini const CGPoint point = CGPointMake(vertices[j].x, vertices[j].y); [pointsArray addObject:[NSValue valueWithCGPoint:point]]; } - NSDictionary *dict = @{@"bbox": pointsArray, @"angle": @(minRect.angle)}; + NSDictionary *dict = + @{@"bbox" : pointsArray, @"angle" : @(minRect.angle)}; [detectedBoxes addObject:dict]; } } - + return detectedBoxes; } -+ (NSArray *)restoreBboxRatio:(NSArray *)boxes usingRestoreRatio:(CGFloat)restoreRatio { ++ (NSArray *)restoreBboxRatio:(NSArray *)boxes + usingRestoreRatio:(CGFloat)restoreRatio { NSMutableArray *result = [NSMutableArray array]; for (NSUInteger i = 0; i < [boxes count]; i++) { NSDictionary *box = boxes[i]; @@ -108,15 +126,16 @@ + (NSArray *)getDetBoxesFromTextMap:(cv::Mat)textMap affinityMap:(cv::Mat)affini point.y *= restoreRatio; [boxArray addObject:[NSValue valueWithCGPoint:point]]; } - NSDictionary *dict = @{@"bbox": boxArray, @"angle": box[@"angle"]}; + NSDictionary *dict = @{@"bbox" : boxArray, @"angle" : box[@"angle"]}; [result addObject:dict]; } - + return result; } /** - * This method normalizes angle returned from cv::minAreaRect function which ranges from 0 to 90 degrees. + * This method normalizes angle returned from cv::minAreaRect function which + *ranges from 0 to 90 degrees. **/ + (CGFloat)normalizeAngle:(CGFloat)angle { if (angle > 45) { @@ -135,8 +154,9 @@ + (CGFloat)distanceFromPoint:(CGPoint)p1 toPoint:(CGPoint)p2 { return sqrt(xDist * xDist + yDist * yDist); } -+ (CGPoint)centerOfBox:(NSArray *)box { - return [self midpointBetweenPoint:[box[0] CGPointValue] andPoint:[box[2] CGPointValue]]; ++ (CGPoint)centerOfBox:(NSArray *)box { + return [self midpointBetweenPoint:[box[0] CGPointValue] + andPoint:[box[2] CGPointValue]]; } + (CGFloat)maxSideLength:(NSArray *)points { @@ -145,8 +165,9 @@ + (CGFloat)maxSideLength:(NSArray *)points { for (NSInteger i = 0; i < numOfPoints; i++) { const CGPoint currentPoint = [points[i] CGPointValue]; const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; - - const CGFloat sideLength = [self distanceFromPoint:currentPoint toPoint:nextPoint]; + + const CGFloat sideLength = [self distanceFromPoint:currentPoint + toPoint:nextPoint]; if (sideLength > maxSideLength) { maxSideLength = sideLength; } @@ -157,21 +178,23 @@ + (CGFloat)maxSideLength:(NSArray *)points { + (CGFloat)minSideLength:(NSArray *)points { CGFloat minSideLength = CGFLOAT_MAX; NSInteger numOfPoints = points.count; - + for (NSInteger i = 0; i < numOfPoints; i++) { const CGPoint currentPoint = [points[i] CGPointValue]; const CGPoint nextPoint = [points[(i + 1) % numOfPoints] CGPointValue]; - - const CGFloat sideLength = [self distanceFromPoint:currentPoint toPoint:nextPoint]; + + const CGFloat sideLength = [self distanceFromPoint:currentPoint + toPoint:nextPoint]; if (sideLength < minSideLength) { minSideLength = sideLength; } } - + return minSideLength; } -+ (CGFloat)calculateMinimalDistanceBetweenBox:(NSArray *)box1 andBox:(NSArray *)box2 { ++ (CGFloat)calculateMinimalDistanceBetweenBox:(NSArray *)box1 + andBox:(NSArray *)box2 { CGFloat minDistance = CGFLOAT_MAX; for (NSValue *value1 in box1) { const CGPoint corner1 = [value1 CGPointValue]; @@ -186,53 +209,61 @@ + (CGFloat)calculateMinimalDistanceBetweenBox:(NSArray *)box1 andBox: return minDistance; } -+ (NSArray *)rotateBox:(NSArray *)box withAngle:(CGFloat)angle { ++ (NSArray *)rotateBox:(NSArray *)box + withAngle:(CGFloat)angle { const CGPoint center = [self centerOfBox:box]; - + const CGFloat radians = angle * M_PI / 180.0; - - NSMutableArray *rotatedPoints = [NSMutableArray arrayWithCapacity:4]; + + NSMutableArray *rotatedPoints = + [NSMutableArray arrayWithCapacity:4]; for (NSValue *value in box) { const CGPoint point = [value CGPointValue]; - + const CGFloat translatedX = point.x - center.x; const CGFloat translatedY = point.y - center.y; - - const CGFloat rotatedX = translatedX * cos(radians) - translatedY * sin(radians); - const CGFloat rotatedY = translatedX * sin(radians) + translatedY * cos(radians); - - const CGPoint rotatedPoint = CGPointMake(rotatedX + center.x, rotatedY + center.y); + + const CGFloat rotatedX = + translatedX * cos(radians) - translatedY * sin(radians); + const CGFloat rotatedY = + translatedX * sin(radians) + translatedY * cos(radians); + + const CGPoint rotatedPoint = + CGPointMake(rotatedX + center.x, rotatedY + center.y); [rotatedPoints addObject:[NSValue valueWithCGPoint:rotatedPoint]]; } - + return rotatedPoints; } /** - * Orders a set of points in a clockwise direction starting with the top-left point. + * Orders a set of points in a clockwise direction starting with the top-left + * point. * * Process: * 1. It iterates through each CGPoint extracted from the NSValues. - * 2. For each point, it calculates the sum (x + y) and difference (y - x) of the coordinates. + * 2. For each point, it calculates the sum (x + y) and difference (y - x) of + * the coordinates. * 3. Points are classified into: * - Top-left: Minimum sum. * - Bottom-right: Maximum sum. * - Top-right: Minimum difference. * - Bottom-left: Maximum difference. - * 4. The points are ordered starting from the top-left in a clockwise manner: top-left, top-right, bottom-right, bottom-left. + * 4. The points are ordered starting from the top-left in a clockwise manner: + * top-left, top-right, bottom-right, bottom-left. */ -+ (NSArray *)orderPointsClockwise:(NSArray *)points{ ++ (NSArray *)orderPointsClockwise:(NSArray *)points { CGPoint topLeft, topRight, bottomRight, bottomLeft; CGFloat minSum = FLT_MAX; CGFloat maxSum = -FLT_MAX; CGFloat minDiff = FLT_MAX; CGFloat maxDiff = -FLT_MAX; - + for (NSValue *value in points) { const CGPoint pt = [value CGPointValue]; const CGFloat sum = pt.x + pt.y; const CGFloat diff = pt.y - pt.x; - + if (sum < minSum) { minSum = sum; topLeft = pt; @@ -250,12 +281,13 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ bottomLeft = pt; } } - - NSArray *rect = @[[NSValue valueWithCGPoint:topLeft], - [NSValue valueWithCGPoint:topRight], - [NSValue valueWithCGPoint:bottomRight], - [NSValue valueWithCGPoint:bottomLeft]]; - + + NSArray *rect = @[ + [NSValue valueWithCGPoint:topLeft], [NSValue valueWithCGPoint:topRight], + [NSValue valueWithCGPoint:bottomRight], + [NSValue valueWithCGPoint:bottomLeft] + ]; + return rect; } @@ -268,44 +300,51 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ return points; } -+ (NSArray *)nsValuesFromPoints:(cv::Point2f *)points count:(int)count { - NSMutableArray *nsValues = [[NSMutableArray alloc] initWithCapacity:count]; ++ (NSArray *)nsValuesFromPoints:(cv::Point2f *)points + count:(int)count { + NSMutableArray *nsValues = + [[NSMutableArray alloc] initWithCapacity:count]; for (int i = 0; i < count; i++) { - [nsValues addObject:[NSValue valueWithCGPoint:CGPointMake(points[i].x, points[i].y)]]; + [nsValues addObject:[NSValue valueWithCGPoint:CGPointMake(points[i].x, + points[i].y)]]; } return nsValues; } -+ (NSArray *)mergeRotatedBoxes:(NSArray *)box1 withBox:(NSArray *)box2 { ++ (NSArray *)mergeRotatedBoxes:(NSArray *)box1 + withBox:(NSArray *)box2 { box1 = [self orderPointsClockwise:box1]; box2 = [self orderPointsClockwise:box2]; - + std::vector points1 = [self pointsFromNSValues:box1]; std::vector points2 = [self pointsFromNSValues:box2]; - + std::vector allPoints; allPoints.insert(allPoints.end(), points1.begin(), points1.end()); allPoints.insert(allPoints.end(), points2.begin(), points2.end()); - + std::vector hullIndices; cv::convexHull(allPoints, hullIndices, false); - + std::vector hullPoints; for (int idx : hullIndices) { hullPoints.push_back(allPoints[idx]); } - + cv::RotatedRect minAreaRect = cv::minAreaRect(hullPoints); - + cv::Point2f rectPoints[4]; minAreaRect.points(rectPoints); - + return [self nsValuesFromPoints:rectPoints count:4]; } -+ (NSMutableArray *)removeSmallBoxesFromArray:(NSArray *)boxes usingMinSideThreshold:(CGFloat)minSideThreshold maxSideThreshold:(CGFloat)maxSideThreshold { ++ (NSMutableArray *) + removeSmallBoxesFromArray:(NSArray *)boxes + usingMinSideThreshold:(CGFloat)minSideThreshold + maxSideThreshold:(CGFloat)maxSideThreshold { NSMutableArray *filteredBoxes = [NSMutableArray array]; - + for (NSDictionary *box in boxes) { const CGFloat maxSideLength = [self maxSideLength:box[@"bbox"]]; const CGFloat minSideLength = [self minSideLength:box[@"bbox"]]; @@ -313,13 +352,14 @@ + (NSArray *)orderPointsClockwise:(NSArray *)points{ [filteredBoxes addObject:box]; } } - + return filteredBoxes; } + (CGFloat)minimumYFromBox:(NSArray *)box { __block CGFloat minY = CGFLOAT_MAX; - [box enumerateObjectsUsingBlock:^(NSValue * _Nonnull obj, NSUInteger idx, BOOL * _Nonnull stop) { + [box enumerateObjectsUsingBlock:^(NSValue *_Nonnull obj, NSUInteger idx, + BOOL *_Nonnull stop) { const CGPoint pt = [obj CGPointValue]; if (pt.y < minY) { minY = pt.y; @@ -329,44 +369,57 @@ + (CGFloat)minimumYFromBox:(NSArray *)box { } /** - * This method calculates the distances between each sequential pair of points in a presumed quadrilateral, - * identifies the two shortest sides, and fits a linear model to the midpoints of these sides. It also evaluates - * whether the resulting line should be considered vertical based on a predefined threshold for the x-coordinate differences. + * This method calculates the distances between each sequential pair of points + * in a presumed quadrilateral, identifies the two shortest sides, and fits a + * linear model to the midpoints of these sides. It also evaluates whether the + * resulting line should be considered vertical based on a predefined threshold + * for the x-coordinate differences. * - * If the line is vertical it is fitted as a function of x = my + c, otherwise as y = mx + c. + * If the line is vertical it is fitted as a function of x = my + c, otherwise + * as y = mx + c. * * @return A NSDictionary containing: * - "slope": NSNumber representing the slope (m) of the line. * - "intercept": NSNumber representing the line's intercept (c) with y-axis. - * - "isVertical": NSNumber (boolean) indicating whether the line is considered vertical. + * - "isVertical": NSNumber (boolean) indicating whether the line is + * considered vertical. */ + (NSDictionary *)fitLineToShortestSides:(NSArray *)points { NSMutableArray *sides = [NSMutableArray array]; NSMutableArray *midpoints = [NSMutableArray array]; - + for (int i = 0; i < 4; i++) { const CGPoint p1 = [points[i] CGPointValue]; const CGPoint p2 = [points[(i + 1) % 4] CGPointValue]; - + const CGFloat sideLength = [self distanceFromPoint:p1 toPoint:p2]; - [sides addObject:@{@"length": @(sideLength), @"index": @(i)}]; - [midpoints addObject:[NSValue valueWithCGPoint:[self midpointBetweenPoint:p1 andPoint:p2]]]; + [sides addObject:@{@"length" : @(sideLength), @"index" : @(i)}]; + [midpoints + addObject:[NSValue valueWithCGPoint:[self midpointBetweenPoint:p1 + andPoint:p2]]]; } - - [sides sortUsingDescriptors:@[[NSSortDescriptor sortDescriptorWithKey:@"length" ascending:YES]]]; - - const CGPoint midpoint1 = [midpoints[[sides[0][@"index"] intValue]] CGPointValue]; - const CGPoint midpoint2 = [midpoints[[sides[1][@"index"] intValue]] CGPointValue]; + + [sides + sortUsingDescriptors:@[ [NSSortDescriptor sortDescriptorWithKey:@"length" + ascending:YES] ]]; + + const CGPoint midpoint1 = + [midpoints [[sides [0] [@"index"] intValue]] CGPointValue]; + const CGPoint midpoint2 = + [midpoints [[sides [1] [@"index"] intValue]] CGPointValue]; const CGFloat dx = fabs(midpoint2.x - midpoint1.x); - + CGFloat m, c; BOOL isVertical; - - std::vector cvMidPoints = {cv::Point2f(midpoint1.x, midpoint1.y), cv::Point2f(midpoint2.x, midpoint2.y)}; + + std::vector cvMidPoints = { + cv::Point2f(midpoint1.x, midpoint1.y), + cv::Point2f(midpoint2.x, midpoint2.y)}; cv::Vec4f line; - + if (dx < verticalLineThreshold) { - for (auto &pt : cvMidPoints) std::swap(pt.x, pt.y); + for (auto &pt : cvMidPoints) + std::swap(pt.x, pt.y); cv::fitLine(cvMidPoints, line, cv::DIST_L2, 0, 0.01, 0.01); m = line[1] / line[0]; c = line[3] - m * line[2]; @@ -377,29 +430,38 @@ + (NSDictionary *)fitLineToShortestSides:(NSArray *)points { c = line[3] - m * line[2]; isVertical = NO; } - - return @{@"slope": @(m), @"intercept": @(c), @"isVertical": @(isVertical)}; + + return @{@"slope" : @(m), @"intercept" : @(c), @"isVertical" : @(isVertical)}; } /** - * This method assesses each box from a provided array, checks its center against the center of a "current box", - * and evaluates its alignment with a specified line equation. The function specifically searches for the box - * whose center is closest to the current box, that has not been ignored, and fits within a defined distance from the line. + * This method assesses each box from a provided array, checks its center + * against the center of a "current box", and evaluates its alignment with a + * specified line equation. The function specifically searches for the box whose + * center is closest to the current box, that has not been ignored, and fits + * within a defined distance from the line. * - * @param boxes An NSArray of NSDictionary objects where each dictionary represents a box with keys "bbox" and "angle". - * "bbox" is an NSArray of NSValue objects each encapsulating CGPoint that define the box vertices. + * @param boxes An NSArray of NSDictionary objects where each dictionary + * represents a box with keys "bbox" and "angle". "bbox" is an NSArray of + * NSValue objects each encapsulating CGPoint that define the box vertices. * "angle" is a NSNumber representing the box's rotation angle. - * @param ignoredIdxs An NSSet of NSNumber objects representing indices of boxes to ignore in the evaluation. - * @param currentBox An NSArray of NSValue objects encapsulating CGPoints representing the current box to compare against. - * @param isVertical A pointer to a BOOL indicating if the line to compare distance to is vertical. - * @param m The slope (gradient) of the line against which the box's alignment is checked. + * @param ignoredIdxs An NSSet of NSNumber objects representing indices of boxes + * to ignore in the evaluation. + * @param currentBox An NSArray of NSValue objects encapsulating CGPoints + * representing the current box to compare against. + * @param isVertical A pointer to a BOOL indicating if the line to compare + * distance to is vertical. + * @param m The slope (gradient) of the line against which the box's alignment + * is checked. * @param c The y-intercept of the line equation y = mx + c. - * @param centerThreshold A multiplier to determine the threshold for the distance between the box's center and the line. + * @param centerThreshold A multiplier to determine the threshold for the + * distance between the box's center and the line. * * @return A NSDictionary containing: - * - "idx" : NSNumber indicating the index of the found box in the original NSArray. - * - "boxHeight" : NSNumber representing the shortest side length of the found box. - * Returns nil if no suitable box is found. + * - "idx" : NSNumber indicating the index of the found box in the + * original NSArray. + * - "boxHeight" : NSNumber representing the shortest side length of the + * found box. Returns nil if no suitable box is found. */ + (NSDictionary *)findClosestBox:(NSArray *)boxes ignoredIdxs:(NSSet *)ignoredIdxs @@ -407,140 +469,184 @@ + (NSDictionary *)findClosestBox:(NSArray *)boxes isVertical:(BOOL)isVertical m:(CGFloat)m c:(CGFloat)c - centerThreshold:(CGFloat)centerThreshold -{ + centerThreshold:(CGFloat)centerThreshold { CGFloat smallestDistance = CGFLOAT_MAX; NSInteger idx = -1; CGFloat boxHeight = 0; const CGPoint centerOfCurrentBox = [self centerOfBox:currentBox]; - + for (NSUInteger i = 0; i < boxes.count; i++) { if ([ignoredIdxs containsObject:@(i)]) { continue; } NSArray *bbox = boxes[i][@"bbox"]; const CGPoint centerOfProcessedBox = [self centerOfBox:bbox]; - const CGFloat distanceBetweenCenters = [self distanceFromPoint:centerOfCurrentBox toPoint:centerOfProcessedBox]; - + const CGFloat distanceBetweenCenters = + [self distanceFromPoint:centerOfCurrentBox + toPoint:centerOfProcessedBox]; + if (distanceBetweenCenters >= smallestDistance) { continue; } - + boxHeight = [self minSideLength:bbox]; - - const CGFloat lineDistance = (isVertical ? - fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) : - fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); - + + const CGFloat lineDistance = + (isVertical + ? fabs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) + : fabs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c))); + if (lineDistance < boxHeight * centerThreshold) { idx = i; smallestDistance = distanceBetweenCenters; } } - - return idx != -1 ? @{@"idx": @(idx), @"boxHeight": @(boxHeight)} : nil; + + return idx != -1 ? @{@"idx" : @(idx), @"boxHeight" : @(boxHeight)} : nil; } /** - * This method processes an array of text box dictionaries, each containing details about individual text boxes, - * and attempts to group and merge these boxes based on specified criteria including proximity, alignment, - * and size thresholds. It prioritizes merging of boxes that are aligned closely in angle, are near each other, - * and whose sizes are compatible based on the given thresholds. + * This method processes an array of text box dictionaries, each containing + * details about individual text boxes, and attempts to group and merge these + * boxes based on specified criteria including proximity, alignment, and size + * thresholds. It prioritizes merging of boxes that are aligned closely in + * angle, are near each other, and whose sizes are compatible based on the given + * thresholds. * - * @param boxes An array of NSDictionary objects where each dictionary represents a text box. Each dictionary must have - * at least a "bbox" key with an NSArray of NSValue wrapping CGPoints defining the box vertices, - * and an "angle" key indicating the orientation of the box. - * @param centerThreshold A CGFloat representing the threshold for considering the distance between center and fitted line. - * @param distanceThreshold A CGFloat that defines the maximum allowed distance between boxes for them to be considered for merging. - * @param heightThreshold A CGFloat representing the maximum allowed difference in height between boxes for merging. - * @param minSideThreshold An int that defines the minimum dimension threshold to filter out small boxes after grouping. - * @param maxSideThreshold An int that specifies the maximum dimension threshold for filtering boxes post-grouping. - * @param maxWidth An int that represents the maximum width allowable for a merged box. + * @param boxes An array of NSDictionary objects where each dictionary + * represents a text box. Each dictionary must have at least a "bbox" key with + * an NSArray of NSValue wrapping CGPoints defining the box vertices, and an + * "angle" key indicating the orientation of the box. + * @param centerThreshold A CGFloat representing the threshold for considering + * the distance between center and fitted line. + * @param distanceThreshold A CGFloat that defines the maximum allowed distance + * between boxes for them to be considered for merging. + * @param heightThreshold A CGFloat representing the maximum allowed difference + * in height between boxes for merging. + * @param minSideThreshold An int that defines the minimum dimension threshold + * to filter out small boxes after grouping. + * @param maxSideThreshold An int that specifies the maximum dimension threshold + * for filtering boxes post-grouping. + * @param maxWidth An int that represents the maximum width allowable for a + * merged box. * - * @return An NSArray of NSDictionary objects representing the merged boxes. Each dictionary contains: - * - "bbox": An NSArray of NSValue each containing a CGPoint that defines the vertices of the merged box. - * - "angle": NSNumber representing the computed orientation of the merged box. + * @return An NSArray of NSDictionary objects representing the merged boxes. + * Each dictionary contains: + * - "bbox": An NSArray of NSValue each containing a CGPoint that + * defines the vertices of the merged box. + * - "angle": NSNumber representing the computed orientation of the + * merged box. * * Processing Steps: * 1. Sort initial boxes based on their maximum side length. - * 2. Sequentially merge boxes considering alignment, proximity, and size compatibility. - * 3. Post-processing to remove any boxes that are too small or exceed max side criteria. + * 2. Sequentially merge boxes considering alignment, proximity, and size + * compatibility. + * 3. Post-processing to remove any boxes that are too small or exceed max side + * criteria. * 4. Sort the final array of boxes by their vertical positions. */ -+ (NSArray *)groupTextBoxes:(NSMutableArray *)boxes ++ (NSArray *)groupTextBoxes: + (NSMutableArray *)boxes centerThreshold:(CGFloat)centerThreshold distanceThreshold:(CGFloat)distanceThreshold heightThreshold:(CGFloat)heightThreshold minSideThreshold:(int)minSideThreshold maxSideThreshold:(int)maxSideThreshold - maxWidth:(int)maxWidth -{ + maxWidth:(int)maxWidth { // Sort boxes based on their maximum side length - boxes = [boxes sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { - const CGFloat maxLen1 = [self maxSideLength:obj1[@"bbox"]]; - const CGFloat maxLen2 = [self maxSideLength:obj2[@"bbox"]]; - return (maxLen1 < maxLen2) ? NSOrderedDescending : (maxLen1 > maxLen2) ? NSOrderedAscending : NSOrderedSame; - }].mutableCopy; - + boxes = [boxes sortedArrayUsingComparator:^NSComparisonResult( + NSDictionary *obj1, NSDictionary *obj2) { + const CGFloat maxLen1 = [self maxSideLength:obj1[@"bbox"]]; + const CGFloat maxLen2 = [self maxSideLength:obj2[@"bbox"]]; + return (maxLen1 < maxLen2) ? NSOrderedDescending + : (maxLen1 > maxLen2) ? NSOrderedAscending + : NSOrderedSame; + }].mutableCopy; + NSMutableArray *mergedArray = [NSMutableArray array]; CGFloat lineAngle; while (boxes.count > 0) { NSMutableDictionary *currentBox = [boxes[0] mutableCopy]; - CGFloat normalizedAngle = [self normalizeAngle:[currentBox[@"angle"] floatValue]]; + CGFloat normalizedAngle = + [self normalizeAngle:[currentBox[@"angle"] floatValue]]; [boxes removeObjectAtIndex:0]; NSMutableArray *ignoredIdxs = [NSMutableArray array]; - + while (YES) { - //Find all aligned boxes and merge them until max_size is reached or no more boxes can be merged - NSDictionary *fittedLine = [self fitLineToShortestSides:currentBox[@"bbox"]]; + // Find all aligned boxes and merge them until max_size is reached or no + // more boxes can be merged + NSDictionary *fittedLine = + [self fitLineToShortestSides:currentBox[@"bbox"]]; const CGFloat slope = [fittedLine[@"slope"] floatValue]; const CGFloat intercept = [fittedLine[@"intercept"] floatValue]; const BOOL isVertical = [fittedLine[@"isVertical"] boolValue]; - + lineAngle = atan(slope) * 180 / M_PI; - if (isVertical){ + if (isVertical) { lineAngle = -90; } - - NSDictionary *closestBoxInfo = [self findClosestBox:boxes ignoredIdxs:[NSSet setWithArray:ignoredIdxs] currentBox:currentBox[@"bbox"] isVertical:isVertical m:slope c:intercept centerThreshold:centerThreshold]; - if (closestBoxInfo == nil) break; - + + NSDictionary *closestBoxInfo = + [self findClosestBox:boxes + ignoredIdxs:[NSSet setWithArray:ignoredIdxs] + currentBox:currentBox[@"bbox"] + isVertical:isVertical + m:slope + c:intercept + centerThreshold:centerThreshold]; + if (closestBoxInfo == nil) + break; + NSInteger candidateIdx = [closestBoxInfo[@"idx"] integerValue]; NSMutableDictionary *candidateBox = [boxes[candidateIdx] mutableCopy]; const CGFloat candidateHeight = [closestBoxInfo[@"boxHeight"] floatValue]; - - if (([candidateBox[@"angle"] isEqual: @90] && !isVertical) || ([candidateBox[@"angle"] isEqual: @0] && isVertical)) { - candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] withAngle:normalizedAngle]; + + if (([candidateBox[@"angle"] isEqual:@90] && !isVertical) || + ([candidateBox[@"angle"] isEqual:@0] && isVertical)) { + candidateBox[@"bbox"] = [self rotateBox:candidateBox[@"bbox"] + withAngle:normalizedAngle]; } - - const CGFloat minDistance = [self calculateMinimalDistanceBetweenBox:candidateBox[@"bbox"] andBox:currentBox[@"bbox"]]; + + const CGFloat minDistance = + [self calculateMinimalDistanceBetweenBox:candidateBox[@"bbox"] + andBox:currentBox[@"bbox"]]; const CGFloat mergedHeight = [self minSideLength:currentBox[@"bbox"]]; - if (minDistance < distanceThreshold * candidateHeight && fabs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { - currentBox[@"bbox"] = [self mergeRotatedBoxes:currentBox[@"bbox"] withBox:candidateBox[@"bbox"]]; + if (minDistance < distanceThreshold * candidateHeight && + fabs(mergedHeight - candidateHeight) < + candidateHeight * heightThreshold) { + currentBox[@"bbox"] = [self mergeRotatedBoxes:currentBox[@"bbox"] + withBox:candidateBox[@"bbox"]]; [boxes removeObjectAtIndex:candidateIdx]; [ignoredIdxs removeAllObjects]; - if ([self maxSideLength:currentBox[@"bbox"]] > maxWidth){ + if ([self maxSideLength:currentBox[@"bbox"]] > maxWidth) { break; } } else { [ignoredIdxs addObject:@(candidateIdx)]; } } - - [mergedArray addObject:@{@"bbox" : currentBox[@"bbox"], @"angle" : @(lineAngle)}]; + + [mergedArray + addObject:@{@"bbox" : currentBox[@"bbox"], @"angle" : @(lineAngle)}]; } - + // Remove small boxes and sort by vertical - mergedArray = [self removeSmallBoxesFromArray:mergedArray usingMinSideThreshold:minSideThreshold maxSideThreshold:maxSideThreshold]; - - NSArray *sortedBoxes = [mergedArray sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, NSDictionary *obj2) { - NSArray *coords1 = obj1[@"bbox"]; - NSArray *coords2 = obj2[@"bbox"]; - const CGFloat minY1 = [self minimumYFromBox:coords1]; - const CGFloat minY2 = [self minimumYFromBox:coords2]; - return (minY1 < minY2) ? NSOrderedAscending : (minY1 > minY2) ? NSOrderedDescending : NSOrderedSame; - }]; - + mergedArray = [self removeSmallBoxesFromArray:mergedArray + usingMinSideThreshold:minSideThreshold + maxSideThreshold:maxSideThreshold]; + + NSArray *sortedBoxes = [mergedArray + sortedArrayUsingComparator:^NSComparisonResult(NSDictionary *obj1, + NSDictionary *obj2) { + NSArray *coords1 = obj1[@"bbox"]; + NSArray *coords2 = obj2[@"bbox"]; + const CGFloat minY1 = [self minimumYFromBox:coords1]; + const CGFloat minY2 = [self minimumYFromBox:coords2]; + return (minY1 < minY2) ? NSOrderedAscending + : (minY1 > minY2) ? NSOrderedDescending + : NSOrderedSame; + }]; + return sortedBoxes; } diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h index 0304ad37e3..dca8b9bba5 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.h @@ -2,6 +2,8 @@ @interface OCRUtils : NSObject -+ (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; ++ (cv::Mat)resizeWithPadding:(cv::Mat)img + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight; @end diff --git a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm index 3bec624450..f530dac2da 100644 --- a/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/OCRUtils.mm @@ -2,47 +2,53 @@ @implementation OCRUtils -+ (cv::Mat)resizeWithPadding:(cv::Mat)img desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { ++ (cv::Mat)resizeWithPadding:(cv::Mat)img + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight { const int height = img.rows; const int width = img.cols; const float heightRatio = (float)desiredHeight / height; const float widthRatio = (float)desiredWidth / width; const float resizeRatio = MIN(heightRatio, widthRatio); - + const int newWidth = width * resizeRatio; const int newHeight = height * resizeRatio; - + cv::Mat resizedImg; - cv::resize(img, resizedImg, cv::Size(newWidth, newHeight), 0, 0, cv::INTER_AREA); - + cv::resize(img, resizedImg, cv::Size(newWidth, newHeight), 0, 0, + cv::INTER_AREA); + const int cornerPatchSize = MAX(1, MIN(height, width) / 30); std::vector corners = { - img(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), - img(cv::Rect(width - cornerPatchSize, 0, cornerPatchSize, cornerPatchSize)), - img(cv::Rect(0, height - cornerPatchSize, cornerPatchSize, cornerPatchSize)), - img(cv::Rect(width - cornerPatchSize, height - cornerPatchSize, cornerPatchSize, cornerPatchSize)) - }; - + img(cv::Rect(0, 0, cornerPatchSize, cornerPatchSize)), + img(cv::Rect(width - cornerPatchSize, 0, cornerPatchSize, + cornerPatchSize)), + img(cv::Rect(0, height - cornerPatchSize, cornerPatchSize, + cornerPatchSize)), + img(cv::Rect(width - cornerPatchSize, height - cornerPatchSize, + cornerPatchSize, cornerPatchSize))}; + cv::Scalar backgroundScalar = cv::mean(corners[0]); for (int i = 1; i < corners.size(); i++) { backgroundScalar += cv::mean(corners[i]); } backgroundScalar /= (double)corners.size(); - + backgroundScalar[0] = cvFloor(backgroundScalar[0]); backgroundScalar[1] = cvFloor(backgroundScalar[1]); backgroundScalar[2] = cvFloor(backgroundScalar[2]); - + const int deltaW = desiredWidth - newWidth; const int deltaH = desiredHeight - newHeight; const int top = deltaH / 2; const int bottom = deltaH - top; const int left = deltaW / 2; const int right = deltaW - left; - + cv::Mat centeredImg; - cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, cv::BORDER_CONSTANT, backgroundScalar); - + cv::copyMakeBorder(resizedImg, centeredImg, top, bottom, left, right, + cv::BORDER_CONSTANT, backgroundScalar); + return centeredImg; } diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h index 337cdc9f94..7af748f58c 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.h @@ -3,15 +3,26 @@ @interface RecognizerUtils : NSObject + (CGFloat)calculateRatio:(int)width height:(int)height; -+ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight; -+ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast; ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img + width:(int)width + height:(int)height + modelHeight:(int)modelHeight; ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image + adjustContrast:(double)adjustContrast; + (cv::Mat)adjustContrastGrey:(cv::Mat)img target:(double)target; + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector; + (cv::Mat)softmax:(cv::Mat)inputs; -+ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight; -+ (cv::Mat)getCroppedImage:(NSDictionary *)box image:(cv::Mat)image modelHeight:(int)modelHeight; -+ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities modelOutputHeight:(int)modelOutputHeight; ++ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width + height:(int)height + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight; ++ (cv::Mat)getCroppedImage:(NSDictionary *)box + image:(cv::Mat)image + modelHeight:(int)modelHeight; ++ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities + modelOutputHeight:(int)modelOutputHeight; + (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities; -+ (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray; ++ (double)computeConfidenceScore:(NSArray *)valuesArray + indicesArray:(NSArray *)indicesArray; @end diff --git a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm index 74048e208c..65c088b361 100644 --- a/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm +++ b/ios/RnExecutorch/models/ocr/utils/RecognizerUtils.mm @@ -1,5 +1,5 @@ -#import "OCRUtils.h" #import "RecognizerUtils.h" +#import "OCRUtils.h" @implementation RecognizerUtils @@ -11,13 +11,18 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { return ratio; } -+ (cv::Mat)computeRatioAndResize:(cv::Mat)img width:(int)width height:(int)height modelHeight:(int)modelHeight { ++ (cv::Mat)computeRatioAndResize:(cv::Mat)img + width:(int)width + height:(int)height + modelHeight:(int)modelHeight { CGFloat ratio = (CGFloat)width / (CGFloat)height; if (ratio < 1.0) { ratio = [self calculateRatio:width height:height]; - cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, 0, cv::INTER_LANCZOS4); + cv::resize(img, img, cv::Size(modelHeight, (int)(modelHeight * ratio)), 0, + 0, cv::INTER_LANCZOS4); } else { - cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, 0, cv::INTER_LANCZOS4); + cv::resize(img, img, cv::Size((int)(modelHeight * ratio), modelHeight), 0, + 0, cv::INTER_LANCZOS4); } return img; } @@ -26,7 +31,7 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { double contrast = 0.0; int high = 0; int low = 255; - + for (int i = 0; i < img.rows; ++i) { for (int j = 0; j < img.cols; ++j) { uchar pixel = img.at(i, j); @@ -35,55 +40,58 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { } } contrast = (high - low) / 255.0; - + if (contrast < target) { const double ratio = 200.0 / MAX(10, high - low); img.convertTo(img, CV_32F); img = ((img - low + 25) * ratio); - + cv::threshold(img, img, 255, 255, cv::THRESH_TRUNC); cv::threshold(img, img, 0, 0, cv::THRESH_TOZERO); - + img.convertTo(img, CV_8U); } - + return img; } -+ (cv::Mat)normalizeForRecognizer:(cv::Mat)image adjustContrast:(double)adjustContrast { ++ (cv::Mat)normalizeForRecognizer:(cv::Mat)image + adjustContrast:(double)adjustContrast { if (adjustContrast > 0) { image = [self adjustContrastGrey:image target:adjustContrast]; } - + int desiredWidth = 128; if (image.cols >= 512) { desiredWidth = 512; } else if (image.cols >= 256) { desiredWidth = 256; } - - image = [OCRUtils resizeWithPadding:image desiredWidth:desiredWidth desiredHeight:64]; - + + image = [OCRUtils resizeWithPadding:image + desiredWidth:desiredWidth + desiredHeight:64]; + image.convertTo(image, CV_32F, 1.0 / 255.0); image = (image - 0.5) * 2.0; - + return image; } + (cv::Mat)divideMatrix:(cv::Mat)matrix byVector:(NSArray *)vector { cv::Mat result = matrix.clone(); - + for (int i = 0; i < matrix.rows; i++) { const float divisor = [vector[i] floatValue]; for (int j = 0; j < matrix.cols; j++) { result.at(i, j) /= divisor; } } - + return result; } -+ (cv::Mat)softmax:(cv::Mat) inputs { ++ (cv::Mat)softmax:(cv::Mat)inputs { cv::Mat maxVal; cv::reduce(inputs, maxVal, 1, cv::REDUCE_MAX, CV_32F); cv::Mat expInputs; @@ -94,7 +102,10 @@ + (CGFloat)calculateRatio:(int)width height:(int)height { return softmaxOutput; } -+ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height desiredWidth:(int)desiredWidth desiredHeight:(int)desiredHeight { ++ (NSDictionary *)calculateResizeRatioAndPaddings:(int)width + height:(int)height + desiredWidth:(int)desiredWidth + desiredHeight:(int)desiredHeight { const float newRatioH = (float)desiredHeight / height; const float newRatioW = (float)desiredWidth / width; float resizeRatio = MIN(newRatioH, newRatioW); @@ -106,58 +117,67 @@ + (NSDictionary *)calculateResizeRatioAndPaddings:(int)width height:(int)height const int left = deltaW / 2; const float heightRatio = (float)height / desiredHeight; const float widthRatio = (float)width / desiredWidth; - + resizeRatio = MAX(heightRatio, widthRatio); - + return @{ - @"resizeRatio": @(resizeRatio), - @"top": @(top), - @"left": @(left), + @"resizeRatio" : @(resizeRatio), + @"top" : @(top), + @"left" : @(left), }; } -+ (cv::Mat)getCroppedImage:(NSDictionary *)box image:(cv::Mat)image modelHeight:(int)modelHeight { ++ (cv::Mat)getCroppedImage:(NSDictionary *)box + image:(cv::Mat)image + modelHeight:(int)modelHeight { NSArray *coords = box[@"bbox"]; const CGFloat angle = [box[@"angle"] floatValue]; - + std::vector points; for (NSValue *value in coords) { const CGPoint point = [value CGPointValue]; - points.emplace_back(static_cast(point.x), static_cast(point.y)); + points.emplace_back(static_cast(point.x), + static_cast(point.y)); } - + cv::RotatedRect rotatedRect = cv::minAreaRect(points); - + cv::Point2f imageCenter = cv::Point2f(image.cols / 2.0, image.rows / 2.0); cv::Mat rotationMatrix = cv::getRotationMatrix2D(imageCenter, angle, 1.0); cv::Mat rotatedImage; - cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), cv::INTER_LINEAR); + cv::warpAffine(image, rotatedImage, rotationMatrix, image.size(), + cv::INTER_LINEAR); cv::Point2f rectPoints[4]; rotatedRect.points(rectPoints); std::vector transformedPoints(4); cv::Mat rectMat(4, 2, CV_32FC2, rectPoints); cv::transform(rectMat, rectMat, rotationMatrix); - + for (int i = 0; i < 4; ++i) { transformedPoints[i] = rectPoints[i]; } - + cv::Rect boundingBox = cv::boundingRect(transformedPoints); boundingBox &= cv::Rect(0, 0, rotatedImage.cols, rotatedImage.rows); cv::Mat croppedImage = rotatedImage(boundingBox); - if (boundingBox.width == 0 || boundingBox.height == 0){ + if (boundingBox.width == 0 || boundingBox.height == 0) { croppedImage = cv::Mat().empty(); - + return croppedImage; } - - croppedImage = [self computeRatioAndResize:croppedImage width:boundingBox.width height:boundingBox.height modelHeight:modelHeight]; - + + croppedImage = [self computeRatioAndResize:croppedImage + width:boundingBox.width + height:boundingBox.height + modelHeight:modelHeight]; + return croppedImage; } -+ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities modelOutputHeight:(int)modelOutputHeight { - NSMutableArray *predsNorm = [NSMutableArray arrayWithCapacity:probabilities.rows]; ++ (NSMutableArray *)sumProbabilityRows:(cv::Mat)probabilities + modelOutputHeight:(int)modelOutputHeight { + NSMutableArray *predsNorm = + [NSMutableArray arrayWithCapacity:probabilities.rows]; for (int i = 0; i < probabilities.rows; i++) { float sum = 0.0; for (int j = 0; j < modelOutputHeight; j++) { @@ -178,10 +198,11 @@ + (NSArray *)findMaxValuesAndIndices:(cv::Mat)probabilities { [valuesArray addObject:@(maxVal)]; [indicesArray addObject:@(maxLoc.x)]; } - return @[valuesArray, indicesArray]; + return @[ valuesArray, indicesArray ]; } -+ (double)computeConfidenceScore:(NSArray *)valuesArray indicesArray:(NSArray *)indicesArray { ++ (double)computeConfidenceScore:(NSArray *)valuesArray + indicesArray:(NSArray *)indicesArray { NSMutableArray *predsMaxProb = [NSMutableArray array]; for (NSUInteger index = 0; index < indicesArray.count; index++) { NSNumber *indicator = indicesArray[index]; From 9538fbf6a07385f1722b0d5c9afef3e17742a52f Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Fri, 21 Feb 2025 13:31:10 +0100 Subject: [PATCH 17/19] feat: ocr(android) (#96) - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] iOS - [x] Android - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings --- .../java/com/swmansion/rnexecutorch/OCR.kt | 101 ++++ .../rnexecutorch/RnExecutorchPackage.kt | 11 + .../rnexecutorch/models/ocr/Detector.kt | 79 +++ .../models/ocr/RecognitionHandler.kt | 115 +++++ .../rnexecutorch/models/ocr/Recognizer.kt | 56 +++ .../models/ocr/utils/CTCLabelConverter.kt | 75 +++ .../models/ocr/utils/Constants.kt | 27 + .../models/ocr/utils/DetectorUtils.kt | 468 ++++++++++++++++++ .../models/ocr/utils/RecognizerUtils.kt | 269 ++++++++++ .../rnexecutorch/utils/ImageProcessor.kt | 110 +++- 10 files changed, 1304 insertions(+), 7 deletions(-) create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/OCR.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt create mode 100644 android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt diff --git a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt new file mode 100644 index 0000000000..85acf06260 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt @@ -0,0 +1,101 @@ +package com.swmansion.rnexecutorch + +import android.util.Log +import com.facebook.react.bridge.Promise +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.utils.ETError +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.android.OpenCVLoader +import com.swmansion.rnexecutorch.models.ocr.Detector +import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.utils.Fetcher +import com.swmansion.rnexecutorch.utils.ResourceType +import org.opencv.imgproc.Imgproc + +class OCR(reactContext: ReactApplicationContext) : + NativeOCRSpec(reactContext) { + + private lateinit var detector: Detector + private lateinit var recognitionHandler: RecognitionHandler + + companion object { + const val NAME = "OCR" + } + + init { + if (!OpenCVLoader.initLocal()) { + Log.d("rn_executorch", "OpenCV not loaded") + } else { + Log.d("rn_executorch", "OpenCV loaded") + } + } + + override fun loadModule( + detectorSource: String, + recognizerSourceLarge: String, + recognizerSourceMedium: String, + recognizerSourceSmall: String, + symbols: String, + languageDictPath: String, + promise: Promise + ) { + try { + detector = Detector(reactApplicationContext) + detector.loadModel(detectorSource) + Fetcher.downloadResource( + reactApplicationContext, + languageDictPath, + ResourceType.TXT, + false, + { path, error -> + if (error != null) { + throw Error(error.message!!) + } + + recognitionHandler = RecognitionHandler( + symbols, + path!!, + reactApplicationContext + ) + + recognitionHandler.loadRecognizers( + recognizerSourceLarge, + recognizerSourceMedium, + recognizerSourceSmall + ) { _, errorRecognizer -> + if (errorRecognizer != null) { + throw Error(errorRecognizer.message!!) + } + + promise.resolve(0) + } + }) + } catch (e: Exception) { + promise.reject(e.message!!, ETError.InvalidModelSource.toString()) + } + } + + override fun forward(input: String, promise: Promise) { + try { + val inputImage = ImageProcessor.readImage(input) + val bBoxesList = detector.runModel(inputImage) + val detectorSize = detector.getModelImageSize() + Imgproc.cvtColor(inputImage, inputImage, Imgproc.COLOR_BGR2GRAY) + val result = recognitionHandler.recognize( + bBoxesList, + inputImage, + (detectorSize.width * Constants.RECOGNIZER_RATIO).toInt(), + (detectorSize.height * Constants.RECOGNIZER_RATIO).toInt() + ) + promise.resolve(result) + } catch (e: Exception) { + Log.d("rn_executorch", "Error running model: ${e.message}") + promise.reject(e.message!!, e.message) + } + } + + override fun getName(): String { + return NAME + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt index fb7fe1f63b..0ec2a51c4f 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/RnExecutorchPackage.kt @@ -25,6 +25,8 @@ class RnExecutorchPackage : TurboReactPackage() { ObjectDetection(reactContext) } else if (name == SpeechToText.NAME) { SpeechToText(reactContext) + } else if (name == OCR.NAME){ + OCR(reactContext) } else { null @@ -85,6 +87,15 @@ class RnExecutorchPackage : TurboReactPackage() { false, // isCxxModule true ) + + moduleInfos[OCR.NAME] = ReactModuleInfo( + OCR.NAME, + OCR.NAME, + false, // canOverrideExistingModule + false, // needsEagerInit + false, // isCxxModule + true + ) moduleInfos } } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt new file mode 100644 index 0000000000..85976e2281 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Detector.kt @@ -0,0 +1,79 @@ +package com.swmansion.rnexecutorch.models.ocr + +import android.util.Log +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.DetectorUtils +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class Detector(reactApplicationContext: ReactApplicationContext) : + BaseModel>(reactApplicationContext) { + private lateinit var originalSize: Size + + fun getModelImageSize(): Size { + val inputShape = module.getInputShape(0) + val width = inputShape[inputShape.lastIndex] + val height = inputShape[inputShape.lastIndex - 1] + + val modelImageSize = Size(height.toDouble(), width.toDouble()) + + return modelImageSize + } + + override fun preprocess(input: Mat): EValue { + originalSize = Size(input.cols().toDouble(), input.rows().toDouble()) + val resizedImage = ImageProcessor.resizeWithPadding( + input, + getModelImageSize().width.toInt(), + getModelImageSize().height.toInt() + ) + + return ImageProcessor.matToEValue( + resizedImage, + module.getInputShape(0), + Constants.MEAN, + Constants.VARIANCE + ) + } + + override fun postprocess(output: Array): List { + val outputTensor = output[0].toTensor() + val outputArray = outputTensor.dataAsFloatArray + val modelImageSize = getModelImageSize() + + val (scoreText, scoreLink) = DetectorUtils.interleavedArrayToMats( + outputArray, + Size(modelImageSize.width / 2, modelImageSize.height / 2) + ) + var bBoxesList = DetectorUtils.getDetBoxesFromTextMap( + scoreText, + scoreLink, + Constants.TEXT_THRESHOLD, + Constants.LINK_THRESHOLD, + Constants.LOW_TEXT_THRESHOLD + ) + bBoxesList = + DetectorUtils.restoreBoxRatio(bBoxesList, (Constants.RECOGNIZER_RATIO * 2).toFloat()) + bBoxesList = DetectorUtils.groupTextBoxes( + bBoxesList, + Constants.CENTER_THRESHOLD, + Constants.DISTANCE_THRESHOLD, + Constants.HEIGHT_THRESHOLD, + Constants.MIN_SIDE_THRESHOLD, + Constants.MAX_SIDE_THRESHOLD, + Constants.MAX_WIDTH + ) + + return bBoxesList.toList() + } + + override fun runModel(input: Mat): List { + return postprocess(forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt new file mode 100644 index 0000000000..1aeae02e22 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt @@ -0,0 +1,115 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.WritableArray +import com.swmansion.rnexecutorch.models.ocr.utils.CTCLabelConverter +import com.swmansion.rnexecutorch.models.ocr.utils.Constants +import com.swmansion.rnexecutorch.models.ocr.utils.OCRbBox +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Core +import org.opencv.core.Mat + +class RecognitionHandler( + symbols: String, + languageDictPath: String, + reactApplicationContext: ReactApplicationContext +) { + private val recognizerLarge = Recognizer(reactApplicationContext) + private val recognizerMedium = Recognizer(reactApplicationContext) + private val recognizerSmall = Recognizer(reactApplicationContext) + private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key")) + + private fun runModel(croppedImage: Mat): Pair, Double> { + val result: Pair, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) { + recognizerLarge.runModel(croppedImage) + } else if (croppedImage.cols() >= Constants.MEDIUM_MODEL_WIDTH) { + recognizerMedium.runModel(croppedImage) + } else { + recognizerSmall.runModel(croppedImage) + } + + return result + } + + fun loadRecognizers( + largeRecognizerPath: String, + mediumRecognizerPath: String, + smallRecognizerPath: String, + onComplete: (Int, Exception?) -> Unit + ) { + try { + recognizerLarge.loadModel(largeRecognizerPath) + recognizerMedium.loadModel(mediumRecognizerPath) + recognizerSmall.loadModel(smallRecognizerPath) + onComplete(0, null) + } catch (e: Exception) { + onComplete(1, e) + } + } + + fun recognize( + bBoxesList: List, + imgGray: Mat, + desiredWidth: Int, + desiredHeight: Int + ): WritableArray { + val res: WritableArray = Arguments.createArray() + val ratioAndPadding = RecognizerUtils.calculateResizeRatioAndPaddings( + imgGray.width(), + imgGray.height(), + desiredWidth, + desiredHeight + ) + + val left = ratioAndPadding["left"] as Int + val top = ratioAndPadding["top"] as Int + val resizeRatio = ratioAndPadding["resizeRatio"] as Float + val resizedImg = ImageProcessor.resizeWithPadding( + imgGray, + desiredWidth, + desiredHeight + ) + + for (box in bBoxesList) { + var croppedImage = RecognizerUtils.getCroppedImage(box, resizedImg, Constants.MODEL_HEIGHT) + if (croppedImage.empty()) { + continue + } + + croppedImage = RecognizerUtils.normalizeForRecognizer(croppedImage, Constants.ADJUST_CONTRAST) + + var result = runModel(croppedImage) + var confidenceScore = result.second + + if (confidenceScore < Constants.LOW_CONFIDENCE_THRESHOLD) { + Core.rotate(croppedImage, croppedImage, Core.ROTATE_180) + val rotatedResult = runModel(croppedImage) + val rotatedConfidenceScore = rotatedResult.second + if (rotatedConfidenceScore > confidenceScore) { + result = rotatedResult + confidenceScore = rotatedConfidenceScore + } + } + + val predIndex = result.first + val decodedTexts = converter.decodeGreedy(predIndex, predIndex.size) + + for (bBox in box.bBox) { + bBox.x = (bBox.x - left) * resizeRatio + bBox.y = (bBox.y - top) * resizeRatio + } + + val resMap = Arguments.createMap() + + resMap.putString("text", decodedTexts[0]) + resMap.putArray("bbox", box.toWritableArray()) + resMap.putDouble("confidence", confidenceScore) + + res.pushMap(resMap) + } + + return res + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt new file mode 100644 index 0000000000..2772cc4a98 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/Recognizer.kt @@ -0,0 +1,56 @@ +package com.swmansion.rnexecutorch.models.ocr + +import com.facebook.react.bridge.ReactApplicationContext +import com.swmansion.rnexecutorch.models.BaseModel +import com.swmansion.rnexecutorch.models.ocr.utils.RecognizerUtils +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Mat +import org.opencv.core.Size +import org.pytorch.executorch.EValue + +class Recognizer(reactApplicationContext: ReactApplicationContext) : + BaseModel, Double>>(reactApplicationContext) { + + private fun getModelOutputSize(): Size { + val outputShape = module.getOutputShape(0) + val width = outputShape[outputShape.lastIndex] + val height = outputShape[outputShape.lastIndex - 1] + + return Size(height.toDouble(), width.toDouble()) + } + + override fun preprocess(input: Mat): EValue { + return ImageProcessor.matToEValueGray(input) + } + + override fun postprocess(output: Array): Pair, Double> { + val modelOutputHeight = getModelOutputSize().height.toInt() + val tensor = output[0].toTensor().dataAsFloatArray + val numElements = tensor.size + val numRows = (numElements + modelOutputHeight - 1) / modelOutputHeight + val resultMat = Mat(numRows, modelOutputHeight, org.opencv.core.CvType.CV_32F) + var counter = 0 + var currentRow = 0 + for (num in tensor) { + resultMat.put(currentRow, counter, floatArrayOf(num)) + counter++ + if (counter >= modelOutputHeight) { + counter = 0 + currentRow++ + } + } + + var probabilities = RecognizerUtils.softmax(resultMat) + val predsNorm = RecognizerUtils.sumProbabilityRows(probabilities, modelOutputHeight) + probabilities = RecognizerUtils.divideMatrixByVector(probabilities, predsNorm) + val (values, indices) = RecognizerUtils.findMaxValuesAndIndices(probabilities) + + val confidenceScore = RecognizerUtils.computeConfidenceScore(values, indices) + return Pair(indices, confidenceScore) + } + + + override fun runModel(input: Mat): Pair, Double> { + return postprocess(module.forward(preprocess(input))) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt new file mode 100644 index 0000000000..336d2f600f --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt @@ -0,0 +1,75 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import java.io.File + +class CTCLabelConverter( + characters: String, + dictPathList: Map +) { + private val dict = mutableMapOf() + val character: List + private val ignoreIdx: List + private val dictList: Map> + + init { + val mutableCharacters = mutableListOf("[blank]") + characters.forEachIndexed { index, char -> + mutableCharacters.add(char.toString()) + dict[char.toString()] = index + 1 + } + character = mutableCharacters.toList() + + val ignoreIndexes = mutableListOf(0) + + ignoreIdx = ignoreIndexes.toList() + + dictList = loadDictionariesWithDictPathList(dictPathList) + } + + private fun loadDictionariesWithDictPathList(dictPathList: Map): Map> { + val tempDictList = mutableMapOf>() + dictPathList.forEach { (lang, dictPath) -> + runCatching { + File(dictPath).readLines() + }.onSuccess { lines -> + tempDictList[lang] = lines + }.onFailure { error -> + println("Error reading file: ${error.localizedMessage}") + } + } + return tempDictList.toMap() + } + + fun decodeGreedy(textIndex: List, length: Int): List { + val texts = mutableListOf() + var index = 0 + while (index < textIndex.size) { + val segmentLength = minOf(length, textIndex.size - index) + val subArray = textIndex.subList(index, index + segmentLength) + + val text = StringBuilder() + var lastChar: Int? = null + val isNotRepeated = mutableListOf(true) + val isNotIgnored = mutableListOf() + + subArray.forEachIndexed { i, currentChar -> + if (i > 0) { + isNotRepeated.add(lastChar != currentChar) + } + isNotIgnored.add(!ignoreIdx.contains(currentChar)) + lastChar = currentChar + } + + subArray.forEachIndexed { j, charIndex -> + if (isNotRepeated[j] && isNotIgnored[j]) { + text.append(character[charIndex]) + } + } + + texts.add(text.toString()) + index += segmentLength + if (segmentLength < length) break + } + return texts.toList() + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt new file mode 100644 index 0000000000..b49232f41a --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/Constants.kt @@ -0,0 +1,27 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import org.opencv.core.Scalar + +class Constants { + companion object { + const val RECOGNIZER_RATIO = 1.6 + const val MODEL_HEIGHT = 64 + const val LARGE_MODEL_WIDTH = 512 + const val MEDIUM_MODEL_WIDTH = 256 + const val SMALL_MODEL_WIDTH = 128 + const val LOW_CONFIDENCE_THRESHOLD = 0.3 + const val ADJUST_CONTRAST = 0.2 + const val TEXT_THRESHOLD = 0.4 + const val LINK_THRESHOLD = 0.4 + const val LOW_TEXT_THRESHOLD = 0.7 + const val CENTER_THRESHOLD = 0.5 + const val DISTANCE_THRESHOLD = 2.0 + const val HEIGHT_THRESHOLD = 2.0 + const val MIN_SIDE_THRESHOLD = 15 + const val MAX_SIDE_THRESHOLD = 30 + const val MAX_WIDTH = (LARGE_MODEL_WIDTH + (LARGE_MODEL_WIDTH * 0.15)).toInt() + const val MIN_SIZE = 20 + val MEAN = Scalar(0.485, 0.456, 0.406) + val VARIANCE = Scalar(0.229, 0.224, 0.225) + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt new file mode 100644 index 0000000000..4beb7ecf45 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/DetectorUtils.kt @@ -0,0 +1,468 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import com.facebook.react.bridge.Arguments +import com.facebook.react.bridge.WritableArray +import org.opencv.core.Core +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.MatOfFloat4 +import org.opencv.core.MatOfInt +import org.opencv.core.MatOfPoint +import org.opencv.core.MatOfPoint2f +import org.opencv.core.Point +import org.opencv.core.Rect +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import kotlin.math.abs +import kotlin.math.atan +import kotlin.math.cos +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow +import kotlin.math.sin +import kotlin.math.sqrt + +class DetectorUtils { + companion object { + private fun normalizeAngle(angle: Double): Double { + if (angle > 45.0) { + return angle - 90.0 + } + + return angle + } + + private fun midpoint(p1: BBoxPoint, p2: BBoxPoint): BBoxPoint { + val midpoint = BBoxPoint((p1.x + p2.x) / 2, (p1.y + p2.y) / 2) + return midpoint + } + + private fun distanceBetweenPoints(p1: BBoxPoint, p2: BBoxPoint): Double { + return sqrt((p1.x - p2.x).pow(2.0) + (p1.y - p2.y).pow(2.0)) + } + + private fun centerOfBox(box: OCRbBox): BBoxPoint { + val p1 = box.bBox[0] + val p2 = box.bBox[2] + return midpoint(p1, p2) + } + + private fun maxSideLength(box: OCRbBox): Double { + var maxSideLength = 0.0 + val numOfPoints = box.bBox.size + for (i in 0 until numOfPoints) { + val currentPoint = box.bBox[i] + val nextPoint = box.bBox[(i + 1) % numOfPoints] + val sideLength = distanceBetweenPoints(currentPoint, nextPoint) + if (sideLength > maxSideLength) { + maxSideLength = sideLength + } + } + return maxSideLength + } + + private fun minSideLength(box: OCRbBox): Double { + var minSideLength = Double.MAX_VALUE + val numOfPoints = box.bBox.size + for (i in 0 until numOfPoints) { + val currentPoint = box.bBox[i] + val nextPoint = box.bBox[(i + 1) % numOfPoints] + val sideLength = distanceBetweenPoints(currentPoint, nextPoint) + if (sideLength < minSideLength) { + minSideLength = sideLength + } + } + return minSideLength + } + + + private fun calculateMinimalDistanceBetweenBoxes(box1: OCRbBox, box2: OCRbBox): Double { + var minDistance = Double.MAX_VALUE + for (i in 0 until 4) { + for (j in 0 until 4) { + val distance = distanceBetweenPoints(box1.bBox[i], box2.bBox[j]) + if (distance < minDistance) { + minDistance = distance + } + } + } + + return minDistance + } + + private fun rotateBox(box: OCRbBox, angle: Double): OCRbBox { + val center = centerOfBox(box) + val radians = angle * Math.PI / 180 + val newBBox = box.bBox.map { point -> + val translatedX = point.x - center.x + val translatedY = point.y - center.y + val rotatedX = translatedX * cos(radians) - translatedY * sin(radians) + val rotatedY = translatedX * sin(radians) + translatedY * cos(radians) + BBoxPoint(rotatedX + center.x, rotatedY + center.y) + } + + return OCRbBox(newBBox, box.angle) + } + + private fun orderPointsClockwise(box: OCRbBox): OCRbBox { + var topLeft = box.bBox[0] + var topRight = box.bBox[1] + var bottomRight = box.bBox[2] + var bottomLeft = box.bBox[3] + var minSum = Double.MAX_VALUE + var maxSum = -Double.MAX_VALUE + var minDiff = Double.MAX_VALUE + var maxDiff = -Double.MAX_VALUE + + for (point in box.bBox) { + val sum = point.x + point.y + val diff = point.x - point.y + if (sum < minSum) { + minSum = sum + topLeft = point + } + if (sum > maxSum) { + maxSum = sum + bottomRight = point + } + if (diff < minDiff) { + minDiff = diff + bottomLeft = point + } + if (diff > maxDiff) { + maxDiff = diff + topRight = point + } + } + + return OCRbBox(listOf(topLeft, topRight, bottomRight, bottomLeft), box.angle) + } + + private fun mergeRotatedBoxes(box1: OCRbBox, box2: OCRbBox): OCRbBox { + val orderedBox1 = orderPointsClockwise(box1) + val orderedBox2 = orderPointsClockwise(box2) + + val allPoints = arrayListOf() + allPoints.addAll(orderedBox1.bBox.map { Point(it.x, it.y) }) + allPoints.addAll(orderedBox2.bBox.map { Point(it.x, it.y) }) + + val matOfAllPoints = MatOfPoint() + matOfAllPoints.fromList(allPoints) + + val hullIndices = MatOfInt() + Imgproc.convexHull(matOfAllPoints, hullIndices, false) + + val hullPoints = hullIndices.toArray().map { allPoints[it] } + + val matOfHullPoints = MatOfPoint2f() + matOfHullPoints.fromList(hullPoints) + + val minAreaRect = Imgproc.minAreaRect(matOfHullPoints) + val rectPoints = arrayOfNulls(4) + minAreaRect.points(rectPoints) + + val bBoxPoints = rectPoints.filterNotNull().map { BBoxPoint(it.x, it.y) } + + return OCRbBox(bBoxPoints, minAreaRect.angle) + } + + private fun removeSmallBoxes( + boxes: MutableList, + minSideThreshold: Int, + maxSideThreshold: Int + ): MutableList { + return boxes.filter { minSideLength(it) > minSideThreshold && maxSideLength(it) > maxSideThreshold } + .toMutableList() + } + + private fun minimumYFromBox(box: List): Double = box.minOf { it.y } + + private fun fitLineToShortestSides(box: OCRbBox): LineInfo { + val sides = mutableListOf>() + val midpoints = mutableListOf() + + for (i in box.bBox.indices) { + val p1 = box.bBox[i] + val p2 = box.bBox[(i + 1) % 4] + val sideLength = distanceBetweenPoints(p1, p2) + sides.add(sideLength to i) + midpoints.add(midpoint(p1, p2)) + } + + sides.sortBy { it.first } + + val midpoint1 = midpoints[sides[0].second] + val midpoint2 = midpoints[sides[1].second] + + val dx = abs(midpoint2.x - midpoint1.x) + val line = MatOfFloat4() + + val isVertical = if (dx < 20) { + for (point in arrayOf(midpoint1, midpoint2)) { + val temp = point.x + point.x = point.y + point.y = temp + } + Imgproc.fitLine( + MatOfPoint2f( + Point(midpoint1.x, midpoint1.y), + Point(midpoint2.x, midpoint2.y) + ), line, Imgproc.DIST_L2, 0.0, 0.01, 0.01 + ) + true + } else { + Imgproc.fitLine( + MatOfPoint2f( + Point(midpoint1.x, midpoint1.y), + Point(midpoint2.x, midpoint2.y) + ), line, Imgproc.DIST_L2, 0.0, 0.01, 0.01 + ) + false + } + + val m = line.get(1, 0)[0] / line.get(0, 0)[0] // slope + val c = line.get(3, 0)[0] - m * line.get(2, 0)[0] // intercept + return LineInfo(m, c, isVertical) + } + + private fun findClosestBox( + boxes: MutableList, + ignoredIds: Set, + currentBox: OCRbBox, + isVertical: Boolean, + m: Double, + c: Double, + centerThreshold: Double + ): Pair? { + var smallestDistance = Double.MAX_VALUE + var idx = -1 + var boxHeight = 0.0 + val centerOfCurrentBox = centerOfBox(currentBox) + boxes.forEachIndexed { i, box -> + if (ignoredIds.contains(i)) { + return@forEachIndexed + } + val centerOfProcessedBox = centerOfBox(box) + val distanceBetweenCenters = distanceBetweenPoints(centerOfCurrentBox, centerOfProcessedBox) + if (distanceBetweenCenters >= smallestDistance) { + return@forEachIndexed + } + boxHeight = minSideLength(box) + val lineDistance = if (isVertical) + abs(centerOfProcessedBox.x - (m * centerOfProcessedBox.y + c)) + else + abs(centerOfProcessedBox.y - (m * centerOfProcessedBox.x + c)) + + if (lineDistance < boxHeight * centerThreshold) { + idx = i + smallestDistance = distanceBetweenCenters + } + } + + return if (idx == -1) null else Pair(idx, boxHeight) + } + + private fun createMaskFromLabels(labels: Mat, labelValue: Int): Mat { + val mask = Mat.zeros(labels.size(), CvType.CV_8U) + + Core.compare(labels, Scalar(labelValue.toDouble()), mask, Core.CMP_EQ) + + return mask + } + + fun interleavedArrayToMats(array: FloatArray, size: Size): Pair { + val mat1 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) + val mat2 = Mat(size.height.toInt(), size.width.toInt(), CvType.CV_32F) + + array.forEachIndexed { index, value -> + val x = (index / 2) % (size.width.toInt()) + val y = (index / 2) / size.width.toInt() + if (index % 2 == 0) { + mat1.put(y, x, value.toDouble()) + } else { + mat2.put(y, x, value.toDouble()) + } + } + + return Pair(mat1, mat2) + } + + fun getDetBoxesFromTextMap( + textMap: Mat, + affinityMap: Mat, + textThreshold: Double, + linkThreshold: Double, + lowTextThreshold: Double + ): MutableList { + val imgH = textMap.rows() + val imgW = textMap.cols() + + val textScore = Mat() + val affinityScore = Mat() + Imgproc.threshold(textMap, textScore, textThreshold, 1.0, Imgproc.THRESH_BINARY) + Imgproc.threshold(affinityMap, affinityScore, linkThreshold, 1.0, Imgproc.THRESH_BINARY) + val textScoreComb = Mat() + Core.add(textScore, affinityScore, textScoreComb) + Imgproc.threshold(textScoreComb, textScoreComb, 0.0, 1.0, Imgproc.THRESH_BINARY) + + val binaryMat = Mat() + textScoreComb.convertTo(binaryMat, CvType.CV_8UC1) + + val labels = Mat() + val stats = Mat() + val centroids = Mat() + val nLabels = Imgproc.connectedComponentsWithStats(binaryMat, labels, stats, centroids, 4) + + val detectedBoxes = mutableListOf() + for (i in 1 until nLabels) { + val area = stats.get(i, Imgproc.CC_STAT_AREA)[0].toInt() + if (area < 10) continue + val mask = createMaskFromLabels(labels, i) + val maxValResult = Core.minMaxLoc(textMap, mask) + val maxVal = maxValResult.maxVal + if (maxVal < lowTextThreshold) continue + val segMap = Mat.zeros(textMap.size(), CvType.CV_8U) + segMap.setTo(Scalar(255.0), mask) + + val x = stats.get(i, Imgproc.CC_STAT_LEFT)[0].toInt() + val y = stats.get(i, Imgproc.CC_STAT_TOP)[0].toInt() + val w = stats.get(i, Imgproc.CC_STAT_WIDTH)[0].toInt() + val h = stats.get(i, Imgproc.CC_STAT_HEIGHT)[0].toInt() + val dilationRadius = (sqrt(area / max(w, h).toDouble()) * 2.0).toInt() + val sx = max(x - dilationRadius, 0) + val ex = min(x + w + dilationRadius + 1, imgW) + val sy = max(y - dilationRadius, 0) + val ey = min(y + h + dilationRadius + 1, imgH) + val roi = Rect(sx, sy, ex - sx, ey - sy) + val kernel = Imgproc.getStructuringElement( + Imgproc.MORPH_RECT, + Size((1 + dilationRadius).toDouble(), (1 + dilationRadius).toDouble()) + ) + val roiSegMap = Mat(segMap, roi) + Imgproc.dilate(roiSegMap, roiSegMap, kernel) + + val contours: List = ArrayList() + Imgproc.findContours( + segMap, + contours, + Mat(), + Imgproc.RETR_EXTERNAL, + Imgproc.CHAIN_APPROX_SIMPLE + ) + if (contours.isNotEmpty()) { + val minRect = Imgproc.minAreaRect(MatOfPoint2f(*contours[0].toArray())) + val points = Array(4) { Point() } + minRect.points(points) + val pointsList = points.map { point -> BBoxPoint(point.x, point.y) } + val boxInfo = OCRbBox(pointsList, minRect.angle) + detectedBoxes.add(boxInfo) + } + } + + return detectedBoxes + } + + fun restoreBoxRatio(boxes: MutableList, restoreRatio: Float): MutableList { + for (box in boxes) { + for (b in box.bBox) { + b.x *= restoreRatio + b.y *= restoreRatio + } + } + + return boxes + } + + fun groupTextBoxes( + boxes: MutableList, + centerThreshold: Double, + distanceThreshold: Double, + heightThreshold: Double, + minSideThreshold: Int, + maxSideThreshold: Int, + maxWidth: Int + ): MutableList { + boxes.sortByDescending { maxSideLength(it) } + var mergedArray = mutableListOf() + + while (boxes.isNotEmpty()) { + var currentBox = boxes.removeAt(0) + val normalizedAngle = normalizeAngle(currentBox.angle) + val ignoredIds = mutableSetOf() + var lineAngle: Double + while (true) { + val fittedLine = + fitLineToShortestSides(currentBox) + val slope = fittedLine.slope + val intercept = fittedLine.intercept + val isVertical = fittedLine.isVertical + + lineAngle = atan(slope) * 180 / Math.PI + if (isVertical) { + lineAngle = -90.0 + } + + val closestBoxInfo = findClosestBox( + boxes, ignoredIds, currentBox, + isVertical, slope, intercept, centerThreshold + ) ?: break + + val candidateIdx = closestBoxInfo.first + var candidateBox = boxes[candidateIdx] + val candidateHeight = closestBoxInfo.second + if ((candidateBox.angle == 90.0 && !isVertical) || (candidateBox.angle == 0.0 && isVertical)) { + candidateBox = + rotateBox(candidateBox, normalizedAngle) + } + val minDistance = + calculateMinimalDistanceBetweenBoxes(candidateBox, currentBox) + val mergedHeight = minSideLength(currentBox) + if (minDistance < distanceThreshold * candidateHeight && abs(mergedHeight - candidateHeight) < candidateHeight * heightThreshold) { + currentBox = mergeRotatedBoxes(currentBox, candidateBox) + boxes.removeAt(candidateIdx) + ignoredIds.clear() + if (maxSideLength(currentBox) > maxWidth) { + break + } + } else { + ignoredIds.add(candidateIdx) + } + } + mergedArray.add(currentBox.copy(angle = lineAngle)) + } + + mergedArray = removeSmallBoxes(mergedArray, minSideThreshold, maxSideThreshold) + mergedArray = mergedArray.sortedWith(compareBy { minimumYFromBox(it.bBox) }).toMutableList() + + return mergedArray + } + } +} + +data class BBoxPoint( + var x: Double, + var y: Double, +) + +data class OCRbBox( + val bBox: List, + val angle: Double, +) { + fun toWritableArray(): WritableArray { + val array = Arguments.createArray() + bBox.forEach { point -> + val pointMap = Arguments.createMap() + pointMap.putDouble("x", point.x) + pointMap.putDouble("y", point.y) + array.pushMap(pointMap) + } + return array + } +} + +data class LineInfo( + val slope: Double, + val intercept: Double, + val isVertical: Boolean +) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt new file mode 100644 index 0000000000..99adcad9f0 --- /dev/null +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/RecognizerUtils.kt @@ -0,0 +1,269 @@ +package com.swmansion.rnexecutorch.models.ocr.utils + +import com.swmansion.rnexecutorch.utils.ImageProcessor +import org.opencv.core.Core +import org.opencv.core.CvType +import org.opencv.core.Mat +import org.opencv.core.MatOfPoint2f +import org.opencv.core.Point +import org.opencv.core.Rect +import org.opencv.core.Scalar +import org.opencv.core.Size +import org.opencv.imgproc.Imgproc +import kotlin.math.max +import kotlin.math.min +import kotlin.math.pow +import kotlin.math.sqrt + +class RecognizerUtils { + companion object { + private fun calculateRatio(width: Int, height: Int): Double { + var ratio = width.toDouble() / height.toDouble() + if (ratio < 1.0) { + ratio = 1.0 / ratio + } + + return ratio + } + + private fun findIntersection(r1: Rect, r2: Rect): Rect { + val aLeft = r1.x + val aTop = r1.y + val aRight = r1.x + r1.width + val aBottom = r1.y + r1.height + + val bLeft = r2.x + val bTop = r2.y + val bRight = r2.x + r2.width + val bBottom = r2.y + r2.height + + val iLeft = max(aLeft, bLeft) + val iTop = max(aTop, bTop) + val iRight = min(aRight, bRight) + val iBottom = min(aBottom, bBottom) + + return if (iRight > iLeft && iBottom > iTop) { + Rect(iLeft, iTop, iRight - iLeft, iBottom - iTop) + } else { + Rect() + } + } + + private fun adjustContrastGrey(img: Mat, target: Double): Mat { + var high = 0 + var low = 255 + + for (i in 0 until img.rows()) { + for (j in 0 until img.cols()) { + val pixel = img.get(i, j)[0].toInt() + high = maxOf(high, pixel) + low = minOf(low, pixel) + } + } + + val contrast = (high - low) / 255.0 + + if (contrast < target) { + val ratio = 200.0 / maxOf(10, high - low) + val tempImg = Mat() + img.convertTo(tempImg, CvType.CV_32F) + Core.subtract(tempImg, Scalar(low.toDouble() - 25), tempImg) + Core.multiply(tempImg, Scalar(ratio), tempImg) + Imgproc.threshold(tempImg, tempImg, 255.0, 255.0, Imgproc.THRESH_TRUNC) + Imgproc.threshold(tempImg, tempImg, 0.0, 255.0, Imgproc.THRESH_TOZERO) + tempImg.convertTo(tempImg, CvType.CV_8U) + + return tempImg + } + + return img + } + + private fun computeRatioAndResize(img: Mat, width: Int, height: Int, modelHeight: Int): Mat { + var ratio = width.toDouble() / height.toDouble() + + if (ratio < 1.0) { + ratio = + calculateRatio(width, height) + Imgproc.resize( + img, img, Size(modelHeight.toDouble(), (modelHeight * ratio)), + 0.0, 0.0, Imgproc.INTER_LANCZOS4 + ) + } else { + Imgproc.resize( + img, img, Size((modelHeight * ratio), modelHeight.toDouble()), + 0.0, 0.0, Imgproc.INTER_LANCZOS4 + ) + } + + return img + } + + fun softmax(inputs: Mat): Mat { + val maxVal = Mat() + Core.reduce(inputs, maxVal, 1, Core.REDUCE_MAX, CvType.CV_32F) + + val tiledMaxVal = Mat() + Core.repeat(maxVal, 1, inputs.width(), tiledMaxVal) + val expInputs = Mat() + Core.subtract(inputs, tiledMaxVal, expInputs) + Core.exp(expInputs, expInputs) + + val sumExp = Mat() + Core.reduce(expInputs, sumExp, 1, Core.REDUCE_SUM, CvType.CV_32F) + + val tiledSumExp = Mat() + Core.repeat(sumExp, 1, inputs.width(), tiledSumExp) + val softmaxOutput = Mat() + Core.divide(expInputs, tiledSumExp, softmaxOutput) + + return softmaxOutput + } + + fun sumProbabilityRows(probabilities: Mat, modelOutputHeight: Int): FloatArray { + val predsNorm = FloatArray(probabilities.rows()) + + for (i in 0 until probabilities.rows()) { + var sum = 0.0 + for (j in 0 until modelOutputHeight) { + sum += probabilities.get(i, j)[0] + } + predsNorm[i] = sum.toFloat() + } + + return predsNorm + } + + fun divideMatrixByVector(matrix: Mat, vector: FloatArray): Mat { + for (i in 0 until matrix.rows()) { + for (j in 0 until matrix.cols()) { + val value = matrix.get(i, j)[0] / vector[i] + matrix.put(i, j, value) + } + } + + return matrix + } + + fun findMaxValuesAndIndices(probabilities: Mat): Pair> { + val values = DoubleArray(probabilities.rows()) + val indices = mutableListOf() + + for (i in 0 until probabilities.rows()) { + val row = probabilities.row(i) + val minMaxLocResult = Core.minMaxLoc(row) + + values[i] = minMaxLocResult.maxVal + indices.add(minMaxLocResult.maxLoc.x.toInt()) + } + + return Pair(values, indices) + } + + fun computeConfidenceScore(valuesArray: DoubleArray, indicesArray: List): Double { + val predsMaxProb = mutableListOf() + for ((index, value) in indicesArray.withIndex()) { + if (value != 0) predsMaxProb.add(valuesArray[index]) + } + + val nonZeroValues = + if (predsMaxProb.isEmpty()) doubleArrayOf(0.0) else predsMaxProb.toDoubleArray() + val product = nonZeroValues.reduce { acc, d -> acc * d } + val score = product.pow(2.0 / sqrt(nonZeroValues.size.toDouble())) + + return score + } + + fun calculateResizeRatioAndPaddings( + width: Int, + height: Int, + desiredWidth: Int, + desiredHeight: Int + ): Map { + val newRatioH = desiredHeight.toFloat() / height + val newRatioW = desiredWidth.toFloat() / width + var resizeRatio = minOf(newRatioH, newRatioW) + + val newWidth = (width * resizeRatio).toInt() + val newHeight = (height * resizeRatio).toInt() + + val deltaW = desiredWidth - newWidth + val deltaH = desiredHeight - newHeight + + val top = deltaH / 2 + val left = deltaW / 2 + + val heightRatio = height.toFloat() / desiredHeight + val widthRatio = width.toFloat() / desiredWidth + + resizeRatio = maxOf(heightRatio, widthRatio) + + return mapOf( + "resizeRatio" to resizeRatio, + "top" to top, + "left" to left + ) + } + + fun getCroppedImage(box: OCRbBox, image: Mat, modelHeight: Int): Mat { + val cords = box.bBox + val angle = box.angle + val points = ArrayList() + + cords.forEach { point -> + points.add(Point(point.x, point.y)) + } + + val rotatedRect = Imgproc.minAreaRect(MatOfPoint2f(*points.toTypedArray())) + val imageCenter = Point((image.cols() / 2.0), (image.rows() / 2.0)) + val rotationMatrix = Imgproc.getRotationMatrix2D(imageCenter, angle, 1.0) + val rotatedImage = Mat() + Imgproc.warpAffine(image, rotatedImage, rotationMatrix, image.size(), Imgproc.INTER_LINEAR) + + val rectPoints = Array(4) { Point() } + rotatedRect.points(rectPoints) + val transformedPoints = arrayOfNulls(4) + val rectMat = Mat(4, 2, CvType.CV_32FC2) + for (i in 0 until 4) { + rectMat.put(i, 0, *doubleArrayOf(rectPoints[i].x, rectPoints[i].y)) + } + Core.transform(rectMat, rectMat, rotationMatrix) + + for (i in 0 until 4) { + transformedPoints[i] = Point(rectMat.get(i, 0)[0], rectMat.get(i, 0)[1]) + } + + var boundingBox = + Imgproc.boundingRect(MatOfPoint2f(*transformedPoints.filterNotNull().toTypedArray())) + val validRegion = Rect(0, 0, rotatedImage.cols(), rotatedImage.rows()) + boundingBox = findIntersection(boundingBox, validRegion) + val croppedImage = Mat(rotatedImage, boundingBox) + if (croppedImage.empty()) { + return croppedImage + } + + return computeRatioAndResize(croppedImage, boundingBox.width, boundingBox.height, modelHeight) + } + + fun normalizeForRecognizer(image: Mat, adjustContrast: Double): Mat { + var img = image.clone() + + if (adjustContrast > 0) { + img = adjustContrastGrey(img, adjustContrast) + } + + val desiredWidth = when { + img.width() >= Constants.LARGE_MODEL_WIDTH -> Constants.LARGE_MODEL_WIDTH + img.width() >= Constants.MEDIUM_MODEL_WIDTH -> Constants.MEDIUM_MODEL_WIDTH + else -> Constants.SMALL_MODEL_WIDTH + } + + img = ImageProcessor.resizeWithPadding(img, desiredWidth, Constants.MODEL_HEIGHT) + img.convertTo(img, CvType.CV_32F, 1.0 / 255.0) + Core.subtract(img, Scalar(0.5), img) + Core.multiply(img, Scalar(2.0), img) + + return img + } + } +} diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt index 5488ecb476..1e00aa4807 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ImageProcessor.kt @@ -3,20 +3,30 @@ package com.swmansion.rnexecutorch.utils import android.content.Context import android.net.Uri import android.util.Base64 +import android.util.Log +import org.opencv.core.Core import org.opencv.core.CvType import org.opencv.core.Mat +import org.opencv.core.Scalar +import org.opencv.core.Size import org.opencv.imgcodecs.Imgcodecs +import org.opencv.imgproc.Imgproc import org.pytorch.executorch.EValue import org.pytorch.executorch.Tensor import java.io.File import java.io.InputStream import java.net.URL import java.util.UUID +import kotlin.math.floor class ImageProcessor { companion object { fun matToEValue(mat: Mat, shape: LongArray): EValue { + return matToEValue(mat, shape, Scalar(0.0, 0.0, 0.0), Scalar(1.0, 1.0, 1.0)) + } + + fun matToEValue(mat: Mat, shape: LongArray, mean: Scalar, variance: Scalar): EValue { val pixelCount = mat.cols() * mat.rows() val floatArray = FloatArray(pixelCount * 3) @@ -26,19 +36,38 @@ class ImageProcessor { val pixel = mat.get(row, col) if (mat.type() == CvType.CV_8UC3 || mat.type() == CvType.CV_8UC4) { - val b = pixel[0] / 255.0f - val g = pixel[1] / 255.0f - val r = pixel[2] / 255.0f + val b = (pixel[0] - mean.`val`[0] * 255.0f) / (variance.`val`[0] * 255.0f) + val g = (pixel[1] - mean.`val`[1] * 255.0f) / (variance.`val`[1] * 255.0f) + val r = (pixel[2] - mean.`val`[2] * 255.0f) / (variance.`val`[2] * 255.0f) - floatArray[i] = r.toFloat() - floatArray[pixelCount + i] = g.toFloat() - floatArray[2 * pixelCount + i] = b.toFloat() + floatArray[0 * pixelCount + i] = b.toFloat() + floatArray[1 * pixelCount + i] = g.toFloat() + floatArray[2 * pixelCount + i] = r.toFloat() } } return EValue.from(Tensor.fromBlob(floatArray, shape)) } + fun matToEValueGray(mat: Mat): EValue { + val pixelCount = mat.cols() * mat.rows() + val floatArray = FloatArray(pixelCount) + + for (i in 0 until pixelCount) { + val row = i / mat.cols() + val col = i % mat.cols() + val pixel = mat.get(row, col) + floatArray[i] = pixel[0].toFloat() + } + + return EValue.from( + Tensor.fromBlob( + floatArray, + longArrayOf(1, 1, mat.rows().toLong(), mat.cols().toLong()) + ) + ) + } + fun EValueToMat(array: FloatArray, width: Int, height: Int): Mat { val mat = Mat(height, width, CvType.CV_8UC3) @@ -64,7 +93,7 @@ class ImageProcessor { Imgcodecs.imwrite(tempFile.absolutePath, mat) return "file://${tempFile.absolutePath}" - }catch (e: Exception) { + } catch (e: Exception) { throw Exception(ETError.FileWriteFailed.toString()) } } @@ -89,11 +118,13 @@ class ImageProcessor { } inputImage = Imgcodecs.imdecode(encodedData, Imgcodecs.IMREAD_COLOR) } + scheme.equals("file", ignoreCase = true) -> { //device storage val path = uri.path inputImage = Imgcodecs.imread(path, Imgcodecs.IMREAD_COLOR) } + else -> { //external source val url = URL(source) @@ -117,5 +148,70 @@ class ImageProcessor { return inputImage } + + fun resizeWithPadding(img: Mat, desiredWidth: Int, desiredHeight: Int): Mat { + val height = img.rows() + val width = img.cols() + val heightRatio = desiredHeight.toFloat() / height + val widthRatio = desiredWidth.toFloat() / width + val resizeRatio = minOf(heightRatio, widthRatio) + val newWidth = (width * resizeRatio).toInt() + val newHeight = (height * resizeRatio).toInt() + + val resizedImg = Mat() + Imgproc.resize( + img, + resizedImg, + Size(newWidth.toDouble(), newHeight.toDouble()), + 0.0, + 0.0, + Imgproc.INTER_AREA + ) + + val cornerPatchSize = maxOf(1, minOf(width, height) / 30) + val corners = listOf( + img.submat(0, cornerPatchSize, 0, cornerPatchSize), + img.submat(0, cornerPatchSize, width - cornerPatchSize, width), + img.submat(height - cornerPatchSize, height, 0, cornerPatchSize), + img.submat(height - cornerPatchSize, height, width - cornerPatchSize, width) + ) + + var backgroundScalar = Core.mean(corners[0]) + for (i in 1 until corners.size) { + val mean = Core.mean(corners[i]) + backgroundScalar = Scalar( + backgroundScalar.`val`[0] + mean.`val`[0], + backgroundScalar.`val`[1] + mean.`val`[1], + backgroundScalar.`val`[2] + mean.`val`[2] + ) + } + + backgroundScalar = Scalar( + floor(backgroundScalar.`val`[0] / corners.size), + floor(backgroundScalar.`val`[1] / corners.size), + floor(backgroundScalar.`val`[2] / corners.size) + ) + + val deltaW = desiredWidth - newWidth + val deltaH = desiredHeight - newHeight + val top = deltaH / 2 + val bottom = deltaH - top + val left = deltaW / 2 + val right = deltaW - left + + val centeredImg = Mat() + Core.copyMakeBorder( + resizedImg, + centeredImg, + top, + bottom, + left, + right, + Core.BORDER_CONSTANT, + backgroundScalar + ) + + return centeredImg + } } } From d0090719d4d0d4382de0fd78ea6beb2d0eb1266c Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 24 Feb 2025 15:42:23 +0100 Subject: [PATCH 18/19] feat: implemented hookless api, also added fetching with expo file system --- .../java/com/swmansion/rnexecutorch/OCR.kt | 42 ++++------ .../rnexecutorch/models/BaseModel.kt | 1 + .../models/ocr/RecognitionHandler.kt | 3 +- .../models/ocr/utils/CTCLabelConverter.kt | 20 +---- .../rnexecutorch/utils/ArrayUtils.kt | 4 +- .../utils/OkHttpClientSingleton.kt | 7 -- ios/RnExecutorch/OCR.mm | 65 ++++++---------- ios/RnExecutorch/models/ocr/Detector.mm | 1 - .../models/ocr/RecognitionHandler.h | 3 +- .../models/ocr/RecognitionHandler.mm | 7 +- .../models/ocr/utils/CTCLabelConverter.h | 5 +- .../models/ocr/utils/CTCLabelConverter.mm | 27 +------ .../computer_vision/useOCR.ts} | 76 +++++++++---------- src/index.tsx | 2 + src/modules/computer_vision/OCRModule.ts | 74 ++++++++++++++++++ src/native/NativeOCR.ts | 3 +- src/native/RnExecutorchModules.ts | 24 ++++++ 17 files changed, 185 insertions(+), 179 deletions(-) delete mode 100644 android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt rename src/{OCR.ts => hooks/computer_vision/useOCR.ts} (51%) create mode 100644 src/modules/computer_vision/OCRModule.ts diff --git a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt index 85acf06260..4f0b926b37 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/OCR.kt @@ -9,8 +9,6 @@ import org.opencv.android.OpenCVLoader import com.swmansion.rnexecutorch.models.ocr.Detector import com.swmansion.rnexecutorch.models.ocr.RecognitionHandler import com.swmansion.rnexecutorch.models.ocr.utils.Constants -import com.swmansion.rnexecutorch.utils.Fetcher -import com.swmansion.rnexecutorch.utils.ResourceType import org.opencv.imgproc.Imgproc class OCR(reactContext: ReactApplicationContext) : @@ -37,40 +35,28 @@ class OCR(reactContext: ReactApplicationContext) : recognizerSourceMedium: String, recognizerSourceSmall: String, symbols: String, - languageDictPath: String, promise: Promise ) { try { detector = Detector(reactApplicationContext) detector.loadModel(detectorSource) - Fetcher.downloadResource( - reactApplicationContext, - languageDictPath, - ResourceType.TXT, - false, - { path, error -> - if (error != null) { - throw Error(error.message!!) - } - recognitionHandler = RecognitionHandler( - symbols, - path!!, - reactApplicationContext - ) + recognitionHandler = RecognitionHandler( + symbols, + reactApplicationContext + ) - recognitionHandler.loadRecognizers( - recognizerSourceLarge, - recognizerSourceMedium, - recognizerSourceSmall - ) { _, errorRecognizer -> - if (errorRecognizer != null) { - throw Error(errorRecognizer.message!!) - } + recognitionHandler.loadRecognizers( + recognizerSourceLarge, + recognizerSourceMedium, + recognizerSourceSmall + ) { _, errorRecognizer -> + if (errorRecognizer != null) { + throw Error(errorRecognizer.message!!) + } - promise.resolve(0) - } - }) + promise.resolve(0) + } } catch (e: Exception) { promise.reject(e.message!!, ETError.InvalidModelSource.toString()) } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt index 19921443c8..764827f542 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/BaseModel.kt @@ -4,6 +4,7 @@ import android.content.Context import com.swmansion.rnexecutorch.utils.ETError import org.pytorch.executorch.EValue import org.pytorch.executorch.Module +import org.pytorch.executorch.Tensor import java.net.URL diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt index 1aeae02e22..90fd61280e 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/RecognitionHandler.kt @@ -13,13 +13,12 @@ import org.opencv.core.Mat class RecognitionHandler( symbols: String, - languageDictPath: String, reactApplicationContext: ReactApplicationContext ) { private val recognizerLarge = Recognizer(reactApplicationContext) private val recognizerMedium = Recognizer(reactApplicationContext) private val recognizerSmall = Recognizer(reactApplicationContext) - private val converter = CTCLabelConverter(symbols, mapOf(languageDictPath to "key")) + private val converter = CTCLabelConverter(symbols) private fun runModel(croppedImage: Mat): Pair, Double> { val result: Pair, Double> = if (croppedImage.cols() >= Constants.LARGE_MODEL_WIDTH) { diff --git a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt index 336d2f600f..007e7e7c29 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/models/ocr/utils/CTCLabelConverter.kt @@ -4,12 +4,10 @@ import java.io.File class CTCLabelConverter( characters: String, - dictPathList: Map ) { private val dict = mutableMapOf() - val character: List + private val character: List private val ignoreIdx: List - private val dictList: Map> init { val mutableCharacters = mutableListOf("[blank]") @@ -22,22 +20,6 @@ class CTCLabelConverter( val ignoreIndexes = mutableListOf(0) ignoreIdx = ignoreIndexes.toList() - - dictList = loadDictionariesWithDictPathList(dictPathList) - } - - private fun loadDictionariesWithDictPathList(dictPathList: Map): Map> { - val tempDictList = mutableMapOf>() - dictPathList.forEach { (lang, dictPath) -> - runCatching { - File(dictPath).readLines() - }.onSuccess { lines -> - tempDictList[lang] = lines - }.onFailure { error -> - println("Error reading file: ${error.localizedMessage}") - } - } - return tempDictList.toMap() } fun decodeGreedy(textIndex: List, length: Int): List { diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt index 9c295499e3..3dcbbed45d 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt @@ -19,9 +19,7 @@ class ArrayUtils { fun createCharArray(input: ReadableArray): CharArray { return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toChar() }.toCharArray() } - fun createByteArray(input: ReadableArray): ByteArray { - return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray() - } + fun createIntArray(input: ReadableArray): IntArray { return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray() } diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt deleted file mode 100644 index 7a2dda79f5..0000000000 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/OkHttpClientSingleton.kt +++ /dev/null @@ -1,7 +0,0 @@ -package com.swmansion.rnexecutorch.utils - -import okhttp3.OkHttpClient - -object OkHttpClientSingleton { - val instance = OkHttpClient() -} \ No newline at end of file diff --git a/ios/RnExecutorch/OCR.mm b/ios/RnExecutorch/OCR.mm index 975c82989b..59740c90bb 100644 --- a/ios/RnExecutorch/OCR.mm +++ b/ios/RnExecutorch/OCR.mm @@ -1,7 +1,6 @@ #import "OCR.h" #import "models/ocr/Detector.h" #import "models/ocr/RecognitionHandler.h" -#import "utils/Fetcher.h" #import "utils/ImageProcessor.h" #import #import @@ -18,7 +17,6 @@ - (void)loadModule:(NSString *)detectorSource recognizerSourceMedium:(NSString *)recognizerSourceMedium recognizerSourceSmall:(NSString *)recognizerSourceSmall symbols:(NSString *)symbols - languageDictPath:(NSString *)languageDictPath resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { detector = [[Detector alloc] init]; @@ -37,44 +35,31 @@ - (void)loadModule:(NSString *)detectorSource error); return; } - [Fetcher fetchResource:[NSURL URLWithString:languageDictPath] - resourceType:ResourceType::TXT - completionHandler:^(NSString *filePath, NSError *error) { - if (error) { - reject(@"init_module_error", - @"Failed to initialize converter module", error); - return; - } - - self->recognitionHandler = - [[RecognitionHandler alloc] initWithSymbols:symbols - languageDictPath:filePath]; - [self->recognitionHandler - loadRecognizers:recognizerSourceLarge - mediumRecognizerPath:recognizerSourceMedium - smallRecognizerPath:recognizerSourceSmall - completion:^(BOOL allModelsLoaded, - NSNumber *errorCode) { - if (allModelsLoaded) { - resolve(@(YES)); - } else { - NSError *error = [NSError - errorWithDomain:@"OCRErrorDomain" - code:[errorCode intValue] - userInfo:@{ - NSLocalizedDescriptionKey : - [NSString stringWithFormat: - @"%ld", - (long)[errorCode - longValue]] - }]; - reject(@"init_recognizer_error", - @"Failed to initialize one or more " - @"recognizer models", - error); - } - }]; - }]; + self->recognitionHandler = + [[RecognitionHandler alloc] initWithSymbols:symbols]; + [self->recognitionHandler + loadRecognizers:recognizerSourceLarge + mediumRecognizerPath:recognizerSourceMedium + smallRecognizerPath:recognizerSourceSmall + completion:^(BOOL allModelsLoaded, NSNumber *errorCode) { + if (allModelsLoaded) { + resolve(@(YES)); + } else { + NSError *error = [NSError + errorWithDomain:@"OCRErrorDomain" + code:[errorCode intValue] + userInfo:@{ + NSLocalizedDescriptionKey : [NSString + stringWithFormat:@"%ld", + (long)[errorCode + longValue]] + }]; + reject(@"init_recognizer_error", + @"Failed to initialize one or more " + @"recognizer models", + error); + } + }]; }]; } diff --git a/ios/RnExecutorch/models/ocr/Detector.mm b/ios/RnExecutorch/models/ocr/Detector.mm index 56604ac607..20b82b5ee7 100644 --- a/ios/RnExecutorch/models/ocr/Detector.mm +++ b/ios/RnExecutorch/models/ocr/Detector.mm @@ -77,7 +77,6 @@ group each character into a single instance (sequence) Both matrices are usingTextThreshold:textThreshold linkThreshold:linkThreshold lowTextThreshold:lowTextThreshold]; - NSLog(@"Detected boxes: %lu", (unsigned long)bBoxesList.count); bBoxesList = [DetectorUtils restoreBboxRatio:bBoxesList usingRestoreRatio:restoreRatio]; bBoxesList = [DetectorUtils groupTextBoxes:bBoxesList diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.h b/ios/RnExecutorch/models/ocr/RecognitionHandler.h index b638eff0e6..412504370e 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.h +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.h @@ -9,8 +9,7 @@ constexpr CGFloat adjustContrast = 0.2; @interface RecognitionHandler : NSObject -- (instancetype)initWithSymbols:(NSString *)symbols - languageDictPath:(NSString *)languageDictPath; +- (instancetype)initWithSymbols:(NSString *)symbols; - (void)loadRecognizers:(NSString *)largeRecognizerPath mediumRecognizerPath:(NSString *)mediumRecognizerPath smallRecognizerPath:(NSString *)smallRecognizerPath diff --git a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm index 57cc419a58..60616b9099 100644 --- a/ios/RnExecutorch/models/ocr/RecognitionHandler.mm +++ b/ios/RnExecutorch/models/ocr/RecognitionHandler.mm @@ -1,5 +1,4 @@ #import "RecognitionHandler.h" -#import "../../utils/Fetcher.h" #import "../../utils/ImageProcessor.h" #import "./utils/CTCLabelConverter.h" #import "./utils/OCRUtils.h" @@ -21,8 +20,7 @@ @implementation RecognitionHandler { CTCLabelConverter *converter; } -- (instancetype)initWithSymbols:(NSString *)symbols - languageDictPath:(NSString *)languageDictPath { +- (instancetype)initWithSymbols:(NSString *)symbols { self = [super init]; if (self) { recognizerLarge = [[Recognizer alloc] init]; @@ -31,8 +29,7 @@ - (instancetype)initWithSymbols:(NSString *)symbols converter = [[CTCLabelConverter alloc] initWithCharacters:symbols - separatorList:@{} - dictPathList:@{@"key" : languageDictPath}]; + separatorList:@{}]; } return self; } diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h index cae07be437..498710dd03 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.h @@ -9,10 +9,7 @@ @property(strong, nonatomic) NSDictionary *dictList; - (instancetype)initWithCharacters:(NSString *)characters - separatorList:(NSDictionary *)separatorList - dictPathList:(NSDictionary *)dictPathList; -- (void)loadDictionariesWithDictPathList: - (NSDictionary *)dictPathList; + separatorList:(NSDictionary *)separatorList; - (NSArray *)decodeGreedy:(NSArray *)textIndex length:(NSInteger)length; diff --git a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm index ca0fd30da0..7d50e3813f 100644 --- a/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm +++ b/ios/RnExecutorch/models/ocr/utils/CTCLabelConverter.mm @@ -3,8 +3,7 @@ @implementation CTCLabelConverter - (instancetype)initWithCharacters:(NSString *)characters - separatorList:(NSDictionary *)separatorList - dictPathList:(NSDictionary *)dictPathList { + separatorList:(NSDictionary *)separatorList{ self = [super init]; if (self) { _dict = [NSMutableDictionary dictionary]; @@ -29,34 +28,10 @@ - (instancetype)initWithCharacters:(NSString *)characters } } _ignoreIdx = [ignoreIndexes copy]; - _dictList = [NSDictionary dictionary]; - [self loadDictionariesWithDictPathList:dictPathList]; } return self; } -- (void)loadDictionariesWithDictPathList: - (NSDictionary *)dictPathList { - NSMutableDictionary *tempDictList = [NSMutableDictionary dictionary]; - for (NSString *lang in dictPathList.allKeys) { - NSString *dictPath = dictPathList[lang]; - NSError *error; - NSString *fileContents = - [NSString stringWithContentsOfFile:dictPath - encoding:NSUTF8StringEncoding - error:&error]; - if (error) { - NSLog(@"Error reading file: %@", error.localizedDescription); - continue; - } - NSArray *lines = [fileContents - componentsSeparatedByCharactersInSet:[NSCharacterSet - newlineCharacterSet]]; - [tempDictList setObject:lines forKey:lang]; - } - _dictList = [tempDictList copy]; -} - - (NSArray *)decodeGreedy:(NSArray *)textIndex length:(NSInteger)length { NSMutableArray *texts = [NSMutableArray array]; diff --git a/src/OCR.ts b/src/hooks/computer_vision/useOCR.ts similarity index 51% rename from src/OCR.ts rename to src/hooks/computer_vision/useOCR.ts index 17c4aafcdb..95bb896d3b 100644 --- a/src/OCR.ts +++ b/src/hooks/computer_vision/useOCR.ts @@ -1,26 +1,20 @@ import { useEffect, useState } from 'react'; -import { ResourceSource } from './types/common'; -import { OCR } from './native/RnExecutorchModules'; -import { ETError, getError } from './Error'; -import { Image } from 'react-native'; -import { OCRDetection } from './types/ocr'; -import { symbols } from './constants/ocr/symbols'; -import { languageDicts } from './constants/ocr/languageDicts'; +import { fetchResource } from '../../utils/fetchResource'; +import { languageDicts } from '../../constants/ocr/languageDicts'; +import { symbols } from '../../constants/ocr/symbols'; +import { getError, ETError } from '../../Error'; +import { _OCRModule } from '../../native/RnExecutorchModules'; +import { ResourceSource } from '../../types/common'; +import { OCRDetection } from '../../types/ocr'; interface OCRModule { error: string | null; isReady: boolean; isGenerating: boolean; forward: (input: string) => Promise; + downloadProgress: number; } -const getResourcePath = (source: ResourceSource) => { - if (typeof source === 'number') { - return Image.resolveAssetSource(source).uri; - } - return source; -}; - export const useOCR = ({ detectorSource, recognizerSources, @@ -34,47 +28,48 @@ export const useOCR = ({ }; language?: string; }): OCRModule => { + const [module, _] = useState(() => new _OCRModule()); const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); + const [downloadProgress, setDownloadProgress] = useState(0); useEffect(() => { const loadModel = async () => { - if (!detectorSource || Object.keys(recognizerSources).length === 0) - return; - - const detectorPath = getResourcePath(detectorSource); - const recognizerPaths = {} as { - recognizerLarge: string; - recognizerMedium: string; - recognizerSmall: string; - }; + try { + if (!detectorSource || Object.keys(recognizerSources).length === 0) + return; - if (!symbols[language] || !languageDicts[language]) { - setError(getError(ETError.LanguageNotSupported)); - return; - } + const recognizerPaths = {} as { + recognizerLarge: string; + recognizerMedium: string; + recognizerSmall: string; + }; - for (const key in recognizerSources) { - if (recognizerSources.hasOwnProperty(key)) { - recognizerPaths[key as keyof typeof recognizerPaths] = - getResourcePath( - recognizerSources[key as keyof typeof recognizerSources] - ); + if (!symbols[language] || !languageDicts[language]) { + setError(getError(ETError.LanguageNotSupported)); + return; } - } - const languageDictPath = getResourcePath(languageDicts[language]); + const detectorPath = await fetchResource(detectorSource); + + await Promise.all([ + fetchResource(recognizerSources.recognizerLarge, setDownloadProgress), + fetchResource(recognizerSources.recognizerMedium), + fetchResource(recognizerSources.recognizerSmall), + ]).then((values) => { + recognizerPaths.recognizerLarge = values[0]; + recognizerPaths.recognizerMedium = values[1]; + recognizerPaths.recognizerSmall = values[2]; + }); - try { setIsReady(false); - await OCR.loadModule( + await module.loadModule( detectorPath, recognizerPaths.recognizerLarge, recognizerPaths.recognizerMedium, recognizerPaths.recognizerSmall, - symbols[language], - languageDictPath + symbols[language] ); setIsReady(true); } catch (e) { @@ -96,7 +91,7 @@ export const useOCR = ({ try { setIsGenerating(true); - const output = await OCR.forward(input); + const output = await module.forward(input); return output; } catch (e) { throw new Error(getError(e)); @@ -110,5 +105,6 @@ export const useOCR = ({ isReady, isGenerating, forward, + downloadProgress, }; }; diff --git a/src/index.tsx b/src/index.tsx index 066e14df4f..f5bfa1854d 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -2,6 +2,7 @@ export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; +export * from './hooks/computer_vision/useOCR'; export * from './hooks/natural_language_processing/useLLM'; @@ -11,6 +12,7 @@ export * from './hooks/general/useExecutorchModule'; export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; +export * from './modules/computer_vision/OCRModule'; export * from './modules/natural_language_processing/LLMModule'; diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts new file mode 100644 index 0000000000..e154204d04 --- /dev/null +++ b/src/modules/computer_vision/OCRModule.ts @@ -0,0 +1,74 @@ +import { languageDicts } from '../../constants/ocr/languageDicts'; +import { symbols } from '../../constants/ocr/symbols'; +import { getError, ETError } from '../../Error'; +import { _OCRModule } from '../../native/RnExecutorchModules'; +import { ResourceSource } from '../../types/common'; +import { fetchResource } from '../../utils/fetchResource'; + +export class OCRModule { + static module = new _OCRModule(); + + static onDownloadProgressCallback = (_downloadProgress: number) => {}; + + static async load( + detectorSource: ResourceSource, + recognizerSources: { + recognizerLarge: ResourceSource; + recognizerMedium: ResourceSource; + recognizerSmall: ResourceSource; + }, + language = 'en' + ) { + try { + if (!detectorSource || Object.keys(recognizerSources).length === 0) + return; + + const recognizerPaths = {} as { + recognizerLarge: string; + recognizerMedium: string; + recognizerSmall: string; + }; + + if (!symbols[language] || !languageDicts[language]) { + throw new Error(getError(ETError.LanguageNotSupported)); + } + + const detectorPath = await fetchResource(detectorSource); + + await Promise.all([ + fetchResource( + recognizerSources.recognizerLarge, + this.onDownloadProgressCallback + ), + fetchResource(recognizerSources.recognizerMedium), + fetchResource(recognizerSources.recognizerSmall), + ]).then((values) => { + recognizerPaths.recognizerLarge = values[0]; + recognizerPaths.recognizerMedium = values[1]; + recognizerPaths.recognizerSmall = values[2]; + }); + + await this.module.loadModule( + detectorPath, + recognizerPaths.recognizerLarge, + recognizerPaths.recognizerMedium, + recognizerPaths.recognizerSmall, + symbols[language] + ); + } catch (e) { + throw new Error(getError(e)); + } + } + + static async forward(input: string) { + try { + return await this.module.forward(input); + } catch (e) { + throw new Error(getError(e)); + } + } + + static onDownloadProgress(callback: (downloadProgress: number) => void) { + this.onDownloadProgressCallback = callback; + } +} diff --git a/src/native/NativeOCR.ts b/src/native/NativeOCR.ts index 305bf01273..2c14c6ac0d 100644 --- a/src/native/NativeOCR.ts +++ b/src/native/NativeOCR.ts @@ -8,8 +8,7 @@ export interface Spec extends TurboModule { recognizerSourceLarge: string, recognizerSourceMedium: string, recognizerSourceSmall: string, - symbols: string, - languageDictPath: string + symbols: string ): Promise; forward(input: string): Promise; } diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index c8044aa473..42d554dc35 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -3,6 +3,7 @@ import { Spec as ClassificationInterface } from './NativeClassification'; import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; import { Spec as StyleTransferInterface } from './NativeStyleTransfer'; import { Spec as ETModuleInterface } from './NativeETModule'; +import { Spec as OCRInterface } from './NativeOCR'; const LINKING_ERROR = `The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` + @@ -125,6 +126,28 @@ class _StyleTransferModule { } } +class _OCRModule { + async forward(input: string): ReturnType { + return await OCR.forward(input); + } + + async loadModule( + detectorSource: string | number, + recognizerLarge: string | number, + recognizerMedium: string | number, + recognizerSmall: string | number, + symbols: string + ): ReturnType { + return await OCR.loadModule( + detectorSource, + recognizerLarge, + recognizerMedium, + recognizerSmall, + symbols + ); + } +} + class _SpeechToTextModule { async generate(waveform: number[]): Promise { return await SpeechToText.generate(waveform); @@ -187,4 +210,5 @@ export { _StyleTransferModule, _ObjectDetectionModule, _SpeechToTextModule, + _OCRModule, }; From 6850891d13e2214ce9b3581faa7b591d2e3d9522 Mon Sep 17 00:00:00 2001 From: Norbert Klockiewicz Date: Mon, 24 Feb 2025 16:18:27 +0100 Subject: [PATCH 19/19] refactor: remove unnecessary _OCRModule --- src/hooks/computer_vision/useOCR.ts | 7 +++---- src/modules/computer_vision/OCRModule.ts | 8 +++----- src/native/RnExecutorchModules.ts | 24 ------------------------ 3 files changed, 6 insertions(+), 33 deletions(-) diff --git a/src/hooks/computer_vision/useOCR.ts b/src/hooks/computer_vision/useOCR.ts index 95bb896d3b..56ee04e412 100644 --- a/src/hooks/computer_vision/useOCR.ts +++ b/src/hooks/computer_vision/useOCR.ts @@ -3,7 +3,7 @@ import { fetchResource } from '../../utils/fetchResource'; import { languageDicts } from '../../constants/ocr/languageDicts'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; -import { _OCRModule } from '../../native/RnExecutorchModules'; +import { OCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; import { OCRDetection } from '../../types/ocr'; @@ -28,7 +28,6 @@ export const useOCR = ({ }; language?: string; }): OCRModule => { - const [module, _] = useState(() => new _OCRModule()); const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); @@ -64,7 +63,7 @@ export const useOCR = ({ }); setIsReady(false); - await module.loadModule( + await OCR.loadModule( detectorPath, recognizerPaths.recognizerLarge, recognizerPaths.recognizerMedium, @@ -91,7 +90,7 @@ export const useOCR = ({ try { setIsGenerating(true); - const output = await module.forward(input); + const output = await OCR.forward(input); return output; } catch (e) { throw new Error(getError(e)); diff --git a/src/modules/computer_vision/OCRModule.ts b/src/modules/computer_vision/OCRModule.ts index e154204d04..26ea6f4e89 100644 --- a/src/modules/computer_vision/OCRModule.ts +++ b/src/modules/computer_vision/OCRModule.ts @@ -1,13 +1,11 @@ import { languageDicts } from '../../constants/ocr/languageDicts'; import { symbols } from '../../constants/ocr/symbols'; import { getError, ETError } from '../../Error'; -import { _OCRModule } from '../../native/RnExecutorchModules'; +import { OCR } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; import { fetchResource } from '../../utils/fetchResource'; export class OCRModule { - static module = new _OCRModule(); - static onDownloadProgressCallback = (_downloadProgress: number) => {}; static async load( @@ -48,7 +46,7 @@ export class OCRModule { recognizerPaths.recognizerSmall = values[2]; }); - await this.module.loadModule( + await OCR.loadModule( detectorPath, recognizerPaths.recognizerLarge, recognizerPaths.recognizerMedium, @@ -62,7 +60,7 @@ export class OCRModule { static async forward(input: string) { try { - return await this.module.forward(input); + return await OCR.forward(input); } catch (e) { throw new Error(getError(e)); } diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index 42d554dc35..c8044aa473 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -3,7 +3,6 @@ import { Spec as ClassificationInterface } from './NativeClassification'; import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; import { Spec as StyleTransferInterface } from './NativeStyleTransfer'; import { Spec as ETModuleInterface } from './NativeETModule'; -import { Spec as OCRInterface } from './NativeOCR'; const LINKING_ERROR = `The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` + @@ -126,28 +125,6 @@ class _StyleTransferModule { } } -class _OCRModule { - async forward(input: string): ReturnType { - return await OCR.forward(input); - } - - async loadModule( - detectorSource: string | number, - recognizerLarge: string | number, - recognizerMedium: string | number, - recognizerSmall: string | number, - symbols: string - ): ReturnType { - return await OCR.loadModule( - detectorSource, - recognizerLarge, - recognizerMedium, - recognizerSmall, - symbols - ); - } -} - class _SpeechToTextModule { async generate(waveform: number[]): Promise { return await SpeechToText.generate(waveform); @@ -210,5 +187,4 @@ export { _StyleTransferModule, _ObjectDetectionModule, _SpeechToTextModule, - _OCRModule, };