From 851b396dc1d6df22c374d24e2977c9c12ed42051 Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 14:32:15 +0100 Subject: [PATCH 01/15] Move computer vision hooks to separate directory --- .../computer_vision/useClassification.ts} | 4 +-- .../computer_vision/useObjectDetection.ts} | 6 ++-- .../computer_vision/useStyleTransfer.ts} | 4 +-- src/index.tsx | 15 ++++++-- src/modules/computer_vision/BaseModule.ts | 34 +++++++++++++++++++ .../computer_vision/ClassificationModule.ts | 8 +++++ .../computer_vision/ObjectDetectionModule.ts | 8 +++++ .../computer_vision/StyleTransferModule.ts | 8 +++++ 8 files changed, 77 insertions(+), 10 deletions(-) rename src/{models/Classification.ts => hooks/computer_vision/useClassification.ts} (83%) rename src/{models/ObjectDetection.ts => hooks/computer_vision/useObjectDetection.ts} (76%) rename src/{models/StyleTransfer.ts => hooks/computer_vision/useStyleTransfer.ts} (83%) create mode 100644 src/modules/computer_vision/BaseModule.ts create mode 100644 src/modules/computer_vision/ClassificationModule.ts create mode 100644 src/modules/computer_vision/ObjectDetectionModule.ts create mode 100644 src/modules/computer_vision/StyleTransferModule.ts diff --git a/src/models/Classification.ts b/src/hooks/computer_vision/useClassification.ts similarity index 83% rename from src/models/Classification.ts rename to src/hooks/computer_vision/useClassification.ts index 6bd8dfb0ae..be10938194 100644 --- a/src/models/Classification.ts +++ b/src/hooks/computer_vision/useClassification.ts @@ -1,6 +1,6 @@ import { useState } from 'react'; -import { _ClassificationModule } from '../native/RnExecutorchModules'; -import { useModule } from '../useModule'; +import { _ClassificationModule } from '../../native/RnExecutorchModules'; +import { useModule } from '../../useModule'; interface Props { modelSource: string | number; diff --git a/src/models/ObjectDetection.ts b/src/hooks/computer_vision/useObjectDetection.ts similarity index 76% rename from src/models/ObjectDetection.ts rename to src/hooks/computer_vision/useObjectDetection.ts index fda2fd0188..9a5ccd1243 100644 --- a/src/models/ObjectDetection.ts +++ b/src/hooks/computer_vision/useObjectDetection.ts @@ -1,7 +1,7 @@ import { useState } from 'react'; -import { _ObjectDetectionModule } from '../native/RnExecutorchModules'; -import { useModule } from '../useModule'; -import { Detection } from '../types/object_detection'; +import { _ObjectDetectionModule } from '../../native/RnExecutorchModules'; +import { useModule } from '../../useModule'; +import { Detection } from '../../types/object_detection'; interface Props { modelSource: string | number; diff --git a/src/models/StyleTransfer.ts b/src/hooks/computer_vision/useStyleTransfer.ts similarity index 83% rename from src/models/StyleTransfer.ts rename to src/hooks/computer_vision/useStyleTransfer.ts index 215f5aea01..12c45ff4fc 100644 --- a/src/models/StyleTransfer.ts +++ b/src/hooks/computer_vision/useStyleTransfer.ts @@ -1,6 +1,6 @@ import { useState } from 'react'; -import { _StyleTransferModule } from '../native/RnExecutorchModules'; -import { useModule } from '../useModule'; +import { _StyleTransferModule } from '../../native/RnExecutorchModules'; +import { useModule } from '../../useModule'; interface Props { modelSource: string | number; diff --git a/src/index.tsx b/src/index.tsx index 74cfd13e34..1408f1f21b 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,7 +1,16 @@ export * from './ETModule'; export * from './LLM'; export * from './constants/modelUrls'; -export * from './models/Classification'; -export * from './models/ObjectDetection'; -export * from './models/StyleTransfer'; + +// hooks +export * from './hooks/computer_vision/useClassification'; +export * from './hooks/computer_vision/useObjectDetection'; +export * from './hooks/computer_vision/useStyleTransfer'; + +// modules +export * from './modules/computer_vision/ClassificationModule'; +export * from './modules/computer_vision/ObjectDetectionModule'; +export * from './modules/computer_vision/StyleTransferModule'; + +// types export * from './types/object_detection'; diff --git a/src/modules/computer_vision/BaseModule.ts b/src/modules/computer_vision/BaseModule.ts new file mode 100644 index 0000000000..9e4aee0658 --- /dev/null +++ b/src/modules/computer_vision/BaseModule.ts @@ -0,0 +1,34 @@ +import { Image } from 'react-native'; +import { getError } from '../../Error'; + +export class BaseModule { + protected module: any; + + constructor(module: any) { + this.module = module; + } + + async loadModule(modelSource: string | number) { + if (!modelSource) return; + + let path = modelSource; + + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; + } + + try { + await this.module.loadModule(path); + } catch (e) { + throw new Error(getError(e)); + } + } + + async forward(input: string) { + try { + return await this.module.forward(input); + } catch (e) { + throw new Error(getError(e)); + } + } +} diff --git a/src/modules/computer_vision/ClassificationModule.ts b/src/modules/computer_vision/ClassificationModule.ts new file mode 100644 index 0000000000..2b9e989de6 --- /dev/null +++ b/src/modules/computer_vision/ClassificationModule.ts @@ -0,0 +1,8 @@ +import { BaseModule } from './BaseModule'; +import { _ClassificationModule } from '../../native/RnExecutorchModules'; + +export class ClassificationModule extends BaseModule { + constructor() { + super(new _ClassificationModule()); + } +} diff --git a/src/modules/computer_vision/ObjectDetectionModule.ts b/src/modules/computer_vision/ObjectDetectionModule.ts new file mode 100644 index 0000000000..fa5f758920 --- /dev/null +++ b/src/modules/computer_vision/ObjectDetectionModule.ts @@ -0,0 +1,8 @@ +import { BaseModule } from './BaseModule'; +import { _ObjectDetectionModule } from '../../native/RnExecutorchModules'; + +export class ObjectDetectionModule extends BaseModule { + constructor() { + super(new _ObjectDetectionModule()); + } +} diff --git a/src/modules/computer_vision/StyleTransferModule.ts b/src/modules/computer_vision/StyleTransferModule.ts new file mode 100644 index 0000000000..830bd0579f --- /dev/null +++ b/src/modules/computer_vision/StyleTransferModule.ts @@ -0,0 +1,8 @@ +import { BaseModule } from './BaseModule'; +import { _StyleTransferModule } from '../../native/RnExecutorchModules'; + +export class StyleTransfer extends BaseModule { + constructor() { + super(new _StyleTransferModule()); + } +} From 2e71bc4b977f6c16f1491e15fa4306d7f0a402cd Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 14:39:59 +0100 Subject: [PATCH 02/15] Move hooks to appropriate directories --- .../bindings/useExecutorchModule.ts} | 8 ++++---- .../natural_language_processing/useLLM.ts} | 6 +++--- src/index.tsx | 4 ++-- 3 files changed, 9 insertions(+), 9 deletions(-) rename src/{ETModule.ts => hooks/bindings/useExecutorchModule.ts} (77%) rename src/{LLM.ts => hooks/natural_language_processing/useLLM.ts} (95%) diff --git a/src/ETModule.ts b/src/hooks/bindings/useExecutorchModule.ts similarity index 77% rename from src/ETModule.ts rename to src/hooks/bindings/useExecutorchModule.ts index 416c1f4c16..09dc1157b8 100644 --- a/src/ETModule.ts +++ b/src/hooks/bindings/useExecutorchModule.ts @@ -1,8 +1,8 @@ import { useState } from 'react'; -import { _ETModule } from './native/RnExecutorchModules'; -import { getError } from './Error'; -import { ExecutorchModule } from './types/common'; -import { useModule } from './useModule'; +import { _ETModule } from '../../native/RnExecutorchModules'; +import { getError } from '../../Error'; +import { ExecutorchModule } from '../../types/common'; +import { useModule } from '../../useModule'; interface Props { modelSource: string | number; diff --git a/src/LLM.ts b/src/hooks/natural_language_processing/useLLM.ts similarity index 95% rename from src/LLM.ts rename to src/hooks/natural_language_processing/useLLM.ts index 4219fdcadc..386750c459 100644 --- a/src/LLM.ts +++ b/src/hooks/natural_language_processing/useLLM.ts @@ -1,12 +1,12 @@ import { useCallback, useEffect, useRef, useState } from 'react'; import { EventSubscription, Image } from 'react-native'; -import { ResourceSource, Model } from './types/common'; +import { ResourceSource, Model } from '../../types/common'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, DEFAULT_SYSTEM_PROMPT, EOT_TOKEN, -} from './constants/llamaDefaults'; -import { LLM } from './native/RnExecutorchModules'; +} from '../../constants/llamaDefaults'; +import { LLM } from '../../native/RnExecutorchModules'; const interrupt = () => { LLM.interrupt(); diff --git a/src/index.tsx b/src/index.tsx index 1408f1f21b..645b122139 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,5 +1,5 @@ -export * from './ETModule'; -export * from './LLM'; +export * from './hooks/bindings/useExecutorchModule'; +export * from './hooks/natural_language_processing/useLLM'; export * from './constants/modelUrls'; // hooks From 656a447f2f1c849bb7baf3ec56266df4bee77b5e Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 15:35:05 +0100 Subject: [PATCH 03/15] Add executorch bindings to hookless api --- src/modules/bindings/ExecutorchModule.ts | 60 ++++++++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 src/modules/bindings/ExecutorchModule.ts diff --git a/src/modules/bindings/ExecutorchModule.ts b/src/modules/bindings/ExecutorchModule.ts new file mode 100644 index 0000000000..2e53d9a8c5 --- /dev/null +++ b/src/modules/bindings/ExecutorchModule.ts @@ -0,0 +1,60 @@ +import { Image } from 'react-native'; +import { ETError, getError } from '../../Error'; +import { _ETModule } from '../../native/RnExecutorchModules'; +import { ETInput } from '../../types/common'; + +const getTypeIdentifier = (arr: ETInput): number => { + if (arr instanceof Int8Array) return 0; + if (arr instanceof Int32Array) return 1; + if (arr instanceof BigInt64Array) return 2; + if (arr instanceof Float32Array) return 3; + if (arr instanceof Float64Array) return 4; + + return -1; +}; + +export class ExecutorchModule { + protected module = new _ETModule(); + + async loadModule(modelSource: string) { + if (!modelSource) return; + + let path = modelSource; + + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; + } + + try { + await this.module.loadModule(path); + } catch (e) { + throw new Error(getError(e)); + } + } + + async forward(input: ETInput, shape: number[]) { + const inputType = getTypeIdentifier(input); + if (inputType === -1) { + throw new Error(getError(ETError.InvalidArgument)); + } + + try { + const numberArray = [...input] as number[]; + return await this.module.forward(numberArray, shape, inputType); + } catch (e) { + throw new Error(getError(e)); + } + } + + async loadMethod(methodName: string) { + try { + await this.module.loadMethod(methodName); + } catch (e) { + throw new Error(getError(e)); + } + } + + async loadForward() { + await this.loadMethod('forward'); + } +} From 1f95ad607765083b836af949b62759251bf269aa Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 15:40:02 +0100 Subject: [PATCH 04/15] Move getTypeIdentifier to common types file --- src/modules/bindings/ExecutorchModule.ts | 12 +----------- src/types/common.ts | 10 ++++++++++ src/useModule.ts | 12 +----------- 3 files changed, 12 insertions(+), 22 deletions(-) diff --git a/src/modules/bindings/ExecutorchModule.ts b/src/modules/bindings/ExecutorchModule.ts index 2e53d9a8c5..cf9322b637 100644 --- a/src/modules/bindings/ExecutorchModule.ts +++ b/src/modules/bindings/ExecutorchModule.ts @@ -1,17 +1,7 @@ import { Image } from 'react-native'; import { ETError, getError } from '../../Error'; import { _ETModule } from '../../native/RnExecutorchModules'; -import { ETInput } from '../../types/common'; - -const getTypeIdentifier = (arr: ETInput): number => { - if (arr instanceof Int8Array) return 0; - if (arr instanceof Int32Array) return 1; - if (arr instanceof BigInt64Array) return 2; - if (arr instanceof Float32Array) return 3; - if (arr instanceof Float64Array) return 4; - - return -1; -}; +import { ETInput, getTypeIdentifier } from '../../types/common'; export class ExecutorchModule { protected module = new _ETModule(); diff --git a/src/types/common.ts b/src/types/common.ts index f12643d4c3..d3a03aa3f6 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -26,6 +26,16 @@ export type ETInput = | Float32Array | Float64Array; +export const getTypeIdentifier = (arr: ETInput): number => { + if (arr instanceof Int8Array) return 0; + if (arr instanceof Int32Array) return 1; + if (arr instanceof BigInt64Array) return 2; + if (arr instanceof Float32Array) return 3; + if (arr instanceof Float64Array) return 4; + + return -1; +}; + export interface ExecutorchModule { error: string | null; isReady: boolean; diff --git a/src/useModule.ts b/src/useModule.ts index 66c2fd49b6..9842f8e38a 100644 --- a/src/useModule.ts +++ b/src/useModule.ts @@ -1,17 +1,7 @@ import { useEffect, useState } from 'react'; import { Image } from 'react-native'; import { ETError, getError } from './Error'; -import { ETInput, module } from './types/common'; - -const getTypeIdentifier = (arr: ETInput): number => { - if (arr instanceof Int8Array) return 0; - if (arr instanceof Int32Array) return 1; - if (arr instanceof BigInt64Array) return 2; - if (arr instanceof Float32Array) return 3; - if (arr instanceof Float64Array) return 4; - - return -1; -}; +import { ETInput, module, getTypeIdentifier } from './types/common'; interface Props { modelSource: string | number; From 390c5d1698c0ed3a238a0101f871a6d147e5cc6c Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 16:21:23 +0100 Subject: [PATCH 05/15] Add LLM hookless api --- src/modules/bindings/ExecutorchModule.ts | 2 +- .../natural_language_processing/LLMModule.ts | 58 +++++++++++++++++++ 2 files changed, 59 insertions(+), 1 deletion(-) create mode 100644 src/modules/natural_language_processing/LLMModule.ts diff --git a/src/modules/bindings/ExecutorchModule.ts b/src/modules/bindings/ExecutorchModule.ts index cf9322b637..971f36e8a5 100644 --- a/src/modules/bindings/ExecutorchModule.ts +++ b/src/modules/bindings/ExecutorchModule.ts @@ -4,7 +4,7 @@ import { _ETModule } from '../../native/RnExecutorchModules'; import { ETInput, getTypeIdentifier } from '../../types/common'; export class ExecutorchModule { - protected module = new _ETModule(); + private module = new _ETModule(); async loadModule(modelSource: string) { if (!modelSource) return; diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts new file mode 100644 index 0000000000..af0a616dee --- /dev/null +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -0,0 +1,58 @@ +import { LLM } from '../../native/RnExecutorchModules'; +import { Image } from 'react-native'; +import { ResourceSource } from '../../types/common'; + +export class LLMModule { + async loadModel( + modelSource: ResourceSource, + tokenizerSource: ResourceSource, + systemPrompt?: string, + contextWindowLength?: number + ) { + try { + let modelUrl = modelSource; + let tokenizerUrl = tokenizerSource; + + if (typeof modelSource === 'number') { + modelUrl = Image.resolveAssetSource(modelSource).uri; + } + + if (typeof tokenizerSource === 'number') { + tokenizerUrl = Image.resolveAssetSource(tokenizerSource).uri; + } + + await LLM.loadLLM( + modelUrl as string, + tokenizerUrl as string, + systemPrompt, + contextWindowLength + ); + } catch (err) { + throw new Error((err as Error).message); + } + } + + async generate(input: string): Promise { + try { + await LLM.runInference(input); + } catch (err) { + throw new Error((err as Error).message); + } + } + + onDownloadProgress(callback: (data: number) => void) { + return LLM.onDownloadProgress(callback); + } + + onToken(callback: (data: string | undefined) => void) { + return LLM.onToken(callback); + } + + interrupt() { + LLM.interrupt(); + } + + deleteModule() { + LLM.deleteModule(); + } +} From 70c4d99bd0ccb8292825db1e70bf2611cf501c53 Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 16:33:45 +0100 Subject: [PATCH 06/15] Rename directory and add exports --- .../{bindings => general}/useExecutorchModule.ts | 0 src/index.tsx | 15 +++++++++++---- .../{bindings => general}/ExecutorchModule.ts | 0 3 files changed, 11 insertions(+), 4 deletions(-) rename src/hooks/{bindings => general}/useExecutorchModule.ts (100%) rename src/modules/{bindings => general}/ExecutorchModule.ts (100%) diff --git a/src/hooks/bindings/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts similarity index 100% rename from src/hooks/bindings/useExecutorchModule.ts rename to src/hooks/general/useExecutorchModule.ts diff --git a/src/index.tsx b/src/index.tsx index 645b122139..1c21b25e08 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -1,16 +1,23 @@ -export * from './hooks/bindings/useExecutorchModule'; -export * from './hooks/natural_language_processing/useLLM'; -export * from './constants/modelUrls'; - // hooks export * from './hooks/computer_vision/useClassification'; export * from './hooks/computer_vision/useObjectDetection'; export * from './hooks/computer_vision/useStyleTransfer'; +export * from './hooks/natural_language_processing/useLLM'; + +export * from './hooks/general/useExecutorchModule'; + // modules export * from './modules/computer_vision/ClassificationModule'; export * from './modules/computer_vision/ObjectDetectionModule'; export * from './modules/computer_vision/StyleTransferModule'; +export * from './modules/natural_language_processing/LLMModule'; + +export * from './modules/general/ExecutorchModule'; + // types export * from './types/object_detection'; + +// constants +export * from './constants/modelUrls'; diff --git a/src/modules/bindings/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts similarity index 100% rename from src/modules/bindings/ExecutorchModule.ts rename to src/modules/general/ExecutorchModule.ts From 2c531ca8474270c9a9e8502fbee56534f7a0772b Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 14 Jan 2025 16:37:11 +0100 Subject: [PATCH 07/15] Fix lint warnings --- examples/computer-vision/components/ImageWithBboxes.tsx | 9 ++------- .../computer-vision/screens/ObjectDetectionScreen.tsx | 6 +++++- examples/llama/screens/ChatScreen.tsx | 5 ++++- src/types/common.ts | 2 +- src/useModule.ts | 4 ++-- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/examples/computer-vision/components/ImageWithBboxes.tsx b/examples/computer-vision/components/ImageWithBboxes.tsx index 7d08e33275..65d345d6cb 100644 --- a/examples/computer-vision/components/ImageWithBboxes.tsx +++ b/examples/computer-vision/components/ImageWithBboxes.tsx @@ -68,13 +68,7 @@ export default function ImageWithBboxes({ const height = (y2 - y1) * scaleY; return ( - + {detection.label} ({(detection.score * 100).toFixed(1)}%) @@ -98,6 +92,7 @@ const styles = StyleSheet.create({ bbox: { position: 'absolute', borderWidth: 2, + borderColor: 'red', }, label: { position: 'absolute', diff --git a/examples/computer-vision/screens/ObjectDetectionScreen.tsx b/examples/computer-vision/screens/ObjectDetectionScreen.tsx index 280e3c5722..ea82dfd493 100644 --- a/examples/computer-vision/screens/ObjectDetectionScreen.tsx +++ b/examples/computer-vision/screens/ObjectDetectionScreen.tsx @@ -76,7 +76,7 @@ export const ObjectDetectionScreen = ({ /> ) : ( @@ -127,4 +127,8 @@ const styles = StyleSheet.create({ flex: 1, marginRight: 4, }, + fullSizeImage: { + width: '100%', + height: '100%', + }, }); diff --git a/examples/llama/screens/ChatScreen.tsx b/examples/llama/screens/ChatScreen.tsx index 4d2f707d75..40d9e70be7 100644 --- a/examples/llama/screens/ChatScreen.tsx +++ b/examples/llama/screens/ChatScreen.tsx @@ -63,7 +63,7 @@ export default function ChatScreen() { @@ -133,6 +133,9 @@ const styles = StyleSheet.create({ container: { flex: 1, }, + keyboardAvoidingView: { + flex: 1, + }, topContainer: { height: 68, width: '100%', diff --git a/src/types/common.ts b/src/types/common.ts index d3a03aa3f6..1d186d5081 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -45,7 +45,7 @@ export interface ExecutorchModule { loadForward: () => Promise; } -export type module = +export type Module = | _ClassificationModule | _StyleTransferModule | _ObjectDetectionModule diff --git a/src/useModule.ts b/src/useModule.ts index 9842f8e38a..e4080a55f9 100644 --- a/src/useModule.ts +++ b/src/useModule.ts @@ -1,11 +1,11 @@ import { useEffect, useState } from 'react'; import { Image } from 'react-native'; import { ETError, getError } from './Error'; -import { ETInput, module, getTypeIdentifier } from './types/common'; +import { ETInput, Module, getTypeIdentifier } from './types/common'; interface Props { modelSource: string | number; - module: module; + module: Module; } interface _Module { From 41d6b4216298cfb1ca1eca85c56316c4728536c7 Mon Sep 17 00:00:00 2001 From: jakmro Date: Wed, 15 Jan 2025 13:28:36 +0100 Subject: [PATCH 08/15] Make classes static --- src/hooks/computer_vision/useClassification.ts | 2 +- .../computer_vision/useObjectDetection.ts | 2 +- src/hooks/computer_vision/useStyleTransfer.ts | 2 +- src/hooks/general/useExecutorchModule.ts | 2 +- src/modules/computer_vision/BaseModule.ts | 14 ++++---------- .../computer_vision/ClassificationModule.ts | 10 +++++++--- .../computer_vision/ObjectDetectionModule.ts | 11 ++++++++--- .../computer_vision/StyleTransferModule.ts | 10 +++++++--- src/modules/general/ExecutorchModule.ts | 16 +++++++--------- .../natural_language_processing/LLMModule.ts | 12 ++++++------ src/native/RnExecutorchModules.ts | 18 +++++++++--------- 11 files changed, 52 insertions(+), 47 deletions(-) diff --git a/src/hooks/computer_vision/useClassification.ts b/src/hooks/computer_vision/useClassification.ts index be10938194..5a82455618 100644 --- a/src/hooks/computer_vision/useClassification.ts +++ b/src/hooks/computer_vision/useClassification.ts @@ -16,7 +16,7 @@ interface ClassificationModule { export const useClassification = ({ modelSource, }: Props): ClassificationModule => { - const [module, _] = useState(() => new _ClassificationModule()); + const [module, _] = useState(() => _ClassificationModule); const { error, isReady, diff --git a/src/hooks/computer_vision/useObjectDetection.ts b/src/hooks/computer_vision/useObjectDetection.ts index 9a5ccd1243..e08d757edb 100644 --- a/src/hooks/computer_vision/useObjectDetection.ts +++ b/src/hooks/computer_vision/useObjectDetection.ts @@ -17,7 +17,7 @@ interface ObjectDetectionModule { export const useObjectDetection = ({ modelSource, }: Props): ObjectDetectionModule => { - const [module, _] = useState(() => new _ObjectDetectionModule()); + const [module, _] = useState(() => _ObjectDetectionModule); const { error, isReady, diff --git a/src/hooks/computer_vision/useStyleTransfer.ts b/src/hooks/computer_vision/useStyleTransfer.ts index 12c45ff4fc..5e7624bf85 100644 --- a/src/hooks/computer_vision/useStyleTransfer.ts +++ b/src/hooks/computer_vision/useStyleTransfer.ts @@ -16,7 +16,7 @@ interface StyleTransferModule { export const useStyleTransfer = ({ modelSource, }: Props): StyleTransferModule => { - const [module, _] = useState(() => new _StyleTransferModule()); + const [module, _] = useState(() => _StyleTransferModule); const { error, isReady, diff --git a/src/hooks/general/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts index 09dc1157b8..eb90383a0f 100644 --- a/src/hooks/general/useExecutorchModule.ts +++ b/src/hooks/general/useExecutorchModule.ts @@ -11,7 +11,7 @@ interface Props { export const useExecutorchModule = ({ modelSource, }: Props): ExecutorchModule => { - const [module] = useState(() => new _ETModule()); + const [module] = useState(() => _ETModule); const { error, isReady, diff --git a/src/modules/computer_vision/BaseModule.ts b/src/modules/computer_vision/BaseModule.ts index 9e4aee0658..a09b0a5c22 100644 --- a/src/modules/computer_vision/BaseModule.ts +++ b/src/modules/computer_vision/BaseModule.ts @@ -2,13 +2,7 @@ import { Image } from 'react-native'; import { getError } from '../../Error'; export class BaseModule { - protected module: any; - - constructor(module: any) { - this.module = module; - } - - async loadModule(modelSource: string | number) { + static async load(module: any, modelSource: string | number) { if (!modelSource) return; let path = modelSource; @@ -18,15 +12,15 @@ export class BaseModule { } try { - await this.module.loadModule(path); + await module.loadModule(path); } catch (e) { throw new Error(getError(e)); } } - async forward(input: string) { + static async forward(module: any, input: string) { try { - return await this.module.forward(input); + return await module.forward(input); } catch (e) { throw new Error(getError(e)); } diff --git a/src/modules/computer_vision/ClassificationModule.ts b/src/modules/computer_vision/ClassificationModule.ts index 2b9e989de6..c21e4594de 100644 --- a/src/modules/computer_vision/ClassificationModule.ts +++ b/src/modules/computer_vision/ClassificationModule.ts @@ -1,8 +1,12 @@ import { BaseModule } from './BaseModule'; import { _ClassificationModule } from '../../native/RnExecutorchModules'; -export class ClassificationModule extends BaseModule { - constructor() { - super(new _ClassificationModule()); +export class ClassificationModule { + static async load(modelSource: string | number) { + await BaseModule.load(_ClassificationModule, modelSource); + } + + static async forward(input: string): Promise<{ [category: string]: number }> { + return await BaseModule.forward(_ClassificationModule, input); } } diff --git a/src/modules/computer_vision/ObjectDetectionModule.ts b/src/modules/computer_vision/ObjectDetectionModule.ts index fa5f758920..3c3dd9ddce 100644 --- a/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/src/modules/computer_vision/ObjectDetectionModule.ts @@ -1,8 +1,13 @@ import { BaseModule } from './BaseModule'; import { _ObjectDetectionModule } from '../../native/RnExecutorchModules'; +import { Detection } from '../../types/object_detection'; -export class ObjectDetectionModule extends BaseModule { - constructor() { - super(new _ObjectDetectionModule()); +export class ObjectDetectionModule { + static async load(modelSource: string | number) { + await BaseModule.load(_ObjectDetectionModule, modelSource); + } + + static async forward(input: string): Promise { + return await BaseModule.forward(_ObjectDetectionModule, input); } } diff --git a/src/modules/computer_vision/StyleTransferModule.ts b/src/modules/computer_vision/StyleTransferModule.ts index 830bd0579f..d7b55e1403 100644 --- a/src/modules/computer_vision/StyleTransferModule.ts +++ b/src/modules/computer_vision/StyleTransferModule.ts @@ -1,8 +1,12 @@ import { BaseModule } from './BaseModule'; import { _StyleTransferModule } from '../../native/RnExecutorchModules'; -export class StyleTransfer extends BaseModule { - constructor() { - super(new _StyleTransferModule()); +export class StyleTransferModule { + static async load(modelSource: string | number) { + await BaseModule.load(_StyleTransferModule, modelSource); + } + + static async forward(input: string): Promise { + return await BaseModule.forward(_StyleTransferModule, input); } } diff --git a/src/modules/general/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts index 971f36e8a5..ca4f33894c 100644 --- a/src/modules/general/ExecutorchModule.ts +++ b/src/modules/general/ExecutorchModule.ts @@ -4,9 +4,7 @@ import { _ETModule } from '../../native/RnExecutorchModules'; import { ETInput, getTypeIdentifier } from '../../types/common'; export class ExecutorchModule { - private module = new _ETModule(); - - async loadModule(modelSource: string) { + static async load(modelSource: string) { if (!modelSource) return; let path = modelSource; @@ -16,13 +14,13 @@ export class ExecutorchModule { } try { - await this.module.loadModule(path); + await _ETModule.loadModule(path); } catch (e) { throw new Error(getError(e)); } } - async forward(input: ETInput, shape: number[]) { + static async forward(input: ETInput, shape: number[]) { const inputType = getTypeIdentifier(input); if (inputType === -1) { throw new Error(getError(ETError.InvalidArgument)); @@ -30,21 +28,21 @@ export class ExecutorchModule { try { const numberArray = [...input] as number[]; - return await this.module.forward(numberArray, shape, inputType); + return await _ETModule.forward(numberArray, shape, inputType); } catch (e) { throw new Error(getError(e)); } } - async loadMethod(methodName: string) { + static async loadMethod(methodName: string) { try { - await this.module.loadMethod(methodName); + await _ETModule.loadMethod(methodName); } catch (e) { throw new Error(getError(e)); } } - async loadForward() { + static async loadForward() { await this.loadMethod('forward'); } } diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index af0a616dee..c0638582c8 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -3,7 +3,7 @@ import { Image } from 'react-native'; import { ResourceSource } from '../../types/common'; export class LLMModule { - async loadModel( + static async load( modelSource: ResourceSource, tokenizerSource: ResourceSource, systemPrompt?: string, @@ -32,7 +32,7 @@ export class LLMModule { } } - async generate(input: string): Promise { + static async generate(input: string) { try { await LLM.runInference(input); } catch (err) { @@ -40,19 +40,19 @@ export class LLMModule { } } - onDownloadProgress(callback: (data: number) => void) { + static onDownloadProgress(callback: (data: number) => void) { return LLM.onDownloadProgress(callback); } - onToken(callback: (data: string | undefined) => void) { + static onToken(callback: (data: string | undefined) => void) { return LLM.onToken(callback); } - interrupt() { + static interrupt() { LLM.interrupt(); } - deleteModule() { + static deleteModule() { LLM.deleteModule(); } } diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index 8a80b59590..ca79367b1b 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -72,44 +72,44 @@ const StyleTransfer = StyleTransferSpec ); class _ObjectDetectionModule { - async forward(input: string) { + static async forward(input: string) { return await ObjectDetection.forward(input); } - async loadModule(modelSource: string | number) { + static async loadModule(modelSource: string | number) { return await ObjectDetection.loadModule(modelSource); } } class _StyleTransferModule { - async forward(input: string) { + static async forward(input: string) { return await StyleTransfer.forward(input); } - async loadModule(modelSource: string | number) { + static async loadModule(modelSource: string | number) { return await StyleTransfer.loadModule(modelSource); } } class _ClassificationModule { - async forward(input: string) { + static async forward(input: string) { return await Classification.forward(input); } - async loadModule(modelSource: string | number) { + static async loadModule(modelSource: string | number) { return await Classification.loadModule(modelSource); } } class _ETModule { - async forward( + static async forward( input: number[], shape: number[], inputType: number ): Promise { return await ETModule.forward(input, shape, inputType); } - async loadModule(modelSource: string) { + static async loadModule(modelSource: string) { return await ETModule.loadModule(modelSource); } - async loadMethod(methodName: string): Promise { + static async loadMethod(methodName: string): Promise { return await ETModule.loadMethod(methodName); } } From 031f399c87d2ce6683c2cbc04a7612634e5b33d6 Mon Sep 17 00:00:00 2001 From: jakmro Date: Sun, 19 Jan 2025 20:32:09 +0100 Subject: [PATCH 09/15] Fix llm hookles api --- src/hooks/general/useExecutorchModule.ts | 2 +- src/modules/natural_language_processing/LLMModule.ts | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/hooks/general/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts index eb90383a0f..9f559acbb1 100644 --- a/src/hooks/general/useExecutorchModule.ts +++ b/src/hooks/general/useExecutorchModule.ts @@ -11,7 +11,7 @@ interface Props { export const useExecutorchModule = ({ modelSource, }: Props): ExecutorchModule => { - const [module] = useState(() => _ETModule); + const [module, _] = useState(() => _ETModule); const { error, isReady, diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index c0638582c8..12ca7a576a 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -1,13 +1,17 @@ import { LLM } from '../../native/RnExecutorchModules'; import { Image } from 'react-native'; +import { + DEFAULT_CONTEXT_WINDOW_LENGTH, + DEFAULT_SYSTEM_PROMPT, +} from '../../constants/llamaDefaults'; import { ResourceSource } from '../../types/common'; export class LLMModule { static async load( modelSource: ResourceSource, tokenizerSource: ResourceSource, - systemPrompt?: string, - contextWindowLength?: number + systemPrompt = DEFAULT_SYSTEM_PROMPT, + contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH ) { try { let modelUrl = modelSource; From 46e9957a9a708a12e71d282e3f1ed9e17cabd2f4 Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 21 Jan 2025 11:28:52 +0100 Subject: [PATCH 10/15] Rename TS interfaces --- src/hooks/computer_vision/useClassification.ts | 4 ++-- src/hooks/computer_vision/useObjectDetection.ts | 4 ++-- src/hooks/computer_vision/useStyleTransfer.ts | 6 ++---- src/hooks/general/useExecutorchModule.ts | 15 ++++++++++++--- src/types/common.ts | 9 --------- 5 files changed, 18 insertions(+), 20 deletions(-) diff --git a/src/hooks/computer_vision/useClassification.ts b/src/hooks/computer_vision/useClassification.ts index 5a82455618..d3b30ef855 100644 --- a/src/hooks/computer_vision/useClassification.ts +++ b/src/hooks/computer_vision/useClassification.ts @@ -6,7 +6,7 @@ interface Props { modelSource: string | number; } -interface ClassificationModule { +interface UseClassification { error: string | null; isReady: boolean; isGenerating: boolean; @@ -15,7 +15,7 @@ interface ClassificationModule { export const useClassification = ({ modelSource, -}: Props): ClassificationModule => { +}: Props): UseClassification => { const [module, _] = useState(() => _ClassificationModule); const { error, diff --git a/src/hooks/computer_vision/useObjectDetection.ts b/src/hooks/computer_vision/useObjectDetection.ts index e08d757edb..596cf660e5 100644 --- a/src/hooks/computer_vision/useObjectDetection.ts +++ b/src/hooks/computer_vision/useObjectDetection.ts @@ -7,7 +7,7 @@ interface Props { modelSource: string | number; } -interface ObjectDetectionModule { +interface UseObjectDetection { error: string | null; isReady: boolean; isGenerating: boolean; @@ -16,7 +16,7 @@ interface ObjectDetectionModule { export const useObjectDetection = ({ modelSource, -}: Props): ObjectDetectionModule => { +}: Props): UseObjectDetection => { const [module, _] = useState(() => _ObjectDetectionModule); const { error, diff --git a/src/hooks/computer_vision/useStyleTransfer.ts b/src/hooks/computer_vision/useStyleTransfer.ts index 5e7624bf85..167d617a55 100644 --- a/src/hooks/computer_vision/useStyleTransfer.ts +++ b/src/hooks/computer_vision/useStyleTransfer.ts @@ -6,16 +6,14 @@ interface Props { modelSource: string | number; } -interface StyleTransferModule { +interface UseStyleTransfer { error: string | null; isReady: boolean; isGenerating: boolean; forward: (input: string) => Promise; } -export const useStyleTransfer = ({ - modelSource, -}: Props): StyleTransferModule => { +export const useStyleTransfer = ({ modelSource }: Props): UseStyleTransfer => { const [module, _] = useState(() => _StyleTransferModule); const { error, diff --git a/src/hooks/general/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts index 9f559acbb1..efed44122b 100644 --- a/src/hooks/general/useExecutorchModule.ts +++ b/src/hooks/general/useExecutorchModule.ts @@ -1,16 +1,25 @@ import { useState } from 'react'; import { _ETModule } from '../../native/RnExecutorchModules'; -import { getError } from '../../Error'; -import { ExecutorchModule } from '../../types/common'; import { useModule } from '../../useModule'; +import { ETInput } from '../../types/common'; +import { getError } from '../../Error'; interface Props { modelSource: string | number; } +interface UseExecutorchModule { + error: string | null; + isReady: boolean; + isGenerating: boolean; + forward: (input: ETInput, shape: number[]) => Promise; + loadMethod: (methodName: string) => Promise; + loadForward: () => Promise; +} + export const useExecutorchModule = ({ modelSource, -}: Props): ExecutorchModule => { +}: Props): UseExecutorchModule => { const [module, _] = useState(() => _ETModule); const { error, diff --git a/src/types/common.ts b/src/types/common.ts index 1d186d5081..ec0daa2c02 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -36,15 +36,6 @@ export const getTypeIdentifier = (arr: ETInput): number => { return -1; }; -export interface ExecutorchModule { - error: string | null; - isReady: boolean; - isGenerating: boolean; - forward: (input: ETInput, shape: number[]) => Promise; - loadMethod: (methodName: string) => Promise; - loadForward: () => Promise; -} - export type Module = | _ClassificationModule | _StyleTransferModule From 8a1278101744852b1fb0a995ae79790e43cd56bc Mon Sep 17 00:00:00 2001 From: jakmro Date: Tue, 21 Jan 2025 11:33:15 +0100 Subject: [PATCH 11/15] Update docs --- docs/docs/computer-vision/useClassification.mdx | 2 +- docs/docs/computer-vision/useObjectDetection.mdx | 2 +- docs/docs/computer-vision/useStyleTransfer.mdx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/docs/computer-vision/useClassification.mdx b/docs/docs/computer-vision/useClassification.mdx index cb39d07a60..bb5cee7b7d 100644 --- a/docs/docs/computer-vision/useClassification.mdx +++ b/docs/docs/computer-vision/useClassification.mdx @@ -35,7 +35,7 @@ try { Type definitions ```typescript -interface ClassificationModule { +interface UseClassification { error: string | null; isReady: boolean; isGenerating: boolean; diff --git a/docs/docs/computer-vision/useObjectDetection.mdx b/docs/docs/computer-vision/useObjectDetection.mdx index 0e40076487..2bcf387947 100644 --- a/docs/docs/computer-vision/useObjectDetection.mdx +++ b/docs/docs/computer-vision/useObjectDetection.mdx @@ -46,7 +46,7 @@ interface Detection { score: number; } -interface ObjectDetectionModule { +interface UseObjectDetection { error: string | null; isReady: boolean; isGenerating: boolean; diff --git a/docs/docs/computer-vision/useStyleTransfer.mdx b/docs/docs/computer-vision/useStyleTransfer.mdx index b6f4f4bed6..a1d373f2fc 100644 --- a/docs/docs/computer-vision/useStyleTransfer.mdx +++ b/docs/docs/computer-vision/useStyleTransfer.mdx @@ -34,7 +34,7 @@ try { Type definitions ```typescript -interface StyleTransferModule { +interface UseStyleTransfer { error: string | null; isReady: boolean; isGenerating: boolean; From 8703a5609bb510d3020438a40ee1285efd8ec507 Mon Sep 17 00:00:00 2001 From: jakmro Date: Thu, 23 Jan 2025 09:56:45 +0100 Subject: [PATCH 12/15] Add suggested changes --- .../computer_vision/useClassification.ts | 12 +++--- .../computer_vision/useObjectDetection.ts | 12 +++--- src/hooks/computer_vision/useStyleTransfer.ts | 10 ++--- src/hooks/general/useExecutorchModule.ts | 12 +++--- src/modules/computer_vision/BaseCVModule.ts | 38 +++++++++++++++++++ src/modules/computer_vision/BaseModule.ts | 28 -------------- .../computer_vision/ClassificationModule.ts | 14 +++---- .../computer_vision/ObjectDetectionModule.ts | 15 ++++---- .../computer_vision/StyleTransferModule.ts | 14 +++---- src/modules/general/ExecutorchModule.ts | 8 ++-- src/native/NativeETModule.ts | 1 + src/native/RnExecutorchModules.ts | 36 +++++++++++++----- 12 files changed, 111 insertions(+), 89 deletions(-) create mode 100644 src/modules/computer_vision/BaseCVModule.ts delete mode 100644 src/modules/computer_vision/BaseModule.ts diff --git a/src/hooks/computer_vision/useClassification.ts b/src/hooks/computer_vision/useClassification.ts index d3b30ef855..836a0fc938 100644 --- a/src/hooks/computer_vision/useClassification.ts +++ b/src/hooks/computer_vision/useClassification.ts @@ -6,17 +6,15 @@ interface Props { modelSource: string | number; } -interface UseClassification { +export const useClassification = ({ + modelSource, +}: Props): { error: string | null; isReady: boolean; isGenerating: boolean; forward: (input: string) => Promise<{ [category: string]: number }>; -} - -export const useClassification = ({ - modelSource, -}: Props): UseClassification => { - const [module, _] = useState(() => _ClassificationModule); +} => { + const [module, _] = useState(() => new _ClassificationModule()); const { error, isReady, diff --git a/src/hooks/computer_vision/useObjectDetection.ts b/src/hooks/computer_vision/useObjectDetection.ts index 596cf660e5..8456ee338a 100644 --- a/src/hooks/computer_vision/useObjectDetection.ts +++ b/src/hooks/computer_vision/useObjectDetection.ts @@ -7,17 +7,15 @@ interface Props { modelSource: string | number; } -interface UseObjectDetection { +export const useObjectDetection = ({ + modelSource, +}: Props): { error: string | null; isReady: boolean; isGenerating: boolean; forward: (input: string) => Promise; -} - -export const useObjectDetection = ({ - modelSource, -}: Props): UseObjectDetection => { - const [module, _] = useState(() => _ObjectDetectionModule); +} => { + const [module, _] = useState(() => new _ObjectDetectionModule()); const { error, isReady, diff --git a/src/hooks/computer_vision/useStyleTransfer.ts b/src/hooks/computer_vision/useStyleTransfer.ts index 167d617a55..20c400b41e 100644 --- a/src/hooks/computer_vision/useStyleTransfer.ts +++ b/src/hooks/computer_vision/useStyleTransfer.ts @@ -6,15 +6,15 @@ interface Props { modelSource: string | number; } -interface UseStyleTransfer { +export const useStyleTransfer = ({ + modelSource, +}: Props): { error: string | null; isReady: boolean; isGenerating: boolean; forward: (input: string) => Promise; -} - -export const useStyleTransfer = ({ modelSource }: Props): UseStyleTransfer => { - const [module, _] = useState(() => _StyleTransferModule); +} => { + const [module, _] = useState(() => new _StyleTransferModule()); const { error, isReady, diff --git a/src/hooks/general/useExecutorchModule.ts b/src/hooks/general/useExecutorchModule.ts index efed44122b..5a180fdff0 100644 --- a/src/hooks/general/useExecutorchModule.ts +++ b/src/hooks/general/useExecutorchModule.ts @@ -8,19 +8,17 @@ interface Props { modelSource: string | number; } -interface UseExecutorchModule { +export const useExecutorchModule = ({ + modelSource, +}: Props): { error: string | null; isReady: boolean; isGenerating: boolean; forward: (input: ETInput, shape: number[]) => Promise; loadMethod: (methodName: string) => Promise; loadForward: () => Promise; -} - -export const useExecutorchModule = ({ - modelSource, -}: Props): UseExecutorchModule => { - const [module, _] = useState(() => _ETModule); +} => { + const [module] = useState(() => new _ETModule()); const { error, isReady, diff --git a/src/modules/computer_vision/BaseCVModule.ts b/src/modules/computer_vision/BaseCVModule.ts new file mode 100644 index 0000000000..43d76a549d --- /dev/null +++ b/src/modules/computer_vision/BaseCVModule.ts @@ -0,0 +1,38 @@ +import { Image } from 'react-native'; +import { + _StyleTransferModule, + _ObjectDetectionModule, + _ClassificationModule, +} from '../../native/RnExecutorchModules'; +import { getError } from '../../Error'; + +export class BaseCVModule { + static module: + | _StyleTransferModule + | _ObjectDetectionModule + | _ClassificationModule; + + static async load(modelSource: string | number) { + if (!modelSource) return; + + let path = modelSource; + + if (typeof modelSource === 'number') { + path = Image.resolveAssetSource(modelSource).uri; + } + + try { + await this.module.loadModule(path); + } catch (e) { + throw new Error(getError(e)); + } + } + + static async forward(input: string) { + try { + return await this.module.forward(input); + } catch (e) { + throw new Error(getError(e)); + } + } +} diff --git a/src/modules/computer_vision/BaseModule.ts b/src/modules/computer_vision/BaseModule.ts deleted file mode 100644 index a09b0a5c22..0000000000 --- a/src/modules/computer_vision/BaseModule.ts +++ /dev/null @@ -1,28 +0,0 @@ -import { Image } from 'react-native'; -import { getError } from '../../Error'; - -export class BaseModule { - static async load(module: any, modelSource: string | number) { - if (!modelSource) return; - - let path = modelSource; - - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - await module.loadModule(path); - } catch (e) { - throw new Error(getError(e)); - } - } - - static async forward(module: any, input: string) { - try { - return await module.forward(input); - } catch (e) { - throw new Error(getError(e)); - } - } -} diff --git a/src/modules/computer_vision/ClassificationModule.ts b/src/modules/computer_vision/ClassificationModule.ts index c21e4594de..ba1bf10718 100644 --- a/src/modules/computer_vision/ClassificationModule.ts +++ b/src/modules/computer_vision/ClassificationModule.ts @@ -1,12 +1,12 @@ -import { BaseModule } from './BaseModule'; +import { BaseCVModule } from './BaseCVModule'; import { _ClassificationModule } from '../../native/RnExecutorchModules'; -export class ClassificationModule { - static async load(modelSource: string | number) { - await BaseModule.load(_ClassificationModule, modelSource); - } +export class ClassificationModule extends BaseCVModule { + static module = new _ClassificationModule(); - static async forward(input: string): Promise<{ [category: string]: number }> { - return await BaseModule.forward(_ClassificationModule, input); + static async forward( + input: string + ): ReturnType<_ClassificationModule['forward']> { + return await super.forward(input); } } diff --git a/src/modules/computer_vision/ObjectDetectionModule.ts b/src/modules/computer_vision/ObjectDetectionModule.ts index 3c3dd9ddce..b9de061c62 100644 --- a/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/src/modules/computer_vision/ObjectDetectionModule.ts @@ -1,13 +1,12 @@ -import { BaseModule } from './BaseModule'; +import { BaseCVModule } from './BaseCVModule'; import { _ObjectDetectionModule } from '../../native/RnExecutorchModules'; -import { Detection } from '../../types/object_detection'; -export class ObjectDetectionModule { - static async load(modelSource: string | number) { - await BaseModule.load(_ObjectDetectionModule, modelSource); - } +export class ObjectDetectionModule extends BaseCVModule { + static module = new _ObjectDetectionModule(); - static async forward(input: string): Promise { - return await BaseModule.forward(_ObjectDetectionModule, input); + static async forward( + input: string + ): ReturnType<_ObjectDetectionModule['forward']> { + return await super.forward(input); } } diff --git a/src/modules/computer_vision/StyleTransferModule.ts b/src/modules/computer_vision/StyleTransferModule.ts index d7b55e1403..48959646c2 100644 --- a/src/modules/computer_vision/StyleTransferModule.ts +++ b/src/modules/computer_vision/StyleTransferModule.ts @@ -1,12 +1,12 @@ -import { BaseModule } from './BaseModule'; +import { BaseCVModule } from './BaseCVModule'; import { _StyleTransferModule } from '../../native/RnExecutorchModules'; -export class StyleTransferModule { - static async load(modelSource: string | number) { - await BaseModule.load(_StyleTransferModule, modelSource); - } +export class StyleTransferModule extends BaseCVModule { + static module = new _StyleTransferModule(); - static async forward(input: string): Promise { - return await BaseModule.forward(_StyleTransferModule, input); + static async forward( + input: string + ): ReturnType<_StyleTransferModule['forward']> { + return await super.forward(input); } } diff --git a/src/modules/general/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts index ca4f33894c..e6d5ef5c84 100644 --- a/src/modules/general/ExecutorchModule.ts +++ b/src/modules/general/ExecutorchModule.ts @@ -4,6 +4,8 @@ import { _ETModule } from '../../native/RnExecutorchModules'; import { ETInput, getTypeIdentifier } from '../../types/common'; export class ExecutorchModule { + static module = new _ETModule(); + static async load(modelSource: string) { if (!modelSource) return; @@ -14,7 +16,7 @@ export class ExecutorchModule { } try { - await _ETModule.loadModule(path); + await this.module.loadModule(path); } catch (e) { throw new Error(getError(e)); } @@ -28,7 +30,7 @@ export class ExecutorchModule { try { const numberArray = [...input] as number[]; - return await _ETModule.forward(numberArray, shape, inputType); + return await this.module.forward(numberArray, shape, inputType); } catch (e) { throw new Error(getError(e)); } @@ -36,7 +38,7 @@ export class ExecutorchModule { static async loadMethod(methodName: string) { try { - await _ETModule.loadMethod(methodName); + await this.module.loadMethod(methodName); } catch (e) { throw new Error(getError(e)); } diff --git a/src/native/NativeETModule.ts b/src/native/NativeETModule.ts index 6d4bfd09ee..d04da1abf7 100644 --- a/src/native/NativeETModule.ts +++ b/src/native/NativeETModule.ts @@ -9,6 +9,7 @@ export interface Spec extends TurboModule { shape: number[], inputType: number ): Promise; + loadMethod(methodName: string): Promise; } diff --git a/src/native/RnExecutorchModules.ts b/src/native/RnExecutorchModules.ts index ca79367b1b..e898216a6d 100644 --- a/src/native/RnExecutorchModules.ts +++ b/src/native/RnExecutorchModules.ts @@ -1,4 +1,8 @@ import { Platform } from 'react-native'; +import { Spec as ClassificationInterface } from './NativeClassification'; +import { Spec as ObjectDetectionInterface } from './NativeObjectDetection'; +import { Spec as StyleTransferInterface } from './NativeStyleTransfer'; +import { Spec as ETModuleInterface } from './NativeETModule'; const LINKING_ERROR = `The package 'react-native-executorch' doesn't seem to be linked. Make sure: \n\n` + @@ -72,44 +76,56 @@ const StyleTransfer = StyleTransferSpec ); class _ObjectDetectionModule { - static async forward(input: string) { + async forward( + input: string + ): ReturnType { return await ObjectDetection.forward(input); } - static async loadModule(modelSource: string | number) { + async loadModule( + modelSource: string | number + ): ReturnType { return await ObjectDetection.loadModule(modelSource); } } class _StyleTransferModule { - static async forward(input: string) { + async forward(input: string): ReturnType { return await StyleTransfer.forward(input); } - static async loadModule(modelSource: string | number) { + async loadModule( + modelSource: string | number + ): ReturnType { return await StyleTransfer.loadModule(modelSource); } } class _ClassificationModule { - static async forward(input: string) { + async forward(input: string): ReturnType { return await Classification.forward(input); } - static async loadModule(modelSource: string | number) { + async loadModule( + modelSource: string | number + ): ReturnType { return await Classification.loadModule(modelSource); } } class _ETModule { - static async forward( + async forward( input: number[], shape: number[], inputType: number - ): Promise { + ): ReturnType { return await ETModule.forward(input, shape, inputType); } - static async loadModule(modelSource: string) { + async loadModule( + modelSource: string + ): ReturnType { return await ETModule.loadModule(modelSource); } - static async loadMethod(methodName: string): Promise { + async loadMethod( + methodName: string + ): ReturnType { return await ETModule.loadMethod(methodName); } } From 6ecffa5c715891093f00db293084109016215b39 Mon Sep 17 00:00:00 2001 From: jakmro Date: Thu, 23 Jan 2025 09:58:03 +0100 Subject: [PATCH 13/15] Fix typing issue --- src/modules/computer_vision/ClassificationModule.ts | 8 ++++---- src/modules/computer_vision/ObjectDetectionModule.ts | 8 ++++---- src/modules/computer_vision/StyleTransferModule.ts | 8 ++++---- src/modules/natural_language_processing/LLMModule.ts | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/src/modules/computer_vision/ClassificationModule.ts b/src/modules/computer_vision/ClassificationModule.ts index ba1bf10718..2c6392cb52 100644 --- a/src/modules/computer_vision/ClassificationModule.ts +++ b/src/modules/computer_vision/ClassificationModule.ts @@ -4,9 +4,9 @@ import { _ClassificationModule } from '../../native/RnExecutorchModules'; export class ClassificationModule extends BaseCVModule { static module = new _ClassificationModule(); - static async forward( - input: string - ): ReturnType<_ClassificationModule['forward']> { - return await super.forward(input); + static async forward(input: string) { + return await (super.forward(input) as ReturnType< + _ClassificationModule['forward'] + >); } } diff --git a/src/modules/computer_vision/ObjectDetectionModule.ts b/src/modules/computer_vision/ObjectDetectionModule.ts index b9de061c62..c50ce02611 100644 --- a/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/src/modules/computer_vision/ObjectDetectionModule.ts @@ -4,9 +4,9 @@ import { _ObjectDetectionModule } from '../../native/RnExecutorchModules'; export class ObjectDetectionModule extends BaseCVModule { static module = new _ObjectDetectionModule(); - static async forward( - input: string - ): ReturnType<_ObjectDetectionModule['forward']> { - return await super.forward(input); + static async forward(input: string) { + return await (super.forward(input) as ReturnType< + _ObjectDetectionModule['forward'] + >); } } diff --git a/src/modules/computer_vision/StyleTransferModule.ts b/src/modules/computer_vision/StyleTransferModule.ts index 48959646c2..830a8c5113 100644 --- a/src/modules/computer_vision/StyleTransferModule.ts +++ b/src/modules/computer_vision/StyleTransferModule.ts @@ -4,9 +4,9 @@ import { _StyleTransferModule } from '../../native/RnExecutorchModules'; export class StyleTransferModule extends BaseCVModule { static module = new _StyleTransferModule(); - static async forward( - input: string - ): ReturnType<_StyleTransferModule['forward']> { - return await super.forward(input); + static async forward(input: string) { + return await (super.forward(input) as ReturnType< + _StyleTransferModule['forward'] + >); } } diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index 12ca7a576a..0dafb7d01c 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -56,7 +56,7 @@ export class LLMModule { LLM.interrupt(); } - static deleteModule() { + static delete() { LLM.deleteModule(); } } From 09995c5dc65115042f9cf4b08fd7e466f3ea627a Mon Sep 17 00:00:00 2001 From: jakmro Date: Thu, 23 Jan 2025 11:00:56 +0100 Subject: [PATCH 14/15] Update docs --- .../computer-vision/useClassification.mdx | 14 ----- .../computer-vision/useObjectDetection.mdx | 53 +++++++++---------- .../docs/computer-vision/useStyleTransfer.mdx | 14 ----- 3 files changed, 26 insertions(+), 55 deletions(-) diff --git a/docs/docs/computer-vision/useClassification.mdx b/docs/docs/computer-vision/useClassification.mdx index bb5cee7b7d..1043088b21 100644 --- a/docs/docs/computer-vision/useClassification.mdx +++ b/docs/docs/computer-vision/useClassification.mdx @@ -31,20 +31,6 @@ try { } ``` -
-Type definitions - -```typescript -interface UseClassification { - error: string | null; - isReady: boolean; - isGenerating: boolean; - forward: (input: string) => Promise<{ [category: string]: number }>; -} -``` - -
- ### Arguments **`modelSource`** diff --git a/docs/docs/computer-vision/useObjectDetection.mdx b/docs/docs/computer-vision/useObjectDetection.mdx index 2bcf387947..5de3da41cc 100644 --- a/docs/docs/computer-vision/useObjectDetection.mdx +++ b/docs/docs/computer-vision/useObjectDetection.mdx @@ -11,6 +11,7 @@ It is recommended to use models provided by us, which are available at our [Hugg ::: ## Reference + ```jsx import { useObjectDetection, SSDLITE_320_MOBILENET_V3_LARGE } from 'react-native-executorch'; @@ -45,42 +46,36 @@ interface Detection { label: keyof typeof CocoLabel; score: number; } - -interface UseObjectDetection { - error: string | null; - isReady: boolean; - isGenerating: boolean; - forward: (input: string) => Promise; -} ``` + ### Arguments `modelSource` -A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main). +A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main). For more information on that topic, you can check out the [Loading models](https://docs.swmansion.com/react-native-executorch/fundamentals/loading-models) page. ### Returns The hook returns an object with the following properties: - -| **Field** | **Type** | **Description** | -|-----------------------|---------------------------------------|------------------------------------------------------------------------------------------------------------------| -| `forward` | `(input: string) => Promise` | A function that accepts an image (url, b64) and returns an array of `Detection` objects. | -| `error` | string | null | Contains the error message if the model loading failed. | -| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | -| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | - +| Field | Type | Description | +| -------------- | ----------------------------------------- | ---------------------------------------------------------------------------------------- | +| `forward` | `(input: string) => Promise` | A function that accepts an image (url, b64) and returns an array of `Detection` objects. | +| `error` | string | null | Contains the error message if the model loading failed. | +| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | +| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | ## Running the model To run the model, you can use the `forward` method. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image. The function returns an array of `Detection` objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. For more information, please refer to the reference or type definitions. ## Detection object + The detection object is specified as follows: + ```typescript interface Bbox { x1: number; @@ -95,14 +90,17 @@ interface Detection { score: number; } ``` + The `bbox` property contains information about the bounding box of detected objects. It is represented as two points: one at the bottom-left corner of the bounding box (`x1`, `y1`) and the other at the top-right corner (`x2`, `y2`). The `label` property contains the name of the detected object, which corresponds to one of the `CocoLabels`. The `score` represents the confidence score of the detected object. - - ## Example + ```tsx -import { useObjectDetection, SSDLITE_320_MOBILENET_V3_LARGE } from 'react-native-executorch'; +import { + useObjectDetection, + SSDLITE_320_MOBILENET_V3_LARGE, +} from 'react-native-executorch'; function App() { const ssdlite = useObjectDetection({ @@ -110,18 +108,19 @@ function App() { }); const runModel = async () => { - const detections = await ssdlite.forward("https://url-to-image.jpg"); + const detections = await ssdlite.forward('https://url-to-image.jpg'); + for (const detection of detections) { - console.log("Bounding box: ", detection.bbox); - console.log("Bounding label: ", detection.label); - console.log("Bounding score: ", detection.score); + console.log('Bounding box: ', detection.bbox); + console.log('Bounding label: ', detection.label); + console.log('Bounding score: ', detection.score); } - } + }; } ``` ## Supported models -| Model | Number of classes | Class list | -| --------------------------------------------------------------------------------------------------------------- | ----------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| [SSDLite320 MobileNetV3 Large](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.ssdlite320_mobilenet_v3_large.html#torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights) | 91 | [COCO](https://github.com/software-mansion/react-native-executorch/blob/69802ee1ca161d9df00def1dabe014d36341cfa9/src/types/object_detection.ts#L14) | \ No newline at end of file +| Model | Number of classes | Class list | +| ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | +| [SSDLite320 MobileNetV3 Large](https://pytorch.org/vision/main/models/generated/torchvision.models.detection.ssdlite320_mobilenet_v3_large.html#torchvision.models.detection.SSDLite320_MobileNet_V3_Large_Weights) | 91 | [COCO](https://github.com/software-mansion/react-native-executorch/blob/69802ee1ca161d9df00def1dabe014d36341cfa9/src/types/object_detection.ts#L14) | diff --git a/docs/docs/computer-vision/useStyleTransfer.mdx b/docs/docs/computer-vision/useStyleTransfer.mdx index a1d373f2fc..c5a5e3e0d2 100644 --- a/docs/docs/computer-vision/useStyleTransfer.mdx +++ b/docs/docs/computer-vision/useStyleTransfer.mdx @@ -30,20 +30,6 @@ try { } ``` -
-Type definitions - -```typescript -interface UseStyleTransfer { - error: string | null; - isReady: boolean; - isGenerating: boolean; - forward: (input: string) => Promise; -} -``` - -
- ### Arguments **`modelSource`** From 6f51f6cf0383cbf9759b3239e393806511c3f59c Mon Sep 17 00:00:00 2001 From: jakmro Date: Fri, 24 Jan 2025 16:19:02 +0100 Subject: [PATCH 15/15] Add suggested changes v2 --- src/modules/BaseModule.ts | 36 +++++++++++++++++++ src/modules/computer_vision/BaseCVModule.ts | 20 ++--------- src/modules/general/ExecutorchModule.ts | 20 ++--------- .../natural_language_processing/LLMModule.ts | 21 ++++++----- 4 files changed, 50 insertions(+), 47 deletions(-) create mode 100644 src/modules/BaseModule.ts diff --git a/src/modules/BaseModule.ts b/src/modules/BaseModule.ts new file mode 100644 index 0000000000..a93591cee1 --- /dev/null +++ b/src/modules/BaseModule.ts @@ -0,0 +1,36 @@ +import { Image } from 'react-native'; +import { + _StyleTransferModule, + _ObjectDetectionModule, + _ClassificationModule, + _ETModule, +} from '../native/RnExecutorchModules'; +import { ResourceSource } from '../types/common'; +import { getError } from '../Error'; + +export class BaseModule { + static module: + | _StyleTransferModule + | _ObjectDetectionModule + | _ClassificationModule + | _ETModule; + + static async load(modelSource: ResourceSource) { + if (!modelSource) return; + + let path = + typeof modelSource === 'number' + ? Image.resolveAssetSource(modelSource).uri + : modelSource; + + try { + await this.module.loadModule(path); + } catch (e) { + throw new Error(getError(e)); + } + } + + static async forward(..._: any[]): Promise { + throw new Error('The forward method is not implemented.'); + } +} diff --git a/src/modules/computer_vision/BaseCVModule.ts b/src/modules/computer_vision/BaseCVModule.ts index 43d76a549d..c61987d33f 100644 --- a/src/modules/computer_vision/BaseCVModule.ts +++ b/src/modules/computer_vision/BaseCVModule.ts @@ -1,4 +1,4 @@ -import { Image } from 'react-native'; +import { BaseModule } from '../BaseModule'; import { _StyleTransferModule, _ObjectDetectionModule, @@ -6,28 +6,12 @@ import { } from '../../native/RnExecutorchModules'; import { getError } from '../../Error'; -export class BaseCVModule { +export class BaseCVModule extends BaseModule { static module: | _StyleTransferModule | _ObjectDetectionModule | _ClassificationModule; - static async load(modelSource: string | number) { - if (!modelSource) return; - - let path = modelSource; - - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - await this.module.loadModule(path); - } catch (e) { - throw new Error(getError(e)); - } - } - static async forward(input: string) { try { return await this.module.forward(input); diff --git a/src/modules/general/ExecutorchModule.ts b/src/modules/general/ExecutorchModule.ts index e6d5ef5c84..5d5990c407 100644 --- a/src/modules/general/ExecutorchModule.ts +++ b/src/modules/general/ExecutorchModule.ts @@ -1,27 +1,11 @@ -import { Image } from 'react-native'; +import { BaseModule } from '../BaseModule'; import { ETError, getError } from '../../Error'; import { _ETModule } from '../../native/RnExecutorchModules'; import { ETInput, getTypeIdentifier } from '../../types/common'; -export class ExecutorchModule { +export class ExecutorchModule extends BaseModule { static module = new _ETModule(); - static async load(modelSource: string) { - if (!modelSource) return; - - let path = modelSource; - - if (typeof modelSource === 'number') { - path = Image.resolveAssetSource(modelSource).uri; - } - - try { - await this.module.loadModule(path); - } catch (e) { - throw new Error(getError(e)); - } - } - static async forward(input: ETInput, shape: number[]) { const inputType = getTypeIdentifier(input); if (inputType === -1) { diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index 0dafb7d01c..20914eddd9 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -14,20 +14,19 @@ export class LLMModule { contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH ) { try { - let modelUrl = modelSource; - let tokenizerUrl = tokenizerSource; + let modelUrl = + typeof modelSource === 'number' + ? Image.resolveAssetSource(modelSource).uri + : modelSource; - if (typeof modelSource === 'number') { - modelUrl = Image.resolveAssetSource(modelSource).uri; - } - - if (typeof tokenizerSource === 'number') { - tokenizerUrl = Image.resolveAssetSource(tokenizerSource).uri; - } + let tokenizerUrl = + typeof tokenizerSource === 'number' + ? Image.resolveAssetSource(tokenizerSource).uri + : tokenizerSource; await LLM.loadLLM( - modelUrl as string, - tokenizerUrl as string, + modelUrl, + tokenizerUrl, systemPrompt, contextWindowLength );