diff --git a/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt b/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt index 7f1e73fc46..929817e713 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/ETModule.kt @@ -7,12 +7,13 @@ import com.facebook.react.bridge.ReadableArray import com.swmansion.rnexecutorch.utils.ArrayUtils import com.swmansion.rnexecutorch.utils.ETError import com.swmansion.rnexecutorch.utils.TensorUtils +import org.pytorch.executorch.EValue import org.pytorch.executorch.Module import java.net.URL class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(reactContext) { private lateinit var module: Module - + private var reactApplicationContext = reactContext; override fun getName(): String { return NAME } @@ -33,26 +34,40 @@ class ETModule(reactContext: ReactApplicationContext) : NativeETModuleSpec(react } override fun forward( - input: ReadableArray, - shape: ReadableArray, - inputType: Double, + inputs: ReadableArray, + shapes: ReadableArray, + inputTypes: ReadableArray, promise: Promise ) { + val inputEValues = ArrayList() try { - val executorchInput = - TensorUtils.getExecutorchInput(input, ArrayUtils.createLongArray(shape), inputType.toInt()) + for (i in 0 until inputs.size()) { + val currentInput = inputs.getArray(i) + ?: throw Exception(ETError.InvalidArgument.code.toString()) + val currentShape = shapes.getArray(i) + ?: throw Exception(ETError.InvalidArgument.code.toString()) + val currentInputType = inputTypes.getInt(i) - val result = module.forward(executorchInput) - val resultArray = Arguments.createArray() + val currentEValue = TensorUtils.getExecutorchInput( + currentInput, + ArrayUtils.createLongArray(currentShape), + currentInputType + ) - for (evalue in result) { - resultArray.pushArray(ArrayUtils.createReadableArray(evalue.toTensor())) + inputEValues.add(currentEValue) } - promise.resolve(resultArray) - return + val forwardOutputs = module.forward(*inputEValues.toTypedArray()); + val outputArray = Arguments.createArray() + + for (output in forwardOutputs) { + val arr = ArrayUtils.createReadableArrayFromTensor(output.toTensor()) + outputArray.pushArray(arr) + } + promise.resolve(outputArray) + } catch (e: IllegalArgumentException) { - //The error is thrown when transformation to Tensor fails + // The error is thrown when transformation to Tensor fails promise.reject("Forward Failed Execution", ETError.InvalidArgument.code.toString()) return } catch (e: Exception) { diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt index 10ffe29eb8..8693a69dcb 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt @@ -7,82 +7,52 @@ import org.pytorch.executorch.Tensor class ArrayUtils { companion object { - fun createByteArray(input: ReadableArray): ByteArray { - val byteArray = ByteArray(input.size()) - for (i in 0 until input.size()) { - byteArray[i] = input.getInt(i).toByte() - } - return byteArray + private inline fun createTypedArrayFromReadableArray(input: ReadableArray, transform: (ReadableArray, Int) -> T): Array { + return Array(input.size()) { index -> transform(input, index) } } + fun createByteArray(input: ReadableArray): ByteArray { + return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toByte() }.toByteArray() + } fun createIntArray(input: ReadableArray): IntArray { - val intArray = IntArray(input.size()) - for (i in 0 until input.size()) { - intArray[i] = input.getInt(i) - } - return intArray + return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index) }.toIntArray() } fun createFloatArray(input: ReadableArray): FloatArray { - val floatArray = FloatArray(input.size()) - for (i in 0 until input.size()) { - floatArray[i] = input.getDouble(i).toFloat() - } - return floatArray + return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index).toFloat() }.toFloatArray() } fun createLongArray(input: ReadableArray): LongArray { - val longArray = LongArray(input.size()) - for (i in 0 until input.size()) { - longArray[i] = input.getInt(i).toLong() - } - return longArray + return createTypedArrayFromReadableArray(input) { array, index -> array.getInt(index).toLong() }.toLongArray() } fun createDoubleArray(input: ReadableArray): DoubleArray { - val doubleArray = DoubleArray(input.size()) - for (i in 0 until input.size()) { - doubleArray[i] = input.getDouble(i) - } - return doubleArray + return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray() } - - fun createReadableArray(result: Tensor): ReadableArray { + fun createReadableArrayFromTensor(result: Tensor): ReadableArray { val resultArray = Arguments.createArray() + when (result.dtype()) { DType.UINT8 -> { - val byteArray = result.dataAsByteArray - for (i in byteArray) { - resultArray.pushInt(i.toInt()) - } + result.dataAsByteArray.forEach { resultArray.pushInt(it.toInt()) } } DType.INT32 -> { - val intArray = result.dataAsIntArray - for (i in intArray) { - resultArray.pushInt(i) - } + result.dataAsIntArray.forEach { resultArray.pushInt(it) } } DType.FLOAT -> { - val longArray = result.dataAsFloatArray - for (i in longArray) { - resultArray.pushDouble(i.toDouble()) - } + result.dataAsFloatArray.forEach { resultArray.pushDouble(it.toDouble()) } } DType.DOUBLE -> { - val floatArray = result.dataAsDoubleArray - for (i in floatArray) { - resultArray.pushDouble(i) - } + result.dataAsDoubleArray.forEach { resultArray.pushDouble(it) } } DType.INT64 -> { - val doubleArray = result.dataAsLongArray - for (i in doubleArray) { - resultArray.pushLong(i) - } + // TODO: Do something to handle or deprecate long dtype + // https://github.com/facebook/react-native/issues/12506 + result.dataAsLongArray.forEach { resultArray.pushInt(it.toInt()) } } else -> { diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt index 2cf217a89d..212b415eb4 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/TensorUtils.kt @@ -12,27 +12,23 @@ class TensorUtils { fun getExecutorchInput(input: ReadableArray, shape: LongArray, type: Int): EValue { try { when (type) { - 0 -> { + 1 -> { val inputTensor = Tensor.fromBlob(ArrayUtils.createByteArray(input), shape) return EValue.from(inputTensor) } - - 1 -> { + 3 -> { val inputTensor = Tensor.fromBlob(ArrayUtils.createIntArray(input), shape) return EValue.from(inputTensor) } - - 2 -> { + 4 -> { val inputTensor = Tensor.fromBlob(ArrayUtils.createLongArray(input), shape) return EValue.from(inputTensor) } - - 3 -> { + 6 -> { val inputTensor = Tensor.fromBlob(ArrayUtils.createFloatArray(input), shape) return EValue.from(inputTensor) } - - 4 -> { + 7 -> { val inputTensor = Tensor.fromBlob(ArrayUtils.createDoubleArray(input), shape) return EValue.from(inputTensor) } diff --git a/ios/RnExecutorch/ETModule.mm b/ios/RnExecutorch/ETModule.mm index abd202796b..7f368d4364 100644 --- a/ios/RnExecutorch/ETModule.mm +++ b/ios/RnExecutorch/ETModule.mm @@ -1,5 +1,6 @@ #import "ETModule.h" #import +#include #import #include @@ -36,20 +37,23 @@ - (void)loadModule:(NSString *)modelSource resolve(result); } -- (void)forward:(NSArray *)input - shape:(NSArray *)shape - inputType:(double)inputType +- (void)forward:(NSArray *)inputs + shapes:(NSArray *)shapes + inputTypes:(NSArray *)inputTypes resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { @try { - NSArray *result = [module forward:input - shape:shape - inputType:[NSNumber numberWithInt:inputType]]; + NSArray *result = [module forward:inputs + shapes:shapes + inputTypes:inputTypes]; resolve(result); } @catch (NSException *exception) { - NSLog(@"An exception occurred: %@, %@", exception.name, exception.reason); - reject(@"result_error", [NSString stringWithFormat:@"%@", exception.reason], - nil); + NSLog(@"An exception occurred in forward: %@, %@", exception.name, + exception.reason); + reject( + @"forward_error", + [NSString stringWithFormat:@"An error occurred: %@", exception.reason], + nil); } } diff --git a/ios/RnExecutorch/models/BaseModel.h b/ios/RnExecutorch/models/BaseModel.h index 147215b687..b06e1cfcf6 100644 --- a/ios/RnExecutorch/models/BaseModel.h +++ b/ios/RnExecutorch/models/BaseModel.h @@ -8,6 +8,11 @@ } - (NSArray *)forward:(NSArray *)input; + +- (NSArray *)forward:(NSArray *)inputs + shapes:(NSArray *)shapes + inputTypes:(NSArray *)inputTypes; + - (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber *code))completion; diff --git a/ios/RnExecutorch/models/BaseModel.mm b/ios/RnExecutorch/models/BaseModel.mm index 587eddf68b..26cebd6e6a 100644 --- a/ios/RnExecutorch/models/BaseModel.mm +++ b/ios/RnExecutorch/models/BaseModel.mm @@ -4,16 +4,29 @@ @implementation BaseModel - (NSArray *)forward:(NSArray *)input { - NSArray *result = [module forward:input - shape:[module getInputShape:@0] - inputType:[module getInputType:@0]]; + NSMutableArray *shapes = [NSMutableArray new]; + NSMutableArray *inputTypes = [NSMutableArray new]; + NSNumber *numberOfInputs = [module getNumberOfInputs]; + + for (NSUInteger i = 0; i < [numberOfInputs intValue]; i++) { + [shapes addObject:[module getInputShape:[NSNumber numberWithInt:i]]]; + [inputTypes addObject:[module getInputType:[NSNumber numberWithInt:i]]]; + } + + NSArray *result = [module forward:@[input] shapes:shapes inputTypes:inputTypes]; + return result; +} + +- (NSArray *)forward:(NSArray *)inputs + shapes:(NSArray *)shapes + inputTypes:(NSArray *)inputTypes { + NSArray *result = [module forward:inputs shapes:shapes inputTypes:inputTypes]; return result; } - (void)loadModel:(NSURL *)modelURL completion:(void (^)(BOOL success, NSNumber *code))completion { module = [[ETModel alloc] init]; - NSNumber *result = [self->module loadModel:modelURL.path]; if ([result intValue] != 0) { completion(NO, result); diff --git a/lefthook.yml b/lefthook.yml index 335348649a..30305336c1 100644 --- a/lefthook.yml +++ b/lefthook.yml @@ -6,4 +6,4 @@ pre-commit: run: npx eslint {staged_files} types: glob: '*.{js,ts, jsx, tsx}' - run: npx tsc + run: npx tsc --noEmit diff --git a/src/hooks/general/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts index 99e2a2fcf8..6e174fbe35 100644 --- a/src/hooks/general/useExecutorchModule.ts +++ b/src/hooks/general/useExecutorchModule.ts @@ -15,7 +15,10 @@ export const useExecutorchModule = ({ isReady: boolean; isGenerating: boolean; downloadProgress: number; - forward: (input: ETInput, shape: number[]) => Promise; + forward: ( + input: ETInput | ETInput[], + shape: number[] | number[][] + ) => Promise; loadMethod: (methodName: string) => Promise; loadForward: () => Promise; } => { diff --git a/src/hooks/useModule.ts b/src/hooks/useModule.ts index bba9638b70..5a85522eb0 100644 --- a/src/hooks/useModule.ts +++ b/src/hooks/useModule.ts @@ -1,7 +1,17 @@ import { useEffect, useState } from 'react'; import { fetchResource } from '../utils/fetchResource'; import { ETError, getError } from '../Error'; -import { ETInput, Module, getTypeIdentifier } from '../types/common'; +import { ETInput, Module } from '../types/common'; +import { _ETModule } from '../native/RnExecutorchModules'; + +export const getTypeIdentifier = (input: ETInput): number => { + if (input instanceof Int8Array) return 1; + if (input instanceof Int32Array) return 3; + if (input instanceof BigInt64Array) return 4; + if (input instanceof Float32Array) return 6; + if (input instanceof Float64Array) return 7; + return -1; +}; interface Props { modelSource: string | number; @@ -13,7 +23,10 @@ interface _Module { isReady: boolean; isGenerating: boolean; downloadProgress: number; - forwardETInput: (input: ETInput, shape: number[]) => Promise; + forwardETInput: ( + input: ETInput[] | ETInput, + shape: number[][] | number[] + ) => ReturnType<_ETModule['forward']>; forwardImage: (input: string) => Promise; } @@ -59,7 +72,10 @@ export const useModule = ({ modelSource, module }: Props): _Module => { } }; - const forwardETInput = async (input: ETInput, shape: number[]) => { + const forwardETInput = async ( + input: ETInput[] | ETInput, + shape: number[][] | number[] + ) => { if (!isReady) { throw new Error(getError(ETError.ModuleNotLoaded)); } @@ -67,15 +83,36 @@ export const useModule = ({ modelSource, module }: Props): _Module => { throw new Error(getError(ETError.ModelGenerating)); } - const inputType = getTypeIdentifier(input); - if (inputType === -1) { - throw new Error(getError(ETError.InvalidArgument)); + // Since the native module expects an array of inputs and an array of shapes, + // if the user provides a single ETInput, we want to "unsqueeze" the array so + // the data is properly processed on the native side + if (!Array.isArray(input)) { + input = [input]; + } + + if (!Array.isArray(shape[0])) { + shape = [shape] as number[][]; + } + + let inputTypeIdentifiers: any[] = []; + let modelInputs: any[] = []; + + for (let idx = 0; idx < input.length; idx++) { + let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput); + if (currentInputTypeIdentifier === -1) { + throw new Error(getError(ETError.InvalidArgument)); + } + inputTypeIdentifiers.push(currentInputTypeIdentifier); + modelInputs.push([...(input[idx] as ETInput)]); } try { - const numberArray = [...input]; setIsGenerating(true); - const output = await module.forward(numberArray, shape, inputType); + const output = await module.forward( + modelInputs, + shape, + inputTypeIdentifiers + ); setIsGenerating(false); return output; } catch (e) { diff --git a/src/modules/general/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts index 5d5990c407..9ae1e27a7d 100644 --- a/src/modules/general/ExecutorchModule.ts +++ b/src/modules/general/ExecutorchModule.ts @@ -1,20 +1,35 @@ import { BaseModule } from '../BaseModule'; import { ETError, getError } from '../../Error'; import { _ETModule } from '../../native/RnExecutorchModules'; -import { ETInput, getTypeIdentifier } from '../../types/common'; +import { ETInput } from '../../types/common'; +import { getTypeIdentifier } from '../../hooks/useModule'; export class ExecutorchModule extends BaseModule { static module = new _ETModule(); - static async forward(input: ETInput, shape: number[]) { - const inputType = getTypeIdentifier(input); - if (inputType === -1) { - throw new Error(getError(ETError.InvalidArgument)); + static async forward(input: ETInput[] | ETInput, shape: number[][]) { + if (!Array.isArray(input)) { + input = [input]; + } + + let inputTypeIdentifiers = []; + let modelInputs = []; + + for (let idx = 0; idx < input.length; idx++) { + let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput); + if (currentInputTypeIdentifier === -1) { + throw new Error(getError(ETError.InvalidArgument)); + } + inputTypeIdentifiers.push(currentInputTypeIdentifier); + modelInputs.push([...(input[idx] as unknown as number[])]); } try { - const numberArray = [...input] as number[]; - return await this.module.forward(numberArray, shape, inputType); + return await this.module.forward( + modelInputs, + shape, + inputTypeIdentifiers + ); } catch (e) { throw new Error(getError(e)); } diff --git a/src/native/NativeETModule.ts b/src/native/NativeETModule.ts index d04da1abf7..feb9850938 100644 --- a/src/native/NativeETModule.ts +++ b/src/native/NativeETModule.ts @@ -5,11 +5,10 @@ export interface Spec extends TurboModule { loadModule(modelSource: string): Promise; forward( - input: number[], - shape: number[], - inputType: number - ): Promise; - + inputs: number[][], + shapes: number[][], + inputTypes: number[] + ): Promise; loadMethod(methodName: string): Promise; } diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index e898216a6d..7918281d0d 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -112,11 +112,11 @@ class _ClassificationModule { class _ETModule { async forward( - input: number[], - shape: number[], - inputType: number + inputs: number[][], + shapes: number[][], + inputTypes: number[] ): ReturnType { - return await ETModule.forward(input, shape, inputType); + return await ETModule.forward(inputs, shapes, inputTypes); } async loadModule( modelSource: string diff --git a/src/types/common.ts b/src/types/common.ts index ec0daa2c02..b1f3602824 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -3,6 +3,7 @@ import { _StyleTransferModule, _ObjectDetectionModule, ETModule, + _ETModule, } from '../native/RnExecutorchModules'; export type ResourceSource = string | number; @@ -26,15 +27,17 @@ export type ETInput = | Float32Array | Float64Array; -export const getTypeIdentifier = (arr: ETInput): number => { - if (arr instanceof Int8Array) return 0; - if (arr instanceof Int32Array) return 1; - if (arr instanceof BigInt64Array) return 2; - if (arr instanceof Float32Array) return 3; - if (arr instanceof Float64Array) return 4; - - return -1; -}; +export interface ExecutorchModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: ( + inputs: ETInput[] | ETInput, + shapes: number[][] + ) => ReturnType<_ETModule['forward']>; + loadMethod: (methodName: string) => Promise; + loadForward: () => Promise; +} export type Module = | _ClassificationModule diff --git a/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.pbxproj b/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.pbxproj index 73b86b0bec..96a4e2dbe1 100644 --- a/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.pbxproj +++ b/third-party/ios/ExecutorchLib/ExecutorchLib.xcodeproj/project.pbxproj @@ -90,7 +90,6 @@ 55EA2C562CB90E7D004315B3 /* Accelerate.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Accelerate.framework; path = System/Library/Frameworks/Accelerate.framework; sourceTree = SDKROOT; }; 55EA2C582CB90E80004315B3 /* CoreML.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreML.framework; path = System/Library/Frameworks/CoreML.framework; sourceTree = SDKROOT; }; 55EA2C5A2CB90E85004315B3 /* libsqlite3.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.tbd; path = usr/lib/libsqlite3.tbd; sourceTree = SDKROOT; }; - A84198832D02DF29006D4D5E /* InputType.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = InputType.h; sourceTree = ""; }; A851C4042CF9F1B600424E93 /* Utils.hpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.h; path = Utils.hpp; sourceTree = ""; }; /* End PBXFileReference section */ @@ -146,7 +145,6 @@ 55EA2C322CB90C7A004315B3 /* sampler */, 55EA2C3E2CB90C7A004315B3 /* tokenizer */, A851C4042CF9F1B600424E93 /* Utils.hpp */, - A84198832D02DF29006D4D5E /* InputType.h */, ); path = ExecutorchLib; sourceTree = ""; diff --git a/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.h b/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.h index 8ac3b477e2..9c78377e25 100644 --- a/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.h +++ b/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.h @@ -8,9 +8,9 @@ - (NSNumber *)loadModel:(NSString *)filePath; - (NSNumber *)loadMethod:(NSString *)methodName; - (NSNumber *)loadForward; -- (NSArray *)forward:(NSArray *)input - shape:(NSArray *)shape - inputType:(NSNumber *)inputType; +- (NSArray *)forward:(NSArray *)inputs + shapes:(NSArray *)shapes + inputTypes: (NSArray *)inputTypes; - (NSNumber *)getNumberOfInputs; - (NSNumber *)getInputType:(NSNumber *)index; - (NSArray *)getInputShape:(NSNumber *)index; diff --git a/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.mm b/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.mm index d975b184b8..1f4c0d1300 100644 --- a/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.mm +++ b/third-party/ios/ExecutorchLib/ExecutorchLib/Exported/ETModel.mm @@ -1,6 +1,5 @@ #import "ETModel.h" #include "Utils.hpp" -#include "InputType.h" #include #include #include @@ -34,58 +33,65 @@ - (NSNumber *)getNumberOfInputs { const auto method_meta = _model->method_meta("forward"); if (!method_meta.ok()) { @throw [NSException - exceptionWithName:@"get_number_of_inputs_error" - reason:[NSString stringWithFormat:@"%ld", (long)method_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_number_of_inputs_error" + reason:[NSString stringWithFormat:@"%ld", + (long)method_meta.error()] + userInfo:nil]; } - + return @(method_meta->num_inputs()); } - (NSNumber *)getInputType:(NSNumber *)index { const auto method_meta = _model->method_meta("forward"); - if(!method_meta.ok()){ + if (!method_meta.ok()) { @throw [NSException - exceptionWithName:@"get_input_type_error" - reason:[NSString stringWithFormat:@"%ld", (long)method_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_input_type_error" + reason:[NSString stringWithFormat:@"%ld", + (long)method_meta.error()] + userInfo:nil]; } - - const auto input_meta = method_meta->input_tensor_meta([index unsignedLongValue]); - if(!input_meta.ok()){ + + const auto input_meta = + method_meta->input_tensor_meta([index unsignedLongValue]); + if (!input_meta.ok()) { @throw [NSException - exceptionWithName:@"get_input_type_error" - reason:[NSString stringWithFormat:@"%ld", (long)input_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_input_type_error" + reason:[NSString + stringWithFormat:@"%ld", (long)input_meta.error()] + userInfo:nil]; } - - return [self getTypeAsNumber:input_meta->scalar_type()]; + + return scalarTypeToNSNumber(input_meta->scalar_type()); }; - (NSArray *)getInputShape:(NSNumber *)index { const auto method_meta = _model->method_meta("forward"); - if(!method_meta.ok()){ + if (!method_meta.ok()) { @throw [NSException - exceptionWithName:@"get_input_shape_error" - reason:[NSString stringWithFormat:@"%ld", (long)method_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_input_shape_error" + reason:[NSString stringWithFormat:@"%ld", + (long)method_meta.error()] + userInfo:nil]; } - - const auto input_meta = method_meta->input_tensor_meta([index unsignedLongValue]); - if(!input_meta.ok()){ + + const auto input_meta = + method_meta->input_tensor_meta([index unsignedLongValue]); + if (!input_meta.ok()) { @throw [NSException - exceptionWithName:@"get_input_shape_error" - reason:[NSString stringWithFormat:@"%ld", (long)input_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_input_shape_error" + reason:[NSString + stringWithFormat:@"%ld", (long)input_meta.error()] + userInfo:nil]; } - + const auto shape = input_meta->sizes(); NSMutableArray *nsShape = [[NSMutableArray alloc] init]; - - for(int i = 0; i < shape.size(); i++) { + + for (int i = 0; i < shape.size(); i++) { [nsShape addObject:@(shape[i])]; } - + return [nsShape copy]; }; @@ -93,120 +99,139 @@ - (NSNumber *)getNumberOfOutputs { const auto method_meta = _model->method_meta("forward"); if (!method_meta.ok()) { @throw [NSException - exceptionWithName:@"get_number_of_outputs_error" - reason:[NSString stringWithFormat:@"%ld", (long)method_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_number_of_outputs_error" + reason:[NSString stringWithFormat:@"%ld", + (long)method_meta.error()] + userInfo:nil]; } - + return @(method_meta->num_outputs()); } - (NSNumber *)getOutputType:(NSNumber *)index { const auto method_meta = _model->method_meta("forward"); - if(!method_meta.ok()){ + if (!method_meta.ok()) { @throw [NSException - exceptionWithName:@"get_output_type_error" - reason:[NSString stringWithFormat:@"%ld", (long)method_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_output_type_error" + reason:[NSString stringWithFormat:@"%ld", + (long)method_meta.error()] + userInfo:nil]; } - - const auto output_meta = method_meta->output_tensor_meta([index unsignedLongValue]); - if(!output_meta.ok()){ + + const auto output_meta = + method_meta->output_tensor_meta([index unsignedLongValue]); + if (!output_meta.ok()) { @throw [NSException - exceptionWithName:@"get_output_type_error" - reason:[NSString stringWithFormat:@"%ld", (long)output_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_output_type_error" + reason:[NSString stringWithFormat:@"%ld", + (long)output_meta.error()] + userInfo:nil]; } - - return [self getTypeAsNumber:output_meta->scalar_type()]; + + return scalarTypeToNSNumber(output_meta->scalar_type()); }; - (NSArray *)getOutputShape:(NSNumber *)index { const auto method_meta = _model->method_meta("forward"); - if(!method_meta.ok()){ + if (!method_meta.ok()) { @throw [NSException - exceptionWithName:@"get_output_shape_error" - reason:[NSString stringWithFormat:@"%ld", (long)method_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_output_shape_error" + reason:[NSString stringWithFormat:@"%ld", + (long)method_meta.error()] + userInfo:nil]; } - - const auto output_meta = method_meta->output_tensor_meta([index unsignedLongValue]); - if(!output_meta.ok()){ + + const auto output_meta = + method_meta->output_tensor_meta([index unsignedLongValue]); + if (!output_meta.ok()) { @throw [NSException - exceptionWithName:@"get_output_shape_error" - reason:[NSString stringWithFormat:@"%ld", (long)output_meta.error()] - userInfo:nil]; + exceptionWithName:@"get_output_shape_error" + reason:[NSString stringWithFormat:@"%ld", + (long)output_meta.error()] + userInfo:nil]; } - + const auto shape = output_meta->sizes(); NSMutableArray *nsShape = [[NSMutableArray alloc] init]; - - for(int i = 0; i < shape.size(); i++) { + + for (int i = 0; i < shape.size(); i++) { [nsShape addObject:@(shape[i])]; } - + return [nsShape copy]; }; -- (NSNumber *) getTypeAsNumber:(ScalarType)scalarType { - switch(scalarType) { - case ScalarType::Byte: return @(InputTypeInt8); - case ScalarType::Int: return @(InputTypeInt32); - case ScalarType::Long: return @(InputTypeInt64); - case ScalarType::Float: return @(InputTypeFloat32); - case ScalarType::Double: return @(InputTypeFloat64); - - default: - return @-1; - } -} - -- (NSArray *)forward:(NSArray *)input - shape:(NSArray *)shape - inputType:(NSNumber *)inputType { - int inputTypeIntValue = [inputType intValue]; - std::vector shapes = NSArrayToIntVector(shape); - @try { - switch (inputTypeIntValue) { - case InputTypeInt8: { - std::vector> output = - runForwardFromNSArray(input, shapes, _model); - return arrayToNSArray(output); - } - case InputTypeInt32: { - std::vector> output = - runForwardFromNSArray(input, shapes, _model); - return arrayToNSArray(output); - } - case InputTypeInt64: { - std::vector> output = - runForwardFromNSArray(input, shapes, _model); - return arrayToNSArray(output); - } - case InputTypeFloat32: { - std::vector> output = - runForwardFromNSArray(input, shapes, _model); - return arrayToNSArray(output); - } - case InputTypeFloat64: { - std::vector> output = - runForwardFromNSArray(input, shapes, _model); - return arrayToNSArray(output); - } +/** + * @brief Processes inputs through the forward pass of the model. + * + * This method takes input tensors, their corresponding shapes, and types, + * and performs a forward pass using _model. It supports both + * single and multiple inputs. + * + * @param inputs NSArray* of inputs where each element is an NSArray + * representing the data for a tensor. + * @param shapes An array of shapes corresponding to the input tensors. + * Each element is an NSArray of integers defining the dimensions. + * @param inputTypes An array of NSNumber objects representing the ScalarType of + * the input tensors + * + * @return An NSArray containing the results of the forward pass. Each element + * represents the output of the corresponding input. + * + * @throws NSException Throws an exception with name "forward_error" if + * an error occurs during input processing or model + * execution. + * + * @warning Ensure that the inputs, shapes, and inputTypes arrays have the + * same number of elements. Mismatched sizes can lead to runtime + * errors. + **/ +- (NSArray *)forward:(NSArray *)inputs + shapes:(NSArray *)shapes + inputTypes:(NSArray *)inputTypes { + std::vector inputTensors; + std::vector inputTensorPtrs; + + for (NSUInteger i = 0; i < [inputTypes count]; i++) { + NSArray *inputShapeNSArray = [shapes objectAtIndex:i]; + + std::vector inputShape = NSArrayToIntVector(inputShapeNSArray); + int inputType = [[inputTypes objectAtIndex:i] intValue]; + + NSArray *input = [inputs objectAtIndex:i]; + + TensorPtr currentTensor = NSArrayToTensorPtr(input, inputShape, inputType); + if (!currentTensor) { + throw [NSException + exceptionWithName:@"forward_error" + reason:[NSString stringWithFormat:@"%d", Error::InvalidArgument] + userInfo:nil]; } - } @catch (NSException *exception) { - NSInteger originalCode = [exception.reason integerValue]; - @throw [NSException - exceptionWithName:@"forward_error" - reason:[NSString stringWithFormat:@"%ld", (long)originalCode] - userInfo:nil]; + + // Since pushing back to inputTensors would cast to EValue (forward accepts a vector of EValues) + // We also push back to inputTensorPtrs to keep the underlying tensor alive. + // inputTensorPtrs vector retains shared ownership to prevent premature destruction + inputTensors.push_back(*currentTensor); + inputTensorPtrs.push_back(currentTensor); + } + + Result result = _model->forward(inputTensors); + + if (!result.ok()) { + throw [NSException + exceptionWithName:@"forward_error" + reason:[NSString stringWithFormat:@"%d", result.error()] + userInfo:nil]; } - // throwing an RN-ET exception - @ - throw [NSException exceptionWithName:@"forward_error" - reason:[NSString stringWithFormat:@"%d", - 0x65] // 101 - userInfo:nil]; + + NSMutableArray *output = [NSMutableArray new]; + for (int i = 0; i < result->size(); i++) { + auto currentResultTensor = result->at(i).toTensor(); + NSArray *currentOutput = arrayToNsArray(currentResultTensor.const_data_ptr(), currentResultTensor.numel(), currentResultTensor.scalar_type()); + [output addObject:currentOutput]; + } + return output; + } @end diff --git a/third-party/ios/ExecutorchLib/ExecutorchLib/Utils.hpp b/third-party/ios/ExecutorchLib/ExecutorchLib/Utils.hpp index 095c68baab..7c3fd7cb17 100644 --- a/third-party/ios/ExecutorchLib/ExecutorchLib/Utils.hpp +++ b/third-party/ios/ExecutorchLib/ExecutorchLib/Utils.hpp @@ -5,9 +5,9 @@ #include #include #include +#include #include #include -#include #ifdef __OBJC__ #import @@ -17,17 +17,21 @@ using namespace ::executorch::extension; using namespace ::torch::executor; template T getValueFromNSNumber(NSNumber *number) { - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { return static_cast([number charValue]); // `charValue` for 8-bit integers } else if constexpr (std::is_same::value) { return static_cast([number intValue]); // `intValue` for 32-bit integers - } else if constexpr (std::is_same::value) { + } else if constexpr (std::is_same::value || + std::is_same::value) { return static_cast( - [number longLongValue]); // `longLongValue` for 64-bit integers + [number longLongValue]); // Use `longLongValue` for 64-bit integers } else if constexpr (std::is_same::value) { return static_cast([number floatValue]); } else if constexpr (std::is_same::value) { return static_cast([number doubleValue]); + } else { + static_assert(std::is_same::value, + "Unsupported type for getValueFromNSNumber"); } } @@ -48,6 +52,99 @@ std::unique_ptr NSArrayToTypedArray(NSArray *nsArray) { return typedArray; } +std::function getDeleterForScalarType(ScalarType scalarType) { + switch (scalarType) { + case ScalarType::Char: + return [](void *ptr) { delete[] static_cast(ptr); }; + case ScalarType::Int: + return [](void *ptr) { delete[] static_cast(ptr); }; + case ScalarType::Long: + return [](void *ptr) { delete[] static_cast(ptr); }; + case ScalarType::Float: + return [](void *ptr) { delete[] static_cast(ptr); }; + case ScalarType::Double: + return [](void *ptr) { delete[] static_cast(ptr); }; + default: + throw std::invalid_argument( + "Unsupported ScalarType passed to getDeleterForScalarType!"); + } +} + +ScalarType intValueToScalarType(int intValue) { + // Check if the intValue is within the valid range of ScalarType + if (intValue < 0 || intValue >= static_cast(ScalarType::NumOptions)) { + throw std::out_of_range("Invalid ScalarType integer value: " + + std::to_string(intValue)); + } + return static_cast(intValue); +} + +NSNumber *scalarTypeToNSNumber(ScalarType scalarType) { + return @(static_cast(scalarType)); +} + +NSArray* flattenArray(NSArray *array) { + NSMutableArray *flatArray = [NSMutableArray array]; + + for (id element in array) { + if ([element isKindOfClass:[NSArray class]]) { + NSArray *nestedArray = flattenArray(element); + [flatArray addObjectsFromArray:nestedArray]; + } else { + [flatArray addObject:element]; + } + } + + return [flatArray copy]; +} + +void *NSArrayToVoidArray(NSArray *nsArray, ScalarType inputScalarType, + size_t &outSize) { + // This function assumes that the passed array may not be flattened, + // that's why we flatten it here + NSArray *flattenedArray = flattenArray(nsArray); + outSize = [flattenedArray count]; + + switch (inputScalarType) { + case ScalarType::Char: { + auto typedArray = NSArrayToTypedArray(flattenedArray); + return typedArray.release(); + } + case ScalarType::Long: { + auto typedArray = NSArrayToTypedArray(flattenedArray); + return typedArray.release(); + } + + case ScalarType::Int: { + auto typedArray = NSArrayToTypedArray(flattenedArray); + return typedArray.release(); + } + case ScalarType::Float: { + auto typedArray = NSArrayToTypedArray(flattenedArray); + return typedArray.release(); + } + case ScalarType::Double: { + auto typedArray = NSArrayToTypedArray(flattenedArray); + return typedArray.release(); + } + default: + throw std::invalid_argument( + "Unsupported ScalarType passed to NSArrayToVoidArray!"); + } +} + +TensorPtr NSArrayToTensorPtr(NSArray *nsArray, std::vector shape, + int inputType) { + ScalarType inputScalarType = intValueToScalarType(inputType); + size_t arraySize; + void *data = NSArrayToVoidArray(nsArray, inputScalarType, arraySize); + std::function deleter = + getDeleterForScalarType(inputScalarType); + auto tensor = make_tensor_ptr(shape, data, inputScalarType, TensorShapeDynamism::DYNAMIC_UNBOUND, deleter); + + return tensor; +} + template NSArray *arrayToNSArray(const void *array, ssize_t numel) { const T *typedArray = static_cast(array); @@ -62,15 +159,44 @@ NSArray *arrayToNSArray(const void *array, ssize_t numel) { template NSArray *arrayToNSArray(const std::vector> &dataPtrVec) { - NSMutableArray *nsArray = [NSMutableArray array]; - for (const auto &span : dataPtrVec) { - NSMutableArray *innerArray = [NSMutableArray arrayWithCapacity:span.size()]; - for(auto x : span) { - [innerArray addObject:@(x)]; - } - [nsArray addObject:[innerArray copy]]; + NSMutableArray *nsArray = [NSMutableArray array]; + for (const auto &span : dataPtrVec) { + NSMutableArray *innerArray = [NSMutableArray arrayWithCapacity:span.size()]; + for (auto x : span) { + [innerArray addObject:@(x)]; } - return [nsArray copy]; + [nsArray addObject:[innerArray copy]]; + } + return [nsArray copy]; +} + +NSArray *arrayToNsArray(const void *dataPtr, size_t numel, ScalarType scalarType) { + switch (scalarType) { + case ScalarType::Char: { + NSArray *outputArray = arrayToNSArray(dataPtr, numel); + return outputArray; + } + case ScalarType::Long: { + NSArray *outputArray = arrayToNSArray(dataPtr, numel); + return outputArray; + } + + case ScalarType::Int: { + NSArray *outputArray = arrayToNSArray(dataPtr, numel); + return outputArray; + } + case ScalarType::Float: { + NSArray *outputArray = arrayToNSArray(dataPtr, numel); + return outputArray; + } + case ScalarType::Double: { + NSArray *outputArray = arrayToNSArray(dataPtr, numel); + return outputArray; + } + default: + throw std::invalid_argument( + "Unsupported ScalarType passed to arrayToNSArray!"); + } } std::vector NSArrayToIntVector(NSArray *inputArray) { @@ -86,30 +212,4 @@ std::vector NSArrayToIntVector(NSArray *inputArray) { return output; } -template -std::vector> -runForwardFromNSArray(NSArray *inputArray, std::vector shapes, - std::unique_ptr &model) { - std::unique_ptr inputPtr = NSArrayToTypedArray(inputArray); - - TensorPtr inputTensor = from_blob(inputPtr.get(), shapes); - Result result = model->forward(inputTensor); - - if (result.ok()) { - std::vector> outputVec; - - for (const auto ¤tResult : *result) { - Tensor currentTensor = currentResult.toTensor(); - std::span currentSpan(currentTensor.const_data_ptr(), currentTensor.numel()); - outputVec.push_back(std::move(currentSpan)); - } - return outputVec; - } - - @throw [NSException - exceptionWithName:@"forward_error" - reason:[NSString stringWithFormat:@"%d", (int)result.error()] - userInfo:nil]; -} - #endif // Utils_hpp