diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index 3742847485..03c32ae473 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -53,3 +53,5 @@ QINT FNUZ wordlist jitpack +coreml +mobilenetv diff --git a/apps/computer-vision/app/classification/index.tsx b/apps/computer-vision/app/classification/index.tsx index e473022540..8769ba9e3c 100644 --- a/apps/computer-vision/app/classification/index.tsx +++ b/apps/computer-vision/app/classification/index.tsx @@ -13,9 +13,7 @@ export default function ClassificationScreen() { ); const [imageUri, setImageUri] = useState(''); - const model = useClassification({ - modelSource: EFFICIENTNET_V2_S, - }); + const model = useClassification({ model: EFFICIENTNET_V2_S }); const { setGlobalGenerating } = useContext(GeneratingContext); useEffect(() => { setGlobalGenerating(model.isGenerating); diff --git a/apps/computer-vision/app/image_segmentation/index.tsx b/apps/computer-vision/app/image_segmentation/index.tsx index 4b15f11074..8dfc7594e5 100644 --- a/apps/computer-vision/app/image_segmentation/index.tsx +++ b/apps/computer-vision/app/image_segmentation/index.tsx @@ -61,9 +61,7 @@ const numberToColor: number[][] = [ ]; export default function ImageSegmentationScreen() { - const model = useImageSegmentation({ - modelSource: DEEPLAB_V3_RESNET50, - }); + const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); const { setGlobalGenerating } = useContext(GeneratingContext); useEffect(() => { setGlobalGenerating(model.isGenerating); diff --git a/apps/computer-vision/app/object_detection/index.tsx b/apps/computer-vision/app/object_detection/index.tsx index 4838418575..3ce52c409f 100644 --- a/apps/computer-vision/app/object_detection/index.tsx +++ b/apps/computer-vision/app/object_detection/index.tsx @@ -20,9 +20,7 @@ export default function ObjectDetectionScreen() { height: number; }>(); - const ssdLite = useObjectDetection({ - modelSource: SSDLITE_320_MOBILENET_V3_LARGE, - }); + const ssdLite = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE }); const { setGlobalGenerating } = useContext(GeneratingContext); useEffect(() => { setGlobalGenerating(ssdLite.isGenerating); diff --git a/apps/computer-vision/app/ocr/index.tsx b/apps/computer-vision/app/ocr/index.tsx index 755f11b007..b2ba8d04dc 100644 --- a/apps/computer-vision/app/ocr/index.tsx +++ b/apps/computer-vision/app/ocr/index.tsx @@ -1,13 +1,7 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; -import { - DETECTOR_CRAFT_800, - RECOGNIZER_EN_CRNN_128, - RECOGNIZER_EN_CRNN_256, - RECOGNIZER_EN_CRNN_512, - useOCR, -} from 'react-native-executorch'; +import { useOCR, OCR_ENGLISH } from 'react-native-executorch'; import { View, StyleSheet, Image, Text, ScrollView } from 'react-native'; import ImageWithBboxes2 from '../../components/ImageWithOCRBboxes'; import React, { useContext, useEffect, useState } from 'react'; @@ -22,15 +16,7 @@ export default function OCRScreen() { height: number; }>(); - const model = useOCR({ - detectorSource: DETECTOR_CRAFT_800, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerMedium: RECOGNIZER_EN_CRNN_256, - recognizerSmall: RECOGNIZER_EN_CRNN_128, - }, - language: 'en', - }); + const model = useOCR({ model: OCR_ENGLISH }); const { setGlobalGenerating } = useContext(GeneratingContext); useEffect(() => { setGlobalGenerating(model.isGenerating); diff --git a/apps/computer-vision/app/ocr_vertical/index.tsx b/apps/computer-vision/app/ocr_vertical/index.tsx index 333d431074..040c709c63 100644 --- a/apps/computer-vision/app/ocr_vertical/index.tsx +++ b/apps/computer-vision/app/ocr_vertical/index.tsx @@ -1,13 +1,7 @@ import Spinner from 'react-native-loading-spinner-overlay'; import { BottomBar } from '../../components/BottomBar'; import { getImage } from '../../utils'; -import { - DETECTOR_CRAFT_1280, - DETECTOR_CRAFT_320, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_64, - useVerticalOCR, -} from 'react-native-executorch'; +import { useVerticalOCR, VERTICAL_OCR_ENGLISH } from 'react-native-executorch'; import { View, StyleSheet, Image, Text, ScrollView } from 'react-native'; import ImageWithBboxes2 from '../../components/ImageWithOCRBboxes'; import React, { useContext, useEffect, useState } from 'react'; @@ -22,15 +16,7 @@ export default function VerticalOCRScree() { height: number; }>(); const model = useVerticalOCR({ - detectorSources: { - detectorLarge: DETECTOR_CRAFT_1280, - detectorNarrow: DETECTOR_CRAFT_320, - }, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerSmall: RECOGNIZER_EN_CRNN_64, - }, - language: 'en', + model: VERTICAL_OCR_ENGLISH, independentCharacters: true, }); const { setGlobalGenerating } = useContext(GeneratingContext); diff --git a/apps/computer-vision/app/style_transfer/index.tsx b/apps/computer-vision/app/style_transfer/index.tsx index 91bf71238c..0075d1a532 100644 --- a/apps/computer-vision/app/style_transfer/index.tsx +++ b/apps/computer-vision/app/style_transfer/index.tsx @@ -11,9 +11,7 @@ import { GeneratingContext } from '../../context'; import ScreenWrapper from '../../ScreenWrapper'; export default function StyleTransferScreen() { - const model = useStyleTransfer({ - modelSource: STYLE_TRANSFER_CANDY, - }); + const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); const { setGlobalGenerating } = useContext(GeneratingContext); useEffect(() => { setGlobalGenerating(model.isGenerating); diff --git a/apps/llm/app/llm/index.tsx b/apps/llm/app/llm/index.tsx index f59e9e6894..f49bab99fe 100644 --- a/apps/llm/app/llm/index.tsx +++ b/apps/llm/app/llm/index.tsx @@ -12,12 +12,7 @@ import { } from 'react-native'; import SendIcon from '../../assets/icons/send_icon.svg'; import Spinner from 'react-native-loading-spinner-overlay'; -import { - LLAMA3_2_1B_QLORA, - LLAMA3_2_TOKENIZER, - LLAMA3_2_TOKENIZER_CONFIG, - useLLM, -} from 'react-native-executorch'; +import { useLLM, LLAMA3_2_1B_QLORA } from 'react-native-executorch'; import PauseIcon from '../../assets/icons/pause_icon.svg'; import ColorPalette from '../../colors'; import Messages from '../../components/Messages'; @@ -35,11 +30,7 @@ function LLMScreen() { const textInputRef = useRef(null); const { setGlobalGenerating } = useContext(GeneratingContext); - const llm = useLLM({ - modelSource: LLAMA3_2_1B_QLORA, - tokenizerSource: LLAMA3_2_TOKENIZER, - tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, - }); + const llm = useLLM({ model: LLAMA3_2_1B_QLORA }); useEffect(() => { if (llm.error) { diff --git a/apps/llm/app/llm_structured_output/index.tsx b/apps/llm/app/llm_structured_output/index.tsx index dd6a7275be..d850a61986 100644 --- a/apps/llm/app/llm_structured_output/index.tsx +++ b/apps/llm/app/llm_structured_output/index.tsx @@ -13,8 +13,6 @@ import { import SendIcon from '../../assets/icons/send_icon.svg'; import Spinner from 'react-native-loading-spinner-overlay'; import { - QWEN3_TOKENIZER, - QWEN3_TOKENIZER_CONFIG, useLLM, fixAndValidateStructuredOutput, getStructuredOutputPrompt, @@ -75,12 +73,7 @@ function LLMScreen() { const textInputRef = useRef(null); const { setGlobalGenerating } = useContext(GeneratingContext); - const llm = useLLM({ - // try out 4B model it this one struggles with following structured output - modelSource: QWEN3_1_7B_QUANTIZED, - tokenizerSource: QWEN3_TOKENIZER, - tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, - }); + const llm = useLLM({ model: QWEN3_1_7B_QUANTIZED }); // try out 4B model if 1.7B struggles with following structured output useEffect(() => { setGlobalGenerating(llm.isGenerating); @@ -89,10 +82,6 @@ function LLMScreen() { const { configure } = llm; useEffect(() => { const formattingInstructions = getStructuredOutputPrompt(responseSchema); - // const formattingInstructionsWithZod = getStructuredOutputPrompt( - // responseSchemaWithZod - // ); - const prompt = `Your goal is to parse user's messages and return them in JSON format. Don't respond to user. Simply return JSON with user's question parsed. \n${formattingInstructions}\n /no_think`; configure({ diff --git a/apps/llm/app/llm_tool_calling/index.tsx b/apps/llm/app/llm_tool_calling/index.tsx index 5d5995f882..df177f7c67 100644 --- a/apps/llm/app/llm_tool_calling/index.tsx +++ b/apps/llm/app/llm_tool_calling/index.tsx @@ -14,11 +14,9 @@ import SWMIcon from '../../assets/icons/swm_icon.svg'; import SendIcon from '../../assets/icons/send_icon.svg'; import Spinner from 'react-native-loading-spinner-overlay'; import { - HAMMER2_1_1_5B, - HAMMER2_1_TOKENIZER, - HAMMER2_1_TOKENIZER_CONFIG, useLLM, DEFAULT_SYSTEM_PROMPT, + HAMMER2_1_1_5B_QUANTIZED, } from 'react-native-executorch'; import PauseIcon from '../../assets/icons/pause_icon.svg'; import ColorPalette from '../../colors'; @@ -41,11 +39,7 @@ function LLMToolCallingScreen() { const textInputRef = useRef(null); const { setGlobalGenerating } = useContext(GeneratingContext); - const llm = useLLM({ - modelSource: HAMMER2_1_1_5B, - tokenizerSource: HAMMER2_1_TOKENIZER, - tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, - }); + const llm = useLLM({ model: HAMMER2_1_1_5B_QUANTIZED }); useEffect(() => { setGlobalGenerating(llm.isGenerating); diff --git a/apps/llm/app/voice_chat/index.tsx b/apps/llm/app/voice_chat/index.tsx index c4baaa9870..bab7dfb115 100644 --- a/apps/llm/app/voice_chat/index.tsx +++ b/apps/llm/app/voice_chat/index.tsx @@ -16,8 +16,7 @@ import { useSpeechToText, useLLM, QWEN3_0_6B_QUANTIZED, - QWEN3_TOKENIZER, - QWEN3_TOKENIZER_CONFIG, + MOONSHINE_TINY, } from 'react-native-executorch'; import PauseIcon from '../../assets/icons/pause_icon.svg'; import MicIcon from '../../assets/icons/mic_icon.svg'; @@ -68,13 +67,9 @@ function VoiceChatScreen() { const messageRecorded = useRef(false); const { setGlobalGenerating } = useContext(GeneratingContext); - const llm = useLLM({ - modelSource: QWEN3_0_6B_QUANTIZED, - tokenizerSource: QWEN3_TOKENIZER, - tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, - }); + const llm = useLLM({ model: QWEN3_0_6B_QUANTIZED }); const speechToText = useSpeechToText({ - modelName: 'moonshine', + model: MOONSHINE_TINY, windowSize: 3, overlapSeconds: 1.2, }); diff --git a/apps/speech-to-text/screens/SpeechToTextScreen.tsx b/apps/speech-to-text/screens/SpeechToTextScreen.tsx index 188c8cc50c..0d9922e5dd 100644 --- a/apps/speech-to-text/screens/SpeechToTextScreen.tsx +++ b/apps/speech-to-text/screens/SpeechToTextScreen.tsx @@ -1,4 +1,4 @@ -import { useSpeechToText } from 'react-native-executorch'; +import { MOONSHINE_TINY, useSpeechToText } from 'react-native-executorch'; import React from 'react'; import { Text, @@ -54,7 +54,7 @@ export const SpeechToTextScreen = () => { sequence, error, transcribe, - } = useSpeechToText({ modelName: 'moonshine', streamingConfig: 'balanced' }); + } = useSpeechToText({ model: MOONSHINE_TINY }); const loadAudio = async (url: string) => { const audioContext = new AudioContext({ sampleRate: 16e3 }); diff --git a/apps/text-embeddings/app/clip-embeddings/index.tsx b/apps/text-embeddings/app/clip-embeddings/index.tsx index 924d1f1a13..7a53a77dfb 100644 --- a/apps/text-embeddings/app/clip-embeddings/index.tsx +++ b/apps/text-embeddings/app/clip-embeddings/index.tsx @@ -28,8 +28,8 @@ export default function ClipEmbeddingsScreenWrapper() { } function ClipEmbeddingsScreen() { - const textModel = useTextEmbeddings(CLIP_VIT_BASE_PATCH32_TEXT); - const imageModel = useImageEmbeddings(CLIP_VIT_BASE_PATCH32_IMAGE); + const textModel = useTextEmbeddings({ model: CLIP_VIT_BASE_PATCH32_TEXT }); + const imageModel = useImageEmbeddings({ model: CLIP_VIT_BASE_PATCH32_IMAGE }); const [inputSentence, setInputSentence] = useState(''); const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState< diff --git a/apps/text-embeddings/app/text-embeddings/index.tsx b/apps/text-embeddings/app/text-embeddings/index.tsx index bc742f1dfa..3e43004dde 100644 --- a/apps/text-embeddings/app/text-embeddings/index.tsx +++ b/apps/text-embeddings/app/text-embeddings/index.tsx @@ -11,11 +11,7 @@ import { Platform, } from 'react-native'; import { Ionicons } from '@expo/vector-icons'; -import { - useTextEmbeddings, - ALL_MINILM_L6_V2, - ALL_MINILM_L6_V2_TOKENIZER, -} from 'react-native-executorch'; +import { useTextEmbeddings, ALL_MINILM_L6_V2 } from 'react-native-executorch'; import { useIsFocused } from '@react-navigation/native'; import { dotProduct } from '../../utils/math'; @@ -26,10 +22,7 @@ export default function TextEmbeddingsScreenWrapper() { } function TextEmbeddingsScreen() { - const model = useTextEmbeddings({ - modelSource: ALL_MINILM_L6_V2, - tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, - }); + const model = useTextEmbeddings({ model: ALL_MINILM_L6_V2 }); const [inputSentence, setInputSentence] = useState(''); const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState< diff --git a/docs/docs/02-hooks/01-natural-language-processing/useLLM.md b/docs/docs/02-hooks/01-natural-language-processing/useLLM.md index a24dcf04ea..58ac30a4ee 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useLLM.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useLLM.md @@ -39,18 +39,9 @@ Given computational constraints, our architecture is designed to support only on In order to load a model into the app, you need to run the following code: ```typescript -import { - useLLM, - LLAMA3_2_1B, - LLAMA3_2_TOKENIZER, - LLAMA3_2_TOKENIZER_CONFIG, -} from 'react-native-executorch'; - -const llm = useLLM({ - modelSource: LLAMA3_2_1B, - tokenizerSource: LLAMA3_2_TOKENIZER, - tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}); +import { useLLM, LLAMA3_2_1B } from 'react-native-executorch'; + +const llm = useLLM({ model: LLAMA3_2_1B }); ```
@@ -59,14 +50,18 @@ The code snippet above fetches the model from the specified URL, loads it into m ### Arguments -**`modelSource`** - `ResourceSource` that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. +**`model`** - Object containing the model source, tokenizer source, and tokenizer config source. + +- **`modelSource`** - `ResourceSource` that specifies the location of the model binary. -**`tokenizerSource`** - `ResourceSource` pointing to the JSON file which contains the tokenizer. +- **`tokenizerSource`** - `ResourceSource` pointing to the JSON file which contains the tokenizer. -**`tokenizerConfigSource`** - `ResourceSource` pointing to the JSON file which contains the tokenizer config. +- **`tokenizerConfigSource`** - `ResourceSource` pointing to the JSON file which contains the tokenizer config. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | @@ -89,14 +84,14 @@ The code snippet above fetches the model from the specified URL, loads it into m ```typescript const useLLM: ({ - modelSource, - tokenizerSource, - tokenizerConfigSource, - preventLoad = false, + model, + preventLoad, }: { - modelSource: ResourceSource; - tokenizerSource: ResourceSource; - tokenizerConfigSource: ResourceSource; + model: { + modelSource: ResourceSource; + tokenizerSource: ResourceSource; + tokenizerConfigSource: ResourceSource; + }; preventLoad?: boolean; }) => LLMType; @@ -167,11 +162,7 @@ You can use functions returned from this hooks in two manners: To perform chat completion you can use the `generate` function. There is no return value. Instead, the `response` value is updated with each token. ```tsx -const llm = useLLM({ - modelSource: LLAMA3_2_1B, - tokenizerSource: LLAMA3_2_TOKENIZER, - tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}); +const llm = useLLM({ model: LLAMA3_2_1B }); const handleGenerate = () => { const chat = [ @@ -226,11 +217,7 @@ const TOOL_DEFINITIONS: LLMTool[] = [ }, ]; -const llm = useLLM({ - modelSource: HAMMER2_1_1_5B, - tokenizerSource: HAMMER2_1_1_5B_TOKENIZER, - tokenizerConfigSource: HAMMER2_1_1_5B_TOKENIZER_CONFIG, -}); +const llm = useLLM({ model: HAMMER2_1_1_5B }); const handleGenerate = () => { const chat = [ @@ -289,11 +276,7 @@ To configure model (i.e. change system prompt, load initial conversation history In order to send a message to the model, one can use the following code: ```tsx -const llm = useLLM({ - modelSource: LLAMA3_2_1B, - tokenizerSource: LLAMA3_2_TOKENIZER, - tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}); +const llm = useLLM({ model: LLAMA3_2_1B }); const send = () => { const message = 'Hi, who are you?'; @@ -338,11 +321,7 @@ const TOOL_DEFINITIONS: LLMTool[] = [ }, ]; -const llm = useLLM({ - modelSource: HAMMER2_1_1_5B, - tokenizerSource: HAMMER2_1_1_5B_TOKENIZER, - tokenizerConfigSource: HAMMER2_1_1_5B_TOKENIZER_CONFIG, -}); +const llm = useLLM({ model: HAMMER2_1_1_5B }); useEffect(() => { llm.configure({ @@ -418,11 +397,7 @@ const responseSchemaWithZod = z.object({ currency: z.optional(z.string().meta({ description: 'Currency of offer.' })), }); -const llm = useLLM({ - modelSource: QWEN3_4B_QUANTIZED, - tokenizerSource: QWEN3_TOKENIZER, - tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}); +const llm = useLLM({ model: QWEN3_4B_QUANTIZED }); useEffect(() => { const formattingInstructions = getStructuredOutputPrompt(responseSchema); diff --git a/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md b/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md index 4115408fe2..d6a6914065 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useSpeechToText.md @@ -33,13 +33,11 @@ It is recommended to use models provided by us, which are available at our [Hugg You can obtain waveform from audio in any way most suitable to you, however in the snippet below we utilize `react-native-audio-api` library to process a mp3 file. ```typescript -import { useSpeechToText } from 'react-native-executorch'; +import { useSpeechToText, MOONSHINE_TINY } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; import * as FileSystem from 'expo-file-system'; -const { transcribe, error } = useSpeechToText({ - modelName: 'moonshine', -}); +const { transcribe, error } = useSpeechToText({ model: MOONSHINE_TINY }); const loadAudio = async (url: string) => { const audioContext = new AudioContext({ sampleRate: 16e3 }); @@ -68,41 +66,38 @@ Given that STT models can process audio no longer than 30 seconds, there is a ne ### Arguments -**`modelName`** -A literal of `"moonshine" | "whisper" | "whisperMultilingual` which serves as an identifier for which model should be used. +**`model`** - Object containing the model name, encoder source, decoder source, and tokenizer source. -**`encoderSource?`** -A string that specifies the location of a .pte file for the encoder. For further information on passing model sources, check out [Loading Models](../../01-fundamentals/02-loading-models.md). Defaults to [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) for given model. +- **`modelName`** - An enum `AvailableModels` value that serves as an identifier for which model should be used. -**`decoderSource?`** -Analogous to the encoderSource, this takes in a string which is a source for the decoder part of the model. Defaults to [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) for given model. +- **`encoderSource?`** - A string that specifies the location of a .pte file for the encoder. Defaults to [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) for given model. -**`tokenizerSource?`** -A string that specifies the location to the tokenizer for the model. This works just as the encoder and decoder do. Defaults to [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) for given model. +- **`decoderSource?`** - Analogous to the encoderSource, this takes in a string which is a source for the decoder part of the model. Defaults to [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) for given model. -**`overlapSeconds?`** -Specifies the length of overlap between consecutive audio chunks (expressed in seconds). Overrides `streamingConfig` argument. +- **`tokenizerSource?`** - A string that specifies the location to the tokenizer for the model. This works just as the encoder and decoder do. Defaults to [constants](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/modelUrls.ts) for given model. -**`windowSize?`** -Specifies the size of each audio chunk (expressed in seconds). Overrides `streamingConfig` argument. +**`overlapSeconds?`** - Specifies the length of overlap between consecutive audio chunks (expressed in seconds). Overrides `streamingConfig` argument. -**`streamingConfig?`** -Specifies config for both `overlapSeconds` and `windowSize` values. Three options are available: `fast`, `balanced` and `quality`. We discourage using `fast` config with `Whisper` model which while has the lowest latency to first token has the slowest overall speed. +**`windowSize?`** - Specifies the size of each audio chunk (expressed in seconds). Overrides `streamingConfig` argument. + +**`streamingConfig?`** - Specifies config for both `overlapSeconds` and `windowSize` values. Three options are available: `fast`, `balanced` and `quality`. We discourage using `fast` config with `Whisper` model which while has the lowest latency to first token has the slowest overall speed. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns -| Field | Type | Description | -| --------------------- | ------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `transcribe` | `(waveform: number[], audioLanguage?: SpeechToTextLanguage) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. Resolves a promise with the output transcription when the model is finished. For multilingual models, you have to specify the audioLanguage flag, which is the language of the spoken language in the audio. Returns error when called when module is in use (i.e. in process of `streamingTranscribe` action) | -| `streamingTranscribe` | `(streamingAction: STREAMING_ACTION, waveform?: number[], audioLanguage?: SpeechToTextLanguage) => Promise` | This allows for running transcription process on-line, which means where the whole audio is not known beforehand i.e. when transcribing from a live microphone feed. `streamingAction` defines the type of package sent to the model:
  • `START` - initializes the process, allows for optional `waveform` data
  • `DATA` - this package should contain consecutive audio data chunks sampled in 16k Hz
  • `STOP` - the last data chunk for this transcription, ends the transcription process and flushes internal buffers
  • Each call returns most recent transcription. Returns error when called when module is in use (i.e. processing `transcribe` call) | -| `error` | Error | undefined | Contains the error message if the model failed to load. | -| `sequence` | string | This property is updated with each generated token. If you're looking to obtain tokens as they're generated, you should use this property. | -| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | -| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | -| `configureStreaming` | (overlapSeconds?: number, windowSize?: number, streamingConfig?: 'fast' | 'balanced' | 'quality') | Configures options for the streaming algorithm:
    • `overlapSeconds` determines how much adjacent audio chunks overlap (increasing it slows down transcription, decreases probability of weird wording at the chunks intersection, setting it larger than 3 seconds generally is discouraged),
    • `windowSize` describes size of the audio chunks (increasing it speeds up the end to end transcription time, but increases latency for the first token to be returned),
    • `streamingConfig` predefined configs for `windowSize` and `overlapSeconds` values.
    Keep `windowSize + 2 * overlapSeconds <= 30`. | -| `downloadProgress` | `number` | Tracks the progress of the model download process. | +| Field | Type | Description | +| --------------------- | ------------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `transcribe` | `(waveform: number[], audioLanguage?: SpeechToTextLanguage) => Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. Resolves a promise with the output transcription when the model is finished. For multilingual models, you have to specify the audioLanguage flag, which is the language of the spoken language in the audio. Returns error when called when module is in use (i.e. in process of `streamingTranscribe` action) | +| `streamingTranscribe` | `(streamingAction: STREAMING_ACTION, waveform?: number[], audioLanguage?: SpeechToTextLanguage) => Promise` | This allows for running transcription process on-line, which means where the whole audio is not known beforehand i.e. when transcribing from a live microphone feed. `streamingAction` defines the type of package sent to the model:
    • `START` - initializes the process, allows for optional `waveform` data
    • `DATA` - this package should contain consecutive audio data chunks sampled in 16k Hz
    • `STOP` - the last data chunk for this transcription, ends the transcription process and flushes internal buffers
    Each call returns most recent transcription. Returns error when called when module is in use (i.e. processing `transcribe` call) | +| `error` | Error | undefined | Contains the error message if the model failed to load. | +| `sequence` | string | This property is updated with each generated token. If you're looking to obtain tokens as they're generated, you should use this property. | +| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | +| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | +| `configureStreaming` | (overlapSeconds?: number, windowSize?: number, streamingConfig?: 'fast' | 'balanced' | 'quality') | Configures options for the streaming algorithm:
    • `overlapSeconds` determines how much adjacent audio chunks overlap (increasing it slows down transcription, decreases probability of weird wording at the chunks intersection, setting it larger than 3 seconds generally is discouraged),
    • `windowSize` describes size of the audio chunks (increasing it speeds up the end to end transcription time, but increases latency for the first token to be returned),
    • `streamingConfig` predefined configs for `windowSize` and `overlapSeconds` values.
    Keep `windowSize + 2 * overlapSeconds <= 30`. | +| `downloadProgress` | `number` | Tracks the progress of the model download process. |
    Type definitions @@ -216,13 +211,13 @@ await model.transcribe(mySpanishAudio, SpeechToTextLanguage.Spanish); ```tsx import { Button, Text, View } from 'react-native'; -import { useSpeechToText } from 'react-native-executorch'; +import { useSpeechToText, WHISPER_TINY } from 'react-native-executorch'; import * as FileSystem from 'expo-file-system'; import { AudioContext } from 'react-native-audio-api'; function App() { const { transcribe, sequence, error } = useSpeechToText({ - modelName: 'whisper', + model: WHISPER_TINY, }); const loadAudio = async (url: string) => { @@ -255,7 +250,11 @@ function App() { ### Live data (microphone) transcription ```tsx -import { STREAMING_ACTION, useSpeechToText } from 'react-native-executorch'; +import { + STREAMING_ACTION, + useSpeechToText, + MOONSHINE_TINY, +} from 'react-native-executorch'; import LiveAudioStream from 'react-native-live-audio-stream'; import { useState } from 'react'; import { Buffer } from 'buffer'; @@ -291,7 +290,7 @@ const float32ArrayFromPCMBinaryBuffer = (b64EncodedBuffer: string) => { function App() { const [isRecording, setIsRecording] = useState(false); const speechToText = useSpeechToText({ - modelName: 'moonshine', + model: MOONSHINE_TINY, windowSize: 3, overlapSeconds: 1.2, }); diff --git a/docs/docs/02-hooks/01-natural-language-processing/useTextEmbeddings.md b/docs/docs/02-hooks/01-natural-language-processing/useTextEmbeddings.md index a8dfcbdc2d..c40d19e94d 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useTextEmbeddings.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useTextEmbeddings.md @@ -24,16 +24,9 @@ It is recommended to use models provided by us, which are available at our [Hugg ## Reference ```typescript -import { - useTextEmbeddings, - ALL_MINILM_L6_V2, - ALL_MINILM_L6_V2_TOKENIZER, -} from 'react-native-executorch'; +import { useTextEmbeddings, ALL_MINILM_L6_V2 } from 'react-native-executorch'; -const model = useTextEmbeddings({ - modelSource: ALL_MINILM_L6_V2, - tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, -}); +const model = useTextEmbeddings({ model: ALL_MINILM_L6_V2 }); try { const embedding = await model.forward('Hello World!'); @@ -44,14 +37,16 @@ try { ### Arguments -**`modelSource`** -A string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. +**`model`** - Object containing the model source and tokenizer source. -**`tokenizerSource`** -A string that specifies the location of the tokenizer JSON file. +- **`modelSource`** - A string that specifies the location of the model binary. + +- **`tokenizerSource`** - A string that specifies the location of the tokenizer JSON file. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | @@ -69,11 +64,7 @@ To run the model, you can use the `forward` method. It accepts one argument, whi ## Example ```typescript -import { - useTextEmbeddings, - ALL_MINILM_L6_V2, - ALL_MINILM_L6_V2_TOKENIZER, -} from 'react-native-executorch'; +import { useTextEmbeddings, ALL_MINILM_L6_V2 } from 'react-native-executorch'; const dotProduct = (a: number[], b: number[]) => a.reduce((sum, val, i) => sum + val * b[i], 0); @@ -86,10 +77,7 @@ const cosineSimilarity = (a: number[], b: number[]) => { }; function App() { - const model = useTextEmbeddings({ - modelSource: ALL_MINILM_L6_V2, - tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, - }); + const model = useTextEmbeddings({ model: ALL_MINILM_L6_V2 }); // ... diff --git a/docs/docs/02-hooks/01-natural-language-processing/useTokenizer.md b/docs/docs/02-hooks/01-natural-language-processing/useTokenizer.md index a5c6cd7549..23ad40803e 100644 --- a/docs/docs/02-hooks/01-natural-language-processing/useTokenizer.md +++ b/docs/docs/02-hooks/01-natural-language-processing/useTokenizer.md @@ -25,14 +25,9 @@ We are using [Hugging Face Tokenizers](https://huggingface.co/docs/tokenizers/in ## Reference ```typescript -import { - useTokenizer, - ALL_MINILM_L6_V2_TOKENIZER, -} from 'react-native-executorch'; +import { useTokenizer, ALL_MINILM_L6_V2 } from 'react-native-executorch'; -const tokenizer = useTokenizer({ - tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, -}); +const tokenizer = useTokenizer({ tokenizer: ALL_MINILM_L6_V2 }); const text = 'Hello, world!'; @@ -51,10 +46,14 @@ try { ## Arguments -**`tokenizerSource`** - A string that specifies the path or URI of the tokenizer JSON file. +**`tokenizer`** - Object containing the tokenizer source. + +- **`tokenizerSource`** - A string that specifies the location of the tokenizer JSON file. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | @@ -72,15 +71,10 @@ try { ## Example ```typescript -import { - useTokenizer, - ALL_MINILM_L6_V2_TOKENIZER, -} from 'react-native-executorch'; +import { useTokenizer, ALL_MINILM_L6_V2 } from 'react-native-executorch'; function App() { - const tokenizer = useTokenizer({ - tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, - }); + const tokenizer = useTokenizer({ tokenizer: ALL_MINILM_L6_V2 }); // ... diff --git a/docs/docs/02-hooks/02-computer-vision/useClassification.md b/docs/docs/02-hooks/02-computer-vision/useClassification.md index 30f8c6390a..cdb96f7327 100644 --- a/docs/docs/02-hooks/02-computer-vision/useClassification.md +++ b/docs/docs/02-hooks/02-computer-vision/useClassification.md @@ -17,9 +17,7 @@ It is recommended to use models provided by us, which are available at our [Hugg ```typescript import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; -const model = useClassification({ - modelSource: EFFICIENTNET_V2_S, -}); +const model = useClassification({ model: EFFICIENTNET_V2_S }); const imageUri = 'file::///Users/.../cute_puppy.png'; @@ -32,11 +30,14 @@ try { ### Arguments -**`modelSource`** -A string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | @@ -61,9 +62,7 @@ Images from external sources are stored in your application's temporary director import { useClassification, EFFICIENTNET_V2_S } from 'react-native-executorch'; function App() { - const model = useClassification({ - modulePath: EFFICIENTNET_V2_S, - }); + const model = useClassification({ model: EFFICIENTNET_V2_S }); // ... const imageUri = 'file:///Users/.../cute_puppy.png'; diff --git a/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md b/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md index f5758d60e3..96bf974c63 100644 --- a/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md +++ b/docs/docs/02-hooks/02-computer-vision/useImageEmbeddings.md @@ -30,7 +30,7 @@ import { CLIP_VIT_BASE_PATCH32_IMAGE, } from 'react-native-executorch'; -const model = useImageEmbeddings(CLIP_VIT_BASE_PATCH32_IMAGE); +const model = useImageEmbeddings({ model: CLIP_VIT_BASE_PATCH32_IMAGE }); try { const imageEmbedding = await model.forward('https://url-to-image.jpg'); @@ -41,11 +41,14 @@ try { ### Arguments -**`modelSource`** -A string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | diff --git a/docs/docs/02-hooks/02-computer-vision/useImageSegmentation.md b/docs/docs/02-hooks/02-computer-vision/useImageSegmentation.md index 8f43c87d69..fa91ef37fc 100644 --- a/docs/docs/02-hooks/02-computer-vision/useImageSegmentation.md +++ b/docs/docs/02-hooks/02-computer-vision/useImageSegmentation.md @@ -16,9 +16,7 @@ import { DEEPLAB_V3_RESNET50, } from 'react-native-executorch'; -const model = useImageSegmentation({ - modelSource: DEEPLAB_V3_RESNET50, -}); +const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); const imageUri = 'file::///Users/.../cute_cat.png'; @@ -31,11 +29,14 @@ try { ### Arguments -**`modelSource`** -A string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | @@ -67,9 +68,7 @@ Setting `resize` to true will make `forward` slower. ```typescript function App() { - const model = useImageSegmentation({ - modelSource: DEEPLAB_V3_RESNET50, - }); + const model = useImageSegmentation({ model: DEEPLAB_V3_RESNET50 }); // ... const imageUri = 'file::///Users/.../cute_cat.png'; diff --git a/docs/docs/02-hooks/02-computer-vision/useOCR.md b/docs/docs/02-hooks/02-computer-vision/useOCR.md index 93017354fb..ad25998d76 100644 --- a/docs/docs/02-hooks/02-computer-vision/useOCR.md +++ b/docs/docs/02-hooks/02-computer-vision/useOCR.md @@ -11,24 +11,10 @@ It is recommended to use models provided by us, which are available at our [Hugg ## Reference ```tsx -import { - useOCR, - DETECTOR_CRAFT_800, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_256, - RECOGNIZER_EN_CRNN_128, -} from 'react-native-executorch'; +import { useOCR, OCR_ENGLISH } from 'react-native-executorch'; function App() { - const model = useOCR({ - detectorSource: DETECTOR_CRAFT_800, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerMedium: RECOGNIZER_EN_CRNN_256, - recognizerSmall: RECOGNIZER_EN_CRNN_128, - }, - language: 'en', - }); + const model = useOCR({ model: OCR_ENGLISH }); // ... for (const ocrDetection of await model.forward('https://url-to-image.jpg')) { @@ -132,20 +118,18 @@ interface OCRDetection { ### Arguments -**`detectorSource`** - A string that specifies the location of the detector binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. +**`model`** - Object containing the detector source, recognizer sources, and language. -**`recognizerSources`** - An object that specifies locations of the recognizers binary files. Each recognizer is composed of three models tailored to process images of varying widths. - -- `recognizerLarge` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. -- `recognizerMedium` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels. -- `recognizerSmall` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels. - -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. - -**`language`** - A parameter that specifies the language of the text to be recognized by the OCR. +- **`detectorSource`** - A string that specifies the location of the detector binary. +- **`recognizerLarge`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. +- **`recognizerMedium`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels. +- **`recognizerSmall`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels. +- **`language`** - A parameter that specifies the language of the text to be recognized by the OCR. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns The hook returns an object with the following properties: @@ -185,24 +169,10 @@ The `text` property contains the text recognized within detected text region. Th ## Example ```tsx -import { - useOCR, - DETECTOR_CRAFT_800, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_256, - RECOGNIZER_EN_CRNN_128, -} from 'react-native-executorch'; +import { useOCR, OCR_ENGLISH } from 'react-native-executorch'; function App() { - const model = useOCR({ - detectorSource: DETECTOR_CRAFT_800, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerMedium: RECOGNIZER_EN_CRNN_256, - recognizerSmall: RECOGNIZER_EN_CRNN_128, - }, - language: 'en', - }); + const model = useOCR({ model: OCR_ENGLISH }); const runModel = async () => { const ocrDetections = await model.forward('https://url-to-image.jpg'); diff --git a/docs/docs/02-hooks/02-computer-vision/useObjectDetection.md b/docs/docs/02-hooks/02-computer-vision/useObjectDetection.md index 53a48369b2..0eca74f5bb 100644 --- a/docs/docs/02-hooks/02-computer-vision/useObjectDetection.md +++ b/docs/docs/02-hooks/02-computer-vision/useObjectDetection.md @@ -18,9 +18,7 @@ import { } from 'react-native-executorch'; function App() { - const ssdlite = useObjectDetection({ - modelSource: SSDLITE_320_MOBILENET_V3_LARGE, // alternatively, you can use require(...) - }); + const ssdlite = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE }); // ... for (const detection of await ssdlite.forward('https://url-to-image.jpg')) { @@ -54,11 +52,14 @@ interface Detection { ### Arguments -**`modelSource`** - A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main). -For more information on that topic, you can check out the [Loading models](../../01-fundamentals/02-loading-models.md) page. +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the path to the model file. You can download the model from our [HuggingFace repository](https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/tree/main). **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns The hook returns an object with the following properties: @@ -106,9 +107,7 @@ import { } from 'react-native-executorch'; function App() { - const ssdlite = useObjectDetection({ - modelSource: SSDLITE_320_MOBILENET_V3_LARGE, - }); + const ssdlite = useObjectDetection({ model: SSDLITE_320_MOBILENET_V3_LARGE }); const runModel = async () => { const detections = await ssdlite.forward('https://url-to-image.jpg'); diff --git a/docs/docs/02-hooks/02-computer-vision/useStyleTransfer.md b/docs/docs/02-hooks/02-computer-vision/useStyleTransfer.md index 3fefe6d934..a48b95cd7a 100644 --- a/docs/docs/02-hooks/02-computer-vision/useStyleTransfer.md +++ b/docs/docs/02-hooks/02-computer-vision/useStyleTransfer.md @@ -16,9 +16,7 @@ import { STYLE_TRANSFER_CANDY, } from 'react-native-executorch'; -const model = useStyleTransfer({ - modelSource: STYLE_TRANSFER_CANDY, -}); +const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); const imageUri = 'file::///Users/.../cute_cat.png'; @@ -31,11 +29,14 @@ try { ### Arguments -**`modelSource`** -A string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns | Field | Type | Description | @@ -58,9 +59,7 @@ Images from external sources and the generated image are stored in your applicat ```typescript function App() { - const model = useStyleTransfer({ - modelSource: STYLE_TRANSFER_CANDY, - }); + const model = useStyleTransfer({ model: STYLE_TRANSFER_CANDY }); // ... const imageUri = 'file::///Users/.../cute_cat.png'; diff --git a/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md b/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md index 1cd8668dcc..b449d9f07c 100644 --- a/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md +++ b/docs/docs/02-hooks/02-computer-vision/useVerticalOCR.md @@ -15,25 +15,11 @@ It is recommended to use models provided by us, which are available at our [Hugg ## Reference ```tsx -import { - DETECTOR_CRAFT_1280, - DETECTOR_CRAFT_320, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_64, - useVerticalOCR, -} from 'react-native-executorch'; +import { useVerticalOCR, VERTICAL_OCR_ENGLISH } from 'react-native-executorch'; function App() { const model = useVerticalOCR({ - detectorSources: { - detectorLarge: DETECTOR_CRAFT_1280, - detectorNarrow: DETECTOR_CRAFT_320, - }, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerSmall: RECOGNIZER_EN_CRNN_64, - }, - language: 'en', + model: VERTICAL_OCR_ENGLISH, independentCharacters: true, }); @@ -143,19 +129,12 @@ interface OCRDetection { ### Arguments -**`detectorSources`** - An object that specifies the location of the detectors binary files. Each detector is composed of two models tailored to process images of varying widths. +**`model`** - Object containing the detector sources, recognizer sources, and language. -- `detectorLarge` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels. -- `detectorNarrow` - A string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels. - -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. - -**`recognizerSources`** - An object that specifies the locations of the recognizers binary files. Each recognizer is composed of two models tailored to process images of varying widths. - -- `recognizerLarge` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. -- `recognizerSmall` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels. - -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. +- **`detectorLarge`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels. +- **`detectorNarrow`** - A string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels. +- **`recognizerLarge`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. +- **`recognizerSmall`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels. **`language`** - A parameter that specifies the language of the text to be recognized by the OCR. @@ -163,6 +142,8 @@ For more information, take a look at [loading models](../../01-fundamentals/02-l **`preventLoad?`** - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. + ### Returns The hook returns an object with the following properties: @@ -202,25 +183,11 @@ The `text` property contains the text recognized within detected text region. Th ## Example ```tsx -import { - DETECTOR_CRAFT_1280, - DETECTOR_CRAFT_320, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_64, - useVerticalOCR, -} from 'react-native-executorch'; +import { useVerticalOCR, VERTICAL_OCR_ENGLISH } from 'react-native-executorch'; function App() { const model = useVerticalOCR({ - detectorSources: { - detectorLarge: DETECTOR_CRAFT_1280, - detectorNarrow: DETECTOR_CRAFT_320, - }, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerSmall: RECOGNIZER_EN_CRNN_64, - }, - language: 'en', + model: VERTICAL_OCR_ENGLISH, independentCharacters: true, }); diff --git a/docs/docs/02-hooks/03-executorch-bindings/useExecutorchModule.md b/docs/docs/02-hooks/03-executorch-bindings/useExecutorchModule.md index e13673ed8b..55d5d8e2b8 100644 --- a/docs/docs/02-hooks/03-executorch-bindings/useExecutorchModule.md +++ b/docs/docs/02-hooks/03-executorch-bindings/useExecutorchModule.md @@ -20,7 +20,9 @@ const executorchModule = useExecutorchModule({ }); ``` -The `modelSource` parameter expects a location string pointing to the model binary. For more details on how to specify model sources, refer to the [loading models](../../01-fundamentals/02-loading-models.md) documentation. +The `modelSource` parameter expects a location string pointing to the model binary. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ### Arguments diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md index a3aef792a8..14d3b4d748 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/LLMModule.md @@ -7,16 +7,7 @@ TypeScript API implementation of the [useLLM](../../02-hooks/01-natural-language ## Reference ```typescript -import { - LLAMA3_2_1B_QLORA, - LLAMA3_2_TOKENIZER, - LLAMA3_2_TOKENIZER_CONFIG, - LLMModule, -} from 'react-native-executorch'; - -const printDownloadProgress = (progress: number) => { - console.log(progress); -}; +import { LLMModule, LLAMA3_2_1B_QLORA } from 'react-native-executorch'; // Creating an instance const llm = new LLMModule({ @@ -25,12 +16,7 @@ const llm = new LLMModule({ }); // Loading the model -await llm.load({ - modelSource: LLAMA3_2_1B_QLORA, - tokenizerSource: LLAMA3_2_TOKENIZER, - tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, - onDownloadProgressCallback: printDownloadProgress, -}); +await llm.load(LLAMA3_2_1B_QLORA, (progress) => console.log(progress)); // Running the model await llm.sendMessage('Hello, World!'); @@ -44,18 +30,18 @@ llm.delete(); ### Methods -| Method | Type | Description | -| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `constructor` | `({tokenCallback?: (token: string) => void, responseCallback?: (response: string) => void, messageHistoryCallback?: (messageHistory: Message[]) => void})` | Creates a new instance of LLMModule with optional callbacks. | -| `load` | `({modelSource: ResourceSource, tokenizerSource: ResourceSource, tokenizerConfigSource: ResourceSource, onDownloadProgressCallback?: (downloadProgress: number) => void}) => Promise` | Loads the model. Checkout the [loading the model](#loading-the-model) section for details. | -| `setTokenCallback` | `{tokenCallback: (token: string) => void}) => void` | Sets new token callback. | -| `generate` | `(messages: Message[], tools?: LLMTool[]) => Promise` | Runs model to complete chat passed in `messages` argument. It doesn't manage conversation context. | -| `forward` | `(input: string) => Promise` | Runs model inference with raw input string. You need to provide entire conversation and prompt (in correct format and with special tokens!) in input string to this method. It doesn't manage conversation context. It is intended for users that need access to the model itself without any wrapper. If you want a simple chat with model the consider using`sendMessage` | -| `configure` | `({chatConfig?: Partial, toolsConfig?: ToolsConfig}) => void` | Configures chat and tool calling. See more details in [configuring the model](#configuring-the-model). | -| `sendMessage` | `(message: string) => Promise` | Method to add user message to conversation. After model responds it will call `messageHistoryCallback()`containing both user message and model response. It also returns them. | -| `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion it will call `messageHistoryCallback()` containing new history. It also returns it. | -| `delete` | `() => void` | Method to delete the model from memory. Note you cannot delete model while it's generating. You need to interrupt it first and make sure model stopped generation. | -| `interrupt` | `() => void` | Interrupts model generation. It may return one more token after interrupt. | +| Method | Type | Description | +| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `constructor` | `({tokenCallback?: (token: string) => void, responseCallback?: (response: string) => void, messageHistoryCallback?: (messageHistory: Message[]) => void})` | Creates a new instance of LLMModule with optional callbacks. | +| `load` | `(model: { modelSource: ResourceSource; tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model. | +| `setTokenCallback` | `{tokenCallback: (token: string) => void}) => void` | Sets new token callback. | +| `generate` | `(messages: Message[], tools?: LLMTool[]) => Promise` | Runs model to complete chat passed in `messages` argument. It doesn't manage conversation context. | +| `forward` | `(input: string) => Promise` | Runs model inference with raw input string. You need to provide entire conversation and prompt (in correct format and with special tokens!) in input string to this method. It doesn't manage conversation context. It is intended for users that need access to the model itself without any wrapper. If you want a simple chat with model the consider using`sendMessage` | +| `configure` | `({chatConfig?: Partial, toolsConfig?: ToolsConfig}) => void` | Configures chat and tool calling. See more details in [configuring the model](#configuring-the-model). | +| `sendMessage` | `(message: string) => Promise` | Method to add user message to conversation. After model responds it will call `messageHistoryCallback()`containing both user message and model response. It also returns them. | +| `deleteMessage` | `(index: number) => void` | Deletes all messages starting with message on `index` position. After deletion it will call `messageHistoryCallback()` containing new history. It also returns it. | +| `delete` | `() => void` | Method to delete the model from memory. Note you cannot delete model while it's generating. You need to interrupt it first and make sure model stopped generation. | +| `interrupt` | `() => void` | Interrupts model generation. It may return one more token after interrupt. |
    Type definitions @@ -104,15 +90,19 @@ To create a new instance of LLMModule, use the constructor with optional callbac Then, to load the model, use the `load` method. It accepts an object with the following fields: -**`modelSource`** - A string that specifies the location of the model binary. +**`model`** - Object containing the model source, tokenizer source, and tokenizer config source. -**`tokenizerSource`** - URL to the JSON file which contains the tokenizer. +- **`modelSource`** - `ResourceSource` specifying the location of the model binary. -**`tokenizerConfigSource`** - URL to the JSON file which contains the tokenizer config. +- **`tokenizerSource`** - `ResourceSource` specifying the location of the tokenizer. + +- **`tokenizerConfigSource`** - `ResourceSource` specifying the location of the tokenizer config. **`onDownloadProgressCallback`** - (Optional) Function called on download progress. -This method returns a promise, which can resolve to an error or void. It only works in managed chat (i.e. when you use `sendMessage`) +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Listening for download progress diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md index 7eed5211f9..104c78b0b0 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/SpeechToTextModule.md @@ -7,7 +7,7 @@ TypeScript API implementation of the [useSpeechToText](../../02-hooks/01-natural ## Reference ```typescript -import { SpeechToTextModule } from 'react-native-executorch'; +import { SpeechToTextModule, MOONSHINE_TINY } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; import * as FileSystem from 'expo-file-system'; @@ -32,10 +32,7 @@ const stt = new SpeechToTextModule({ }); // Loading the model -await stt.load({ - modelName: 'moonshine', - onDownloadProgressCallback: (progress) => console.log(progress), -}); +await stt.load(MOONSHINE_TINY, (progress) => console.log(progress)); // Loading the audio and running the model const waveform = await loadAudio(audioUrl); @@ -44,15 +41,15 @@ const transcribedText = await stt.transcribe(waveform); ### Methods -| Method | Type | Description | -| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `constructor` | `({transcribeCallback?: (sequence: string) => void, overlapSeconds?: number, windowSize?: number, streamingConfig?: 'fast' \| 'balanced' \| 'quality'})` | Creates a new instance of SpeechToTextModule with an optional transcription callback and streaming configuration. | -| `load` | `({modelName: 'whisper' \| 'moonshine' \| 'whisperMultilingual', encoderSource?: ResourceSource, decoderSource?: ResourceSource, tokenizerSource?: ResourceSource, onDownloadProgressCallback?: (downloadProgress: number) => void}) => Promise` | Loads the model specified with `modelName`, where `encoderSource`, `decoderSource`, `tokenizerSource` are strings specifying the location of the binaries for the models. `onDownloadProgressCallback` allows you to monitor the current progress of the model download | -| `transcribe` | `(waveform: number[], audioLanguage?: SpeechToTextLanguage): Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. Resolves a promise with the output transcription when the model is finished. For multilingual models, you have to specify the audioLanguage flag, which is the language of the spoken language in the audio. | -| `streamingTranscribe` | `(streamingAction: STREAMING_ACTION, waveform?: number[], audioLanguage?: SpeechToTextLanguage) => Promise` | This allows for running transcription process on-line, which means where the whole audio is not known beforehand i.e. when transcribing from a live microphone feed. `streamingAction` defines the type of package sent to the model:
  • `START` - initializes the process, allows for optional `waveform` data
  • `DATA` - this package should contain consecutive audio data chunks sampled in 16k Hz
  • `STOP` - the last data chunk for this transcription, ends the transcription process and flushes internal buffers
  • Each call returns most recent transcription. Returns error when called when module is in use (i.e. processing `transcribe` call) | -| `encode` | `(waveform: number[]) => Promise` | Runs the encoding part of the model. It doesn't return the encodings. Instead, it stores the result internally, reducing data transfer overhead. | -| `decode` | `(tokens: number[]) => Promise` | Runs the decoder of the model. Returns a single token representing the next token in the output. It uses internal cached encodings from the most recent `encode` call, meaning that you have to call `encode` prior to decoding. | -| `configureStreaming` | `(overlapSeconds?: number, windowSize?: number, streamingConfig?: 'fast' \| 'balanced' \| 'quality') => void` | Configures options for the streaming algorithm:
    • `overlapSeconds` determines how much adjacent audio chunks overlap (increasing it slows down transcription, decreases probability of weird wording at the chunks intersection, setting it larger than 3 seconds is generally discouraged),
    • `windowSize` describes size of the audio chunks (increasing it speeds up the end to end transcription time, but increases latency for the first token to be returned),
    • `streamingConfig` predefined configs for `windowSize` and `overlapSeconds` values.
    Keep `windowSize + 2 * overlapSeconds <= 30`. | +| Method | Type | Description | +| --------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `constructor` | `({transcribeCallback?: (sequence: string) => void, overlapSeconds?: number, windowSize?: number, streamingConfig?: 'fast' \| 'balanced' \| 'quality'})` | Creates a new instance of SpeechToTextModule with an optional transcription callback and streaming configuration. | +| `load` | `(model: { modelName: AvailableModels; encoderSource?: ResourceSource; decoderSource?: ResourceSource; tokenizerSource?: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model specified with `modelName`, where `encoderSource`, `decoderSource`, `tokenizerSource` are strings specifying the location of the binaries for the models. `onDownloadProgressCallback` allows you to monitor the current progress of the model download | +| `transcribe` | `(waveform: number[], audioLanguage?: SpeechToTextLanguage): Promise` | Starts a transcription process for a given input array, which should be a waveform at 16kHz. Resolves a promise with the output transcription when the model is finished. For multilingual models, you have to specify the audioLanguage flag, which is the language of the spoken language in the audio. | +| `streamingTranscribe` | `(streamingAction: STREAMING_ACTION, waveform?: number[], audioLanguage?: SpeechToTextLanguage) => Promise` | This allows for running transcription process on-line, which means where the whole audio is not known beforehand i.e. when transcribing from a live microphone feed. `streamingAction` defines the type of package sent to the model:
    • `START` - initializes the process, allows for optional `waveform` data
    • `DATA` - this package should contain consecutive audio data chunks sampled in 16k Hz
    • `STOP` - the last data chunk for this transcription, ends the transcription process and flushes internal buffers
    Each call returns most recent transcription. Returns error when called when module is in use (i.e. processing `transcribe` call) | +| `encode` | `(waveform: number[]) => Promise` | Runs the encoding part of the model. It doesn't return the encodings. Instead, it stores the result internally, reducing data transfer overhead. | +| `decode` | `(tokens: number[]) => Promise` | Runs the decoder of the model. Returns a single token representing the next token in the output. It uses internal cached encodings from the most recent `encode` call, meaning that you have to call `encode` prior to decoding. | +| `configureStreaming` | `(overlapSeconds?: number, windowSize?: number, streamingConfig?: 'fast' \| 'balanced' \| 'quality') => void` | Configures options for the streaming algorithm:
    • `overlapSeconds` determines how much adjacent audio chunks overlap (increasing it slows down transcription, decreases probability of weird wording at the chunks intersection, setting it larger than 3 seconds is generally discouraged),
    • `windowSize` describes size of the audio chunks (increasing it speeds up the end to end transcription time, but increases latency for the first token to be returned),
    • `streamingConfig` predefined configs for `windowSize` and `overlapSeconds` values.
    Keep `windowSize + 2 * overlapSeconds <= 30`. |
    Type definitions @@ -66,6 +63,12 @@ enum STREAMING_ACTION { STOP, } +enum AvailableModels { + WHISPER = 'whisper', + MOONSHINE = 'moonshine', + WHISPER_MULTILINGUAL = 'whisperMultilingual', +} + enum SpeechToTextLanguage { Afrikaans = 'af', Albanian = 'sq', @@ -161,17 +164,21 @@ To create a new instance of SpeechToTextModule, use the constructor with optiona Then, to load the model, use the `load` method. It accepts an object with the following fields: -**`modelName`** - Identifier for which model to use ('whisper', 'moonshine', or 'whisperMultilingual'). +**`model`** - Object containing the model name, encoder source, decoder source, and tokenizer source. -**`encoderSource`** - (Optional) String that specifies the location of the encoder binary. +- **`modelName`** - Identifier for which model to use ('whisper', 'moonshine', or 'whisperMultilingual'). -**`decoderSource`** - (Optional) String that specifies the location of the decoder binary. +- **`encoderSource`** - (Optional) String that specifies the location of the encoder binary. -**`tokenizerSource`** - (Optional) String that specifies the location of the tokenizer. +- **`decoderSource`** - (Optional) String that specifies the location of the decoder binary. + +- **`tokenizerSource`** - (Optional) String that specifies the location of the tokenizer. **`onDownloadProgressCallback`** - (Optional) Function that will be called on download progress. -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md index 5c6b33c4b2..7f59268f97 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md @@ -10,23 +10,25 @@ TypeScript API implementation of the [useTextEmbeddings](../../02-hooks/01-natur import { TextEmbeddingsModule, ALL_MINILM_L6_V2, - All_MINILM_L6_V2_TOKENIZER, } from 'react-native-executorch'; +// Creating an instance +const textEmbeddingsModule = new TextEmbeddingsModule(); + // Loading the model -await TextEmbeddingsModule.load(ALL_MINILM_L6_V2, All_MINILM_L6_V2_TOKENIZER); +await textEmbeddingsModule.load(ALL_MINILM_L6_V2); // Running the model -const embedding = await TextEmbeddingsModule.forward('Hello World!'); +const embedding = await textEmbeddingsModule.forward('Hello World!'); ``` ### Methods -| Method | Type | Description | -| -------------------- | ------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(modelSource: ResourceSource, tokenizerSource: ResourceSource): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary, `tokenizerSource` is a string that specifies the location of the tokenizer JSON file. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` is a text that will be embedded. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| -------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: { modelSource: ResourceSource; tokenizerSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary, `tokenizerSource` is a string that specifies the location of the tokenizer JSON file. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` is a text that will be embedded. | +| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
    Type definitions @@ -39,7 +41,18 @@ type ResourceSource = string | number | object; ## Loading the model -To load the model, use the `load` method. It accepts the `modelSource` which is a string that specifies the location of the model binary, `tokenizerSource` which is a string that specifies the location of the tokenizer JSON file. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +To load the model, use the `load` method. It accepts an object: + +**`model`** - Object containing the model source and tokenizer source. + +- **`modelSource`** - A string that specifies the location of the model binary. +- **`tokenizerSource`** - A string that specifies the location of the tokenizer JSON file. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/01-natural-language-processing/TokenizerModule.md b/docs/docs/03-typescript-api/01-natural-language-processing/TokenizerModule.md index a68f66b53f..41ad2b0279 100644 --- a/docs/docs/03-typescript-api/01-natural-language-processing/TokenizerModule.md +++ b/docs/docs/03-typescript-api/01-natural-language-processing/TokenizerModule.md @@ -7,48 +7,48 @@ TypeScript API implementation of the [useTokenizer](../../02-hooks/01-natural-la ## Reference ```typescript -import { - TokenizerModule, - ALL_MINILM_L6_V2_TOKENIZER, -} from 'react-native-executorch'; +import { TokenizerModule, ALL_MINILM_L6_V2 } from 'react-native-executorch'; + +// Creating an instance +const tokenizerModule = new TokenizerModule(); // Load the tokenizer -await TokenizerModule.load(ALL_MINILM_L6_V2_TOKENIZER); +await tokenizerModule.load(ALL_MINILM_L6_V2); console.log('Tokenizer loaded'); // Get tokenizers vocabulary size -const vocabSize = await TokenizerModule.getVocabSize(); +const vocabSize = await tokenizerModule.getVocabSize(); console.log('Vocabulary size:', vocabSize); const text = 'Hello, world!'; // Tokenize the text -const tokens = await TokenizerModule.encode(text); +const tokens = await tokenizerModule.encode(text); console.log('Token IDs:', tokens); // Decode the tokens back to text -const decoded = await TokenizerModule.decode(tokens); +const decoded = await tokenizerModule.decode(tokens); console.log('Decoded text:', decoded); // Get the token ID for a specific token -const tokenId = await TokenizerModule.tokenToId('hello'); +const tokenId = await tokenizerModule.tokenToId('hello'); console.log('Token ID for "Hello":', tokenId); // Get the token for a specific ID -const token = await TokenizerModule.idToToken(tokenId); +const token = await tokenizerModule.idToToken(tokenId); console.log('Token for ID:', token); ``` ### Methods -| Method | Type | Description | -| -------------- | -------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------ | -| `load` | `(tokenizerSource: ResourceSource): Promise` | Loads the tokenizer from the specified source. `tokenizerSource` is a string that points to the location of the tokenizer JSON file. | -| `encode` | `(input: string): Promise` | Converts a string into an array of token IDs. | -| `decode` | `(input: number[]): Promise` | Converts an array of token IDs into a string. | -| `getVocabSize` | `(): Promise` | Returns the size of the tokenizer's vocabulary. | -| `idToToken` | `(tokenId: number): Promise` | Returns the token associated to the ID. | -| `tokenToId` | `(token: string): Promise` | Returns the ID associated to the token. | +| Method | Type | Description | +| -------------- | -------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------ | +| `load` | `(tokenizer: { tokenizerSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the tokenizer from the specified source. `tokenizerSource` is a string that points to the location of the tokenizer JSON file. | +| `encode` | `(input: string): Promise` | Converts a string into an array of token IDs. | +| `decode` | `(input: number[]): Promise` | Converts an array of token IDs into a string. | +| `getVocabSize` | `(): Promise` | Returns the size of the tokenizer's vocabulary. | +| `idToToken` | `(tokenId: number): Promise` | Returns the token associated to the ID. | +| `tokenToId` | `(token: string): Promise` | Returns the ID associated to the token. |
    Type definitions diff --git a/docs/docs/03-typescript-api/02-computer-vision/ClassificationModule.md b/docs/docs/03-typescript-api/02-computer-vision/ClassificationModule.md index ca1cf99b58..7ddf0f0183 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/ClassificationModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/ClassificationModule.md @@ -14,22 +14,23 @@ import { const imageUri = 'path/to/image.png'; -const module = new ClassificationModule(); +// Creating an instance +const classificationModule = new ClassificationModule(); // Loading the model -await module.load(EFFICIENTNET_V2_S); +await classificationModule.load(EFFICIENTNET_V2_S); // Running the model -const classesWithProbabilities = await module.forward(imageUri); +const classesWithProbabilities = await classificationModule.forward(imageUri); ``` ### Methods -| Method | Type | Description | -| --------- | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `load` | `(modelSource: ResourceSource, onDownloadProgressCallback: (_: number) => void () => {}): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | -| `forward` | `(input: string): Promise<{ [category: string]: number }>` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | -| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. | +| Method | Type | Description | +| --------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `load` | `(model: { modelSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | +| `forward` | `(input: string): Promise<{ [category: string]: number }>` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | +| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. |
    Type definitions @@ -42,7 +43,17 @@ type ResourceSource = string | number | object; ## Loading the model -To load the model, create a new instance of the module and use the `load` method on it. It accepts the `modelSource` which is a string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +To load the model, create a new instance of the module and use the `load` method on it. It accepts an object: + +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/02-computer-vision/ImageEmbeddingsModule.md b/docs/docs/03-typescript-api/02-computer-vision/ImageEmbeddingsModule.md index 04dd942db6..2333f57b49 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/ImageEmbeddingsModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/ImageEmbeddingsModule.md @@ -9,25 +9,28 @@ TypeScript API implementation of the [useImageEmbeddings](../../02-hooks/02-comp ```typescript import { ImageEmbeddingsModule, - CLIP_VIT_BASE_PATCH_32_IMAGE_ENCODER, + CLIP_VIT_BASE_PATCH32_IMAGE, } from 'react-native-executorch'; +// Creating an instance +const imageEmbeddingsModule = new ImageEmbeddingsModule(); + // Loading the model -await ImageEmbeddingsModule.load(CLIP_VIT_BASE_PATCH_32_IMAGE_ENCODER); +await imageEmbeddingsModule.load(CLIP_VIT_BASE_PATCH32_IMAGE); // Running the model -const embedding = await ImageEmbeddingsModule.forward( +const embedding = await imageEmbeddingsModule.forward( 'https://url-to-image.jpg' ); ``` ### Methods -| Method | Type | Description | -| -------------------- | ----------------------------------------------------- | ------------------------------------------------------------------------------------------------- | -| `load` | `(modelSource: ResourceSource): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` is a URI/URL to image that will be embedded. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| -------------------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------- | +| `load` | `(model: { modelSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` is a URI/URL to image that will be embedded. | +| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
    Type definitions @@ -40,7 +43,17 @@ type ResourceSource = string | number | object; ## Loading the model -To load the model, use the `load` method. It accepts the `modelSource` which is a string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +To load the model, use the `load` method. It accepts an object: + +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/02-computer-vision/ImageSegmentationModule.md b/docs/docs/03-typescript-api/02-computer-vision/ImageSegmentationModule.md index 7934717db6..9d4e1e9291 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/ImageSegmentationModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/ImageSegmentationModule.md @@ -14,20 +14,21 @@ import { const imageUri = 'path/to/image.png'; -const module = new ImageSegmentationModule(); +// Creating an instance +const imageSegmentationModule = new ImageSegmentationModule(); // Loading the model -await module.load(DEEPLAB_V3_RESNET50); +await imageSegmentationModule.load(DEEPLAB_V3_RESNET50); // Running the model -const outputDict = await module.forward(imageUri); +const outputDict = await imageSegmentationModule.forward(imageUri); ``` ### Methods | Method | Type | Description | | --------- | ---------------------------------------------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(modelSource: ResourceSource, onDownloadProgressCallback: (_: number) => void () => {}): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | +| `load` | `(model: { modelSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | | `forward` | `(input: string, classesOfInterest?: DeeplabLabel[], resize?: boolean) => Promise<{[key in DeeplabLabel]?: number[]}>` | Executes the model's forward pass, where :
    \* `input` can be a fetchable resource or a Base64-encoded string.
    \* `classesOfInterest` is an optional list of `DeeplabLabel` used to indicate additional arrays of probabilities to output (see section "Running the model"). The default is an empty list.
    \* `resize` is an optional boolean to indicate whether the output should be resized to the original image dimensions, or left in the size of the model (see section "Running the model"). The default is `false`.

    The return is a dictionary containing:
    \* for the key `DeeplabLabel.ARGMAX` an array of integers corresponding to the most probable class for each pixel
    \* an array of floats for each class from `classesOfInterest` corresponding to the probabilities for this class. | | `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. | @@ -42,7 +43,17 @@ type ResourceSource = string | number | object; ## Loading the model -To load the model, create a new instance of the module and use the `load` method on it. It accepts the `modelSource` which is a string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +To load the model, create a new instance of the module and use the `load` method on it. It accepts an object: + +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md b/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md index 4785d150a0..f709ffe1a4 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/OCRModule.md @@ -7,48 +7,31 @@ TypeScript API implementation of the [useOCR](../../02-hooks/02-computer-vision/ ## Reference ```typescript -import { - OCRModule, - DETECTOR_CRAFT_800, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_256, - RECOGNIZER_EN_CRNN_128, -} from 'react-native-executorch'; +import { OCRModule, OCR_ENGLISH } from 'react-native-executorch'; const imageUri = 'path/to/image.png'; +// Creating an instance +const ocrModule = new OCRModule(); + // Loading the model -await OCRModule.load({ - detectorSource: DETECTOR_CRAFT_800, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerMedium: RECOGNIZER_EN_CRNN_256, - recognizerSmall: RECOGNIZER_EN_CRNN_128, - }, - language: 'en', -}); +await ocrModule.load(OCR_ENGLISH); // Running the model -const ocrDetections = await OCRModule.forward(imageUri); +const detections = await ocrModule.forward(imageUri); ``` ### Methods -| Method | Type | Description | -| -------------------- | ------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------- | -| `load` | `(detectorSource: string, recognizerSources: RecognizerSources, language: OCRLanguage): Promise` | Loads the detector and recognizers, which sources are represented by `RecognizerSources`. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| -------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: { detectorSource: ResourceSource; recognizerLarge: ResourceSource; recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; language: OCRLanguage }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `detectorSource` is a string that specifies the location of the detector binary, `recognizerLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels, `recognizerMedium` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels, `recognizerSmall` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels, and `language` is a parameter that specifies the language of the text to be recognized by the OCR. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | +| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
    Type definitions ```typescript -interface RecognizerSources { - recognizerLarge: string | number; - recognizerMedium: string | number; - recognizerSmall: string | number; -} - type OCRLanguage = | 'abq' | 'ady' @@ -131,25 +114,21 @@ interface OCRDetection { ## Loading the model -To load the model, use the `load` method. It accepts: +To load the model, use the `load` method. It accepts an object: -**`detectorSource`** - A string that specifies the location of the detector binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. +**`model`** - Object containing the detector source, recognizer sources, and language. -**`recognizerSources`** - An object that specifies locations of the recognizers binary files. Each recognizer is composed of three models tailored to process images of varying widths. +- **`detectorSource`** - A string that specifies the location of the detector binary. +- **`recognizerLarge`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. +- **`recognizerMedium`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels. +- **`recognizerSmall`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels. +- **`language`** - A parameter that specifies the language of the text to be recognized by the OCR. -- `recognizerLarge` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. -- `recognizerMedium` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 256 pixels. -- `recognizerSmall` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 128 pixels. - -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. - -**`language`** - A parameter that specifies the language of the text to be recognized by the OCR. +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. This method returns a promise, which can resolve to an error or void. -## Listening for download progress - -To subscribe to the download progress event, you can use the `onDownloadProgress` method. It accepts a callback function that will be called whenever the download progress changes. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/02-computer-vision/ObjectDetectionModule.md b/docs/docs/03-typescript-api/02-computer-vision/ObjectDetectionModule.md index 9b2ee878f2..07506a5c57 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/ObjectDetectionModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/ObjectDetectionModule.md @@ -14,22 +14,23 @@ import { const imageUri = 'path/to/image.png'; -const module = new ObjectDetectionModule(); +// Creating an instance +const objectDetectionModule = new ObjectDetectionModule(); // Loading the model -await module.load(SSDLITE_320_MOBILENET_V3_LARGE); +await objectDetectionModule.load(SSDLITE_320_MOBILENET_V3_LARGE); // Running the model -const detections = await module.forward(imageUri); +const detections = await objectDetectionModule.forward(imageUri); ``` ### Methods -| Method | Type | Description | -| --------- | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `load` | `(modelSource: ResourceSource, onDownloadProgressCallback: (_: number) => void () => {}): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | -| `forward` | `(input: string, detectionThreshold: number = 0.7): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. `detectionThreshold` can be supplied to alter the sensitivity of the detection. | -| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. | +| Method | Type | Description | +| --------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `load` | `(model: { modelSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | +| `forward` | `(input: string, detectionThreshold: number = 0.7): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. `detectionThreshold` can be supplied to alter the sensitivity of the detection. | +| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. |
    Type definitions @@ -55,7 +56,17 @@ interface Detection { ## Loading the model -To load the model, create a new instance of the module and use the `load` method on it. It accepts the `modelSource` which is a string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +To load the model, create a new instance of the module and use the `load` method on it. It accepts an object: + +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/02-computer-vision/StyleTransferModule.md b/docs/docs/03-typescript-api/02-computer-vision/StyleTransferModule.md index 0beabbb863..87ba6c955d 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/StyleTransferModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/StyleTransferModule.md @@ -14,22 +14,23 @@ import { const imageUri = 'path/to/image.png'; -const module = new StyleTransferModule(); +// Creating an instance +const styleTransferModule = new StyleTransferModule(); // Loading the model -await module.load(STYLE_TRANSFER_CANDY); +await styleTransferModule.load(STYLE_TRANSFER_CANDY); // Running the model -const generatedImageUrl = await module.forward(imageUri); +const generatedImageUrl = await styleTransferModule.forward(imageUri); ``` ### Methods -| Method | Type | Description | -| --------- | -------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| `load` | `(modelSource: ResourceSource, onDownloadProgressCallback: (_: number) => void () => {}): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | -| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. | +| Method | Type | Description | +| --------- | ------------------------------------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| `load` | `(model: { modelSource: ResourceSource }, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. To track the download progress, supply a callback function `onDownloadProgressCallback`. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | +| `delete` | `(): void` | Release the memory held by the module. Calling `forward` afterwards is invalid. |
    Type definitions @@ -42,7 +43,17 @@ type ResourceSource = string | number | object; ## Loading the model -To load the model, create a new instance of the module and use the `load` method on it. It accepts the `modelSource` which is a string that specifies the location of the model binary. For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. This method returns a promise, which can resolve to an error or void. +To load the model, create a new instance of the module and use the `load` method on it. It accepts an object: + +**`model`** - Object containing the model source. + +- **`modelSource`** - A string that specifies the location of the model binary. + +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. + +This method returns a promise, which can resolve to an error or void. + +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md b/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md index da165b5120..bf3b56c7e5 100644 --- a/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md +++ b/docs/docs/03-typescript-api/02-computer-vision/VerticalOCRModule.md @@ -8,40 +8,29 @@ TypeScript API implementation of the [useVerticalOCR](../../02-hooks/02-computer ```typescript import { - DETECTOR_CRAFT_1280, - DETECTOR_CRAFT_320, - RECOGNIZER_EN_CRNN_512, - RECOGNIZER_EN_CRNN_64, - useVerticalOCR, + VerticalOCRModule, + VERTICAL_OCR_ENGLISH, } from 'react-native-executorch'; const imageUri = 'path/to/image.png'; +// Creating an instance +const verticalOCRModule = new VerticalOCRModule(); + // Loading the model -await VerticalOCRModule.load({ - detectorSources: { - detectorLarge: DETECTOR_CRAFT_1280, - detectorNarrow: DETECTOR_CRAFT_320, - }, - recognizerSources: { - recognizerLarge: RECOGNIZER_EN_CRNN_512, - recognizerSmall: RECOGNIZER_EN_CRNN_64, - }, - language: 'en', - independentCharacters: true, -}); +await verticalOCRModule.load(VERTICAL_OCR_ENGLISH); // Running the model -const ocrDetections = await VerticalOCRModule.forward(imageUri); +const detections = await verticalOCRModule.forward(imageUri); ``` ### Methods -| Method | Type | Description | -| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------ | -| `load` | `(detectorSources: DetectorSources, recognizerSources: RecognizerSources, language: OCRLanguage independentCharacters: boolean): Promise` | Loads detectors and recognizers, which sources are represented by `DetectorSources` and `RecognizerSources`. | -| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| -------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(model: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; recognizerLarge: ResourceSource; recognizerSmall: ResourceSource; language: OCRLanguage }, independentCharacters: boolean, onDownloadProgressCallback?: (progress: number) => void): Promise` | Loads the model, where `detectorLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels, `detectorNarrow` is a string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels, `recognizerLarge` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels, `recognizerSmall` is a string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels, and `language` is a parameter that specifies the language of the text to be recognized by the OCR. | +| `forward` | `(input: string): Promise` | Executes the model's forward pass, where `input` can be a fetchable resource or a Base64-encoded string. | +| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
    Type definitions @@ -141,29 +130,21 @@ interface OCRDetection { To load the model, use the `load` method. It accepts: -**`detectorSources`** - An object that specifies the location of the detectors binary files. Each detector is composed of two models tailored to process images of varying widths. - -- `detectorLarge` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels. -- `detectorNarrow` - A string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels. - -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. - -**`recognizerSources`** - An object that specifies the locations of the recognizers binary files. Each recognizer is composed of two models tailored to process images of varying widths. +**`model`** - Object containing the detector sources, recognizer sources, and language. -- `recognizerLarge` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. -- `recognizerSmall` - A string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels. - -For more information, take a look at [loading models](../../01-fundamentals/02-loading-models.md) section. - -**`language`** - A parameter that specifies the language of the text to be recognized by the OCR. +- **`detectorLarge`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 1280 pixels. +- **`detectorNarrow`** - A string that specifies the location of the detector binary file which accepts input images with a width of 320 pixels. +- **`recognizerLarge`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 512 pixels. +- **`recognizerSmall`** - A string that specifies the location of the recognizer binary file which accepts input images with a width of 64 pixels. +- **`language`** - A parameter that specifies the language of the text to be recognized by the OCR. **`independentCharacters`** – A boolean parameter that indicates whether the text in the image consists of a random sequence of characters. If set to true, the algorithm will scan each character individually instead of reading them as continuous text. -This method returns a promise, which can resolve to an error or void. +**`onDownloadProgressCallback`** - (Optional) Function called on download progress. -## Listening for download progress +This method returns a promise, which can resolve to an error or void. -To subscribe to the download progress event, you can use the `onDownloadProgress` method. It accepts a callback function that will be called whenever the download progress changes. +For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 71730b1857..574e1945be 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -1,212 +1,391 @@ import { Platform } from 'react-native'; +import { AvailableModels } from '../types/stt'; -// LLM's +const URL_PREFIX = + 'https://huggingface.co/software-mansion/react-native-executorch'; +const VERSION_TAG = 'resolve/v0.4.0'; +const NEXT_VERSION_TAG = 'resolve/v0.5.0'; + +// LLMs // LLAMA 3.2 -export const LLAMA3_2_3B = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/llama-3.2-3B/original/llama3_2_3B_bf16.pte'; -export const LLAMA3_2_3B_QLORA = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/llama-3.2-3B/QLoRA/llama3_2-3B_qat_lora.pte'; -export const LLAMA3_2_3B_SPINQUANT = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/llama-3.2-3B/spinquant/llama3_2_3B_spinquant.pte'; -export const LLAMA3_2_1B = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/llama-3.2-1B/original/llama3_2_bf16.pte'; -export const LLAMA3_2_1B_QLORA = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/llama-3.2-1B/QLoRA/llama3_2_qat_lora.pte'; -export const LLAMA3_2_1B_SPINQUANT = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/llama-3.2-1B/spinquant/llama3_2_spinquant.pte'; -export const LLAMA3_2_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/tokenizer.json'; -export const LLAMA3_2_TOKENIZER_CONFIG = - 'https://huggingface.co/software-mansion/react-native-executorch-llama-3.2/resolve/v0.4.0/tokenizer_config.json'; +const LLAMA3_2_3B_MODEL = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/llama-3.2-3B/original/llama3_2_3B_bf16.pte`; +const LLAMA3_2_3B_QLORA_MODEL = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/llama-3.2-3B/QLoRA/llama3_2-3B_qat_lora.pte`; +const LLAMA3_2_3B_SPINQUANT_MODEL = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/llama-3.2-3B/spinquant/llama3_2_3B_spinquant.pte`; +const LLAMA3_2_1B_MODEL = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/llama-3.2-1B/original/llama3_2_bf16.pte`; +const LLAMA3_2_1B_QLORA_MODEL = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/llama-3.2-1B/QLoRA/llama3_2_qat_lora.pte`; +const LLAMA3_2_1B_SPINQUANT_MODEL = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/llama-3.2-1B/spinquant/llama3_2_spinquant.pte`; +const LLAMA3_2_TOKENIZER = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/tokenizer.json`; +const LLAMA3_2_TOKENIZER_CONFIG = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/tokenizer_config.json`; + +export const LLAMA3_2_3B = { + modelSource: LLAMA3_2_3B_MODEL, + tokenizerSource: LLAMA3_2_TOKENIZER, + tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, +}; + +export const LLAMA3_2_3B_QLORA = { + modelSource: LLAMA3_2_3B_QLORA_MODEL, + tokenizerSource: LLAMA3_2_TOKENIZER, + tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, +}; + +export const LLAMA3_2_3B_SPINQUANT = { + modelSource: LLAMA3_2_3B_SPINQUANT_MODEL, + tokenizerSource: LLAMA3_2_TOKENIZER, + tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, +}; + +export const LLAMA3_2_1B = { + modelSource: LLAMA3_2_1B_MODEL, + tokenizerSource: LLAMA3_2_TOKENIZER, + tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, +}; + +export const LLAMA3_2_1B_QLORA = { + modelSource: LLAMA3_2_1B_QLORA_MODEL, + tokenizerSource: LLAMA3_2_TOKENIZER, + tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, +}; + +export const LLAMA3_2_1B_SPINQUANT = { + modelSource: LLAMA3_2_1B_SPINQUANT_MODEL, + tokenizerSource: LLAMA3_2_TOKENIZER, + tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, +}; // QWEN 3 -export const QWEN3_0_6B = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/qwen-3-0.6B/original/qwen3_0_6b_bf16.pte'; -export const QWEN3_0_6B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/qwen-3-0.6B/quantized/qwen3_0_6b_8da4w.pte'; -export const QWEN3_1_7B = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/qwen-3-1.7B/original/qwen3_1_7b_bf16.pte'; -export const QWEN3_1_7B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/qwen-3-1.7B/quantized/qwen3_1_7b_8da4w.pte'; -export const QWEN3_4B = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/qwen-3-4B/original/qwen3_4b_bf16.pte'; -export const QWEN3_4B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/qwen-3-4B/quantized/qwen3_4b_8da4w.pte'; -export const QWEN3_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/tokenizer.json'; -export const QWEN3_TOKENIZER_CONFIG = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-3/resolve/v0.4.0/tokenizer_config.json'; +const QWEN3_0_6B_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-0.6B/original/qwen3_0_6b_bf16.pte`; +const QWEN3_0_6B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-0.6B/quantized/qwen3_0_6b_8da4w.pte`; +const QWEN3_1_7B_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-1.7B/original/qwen3_1_7b_bf16.pte`; +const QWEN3_1_7B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-1.7B/quantized/qwen3_1_7b_8da4w.pte`; +const QWEN3_4B_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-4B/original/qwen3_4b_bf16.pte`; +const QWEN3_4B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-4B/quantized/qwen3_4b_8da4w.pte`; +const QWEN3_TOKENIZER = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/tokenizer.json`; +const QWEN3_TOKENIZER_CONFIG = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/tokenizer_config.json`; + +export const QWEN3_0_6B = { + modelSource: QWEN3_0_6B_MODEL, + tokenizerSource: QWEN3_TOKENIZER, + tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, +}; + +export const QWEN3_0_6B_QUANTIZED = { + modelSource: QWEN3_0_6B_QUANTIZED_MODEL, + tokenizerSource: QWEN3_TOKENIZER, + tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, +}; + +export const QWEN3_1_7B = { + modelSource: QWEN3_1_7B_MODEL, + tokenizerSource: QWEN3_TOKENIZER, + tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, +}; + +export const QWEN3_1_7B_QUANTIZED = { + modelSource: QWEN3_1_7B_QUANTIZED_MODEL, + tokenizerSource: QWEN3_TOKENIZER, + tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, +}; + +export const QWEN3_4B = { + modelSource: QWEN3_4B_MODEL, + tokenizerSource: QWEN3_TOKENIZER, + tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, +}; + +export const QWEN3_4B_QUANTIZED = { + modelSource: QWEN3_4B_QUANTIZED_MODEL, + tokenizerSource: QWEN3_TOKENIZER, + tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, +}; // HAMMER 2.1 -export const HAMMER2_1_0_5B = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/hammer-2.1-0.5B/original/hammer2_1_0_5B_bf16.pte'; -export const HAMMER2_1_0_5B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/hammer-2.1-0.5B/quantized/hammer2_1_0_5B_8da4w.pte'; -export const HAMMER2_1_1_5B = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/hammer-2.1-1.5B/original/hammer2_1_1_5B_bf16.pte'; -export const HAMMER2_1_1_5B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/hammer-2.1-1.5B/quantized/hammer2_1_1_5B_8da4w.pte'; -export const HAMMER2_1_3B = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/hammer-2.1-3B/original/hammer2_1_3B_bf16.pte'; -export const HAMMER2_1_3B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/hammer-2.1-3B/quantized/hammer2_1_3B_8da4w.pte'; -export const HAMMER2_1_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/tokenizer.json'; -export const HAMMER2_1_TOKENIZER_CONFIG = - 'https://huggingface.co/software-mansion/react-native-executorch-hammer-2.1/resolve/v0.4.0/tokenizer_config.json'; +const HAMMER2_1_0_5B_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-0.5B/original/hammer2_1_0_5B_bf16.pte`; +const HAMMER2_1_0_5B_QUANTIZED_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-0.5B/quantized/hammer2_1_0_5B_8da4w.pte`; +const HAMMER2_1_1_5B_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-1.5B/original/hammer2_1_1_5B_bf16.pte`; +const HAMMER2_1_1_5B_QUANTIZED_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-1.5B/quantized/hammer2_1_1_5B_8da4w.pte`; +const HAMMER2_1_3B_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-3B/original/hammer2_1_3B_bf16.pte`; +const HAMMER2_1_3B_QUANTIZED_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-3B/quantized/hammer2_1_3B_8da4w.pte`; +const HAMMER2_1_TOKENIZER = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/tokenizer.json`; +const HAMMER2_1_TOKENIZER_CONFIG = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/tokenizer_config.json`; + +export const HAMMER2_1_0_5B = { + modelSource: HAMMER2_1_0_5B_MODEL, + tokenizerSource: HAMMER2_1_TOKENIZER, + tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, +}; + +export const HAMMER2_1_0_5B_QUANTIZED = { + modelSource: HAMMER2_1_0_5B_QUANTIZED_MODEL, + tokenizerSource: HAMMER2_1_TOKENIZER, + tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, +}; + +export const HAMMER2_1_1_5B = { + modelSource: HAMMER2_1_1_5B_MODEL, + tokenizerSource: HAMMER2_1_TOKENIZER, + tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, +}; + +export const HAMMER2_1_1_5B_QUANTIZED = { + modelSource: HAMMER2_1_1_5B_QUANTIZED_MODEL, + tokenizerSource: HAMMER2_1_TOKENIZER, + tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, +}; + +export const HAMMER2_1_3B = { + modelSource: HAMMER2_1_3B_MODEL, + tokenizerSource: HAMMER2_1_TOKENIZER, + tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, +}; + +export const HAMMER2_1_3B_QUANTIZED = { + modelSource: HAMMER2_1_3B_QUANTIZED_MODEL, + tokenizerSource: HAMMER2_1_TOKENIZER, + tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, +}; // SMOLLM2 -export const SMOLLM2_1_135M = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/smolLm-2-135M/original/smolLm2_135M_bf16.pte'; -export const SMOLLM2_1_135M_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/smolLm-2-135M/quantized/smolLm2_135M_8da4w.pte'; -export const SMOLLM2_1_360M = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/smolLm-2-360M/original/smolLm2_360M_bf16.pte'; -export const SMOLLM2_1_360M_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/smolLm-2-360M/quantized/smolLm2_360M_8da4w.pte'; -export const SMOLLM2_1_1_7B = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/smolLm-2-1.7B/original/smolLm2_1_7B_bf16.pte'; -export const SMOLLM2_1_1_7B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/smolLm-2-1.7B/quantized/smolLm2_1_7B_8da4w.pte'; -export const SMOLLM2_1_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/tokenizer.json'; -export const SMOLLM2_1_TOKENIZER_CONFIG = - 'https://huggingface.co/software-mansion/react-native-executorch-smolLm-2/resolve/v0.4.0/tokenizer_config.json'; +const SMOLLM2_1_135M_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-135M/original/smolLm2_135M_bf16.pte`; +const SMOLLM2_1_135M_QUANTIZED_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-135M/quantized/smolLm2_135M_8da4w.pte`; +const SMOLLM2_1_360M_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-360M/original/smolLm2_360M_bf16.pte`; +const SMOLLM2_1_360M_QUANTIZED_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-360M/quantized/smolLm2_360M_8da4w.pte`; +const SMOLLM2_1_1_7B_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-1.7B/original/smolLm2_1_7B_bf16.pte`; +const SMOLLM2_1_1_7B_QUANTIZED_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-1.7B/quantized/smolLm2_1_7B_8da4w.pte`; +const SMOLLM2_1_TOKENIZER = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/tokenizer.json`; +const SMOLLM2_1_TOKENIZER_CONFIG = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/tokenizer_config.json`; + +export const SMOLLM2_1_135M = { + modelSource: SMOLLM2_1_135M_MODEL, + tokenizerSource: SMOLLM2_1_TOKENIZER, + tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, +}; + +export const SMOLLM2_1_135M_QUANTIZED = { + modelSource: SMOLLM2_1_135M_QUANTIZED_MODEL, + tokenizerSource: SMOLLM2_1_TOKENIZER, + tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, +}; + +export const SMOLLM2_1_360M = { + modelSource: SMOLLM2_1_360M_MODEL, + tokenizerSource: SMOLLM2_1_TOKENIZER, + tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, +}; + +export const SMOLLM2_1_360M_QUANTIZED = { + modelSource: SMOLLM2_1_360M_QUANTIZED_MODEL, + tokenizerSource: SMOLLM2_1_TOKENIZER, + tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, +}; + +export const SMOLLM2_1_1_7B = { + modelSource: SMOLLM2_1_1_7B_MODEL, + tokenizerSource: SMOLLM2_1_TOKENIZER, + tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, +}; + +export const SMOLLM2_1_1_7B_QUANTIZED = { + modelSource: SMOLLM2_1_1_7B_QUANTIZED_MODEL, + tokenizerSource: SMOLLM2_1_TOKENIZER, + tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, +}; // QWEN 2.5 -export const QWEN2_5_0_5B = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/qwen-2.5-0.5B/original/qwen2_5_0_5b_bf16.pte'; -export const QWEN2_5_0_5B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/qwen-2.5-0.5B/quantized/qwen2_5_0_5b_8da4w.pte'; -export const QWEN2_5_1_5B = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/qwen-2.5-1.5B/original/qwen2_5_1_5b_bf16.pte'; -export const QWEN2_5_1_5B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/qwen-2.5-1.5B/quantized/qwen2_5_1_5b_8da4w.pte'; -export const QWEN2_5_3B = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/qwen-2.5-3B/original/qwen2_5_3b_bf16.pte'; -export const QWEN2_5_3B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/qwen-2.5-3B/quantized/qwen2_5_3b_8da4w.pte'; -export const QWEN2_5_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/tokenizer.json'; -export const QWEN2_5_TOKENIZER_CONFIG = - 'https://huggingface.co/software-mansion/react-native-executorch-qwen-2.5/resolve/v0.4.0/tokenizer_config.json'; +const QWEN2_5_0_5B_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-0.5B/original/qwen2_5_0_5b_bf16.pte`; +const QWEN2_5_0_5B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-0.5B/quantized/qwen2_5_0_5b_8da4w.pte`; +const QWEN2_5_1_5B_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-1.5B/original/qwen2_5_1_5b_bf16.pte`; +const QWEN2_5_1_5B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-1.5B/quantized/qwen2_5_1_5b_8da4w.pte`; +const QWEN2_5_3B_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-3B/original/qwen2_5_3b_bf16.pte`; +const QWEN2_5_3B_QUANTIZED_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-3B/quantized/qwen2_5_3b_8da4w.pte`; +const QWEN2_5_TOKENIZER = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/tokenizer.json`; +const QWEN2_5_TOKENIZER_CONFIG = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/tokenizer_config.json`; + +export const QWEN2_5_0_5B = { + modelSource: QWEN2_5_0_5B_MODEL, + tokenizerSource: QWEN2_5_TOKENIZER, + tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, +}; + +export const QWEN2_5_0_5B_QUANTIZED = { + modelSource: QWEN2_5_0_5B_QUANTIZED_MODEL, + tokenizerSource: QWEN2_5_TOKENIZER, + tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, +}; + +export const QWEN2_5_1_5B = { + modelSource: QWEN2_5_1_5B_MODEL, + tokenizerSource: QWEN2_5_TOKENIZER, + tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, +}; + +export const QWEN2_5_1_5B_QUANTIZED = { + modelSource: QWEN2_5_1_5B_QUANTIZED_MODEL, + tokenizerSource: QWEN2_5_TOKENIZER, + tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, +}; + +export const QWEN2_5_3B = { + modelSource: QWEN2_5_3B_MODEL, + tokenizerSource: QWEN2_5_TOKENIZER, + tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, +}; + +export const QWEN2_5_3B_QUANTIZED = { + modelSource: QWEN2_5_3B_QUANTIZED_MODEL, + tokenizerSource: QWEN2_5_TOKENIZER, + tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, +}; // PHI 4 -export const PHI_4_MINI_4B = - 'https://huggingface.co/software-mansion/react-native-executorch-phi-4-mini/resolve/v0.4.0/original/phi-4-mini_bf16.pte'; -export const PHI_4_MINI_4B_QUANTIZED = - 'https://huggingface.co/software-mansion/react-native-executorch-phi-4-mini/resolve/v0.4.0/quantized/phi-4-mini_8da4w.pte'; -export const PHI_4_MINI_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-phi-4-mini/resolve/v0.4.0/tokenizer.json'; -export const PHI_4_MINI_TOKENIZER_CONFIG = - 'https://huggingface.co/software-mansion/react-native-executorch-phi-4-mini/resolve/v0.4.0/tokenizer_config.json'; +const PHI_4_MINI_4B_MODEL = `${URL_PREFIX}-phi-4-mini/${VERSION_TAG}/original/phi-4-mini_bf16.pte`; +const PHI_4_MINI_4B_QUANTIZED_MODEL = `${URL_PREFIX}-phi-4-mini/${VERSION_TAG}/quantized/phi-4-mini_8da4w.pte`; +const PHI_4_MINI_TOKENIZER = `${URL_PREFIX}-phi-4-mini/${VERSION_TAG}/tokenizer.json`; +const PHI_4_MINI_TOKENIZER_CONFIG = `${URL_PREFIX}-phi-4-mini/${VERSION_TAG}/tokenizer_config.json`; + +export const PHI_4_MINI_4B = { + modelSource: PHI_4_MINI_4B_MODEL, + tokenizerSource: PHI_4_MINI_TOKENIZER, + tokenizerConfigSource: PHI_4_MINI_TOKENIZER_CONFIG, +}; + +export const PHI_4_MINI_4B_QUANTIZED = { + modelSource: PHI_4_MINI_4B_QUANTIZED_MODEL, + tokenizerSource: PHI_4_MINI_TOKENIZER, + tokenizerConfigSource: PHI_4_MINI_TOKENIZER_CONFIG, +}; // Classification -export const EFFICIENTNET_V2_S = - Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.4.0/coreml/efficientnet_v2_s_coreml_all.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s/resolve/v0.4.0/xnnpack/efficientnet_v2_s_xnnpack.pte'; +const EFFICIENTNET_V2_S_MODEL = + Platform.OS === `ios` + ? `${URL_PREFIX}-efficientnet-v2-s/${VERSION_TAG}/coreml/efficientnet_v2_s_coreml_all.pte` + : `${URL_PREFIX}-efficientnet-v2-s/${VERSION_TAG}/xnnpack/efficientnet_v2_s_xnnpack.pte`; + +export const EFFICIENTNET_V2_S = { + modelSource: EFFICIENTNET_V2_S_MODEL, +}; // Object detection -export const SSDLITE_320_MOBILENET_V3_LARGE = - 'https://huggingface.co/software-mansion/react-native-executorch-ssdlite320-mobilenet-v3-large/resolve/v0.4.0/ssdlite320-mobilenetv3-large.pte'; +const SSDLITE_320_MOBILENET_V3_LARGE_MODEL = `${URL_PREFIX}-ssdlite320-mobilenet-v3-large/${VERSION_TAG}/ssdlite320-mobilenetv3-large.pte`; + +export const SSDLITE_320_MOBILENET_V3_LARGE = { + modelSource: SSDLITE_320_MOBILENET_V3_LARGE_MODEL, +}; // Style transfer -export const STYLE_TRANSFER_CANDY = - Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.4.0/coreml/style_transfer_candy_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-candy/resolve/v0.4.0/xnnpack/style_transfer_candy_xnnpack.pte'; -export const STYLE_TRANSFER_MOSAIC = - Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.4.0/coreml/style_transfer_mosaic_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-mosaic/resolve/v0.4.0/xnnpack/style_transfer_mosaic_xnnpack.pte'; -export const STYLE_TRANSFER_RAIN_PRINCESS = - Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.4.0/coreml/style_transfer_rain_princess_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-rain-princess/resolve/v0.4.0/xnnpack/style_transfer_rain_princess_xnnpack.pte'; -export const STYLE_TRANSFER_UDNIE = - Platform.OS === 'ios' - ? 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.4.0/coreml/style_transfer_udnie_coreml.pte' - : 'https://huggingface.co/software-mansion/react-native-executorch-style-transfer-udnie/resolve/v0.4.0/xnnpack/style_transfer_udnie_xnnpack.pte'; +const STYLE_TRANSFER_CANDY_MODEL = + Platform.OS === `ios` + ? `${URL_PREFIX}-style-transfer-candy/${VERSION_TAG}/coreml/style_transfer_candy_coreml.pte` + : `${URL_PREFIX}-style-transfer-candy/${VERSION_TAG}/xnnpack/style_transfer_candy_xnnpack.pte`; +const STYLE_TRANSFER_MOSAIC_MODEL = + Platform.OS === `ios` + ? `${URL_PREFIX}-style-transfer-mosaic/${VERSION_TAG}/coreml/style_transfer_mosaic_coreml.pte` + : `${URL_PREFIX}-style-transfer-mosaic/${VERSION_TAG}/xnnpack/style_transfer_mosaic_xnnpack.pte`; +const STYLE_TRANSFER_RAIN_PRINCESS_MODEL = + Platform.OS === `ios` + ? `${URL_PREFIX}-style-transfer-rain-princess/${VERSION_TAG}/coreml/style_transfer_rain_princess_coreml.pte` + : `${URL_PREFIX}-style-transfer-rain-princess/${VERSION_TAG}/xnnpack/style_transfer_rain_princess_xnnpack.pte`; +const STYLE_TRANSFER_UDNIE_MODEL = + Platform.OS === `ios` + ? `${URL_PREFIX}-style-transfer-udnie/${VERSION_TAG}/coreml/style_transfer_udnie_coreml.pte` + : `${URL_PREFIX}-style-transfer-udnie/${VERSION_TAG}/xnnpack/style_transfer_udnie_xnnpack.pte`; + +export const STYLE_TRANSFER_CANDY = { + modelSource: STYLE_TRANSFER_CANDY_MODEL, +}; + +export const STYLE_TRANSFER_MOSAIC = { + modelSource: STYLE_TRANSFER_MOSAIC_MODEL, +}; + +export const STYLE_TRANSFER_RAIN_PRINCESS = { + modelSource: STYLE_TRANSFER_RAIN_PRINCESS_MODEL, +}; + +export const STYLE_TRANSFER_UDNIE = { + modelSource: STYLE_TRANSFER_UDNIE_MODEL, +}; // S2T -export const MOONSHINE_TINY_DECODER = - 'https://huggingface.co/software-mansion/react-native-executorch-moonshine-tiny/resolve/v0.4.0/xnnpack/moonshine_tiny_xnnpack_decoder.pte'; -export const MOONSHINE_TINY_ENCODER = - 'https://huggingface.co/software-mansion/react-native-executorch-moonshine-tiny/resolve/v0.4.0/xnnpack/moonshine_tiny_xnnpack_encoder.pte'; -export const MOONSHINE_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-moonshine-tiny/resolve/v0.4.0/moonshine_tiny_tokenizer.json'; -export const WHISPER_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-whisper-tiny.en/resolve/v0.4.0/whisper_tokenizer.json'; -export const WHISPER_TINY_DECODER = - 'https://huggingface.co/software-mansion/react-native-executorch-whisper-tiny.en/resolve/v0.4.0/xnnpack/whisper_tiny_en_xnnpack_decoder.pte'; -export const WHISPER_TINY_ENCODER = - 'https://huggingface.co/software-mansion/react-native-executorch-whisper-tiny.en/resolve/v0.4.0/xnnpack/whisper_tiny_en_xnnpack_encoder.pte'; -export const WHISPER_TINY_MULTILINGUAL_ENCODER = - 'https://huggingface.co/software-mansion/react-native-executorch-whisper-tiny/resolve/v0.4.0/xnnpack/xnnpack_whisper_encoder.pte'; -export const WHISPER_TINY_MULTILINGUAL_DECODER = - 'https://huggingface.co/software-mansion/react-native-executorch-whisper-tiny/resolve/v0.4.0/xnnpack/xnnpack_whisper_decoder.pte'; -export const WHISPER_TINY_MULTILINGUAL_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-whisper-tiny/resolve/v0.4.0/tokenizer.json'; - -// OCR -export const DETECTOR_CRAFT_1280 = - 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.4.0/xnnpack/xnnpack_craft_1280.pte'; -export const DETECTOR_CRAFT_800 = - 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.4.0/xnnpack/xnnpack_craft_800.pte'; -export const DETECTOR_CRAFT_320 = - 'https://huggingface.co/software-mansion/react-native-executorch-detector-craft/resolve/v0.4.0/xnnpack/xnnpack_craft_320.pte'; +const MOONSHINE_TINY_DECODER_MODEL = `${URL_PREFIX}-moonshine-tiny/${VERSION_TAG}/xnnpack/moonshine_tiny_xnnpack_decoder.pte`; +const MOONSHINE_TINY_ENCODER_MODEL = `${URL_PREFIX}-moonshine-tiny/${VERSION_TAG}/xnnpack/moonshine_tiny_xnnpack_encoder.pte`; +const MOONSHINE_TOKENIZER = `${URL_PREFIX}-moonshine-tiny/${VERSION_TAG}/moonshine_tiny_tokenizer.json`; +const WHISPER_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/whisper_tokenizer.json`; +const WHISPER_TINY_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_xnnpack_decoder.pte`; +const WHISPER_TINY_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/xnnpack/whisper_tiny_en_xnnpack_encoder.pte`; +const WHISPER_TINY_MULTILINGUAL_ENCODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/xnnpack_whisper_encoder.pte`; +const WHISPER_TINY_MULTILINGUAL_DECODER_MODEL = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/xnnpack/xnnpack_whisper_decoder.pte`; +const WHISPER_TINY_MULTILINGUAL_TOKENIZER = `${URL_PREFIX}-whisper-tiny/${VERSION_TAG}/tokenizer.json`; + +export const MOONSHINE_TINY = { + modelName: AvailableModels.MOONSHINE, + decoderSource: MOONSHINE_TINY_DECODER_MODEL, + encoderSource: MOONSHINE_TINY_ENCODER_MODEL, + tokenizerSource: MOONSHINE_TOKENIZER, +}; + +export const WHISPER_TINY = { + modelName: AvailableModels.WHISPER, + decoderSource: WHISPER_TINY_DECODER_MODEL, + encoderSource: WHISPER_TINY_ENCODER_MODEL, + tokenizerSource: WHISPER_TOKENIZER, +}; + +export const WHISPER_TINY_MULTILINGUAL = { + modelName: AvailableModels.WHISPER_MULTILINGUAL, + decoderSource: WHISPER_TINY_MULTILINGUAL_DECODER_MODEL, + encoderSource: WHISPER_TINY_MULTILINGUAL_ENCODER_MODEL, + tokenizerSource: WHISPER_TINY_MULTILINGUAL_TOKENIZER, +}; // Image segmentation -export const DEEPLAB_V3_RESNET50 = - 'https://huggingface.co/software-mansion/react-native-executorch-deeplab-v3/resolve/v0.4.0/xnnpack/deeplabV3_xnnpack_fp32.pte'; +const DEEPLAB_V3_RESNET50_MODEL = `${URL_PREFIX}-deeplab-v3/${VERSION_TAG}/xnnpack/deeplabV3_xnnpack_fp32.pte`; + +export const DEEPLAB_V3_RESNET50 = { + modelSource: DEEPLAB_V3_RESNET50_MODEL, +}; // Image Embeddings -export const CLIP_VIT_BASE_PATCH32_IMAGE_MODEL = - 'https://huggingface.co/software-mansion/react-native-executorch-clip-vit-base-patch32/resolve/v0.5.0/clip-vit-base-patch32-vision_xnnpack.pte'; +const CLIP_VIT_BASE_PATCH32_IMAGE_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${NEXT_VERSION_TAG}/clip-vit-base-patch32-vision_xnnpack.pte`; + +export const CLIP_VIT_BASE_PATCH32_IMAGE = { + modelSource: CLIP_VIT_BASE_PATCH32_IMAGE_MODEL, +}; // Text Embeddings -export const ALL_MINILM_L6_V2 = - 'https://huggingface.co/software-mansion/react-native-executorch-all-MiniLM-L6-v2/resolve/v0.5.0/all-MiniLM-L6-v2_xnnpack.pte'; -export const ALL_MINILM_L6_V2_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-all-MiniLM-L6-v2/resolve/v0.5.0/tokenizer.json'; - -export const ALL_MPNET_BASE_V2 = - 'https://huggingface.co/software-mansion/react-native-executorch-all-mpnet-base-v2/resolve/v0.5.0/all-mpnet-base-v2_xnnpack.pte'; -export const ALL_MPNET_BASE_V2_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-all-mpnet-base-v2/resolve/v0.5.0/tokenizer.json'; - -export const MULTI_QA_MINILM_L6_COS_V1 = - 'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-MiniLM-L6-cos-v1/resolve/v0.5.0/multi-qa-MiniLM-L6-cos-v1_xnnpack.pte'; -export const MULTI_QA_MINILM_L6_COS_V1_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-MiniLM-L6-cos-v1/resolve/v0.5.0/tokenizer.json'; - -export const MULTI_QA_MPNET_BASE_DOT_V1 = - 'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-mpnet-base-dot-v1/resolve/v0.5.0/multi-qa-mpnet-base-dot-v1_xnnpack.pte'; -export const MULTI_QA_MPNET_BASE_DOT_V1_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-multi-qa-mpnet-base-dot-v1/resolve/v0.5.0/tokenizer.json'; - -export const CLIP_VIT_BASE_PATCH32_TEXT_MODEL = - 'https://huggingface.co/software-mansion/react-native-executorch-clip-vit-base-patch32/resolve/v0.5.0/clip-vit-base-patch32-text_xnnpack.pte'; -export const CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER = - 'https://huggingface.co/software-mansion/react-native-executorch-clip-vit-base-patch32/resolve/v0.5.0/tokenizer.json'; +const ALL_MINILM_L6_V2_MODEL = `${URL_PREFIX}-all-MiniLM-L6-v2/${NEXT_VERSION_TAG}/all-MiniLM-L6-v2_xnnpack.pte`; +const ALL_MINILM_L6_V2_TOKENIZER = `${URL_PREFIX}-all-MiniLM-L6-v2/${NEXT_VERSION_TAG}/tokenizer.json`; +const ALL_MPNET_BASE_V2_MODEL = `${URL_PREFIX}-all-mpnet-base-v2/${NEXT_VERSION_TAG}/all-mpnet-base-v2_xnnpack.pte`; +const ALL_MPNET_BASE_V2_TOKENIZER = `${URL_PREFIX}-all-mpnet-base-v2/${NEXT_VERSION_TAG}/tokenizer.json`; +const MULTI_QA_MINILM_L6_COS_V1_MODEL = `${URL_PREFIX}-multi-qa-MiniLM-L6-cos-v1/${NEXT_VERSION_TAG}/multi-qa-MiniLM-L6-cos-v1_xnnpack.pte`; +const MULTI_QA_MINILM_L6_COS_V1_TOKENIZER = `${URL_PREFIX}-multi-qa-MiniLM-L6-cos-v1/${NEXT_VERSION_TAG}/tokenizer.json`; +const MULTI_QA_MPNET_BASE_DOT_V1_MODEL = `${URL_PREFIX}-multi-qa-mpnet-base-dot-v1/${NEXT_VERSION_TAG}/multi-qa-mpnet-base-dot-v1_xnnpack.pte`; +const MULTI_QA_MPNET_BASE_DOT_V1_TOKENIZER = `${URL_PREFIX}-multi-qa-mpnet-base-dot-v1/${NEXT_VERSION_TAG}/tokenizer.json`; +const CLIP_VIT_BASE_PATCH32_TEXT_MODEL = `${URL_PREFIX}-clip-vit-base-patch32/${NEXT_VERSION_TAG}/clip-vit-base-patch32-text_xnnpack.pte`; +const CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER = `${URL_PREFIX}-clip-vit-base-patch32/${NEXT_VERSION_TAG}/tokenizer.json`; -export const CLIP_VIT_BASE_PATCH32_TEXT = { - modelSource: CLIP_VIT_BASE_PATCH32_TEXT_MODEL, - tokenizerSource: CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER, +export const ALL_MINILM_L6_V2 = { + modelSource: ALL_MINILM_L6_V2_MODEL, + tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, }; -export const CLIP_VIT_BASE_PATCH32_IMAGE = { - modelSource: CLIP_VIT_BASE_PATCH32_IMAGE_MODEL, +export const ALL_MPNET_BASE_V2 = { + modelSource: ALL_MPNET_BASE_V2_MODEL, + tokenizerSource: ALL_MPNET_BASE_V2_TOKENIZER, +}; + +export const MULTI_QA_MINILM_L6_COS_V1 = { + modelSource: MULTI_QA_MINILM_L6_COS_V1_MODEL, + tokenizerSource: MULTI_QA_MINILM_L6_COS_V1_TOKENIZER, }; -// Backward compatibility -export const LLAMA3_2_3B_URL = LLAMA3_2_3B; -export const LLAMA3_2_3B_QLORA_URL = LLAMA3_2_3B_QLORA; -export const LLAMA3_2_3B_SPINQUANT_URL = LLAMA3_2_3B_SPINQUANT; -export const LLAMA3_2_1B_URL = LLAMA3_2_1B; -export const LLAMA3_2_1B_QLORA_URL = LLAMA3_2_1B_QLORA; -export const LLAMA3_2_1B_SPINQUANT_URL = LLAMA3_2_1B_SPINQUANT; -export const LLAMA3_2_1B_TOKENIZER = LLAMA3_2_TOKENIZER; -export const LLAMA3_2_3B_TOKENIZER = LLAMA3_2_TOKENIZER; +export const MULTI_QA_MPNET_BASE_DOT_V1 = { + modelSource: MULTI_QA_MPNET_BASE_DOT_V1_MODEL, + tokenizerSource: MULTI_QA_MPNET_BASE_DOT_V1_TOKENIZER, +}; + +export const CLIP_VIT_BASE_PATCH32_TEXT = { + modelSource: CLIP_VIT_BASE_PATCH32_TEXT_MODEL, + tokenizerSource: CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER, +}; diff --git a/packages/react-native-executorch/src/constants/ocr/models.ts b/packages/react-native-executorch/src/constants/ocr/models.ts index 4f198612c0..f16ea5728c 100644 --- a/packages/react-native-executorch/src/constants/ocr/models.ts +++ b/packages/react-native-executorch/src/constants/ocr/models.ts @@ -1,453 +1,884 @@ -const createHFRecognizerDownloadUrl = (alphabet: string, size: number) => { - return `https://huggingface.co/software-mansion/react-native-executorch-recognizer-crnn.en/resolve/v0.4.0/xnnpack/${alphabet}/xnnpack_crnn_${alphabet}_${size}.pte`; -}; -export const RECOGNIZER_ENGLISH_CRNN_512 = createHFRecognizerDownloadUrl( +import { alphabets, symbols } from './symbols'; + +const URL_PREFIX = + 'https://huggingface.co/software-mansion/react-native-executorch'; +const VERSION_TAG = 'resolve/v0.4.0'; + +const DETECTOR_CRAFT_1280_MODEL = `${URL_PREFIX}-detector-craft/${VERSION_TAG}/xnnpack/xnnpack_craft_1280.pte`; +const DETECTOR_CRAFT_800_MODEL = `${URL_PREFIX}-detector-craft/${VERSION_TAG}/xnnpack/xnnpack_craft_800.pte`; +const DETECTOR_CRAFT_320_MODEL = `${URL_PREFIX}-detector-craft/${VERSION_TAG}/xnnpack/xnnpack_craft_320.pte`; + +type RecognizerSize = 64 | 128 | 256 | 512; + +const createHFRecognizerDownloadUrl = ( + alphabet: keyof typeof alphabets, + size: RecognizerSize +) => + `${URL_PREFIX}-recognizer-crnn.en/${VERSION_TAG}/xnnpack/${alphabet}/xnnpack_crnn_${alphabet}_${size}.pte`; + +const RECOGNIZER_ENGLISH_CRNN_512 = createHFRecognizerDownloadUrl( 'english', 512 ); -export const RECOGNIZER_ENGLISH_CRNN_256 = createHFRecognizerDownloadUrl( +const RECOGNIZER_ENGLISH_CRNN_256 = createHFRecognizerDownloadUrl( 'english', 256 ); -export const RECOGNIZER_ENGLISH_CRNN_128 = createHFRecognizerDownloadUrl( +const RECOGNIZER_ENGLISH_CRNN_128 = createHFRecognizerDownloadUrl( 'english', 128 ); -export const RECOGNIZER_ENGLISH_CRNN_64 = createHFRecognizerDownloadUrl( - 'english', - 64 -); +const RECOGNIZER_ENGLISH_CRNN_64 = createHFRecognizerDownloadUrl('english', 64); -export const RECOGNIZER_LATIN_CRNN_512 = createHFRecognizerDownloadUrl( - 'latin', - 512 -); -export const RECOGNIZER_LATIN_CRNN_256 = createHFRecognizerDownloadUrl( - 'latin', - 256 -); -export const RECOGNIZER_LATIN_CRNN_128 = createHFRecognizerDownloadUrl( - 'latin', - 128 -); -export const RECOGNIZER_LATIN_CRNN_64 = createHFRecognizerDownloadUrl( - 'latin', - 64 -); +const RECOGNIZER_LATIN_CRNN_512 = createHFRecognizerDownloadUrl('latin', 512); +const RECOGNIZER_LATIN_CRNN_256 = createHFRecognizerDownloadUrl('latin', 256); +const RECOGNIZER_LATIN_CRNN_128 = createHFRecognizerDownloadUrl('latin', 128); +const RECOGNIZER_LATIN_CRNN_64 = createHFRecognizerDownloadUrl('latin', 64); -export const RECOGNIZER_JAPANESE_CRNN_512 = createHFRecognizerDownloadUrl( +const RECOGNIZER_JAPANESE_CRNN_512 = createHFRecognizerDownloadUrl( 'japanese', 512 ); -export const RECOGNIZER_JAPANESE_CRNN_256 = createHFRecognizerDownloadUrl( +const RECOGNIZER_JAPANESE_CRNN_256 = createHFRecognizerDownloadUrl( 'japanese', 256 ); -export const RECOGNIZER_JAPANESE_CRNN_128 = createHFRecognizerDownloadUrl( +const RECOGNIZER_JAPANESE_CRNN_128 = createHFRecognizerDownloadUrl( 'japanese', 128 ); -export const RECOGNIZER_JAPANESE_CRNN_64 = createHFRecognizerDownloadUrl( +const RECOGNIZER_JAPANESE_CRNN_64 = createHFRecognizerDownloadUrl( 'japanese', 64 ); -export const RECOGNIZER_KANNADA_CRNN_512 = createHFRecognizerDownloadUrl( +const RECOGNIZER_KANNADA_CRNN_512 = createHFRecognizerDownloadUrl( 'kannada', 512 ); -export const RECOGNIZER_KANNADA_CRNN_256 = createHFRecognizerDownloadUrl( +const RECOGNIZER_KANNADA_CRNN_256 = createHFRecognizerDownloadUrl( 'kannada', 256 ); -export const RECOGNIZER_KANNADA_CRNN_128 = createHFRecognizerDownloadUrl( +const RECOGNIZER_KANNADA_CRNN_128 = createHFRecognizerDownloadUrl( 'kannada', 128 ); -export const RECOGNIZER_KANNADA_CRNN_64 = createHFRecognizerDownloadUrl( - 'kannada', - 64 -); +const RECOGNIZER_KANNADA_CRNN_64 = createHFRecognizerDownloadUrl('kannada', 64); + +const RECOGNIZER_KOREAN_CRNN_512 = createHFRecognizerDownloadUrl('korean', 512); +const RECOGNIZER_KOREAN_CRNN_256 = createHFRecognizerDownloadUrl('korean', 256); +const RECOGNIZER_KOREAN_CRNN_128 = createHFRecognizerDownloadUrl('korean', 128); +const RECOGNIZER_KOREAN_CRNN_64 = createHFRecognizerDownloadUrl('korean', 64); + +const RECOGNIZER_TELUGU_CRNN_512 = createHFRecognizerDownloadUrl('telugu', 512); +const RECOGNIZER_TELUGU_CRNN_256 = createHFRecognizerDownloadUrl('telugu', 256); +const RECOGNIZER_TELUGU_CRNN_128 = createHFRecognizerDownloadUrl('telugu', 128); +const RECOGNIZER_TELUGU_CRNN_64 = createHFRecognizerDownloadUrl('telugu', 64); -export const RECOGNIZER_KOREAN_CRNN_512 = createHFRecognizerDownloadUrl( - 'korean', +const RECOGNIZER_ZH_SIM_CRNN_512 = createHFRecognizerDownloadUrl('zhSim', 512); +const RECOGNIZER_ZH_SIM_CRNN_256 = createHFRecognizerDownloadUrl('zhSim', 256); +const RECOGNIZER_ZH_SIM_CRNN_128 = createHFRecognizerDownloadUrl('zhSim', 128); +const RECOGNIZER_ZH_SIM_CRNN_64 = createHFRecognizerDownloadUrl('zhSim', 64); + +const RECOGNIZER_CYRILLIC_CRNN_512 = createHFRecognizerDownloadUrl( + 'cyrillic', 512 ); -export const RECOGNIZER_KOREAN_CRNN_256 = createHFRecognizerDownloadUrl( - 'korean', +const RECOGNIZER_CYRILLIC_CRNN_256 = createHFRecognizerDownloadUrl( + 'cyrillic', 256 ); -export const RECOGNIZER_KOREAN_CRNN_128 = createHFRecognizerDownloadUrl( - 'korean', +const RECOGNIZER_CYRILLIC_CRNN_128 = createHFRecognizerDownloadUrl( + 'cyrillic', 128 ); -export const RECOGNIZER_KOREAN_CRNN_64 = createHFRecognizerDownloadUrl( - 'korean', +const RECOGNIZER_CYRILLIC_CRNN_64 = createHFRecognizerDownloadUrl( + 'cyrillic', 64 ); -export const RECOGNIZER_TELUGU_CRNN_512 = createHFRecognizerDownloadUrl( - 'telugu', - 512 +const createOCRObject = ( + recognizerLarge: string, + recognizerMedium: string, + recognizerSmall: string, + language: keyof typeof symbols +) => { + return { + detectorSource: DETECTOR_CRAFT_800_MODEL, + recognizerLarge, + recognizerMedium, + recognizerSmall, + language, + }; +}; + +const createVerticalOCRObject = ( + recognizerLarge: string, + recognizerSmall: string, + language: keyof typeof symbols +) => { + return { + detectorLarge: DETECTOR_CRAFT_1280_MODEL, + detectorNarrow: DETECTOR_CRAFT_320_MODEL, + recognizerLarge, + recognizerSmall, + language, + }; +}; + +export const OCR_ABAZA = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'abq' ); -export const RECOGNIZER_TELUGU_CRNN_256 = createHFRecognizerDownloadUrl( - 'telugu', - 256 +export const VERTICAL_OCR_ABAZA = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'abq' ); -export const RECOGNIZER_TELUGU_CRNN_128 = createHFRecognizerDownloadUrl( - 'telugu', - 128 + +export const OCR_ADYGHE = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'ady' ); -export const RECOGNIZER_TELUGU_CRNN_64 = createHFRecognizerDownloadUrl( - 'telugu', - 64 +export const VERTICAL_OCR_ADYGHE = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'ady' ); -export const RECOGNIZER_ZH_SIM_CRNN_512 = createHFRecognizerDownloadUrl( - 'zh-sim', - 512 +export const OCR_AFRIKAANS = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'af' ); -export const RECOGNIZER_ZH_SIM_CRNN_256 = createHFRecognizerDownloadUrl( - 'zh-sim', - 256 +export const VERTICAL_OCR_AFRIKAANS = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'af' ); -export const RECOGNIZER_ZH_SIM_CRNN_128 = createHFRecognizerDownloadUrl( - 'zh-sim', - 128 + +export const OCR_AVAR = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'ava' ); -export const RECOGNIZER_ZH_SIM_CRNN_64 = createHFRecognizerDownloadUrl( - 'zh-sim', - 64 +export const VERTICAL_OCR_AVAR = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'ava' ); -export const RECOGNIZER_CYRILLIC_CRNN_512 = createHFRecognizerDownloadUrl( - 'cyrillic', - 512 +export const OCR_AZERBAIJANI = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'az' ); -export const RECOGNIZER_CYRILLIC_CRNN_256 = createHFRecognizerDownloadUrl( - 'cyrillic', - 256 +export const VERTICAL_OCR_AZERBAIJANI = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'az' ); -export const RECOGNIZER_CYRILLIC_CRNN_128 = createHFRecognizerDownloadUrl( - 'cyrillic', - 128 + +export const OCR_BELARUSIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'be' ); -export const RECOGNIZER_CYRILLIC_CRNN_64 = createHFRecognizerDownloadUrl( - 'cyrillic', - 64 +export const VERTICAL_OCR_BELARUSIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'be' +); + +export const OCR_BULGARIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'bg' +); +export const VERTICAL_OCR_BULGARIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'bg' +); + +export const OCR_BOSNIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'bs' +); +export const VERTICAL_OCR_BOSNIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'bs' +); + +export const OCR_SIMPLIFIED_CHINESE = createOCRObject( + RECOGNIZER_ZH_SIM_CRNN_512, + RECOGNIZER_ZH_SIM_CRNN_256, + RECOGNIZER_ZH_SIM_CRNN_128, + 'chSim' +); +export const VERTICAL_OCR_SIMPLIFIED_CHINESE = createVerticalOCRObject( + RECOGNIZER_ZH_SIM_CRNN_512, + RECOGNIZER_ZH_SIM_CRNN_64, + 'chSim' +); + +export const OCR_CHECHEN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'che' +); +export const VERTICAL_OCR_CHECHEN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'che' +); + +export const OCR_CZECH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'cs' +); +export const VERTICAL_OCR_CZECH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'cs' +); + +export const OCR_WELSH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'cy' +); +export const VERTICAL_OCR_WELSH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'cy' +); + +export const OCR_DANISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'da' +); +export const VERTICAL_OCR_DANISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'da' ); -export const RECOGNIZER_ABQ_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_ABQ_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_ABQ_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_ABQ_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_ADY_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_ADY_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_ADY_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_ADY_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_AF_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_AF_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_AF_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_AF_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_AVA_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_AVA_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_AVA_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_AVA_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_AZ_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_AZ_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_AZ_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_AZ_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_BE_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_BE_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_BE_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_BE_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_BG_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_BG_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_BG_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_BG_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_BS_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_BS_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_BS_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_BS_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_CH_SIM_CRNN_512 = RECOGNIZER_ZH_SIM_CRNN_512; -export const RECOGNIZER_CH_SIM_CRNN_256 = RECOGNIZER_ZH_SIM_CRNN_256; -export const RECOGNIZER_CH_SIM_CRNN_128 = RECOGNIZER_ZH_SIM_CRNN_128; -export const RECOGNIZER_CH_SIM_CRNN_64 = RECOGNIZER_ZH_SIM_CRNN_64; - -export const RECOGNIZER_CHE_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_CHE_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_CHE_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_CHE_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_CS_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_CS_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_CS_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_CS_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_CY_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_CY_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_CY_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_CY_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_DA_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_DA_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_DA_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_DA_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_DAR_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_DAR_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_DAR_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_DAR_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_DE_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_DE_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_DE_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_DE_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_EN_CRNN_512 = RECOGNIZER_ENGLISH_CRNN_512; -export const RECOGNIZER_EN_CRNN_256 = RECOGNIZER_ENGLISH_CRNN_256; -export const RECOGNIZER_EN_CRNN_128 = RECOGNIZER_ENGLISH_CRNN_128; -export const RECOGNIZER_EN_CRNN_64 = RECOGNIZER_ENGLISH_CRNN_64; - -export const RECOGNIZER_ES_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_ES_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_ES_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_ES_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_ET_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_ET_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_ET_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_ET_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_FR_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_FR_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_FR_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_FR_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_GA_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_GA_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_GA_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_GA_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_HR_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_HR_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_HR_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_HR_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_HU_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_HU_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_HU_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_HU_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_ID_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_ID_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_ID_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_ID_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_INH_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_INH_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_INH_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_INH_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_IC_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_IC_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_IC_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_IC_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_IT_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_IT_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_IT_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_IT_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_JA_CRNN_512 = RECOGNIZER_JAPANESE_CRNN_512; -export const RECOGNIZER_JA_CRNN_256 = RECOGNIZER_JAPANESE_CRNN_256; -export const RECOGNIZER_JA_CRNN_128 = RECOGNIZER_JAPANESE_CRNN_128; -export const RECOGNIZER_JA_CRNN_64 = RECOGNIZER_JAPANESE_CRNN_64; - -export const RECOGNIZER_KBD_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_KBD_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_KBD_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_KBD_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_KN_CRNN_512 = RECOGNIZER_KANNADA_CRNN_512; -export const RECOGNIZER_KN_CRNN_256 = RECOGNIZER_KANNADA_CRNN_256; -export const RECOGNIZER_KN_CRNN_128 = RECOGNIZER_KANNADA_CRNN_128; -export const RECOGNIZER_KN_CRNN_64 = RECOGNIZER_KANNADA_CRNN_64; - -export const RECOGNIZER_KO_CRNN_512 = RECOGNIZER_KOREAN_CRNN_512; -export const RECOGNIZER_KO_CRNN_256 = RECOGNIZER_KOREAN_CRNN_256; -export const RECOGNIZER_KO_CRNN_128 = RECOGNIZER_KOREAN_CRNN_128; -export const RECOGNIZER_KO_CRNN_64 = RECOGNIZER_KOREAN_CRNN_64; - -export const RECOGNIZER_KU_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_KU_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_KU_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_KU_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_LA_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_LA_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_LA_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_LA_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_LBE_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_LBE_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_LBE_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_LBE_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_LEZ_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_LEZ_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_LEZ_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_LEZ_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_LT_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_LT_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_LT_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_LT_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_LV_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_LV_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_LV_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_LV_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_MI_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_MI_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_MI_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_MI_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_MN_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_MN_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_MN_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_MN_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_MS_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_MS_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_MS_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_MS_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_MT_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_MT_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_MT_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_MT_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_NL_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_NL_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_NL_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_NL_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_NO_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_NO_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_NO_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_NO_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_OC_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_OC_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_OC_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_OC_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_PI_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_PI_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_PI_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_PI_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_PL_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_PL_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_PL_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_PL_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_PT_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_PT_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_PT_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_PT_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_RO_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_RO_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_RO_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_RO_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_RU_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_RU_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_RU_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_RU_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_RS_CYRILLIC_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_RS_CYRILLIC_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_RS_CYRILLIC_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_RS_CYRILLIC_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_RS_LATIN_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_RS_LATIN_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_RS_LATIN_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_RS_LATIN_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_SK_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_SK_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_SK_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_SK_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_SL_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_SL_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_SL_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_SL_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_SQ_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_SQ_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_SQ_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_SQ_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_SV_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_SV_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_SV_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_SV_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_SW_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_SW_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_SW_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_SW_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_TAB_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_TAB_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_TAB_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_TAB_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_TE_CRNN_512 = RECOGNIZER_TELUGU_CRNN_512; -export const RECOGNIZER_TE_CRNN_256 = RECOGNIZER_TELUGU_CRNN_256; -export const RECOGNIZER_TE_CRNN_128 = RECOGNIZER_TELUGU_CRNN_128; -export const RECOGNIZER_TE_CRNN_64 = RECOGNIZER_TELUGU_CRNN_64; - -export const RECOGNIZER_TJK_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_TJK_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_TJK_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_TJK_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_TL_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_TL_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_TL_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_TL_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_TR_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_TR_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_TR_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_TR_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_UK_CRNN_512 = RECOGNIZER_CYRILLIC_CRNN_512; -export const RECOGNIZER_UK_CRNN_256 = RECOGNIZER_CYRILLIC_CRNN_256; -export const RECOGNIZER_UK_CRNN_128 = RECOGNIZER_CYRILLIC_CRNN_128; -export const RECOGNIZER_UK_CRNN_64 = RECOGNIZER_CYRILLIC_CRNN_64; - -export const RECOGNIZER_UZ_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_UZ_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_UZ_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_UZ_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; - -export const RECOGNIZER_VI_CRNN_512 = RECOGNIZER_LATIN_CRNN_512; -export const RECOGNIZER_VI_CRNN_256 = RECOGNIZER_LATIN_CRNN_256; -export const RECOGNIZER_VI_CRNN_128 = RECOGNIZER_LATIN_CRNN_128; -export const RECOGNIZER_VI_CRNN_64 = RECOGNIZER_LATIN_CRNN_64; +export const OCR_DARGWA = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'dar' +); +export const VERTICAL_OCR_DARGWA = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'dar' +); + +export const OCR_GERMAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'de' +); +export const VERTICAL_OCR_GERMAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'de' +); + +export const OCR_ENGLISH = createOCRObject( + RECOGNIZER_ENGLISH_CRNN_512, + RECOGNIZER_ENGLISH_CRNN_256, + RECOGNIZER_ENGLISH_CRNN_128, + 'en' +); +export const VERTICAL_OCR_ENGLISH = createVerticalOCRObject( + RECOGNIZER_ENGLISH_CRNN_512, + RECOGNIZER_ENGLISH_CRNN_64, + 'en' +); + +export const OCR_SPANISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'es' +); +export const VERTICAL_OCR_SPANISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'es' +); + +export const OCR_ESTONIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'et' +); +export const VERTICAL_OCR_ESTONIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'et' +); + +export const OCR_FRENCH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'fr' +); +export const VERTICAL_OCR_FRENCH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'fr' +); + +export const OCR_IRISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'ga' +); +export const VERTICAL_OCR_IRISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'ga' +); + +export const OCR_CROATIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'hr' +); +export const VERTICAL_OCR_CROATIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'hr' +); + +export const OCR_HUNGARIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'hu' +); +export const VERTICAL_OCR_HUNGARIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'hu' +); + +export const OCR_INDONESIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'id' +); +export const VERTICAL_OCR_INDONESIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'id' +); + +export const OCR_INGUSH = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'inh' +); +export const VERTICAL_OCR_INGUSH = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'inh' +); + +export const OCR_ICELANDIC = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'ic' +); +export const VERTICAL_OCR_ICELANDIC = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'ic' +); + +export const OCR_ITALIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'it' +); +export const VERTICAL_OCR_ITALIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'it' +); + +export const OCR_JAPANESE = createOCRObject( + RECOGNIZER_JAPANESE_CRNN_512, + RECOGNIZER_JAPANESE_CRNN_256, + RECOGNIZER_JAPANESE_CRNN_128, + 'ja' +); +export const VERTICAL_OCR_JAPANESE = createVerticalOCRObject( + RECOGNIZER_JAPANESE_CRNN_512, + RECOGNIZER_JAPANESE_CRNN_64, + 'ja' +); + +export const OCR_KARBADIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'kbd' +); +export const VERTICAL_OCR_KARBADIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'kbd' +); + +export const OCR_KANNADA = createOCRObject( + RECOGNIZER_KANNADA_CRNN_512, + RECOGNIZER_KANNADA_CRNN_256, + RECOGNIZER_KANNADA_CRNN_128, + 'kn' +); +export const VERTICAL_OCR_KANNADA = createVerticalOCRObject( + RECOGNIZER_KANNADA_CRNN_512, + RECOGNIZER_KANNADA_CRNN_64, + 'kn' +); + +export const OCR_KOREAN = createOCRObject( + RECOGNIZER_KOREAN_CRNN_512, + RECOGNIZER_KOREAN_CRNN_256, + RECOGNIZER_KOREAN_CRNN_128, + 'ko' +); +export const VERTICAL_OCR_KOREAN = createVerticalOCRObject( + RECOGNIZER_KOREAN_CRNN_512, + RECOGNIZER_KOREAN_CRNN_64, + 'ko' +); + +export const OCR_KURDISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'ku' +); +export const VERTICAL_OCR_KURDISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'ku' +); + +export const OCR_LATIN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'la' +); +export const VERTICAL_OCR_LATIN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'la' +); + +export const OCR_LAK = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'lbe' +); +export const VERTICAL_OCR_LAK = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'lbe' +); + +export const OCR_LEZGHIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'lez' +); +export const VERTICAL_OCR_LEZGHIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'lez' +); + +export const OCR_LITHUANIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'lt' +); +export const VERTICAL_OCR_LITHUANIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'lt' +); + +export const OCR_LATVIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'lv' +); +export const VERTICAL_OCR_LATVIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'lv' +); + +export const OCR_MAORI = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'mi' +); +export const VERTICAL_OCR_MAORI = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'mi' +); + +export const OCR_MONGOLIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'mn' +); +export const VERTICAL_OCR_MONGOLIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'mn' +); + +export const OCR_MALAY = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'ms' +); +export const VERTICAL_OCR_MALAY = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'ms' +); + +export const OCR_MALTESE = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'mt' +); +export const VERTICAL_OCR_MALTESE = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'mt' +); + +export const OCR_DUTCH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'nl' +); +export const VERTICAL_OCR_DUTCH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'nl' +); + +export const OCR_NORWEGIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'no' +); +export const VERTICAL_OCR_NORWEGIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'no' +); + +export const OCR_OCCITAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'oc' +); +export const VERTICAL_OCR_OCCITAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'oc' +); + +export const OCR_PALI = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'pi' +); +export const VERTICAL_OCR_PALI = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'pi' +); + +export const OCR_POLISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'pl' +); +export const VERTICAL_OCR_POLISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'pl' +); + +export const OCR_PORTUGUESE = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'pt' +); +export const VERTICAL_OCR_PORTUGUESE = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'pt' +); + +export const OCR_ROMANIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'ro' +); +export const VERTICAL_OCR_ROMANIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'ro' +); + +export const OCR_RUSSIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'ru' +); +export const VERTICAL_OCR_RUSSIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'ru' +); + +export const OCR_SERBIAN_CYRILLIC = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'rsCyrillic' +); +export const VERTICAL_OCR_SERBIAN_CYRILLIC = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'rsCyrillic' +); + +export const OCR_SERBIAN_LATIN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'rsLatin' +); +export const VERTICAL_OCR_SERBIAN_LATIN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'rsLatin' +); + +export const OCR_SLOVAK = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'sk' +); +export const VERTICAL_OCR_SLOVAK = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'sk' +); + +export const OCR_SLOVENIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'sl' +); +export const VERTICAL_OCR_SLOVENIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'sl' +); + +export const OCR_ALBANIAN = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'sq' +); +export const VERTICAL_OCR_ALBANIAN = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'sq' +); + +export const OCR_SWEDISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'sv' +); +export const VERTICAL_OCR_SWEDISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'sv' +); + +export const OCR_SWAHILI = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'sw' +); +export const VERTICAL_OCR_SWAHILI = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'sw' +); + +export const OCR_TABASSARAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'tab' +); +export const VERTICAL_OCR_TABASSARAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'tab' +); + +export const OCR_TELUGU = createOCRObject( + RECOGNIZER_TELUGU_CRNN_512, + RECOGNIZER_TELUGU_CRNN_256, + RECOGNIZER_TELUGU_CRNN_128, + 'te' +); +export const VERTICAL_OCR_TELUGU = createVerticalOCRObject( + RECOGNIZER_TELUGU_CRNN_512, + RECOGNIZER_TELUGU_CRNN_64, + 'te' +); + +export const OCR_TAJIK = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'tjk' +); +export const VERTICAL_OCR_TAJIK = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'tjk' +); + +export const OCR_TAGALOG = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'tl' +); +export const VERTICAL_OCR_TAGALOG = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'tl' +); + +export const OCR_TURKISH = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'tr' +); +export const VERTICAL_OCR_TURKISH = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'tr' +); + +export const OCR_UKRAINIAN = createOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_256, + RECOGNIZER_CYRILLIC_CRNN_128, + 'uk' +); +export const VERTICAL_OCR_UKRAINIAN = createVerticalOCRObject( + RECOGNIZER_CYRILLIC_CRNN_512, + RECOGNIZER_CYRILLIC_CRNN_64, + 'uk' +); + +export const OCR_UZBEK = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'uz' +); +export const VERTICAL_OCR_UZBEK = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'uz' +); + +export const OCR_VIETNAMESE = createOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_256, + RECOGNIZER_LATIN_CRNN_128, + 'vi' +); +export const VERTICAL_OCR_VIETNAMESE = createVerticalOCRObject( + RECOGNIZER_LATIN_CRNN_512, + RECOGNIZER_LATIN_CRNN_64, + 'vi' +); diff --git a/packages/react-native-executorch/src/constants/ocr/symbols.ts b/packages/react-native-executorch/src/constants/ocr/symbols.ts index 0bbc92ac5c..99471c4709 100644 --- a/packages/react-native-executorch/src/constants/ocr/symbols.ts +++ b/packages/react-native-executorch/src/constants/ocr/symbols.ts @@ -19,130 +19,130 @@ export const alphabets = { }; export const symbols = { - //Abaza + // Abaza abq: alphabets.cyrillic, - //Adyghe + // Adyghe ady: alphabets.cyrillic, - //Africans + // Africans af: alphabets.latin, - //Avar + // Avar ava: alphabets.cyrillic, - //Azerbaijani + // Azerbaijani az: alphabets.latin, - //Belarusian + // Belarusian be: alphabets.cyrillic, - //Bulgarian + // Bulgarian bg: alphabets.cyrillic, - //Bosnian + // Bosnian bs: alphabets.latin, - //Simplified Chinese + // Simplified Chinese chSim: alphabets.zhSim, - //Chechen + // Chechen che: alphabets.cyrillic, - //Czech + // Czech cs: alphabets.latin, - //Welsh + // Welsh cy: alphabets.latin, - //Danish + // Danish da: alphabets.latin, - //Dargwa + // Dargwa dar: alphabets.cyrillic, - //German + // German de: alphabets.latin, - //English + // English en: alphabets.english, - //Spanish + // Spanish es: alphabets.latin, - //Estonian + // Estonian et: alphabets.latin, - //French + // French fr: alphabets.latin, - //Irish + // Irish ga: alphabets.latin, - //Croatian + // Croatian hr: alphabets.latin, - //Hungarian + // Hungarian hu: alphabets.latin, - //Indonesian + // Indonesian id: alphabets.latin, - //Ingush + // Ingush inh: alphabets.cyrillic, - //Icelandic + // Icelandic ic: alphabets.latin, - //Italian + // Italian it: alphabets.latin, - //Japanese + // Japanese ja: alphabets.japanese, - //Karbadian + // Karbadian kbd: alphabets.cyrillic, - //Kannada + // Kannada kn: alphabets.kannada, - //Korean + // Korean ko: alphabets.korean, - //Kurdish + // Kurdish ku: alphabets.latin, - //Latin + // Latin la: alphabets.latin, - //Lak + // Lak lbe: alphabets.cyrillic, - //Lezghian + // Lezghian lez: alphabets.cyrillic, - //Lithuanian + // Lithuanian lt: alphabets.latin, - //Latvian + // Latvian lv: alphabets.latin, - //Maori + // Maori mi: alphabets.latin, - //Mongolian + // Mongolian mn: alphabets.cyrillic, - //Malay + // Malay ms: alphabets.latin, - //Maltese + // Maltese mt: alphabets.latin, - //Dutch + // Dutch nl: alphabets.latin, - //Norwegian + // Norwegian no: alphabets.latin, - //Occitan + // Occitan oc: alphabets.latin, - //Pali + // Pali pi: alphabets.latin, - //Polish + // Polish pl: alphabets.latin, - //Portuguese + // Portuguese pt: alphabets.latin, - //Romanian + // Romanian ro: alphabets.latin, - //Russian + // Russian ru: alphabets.cyrillic, - //Serbian (cyrillic) + // Serbian (cyrillic) rsCyrillic: alphabets.cyrillic, - //Serbian (latin) + // Serbian (latin) rsLatin: alphabets.latin, - //Slovak + // Slovak sk: alphabets.latin, - //Slovenian + // Slovenian sl: alphabets.latin, - //Albanian + // Albanian sq: alphabets.latin, - //Swedish + // Swedish sv: alphabets.latin, - //Swahili + // Swahili sw: alphabets.latin, - //Tabassaran + // Tabassaran tab: alphabets.cyrillic, - //Telugu + // Telugu te: alphabets.telugu, - //Tajik + // Tajik tjk: alphabets.cyrillic, - //Tagalog + // Tagalog tl: alphabets.latin, - //Turkish: + // Turkish tr: alphabets.latin, - //Ukrainian + // Ukrainian uk: alphabets.cyrillic, - //Uzbek + // Uzbek uz: alphabets.latin, - //Vietnamese + // Vietnamese vi: alphabets.latin, }; diff --git a/packages/react-native-executorch/src/constants/sttDefaults.ts b/packages/react-native-executorch/src/constants/sttDefaults.ts index 647c9a1468..586040a7ab 100644 --- a/packages/react-native-executorch/src/constants/sttDefaults.ts +++ b/packages/react-native-executorch/src/constants/sttDefaults.ts @@ -1,13 +1,7 @@ import { - MOONSHINE_TINY_ENCODER, - MOONSHINE_TINY_DECODER, - MOONSHINE_TOKENIZER, - WHISPER_TINY_ENCODER, - WHISPER_TINY_DECODER, - WHISPER_TOKENIZER, - WHISPER_TINY_MULTILINGUAL_ENCODER, - WHISPER_TINY_MULTILINGUAL_DECODER, - WHISPER_TINY_MULTILINGUAL_TOKENIZER, + MOONSHINE_TINY, + WHISPER_TINY, + WHISPER_TINY_MULTILINGUAL, } from './modelUrls'; import { AvailableModels, ModelConfig } from '../types/stt'; @@ -17,11 +11,11 @@ export const HAMMING_DIST_THRESHOLD = 1; const whisperTinyModelConfig = { sources: { - encoder: WHISPER_TINY_ENCODER, - decoder: WHISPER_TINY_DECODER, + encoder: WHISPER_TINY.encoderSource, + decoder: WHISPER_TINY.decoderSource, }, tokenizer: { - source: WHISPER_TOKENIZER, + source: WHISPER_TINY.tokenizerSource, bos: 50257, // FIXME: this is a placeholder and needs to be changed eos: 50256, // FIXME: this is a placeholder and needs to be changed }, @@ -30,11 +24,11 @@ const whisperTinyModelConfig = { const moonshineTinyModelConfig = { sources: { - encoder: MOONSHINE_TINY_ENCODER, - decoder: MOONSHINE_TINY_DECODER, + encoder: MOONSHINE_TINY.encoderSource, + decoder: MOONSHINE_TINY.decoderSource, }, tokenizer: { - source: MOONSHINE_TOKENIZER, + source: MOONSHINE_TINY.tokenizerSource, bos: 1, // FIXME: this is a placeholder and needs to be changed eos: 2, // FIXME: this is a placeholder and needs to be changed }, @@ -43,11 +37,11 @@ const moonshineTinyModelConfig = { const whisperTinyMultilingualModelConfig = { sources: { - encoder: WHISPER_TINY_MULTILINGUAL_ENCODER, - decoder: WHISPER_TINY_MULTILINGUAL_DECODER, + encoder: WHISPER_TINY_MULTILINGUAL.encoderSource, + decoder: WHISPER_TINY_MULTILINGUAL.decoderSource, }, tokenizer: { - source: WHISPER_TINY_MULTILINGUAL_TOKENIZER, + source: WHISPER_TINY_MULTILINGUAL.tokenizerSource, bos: 50258, // FIXME: this is a placeholder and needs to be changed eos: 50257, // FIXME: this is a placeholder and needs to be changed }, @@ -84,3 +78,5 @@ export enum STREAMING_ACTION { DATA, STOP, } + +export { AvailableModels }; diff --git a/packages/react-native-executorch/src/controllers/SpeechToTextController.ts b/packages/react-native-executorch/src/controllers/SpeechToTextController.ts index 3a53074057..3661b9938f 100644 --- a/packages/react-native-executorch/src/controllers/SpeechToTextController.ts +++ b/packages/react-native-executorch/src/controllers/SpeechToTextController.ts @@ -100,9 +100,9 @@ export class SpeechToTextController { this.config = MODEL_CONFIGS[modelName]; try { - const tokenizerLoadPromise = this.tokenizerModule.load( - tokenizerSource || this.config.tokenizer.source - ); + const tokenizerLoadPromise = this.tokenizerModule.load({ + tokenizerSource: tokenizerSource || this.config.tokenizer.source, + }); const pathsPromise = ResourceFetcher.fetch( onDownloadProgressCallback, encoderSource || this.config.sources.encoder, @@ -127,7 +127,7 @@ export class SpeechToTextController { // create a separate class for multilingual version of Whisper, since it is the same. We just need // the distinction here, in TS, for start tokens and such. If we introduce // more versions of Whisper, such as the small one, this should be refactored. - modelName = 'whisper'; + modelName = AvailableModels.WHISPER; } try { diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts index 2c4dcec699..06088ec764 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts @@ -3,16 +3,13 @@ import { useNonStaticModule } from '../useNonStaticModule'; import { ClassificationModule } from '../../modules/computer_vision/ClassificationModule'; interface Props { - modelSource: ResourceSource; + model: { modelSource: ResourceSource }; preventLoad?: boolean; } -export const useClassification = ({ - modelSource, - preventLoad = false, -}: Props) => +export const useClassification = ({ model, preventLoad = false }: Props) => useNonStaticModule({ module: ClassificationModule, - loadArgs: [modelSource], + model, preventLoad: preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts index e81de1855c..2fff4d0412 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts @@ -2,15 +2,14 @@ import { ImageEmbeddingsModule } from '../../modules/computer_vision/ImageEmbedd import { ResourceSource } from '../../types/common'; import { useNonStaticModule } from '../useNonStaticModule'; -export const useImageEmbeddings = ({ - modelSource, - preventLoad = false, -}: { - modelSource: ResourceSource; +interface Props { + model: { modelSource: ResourceSource }; preventLoad?: boolean; -}) => +} + +export const useImageEmbeddings = ({ model, preventLoad = false }: Props) => useNonStaticModule({ module: ImageEmbeddingsModule, - loadArgs: [modelSource], + model, preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts index ff2ce941d2..31e26d059e 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageSegmentation.ts @@ -3,16 +3,13 @@ import { ImageSegmentationModule } from '../../modules/computer_vision/ImageSegm import { ResourceSource } from '../../types/common'; interface Props { - modelSource: ResourceSource; + model: { modelSource: ResourceSource }; preventLoad?: boolean; } -export const useImageSegmentation = ({ - modelSource, - preventLoad = false, -}: Props) => +export const useImageSegmentation = ({ model, preventLoad = false }: Props) => useNonStaticModule({ module: ImageSegmentationModule, - loadArgs: [modelSource], + model, preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index e534bca28c..1aa5a2e51c 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -12,18 +12,16 @@ interface OCRModule { } export const useOCR = ({ - detectorSource, - recognizerSources, - language = 'en', + model, preventLoad = false, }: { - detectorSource: ResourceSource; - recognizerSources: { + model: { + detectorSource: ResourceSource; recognizerLarge: ResourceSource; recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; + language: OCRLanguage; }; - language?: OCRLanguage; preventLoad?: boolean; }): OCRModule => { const [error, setError] = useState(null); @@ -31,7 +29,7 @@ export const useOCR = ({ const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - const model = useMemo( + const controllerInstance = useMemo( () => new OCRController({ modelDownloadProgressCallback: setDownloadProgress, @@ -44,19 +42,27 @@ export const useOCR = ({ useEffect(() => { const loadModel = async () => { - await model.loadModel(detectorSource, recognizerSources, language); + await controllerInstance.loadModel( + model.detectorSource, + { + recognizerLarge: model.recognizerLarge, + recognizerMedium: model.recognizerMedium, + recognizerSmall: model.recognizerSmall, + }, + model.language + ); }; if (!preventLoad) { loadModel(); } - // eslint-disable-next-line react-hooks/exhaustive-deps }, [ - model, - detectorSource, - language, - // eslint-disable-next-line react-hooks/exhaustive-deps - JSON.stringify(recognizerSources), + controllerInstance, + model.detectorSource, + model.recognizerLarge, + model.recognizerMedium, + model.recognizerSmall, + model.language, preventLoad, ]); @@ -64,7 +70,7 @@ export const useOCR = ({ error, isReady, isGenerating, - forward: model.forward, + forward: controllerInstance.forward, downloadProgress, }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts index 4d674f3e6b..181b770f4a 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts @@ -3,16 +3,13 @@ import { useNonStaticModule } from '../useNonStaticModule'; import { ObjectDetectionModule } from '../../modules/computer_vision/ObjectDetectionModule'; interface Props { - modelSource: ResourceSource; + model: { modelSource: ResourceSource }; preventLoad?: boolean; } -export const useObjectDetection = ({ - modelSource, - preventLoad = false, -}: Props) => +export const useObjectDetection = ({ model, preventLoad = false }: Props) => useNonStaticModule({ module: ObjectDetectionModule, - loadArgs: [modelSource], + model, preventLoad: preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts index 07a68f4a94..5e70dc9707 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts @@ -3,13 +3,13 @@ import { useNonStaticModule } from '../useNonStaticModule'; import { StyleTransferModule } from '../../modules/computer_vision/StyleTransferModule'; interface Props { - modelSource: ResourceSource; + model: { modelSource: ResourceSource }; preventLoad?: boolean; } -export const useStyleTransfer = ({ modelSource, preventLoad = false }: Props) => +export const useStyleTransfer = ({ model, preventLoad = false }: Props) => useNonStaticModule({ module: StyleTransferModule, - loadArgs: [modelSource], + model, preventLoad: preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts index 7f193c3719..31ea5832d8 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts @@ -12,21 +12,17 @@ interface OCRModule { } export const useVerticalOCR = ({ - detectorSources, - recognizerSources, - language = 'en', + model, independentCharacters = false, preventLoad = false, }: { - detectorSources: { + model: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; - }; - recognizerSources: { recognizerLarge: ResourceSource; recognizerSmall: ResourceSource; + language: OCRLanguage; }; - language?: OCRLanguage; independentCharacters?: boolean; preventLoad?: boolean; }): OCRModule => { @@ -35,7 +31,7 @@ export const useVerticalOCR = ({ const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - const model = useMemo( + const controllerInstance = useMemo( () => new VerticalOCRController({ modelDownloadProgressCallback: setDownloadProgress, @@ -47,27 +43,30 @@ export const useVerticalOCR = ({ ); useEffect(() => { - const loadModel = async () => { - await model.loadModel( - detectorSources, - recognizerSources, - language, + if (preventLoad) return; + + (async () => { + await controllerInstance.loadModel( + { + detectorLarge: model.detectorLarge, + detectorNarrow: model.detectorNarrow, + }, + { + recognizerLarge: model.recognizerLarge, + recognizerSmall: model.recognizerSmall, + }, + model.language, independentCharacters ); - }; - - if (!preventLoad) { - loadModel(); - } - // eslint-disable-next-line react-hooks/exhaustive-deps + })(); }, [ - model, - // eslint-disable-next-line react-hooks/exhaustive-deps - JSON.stringify(detectorSources), - language, + controllerInstance, + model.detectorLarge, + model.detectorNarrow, + model.recognizerLarge, + model.recognizerSmall, + model.language, independentCharacters, - // eslint-disable-next-line react-hooks/exhaustive-deps - JSON.stringify(recognizerSources), preventLoad, ]); @@ -75,7 +74,7 @@ export const useVerticalOCR = ({ error, isReady, isGenerating, - forward: model.forward, + forward: controllerInstance.forward, downloadProgress, }; }; diff --git a/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts b/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts index 42826a07a5..72cd1999f5 100644 --- a/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts +++ b/packages/react-native-executorch/src/hooks/general/useExecutorchModule.ts @@ -13,6 +13,6 @@ export const useExecutorchModule = ({ }: Props) => useNonStaticModule({ module: ExecutorchModule, - loadArgs: [modelSource], + model: { modelSource }, preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index c8522fa1e6..967d9913ee 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -13,14 +13,14 @@ import { LLMController } from '../../controllers/LLMController'; Hook version of LLMModule */ export const useLLM = ({ - modelSource, - tokenizerSource, - tokenizerConfigSource, + model, preventLoad = false, }: { - modelSource: ResourceSource; - tokenizerSource: ResourceSource; - tokenizerConfigSource: ResourceSource; + model: { + modelSource: ResourceSource; + tokenizerSource: ResourceSource; + tokenizerConfigSource: ResourceSource; + }; preventLoad?: boolean; }): LLMType => { const [token, setToken] = useState(''); @@ -36,7 +36,7 @@ export const useLLM = ({ setResponse((prevResponse) => prevResponse + newToken); }, []); - const model = useMemo( + const controllerInstance = useMemo( () => new LLMController({ tokenCallback: tokenCallback, @@ -55,10 +55,10 @@ export const useLLM = ({ (async () => { try { - await model.load({ - modelSource, - tokenizerSource, - tokenizerConfigSource, + await controllerInstance.load({ + modelSource: model.modelSource, + tokenizerSource: model.tokenizerSource, + tokenizerConfigSource: model.tokenizerConfigSource, onDownloadProgressCallback: setDownloadProgress, }); } catch (e) { @@ -67,9 +67,15 @@ export const useLLM = ({ })(); return () => { - model.delete(); + controllerInstance.delete(); }; - }, [modelSource, tokenizerSource, tokenizerConfigSource, preventLoad, model]); + }, [ + controllerInstance, + model.modelSource, + model.tokenizerSource, + model.tokenizerConfigSource, + preventLoad, + ]); // memoization of returned functions const configure = useCallback( @@ -79,31 +85,35 @@ export const useLLM = ({ }: { chatConfig?: Partial; toolsConfig?: ToolsConfig; - }) => model.configure({ chatConfig, toolsConfig }), - [model] + }) => controllerInstance.configure({ chatConfig, toolsConfig }), + [controllerInstance] ); const generate = useCallback( (messages: Message[], tools?: LLMTool[]) => { setResponse(''); - return model.generate(messages, tools); + return controllerInstance.generate(messages, tools); }, - [model] + [controllerInstance] ); const sendMessage = useCallback( (message: string) => { setResponse(''); - return model.sendMessage(message); + return controllerInstance.sendMessage(message); }, - [model] + [controllerInstance] ); const deleteMessage = useCallback( - (index: number) => model.deleteMessage(index), - [model] + (index: number) => controllerInstance.deleteMessage(index), + [controllerInstance] + ); + + const interrupt = useCallback( + () => controllerInstance.interrupt(), + [controllerInstance] ); - const interrupt = useCallback(() => model.interrupt(), [model]); return { messageHistory, diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts index e22b055b91..d1549bf70f 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts @@ -23,19 +23,18 @@ interface SpeechToTextModule { } export const useSpeechToText = ({ - modelName, - encoderSource, - decoderSource, - tokenizerSource, + model, overlapSeconds, windowSize, streamingConfig, preventLoad = false, }: { - modelName: AvailableModels; - encoderSource?: ResourceSource; - decoderSource?: ResourceSource; - tokenizerSource?: ResourceSource; + model: { + modelName: AvailableModels; + encoderSource: ResourceSource; + decoderSource: ResourceSource; + tokenizerSource: ResourceSource; + }; overlapSeconds?: ConstructorParameters< typeof SpeechToTextController >['0']['overlapSeconds']; @@ -53,7 +52,7 @@ export const useSpeechToText = ({ const [isGenerating, setIsGenerating] = useState(false); const [error, setError] = useState(); - const model = useMemo( + const controllerInstance = useMemo( () => new SpeechToTextController({ transcribeCallback: setSequence, @@ -65,16 +64,20 @@ export const useSpeechToText = ({ ); useEffect(() => { - model.configureStreaming(overlapSeconds, windowSize, streamingConfig); - }, [model, overlapSeconds, windowSize, streamingConfig]); + controllerInstance.configureStreaming( + overlapSeconds, + windowSize, + streamingConfig + ); + }, [controllerInstance, overlapSeconds, windowSize, streamingConfig]); useEffect(() => { const loadModel = async () => { - await model.load({ - modelName, - encoderSource, - decoderSource, - tokenizerSource, + await controllerInstance.load({ + modelName: model.modelName, + encoderSource: model.encoderSource, + decoderSource: model.decoderSource, + tokenizerSource: model.tokenizerSource, onDownloadProgressCallback: setDownloadProgress, }); }; @@ -82,11 +85,11 @@ export const useSpeechToText = ({ loadModel(); } }, [ - model, - modelName, - encoderSource, - decoderSource, - tokenizerSource, + controllerInstance, + model.modelName, + model.encoderSource, + model.decoderSource, + model.tokenizerSource, preventLoad, ]); @@ -94,15 +97,20 @@ export const useSpeechToText = ({ isReady, isGenerating, downloadProgress, - configureStreaming: model.configureStreaming, + configureStreaming: controllerInstance.configureStreaming, sequence, error, transcribe: (waveform: number[], audioLanguage?: SpeechToTextLanguage) => - model.transcribe(waveform, audioLanguage), + controllerInstance.transcribe(waveform, audioLanguage), streamingTranscribe: ( streamAction: STREAMING_ACTION, waveform?: number[], audioLanguage?: SpeechToTextLanguage - ) => model.streamingTranscribe(streamAction, waveform, audioLanguage), + ) => + controllerInstance.streamingTranscribe( + streamAction, + waveform, + audioLanguage + ), }; }; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts index 058ed77eb9..18a7e1e453 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts @@ -3,18 +3,16 @@ import { ResourceSource } from '../../types/common'; import { useNonStaticModule } from '../useNonStaticModule'; interface Props { - modelSource: ResourceSource; - tokenizerSource: ResourceSource; + model: { + modelSource: ResourceSource; + tokenizerSource: ResourceSource; + }; preventLoad?: boolean; } -export const useTextEmbeddings = ({ - modelSource, - tokenizerSource, - preventLoad = false, -}: Props) => +export const useTextEmbeddings = ({ model, preventLoad = false }: Props) => useNonStaticModule({ module: TextEmbeddingsModule, - loadArgs: [modelSource, tokenizerSource], + model, preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts index 9748d0fafe..3c01f81fdb 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts @@ -4,17 +4,17 @@ import { ResourceSource } from '../../types/common'; import { ETError, getError } from '../../Error'; export const useTokenizer = ({ - tokenizerSource, + tokenizer, preventLoad = false, }: { - tokenizerSource: ResourceSource; + tokenizer: { tokenizerSource: ResourceSource }; preventLoad?: boolean; }) => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - const model = useMemo(() => new TokenizerModule(), []); + const _tokenizer = useMemo(() => new TokenizerModule(), []); useEffect(() => { if (preventLoad) return; @@ -23,13 +23,16 @@ export const useTokenizer = ({ setError(null); try { setIsReady(false); - await model.load(tokenizerSource, setDownloadProgress); + await _tokenizer.load( + { tokenizerSource: tokenizer.tokenizerSource }, + setDownloadProgress + ); setIsReady(true); } catch (err) { setError((err as Error).message); } })(); - }, [model, tokenizerSource, preventLoad]); + }, [_tokenizer, tokenizer.tokenizerSource, preventLoad]); const stateWrapper = Promise>(fn: T) => { return (...args: Parameters): Promise> => { @@ -37,7 +40,7 @@ export const useTokenizer = ({ if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); try { setIsGenerating(true); - return fn.apply(model, args); + return fn.apply(_tokenizer, args); } finally { setIsGenerating(false); } diff --git a/packages/react-native-executorch/src/hooks/useNonStaticModule.ts b/packages/react-native-executorch/src/hooks/useNonStaticModule.ts index 04a1c8e2d9..fe0e9860ca 100644 --- a/packages/react-native-executorch/src/hooks/useNonStaticModule.ts +++ b/packages/react-native-executorch/src/hooks/useNonStaticModule.ts @@ -18,47 +18,47 @@ export const useNonStaticModule = < ForwardReturn extends Awaited>, >({ module, - loadArgs, + model, preventLoad = false, }: { module: ModuleConstructor; - loadArgs: LoadArgs; + model: LoadArgs[0]; preventLoad?: boolean; }) => { const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - const model = useMemo(() => new module(), [module]); + const moduleInstance = useMemo(() => new module(), [module]); useEffect(() => { - if (!preventLoad) { - (async () => { - setDownloadProgress(0); - setError(null); - try { - setIsReady(false); - await model.load(...loadArgs, setDownloadProgress); - setIsReady(true); - } catch (err) { - setError((err as Error).message); - } - })(); + if (preventLoad) return; + + (async () => { + setDownloadProgress(0); + setError(null); + try { + setIsReady(false); + await moduleInstance.load(model, setDownloadProgress); + setIsReady(true); + } catch (err) { + setError((err as Error).message); + } + })(); + + return () => { + moduleInstance.delete(); + }; - return () => { - model.delete(); - }; - } - return () => {}; // eslint-disable-next-line react-hooks/exhaustive-deps - }, [...loadArgs, preventLoad]); + }, [moduleInstance, ...Object.values(model), preventLoad]); const forward = async (...input: ForwardArgs): Promise => { if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded)); if (isGenerating) throw new Error(getError(ETError.ModelGenerating)); try { setIsGenerating(true); - return await model.forward(...input); + return await moduleInstance.forward(...input); } finally { setIsGenerating(false); } diff --git a/packages/react-native-executorch/src/index.tsx b/packages/react-native-executorch/src/index.tsx index 2f38868e1d..48471df8e8 100644 --- a/packages/react-native-executorch/src/index.tsx +++ b/packages/react-native-executorch/src/index.tsx @@ -86,4 +86,8 @@ export { SpeechToTextLanguage }; export * from './constants/modelUrls'; export * from './constants/ocr/models'; export * from './constants/llmDefaults'; -export { STREAMING_ACTION, MODES } from './constants/sttDefaults'; +export { + STREAMING_ACTION, + MODES, + AvailableModels, +} from './constants/sttDefaults'; diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index edbb36ba95..c7e23b3e52 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -5,12 +5,12 @@ import { BaseNonStaticModule } from '../BaseNonStaticModule'; export class ClassificationModule extends BaseNonStaticModule { async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + model: { modelSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource ); if (paths === null || paths.length < 1) { throw new Error('Download interrupted.'); diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index 3ff8b2a5a0..133f31707d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -5,12 +5,12 @@ import { BaseNonStaticModule } from '../BaseNonStaticModule'; export class ImageEmbeddingsModule extends BaseNonStaticModule { async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + model: { modelSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource ); if (paths === null || paths.length < 1) { throw new Error('Download interrupted.'); diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts index 733219fed4..a7d2b5d93c 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageSegmentationModule.ts @@ -6,12 +6,12 @@ import { BaseNonStaticModule } from '../BaseNonStaticModule'; export class ImageSegmentationModule extends BaseNonStaticModule { async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + model: { modelSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource ); if (paths === null || paths.length < 1) { throw new Error('Download interrupted.'); diff --git a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts index c7a28ef622..e7ffd4cd9a 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts @@ -5,29 +5,32 @@ import { OCRLanguage } from '../../types/ocr'; export class OCRModule { static module: OCRController; - static onDownloadProgressCallback = (_downloadProgress: number) => {}; - static async load( - detectorSource: ResourceSource, - recognizerSources: { + model: { + detectorSource: ResourceSource; recognizerLarge: ResourceSource; recognizerMedium: ResourceSource; recognizerSmall: ResourceSource; + language: OCRLanguage; }, - language: OCRLanguage = 'en' + onDownloadProgressCallback: (progress: number) => void = () => {} ) { this.module = new OCRController({ - modelDownloadProgressCallback: this.onDownloadProgressCallback, + modelDownloadProgressCallback: onDownloadProgressCallback, }); - await this.module.loadModel(detectorSource, recognizerSources, language); + await this.module.loadModel( + model.detectorSource, + { + recognizerLarge: model.recognizerLarge, + recognizerMedium: model.recognizerMedium, + recognizerSmall: model.recognizerSmall, + }, + model.language + ); } static async forward(input: string) { return await this.module.forward(input); } - - static onDownloadProgress(callback: (downloadProgress: number) => void) { - this.onDownloadProgressCallback = callback; - } } diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index 83e66b4da6..1fc16c4924 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -6,12 +6,12 @@ import { BaseNonStaticModule } from '../BaseNonStaticModule'; export class ObjectDetectionModule extends BaseNonStaticModule { async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + model: { modelSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource ); if (paths === null || paths.length < 1) { throw new Error('Download interrupted.'); diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index 4a8ae94965..2969194a43 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -5,12 +5,12 @@ import { BaseNonStaticModule } from '../BaseNonStaticModule'; export class StyleTransferModule extends BaseNonStaticModule { async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + model: { modelSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource ); if (paths === null || paths.length < 1) { throw new Error('Download interrupted.'); diff --git a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts index 4c8b1120de..aaa4e83a10 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts @@ -5,28 +5,31 @@ import { OCRLanguage } from '../../types/ocr'; export class VerticalOCRModule { static module: VerticalOCRController; - static onDownloadProgressCallback = (_downloadProgress: number) => {}; - static async load( - detectorSources: { + model: { detectorLarge: ResourceSource; detectorNarrow: ResourceSource; - }, - recognizerSources: { recognizerLarge: ResourceSource; recognizerSmall: ResourceSource; + language: OCRLanguage; }, - language: OCRLanguage = 'en', - independentCharacters: boolean = false + independentCharacters: boolean, + onDownloadProgressCallback: (progress: number) => void = () => {} ) { this.module = new VerticalOCRController({ - modelDownloadProgressCallback: this.onDownloadProgressCallback, + modelDownloadProgressCallback: onDownloadProgressCallback, }); await this.module.loadModel( - detectorSources, - recognizerSources, - language, + { + detectorLarge: model.detectorLarge, + detectorNarrow: model.detectorNarrow, + }, + { + recognizerLarge: model.recognizerLarge, + recognizerSmall: model.recognizerSmall, + }, + model.language, independentCharacters ); } @@ -34,8 +37,4 @@ export class VerticalOCRModule { static async forward(input: string) { return await this.module.forward(input); } - - static onDownloadProgress(callback: (downloadProgress: number) => void) { - this.onDownloadProgressCallback = callback; - } } diff --git a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts index 8c8e7ec0a7..6ae6337538 100644 --- a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts +++ b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts @@ -6,7 +6,7 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; export class ExecutorchModule extends BaseNonStaticModule { async load( modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts index 520f07591d..2993308078 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts @@ -21,21 +21,16 @@ export class LLMModule { }); } - async load({ - modelSource, - tokenizerSource, - tokenizerConfigSource, - onDownloadProgressCallback, - }: { - modelSource: ResourceSource; - tokenizerSource: ResourceSource; - tokenizerConfigSource: ResourceSource; - onDownloadProgressCallback?: (_downloadProgress: number) => void; - }) { + async load( + model: { + modelSource: ResourceSource; + tokenizerSource: ResourceSource; + tokenizerConfigSource: ResourceSource; + }, + onDownloadProgressCallback: (progress: number) => void = () => {} + ) { await this.controller.load({ - modelSource, - tokenizerSource, - tokenizerConfigSource, + ...model, onDownloadProgressCallback, }); } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts index 4f5ebf9a08..6bda902e38 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts @@ -31,24 +31,20 @@ export class SpeechToTextModule { }); } - async load({ - modelName, - encoderSource, - decoderSource, - tokenizerSource, - onDownloadProgressCallback, - }: { - modelName: AvailableModels; - encoderSource?: ResourceSource; - decoderSource?: ResourceSource; - tokenizerSource?: ResourceSource; - onDownloadProgressCallback?: (downloadProgress: number) => void; - }) { + async load( + model: { + modelName: AvailableModels; + encoderSource?: ResourceSource; + decoderSource?: ResourceSource; + tokenizerSource?: ResourceSource; + }, + onDownloadProgressCallback: (progress: number) => void = () => {} + ) { await this.module.load({ - modelName, - encoderSource, - decoderSource, - tokenizerSource, + modelName: model.modelName, + encoderSource: model.encoderSource, + decoderSource: model.decoderSource, + tokenizerSource: model.tokenizerSource, onDownloadProgressCallback, }); } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts index c883de3ed5..db19ecbb66 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts @@ -4,15 +4,17 @@ import { BaseNonStaticModule } from '../BaseNonStaticModule'; export class TextEmbeddingsModule extends BaseNonStaticModule { async load( - modelSource: ResourceSource, - tokenizerSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + model: { modelSource: ResourceSource; tokenizerSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const modelPromise = ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + model.modelSource + ); + const tokenizerPromise = ResourceFetcher.fetch( + undefined, + model.tokenizerSource ); - const tokenizerPromise = ResourceFetcher.fetch(undefined, tokenizerSource); const [modelResult, tokenizerResult] = await Promise.all([ modelPromise, tokenizerPromise, diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts index 87e1965a37..47f3dd2f91 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts @@ -5,12 +5,12 @@ export class TokenizerModule { nativeModule: any; async load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void = () => {} + tokenizer: { tokenizerSource: ResourceSource }, + onDownloadProgressCallback: (progress: number) => void = () => {} ): Promise { const paths = await ResourceFetcher.fetch( onDownloadProgressCallback, - modelSource + tokenizer.tokenizerSource ); const path = paths?.[0]; if (!path) { diff --git a/packages/react-native-executorch/src/types/stt.ts b/packages/react-native-executorch/src/types/stt.ts index 783316e857..a8074ff786 100644 --- a/packages/react-native-executorch/src/types/stt.ts +++ b/packages/react-native-executorch/src/types/stt.ts @@ -90,4 +90,8 @@ export enum SpeechToTextLanguage { Yiddish = 'yi', } -export type AvailableModels = 'whisper' | 'moonshine' | 'whisperMultilingual'; +export enum AvailableModels { + WHISPER = 'whisper', + MOONSHINE = 'moonshine', + WHISPER_MULTILINGUAL = 'whisperMultilingual', +}