From a63d4e1bad052273e48d0ce3c713048f9f9a921d Mon Sep 17 00:00:00 2001 From: jakmro Date: Sun, 23 Feb 2025 13:51:05 +0100 Subject: [PATCH 1/3] Add LLM messageHistory --- .../main/java/com/swmansion/rnexecutorch/LLM.kt | 15 ++++++++------- .../swmansion/rnexecutorch/utils/ArrayUtils.kt | 9 +++++++++ .../utils/llms/ConversationManager.kt | 17 +++++++++++++---- ios/RnExecutorch/LLM.mm | 4 +++- .../utils/llms/ConversationManager.h | 3 ++- .../utils/llms/ConversationManager.mm | 13 ++++++++++++- src/constants/llamaDefaults.ts | 4 ++++ src/hooks/natural_language_processing/useLLM.ts | 14 ++++++++++++-- .../natural_language_processing/LLMModule.ts | 3 +++ src/native/NativeLLM.ts | 2 ++ src/types/common.ts | 5 +++++ 11 files changed, 73 insertions(+), 16 deletions(-) diff --git a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt index 4ca13ca51c..1626658e4a 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt @@ -3,15 +3,16 @@ package com.swmansion.rnexecutorch import android.util.Log import com.facebook.react.bridge.Promise import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReadableArray import com.swmansion.rnexecutorch.utils.llms.ChatRole import com.swmansion.rnexecutorch.utils.llms.ConversationManager import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN import org.pytorch.executorch.LlamaCallback import org.pytorch.executorch.LlamaModule +import com.swmansion.rnexecutorch.utils.ArrayUtils import java.net.URL -class LLM(reactContext: ReactApplicationContext) : - NativeLLMSpec(reactContext), LlamaCallback { +class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext), LlamaCallback { private var llamaModule: LlamaModule? = null private var tempLlamaResponse = StringBuilder() @@ -38,11 +39,14 @@ class LLM(reactContext: ReactApplicationContext) : modelSource: String, tokenizerSource: String, systemPrompt: String, + messageHistory: ReadableArray, contextWindowLength: Double, promise: Promise ) { try { - this.conversationManager = ConversationManager(contextWindowLength.toInt(), systemPrompt) + this.conversationManager = ConversationManager( + contextWindowLength.toInt(), systemPrompt, ArrayUtils.createMapArray(messageHistory) + ) llamaModule = LlamaModule(1, URL(modelSource).path, URL(tokenizerSource).path, 0.7f) this.tempLlamaResponse.clear() promise.resolve("Model loaded successfully") @@ -51,10 +55,7 @@ class LLM(reactContext: ReactApplicationContext) : } } - override fun runInference( - input: String, - promise: Promise - ) { + override fun runInference(input: String, promise: Promise) { this.conversationManager.addResponse(input, ChatRole.USER) val conversation = this.conversationManager.getConversation() diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt index 10ffe29eb8..8fa9bbf965 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt @@ -2,6 +2,7 @@ package com.swmansion.rnexecutorch.utils import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.ReadableArray +import com.facebook.react.bridge.ReadableMap import org.pytorch.executorch.DType import org.pytorch.executorch.Tensor @@ -47,6 +48,14 @@ class ArrayUtils { return doubleArray } + fun createMapArray(input: ReadableArray): Array> { + val mapArray = Array>(input.size()) { mapOf() } + for (i in 0 until input.size()) { + mapArray[i] = input.getMap(i).toHashMap() as Map + } + return mapArray + } + fun createReadableArray(result: Tensor): ReadableArray { val resultArray = Arguments.createArray() when (result.dtype()) { diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt index 2f59353090..78654b295a 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt @@ -11,7 +11,11 @@ const val END_OF_TEXT_TOKEN = "<|eot_id|>" const val START_HEADER_ID_TOKEN = "<|start_header_id|>" const val END_HEADER_ID_TOKEN = "<|end_header_id|>" -class ConversationManager(private val numMessagesContextWindow: Int, systemPrompt: String) { +class ConversationManager( + private val numMessagesContextWindow: Int, + systemPrompt: String, + messageHistory: Array> +) { private val basePrompt: String; private val messages = ArrayDeque(); @@ -22,6 +26,13 @@ class ConversationManager(private val numMessagesContextWindow: Int, systemPromp systemPrompt + END_OF_TEXT_TOKEN + getHeaderTokenFromRole(ChatRole.USER) + + messageHistory.forEach { message -> + when (message["role"]) { + "user" -> addResponse(message["content"]!!, ChatRole.USER) + "assistant" -> addResponse(message["content"]!!, ChatRole.ASSISTANT) + } + } } fun addResponse(text: String, senderRole: ChatRole) { @@ -31,9 +42,7 @@ class ConversationManager(private val numMessagesContextWindow: Int, systemPromp val formattedMessage: String = if (senderRole == ChatRole.ASSISTANT) { text + getHeaderTokenFromRole(ChatRole.USER) } else { - text + - END_OF_TEXT_TOKEN + - getHeaderTokenFromRole(ChatRole.ASSISTANT) + text + END_OF_TEXT_TOKEN + getHeaderTokenFromRole(ChatRole.ASSISTANT) } this.messages.add(formattedMessage) } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 462c114279..34d40c285f 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -43,6 +43,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt { - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt + messageHistory:(NSArray *)messageHistory contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { @@ -55,7 +56,8 @@ - (void)loadLLM:(NSString *)modelSource self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow:contextWindowLengthUInt - systemPrompt:systemPrompt]; + systemPrompt:systemPrompt + messageHistory:messageHistory]; self->tempLlamaResponse = [NSMutableString string]; resolve(@"Model and tokenizer loaded successfully"); diff --git a/ios/RnExecutorch/utils/llms/ConversationManager.h b/ios/RnExecutorch/utils/llms/ConversationManager.h index 20059a595b..0571eb5db6 100644 --- a/ios/RnExecutorch/utils/llms/ConversationManager.h +++ b/ios/RnExecutorch/utils/llms/ConversationManager.h @@ -23,7 +23,8 @@ inline constexpr std::string_view END_HEADER_ID_TOKEN = "<|end_header_id|>"; } - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages - systemPrompt:(NSString *)systemPrompt; + systemPrompt:(NSString *)systemPrompt + messageHistory:(NSArray *)messageHistory; - (void)addResponse:(NSString *)text senderRole:(ChatRole)senderRole; - (NSString *)getConversation; diff --git a/ios/RnExecutorch/utils/llms/ConversationManager.mm b/ios/RnExecutorch/utils/llms/ConversationManager.mm index a241a51533..4473cb4ef2 100644 --- a/ios/RnExecutorch/utils/llms/ConversationManager.mm +++ b/ios/RnExecutorch/utils/llms/ConversationManager.mm @@ -3,7 +3,8 @@ @implementation ConversationManager - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages - systemPrompt:(NSString *)systemPrompt { + systemPrompt:(NSString *)systemPrompt + messageHistory:(NSArray *)messageHistory { self = [super init]; if (self) { numMessagesContextWindow = numMessages; @@ -12,6 +13,16 @@ - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages basePrompt += [systemPrompt UTF8String]; basePrompt += std::string(END_OF_TEXT_TOKEN); basePrompt += [self getHeaderTokenFromRole:ChatRole::USER]; + + for (const NSDictionary *elem in messageHistory) { + NSString *role = elem[@"role"]; + NSString *content = elem[@"content"]; + if ([role isEqualToString:@"user"]) { + [self addResponse:content senderRole:ChatRole::USER]; + } else if ([role isEqualToString:@"assistant"]) { + [self addResponse:content senderRole:ChatRole::ASSISTANT]; + } + } } return self; } diff --git a/src/constants/llamaDefaults.ts b/src/constants/llamaDefaults.ts index 2b1202fa33..234b7f9e06 100644 --- a/src/constants/llamaDefaults.ts +++ b/src/constants/llamaDefaults.ts @@ -1,5 +1,9 @@ +import { MessageType } from '../types/common'; + export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text."; +export const DEFAULT_MESSAGE_HISTORY: MessageType[] = []; + export const DEFAULT_CONTEXT_WINDOW_LENGTH = 3; export const EOT_TOKEN = '<|eot_id|>'; diff --git a/src/hooks/natural_language_processing/useLLM.ts b/src/hooks/natural_language_processing/useLLM.ts index 5864d40bc4..c86a78d1ce 100644 --- a/src/hooks/natural_language_processing/useLLM.ts +++ b/src/hooks/natural_language_processing/useLLM.ts @@ -2,9 +2,10 @@ import { useCallback, useEffect, useRef, useState } from 'react'; import { EventSubscription } from 'react-native'; import { LLM } from '../../native/RnExecutorchModules'; import { fetchResource } from '../../utils/fetchResource'; -import { ResourceSource, Model } from '../../types/common'; +import { ResourceSource, Model, MessageType } from '../../types/common'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, + DEFAULT_MESSAGE_HISTORY, DEFAULT_SYSTEM_PROMPT, EOT_TOKEN, } from '../../constants/llamaDefaults'; @@ -17,11 +18,13 @@ export const useLLM = ({ modelSource, tokenizerSource, systemPrompt = DEFAULT_SYSTEM_PROMPT, + messageHistory = DEFAULT_MESSAGE_HISTORY, contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH, }: { modelSource: ResourceSource; tokenizerSource: ResourceSource; systemPrompt?: string; + messageHistory?: MessageType[]; contextWindowLength?: number; }): Model => { const [error, setError] = useState(null); @@ -46,6 +49,7 @@ export const useLLM = ({ modelFileUri, tokenizerFileUri, systemPrompt, + messageHistory, contextWindowLength ); @@ -79,7 +83,13 @@ export const useLLM = ({ tokenGeneratedListener.current = null; LLM.deleteModule(); }; - }, [contextWindowLength, modelSource, systemPrompt, tokenizerSource]); + }, [ + modelSource, + tokenizerSource, + systemPrompt, + messageHistory, + contextWindowLength, + ]); const generate = useCallback( async (input: string): Promise => { diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index 7b44e8a90f..c9994e4bac 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -2,6 +2,7 @@ import { LLM } from '../../native/RnExecutorchModules'; import { fetchResource } from '../../utils/fetchResource'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, + DEFAULT_MESSAGE_HISTORY, DEFAULT_SYSTEM_PROMPT, } from '../../constants/llamaDefaults'; import { ResourceSource } from '../../types/common'; @@ -13,6 +14,7 @@ export class LLMModule { modelSource: ResourceSource, tokenizerSource: ResourceSource, systemPrompt = DEFAULT_SYSTEM_PROMPT, + messageHistory = DEFAULT_MESSAGE_HISTORY, contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH ) { try { @@ -26,6 +28,7 @@ export class LLMModule { modelFileUri, tokenizerFileUri, systemPrompt, + messageHistory, contextWindowLength ); } catch (err) { diff --git a/src/native/NativeLLM.ts b/src/native/NativeLLM.ts index b610e30e39..edd9e59480 100644 --- a/src/native/NativeLLM.ts +++ b/src/native/NativeLLM.ts @@ -1,12 +1,14 @@ import type { TurboModule } from 'react-native'; import { TurboModuleRegistry } from 'react-native'; import type { EventEmitter } from 'react-native/Libraries/Types/CodegenTypes'; +import { MessageType } from '../types/common'; export interface Spec extends TurboModule { loadLLM( modelSource: string, tokenizerSource: string, systemPrompt: string, + messageHistory: MessageType[], contextWindowLength: number ): Promise; runInference(input: string): Promise; diff --git a/src/types/common.ts b/src/types/common.ts index ec0daa2c02..17ba2761fe 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -41,3 +41,8 @@ export type Module = | _StyleTransferModule | _ObjectDetectionModule | typeof ETModule; + +export interface MessageType { + role: 'user' | 'assistant'; + content: string; +} From 40d1859fd08e0d942893c5480be891b81f952671 Mon Sep 17 00:00:00 2001 From: jakmro Date: Sun, 23 Feb 2025 15:08:49 +0100 Subject: [PATCH 2/3] Change MessageType in llama example to be similar to MessageType used in useLLM's messageHistory --- examples/llama/components/MessageItem.tsx | 10 +++++++--- examples/llama/screens/ChatScreen.tsx | 9 +++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/examples/llama/components/MessageItem.tsx b/examples/llama/components/MessageItem.tsx index 9f2c757a0d..5690c6588d 100644 --- a/examples/llama/components/MessageItem.tsx +++ b/examples/llama/components/MessageItem.tsx @@ -11,13 +11,17 @@ interface MessageItemProps { const MessageItem = memo(({ message }: MessageItemProps) => { return ( - - {message.from === 'ai' && ( + + {message.role === 'assistant' && ( )} - + ); }); diff --git a/examples/llama/screens/ChatScreen.tsx b/examples/llama/screens/ChatScreen.tsx index 40d9e70be7..a113ef8abe 100644 --- a/examples/llama/screens/ChatScreen.tsx +++ b/examples/llama/screens/ChatScreen.tsx @@ -32,15 +32,12 @@ export default function ChatScreen() { const textInputRef = useRef(null); useEffect(() => { if (llama.response && !llama.isGenerating) { - appendToMessageHistory(llama.response, 'ai'); + appendToMessageHistory(llama.response, 'assistant'); } }, [llama.response, llama.isGenerating]); - const appendToMessageHistory = (input: string, role: SenderType) => { - setChatHistory((prevHistory) => [ - ...prevHistory, - { text: input, from: role }, - ]); + const appendToMessageHistory = (content: string, role: SenderType) => { + setChatHistory((prevHistory) => [...prevHistory, { role, content }]); }; const sendMessage = async () => { From ed9d21218001e59f0d1470eb6ce3b210cab6441c Mon Sep 17 00:00:00 2001 From: jakmro Date: Mon, 24 Feb 2025 09:18:57 +0100 Subject: [PATCH 3/3] Reformat code --- examples/llama/types.d.ts | 6 +- ios/RnExecutorch/LLM.mm | 4 +- ios/RnExecutorch/models/StyleTransferModel.mm | 22 ++--- .../classification/ClassificationModel.mm | 22 ++--- .../models/classification/Utils.h | 2 +- .../models/classification/Utils.mm | 28 +++--- ios/RnExecutorch/utils/ETError.h | 8 +- ios/RnExecutorch/utils/ImageProcessor.mm | 88 +++++++++++-------- ios/RnExecutorch/utils/llms/Constants.mm | 31 ++++--- .../utils/llms/ConversationManager.h | 12 +-- .../utils/llms/ConversationManager.mm | 29 +++--- 11 files changed, 133 insertions(+), 119 deletions(-) diff --git a/examples/llama/types.d.ts b/examples/llama/types.d.ts index dc3880a96d..7ae077192e 100644 --- a/examples/llama/types.d.ts +++ b/examples/llama/types.d.ts @@ -1,6 +1,6 @@ -export type SenderType = 'user' | 'ai'; +export type SenderType = 'user' | 'assistant'; export interface MessageType { - text: string; - from: SenderType; + role: SenderType; + content: string; } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 34d40c285f..65f3971022 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -43,7 +43,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt { - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt - messageHistory:(NSArray *)messageHistory + messageHistory:(NSArray *)messageHistory contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { @@ -57,7 +57,7 @@ - (void)loadLLM:(NSString *)modelSource self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow:contextWindowLengthUInt systemPrompt:systemPrompt - messageHistory:messageHistory]; + messageHistory:messageHistory]; self->tempLlamaResponse = [NSMutableString string]; resolve(@"Model and tokenizer loaded successfully"); diff --git a/ios/RnExecutorch/models/StyleTransferModel.mm b/ios/RnExecutorch/models/StyleTransferModel.mm index 5fbd53de19..6051e24b1f 100644 --- a/ios/RnExecutorch/models/StyleTransferModel.mm +++ b/ios/RnExecutorch/models/StyleTransferModel.mm @@ -6,35 +6,37 @@ @implementation StyleTransferModel { cv::Size originalSize; } -- (cv::Size)getModelImageSize{ - NSArray * inputShape = [module getInputShape: @0]; +- (cv::Size)getModelImageSize { + NSArray *inputShape = [module getInputShape:@0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - + int height = [heightNumber intValue]; int width = [widthNumber intValue]; - + return cv::Size(height, width); } - (NSArray *)preprocess:(cv::Mat &)input { self->originalSize = cv::Size(input.cols, input.rows); - + cv::Size modelImageSize = [self getModelImageSize]; cv::Mat output; cv::resize(input, output, modelImageSize); - - NSArray *modelInput = [ImageProcessor matToNSArray: output]; + + NSArray *modelInput = [ImageProcessor matToNSArray:output]; return modelInput; } - (cv::Mat)postprocess:(NSArray *)output { cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat processedImage = [ImageProcessor arrayToMat: output width:modelImageSize.width height:modelImageSize.height]; + cv::Mat processedImage = [ImageProcessor arrayToMat:output + width:modelImageSize.width + height:modelImageSize.height]; cv::Mat processedOutput; cv::resize(processedImage, processedOutput, originalSize); - + return processedOutput; } @@ -42,7 +44,7 @@ - (NSArray *)preprocess:(cv::Mat &)input { NSArray *modelInput = [self preprocess:input]; NSArray *result = [self forward:modelInput]; input = [self postprocess:result[0]]; - + return input; } diff --git a/ios/RnExecutorch/models/classification/ClassificationModel.mm b/ios/RnExecutorch/models/classification/ClassificationModel.mm index f1b3f947cb..8e7973e265 100644 --- a/ios/RnExecutorch/models/classification/ClassificationModel.mm +++ b/ios/RnExecutorch/models/classification/ClassificationModel.mm @@ -1,27 +1,27 @@ #import "ClassificationModel.h" -#import "opencv2/opencv.hpp" -#import "Utils.h" -#import "Constants.h" #import "../../utils/ImageProcessor.h" +#import "Constants.h" +#import "Utils.h" +#import "opencv2/opencv.hpp" @implementation ClassificationModel - (cv::Size)getModelImageSize { - NSArray * inputShape = [module getInputShape: 0]; + NSArray *inputShape = [module getInputShape:0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - + int height = [heightNumber intValue]; int width = [widthNumber intValue]; - + return cv::Size(height, width); } - (NSArray *)preprocess:(cv::Mat &)input { cv::Size modelImageSize = [self getModelImageSize]; cv::resize(input, input, modelImageSize); - - NSArray *modelInput = [ImageProcessor matToNSArray: input]; + + NSArray *modelInput = [ImageProcessor matToNSArray:input]; return modelInput; } @@ -32,16 +32,16 @@ - (NSDictionary *)postprocess:(NSArray *)output { for (NSUInteger i = 0; i < output.count; ++i) { outputVector[i] = [output[i] doubleValue]; } - + std::vector probabilities = softmax(outputVector); NSMutableDictionary *result = [NSMutableDictionary dictionary]; - + for (int i = 0; i < probabilities.size(); ++i) { NSString *className = @(imagenet1k_v1_labels[i].c_str()); NSNumber *probability = @(probabilities[i]); result[className] = probability; } - + return result; } diff --git a/ios/RnExecutorch/models/classification/Utils.h b/ios/RnExecutorch/models/classification/Utils.h index 5785a5c4be..102a2bc546 100644 --- a/ios/RnExecutorch/models/classification/Utils.h +++ b/ios/RnExecutorch/models/classification/Utils.h @@ -1,3 +1,3 @@ #include -std::vector softmax(const std::vector& v); \ No newline at end of file +std::vector softmax(const std::vector &v); diff --git a/ios/RnExecutorch/models/classification/Utils.mm b/ios/RnExecutorch/models/classification/Utils.mm index da613019d0..84e4bd6295 100644 --- a/ios/RnExecutorch/models/classification/Utils.mm +++ b/ios/RnExecutorch/models/classification/Utils.mm @@ -1,20 +1,20 @@ #include "Utils.h" -#include #include +#include -std::vector softmax(const std::vector& v) { - std::vector result(v.size()); - double maxVal = *std::max_element(v.begin(), v.end()); +std::vector softmax(const std::vector &v) { + std::vector result(v.size()); + double maxVal = *std::max_element(v.begin(), v.end()); - double sumExp = 0.0; - for (size_t i = 0; i < v.size(); ++i) { - result[i] = std::exp(v[i] - maxVal); - sumExp += result[i]; - } + double sumExp = 0.0; + for (size_t i = 0; i < v.size(); ++i) { + result[i] = std::exp(v[i] - maxVal); + sumExp += result[i]; + } - for (size_t i = 0; i < v.size(); ++i) { - result[i] /= sumExp; - } + for (size_t i = 0; i < v.size(); ++i) { + result[i] /= sumExp; + } - return result; -} \ No newline at end of file + return result; +} diff --git a/ios/RnExecutorch/utils/ETError.h b/ios/RnExecutorch/utils/ETError.h index f1394a011c..2b95d0f143 100644 --- a/ios/RnExecutorch/utils/ETError.h +++ b/ios/RnExecutorch/utils/ETError.h @@ -3,23 +3,23 @@ typedef NS_ENUM(NSUInteger, ETError) { ModuleNotLoaded = 0x66, FileWriteFailed = 0x67, InvalidModelSource = 0xff, - + Ok = 0x00, Internal = 0x01, InvalidState = 0x02, EndOfMethod = 0x03, - + NotSupported = 0x10, NotImplemented = 0x11, InvalidArgument = 0x12, InvalidType = 0x13, OperatorMissing = 0x14, - + NotFound = 0x20, MemoryAllocationFailed = 0x21, AccessFailed = 0x22, InvalidProgram = 0x23, - + DelegateInvalidCompatibility = 0x30, DelegateMemoryAllocationFailed = 0x31, DelegateInvalidHandle = 0x32 diff --git a/ios/RnExecutorch/utils/ImageProcessor.mm b/ios/RnExecutorch/utils/ImageProcessor.mm index feab17f608..a0f349ebf3 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.mm +++ b/ios/RnExecutorch/utils/ImageProcessor.mm @@ -5,11 +5,12 @@ @implementation ImageProcessor + (NSArray *)matToNSArray:(const cv::Mat &)mat { int pixelCount = mat.cols * mat.rows; - NSMutableArray *floatArray = [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; + NSMutableArray *floatArray = + [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; for (NSUInteger k = 0; k < pixelCount * 3; k++) { [floatArray addObject:@0.0]; } - + for (int i = 0; i < pixelCount; i++) { int row = i / mat.cols; int col = i % mat.cols; @@ -18,78 +19,89 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat { floatArray[1 * pixelCount + i] = @(pixel[1] / 255.0f); floatArray[2 * pixelCount + i] = @(pixel[0] / 255.0f); } - + return floatArray; } + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { cv::Mat mat(height, width, CV_8UC3); - + int pixelCount = width * height; for (int i = 0; i < pixelCount; i++) { int row = i / width; int col = i % width; float r = 0, g = 0, b = 0; - - r = [[array objectAtIndex: 0 * pixelCount + i] floatValue]; - g = [[array objectAtIndex: 1 * pixelCount + i] floatValue]; - b = [[array objectAtIndex: 2 * pixelCount + i] floatValue]; - + + r = [[array objectAtIndex:0 * pixelCount + i] floatValue]; + g = [[array objectAtIndex:1 * pixelCount + i] floatValue]; + b = [[array objectAtIndex:2 * pixelCount + i] floatValue]; + cv::Vec3b color((uchar)(b * 255), (uchar)(g * 255), (uchar)(r * 255)); mat.at(row, col) = color; } - + return mat; } -+ (NSString *)saveToTempFile:(const cv::Mat&)image { ++ (NSString *)saveToTempFile:(const cv::Mat &)image { NSString *uniqueID = [[NSUUID UUID] UUIDString]; - NSString *filename = [NSString stringWithFormat:@"rn_executorch_%@.png", uniqueID]; - NSString *outputPath = [NSTemporaryDirectory() stringByAppendingPathComponent:filename]; - + NSString *filename = + [NSString stringWithFormat:@"rn_executorch_%@.png", uniqueID]; + NSString *outputPath = + [NSTemporaryDirectory() stringByAppendingPathComponent:filename]; + std::string filePath = [outputPath UTF8String]; if (!cv::imwrite(filePath, image)) { - @throw [NSException exceptionWithName:@"ImageSaveException" - reason:[NSString stringWithFormat:@"%ld", (long)FileWriteFailed] - userInfo:nil]; + @throw [NSException + exceptionWithName:@"ImageSaveException" + reason:[NSString + stringWithFormat:@"%ld", (long)FileWriteFailed] + userInfo:nil]; } - + return [NSString stringWithFormat:@"file://%@", outputPath]; } + (cv::Mat)readImage:(NSString *)source { NSURL *url = [NSURL URLWithString:source]; - + cv::Mat inputImage; - if([[url scheme] isEqualToString: @"data"]){ - //base64 + if ([[url scheme] isEqualToString:@"data"]) { + // base64 NSArray *parts = [source componentsSeparatedByString:@","]; if ([parts count] < 2) { - @throw [NSException exceptionWithName:@"readImage_error" - reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] - userInfo:nil]; + @throw [NSException + exceptionWithName:@"readImage_error" + reason:[NSString + stringWithFormat:@"%ld", (long)InvalidArgument] + userInfo:nil]; } NSString *encodedString = parts[1]; - NSData *data = [[NSData alloc] initWithBase64EncodedString:encodedString options:NSDataBase64DecodingIgnoreUnknownCharacters]; + NSData *data = [[NSData alloc] + initWithBase64EncodedString:encodedString + options: + NSDataBase64DecodingIgnoreUnknownCharacters]; cv::Mat encodedData(1, [data length], CV_8UC1, (void *)data.bytes); inputImage = cv::imdecode(encodedData, cv::IMREAD_COLOR); - } - else if([[url scheme] isEqualToString: @"file"]){ - //local file + } else if ([[url scheme] isEqualToString:@"file"]) { + // local file inputImage = cv::imread([[url path] UTF8String], cv::IMREAD_COLOR); - } - else { - //external file + } else { + // external file NSData *data = [NSData dataWithContentsOfURL:url]; - inputImage = cv::imdecode(cv::Mat(1, [data length], CV_8UC1, (void*)data.bytes), cv::IMREAD_COLOR); + inputImage = + cv::imdecode(cv::Mat(1, [data length], CV_8UC1, (void *)data.bytes), + cv::IMREAD_COLOR); } - - if(inputImage.empty()){ - @throw [NSException exceptionWithName:@"readImage_error" - reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] - userInfo:nil]; + + if (inputImage.empty()) { + @throw [NSException + exceptionWithName:@"readImage_error" + reason:[NSString + stringWithFormat:@"%ld", (long)InvalidArgument] + userInfo:nil]; } - + return inputImage; } diff --git a/ios/RnExecutorch/utils/llms/Constants.mm b/ios/RnExecutorch/utils/llms/Constants.mm index 2a30b02b4b..7ca223759c 100644 --- a/ios/RnExecutorch/utils/llms/Constants.mm +++ b/ios/RnExecutorch/utils/llms/Constants.mm @@ -1,20 +1,23 @@ #import "Constants.h" -#import #import "ConversationManager.h" +#import -NSString * const END_OF_TEXT_TOKEN_NS = [[NSString alloc] initWithBytes:END_OF_TEXT_TOKEN.data() - length:END_OF_TEXT_TOKEN.size() - encoding:NSUTF8StringEncoding]; - -NSString * const BEGIN_OF_TEXT_TOKEN_NS = [[NSString alloc] initWithBytes:BEGIN_OF_TEXT_TOKEN.data() - length:BEGIN_OF_TEXT_TOKEN.size() - encoding:NSUTF8StringEncoding]; +NSString *const END_OF_TEXT_TOKEN_NS = + [[NSString alloc] initWithBytes:END_OF_TEXT_TOKEN.data() + length:END_OF_TEXT_TOKEN.size() + encoding:NSUTF8StringEncoding]; -NSString * const START_HEADER_ID_TOKEN_NS = [[NSString alloc] initWithBytes:START_HEADER_ID_TOKEN.data() - length:START_HEADER_ID_TOKEN.size() - encoding:NSUTF8StringEncoding]; +NSString *const BEGIN_OF_TEXT_TOKEN_NS = + [[NSString alloc] initWithBytes:BEGIN_OF_TEXT_TOKEN.data() + length:BEGIN_OF_TEXT_TOKEN.size() + encoding:NSUTF8StringEncoding]; +NSString *const START_HEADER_ID_TOKEN_NS = + [[NSString alloc] initWithBytes:START_HEADER_ID_TOKEN.data() + length:START_HEADER_ID_TOKEN.size() + encoding:NSUTF8StringEncoding]; -NSString * const END_HEADER_ID_TOKEN_NS = [[NSString alloc] initWithBytes:END_HEADER_ID_TOKEN.data() - length:END_HEADER_ID_TOKEN.size() - encoding:NSUTF8StringEncoding]; +NSString *const END_HEADER_ID_TOKEN_NS = + [[NSString alloc] initWithBytes:END_HEADER_ID_TOKEN.data() + length:END_HEADER_ID_TOKEN.size() + encoding:NSUTF8StringEncoding]; diff --git a/ios/RnExecutorch/utils/llms/ConversationManager.h b/ios/RnExecutorch/utils/llms/ConversationManager.h index 0571eb5db6..f3d3a95e1c 100644 --- a/ios/RnExecutorch/utils/llms/ConversationManager.h +++ b/ios/RnExecutorch/utils/llms/ConversationManager.h @@ -1,22 +1,16 @@ #import +#import #import #import -#import -enum class ChatRole -{ - SYSTEM, - USER, - ASSISTANT -}; +enum class ChatRole { SYSTEM, USER, ASSISTANT }; inline constexpr std::string_view BEGIN_OF_TEXT_TOKEN = "<|begin_of_text|>"; inline constexpr std::string_view END_OF_TEXT_TOKEN = "<|eot_id|>"; inline constexpr std::string_view START_HEADER_ID_TOKEN = "<|start_header_id|>"; inline constexpr std::string_view END_HEADER_ID_TOKEN = "<|end_header_id|>"; -@interface ConversationManager : NSObject -{ +@interface ConversationManager : NSObject { NSUInteger numMessagesContextWindow; std::string basePrompt; std::deque messages; diff --git a/ios/RnExecutorch/utils/llms/ConversationManager.mm b/ios/RnExecutorch/utils/llms/ConversationManager.mm index 4473cb4ef2..758c5a8198 100644 --- a/ios/RnExecutorch/utils/llms/ConversationManager.mm +++ b/ios/RnExecutorch/utils/llms/ConversationManager.mm @@ -14,9 +14,9 @@ - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages basePrompt += std::string(END_OF_TEXT_TOKEN); basePrompt += [self getHeaderTokenFromRole:ChatRole::USER]; - for (const NSDictionary *elem in messageHistory) { - NSString *role = elem[@"role"]; - NSString *content = elem[@"content"]; + for (const NSDictionary *message in messageHistory) { + NSString *role = message[@"role"]; + NSString *content = message[@"content"]; if ([role isEqualToString:@"user"]) { [self addResponse:content senderRole:ChatRole::USER]; } else if ([role isEqualToString:@"assistant"]) { @@ -31,7 +31,7 @@ - (void)addResponse:(NSString *)text senderRole:(ChatRole)senderRole { if (messages.size() >= numMessagesContextWindow) { messages.pop_front(); } - + std::string formattedMessage; if (senderRole == ChatRole::ASSISTANT) { formattedMessage = [text UTF8String]; @@ -46,7 +46,7 @@ - (void)addResponse:(NSString *)text senderRole:(ChatRole)senderRole { - (NSString *)getConversation { std::string prompt = basePrompt; - for (const auto& elem : messages) { + for (const auto &elem : messages) { prompt += elem; } return [NSString stringWithUTF8String:prompt.c_str()]; @@ -54,14 +54,17 @@ - (NSString *)getConversation { - (std::string)getHeaderTokenFromRole:(ChatRole)role { switch (role) { - case ChatRole::SYSTEM: - return std::string(START_HEADER_ID_TOKEN) + "system" + std::string(END_HEADER_ID_TOKEN); - case ChatRole::USER: - return std::string(START_HEADER_ID_TOKEN) + "user" + std::string(END_HEADER_ID_TOKEN); - case ChatRole::ASSISTANT: - return std::string(START_HEADER_ID_TOKEN) + "assistant" + std::string(END_HEADER_ID_TOKEN); - default: - return ""; + case ChatRole::SYSTEM: + return std::string(START_HEADER_ID_TOKEN) + "system" + + std::string(END_HEADER_ID_TOKEN); + case ChatRole::USER: + return std::string(START_HEADER_ID_TOKEN) + "user" + + std::string(END_HEADER_ID_TOKEN); + case ChatRole::ASSISTANT: + return std::string(START_HEADER_ID_TOKEN) + "assistant" + + std::string(END_HEADER_ID_TOKEN); + default: + return ""; } }