Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 51 additions & 2 deletions examples/llama.android/app/src/main/cpp/llama-android.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,45 @@ jclass la_int_var;
jmethodID la_int_var_value;
jmethodID la_int_var_inc;

std::string cached_token_chars;

bool is_valid_utf8(const char * string) {
if (!string) {
return true;
}

const unsigned char * bytes = (const unsigned char *)string;
int num;

while (*bytes != 0x00) {
if ((*bytes & 0x80) == 0x00) {
// U+0000 to U+007F
num = 1;
} else if ((*bytes & 0xE0) == 0xC0) {
// U+0080 to U+07FF
num = 2;
} else if ((*bytes & 0xF0) == 0xE0) {
// U+0800 to U+FFFF
num = 3;
} else if ((*bytes & 0xF8) == 0xF0) {
// U+10000 to U+10FFFF
num = 4;
} else {
return false;
}

bytes += 1;
for (int i = 1; i < num; ++i) {
if ((*bytes & 0xC0) != 0x80) {
return false;
}
bytes += 1;
}
}

return true;
}

static void log_callback(ggml_log_level level, const char * fmt, void * data) {
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
Expand Down Expand Up @@ -295,6 +334,8 @@ Java_com_example_llama_Llm_completion_1init(
jint n_len
) {

cached_token_chars.clear();

const auto text = env->GetStringUTFChars(jtext, 0);
const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
Expand Down Expand Up @@ -372,8 +413,16 @@ Java_com_example_llama_Llm_completion_1loop(
}

auto new_token_chars = llama_token_to_piece(context, new_token_id);
LOGi("new_token_chars: `%s`", new_token_chars.c_str());
auto new_token = env->NewStringUTF(new_token_chars.c_str());
cached_token_chars += new_token_chars;

jstring new_token = nullptr;
if (is_valid_utf8(cached_token_chars.c_str())) {
new_token = env->NewStringUTF(cached_token_chars.c_str());
LOGi("cached: %s, new_token_chars: `%s`, id: %d", cached_token_chars.c_str(), new_token_chars.c_str(), new_token_id);
cached_token_chars.clear();
} else {
new_token = env->NewStringUTF("");
}

llama_batch_clear(*batch);
llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Llm {
batch: Long,
nLen: Int,
ncur: IntVar
): String
): String?

private external fun kv_cache_clear(context: Long)

Expand Down Expand Up @@ -115,7 +115,7 @@ class Llm {
val ncur = IntVar(completion_init(state.context, state.batch, message, nlen))
while (ncur.value <= nlen) {
val str = completion_loop(state.context, state.batch, nlen, ncur)
if (str.isEmpty()) {
if (str == null) {
break
}
emit(str)
Expand Down