Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions android/src/main/java/com/swmansion/rnexecutorch/LLM.kt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ package com.swmansion.rnexecutorch
import android.util.Log
import com.facebook.react.bridge.Promise
import com.facebook.react.bridge.ReactApplicationContext
import com.facebook.react.bridge.ReadableArray
import com.swmansion.rnexecutorch.utils.llms.ChatRole
import com.swmansion.rnexecutorch.utils.llms.ConversationManager
import com.swmansion.rnexecutorch.utils.llms.END_OF_TEXT_TOKEN
import org.pytorch.executorch.LlamaCallback
import org.pytorch.executorch.LlamaModule
import com.swmansion.rnexecutorch.utils.ArrayUtils
import java.net.URL

class LLM(reactContext: ReactApplicationContext) :
NativeLLMSpec(reactContext), LlamaCallback {
class LLM(reactContext: ReactApplicationContext) : NativeLLMSpec(reactContext), LlamaCallback {

private var llamaModule: LlamaModule? = null
private var tempLlamaResponse = StringBuilder()
Expand All @@ -38,11 +39,14 @@ class LLM(reactContext: ReactApplicationContext) :
modelSource: String,
tokenizerSource: String,
systemPrompt: String,
messageHistory: ReadableArray,
contextWindowLength: Double,
promise: Promise
) {
try {
this.conversationManager = ConversationManager(contextWindowLength.toInt(), systemPrompt)
this.conversationManager = ConversationManager(
contextWindowLength.toInt(), systemPrompt, ArrayUtils.createMapArray<String>(messageHistory)
)
llamaModule = LlamaModule(1, URL(modelSource).path, URL(tokenizerSource).path, 0.7f)
this.tempLlamaResponse.clear()
promise.resolve("Model loaded successfully")
Expand All @@ -51,10 +55,7 @@ class LLM(reactContext: ReactApplicationContext) :
}
}

override fun runInference(
input: String,
promise: Promise
) {
override fun runInference(input: String, promise: Promise) {
this.conversationManager.addResponse(input, ChatRole.USER)
val conversation = this.conversationManager.getConversation()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.swmansion.rnexecutorch.utils
import android.util.Log
import com.facebook.react.bridge.Arguments
import com.facebook.react.bridge.ReadableArray
import com.facebook.react.bridge.ReadableMap
import org.pytorch.executorch.DType
import org.pytorch.executorch.Tensor

Expand Down Expand Up @@ -35,6 +36,15 @@ class ArrayUtils {
fun createDoubleArray(input: ReadableArray): DoubleArray {
return createTypedArrayFromReadableArray(input) { array, index -> array.getDouble(index) }.toDoubleArray()
}

fun <V> createMapArray(input: ReadableArray): Array<Map<String, V>> {
val mapArray = Array<Map<String, V>>(input.size()) { mapOf() }
for (i in 0 until input.size()) {
mapArray[i] = input.getMap(i).toHashMap() as Map<String, V>
}
return mapArray
}

fun createReadableArrayFromTensor(result: Tensor): ReadableArray {
val resultArray = Arguments.createArray()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ const val END_OF_TEXT_TOKEN = "<|eot_id|>"
const val START_HEADER_ID_TOKEN = "<|start_header_id|>"
const val END_HEADER_ID_TOKEN = "<|end_header_id|>"

class ConversationManager(private val numMessagesContextWindow: Int, systemPrompt: String) {
class ConversationManager(
private val numMessagesContextWindow: Int,
systemPrompt: String,
messageHistory: Array<Map<String, String>>
) {
private val basePrompt: String;
private val messages = ArrayDeque<String>();

Expand All @@ -22,6 +26,13 @@ class ConversationManager(private val numMessagesContextWindow: Int, systemPromp
systemPrompt +
END_OF_TEXT_TOKEN +
getHeaderTokenFromRole(ChatRole.USER)

messageHistory.forEach { message ->
when (message["role"]) {
"user" -> addResponse(message["content"]!!, ChatRole.USER)
"assistant" -> addResponse(message["content"]!!, ChatRole.ASSISTANT)
}
}
}

fun addResponse(text: String, senderRole: ChatRole) {
Expand All @@ -31,9 +42,7 @@ class ConversationManager(private val numMessagesContextWindow: Int, systemPromp
val formattedMessage: String = if (senderRole == ChatRole.ASSISTANT) {
text + getHeaderTokenFromRole(ChatRole.USER)
} else {
text +
END_OF_TEXT_TOKEN +
getHeaderTokenFromRole(ChatRole.ASSISTANT)
text + END_OF_TEXT_TOKEN + getHeaderTokenFromRole(ChatRole.ASSISTANT)
}
this.messages.add(formattedMessage)
}
Expand Down
10 changes: 7 additions & 3 deletions examples/llama/components/MessageItem.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@ interface MessageItemProps {

const MessageItem = memo(({ message }: MessageItemProps) => {
return (
<View style={message.from === 'ai' ? styles.aiMessage : styles.userMessage}>
{message.from === 'ai' && (
<View
style={
message.role === 'assistant' ? styles.aiMessage : styles.userMessage
}
>
{message.role === 'assistant' && (
<View style={styles.aiMessageIconContainer}>
<LlamaIcon width={24} height={24} />
</View>
)}
<MarkdownComponent text={message.text} />
<MarkdownComponent text={message.content} />
</View>
);
});
Expand Down
9 changes: 3 additions & 6 deletions examples/llama/screens/ChatScreen.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,12 @@ export default function ChatScreen() {
const textInputRef = useRef<TextInput>(null);
useEffect(() => {
if (llama.response && !llama.isGenerating) {
appendToMessageHistory(llama.response, 'ai');
appendToMessageHistory(llama.response, 'assistant');
}
}, [llama.response, llama.isGenerating]);

const appendToMessageHistory = (input: string, role: SenderType) => {
setChatHistory((prevHistory) => [
...prevHistory,
{ text: input, from: role },
]);
const appendToMessageHistory = (content: string, role: SenderType) => {
setChatHistory((prevHistory) => [...prevHistory, { role, content }]);
};

const sendMessage = async () => {
Expand Down
6 changes: 3 additions & 3 deletions examples/llama/types.d.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
export type SenderType = 'user' | 'ai';
export type SenderType = 'user' | 'assistant';

export interface MessageType {
text: string;
from: SenderType;
role: SenderType;
content: string;
}
4 changes: 3 additions & 1 deletion ios/RnExecutorch/LLM.mm
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ - (void)onResult:(NSString *)token prompt:(NSString *)prompt {
- (void)loadLLM:(NSString *)modelSource
tokenizerSource:(NSString *)tokenizerSource
systemPrompt:(NSString *)systemPrompt
messageHistory:(NSArray *)messageHistory
contextWindowLength:(double)contextWindowLength
resolve:(RCTPromiseResolveBlock)resolve
reject:(RCTPromiseRejectBlock)reject {
Expand All @@ -55,7 +56,8 @@ - (void)loadLLM:(NSString *)modelSource

self->conversationManager = [[ConversationManager alloc]
initWithNumMessagesContextWindow:contextWindowLengthUInt
systemPrompt:systemPrompt];
systemPrompt:systemPrompt
messageHistory:messageHistory];

self->tempLlamaResponse = [NSMutableString string];
resolve(@"Model and tokenizer loaded successfully");
Expand Down
22 changes: 12 additions & 10 deletions ios/RnExecutorch/models/StyleTransferModel.mm
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,45 @@ @implementation StyleTransferModel {
cv::Size originalSize;
}

- (cv::Size)getModelImageSize{
NSArray * inputShape = [module getInputShape: @0];
- (cv::Size)getModelImageSize {
NSArray *inputShape = [module getInputShape:@0];
NSNumber *widthNumber = inputShape.lastObject;
NSNumber *heightNumber = inputShape[inputShape.count - 2];

int height = [heightNumber intValue];
int width = [widthNumber intValue];

return cv::Size(height, width);
}

- (NSArray *)preprocess:(cv::Mat &)input {
self->originalSize = cv::Size(input.cols, input.rows);

cv::Size modelImageSize = [self getModelImageSize];
cv::Mat output;
cv::resize(input, output, modelImageSize);
NSArray *modelInput = [ImageProcessor matToNSArray: output];

NSArray *modelInput = [ImageProcessor matToNSArray:output];
return modelInput;
}

- (cv::Mat)postprocess:(NSArray *)output {
cv::Size modelImageSize = [self getModelImageSize];
cv::Mat processedImage = [ImageProcessor arrayToMat: output width:modelImageSize.width height:modelImageSize.height];
cv::Mat processedImage = [ImageProcessor arrayToMat:output
width:modelImageSize.width
height:modelImageSize.height];

cv::Mat processedOutput;
cv::resize(processedImage, processedOutput, originalSize);

return processedOutput;
}

- (cv::Mat)runModel:(cv::Mat &)input {
NSArray *modelInput = [self preprocess:input];
NSArray *result = [self forward:modelInput];
input = [self postprocess:result[0]];

return input;
}

Expand Down
22 changes: 11 additions & 11 deletions ios/RnExecutorch/models/classification/ClassificationModel.mm
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
#import "ClassificationModel.h"
#import "opencv2/opencv.hpp"
#import "Utils.h"
#import "Constants.h"
#import "../../utils/ImageProcessor.h"
#import "Constants.h"
#import "Utils.h"
#import "opencv2/opencv.hpp"

@implementation ClassificationModel

- (cv::Size)getModelImageSize {
NSArray * inputShape = [module getInputShape: 0];
NSArray *inputShape = [module getInputShape:0];
NSNumber *widthNumber = inputShape.lastObject;
NSNumber *heightNumber = inputShape[inputShape.count - 2];

int height = [heightNumber intValue];
int width = [widthNumber intValue];

return cv::Size(height, width);
}

- (NSArray *)preprocess:(cv::Mat &)input {
cv::Size modelImageSize = [self getModelImageSize];
cv::resize(input, input, modelImageSize);
NSArray *modelInput = [ImageProcessor matToNSArray: input];

NSArray *modelInput = [ImageProcessor matToNSArray:input];
return modelInput;
}

Expand All @@ -32,16 +32,16 @@ - (NSDictionary *)postprocess:(NSArray *)output {
for (NSUInteger i = 0; i < output.count; ++i) {
outputVector[i] = [output[i] doubleValue];
}

std::vector<double> probabilities = softmax(outputVector);
NSMutableDictionary *result = [NSMutableDictionary dictionary];

for (int i = 0; i < probabilities.size(); ++i) {
NSString *className = @(imagenet1k_v1_labels[i].c_str());
NSNumber *probability = @(probabilities[i]);
result[className] = probability;
}

return result;
}

Expand Down
2 changes: 1 addition & 1 deletion ios/RnExecutorch/models/classification/Utils.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include <vector>

std::vector<double> softmax(const std::vector<double>& v);
std::vector<double> softmax(const std::vector<double> &v);
28 changes: 14 additions & 14 deletions ios/RnExecutorch/models/classification/Utils.mm
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
#include "Utils.h"
#include <vector>
#include <cmath>
#include <vector>

std::vector<double> softmax(const std::vector<double>& v) {
std::vector<double> result(v.size());
double maxVal = *std::max_element(v.begin(), v.end());
std::vector<double> softmax(const std::vector<double> &v) {
std::vector<double> result(v.size());
double maxVal = *std::max_element(v.begin(), v.end());

double sumExp = 0.0;
for (size_t i = 0; i < v.size(); ++i) {
result[i] = std::exp(v[i] - maxVal);
sumExp += result[i];
}
double sumExp = 0.0;
for (size_t i = 0; i < v.size(); ++i) {
result[i] = std::exp(v[i] - maxVal);
sumExp += result[i];
}

for (size_t i = 0; i < v.size(); ++i) {
result[i] /= sumExp;
}
for (size_t i = 0; i < v.size(); ++i) {
result[i] /= sumExp;
}

return result;
}
return result;
}
8 changes: 4 additions & 4 deletions ios/RnExecutorch/utils/ETError.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ typedef NS_ENUM(NSUInteger, ETError) {
ModuleNotLoaded = 0x66,
FileWriteFailed = 0x67,
InvalidModelSource = 0xff,

Ok = 0x00,
Internal = 0x01,
InvalidState = 0x02,
EndOfMethod = 0x03,

NotSupported = 0x10,
NotImplemented = 0x11,
InvalidArgument = 0x12,
InvalidType = 0x13,
OperatorMissing = 0x14,

NotFound = 0x20,
MemoryAllocationFailed = 0x21,
AccessFailed = 0x22,
InvalidProgram = 0x23,

DelegateInvalidCompatibility = 0x30,
DelegateMemoryAllocationFailed = 0x31,
DelegateInvalidHandle = 0x32
Expand Down
Loading