diff --git a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt index 4ca13ca51c..1626658e4a 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/LLM.kt @@ -3,15 +3,16 @@ package com.swmansion.rnexecutorch import android.util.Log import com.facebook.react.bridge.Promise import com.facebook.react.bridge.ReactApplicationContext +import com.facebook.react.bridge.ReadableArray import com.swmansion.rnexecutorch.utils.llms.ChatRole import com.swmansion.rnexecutorch.utils.llms.ConversationManager import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN import org.pytorch.executorch.LlamaCallback import org.pytorch.executorch.LlamaModule +import com.swmansion.rnexecutorch.utils.ArrayUtils import java.net.URL -class LLM(reactContext: ReactApplicationContext) : - NativeLLMSpec(reactContext), LlamaCallback { +class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext), LlamaCallback { private var llamaModule: LlamaModule? = null private var tempLlamaResponse = StringBuilder() @@ -38,11 +39,14 @@ class LLM(reactContext: ReactApplicationContext) : modelSource: String, tokenizerSource: String, systemPrompt: String, + messageHistory: ReadableArray, contextWindowLength: Double, promise: Promise ) { try { - this.conversationManager = ConversationManager(contextWindowLength.toInt(), systemPrompt) + this.conversationManager = ConversationManager( + contextWindowLength.toInt(), systemPrompt, ArrayUtils.createMapArray(messageHistory) + ) llamaModule = LlamaModule(1, URL(modelSource).path, URL(tokenizerSource).path, 0.7f) this.tempLlamaResponse.clear() promise.resolve("Model loaded successfully") @@ -51,10 +55,7 @@ class LLM(reactContext: ReactApplicationContext) : } } - override fun runInference( - input: String, - promise: Promise - ) { + override fun runInference(input: String, promise: Promise) { this.conversationManager.addResponse(input, ChatRole.USER) val conversation = this.conversationManager.getConversation() diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt index 3dcbbed45d..ef070bcabf 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/ArrayUtils.kt @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.utils import android.util.Log import com.facebook.react.bridge.Arguments import com.facebook.react.bridge.ReadableArray +import com.facebook.react.bridge.ReadableMap import org.pytorch.executorch.DType import org.pytorch.executorch.Tensor @@ -35,6 +36,15 @@ class ArrayUtils { fun createDoubleArray(input: ReadableArray): DoubleArray { return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray() } + + fun createMapArray(input: ReadableArray): Array> { + val mapArray = Array>(input.size()) { mapOf() } + for (i in 0 until input.size()) { + mapArray[i] = input.getMap(i).toHashMap() as Map + } + return mapArray + } + fun createReadableArrayFromTensor(result: Tensor): ReadableArray { val resultArray = Arguments.createArray() diff --git a/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt b/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt index 2f59353090..78654b295a 100644 --- a/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt +++ b/android/src/main/java/com/swmansion/rnexecutorch/utils/llms/ConversationManager.kt @@ -11,7 +11,11 @@ const val END_OF_TEXT_TOKEN = "<|eot_id|>" const val START_HEADER_ID_TOKEN = "<|start_header_id|>" const val END_HEADER_ID_TOKEN = "<|end_header_id|>" -class ConversationManager(private val numMessagesContextWindow: Int, systemPrompt: String) { +class ConversationManager( + private val numMessagesContextWindow: Int, + systemPrompt: String, + messageHistory: Array> +) { private val basePrompt: String; private val messages = ArrayDeque(); @@ -22,6 +26,13 @@ class ConversationManager(private val numMessagesContextWindow: Int, systemPromp systemPrompt + END_OF_TEXT_TOKEN + getHeaderTokenFromRole(ChatRole.USER) + + messageHistory.forEach { message -> + when (message["role"]) { + "user" -> addResponse(message["content"]!!, ChatRole.USER) + "assistant" -> addResponse(message["content"]!!, ChatRole.ASSISTANT) + } + } } fun addResponse(text: String, senderRole: ChatRole) { @@ -31,9 +42,7 @@ class ConversationManager(private val numMessagesContextWindow: Int, systemPromp val formattedMessage: String = if (senderRole == ChatRole.ASSISTANT) { text + getHeaderTokenFromRole(ChatRole.USER) } else { - text + - END_OF_TEXT_TOKEN + - getHeaderTokenFromRole(ChatRole.ASSISTANT) + text + END_OF_TEXT_TOKEN + getHeaderTokenFromRole(ChatRole.ASSISTANT) } this.messages.add(formattedMessage) } diff --git a/examples/llama/components/MessageItem.tsx b/examples/llama/components/MessageItem.tsx index 9f2c757a0d..5690c6588d 100644 --- a/examples/llama/components/MessageItem.tsx +++ b/examples/llama/components/MessageItem.tsx @@ -11,13 +11,17 @@ interface MessageItemProps { const MessageItem = memo(({ message }: MessageItemProps) => { return ( - - {message.from === 'ai' && ( + + {message.role === 'assistant' && ( )} - + ); }); diff --git a/examples/llama/screens/ChatScreen.tsx b/examples/llama/screens/ChatScreen.tsx index 40d9e70be7..a113ef8abe 100644 --- a/examples/llama/screens/ChatScreen.tsx +++ b/examples/llama/screens/ChatScreen.tsx @@ -32,15 +32,12 @@ export default function ChatScreen() { const textInputRef = useRef(null); useEffect(() => { if (llama.response && !llama.isGenerating) { - appendToMessageHistory(llama.response, 'ai'); + appendToMessageHistory(llama.response, 'assistant'); } }, [llama.response, llama.isGenerating]); - const appendToMessageHistory = (input: string, role: SenderType) => { - setChatHistory((prevHistory) => [ - ...prevHistory, - { text: input, from: role }, - ]); + const appendToMessageHistory = (content: string, role: SenderType) => { + setChatHistory((prevHistory) => [...prevHistory, { role, content }]); }; const sendMessage = async () => { diff --git a/examples/llama/types.d.ts b/examples/llama/types.d.ts index dc3880a96d..7ae077192e 100644 --- a/examples/llama/types.d.ts +++ b/examples/llama/types.d.ts @@ -1,6 +1,6 @@ -export type SenderType = 'user' | 'ai'; +export type SenderType = 'user' | 'assistant'; export interface MessageType { - text: string; - from: SenderType; + role: SenderType; + content: string; } diff --git a/ios/RnExecutorch/LLM.mm b/ios/RnExecutorch/LLM.mm index 462c114279..65f3971022 100644 --- a/ios/RnExecutorch/LLM.mm +++ b/ios/RnExecutorch/LLM.mm @@ -43,6 +43,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt { - (void)loadLLM:(NSString *)modelSource tokenizerSource:(NSString *)tokenizerSource systemPrompt:(NSString *)systemPrompt + messageHistory:(NSArray *)messageHistory contextWindowLength:(double)contextWindowLength resolve:(RCTPromiseResolveBlock)resolve reject:(RCTPromiseRejectBlock)reject { @@ -55,7 +56,8 @@ - (void)loadLLM:(NSString *)modelSource self->conversationManager = [[ConversationManager alloc] initWithNumMessagesContextWindow:contextWindowLengthUInt - systemPrompt:systemPrompt]; + systemPrompt:systemPrompt + messageHistory:messageHistory]; self->tempLlamaResponse = [NSMutableString string]; resolve(@"Model and tokenizer loaded successfully"); diff --git a/ios/RnExecutorch/models/StyleTransferModel.mm b/ios/RnExecutorch/models/StyleTransferModel.mm index 5fbd53de19..6051e24b1f 100644 --- a/ios/RnExecutorch/models/StyleTransferModel.mm +++ b/ios/RnExecutorch/models/StyleTransferModel.mm @@ -6,35 +6,37 @@ @implementation StyleTransferModel { cv::Size originalSize; } -- (cv::Size)getModelImageSize{ - NSArray * inputShape = [module getInputShape: @0]; +- (cv::Size)getModelImageSize { + NSArray *inputShape = [module getInputShape:@0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - + int height = [heightNumber intValue]; int width = [widthNumber intValue]; - + return cv::Size(height, width); } - (NSArray *)preprocess:(cv::Mat &)input { self->originalSize = cv::Size(input.cols, input.rows); - + cv::Size modelImageSize = [self getModelImageSize]; cv::Mat output; cv::resize(input, output, modelImageSize); - - NSArray *modelInput = [ImageProcessor matToNSArray: output]; + + NSArray *modelInput = [ImageProcessor matToNSArray:output]; return modelInput; } - (cv::Mat)postprocess:(NSArray *)output { cv::Size modelImageSize = [self getModelImageSize]; - cv::Mat processedImage = [ImageProcessor arrayToMat: output width:modelImageSize.width height:modelImageSize.height]; + cv::Mat processedImage = [ImageProcessor arrayToMat:output + width:modelImageSize.width + height:modelImageSize.height]; cv::Mat processedOutput; cv::resize(processedImage, processedOutput, originalSize); - + return processedOutput; } @@ -42,7 +44,7 @@ - (NSArray *)preprocess:(cv::Mat &)input { NSArray *modelInput = [self preprocess:input]; NSArray *result = [self forward:modelInput]; input = [self postprocess:result[0]]; - + return input; } diff --git a/ios/RnExecutorch/models/classification/ClassificationModel.mm b/ios/RnExecutorch/models/classification/ClassificationModel.mm index f1b3f947cb..8e7973e265 100644 --- a/ios/RnExecutorch/models/classification/ClassificationModel.mm +++ b/ios/RnExecutorch/models/classification/ClassificationModel.mm @@ -1,27 +1,27 @@ #import "ClassificationModel.h" -#import "opencv2/opencv.hpp" -#import "Utils.h" -#import "Constants.h" #import "../../utils/ImageProcessor.h" +#import "Constants.h" +#import "Utils.h" +#import "opencv2/opencv.hpp" @implementation ClassificationModel - (cv::Size)getModelImageSize { - NSArray * inputShape = [module getInputShape: 0]; + NSArray *inputShape = [module getInputShape:0]; NSNumber *widthNumber = inputShape.lastObject; NSNumber *heightNumber = inputShape[inputShape.count - 2]; - + int height = [heightNumber intValue]; int width = [widthNumber intValue]; - + return cv::Size(height, width); } - (NSArray *)preprocess:(cv::Mat &)input { cv::Size modelImageSize = [self getModelImageSize]; cv::resize(input, input, modelImageSize); - - NSArray *modelInput = [ImageProcessor matToNSArray: input]; + + NSArray *modelInput = [ImageProcessor matToNSArray:input]; return modelInput; } @@ -32,16 +32,16 @@ - (NSDictionary *)postprocess:(NSArray *)output { for (NSUInteger i = 0; i < output.count; ++i) { outputVector[i] = [output[i] doubleValue]; } - + std::vector probabilities = softmax(outputVector); NSMutableDictionary *result = [NSMutableDictionary dictionary]; - + for (int i = 0; i < probabilities.size(); ++i) { NSString *className = @(imagenet1k_v1_labels[i].c_str()); NSNumber *probability = @(probabilities[i]); result[className] = probability; } - + return result; } diff --git a/ios/RnExecutorch/models/classification/Utils.h b/ios/RnExecutorch/models/classification/Utils.h index 5785a5c4be..102a2bc546 100644 --- a/ios/RnExecutorch/models/classification/Utils.h +++ b/ios/RnExecutorch/models/classification/Utils.h @@ -1,3 +1,3 @@ #include -std::vector softmax(const std::vector& v); \ No newline at end of file +std::vector softmax(const std::vector &v); diff --git a/ios/RnExecutorch/models/classification/Utils.mm b/ios/RnExecutorch/models/classification/Utils.mm index da613019d0..84e4bd6295 100644 --- a/ios/RnExecutorch/models/classification/Utils.mm +++ b/ios/RnExecutorch/models/classification/Utils.mm @@ -1,20 +1,20 @@ #include "Utils.h" -#include #include +#include -std::vector softmax(const std::vector& v) { - std::vector result(v.size()); - double maxVal = *std::max_element(v.begin(), v.end()); +std::vector softmax(const std::vector &v) { + std::vector result(v.size()); + double maxVal = *std::max_element(v.begin(), v.end()); - double sumExp = 0.0; - for (size_t i = 0; i < v.size(); ++i) { - result[i] = std::exp(v[i] - maxVal); - sumExp += result[i]; - } + double sumExp = 0.0; + for (size_t i = 0; i < v.size(); ++i) { + result[i] = std::exp(v[i] - maxVal); + sumExp += result[i]; + } - for (size_t i = 0; i < v.size(); ++i) { - result[i] /= sumExp; - } + for (size_t i = 0; i < v.size(); ++i) { + result[i] /= sumExp; + } - return result; -} \ No newline at end of file + return result; +} diff --git a/ios/RnExecutorch/utils/ETError.h b/ios/RnExecutorch/utils/ETError.h index f1394a011c..2b95d0f143 100644 --- a/ios/RnExecutorch/utils/ETError.h +++ b/ios/RnExecutorch/utils/ETError.h @@ -3,23 +3,23 @@ typedef NS_ENUM(NSUInteger, ETError) { ModuleNotLoaded = 0x66, FileWriteFailed = 0x67, InvalidModelSource = 0xff, - + Ok = 0x00, Internal = 0x01, InvalidState = 0x02, EndOfMethod = 0x03, - + NotSupported = 0x10, NotImplemented = 0x11, InvalidArgument = 0x12, InvalidType = 0x13, OperatorMissing = 0x14, - + NotFound = 0x20, MemoryAllocationFailed = 0x21, AccessFailed = 0x22, InvalidProgram = 0x23, - + DelegateInvalidCompatibility = 0x30, DelegateMemoryAllocationFailed = 0x31, DelegateInvalidHandle = 0x32 diff --git a/ios/RnExecutorch/utils/ImageProcessor.mm b/ios/RnExecutorch/utils/ImageProcessor.mm index a8617c262f..be38974620 100644 --- a/ios/RnExecutorch/utils/ImageProcessor.mm +++ b/ios/RnExecutorch/utils/ImageProcessor.mm @@ -11,11 +11,12 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat mean:(cv::Scalar)mean variance:(cv::Scalar)variance { int pixelCount = mat.cols * mat.rows; - NSMutableArray *floatArray = [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; + NSMutableArray *floatArray = + [[NSMutableArray alloc] initWithCapacity:pixelCount * 3]; for (NSUInteger k = 0; k < pixelCount * 3; k++) { [floatArray addObject:@0.0]; } - + for (int i = 0; i < pixelCount; i++) { int row = i / mat.cols; int col = i % mat.cols; @@ -24,7 +25,7 @@ + (NSArray *)matToNSArray:(const cv::Mat &)mat floatArray[1 * pixelCount + i] = @((pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0)); floatArray[2 * pixelCount + i] = @((pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0)); } - + return floatArray; } @@ -43,21 +44,21 @@ + (NSArray *)matToNSArrayGray:(const cv::Mat &)mat { + (cv::Mat)arrayToMat:(NSArray *)array width:(int)width height:(int)height { cv::Mat mat(height, width, CV_8UC3); - + int pixelCount = width * height; for (int i = 0; i < pixelCount; i++) { int row = i / width; int col = i % width; float r = 0, g = 0, b = 0; - - r = [[array objectAtIndex: 0 * pixelCount + i] floatValue]; - g = [[array objectAtIndex: 1 * pixelCount + i] floatValue]; - b = [[array objectAtIndex: 2 * pixelCount + i] floatValue]; - + + r = [[array objectAtIndex:0 * pixelCount + i] floatValue]; + g = [[array objectAtIndex:1 * pixelCount + i] floatValue]; + b = [[array objectAtIndex:2 * pixelCount + i] floatValue]; + cv::Vec3b color((uchar)(b * 255), (uchar)(g * 255), (uchar)(r * 255)); mat.at(row, col) = color; } - + return mat; } @@ -75,54 +76,65 @@ + (NSArray *)matToNSArrayGray:(const cv::Mat &)mat { return mat; } -+ (NSString *)saveToTempFile:(const cv::Mat&)image { ++ (NSString *)saveToTempFile:(const cv::Mat &)image { NSString *uniqueID = [[NSUUID UUID] UUIDString]; - NSString *filename = [NSString stringWithFormat:@"rn_executorch_%@.png", uniqueID]; - NSString *outputPath = [NSTemporaryDirectory() stringByAppendingPathComponent:filename]; - + NSString *filename = + [NSString stringWithFormat:@"rn_executorch_%@.png", uniqueID]; + NSString *outputPath = + [NSTemporaryDirectory() stringByAppendingPathComponent:filename]; + std::string filePath = [outputPath UTF8String]; if (!cv::imwrite(filePath, image)) { - @throw [NSException exceptionWithName:@"ImageSaveException" - reason:[NSString stringWithFormat:@"%ld", (long)FileWriteFailed] - userInfo:nil]; + @throw [NSException + exceptionWithName:@"ImageSaveException" + reason:[NSString + stringWithFormat:@"%ld", (long)FileWriteFailed] + userInfo:nil]; } - + return [NSString stringWithFormat:@"file://%@", outputPath]; } + (cv::Mat)readImage:(NSString *)source { NSURL *url = [NSURL URLWithString:source]; - + cv::Mat inputImage; - if([[url scheme] isEqualToString: @"data"]){ - //base64 + if ([[url scheme] isEqualToString:@"data"]) { + // base64 NSArray *parts = [source componentsSeparatedByString:@","]; if ([parts count] < 2) { - @throw [NSException exceptionWithName:@"readImage_error" - reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] - userInfo:nil]; + @throw [NSException + exceptionWithName:@"readImage_error" + reason:[NSString + stringWithFormat:@"%ld", (long)InvalidArgument] + userInfo:nil]; } NSString *encodedString = parts[1]; - NSData *data = [[NSData alloc] initWithBase64EncodedString:encodedString options:NSDataBase64DecodingIgnoreUnknownCharacters]; + NSData *data = [[NSData alloc] + initWithBase64EncodedString:encodedString + options: + NSDataBase64DecodingIgnoreUnknownCharacters]; cv::Mat encodedData(1, [data length], CV_8UC1, (void *)data.bytes); inputImage = cv::imdecode(encodedData, cv::IMREAD_COLOR); - } - else if([[url scheme] isEqualToString: @"file"]){ - //local file + } else if ([[url scheme] isEqualToString:@"file"]) { + // local file inputImage = cv::imread([[url path] UTF8String], cv::IMREAD_COLOR); - } - else { - //external file + } else { + // external file NSData *data = [NSData dataWithContentsOfURL:url]; - inputImage = cv::imdecode(cv::Mat(1, [data length], CV_8UC1, (void*)data.bytes), cv::IMREAD_COLOR); + inputImage = + cv::imdecode(cv::Mat(1, [data length], CV_8UC1, (void *)data.bytes), + cv::IMREAD_COLOR); } - - if(inputImage.empty()){ - @throw [NSException exceptionWithName:@"readImage_error" - reason:[NSString stringWithFormat:@"%ld", (long)InvalidArgument] - userInfo:nil]; + + if (inputImage.empty()) { + @throw [NSException + exceptionWithName:@"readImage_error" + reason:[NSString + stringWithFormat:@"%ld", (long)InvalidArgument] + userInfo:nil]; } - + return inputImage; } diff --git a/ios/RnExecutorch/utils/llms/Constants.mm b/ios/RnExecutorch/utils/llms/Constants.mm index 2a30b02b4b..7ca223759c 100644 --- a/ios/RnExecutorch/utils/llms/Constants.mm +++ b/ios/RnExecutorch/utils/llms/Constants.mm @@ -1,20 +1,23 @@ #import "Constants.h" -#import #import "ConversationManager.h" +#import -NSString * const END_OF_TEXT_TOKEN_NS = [[NSString alloc] initWithBytes:END_OF_TEXT_TOKEN.data() - length:END_OF_TEXT_TOKEN.size() - encoding:NSUTF8StringEncoding]; - -NSString * const BEGIN_OF_TEXT_TOKEN_NS = [[NSString alloc] initWithBytes:BEGIN_OF_TEXT_TOKEN.data() - length:BEGIN_OF_TEXT_TOKEN.size() - encoding:NSUTF8StringEncoding]; +NSString *const END_OF_TEXT_TOKEN_NS = + [[NSString alloc] initWithBytes:END_OF_TEXT_TOKEN.data() + length:END_OF_TEXT_TOKEN.size() + encoding:NSUTF8StringEncoding]; -NSString * const START_HEADER_ID_TOKEN_NS = [[NSString alloc] initWithBytes:START_HEADER_ID_TOKEN.data() - length:START_HEADER_ID_TOKEN.size() - encoding:NSUTF8StringEncoding]; +NSString *const BEGIN_OF_TEXT_TOKEN_NS = + [[NSString alloc] initWithBytes:BEGIN_OF_TEXT_TOKEN.data() + length:BEGIN_OF_TEXT_TOKEN.size() + encoding:NSUTF8StringEncoding]; +NSString *const START_HEADER_ID_TOKEN_NS = + [[NSString alloc] initWithBytes:START_HEADER_ID_TOKEN.data() + length:START_HEADER_ID_TOKEN.size() + encoding:NSUTF8StringEncoding]; -NSString * const END_HEADER_ID_TOKEN_NS = [[NSString alloc] initWithBytes:END_HEADER_ID_TOKEN.data() - length:END_HEADER_ID_TOKEN.size() - encoding:NSUTF8StringEncoding]; +NSString *const END_HEADER_ID_TOKEN_NS = + [[NSString alloc] initWithBytes:END_HEADER_ID_TOKEN.data() + length:END_HEADER_ID_TOKEN.size() + encoding:NSUTF8StringEncoding]; diff --git a/ios/RnExecutorch/utils/llms/ConversationManager.h b/ios/RnExecutorch/utils/llms/ConversationManager.h index 20059a595b..f3d3a95e1c 100644 --- a/ios/RnExecutorch/utils/llms/ConversationManager.h +++ b/ios/RnExecutorch/utils/llms/ConversationManager.h @@ -1,29 +1,24 @@ #import +#import #import #import -#import -enum class ChatRole -{ - SYSTEM, - USER, - ASSISTANT -}; +enum class ChatRole { SYSTEM, USER, ASSISTANT }; inline constexpr std::string_view BEGIN_OF_TEXT_TOKEN = "<|begin_of_text|>"; inline constexpr std::string_view END_OF_TEXT_TOKEN = "<|eot_id|>"; inline constexpr std::string_view START_HEADER_ID_TOKEN = "<|start_header_id|>"; inline constexpr std::string_view END_HEADER_ID_TOKEN = "<|end_header_id|>"; -@interface ConversationManager : NSObject -{ +@interface ConversationManager : NSObject { NSUInteger numMessagesContextWindow; std::string basePrompt; std::deque messages; } - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages - systemPrompt:(NSString *)systemPrompt; + systemPrompt:(NSString *)systemPrompt + messageHistory:(NSArray *)messageHistory; - (void)addResponse:(NSString *)text senderRole:(ChatRole)senderRole; - (NSString *)getConversation; diff --git a/ios/RnExecutorch/utils/llms/ConversationManager.mm b/ios/RnExecutorch/utils/llms/ConversationManager.mm index a241a51533..758c5a8198 100644 --- a/ios/RnExecutorch/utils/llms/ConversationManager.mm +++ b/ios/RnExecutorch/utils/llms/ConversationManager.mm @@ -3,7 +3,8 @@ @implementation ConversationManager - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages - systemPrompt:(NSString *)systemPrompt { + systemPrompt:(NSString *)systemPrompt + messageHistory:(NSArray *)messageHistory { self = [super init]; if (self) { numMessagesContextWindow = numMessages; @@ -12,6 +13,16 @@ - (instancetype)initWithNumMessagesContextWindow:(NSUInteger)numMessages basePrompt += [systemPrompt UTF8String]; basePrompt += std::string(END_OF_TEXT_TOKEN); basePrompt += [self getHeaderTokenFromRole:ChatRole::USER]; + + for (const NSDictionary *message in messageHistory) { + NSString *role = message[@"role"]; + NSString *content = message[@"content"]; + if ([role isEqualToString:@"user"]) { + [self addResponse:content senderRole:ChatRole::USER]; + } else if ([role isEqualToString:@"assistant"]) { + [self addResponse:content senderRole:ChatRole::ASSISTANT]; + } + } } return self; } @@ -20,7 +31,7 @@ - (void)addResponse:(NSString *)text senderRole:(ChatRole)senderRole { if (messages.size() >= numMessagesContextWindow) { messages.pop_front(); } - + std::string formattedMessage; if (senderRole == ChatRole::ASSISTANT) { formattedMessage = [text UTF8String]; @@ -35,7 +46,7 @@ - (void)addResponse:(NSString *)text senderRole:(ChatRole)senderRole { - (NSString *)getConversation { std::string prompt = basePrompt; - for (const auto& elem : messages) { + for (const auto &elem : messages) { prompt += elem; } return [NSString stringWithUTF8String:prompt.c_str()]; @@ -43,14 +54,17 @@ - (NSString *)getConversation { - (std::string)getHeaderTokenFromRole:(ChatRole)role { switch (role) { - case ChatRole::SYSTEM: - return std::string(START_HEADER_ID_TOKEN) + "system" + std::string(END_HEADER_ID_TOKEN); - case ChatRole::USER: - return std::string(START_HEADER_ID_TOKEN) + "user" + std::string(END_HEADER_ID_TOKEN); - case ChatRole::ASSISTANT: - return std::string(START_HEADER_ID_TOKEN) + "assistant" + std::string(END_HEADER_ID_TOKEN); - default: - return ""; + case ChatRole::SYSTEM: + return std::string(START_HEADER_ID_TOKEN) + "system" + + std::string(END_HEADER_ID_TOKEN); + case ChatRole::USER: + return std::string(START_HEADER_ID_TOKEN) + "user" + + std::string(END_HEADER_ID_TOKEN); + case ChatRole::ASSISTANT: + return std::string(START_HEADER_ID_TOKEN) + "assistant" + + std::string(END_HEADER_ID_TOKEN); + default: + return ""; } } diff --git a/src/constants/llamaDefaults.ts b/src/constants/llamaDefaults.ts index 2b1202fa33..234b7f9e06 100644 --- a/src/constants/llamaDefaults.ts +++ b/src/constants/llamaDefaults.ts @@ -1,5 +1,9 @@ +import { MessageType } from '../types/common'; + export const DEFAULT_SYSTEM_PROMPT = "You are a knowledgeable, efficient, and direct AI assistant. Provide concise answers, focusing on the key information needed. Offer suggestions tactfully when appropriate to improve outcomes. Engage in productive collaboration with the user. Don't return too much text."; +export const DEFAULT_MESSAGE_HISTORY: MessageType[] = []; + export const DEFAULT_CONTEXT_WINDOW_LENGTH = 3; export const EOT_TOKEN = '<|eot_id|>'; diff --git a/src/hooks/natural_language_processing/useLLM.ts b/src/hooks/natural_language_processing/useLLM.ts index 5864d40bc4..c86a78d1ce 100644 --- a/src/hooks/natural_language_processing/useLLM.ts +++ b/src/hooks/natural_language_processing/useLLM.ts @@ -2,9 +2,10 @@ import { useCallback, useEffect, useRef, useState } from 'react'; import { EventSubscription } from 'react-native'; import { LLM } from '../../native/RnExecutorchModules'; import { fetchResource } from '../../utils/fetchResource'; -import { ResourceSource, Model } from '../../types/common'; +import { ResourceSource, Model, MessageType } from '../../types/common'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, + DEFAULT_MESSAGE_HISTORY, DEFAULT_SYSTEM_PROMPT, EOT_TOKEN, } from '../../constants/llamaDefaults'; @@ -17,11 +18,13 @@ export const useLLM = ({ modelSource, tokenizerSource, systemPrompt = DEFAULT_SYSTEM_PROMPT, + messageHistory = DEFAULT_MESSAGE_HISTORY, contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH, }: { modelSource: ResourceSource; tokenizerSource: ResourceSource; systemPrompt?: string; + messageHistory?: MessageType[]; contextWindowLength?: number; }): Model => { const [error, setError] = useState(null); @@ -46,6 +49,7 @@ export const useLLM = ({ modelFileUri, tokenizerFileUri, systemPrompt, + messageHistory, contextWindowLength ); @@ -79,7 +83,13 @@ export const useLLM = ({ tokenGeneratedListener.current = null; LLM.deleteModule(); }; - }, [contextWindowLength, modelSource, systemPrompt, tokenizerSource]); + }, [ + modelSource, + tokenizerSource, + systemPrompt, + messageHistory, + contextWindowLength, + ]); const generate = useCallback( async (input: string): Promise => { diff --git a/src/modules/natural_language_processing/LLMModule.ts b/src/modules/natural_language_processing/LLMModule.ts index 7b44e8a90f..c9994e4bac 100644 --- a/src/modules/natural_language_processing/LLMModule.ts +++ b/src/modules/natural_language_processing/LLMModule.ts @@ -2,6 +2,7 @@ import { LLM } from '../../native/RnExecutorchModules'; import { fetchResource } from '../../utils/fetchResource'; import { DEFAULT_CONTEXT_WINDOW_LENGTH, + DEFAULT_MESSAGE_HISTORY, DEFAULT_SYSTEM_PROMPT, } from '../../constants/llamaDefaults'; import { ResourceSource } from '../../types/common'; @@ -13,6 +14,7 @@ export class LLMModule { modelSource: ResourceSource, tokenizerSource: ResourceSource, systemPrompt = DEFAULT_SYSTEM_PROMPT, + messageHistory = DEFAULT_MESSAGE_HISTORY, contextWindowLength = DEFAULT_CONTEXT_WINDOW_LENGTH ) { try { @@ -26,6 +28,7 @@ export class LLMModule { modelFileUri, tokenizerFileUri, systemPrompt, + messageHistory, contextWindowLength ); } catch (err) { diff --git a/src/native/NativeLLM.ts b/src/native/NativeLLM.ts index b610e30e39..edd9e59480 100644 --- a/src/native/NativeLLM.ts +++ b/src/native/NativeLLM.ts @@ -1,12 +1,14 @@ import type { TurboModule } from 'react-native'; import { TurboModuleRegistry } from 'react-native'; import type { EventEmitter } from 'react-native/Libraries/Types/CodegenTypes'; +import { MessageType } from '../types/common'; export interface Spec extends TurboModule { loadLLM( modelSource: string, tokenizerSource: string, systemPrompt: string, + messageHistory: MessageType[], contextWindowLength: number ): Promise; runInference(input: string): Promise; diff --git a/src/types/common.ts b/src/types/common.ts index 10b2b2600a..b98be61593 100644 --- a/src/types/common.ts +++ b/src/types/common.ts @@ -53,3 +53,8 @@ export type Module = | _StyleTransferModule | _ObjectDetectionModule | typeof ETModule; + +export interface MessageType { + role: 'user' | 'assistant'; + content: string; +}