Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3435,7 +3435,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_JINJA"));
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"
Expand Down
47 changes: 32 additions & 15 deletions tools/mtmd/mtmd-cli.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ struct mtmd_cli_context {

mtmd::bitmaps bitmaps;

// note: we know that gemma3 template is "linear", meaning each turn is completely separated to another
// so here we don't need to keep track of chat history
// chat template
common_chat_templates_ptr tmpls;
std::vector<common_chat_msg> chat_history;
bool use_jinja = false;
// TODO: support for --system-prompt with /clear command

// support for legacy templates (models not having EOT token)
llama_tokens antiprompt_tokens;
Expand Down Expand Up @@ -108,6 +110,8 @@ struct mtmd_cli_context {
}

tmpls = common_chat_templates_init(model, params.chat_template);
use_jinja = params.use_jinja;
chat_history.clear();
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(tmpls.get(), params.use_jinja, params.default_template_kwargs).c_str());

init_vision_context(params);
Expand Down Expand Up @@ -193,19 +197,33 @@ static int generate_response(mtmd_cli_context & ctx, int n_predict) {
return 1;
}
}

std::string generated_text = common_detokenize(ctx.lctx, generated_tokens);
common_chat_msg msg;
msg.role = "assistant";
msg.content = generated_text;
ctx.chat_history.push_back(std::move(msg));

return 0;
}

static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg, bool add_bos = false) {
common_chat_templates_inputs tmpl_inputs;
tmpl_inputs.messages = {msg};
tmpl_inputs.add_generation_prompt = true;
tmpl_inputs.use_jinja = false; // jinja is buggy here
auto formatted_chat = common_chat_templates_apply(ctx.tmpls.get(), tmpl_inputs);
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.prompt.c_str());
static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & new_msg) {
LOG_DBG("chat_add_and_format: new_msg.role='%s', new_msg.content='%s'\n",
new_msg.role.c_str(), new_msg.content.c_str());
auto formatted = common_chat_format_single(ctx.tmpls.get(), ctx.chat_history,
new_msg, new_msg.role == "user",
ctx.use_jinja);
ctx.chat_history.push_back(new_msg);
return formatted;
}

static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) {
bool add_bos = ctx.chat_history.empty();
auto formatted_chat = chat_add_and_format(ctx, msg);
LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str());

mtmd_input_text text;
text.text = formatted_chat.prompt.c_str();
text.text = formatted_chat.c_str();
text.add_special = add_bos;
text.parse_special = true;

Expand Down Expand Up @@ -303,7 +321,7 @@ int main(int argc, char ** argv) {
return 1; // error is already printed by libmtmd
}
}
if (eval_message(ctx, msg, true)) {
if (eval_message(ctx, msg)) {
return 1;
}
if (!g_is_interrupted && generate_response(ctx, n_predict)) {
Expand All @@ -322,7 +340,6 @@ int main(int argc, char ** argv) {
LOG("\n /quit or /exit exit the program");
LOG("\n");

bool is_first_msg = true;
std::string content;

while (!g_is_interrupted) {
Expand All @@ -342,7 +359,8 @@ int main(int argc, char ** argv) {
}
if (line == "/clear") {
ctx.n_past = 0;
llama_memory_seq_rm(llama_get_memory(ctx.lctx), 0, 1, -1); // keep BOS
ctx.chat_history.clear();
llama_memory_clear(llama_get_memory(ctx.lctx), true);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will no longer keep the BOS - just making sure it is intended.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it's expected, as the BOS token will always be added along with the first formatted message

LOG("Chat history cleared\n\n");
continue;
}
Expand All @@ -367,7 +385,7 @@ int main(int argc, char ** argv) {
common_chat_msg msg;
msg.role = "user";
msg.content = content;
int ret = eval_message(ctx, msg, is_first_msg);
int ret = eval_message(ctx, msg);
if (ret) {
return 1;
}
Expand All @@ -376,7 +394,6 @@ int main(int argc, char ** argv) {
return 1;
}
content.clear();
is_first_msg = false;
}
}
if (g_is_interrupted) LOG("\nInterrupted by user\n");
Expand Down
Loading