diff --git a/CODEOWNERS b/CODEOWNERS index 9d252c9b8dc..675c27b2522 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -11,6 +11,8 @@ /common/base64.hpp.* @ggerganov /common/build-info.* @ggerganov /common/chat.* @pwilkin +/common/chat-auto*.* @pwilkin +/common/chat-diff-analyzer.* @pwilkin /common/chat-peg-parser.* @aldehir /common/common.* @ggerganov /common/console.* @ggerganov @@ -89,12 +91,13 @@ /src/llama-vocab.* @CISC /src/models/ @CISC /tests/ @ggerganov -/tests/test-chat-.* @pwilkin +/tests/test-chat.* @pwilkin /tools/batched-bench/ @ggerganov /tools/cli/ @ngxson /tools/completion/ @ggerganov /tools/mtmd/ @ngxson /tools/perplexity/ @ggerganov +/tools/parser/ @pwilkin /tools/quantize/ @ggerganov /tools/rpc/ @rgerganov /tools/server/* @ngxson @ggerganov # no subdir diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 7545e790f82..0fe627f4e7f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -159,7 +159,7 @@ Maintainers reserve the right to decline review or close pull requests for any r # Code maintenance -- Existing code should have designated collaborators and/or maintainers specified in the [CODEOWNERS](CODEOWNERS) file reponsible for: +- Existing code should have designated collaborators and/or maintainers specified in the [CODEOWNERS](CODEOWNERS) file responsible for: - Reviewing and merging related PRs - Fixing related bugs - Providing developer guidance/support diff --git a/README.md b/README.md index 5c11f38048a..125cb3f3700 100644 --- a/README.md +++ b/README.md @@ -287,7 +287,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [IBM zDNN](docs/backend/zDNN.md) | IBM Z & LinuxONE | | [WebGPU [In Progress]](docs/build.md#webgpu) | All | | [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | -| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon | +| [Hexagon [In Progress]](docs/backend/snapdragon/README.md) | Snapdragon | | [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR | ## Obtaining and quantizing models diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 27ca335be37..51bff1c44bf 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -47,10 +47,10 @@ add_library(${TARGET} STATIC arg.cpp arg.h base64.hpp - chat-parser.cpp - chat-parser.h - chat-parser-xml-toolcall.h - chat-parser-xml-toolcall.cpp + chat-auto-parser-generator.cpp + chat-auto-parser-helpers.cpp + chat-auto-parser.h + chat-diff-analyzer.cpp chat-peg-parser.cpp chat-peg-parser.h chat.cpp diff --git a/common/arg.cpp b/common/arg.cpp index 05f4a5244e7..0d8561dbb3c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1279,13 +1279,20 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_env("LLAMA_ARG_SWA_FULL")); add_opt(common_arg( - {"--ctx-checkpoints", "--swa-checkpoints"}, "N", + {"-ctxcp", "--ctx-checkpoints", "--swa-checkpoints"}, "N", string_format("max number of context checkpoints to create per slot (default: %d)" "[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints), [](common_params & params, int value) { params.n_ctx_checkpoints = value; } ).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"-cpent", "--checkpoint-every-n-tokens"}, "N", + string_format("create a checkpoint every n tokens during prefill (processing), -1 to disable (default: %d)", params.checkpoint_every_nt), + [](common_params & params, int value) { + params.checkpoint_every_nt = value; + } + ).set_env("LLAMA_ARG_CHECKPOINT_EVERY_NT").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); add_opt(common_arg( {"-cram", "--cache-ram"}, "N", string_format("set the maximum cache size in MiB (default: %d, -1 - no limit, 0 - disable)" @@ -2399,7 +2406,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.fit_params = false; } else { throw std::runtime_error( - string_format("error: unkown value for --fit: '%s'\n", value.c_str())); + string_format("error: unknown value for --fit: '%s'\n", value.c_str())); } } ).set_env("LLAMA_ARG_FIT")); @@ -2827,6 +2834,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.webui_config_json = read_file(value); } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE")); + add_opt(common_arg( + {"--webui-mcp-proxy"}, + {"--no-webui-mcp-proxy"}, + string_format("experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: %s)", params.webui_mcp_proxy ? "enabled" : "disabled"), + [](common_params & params, bool value) { + params.webui_mcp_proxy = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY")); add_opt(common_arg( {"--webui"}, {"--no-webui"}, diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp new file mode 100644 index 00000000000..9a936f892b0 --- /dev/null +++ b/common/chat-auto-parser-generator.cpp @@ -0,0 +1,424 @@ +#include "chat-auto-parser.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "json-schema-to-grammar.h" +#include "nlohmann/json.hpp" + +#include +#include + +using json = nlohmann::ordered_json; + +// Helper to iterate over tools/functions +static void foreach_function(const json & tools, const std::function & fn) { + for (const auto & tool : tools) { + if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { + continue; + } + fn(tool); + } +} + +namespace autoparser { + +parser_build_context::parser_build_context(common_chat_peg_builder & p, const templates_params & inputs) : + p(p), + inputs(inputs), + reasoning_parser(p.eps()) {} + +common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs) { + // Run differential analysis to extract template structure + struct autoparser autoparser; + autoparser.analyze_template(tmpl); + return generate_parser(tmpl, inputs, autoparser); +} + +common_chat_params peg_generator::generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs, + const autoparser & autoparser) { + // Build the parser using the analysis results + auto parser = autoparser.build_parser(inputs); + + // Create the result structure + common_chat_params data; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = autoparser.preserved_tokens; + data.parser = parser.save(); + + // Build grammar if tools are present + bool has_tools = + autoparser.tools.format.mode != tool_format::NONE && inputs.tools.is_array() && !inputs.tools.empty(); + std::string trigger_marker = !autoparser.tools.format.section_start.empty() ? autoparser.tools.format.section_start : + autoparser.tools.format.per_call_start; + bool include_grammar = + has_tools && ((inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO && !trigger_marker.empty()) || + inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED); + + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); + }); + parser.build_grammar(builder, data.grammar_lazy); + }); + + // Set grammar triggers based on tool section markers (fall back to per-call markers) + if (data.grammar_lazy) { // only do triggers on lazy grammar + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, trigger_marker } + }; + } + } + + return data; +} + +common_peg_arena autoparser::build_parser(const templates_params & inputs) const { + if (!analysis_complete) { + throw std::invalid_argument("Cannot call build_parser on autoparser without performing analysis first, call analyze_template(...)"); + } + return build_chat_peg_parser([&](common_chat_peg_builder & p) { + // If the template uses Python dict format (single-quoted strings in JSON structures), + // pre-register a json-string rule that accepts both quote styles. This must happen + // before any call to p.json() so that all JSON parsing inherits the flexible rule. + if (tools.format.uses_python_dicts) { + p.rule("json-string", [&]() { return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); }); + } + + parser_build_context ctx(p, inputs); + bool extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + bool enable_thinking = inputs.enable_thinking; + + ctx.extracting_reasoning = extract_reasoning && enable_thinking && reasoning.mode != reasoning_mode::NONE; + ctx.content = &content; + + // Build reasoning parser + ctx.reasoning_parser = reasoning.build_parser(ctx); + + bool has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + bool has_response_format = inputs.json_schema.is_object() && !inputs.json_schema.empty(); + + if (has_response_format) { + return ctx.reasoning_parser + p.space() + + p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + p.end(); + } + + if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && jinja_caps.supports_tool_calls) { + return tools.build_parser(ctx); + } + + return content.build_parser(ctx); + }); +} + +common_peg_parser analyze_reasoning::build_parser(parser_build_context & ctx) const { + auto & p = ctx.p; + + if (!ctx.extracting_reasoning) { + return p.eps(); + } + + bool thinking_forced_open = (mode == reasoning_mode::FORCED_OPEN); + bool thinking_forced_closed = (mode == reasoning_mode::FORCED_CLOSED); + + if (thinking_forced_open || thinking_forced_closed) { + // Thinking is forced open OR forced closed with enable_thinking=true + // In both cases, expect only the closing tag (opening was in template) + return p.reasoning(p.until(end)) + end; + } + if (mode == reasoning_mode::TAG_BASED || mode == reasoning_mode::TOOLS_ONLY) { + // Standard tag-based reasoning OR tools-only mode (reasoning appears with tools) + // Both use the same tag-based pattern if markers are available + if (!start.empty() && !end.empty()) { + return p.optional(start + p.reasoning(p.until(end)) + end); + } + } else if (mode == reasoning_mode::DELIMITER) { + return p.optional(p.reasoning(p.until(end)) + end); + } + + return p.eps(); +} + +common_peg_parser analyze_content::build_parser(parser_build_context & ctx) const { + auto & p = ctx.p; + + if (is_always_wrapped()) { + if (ctx.extracting_reasoning) { + return ctx.reasoning_parser + start + p.content(p.until(end)) + end + p.end(); + } + return p.content(p.until(start)) + start + p.content(p.until(end)) + end + p.end(); + } + return ctx.reasoning_parser + p.content(p.rest()) + p.end(); +} + +common_peg_parser analyze_content::build_optional_wrapped(parser_build_context & ctx) const { + auto & p = ctx.p; + + if (is_always_wrapped()) { + return p.optional(start + p.content(p.until(end)) + end); + } + return p.eps(); +} + +common_peg_parser analyze_tools::build_parser(parser_build_context & ctx) const { + switch (format.mode) { + case tool_format::JSON_NATIVE: + return build_tool_parser_json_native(ctx); + case tool_format::TAG_WITH_JSON: + return build_tool_parser_tag_json(ctx); + case tool_format::TAG_WITH_TAGGED: + return build_tool_parser_tag_tagged(ctx); + default: + GGML_ABORT("Unable to create tool parser"); + } +} + +common_peg_parser analyze_tools::build_tool_parser_json_native(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + // Build effective field names with dot notation if function_field is set + std::string name_field = format.name_field; + std::string args_field = format.args_field; + + if (!format.function_field.empty() && format.function_field != "function" && + name_field.find('.') == std::string::npos) { + name_field = format.function_field + "." + name_field; + args_field = format.function_field + "." + args_field; + } + + auto tools_parser = p.standard_json_tools( + format.section_start, format.section_end, inputs.tools, inputs.parallel_tool_calls, + inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED, name_field, args_field, format.tools_array_wrapped, + format.fun_name_is_key, format.id_field, format.gen_id_field, format.parameter_order); + + // Handle content wrappers if present + if (ctx.content && ctx.content->is_always_wrapped()) { + auto wrapped_content = ctx.content->build_optional_wrapped(ctx); + return ctx.reasoning_parser + wrapped_content + tools_parser + p.end(); + } + + std::string tool_start = "{"; + if (!format.section_start.empty()) { + tool_start = format.section_start; + } else if (!format.per_call_start.empty()) { + tool_start = format.per_call_start; + } + + return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(p.until(tool_start)))) + tools_parser + + p.end(); +} + +common_peg_parser analyze_tools::build_tool_parser_tag_json(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_choice = p.choice(); + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & func = tool.at("function"); + std::string name = func.at("name"); + const auto & schema = func.at("parameters"); + + // Build call_id parser based on position (if supported) + common_peg_parser call_id_section = p.eps(); + if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() && + !call_id.suffix.empty()) { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix; + } + + auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)); + if (!function.close.empty()) { + func_parser = func_parser + function.close; + } + func_parser = p.atomic(func_parser); + + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + auto require_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_calls = p.eps(); + + if (!format.per_call_start.empty()) { + auto wrapped_call = format.per_call_start + tool_choice + format.per_call_end; + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + } else { + tool_calls = p.trigger_rule("tool-call", wrapped_call); + } + if (!format.section_start.empty()) { + tool_calls = p.trigger_rule("tool-calls", + p.literal(format.section_start) + p.space() + tool_calls + p.space() + + (format.section_end.empty() ? p.end() : p.literal(format.section_end))); + } + } else { + std::string separator = ", "; // Default + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice + + p.zero_or_more(separator + tool_choice) + format.section_end); + } else { + tool_calls = p.trigger_rule("tool-call", format.section_start + tool_choice + format.section_end); + } + } + + if (!require_calls) { + tool_calls = p.optional(tool_calls); + } + + std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; + auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker); + return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls + + p.end(); +} + +common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_context & ctx) const { + auto & p = ctx.p; + const auto & inputs = ctx.inputs; + bool force_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_choice = p.choice(); + + foreach_function(inputs.tools, [&](const json & tool) { + const auto & func = tool.at("function"); + std::string name = func.at("name"); + const auto & params = func.at("parameters"); + + if (!params.contains("properties") || !params.at("properties").is_object()) { + return; + } + + const auto & properties = params.at("properties"); + std::set required; + if (params.contains("required") && params.at("required").is_array()) { + params.at("required").get_to(required); + } + + // Build parser for each argument, separating required and optional + std::vector required_parsers; + std::vector optional_parsers; + for (const auto & [param_name, param_schema] : properties.items()) { + bool is_required = required.find(param_name) != required.end(); + std::string type = "object"; + auto type_obj = param_schema.contains("type") ? param_schema.at("type") : json::object(); + if (type_obj.is_string()) { + type_obj.get_to(type); + } else if (type_obj.is_object()) { + if (type_obj.contains("type") && type_obj.at("type").is_string()) { + type_obj.at("type").get_to(type); + } + } + + auto arg = p.tool_arg( + p.tool_arg_open(arguments.name_prefix + p.tool_arg_name(p.literal(param_name)) + + arguments.name_suffix) + + arguments.value_prefix + + (type == "string" ? p.tool_arg_string_value(p.schema(p.until(arguments.value_suffix), + "tool-" + name + "-arg-" + param_name + "-schema", + param_schema, true)) : + p.tool_arg_json_value(p.schema( + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, format.uses_python_dicts)) + + p.space()) + + p.tool_arg_close(p.literal(arguments.value_suffix))); + + auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg); + if (is_required) { + required_parsers.push_back(named_arg); + } else { + optional_parsers.push_back(named_arg); + } + } + + // Build required arg sequence in definition order + common_peg_parser args_seq = p.eps(); + for (size_t i = 0; i < required_parsers.size(); i++) { + if (i > 0) { + args_seq = args_seq + p.space(); + } + args_seq = args_seq + required_parsers[i]; + } + + // Build optional args with flexible ordering + if (!optional_parsers.empty()) { + common_peg_parser any_opt = p.choice(); + for (const auto & opt : optional_parsers) { + any_opt |= opt; + } + args_seq = args_seq + p.repeat(p.space() + any_opt, 0, (int) optional_parsers.size()); + } + + // Build call_id parser based on position (if supported) + common_peg_parser call_id_section = p.eps(); + if (call_id.pos == call_id_position::BETWEEN_FUNC_AND_ARGS && !call_id.prefix.empty() && + !call_id.suffix.empty()) { + call_id_section = p.optional(call_id.prefix + p.tool_id(p.until(call_id.suffix))) + call_id.suffix; + } + + auto func_parser = p.tool_open(function.name_prefix + p.tool_name(p.literal(name)) + function.name_suffix) + + call_id_section + p.space() + args_seq; + + if (!function.close.empty()) { + func_parser = func_parser + p.space() + p.tool_close(p.literal(function.close)); + } else if (!format.per_call_end.empty()) { + // When there's no func_close but there is a per_call_end marker, use peek() to ensure + // we only emit tool_close when we can actually see the closing marker. This prevents + // premature closing during partial parsing when we've seen e.g. "" (end) or "" prefix that failed to match. + func_parser = func_parser + p.tool_close(p.peek(p.literal(format.per_call_end))); + } else { + func_parser = + func_parser + p.tool_close(p.space()); // force this to process tool closing callbacks in mapper + } + + func_parser = p.atomic(func_parser); + tool_choice |= p.rule("tool-" + name, func_parser); + }); + + auto require_tools = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED; + + common_peg_parser tool_calls = p.eps(); + + if (!format.per_call_start.empty()) { + auto wrapped_call = format.per_call_start + p.space() + tool_choice + p.space() + format.per_call_end; + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", wrapped_call + p.zero_or_more(p.space() + wrapped_call)); + } else { + tool_calls = p.trigger_rule("tool-call", wrapped_call); + } + if (!format.section_start.empty()) { + tool_calls = p.trigger_rule("tool-calls", + p.literal(format.section_start) + p.space() + tool_calls + p.space() + + (format.section_end.empty() ? p.end() : p.literal(format.section_end))); + } + } else { + std::string separator = ", "; // Default + + if (inputs.parallel_tool_calls) { + tool_calls = p.trigger_rule("tool-call", format.section_start + p.space() + tool_choice + + p.zero_or_more(separator + tool_choice) + p.space() + + format.section_end); + } else { + tool_calls = p.trigger_rule( + "tool-call", format.section_start + p.space() + tool_choice + p.space() + format.section_end); + } + } + + if (!require_tools) { + tool_calls = p.optional(tool_calls); + } + + std::string trigger_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; + auto content_before_tools = trigger_marker.empty() ? p.eps() : p.until(trigger_marker); + return ctx.reasoning_parser + (force_tools ? p.eps() : p.optional(p.content(content_before_tools))) + tool_calls + + p.end(); +} + +} // namespace autoparser diff --git a/common/chat-auto-parser-helpers.cpp b/common/chat-auto-parser-helpers.cpp new file mode 100644 index 00000000000..1519d8bc6cf --- /dev/null +++ b/common/chat-auto-parser-helpers.cpp @@ -0,0 +1,347 @@ +#include "chat-auto-parser-helpers.h" + +#include "chat-auto-parser.h" +#include "chat.h" +#include "log.h" +#include "nlohmann/json.hpp" + +#include +#include + +using json = nlohmann::ordered_json; + +std::string trim_whitespace(const std::string & str) { + size_t start = 0; + while (start < str.length() && std::isspace(static_cast(str[start]))) { + start++; + } + + if (start == str.length()) { + return ""; + } + + size_t end = str.length() - 1; + while (end > start && std::isspace(static_cast(str[end]))) { + end--; + } + + return str.substr(start, end - start + 1); +} + +std::string trim_leading_whitespace(const std::string & str) { + size_t start = 0; + while (start < str.length() && std::isspace(static_cast(str[start]))) { + start++; + } + + return str.substr(start); +} + +std::string trim_trailing_whitespace(const std::string & str) { + if (str.empty()) { + return ""; + } + + size_t end = str.length() - 1; + while (end > 0 && std::isspace(static_cast(str[end]))) { + end--; + } + + // If first char is also whitespace, return empty string + if (end == 0 && std::isspace(static_cast(str[0]))) { + return ""; + } + + return str.substr(0, end + 1); +} + +std::string trim_trailing_newlines(const std::string & str) { + size_t end = str.length(); + while (end > 0 && str[end - 1] == '\n') { + end--; + } + + return str.substr(0, end); +} + +static size_t common_prefix_len(const std::string & left, const std::string & right) { + size_t prefix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (prefix_len < min_len && left[prefix_len] == right[prefix_len]) { + prefix_len++; + } + return prefix_len; +} + +static size_t common_suffix_len(const std::string & left, const std::string & right) { + size_t suffix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (suffix_len < min_len && left[left.length() - 1 - suffix_len] == right[right.length() - 1 - suffix_len]) { + suffix_len++; + } + return suffix_len; +} + +diff_split calculate_diff_split(const std::string & left, const std::string & right) { + diff_split result; + + auto left_seg = segmentize_markers(left); + auto right_seg = segmentize_markers(right); + + if (left_seg.empty()) { + result.right = right; + return result; + } + if (right_seg.empty()) { + result.left = left; + return result; + } + + auto left_start = left_seg.begin(); + auto left_end = --left_seg.end(); + auto right_start = right_seg.begin(); + auto right_end = --right_seg.end(); + + auto test = [&] () { + return left_start != left_end && right_start != right_end; + }; + + bool left_fully_consumed = false; + bool right_fully_consumed = false; + + while (test()) { + bool advanced = false; + if (*left_start == *right_start) { + result.prefix.append(left_start->value); + left_start++; + right_start++; + advanced = true; + } + if (*left_end == *right_end) { + result.suffix = left_end->value + result.suffix; + if (left_start != left_end) { + left_end--; + } else { + left_fully_consumed = true; + } + if (right_start != right_end) { + right_end--; + } else { + right_fully_consumed = true; + } + advanced = true; + } + if (!advanced) { + break; + } + } + + if (left_start == left_end && right_start != right_end) { + if (*left_start == *right_end) { + result.suffix = right_end->value + result.suffix; + right_end--; + left_fully_consumed = true; + } else if (*left_start == *right_start) { + result.prefix.append(right_start->value); + right_start++; + left_fully_consumed = true; + } + } else if (right_start == right_end && left_start != left_end) { + if (*left_end == *right_start) { + result.suffix = left_end->value + result.suffix; + left_end--; + right_fully_consumed = true; + } else if (*left_start == *right_start) { + result.prefix.append(left_start->value); + left_start++; + right_fully_consumed = true; + } + } else if (left_start == left_end && right_start == right_end && *left_start == *right_start && left_start->type == segment_type::MARKER) { + result.prefix.append(right_start->value); + left_fully_consumed = true; + right_fully_consumed = true; + } + + auto eat_segment = [](std::string & str, segment & seg) -> std::string { return str.append(seg.value); }; + + bool can_have_text_suffix = left_end->type == segment_type::TEXT && right_end->type == segment_type::TEXT; + bool can_have_text_prefix = right_start->type == segment_type::TEXT && left_start->type == segment_type::TEXT; + + std::string remainder_left = std::accumulate(left_start, left_fully_consumed ? left_end : ++left_end, std::string(), eat_segment); + std::string remainder_right = std::accumulate(right_start, right_fully_consumed ? right_end : ++right_end, std::string(), eat_segment); + + size_t suffix_len = can_have_text_suffix ? common_suffix_len(remainder_left, remainder_right) : 0; + // avoid overlaps between prefix and suffix + size_t prefix_len = can_have_text_prefix ? common_prefix_len(remainder_left.substr(0, remainder_left.size() - suffix_len), + remainder_right.substr(0, remainder_right.size() - suffix_len)) : 0; + + result.prefix.append(remainder_left.substr(0, prefix_len)); + result.suffix = remainder_left.substr(remainder_left.length() - suffix_len, suffix_len) + result.suffix; + result.left = remainder_left.substr(prefix_len, remainder_left.length() - prefix_len - suffix_len); + result.right = remainder_right.substr(prefix_len, remainder_right.length() - prefix_len - suffix_len); + + if (result.left == "" && result.right == "") { + // degenerate case, no diff + result.prefix = left; + result.suffix = ""; + // pick prefix = all as representation + } + return result; +} + +// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right` +std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right) { + // Find the common prefix of left and right + size_t common_prefix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (common_prefix_len < min_len && left[common_prefix_len] == right[common_prefix_len]) { + common_prefix_len++; + } + + // If there's no common prefix, return empty string + if (common_prefix_len == 0) { + return ""; + } + + // Find the common prefix in the full string + std::string common_prefix = left.substr(0, common_prefix_len); + size_t pos = full.find(common_prefix); + + // If not found, return empty string + if (pos == std::string::npos) { + return ""; + } + + // Return everything before the common prefix + return full.substr(0, pos); +} + +// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right` +std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right) { + // Find the common suffix of left and right (compare from the end) + size_t common_suffix_len = 0; + size_t min_len = std::min(left.length(), right.length()); + while (common_suffix_len < min_len && + left[left.length() - 1 - common_suffix_len] == right[right.length() - 1 - common_suffix_len]) { + common_suffix_len++; + } + + // If there's no common suffix, return empty string + if (common_suffix_len == 0) { + return ""; + } + + // Extract the common suffix + std::string common_suffix = left.substr(left.length() - common_suffix_len); + + // Find the last occurrence of the common suffix in the full string + size_t pos = full.rfind(common_suffix); + + // If not found, return empty string + if (pos == std::string::npos) { + return ""; + } + + // Return everything after the common suffix + return full.substr(pos + common_suffix_len); +} + +// TODO: segmentize will treat a JSON array inside tags as a tag: [{ "fun": { ... } }] will be three markers +// not too worried about that because it hasn't turned out as a problem anywhere, but noting here in case it will +// Might have to put some restrictions on tag contents as well (like "no { }") +std::vector segmentize_markers(const std::string & text) { + std::vector retval; + bool in_marker = false; + char marker_opener = '\0'; + + auto is_marker_opener = [](char c) -> bool { return c == '<' || c == '['; }; + auto is_marker_closer = [](char op, char c) -> bool { return (op == '<' && c == '>') || (op == '[' && c == ']'); }; + + size_t last_border = 0; + + for (size_t cur_pos = 0; cur_pos < text.length(); cur_pos++) { + if (!in_marker && is_marker_opener(text[cur_pos])) { + if (last_border < cur_pos) { + retval.push_back(segment(segment_type::TEXT, text.substr(last_border, cur_pos - last_border))); + } + last_border = cur_pos; + in_marker = true; + marker_opener = text[cur_pos]; + } else if (in_marker && is_marker_closer(marker_opener, text[cur_pos])) { + // no need to check because last_border will always be smaller + retval.push_back(segment(segment_type::MARKER, text.substr(last_border, cur_pos - last_border + 1))); + last_border = cur_pos + 1; + in_marker = false; + marker_opener = '\0'; + } + } + if (last_border < text.length()) { + retval.push_back(segment(segment_type::TEXT, text.substr(last_border))); + } + return retval; +} + +std::vector prune_whitespace_segments(const std::vector & segments) { + std::vector result; + for (const auto & seg : segments) { + if (!trim_whitespace(seg.value).empty()) { + result.push_back(seg); + } + } + return result; +} + +namespace autoparser { + +std::string apply_template(const common_chat_template & tmpl, const template_params & params) { + templates_params tmpl_params; + tmpl_params.messages = params.messages; + tmpl_params.tools = params.tools; + tmpl_params.add_generation_prompt = params.add_generation_prompt; + tmpl_params.enable_thinking = params.enable_thinking; + + if (params.extra_context) { + tmpl_params.extra_context = *params.extra_context; + } + tmpl_params.extra_context["enable_thinking"] = params.enable_thinking; + + try { + return common_chat_template_direct_apply(tmpl, tmpl_params); + } catch (const std::exception & e) { + LOG_DBG("Template application failed: %s\n", e.what()); + return ""; + } +} + +std::optional compare_variants( + const common_chat_template & tmpl, + const template_params & params_A, + const std::function & params_modifier) { + // Create variant B by copying A + template_params params_B = params_A; + + // Apply modifier to create variant B + if (params_modifier) { + params_modifier(params_B); + } + + // Apply template to both variants + std::string output_A = apply_template(tmpl, params_A); + std::string output_B = apply_template(tmpl, params_B); + + // Check for template application failures + if (output_A.empty() || output_B.empty()) { + return std::nullopt; + } + + // Calculate diff and return result with both outputs + compare_variants_result result; + result.diff = calculate_diff_split(output_A, output_B); + result.output_A = output_A; + result.output_B = output_B; + + return result; +} + +} // namespace autoparser + diff --git a/common/chat-auto-parser-helpers.h b/common/chat-auto-parser-helpers.h new file mode 100644 index 00000000000..6e3df79db8a --- /dev/null +++ b/common/chat-auto-parser-helpers.h @@ -0,0 +1,73 @@ +#pragma once + +#include "chat-auto-parser.h" +#include +#include +#include + +std::string trim_whitespace(const std::string & str); +std::string trim_leading_whitespace(const std::string & str); +std::string trim_trailing_whitespace(const std::string & str); +std::string trim_trailing_newlines(const std::string & str); + +// calculate a diff split (longest common prefix, longest common suffix excluding prefix, +// mismatched part on the left, mismatched part on the right) between two strings +// account for markers - align prefix and suffix endings so that they end on markers +// * eg.: +// calculate_diff_split("
", "

Something

") -> +// { "prefix": "" (not: "<"), "suffix": "", "left": "
", "right": "

Something

" } +// calculate_diff_split("Something", "") -> +// { "prefix": "", "suffix": "", "left": "Something", "right": "" } +diff_split calculate_diff_split(const std::string & left, const std::string & right); + +// Returns the prefix of `full` up until the first occurrence of the common prefix of `left` and `right` +// Returns empty string if there's no common prefix +// * eg.: +// until_common_prefix("really want a FUNCTION call", "FUNCTION alpha", "FUNCTION beta") -> "really want a " +// until_common_prefix("", "", "") -> "" +// until_common_prefix("some text", "1234", "abcd") -> "" +// until_common_prefix("one arg two args three args four", "argument alpha", "argument beta") -> "one "" +std::string until_common_prefix(const std::string & full, const std::string & left, const std::string & right); + +// Returns the suffix of `full` after the last occurrence of the common suffix of `left` and `right` +// Returns empty string if there's no common suffix +// Mirror function of `until_common_prefix` +// * eg.: +// after_common_suffix("really want a FUNCTION call", "first FUNCTION", "second FUNCTION") -> " call" +// after_common_suffix("one arg two-args three args four", "alpha-args", "beta-args") -> " three args four" +std::string after_common_suffix(const std::string & full, const std::string & left, const std::string & right); + +// Segmentize text into markers and non-marker fragments +// * eg.: +// segmentize_markers("The site title
Here's some content
" -> +// [ (MARKER, ""), (MARKER, ""), (MARKER, ""), (TEXT, "The site title"), (MARKER, ""), +// (MARKER, ""), (MARKER, "
"), (TEXT, "Here's some "), (MARKER, ""), (TEXT, "content"), (MARKER, ""), +// (MARKER, "
"), (MARKER, ""), (MARKER, "") +// ] +// segmentize_markers("<|tool_call|>[args]{ are here }[/args]<|tool_call_end|>") -> +// [ (MARKER, "<|tool_call|>"), (MARKER, "[args]"), (TEXT, "{ are here }"), (MARKER, "[/args]"), (MARKER, "<|tool_call_end|>") ] +std::vector segmentize_markers(const std::string & text); + +// Prune whitespace-only segments from a vector of segments +// * eg.: +// segmentize_markers("\n\n\n \n\n\n") -> +// X = [ (MARKER, ""), (TEXT, "\n"), (MARKER, ""), (TEXT, "\n"), (MARKER, ""), (TEXT, "\n \n"), +// (MARKER, ""), (TEXT, "\n"), (MARKER, ""), (TEXT, "\n"), (MARKER, "") ] +// prune_whitespace_segments(X) -> [ (MARKER, ""), (MARKER, ""), (MARKER, ""), (MARKER, ""), +// (MARKER, ""), (MARKER, "") ] +std::vector prune_whitespace_segments(const std::vector & segments); + +namespace autoparser { + +// Apply a template with the given parameters, returning the rendered string (empty on failure) +std::string apply_template(const common_chat_template & tmpl, const template_params & params); + +// Factorized differential comparison function +// Takes base params and a single modifier lambda to create variant B +// Returns compare_variants_result containing diff and both outputs, or std::nullopt on failure +std::optional compare_variants( + const common_chat_template & tmpl, + const template_params & params_A, + const std::function & params_modifier); + +} // namespace autoparser diff --git a/common/chat-auto-parser.h b/common/chat-auto-parser.h new file mode 100644 index 00000000000..52c6488f4b4 --- /dev/null +++ b/common/chat-auto-parser.h @@ -0,0 +1,433 @@ +#pragma once + +#include "chat.h" +#include "common.h" +#include "jinja/caps.h" +#include "peg-parser.h" + +#include +#include +#include +#include +#include + +using json = nlohmann::ordered_json; + +class common_chat_peg_builder; + +// ============================================================================ +// Parameters for template application (low-level, used by diff analysis) +// ============================================================================ +struct template_params { + json messages; + json tools; + bool add_generation_prompt = false; + bool enable_thinking = true; + std::optional extra_context = std::nullopt; +}; + +struct diff_split { + std::string prefix; + std::string suffix; + std::string left; + std::string right; + + bool operator==(struct diff_split & other) const { + return prefix == other.prefix && suffix == other.suffix && left == other.left && right == other.right; + } +}; + +// Result of compare_variants containing diff and original outputs +struct compare_variants_result { + diff_split diff; + std::string output_A; + std::string output_B; +}; + +namespace autoparser { + +// ============================================================================ +// High-level params for parser generation +// ============================================================================ + +struct templates_params { + json messages; + json tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + json json_schema; + bool parallel_tool_calls = true; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; + bool stream = true; + std::string grammar; + bool add_generation_prompt = false; + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + json extra_context; + bool add_bos = false; + bool add_eos = false; + bool is_inference = true; + bool add_inference = false; + bool mark_input = true; // whether to mark input strings in the jinja context +}; + +// ============================================================================ +// Analysis Result Enums +// ============================================================================ + +// Reasoning handling mode (derived from R1-R3 comparisons) +enum class reasoning_mode { + NONE, // No reasoning markers detected + TAG_BASED, // Standard tag-based: ... + DELIMITER, // Delimiter-based: [BEGIN FINAL RESPONSE] (reasoning ends at delimiter) + FORCED_OPEN, // Template ends with open reasoning tag (empty start, non-empty end) + FORCED_CLOSED, // Template ends with open reasoning tag on enabled thinking but + // with both opened and closed tag for disabled thinking + TOOLS_ONLY // Only reason on tool calls, not on normal content +}; + +inline std::ostream & operator<<(std::ostream & os, const reasoning_mode & mode) { + switch (mode) { + case reasoning_mode::NONE: + return os << "NONE"; + case reasoning_mode::TAG_BASED: + return os << "TAG_BASED"; + case reasoning_mode::DELIMITER: + return os << "DELIMITER"; + case reasoning_mode::FORCED_OPEN: + return os << "FORCED_OPEN"; + case reasoning_mode::FORCED_CLOSED: + return os << "FORCED_CLOSED"; + case reasoning_mode::TOOLS_ONLY: + return os << "TOOLS_ONLY"; + default: + return os << "UNKNOWN"; + } +} + +// Content wrapping mode (derived from C1 comparison) +enum class content_mode { + PLAIN, // No content markers + ALWAYS_WRAPPED, // Content always wrapped with markers + WRAPPED_WITH_REASONING, // Content wrapped only when reasoning present +}; + +inline std::ostream & operator<<(std::ostream & os, const content_mode & mode) { + switch (mode) { + case content_mode::PLAIN: + return os << "PLAIN"; + case content_mode::ALWAYS_WRAPPED: + return os << "ALWAYS_WRAPPED"; + case content_mode::WRAPPED_WITH_REASONING: + return os << "WRAPPED_WITH_REASONING"; + default: + return os << "UNKNOWN"; + } +} + +// Call ID position in tool calls (for non-JSON formats) +enum class call_id_position { + NONE, // No call ID support detected + PRE_FUNC_NAME, // Call ID before function name: [CALL_ID]id[FUNC]name{args} + BETWEEN_FUNC_AND_ARGS, // Call ID between function and args: [FUNC]name[CALL_ID]id{args} + POST_ARGS, // Call ID after arguments: [FUNC]name{args}[CALL_ID]id +}; + +inline std::ostream & operator<<(std::ostream & os, const call_id_position & pos) { + switch (pos) { + case call_id_position::NONE: + return os << "NONE"; + case call_id_position::PRE_FUNC_NAME: + return os << "PRE_FUNC_NAME"; + case call_id_position::BETWEEN_FUNC_AND_ARGS: + return os << "BETWEEN_FUNC_AND_ARGS"; + case call_id_position::POST_ARGS: + return os << "POST_ARGS"; + default: + return os << "UNKNOWN"; + } +} + +// Tool call format classification (derived from T1-T5, A1-A3 comparisons) +enum class tool_format { + NONE, // No tool support detected + JSON_NATIVE, // Pure JSON: {"name": "X", "arguments": {...}} + TAG_WITH_JSON, // Tag-based with JSON args: {...} + TAG_WITH_TAGGED, // Tag-based with tagged args: value +}; + +inline std::ostream & operator<<(std::ostream & os, const tool_format & format) { + switch (format) { + case tool_format::NONE: + return os << "NONE"; + case tool_format::JSON_NATIVE: + return os << "JSON_NATIVE"; + case tool_format::TAG_WITH_JSON: + return os << "TAG_WITH_JSON"; + case tool_format::TAG_WITH_TAGGED: + return os << "TAG_WITH_TAGGED"; + default: + return os << "UNKNOWN"; + } +} + +// ============================================================================ +// Sub-structs for tool analysis +// ============================================================================ + +struct tool_format_analysis { + tool_format mode = tool_format::NONE; + + std::string section_start; // e.g., "", "[TOOL_CALLS]", "" + std::string section_end; // e.g., "", "" + std::string per_call_start; // e.g., "<|tool_call_begin|>", "" (for multi-call templates) + std::string per_call_end; // e.g., "<|tool_call_end|>", "" + + bool fun_name_is_key = false; // In JSON format function name is JSON key, i.e. { "": { ... arguments ... } } + bool tools_array_wrapped = false; // Tool calls wrapped in JSON array [...] + bool uses_python_dicts = false; // Tool call args use Python dict format (single-quoted strings) + + std::string function_field = "function"; + std::string name_field = "name"; + std::string args_field = "arguments"; + std::string id_field; + std::string gen_id_field; + std::vector parameter_order; +}; + +struct tool_function_analysis { + std::string name_prefix; // e.g., "", "\"", ":0" + std::string close; // e.g., "", "" (for tag-based) +}; + +struct tool_arguments_analysis { + std::string start; // e.g., "<|tool_call_argument_begin|>", "" + std::string end; // e.g., "<|tool_call_argument_end|>", "" + std::string name_prefix; // e.g., "", "\"" + std::string name_suffix; // e.g., ">", "
", "\":" + std::string value_prefix; // e.g., "", "", "" + std::string value_suffix; // e.g., "", "", "" + std::string separator; // e.g., "", "\n", "," +}; + +struct tool_id_analysis { + call_id_position pos = call_id_position::NONE; + + std::string prefix; // e.g., "[CALL_ID]" (marker before call ID value) + std::string suffix; // e.g., "" (marker after call ID value, before next section) +}; + +// ============================================================================ +// Parser build context (shared interface for build_parser methods) +// ============================================================================ + +struct analyze_content; + +struct parser_build_context { + common_chat_peg_builder & p; + const templates_params & inputs; + common_peg_parser reasoning_parser; + bool extracting_reasoning = false; + const analyze_content * content = nullptr; + + parser_build_context(common_chat_peg_builder & p, const templates_params & inputs); +}; + +// ============================================================================ +// Base class for analyzers with parser building +// ============================================================================ + +struct analyze_base { + virtual ~analyze_base() = default; + virtual common_peg_parser build_parser(parser_build_context & ctx) const = 0; + + protected: + const common_chat_template * tmpl = nullptr; + + analyze_base() = default; + explicit analyze_base(const common_chat_template & tmpl) : tmpl(&tmpl) {} +}; + +// ============================================================================ +// Reasoning analyzer +// ============================================================================ + +struct analyze_reasoning : analyze_base { + reasoning_mode mode = reasoning_mode::NONE; + + std::string start; // e.g., "", "[THINK]", "<|START_THINKING|>", "" + std::string end; // e.g., "", "[BEGIN FINAL RESPONSE]", "<|END_THINKING|>" + + analyze_reasoning() = default; + analyze_reasoning(const common_chat_template & tmpl, bool supports_tools); + + common_peg_parser build_parser(parser_build_context & ctx) const override; + + private: + // Look for reasoning markers in rendered content + void compare_reasoning_presence(); + + // Compare generation prompt with enable_thinking=true vs false + void compare_thinking_enabled(); + + // Check if reasoning is always possible or only in tool calls + void compare_reasoning_scope(); +}; + +// ============================================================================ +// Content analyzer +// ============================================================================ + +struct analyze_content : analyze_base { + content_mode mode = content_mode::PLAIN; + + std::string start; // e.g., "", ">>>all\n", "" + std::string end; // e.g., "", "" + + bool requires_nonnull_content = false; + + analyze_content() = default; + analyze_content(const common_chat_template & tmpl, const analyze_reasoning & reasoning); + + common_peg_parser build_parser(parser_build_context & ctx) const override; + + bool is_always_wrapped() const; + common_peg_parser build_optional_wrapped(parser_build_context & ctx) const; +}; + +// ============================================================================ +// Tool analyzer +// ============================================================================ + +struct analyze_tools : analyze_base { + tool_format_analysis format; + tool_function_analysis function; + tool_arguments_analysis arguments; + tool_id_analysis call_id; + + analyze_tools() = default; + analyze_tools(const common_chat_template & tmpl, + const jinja::caps & caps, + const analyze_reasoning & reasoning); + + common_peg_parser build_parser(parser_build_context & ctx) const override; + + private: + // Extract tool calling 'haystack' for further analysis and delegate further analysis based on format + void analyze_tool_calls(const analyze_reasoning & reasoning); + + // Analyze format based on position of function and argument name in needle + void analyze_tool_call_format(const std::string & haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle, + const analyze_reasoning & reasoning); + + // Analyze specifics of JSON native format (entire tool call is a JSON object) + void analyze_tool_call_format_json_native(const std::string & clean_haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle); + + // Analyze specifics of non-JSON native format (tags for function name or for function name and arguments) + void analyze_tool_call_format_non_json(const std::string & clean_haystack, + const std::string & fun_name_needle); + + // Check for and extract specific per-call markers for non-native-JSON templates with parallel call support + void check_per_call_markers(); + + // Extract function name markers + void extract_function_markers(); + + // Delegates to separate functions for: separator analysis, argument name analysis, argument value analysis + void analyze_arguments(); + + // Extract argument name markers + void extract_argument_name_markers(); + + // Extract argument value markers + void extract_argument_value_markers(); + + // Extract argument separator, if specified (eg. ......) + void extract_argument_separator(); + + // Extract argument wrapper markers, if present (eg. '......') + void extract_args_markers(); + + // Extract call ID markers, if present + void extract_call_id_markers(); + + // Per-format tool parser builders + common_peg_parser build_tool_parser_json_native(parser_build_context & ctx) const; + common_peg_parser build_tool_parser_tag_json(parser_build_context & ctx) const; + common_peg_parser build_tool_parser_tag_tagged(parser_build_context & ctx) const; +}; + +// ============================================================================ +// Main autoparser class +// ============================================================================ + +struct autoparser { + jinja::caps jinja_caps; + analyze_reasoning reasoning; + analyze_content content; + analyze_tools tools; + bool analysis_complete = false; + + // Preserved tokens for tokenizer (union of all non-empty markers) + std::vector preserved_tokens; + + autoparser() = default; + + // Run full differential analysis on a template + void analyze_template(const common_chat_template & tmpl); + + // Build the PEG parser for this template + common_peg_arena build_parser(const templates_params & inputs) const; + + private: + // Collect tokens from entire analysis to preserve + void collect_preserved_tokens(); +}; + +// ============================================================================ +// Parser generator +// ============================================================================ + +class peg_generator { + public: + static common_chat_params generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs); + + static common_chat_params generate_parser(const common_chat_template & tmpl, + const struct templates_params & inputs, + const autoparser & autoparser); +}; + +} // namespace autoparser + +enum segment_type { TEXT, MARKER }; + +inline std::ostream & operator<<(std::ostream & os, const segment_type & type) { + switch (type) { + case segment_type::TEXT: + return os << "TEXT"; + case segment_type::MARKER: + return os << "MARKER"; + default: + return os << "UNKNOWN"; + } +} + +struct segment { + segment_type type; + std::string value; + + segment(segment_type type, std::string value) : type(type), value(std::move(value)) {} + + bool operator==(const segment & other) const { + return type == other.type && value == other.value; + } + + bool operator!=(const segment & other) const { + return !(*this == other); + } +}; diff --git a/common/chat-diff-analyzer.cpp b/common/chat-diff-analyzer.cpp new file mode 100644 index 00000000000..4068340a5c0 --- /dev/null +++ b/common/chat-diff-analyzer.cpp @@ -0,0 +1,1330 @@ +#include "chat-auto-parser.h" +#include "chat-auto-parser-helpers.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "log.h" +#include "nlohmann/json.hpp" +#include "peg-parser.h" + +#include + +#define ANSI_RESET "\033[0m" +#define ANSI_PURPLE "\033[1m\x1b[38;5;126m" +#define ANSI_ORANGE "\033[1m\x1b[38;5;214m" +#define ANSI_RED "\033[1m\x1b[38;5;196m" + +using json = nlohmann::ordered_json; + +namespace autoparser { + +static const std::string FUN_FIRST = "FFF_FIRST_FUN_F"; +static const std::string FUN_SECOND = "SSS_SECOND_FUN_S"; +static const std::string ARG_FIRST = "AA_ARG_FST_AA"; +static const std::string ARG_SECOND = "BB_ARG_SND_BB"; +static const std::string USER_MSG = "U_USER_MSG Hello END_U"; +static const std::string ASSISTANT_MSG = "A_ASST_MSG I can help END_A"; +static const std::string THINKING_CONTENT = "REASON_PART I am thinking END_R"; + +static std::vector> workarounds( + { // Old reasoning Qwen templates - they don't really display reasoning content, but we still want to + // support reasoning on them + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("content.split('')") != std::string::npos && + tmpl.src.find("reasoning_content") == std::string::npos && + analysis.reasoning.mode == reasoning_mode::NONE) { + analysis.reasoning.mode = reasoning_mode::FORCED_OPEN; + analysis.reasoning.start = ""; + analysis.reasoning.end = ""; + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: old Qwen/Deepseek thinking template]\n" ANSI_RESET); + } + }, + // Granite 3.3, with separate reasoning and content markers + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("Write your thoughts between and write your response between " + "") != std::string::npos) { + analysis.reasoning.mode = reasoning_mode::TAG_BASED; + analysis.reasoning.start = ""; + analysis.reasoning.end = ""; + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + analysis.content.mode = content_mode::WRAPPED_WITH_REASONING; + analysis.content.start = ""; + analysis.content.end = ""; + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: Granite 3.3]\n" ANSI_RESET); + } + }, + // Cohere Command R+ - content wrapped in <|CHATBOT_TOKEN|>...<|END_OF_TURN_TOKEN|> + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("<|CHATBOT_TOKEN|>") != std::string::npos && + tmpl.src.find("<|END_OF_TURN_TOKEN|>") != std::string::npos && analysis.content.start.empty()) { + analysis.content.mode = content_mode::ALWAYS_WRAPPED; + analysis.content.start = "<|CHATBOT_TOKEN|>"; + analysis.content.end = "<|END_OF_TURN_TOKEN|>"; + analysis.preserved_tokens.push_back("<|CHATBOT_TOKEN|>"); + analysis.preserved_tokens.push_back("<|END_OF_TURN_TOKEN|>"); + LOG_DBG(ANSI_ORANGE "[Patch: Cohere Command R+]\n" ANSI_RESET); + } + }, + // Functionary - no tool call section delimiter + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find("set has_code_interpreter = tools | selectattr(\"type\", \"equalto\", " + "\"code_interpreter\") | list | length > 0") != std::string::npos) { + analysis.content.mode = content_mode::PLAIN; + analysis.content.end = ""; + analysis.tools.function.name_prefix = ""; + analysis.tools.format.section_start = ""; + analysis.tools.format.section_end = ""; + analysis.tools.format.per_call_start = ""); + analysis.preserved_tokens.push_back("<|eom_id|>"); + analysis.preserved_tokens.push_back(""); + analysis.preserved_tokens.push_back(""); + LOG_DBG(ANSI_ORANGE "[Patch: Functionary 3.1]\n" ANSI_RESET); + } + }, + // DeepSeek-R1-Distill-Qwen + [](const common_chat_template & tmpl, autoparser & analysis) -> void { + if (tmpl.src.find( + "{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>'") != + std::string::npos) { + analysis.tools.format.section_start = "<|tool▁calls▁begin|>"; + analysis.tools.format.section_end = "<|tool▁calls▁end|>"; + analysis.tools.format.per_call_start = "<|tool▁call▁begin|>function"; + analysis.tools.function.name_prefix = "<|tool▁sep|>"; + analysis.tools.format.per_call_end = "<|tool▁call▁end|>"; + analysis.tools.function.close = "```"; + } + } + }); + +// Common JSON structures +static json params_schema = { + { "type", "object" }, + { "properties", + { { ARG_FIRST, { { "type", "string" }, { "description", "First argument" } } }, + { ARG_SECOND, { { "type", "string" }, { "description", "Second argument" } } } } }, + { "required", json::array({}) } +}; + +static json tools = json::array({ + { { "type", "function" }, + { "function", + json{ { "name", FUN_FIRST }, { "description", "Test function foo" }, { "parameters", params_schema } } } }, + { { "type", "function" }, + { "function", + json{ { "name", FUN_SECOND }, { "description", "Test function bar" }, { "parameters", params_schema } } } } +}); + +static json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } +}; + +static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call00001") { + return json{ + { "id", id }, + { "type", "function" }, + { "function", json{ { "name", name }, { "arguments", args } } } + }; +} + +static json first_tool_call_zero_args = build_tool_call(FUN_FIRST, json::object(), "call00001"); +static json first_tool_call_one_arg = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "XXXX" }}, "call00001"); +static json first_tool_call_one_arg_other_val = build_tool_call(FUN_FIRST, {{ ARG_FIRST, "YYYY" }}, "call00001"); +static json first_tool_call_other_arg = build_tool_call(FUN_FIRST, {{ ARG_SECOND, "YYYY" }}, "call00001"); + +static json first_tool_call = + build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00001"); +static json second_tool_call = + build_tool_call(FUN_SECOND, json{ { ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call00002"); +static json first_tool_call_alt_id = + build_tool_call(FUN_FIRST, json{{ ARG_FIRST, "XXXX" }, { ARG_SECOND, "YYYY" }}, "call99999"); + +template +static std::string mode_to_str(T mode) { + std::ostringstream os; + os << mode; + return os.str(); +} + +void autoparser::analyze_template(const common_chat_template & tmpl) { + jinja_caps = tmpl.original_caps(); + reasoning = analyze_reasoning(tmpl, jinja_caps.supports_tool_calls); + content = analyze_content(tmpl, reasoning); + tools = analyze_tools(jinja_caps.supports_tool_calls ? analyze_tools(tmpl, jinja_caps, reasoning) : analyze_tools()); + collect_preserved_tokens(); + + for (auto & workaround : workarounds) { + workaround(tmpl, *this); + } + + LOG_DBG("\n--- Reasoning & Content Structure ---\n"); + LOG_DBG("reasoning_mode: %s\n", mode_to_str(reasoning.mode).c_str()); + LOG_DBG("reasoning_start: '%s'\n", reasoning.start.c_str()); + LOG_DBG("reasoning_end: '%s'\n", reasoning.end.c_str()); + LOG_DBG("content_mode: %s\n", mode_to_str(content.mode).c_str()); + LOG_DBG("content_start: '%s'\n", content.start.c_str()); + LOG_DBG("content_end: '%s'\n", content.end.c_str()); + + LOG_DBG("\n--- Tool Call Structure ---\n"); + LOG_DBG("tool_mode: %s\n", mode_to_str(tools.format.mode).c_str()); + LOG_DBG("supports_tools: %s\n", jinja_caps.supports_tools ? "true" : "false"); + LOG_DBG("supports_parallel_calls: %s\n", jinja_caps.supports_parallel_tool_calls ? "true" : "false"); + LOG_DBG("tool_section_start: '%s'\n", tools.format.section_start.c_str()); + LOG_DBG("tool_section_end: '%s'\n", tools.format.section_end.c_str()); + LOG_DBG("per_call_start: '%s'\n", tools.format.per_call_start.c_str()); + LOG_DBG("per_call_end: '%s'\n", tools.format.per_call_end.c_str()); + LOG_DBG("func_name_prefix: '%s'\n", tools.function.name_prefix.c_str()); + LOG_DBG("func_name_suffix: '%s'\n", tools.function.name_suffix.c_str()); + LOG_DBG("func_close: '%s'\n", tools.function.close.c_str()); + LOG_DBG("python_dict_format: %s\n", tools.format.uses_python_dicts ? "true" : "false"); + LOG_DBG("arg_name_prefix: '%s'\n", tools.arguments.name_prefix.c_str()); + LOG_DBG("arg_name_suffix: '%s'\n", tools.arguments.name_suffix.c_str()); + LOG_DBG("arg_value_prefix: '%s'\n", tools.arguments.value_prefix.c_str()); + LOG_DBG("arg_value_suffix: '%s'\n", tools.arguments.value_suffix.c_str()); + LOG_DBG("name_field: '%s'\n", tools.format.name_field.c_str()); + LOG_DBG("args_field: '%s'\n", tools.format.args_field.c_str()); + LOG_DBG("id_field: '%s'\n", tools.format.id_field.c_str()); + LOG_DBG("gen_id_field: '%s'\n", tools.format.gen_id_field.c_str()); + LOG_DBG("parameter_order: '%s'\n", std::accumulate(tools.format.parameter_order.begin(), tools.format.parameter_order.end(), + std::string(""), [] (const std::string & a, const std::string & b) { return a.empty() ? b : a + ", " + b; } + ).c_str()); + + LOG_DBG(ANSI_PURPLE "=== Differential analysis complete ===\n" ANSI_RESET); + analysis_complete = true; +} + +void autoparser::collect_preserved_tokens() { + auto add_token = [this](const std::string & org_token) { + std::string token = trim_whitespace(org_token); + if (!token.empty()) { + // Avoid duplicates + if (std::find(preserved_tokens.begin(), preserved_tokens.end(), token) == preserved_tokens.end()) { + preserved_tokens.push_back(token); + } + } + }; + + add_token(reasoning.start); + add_token(reasoning.end); + add_token(content.start); + add_token(content.end); + add_token(tools.format.section_start); + add_token(tools.format.section_end); + add_token(tools.format.per_call_start); + add_token(tools.format.per_call_end); + add_token(tools.function.name_prefix); + add_token(tools.function.name_suffix); + add_token(tools.function.close); + add_token(tools.arguments.start); + add_token(tools.arguments.end); + add_token(tools.arguments.name_prefix); + add_token(tools.arguments.name_suffix); + add_token(tools.arguments.separator); + add_token(tools.arguments.value_prefix); + add_token(tools.arguments.value_suffix); + add_token(tools.call_id.prefix); + add_token(tools.call_id.suffix); +} + +analyze_reasoning::analyze_reasoning(const common_chat_template & tmpl, bool supports_tools) + : analyze_base(tmpl) { + LOG_DBG(ANSI_PURPLE "=== Starting differential analysis ===\n" ANSI_RESET); + LOG_DBG(ANSI_ORANGE "Phase 1: Reasoning analysis\n" ANSI_RESET); + + compare_reasoning_presence(); + compare_thinking_enabled(); + if (supports_tools) { + compare_reasoning_scope(); + } +} + +void analyze_reasoning::compare_reasoning_presence() { + json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } + }; + + json assistant_no_reasoning = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json assistant_with_reasoning = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG }, + { "reasoning_content", THINKING_CONTENT } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_reasoning }); + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_with_reasoning }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed, skipping reasoning detection\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + const std::string reasoning_content = THINKING_CONTENT; + + if (!diff.right.empty() && diff.right.find(reasoning_content) != std::string::npos) { + auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space())) + p.rest()); + }); + auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + }); + // try the more aggressive parse first, if it fails, fall back to the delimiter one + auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); + if (!result.result.success()) { + result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); + } + if (result.result.success()) { + if (!result.tags["pre"].empty() && !result.tags["post"].empty()) { + if (parser_wrapped.parse_anywhere_and_extract(diff.right).result.success()) { // both tags in the diff = no forced close + mode = reasoning_mode::TAG_BASED; + } else { + mode = reasoning_mode::FORCED_CLOSED; + } + start = trim_whitespace(result.tags["pre"]); + end = result.tags["post"]; + } else if (!result.tags["post"].empty()) { + mode = reasoning_mode::DELIMITER; + end = result.tags["post"]; + } + } + } +} + +void analyze_reasoning::compare_thinking_enabled() { + json user_msg = json{ + { "role", "user" }, + { "content", USER_MSG } + }; + + template_params params; + params.messages = json::array({ user_msg }); + params.add_generation_prompt = true; + params.enable_thinking = false; + + auto comparison = compare_variants(*tmpl, params, [&](template_params & p) { p.enable_thinking = true; }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET , __func__); + return; + } + + const auto & diff = comparison->diff; + + std::string left_trimmed = trim_whitespace(diff.left); + + if (left_trimmed.empty() && !diff.right.empty()) { + std::string right_trimmed = trim_whitespace(diff.right); + + if (!right_trimmed.empty() && string_ends_with(comparison->output_B, right_trimmed)) { + if (start.empty()) { + start = right_trimmed; + mode = reasoning_mode::FORCED_OPEN; + } + } + } + + if (start.empty() && !end.empty()) { + mode = reasoning_mode::DELIMITER; + } + + // Check for FORCED_CLOSED: when enable_thinking=false produces both start and end markers, + // but enable_thinking=true produces only the start marker + if (!comparison->output_A.empty() && !comparison->output_B.empty()) { + auto parser_start = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.literal(start) + p.space() + p.literal(end) + p.rest(); + }); + auto parser_start_end = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.literal(start)) + p.space() + p.negate(p.literal(end)) + p.rest(); + }); + if (!start.empty() && parser_start_end.parse_anywhere_and_extract(comparison->output_A).result.success() && + parser_start.parse_anywhere_and_extract(comparison->output_B).result.success()) { + mode = reasoning_mode::FORCED_CLOSED; + } else if (!end.empty()) { // we extract the starting marker now since we didn't get it earlier + auto result = parser_start_end.parse_anywhere_and_extract(comparison->output_A); + if (result.result.success()) { + start = result.tags["pre"]; + mode = reasoning_mode::FORCED_CLOSED; + } + } + } + + if (start.empty() && end.empty()) { // we might still have the case of "just open" and "just close" + if (!diff.left.empty() && !diff.right.empty()) { + auto seg_A = segmentize_markers(trim_trailing_whitespace(diff.left)); + auto seg_B = segmentize_markers(trim_trailing_whitespace(diff.right)); + if (seg_A.size() == 1 && seg_B.size() == 1) { + mode = reasoning_mode::FORCED_CLOSED; + start = seg_B[0].value; + end = seg_A[0].value; + } + } + } +} + +void analyze_reasoning::compare_reasoning_scope() { + json assistant_reasoning_content = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG }, + { "reasoning_content", THINKING_CONTENT } + }; + + json assistant_reasoning_tools = json{ + { "role", "assistant" }, + { "content", nullptr }, + { "reasoning_content", THINKING_CONTENT }, + { "tool_calls", + json::array({ build_tool_call(FUN_FIRST, json{ { ARG_FIRST, "VVVV" }, { ARG_SECOND, "XXXX" } }) }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_reasoning_content }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_reasoning_tools }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + std::string reasoning_content = THINKING_CONTENT; + + // Check if reasoning only appears in variant B (with tools) + bool reasoning_in_A = comparison->output_A.find(reasoning_content) != std::string::npos; + bool reasoning_in_B = comparison->output_B.find(reasoning_content) != std::string::npos; + + if (!reasoning_in_A && reasoning_in_B) { + mode = reasoning_mode::TOOLS_ONLY; + LOG_DBG(ANSI_ORANGE "%s: Detected TOOLS_ONLY reasoning mode\n" ANSI_RESET, __func__); + + auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.marker()) + p.space() + p.literal(reasoning_content) + p.space() + p.tag("post", (p.marker() + p.space())); + }); + auto result = parser_wrapped.parse_anywhere_and_extract(comparison->output_B); + if (result.result.success()) { + start = result.tags["pre"]; + end = result.tags["post"]; + } else { + auto parser_delimiter = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.literal(reasoning_content) + p.space() + p.optional(p.tag("post", (p.marker() + p.space()))); + }); + result = parser_delimiter.parse_anywhere_and_extract(comparison->output_B); + if (result.result.success()) { + end = result.tags["post"]; + } else { + LOG_DBG(ANSI_ORANGE "%s: Unable to extracft reasoning markers, falling back to reasoning = NONE\n" ANSI_RESET, __func__); + mode = reasoning_mode::NONE; + } + } + } +} + +analyze_content::analyze_content(const common_chat_template & tmpl, const analyze_reasoning & reasoning) + : analyze_base(tmpl) { + LOG_DBG(ANSI_ORANGE "Phase 2: Content analysis\n" ANSI_RESET); + + json assistant_content_only = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json assistant_with_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ build_tool_call("test_func", json{ { "arg1", "value1" } }) }) } + }; + + json assistant_with_reasoning = json{ + { "role", "assistant" }, + { "content", "" }, + { "reasoning_content", THINKING_CONTENT } + }; + + template_params params_content_only; + params_content_only.messages = json::array({ user_msg, assistant_content_only }); + params_content_only.add_generation_prompt = false; + params_content_only.enable_thinking = true; + params_content_only.tools = tools; + + auto comparison_with_tools = compare_variants(tmpl, params_content_only, [&](template_params & p) { + p.messages = json::array({ user_msg, assistant_with_tools }); + }); + + auto comparison_with_reasoning = compare_variants(tmpl, params_content_only, [&](template_params & p) { + p.messages = json::array({ user_msg, assistant_with_reasoning }); + }); + + if (!comparison_with_tools || !comparison_with_reasoning) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + } + + const auto & diff_tools = comparison_with_tools->diff; + const auto & diff_reasoning = comparison_with_reasoning->diff; + + std::string response = ASSISTANT_MSG; + + bool found_plain_content = false; + if (trim_whitespace(diff_tools.left) == response) { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.space() + diff_reasoning.left + p.space() + p.optional(p.marker()) + p.space() + p.end(); + }); + if (parser.parse_and_extract(diff_reasoning.left).result.success()) { + // We only have the content text in the diff (possibly with a stray EOG marker), so no markers + mode = content_mode::PLAIN; + found_plain_content = true; + } else if (reasoning.mode != reasoning_mode::NONE && !reasoning.end.empty()) { + auto post_reasoning_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.literal(reasoning.end) + p.space() + p.literal(response); + }); + if (post_reasoning_parser.parse_anywhere_and_extract(diff_reasoning.left).result.success()) { + mode = content_mode::PLAIN; + found_plain_content = true; + } + } + } + if (!found_plain_content) { + std::string rdiff = diff_reasoning.left; + if (!reasoning.end.empty() && rdiff.find(reasoning.end) != std::string::npos) { + rdiff = rdiff.substr(rdiff.find(reasoning.end) + reasoning.end.length()); + } + // Take the more promising diff + std::string pure_content = rdiff.length() > diff_tools.left.length() ? rdiff : diff_tools.left; + auto parser_wrapped = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.marker()) + p.space() + p.literal(response) + p.space() + p.tag("post", (p.marker() + p.space())) + p.rest(); + }); + auto result = parser_wrapped.parse_anywhere_and_extract(pure_content); + start = result.tags["pre"]; + end = result.tags["post"]; + // TODO: WRAPPED_WITH_REASONING + } + + // Determine content mode + if (!start.empty() || !end.empty()) { + mode = content_mode::ALWAYS_WRAPPED; + // TODO: END_DELIMITED content mode - delimited at end but not at start? + } +} + +bool analyze_content::is_always_wrapped() const { + return mode == content_mode::ALWAYS_WRAPPED && !start.empty() && !end.empty(); +} + +analyze_tools::analyze_tools(const common_chat_template & tmpl, + const jinja::caps & caps, + const analyze_reasoning & reasoning) + : analyze_base(tmpl) { + LOG_DBG(ANSI_ORANGE "Phase 3: Tool call analysis\n" ANSI_RESET); + + analyze_tool_calls(reasoning); + + if (format.mode != tool_format::NONE && format.mode != tool_format::JSON_NATIVE) { + if (caps.supports_parallel_tool_calls) { + check_per_call_markers(); + } + extract_function_markers(); + if (format.mode == tool_format::TAG_WITH_TAGGED) { + analyze_arguments(); + } + extract_argument_separator(); + extract_args_markers(); + extract_call_id_markers(); + } +} + +void analyze_tools::analyze_tool_calls(const analyze_reasoning & reasoning) { + json assistant_no_tools = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG } + }; + + json assistant_with_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_tools }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_with_tools }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + std::string tool_section = diff.right; + + if (tool_section.empty()) { + return; + } + + analyze_tool_call_format(tool_section, FUN_FIRST, ARG_FIRST, reasoning); +} + +void analyze_tools::analyze_tool_call_format(const std::string & haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle, + const analyze_reasoning & reasoning) { + if (fun_name_needle.empty() || arg_name_needle.empty() || haystack.empty()) { + return; + } + + enum class json_quote_style { NONE, DOUBLE_QUOTES, SINGLE_QUOTES }; + + auto in_json_haystack = [&haystack](const std::string & needle) -> json_quote_style { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.choice({ p.literal("{"), p.literal(":") }) << p.choice({ + p.tag("sq", p.literal("'") + p.literal(needle) + p.literal("'")), + p.tag("dq", p.literal("\"") + p.literal(needle) + p.literal("\"")) }); + }); + auto result = parser.parse_anywhere_and_extract(haystack); + if (!result.result.success()) { + return json_quote_style::NONE; + } + return result.tags.count("sq") && !result.tags["sq"].empty() + ? json_quote_style::SINGLE_QUOTES + : json_quote_style::DOUBLE_QUOTES; + }; + + auto fun_quote = in_json_haystack(fun_name_needle); + auto arg_quote = in_json_haystack(arg_name_needle); + + if (fun_quote != json_quote_style::NONE) { + // no need to check further, we're in JSON land + format.mode = tool_format::JSON_NATIVE; + format.uses_python_dicts = (fun_quote == json_quote_style::SINGLE_QUOTES); + } else if (arg_quote != json_quote_style::NONE) { + format.mode = tool_format::TAG_WITH_JSON; + format.uses_python_dicts = (arg_quote == json_quote_style::SINGLE_QUOTES); + } else { + format.mode = tool_format::TAG_WITH_TAGGED; + } + + // first, remove any reasoning markers + std::string clean_haystack = haystack; + if (!reasoning.start.empty()) { + auto pos = haystack.find(reasoning.start); + if (pos != std::string::npos) { + clean_haystack = haystack.substr(0, pos) + haystack.substr(pos + reasoning.start.length()); + } + } + if (!reasoning.end.empty()) { + auto pos = clean_haystack.find(reasoning.end); + if (pos != std::string::npos) { + clean_haystack = clean_haystack.substr(0, pos) + clean_haystack.substr(pos + reasoning.end.length()); + } + } + + if (format.mode == tool_format::JSON_NATIVE) { + analyze_tool_call_format_json_native(clean_haystack, fun_name_needle, arg_name_needle); + } else { + analyze_tool_call_format_non_json(clean_haystack, fun_name_needle); + } + // always relax whitespace requirements on ending markers since they don't influence content + format.section_end = trim_whitespace(format.section_end); + format.per_call_end = trim_whitespace(format.per_call_end); +} + +void analyze_tools::analyze_tool_call_format_json_native(const std::string & clean_haystack, + const std::string & fun_name_needle, + const std::string & arg_name_needle) { + // we might not have the typical OpenAI tool calling structure + int json_start = clean_haystack.find_first_of('{'); + int json_end = clean_haystack.find_last_of('}'); + std::string cut = clean_haystack.substr(json_start, json_end - json_start + 1); + json call_struct = json::parse(cut); + auto register_field = [&](const std::string & prefix, const nlohmann::detail::iteration_proxy_value & subel) { + if (subel.value().is_string() && std::string(subel.value()).find("call0000") != std::string::npos) { + format.id_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } else if (subel.value().is_string() && std::string(subel.value()) == fun_name_needle) { + format.name_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } else if (subel.value().dump().find(arg_name_needle) != + std::string::npos) { // handle both string and JSON obj variants + format.args_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } else if (subel.key().find("id") != std::string::npos) { + // heuristics for generated id field + format.gen_id_field = !prefix.empty() ? prefix + "." + subel.key() : subel.key(); + } + }; + for (const auto & el : call_struct.items()) { + if (el.key() == fun_name_needle) { + format.fun_name_is_key = true; + // When function name is the key, there's no name field and args are direct + format.name_field.clear(); + format.args_field.clear(); + // Don't register this element - the function name IS the key, not a field + } else { + if (el.value().is_object() && + el.value().dump().find(arg_name_needle) == std::string::npos) { // not the args object + format.function_field = el.key(); + for (const auto & subel : el.value().items()) { + register_field(el.key(), subel); + } + } + // Register this element as a potential field + register_field("", el); + } + } + auto array_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.literal("[") + p.space()) + p.literal(cut) + p.tag("post", p.space() + p.literal("]")); + }); + + auto ar_parse_res = array_parser.parse_anywhere_and_extract(clean_haystack); + if (ar_parse_res.result.success()) { + format.tools_array_wrapped = true; + json_start -= ar_parse_res.tags["pre"].length(); + json_end += ar_parse_res.tags["post"].length(); + } + json_end++; // we want to move past the closing char for end marker extraction + + std::vector> located_params; + if (!format.name_field.empty()) { + located_params.push_back({ clean_haystack.find(format.name_field), format.name_field }); + } + if (!format.args_field.empty()) { + located_params.push_back({ clean_haystack.find(format.args_field), format.args_field }); + } + if (!format.id_field.empty()) { + located_params.push_back({ clean_haystack.find(format.id_field), format.id_field }); + } + if (!format.gen_id_field.empty()) { + located_params.push_back({ clean_haystack.find(format.gen_id_field), format.gen_id_field }); + } + std::sort(located_params.begin(), located_params.end()); + for (auto & pair : located_params) { + format.parameter_order.push_back(pair.second); + } + // we can immediately extract tool calling markers too + format.section_start = trim_leading_whitespace(clean_haystack.substr(0, json_start)); + format.section_end = trim_whitespace(clean_haystack.substr(json_end)); + // When tools_array_wrapped is true, the closing bracket is part of the array structure, + // not a separate section end marker. Clear tool_section_end to avoid duplicate brackets. + if (format.tools_array_wrapped && format.section_end == "]") { + format.section_end.clear(); + } +} + +void analyze_tools::analyze_tool_call_format_non_json(const std::string & clean_haystack, + const std::string & fun_name_needle) { + // first, let's find out if the function is inside a tag or standalone + auto fun_marker_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("fun_marker", p.choice({ + p.tag("fun_pre", p.literal("<") + p.until_one_of({ ">", fun_name_needle })) + p.literal(fun_name_needle) + + p.tag("fun_post", p.negate(p.space() + p.literal("<")) + p.until(">") + p.literal(">")) + p.space(), + p.tag("fun_pre", p.literal("[") + p.until_one_of({ "]", fun_name_needle })) + p.literal(fun_name_needle) + + p.tag("fun_post", p.negate(p.space() + p.literal("[") + p.until("]") + p.literal("]")) + p.space()) })); + }); + auto fun_res = fun_marker_parser.parse_anywhere_and_extract(clean_haystack); + std::string fun_marker = fun_name_needle; + if (fun_res.result.success()) { + fun_marker = fun_res.tags["fun_marker"]; + } + // now, consume up to two markers, then treat everything up to the function marker as function name prefix + auto per_tool_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("sec_start", p.marker() + p.space()) + p.tag("call_start", p.marker() + p.space()) + + p.tag("fun_pre", p.until(fun_marker)) + fun_marker + p.tag("rest", p.rest()); + }); + auto section_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("sec_start", p.marker() + p.space()) + fun_marker + p.tag("rest", p.rest()); + }); + auto result = per_tool_parser.parse_anywhere_and_extract(clean_haystack); + tagged_parse_result result_end; + if (result.result.success()) { + auto double_closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("call_end", p.marker() + p.space()) + p.tag("sec_end", p.marker() + p.space()) + p.end(); + }); + result_end = double_closer_parser.parse_anywhere_and_extract(result.tags["rest"]); + function.name_prefix = fun_res.tags["fun_pre"] + function.name_prefix; + } else { + result = section_parser.parse_anywhere_and_extract(clean_haystack); + auto single_closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("sec_end", p.marker() + p.space()) + p.end(); + }); + result_end = single_closer_parser.parse_anywhere_and_extract(result.tags["rest"]); + } + format.per_call_start = result.tags["call_start"]; + format.per_call_end = result_end.tags["call_end"]; + format.section_start = result.tags["sec_start"]; + format.section_end = result_end.tags["sec_end"]; +} + +void analyze_tools::check_per_call_markers() { + json assistant_one_tool = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_two_tools = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call, second_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_one_tool }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto one_vs_two = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_tools }); }); + + if (!one_vs_two) { + LOG_DBG(ANSI_ORANGE "%s: Generating double tool call comparison failed\n" ANSI_RESET, __func__); + return; + } + + diff_split filter_common_call_part = calculate_diff_split(one_vs_two->diff.suffix, one_vs_two->diff.right); + + std::string second_tool_content = trim_leading_whitespace(filter_common_call_part.right); + if (!format.section_start.empty() && + second_tool_content.find(format.section_start) == 0) { + format.per_call_start = format.section_start; + format.per_call_end = format.section_end; + format.section_start.clear(); + format.section_end.clear(); + } +} + +void analyze_tools::extract_function_markers() { + json assistant_nocall = json{ + { "role", "assistant" }, + { "content", ASSISTANT_MSG }, + }; + + json assistant_foofoo = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_barbar = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ second_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_foofoo }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_barbar }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (diff.left.find(FUN_FIRST) != std::string::npos && diff.right.find(FUN_SECOND) != std::string::npos) { + std::string prefix_marker; + if (!format.per_call_start.empty()) { + prefix_marker = format.per_call_start; + } else { + prefix_marker = format.section_start; + } + if (!prefix_marker.empty() && diff.prefix.rfind(prefix_marker) != std::string::npos) { + function.name_prefix = + diff.prefix.substr(diff.prefix.rfind(prefix_marker) + prefix_marker.size()); + } + + // Extract name prefix/suffix from diff.left (stop at the next marker boundary) + auto name_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("pre", p.until(FUN_FIRST)) + p.literal(FUN_FIRST) + + p.tag("post", p.zero_or_more(p.negate(p.marker()) + p.any())); + }); + auto name_result = name_parser.parse_and_extract(diff.left); + if (name_result.result.success()) { + function.name_prefix += name_result.tags["pre"]; + function.name_suffix = name_result.tags["post"]; + } + + // Extend name_suffix with content from diff.suffix before args begin + if (format.mode == tool_format::TAG_WITH_JSON) { + // For JSON: name_suffix extends to the first non-marker { or [, including any + // markers along the way. Only applies if there's at least one marker after + // the JSON content (matching the original "stop < seg_suf.size() - 1" guard). + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + auto non_json = p.marker() | (p.negate(p.literal("{")) + p.negate(p.literal("[")) + p.any()); + auto after_json = p.zero_or_more(p.negate(p.marker()) + p.any()) + p.marker(); + return p.tag("ext", p.zero_or_more(non_json)) + after_json; + }); + auto suf_result = suffix_parser.parse_and_extract(diff.suffix); + if (suf_result.result.success()) { + function.name_suffix += suf_result.tags["ext"]; + } + } else { + // For tagged: name_suffix extends to the first marker (arg marker) + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.tag("ext", p.zero_or_more(p.negate(p.marker()) + p.any())); + }); + auto suf_result = suffix_parser.parse_and_extract(diff.suffix); + if (suf_result.result.success()) { + function.name_suffix += suf_result.tags["ext"]; + } + } + + // Extract the closer (between last arg and call/section end marker) + std::string suffix_marker; + if (!format.per_call_end.empty()) { + suffix_marker = format.per_call_end; + } else { + suffix_marker = format.section_end; + } + std::string closer_suffix; + if (suffix_marker.empty()) { + // we'll have to rely on an extra diff with no-calls version + auto notool_comp = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_nocall }); }); + auto nt_diff = notool_comp->diff; + closer_suffix = nt_diff.left.substr(nt_diff.left.find("YYYY") + 4); + } else { + closer_suffix = diff.suffix.substr(0, diff.suffix.find(suffix_marker)); + } + if (!closer_suffix.empty()) { + if (format.mode == tool_format::TAG_WITH_TAGGED) { + // After last arg value, skip the closing arg marker, rest is closer + auto closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.until("YYYY") + p.literal("YYYY") + p.space() + + p.marker() + p.space() + + p.tag("close", p.rest()); + }); + auto close_result = closer_parser.parse_and_extract(closer_suffix); + if (close_result.result.success()) { + function.close = close_result.tags["close"]; + } + } else if (format.mode == tool_format::TAG_WITH_JSON) { + // After last arg value, find end of JSON args, rest is closer + auto closer_parser = build_tagged_peg_parser([&](common_peg_parser_builder &p) { + return p.until("YYYY") + p.literal("YYYY") + p.tag("post_val", p.rest()); + }); + auto close_result = closer_parser.parse_and_extract(closer_suffix); + if (close_result.result.success()) { + const auto & post = close_result.tags["post_val"]; + size_t pos = post.find_last_of("}]"); + if (pos != std::string::npos && pos < post.size() - 1) { + function.close = trim_leading_whitespace(post.substr(pos + 1)); + } + } + } + } + function.close = trim_leading_whitespace(function.close); + } +} + +void analyze_tools::analyze_arguments() { + LOG_DBG(ANSI_ORANGE "Phase 4: Argument analysis\n" ANSI_RESET); + + extract_argument_name_markers(); + extract_argument_value_markers(); +} + +void analyze_tools::extract_argument_name_markers() { + json assistant_first_arg = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + json assistant_second_arg = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_other_arg }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_first_arg }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_second_arg }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (!diff.left.empty() && !diff.right.empty()) { + // Parse both sides to find ARG_FIRST/ARG_SECOND and extract the surrounding structure + auto left_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("pre", p.until(ARG_FIRST)) + p.literal(ARG_FIRST) + + p.tag("suffix", p.until_one_of({"\"", "X"})); + }); + auto right_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("pre", p.until(ARG_SECOND)) + p.literal(ARG_SECOND) + + p.tag("suffix", p.until_one_of({"\"", "Y"})); + }); + auto left_result = left_parser.parse_anywhere_and_extract(diff.left); + auto right_result = right_parser.parse_anywhere_and_extract(diff.right); + + if (left_result.result.success() && right_result.result.success() && + !left_result.tags["pre"].empty() && + left_result.tags["pre"] == right_result.tags["pre"] && + left_result.tags["suffix"] == right_result.tags["suffix"]) { + // Name is inside a structure (e.g., JSON key): prefix is the shared wrapper + arguments.name_prefix = trim_whitespace(left_result.tags["pre"]); + arguments.name_suffix = trim_leading_whitespace(left_result.tags["suffix"]); + } else if (diff.left.substr(0, ARG_FIRST.length()) == ARG_FIRST && diff.right.substr(0, ARG_SECOND.length()) == ARG_SECOND) { + // Name is directly in the diff: prefix comes from last marker in diff.prefix + auto pre_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + auto last_marker = p.marker() + p.zero_or_more(p.negate(p.marker()) + p.any()) + p.end(); + return p.zero_or_more(p.negate(last_marker) + p.any()) + p.tag("name_prefix", last_marker); + }); + auto pre_result = pre_parser.parse_and_extract(diff.prefix); + arguments.name_prefix = pre_result.result.success() + ? pre_result.tags["name_prefix"] : diff.prefix; + + // Suffix extends from after ARG_FIRST to the first marker (+ optional whitespace). + // The marker could be in diff.left itself or in diff.suffix, so we concatenate. + std::string after_first = diff.left.substr(ARG_FIRST.length()) + diff.suffix; + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("suffix", p.zero_or_more(p.negate(p.marker()) + p.any()) + + p.marker() + p.space()); + }); + auto suf_result = suffix_parser.parse_anywhere_and_extract(after_first); + if (suf_result.result.success()) { + arguments.name_suffix = suf_result.tags["suffix"]; + } + } + } +} + +void analyze_tools::extract_argument_value_markers() { + json assistant_val_X = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + json assistant_val_Y = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg_other_val }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_val_X }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_val_Y }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (diff.left == "XXXX" && diff.right == "YYYY") { + std::string arg_name_ending = ARG_FIRST + arguments.name_suffix; + std::string prefix = diff.prefix; + if (prefix.rfind(arg_name_ending) != std::string::npos) { + prefix = prefix.substr(prefix.rfind(arg_name_ending) + arg_name_ending.size()); + } + if (!prefix.empty()) { + // Find the last marker + any trailing non-marker text to end + auto prefix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + auto last_marker = p.marker() + p.zero_or_more(p.negate(p.marker()) + p.any()) + p.end(); + return p.zero_or_more(p.negate(last_marker) + p.any()) + p.tag("val_prefix", last_marker); + }); + auto pre_result = prefix_parser.parse_and_extract(prefix); + arguments.value_prefix = pre_result.result.success() ? pre_result.tags["val_prefix"] : prefix; + } + + std::string value_suffix = diff.suffix; + if (!function.close.empty()) { + size_t func_close_pos = value_suffix.find(function.close); + if (func_close_pos != std::string::npos) { + value_suffix = value_suffix.substr(0, func_close_pos); + } + } else if (!format.per_call_end.empty() || !format.section_end.empty()) { + std::string end_marker = + !format.per_call_end.empty() ? format.per_call_end : format.section_end; + size_t end_marker_pos = value_suffix.find(end_marker); + if (end_marker_pos != std::string::npos) { + value_suffix = value_suffix.substr(0, end_marker_pos); + } + } + value_suffix = trim_leading_whitespace(value_suffix); + if (!value_suffix.empty()) { + arguments.value_suffix = value_suffix; + } + } +} + +void analyze_tools::extract_argument_separator() { + json assistant_one_arg = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + json assistant_two_args = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_one_arg }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_two_args }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (!diff.right.empty()) { + std::string separator = until_common_prefix(diff.right, ARG_FIRST, ARG_SECOND); + arguments.separator = separator; + } +} + +void analyze_tools::extract_args_markers() { + json assistant_no_args = json{ + { "role", "assistant"}, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_zero_args }) } + }; + + json assistant_with_args = json{ + { "role", "assistant"}, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_one_arg }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_args }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_with_args }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (format.mode != tool_format::JSON_NATIVE) { + std::string prefix_marker = !format.section_start.empty() ? format.section_start : format.per_call_start; + std::string suffix_marker = !format.section_end.empty() ? format.section_end : format.per_call_end; + // these might happen earlier in the tools section as an example or somewhere else, so we need to find the closest ones + size_t prefix_pos = prefix_marker.empty() ? 0 : diff.prefix.rfind(prefix_marker); + size_t suffix_pos = suffix_marker.empty() ? diff.suffix.size() : diff.suffix.find(suffix_marker); + if (prefix_pos == std::string::npos) { + prefix_pos = 0; + } + if (suffix_pos == std::string::npos) { + suffix_pos = diff.suffix.size(); + } + std::string prefix_cut = diff.prefix.substr(prefix_pos + prefix_marker.size()); + std::string suffix_cut = diff.suffix.substr(0, suffix_pos); + std::string args_start = until_common_prefix(prefix_cut, "{}", "{\"first\":"); + std::string args_end = after_common_suffix(suffix_cut, "{}", "\"XXXX\"}"); + + if (!args_start.empty() || !args_end.empty()) { + size_t find_fun = args_start.find(FUN_FIRST); + if (find_fun != std::string::npos) { + args_start = args_start.substr(find_fun + FUN_FIRST.size(), args_start.size() - find_fun - FUN_FIRST.size()); + } + arguments.start = args_start; + arguments.end = args_end; + } + } +} + +void analyze_tools::extract_call_id_markers() { + json assistant_id1 = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call }) } + }; + + json assistant_id2 = json{ + { "role", "assistant" }, + { "content", "" }, + { "tool_calls", json::array({ first_tool_call_alt_id }) } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_id1 }); + params.tools = tools; + params.add_generation_prompt = false; + params.enable_thinking = true; + + auto comparison = compare_variants( + *tmpl, params, [&](template_params & p) { p.messages = json::array({ user_msg, assistant_id2 }); }); + + if (!comparison) { + LOG_DBG(ANSI_ORANGE "%s: Template application failed for call_id detection\n" ANSI_RESET, __func__); + return; + } + + const auto & diff = comparison->diff; + + if (diff.left.empty() && diff.right.empty()) { + return; + } + + std::string id_value_1 = "call00001"; + std::string id_value_2 = "call99999"; + + size_t common_id_prefix_len = 0; + for (size_t i = 0; i < std::min(id_value_1.length(), id_value_2.length()); i++) { + if (id_value_1[i] == id_value_2[i]) { + common_id_prefix_len++; + } else { + break; + } + } + std::string common_id_part = id_value_1.substr(0, common_id_prefix_len); + + // Check if the function name is in the prefix (normal case: BETWEEN_FUNC_AND_ARGS or POST_ARGS) + // or in the suffix (call_id is PRE_FUNC_NAME) + std::string func_name = FUN_FIRST; + size_t func_name_in_prefix = diff.prefix.rfind(func_name); + size_t func_name_in_suffix = diff.suffix.find(func_name); + + // Helper: find the last marker in a string (returns just the marker, not trailing text) + auto find_last_marker = [](const std::string & str) -> std::string { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + auto last = p.marker() + p.zero_or_more(p.negate(p.marker()) + p.any()) + p.end(); + return p.zero_or_more(p.negate(last) + p.any()) + p.tag("m", p.marker()); + }); + auto res = parser.parse_anywhere_and_extract(str); + return res.result.success() ? res.tags["m"] : ""; + }; + + // Helper: find the first marker in a string + auto find_first_marker = [](const std::string & str) -> std::string { + auto parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("m", p.marker()); + }); + auto res = parser.parse_anywhere_and_extract(str); + return res.result.success() ? res.tags["m"] : ""; + }; + + if (func_name_in_prefix != std::string::npos && func_name_in_suffix == std::string::npos) { + // Function name is only in prefix - call_id is BETWEEN_FUNC_AND_ARGS or POST_ARGS + // Check if args indicator "{" is in prefix or suffix + size_t args_in_prefix = diff.prefix.find('{', func_name_in_prefix); + size_t args_in_suffix = diff.suffix.find('{'); + + if (args_in_suffix != std::string::npos && + (args_in_prefix == std::string::npos || args_in_prefix > diff.prefix.length())) { + // Args are in suffix, so call_id is BETWEEN_FUNC_AND_ARGS + call_id.pos = call_id_position::BETWEEN_FUNC_AND_ARGS; + + // Find call_id_prefix: marker immediately preceding common_id_part (no intervening markers) + std::string after_func = diff.prefix.substr(func_name_in_prefix + func_name.length()); + auto id_prefix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.tag("prefix", p.marker()) + + p.zero_or_more(p.negate(p.marker()) + p.negate(p.literal(common_id_part)) + p.any()) + + p.literal(common_id_part); + }); + auto id_res = id_prefix_parser.parse_anywhere_and_extract(after_func); + if (id_res.result.success()) { + call_id.prefix = id_res.tags["prefix"]; + } else { + // Fallback: use the last marker in after_func + call_id.prefix = find_last_marker(after_func); + } + + // Extract call_id_suffix: the first marker in the suffix before args "{" + auto suffix_parser = build_tagged_peg_parser([&](common_peg_parser_builder & p) { + return p.zero_or_more(p.negate(p.marker()) + p.negate(p.literal("{")) + p.any()) + + p.tag("suffix", p.marker()); + }); + auto suf_res = suffix_parser.parse_anywhere_and_extract(diff.suffix); + if (suf_res.result.success()) { + call_id.suffix = suf_res.tags["suffix"]; + } + } else if (args_in_prefix != std::string::npos) { + // Args are in prefix, so call_id is POST_ARGS + call_id.pos = call_id_position::POST_ARGS; + + // Extract last marker between args closing brace and the ID + std::string after_args = diff.prefix.substr(args_in_prefix); + size_t closing_brace = after_args.rfind('}'); + if (closing_brace != std::string::npos) { + std::string between_args_and_id = after_args.substr(closing_brace + 1); + call_id.prefix = find_last_marker(between_args_and_id); + } + + // call_id_suffix: first marker in diff.suffix + call_id.suffix = find_first_marker(diff.suffix); + } + } else if (func_name_in_suffix != std::string::npos && func_name_in_prefix == std::string::npos) { + // Function name is only in suffix - call_id is PRE_FUNC_NAME + call_id.pos = call_id_position::PRE_FUNC_NAME; + + // call_id_prefix: last marker in diff.prefix + call_id.prefix = find_last_marker(diff.prefix); + + // call_id_suffix: first marker in the portion of diff.suffix before func_name + std::string before_func = diff.suffix.substr(0, func_name_in_suffix); + call_id.suffix = find_first_marker(before_func); + } + + // When call_id is detected, per_call_end may have been incorrectly set to include + // the call_id_suffix and sample args. Clear it if it starts with call_id_suffix. + if (call_id.pos != call_id_position::NONE && !call_id.suffix.empty() && + format.per_call_end.find(call_id.suffix) == 0) { + format.per_call_end.clear(); + } +} + +} // namespace autoparser diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp deleted file mode 100644 index ba359fdbf46..00000000000 --- a/common/chat-parser-xml-toolcall.cpp +++ /dev/null @@ -1,879 +0,0 @@ -#include "chat.h" -#include "chat-parser.h" -#include "common.h" -#include "json-partial.h" -#include "json-schema-to-grammar.h" -#include "log.h" -#include "regex-partial.h" - -using json = nlohmann::ordered_json; - -class xml_toolcall_syntax_exception : public std::runtime_error { - public: - xml_toolcall_syntax_exception(const std::string & message) : std::runtime_error(message) {} -}; - -template -inline void sort_uniq(std::vector &vec) { - std::sort(vec.begin(), vec.end()); - vec.erase(std::unique(vec.begin(), vec.end()), vec.end()); -} - -template -inline bool all_space(const T &str) { - return std::all_of(str.begin(), str.end(), [](unsigned char ch) { return std::isspace(ch); }); -} - -static size_t utf8_truncate_safe(const std::string_view s) { - size_t len = s.size(); - if (len == 0) return 0; - size_t i = len; - for (size_t back = 0; back < 4 && i > 0; ++back) { - --i; - unsigned char c = s[i]; - if ((c & 0x80) == 0) { - return len; - } else if ((c & 0xC0) == 0xC0) { - size_t expected_len = 0; - if ((c & 0xE0) == 0xC0) expected_len = 2; - else if ((c & 0xF0) == 0xE0) expected_len = 3; - else if ((c & 0xF8) == 0xF0) expected_len = 4; - else return i; - if (len - i >= expected_len) { - return len; - } else { - return i; - } - } - } - return len - std::min(len, size_t(3)); -} - -inline void utf8_truncate_safe_resize(std::string &s) { - s.resize(utf8_truncate_safe(s)); -} - -inline std::string_view utf8_truncate_safe_view(const std::string_view s) { - return s.substr(0, utf8_truncate_safe(s)); -} - -static std::optional try_find_2_literal_splited_by_spaces(common_chat_msg_parser & builder, const std::string & literal1, const std::string & literal2) { - if (literal1.size() == 0) return builder.try_find_literal(literal2); - const auto saved_pos = builder.pos(); - while (auto res = builder.try_find_literal(literal1)) { - builder.consume_spaces(); - const auto match_len = std::min(literal2.size(), builder.input().size() - builder.pos()); - if (builder.input().compare(builder.pos(), match_len, literal2, 0, match_len) == 0) { - if (res->prelude.size() != res->groups[0].begin - saved_pos) { - res->prelude = builder.str({saved_pos, res->groups[0].begin}); - } - builder.move_to(builder.pos() + match_len); - res->groups[0].end = builder.pos(); - GGML_ASSERT(res->groups[0].begin != res->groups[0].end); - return res; - } - builder.move_to(res->groups[0].begin + 1); - } - builder.move_to(saved_pos); - return std::nullopt; -} - -/** - * make a GBNF that accept any strings except those containing any of the forbidden strings. - */ -std::string make_gbnf_excluding(std::vector forbids) { - constexpr auto charclass_escape = [](unsigned char c) -> std::string { - if (c == '\\' || c == ']' || c == '^' || c == '-') { - std::string s = "\\"; - s.push_back((char)c); - return s; - } - if (isprint(c)) { - return std::string(1, (char)c); - } - char buf[16]; - snprintf(buf, 15, "\\x%02X", c); - return std::string(buf); - }; - constexpr auto build_expr = [charclass_escape](auto self, const std::vector& forbids, int l, int r, int depth) -> std::string { - std::vector>> children; - int i = l; - while (i < r) { - const std::string &s = forbids[i]; - if ((int)s.size() == depth) { - ++i; - continue; - } - unsigned char c = (unsigned char)s[depth]; - int j = i; - while (j < r && (int)forbids[j].size() > depth && - (unsigned char)forbids[j][depth] == c) { - ++j; - } - children.push_back({c, {i, j}}); - i = j; - } - std::vector alts; - if (!children.empty()) { - std::string cls; - for (auto &ch : children) cls += charclass_escape(ch.first); - alts.push_back(std::string("[^") + cls + "]"); - } - for (auto &ch : children) { - std::string childExpr = self(self, forbids, ch.second.first, ch.second.second, depth+1); - if (!childExpr.empty()) { - std::string quoted_ch = "\""; - if (ch.first == '\\') quoted_ch += "\\\\"; - else if (ch.first == '"') quoted_ch += "\\\""; - else if (isprint(ch.first)) quoted_ch.push_back(ch.first); - else { - char buf[16]; - snprintf(buf, 15, "\\x%02X", ch.first); - quoted_ch += buf; - } - quoted_ch += "\""; - std::string branch = quoted_ch + std::string(" ") + childExpr; - alts.push_back(branch); - } - } - if (alts.empty()) return ""; - std::ostringstream oss; - oss << "( "; - for (size_t k = 0; k < alts.size(); ++k) { - if (k) oss << " | "; - oss << alts[k]; - } - oss << " )"; - return oss.str(); - }; - if (forbids.empty()) return "( . )*"; - sort(forbids.begin(), forbids.end()); - std::string expr = build_expr(build_expr, forbids, 0, forbids.size(), 0); - if (expr.empty()) { - std::string cls; - for (auto &s : forbids) if (!s.empty()) cls += charclass_escape((unsigned char)s[0]); - expr = std::string("( [^") + cls + "] )"; - } - if (forbids.size() == 1) - return expr + "*"; - else - return std::string("( ") + expr + " )*"; -} - -/** - * Build grammar for xml-style tool call - * form.scope_start and form.scope_end can be empty. - * Requires data.format for model-specific hacks. - */ -void build_grammar_xml_tool_call(common_chat_params & data, const json & tools, const struct xml_tool_call_format & form) { - GGML_ASSERT(!form.tool_start.empty()); - GGML_ASSERT(!form.tool_sep.empty()); - GGML_ASSERT(!form.key_start.empty()); - GGML_ASSERT(!form.val_end.empty()); - GGML_ASSERT(!form.tool_end.empty()); - - std::string key_val_sep = form.key_val_sep; - if (form.key_val_sep2) { - key_val_sep += "\n"; - key_val_sep += *form.key_val_sep2; - } - GGML_ASSERT(!key_val_sep.empty()); - - if (tools.is_array() && !tools.empty()) { - data.grammar = build_grammar([&](const common_grammar_builder &builder) { - auto string_arg_val = form.last_val_end ? - builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end, *form.last_val_end})) : - builder.add_rule("string-arg-val", make_gbnf_excluding({form.val_end})); - - std::vector tool_rules; - for (const auto & tool : tools) { - if (!tool.contains("type") || tool.at("type") != "function" || !tool.contains("function")) { - LOG_WRN("Skipping tool without function: %s", tool.dump(2).c_str()); - continue; - } - const auto & function = tool.at("function"); - if (!function.contains("name") || !function.at("name").is_string()) { - LOG_WRN("Skipping invalid function (invalid name): %s", function.dump(2).c_str()); - continue; - } - if (!function.contains("parameters") || !function.at("parameters").is_object()) { - LOG_WRN("Skipping invalid function (invalid parameters): %s", function.dump(2).c_str()); - continue; - } - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - struct parameter_rule { - std::string symbol_name; - bool is_required; - }; - std::vector arg_rules; - if (!parameters.contains("properties") || !parameters.at("properties").is_object()) { - LOG_WRN("Skipping invalid function (invalid properties): %s", function.dump(2).c_str()); - continue; - } else { - std::vector requiredParameters; - if (parameters.contains("required")) { - try { parameters.at("required").get_to(requiredParameters); } - catch (const std::runtime_error&) { - LOG_WRN("Invalid function required parameters, ignoring: %s", function.at("required").dump(2).c_str()); - } - } - sort_uniq(requiredParameters); - for (const auto & [key, value] : parameters.at("properties").items()) { - std::string quoted_key = key; - bool required = std::binary_search(requiredParameters.begin(), requiredParameters.end(), key); - if (form.key_start.back() == '"' && key_val_sep[0] == '"') { - quoted_key = gbnf_format_literal(key); - quoted_key = quoted_key.substr(1, quoted_key.size() - 2); - } - arg_rules.push_back(parameter_rule {builder.add_rule("func-" + name + "-kv-" + key, - gbnf_format_literal(form.key_start) + " " + - gbnf_format_literal(quoted_key) + " " + - gbnf_format_literal(key_val_sep) + " " + - ((value.contains("type") && value["type"].is_string() && value["type"] == "string" && (!form.raw_argval || *form.raw_argval)) ? - (form.raw_argval ? - string_arg_val : - "( " + string_arg_val + " | " + builder.add_schema(name + "-arg-" + key, value) + " )" - ) : - builder.add_schema(name + "-arg-" + key, value) - ) - ), required}); - } - } - - auto next_arg_with_sep = builder.add_rule(name + "-last-arg-end", form.last_val_end ? gbnf_format_literal(*form.last_val_end) : gbnf_format_literal(form.val_end)); - decltype(next_arg_with_sep) next_arg = "\"\""; - for (auto i = arg_rules.size() - 1; /* i >= 0 && */ i < arg_rules.size(); --i) { - std::string include_this_arg = arg_rules[i].symbol_name + " " + next_arg_with_sep; - next_arg = builder.add_rule(name + "-arg-after-" + std::to_string(i), arg_rules[i].is_required ? - include_this_arg : "( " + include_this_arg + " ) | " + next_arg - ); - include_this_arg = gbnf_format_literal(form.val_end) + " " + include_this_arg; - next_arg_with_sep = builder.add_rule(name + "-arg-after-" + std::to_string(i) + "-with-sep", arg_rules[i].is_required ? - include_this_arg : "( " + include_this_arg + " ) | " + next_arg_with_sep - ); - } - - std::string quoted_name = name; - if (form.tool_start.back() == '"' && form.tool_sep[0] == '"') { - quoted_name = gbnf_format_literal(name); - quoted_name = quoted_name.substr(1, quoted_name.size() - 2); - } - quoted_name = gbnf_format_literal(quoted_name); - // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name - if (data.format == COMMON_CHAT_FORMAT_KIMI_K2) { - quoted_name = "\"functions.\" " + quoted_name + " \":\" [0-9]+"; - } - tool_rules.push_back(builder.add_rule(name + "-call", - gbnf_format_literal(form.tool_start) + " " + - quoted_name + " " + - gbnf_format_literal(form.tool_sep) + " " + - next_arg - )); - } - - auto tool_call_once = builder.add_rule("root-tool-call-once", string_join(tool_rules, " | ")); - auto tool_call_more = builder.add_rule("root-tool-call-more", gbnf_format_literal(form.tool_end) + " " + tool_call_once); - auto call_end = builder.add_rule("root-call-end", form.last_tool_end ? gbnf_format_literal(*form.last_tool_end) : gbnf_format_literal(form.tool_end)); - auto tool_call_multiple_with_end = builder.add_rule("root-tool-call-multiple-with-end", tool_call_once + " " + tool_call_more + "* " + call_end); - builder.add_rule("root", - (form.scope_start.empty() ? "" : gbnf_format_literal(form.scope_start) + " ") + - tool_call_multiple_with_end + "?" + - (form.scope_end.empty() ? "" : " " + gbnf_format_literal(form.scope_end)) - ); - }); - - // grammar trigger for tool call - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, form.scope_start + form.tool_start }); - } -} - -/** - * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. - * Throws xml_toolcall_syntax_exception if there is invalid syntax and cannot recover the original status for common_chat_msg_parser. - * form.scope_start, form.tool_sep and form.scope_end can be empty. - */ -inline bool parse_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form) { - GGML_ASSERT(!form.tool_start.empty()); - GGML_ASSERT(!form.key_start.empty()); - GGML_ASSERT(!form.key_val_sep.empty()); - GGML_ASSERT(!form.val_end.empty()); - GGML_ASSERT(!form.tool_end.empty()); - - // Helper to choose return false or throw error - constexpr auto return_error = [](common_chat_msg_parser & builder, auto &start_pos, const bool &recovery) { - LOG_DBG("Failed to parse XML-Style tool call at position: %s\n", gbnf_format_literal(builder.consume_rest().substr(0, 20)).c_str()); - if (recovery) { - builder.move_to(start_pos); - return false; - } else throw xml_toolcall_syntax_exception("Tool call parsing failed with unrecoverable errors. Try using a grammar to constrain the model’s output."); - }; - // Drop substring from needle to end from a JSON - constexpr auto partial_json = [](std::string &json_str, std::string_view needle = "XML_TOOL_CALL_PARTIAL_FLAG") { - auto pos = json_str.rfind(needle); - if (pos == std::string::npos) { - return false; - } - for (auto i = pos + needle.size(); i < json_str.size(); ++i) { - unsigned char ch = static_cast(json_str[i]); - if (ch != '\'' && ch != '"' && ch != '}' && ch != ':' && !std::isspace(ch)) { - return false; - } - } - if (pos != 0 && json_str[pos - 1] == '"') { - --pos; - } - json_str.resize(pos); - return true; - }; - // Helper to generate a partial argument JSON - constexpr auto gen_partial_json = [partial_json](auto set_partial_arg, auto &arguments, auto &builder, auto &function_name) { - auto rest = builder.consume_rest(); - utf8_truncate_safe_resize(rest); - set_partial_arg(rest, "XML_TOOL_CALL_PARTIAL_FLAG"); - auto tool_str = arguments.dump(); - if (partial_json(tool_str)) { - if (builder.add_tool_call(function_name, "", tool_str)) { - return; - } - } - LOG_DBG("Failed to parse partial XML-Style tool call, fallback to non-partial: %s\n", tool_str.c_str()); - }; - // Helper to find a close (because there may be form.last_val_end or form.last_tool_end) - constexpr auto try_find_close = []( - common_chat_msg_parser & builder, - const std::string & end, - const std::optional & alt_end, - const std::string & end_next, - const std::optional & alt_end_next - ) { - auto saved_pos = builder.pos(); - auto tc = builder.try_find_literal(end); - auto val_end_size = end.size(); - if (alt_end) { - auto pos_1 = builder.pos(); - builder.move_to(saved_pos); - auto tc2 = try_find_2_literal_splited_by_spaces(builder, *alt_end, end_next); - if (alt_end_next) { - builder.move_to(saved_pos); - auto tc3 = try_find_2_literal_splited_by_spaces(builder, *alt_end, *alt_end_next); - if (tc3 && (!tc2 || tc2->prelude.size() > tc3->prelude.size())) { - tc2 = tc3; - } - } - if (tc2 && (!tc || tc->prelude.size() > tc2->prelude.size())) { - tc = tc2; - tc->groups[0].end = std::min(builder.input().size(), tc->groups[0].begin + alt_end->size()); - builder.move_to(tc->groups[0].end); - val_end_size = alt_end->size(); - } else { - builder.move_to(pos_1); - } - } - return std::make_pair(val_end_size, tc); - }; - // Helper to find a val_end or last_val_end, returns matched pattern size - const auto try_find_val_end = [try_find_close, &builder, &form]() { - return try_find_close(builder, form.val_end, form.last_val_end, form.tool_end, form.last_tool_end); - }; - // Helper to find a tool_end or last_tool_end, returns matched pattern size - const auto try_find_tool_end = [try_find_close, &builder, &form]() { - return try_find_close(builder, form.tool_end, form.last_tool_end, form.scope_end, std::nullopt); - }; - - bool recovery = true; - const auto start_pos = builder.pos(); - if (!all_space(form.scope_start)) { - if (auto tc = builder.try_find_literal(form.scope_start)) { - if (all_space(tc->prelude)) { - if (form.scope_start.size() != tc->groups[0].end - tc->groups[0].begin) - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.scope_start)); - } else { - builder.move_to(start_pos); - return false; - } - } else return false; - } - while (auto tc = builder.try_find_literal(form.tool_start)) { - if (!all_space(tc->prelude)) { - LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n", - gbnf_format_literal(form.tool_start).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - builder.move_to(tc->groups[0].begin - tc->prelude.size()); - break; - } - - // Find tool name - auto func_name = builder.try_find_literal(all_space(form.tool_sep) ? form.key_start : form.tool_sep); - if (!func_name) { - auto [sz, tc] = try_find_tool_end(); - func_name = tc; - } - if (!func_name) { - // Partial tool name not supported - throw common_chat_msg_partial_exception("incomplete tool_call"); - } - // If the model generate multiple tool call and the first tool call has no argument - if (func_name->prelude.find(form.tool_end) != std::string::npos || (form.last_tool_end ? func_name->prelude.find(*form.last_tool_end) != std::string::npos : false)) { - builder.move_to(func_name->groups[0].begin - func_name->prelude.size()); - auto [sz, tc] = try_find_tool_end(); - func_name = tc; - } - - // Parse tool name - builder.move_to(all_space(form.tool_sep) ? func_name->groups[0].begin : func_name->groups[0].end); - std::string function_name = string_strip(func_name->prelude); - // Kimi-K2 uses functions.{{ tool_call['function']['name'] }}:{{ loop.index }} as function name - if (builder.syntax().format == COMMON_CHAT_FORMAT_KIMI_K2) { - if (string_starts_with(function_name, "functions.")) { - static const std::regex re(":\\d+$"); - if (std::regex_search(function_name, re)) { - function_name = function_name.substr(10, function_name.rfind(":") - 10); - } - } - } - - // Argument JSON - json arguments = json::object(); - - // Helper to generate a partial argument JSON - const auto gen_partial_args = [&](auto set_partial_arg) { - gen_partial_json(set_partial_arg, arguments, builder, function_name); - }; - - // Parse all arg_key/arg_value pairs - while (auto tc = builder.try_find_literal(form.key_start)) { - if (!all_space(tc->prelude)) { - LOG_DBG("XML-Style tool call: Expected %s, but found %s, trying to match next pattern\n", - gbnf_format_literal(form.key_start).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - builder.move_to(tc->groups[0].begin - tc->prelude.size()); - break; - } - if (tc->groups[0].end - tc->groups[0].begin != form.key_start.size()) { - auto tool_call_arg = arguments.dump(); - if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') { - tool_call_arg.resize(tool_call_arg.size() - 1); - } - builder.add_tool_call(function_name, "", tool_call_arg); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_start)); - } - - // Parse arg_key - auto key_res = builder.try_find_literal(form.key_val_sep); - if (!key_res) { - gen_partial_args([&](auto &rest, auto &needle) {arguments[rest + needle] = "";}); - throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.key_val_sep) + " after " + gbnf_format_literal(form.key_start)); - } - if (key_res->groups[0].end - key_res->groups[0].begin != form.key_val_sep.size()) { - gen_partial_args([&](auto &, auto &needle) {arguments[key_res->prelude + needle] = "";}); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.key_val_sep)); - } - auto &key = key_res->prelude; - recovery = false; - - // Parse arg_value - if (form.key_val_sep2) { - if (auto tc = builder.try_find_literal(*form.key_val_sep2)) { - if (!all_space(tc->prelude)) { - LOG_DBG("Failed to parse XML-Style tool call: Unexcepted %s between %s and %s\n", - gbnf_format_literal(tc->prelude).c_str(), - gbnf_format_literal(form.key_val_sep).c_str(), - gbnf_format_literal(*form.key_val_sep2).c_str() - ); - return return_error(builder, start_pos, false); - } - if (tc->groups[0].end - tc->groups[0].begin != form.key_val_sep2->size()) { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(*form.key_val_sep2)); - } - } else { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(*form.key_val_sep2) + " after " + gbnf_format_literal(form.key_val_sep)); - } - } - auto val_start = builder.pos(); - - // Test if arg_val is a partial JSON - std::optional value_json = std::nullopt; - if (!form.raw_argval || !*form.raw_argval) { - try { value_json = builder.try_consume_json(); } - catch (const std::runtime_error&) { builder.move_to(val_start); } - // TODO: Delete this when json_partial adds top-level support for null/true/false - if (builder.pos() == val_start) { - const static std::regex number_regex(R"([0-9-][0-9]*(\.\d*)?([eE][+-]?\d*)?)"); - builder.consume_spaces(); - std::string_view sv = utf8_truncate_safe_view(builder.input()); - sv.remove_prefix(builder.pos()); - std::string rest = "a"; - if (sv.size() < 6) rest = sv; - if (string_starts_with("null", rest) || string_starts_with("true", rest) || string_starts_with("false", rest) || std::regex_match(sv.begin(), sv.end(), number_regex)) { - value_json = {123, {"123", "123"}}; - builder.consume_rest(); - } else { - builder.move_to(val_start); - } - } - } - - // If it is a JSON and followed by , parse as json - // cannot support streaming because it may be a plain text starting with JSON - if (value_json) { - auto json_end = builder.pos(); - builder.consume_spaces(); - if (builder.pos() == builder.input().size()) { - if (form.raw_argval && !*form.raw_argval && (value_json->json.is_string() || value_json->json.is_object() || value_json->json.is_array())) { - arguments[key] = value_json->json; - auto json_str = arguments.dump(); - if (!value_json->healing_marker.json_dump_marker.empty()) { - GGML_ASSERT(std::string::npos != json_str.rfind(value_json->healing_marker.json_dump_marker)); - json_str.resize(json_str.rfind(value_json->healing_marker.json_dump_marker)); - } else { - GGML_ASSERT(json_str.back() == '}'); - json_str.resize(json_str.size() - 1); - } - builder.add_tool_call(function_name, "", json_str); - } else { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - } - LOG_DBG("Possible JSON arg_value: %s\n", value_json->json.dump().c_str()); - throw common_chat_msg_partial_exception("JSON arg_value detected. Waiting for more tokens for validations."); - } - builder.move_to(json_end); - auto [val_end_size, tc] = try_find_val_end(); - if (tc && all_space(tc->prelude) && value_json->healing_marker.marker.empty()) { - if (tc->groups[0].end - tc->groups[0].begin != val_end_size) { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = needle;}); - LOG_DBG("Possible terminated JSON arg_value: %s\n", value_json->json.dump().c_str()); - throw common_chat_msg_partial_exception("Partial literal: " + gbnf_format_literal(form.val_end) + (form.last_val_end ? gbnf_format_literal(*form.last_val_end) : "")); - } else arguments[key] = value_json->json; - } else builder.move_to(val_start); - } - - // If not, parse as plain text - if (val_start == builder.pos()) { - if (auto [val_end_size, value_plain] = try_find_val_end(); value_plain) { - auto &value_str = value_plain->prelude; - if (form.trim_raw_argval) value_str = string_strip(value_str); - if (value_plain->groups[0].end - value_plain->groups[0].begin != val_end_size) { - gen_partial_args([&](auto &, auto &needle) {arguments[key] = value_str + needle;}); - throw common_chat_msg_partial_exception( - "Expected " + gbnf_format_literal(form.val_end) + - " after " + gbnf_format_literal(form.key_val_sep) + - (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "") - ); - } - arguments[key] = value_str; - } else { - if (form.trim_raw_argval) { - gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = string_strip(rest) + needle;}); - } else { - gen_partial_args([&](auto &rest, auto &needle) {arguments[key] = rest + needle;}); - } - throw common_chat_msg_partial_exception( - "Expected " + gbnf_format_literal(form.val_end) + - " after " + gbnf_format_literal(form.key_val_sep) + - (form.key_val_sep2 ? " " + gbnf_format_literal(*form.key_val_sep2) : "") - ); - } - } - } - - // Consume closing tag - if (auto [tool_end_size, tc] = try_find_tool_end(); tc) { - if (!all_space(tc->prelude)) { - LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", - gbnf_format_literal(form.tool_end).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - return return_error(builder, start_pos, recovery); - } - if (tc->groups[0].end - tc->groups[0].begin == tool_end_size) { - // Add the parsed tool call - if (!builder.add_tool_call(function_name, "", arguments.dump())) { - throw common_chat_msg_partial_exception("Failed to add XML-Style tool call"); - } - recovery = false; - continue; - } - } - - auto tool_call_arg = arguments.dump(); - if (tool_call_arg.size() != 0 && tool_call_arg[tool_call_arg.size() - 1] == '}') { - tool_call_arg.resize(tool_call_arg.size() - 1); - } - builder.add_tool_call(function_name, "", tool_call_arg); - throw common_chat_msg_partial_exception("Expected " + gbnf_format_literal(form.tool_end) + " after " + gbnf_format_literal(form.val_end)); - } - if (auto tc = builder.try_find_literal(form.scope_end)) { - if (!all_space(tc->prelude)) { - LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", - gbnf_format_literal(form.scope_end).c_str(), - gbnf_format_literal(tc->prelude).c_str() - ); - return return_error(builder, start_pos, recovery); - } - } else { - if (all_space(form.scope_end)) return true; - builder.consume_spaces(); - if (builder.pos() == builder.input().size()) - throw common_chat_msg_partial_exception("incomplete tool calls"); - LOG_DBG("Failed to parse XML-Style tool call: Expected %s, but found %s\n", - gbnf_format_literal(form.scope_end).c_str(), - gbnf_format_literal(builder.consume_rest()).c_str() - ); - return return_error(builder, start_pos, recovery); - } - - return true; -} - -/** - * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. - * May cause std::runtime_error if there is invalid syntax because partial valid tool call is already sent out to client. - * form.scope_start, form.tool_sep and form.scope_end can be empty. - */ -bool common_chat_msg_parser::try_consume_xml_tool_calls(const struct xml_tool_call_format & form) { - auto pos = pos_; - auto tsize = result_.tool_calls.size(); - try { return parse_xml_tool_calls(*this, form); } - catch (const xml_toolcall_syntax_exception&) {} - move_to(pos); - result_.tool_calls.resize(tsize); - return false; -} - -/** - * Parse content uses reasoning and XML-Style tool call - * TODO: Note that form.allow_toolcall_in_think is not tested yet. If anyone confirms it works, this comment can be removed. - */ -inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, const struct xml_tool_call_format & form, const std::string & start_think = "", const std::string & end_think = "") { - constexpr auto rstrip = [](std::string &s) { - s.resize(std::distance(s.begin(), std::find_if(s.rbegin(), s.rend(), [](unsigned char ch) { return !std::isspace(ch); }).base())); - }; - // Erase substring from l to r, along with additional spaces nearby - constexpr auto erase_spaces = [](auto &str, size_t l, size_t r) { - while (/* l > -1 && */ --l < str.size() && std::isspace(static_cast(str[l]))); - ++l; - while (++r < str.size() && std::isspace(static_cast(str[r]))); - if (l < r) str[l] = '\n'; - if (l + 1 < r) str[l + 1] = '\n'; - if (l != 0) l += 2; - str.erase(l, r - l); - return l; - }; - constexpr auto trim_suffix = [](std::string &content, std::initializer_list list) { - auto best_match = content.size(); - for (auto pattern: list) { - if (pattern.size() == 0) continue; - for (auto match_idx = content.size() - std::min(pattern.size(), content.size()); content.size() > match_idx; match_idx++) { - auto match_len = content.size() - match_idx; - if (content.compare(match_idx, match_len, pattern.data(), match_len) == 0 && best_match > match_idx) { - best_match = match_idx; - } - } - } - if (content.size() > best_match) { - content.erase(best_match); - } - }; - const auto trim_potential_partial_word = [&start_think, &end_think, &form, trim_suffix](std::string &content) { - return trim_suffix(content, { - start_think, end_think, form.scope_start, form.tool_start, form.tool_sep, form.key_start, - form.key_val_sep, form.key_val_sep2 ? form.key_val_sep2->c_str() : "", - form.val_end, form.last_val_end ? form.last_val_end->c_str() : "", - form.tool_end, form.last_tool_end ? form.last_tool_end->c_str() : "", - form.scope_end - }); - }; - - - // Trim leading spaces without affecting keyword matching - static const common_regex spaces_regex("\\s*"); - { - auto tc = builder.consume_regex(spaces_regex); - auto spaces = builder.str(tc.groups[0]); - auto s1 = spaces.size(); - trim_potential_partial_word(spaces); - auto s2 = spaces.size(); - builder.move_to(builder.pos() - (s1 - s2)); - } - - // Parse content - bool reasoning_unclosed = builder.syntax().thinking_forced_open; - std::string unclosed_reasoning_content(""); - for (;;) { - auto tc = try_find_2_literal_splited_by_spaces(builder, form.scope_start, form.tool_start); - std::string content; - std::string tool_call_start; - - if (tc) { - content = std::move(tc->prelude); - tool_call_start = builder.str(tc->groups[0]); - LOG_DBG("Matched tool start: %s\n", gbnf_format_literal(tool_call_start).c_str()); - } else { - content = builder.consume_rest(); - utf8_truncate_safe_resize(content); - } - - // Handle unclosed think block - if (reasoning_unclosed) { - if (auto pos = content.find(end_think); pos == std::string::npos && builder.pos() != builder.input().size()) { - unclosed_reasoning_content += content; - if (!(form.allow_toolcall_in_think && tc)) { - unclosed_reasoning_content += tool_call_start; - continue; - } - } else { - reasoning_unclosed = false; - std::string reasoning_content; - if (pos == std::string::npos) { - reasoning_content = std::move(content); - } else { - reasoning_content = content.substr(0, pos); - content.erase(0, pos + end_think.size()); - } - if (builder.pos() == builder.input().size() && all_space(content)) { - rstrip(reasoning_content); - trim_potential_partial_word(reasoning_content); - rstrip(reasoning_content); - if (reasoning_content.empty()) { - rstrip(unclosed_reasoning_content); - trim_potential_partial_word(unclosed_reasoning_content); - rstrip(unclosed_reasoning_content); - if (unclosed_reasoning_content.empty()) continue; - } - } - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(start_think); - builder.add_content(unclosed_reasoning_content); - builder.add_content(reasoning_content); - if (builder.pos() != builder.input().size() || !all_space(content)) - builder.add_content(end_think); - } else { - builder.add_reasoning_content(unclosed_reasoning_content); - builder.add_reasoning_content(reasoning_content); - } - unclosed_reasoning_content.clear(); - } - } - - // Handle multiple think block - bool toolcall_in_think = false; - for (auto think_start = content.find(start_think); think_start != std::string::npos; think_start = content.find(start_think, think_start)) { - if (auto think_end = content.find(end_think, think_start + start_think.size()); think_end != std::string::npos) { - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { - auto reasoning_content = content.substr(think_start + start_think.size(), think_end - think_start - start_think.size()); - builder.add_reasoning_content(reasoning_content); - think_start = erase_spaces(content, think_start, think_end + end_think.size() - 1); - } else { - think_start = think_end + end_think.size() - 1; - } - } else { - // This start is in thinking block, skip this tool call - // This start is in thinking block - if (form.allow_toolcall_in_think) { - unclosed_reasoning_content = content.substr(think_start + start_think.size()); - } else { - unclosed_reasoning_content = content.substr(think_start + start_think.size()) + tool_call_start; - } - reasoning_unclosed = true; - content.resize(think_start); - toolcall_in_think = true; - } - } - - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { - rstrip(content); - // Handle unclosed token from content: delete all token - if (auto pos = content.rfind(end_think); pos != std::string::npos) { - while (pos != std::string::npos) { - pos = erase_spaces(content, pos, pos + end_think.size() - 1); - pos = content.rfind(end_think, pos); - } - } - // Strip if needed - if (content.size() > 0 && std::isspace(static_cast(content[0]))) { - content = string_strip(content); - } - } - - // remove potential partial suffix - if (builder.pos() == builder.input().size() && builder.is_partial()) { - if (unclosed_reasoning_content.empty()) { - rstrip(content); - trim_potential_partial_word(content); - rstrip(content); - } else { - rstrip(unclosed_reasoning_content); - trim_potential_partial_word(unclosed_reasoning_content); - rstrip(unclosed_reasoning_content); - } - } - - // consume unclosed_reasoning_content if allow_toolcall_in_think is set - if (form.allow_toolcall_in_think && !unclosed_reasoning_content.empty()) { - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content) { - builder.add_reasoning_content(unclosed_reasoning_content); - } else { - if (content.empty()) { - content = start_think + unclosed_reasoning_content; - } else { - content += "\n\n" + start_think; - content += unclosed_reasoning_content; - } - } - unclosed_reasoning_content.clear(); - } - - // Add content - if (!content.empty()) { - // If there are multiple content blocks - if (builder.syntax().reasoning_format != COMMON_REASONING_FORMAT_NONE && !builder.syntax().reasoning_in_content && builder.result().content.size() != 0) { - builder.add_content("\n\n"); - } - builder.add_content(content); - } - - // This start is in thinking block and toolcall_in_think not set, skip this tool call - if (toolcall_in_think && !form.allow_toolcall_in_think) { - continue; - } - - // There is no tool call and all content is parsed - if (!tc) { - GGML_ASSERT(builder.pos() == builder.input().size()); - GGML_ASSERT(unclosed_reasoning_content.empty()); - if (!form.allow_toolcall_in_think) GGML_ASSERT(!reasoning_unclosed); - break; - } - - builder.move_to(tc->groups[0].begin); - if (builder.try_consume_xml_tool_calls(form)) { - auto end_of_tool = builder.pos(); - builder.consume_spaces(); - if (builder.pos() != builder.input().size()) { - builder.move_to(end_of_tool); - if (!builder.result().content.empty()) { - builder.add_content("\n\n"); - } - } - } else { - static const common_regex next_char_regex("."); - auto c = builder.str(builder.consume_regex(next_char_regex).groups[0]); - rstrip(c); - builder.add_content(c); - } - } -} - -/** - * Parse content uses reasoning and XML-Style tool call - */ -void common_chat_msg_parser::consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think, const std::string & end_think) { - parse_msg_with_xml_tool_calls(*this, form, start_think, end_think); -} diff --git a/common/chat-parser-xml-toolcall.h b/common/chat-parser-xml-toolcall.h deleted file mode 100644 index b309fb66705..00000000000 --- a/common/chat-parser-xml-toolcall.h +++ /dev/null @@ -1,45 +0,0 @@ -#pragma once - -#include "chat.h" - -#include - -#include -#include -#include - - -// Sample config: -// MiniMax-M2 (left): \n\nvalue\n...\n... -// GLM 4.5 (right): function_name\nkey\nvalue\n -struct xml_tool_call_format { - std::string scope_start; // \n // \n // can be empty - std::string tool_start; // - std::string tool_sep; // \">\n // \n // can be empty only for parse_xml_tool_calls - std::string key_start; // - std::string key_val_sep; // \"> // \n - std::string val_end; // \n // \n - std::string tool_end; // \n // \n - std::string scope_end; // // // can be empty - // Set this if there can be dynamic spaces inside key_val_sep. - // e.g. key_val_sep= key_val_sep2= for GLM4.5 - std::optional key_val_sep2 = std::nullopt; - // Set true if argval should only be raw string. e.g. Hello "world" hi - // Set false if argval should only be json string. e.g. "Hello \"world\" hi" - // Defaults to std::nullopt, both will be allowed. - std::optional raw_argval = std::nullopt; - std::optional last_val_end = std::nullopt; - std::optional last_tool_end = std::nullopt; - bool trim_raw_argval = false; - bool allow_toolcall_in_think = false; -}; - -// make a GBNF that accept any strings except those containing any of the forbidden strings. -std::string make_gbnf_excluding(std::vector forbids); - -/** - * Build grammar for xml-style tool call - * form.scope_start and form.scope_end can be empty. - * Requires data.format for model-specific hacks. - */ -void build_grammar_xml_tool_call(common_chat_params & data, const nlohmann::ordered_json & tools, const struct xml_tool_call_format & form); diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp deleted file mode 100644 index 060578f0b70..00000000000 --- a/common/chat-parser.cpp +++ /dev/null @@ -1,1649 +0,0 @@ -#include "chat-parser.h" -#include "chat-peg-parser.h" -#include "common.h" -#include "log.h" -#include "peg-parser.h" -#include "regex-partial.h" - -#include -#include -#include -#include -#include -#include -#include - -using json = nlohmann::ordered_json; - -static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder, - const common_regex & prefix, - size_t rstrip_prefix = 0) { - static const std::vector> args_paths = { { "arguments" } }; - if (auto res = builder.try_find_regex(prefix)) { - builder.move_back(rstrip_prefix); - auto tool_calls = builder.consume_json_with_dumped_args(args_paths); - if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call array"); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) { - std::string arguments; - if (builder.is_partial()) { - arguments = (json{ - { "code", code + builder.healing_marker() } - }) - .dump(); - auto idx = arguments.find(builder.healing_marker()); - if (idx != std::string::npos) { - arguments.resize(idx); - } - } else { - arguments = (json{ - { "code", code } - }) - .dump(); - } - return arguments; -} - -/** - * Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between. - * Aggregates the prefix, suffix and in-between text into the content. - */ -static void parse_json_tool_calls( - common_chat_msg_parser & builder, - const std::optional & block_open, - const std::optional & function_regex_start_only, - const std::optional & function_regex, - const common_regex & close_regex, - const std::optional & block_close, - bool allow_raw_python = false, - const std::function & get_function_name = - nullptr) { - auto parse_tool_calls = [&]() { - size_t from = std::string::npos; - auto first = true; - while (true) { - auto start_pos = builder.pos(); - auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) : - function_regex ? builder.try_find_regex(*function_regex, from) : - std::nullopt; - - if (res) { - std::string name; - if (get_function_name) { - name = get_function_name(*res); - } else { - GGML_ASSERT(res->groups.size() == 2); - name = builder.str(res->groups[1]); - } - first = false; - if (name.empty()) { - // get_function_name signalled us that we should skip this match and treat it as content. - from = res->groups[0].begin + 1; - continue; - } - from = std::string::npos; - - auto maybe_raw_python = name == "python" && allow_raw_python; - if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) { - if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) { - if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(close_regex); - } - continue; - } - if (maybe_raw_python) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - if (!builder.add_tool_call(name, "", arguments)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - return; - } - throw common_chat_msg_partial_exception("incomplete tool call"); - } else { - builder.move_to(start_pos); - } - break; - } - if (block_close) { - builder.consume_regex(*block_close); - } - builder.consume_spaces(); - builder.add_content(builder.consume_rest()); - }; - if (block_open) { - if (auto res = builder.try_find_regex(*block_open)) { - parse_tool_calls(); - } else { - builder.add_content(builder.consume_rest()); - } - } else { - parse_tool_calls(); - } -} - -common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) - : input_(input), is_partial_(is_partial), syntax_(syntax) -{ - result_.role = "assistant"; - - while (true) { - std::string id = std::to_string(std::rand()); - if (input.find(id) == std::string::npos) { - healing_marker_ = id; - break; - } - } -} - -std::string common_chat_msg_parser::str(const common_string_range & rng) const { - GGML_ASSERT(rng.begin <= rng.end); - return input_.substr(rng.begin, rng.end - rng.begin); -} - -void common_chat_msg_parser::add_content(const std::string &content) { - result_.content += content; -} - -void common_chat_msg_parser::add_reasoning_content(const std::string &reasoning_content) { - result_.reasoning_content += reasoning_content; -} - -bool common_chat_msg_parser::add_tool_call(const std::string & name, const std::string & id, const std::string & arguments) { - if (name.empty()) { - return false; - } - - common_chat_tool_call tool_call; - tool_call.name = name; - tool_call.arguments = arguments; - tool_call.id = id; - - // LOG_DBG("Tool call arguments:\n\traw: %s\n\tresult: %s\n", arguments.c_str(), tool_call.arguments.c_str()); - result_.tool_calls.emplace_back(tool_call); - - return true; -} -bool common_chat_msg_parser::add_tool_call(const json & tool_call) { - std::string name = tool_call.contains("name") ? tool_call.at("name") : ""; - std::string id = tool_call.contains("id") ? tool_call.at("id") : ""; - std::string arguments = ""; - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else { - arguments = tool_call.at("arguments"); - } - } - - return add_tool_call(name, id, arguments); -} - -bool common_chat_msg_parser::add_tool_calls(const json & arr) { - for (const auto & item : arr) { - if (!add_tool_call(item)) { - return false; - } - } - return true; -} - -bool common_chat_msg_parser::add_tool_call_short_form(const json & tool_call) { - if (!tool_call.is_object() || tool_call.size() != 1) { - return false; - } - - // Get the tool name (the single key in the object) - auto it = tool_call.begin(); - std::string name = it.key(); - - if (name.empty()) { - return false; - } - - // Get the arguments (the nested object) - const json & args_json = it.value(); - std::string arguments = ""; - - if (args_json.is_object()) { - arguments = args_json.dump(); - } else if (args_json.is_string()) { - arguments = args_json; - } else if (!args_json.is_null()) { - // For other types, convert to string representation - arguments = args_json.dump(); - } - - return add_tool_call(name, "", arguments); -} -void common_chat_msg_parser::finish() { - if (!is_partial_ && pos_ != input_.size()) { - throw std::runtime_error("Unexpected content at end of input");// + input_.substr(pos_)); - } -} - -bool common_chat_msg_parser::consume_spaces() { - const auto length = input_.size(); - auto consumed = false; - while (pos_ < length && std::isspace(input_[pos_])) { - ++pos_; - consumed = true; - } - return consumed; -} - -bool common_chat_msg_parser::try_consume_literal(const std::string & literal) { - auto pos = pos_; - for (auto i = 0u; i < literal.size(); ++i) { - if (pos >= input_.size()) { - return false; - } - if (input_[pos] != literal[i]) { - return false; - } - ++pos; - } - pos_ = pos; - return true; -} - -std::optional common_chat_msg_parser::try_find_literal(const std::string & literal) { - auto idx = input_.find(literal, pos_); - if (idx != std::string::npos) { - find_regex_result res; - res.prelude = input_.substr(pos_, idx - pos_); - auto end = idx + literal.size(); - res.groups.emplace_back(common_string_range{idx, end}); - move_to(end); - return res; - } - if (is_partial_) { - idx = string_find_partial_stop(input_, literal); - if (idx != std::string::npos && idx >= pos_) { - find_regex_result res; - res.prelude = input_.substr(pos_, idx - pos_); - auto end = input_.size(); - res.groups.emplace_back(common_string_range{idx, end}); - move_to(end); - return res; - } - } - return std::nullopt; -} - -void common_chat_msg_parser::consume_literal(const std::string & literal) { - if (!try_consume_literal(literal)) { - throw common_chat_msg_partial_exception(literal); - } -} - -bool common_chat_msg_parser::try_parse_reasoning(const std::string & start_think, const std::string & end_think) { - std::string pending_reasoning_prefix; - - if (syntax_.reasoning_format == COMMON_REASONING_FORMAT_NONE) { - return false; - } - - auto set_reasoning_prefix = [&](size_t prefix_pos) { - if (!syntax_.thinking_forced_open || syntax_.reasoning_in_content) { - return; - } - if (prefix_pos + start_think.size() > input_.size()) { - pending_reasoning_prefix.clear(); - return; - } - // Capture the exact literal that opened the reasoning section so we can - // surface it back to callers. This ensures formats that force the - // reasoning tag open (e.g. DeepSeek R1) retain their original prefix - // instead of dropping it during parsing. - pending_reasoning_prefix = input_.substr(prefix_pos, start_think.size()); - }; - - auto handle_reasoning = [&](const std::string & reasoning, bool closed) { - auto stripped_reasoning = string_strip(reasoning); - if (stripped_reasoning.empty()) { - return; - } - if (syntax_.reasoning_in_content) { - add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : start_think); - add_content(stripped_reasoning); - if (closed) { - add_content(syntax_.reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK ? "" : end_think); - } - } else { - if (!pending_reasoning_prefix.empty()) { - add_reasoning_content(pending_reasoning_prefix); - pending_reasoning_prefix.clear(); - } - add_reasoning_content(stripped_reasoning); - } - }; - - const size_t saved_pos = pos_; - const size_t saved_content_size = result_.content.size(); - const size_t saved_reasoning_size = result_.reasoning_content.size(); - - auto restore_state = [&]() { - move_to(saved_pos); - result_.content.resize(saved_content_size); - result_.reasoning_content.resize(saved_reasoning_size); - }; - - // Allow leading whitespace to be preserved as content when reasoning is present at the start - size_t cursor = pos_; - size_t whitespace_end = cursor; - while (whitespace_end < input_.size() && std::isspace(static_cast(input_[whitespace_end]))) { - ++whitespace_end; - } - - if (whitespace_end >= input_.size()) { - restore_state(); - if (syntax_.thinking_forced_open) { - auto rest = input_.substr(saved_pos); - if (!rest.empty()) { - handle_reasoning(rest, /* closed */ !is_partial()); - } - move_to(input_.size()); - return true; - } - return false; - } - - cursor = whitespace_end; - const size_t remaining = input_.size() - cursor; - const size_t start_prefix = std::min(start_think.size(), remaining); - const bool has_start_tag = input_.compare(cursor, start_prefix, start_think, 0, start_prefix) == 0; - - if (has_start_tag && start_prefix < start_think.size()) { - move_to(input_.size()); - return true; - } - - if (has_start_tag) { - if (whitespace_end > pos_) { - add_content(input_.substr(pos_, whitespace_end - pos_)); - } - set_reasoning_prefix(cursor); - cursor += start_think.size(); - } else if (syntax_.thinking_forced_open) { - cursor = whitespace_end; - } else { - restore_state(); - return false; - } - while (true) { - if (cursor >= input_.size()) { - move_to(input_.size()); - return true; - } - - size_t end_pos = input_.find(end_think, cursor); - if (end_pos == std::string::npos) { - std::string_view remaining_view(input_.data() + cursor, input_.size() - cursor); - size_t partial_off = string_find_partial_stop(remaining_view, end_think); - size_t reasoning_end = partial_off == std::string::npos ? input_.size() : cursor + partial_off; - if (reasoning_end > cursor) { - handle_reasoning(input_.substr(cursor, reasoning_end - cursor), /* closed */ partial_off == std::string::npos && !is_partial()); - } - move_to(input_.size()); - return true; - } - - if (end_pos > cursor) { - handle_reasoning(input_.substr(cursor, end_pos - cursor), /* closed */ true); - } else { - handle_reasoning("", /* closed */ true); - } - - cursor = end_pos + end_think.size(); - - while (cursor < input_.size() && std::isspace(static_cast(input_[cursor]))) { - ++cursor; - } - - const size_t next_remaining = input_.size() - cursor; - if (next_remaining == 0) { - move_to(cursor); - return true; - } - - const size_t next_prefix = std::min(start_think.size(), next_remaining); - if (input_.compare(cursor, next_prefix, start_think, 0, next_prefix) == 0) { - if (next_prefix < start_think.size()) { - move_to(input_.size()); - return true; - } - set_reasoning_prefix(cursor); - cursor += start_think.size(); - continue; - } - - move_to(cursor); - return true; - } -} - -std::string common_chat_msg_parser::consume_rest() { - auto rest = input_.substr(pos_); - pos_ = input_.size(); - return rest; -} - -// Tries to find the regex, consumes it (pos right after it) and gives the prelude (right before it) and the groups to the callback. -std::optional common_chat_msg_parser::try_find_regex(const common_regex & regex, size_t from, bool add_prelude_to_content) { - auto m = regex.search(input_, from == std::string::npos ? pos_ : from); - if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { - return std::nullopt; - } - auto prelude = input_.substr(pos_, m.groups[0].begin - pos_); - pos_ = m.groups[0].end; - - if (add_prelude_to_content) { - add_content(prelude); - } - if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { - if (is_partial()) { - throw common_chat_msg_partial_exception(regex.str()); - } - return std::nullopt; - } - return find_regex_result{prelude, m.groups}; -} - -common_chat_msg_parser::find_regex_result common_chat_msg_parser::consume_regex(const common_regex & regex) { - if (auto result = try_consume_regex(regex)) { - return *result; - } - throw common_chat_msg_partial_exception(regex.str()); -} - -std::optional common_chat_msg_parser::try_consume_regex(const common_regex & regex) { - auto m = regex.search(input_, pos_); - if (m.type == COMMON_REGEX_MATCH_TYPE_NONE) { - return std::nullopt; - } - if (m.type == COMMON_REGEX_MATCH_TYPE_PARTIAL) { - if (is_partial()) { - throw common_chat_msg_partial_exception(regex.str()); - } - return std::nullopt; - } - if (m.groups[0].begin != pos_) { - // Didn't match at the current position. - return std::nullopt; - } - pos_ = m.groups[0].end; - - return find_regex_result { - /* .prelude = */ "", - m.groups, - }; -} - -std::optional common_chat_msg_parser::try_consume_json() { - auto it = input_.cbegin() + pos_; - const auto end = input_.cend(); - common_json result; - if (!common_json_parse(it, end, healing_marker_, result)) { - return std::nullopt; - } - pos_ = std::distance(input_.cbegin(), it); - if (result.healing_marker.marker.empty()) { - // No healing marker, just return the parsed json - return result; - } - if (!is_partial()) { - throw common_chat_msg_partial_exception("JSON"); - } - return result; -} - -common_json common_chat_msg_parser::consume_json() { - if (auto result = try_consume_json()) { - return *result; - } - throw common_chat_msg_partial_exception("JSON"); -} - -common_chat_msg_parser::consume_json_result common_chat_msg_parser::consume_json_with_dumped_args( - const std::vector> & args_paths, - const std::vector> & content_paths -) { - if (auto result = try_consume_json_with_dumped_args(args_paths, content_paths)) { - return *result; - } - throw common_chat_msg_partial_exception("JSON"); -} - -std::optional common_chat_msg_parser::try_consume_json_with_dumped_args( - const std::vector> & args_paths, - const std::vector> & content_paths -) { - auto partial = try_consume_json(); - if (!partial) { - return std::nullopt; - } - auto is_arguments_path = [&](const std::vector & path) { - return std::find(args_paths.begin(), args_paths.end(), path) != args_paths.end(); - }; - auto is_content_path = [&](const std::vector & path) { - return std::find(content_paths.begin(), content_paths.end(), path) != content_paths.end(); - }; - - if (partial->healing_marker.marker.empty()) { - if (args_paths.empty()) { - // No arguments to dump, and JSON was parsed fully. - return consume_json_result { - partial->json, - /* .is_partial = */ false, - }; - } - if (is_arguments_path({})) { - // Entire JSON is the arguments and was parsed fully. - return consume_json_result { - partial->json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true), - /* .is_partial = */ false, - }; - } - } - - LOG_DBG("Parsed partial JSON: %s (json_healing_marker: %s)\n", partial->json.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); - - auto found_healing_marker = false; - std::vector path; - std::function remove_unsupported_healings_and_dump_args = [&](const json & j) -> json { - if (is_arguments_path(path)) { - auto arguments = j.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true); - if (is_partial() && !partial->healing_marker.marker.empty()) { - auto idx = arguments.find(partial->healing_marker.json_dump_marker); - if (idx != std::string::npos) { - arguments.resize(idx); - found_healing_marker = true; - } - if (arguments == "\"") { - // This happens because of completing `:"$magic` after `"arguments"` - arguments = ""; - } - } - return arguments; - } - if (is_content_path(path)) { - if (!j.is_string()) { - throw std::runtime_error("Content path must be a string"); - } - std::string str = j; - auto idx = str.find(partial->healing_marker.marker); // not using json_dump_marker as we're inside a string - if (idx != std::string::npos) { - str.resize(idx); - found_healing_marker = true; - } - return str; - } - if (j.is_object()) { - auto obj = json::object(); - for (const auto & p : j.items()) { - const auto & key = p.key(); - const auto & value = p.value(); - const std::string key_str = key; // NOLINT - auto idx = key_str.find(healing_marker_); - if (idx != std::string::npos) { - found_healing_marker = true; - break; - } - path.push_back(key_str); - if (value.is_string()) { - const std::string value_str = value; - if (value_str.find(healing_marker_) != std::string::npos) { - found_healing_marker = true; - if (is_content_path(path)) { - if (partial->healing_marker.marker == partial->healing_marker.json_dump_marker) { - // The healing occurred inside the string: good. Otherwise we just ditch the entire key/value pair. - obj[key] = remove_unsupported_healings_and_dump_args(value); - } - } - break; - } - obj[key] = value; - } else { - obj[key] = remove_unsupported_healings_and_dump_args(value); - } - path.pop_back(); - } - return obj; - } - if (j.is_array()) { - auto arr = json::array(); - for (const auto & value : j) { - if (value.is_string()) { - std::string str = value; - auto idx = str.find(healing_marker_); - if (idx != std::string::npos) { - // Don't heal array values that aren't in the arguments. - found_healing_marker = true; - break; - } - } - arr.push_back(remove_unsupported_healings_and_dump_args(value)); - } - return arr; - } - return j; - }; - - auto cleaned = remove_unsupported_healings_and_dump_args(partial->json); - LOG_DBG("Cleaned up JSON %s to %s (json_healing_marker : '%s')\n", partial->json.dump().c_str(), cleaned.dump().c_str(), partial->healing_marker.json_dump_marker.c_str()); - return consume_json_result { - cleaned, - /* .is_partial = */ found_healing_marker, - }; -} - -void common_chat_msg_parser::clear_tools() { - result_.tool_calls.clear(); -} - -/** - * All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below - * to reduce incremental compile time for parser changes. - */ -static void common_chat_parse_generic(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const std::vector> content_paths = { - {"response"}, - }; - static const std::vector> args_paths = { - {"tool_call", "arguments"}, - {"tool_calls", "arguments"}, - }; - auto data = builder.consume_json_with_dumped_args(args_paths, content_paths); - if (data.value.contains("tool_calls")) { - if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool calls"); - } - } else if (data.value.contains("tool_call")) { - if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (data.value.contains("response")) { - const auto & response = data.value.at("response"); - builder.add_content(response.is_string() ? response.template get() : response.dump(2)); - if (data.is_partial) { - throw common_chat_msg_partial_exception("incomplete response"); - } - } else { - throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON"); - } -} - -static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_magistral(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("[THINK]", "[/THINK]"); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex prefix(regex_escape("[TOOL_CALLS]")); - parse_prefixed_json_tool_call_array(builder, prefix); -} - -static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}}); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { - builder.try_parse_reasoning("", ""); - - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex function_regex( - "\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: "); - static const common_regex close_regex("\\}\\s*"); - - static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\("); - static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*"); - - if (with_builtin_tools) { - static const common_regex builtin_call_regex("<\\|python_tag\\|>"); - if (auto res = builder.try_find_regex(builtin_call_regex)) { - auto fun_res = builder.consume_regex(function_name_regex); - auto function_name = builder.str(fun_res.groups[1]); - - common_healing_marker healing_marker; - json args = json::object(); - while (true) { - if (auto arg_res = builder.try_consume_regex(arg_name_regex)) { - auto arg_name = builder.str(arg_res->groups[1]); - auto partial = builder.consume_json(); - args[arg_name] = partial.json; - healing_marker.marker = partial.healing_marker.marker; - healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker; - builder.consume_spaces(); - if (!builder.try_consume_literal(",")) { - break; - } - } else { - break; - } - } - builder.consume_literal(")"); - builder.consume_spaces(); - - auto arguments = args.dump(); - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - return; - } - } - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ function_regex, - /* function_regex= */ std::nullopt, - close_regex, - std::nullopt); - -} - -static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?function<|tool▁sep|>([^\n]+)\n```json\n"); - static const common_regex close_regex("```[\\s\\r\\n]*<|tool▁call▁end|>"); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) { - static const common_regex function_regex("(?:<|tool▁call▁begin|>)?([^\\n<]+)(?:<|tool▁sep|>)"); - - static const common_regex close_regex("(?:[\\s]*)?<|tool▁call▁end|>"); - static const common_regex tool_calls_begin("(?:<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)"); - static const common_regex tool_calls_end("<|tool▁calls▁end|>"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - parse_json_tool_calls( - builder, - /* block_open= */ tool_calls_begin, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - tool_calls_end); -} - -static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) { - // DeepSeek V3.1 outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - // <|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>NAME\n```json\nJSON\n```<|tool▁call▁end|><|tool▁calls▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_deepseek_v3_1_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - // <|tool▁call▁begin|>NAME<|tool▁sep|>JSON<|tool▁call▁end|> - common_chat_parse_deepseek_v3_1_content(builder); - } - } -} - -static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "<|tool_calls_section_begin|>"; - form.tool_start = "<|tool_call_begin|>"; - form.tool_sep = "<|tool_call_argument_begin|>{"; - form.key_start = "\""; - form.key_val_sep = "\":"; - form.val_end = ","; - form.tool_end = "}<|tool_call_end|>"; - form.scope_end = "<|tool_calls_section_end|>"; - form.raw_argval = false; - form.last_val_end = ""; - form.allow_toolcall_in_think = true; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - builder.consume_reasoning_with_xml_tool_calls(form); -} - -static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) { - static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))"; - static const std::string recipient("(?: to=functions\\.([^<\\s]+))"); - - static const common_regex start_regex("<\\|start\\|>assistant"); - static const common_regex analysis_regex("<\\|channel\\|>analysis"); - static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?"); - static const common_regex preamble_regex("<\\|channel\\|>commentary"); - static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?"); - static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?"); - - auto consume_end = [&](bool include_end = false) { - if (auto res = builder.try_find_literal("<|end|>")) { - return res->prelude + (include_end ? builder.str(res->groups[0]) : ""); - } - return builder.consume_rest(); - }; - - auto handle_tool_call = [&](const std::string & name) { - if (auto args = builder.try_consume_json_with_dumped_args({{}})) { - if (builder.syntax().parse_tool_calls) { - if (!builder.add_tool_call(name, "", args->value) || args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } else if (args->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - }; - - auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional { - auto match = regex.search(input, 0, true); - if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) { - return match; - } - return std::nullopt; - }; - - do { - auto header_start_pos = builder.pos(); - auto content_start = builder.try_find_literal("<|message|>"); - if (!content_start) { - throw common_chat_msg_partial_exception("incomplete header"); - } - - auto header = content_start->prelude; - - if (auto match = regex_match(tool_call1_regex, header)) { - auto group = match->groups[1]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (auto match = regex_match(tool_call2_regex, header)) { - auto group = match->groups[2]; - auto name = header.substr(group.begin, group.end - group.begin); - handle_tool_call(name); - continue; - } - - if (regex_match(analysis_regex, header)) { - builder.move_to(header_start_pos); - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) { - builder.add_content(consume_end(true)); - } else { - builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>"); - } - continue; - } - - if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) { - builder.add_content(consume_end()); - continue; - } - - // Possibly a malformed message, attempt to recover by rolling - // back to pick up the next <|start|> - LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str()); - builder.move_to(header_start_pos); - } while (builder.try_find_regex(start_regex, std::string::npos, false)); - - auto remaining = builder.consume_rest(); - if (!remaining.empty()) { - LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str()); - } -} - -static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.tool_sep = */ "", - /* form.key_start = */ "", - /* form.key_val_sep = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - /* form.key_val_sep2 = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - static const common_regex prefix(regex_escape(" functools[")); - parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1); -} - -static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) { - static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))"); - static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))"); - static const common_regex close_regex(R"(\s*)"); - - parse_json_tool_calls( - builder, - std::nullopt, - function_regex_start_only, - function_regex, - close_regex, - std::nullopt, - /* allow_raw_python= */ true, - /* get_function_name= */ [&](const auto & res) -> std::string { - auto at_start = res.groups[0].begin == 0; - auto name = builder.str(res.groups[1]); - if (!name.empty() && name.back() == '{') { - // Unconsume the opening brace '{' to ensure the JSON parsing goes well. - builder.move_back(1); - } - auto idx = name.find_last_not_of("\n{"); - name = name.substr(0, idx + 1); - if (at_start && name == "all") { - return ""; - } - return name; - }); -} - -static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - // This version of Functionary still supports the llama 3.1 tool call format for the python tool. - static const common_regex python_tag_regex(regex_escape("<|python_tag|>")); - - static const common_regex function_regex(R"()"); - static const common_regex close_regex(R"()"); - - parse_json_tool_calls( - builder, - /* block_open= */ std::nullopt, - /* function_regex_start_only= */ std::nullopt, - function_regex, - close_regex, - std::nullopt); - - if (auto res = builder.try_find_regex(python_tag_regex)) { - auto arguments = wrap_code_as_arguments(builder, builder.consume_rest()); - builder.add_tool_call("python", "", arguments); - return; - } -} - -static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - static const common_regex open_regex( - "(?:" - "(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start) - "(" // match 2 (open_tag) - "" - "|" - "|" - "|" - "|" - "|" - "|" - "|" - ")?" - "(\\s*\\{\\s*\"name\")" // match 3 (named tool call) - ")" - "|]+)>" // match 4 (function name) - "|" // match 5 (function name again) - ); - - while (auto res = builder.try_find_regex(open_regex)) { - const auto & block_start = res->groups[1]; - std::string block_end = block_start.empty() ? "" : "```"; - - const auto & open_tag = res->groups[2]; - std::string close_tag; - - if (!res->groups[3].empty()) { - builder.move_to(res->groups[3].begin); - close_tag = open_tag.empty() ? "" : "value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } else { - throw common_chat_msg_partial_exception("failed to parse tool call"); - } - } else { - auto function_name = builder.str(res->groups[4]); - if (function_name.empty()) { - function_name = builder.str(res->groups[5]); - } - GGML_ASSERT(!function_name.empty()); - - close_tag = ""; - - if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) { - if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - builder.consume_literal(close_tag); - builder.consume_spaces(); - if (!block_end.empty()) { - builder.consume_literal(block_end); - builder.consume_spaces(); - } - } - } - } - - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_granite(common_chat_msg_parser & builder) { - // Parse thinking tags - static const common_regex start_think_regex(regex_escape("")); - static const common_regex end_think_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "groups[0].begin); - builder.try_find_regex(end_think_regex, std::string::npos, false); - // Restore position for try_parse_reasoning() - builder.move_to(res->groups[0].begin); - } - builder.try_parse_reasoning("", ""); - - // Parse response tags - static const common_regex start_response_regex(regex_escape("")); - static const common_regex end_response_regex(regex_escape("")); - // Granite models output partial tokens such as "<" and "")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) { - if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - } else { - builder.add_content(builder.consume_rest()); - } -} - -static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("", ""); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - // Expect JSON array of tool calls - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - builder.add_tool_calls(tool_calls_data.json); - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_apertus(common_chat_msg_parser & builder) { - // Parse thinking tags - builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>"); - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // Look for tool calls - static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>")); - if (auto res = builder.try_find_regex(tool_call_regex)) { - builder.move_to(res->groups[0].end); - - auto tool_calls_data = builder.consume_json(); - if (tool_calls_data.json.is_array()) { - builder.consume_spaces(); - if (!builder.try_consume_literal("<|tools_suffix|>")) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - for (const auto & value : tool_calls_data.json) { - if (value.is_object()) { - builder.add_tool_call_short_form(value); - } - } - } else { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - builder.add_content(builder.consume_rest()); -} - - -static void common_chat_parse_lfm2(common_chat_msg_parser & builder) { - if (!builder.syntax().parse_tool_calls) { - builder.add_content(builder.consume_rest()); - return; - } - - // LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|> - static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>")); - static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>")); - - // Loop through all tool calls - while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(res->groups[0].end); - - // Parse JSON array format: [{"name": "...", "arguments": {...}}] - auto tool_calls_data = builder.consume_json(); - - // Consume end marker - builder.consume_spaces(); - if (!builder.try_consume_regex(tool_call_end_regex)) { - throw common_chat_msg_partial_exception("Expected <|tool_call_end|>"); - } - - // Process each tool call in the array - if (tool_calls_data.json.is_array()) { - for (const auto & tool_call : tool_calls_data.json) { - if (!tool_call.is_object()) { - throw common_chat_msg_partial_exception("Tool call must be an object"); - } - - if (!tool_call.contains("name")) { - throw common_chat_msg_partial_exception("Tool call missing 'name' field"); - } - - std::string function_name = tool_call.at("name"); - std::string arguments = "{}"; - - if (tool_call.contains("arguments")) { - if (tool_call.at("arguments").is_object()) { - arguments = tool_call.at("arguments").dump(); - } else if (tool_call.at("arguments").is_string()) { - arguments = tool_call.at("arguments"); - } - } - - if (!builder.add_tool_call(function_name, "", arguments)) { - throw common_chat_msg_partial_exception("Incomplete tool call"); - } - } - } else { - throw common_chat_msg_partial_exception("Expected JSON array for tool calls"); - } - - // Consume any trailing whitespace after this tool call - builder.consume_spaces(); - } - - // Consume any remaining content after all tool calls - auto remaining = builder.consume_rest(); - if (!string_strip(remaining).empty()) { - builder.add_content(remaining); - } -} - -static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "", - /* form.key_start = */ "", - /* form.val_end = */ "", - /* form.tool_end = */ "", - /* form.scope_end = */ "", - }; - builder.consume_reasoning_with_xml_tool_calls(form, "", ""); -} - -static void common_chat_parse_solar_open(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>"); - - // TODO: Tool calling - - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_exaone_moe_content(common_chat_msg_parser & builder) { - // 1) { "name": "...", "arguments": {...} } - // 2) { "id": "...", "type": "function", "function": { "name": "...", "arguments": {...} } } - static const common_regex tool_call_open(R"(]*>)"); - - if (!builder.syntax().parse_tool_calls) { - LOG_DBG("%s: not parse_tool_calls\n", __func__); - builder.add_content(builder.consume_rest()); - return; - } - - LOG_DBG("%s: parse_tool_calls\n", __func__); - - // Find all blocks - while (auto first = builder.try_find_regex(tool_call_open, std::string::npos, /* add_prelude_to_content= */ true)) { - builder.move_to(first->groups[0].end); - builder.consume_spaces(); - - builder.try_consume_literal("```json"); - builder.try_consume_literal("```"); - builder.consume_spaces(); - - // Consume JSON object - auto data = builder.consume_json(); - - builder.consume_spaces(); - builder.try_consume_literal("```"); - builder.consume_spaces(); - - if (!builder.try_consume_literal("")) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_spaces(); - - // Extract name and arguments - std::string name; - std::string id; - nlohmann::ordered_json arguments; - - const auto extract_args = [&](const nlohmann::ordered_json & obj) -> bool { - if (!obj.contains("name") || !obj.contains("arguments")) { - return false; - } - name = obj.at("name").get(); - arguments = obj.at("arguments"); - if (obj.contains("id") && obj.at("id").is_string()) { - id = obj.at("id").get(); - } - return true; - }; - - if (!extract_args(data.json)) { - if (data.json.contains("function") && data.json.at("function").is_object()) { - auto fn = data.json.at("function"); - extract_args(fn); - if (id.empty() && data.json.contains("id") && data.json.at("id").is_string()) { - id = data.json.at("id").get(); - } - } - } - - // If name is empty, treat the JSON object as content - if (name.empty()) { - LOG_DBG("%s: tool call missing name, treating as content\n", __func__); - builder.add_content(data.json.dump()); - continue; - } - - std::string args_str = arguments.dump(); - if (!builder.add_tool_call(name, id, args_str)) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse_exaone_moe(common_chat_msg_parser & builder) { - LOG_DBG("%s: parsing exaone_moe\n", __func__); - // EXAONE MoE outputs reasoning content between "" and "" tags, followed by regular content - // First try to parse using the standard reasoning parsing method - LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str()); - - auto start_pos = builder.pos(); - auto found_end_think = builder.try_find_literal(""); - builder.move_to(start_pos); - - if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) { - LOG_DBG("%s: no end_think, not partial, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - } else if (builder.try_parse_reasoning("", "")) { - // If reasoning was parsed successfully, the remaining content is regular content - LOG_DBG("%s: parsed reasoning, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - } else { - if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) { - LOG_DBG("%s: reasoning_format none, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - return; - } - // If no reasoning tags found, check if we should treat everything as reasoning - if (builder.syntax().thinking_forced_open) { - // If thinking is forced open but no tags found, treat everything as reasoning - LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__); - builder.add_reasoning_content(builder.consume_rest()); - } else { - LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__); - common_chat_parse_exaone_moe_content(builder); - } - } -} - -static void common_chat_parse_content_only(common_chat_msg_parser & builder) { - builder.try_parse_reasoning("", ""); - builder.add_content(builder.consume_rest()); -} - -static void common_chat_parse(common_chat_msg_parser & builder) { - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str()); - - switch (builder.syntax().format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: - common_chat_parse_content_only(builder); - break; - case COMMON_CHAT_FORMAT_GENERIC: - common_chat_parse_generic(builder); - break; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: - common_chat_parse_mistral_nemo(builder); - break; - case COMMON_CHAT_FORMAT_MAGISTRAL: - common_chat_parse_magistral(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X: - common_chat_parse_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: - common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: - common_chat_parse_deepseek_r1(builder); - break; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: - common_chat_parse_deepseek_v3_1(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: - common_chat_parse_functionary_v3_2(builder); - break; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: - common_chat_parse_functionary_v3_1_llama_3_1(builder); - break; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: - common_chat_parse_hermes_2_pro(builder); - break; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: - common_chat_parse_firefunction_v2(builder); - break; - case COMMON_CHAT_FORMAT_COMMAND_R7B: - common_chat_parse_command_r7b(builder); - break; - case COMMON_CHAT_FORMAT_GRANITE: - common_chat_parse_granite(builder); - break; - case COMMON_CHAT_FORMAT_GPT_OSS: - common_chat_parse_gpt_oss(builder); - break; - case COMMON_CHAT_FORMAT_SEED_OSS: - common_chat_parse_seed_oss(builder); - break; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: - common_chat_parse_nemotron_v2(builder); - break; - case COMMON_CHAT_FORMAT_APERTUS: - common_chat_parse_apertus(builder); - break; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: - common_chat_parse_lfm2(builder); - break; - case COMMON_CHAT_FORMAT_MINIMAX_M2: - common_chat_parse_minimax_m2(builder); - break; - case COMMON_CHAT_FORMAT_GLM_4_5: - common_chat_parse_glm_4_5(builder); - break; - case COMMON_CHAT_FORMAT_KIMI_K2: - common_chat_parse_kimi_k2(builder); - break; - case COMMON_CHAT_FORMAT_APRIEL_1_5: - common_chat_parse_apriel_1_5(builder); - break; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: - common_chat_parse_xiaomi_mimo(builder); - break; - case COMMON_CHAT_FORMAT_SOLAR_OPEN: - common_chat_parse_solar_open(builder); - break; - case COMMON_CHAT_FORMAT_EXAONE_MOE: - common_chat_parse_exaone_moe(builder); - break; - default: - throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); - } - builder.finish(); -} - -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax) { - if (syntax.format == COMMON_CHAT_FORMAT_PEG_SIMPLE || - syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE || - syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { - return common_chat_peg_parse(syntax.parser, input, is_partial, syntax); - } - common_chat_msg_parser builder(input, is_partial, syntax); - try { - common_chat_parse(builder); - } catch (const common_chat_msg_partial_exception & ex) { - LOG_DBG("Partial parse: %s\n", ex.what()); - if (!is_partial) { - builder.clear_tools(); - builder.move_to(0); - common_chat_parse_content_only(builder); - } - } - auto msg = builder.result(); - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} - -common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax) { - if (parser.empty()) { - throw std::runtime_error("Failed to parse due to missing parser definition."); - } - - LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(syntax.format), input.c_str()); - - common_peg_parse_context ctx(input, is_partial); - auto result = parser.parse(ctx); - if (result.fail()) { - throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end)); - } - - common_chat_msg msg; - msg.role = "assistant"; - - if (syntax.format == COMMON_CHAT_FORMAT_PEG_NATIVE) { - auto mapper = common_chat_peg_native_mapper(msg); - mapper.from_ast(ctx.ast, result); - } else if (syntax.format == COMMON_CHAT_FORMAT_PEG_CONSTRUCTED) { - auto mapper = common_chat_peg_constructed_mapper(msg); - mapper.from_ast(ctx.ast, result); - } else { - // Generic mapper - auto mapper = common_chat_peg_mapper(msg); - mapper.from_ast(ctx.ast, result); - } - if (!is_partial) { - LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({msg}).at(0).dump().c_str()); - } - return msg; -} diff --git a/common/chat-parser.h b/common/chat-parser.h deleted file mode 100644 index 3ed9c30a2b3..00000000000 --- a/common/chat-parser.h +++ /dev/null @@ -1,133 +0,0 @@ -#pragma once - -#include "chat.h" -#include "chat-parser-xml-toolcall.h" -#include "json-partial.h" -#include "regex-partial.h" - -#include - -#include -#include -#include - -class common_chat_msg_partial_exception : public std::runtime_error { - public: - common_chat_msg_partial_exception(const std::string & message) : std::runtime_error(message) {} -}; - -class common_chat_msg_parser { - std::string input_; - bool is_partial_; - common_chat_parser_params syntax_; // TODO: rename to params - std::string healing_marker_; - - size_t pos_ = 0; - common_chat_msg result_; - - public: - common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_parser_params & syntax); - const std::string & input() const { return input_; } - size_t pos() const { return pos_; } - const std::string & healing_marker() const { return healing_marker_; } - const bool & is_partial() const { return is_partial_; } - const common_chat_msg & result() const { return result_; } - const common_chat_parser_params & syntax() const { return syntax_; } - - void move_to(size_t pos) { - if (pos > input_.size()) { - throw std::runtime_error("Invalid position!"); - } - pos_ = pos; - } - void move_back(size_t n) { - if (pos_ < n) { - throw std::runtime_error("Can't move back that far!"); - } - pos_ -= n; - } - - // Get the substring of the input at the given range - std::string str(const common_string_range & rng) const; - - // Appends to the result.content field - void add_content(const std::string & content); - - // Appends to the result.reasoning_content field - void add_reasoning_content(const std::string & reasoning_content); - - // Adds a tool call to the result. If the tool call is too incomplete (e.g. name empty), it won't add anything. - bool add_tool_call(const std::string & name, const std::string & id, const std::string & arguments); - - // Adds a tool call using the "name", "id" and "arguments" fields of the json object - bool add_tool_call(const nlohmann::ordered_json & tool_call); - - // Adds an array of tool calls using their "name", "id" and "arguments" fields. - bool add_tool_calls(const nlohmann::ordered_json & arr); - - // Adds a tool call using the short form: { "tool_name": { "arg1": val, "arg2": val } } - bool add_tool_call_short_form(const nlohmann::ordered_json & tool_call); - - void finish(); - - bool consume_spaces(); - - void consume_literal(const std::string & literal); - - bool try_parse_reasoning(const std::string & start_think, const std::string & end_think); - - std::string consume_rest(); - - struct find_regex_result { - std::string prelude; - std::vector groups; - }; - - std::optional try_find_regex(const common_regex & regex, size_t from = std::string::npos, bool add_prelude_to_content = true); - - bool try_consume_literal(const std::string & literal); - - std::optional try_find_literal(const std::string & literal); - - find_regex_result consume_regex(const common_regex & regex); - - std::optional try_consume_regex(const common_regex & regex); - - std::optional try_consume_json(); - common_json consume_json(); - - struct consume_json_result { - nlohmann::ordered_json value; - bool is_partial; - }; - - /* - Consume (possibly partial) json and converts specific subtrees to (possibly truncated) JSON strings. - - By default, object keys can't be truncated, nor can string values (their corresponding key is removed, - e.g. `{"foo": "bar", "baz": "b` -> `{"foo": "bar"}` - - But one can allow subpaths to be kept truncated, and possibly json-dumped to truncated json strings - - with `content_paths={{"foo"}}` -> `{"foo": "b` -> {"foo": "b"}` - - with `args_paths={{"foo"}}` -> `{"foo": {"b` -> `{"foo": "{b"}` - */ - consume_json_result consume_json_with_dumped_args( - const std::vector> & args_paths = {}, - const std::vector> & content_paths = {} - ); - std::optional try_consume_json_with_dumped_args( - const std::vector> & args_paths = {}, - const std::vector> & content_paths = {} - ); - - /** - * Parse XML-Style tool call for given xml_tool_call_format. Return false for invalid syntax and get the position untouched. - * form.scope_start, form.tool_sep and form.scope_end can be empty. - */ - bool try_consume_xml_tool_calls(const struct xml_tool_call_format & form); - - // Parse content uses reasoning and XML-Style tool call - void consume_reasoning_with_xml_tool_calls(const struct xml_tool_call_format & form, const std::string & start_think = "", const std::string & end_think = ""); - - void clear_tools(); -}; diff --git a/common/chat-peg-parser.cpp b/common/chat-peg-parser.cpp index 1bcba9cd866..ef9dec5935a 100644 --- a/common/chat-peg-parser.cpp +++ b/common/chat-peg-parser.cpp @@ -1,13 +1,17 @@ #include "chat-peg-parser.h" +#include "chat-auto-parser.h" +#include "ggml.h" +#include "peg-parser.h" + #include -using json = nlohmann::json; +using json = nlohmann::ordered_json; static std::string_view trim_trailing_space(std::string_view sv, int max = -1) { int count = 0; while (!sv.empty() && std::isspace(static_cast(sv.back()))) { - if (max != -1 && count <= max) { + if (max != -1 && count >= max) { break; } sv.remove_suffix(1); @@ -16,109 +20,753 @@ static std::string_view trim_trailing_space(std::string_view sv, int max = -1) { return sv; } -void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { - arena.visit(result, [this](const common_peg_ast_node & node) { - map(node); - }); +static std::string_view trim_leading_space(std::string_view sv, int max = -1) { + int count = 0; + while (!sv.empty() && std::isspace(static_cast(sv.front()))) { + if (max != -1 && count >= max) { + break; + } + sv.remove_prefix(1); + count++; + } + return sv; } -void common_chat_peg_mapper::map(const common_peg_ast_node & node) { - bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; - bool is_content = node.tag == common_chat_peg_builder::CONTENT; +static std::string_view trim(std::string_view sv) { + return trim_trailing_space(trim_leading_space(sv, 1)); +} - if (is_reasoning) { - result.reasoning_content = std::string(trim_trailing_space(node.text)); +// Count the number of unclosed '{' braces in a JSON-like string, +// properly skipping braces inside quoted strings. +static int json_brace_depth(const std::string & s) { + int depth = 0; + bool in_string = false; + bool escaped = false; + for (char c : s) { + if (escaped) { + escaped = false; + continue; + } + if (c == '\\' && in_string) { + escaped = true; + continue; + } + if (c == '"') { + in_string = !in_string; + continue; + } + if (!in_string) { + if (c == '{') { + depth++; + } else if (c == '}') { + depth--; + } + } } + return depth; +} - if (is_content) { - result.content = std::string(trim_trailing_space(node.text)); +// JSON-escape a string and return the inner content (without surrounding quotes). +static std::string escape_json_string_inner(const std::string & s) { + std::string escaped = json(s).dump(); + if (escaped.size() >= 2 && escaped.front() == '"' && escaped.back() == '"') { + return escaped.substr(1, escaped.size() - 2); } + return escaped; } -void common_chat_peg_native_mapper::map(const common_peg_ast_node & node) { - common_chat_peg_mapper::map(node); +// Convert Python-style single-quoted strings to JSON double-quoted strings +// Only converts outer string delimiters, properly handling escape sequences: +// - {'key': 'value'} -> {"key": "value"} +// - {'code': 'print(\'hello\')'} -> {"code": "print('hello')"} +// - {'msg': 'He said "hi"'} -> {"msg": "He said \"hi\""} +static std::string normalize_quotes_to_json(const std::string & input) { + std::string result; + result.reserve(input.size() + 16); // May need extra space for escaping - bool is_tool_open = node.tag == common_chat_peg_native_builder::TOOL_OPEN; - bool is_tool_name = node.tag == common_chat_peg_native_builder::TOOL_NAME; - bool is_tool_id = node.tag == common_chat_peg_native_builder::TOOL_ID; - bool is_tool_args = node.tag == common_chat_peg_native_builder::TOOL_ARGS; + bool in_single_quoted = false; + bool in_double_quoted = false; - if (is_tool_open) { - result.tool_calls.emplace_back(); - current_tool = &result.tool_calls.back(); + for (size_t i = 0; i < input.size(); ++i) { + char c = input[i]; + + // Handle escape sequences + if (c == '\\' && i + 1 < input.size()) { + char next = input[i + 1]; + + if (in_single_quoted) { + // Inside a single-quoted string being converted to double quotes + if (next == '\'') { + // \' -> ' (escaped single quote becomes unescaped in double-quoted string) + result += '\''; + ++i; + continue; + } + if (next == '"') { + // \" stays as \" (already escaped, works in double-quoted string) + result += "\\\""; + ++i; + continue; + } + // Other escapes (\n, \\, etc.): pass through both characters + result += c; + result += next; + ++i; + continue; + } + + if (in_double_quoted) { + // Inside a double-quoted string - pass through escape sequences as-is + result += c; + result += next; + ++i; + continue; + } + + // Outside any string - just pass through the backslash + result += c; + continue; + } + + // Handle quote characters + if (c == '"') { + if (in_single_quoted) { + // Unescaped double quote inside single-quoted string -> must escape for JSON + result += "\\\""; + } else { + // Double quote as string delimiter or outside strings + in_double_quoted = !in_double_quoted; + result += c; + } + } else if (c == '\'') { + if (in_double_quoted) { + // Single quote inside double-quoted string -> pass through + result += c; + } else if (in_single_quoted) { + // Closing single quote -> convert to double quote + in_single_quoted = false; + result += '"'; + } else { + // Opening single quote -> convert to double quote + in_single_quoted = true; + result += '"'; + } + } else { + result += c; + } } - if (is_tool_id && current_tool) { - current_tool->id = std::string(trim_trailing_space(node.text)); + return result; +} + +void tag_based_peg_mapper::from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result) { + arena.visit(result, [this](const common_peg_ast_node & node) { + if (!node.tag.empty()) { + tags[node.tag] = std::string(node.text); + } + }); +} + +tagged_parse_result tagged_peg_parser::parse_and_extract(const std::string & input, bool is_partial) const { + common_peg_parse_context ctx(input, is_partial); + auto parse_result = arena.parse(ctx); + + tag_based_peg_mapper mapper; + mapper.from_ast(ctx.ast, parse_result); + + return { std::move(parse_result), std::move(mapper.tags) }; +} + +tagged_parse_result tagged_peg_parser::parse_anywhere_and_extract(const std::string & input) const { + if (input.empty()) { + return parse_and_extract(input, false); + } + for (size_t i = 0; i < input.size(); i++) { + common_peg_parse_context ctx(input, false); + ctx.debug = debug; + auto parse_result = arena.parse(ctx, i); + if (parse_result.success() || i == input.size() - 1) { + tag_based_peg_mapper mapper; + mapper.from_ast(ctx.ast, parse_result); + return { std::move(parse_result), std::move(mapper.tags) }; + } } + GGML_ABORT("Should not happen"); +} - if (is_tool_name && current_tool) { - current_tool->name = std::string(trim_trailing_space(node.text)); +tagged_peg_parser build_tagged_peg_parser( + const std::function & fn) { + common_peg_parser_builder builder; + builder.set_root(fn(builder)); + return { builder.build() }; +} + +common_peg_parser common_chat_peg_builder::tag_with_safe_content(const std::string & tag_name, + const std::string & marker, + const common_peg_parser & p) { + if (marker.empty()) { + return zero_or_more(choice({ p, rule(tag_name, content(any())) })); } + auto content_chunk = rule(tag_name, content(negate(literal(marker)) + any() + until(marker))); + return zero_or_more(choice({ p, content_chunk })); +} - if (is_tool_args && current_tool) { - current_tool->arguments = std::string(trim_trailing_space(node.text)); +std::string & common_chat_peg_mapper::args_target() { + return (current_tool && !current_tool->name.empty()) ? current_tool->arguments : args_buffer; +} + +void common_chat_peg_mapper::from_ast(const common_peg_ast_arena & arena, + const common_peg_parse_result & parse_result_arg) { + arena.visit(parse_result_arg, [this](const common_peg_ast_node & node) { map(node); }); + // Flush any pending tool call that was started but never got a name + // This happens during partial parsing when the tool call is incomplete + if (pending_tool_call.has_value() && !pending_tool_call->name.empty()) { + if (!args_buffer.empty()) { + pending_tool_call->arguments = args_buffer; + } + if (closing_quote_pending && !pending_tool_call->arguments.empty()) { + pending_tool_call->arguments += "\""; + } + result.tool_calls.push_back(pending_tool_call.value()); + pending_tool_call.reset(); } } -void common_chat_peg_constructed_mapper::map(const common_peg_ast_node & node) { - common_chat_peg_mapper::map(node); +void common_chat_peg_mapper::map(const common_peg_ast_node & node) { + // Handle reasoning/content tags + bool is_reasoning = node.tag == common_chat_peg_builder::REASONING; + bool is_content = node.tag == common_chat_peg_builder::CONTENT; - bool is_tool_open = node.tag == common_chat_peg_constructed_builder::TOOL_OPEN; - bool is_tool_name = node.tag == common_chat_peg_constructed_builder::TOOL_NAME; - bool is_tool_close = node.tag == common_chat_peg_constructed_builder::TOOL_CLOSE; - bool is_arg_open = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_OPEN; - bool is_arg_close = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_CLOSE; - bool is_arg_name = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_NAME; - bool is_arg_string = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_STRING_VALUE; - bool is_arg_json = node.tag == common_chat_peg_constructed_builder::TOOL_ARG_JSON_VALUE; + if (is_reasoning) { // GPT OSS can have more than 1 reasoning block, so concatenate here + result.reasoning_content += std::string(node.text); + } + + if (is_content) { + // Concatenate content from multiple content nodes (e.g., when reasoning markers + // are preserved before content markers in reasoning_format=NONE mode) + result.content += std::string(node.text); + } + + // Handle tool-related tags (supporting both JSON and tagged formats) + bool is_tool_open = node.tag == common_chat_peg_builder::TOOL_OPEN; + bool is_tool_close = node.tag == common_chat_peg_builder::TOOL_CLOSE; + bool is_tool_name = node.tag == common_chat_peg_builder::TOOL_NAME; + bool is_tool_id = node.tag == common_chat_peg_builder::TOOL_ID; + bool is_tool_args = node.tag == common_chat_peg_builder::TOOL_ARGS; + bool is_arg_open = node.tag == common_chat_peg_builder::TOOL_ARG_OPEN; + bool is_arg_close = node.tag == common_chat_peg_builder::TOOL_ARG_CLOSE; + bool is_arg_name = node.tag == common_chat_peg_builder::TOOL_ARG_NAME; + bool is_arg_value = node.tag == common_chat_peg_builder::TOOL_ARG_VALUE; + bool is_arg_string_value = node.tag == common_chat_peg_builder::TOOL_ARG_STRING_VALUE; if (is_tool_open) { - result.tool_calls.emplace_back(); - current_tool = &result.tool_calls.back(); - arg_count = 0; + pending_tool_call = common_chat_tool_call(); + current_tool = &pending_tool_call.value(); + arg_count = 0; + args_buffer.clear(); + closing_quote_pending = false; + } + + if (is_tool_id && current_tool) { + auto text = trim_trailing_space(node.text); + if (text.size() >= 2 && text.front() == '"' && text.back() == '"') { + text = text.substr(1, text.size() - 2); + } + current_tool->id = std::string(text); + } + + if (is_tool_name && current_tool) { + current_tool->name = std::string(trim_trailing_space(node.text)); + // Now that we have the name, populate the arguments from the buffer + if (!args_buffer.empty()) { + current_tool->arguments = args_buffer; + args_buffer.clear(); + } else if (current_tool->arguments.empty()) { + current_tool->arguments = "{"; + } + // Add the tool call to results so streaming can see it + if (pending_tool_call.has_value()) { + result.tool_calls.push_back(pending_tool_call.value()); + pending_tool_call.reset(); + current_tool = &result.tool_calls.back(); + } } - if (is_tool_name) { - current_tool->name = std::string(node.text); - current_tool->arguments = "{"; + if (is_tool_args && current_tool) { + // For JSON format: arguments come as a complete JSON object + // For tagged format: built up from individual arg_name/arg_value nodes + auto text = trim_trailing_space(node.text); + if (!text.empty() && text.front() == '{') { + args_target() = std::string(text); + } } if (is_arg_open) { - needs_closing_quote = false; + closing_quote_pending = false; } if (is_arg_name && current_tool) { + std::string arg_entry; if (arg_count > 0) { - current_tool->arguments += ","; + arg_entry = ","; } - current_tool->arguments += json(trim_trailing_space(node.text)).dump() + ":"; + arg_entry += json(trim(node.text)).dump() + ":"; ++arg_count; + + auto & target = args_target(); + if (target.empty()) { + target = "{"; + } + target += arg_entry; } - if (is_arg_string && current_tool) { - // Serialize to JSON, but exclude the end quote - std::string dumped = json(trim_trailing_space(node.text)).dump(); - current_tool->arguments += dumped.substr(0, dumped.size() - 1); - needs_closing_quote = true; + if ((is_arg_value || is_arg_string_value) && current_tool) { + std::string value_content = std::string(trim_trailing_space(trim_leading_space(node.text, 1), 1)); + + std::string value_to_add; + if (value_content.empty() && is_arg_string_value) { + // Empty string value - arg_close will add the closing quote + value_to_add = "\""; + closing_quote_pending = true; + } else if (!value_content.empty() && is_arg_string_value) { + // Schema declares this as string type - always treat as literal string value + if (!closing_quote_pending) { + value_to_add = "\""; + closing_quote_pending = true; + } + value_to_add += escape_json_string_inner(value_content); + } else if (!value_content.empty()) { + // For potential containers, normalize Python-style single quotes to JSON double quotes + bool is_potential_container = value_content[0] == '[' || value_content[0] == '{'; + if (is_potential_container) { + value_content = normalize_quotes_to_json(value_content); + } + + // Try to parse as JSON value (number, bool, null, object, array) + try { + json parsed = json::parse(value_content); + if (parsed.is_string()) { + // Don't add closing quote yet (added by arg_close) for monotonic streaming + std::string escaped = parsed.dump(); + if (!escaped.empty() && escaped.back() == '"') { + escaped.pop_back(); + } + value_to_add = escaped; + closing_quote_pending = true; + } else { + // Non-string values: use raw content to preserve whitespace for monotonicity + value_to_add = value_content; + } + } catch (...) { + if (node.is_partial && is_potential_container) { + // Partial container: pass through the already-normalized content + value_to_add = value_content; + } else { + // Not valid JSON - treat as string value + if (!closing_quote_pending) { + value_to_add = "\""; + closing_quote_pending = true; + } + value_to_add += escape_json_string_inner(value_content); + } + } + } + + args_target() += value_to_add; } if (is_arg_close && current_tool) { - if (needs_closing_quote) { + if (closing_quote_pending) { + args_target() += "\""; + closing_quote_pending = false; + } + } + + if (is_tool_close && current_tool) { + // Flush buffer to arguments if tool name was never seen + if (current_tool->name.empty() && !args_buffer.empty()) { + current_tool->arguments = args_buffer; + args_buffer.clear(); + } + // Close any pending string quote + if (closing_quote_pending) { current_tool->arguments += "\""; - needs_closing_quote = false; + closing_quote_pending = false; + } + // Close any unclosed braces (accounts for nested objects) + for (int d = json_brace_depth(current_tool->arguments); d > 0; d--) { + current_tool->arguments += "}"; + } + // Add tool call to results if named; otherwise discard + if (pending_tool_call.has_value()) { + if (!current_tool->name.empty()) { + result.tool_calls.push_back(pending_tool_call.value()); + } + pending_tool_call.reset(); + } + } +} + +common_peg_parser common_chat_peg_builder::standard_constructed_tools( + const std::map & markers, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls) { + if (!tools.is_array() || tools.empty()) { + return eps(); + } + + // Extract markers with defaults + auto get_marker = [&markers](const std::string & key, const std::string & default_val = "") -> std::string { + auto it = markers.find(key); + return it != markers.end() ? it->second : default_val; + }; + + std::string section_start = get_marker("tool_call_start_marker", ""); + std::string section_end = get_marker("tool_call_end_marker", ""); + std::string func_opener = get_marker("function_opener", ""); + std::string func_closer = get_marker("function_closer", ""); + std::string param_key_prefix = get_marker("parameter_key_prefix", ""); + std::string param_closer = get_marker("parameter_closer", ""); + + // Build tool choices for tagged format + auto tool_choices = choice(); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + // Build argument parsers + auto args = eps(); + if (params.contains("properties") && !params["properties"].empty()) { + auto arg_choice = choice(); + for (const auto & el : params["properties"].items()) { + const std::string & prop_name = el.key(); + + auto arg_name_parser = + choice({ literal(prop_name), literal("\"" + prop_name + "\""), literal("'" + prop_name + "'") }); + + auto arg_rule = tool_arg(tool_arg_open(literal(param_key_prefix)) + tool_arg_name(arg_name_parser) + + literal(param_key_suffix) + tool_arg_value(until(param_closer)) + + tool_arg_close(literal(param_closer))); + arg_choice |= arg_rule; + } + args = zero_or_more(arg_choice + space()); + } + + // Build function parser: args + auto tool_parser = tool(tool_open(literal(func_opener) + tool_name(literal(name)) + literal(func_name_suffix)) + + space() + tool_args(args) + space() + tool_close(literal(func_closer))); + + tool_choices |= rule("tool-" + name, tool_parser); } - if (is_arg_json && current_tool) { - current_tool->arguments += std::string(trim_trailing_space(node.text)); + // Build the section with markers + auto section = + parallel_tool_calls ? + trigger_rule("tool-call", literal(section_start) + space() + one_or_more(tool_choices + space()) + + literal(section_end)) : + trigger_rule("tool-call", literal(section_start) + space() + tool_choices + space() + literal(section_end)); + + return force_tool_calls ? section : optional(section); +} + +// Helper: Parse dot notation key into prefix and field name +static std::pair parse_key_spec(const std::string & key) { + auto dot_pos = key.find('.'); + if (dot_pos == std::string::npos) { + return {"", key}; // Top-level field } + return {key.substr(0, dot_pos), key.substr(dot_pos + 1)}; +} - if (is_tool_close && current_tool) { - if (needs_closing_quote) { - current_tool->arguments += "\""; - needs_closing_quote = false; +// Mode 1: function_is_key — parse {"function_name": {...}} +common_peg_parser common_chat_peg_builder::build_json_tools_function_is_key( + const nlohmann::json & tools, + const std::string & args_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key) { + + auto tool_choices = choice(); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; + } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + // Build inner object fields + std::vector inner_fields; + + if (!call_id_key.empty()) { + auto id_parser = atomic( + literal("\"" + call_id_key + "\"") + space() + literal(":") + space() + + literal("\"") + tool_id(json_string_content()) + literal("\"") + ); + inner_fields.push_back(optional(id_parser + space() + optional(literal(",") + space()))); + } + + if (!gen_call_id_key.empty()) { + auto gen_id_parser = atomic( + literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + inner_fields.push_back(optional(gen_id_parser + space() + optional(literal(",") + space()))); + } + + // Arguments — either wrapped in args_key or parsed directly + common_peg_parser args_parser = eps(); + if (args_key.empty()) { + args_parser = tool_args(schema(json(), "tool-" + name + "-schema", params)); + } else { + args_parser = literal("\"" + effective_args_key + "\"") + space() + literal(":") + space() + + tool_args(schema(json(), "tool-" + name + "-schema", params)); + } + inner_fields.push_back(args_parser); + + // Build inner object parser + common_peg_parser inner_object = eps(); + if (args_key.empty() && inner_fields.size() == 1) { + inner_object = inner_fields[0]; + } else { + inner_object = literal("{") + space(); + for (size_t i = 0; i < inner_fields.size(); i++) { + inner_object = inner_object + inner_fields[i]; + if (i < inner_fields.size() - 1) { + inner_object = inner_object + space(); + } + } + inner_object = inner_object + space() + literal("}"); + } + + auto tool_parser = tool( + tool_open(literal("{")) + space() + + literal("\"") + tool_name(literal(name)) + literal("\"") + + space() + literal(":") + space() + + inner_object + + space() + tool_close(literal("}")) + ); + + tool_choices |= rule("tool-" + name, tool_parser); + } + + return tool_choices; +} + +// Mode 2: Nested keys (dot notation like "function.name") +common_peg_parser common_chat_peg_builder::build_json_tools_nested_keys( + const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key) { + + auto tool_choices = choice(); + + auto name_spec = parse_key_spec(effective_name_key); + auto args_spec = parse_key_spec(effective_args_key); + + std::string nested_prefix = !name_spec.first.empty() ? name_spec.first : args_spec.first; + std::string nested_name_field = !name_spec.first.empty() ? name_spec.second : effective_name_key; + std::string nested_args_field = !args_spec.first.empty() ? args_spec.second : effective_args_key; + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; } - current_tool->arguments += "}"; + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + auto nested_name = literal("\"" + nested_name_field + "\"") + space() + literal(":") + space() + + literal("\"") + tool_name(literal(name)) + literal("\""); + auto nested_args = literal("\"" + nested_args_field + "\"") + space() + literal(":") + space() + + tool_args(schema(json(), "tool-" + name + "-schema", params)); + + auto nested_object = literal("{") + space() + + nested_name + space() + literal(",") + space() + + nested_args + + space() + literal("}"); + + // Format: { id?, "function": {...} } + auto tool_parser_body = tool_open(literal("{")) + space(); + + if (!call_id_key.empty()) { + auto id_spec = parse_key_spec(call_id_key); + if (id_spec.first.empty()) { + auto id_parser = atomic( + literal("\"" + call_id_key + "\"") + space() + literal(":") + space() + + literal("\"") + tool_id(json_string_content()) + literal("\"") + ); + tool_parser_body = tool_parser_body + optional(id_parser + space() + literal(",") + space()); + } + } + + if (!gen_call_id_key.empty()) { + auto gen_id_spec = parse_key_spec(gen_call_id_key); + if (gen_id_spec.first.empty()) { + auto gen_id_parser = atomic( + literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + tool_parser_body = tool_parser_body + optional(gen_id_parser + space() + literal(",") + space()); + } + } + + auto nested_field = literal("\"" + nested_prefix + "\"") + space() + literal(":") + space() + nested_object; + tool_parser_body = tool_parser_body + nested_field + space() + tool_close(literal("}")); + + tool_choices |= rule("tool-" + name, tool(tool_parser_body)); + } + + return tool_choices; +} + +// Mode 3: Flat keys with optional ID fields and parameter ordering +common_peg_parser common_chat_peg_builder::build_json_tools_flat_keys( + const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key, + const std::vector & parameters_order) { + + auto tool_choices = choice(); + auto name_key_parser = literal("\"" + effective_name_key + "\""); + auto args_key_parser = literal("\"" + effective_args_key + "\""); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { + continue; + } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + nlohmann::json params = function.contains("parameters") ? function.at("parameters") : nlohmann::json::object(); + + auto tool_name_ = name_key_parser + space() + literal(":") + space() + + literal("\"") + tool_name(literal(name)) + literal("\""); + auto tool_args_ = args_key_parser + space() + literal(":") + space() + + tool_args(schema(json(), "tool-" + name + "-schema", params)); + + // Build ID parsers if keys are provided + common_peg_parser id_parser = eps(); + if (!call_id_key.empty()) { + id_parser = atomic( + literal("\"" + call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + } + + common_peg_parser gen_id_parser = eps(); + if (!gen_call_id_key.empty()) { + gen_id_parser = atomic( + literal("\"" + gen_call_id_key + "\"") + space() + literal(":") + space() + + choice({ + literal("\"") + tool_id(json_string_content()) + literal("\""), + tool_id(json_number()) + }) + ); + } + + // Create (parser, key) pairs for all fields, then sort by parameters_order + std::vector> parser_pairs; + parser_pairs.emplace_back(tool_name_, effective_name_key); + parser_pairs.emplace_back(tool_args_, effective_args_key); + if (!call_id_key.empty()) { + parser_pairs.emplace_back(optional(id_parser), call_id_key); + } + if (!gen_call_id_key.empty()) { + parser_pairs.emplace_back(optional(gen_id_parser), gen_call_id_key); + } + + std::sort(parser_pairs.begin(), parser_pairs.end(), + [¶meters_order](const auto & a, const auto & b) { + auto pos_a = std::find(parameters_order.begin(), parameters_order.end(), a.second); + auto pos_b = std::find(parameters_order.begin(), parameters_order.end(), b.second); + size_t idx_a = (pos_a == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_a); + size_t idx_b = (pos_b == parameters_order.end()) ? parameters_order.size() : std::distance(parameters_order.begin(), pos_b); + return idx_a < idx_b; + }); + + auto ordered_body = tool_open(literal("{")) + space(); + for (size_t i = 0; i < parser_pairs.size(); i++) { + ordered_body = ordered_body + parser_pairs[i].first; + if (i < parser_pairs.size() - 1) { + ordered_body = ordered_body + space() + literal(",") + space(); + } + } + ordered_body = ordered_body + space() + tool_close(literal("}")); + + tool_choices |= rule("tool-" + name, tool(ordered_body)); + } + + return tool_choices; +} + +common_peg_parser common_chat_peg_builder::standard_json_tools( + const std::string & section_start, + const std::string & section_end, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls, + const std::string & name_key, + const std::string & args_key, + bool array_wrapped, + bool function_is_key, + const std::string & call_id_key, + const std::string & gen_call_id_key, + const std::vector & parameters_order) { + if (!tools.is_array() || tools.empty()) { + return eps(); } + + std::string effective_name_key = name_key.empty() ? "name" : name_key; + std::string effective_args_key = args_key.empty() ? "arguments" : args_key; + + // Dispatch to the appropriate builder based on the JSON layout mode + common_peg_parser tool_choices = eps(); + if (function_is_key) { + tool_choices = build_json_tools_function_is_key(tools, args_key, effective_args_key, call_id_key, gen_call_id_key); + } else { + auto name_spec = parse_key_spec(effective_name_key); + auto args_spec = parse_key_spec(effective_args_key); + if (!name_spec.first.empty() || !args_spec.first.empty()) { + tool_choices = build_json_tools_nested_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key); + } else { + tool_choices = build_json_tools_flat_keys(tools, effective_name_key, effective_args_key, call_id_key, gen_call_id_key, parameters_order); + } + } + + // Build the section with markers + auto tool_calls = tool_choices; + if (parallel_tool_calls) { + tool_calls = tool_calls + zero_or_more(space() + literal(",") + space() + tool_choices); + } + + if (array_wrapped) { + tool_calls = literal("[") + space() + tool_calls + space() + literal("]"); + } + + auto section = + trigger_rule("tool-call", literal(section_start) + space() + tool_calls + space() + literal(section_end)); + + return force_tool_calls ? section : optional(section); } diff --git a/common/chat-peg-parser.h b/common/chat-peg-parser.h index b84cbed2069..e130ceea5ff 100644 --- a/common/chat-peg-parser.h +++ b/common/chat-peg-parser.h @@ -3,22 +3,9 @@ #include "chat.h" #include "peg-parser.h" -class common_chat_peg_builder : public common_peg_parser_builder { - public: - static constexpr const char * REASONING_BLOCK = "reasoning-block"; - static constexpr const char * REASONING = "reasoning"; - static constexpr const char * CONTENT = "content"; - - common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } - common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } - common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } -}; - -inline common_peg_arena build_chat_peg_parser(const std::function & fn) { - common_chat_peg_builder builder; - builder.set_root(fn(builder)); - return builder.build(); -} +#include +#include +#include class common_chat_peg_mapper { public: @@ -26,80 +13,164 @@ class common_chat_peg_mapper { common_chat_peg_mapper(common_chat_msg & msg) : result(msg) {} + virtual ~common_chat_peg_mapper() = default; + virtual void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); virtual void map(const common_peg_ast_node & node); + private: + // Tool call handling state + std::optional pending_tool_call; // Tool call waiting for name + common_chat_tool_call * current_tool = nullptr; + int arg_count = 0; + bool closing_quote_pending = false; + std::string args_buffer; // Buffer to delay arguments until tool name is known + + // Returns a reference to the active argument destination string. + // Before tool_name is known, writes go to args_buffer; after, to current_tool->arguments. + std::string & args_target(); }; -class common_chat_peg_native_builder : public common_chat_peg_builder { - public: - static constexpr const char * TOOL = "tool"; - static constexpr const char * TOOL_OPEN = "tool-open"; - static constexpr const char * TOOL_CLOSE = "tool-close"; - static constexpr const char * TOOL_ID = "tool-id"; - static constexpr const char * TOOL_NAME = "tool-name"; - static constexpr const char * TOOL_ARGS = "tool-args"; - - common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } - common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } - common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } - common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); } - common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } - common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); } -}; - -class common_chat_peg_native_mapper : public common_chat_peg_mapper { - common_chat_tool_call * current_tool; +struct content_structure; +struct tool_call_structure; +class common_chat_peg_builder : public common_peg_parser_builder { public: - common_chat_peg_native_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + // Tag constants (from former common_chat_peg_base_builder) + static constexpr const char * REASONING_BLOCK = "reasoning-block"; + static constexpr const char * REASONING = "reasoning"; + static constexpr const char * CONTENT = "content"; + + // Tag constants + static constexpr const char * TOOL = "tool"; + static constexpr const char * TOOL_OPEN = "tool-open"; + static constexpr const char * TOOL_CLOSE = "tool-close"; + static constexpr const char * TOOL_ID = "tool-id"; + static constexpr const char * TOOL_NAME = "tool-name"; + static constexpr const char * TOOL_ARGS = "tool-args"; + static constexpr const char * TOOL_ARG = "tool-arg"; + static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; + static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; + static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; + static constexpr const char * TOOL_ARG_VALUE = "tool-arg-value"; + static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; // For schema-declared string types - void map(const common_peg_ast_node & node) override; -}; + // Low-level tag methods (from former common_chat_peg_base_builder) + common_peg_parser reasoning_block(const common_peg_parser & p) { return tag(REASONING_BLOCK, p); } -inline common_peg_arena build_chat_peg_native_parser(const std::function & fn) { - common_chat_peg_native_builder builder; - builder.set_root(fn(builder)); - return builder.build(); -} + common_peg_parser reasoning(const common_peg_parser & p) { return tag(REASONING, p); } -class common_chat_peg_constructed_builder : public common_chat_peg_builder { - public: - static constexpr const char * TOOL = "tool"; - static constexpr const char * TOOL_OPEN = "tool-open"; - static constexpr const char * TOOL_CLOSE = "tool-close"; - static constexpr const char * TOOL_NAME = "tool-name"; - static constexpr const char * TOOL_ARG = "tool-arg"; - static constexpr const char * TOOL_ARG_OPEN = "tool-arg-open"; - static constexpr const char * TOOL_ARG_CLOSE = "tool-arg-close"; - static constexpr const char * TOOL_ARG_NAME = "tool-arg-name"; - static constexpr const char * TOOL_ARG_STRING_VALUE = "tool-arg-string-value"; - static constexpr const char * TOOL_ARG_JSON_VALUE = "tool-arg-json-value"; + common_peg_parser content(const common_peg_parser & p) { return tag(CONTENT, p); } + + common_peg_parser tag_with_safe_content(const std::string & tag_name, + const std::string & marker, + const common_peg_parser & p); + // Low-level tag methods common_peg_parser tool(const common_peg_parser & p) { return tag(TOOL, p); } common_peg_parser tool_open(const common_peg_parser & p) { return atomic(tag(TOOL_OPEN, p)); } common_peg_parser tool_close(const common_peg_parser & p) { return atomic(tag(TOOL_CLOSE, p)); } + common_peg_parser tool_id(const common_peg_parser & p) { return atomic(tag(TOOL_ID, p)); } common_peg_parser tool_name(const common_peg_parser & p) { return atomic(tag(TOOL_NAME, p)); } + common_peg_parser tool_args(const common_peg_parser & p) { return tag(TOOL_ARGS, p); } common_peg_parser tool_arg(const common_peg_parser & p) { return tag(TOOL_ARG, p); } common_peg_parser tool_arg_open(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_OPEN, p)); } common_peg_parser tool_arg_close(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_CLOSE, p)); } common_peg_parser tool_arg_name(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_NAME, p)); } + common_peg_parser tool_arg_value(const common_peg_parser & p) { return tag(TOOL_ARG_VALUE, p); } + + // Use for schema-declared string types - won't be treated as potential JSON container common_peg_parser tool_arg_string_value(const common_peg_parser & p) { return tag(TOOL_ARG_STRING_VALUE, p); } - common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return tag(TOOL_ARG_JSON_VALUE, p); } + common_peg_parser tool_arg_json_value(const common_peg_parser & p) { return atomic(tag(TOOL_ARG_VALUE, p)); } + + // Legacy-compatible helper for building standard JSON tool calls + // Used by tests and manual parsers + // name_key/args_key: JSON key names for function name and arguments + // Empty or "name"/"arguments" will accept both common variations + // Supports dot notation for nested objects (e.g., "function.name") + // array_wrapped: if true, tool calls are wrapped in JSON array [...] + // function_is_key: if true, function name is the JSON key (e.g., {"func_name": {...}}) + // call_id_key: JSON key for string call ID (e.g., "id") + // gen_call_id_key: JSON key for generated integer call ID (e.g., "tool_call_id") + // parameters_order: order in which JSON fields should be parsed + common_peg_parser standard_json_tools(const std::string & section_start, + const std::string & section_end, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls, + const std::string & name_key = "", + const std::string & args_key = "", + bool array_wrapped = false, + bool function_is_key = false, + const std::string & call_id_key = "", + const std::string & gen_call_id_key = "", + const std::vector & parameters_order = {}); + + // Legacy-compatible helper for building XML/tagged style tool calls + // Used by tests and manual parsers + common_peg_parser standard_constructed_tools(const std::map & markers, + const nlohmann::json & tools, + bool parallel_tool_calls, + bool force_tool_calls); + + private: + // Implementation helpers for standard_json_tools — one per JSON tool call layout mode + common_peg_parser build_json_tools_function_is_key(const nlohmann::json & tools, + const std::string & args_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key); + + common_peg_parser build_json_tools_nested_keys(const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key); + + common_peg_parser build_json_tools_flat_keys(const nlohmann::json & tools, + const std::string & effective_name_key, + const std::string & effective_args_key, + const std::string & call_id_key, + const std::string & gen_call_id_key, + const std::vector & parameters_order); }; -class common_chat_peg_constructed_mapper : public common_chat_peg_mapper { - common_chat_tool_call * current_tool; - int arg_count = 0; - bool needs_closing_quote = false; +inline common_peg_arena build_chat_peg_parser( + const std::function & fn) { + common_chat_peg_builder builder; + builder.set_root(fn(builder)); + return builder.build(); +} +class tag_based_peg_mapper { public: - common_chat_peg_constructed_mapper(common_chat_msg & msg) : common_chat_peg_mapper(msg) {} + std::map tags; - void map(const common_peg_ast_node & node) override; + void from_ast(const common_peg_ast_arena & arena, const common_peg_parse_result & result); }; -inline common_peg_arena build_chat_peg_constructed_parser(const std::function & fn) { - common_chat_peg_constructed_builder builder; - builder.set_root(fn(builder)); - return builder.build(); -} +struct tagged_parse_result { + common_peg_parse_result result; + std::map tags; +}; + +struct tagged_peg_parser { + common_peg_arena arena; + bool debug = false; + + tagged_peg_parser & withDebug() { + debug = true; + return *this; + } + + tagged_peg_parser & withoutDebug() { + debug = false; + return *this; + } + + tagged_parse_result parse_and_extract(const std::string & input, bool is_partial = false) const; + tagged_parse_result parse_anywhere_and_extract(const std::string & input) const; +}; + +tagged_peg_parser build_tagged_peg_parser( + const std::function & fn); + diff --git a/common/chat.cpp b/common/chat.cpp index 52780c59ad1..81c23430af9 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -1,24 +1,25 @@ #include "chat.h" -#include "chat-parser.h" + +#include "chat-auto-parser.h" #include "chat-peg-parser.h" #include "common.h" -#include "json-partial.h" +#include "ggml.h" #include "json-schema-to-grammar.h" #include "log.h" -#include "regex-partial.h" -#include "jinja/parser.h" #include "jinja/value.h" #include "jinja/runtime.h" #include "jinja/caps.h" +#include "peg-parser.h" -#include #include -#include +#include +#include #include #include -#include + #include +#include #include #include #include @@ -26,14 +27,26 @@ using json = nlohmann::ordered_json; static std::string format_time(const std::chrono::system_clock::time_point & now, const std::string & format) { - auto time = std::chrono::system_clock::to_time_t(now); - auto local_time = *std::localtime(&time); + auto time = std::chrono::system_clock::to_time_t(now); + auto local_time = *std::localtime(&time); std::ostringstream ss; ss << std::put_time(&local_time, format.c_str()); auto res = ss.str(); return res; } +static json safe_args_parse(const std::string & to_parse) { + std::string stripped = to_parse; + if (to_parse.at(0) == '"' && to_parse.at(to_parse.length() - 1) == '"') { + stripped = to_parse.substr(1, to_parse.length() - 1); + } + try { + return json::parse(stripped); + } catch (json::exception & e) { + return stripped; + } +} + static std::string string_diff(const std::string & last, const std::string & current) { if (last.empty()) { return current; @@ -116,7 +129,7 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { {"type", "function"}, {"function", { {"name", tool_call.name}, - {"arguments", tool_call.arguments}, + {"arguments", json::parse(tool_call.arguments)}, }}, }; if (!tool_call.id.empty()) { @@ -133,7 +146,8 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { return jmsg; } -std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new) { +std::vector common_chat_msg_diff::compute_diffs(const common_chat_msg & msg_prv, + const common_chat_msg & msg_new) { std::vector diffs; if (msg_new.tool_calls.size() > msg_prv.tool_calls.size()) { diffs.reserve(msg_new.tool_calls.size() - msg_prv.tool_calls.size() + 3); @@ -143,38 +157,56 @@ std::vector common_chat_msg_diff::compute_diffs(const comm // TODO: these can become expensive for long messages - how to optimize? if (msg_prv.reasoning_content != msg_new.reasoning_content) { - auto & diff = diffs.emplace_back(); + auto & diff = diffs.emplace_back(); diff.reasoning_content_delta = string_diff(msg_prv.reasoning_content, msg_new.reasoning_content); } if (msg_prv.content != msg_new.content) { - auto & diff = diffs.emplace_back(); + auto & diff = diffs.emplace_back(); diff.content_delta = string_diff(msg_prv.content, msg_new.content); } if (msg_new.tool_calls.size() < msg_prv.tool_calls.size()) { - throw std::runtime_error("Invalid diff: now finding less tool calls!"); + std::string err = "Invalid diff: now finding less tool calls!\n"; + err += " Previous (" + std::to_string(msg_prv.tool_calls.size()) + "):\n"; + for (const auto & tc : msg_prv.tool_calls) { + err += " - name: '" + tc.name + "', args: '" + tc.arguments + "'\n"; + } + err += " Current (" + std::to_string(msg_new.tool_calls.size()) + "):\n"; + for (const auto & tc : msg_new.tool_calls) { + err += " - name: '" + tc.name + "', args: '" + tc.arguments + "'\n"; + } + err += " Current msg text content:\n" + msg_new.content + "\n"; + throw std::runtime_error(err); } if (!msg_prv.tool_calls.empty()) { - const auto idx = msg_prv.tool_calls.size() - 1; + const auto idx = msg_prv.tool_calls.size() - 1; const auto & pref = msg_prv.tool_calls[idx]; const auto & newf = msg_new.tool_calls[idx]; - if (pref.name != newf.name) { - throw std::runtime_error("Invalid diff: tool call mismatch!"); + // Allow tool name to change during incremental parsing: + // - empty -> non-empty (initial discovery) + // - prefix -> longer string (name grows as more input is parsed) + if (pref.name != newf.name && !pref.name.empty() && !newf.name.empty()) { + // Check if one is a prefix of the other (for incremental parsing where names grow or shrink) + bool is_prefix = (newf.name.rfind(pref.name, 0) == 0); + if (!is_prefix) { + LOG_ERR("Tool call mismatch: prev='%s' new='%s'\n", pref.name.c_str(), newf.name.c_str()); + throw std::runtime_error("Invalid diff: tool call mismatch!"); + } } const auto args_diff = string_diff(pref.arguments, newf.arguments); - if (!args_diff.empty() || pref.id != newf.id) { - auto & diff = diffs.emplace_back(); + if (!args_diff.empty() || pref.id != newf.id || pref.name != newf.name) { + auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; - if (pref.id != newf.id) { - diff.tool_call_delta.id = newf.id; + if (pref.id != newf.id || pref.name != newf.name) { + diff.tool_call_delta.id = newf.id; diff.tool_call_delta.name = newf.name; } diff.tool_call_delta.arguments = args_diff; } } for (size_t idx = msg_prv.tool_calls.size(); idx < msg_new.tool_calls.size(); ++idx) { - auto & diff = diffs.emplace_back(); + auto & diff = diffs.emplace_back(); diff.tool_call_index = idx; diff.tool_call_delta = msg_new.tool_calls[idx]; } @@ -184,94 +216,14 @@ std::vector common_chat_msg_diff::compute_diffs(const comm using chat_template_caps = jinja::caps; -struct common_chat_template { - jinja::program prog; - std::string bos_tok; - std::string eos_tok; - std::string src; - chat_template_caps caps; - - common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { - jinja::lexer lexer; - auto lexer_res = lexer.tokenize(src); - this->prog = jinja::parse_from_tokens(lexer_res); - - this->src = lexer_res.source; - this->bos_tok = bos_token; - this->eos_tok = eos_token; - - this->caps = jinja::caps_get(prog); - // LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str()); - } - - const std::string & source() const { return src; } - const std::string & bos_token() const { return bos_tok; } - const std::string & eos_token() const { return eos_tok; } - - // TODO: this is ugly, refactor it somehow - json add_system(const json & messages, const std::string & system_prompt) const { - GGML_ASSERT(messages.is_array()); - auto msgs_copy = messages; - if (!caps.supports_system_role) { - if (msgs_copy.empty()) { - msgs_copy.insert(msgs_copy.begin(), json{ - {"role", "user"}, - {"content", system_prompt} - }); - } else { - auto & first_msg = msgs_copy[0]; - if (!first_msg.contains("content")) { - first_msg["content"] = ""; - } - first_msg["content"] = system_prompt + "\n\n" - + first_msg["content"].get(); - } - } else { - if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") { - msgs_copy.insert(msgs_copy.begin(), json{ - {"role", "system"}, - {"content", system_prompt} - }); - } else if (msgs_copy[0].at("role") == "system") { - msgs_copy[0]["content"] = system_prompt; - } - } - return msgs_copy; - } - - chat_template_caps original_caps() const { - return caps; - } - -}; - struct common_chat_templates { bool add_bos; bool add_eos; - bool has_explicit_template; // Model had builtin template or template overridde was specified. - std::unique_ptr template_default; // always set (defaults to chatml) + bool has_explicit_template; // Model had builtin template or template overridde was specified. + std::unique_ptr template_default; // always set (defaults to chatml) std::unique_ptr template_tool_use; }; -struct templates_params { - json messages; - json tools; - common_chat_tool_choice tool_choice; - json json_schema; - bool parallel_tool_calls; - common_reasoning_format reasoning_format; - bool stream; - std::string grammar; - bool add_generation_prompt = true; - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - json extra_context; - bool add_bos; - bool add_eos; - bool is_inference = true; - bool mark_input = true; // whether to mark input strings in the jinja context -}; - common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice) { if (tool_choice == "auto") { return COMMON_CHAT_TOOL_CHOICE_AUTO; @@ -286,23 +238,24 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin } bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) { - common_chat_templates_inputs dummy_inputs; + common_chat_templates_inputs inputs; + inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; common_chat_msg msg; - msg.role = "user"; + msg.role = "user"; msg.content = "test"; - dummy_inputs.messages = {msg}; - dummy_inputs.enable_thinking = false; - const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); - dummy_inputs.enable_thinking = true; - const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs); - return rendered_no_thinking.prompt != rendered_with_thinking.prompt; + inputs.messages = { msg }; + inputs.enable_thinking = true; + inputs.add_generation_prompt = true; + inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + + auto params = common_chat_templates_apply(chat_templates, inputs); + return params.supports_thinking; } std::vector common_chat_msgs_parse_oaicompat(const json & messages) { std::vector msgs; try { - if (!messages.is_array()) { throw std::invalid_argument("Expected 'messages' to be an array, got " + messages.dump()); } @@ -318,7 +271,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } msg.role = message.at("role"); - auto has_content = message.contains("content"); + auto has_content = message.contains("content"); auto has_tool_calls = message.contains("tool_calls"); if (has_content) { const auto & content = message.at("content"); @@ -339,7 +292,9 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa msg.content_parts.push_back(msg_part); } } else if (!content.is_null()) { - throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + content.dump() + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); + throw std::invalid_argument("Invalid 'content' type: expected string or array, got " + + content.dump() + + " (ref: https://github.com/ggml-org/llama.cpp/issues/8367)"); } } if (has_tool_calls) { @@ -359,8 +314,13 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa if (!fc.contains("name")) { throw std::invalid_argument("Missing tool call name: " + tool_call.dump()); } - tc.name = fc.at("name"); - tc.arguments = fc.at("arguments"); + tc.name = fc.at("name"); + const auto & args = fc.at("arguments"); + if (args.is_string()) { + tc.arguments = args; + } else { + tc.arguments = args.dump(); + } if (tool_call.contains("id")) { tc.id = tool_call.at("id"); } @@ -368,7 +328,9 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa } } if (!has_content && !has_tool_calls) { - throw std::invalid_argument("Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & https://github.com/ggml-org/llama.cpp/issues/12279)"); + throw std::invalid_argument( + "Expected 'content' or 'tool_calls' (ref: https://github.com/ggml-org/llama.cpp/issues/8367 & " + "https://github.com/ggml-org/llama.cpp/issues/12279)"); } if (message.contains("reasoning_content")) { msg.reasoning_content = message.at("reasoning_content"); @@ -474,12 +436,13 @@ json common_chat_tools_to_json_oaicompat(const std::vector & t auto result = json::array(); for (const auto & tool : tools) { result.push_back({ - {"type", "function"}, - {"function", { - {"name", tool.name}, - {"description", tool.description}, - {"parameters", json::parse(tool.parameters)}, - }}, + { "type", "function" }, + { "function", + { + { "name", tool.name }, + { "description", tool.description }, + { "parameters", json::parse(tool.parameters) }, + } }, }); } return result; @@ -497,16 +460,20 @@ json common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff) { json tool_call; tool_call["index"] = diff.tool_call_index; if (!diff.tool_call_delta.id.empty()) { - tool_call["id"] = diff.tool_call_delta.id; + tool_call["id"] = diff.tool_call_delta.id; tool_call["type"] = "function"; } - json function = json::object(); - if (!diff.tool_call_delta.name.empty()) { - function["name"] = diff.tool_call_delta.name; + if (!diff.tool_call_delta.name.empty() || !diff.tool_call_delta.arguments.empty()) { + json function = json::object(); + if (!diff.tool_call_delta.name.empty()) { + function["name"] = diff.tool_call_delta.name; + } + if (!diff.tool_call_delta.arguments.empty()) { + function["arguments"] = diff.tool_call_delta.arguments; + } + tool_call["function"] = function; } - function["arguments"] = diff.tool_call_delta.arguments; - tool_call["function"] = function; - delta["tool_calls"] = json::array({tool_call}); + delta["tool_calls"] = json::array({ tool_call }); } return delta; } @@ -515,13 +482,13 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { if (use_jinja) { try { common_chat_msg msg; - msg.role = "user"; + msg.role = "user"; msg.content = "test"; auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl); common_chat_templates_inputs inputs; - inputs.messages = {msg}; + inputs.messages = { msg }; common_chat_templates_apply(tmpls.get(), inputs); return true; @@ -530,28 +497,28 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja) { return false; } } - llama_chat_message chat[] = {{"user", "test"}}; + llama_chat_message chat[] = { + { "user", "test" } + }; const int res = llama_chat_apply_template(tmpl.c_str(), chat, 1, true, nullptr, 0); return res >= 0; } -std::string common_chat_format_single( - const struct common_chat_templates * tmpls, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja) { - +std::string common_chat_format_single(const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja) { common_chat_templates_inputs inputs; inputs.use_jinja = use_jinja; - inputs.add_bos = tmpls->add_bos; - inputs.add_eos = tmpls->add_eos; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; std::string fmt_past_msg; if (!past_msg.empty()) { - inputs.messages = past_msg; + inputs.messages = past_msg; inputs.add_generation_prompt = false; - fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; + fmt_past_msg = common_chat_templates_apply(tmpls, inputs).prompt; } std::ostringstream ss; // if the past_msg ends with a newline, we must preserve it in the formatted version @@ -561,37 +528,39 @@ std::string common_chat_format_single( // format chat with new_msg inputs.messages.push_back(new_msg); inputs.add_generation_prompt = add_ass; - auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; + auto fmt_new_msg = common_chat_templates_apply(tmpls, inputs).prompt; // get the diff part ss << fmt_new_msg.substr(fmt_past_msg.size(), fmt_new_msg.size() - fmt_past_msg.size()); return ss.str(); } -std::string common_chat_format_example(const struct common_chat_templates * tmpls, bool use_jinja, const std::map & chat_template_kwargs) { +std::string common_chat_format_example(const struct common_chat_templates * tmpls, + bool use_jinja, + const std::map & chat_template_kwargs) { common_chat_templates_inputs inputs; - inputs.use_jinja = use_jinja; - inputs.add_bos = tmpls->add_bos; - inputs.add_eos = tmpls->add_eos; + inputs.use_jinja = use_jinja; + inputs.add_bos = tmpls->add_bos; + inputs.add_eos = tmpls->add_eos; inputs.chat_template_kwargs = chat_template_kwargs; - auto add_simple_msg = [&](auto role, auto content) { + auto add_simple_msg = [&](auto role, auto content) { common_chat_msg msg; - msg.role = role; + msg.role = role; msg.content = content; inputs.messages.push_back(msg); }; - add_simple_msg("system", "You are a helpful assistant"); - add_simple_msg("user", "Hello"); + add_simple_msg("system", "You are a helpful assistant"); + add_simple_msg("user", "Hello"); add_simple_msg("assistant", "Hi there"); - add_simple_msg("user", "How are you?"); + add_simple_msg("user", "How are you?"); return common_chat_templates_apply(tmpls, inputs).prompt; } -#define CHATML_TEMPLATE_SRC \ - "{%- for message in messages -%}\n" \ +#define CHATML_TEMPLATE_SRC \ + "{%- for message in messages -%}\n" \ " {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' -}}\n" \ - "{%- endfor -%}\n" \ - "{%- if add_generation_prompt -%}\n" \ - " {{- '<|im_start|>assistant\n' -}}\n" \ + "{%- endfor -%}\n" \ + "{%- if add_generation_prompt -%}\n" \ + " {{- '<|im_start|>assistant\n' -}}\n" \ "{%- endif -%}" void common_chat_templates_free(struct common_chat_templates * tmpls) { @@ -609,19 +578,16 @@ std::string common_chat_templates_source(const struct common_chat_templates * tm return tmpls->template_tool_use->source(); } return ""; - } else { - LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str()); } + LOG_DBG("%s: unknown template variant: %s\n", __func__, variant.c_str()); } return tmpls->template_default->source(); } -common_chat_templates_ptr common_chat_templates_init( - const struct llama_model * model, - const std::string & chat_template_override, - const std::string & bos_token_override, - const std::string & eos_token_override) -{ +common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override, + const std::string & eos_token_override) { std::string default_template_src; std::string template_tool_use_src; @@ -630,7 +596,7 @@ common_chat_templates_ptr common_chat_templates_init( GGML_ASSERT(model != nullptr); const auto * str = llama_model_chat_template(model, /* name */ nullptr); if (str) { - default_template_src = str; + default_template_src = str; has_explicit_template = true; } str = llama_model_chat_template(model, /* name */ "tool_use"); @@ -652,34 +618,40 @@ common_chat_templates_ptr common_chat_templates_init( // TODO @ngxson : this is a temporary hack to prevent chat template from throwing an error // Ref: https://github.com/ggml-org/llama.cpp/pull/15230#issuecomment-3173959633 if (default_template_src.find("<|channel|>") != std::string::npos - // search for the error message and patch it - && default_template_src.find("in message.content or") != std::string::npos) { + // search for the error message and patch it + && default_template_src.find("in message.content or") != std::string::npos) { string_replace_all(default_template_src, - "{%- if \"<|channel|>analysis<|message|>\" in message.content or \"<|channel|>final<|message|>\" in message.content %}", - "{%- if false %}"); + "{%- if \"<|channel|>analysis<|message|>\" in message.content or " + "\"<|channel|>final<|message|>\" in message.content %}", + "{%- if false %}"); } // TODO @aldehir : this is a temporary fix, pending Minja changes // Ref: https://github.com/ggml-org/llama.cpp/pull/17713#issuecomment-3631342664 if (default_template_src.find("[TOOL_CALLS]") != std::string::npos - // search for the error message and patch it - && default_template_src.find("if (message['content'] is none or") != std::string::npos) { + // search for the error message and patch it + && default_template_src.find("if (message['content'] is none or") != std::string::npos) { string_replace_all(default_template_src, - "{%- if (message['content'] is none or message['content'] == '' or message['content']|length == 0) and (message['tool_calls'] is not defined or message['tool_calls'] is none or message['tool_calls']|length == 0) %}", - "{%- if false %}"); + "{%- if (message['content'] is none or message['content'] == '' or " + "message['content']|length == 0) and (message['tool_calls'] is not defined or " + "message['tool_calls'] is none or message['tool_calls']|length == 0) %}", + "{%- if false %}"); } std::string token_bos = bos_token_override; std::string token_eos = eos_token_override; - bool add_bos = false; - bool add_eos = false; + bool add_bos = false; + bool add_eos = false; if (model) { - const auto * vocab = llama_model_get_vocab(model); - const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { + const auto * vocab = llama_model_get_vocab(model); + const auto get_token = [&](llama_token token, const char * name, const char * jinja_variable_name) { if (token == LLAMA_TOKEN_NULL) { - if (default_template_src.find(jinja_variable_name) != std::string::npos - || template_tool_use_src.find(jinja_variable_name) != std::string::npos) { - LOG_WRN("common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't work as intended.\n", name); + if (default_template_src.find(jinja_variable_name) != std::string::npos || + template_tool_use_src.find(jinja_variable_name) != std::string::npos) { + LOG_WRN( + "common_chat_templates_init: warning: vocab does not have a %s token, jinja template won't " + "work as intended.\n", + name); } return std::string(); } @@ -687,13 +659,13 @@ common_chat_templates_ptr common_chat_templates_init( }; token_bos = get_token(llama_vocab_bos(vocab), "BOS", "bos_token"); token_eos = get_token(llama_vocab_eos(vocab), "EOS", "eos_token"); - add_bos = llama_vocab_get_add_bos(vocab); - add_eos = llama_vocab_get_add_eos(vocab); + add_bos = llama_vocab_get_add_bos(vocab); + add_eos = llama_vocab_get_add_eos(vocab); } common_chat_templates_ptr tmpls(new common_chat_templates()); tmpls->has_explicit_template = has_explicit_template; - tmpls->add_bos = add_bos; - tmpls->add_eos = add_eos; + tmpls->add_bos = add_bos; + tmpls->add_eos = add_eos; try { tmpls->template_default = std::make_unique(default_template_src, token_bos, token_eos); } catch (const std::exception & e) { @@ -714,35 +686,12 @@ common_chat_templates_ptr common_chat_templates_init( const char * common_chat_format_name(common_chat_format format) { switch (format) { - case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only"; - case COMMON_CHAT_FORMAT_GENERIC: return "Generic"; - case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo"; - case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral"; - case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x"; - case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools"; - case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1"; - case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; - case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; - case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1"; - case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; - case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; - case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; - case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; - case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; - case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2"; - case COMMON_CHAT_FORMAT_APERTUS: return "Apertus"; - case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools"; - case COMMON_CHAT_FORMAT_MINIMAX_M2: return "MiniMax-M2"; - case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5"; - case COMMON_CHAT_FORMAT_KIMI_K2: return "Kimi K2"; - case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; - case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; - case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open"; - case COMMON_CHAT_FORMAT_EXAONE_MOE: return "EXAONE MoE"; - case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; - case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; - case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; + case COMMON_CHAT_FORMAT_CONTENT_ONLY: + return "Content-only"; + case COMMON_CHAT_FORMAT_PEG_SIMPLE: + return "peg-simple"; + case COMMON_CHAT_FORMAT_PEG_NATIVE: + return "peg-native"; default: throw std::runtime_error("Unknown chat format"); } @@ -750,10 +699,14 @@ const char * common_chat_format_name(common_chat_format format) { const char * common_reasoning_format_name(common_reasoning_format format) { switch (format) { - case COMMON_REASONING_FORMAT_NONE: return "none"; - case COMMON_REASONING_FORMAT_AUTO: return "auto"; - case COMMON_REASONING_FORMAT_DEEPSEEK: return "deepseek"; - case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: return "deepseek-legacy"; + case COMMON_REASONING_FORMAT_NONE: + return "none"; + case COMMON_REASONING_FORMAT_AUTO: + return "auto"; + case COMMON_REASONING_FORMAT_DEEPSEEK: + return "deepseek"; + case COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY: + return "deepseek-legacy"; default: throw std::runtime_error("Unknown reasoning format"); } @@ -762,11 +715,14 @@ const char * common_reasoning_format_name(common_reasoning_format format) { common_reasoning_format common_reasoning_format_from_name(const std::string & format) { if (format == "none") { return COMMON_REASONING_FORMAT_NONE; - } else if (format == "auto") { + } + if (format == "auto") { return COMMON_REASONING_FORMAT_AUTO; - } else if (format == "deepseek") { + } + if (format == "deepseek") { return COMMON_REASONING_FORMAT_DEEPSEEK; - } else if (format == "deepseek-legacy") { + } + if (format == "deepseek-legacy") { return COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY; } throw std::runtime_error("Unknown reasoning format: " + format); @@ -782,7 +738,8 @@ static void foreach_function(const json & tools, const std::function & fn) { +static void foreach_parameter(const json & function, + const std::function & fn) { if (!function.contains("parameters") || !function.at("parameters").is_object()) { return; } @@ -790,7 +747,7 @@ static void foreach_parameter(const json & function, const std::function required; if (params.contains("required") && params.at("required").is_array()) { params.at("required").get_to(required); @@ -801,19 +758,19 @@ static void foreach_parameter(const json & function, const std::function & messages_override = std::nullopt, - const std::optional & tools_override = std::nullopt, - const std::optional & additional_context = std::nullopt) -{ + const autoparser::templates_params & inputs, + const std::optional & messages_override, + const std::optional & tools_override, + const std::optional & additional_context) { jinja::context ctx(tmpl.source()); nlohmann::ordered_json inp = nlohmann::ordered_json{ {"messages", messages_override.has_value() ? *messages_override : inputs.messages}, {"bos_token", tmpl.bos_token()}, {"eos_token", tmpl.eos_token()}, + {"enable_thinking", inputs.enable_thinking}, }; if (tools_override.has_value() || !inputs.tools.empty()) { inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools; @@ -839,7 +796,7 @@ static std::string apply( // render jinja::runtime runtime(ctx); const jinja::value results = runtime.execute(tmpl.prog); - auto parts = runtime.gather_string_parts(results); + auto parts = jinja::runtime::gather_string_parts(results); std::string result = parts->as_string().str(); @@ -853,265 +810,8 @@ static std::string apply( return result; } -static common_chat_params common_chat_params_init_generic(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - auto tool_call_schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - auto tool_schema = json { - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - }}, - {"required", json::array({"name", "arguments"})}, - }; - if (function.contains("description")) { - tool_schema["description"] = function.at("description"); - } - if (inputs.parallel_tool_calls) { - tool_schema.at("properties")["id"] = { - {"type", "string"}, - {"minLength", 4}, - }; - tool_schema.at("required").push_back("id"); - } - tool_call_schemas.emplace_back(tool_schema); - }); - const auto tool_call = - inputs.parallel_tool_calls - ? json { - {"type", "object"}, - {"properties", { - {"tool_calls", { - {"type", "array"}, - {"items", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { - {"anyOf", tool_call_schemas}, - }}, - {"minItems", 1}, - }}, - }}, - {"required", json::array({"tool_calls"})}, - } - : json { - {"type", "object"}, - {"properties", { - {"tool_call", tool_call_schemas.size() == 1 ? tool_call_schemas[0] : json { - {"anyOf", tool_call_schemas}, - }}, - }}, - {"required", json::array({"tool_call"})}, - }; - const auto schema = - inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED - ? json { - {"anyOf", json::array({ - tool_call, - { - {"type", "object"}, - {"properties", { - {"response", inputs.json_schema.is_null() - ? json {{"type", "string"}} - : inputs.json_schema - }, - }}, - {"required", json::array({"response"})}, - }, - })} - } - : tool_call; - - data.grammar_lazy = false; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - builder.add_schema("root", schema); - }); - - auto tweaked_messages = tmpl.add_system( - inputs.messages, - "Respond in JSON format, either with `tool_call` (a request to call tools) or with `response` reply to the user's request"); - - // ensure all messages has "content" field - for (auto & message : tweaked_messages) { - if (!message.contains("content") || message["content"].is_null()) { - message["content"] = ""; - } - } - - data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); - data.format = COMMON_CHAT_FORMAT_GENERIC; - return data; -} - -static common_chat_params common_chat_params_init_mistral_nemo(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - // Important note: the model is probably trained to take a JSON stringified arguments value. - // It's hard to constrain that for now (while reusing the JSON schema conversion), so we're just expecting a plain object. - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - {"id", { - {"type", "string"}, - // Nemo's template expects a 9-character alphanumeric ID. - {"pattern", "^[a-zA-Z0-9]{9}$"}, - }}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); - }); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); - data.preserved_tokens = { - "[TOOL_CALLS]", - }; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO; - return data; -} - - -// Case-insensitive find -static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) { - auto it = std::search( - haystack.begin() + pos, haystack.end(), - needle.begin(), needle.end(), - [](char a, char b) { return std::tolower(a) == std::tolower(b); } - ); - return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it); -} - -static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - const auto is_json_schema_provided = !inputs.json_schema.is_null(); - const auto is_grammar_provided = !inputs.grammar.empty(); - const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty(); - - // the logic requires potentially modifying the messages - auto tweaked_messages = inputs.messages; - - auto replace_json_schema_marker = [](json & messages) -> bool { - static std::string marker1 = "force json schema.\n"; - static std::string marker2 = "force json schema."; - - if (messages.empty() || messages.at(0).at("role") != "system") { - return false; - } - - std::string content = messages.at(0).at("content"); - - for (const auto & marker : {marker1, marker2}) { - const auto pos = ifind_string(content, marker); - if (pos != std::string::npos) { - content.replace(pos, marker.length(), ""); - // inject modified content back into the messages - messages.at(0).at("content") = content; - return true; - } - } - - return false; - }; - - // Lfm2 model does not natively work with json, but can generally understand the tools structure - // - // Example of the pytorch dialog structure: - // <|startoftext|><|im_start|>system - // List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|> - // <|im_start|>user - // What is the current status of candidate ID 12345?<|im_end|> - // <|im_start|>assistant - // <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|> - // <|im_start|>tool - // <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|> - // <|im_start|>assistant - // The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|> - // - // For the llama server compatibility with json tools semantic, - // the client can add "Follow json schema." line into the system message prompt to force the json output. - // - if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) { - // server/utils.hpp prohibits that branch for the custom grammar anyways - throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar"); - } else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) { - LOG_INF("%s: Using tools to build a grammar\n", __func__); - - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - - builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\""); - }); - // model has no concept of tool selection mode choice, - // if the system prompt rendered correctly it will produce a tool call - // the grammar goes inside the tool call body - data.grammar_lazy = true; - data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}}; - data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; - data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS; - } else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) { - LOG_INF("%s: Using tools without json schema or grammar\n", __func__); - // output those tokens - data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"}; - } else if (is_json_schema_provided) { - LOG_INF("%s: Using provided json schema to build a grammar\n", __func__); - data.grammar = json_schema_to_grammar(inputs.json_schema); - } else if (is_grammar_provided) { - LOG_INF("%s: Using provided grammar\n", __func__); - data.grammar = inputs.grammar; - } else { - LOG_INF("%s: Using content relying on the template\n", __func__); - } - - data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages); - LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str()); - - return data; -} - -static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, const struct templates_params & inputs) { +static common_chat_params common_chat_params_init_ministral_3(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; // Build up messages to follow the format: https://huggingface.co/mistralai/Ministral-3-14B-Reasoning-2512/blob/main/chat_template.jinja @@ -1129,8 +829,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ // If message contains `reasoning_content`, add it as a block of type `thinking` if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { content.push_back({ - {"type", "thinking"}, - {"thinking", msg.at("reasoning_content").get()}, + { "type", "thinking" }, + { "thinking", msg.at("reasoning_content").get() }, }); } @@ -1138,8 +838,8 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ if (msg.contains("content")) { if (msg.at("content").is_string()) { content.push_back({ - {"type", "text"}, - {"text", msg.at("content").get()}, + { "type", "text" }, + { "text", msg.at("content").get() }, }); } else if (msg.at("content").is_array()) { auto blocks = msg.at("content"); @@ -1147,32 +847,35 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ } } - auto adjusted = msg; + auto adjusted = msg; adjusted["content"] = content; adjusted.erase("reasoning_content"); adjusted_messages.push_back(adjusted); } - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; - auto include_grammar = true; + auto include_grammar = true; - data.prompt = apply(tmpl, inputs, /* messages_override = */ adjusted_messages); - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { + data.supports_thinking = true; + data.prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override = */ adjusted_messages); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { "[THINK]", "[/THINK]", "[TOOL_CALLS]", "[ARGS]", }; - auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { - auto reasoning = extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto reasoning = + extract_reasoning ? p.optional("[THINK]" + p.reasoning(p.until("[/THINK]")) + "[/THINK]") : p.eps(); // Response format parser if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { // Ministral wants to emit json surrounded by code fences - return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) << "```"; + return reasoning << "```json" << p.content(p.schema(p.json(), "response-format", inputs.json_schema)) + << "```"; } // Tool call parser @@ -1180,17 +883,16 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ auto tool_choice = p.choice(); foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - std::string name = function.at("name"); - const auto & schema = function.at("parameters"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); - tool_choice |= p.rule("tool-" + name, - p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") - + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) - ); + tool_choice |= + p.rule("tool-" + name, p.tool_open(p.tool_name(p.literal(name)) + "[ARGS]") + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); }); - auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; - auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; auto tool_calls = p.trigger_rule("tool-call", p.repeat("[TOOL_CALLS]" + tool_choice, min_calls, max_calls)); return reasoning << p.content(p.until("[TOOL_CALLS]")) << tool_calls; @@ -1209,1722 +911,369 @@ static common_chat_params common_chat_params_init_ministral_3(const common_chat_ data.grammar = build_grammar([&](const common_grammar_builder & builder) { foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - auto schema = function.at("parameters"); + auto schema = function.at("parameters"); builder.resolve_refs(schema); }); parser.build_grammar(builder, data.grammar_lazy); }); data.grammar_triggers = { - {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"} + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]" } }; } return data; } -static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_MAGISTRAL; - data.preserved_tokens = { - "[THINK]", - "[/THINK]", - }; - - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - {"id", { - {"type", "string"}, - {"pattern", "^[a-zA-Z0-9]{9}$"}, - }}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema)); - }); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"}); - data.preserved_tokens.push_back("[TOOL_CALLS]"); - } else { - data.grammar_lazy = false; - if (!inputs.json_schema.is_null()) { - if (!inputs.grammar.empty()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } - data.grammar = json_schema_to_grammar(inputs.json_schema); - } else { - data.grammar = inputs.grammar; - } - } - - return data; -} - -static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) { +static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; + // Copy reasoning to the "thinking" field as expected by the gpt-oss template auto adjusted_messages = json::array(); for (const auto & msg : inputs.messages) { auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); - auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); + if (has_reasoning_content && has_tool_calls) { - auto adjusted_message = msg; - adjusted_message["tool_plan"] = msg.at("reasoning_content"); - adjusted_message.erase("reasoning_content"); + auto adjusted_message = msg; + adjusted_message["thinking"] = msg.at("reasoning_content"); adjusted_messages.push_back(adjusted_message); } else { adjusted_messages.push_back(msg); } } - data.prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); - data.format = COMMON_CHAT_FORMAT_COMMAND_R7B; - if (string_ends_with(data.prompt, "<|START_THINKING|>")) { - if (!inputs.enable_thinking) { - data.prompt += "<|END_THINKING|>"; - } else { - data.thinking_forced_open = true; - } - } else if (!inputs.enable_thinking && string_ends_with(data.prompt, "<|CHATBOT_TOKEN|>")) { - data.prompt += "<|START_THINKING|><|END_THINKING|>"; - } - - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"tool_call_id", { - {"type", "string"}, - // Command-R's template expects an integer string. - {"pattern", "^[0-9]{1,10}$"}, - }}, - {"tool_name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"parameters", function.at("parameters")}, - }}, - {"required", json::array({"tool_call_id", "tool_name", "parameters"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"<|END_THINKING|>\" space )? " : "") + - "\"<|START_ACTION|>\" " + builder.add_schema("tool_calls", schema) + " \"<|END_ACTION|>\""); - }); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(<\\|END_THINKING\\|>\\s*)" : "(?:<\\|START_THINKING\\|>[\\s\\S]*?<\\|END_THINKING\\|>\\s*)?") + - "(<\\|START_ACTION\\|>)[\\s\\S]*" - }); - data.preserved_tokens = { - "<|START_ACTION|>", - "<|END_ACTION|>", - "<|START_RESPONSE|>", - "<|END_RESPONSE|>", - "<|START_THINKING|>", - "<|END_THINKING|>", - }; - return data; -} - -static void expect_tool_parameters(const std::string & name, const json & parameters, const std::vector & expected_properties) { - if (!parameters.is_object() || !parameters.contains("type") || parameters.at("type") != "object" || !parameters.contains("properties") || !parameters.contains("required")) { - throw std::runtime_error("Parameters of tool " + name + " must be an object w/ required properties"); - } - const auto & parameters_properties = parameters.at("properties"); - const auto & parameters_required = parameters.at("required"); - for (const auto & prop : expected_properties) { - if (!parameters_properties.contains(prop)) { - throw std::runtime_error("Parameters of tool " + name + " is missing property: " + prop); // NOLINT - } - if (std::find(parameters_required.begin(), parameters_required.end(), json(prop)) == parameters_required.end()) { - throw std::runtime_error("Parameters of tool " + name + " must have property marked as required: " + prop); // NOLINT - } - } - if (parameters_properties.size() != expected_properties.size()) { - throw std::runtime_error("Parameters of tool " + name + " must only have these properties:" + string_join(expected_properties, ", ")); - } -} - -static common_chat_params common_chat_params_init_llama_3_x(const common_chat_template & tmpl, const struct templates_params & inputs, bool allow_python_tag_builtin_tools) { - auto builtin_tools = json::array(); - common_chat_params data; - if (!inputs.tools.is_null()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - - auto handle_builtin_tool = [&](const std::string & name, const json & parameters) { - if (name == "wolfram_alpha" || name == "web_search" || name == "brave_search") { - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/wolfram_alpha/wolfram_alpha.py - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/tool_runtime/brave_search/brave_search.py - expect_tool_parameters(name, parameters, {"query"}); - } else if (name == "python" || name == "code_interpreter") { - // https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py - expect_tool_parameters(name, parameters, {"code"}); - } else { - return false; - } - - std::vector kvs; - for (const auto & [key, value] : parameters.at("properties").items()) { - kvs.push_back("\"" + key + "=\" " + builder.add_schema(name + "-args-" + key, value)); // NOLINT - } - - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"<|python_tag|>" + name + ".call(\" " + string_join(kvs, " \", \" ") + " \")\"")); - builtin_tools.push_back(name); - - return true; - }; - - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - // https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote/tool_runtime - if (allow_python_tag_builtin_tools) { - handle_builtin_tool(name, parameters); - } - tool_rules.push_back( - builder.add_rule( - name + "-call", - "\"{\" space " - "( \"\\\"type\\\"\" space \":\" space \"\\\"function\\\"\" space \",\" space )? " - " \"\\\"name\\\"\" space \":\" space \"\\\"" + name + "\\\"\" space \",\" space " - " \"\\\"parameters\\\"\" space \":\" space " + builder.add_schema(name + "-args", parameters) + " " - "\"}\" space")); - }); - // Small models may hallucinate function names so we match anything (*at the start*) that looks like the JSON of a function call, regardless of the name. - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "(\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\")[\\s\\S]*", // + name + "\"[\\s\\S]*", - }); - if (!builtin_tools.empty()) { - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); - } - // Allow a few empty lines on top of the usual constrained json schema space rule. - builder.add_rule("root", string_join(tool_rules, " | ")); - data.additional_stops.push_back("<|eom_id|>"); - }); - data.format = allow_python_tag_builtin_tools && !builtin_tools.empty() - ? COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS - : COMMON_CHAT_FORMAT_LLAMA_3_X; - } else { - data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - } - data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override= */ std::nullopt, json { - {"date_string", format_time(inputs.now, "%d %b %Y")}, - {"tools_in_user_message", false}, - {"builtin_tools", builtin_tools}, - }); - return data; -} - -static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - // Generate the prompt using the apply() function with the template - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2; + auto prompt = common_chat_template_direct_apply(tmpl, inputs, /* messages_override= */ adjusted_messages); - // Handle thinking tags appropriately based on inputs.enable_thinking - if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; + // Check if we need to replace the return token with end token during + // inference and without generation prompt. For more details see: + // https://github.com/ggml-org/llama.cpp/issues/15417 + if (inputs.is_inference && !inputs.add_generation_prompt) { + static constexpr std::string_view return_token = "<|return|>"; + static constexpr std::string_view end_token = "<|end|>"; + if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) { + prompt.replace(pos, return_token.length(), end_token); } } - // When tools are present, build grammar for the format, similar to CommandR, but without tool call ID - if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = true; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - { "type", "object" }, - { "properties", - { - { "name", - { - { "type", "string" }, - { "const", function.at("name") }, - } }, - { "arguments", function.at("parameters") }, - } }, - { "required", json::array({ "name", "arguments" }) }, - }); - }); - auto schema = json{ - { "type", "array" }, - { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, - { "minItems", 1 }, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - "\"\" " + builder.add_schema("tool_calls", schema) + - " \"\""); - }); - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? - "[\\s\\S]*?(\\s*)" : - "(?:[\\s\\S]*?\\s*)?") + - "()[\\s\\S]*" }); - } - return data; -} - -static common_chat_params common_chat_params_init_qwen3_coder(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_PEG_CONSTRUCTED; - - // Nemotron Nano 3 and Step-3.5-Flash use the Qwen3 Coder tool calling with thinking - bool supports_reasoning = (tmpl.source().find("") != std::string::npos); - - // Handle thinking tags appropriately based on inputs.enable_thinking - if (supports_reasoning && string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } + data.prompt = prompt; + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + // These special tokens are required to parse properly, so we include them + // even if parse_tool_calls is false. data.preserved_tokens = { - "", - "", + "<|channel|>", "<|constrain|>", "<|message|>", "<|start|>", "<|end|>", }; - if (supports_reasoning) { - data.preserved_tokens.insert(data.preserved_tokens.end(), {"", ""}); - } - - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; - auto include_grammar = true; - - auto parser = build_chat_peg_constructed_parser([&](auto & p) { - auto reasoning = p.eps(); - if (supports_reasoning && inputs.enable_thinking && extract_reasoning) { - auto reasoning_content = p.reasoning(p.until("")) + ("" | p.end()); - if (data.thinking_forced_open) { - reasoning = reasoning_content; - } - } + auto include_grammar = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && has_tools; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + const std::string END = "<|end|>"; + const std::string START = "<|start|>"; + const std::string MESSAGE = "<|message|>"; + const std::string CHANNEL = "<|channel|>"; + const std::string CONSTRAIN = "<|constrain|>"; + const std::string START_ASSISTANT = START + "assistant"; + const std::string CHANNEL_ANALYSIS = CHANNEL + "analysis"; + const std::string CHANNEL_COMMENTARY = CHANNEL + "commentary"; + const std::string CHANNEL_FINAL = CHANNEL + "final"; + + auto the_end = END | p.end(); + + const std::string analysis_header = CHANNEL_ANALYSIS + MESSAGE; + auto segment_content = p.until(END); + auto analysis_segment = extract_reasoning ? + p.literal(analysis_header) + p.reasoning(segment_content) + p.until(END) + the_end : + p.content(analysis_header + p.until(END) + the_end); + + auto channel_header_content = p.until_one_of({ " to=functions.", MESSAGE }); + auto content_header = p.choice({ p.literal(CHANNEL_COMMENTARY), p.literal(CHANNEL_FINAL) }); + auto content_segment = p.rule("content-segment", content_header + channel_header_content + MESSAGE + + p.content(segment_content) + the_end); - // Response format parser - if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { - return reasoning << p.content(p.schema(p.json(), "response-format", inputs.json_schema)); + if (!inputs.json_schema.is_null()) { + auto final_header = p.literal(CHANNEL_FINAL); + auto constraint = p.optional(p.space() + p.literal(CONSTRAIN) + channel_header_content); + return p.optional(analysis_segment) + final_header + constraint + MESSAGE + + p.content(p.schema(p.json(), "response-format", inputs.json_schema)); } + auto segment = p.optional(START_ASSISTANT + p.space()) + p.choice({ content_segment, analysis_segment }); + auto contents = p.optional(segment + p.repeat(p.optional(p.space()) + segment, 0, -1)) + p.end(); + // Tool call parser if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - - auto schema_info = common_schema_info(); - schema_info.resolve_refs(parameters); - - auto tool_open = "\n"; - auto tool_close = p.literal("\n"); - auto args = p.sequence(); - auto arg_string = p.rule("xml-arg-string", p.until_one_of({ - "\n", - "\n" - })); - - foreach_parameter(function, [&](const auto & param_name, const json & param_schema, bool is_required) { - auto rule_name = "tool-" + name + "-arg-" + param_name; - - auto arg_open = "\n"; - auto arg_close = p.literal("\n"); - auto arg_value = p.eps(); - - if (schema_info.resolves_to_string(param_schema)) { - arg_value = p.tool_arg_string_value(arg_string) + "\n"; - } else { - arg_value = p.tool_arg_json_value(p.schema(p.json(), rule_name + "-schema", param_schema)); - } + std::string name = function.at("name"); + const auto & params = function.at("parameters"); - // Model may or my not close with - auto arg_rule = p.rule(rule_name, p.tool_arg_open(arg_open) + arg_value + p.optional(p.tool_arg_close(arg_close))); - args += p.repeat(arg_rule, /* min = */ is_required ? 1 : 0, /* max = */ 1); - }); + // Tool call can appear as: + // 1. In role header: " to=functions.NAME<|channel|>..." + // 2. In channel: "<|channel|>(analysis|commentary) to=functions.NAME..." + auto func_name = p.literal(" to=functions.") + p.tool_name(p.literal(name)); + + auto channel = p.literal(CHANNEL_COMMENTARY) | p.literal(CHANNEL_ANALYSIS); + auto constraint = p.space() + p.optional(p.literal(CONSTRAIN) + channel_header_content); + auto args = p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", params)); - tool_choice |= p.rule("tool-" + name, p.tool_open(tool_open) + args + p.tool_close(tool_close)); + // Pattern 1: recipient in role header + // " to=functions.NAME<|channel|>(analysis|commentary)[constraint]<|message|>ARGS" + auto tool_in_role = p.tool(p.tool_open(func_name + channel) + constraint + MESSAGE + args); + + // Pattern 2: recipient in channel header + // "<|channel|>(analysis|commentary) to=functions.NAME[constraint]<|message|>ARGS" + auto tool_in_channel = p.tool(channel + p.tool_open(func_name + constraint + MESSAGE) + args); + + tool_choice |= tool_in_role | tool_in_channel; }); auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; auto max_calls = inputs.parallel_tool_calls ? -1 : 1; - auto tool_call = p.rule("tool-call", "\n" + tool_choice + "" + p.space()); - auto tool_calls = p.trigger_rule("tool-call-root", p.repeat(tool_call, /* min = */ min_calls, /* max = */ max_calls)); - return reasoning << p.content(p.until("")) << tool_calls; + auto role_start = p.optional(p.space() + p.literal(START_ASSISTANT)); + auto tool_call = p.rule("tool-call", p.repeat(role_start + tool_choice, min_calls, max_calls) + p.end()); + + return p.choice({ p.trigger_rule("single-tool", tool_call), p.trigger_rule("tools", p.one_or_more(segment) + tool_call) }); } - // Content only parser - include_grammar = false; - return reasoning << p.content(p.rest()); + return contents; }); data.parser = parser.save(); if (include_grammar) { data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; - - data.grammar = build_grammar([&](const common_grammar_builder & builder) { + data.grammar = build_grammar([&](const common_grammar_builder & builder) { foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - auto schema = function.at("parameters"); + auto schema = function.at("parameters"); builder.resolve_refs(schema); }); parser.build_grammar(builder, data.grammar_lazy); }); data.grammar_triggers = { - {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, ""} + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "^(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" }, + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, "(?:<\\|end\\|>)(?:<\\|start\\|>assistant\\s*)?(\\s+to=functions)" }, + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, + "(?:<\\|start\\|>assistant\\s*)?(<\\|channel\\|>(?:commentary|analysis)\\s+to=functions)" } }; } return data; } - -static common_chat_params common_chat_params_init_apertus(const common_chat_template & tmpl, const struct templates_params & inputs) { +// Functionary v3.2 - uses recipient-based format: >>>recipient\n{content} +static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; - // Generate the prompt using the apply() function with the template - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_APERTUS; + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.preserved_tokens = { + ">>>all", + }; - // Handle thinking tags appropriately based on inputs.enable_thinking - if (string_ends_with(data.prompt, "<|inner_prefix|>")) { - if (!inputs.enable_thinking) { - data.prompt += "<|inner_suffix|>"; - } else { - data.thinking_forced_open = true; - } - } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; - // When tools are present, build grammar for the <|tools_prefix|> format - if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = true; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - { "type", "object" }, - { "properties", - { - { function.at("name"), function.at("parameters") } - } }, - { "required", json::array({ function.at("name") }) }, - }); - }); - auto schema = json{ - { "type", "array" }, - { "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } }, - { "minItems", 1 }, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"<|inner_suffix|>\" space )? " : "") + - "\"<|tools_prefix|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tools_suffix|>\""); - }); - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the <|inner_suffix|> tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? - "[\\s\\S]*?(<\\|inner_suffix\\|>\\s*)" : - "(?:<\\|inner_prefix\\|>[\\s\\S]*?<\\|inner_suffix\\|>\\s*)?") + - "(<\\|tools_prefix\\|>)[\\s\\S]*" }); - data.preserved_tokens = { - "<|system_start|>", - "<|system_end|>", - "<|developer_start|>", - "<|developer_end|>", - "<|user_start|>", - "<|user_end|>", - "<|assistant_start|>", - "<|assistant_end|>", - "<|inner_prefix|>", - "<|inner_suffix|>", - "<|tools_prefix|>", - "<|tools_suffix|>", - }; - } - return data; -} - -static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - auto prompt = apply(tmpl, inputs); - - // Hacks to fix the official (broken) prompt. - // It is advisable to use --chat-template-file models/templates/llama-cpp-deepseek-r1.jinja instead, - // until the official template is fixed. - if (tmpl.source().find("{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}") != std::string::npos) { - // Don't leave the chat dangling after tool results - if (string_ends_with(prompt, "<|tool▁outputs▁end|>")) { - prompt += "<|end▁of▁sentence|>"; - if (inputs.add_generation_prompt) { - prompt += "<|Assistant|>"; - } - } - // Fix up tool call delta example added by Minja - prompt = std::regex_replace( - prompt, - std::regex("(<|tool▁call▁end|>)[\\s\\r\\n]*(<|tool▁outputs▁begin|>|<|User|>)"), - "$1<|tool▁calls▁end|><|end▁of▁sentence|>$2"); - } - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_DEEPSEEK_R1; - if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "( \"<|tool▁call▁begin|>\" )? \"function<|tool▁sep|>" + name + "\\n" - "```json\\n\" " + builder.add_schema(name + "-args", parameters) + " " - "\"```<|tool▁call▁end|>\"")); - }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " - "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + - "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" - }); - data.preserved_tokens = { - "", - "", - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>", - "<|tool▁sep|>", - "<|tool▁call▁end|>", - "<|tool▁calls▁end|", - }; - }); - } - return data; -} - -static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Pass thinking context for DeepSeek V3.1 template - json additional_context = { - {"thinking", inputs.enable_thinking}, - }; - - auto prompt = apply(tmpl, inputs, - /* messages_override= */ inputs.messages, - /* tools_override= */ std::nullopt, - additional_context); - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - if (string_ends_with(data.prompt, "")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", - "( \"<|tool▁call▁begin|>\" )? \"" + name + "<|tool▁sep|>" - "\" " + builder.add_schema(name + "-args", parameters) + " " - "\"<|tool▁call▁end|>\"")); - }); - // Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag, - // so we accept common variants (then it's all constrained) - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - "( \"<|tool▁calls▁begin|>\" | \"<|tool_calls_begin|>\" | \"<|tool calls begin|>\" | \"<|tool\\\\_calls\\\\_begin|>\" | \"<|tool▁calls|>\" ) " - "(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " " - "\"<|tool▁calls▁end|>\"" - " space"); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + - "(<|tool▁calls▁begin|>|<|tool_calls_begin|>|<|tool calls begin|>|<|tool\\\\_calls\\\\_begin|>|<|tool▁calls|>)[\\s\\S]*" - }); - data.preserved_tokens = { - "", - "", - "<|tool▁calls▁begin|>", - "<|tool▁call▁begin|>", - "<|tool▁sep|>", - "<|tool▁call▁end|>", - "<|tool▁calls▁end|>", - }; - }); - } - return data; -} - -static common_chat_params common_chat_params_init_minimax_m2(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_MINIMAX_M2; - - // Handle thinking tags based on prompt ending - if (string_ends_with(data.prompt, "\n")) { - if (!params.enable_thinking) { - // Close the thinking tag immediately if thinking is disabled - data.prompt += "\n\n"; - } else { - // Mark thinking as forced open (template started with ) - data.thinking_forced_open = true; - } - } - - // Preserve MiniMax-M2 special tokens - data.preserved_tokens = { - "", - "", - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form { - /* form.scope_start = */ "\n", - /* form.tool_start = */ "\n", - /* form.key_start = */ "", - /* form.val_end = */ "\n", - /* form.tool_end = */ "\n", - /* form.scope_end = */ "", - }; - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_KIMI_K2; - - data.preserved_tokens = { - "", - "", - "<|tool_calls_section_begin|>", - "<|tool_call_begin|>", - "<|tool_call_argument_begin|>", - "<|tool_call_end|>", - "<|tool_calls_section_end|>", - "<|im_end|>", - "<|im_system|>", - "<|im_middle|>", - }; - - data.additional_stops.insert(data.additional_stops.end(), { - "<|im_end|>", - "<|im_middle|>" - }); - // build grammar for tool call - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "<|tool_calls_section_begin|>"; - form.tool_start = "<|tool_call_begin|>"; - form.tool_sep = "<|tool_call_argument_begin|>{"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}<|tool_call_end|>"; - form.scope_end = "<|tool_calls_section_end|>"; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_apriel_1_5(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_APRIEL_1_5; - - data.preserved_tokens = { - "", - "", - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "["; - form.tool_start = "{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}, "; - form.scope_end = "]"; - form.raw_argval = false; - form.last_val_end = ""; - form.last_tool_end = "}"; - return form; - })(); - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_xiaomi_mimo(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_XIAOMI_MIMO; - - data.preserved_tokens = { - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = "\n"; - form.tool_start = "\n{\"name\": \""; - form.tool_sep = "\", \"arguments\": {"; - form.key_start = "\""; - form.key_val_sep = "\": "; - form.val_end = ", "; - form.tool_end = "}\n"; - form.scope_end = ""; - form.raw_argval = false; - form.last_val_end = ""; - return form; - })(); - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - -static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Copy reasoning to the "thinking" field as expected by the gpt-oss template - auto adjusted_messages = json::array(); - for (const auto & msg : inputs.messages) { - auto has_reasoning_content = msg.contains("reasoning_content") && msg.at("reasoning_content").is_string(); - auto has_tool_calls = msg.contains("tool_calls") && msg.at("tool_calls").is_array(); - - if (has_reasoning_content && has_tool_calls) { - auto adjusted_message = msg; - adjusted_message["thinking"] = msg.at("reasoning_content"); - adjusted_message.erase("content"); - adjusted_messages.push_back(adjusted_message); - } else { - adjusted_messages.push_back(msg); - } - } + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + // Functionary v3.2 format: + // - Normal content: >>>all\n{content} + // - Tool calls: >>>function_name\n{json_args} + // Generation prompt ends with ">>>" so model outputs recipient immediately - auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); + // Build content parser for >>>all\n{content} + // When tools are present, content stops before the next ">>>" (tool call) + // When no tools, content goes until end + auto content_until_tool = p.literal(">>>all\n") + p.content(p.until(">>>")); + auto content_until_end = p.literal(">>>all\n") + p.content(p.rest()); - // Check if we need to replace the return token with end token during - // inference and without generation prompt. For more details see: - // https://github.com/ggml-org/llama.cpp/issues/15417 - if (inputs.is_inference && !inputs.add_generation_prompt) { - static constexpr std::string_view return_token = "<|return|>"; - static constexpr std::string_view end_token = "<|end|>"; - if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) { - prompt.replace(pos, return_token.length(), end_token); + // If no tools or tool_choice is NONE, just parse content + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + // When no tools, just match the prefix and capture everything after + return content_until_end + p.end(); } - } - - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_GPT_OSS; - // These special tokens are required to parse properly, so we include them - // even if parse_tool_calls is false. - data.preserved_tokens = { - "<|channel|>", - "<|constrain|>", - "<|message|>", - "<|start|>", - "<|end|>", - }; + // Build tool call parsers for each available function + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); - if (!inputs.json_schema.is_null()) { - data.grammar_lazy = false; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schema = inputs.json_schema; - builder.resolve_refs(schema); - - auto not_end = builder.add_rule("not-end", - "[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]"); - auto analysis = builder.add_rule("analysis", - "\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\""); - auto constraint = builder.add_rule("constraint", "\"<|constrain|>\"? [a-zA-Z0-9_-]+"); - auto final = builder.add_rule("final", - "\"<|channel|>final\" ( \" \" " + constraint + " )? \"<|message|>\" " + - builder.add_schema("response", schema) + // Tool format: >>>function_name\n{json_args} + auto tool_parser = p.tool( + p.tool_open(p.literal(">>>") + p.tool_name(p.literal(name)) + p.literal("\n")) + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) ); - builder.add_rule("root", "( " + analysis + " \"<|start|>assistant\" )? " + final); + tool_choice |= p.rule("tool-" + name, tool_parser); }); - } - - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - // tool calls can appear in commentary or analysis channels - auto channel = builder.add_rule("channel", "\"<|channel|>\" ( \"commentary\" | \"analysis\" )"); - std::vector tool_rules_recipient_in_role; - std::vector tool_rules_recipient_in_channel; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - tool_rules_recipient_in_role.push_back( - builder.add_rule(name + "-call", - "\"" + name + "\"" + channel + " \" <|constrain|>json\"? \"<|message|>\" " + - builder.add_schema(name + "-args", parameters) - ) - ); - - tool_rules_recipient_in_channel.push_back( - builder.add_rule(name + "-call", - "\"" + name + "\"" + " \" <|constrain|>json\"? \"<|message|>\" " + - builder.add_schema(name + "-args", parameters) - ) - ); - }); - - auto recipient_in_channel = builder.add_rule("recipient_in_channel", - channel + " \" to=functions.\" ( " + - string_join(tool_rules_recipient_in_channel, " | ") + " )" - ); + auto content_only = content_until_end; + auto tools_only = p.trigger_rule("tools", p.one_or_more(tool_choice)); + auto content_and_tools = content_until_tool + tools_only; - if (data.grammar_lazy) { - auto recipient_in_role = builder.add_rule("recipient_in_role", - "\"<|start|>assistant\"? \" to=functions.\" ( " + - string_join(tool_rules_recipient_in_role, " | ") + " )" - ); - - builder.add_rule("root", recipient_in_role + " | " + recipient_in_channel); - } else { - auto not_end = builder.add_rule("not-end", - "[^<] | \"<\" [^|] | \"<|\" [^e] | \"<|e\" [^n] | \"<|en\" [^d] | \"<|end\" [^|] | \"<|end|\" [^>]"); - auto analysis = builder.add_rule("analysis", - "\"<|channel|>analysis<|message|>\" ( " + not_end + " )* \"<|end|>\""); - auto commentary = builder.add_rule("commentary", - "\"<|channel|>commentary<|message|>\" ( " + not_end + " )* \"<|end|>\""); - - auto recipient_in_role = builder.add_rule("recipient_in_role", - "\" to=functions.\" ( " + string_join(tool_rules_recipient_in_role, " | ") + " )" - ); - - builder.add_rule("root", - "( " + analysis + " \"<|start|>assistant\" )? " + - "( " + commentary + " \"<|start|>assistant\" )? " + - "( " + recipient_in_role + " | " + recipient_in_channel + " )" - ); + if (inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED) { + if (inputs.parallel_tool_calls) { + return p.choice({ content_and_tools, tools_only }) + p.end(); } - - // Trigger on tool calls that appear in the commentary channel - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|channel\\|>(?:commentary|analysis) to" - }); - - // Trigger tool calls that appear in the role section, either at the - // start or in the middle. - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "^ to" - }); - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|start\\|>assistant to" - }); - }); - } - - return data; -} - -static common_chat_params common_chat_params_init_glm_4_5(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.grammar_lazy = inputs.tools.is_array() && !inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - std::string prompt = apply(tmpl, inputs); - - // match the existing trimming behavior - if (inputs.add_bos && string_starts_with(prompt, tmpl.bos_token())) { - prompt.erase(0, tmpl.bos_token().size()); - } - if (inputs.add_eos && string_ends_with(prompt, tmpl.eos_token())) { - prompt.erase(prompt.size() - tmpl.eos_token().size()); - } - if (string_ends_with(prompt, "")) { - if (!inputs.enable_thinking) { - prompt += ""; - } else { - data.thinking_forced_open = true; + return p.choice({ content_until_tool + tool_choice, tools_only }) + p.end(); } - } - - // add GLM preserved tokens - data.preserved_tokens = { - "<|endoftext|>", - "[MASK]", - "[gMASK]", - "[sMASK]", - "", - "", - "<|system|>", - "<|user|>", - "<|assistant|>", - "<|observation|>", - "<|begin_of_image|>", - "<|end_of_image|>", - "<|begin_of_video|>", - "<|end_of_video|>", - "<|begin_of_audio|>", - "<|end_of_audio|>", - "<|begin_of_transcription|>", - "<|end_of_transcription|>", - "<|code_prefix|>", - "<|code_middle|>", - "<|code_suffix|>", - "/nothink", - "", - "", - "", - "", - "", - "", - "", - "" - }; - - // extra GLM 4.5 stop word - data.additional_stops.insert(data.additional_stops.end(), { - "<|user|>", - "<|observation|>" + if (inputs.parallel_tool_calls) { + return p.choice({ content_and_tools, content_only, tools_only }) + p.end(); + } + auto content_and_tool = content_until_tool + tool_choice; + return p.choice({ content_and_tool, content_only, tool_choice }) + p.end(); }); - // build grammar for tool call - static const xml_tool_call_format form { - /* form.scope_start = */ "", - /* form.tool_start = */ "\n", - /* form.tool_sep = */ "\n", - /* form.key_start = */ "", - /* form.key_val_sep = */ "\n", - /* form.val_end = */ "\n", - /* form.tool_end = */ "\n", - /* form.scope_end = */ "", - }; - build_grammar_xml_tool_call(data, inputs.tools, form); - - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_GLM_4_5; - return data; -} - -static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { - LOG_DBG("%s\n", __func__); - common_chat_params data; - const std::optional additional_context = json { - {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, - {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, - }; - data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context); - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - auto schemas = json::array(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - schemas.push_back({ - {"type", "object"}, - {"properties", { - {"name", { - {"type", "string"}, - {"const", function.at("name")}, - }}, - {"arguments", function.at("parameters")}, - }}, - {"required", json::array({"name", "arguments", "id"})}, - }); - }); - auto schema = json { - {"type", "array"}, - {"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}}, - {"minItems", 1}, - }; - if (!inputs.parallel_tool_calls) { - schema["maxItems"] = 1; - } - builder.add_rule("root", "\" functools\"? " + builder.add_schema("tool_calls", schema)); - }); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, " functools["}); - data.preserved_tokens = { - " functools[", - }; - data.format = COMMON_CHAT_FORMAT_FIREFUNCTION_V2; - } else { - data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - } - return data; -} - -static common_chat_params common_chat_params_init_functionary_v3_2(const common_chat_template & tmpl, const struct templates_params & inputs) { - // >>>all\nlet's call functions>>>fn1\n{"arg1": 1...}\n>>>fn2\n{"arg1": 1...}... - // Using ">>>f1\n", ">>>f2\n"... as trigger words for the grammar - // If the function is python, we also allow raw python code (if the line after `python\n` doesn't start w/ opening `{`), which the model seems to prefer for multiline code. - common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2; - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector first_tool_rules; - std::vector subsequent_tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - std::string args_pattern = "[\\s\\S]*"; - auto args_rule = builder.add_schema(name + "-args", parameters); - if (name == "python") { - args_rule = builder.add_rule(name + "-maybe-raw-args", args_rule + " | [^{] .*"); - } else { - args_pattern = "\\{" + args_pattern; - } - auto call_rule = builder.add_rule(name + "-call", "\"" + name + "\\n\" " + args_rule); - first_tool_rules.push_back(call_rule); - if (inputs.parallel_tool_calls) { - subsequent_tool_rules.push_back(builder.add_rule(name + "-call2", "\">>>\" " + call_rule)); - } - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - "((?:[\\s\\S]+?>>>)?" + regex_escape(name) + "\n)" + args_pattern, - }); - }); - data.preserved_tokens = { - "<|end_header_id|>", - }; - auto first_rule = first_tool_rules.empty() ? "" : builder.add_rule("first_tool_call", string_join(first_tool_rules, " | ")) + " space"; - if (inputs.parallel_tool_calls) { - auto subsequent_rule = builder.add_rule("subsequent_tool_call", string_join(subsequent_tool_rules, " | ")) + " space"; - builder.add_rule("root", first_rule + " (" + subsequent_rule + ")*"); - } else { - builder.add_rule("root", first_rule); - } - - }); - } - return data; -} - -static common_chat_params common_chat_params_init_functionary_v3_1_llama_3_1(const common_chat_template & tmpl, const struct templates_params & inputs) { - // https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v3-llama3.1.txt - common_chat_params data; + data.parser = parser.save(); - if (!inputs.tools.is_null()) { - std::string python_code_argument_name; - auto has_raw_python = false; + if (include_grammar) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - const auto & parameters = function.at("parameters"); - std::string name = function.at("name"); - if (name == "python" || name == "ipython") { - if (!parameters.contains("type")) { - throw std::runtime_error("Missing type in python tool"); - } - has_raw_python = true; - const auto & type = parameters.at("type"); - if (type == "object") { - auto properties = parameters.at("properties"); - for (auto it = properties.begin(); it != properties.end(); ++it) { - if (it.value().at("type") == "string") { - if (!python_code_argument_name.empty()) { - throw std::runtime_error("Multiple string arguments found in python tool"); - } - python_code_argument_name = it.key(); - } - } - if (python_code_argument_name.empty()) { - throw std::runtime_error("No string argument found in python tool"); - } - } else if (type != "string") { - throw std::runtime_error("Invalid type in python tool: " + type.dump()); - } - } - tool_rules.push_back(builder.add_rule(name + "-call", "\"\" " + builder.add_schema(name + "-args", parameters) + " \"\" space")); + auto schema = function.at("parameters"); + builder.resolve_refs(schema); }); - if (has_raw_python) { - tool_rules.push_back(builder.add_rule("python-call", "\"<|python_tag|>\" .*")); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|python_tag|>"}); - data.preserved_tokens.push_back("<|python_tag|>"); - } - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")) + " space"; - builder.add_rule("root", inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call); - data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "\n")) { - if (!extra_context["enable_thinking"]) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (!inputs.tools.is_null()) { - // (content)?({"name": "foo", "arguments": {"a": 1}})* - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - std::vector tool_call_alts; - std::vector escaped_names; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_schema(name + "-call", { - {"type", "object"}, - {"properties", json { - {"name", json {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - })); - tool_call_alts.push_back(builder.add_rule( - name + "-function-tag", - "\"\" space " + - builder.add_schema(name + "-args", parameters) + " " - "\"\" space")); - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "", - }); - auto escaped_name = regex_escape(name); - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - " alt_tags { - any_tool_call, - "\"\" space " + any_tool_call + " \"\"", - // The rest is just to accommodate common "good bad" outputs. - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - "\"\" space " + any_tool_call + " \"\"", - }; - auto wrappable_tool_call = builder.add_rule("wrappable_tool_call", "( " + string_join(alt_tags, " | ") + " ) space"); - tool_call_alts.push_back(wrappable_tool_call); - tool_call_alts.push_back( - "( \"```\\n\" | \"```json\\n\" | \"```xml\\n\" ) space " + wrappable_tool_call + " space \"```\" space "); - auto tool_call = builder.add_rule("tool_call", string_join(tool_call_alts, " | ")); - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); - // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - // If thinking_forced_open, then we capture the tag in the grammar, - // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "(\\s*)" : "") + ( - "\\s*(" - "(?:" - "||||)?" - "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" - ")" - ")" - ), - }); - data.preserved_tokens = { - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "", - "```", - "```json", - "```xml", - }; - }); + // Grammar trigger for when the model starts outputting a tool call + // (after the initial ">>>" in the generation prompt but recipient other than "all") + data.grammar_triggers = { + { COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, ">>>(?!all)" } + }; } return data; } -static common_chat_params common_chat_params_init_granite(const common_chat_template & tmpl, const struct templates_params & inputs) { +// Kimi K2 Thinking - uses unique tool call ID format: functions.: +// The ID contains both the function name and an incrementing counter +static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, + const autoparser::templates_params & inputs) { common_chat_params data; - // Pass thinking context for Granite template - json additional_context = { - {"thinking", inputs.enable_thinking}, + data.prompt = common_chat_template_direct_apply(tmpl, inputs); + data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; + data.supports_thinking = true; + data.preserved_tokens = { + "<|tool_calls_section_begin|>", + "<|tool_calls_section_end|>", + "<|tool_call_begin|>", + "<|tool_call_argument_begin|>", + "<|tool_call_end|>", + "", + "", }; - data.prompt = apply(tmpl, inputs, /* messages_override= */ std::nullopt, /* tools_override= */ std::nullopt, additional_context); - data.format = COMMON_CHAT_FORMAT_GRANITE; - - if (string_ends_with(data.prompt, "\n") || string_ends_with(data.prompt, "")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (!inputs.tools.is_null()) { - // Granite uses <|tool_call|> followed by JSON list - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - tool_rules.push_back(builder.add_rule(name + "-call", builder.add_schema(name + -"-args", { - {"type", "object"}, - {"properties", { - {"name", {{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - }))); - }); - - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); - auto tool_list = builder.add_rule("tool_list", "\"[\" space " + tool_call + " (\",\" space " + tool_call + ")* space \"]\""); - - if (data.thinking_forced_open) { - builder.add_rule("root", "\"\" space \"\" space [^<]* \"\" space \"<|tool_call|>\" space " + tool_list); - } else { - builder.add_rule("root", "\"<|tool_call|>\" space " + tool_list); - } - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_WORD, - "<|tool_call|>" - }); + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); + auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; + auto include_grammar = has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + // Kimi K2 Thinking format: + // - Reasoning: {reasoning} + // - Content: text after reasoning + // - Tool calls section: + // <|tool_calls_section_begin|> + // <|tool_call_begin|>functions.:<|tool_call_argument_begin|>{json_args}<|tool_call_end|> + // ... + // <|tool_calls_section_end|> + // The ID format is: functions.: where counter is 0, 1, 2, ... + + // Tool call markers + const std::string SECTION_BEGIN = "<|tool_calls_section_begin|>"; + const std::string SECTION_END = "<|tool_calls_section_end|>"; + const std::string CALL_BEGIN = "<|tool_call_begin|>"; + const std::string ARGS_BEGIN = "<|tool_call_argument_begin|>"; + const std::string CALL_END = "<|tool_call_end|>"; + + const std::string THINK_START = ""; + const std::string THINK_END = ""; + + auto end = p.end(); + + // Note: this model is CRAZY. It can diverge from its supposed tool calling pattern in so many ways it's not funny. + // For example, it can call tools at the end of reasoning without closing reasoning... + auto reasoning = extract_reasoning ? p.optional(THINK_START + p.reasoning( + p.until_one_of({ THINK_END, "<|tool_calls_section_begin|>", "<|tool_call_begin|>" })) + + p.optional(p.literal(THINK_END))) : p.eps(); + + + // Content only parser (no tools) + if (!has_tools || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { + return reasoning + p.content(p.rest()) + end; + } + + // Build tool call parsers for each available function + // The ID format is: functions.: + // We need to match: functions.: + auto tool_choice = p.choice(); + foreach_function(inputs.tools, [&](const json & tool) { + const auto & function = tool.at("function"); + std::string name = function.at("name"); + const auto & schema = function.at("parameters"); + + // Match: functions.: + // Capture the full call id (functions.:) using tool_id tag + auto tool_id = p.tool_id(p.literal("functions.") + p.tool_name(p.literal(name)) + p.literal(":") + p.chars("[0-9]", 1, -1)); + auto tool_parser = p.tool( + p.tool_open(tool_id + p.literal(ARGS_BEGIN)) + + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)) + + p.tool_close(p.optional((p.literal(CALL_END)))) + ); - data.preserved_tokens = { - "", - "", - "", - "", - "<|tool_call|>", - }; + tool_choice |= p.rule("tool-" + name, tool_parser); }); - } else { - // Handle thinking tags for non-tool responses - if (data.thinking_forced_open && inputs.enable_thinking) { - data.grammar_lazy = false; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - builder.add_rule("root", "\"\" space \"\" space .* \"\" space"); - }); - data.preserved_tokens = { - "", - "", - "", - "", - }; - } - } - - return data; -} - -static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // Copy `reasoning_content` to `reasoning` - auto adjusted_messages = json::array(); - for (const auto & msg : inputs.messages) { - if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) { - auto adjusted_message = msg; - adjusted_message["reasoning"] = msg.at("reasoning_content"); - adjusted_message.erase("reasoning_content"); - adjusted_messages.push_back(adjusted_message); - } else { - adjusted_messages.push_back(msg); - } - } - - auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); - auto include_grammar = true; - - auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages); - - // Check if we need to replace the flush token with end token during inference and without generation prompt. - if (inputs.is_inference && !inputs.add_generation_prompt) { - static constexpr std::string_view return_token = "<|flush|>"; - static constexpr std::string_view end_token = "<|end|>"; - if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) { - prompt.replace(pos, return_token.length(), end_token); - } - } - - data.prompt = prompt; - data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; - data.preserved_tokens = { - "<|think|>", - "<|content|>", - "<|begin|>", - "<|end|>", - "<|tool_calls|>", - "<|tool_call:begin|>", - "<|tool_call:end|>", - "<|tool_call:name|>", - "<|tool_call:args|>", - }; - - auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { - auto lit_think = p.atomic(p.literal("<|think|>")); - auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant")); - auto lit_content = p.atomic(p.literal("<|content|>")); - auto lit_end = p.atomic(p.literal("<|end|>")); - auto parser_until_end = p.until("<|end|>"); - - // reasoning <- "<|think|>" (!"<|end|>" .)* - auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end)); - - // content <- "<|content|>" (!"<|end|>" .)* - auto parser_content = p.rule("content", lit_content + p.content(parser_until_end)); - - // wrap_choice(items) <- item-choice wrapped* - // item-choice <- items[0] / ... / items[n] - // wrapped <- "<|end|><|begin|>assistant" item-choice - auto wrap_choice = [&](const std::vector & items) { - auto choice = p.choice(items); - return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice); - }; - - // wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ... - auto wrap_seq = [&](const std::vector & items) { - auto seq = p.sequence(); - for (auto i = 0u; i < items.size(); i++) { - if (i == 0) { - seq += items[i]; - continue; - } - seq += lit_end + lit_assistant_begin + items[i]; - } - return seq; - }; - - // Response format parser - if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) { - auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema)); - return p.choice({ - wrap_seq({parser_reasoning, parser_response_format}), - wrap_seq({parser_response_format}) - }); - } - auto lit_tool_call_begin = p.literal("<|tool_call:begin|>"); - auto lit_tool_call_name = p.literal("<|tool_call:name|>"); - auto lit_tool_call_args = p.literal("<|tool_call:args|>"); - auto lit_tool_call_end = p.literal("<|tool_call:end|>"); + // Tool calls section: <|tool_calls_section_begin|> tool_calls <|tool_calls_section_end|> + auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; + auto max_calls = inputs.parallel_tool_calls ? -1 : 1; + // Use trigger_rule so grammar generator knows where to start generating rules + auto tool_calls = p.rule("tool-calls", + p.optional(p.literal(SECTION_BEGIN)) + + p.trigger_rule("tool-call", p.repeat(CALL_BEGIN + tool_choice, min_calls, max_calls) + + p.optional(p.literal(SECTION_END))) + ); - // Tool call parser - if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { - auto parser_tool_call = p.choice(); - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - const auto & schema = function.at("parameters"); + auto content_before_tools = p.content(p.until_one_of({ SECTION_BEGIN, CALL_BEGIN })); - // tool(name, schema) <- name "<|tool_call:args|>" schema - parser_tool_call |= p.rule("tool-" + name, - p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args) - + p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); - }); - - auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0; - auto max_calls = inputs.parallel_tool_calls ? -1 : 1; - - // tool-calls <- "<|tool_calls|>" tool-call+ - // tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>" - // call-id <- [a-zA-Z0-9_-]+ - // tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema) - auto parser_tool_calls = p.trigger_rule("tool-calls", - p.atomic(p.literal("<|tool_calls|>")) - + p.repeat( - p.tool_open( - lit_tool_call_begin - + p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1)) - + lit_tool_call_name - + p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args)) - + parser_tool_call - + p.tool_close(lit_tool_call_end), - /* min = */ 1, - /* max = */ max_calls)); - - if (min_calls == 1) { - // If required, then try any combination of the reasoning, content, and tool call - return p.choice({ - wrap_seq({parser_reasoning, parser_content, parser_tool_calls}), - wrap_seq({parser_reasoning, parser_tool_calls}), - wrap_seq({parser_content, parser_tool_calls}), - wrap_seq({parser_tool_calls}) - }); - } - - return wrap_choice({parser_reasoning, parser_content, parser_tool_calls}); - } - - // Content only parser - include_grammar = false; - return wrap_choice({parser_reasoning, parser_content}); + return reasoning + content_before_tools + tool_calls + end; }); data.parser = parser.save(); if (include_grammar) { - data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; - - data.grammar = build_grammar([&](const common_grammar_builder & builder) { + data.grammar_lazy = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO; + data.grammar = build_grammar([&](const common_grammar_builder & builder) { foreach_function(inputs.tools, [&](const json & tool) { const auto & function = tool.at("function"); - auto schema = function.at("parameters"); + auto schema = function.at("parameters"); builder.resolve_refs(schema); }); parser.build_grammar(builder, data.grammar_lazy); }); data.grammar_triggers = { - {COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"} + { COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_call_begin|>" } }; } return data; } -static common_chat_params common_chat_params_init_exaone_moe(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_EXAONE_MOE; - if (string_ends_with(data.prompt, "\n")) { - if (!inputs.enable_thinking) { - data.prompt += "\n\n"; - } else { - data.thinking_forced_open = true; - } - } - - if (inputs.tools.is_array() && !inputs.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null(); - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(inputs.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - // Expect: {"name": "", "arguments": {...}} - tool_rules.push_back(builder.add_rule( - name + "-call", - "\"\" space " + - builder.add_schema(name + "-obj", json{ - {"type", "object"}, - {"properties", { - {"name", json{{"const", name}}}, - {"arguments", parameters}, - }}, - {"required", json::array({"name", "arguments"})}, - }) + - " space \"\" space")); - }); - - auto tool_call = builder.add_rule("tool_call", string_join(tool_rules, " | ")); - builder.add_rule("root", - std::string(data.thinking_forced_open ? "( \"\" space )? " : "") + - (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); - - data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)?" : "") + - "()[\\s\\S]*" - }); - data.preserved_tokens = { - "", - "", - "", - "", - }; - }); - } - - return data; -} - -static common_chat_params common_chat_params_init_translate_gemma(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - - // This template does not support tools or reasoning - // we just need to transform the messages into the correct schema - - templates_params inputs_new = inputs; - json & messages = inputs_new.messages; - - // default to chat_template_kwargs, or en-GB if not specified - std::string default_src_lang = inputs.extra_context.value("source_lang_code", "en-GB"); - std::string default_tgt_lang = inputs.extra_context.value("target_lang_code", "en-GB"); - - GGML_ASSERT(messages.is_array()); - for (auto & message : messages) { - if (message.contains("role") && message["role"].get() != "user") { - continue; - } - if (!message.contains("content")) { - message["content"] = json::array(); - } - if (message.contains("content") && !message["content"].is_array()) { - auto content_str = message["content"].get(); - // default to en-GB if not specified (to make common_chat_format_example works) - auto src_lang = message.contains("source_lang_code") - ? message["source_lang_code"].get() : default_src_lang; - auto tgt_lang = message.contains("target_lang_code") - ? message["target_lang_code"].get() : default_tgt_lang; - message["content"] = json::array({ - json{ - {"type", "text"}, - {"text", content_str}, - {"source_lang_code", src_lang}, - {"target_lang_code", tgt_lang}, - } - }); - } - } - - data.prompt = apply(tmpl, inputs_new, std::nullopt, std::nullopt); - data.format = COMMON_CHAT_FORMAT_GENERIC; - - return data; -} - -static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { - common_chat_params data; - data.prompt = apply(tmpl, inputs); - data.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - data.grammar_lazy = false; - if (!inputs.json_schema.is_null()) { - if (!inputs.grammar.empty()) { - throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both"); - } - data.grammar = json_schema_to_grammar(inputs.json_schema); - } else { - data.grammar = inputs.grammar; - } - return data; -} - -static common_chat_params common_chat_params_init_seed_oss( - const common_chat_template & tmpl, - templates_params & params, - const common_chat_templates_inputs & inputs) -{ - common_chat_params data; - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_SEED_OSS; - if (string_ends_with(data.prompt, "")) { - if (!inputs.enable_thinking) { - data.prompt += ""; - } else { - data.thinking_forced_open = true; - } - } - - if (params.tools.is_array() && !params.tools.empty()) { - data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - data.grammar = build_grammar([&](const common_grammar_builder & builder) { - std::vector tool_rules; - foreach_function(params.tools, [&](const json & tool) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - builder.resolve_refs(parameters); - - // Create rule for Seed-OSS function call format - std::string param_rules; - if (parameters.contains("properties")) { - for (const auto & [key, value] : parameters.at("properties").items()) { - param_rules += "\"\"" + builder.add_schema(name + "-arg-" + key, value) + - "\"\""; - } - } - - tool_rules.push_back(builder.add_rule(name + "-call", - "\"\" space \"\" space " + - param_rules + - " \"\" space \"\"")); - }); - - data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "" }); - - data.preserved_tokens = { - "", "", "", "", - "", "", - }; - - builder.add_rule("root", string_join(tool_rules, " | ")); - }); - } - return data; -} - -// various workarounds for known issues with certain templates or model behaviors -// TODO @ngxson : improve this (how?) namespace workaround { // if first message is system and template does not support it, merge it with next message @@ -2944,6 +1293,15 @@ static void system_message_not_supported(json & messages) { } } +static void requires_non_null_content(json & messages) { + GGML_ASSERT(messages.is_array()); + for (auto & message : messages) { + if (message.contains("tool_calls") && !message.contains("content")) { + message["content"] = ""; + } + } +} + static void func_args_not_string(json & messages) { GGML_ASSERT(messages.is_array()); for (auto & message : messages) { @@ -2964,71 +1322,21 @@ static void func_args_not_string(json & messages) { } } -static void move_tool_calls_to_content(json & messages, int indent_spaces = 2) { - GGML_ASSERT(messages.is_array()); - for (auto & message : messages) { - if (message.contains("tool_calls")) { - auto tool_calls_new = json{ - {"tool_calls", message.at("tool_calls")} - }; - message.erase("tool_calls"); - auto content = message.at("content"); - std::string content_new = content.is_null() ? "" : content.get(); - message["content"] = content_new + tool_calls_new.dump(indent_spaces, ' ', false, json::error_handler_t::replace); - } - } } -// TODO @ngxson : we may remove support for generic schema in the future -static void use_generic_schema(json & messages) { - GGML_ASSERT(messages.is_array()); - for (auto & message : messages) { - if (message.contains("tool_calls") && message.at("tool_calls").is_array()) { - auto & tool_calls = message.at("tool_calls"); - for (auto & tool_call : tool_calls) { - if (tool_call.contains("type") && tool_call.at("type") == "function" && - tool_call.contains("function") && tool_call.at("function").is_object()) { - // Copy values before erasing to avoid use-after-free - json name_value; - json arguments_value; - json id_value; - const auto & function = tool_call.at("function"); - if (function.contains("name")) { - name_value = function.at("name"); - } - if (function.contains("arguments")) { - arguments_value = function.at("arguments"); - } - if (tool_call.contains("id")) { - id_value = tool_call.at("id"); - } - // Now safely erase and assign in the correct order - tool_call.erase("type"); - tool_call.erase("function"); - tool_call.erase("id"); - // Reassign in desired order: name, arguments, id - if (!name_value.is_null()) { - tool_call["name"] = name_value; - } - if (!arguments_value.is_null()) { - tool_call["arguments"] = arguments_value; - } - if (!id_value.is_null()) { - tool_call["id"] = id_value; - } - } - } - } - } +static json common_chat_extra_context() { + json ctx = json::object(); + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::string datetime_str = format_time(now, "%b %d %Y"); + std::string date_str = format_time(now, "%d %b %Y"); + ctx["datetime"] = datetime_str; + ctx["date_string"] = date_str; + return ctx; } -} // namespace workaround - -static common_chat_params common_chat_templates_apply_jinja( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) -{ - templates_params params; +static common_chat_params common_chat_templates_apply_jinja(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { + autoparser::templates_params params; params.tools = common_chat_tools_to_json_oaicompat(inputs.tools); const auto & tmpl = params.tools.is_array() && tmpls->template_tool_use ? *tmpls->template_tool_use @@ -3049,7 +1357,14 @@ static common_chat_params common_chat_templates_apply_jinja( workaround::system_message_not_supported(params.messages); } - params.extra_context = json::object(); + if (tmpl.original_caps().supports_tool_calls) { + // some templates will require the content field in tool call messages + // to still be non-null, this puts an empty string everywhere where the + // content field is null + workaround::requires_non_null_content(params.messages); + } + + params.extra_context = common_chat_extra_context(); for (auto el : inputs.chat_template_kwargs) { params.extra_context[el.first] = json::parse(el.second); } @@ -3058,229 +1373,71 @@ static common_chat_params common_chat_templates_apply_jinja( params.json_schema = json::parse(inputs.json_schema); } - if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { - LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); - params.parallel_tool_calls = false; - } else { - params.parallel_tool_calls = inputs.parallel_tool_calls; - } + // if (inputs.parallel_tool_calls && !tmpl.original_caps().supports_parallel_tool_calls) { + // LOG_DBG("Disabling parallel_tool_calls because the template does not support it\n"); + // params.parallel_tool_calls = false; + // } else { + params.parallel_tool_calls = inputs.parallel_tool_calls; + //} if (params.tools.is_array()) { if (params.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE && !params.grammar.empty()) { throw std::runtime_error("Cannot specify grammar with tools"); } if (caps.supports_tool_calls && !caps.supports_tools) { - LOG_WRN("Template supports tool calls but does not natively describe tools. The fallback behaviour used may produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); - } - } - - // DeepSeek V3.1: detect based on specific patterns in the template - if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos && - params.json_schema.is_null()) { - return common_chat_params_init_deepseek_v3_1(tmpl, params); - } - - // DeepSeek R1: use handler in all cases except json schema (thinking / tools). - if (src.find("<|tool▁calls▁begin|>") != std::string::npos && params.json_schema.is_null()) { - return common_chat_params_init_deepseek_r1(tmpl, params); - } - - // Command R7B: : use handler in all cases except json schema (thinking / tools). - if (src.find("<|END_THINKING|><|START_ACTION|>") != std::string::npos && params.json_schema.is_null()) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_command_r7b(tmpl, params); - } - - // Granite (IBM) - detects thinking / tools support - if (src.find("elif thinking") != std::string::npos && src.find("<|tool_call|>") != std::string::npos) { - workaround::func_args_not_string(params.messages); - workaround::use_generic_schema(params.messages); - workaround::move_tool_calls_to_content(params.messages); - return common_chat_params_init_granite(tmpl, params); - } - - // GLM 4.5: detect by and tags (check before Hermes since both use ) - if (src.find("[gMASK]") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - params.json_schema.is_null()) { - workaround::func_args_not_string(params.messages); - if (!params.extra_context.contains("clear_thinking")) { - // by default, do not clear reasoning_content (added since GLM-4.7) - params.extra_context["clear_thinking"] = false; + LOG_WRN( + "Template supports tool calls but does not natively describe tools. The fallback behaviour used may " + "produce bad results, inspect prompt w/ --verbose & consider overriding the template.\n"); } - return common_chat_params_init_glm_4_5(tmpl, params); - } - - // Qwen3-Coder XML format detection (must come before Hermes 2 Pro) - // Detect via XML markers: , , and blocks. - // Also matches Step-3.5-Flash and Nemotron 3 Nano which use the same output format. - if (src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("# Tools") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos) { - return common_chat_params_init_xiaomi_mimo(tmpl, params); } - // EXAONE MoE format detection - if (src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("<|tool_declare|>") != std::string::npos) { - return common_chat_params_init_exaone_moe(tmpl, params); - } - - // Hermes 2/3 Pro, Qwen 2.5 Instruct (w/ tools) - if (src.find("") != std::string::npos && params.json_schema.is_null()) { - return common_chat_params_init_hermes_2_pro(tmpl, params); + // Ministral/Mistral Large 3 - uses special reasoning structure fixes, can't use autoparser + // Note: Mistral Small 3.2 uses [CALL_ID] which Ministral doesn't have, so we can distinguish them + if (src.find("[SYSTEM_PROMPT]") != std::string::npos && src.find("[TOOL_CALLS]") != std::string::npos && + src.find("[ARGS]") != std::string::npos && src.find("[CALL_ID]") == std::string::npos) { + LOG_DBG("Using specialized template: Ministral/Magistral Large 3\n"); + return common_chat_params_init_ministral_3(tmpl, params); } - // GPT-OSS + // GPT-OSS - has unique channel-based structure that needs dedicated handler if (src.find("<|channel|>") != std::string::npos) { + LOG_DBG("Using specialized template: GPT-OSS\n"); return common_chat_params_init_gpt_oss(tmpl, params); } - // Seed-OSS - if (src.find("") != std::string::npos) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_seed_oss(tmpl, params, inputs); - } - - // Nemotron v2 - if (src.find("") != std::string::npos) { - return common_chat_params_init_nemotron_v2(tmpl, params); - } - - // Apertus format detection - if (src.find("<|system_start|>") != std::string::npos && src.find("<|tools_prefix|>") != std::string::npos) { - return common_chat_params_init_apertus(tmpl, params); - } - - // LFM2 (w/ tools) - if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos && - src.find("]<|tool_list_end|>") != std::string::npos) { - return common_chat_params_init_lfm2(tmpl, params); - } - - // MiniMax-M2 format detection - if (src.find("]~!b[") != std::string::npos && src.find("]~b]") != std::string::npos) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_minimax_m2(tmpl, params); - } - - // Kimi K2 format detection - if (src.find("<|im_system|>tool_declare<|im_middle|>") != std::string::npos && - src.find("<|tool_calls_section_begin|>") != std::string::npos && - src.find("## Return of") != std::string::npos) { - return common_chat_params_init_kimi_k2(tmpl, params); - } - - // Apriel 1.5 format detection - if (src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("") != std::string::npos && - src.find("<|assistant|>") != std::string::npos && - src.find("<|tool_result|>") != std::string::npos && - src.find("[") != std::string::npos && - src.find("]") != std::string::npos) { - return common_chat_params_init_apriel_1_5(tmpl, params); - } - - // Solar Open - if (src.find("<|tool_response:begin|>") != std::string::npos && - src.find("<|tool_response:name|>") != std::string::npos && - src.find("<|tool_response:result|>") != std::string::npos) { - return common_chat_params_init_solar_open(tmpl, params); - } - - // Use generic handler when mixing tools + JSON schema. - // TODO: support that mix in handlers below. - if ((params.tools.is_array() && params.json_schema.is_object())) { - return common_chat_params_init_generic(tmpl, params); - } - - // Functionary prepends "all\n" to plain content outputs, so we use its handler in all cases. - if (src.find(">>>all") != std::string::npos) { + // Functionary v3.2 - uses recipient-based format with >>>recipient\n{content} + // Detection: template has ">>>all" for content and ">>>" prefix for tool calls + if (src.find(">>>all") != std::string::npos && src.find(">>>${recipient}") != std::string::npos) { + LOG_DBG("Using specialized template: Functionary v3.2\n"); return common_chat_params_init_functionary_v3_2(tmpl, params); } - // Firefunction v2 requires datetime and functions in the context even w/o tools, so we also use its handler in all cases. - if (src.find(" functools[") != std::string::npos) { - return common_chat_params_init_firefunction_v2(tmpl, params); - } - - // Functionary v3.1 (w/ tools) - if (src.find("<|start_header_id|>") != std::string::npos - && src.find("ipython<|end_header_id|>") != std::string::npos) { - auto allow_python_tag_builtin_tools = src.find("<|python_tag|>") != std::string::npos; - workaround::func_args_not_string(params.messages); - return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools); - } - - // Ministral/Mistral Large 3 - if (src.find("[SYSTEM_PROMPT]") != std::string::npos && - src.find("[TOOL_CALLS]") != std::string::npos && - src.find("[ARGS]") != std::string::npos) { - return common_chat_params_init_ministral_3(tmpl, params); - } - - if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) { - return common_chat_params_init_magistral(tmpl, params); - } - - // Solar Open - if (src.find("<|tool_response:begin|>") != std::string::npos && - src.find("<|tool_response:name|>") != std::string::npos && - src.find("<|tool_response:result|>") != std::string::npos) { - return common_chat_params_init_solar_open(tmpl, params); - } - - // TranslateGemma - if (src.find("[source_lang_code]") != std::string::npos && - src.find("[target_lang_code]") != std::string::npos) { - return common_chat_params_init_translate_gemma(tmpl, params); - } - - // Plain handler (no tools) - if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { - return common_chat_params_init_without_tools(tmpl, params); + // Kimi K2 Thinking - uses unique tool call ID format: functions.: + // Detection: template has "<|tool_calls_section_begin|>" and "functions." prefix in tool call IDs + if (src.find("<|tool_calls_section_begin|>") != std::string::npos && + src.find("<|tool_call_begin|>") != std::string::npos) { + LOG_DBG("Using specialized template: Kimi K2 Thinking\n"); + return common_chat_params_init_kimi_k2(tmpl, params); } - // Mistral Nemo (w/ tools) - if (src.find("[TOOL_CALLS]") != std::string::npos) { - workaround::func_args_not_string(params.messages); - return common_chat_params_init_mistral_nemo(tmpl, params); + try { + LOG_DBG("Using differential autoparser\n"); + struct autoparser::autoparser autoparser; + autoparser.analyze_template(tmpl); + auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); + auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; + return auto_params; + } catch (const std::exception & e) { + throw std::invalid_argument(std::string("Unable to generate parser for this template. Automatic parser generation failed: ") + e.what()); } - - // Generic fallback - workaround::func_args_not_string(params.messages); - workaround::use_generic_schema(params.messages); - workaround::move_tool_calls_to_content(params.messages); - return common_chat_params_init_generic(tmpl, params); } // Legacy template route (adhoc C++ implementation of known templates), forward to llama_chat_apply_template. -static common_chat_params common_chat_templates_apply_legacy( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) -{ - size_t alloc_size = 0; +static common_chat_params common_chat_templates_apply_legacy(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { + size_t alloc_size = 0; std::vector chat; - std::vector contents; + std::vector contents; for (const auto & msg : inputs.messages) { auto content = msg.content; @@ -3290,25 +1447,27 @@ static common_chat_params common_chat_templates_apply_legacy( continue; } if (!content.empty()) { - content += "\n";; + content += "\n"; + ; } content += part.text; } contents.emplace_back(std::move(content)); } for (size_t i = 0; i < contents.size(); ++i) { - const auto & msg = inputs.messages[i]; + const auto & msg = inputs.messages[i]; const auto & content = contents[i]; - chat.push_back({msg.role.c_str(), content.c_str()}); + chat.push_back({ msg.role.c_str(), content.c_str() }); size_t msg_size = msg.role.size() + content.size(); - alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops + alloc_size += msg_size + (msg_size / 4); // == msg_size * 1.25 but avoiding float ops } std::vector buf(alloc_size); // run the first time to get the total output length const auto & src = tmpls->template_default->source(); - int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + int32_t res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, + buf.data(), buf.size()); // error: chat template is not supported if (res < 0) { @@ -3320,7 +1479,8 @@ static common_chat_params common_chat_templates_apply_legacy( // if it turns out that our buffer is too small, we resize it if ((size_t) res > buf.size()) { buf.resize(res); - res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), buf.size()); + res = llama_chat_apply_template(src.c_str(), chat.data(), chat.size(), inputs.add_generation_prompt, buf.data(), + buf.size()); } // for safety, we check the result again @@ -3338,14 +1498,72 @@ static common_chat_params common_chat_templates_apply_legacy( return params; } -common_chat_params common_chat_templates_apply( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs) -{ +common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs) { GGML_ASSERT(tmpls != nullptr); - return inputs.use_jinja - ? common_chat_templates_apply_jinja(tmpls, inputs) - : common_chat_templates_apply_legacy(tmpls, inputs); + return inputs.use_jinja ? common_chat_templates_apply_jinja(tmpls, inputs) : + common_chat_templates_apply_legacy(tmpls, inputs); +} + +common_chat_msg common_chat_parse(const std::string & input, + bool is_partial, + const common_chat_parser_params & params) { + return common_chat_peg_parse(params.parser, input, is_partial, params); +} + +common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, + const std::string & input, + bool is_partial, + const common_chat_parser_params & params) { + const common_peg_arena & parser = src_parser.empty() ? + build_chat_peg_parser([](common_chat_peg_builder & p) { return p.content(p.rest()) + p.end(); }) : + src_parser; + + if (src_parser.empty()) { + LOG_WRN("No parser definition detected, assuming pure content parser."); + } + + LOG_DBG("Parsing PEG input with format %s: %s\n", common_chat_format_name(params.format), input.c_str()); + + common_peg_parse_context ctx(input, is_partial); + ctx.debug = params.debug; + auto result = parser.parse(ctx); + + if (result.fail()) { + // During partial parsing, return partial results if any AST nodes were captured + // This allows streaming to work correctly for formats like FUNC_MARKDOWN_CODE_BLOCK + if (is_partial && result.end > 0) { + // Try to extract any partial results from what was successfully parsed + common_chat_msg msg; + msg.role = "assistant"; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (ctx.debug) { + fprintf(stderr, "\nAST for partial parse (fail):\n%s\n", ctx.ast.dump().c_str()); + fflush(stderr); + } + return msg; + } + throw std::runtime_error(std::string("Failed to parse input at pos ") + std::to_string(result.end) + ": " + + input.substr(result.end)); + } + + common_chat_msg msg; + msg.role = "assistant"; + + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + if (ctx.debug) { + fprintf(stderr, "\nAST for %s parse:\n%s\n", is_partial ? "partial" : "full", ctx.ast.dump().c_str()); + fflush(stderr); + } + + if (!is_partial) { + LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat({ msg }).at(0).dump().c_str()); + } + return msg; } std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates) { @@ -3353,3 +1571,4 @@ std::map common_chat_templates_get_caps(const common_chat_tem GGML_ASSERT(chat_templates->template_default != nullptr); return chat_templates->template_default->caps.to_map(); } + diff --git a/common/chat.h b/common/chat.h index 6f0b9409ec9..005cc5c8b3f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -3,17 +3,30 @@ #pragma once #include "common.h" +#include "jinja/parser.h" +#include "nlohmann/json_fwd.hpp" #include "peg-parser.h" -#include +#include "jinja/runtime.h" +#include "jinja/caps.h" +#include "nlohmann/json.hpp" + #include +#include +#include #include #include -#include + +using chat_template_caps = jinja::caps; +using json = nlohmann::ordered_json; #include struct common_chat_templates; +namespace autoparser { +struct templates_params; +} // namespace autoparser + struct common_chat_tool_call { std::string name; std::string arguments; @@ -38,21 +51,85 @@ struct common_chat_msg_content_part { } }; +struct common_chat_template { + jinja::program prog; + std::string bos_tok; + std::string eos_tok; + std::string src; + chat_template_caps caps; + + common_chat_template(const std::string & src, const std::string & bos_token, const std::string & eos_token) { + jinja::lexer lexer; + auto lexer_res = lexer.tokenize(src); + this->prog = jinja::parse_from_tokens(lexer_res); + + this->src = lexer_res.source; + this->bos_tok = bos_token; + this->eos_tok = eos_token; + + this->caps = jinja::caps_get(prog); + // LOG_INF("%s: caps:\n%s\n", __func__, this->caps.to_string().c_str()); + } + + const std::string & source() const { return src; } + const std::string & bos_token() const { return bos_tok; } + const std::string & eos_token() const { return eos_tok; } + + // TODO: this is ugly, refactor it somehow + json add_system(const json & messages, const std::string & system_prompt) const { + GGML_ASSERT(messages.is_array()); + auto msgs_copy = messages; + if (!caps.supports_system_role) { + if (msgs_copy.empty()) { + msgs_copy.insert(msgs_copy.begin(), json{ + {"role", "user"}, + {"content", system_prompt} + }); + } else { + auto & first_msg = msgs_copy[0]; + if (!first_msg.contains("content")) { + first_msg["content"] = ""; + } + first_msg["content"] = system_prompt + "\n\n" + + first_msg["content"].get(); + } + } else { + if (msgs_copy.empty() || msgs_copy[0].at("role") != "system") { + msgs_copy.insert(msgs_copy.begin(), json{ + {"role", "system"}, + {"content", system_prompt} + }); + } else if (msgs_copy[0].at("role") == "system") { + msgs_copy[0]["content"] = system_prompt; + } + } + return msgs_copy; + } + + chat_template_caps original_caps() const { + return caps; + } + +}; + struct common_chat_msg { - std::string role; - std::string content; + std::string role; + std::string content; std::vector content_parts; - std::vector tool_calls; - std::string reasoning_content; - std::string tool_name; - std::string tool_call_id; + std::vector tool_calls; + std::string reasoning_content; + std::string tool_name; + std::string tool_call_id; nlohmann::ordered_json to_json_oaicompat(bool concat_typed_text = false) const; bool empty() const { - return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); + return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && + tool_name.empty() && tool_call_id.empty(); } - void set_tool_call_ids(std::vector & ids_cache, const std::function & gen_tool_call_id) { + + void set_tool_call_ids(std::vector & ids_cache, + const std::function & gen_tool_call_id) { for (auto i = 0u; i < tool_calls.size(); i++) { if (ids_cache.size() <= i) { auto id = tool_calls[i].id; @@ -64,32 +141,28 @@ struct common_chat_msg { tool_calls[i].id = ids_cache[i]; } } + bool operator==(const common_chat_msg & other) const { - return role == other.role - && content == other.content - && content_parts == other.content_parts - && tool_calls == other.tool_calls - && reasoning_content == other.reasoning_content - && tool_name == other.tool_name - && tool_call_id == other.tool_call_id; - } - bool operator!=(const common_chat_msg & other) const { - return !(*this == other); + return role == other.role && content == other.content && content_parts == other.content_parts && + tool_calls == other.tool_calls && reasoning_content == other.reasoning_content && + tool_name == other.tool_name && tool_call_id == other.tool_call_id; } + + bool operator!=(const common_chat_msg & other) const { return !(*this == other); } }; struct common_chat_msg_diff { - std::string reasoning_content_delta; - std::string content_delta; - size_t tool_call_index = std::string::npos; + std::string reasoning_content_delta; + std::string content_delta; + size_t tool_call_index = std::string::npos; common_chat_tool_call tool_call_delta; - static std::vector compute_diffs(const common_chat_msg & msg_prv, const common_chat_msg & msg_new); + static std::vector compute_diffs(const common_chat_msg & msg_prv, + const common_chat_msg & msg_new); bool operator==(const common_chat_msg_diff & other) const { - return content_delta == other.content_delta - && tool_call_index == other.tool_call_index - && tool_call_delta == other.tool_call_delta; + return content_delta == other.content_delta && tool_call_index == other.tool_call_index && + tool_call_delta == other.tool_call_delta; } }; @@ -107,64 +180,39 @@ enum common_chat_tool_choice { enum common_chat_format { COMMON_CHAT_FORMAT_CONTENT_ONLY, - COMMON_CHAT_FORMAT_GENERIC, - COMMON_CHAT_FORMAT_MISTRAL_NEMO, - COMMON_CHAT_FORMAT_MAGISTRAL, - COMMON_CHAT_FORMAT_LLAMA_3_X, - COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - COMMON_CHAT_FORMAT_DEEPSEEK_R1, - COMMON_CHAT_FORMAT_FIREFUNCTION_V2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, - COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - COMMON_CHAT_FORMAT_HERMES_2_PRO, - COMMON_CHAT_FORMAT_COMMAND_R7B, - COMMON_CHAT_FORMAT_GRANITE, - COMMON_CHAT_FORMAT_GPT_OSS, - COMMON_CHAT_FORMAT_SEED_OSS, - COMMON_CHAT_FORMAT_NEMOTRON_V2, - COMMON_CHAT_FORMAT_APERTUS, - COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, - COMMON_CHAT_FORMAT_GLM_4_5, - COMMON_CHAT_FORMAT_MINIMAX_M2, - COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_CHAT_FORMAT_APRIEL_1_5, - COMMON_CHAT_FORMAT_XIAOMI_MIMO, - COMMON_CHAT_FORMAT_SOLAR_OPEN, - COMMON_CHAT_FORMAT_EXAONE_MOE, // These are intended to be parsed by the PEG parser COMMON_CHAT_FORMAT_PEG_SIMPLE, COMMON_CHAT_FORMAT_PEG_NATIVE, - COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, - COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats + COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats }; struct common_chat_templates_inputs { - std::vector messages; - std::string grammar; - std::string json_schema; - bool add_generation_prompt = true; - bool use_jinja = true; + std::vector messages; + std::string grammar; + std::string json_schema; + bool add_generation_prompt = true; + bool use_jinja = true; // Parameters below only supported when use_jinja is true - std::vector tools; - common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; - bool parallel_tool_calls = false; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking" - bool enable_thinking = true; - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - std::map chat_template_kwargs; - bool add_bos = false; - bool add_eos = false; + std::vector tools; + common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; + bool parallel_tool_calls = false; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool enable_thinking" + bool enable_thinking = true; + std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + std::map chat_template_kwargs; + bool add_bos = false; + bool add_eos = false; }; struct common_chat_params { common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; std::string prompt; std::string grammar; - bool grammar_lazy = false; + bool grammar_lazy = false; bool thinking_forced_open = false; + bool supports_thinking = false; std::vector grammar_triggers; std::vector preserved_tokens; std::vector additional_stops; @@ -174,13 +222,14 @@ struct common_chat_params { // per-message parsing syntax // should be derived from common_chat_params struct common_chat_parser_params { - common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" + common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; // TODO: refactor this to "bool parse_reasoning" // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) - bool reasoning_in_content = false; - bool thinking_forced_open = false; - bool parse_tool_calls = true; - common_peg_arena parser = {}; + bool reasoning_in_content = false; + bool thinking_forced_open = false; + bool parse_tool_calls = true; + bool debug = false; // Enable debug output for PEG parser + common_peg_arena parser = {}; common_chat_parser_params() = default; common_chat_parser_params(const common_chat_params & chat_params) { format = chat_params.format; @@ -193,45 +242,42 @@ bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); void common_chat_templates_free(struct common_chat_templates * tmpls); -struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; +struct common_chat_templates_deleter { + void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } +}; typedef std::unique_ptr common_chat_templates_ptr; -common_chat_templates_ptr common_chat_templates_init( - const struct llama_model * model, - const std::string & chat_template_override, - const std::string & bos_token_override = "", - const std::string & eos_token_override = ""); +common_chat_templates_ptr common_chat_templates_init(const struct llama_model * model, + const std::string & chat_template_override, + const std::string & bos_token_override = "", + const std::string & eos_token_override = ""); bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); std::string common_chat_templates_source(const struct common_chat_templates * tmpls, const std::string & variant = ""); - -struct common_chat_params common_chat_templates_apply( - const struct common_chat_templates * tmpls, - const struct common_chat_templates_inputs & inputs); +struct common_chat_params common_chat_templates_apply(const struct common_chat_templates * tmpls, + const struct common_chat_templates_inputs & inputs); // Format single message, while taking into account the position of that message in chat history -std::string common_chat_format_single( - const struct common_chat_templates * tmpls, - const std::vector & past_msg, - const common_chat_msg & new_msg, - bool add_ass, - bool use_jinja); +std::string common_chat_format_single(const struct common_chat_templates * tmpls, + const std::vector & past_msg, + const common_chat_msg & new_msg, + bool add_ass, + bool use_jinja); // Returns an example of formatted chat -std::string common_chat_format_example( - const struct common_chat_templates * tmpls, - bool use_jinja, - const std::map & chat_template_kwargs); +std::string common_chat_format_example(const struct common_chat_templates * tmpls, + bool use_jinja, + const std::map & chat_template_kwargs); -const char* common_chat_format_name(common_chat_format format); -common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & syntax); -common_chat_msg common_chat_peg_parse(const common_peg_arena & parser, const std::string & input, bool is_partial, const common_chat_parser_params & syntax); +const char * common_chat_format_name(common_chat_format format); +common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_parser_params & params); +common_chat_msg common_chat_peg_parse(const common_peg_arena & src_parser, const std::string & input, bool is_partial, const common_chat_parser_params & params); // used by arg and server -const char * common_reasoning_format_name(common_reasoning_format format); -common_reasoning_format common_reasoning_format_from_name(const std::string & format); +const char * common_reasoning_format_name(common_reasoning_format format); +common_reasoning_format common_reasoning_format_from_name(const std::string & format); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); @@ -250,3 +296,10 @@ nlohmann::ordered_json common_chat_msg_diff_to_json_oaicompat(const common_chat_ // get template caps, useful for reporting to server /props endpoint std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates); + +std::string common_chat_template_direct_apply( + const common_chat_template & tmpl, + const autoparser::templates_params & inputs, + const std::optional & messages_override = std::nullopt, + const std::optional & tools_override = std::nullopt, + const std::optional & additional_context = std::nullopt); diff --git a/common/common.cpp b/common/common.cpp index 53bddc4ef2f..cc423d3439f 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -676,7 +676,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { size_t offset = 0; while (offset < filename.size()) { - utf8_parse_result result = parse_utf8_codepoint(filename, offset); + utf8_parse_result result = common_parse_utf8_codepoint(filename, offset); if (result.status != utf8_parse_result::SUCCESS) { return false; diff --git a/common/common.h b/common/common.h index c5a80375713..3e1b23f5d46 100644 --- a/common/common.h +++ b/common/common.h @@ -516,14 +516,15 @@ struct common_params { std::string cls_sep = "\t"; // separator of classification sequences // server params - int32_t port = 8080; // server listens on this network port - int32_t timeout_read = 600; // http read timeout in seconds - int32_t timeout_write = timeout_read; // http write timeout in seconds - int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) - int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - bool cache_prompt = true; // whether to enable prompt caching - int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot - int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. + int32_t port = 8080; // server listens on this network port + int32_t timeout_read = 600; // http read timeout in seconds + int32_t timeout_write = timeout_read; // http write timeout in seconds + int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) + int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting + bool cache_prompt = true; // whether to enable prompt caching + int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot + int32_t checkpoint_every_nt = 8192; // make a checkpoint every n tokens during prefill + int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT @@ -545,6 +546,7 @@ struct common_params { // webui configs bool webui = true; + bool webui_mcp_proxy = false; std::string webui_config_json; // "advanced" endpoints are disabled by default for better security @@ -869,7 +871,7 @@ std::string common_detokenize( // Embedding utils // -// TODO: repace embd_norm with an enum +// TODO: replace embd_norm with an enum void common_embd_normalize(const float * inp, float * out, int n, int embd_norm); float common_embd_similarity_cos(const float * embd1, const float * embd2, int n); diff --git a/common/console.cpp b/common/console.cpp index 2ea178f81ed..a770416ab7a 100644 --- a/common/console.cpp +++ b/common/console.cpp @@ -80,6 +80,8 @@ namespace console { static termios initial_state; #endif + static completion_callback completion_cb = nullptr; + // // Init and cleanup // @@ -493,7 +495,7 @@ namespace console { } static void set_line_contents(std::string new_line, std::string & line, std::vector & widths, size_t & char_pos, - size_t & byte_pos) { + size_t & byte_pos, int cursor_byte_pos = -1) { move_to_line_start(char_pos, byte_pos, widths); clear_current_line(widths); @@ -503,6 +505,7 @@ namespace console { char_pos = 0; size_t idx = 0; + int back_width = 0; while (idx < line.size()) { size_t advance = 0; char32_t cp = decode_utf8(line, idx, advance); @@ -511,8 +514,15 @@ namespace console { if (real_width < 0) real_width = 0; widths.push_back(real_width); idx += advance; - ++char_pos; - byte_pos = idx; + if (cursor_byte_pos >= 0 && static_cast(cursor_byte_pos) < idx) { + back_width += real_width; + } else { + ++char_pos; + byte_pos = idx; + } + } + if (cursor_byte_pos >= 0) { + move_cursor(-back_width); } } @@ -784,6 +794,20 @@ namespace console { break; } + if (completion_cb && input_char == '\t') { + auto candidates = completion_cb(line, byte_pos); + + if (!candidates.empty()) { + if (candidates.size() > 1 || candidates[0].first != line) { + // TODO?: Display all candidates + set_line_contents(candidates[0].first, line, widths, char_pos, byte_pos, candidates[0].second); + } else { + // TODO: Move cursor to new byte_pos + } + continue; + } + } + if (input_char == (char32_t) WEOF || input_char == 0x04 /* Ctrl+D */) { end_of_stream = true; break; @@ -1062,6 +1086,10 @@ namespace console { return readline_advanced(line, multiline_input); } + void set_completion_callback(completion_callback cb) { + completion_cb = cb; + } + namespace spinner { static const char LOADING_CHARS[] = {'|', '/', '-', '\\'}; static std::condition_variable cv_stop; diff --git a/common/console.h b/common/console.h index fad6d395316..72781bea6f6 100644 --- a/common/console.h +++ b/common/console.h @@ -4,7 +4,9 @@ #include "common.h" +#include #include +#include enum display_type { DISPLAY_TYPE_RESET = 0, @@ -21,6 +23,9 @@ namespace console { void set_display(display_type display); bool readline(std::string & line, bool multiline_input); + using completion_callback = std::function>(std::string_view, size_t)>; + void set_completion_callback(completion_callback cb); + namespace spinner { void start(); void stop(); diff --git a/common/debug.h b/common/debug.h index 0c559632586..e563b40d68f 100644 --- a/common/debug.h +++ b/common/debug.h @@ -18,7 +18,7 @@ template void common_debug_print_tensor(uint8_t * data, ggml // prints tensors that are processed in the computation graph // by default prints all tensors, but can be configured by creating a `base_callback_data` instance with // non-empty filter_patterns. See examples/debug.ccp for possible usage patterns -// The template parameter determins whether an error should be thrown whenever a NaN is encountered +// The template parameter determines whether an error should be thrown whenever a NaN is encountered // in a tensor (useful for stopping debug sessions on first erroneous tensor) // The callback data will be passed as the third parameter (user_data) template bool common_debug_cb_eval(struct ggml_tensor * t, bool ask, void * user_data); diff --git a/common/jinja/README.md b/common/jinja/README.md index 7059105ee39..8291240767e 100644 --- a/common/jinja/README.md +++ b/common/jinja/README.md @@ -63,7 +63,7 @@ The llama.cpp Jinja engine introduces `jinja::string` (see `jinja/string.h`), wh - **One-to-many** (e.g., split): result is marked `is_input` **only if ALL** input parts are marked `is_input` - **Many-to-one** (e.g., join): same as one-to-many -For string concatenation, string parts will be appended to the new string as-is, while perserving the `is_input` flag. +For string concatenation, string parts will be appended to the new string as-is, while preserving the `is_input` flag. **Enabling Input Marking:** diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index dbaaed500a8..1158d5e5d6d 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -1,3 +1,4 @@ +#include "log.h" #include "value.h" #include "runtime.h" #include "caps.h" @@ -36,12 +37,16 @@ static void caps_try_execute(jinja::program & prog, auto tools = ctx.get_val("tools"); bool success = false; + std::string result; try { jinja::runtime runtime(ctx); - runtime.execute(prog); + auto results = runtime.execute(prog); + auto parts = jinja::runtime::gather_string_parts(results); + result = parts->as_string().str(); success = true; } catch (const std::exception & e) { JJ_DEBUG("Exception during execution: %s", e.what()); + result = ""; // ignore exceptions during capability analysis } @@ -90,6 +95,8 @@ caps caps_get(jinja::program & prog) { return v->stats.ops.find(op_name) != v->stats.ops.end(); }; + JJ_DEBUG("%s\n", ">>> Running capability check: typed content"); + // case: typed content support caps_try_execute( prog, @@ -120,6 +127,7 @@ caps caps_get(jinja::program & prog) { } ); + JJ_DEBUG("%s\n", ">>> Running capability check: system prompt"); // case: system prompt support caps_try_execute( @@ -150,7 +158,9 @@ caps caps_get(jinja::program & prog) { } ); - // case: tools support + JJ_DEBUG("%s\n", ">>> Running capability check: single tool support"); + + // case: tools support: single call caps_try_execute( prog, [&]() { @@ -162,10 +172,10 @@ caps caps_get(jinja::program & prog) { }, { {"role", "assistant"}, - {"content", "Assistant message"}, + {"content", ""}, // Some templates expect content to be empty with tool calls {"tool_calls", json::array({ { - {"id", "call1"}, + {"id", "call00001"}, {"type", "function"}, {"function", { {"name", "tool1"}, @@ -173,19 +183,18 @@ caps caps_get(jinja::program & prog) { {"arg", "value"} }} }} - }, - { - {"id", "call2"}, - {"type", "function"}, - {"function", { - {"name", "tool2"}, - {"arguments", { - {"arg", "value"} - }} - }} } })} }, + { + {"role", "tool"}, + {"content", "Tool response"}, + {"tool_call_id", "call00001"} + }, + { + {"role", "assistant"}, + {"content", "The tool response was 'tool response'"} + }, { {"role", "user"}, {"content", "User message"}, @@ -199,7 +208,7 @@ caps caps_get(jinja::program & prog) { {"name", "tool"}, {"type", "function"}, {"function", { - {"name", "tool"}, + {"name", "tool1"}, {"description", "Tool description"}, {"parameters", { {"type", "object"}, @@ -224,6 +233,7 @@ caps caps_get(jinja::program & prog) { auto & tool_name = tools->at(0)->at("function")->at("name"); caps_print_stats(tool_name, "tools[0].function.name"); + caps_print_stats(tools, "tools"); if (!tool_name->stats.used) { result.supports_tools = false; } @@ -233,6 +243,93 @@ caps caps_get(jinja::program & prog) { if (!tool_calls->stats.used) { result.supports_tool_calls = false; } + } + ); + + JJ_DEBUG("%s\n", ">>> Running capability check: parallel tool support"); + + // case: tools support: parallel calls + caps_try_execute( + prog, + [&]() { + // messages + return json::array({ + { + {"role", "user"}, + {"content", "User message"}, + }, + { + {"role", "assistant"}, + {"content", ""}, // Some templates expect content to be empty with tool calls + {"tool_calls", json::array({ + { + {"id", "call00001"}, + {"type", "function"}, + {"function", { + {"name", "tool1"}, + {"arguments", { + {"arg", "value"} + }} + }} + }, + { + {"id", "call00002"}, + {"type", "function"}, + {"function", { + {"name", "tool1"}, + {"arguments", { + {"arg", "value"} + }} + }} + } + })} + }, + { + {"role", "tool"}, + {"content", "Tool response"}, + {"tool_call_id", "call00001"} + }, + { + {"role", "assistant"}, + {"content", "The tool response was 'tool response'"} + }, + { + {"role", "user"}, + {"content", "User message"}, + }, + }); + }, + [&]() { + // tools + return json::array({ + { + {"name", "tool"}, + {"type", "function"}, + {"function", { + {"name", "tool1"}, + {"description", "Tool description"}, + {"parameters", { + {"type", "object"}, + {"properties", { + {"arg", { + {"type", "string"}, + {"description", "Arg description"}, + }}, + }}, + {"required", json::array({ "arg" })}, + }}, + }}, + }, + }); + }, + [&](bool success, value & messages, value & /*tools*/) { + if (!success) { + result.supports_parallel_tool_calls = false; + return; + } + + auto & tool_calls = messages->at(1)->at("tool_calls");; + caps_print_stats(tool_calls, "messages[1].tool_calls"); // check for second tool call usage auto & tool_call_1 = tool_calls->at(1)->at("function"); @@ -243,6 +340,8 @@ caps caps_get(jinja::program & prog) { } ); + JJ_DEBUG("%s\n", ">>> Running capability check: preserve reasoning"); + // case: preserve reasoning content in chat history caps_try_execute( prog, diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 5757c76b7a1..af2282c5469 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -114,8 +114,10 @@ value binary_expression::execute_impl(context & ctx) { // Logical operators if (op.value == "and") { + JJ_DEBUG("Executing logical test: %s AND %s", left->type().c_str(), right->type().c_str()); return left_val->as_bool() ? right->execute(ctx) : std::move(left_val); } else if (op.value == "or") { + JJ_DEBUG("Executing logical test: %s OR %s", left->type().c_str(), right->type().c_str()); return left_val->as_bool() ? std::move(left_val) : right->execute(ctx); } @@ -838,7 +840,7 @@ value call_expression::execute_impl(context & ctx) { for (auto & arg_stmt : this->args) { auto arg_val = arg_stmt->execute(ctx); JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); - args.push_back(std::move(arg_val)); + args.push_back(arg_val); } // execute callee value callee_val = callee->execute(ctx); diff --git a/common/jinja/value.h b/common/jinja/value.h index 07e447ff696..6cbedefd96e 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -12,8 +12,8 @@ #include #include #include -#include #include +#include namespace jinja { diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index 2f67c74d796..27f13f034ed 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -27,11 +27,11 @@ static std::string build_repetition(const std::string & item_rule, int min_items if (separator_rule.empty()) { if (min_items == 1 && !has_max) { return item_rule + "+"; - } else if (min_items == 0 && !has_max) { + } + if (min_items == 0 && !has_max) { return item_rule + "*"; - } else { - return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } + return item_rule + "{" + std::to_string(min_items) + "," + (has_max ? std::to_string(max_items) : "") + "}"; } auto result = item_rule + " " + build_repetition("(" + separator_rule + " " + item_rule + ")", min_items == 0 ? 0 : min_items - 1, has_max ? max_items - 1 : max_items); @@ -41,7 +41,7 @@ static std::string build_repetition(const std::string & item_rule, int min_items return result; } -static void _build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { +static void build_min_max_int(int64_t min_value, int64_t max_value, std::stringstream & out, int decimals_left = 16, bool top_level = true) { auto has_min = min_value != std::numeric_limits::min(); auto has_max = max_value != std::numeric_limits::max(); @@ -128,14 +128,14 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string if (has_min && has_max) { if (min_value < 0 && max_value < 0) { out << "\"-\" ("; - _build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); + build_min_max_int(-max_value, -min_value, out, decimals_left, /* top_level= */ true); out << ")"; return; } if (min_value < 0) { out << "\"-\" ("; - _build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); + build_min_max_int(0, -min_value, out, decimals_left, /* top_level= */ true); out << ") | "; min_value = 0; } @@ -159,7 +159,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string if (has_min) { if (min_value < 0) { out << "\"-\" ("; - _build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); + build_min_max_int(std::numeric_limits::min(), -min_value, out, decimals_left, /* top_level= */ false); out << ") | [0] | [1-9] "; more_digits(0, decimals_left - 1); } else if (min_value == 0) { @@ -194,7 +194,7 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string } digit_range(c, c); out << " ("; - _build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); + build_min_max_int(std::stoll(min_s.substr(1)), std::numeric_limits::max(), out, less_decimals, /* top_level= */ false); out << ")"; if (c < '9') { out << " | "; @@ -213,10 +213,10 @@ static void _build_min_max_int(int64_t min_value, int64_t max_value, std::string more_digits(0, less_decimals); out << " | "; } - _build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); + build_min_max_int(0, max_value, out, decimals_left, /* top_level= */ true); } else { out << "\"-\" ("; - _build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); + build_min_max_int(-max_value, std::numeric_limits::max(), out, decimals_left, /* top_level= */ false); out << ")"; } return; @@ -232,7 +232,7 @@ struct BuiltinRule { std::vector deps; }; -std::unordered_map PRIMITIVE_RULES = { +static std::unordered_map PRIMITIVE_RULES = { {"boolean", {"(\"true\" | \"false\") space", {}}}, {"decimal-part", {"[0-9]{1,16}", {}}}, {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, @@ -247,7 +247,7 @@ std::unordered_map PRIMITIVE_RULES = { {"null", {"\"null\" space", {}}}, }; -std::unordered_map STRING_FORMAT_RULES = { +static std::unordered_map STRING_FORMAT_RULES = { {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, @@ -260,22 +260,26 @@ static bool is_reserved_name(const std::string & name) { static const std::unordered_set RESERVED_NAMES = [] { std::unordered_set s; s.insert("root"); - for (const auto & p : PRIMITIVE_RULES) s.insert(p.first); - for (const auto & p : STRING_FORMAT_RULES) s.insert(p.first); + for (const auto & p : PRIMITIVE_RULES) { + s.insert(p.first); + } + for (const auto & p : STRING_FORMAT_RULES) { + s.insert(p.first); + } return s; }(); return RESERVED_NAMES.find(name) != RESERVED_NAMES.end(); } -std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); -std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); -std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); -std::unordered_map GRAMMAR_LITERAL_ESCAPES = { +static std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); +static std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]"); +static std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); +static std::unordered_map GRAMMAR_LITERAL_ESCAPES = { {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"} }; -std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; -std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; +static std::unordered_set NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; +static std::unordered_set ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = {'^', '$', '.', '[', ']', '(', ')', '|', '{', '}', '*', '+', '?'}; static std::string replacePattern(const std::string & input, const std::regex & regex, const std::function & replacement) { std::smatch match; @@ -322,19 +326,19 @@ class common_schema_converter { if (_rules.find(esc_name) == _rules.end() || _rules[esc_name] == rule) { _rules[esc_name] = rule; return esc_name; - } else { - int i = 0; - while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { - i++; - } - std::string key = esc_name + std::to_string(i); - _rules[key] = rule; - return key; } + int i = 0; + while (_rules.find(esc_name + std::to_string(i)) != _rules.end() && _rules[esc_name + std::to_string(i)] != rule) { + i++; + } + std::string key = esc_name + std::to_string(i); + _rules[key] = rule; + return key; } std::string _generate_union_rule(const std::string & name, const std::vector & alt_schemas) { std::vector rules; + rules.reserve(alt_schemas.size()); for (size_t i = 0; i < alt_schemas.size(); i++) { rules.push_back(visit(alt_schemas[i], name + (name.empty() ? "alternative-" : "-") + std::to_string(i))); } @@ -398,6 +402,7 @@ class common_schema_converter { flush_literal(); std::vector results; + results.reserve(ret.size()); for (const auto & item : ret) { results.push_back(to_rule(item)); } @@ -551,7 +556,7 @@ class common_schema_converter { TrieNode() : is_end_of_string(false) {} void insert(const std::string & string) { - auto node = this; + auto *node = this; for (char c : string) { node = &node->children[c]; } @@ -676,7 +681,7 @@ class common_schema_converter { if (ks.empty()) { return res; } - std::string k = ks[0]; + const std::string& k = ks[0]; std::string kv_rule_name = prop_kv_rule_names[k]; std::string comma_ref = "( \",\" space " + kv_rule_name + " )"; if (first_is_optional) { @@ -779,7 +784,7 @@ class common_schema_converter { std::string pointer = ref.substr(ref.find('#') + 1); std::vector tokens = string_split(pointer, "/"); for (size_t i = 1; i < tokens.size(); ++i) { - std::string sel = tokens[i]; + const std::string& sel = tokens[i]; if (target.is_object() && target.contains(sel)) { target = target[sel]; } else if (target.is_array()) { @@ -802,7 +807,7 @@ class common_schema_converter { _refs[ref] = target; } } else { - for (auto & kv : n.items()) { + for (const auto & kv : n.items()) { visit_refs(kv.value()); } } @@ -812,7 +817,7 @@ class common_schema_converter { visit_refs(schema); } - std::string _generate_constant_rule(const json & value) { + static std::string _generate_constant_rule(const json & value) { return format_literal(value.dump()); } @@ -823,10 +828,12 @@ class common_schema_converter { if (schema.contains("$ref")) { return _add_rule(rule_name, _resolve_ref(schema["$ref"])); - } else if (schema.contains("oneOf") || schema.contains("anyOf")) { + } + if (schema.contains("oneOf") || schema.contains("anyOf")) { std::vector alt_schemas = schema.contains("oneOf") ? schema["oneOf"].get>() : schema["anyOf"].get>(); return _add_rule(rule_name, _generate_union_rule(name, alt_schemas)); - } else if (schema_type.is_array()) { + } + if (schema_type.is_array()) { std::vector schema_types; for (const auto & t : schema_type) { json schema_copy(schema); @@ -834,15 +841,18 @@ class common_schema_converter { schema_types.push_back(schema_copy); } return _add_rule(rule_name, _generate_union_rule(name, schema_types)); - } else if (schema.contains("const")) { + } + if (schema.contains("const")) { return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); - } else if (schema.contains("enum")) { + } + if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); - } else if ((schema_type.is_null() || schema_type == "object") + } + if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || (schema.contains("additionalProperties") && schema["additionalProperties"] != true))) { std::unordered_set required; @@ -863,11 +873,12 @@ class common_schema_converter { _build_object_rule( properties, required, name, schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); - } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { + } + if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) { std::unordered_set required; std::vector> properties; std::map enum_values; - std::string hybrid_name = name; + const std::string& hybrid_name = name; std::function add_component = [&](const json & comp_schema, bool is_required) { if (comp_schema.contains("$ref")) { add_component(_refs[comp_schema["$ref"]], is_required); @@ -890,9 +901,9 @@ class common_schema_converter { // todo warning } }; - for (auto & t : schema["allOf"]) { + for (const auto & t : schema["allOf"]) { if (t.contains("anyOf")) { - for (auto & tt : t["anyOf"]) { + for (const auto & tt : t["anyOf"]) { add_component(tt, false); } } else { @@ -911,7 +922,8 @@ class common_schema_converter { } } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); - } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { + } + if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; if (items.is_array()) { std::string rule = "\"[\" space "; @@ -923,27 +935,31 @@ class common_schema_converter { } rule += " \"]\" space"; return _add_rule(rule_name, rule); - } else { - std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); - int min_items = schema.contains("minItems") ? schema["minItems"].get() : 0; - json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); - int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); - - return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); } - } else if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { + std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); + int min_items = schema.contains("minItems") ? schema["minItems"].get() : 0; + json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); + int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); + + return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); + } + if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { return _visit_pattern(schema["pattern"], rule_name); - } else if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) { + } + if ((schema_type.is_null() || schema_type == "string") && std::regex_match(schema_format, std::regex("^uuid[1-5]?$"))) { return _add_primitive(rule_name == "root" ? "root" : schema_format, PRIMITIVE_RULES.at("uuid")); - } else if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) { + } + if ((schema_type.is_null() || schema_type == "string") && STRING_FORMAT_RULES.find(schema_format + "-string") != STRING_FORMAT_RULES.end()) { auto prim_name = schema_format + "-string"; return _add_rule(rule_name, _add_primitive(prim_name, STRING_FORMAT_RULES.at(prim_name))); - } else if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) { + } + if (schema_type == "string" && (schema.contains("minLength") || schema.contains("maxLength"))) { std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); - } else if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { + } + if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { int64_t min_value = std::numeric_limits::min(); int64_t max_value = std::numeric_limits::max(); if (schema.contains("minimum")) { @@ -958,19 +974,24 @@ class common_schema_converter { } std::stringstream out; out << "("; - _build_min_max_int(min_value, max_value, out); + build_min_max_int(min_value, max_value, out); out << ") space"; return _add_rule(rule_name, out.str()); - } else if (schema.empty() || schema_type == "object") { + } + if (schema.empty() || schema_type == "object") { return _add_rule(rule_name, _add_primitive("object", PRIMITIVE_RULES.at("object"))); - } else { - if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) { - _errors.push_back("Unrecognized schema: " + schema.dump()); - return ""; - } - // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero - return _add_primitive(rule_name == "root" ? "root" : schema_type.get(), PRIMITIVE_RULES.at(schema_type.get())); } + if (schema_type.is_null() && schema.is_object()) { + // No type constraint and no recognized structural keywords (e.g. {"description": "..."}). + // Per JSON Schema semantics this is equivalent to {} and accepts any value. + return _add_rule(rule_name, _add_primitive("value", PRIMITIVE_RULES.at("value"))); + } + if (!schema_type.is_string() || PRIMITIVE_RULES.find(schema_type.get()) == PRIMITIVE_RULES.end()) { + _errors.push_back("Unrecognized schema: " + schema.dump()); + return ""; + } + // TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero + return _add_primitive(rule_name == "root" ? "root" : schema_type.get(), PRIMITIVE_RULES.at(schema_type.get())); } void check_errors() { @@ -985,7 +1006,7 @@ class common_schema_converter { std::string format_grammar() { std::stringstream ss; for (const auto & kv : _rules) { - ss << kv.first << " ::= " << kv.second << std::endl; + ss << kv.first << " ::= " << kv.second << '\n'; } return ss.str(); } diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index f2fc84500f7..48379f1ec89 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -1,14 +1,15 @@ -#include "common.h" #include "peg-parser.h" + +#include "common.h" #include "json-schema-to-grammar.h" +#include "log.h" #include "unicode.h" -#include - #include #include #include #include +#include #include #include #include @@ -34,8 +35,7 @@ static bool is_hex_digit(const char c) { // This is used in common_peg_until_parser and to build a GBNF exclusion grammar struct trie { struct node { - size_t depth = 0; - std::map children; + std::map children; // Use uint32_t to store Unicode codepoints bool is_word; }; @@ -55,15 +55,22 @@ struct trie { size_t current = 0; // Start at root size_t pos = start_pos; + // LOG_DBG("%s: checking at pos %zu, sv='%s'\n", __func__, start_pos, std::string(sv).c_str()); + while (pos < sv.size()) { - auto it = nodes[current].children.find(sv[pos]); + auto result = common_parse_utf8_codepoint(sv, pos); + if (result.status != utf8_parse_result::SUCCESS) { + break; + } + + auto it = nodes[current].children.find(result.codepoint); if (it == nodes[current].children.end()) { // Can't continue matching return match_result{match_result::NO_MATCH}; } current = it->second; - pos++; + pos += result.bytes_consumed; // Check if we've matched a complete word if (nodes[current].is_word) { @@ -82,22 +89,22 @@ struct trie { } struct prefix_and_next { - std::string prefix; - std::string next_chars; + std::vector prefix; + std::vector next_chars; }; std::vector collect_prefix_and_next() { - std::string prefix; + std::vector prefix; std::vector result; collect_prefix_and_next(0, prefix, result); return result; } private: - void collect_prefix_and_next(size_t index, std::string & prefix, std::vector & out) { + void collect_prefix_and_next(size_t index, std::vector & prefix, std::vector & out) { if (!nodes[index].is_word) { if (!nodes[index].children.empty()) { - std::string chars; + std::vector chars; chars.reserve(nodes[index].children.size()); for (const auto & p : nodes[index].children) { chars.push_back(p.first); @@ -107,7 +114,7 @@ struct trie { } for (const auto & p : nodes[index].children) { - unsigned char ch = p.first; + uint32_t ch = p.first; auto child = p.second; prefix.push_back(ch); collect_prefix_and_next(child, prefix, out); @@ -123,11 +130,19 @@ struct trie { void insert(const std::string & word) { size_t current = 0; - for (unsigned char ch : word) { + size_t pos = 0; + while (pos < word.length()) { + auto result = common_parse_utf8_codepoint(word, pos); + if (result.status != utf8_parse_result::SUCCESS) { + break; + } + + uint32_t ch = result.codepoint; + pos += result.bytes_consumed; + auto it = nodes[current].children.find(ch); if (it == nodes[current].children.end()) { size_t child = create_node(); - nodes[child].depth = nodes[current].depth + 1; nodes[current].children[ch] = child; current = child; } else { @@ -286,6 +301,32 @@ struct parser_executor { parser_executor(const common_peg_arena & arena, common_peg_parse_context & ctx, size_t start) : arena(arena), ctx(ctx), start_pos(start) {} + std::string debug_indent() const { return std::string(ctx.parse_depth * 2, ' '); } + + std::string debug_input_snippet(size_t pos, size_t len = 60) const { + if (pos >= ctx.input.size()) { + return ""; + } + auto snippet = ctx.input.substr(pos, len); + // Escape newlines for display + std::string result; + for (char c : snippet) { + if (c == '\n') { + result += "\\n"; + } else if (c == '\r') { + result += "\\r"; + } else if (c == '\t') { + result += "\\t"; + } else { + result += c; + } + } + if (pos + len < ctx.input.size()) { + result += "..."; + } + return result; + } + common_peg_parse_result operator()(const common_peg_epsilon_parser & /* p */) const { return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos); } @@ -323,12 +364,39 @@ struct parser_executor { } common_peg_parse_result operator()(const common_peg_sequence_parser & p) { + if (ctx.debug) { + LOG_DBG("%sSEQ start at %zu '%s' (%zu children)\n", debug_indent().c_str(), start_pos, + debug_input_snippet(start_pos).c_str(), p.children.size()); + } + ctx.parse_depth++; + auto pos = start_pos; std::vector nodes; - for (const auto & child_id : p.children) { + for (size_t i = 0; i < p.children.size(); i++) { + const auto & child_id = p.children[i]; + if (ctx.debug) { + fprintf(stderr, "%sSEQ child %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str()); + } auto result = arena.parse(child_id, ctx, pos); + + if (ctx.debug) { + fprintf(stderr, "%sSEQ child %zu: %s at %zu->%zu\n", debug_indent().c_str(), i, + common_peg_parse_result_type_name(result.type), result.start, result.end); + } + if (result.fail()) { + ctx.parse_depth--; + if (ctx.is_partial && result.end >= ctx.input.size()) { + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> NEED_MORE (child failed at end)\n", debug_indent().c_str()); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, + std::move(nodes)); + } + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> FAIL\n", debug_indent().c_str()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, result.end); } @@ -337,28 +405,65 @@ struct parser_executor { } if (result.need_more_input()) { + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> NEED_MORE\n", debug_indent().c_str()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); } pos = result.end; } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sSEQ -> SUCCESS at %zu->%zu\n", debug_indent().c_str(), start_pos, pos); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); } common_peg_parse_result operator()(const common_peg_choice_parser & p) { + if (ctx.debug) { + fprintf(stderr, "%sCHOICE start at %zu '%s' (%zu options)\n", debug_indent().c_str(), start_pos, + debug_input_snippet(start_pos).c_str(), p.children.size()); + } + ctx.parse_depth++; + auto pos = start_pos; - for (const auto & child_id : p.children) { + for (size_t i = 0; i < p.children.size(); i++) { + const auto & child_id = p.children[i]; + if (ctx.debug) { + fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, arena.dump(child_id).c_str()); + } auto result = arena.parse(child_id, ctx, pos); + if (ctx.debug) { + fprintf(stderr, "%sCHOICE option %zu: %s\n", debug_indent().c_str(), i, + common_peg_parse_result_type_name(result.type)); + } if (!result.fail()) { + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sCHOICE -> %s (option %zu)\n", debug_indent().c_str(), + common_peg_parse_result_type_name(result.type), i); + } return result; } } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sCHOICE -> FAIL (no options matched)\n", debug_indent().c_str()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); } common_peg_parse_result operator()(const common_peg_repetition_parser & p) { + if (ctx.debug) { + fprintf(stderr, "%sREPEAT start at %zu '%s' (min=%d, max=%d)\n", debug_indent().c_str(), start_pos, + debug_input_snippet(start_pos).c_str(), p.min_count, p.max_count); + } + ctx.parse_depth++; + auto pos = start_pos; int match_count = 0; std::vector nodes; @@ -366,14 +471,26 @@ struct parser_executor { // Try to match up to max_count times (or unlimited if max_count is -1) while (p.max_count == -1 || match_count < p.max_count) { if (pos >= ctx.input.size()) { + if (ctx.debug) { + fprintf(stderr, "%sREPEAT: at end of input, count=%d\n", debug_indent().c_str(), match_count); + } break; } auto result = arena.parse(p.child, ctx, pos); + if (ctx.debug) { + fprintf(stderr, "%sREPEAT iter %d: %s at %zu->%zu, nodes=%zu\n", debug_indent().c_str(), match_count, + common_peg_parse_result_type_name(result.type), result.start, result.end, result.nodes.size()); + fprintf(stderr, "%sREPEAT CHILD: %s\n", debug_indent().c_str(), arena.dump(p.child).c_str()); + } + if (result.success()) { // Prevent infinite loop on empty matches if (result.end == pos) { + if (ctx.debug) { + fprintf(stderr, "%s REPEAT: empty match, stopping\n", debug_indent().c_str()); + } break; } @@ -391,21 +508,43 @@ struct parser_executor { nodes.insert(nodes.end(), result.nodes.begin(), result.nodes.end()); } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> NEED_MORE (count=%d, nodes=%zu)\n", debug_indent().c_str(), + match_count, nodes.size()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, result.end, std::move(nodes)); } // Child failed - stop trying + if (ctx.debug) { + fprintf(stderr, "%sREPEAT: child failed, stopping\n", debug_indent().c_str()); + } break; } // Check if we got enough matches if (p.min_count > 0 && match_count < p.min_count) { + ctx.parse_depth--; if (pos >= ctx.input.size() && ctx.is_partial) { + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> NEED_MORE (not enough matches: %d < %d)\n", debug_indent().c_str(), + match_count, p.min_count); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos, std::move(nodes)); } + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> FAIL (not enough matches: %d < %d)\n", debug_indent().c_str(), match_count, + p.min_count); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); } + ctx.parse_depth--; + if (ctx.debug) { + fprintf(stderr, "%sREPEAT -> SUCCESS (count=%d, nodes=%zu)\n", debug_indent().c_str(), match_count, + nodes.size()); + } return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos, std::move(nodes)); } @@ -434,7 +573,7 @@ struct parser_executor { common_peg_parse_result operator()(const common_peg_any_parser & /* p */) const { // Parse a single UTF-8 codepoint (not just a single byte) - auto result = parse_utf8_codepoint(ctx.input, start_pos); + auto result = common_parse_utf8_codepoint(ctx.input, start_pos); if (result.status == utf8_parse_result::INCOMPLETE) { if (!ctx.is_partial) { @@ -468,7 +607,7 @@ struct parser_executor { // Try to match up to max_count times (or unlimited if max_count is -1) while (p.max_count == -1 || match_count < p.max_count) { - auto result = parse_utf8_codepoint(ctx.input, pos); + auto result = common_parse_utf8_codepoint(ctx.input, pos); if (result.status == utf8_parse_result::INCOMPLETE) { if (match_count >= p.min_count) { @@ -537,6 +676,7 @@ struct parser_executor { switch (ctx.input[pos]) { case '"': + case '\'': case '\\': case '/': case 'b': @@ -589,7 +729,49 @@ struct parser_executor { return result; } } else { - auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos); + + if (utf8_result.status == utf8_parse_result::INCOMPLETE) { + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + if (utf8_result.status == utf8_parse_result::INVALID) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos); + } + + pos += utf8_result.bytes_consumed; + } + } + + // Reached end without finding closing quote + if (!ctx.is_partial) { + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_FAIL, start_pos, pos); + } + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_NEED_MORE_INPUT, start_pos, pos); + } + + common_peg_parse_result operator()(const common_peg_python_dict_string_parser & /* p */) { + auto pos = start_pos; + + // Parse string content (without quotes) + while (pos < ctx.input.size()) { + char c = ctx.input[pos]; + + if (c == '\'') { + // Found closing quote - success (don't consume it) + return common_peg_parse_result(COMMON_PEG_PARSE_RESULT_SUCCESS, start_pos, pos); + } + + if (c == '\\') { + auto result = handle_escape_sequence(ctx, start_pos, pos); + if (!result.success()) { + return result; + } + } else { + auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos); if (utf8_result.status == utf8_parse_result::INCOMPLETE) { if (!ctx.is_partial) { @@ -621,7 +803,7 @@ struct parser_executor { size_t last_valid_pos = start_pos; while (pos < ctx.input.size()) { - auto utf8_result = parse_utf8_codepoint(ctx.input, pos); + auto utf8_result = common_parse_utf8_codepoint(ctx.input, pos); if (utf8_result.status == utf8_parse_result::INCOMPLETE) { // Incomplete UTF-8 sequence @@ -694,6 +876,9 @@ struct parser_executor { common_peg_parse_result operator()(const common_peg_tag_parser & p) { // Parse the child + if (ctx.debug) { + fprintf(stderr, "%sTAG: %s\n", debug_indent().c_str(), p.tag.c_str()); + } auto result = arena.parse(p.child, ctx, start_pos); if (!result.fail()) { @@ -755,6 +940,31 @@ common_peg_parser_id common_peg_arena::resolve_ref(common_peg_parser_id id) { return id; } +static void bfs_node(common_peg_ast_arena &arena, std::ostringstream & oss, const common_peg_ast_node & node, int indent) { + for (int i = 0; i < indent; i++) { + oss << " "; + } + oss << "NODE " << node.id; + if (!node.rule.empty()) { + oss << " (rule " << node.rule << ")"; + } + if (!node.tag.empty()) { + oss << " (tag " << node.tag << ")"; + } + oss << " ['" << node.text << "']\n"; + for (const auto child : node.children) { + bfs_node(arena, oss, arena.get(child), indent + 1); + } +} + +std::string common_peg_ast_arena::dump() { + std::ostringstream oss; + for (auto & node : nodes_) { + bfs_node(*this, oss, node, 0); + } + return oss.str(); +} + void common_peg_arena::resolve_refs() { // Walk through all parsers and replace refs with their corresponding rule IDs for (auto & parser : parsers_) { @@ -786,6 +996,7 @@ void common_peg_arena::resolve_refs() { std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v || std::is_same_v || std::is_same_v) { @@ -803,9 +1014,21 @@ void common_peg_arena::resolve_refs() { } std::string common_peg_arena::dump(common_peg_parser_id id) const { + std::unordered_set visited; + return dump_impl(id, visited); +} + +std::string common_peg_arena::dump_impl(common_peg_parser_id id, + std::unordered_set & visited) const { + // Check for cycles + if (visited.count(id)) { + return "[cycle]"; + } + visited.insert(id); + const auto & parser = parsers_.at(id); - return std::visit([this](const auto & p) -> std::string { + return std::visit([this, &visited](const auto & p) -> std::string { using T = std::decay_t; if constexpr (std::is_same_v) { @@ -819,24 +1042,27 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const { } else if constexpr (std::is_same_v) { std::vector parts; for (const auto & child : p.children) { - parts.push_back(dump(child)); + parts.push_back(dump_impl(child, visited)); } return "Sequence(" + string_join(parts, ", ") + ")"; } else if constexpr (std::is_same_v) { std::vector parts; for (const auto & child : p.children) { - parts.push_back(dump(child)); + parts.push_back(dump_impl(child, visited)); } return "Choice(" + string_join(parts, ", ") + ")"; } else if constexpr (std::is_same_v) { if (p.max_count == -1) { - return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", unbounded)"; + return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) + + ", unbounded)"; } - return "Repetition(" + dump(p.child) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; + return "Repetition(" + dump_impl(p.child, visited) + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; } else if constexpr (std::is_same_v) { - return "And(" + dump(p.child) + ")"; + return "And(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { - return "Not(" + dump(p.child) + ")"; + return "Not(" + dump_impl(p.child, visited) + ")"; + } else if constexpr (std::is_same_v) { + return "Atomic(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Any"; } else if constexpr (std::is_same_v) { @@ -848,14 +1074,20 @@ std::string common_peg_arena::dump(common_peg_parser_id id) const { return "CharRepeat(" + p.pattern + ", " + std::to_string(p.min_count) + ", " + std::to_string(p.max_count) + ")"; } else if constexpr (std::is_same_v) { return "JsonString()"; + } else if constexpr (std::is_same_v) { + return "PythonDictString()"; } else if constexpr (std::is_same_v) { return "Until(" + string_join(p.delimiters, " | ") + ")"; } else if constexpr (std::is_same_v) { - return "Schema(" + dump(p.child) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; + return "Schema(" + dump_impl(p.child, visited) + ", " + (p.schema ? p.schema->dump() : "null") + ")"; } else if constexpr (std::is_same_v) { - return "Rule(" + p.name + ", " + dump(p.child) + ")"; + return "Rule(" + p.name + ", " + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Ref(" + p.name + ")"; + } else if constexpr (std::is_same_v) { + return "Tag(" + p.tag + ", " + dump(p.child) + ")"; + } else if constexpr (std::is_same_v) { + return "Atomic(" + dump(p.child) + ")"; } else { return "Unknown"; } @@ -1054,7 +1286,54 @@ common_peg_arena common_peg_parser_builder::build() { return std::move(arena_); } +// String primitives + +common_peg_parser common_peg_parser_builder::json_string_content() { + return wrap(arena_.add_parser(common_peg_json_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::single_quoted_string_content() { + return wrap(arena_.add_parser(common_peg_python_dict_string_parser{})); +} + +common_peg_parser common_peg_parser_builder::double_quoted_string() { + return rule("dq-string", + [this]() { return sequence({ literal("\""), json_string_content(), literal("\""), space() }); }); +} + +common_peg_parser common_peg_parser_builder::single_quoted_string() { + return rule("sq-string", + [this]() { return sequence({ literal("'"), single_quoted_string_content(), literal("'"), space() }); }); +} + +common_peg_parser common_peg_parser_builder::flexible_string() { + return rule("flexible-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); }); +} + +// Generic helpers for object/array structure + +common_peg_parser common_peg_parser_builder::generic_object(const std::string & name, + const common_peg_parser & string_parser, + const common_peg_parser & value_parser) { + return rule(name, [this, string_parser, value_parser]() { + auto ws = space(); + auto member = sequence({ string_parser, ws, literal(":"), ws, value_parser }); + auto members = sequence({ member, zero_or_more(sequence({ ws, literal(","), ws, member })) }); + return sequence({ literal("{"), ws, choice({ literal("}"), sequence({ members, ws, literal("}") }) }) }); + }); +} + +common_peg_parser common_peg_parser_builder::generic_array(const std::string & name, + const common_peg_parser & value_parser) { + return rule(name, [this, value_parser]() { + auto ws = space(); + auto elements = sequence({ value_parser, zero_or_more(sequence({ literal(","), ws, value_parser })) }); + return sequence({ literal("["), ws, choice({ literal("]"), sequence({ elements, ws, literal("]") }) }) }); + }); +} + // JSON parsers + common_peg_parser common_peg_parser_builder::json_number() { return rule("json-number", [this]() { auto digit1_9 = chars("[1-9]", 1, 1); @@ -1062,7 +1341,11 @@ common_peg_parser common_peg_parser_builder::json_number() { auto int_part = choice({literal("0"), sequence({digit1_9, chars("[0-9]", 0, -1)})}); auto frac = sequence({literal("."), digits}); auto exp = sequence({choice({literal("e"), literal("E")}), optional(chars("[+-]", 1, 1)), digits}); - return sequence({optional(literal("-")), int_part, optional(frac), optional(exp), space()}); + // Negative lookahead: only commit the number when the next character can't extend it. + // At EOF in partial mode, chars returns NEED_MORE → negate propagates NEED_MORE → number not committed. + // This prevents premature commits of partial numbers (e.g. "3" when "3.14" is incoming). + auto not_number_continuation = negate(chars("[0-9.eE+-]", 1, 1)); + return sequence({ optional(literal("-")), int_part, optional(frac), optional(exp), not_number_continuation, space() }); }); } @@ -1085,36 +1368,11 @@ common_peg_parser common_peg_parser_builder::json_null() { } common_peg_parser common_peg_parser_builder::json_object() { - return rule("json-object", [this]() { - auto ws = space(); - auto member = sequence({json_string(), ws, literal(":"), ws, json()}); - auto members = sequence({member, zero_or_more(sequence({ws, literal(","), ws, member}))}); - return sequence({ - literal("{"), - ws, - choice({ - literal("}"), - sequence({members, ws, literal("}")}) - }), - ws - }); - }); + return generic_object("json-object", json_string(), json()); } common_peg_parser common_peg_parser_builder::json_array() { - return rule("json-array", [this]() { - auto ws = space(); - auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); - return sequence({ - literal("["), - ws, - choice({ - literal("]"), - sequence({elements, ws, literal("]")}) - }), - ws - }); - }); + return generic_array("json-array", json()); } common_peg_parser common_peg_parser_builder::json() { @@ -1130,8 +1388,40 @@ common_peg_parser common_peg_parser_builder::json() { }); } -common_peg_parser common_peg_parser_builder::json_string_content() { - return wrap(arena_.add_parser(common_peg_json_string_parser{})); +common_peg_parser common_peg_parser_builder::python_string() { + return rule("python-string", [this]() { return choice({ double_quoted_string(), single_quoted_string() }); }); +} + +common_peg_parser common_peg_parser_builder::python_number() { + return json_number(); +} + +common_peg_parser common_peg_parser_builder::python_bool() { + return rule("python-bool", [this]() { return sequence({ choice({ literal("True"), literal("False") }), space() }); }); +} + +common_peg_parser common_peg_parser_builder::python_null() { + return rule("python-none", [this]() { return sequence({ literal("None"), space() }); }); +} + +common_peg_parser common_peg_parser_builder::python_dict() { + return generic_object("python-dict", python_string(), python_value()); +} + +common_peg_parser common_peg_parser_builder::python_array() { + return generic_array("python-array", python_value()); +} + +common_peg_parser common_peg_parser_builder::python_value() { + return rule("python-value", [this]() { + return choice({ python_dict(), python_array(), python_string(), python_number(), python_bool(), python_null() }); + }); +} + +common_peg_parser common_peg_parser_builder::marker() { + auto sharp_bracket_parser = literal("<") + until(">") + literal(">"); + auto square_bracket_parser = literal("[") + until("]") + literal("]"); + return choice({ sharp_bracket_parser, square_bracket_parser }); } common_peg_parser common_peg_parser_builder::json_member(const std::string & key, const common_peg_parser & p) { @@ -1145,17 +1435,54 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key }); } +static std::string gbnf_escape_char_class(uint32_t c) { + if (c == '-' || c == ']' || c == '[' || c == '\\') { + return "\\" + std::string(1, (char) c); + } + // Escape whitespace control characters + if (c == '\n') { + return "\\n"; + } + if (c == '\t') { + return "\\t"; + } + if (c == '\r') { + return "\\r"; + } + + // Printable ASCII + if (c >= 0x20 && c <= 0x7E) { + return std::string(1, (char) c); + } + + // Hex escape + char buf[16]; + const char * hex = "0123456789ABCDEF"; -static std::string gbnf_escape_char_class(char c) { - switch (c) { - case '\n': return "\\n"; - case '\t': return "\\t"; - case '\r': return "\\r"; - case '\\': return "\\\\"; - case ']': return "\\]"; - case '[': return "\\["; - default: return std::string(1, c); + if (c <= 0xFF) { + buf[0] = '\\'; + buf[1] = 'x'; + buf[2] = hex[(c >> 4) & 0xF]; + buf[3] = hex[c & 0xF]; + buf[4] = '\0'; + } else if (c <= 0xFFFF) { + buf[0] = '\\'; + buf[1] = 'u'; + buf[2] = hex[(c >> 12) & 0xF]; + buf[3] = hex[(c >> 8) & 0xF]; + buf[4] = hex[(c >> 4) & 0xF]; + buf[5] = hex[c & 0xF]; + buf[6] = '\0'; + } else { + buf[0] = '\\'; + buf[1] = 'U'; + for (int i = 0; i < 8; i++) { + buf[2 + i] = hex[(c >> ((7 - i) * 4)) & 0xF]; + } + buf[10] = '\0'; } + + return std::string(buf); } static std::string gbnf_excluding_pattern(const std::vector & strings) { @@ -1173,12 +1500,12 @@ static std::string gbnf_excluding_pattern(const std::vector & strin std::string cls; cls.reserve(chars.size()); - for (const auto & ch : chars) { + for (uint32_t ch : chars) { cls += gbnf_escape_char_class(ch); } if (!pre.empty()) { - pattern += gbnf_format_literal(pre) + " [^" + cls + "]"; + pattern += gbnf_format_literal(common_unicode_cpts_to_utf8(pre)) + " [^" + cls + "]"; } else { pattern += "[^" + cls + "]"; } @@ -1208,7 +1535,8 @@ static std::unordered_set collect_reachable_rules( std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { // These parsers do not have any children } else if constexpr (std::is_same_v) { for (auto child : p.children) { @@ -1346,6 +1674,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return result + "{" + std::to_string(p.min_count) + "," + std::to_string(p.max_count) + "}"; } else if constexpr (std::is_same_v) { return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; + } else if constexpr (std::is_same_v) { + return R"(( [^"\\] | "\\" ( ["\\/ bfnrt] | "u" [0-9a-fA-F]{4} ) )*)"; } else if constexpr (std::is_same_v) { if (p.delimiters.empty()) { return ".*"; @@ -1477,6 +1807,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & }; } else if constexpr (std::is_same_v) { return json{{"type", "json_string"}}; + } else if constexpr (std::is_same_v) { + return json{{ "type", "python_dict_string" }}; } else if constexpr (std::is_same_v) { return json{{"type", "until"}, {"delimiters", p.delimiters}}; } else if constexpr (std::is_same_v) { @@ -1606,6 +1938,9 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json if (type == "json_string") { return common_peg_json_string_parser{}; } + if (type == "python_dict_string") { + return common_peg_python_dict_string_parser{}; + } if (type == "until") { if (!j.contains("delimiters") || !j["delimiters"].is_array()) { throw std::runtime_error("until parser missing or invalid 'delimiters' field"); diff --git a/common/peg-parser.h b/common/peg-parser.h index 1cd640365f2..57d4bcd8eaa 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -111,6 +112,8 @@ class common_peg_ast_arena { void visit(common_peg_ast_id id, const common_peg_ast_visitor & visitor) const; void visit(const common_peg_parse_result & result, const common_peg_ast_visitor & visitor) const; + + std::string dump(); }; struct common_peg_parse_result { @@ -139,6 +142,7 @@ struct common_peg_parse_result { struct common_peg_parse_context { std::string input; bool is_partial; + bool debug = false; // Enable debug output for parser tracing common_peg_ast_arena ast; int parse_depth; @@ -207,6 +211,7 @@ struct common_peg_chars_parser { }; struct common_peg_json_string_parser {}; +struct common_peg_python_dict_string_parser {}; struct common_peg_until_parser { std::vector delimiters; @@ -255,6 +260,7 @@ using common_peg_parser_variant = std::variant< common_peg_space_parser, common_peg_chars_parser, common_peg_json_string_parser, + common_peg_python_dict_string_parser, common_peg_until_parser, common_peg_schema_parser, common_peg_rule_parser, @@ -299,6 +305,8 @@ class common_peg_arena { friend class common_peg_parser_builder; private: + std::string dump_impl(common_peg_parser_id id, std::unordered_set & visited) const; + common_peg_parser_id add_parser(common_peg_parser_variant parser); void add_rule(const std::string & name, common_peg_parser_id id); @@ -311,6 +319,10 @@ class common_peg_parser_builder { common_peg_parser wrap(common_peg_parser_id id) { return common_peg_parser(id, *this); } common_peg_parser add(const common_peg_parser_variant & p) { return wrap(arena_.add_parser(p)); } + // Generic helpers for building object/array structures with configurable string/value parsers. + common_peg_parser generic_object(const std::string & name, const common_peg_parser & string_parser, const common_peg_parser & value_parser); + common_peg_parser generic_array(const std::string & name, const common_peg_parser & value_parser); + public: common_peg_parser_builder(); @@ -404,6 +416,21 @@ class common_peg_parser_builder { // S -> A{n} common_peg_parser repeat(const common_peg_parser & p, int n) { return repeat(p, n, n); } + // Matches a double-quoted string: '"' content '"' space + common_peg_parser double_quoted_string(); + + // Matches a single-quoted string: "'" content "'" space + common_peg_parser single_quoted_string(); + + // Matches a string that accepts both double-quoted and single-quoted styles. + common_peg_parser flexible_string(); + + // Matches double-quoted string content without the surrounding quotes. + common_peg_parser json_string_content(); + + // Matches single-quoted string content without the surrounding quotes. + common_peg_parser single_quoted_string_content(); + // Creates a complete JSON parser supporting objects, arrays, strings, numbers, booleans, and null. // value -> object | array | string | number | true | false | null common_peg_parser json(); @@ -414,14 +441,24 @@ class common_peg_parser_builder { common_peg_parser json_bool(); common_peg_parser json_null(); - // Matches JSON string content without the surrounding quotes. - // Useful for extracting content within a JSON string. - common_peg_parser json_string_content(); - // Matches a JSON object member with a key and associated parser as the // value. common_peg_parser json_member(const std::string & key, const common_peg_parser & p); + // Creates a complete Python format parser supporting dicts, arrays, strings, numbers, booleans, and None. + // Differs from JSON: uses True/False/None, accepts both single and double-quoted strings. + // value -> dict | array | string | number | True | False | None + common_peg_parser python_value(); + common_peg_parser python_dict(); + common_peg_parser python_string(); + common_peg_parser python_array(); + common_peg_parser python_number(); + common_peg_parser python_bool(); + common_peg_parser python_null(); + + // A marker, i.e. text delimited by a pair of <> or [] + common_peg_parser marker(); + // Wraps a parser with JSON schema metadata for grammar generation. // Used internally to convert JSON schemas to GBNF grammar rules. common_peg_parser schema(const common_peg_parser & p, const std::string & name, const nlohmann::ordered_json & schema, bool raw = false); diff --git a/common/unicode.cpp b/common/unicode.cpp index 56ab0f468e0..c0ef6d02926 100644 --- a/common/unicode.cpp +++ b/common/unicode.cpp @@ -1,14 +1,18 @@ #include "unicode.h" +#include +#include +#include +#include // implementation adopted from src/unicode.cpp -size_t utf8_sequence_length(unsigned char first_byte) { +size_t common_utf8_sequence_length(unsigned char first_byte) { const size_t lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 3, 4 }; uint8_t highbits = static_cast(first_byte) >> 4; return lookup[highbits]; } -utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { +utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset) { if (offset >= input.size()) { return utf8_parse_result(utf8_parse_result::INCOMPLETE); } @@ -62,3 +66,43 @@ utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset) { // Invalid first byte return utf8_parse_result(utf8_parse_result::INVALID); } + +std::string common_unicode_cpts_to_utf8(const std::vector & cps) { + std::string result; + for (size_t i = 0; i < cps.size(); ++i) { + result.append(common_unicode_cpt_to_utf8(cps[i])); + } + return result; +} + +std::string common_unicode_cpt_to_utf8(uint32_t cpt) { + std::string result; + + if (/* 0x00 <= cpt && */ cpt <= 0x7f) { + result.push_back(cpt); + return result; + } + if (0x80 <= cpt && cpt <= 0x7ff) { + result.push_back(0xc0 | ((cpt >> 6) & 0x1f)); + result.push_back(0x80 | (cpt & 0x3f)); + return result; + } + if (0x800 <= cpt && cpt <= 0xffff) { + result.push_back(0xe0 | ((cpt >> 12) & 0x0f)); + result.push_back(0x80 | ((cpt >> 6) & 0x3f)); + result.push_back(0x80 | (cpt & 0x3f)); + return result; + } + if (0x10000 <= cpt && cpt <= 0x10ffff) { + result.push_back(0xf0 | ((cpt >> 18) & 0x07)); + result.push_back(0x80 | ((cpt >> 12) & 0x3f)); + result.push_back(0x80 | ((cpt >> 6) & 0x3f)); + result.push_back(0x80 | (cpt & 0x3f)); + return result; + } + + throw std::invalid_argument("invalid codepoint"); +} + + + diff --git a/common/unicode.h b/common/unicode.h index 9d9e8e1227a..87bcc0ffcaf 100644 --- a/common/unicode.h +++ b/common/unicode.h @@ -2,6 +2,8 @@ #include #include +#include +#include // UTF-8 parsing utilities for streaming-aware unicode support @@ -16,7 +18,10 @@ struct utf8_parse_result { // Determine the expected length of a UTF-8 sequence from its first byte // Returns 0 for invalid first bytes -size_t utf8_sequence_length(unsigned char first_byte); +size_t common_utf8_sequence_length(unsigned char first_byte); // Parse a single UTF-8 codepoint from input -utf8_parse_result parse_utf8_codepoint(std::string_view input, size_t offset); +utf8_parse_result common_parse_utf8_codepoint(std::string_view input, size_t offset); + +std::string common_unicode_cpts_to_utf8(const std::vector & cps); +std::string common_unicode_cpt_to_utf8(uint32_t cpt); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 09544173981..083b5bca9e9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4031,7 +4031,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter # split Conv3D into Conv2Ds c1, c2, kt, kh, kw = data_torch.shape del c1, c2, kh, kw # unused - assert kt == 2, "Current implmentation only support temporal_patch_size of 2" + assert kt == 2, "Current implementation only support temporal_patch_size of 2" yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight" , data_torch[:, :, 0, ...]) yield (gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH] + ".weight.1", data_torch[:, :, 1, ...]) else: @@ -4842,12 +4842,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Qwen3_5ForConditionalGeneration") +@ModelBase.register("Qwen3_5ForConditionalGeneration", "Qwen3_5ForCausalLM") class Qwen3_5TextModel(_LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35 -@ModelBase.register("Qwen3_5MoeForConditionalGeneration") +@ModelBase.register("Qwen3_5MoeForConditionalGeneration", "Qwen3_5MoeForCausalLM") class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase): model_arch = gguf.MODEL_ARCH.QWEN35MOE @@ -5404,7 +5404,7 @@ def set_gguf_parameters(self): # Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv linear_attn_config = self.hparams["linear_attn_config"] # n_head == 0 for KDA layers, n_head > 0 for MLA layers - # full_attention_layers list will be used to distingush layer type + # full_attention_layers list will be used to distinguish layer type _num_kv_heads = list() _full_attn_layers = linear_attn_config["full_attn_layers"] for il in range(self.hparams["num_hidden_layers"]): @@ -6505,7 +6505,7 @@ def set_gguf_parameters(self): super().set_gguf_parameters() hparams = self.hparams self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.GEMMA3) - # default values below are taken from HF tranformers code + # default values below are taken from HF transformers code self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6)) self.gguf_writer.add_vision_use_gelu(True) # calculate proj_scale_factor (used by tinygemma3 test model) @@ -7097,7 +7097,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter if bid == 0 and "time_mix_a" in new_name: # dummy v0/v1/v2 on first layer - # easist way to make llama happy + # easiest way to make llama happy yield (new_name.replace("time_mix_a", "time_mix_v"), data_torch) yield (new_name, data_torch) @@ -9596,7 +9596,7 @@ def __init__(self, *args, **kwargs): # NOTE: Explicitly include hparam prefix prefix for d_model to # disambiguate with top-level head_dim # NOTE 2: If needed for future models, this can be isolated in a method - # to separate the prefix setting and teh keys used + # to separate the prefix setting and the keys used self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"]) self.n_group = self.find_hparam(["n_groups", "num_groups"]) self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model @@ -9743,7 +9743,7 @@ def set_gguf_parameters(self): self.gguf_writer.add_value_length(self.head_dim) # Set feed_forward_length - # NOTE: This will trigger an override warning. This is preferrable to + # NOTE: This will trigger an override warning. This is preferable to # duplicating all the parent logic if not self.is_moe: n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"]) diff --git a/docs/autoparser.md b/docs/autoparser.md new file mode 100644 index 00000000000..686b2c249b6 --- /dev/null +++ b/docs/autoparser.md @@ -0,0 +1,525 @@ +# Auto-Parser Architecture + +The auto-parser automatically analyzes chat templates to determine how to parse model outputs, including content, reasoning, and tool calls. + +## Overview + +The unified auto-parser uses a pure differential, compositional approach (inspired by the `git diff` algorithm) to analyze chat templates: + +**Core Philosophy**: + +- **Minimize Hardcoded Patterns**: All markers extracted through template comparison (the only heuristic is JSON detection to distinguish `JSON_NATIVE` from tag-based formats) +- **Compositional Architecture**: Separate analyzer structs for reasoning, content, and tools — each responsible for its own analysis and parser construction + +**Analysis + Parser Building in Two Steps**: + +1. `autoparser::autoparser tmpl_analysis(tmpl)` — runs all differential comparisons and populates the analysis structs +2. `autoparser::peg_generator::generate_parser(tmpl, params, tmpl_analysis)` — uses the analysis to build a PEG parser and optional GBNF grammar + +## Data Structures + +All structs are defined in [common/chat-auto-parser.h](common/chat-auto-parser.h). + +### Top-Level: `autoparser` (main analyzer and generator) + +[common/chat-auto-parser.h:367-388](common/chat-auto-parser.h#L367-L388) — top-level analysis result aggregating `jinja_caps`, `reasoning`, `content`, and `tools` sub-analyses, plus `preserved_tokens` (union of all non-empty markers). + +### `analyze_reasoning` + +[common/chat-auto-parser.h:254-274](common/chat-auto-parser.h#L254-L274) — reasoning analysis result: `mode` enum, `start` marker (e.g. ``), and `end` marker (e.g. ``). + +### `analyze_content` + +[common/chat-auto-parser.h:280-295](common/chat-auto-parser.h#L280-L295) — content analysis result: `mode` enum, `start`/`end` markers, and `requires_nonnull_content` flag. + +### `analyze_tools` and its sub-structs + +- [common/chat-auto-parser.h:176-194](common/chat-auto-parser.h#L176-L194) — `tool_format_analysis`: `mode` enum, `section_start/end`, `per_call_start/end`, JSON field names (`function_field`, `name_field`, `args_field`, `id_field`, `gen_id_field`), and format flags (`fun_name_is_key`, `tools_array_wrapped`, `uses_python_dicts`) +- [common/chat-auto-parser.h:196-200](common/chat-auto-parser.h#L196-L200) — `tool_function_analysis`: `name_prefix`, `name_suffix`, `close` markers around function names +- [common/chat-auto-parser.h:202-210](common/chat-auto-parser.h#L202-L210) — `tool_arguments_analysis`: `start/end` container markers, `name_prefix/suffix`, `value_prefix/suffix`, `separator` +- [common/chat-auto-parser.h:212-217](common/chat-auto-parser.h#L212-L217) — `tool_id_analysis`: `pos` enum, `prefix`/`suffix` markers around call ID values +- [common/chat-auto-parser.h:301-361](common/chat-auto-parser.h#L301-L361) — `analyze_tools`: aggregates the four sub-structs above + +### Enums + +**`reasoning_mode`**: How the template handles reasoning/thinking blocks. + +| Value | Description | +|-----------------|-----------------------------------------------------------------------------------| +| `NONE` | No reasoning markers detected | +| `TAG_BASED` | Standard tag-based: `...` | +| `DELIMITER` | Delimiter-based: reasoning ends at a delimiter (e.g., `[BEGIN FINAL RESPONSE]`) | +| `FORCED_OPEN` | Template ends with open reasoning tag when `enable_thinking=true` | +| `FORCED_CLOSED` | `enable_thinking=false` emits both tags; `enable_thinking=true` emits only start | +| `TOOLS_ONLY` | Reasoning only appears in tool call responses, not plain content | + +**`content_mode`**: How the template wraps assistant content. + +| Value | Description | +|--------------------------|----------------------------------------------------------------| +| `PLAIN` | No content markers | +| `ALWAYS_WRAPPED` | Content always wrapped: `...` | +| `WRAPPED_WITH_REASONING` | Content wrapped only when reasoning is present | + +**`tool_format`**: Classification of tool call structure. + +| Value | Description | +|------------------|------------------------------------------------------------------| +| `NONE` | No tool support detected | +| `JSON_NATIVE` | Pure JSON: `{"name": "X", "arguments": {...}}` | +| `TAG_WITH_JSON` | Tag-based with JSON args: `{...}` | +| `TAG_WITH_TAGGED`| Tag-based with tagged args: `value` | + +**`call_id_position`**: Where call IDs appear in tag-based formats. + +| Value | Description | +|--------------------------|----------------------------------------------| +| `NONE` | No call ID support detected | +| `PRE_FUNC_NAME` | Before function name | +| `BETWEEN_FUNC_AND_ARGS` | Between function name and arguments | +| `POST_ARGS` | After arguments | + +## Tool Calling Formats + +### JSON_NATIVE + +**Structure**: The entire tool call (function name, arguments, values) is in JSON format. Optional enclosing tags around the section. + +**Detection**: Function name appears inside a JSON structure (quotes preceded by `{` or `:`). + +**Examples**: + +Standard OpenAI-style: + +```json + +{"name": "get_weather", "arguments": {"location": "Paris", "unit": "celsius"}} + +``` + +Mistral Nemo with array wrapper: + +```json +[TOOL_CALLS] +[{"name": "calculate", "arguments": {"expr": "2+2"}}] +``` + +Function name as JSON key (Apertus style): + +```json +{"get_weather": {"location": "Paris"}} +``` + +--- + +### TAG_WITH_JSON + +**Structure**: Function name is outside JSON, in tag attributes or XML-style tags. Arguments are a JSON object. + +**Detection**: Function name not in JSON, but argument names appear in JSON context. + +**Examples**: + +Functionary v3.1: + +```xml +{"location": "Paris", "unit": "celsius"} +``` + +MiniMax: + +```xml + +calculate +{"expr": "2+2"} + +``` + +--- + +### TAG_WITH_TAGGED + +**Structure**: Both function name and argument names are in XML-style tags. String values are unquoted; non-string values are JSON-formatted. + +**Detection**: Neither function name nor argument names appear in a JSON context. + +**Examples**: + +Qwen/Hermes XML format: + +```xml + +Paris +celsius + +``` + +Mixed types: + +```xml + +2+2 +2 +{"round": true} + +``` + +String values (`Paris`, `celsius`, `2+2`) are unquoted; `options` (object type) is JSON-formatted. + +--- + +## Analysis Flow + +```text +autoparser::autoparser(tmpl) + | + |-- Phase 1: analyze_reasoning(tmpl, jinja_caps.supports_tool_calls) + | |-- R1: compare_reasoning_presence() — with/without reasoning_content field + | |-- R2: compare_thinking_enabled() — enable_thinking=false vs true + | '-- R3: compare_reasoning_scope() — reasoning+content vs reasoning+tools + | (only if supports_tool_calls) + | + |-- Phase 2: analyze_content(tmpl, reasoning) + | '-- C1: compares content-only vs tools output and content-only vs reasoning output + | + |-- Phase 3: analyze_tools(tmpl, jinja_caps, reasoning) + | (skipped entirely if !jinja_caps.supports_tool_calls) + | | + | |-- T1: analyze_tool_calls() — no tools vs with tools; classifies format + | | |-- JSON path → analyze_tool_call_format_json_native() + | | '-- tag path → analyze_tool_call_format_non_json() + | | + | (if format != NONE and format != JSON_NATIVE:) + | | + | |-- T2: check_per_call_markers() — 1 call vs 2 calls; moves section→per-call if needed + | | (only if supports_parallel_tool_calls) + | | + | |-- T3: extract_function_markers() — func_alpha vs func_beta; extracts name prefix/suffix/close + | | + | |-- T4: analyze_arguments() — (TAG_WITH_TAGGED only) + | | |-- A1: extract_argument_name_markers() — arg_name_A vs arg_name_B + | | '-- A2: extract_argument_value_markers() — value "XXXX" vs "YYYY" + | | + | |-- T5: extract_argument_separator() — 1 arg vs 2 args; finds separator between args + | | + | |-- T6: extract_args_markers() — 0 args vs 1 arg; finds args container markers + | | + | '-- T7: extract_call_id_markers() — call_id "call00001" vs "call99999" + | + '-- collect_preserved_tokens() — union of all non-empty markers + | + '-- apply workarounds() — post-hoc patches for edge-case templates + | + v +autoparser (analysis result) + | + v +autoparser::peg_generator::generate_parser(tmpl, inputs, analysis) + |-- analysis.build_parser(inputs) — builds PEG parser arena + | |-- reasoning.build_parser(ctx) — reasoning parser (mode-dependent) + | |-- content.build_parser(ctx) — content parser (mode-dependent) + | '-- tools.build_parser(ctx) — tool parser (dispatches by tool_format) + | |-- build_tool_parser_json_native() + | |-- build_tool_parser_tag_json() + | '-- build_tool_parser_tag_tagged() + | + |-- Build GBNF grammar (if tools present and trigger_marker non-empty) + '-- Set grammar_triggers from section_start or per_call_start + | + v +common_chat_params (prompt, parser, grammar, triggers, preserved_tokens) +``` + +## Entry Point + +The auto-parser is invoked in [common/chat.cpp:1280-1310](common/chat.cpp#L1280-L1310) in `common_chat_templates_apply_jinja`. A few specialized templates are handled first (Ministral/Magistral Large 3, GPT-OSS with `<|channel|>`, Functionary v3.2 with `>>>all`), then the auto-parser handles everything else via `autoparser::autoparser` + `peg_generator::generate_parser`. + +## Algorithm Details + +### Core Mechanism: Differential Comparison + +All analysis phases use the same factorized comparison function declared in [common/chat-auto-parser-helpers.h:68](common/chat-auto-parser-helpers.h#L68): + +```cpp +compare_variants(tmpl, params_A, params_modifier) +``` + +This creates variant B by applying a modifier lambda to a copy of `params_A`, renders both through the template, and computes a `diff_split` ([common/chat-auto-parser.h:28-37](common/chat-auto-parser.h#L28-L37)): + +- `prefix` — common prefix between A and B +- `suffix` — common suffix between A and B +- `left` — unique to variant A +- `right` — unique to variant B + +The diff is computed via `calculate_diff_split()`, which finds the longest-common-prefix and longest-common-suffix, then iteratively moves incomplete `<...>` or `[...]` markers from the prefix/suffix into left/right until stable (tag boundary fixing). + +Text is segmentized into markers and non-marker fragments using `segmentize_markers()`, which splits on `<...>` and `[...]` boundaries. + +### Phase 1: Reasoning Analysis + +**R1 — `compare_reasoning_presence()`**: Compares assistant message with vs without a `reasoning_content` field. + +- Searches `diff.right` (output with reasoning) for the reasoning content needle +- Uses PEG parsers to find surrounding markers: + - If both pre/post markers found in `diff.right` → `TAG_BASED` (both tags visible in diff = no forced close) + - If both found but post marker only in the full output B → `FORCED_CLOSED` + - If only post marker found → `DELIMITER` +- Sets `reasoning.start` and `reasoning.end` + +**R2 — `compare_thinking_enabled()`**: Compares `enable_thinking=false` vs `true` with a generation prompt. + +- Detects `FORCED_OPEN`: `enable_thinking=true` adds a non-empty marker at the end of the prompt (where model will start generating) — sets `reasoning.start`, mode = `FORCED_OPEN` +- Detects `FORCED_CLOSED`: `enable_thinking=false` produces both start+end markers; `enable_thinking=true` produces only start marker +- Handles the reverse case: if both start and end are still empty, looks for a single-segment diff on each side to extract both markers + +**R3 — `compare_reasoning_scope()`**: Compares assistant message with reasoning+text-content vs reasoning+tool-calls. + +- Only runs if `jinja_caps.supports_tool_calls` +- Detects `TOOLS_ONLY`: reasoning content present in B (with tools) but not in A (with text content) +- Extracts reasoning markers from the tool call output using PEG parsers + +### Phase 2: Content Analysis + +**C1**: Two comparisons in the `analyze_content` constructor: + +- Comparison 1: content-only output vs tool-call output → `diff_tools` +- Comparison 2: content-only output vs reasoning+empty-content output → `diff_reasoning` + +Classification logic: + +- `PLAIN`: `diff_tools.left` equals the response string (content is the entire diff, no wrapper) +- `ALWAYS_WRAPPED`: markers found surrounding the content text in `pure_content` → extracts `start`/`end` + +### Phase 3: Tool Call Analysis + +**T1 — `analyze_tool_calls()`**: Compares no-tools vs with-tools output. + +- Extracts the tool call section as `diff.right` +- Calls `analyze_tool_call_format()` which first strips reasoning markers from the haystack, then: + - Calls `in_json_haystack()` for both function name and argument name needles + - `in_json_haystack()` uses a PEG parser to check whether the needle appears in a JSON context (preceded by `{` or `:` with surrounding quotes) + - If function name is in JSON → `JSON_NATIVE` → `analyze_tool_call_format_json_native()` + - If function name not in JSON, arg name is in JSON → `TAG_WITH_JSON` + - If neither in JSON → `TAG_WITH_TAGGED` + - `analyze_tool_call_format_json_native()`: parses the JSON object, matches field values to needles to populate `name_field`, `args_field`, `id_field`, `gen_id_field`; detects `tools_array_wrapped`; extracts `section_start`/`section_end` + - `analyze_tool_call_format_non_json()`: uses PEG parsers on the haystack to find up to two opening markers (section + per-call) then up to two closing markers + +**T2 — `check_per_call_markers()`**: Compares 1 call vs 2 calls. + +- Computes a secondary diff of the second call portion vs the common suffix +- If the second call content starts with `section_start` → the section marker is actually per-call → moves `section_start/end` to `per_call_start/end` and clears the section markers + +**T3 — `extract_function_markers()`**: Compares function name `FUN_FIRST` vs `FUN_SECOND` (two different named functions). + +- Finds where the function name appears in `diff.left` +- Extracts `function.name_prefix` from the common prefix up to the function marker, and `function.name_suffix` from after the name up to the next marker +- Extends `name_suffix` into `diff.suffix` (to the first marker for TAG_WITH_TAGGED; to the first `{` or `[` for TAG_WITH_JSON) +- Extracts `function.close` from after the last argument value up to the per-call/section end marker + +**T4 — `analyze_arguments()`** (TAG_WITH_TAGGED only): + +- **A1 `extract_argument_name_markers()`**: Compares `arg_name_A` vs `arg_name_B` (two different argument names). + - Finds shared surrounding structure → `arguments.name_prefix`, `arguments.name_suffix` +- **A2 `extract_argument_value_markers()`**: Compares argument value `"XXXX"` vs `"YYYY"` (same arg, different value). + - Finds markers surrounding the value → `arguments.value_prefix`, `arguments.value_suffix` + +**T5 — `extract_argument_separator()`**: Compares 1 argument vs 2 arguments (same function). + +- Uses `until_common_prefix(diff.right, ARG_FIRST, ARG_SECOND)` to find what separates the two argument blocks + +**T6 — `extract_args_markers()`**: Compares 0 arguments vs 1 argument. + +- Uses `until_common_prefix()` and `after_common_suffix()` with the empty and single-arg JSON strings as anchors to find container markers (`arguments.start`, `arguments.end`) + +**T7 — `extract_call_id_markers()`**: Compares call IDs `"call00001"` vs `"call99999"`. + +- Determines whether function name appears in `diff.prefix` or `diff.suffix` to classify position: + - Function name in prefix only → `BETWEEN_FUNC_AND_ARGS` or `POST_ARGS` (further distinguished by where `{` appears) + - Function name in suffix only → `PRE_FUNC_NAME` +- Extracts `call_id.prefix` and `call_id.suffix` markers around the call ID value +- Clears `per_call_end` if it incorrectly incorporated the call ID suffix + +### Workarounds + +A workaround array in `common/chat-diff-analyzer.cpp` applies post-hoc patches after analysis. Each workaround is a lambda that inspects the template source and overrides analysis results. Current workarounds: + +1. **Old Qwen/DeepSeek thinking templates** — source contains `content.split('')`: sets `reasoning.mode = FORCED_OPEN` with ``/`` markers if no reasoning was detected +2. **Granite 3.3** — source contains specific "Write your thoughts" text: forces `TAG_BASED` reasoning with ``/`` and `WRAPPED_WITH_REASONING` content with ``/`` +3. **Cohere Command R+** — source contains `<|CHATBOT_TOKEN|>`: sets `ALWAYS_WRAPPED` content mode if no content start is already set +4. **Functionary 3.1** — source contains `set has_code_interpreter`: forces `PLAIN` content, specific `per_call_start/end`, clears preserved tokens to only keep Functionary-specific markers +5. **DeepSeek-R1-Distill-Qwen** — source contains `tool▁calls▁begin` markers: overrides tool section/per-call markers with the correct Unicode block characters + +### Parser Building + +Each analyzer struct (`analyze_reasoning`, `analyze_content`, `analyze_tools`) implements `build_parser(parser_build_context&)`. They share a `parser_build_context` that carries the PEG builder, inference inputs, the pre-built reasoning parser, and a pointer to the content analyzer. + +#### Reasoning Parser (`analyze_reasoning::build_parser`) + +| Mode | Parser | +|-----------------------------------|---------------------------------------------------------------------| +| Not extracting reasoning | `eps()` | +| `FORCED_OPEN` or `FORCED_CLOSED` | `reasoning(until(end)) + end` — opening tag was in the prompt | +| `TAG_BASED` or `TOOLS_ONLY` | `optional(start + reasoning(until(end)) + end)` | +| `DELIMITER` | `optional(reasoning(until(end)) + end)` — no start marker | + +#### Content Parser (`analyze_content::build_parser`) + +| Condition | Parser | +|----------------------------------------|---------------------------------------------------------------------------------| +| `json_schema` present | `reasoning + space() + content(schema(json(), "response-format", ...)) + end()` | +| Tools present | Dispatches to `analyze_tools::build_parser()` | +| `ALWAYS_WRAPPED` with reasoning | `reasoning + start + content(until(end)) + end + end()` | +| `ALWAYS_WRAPPED` without reasoning | `content(until(start)) + start + content(until(end)) + end + end()` | +| Default (PLAIN) | `reasoning + content(rest()) + end()` | + +#### Tool Parsers (`analyze_tools::build_parser`) + +Dispatches by `format.mode`: + +**`build_tool_parser_json_native()`**: Calls `p.standard_json_tools()` which internally dispatches to: + +- `build_json_tools_function_is_key()` — function name is the JSON key: `{"get_weather": {...}}` +- `build_json_tools_nested_keys()` — nested: `{"function": {"name": "X", "arguments": {...}}}` +- `build_json_tools_flat_keys()` — flat: `{"name": "X", "arguments": {...}}` + +Handles content wrappers, array wrapping (`tools_array_wrapped`), parallel calls, and `parameter_order`. + +**`build_tool_parser_tag_json()`**: For each tool function: + +```text +tool_open(name_prefix + tool_name(literal(name)) + name_suffix) + + call_id_section + + tool_args(schema(json(), tool_schema)) + [+ function.close if non-empty] +``` + +Wrapped in per-call markers (with optional parallel call repetition) then optionally in section markers. + +**`build_tool_parser_tag_tagged()`**: For each tool function, builds one parser per argument: + +- String types: `tool_arg_string_value(schema(until(value_suffix), ...))` +- JSON types: `tool_arg_json_value(schema(json(), ...))` +- Required args are plain; optional args wrapped in `optional()` +- Arguments joined with `space()` between consecutive parsers + +For closing: uses `function.close` if present; otherwise uses `peek(per_call_end)` to avoid premature close during partial streaming; falls back to `tool_close(space())` to trigger mapper callbacks. + +All three tool parsers return: + +```text +reasoning + optional(content(until(trigger_marker))) + tool_calls + end() +``` + +### Python Dict Format + +When `format.uses_python_dicts` is true (detected when single-quoted strings appear in JSON argument context), `build_parser()` pre-registers a `json-string` rule that accepts both single-quoted and double-quoted strings. This is done before any `p.json()` call so all JSON parsing inherits the flexible rule. + +## Mapper + +`common_chat_peg_mapper` maps PEG parse results (AST nodes) into `common_chat_msg` structures. Key design: + +- **Buffered arguments**: Before `tool_name` is known, argument text goes to `args_buffer`; once the name is set, the buffer is flushed to `current_tool->arguments` +- **`args_target()`**: Returns a reference to whichever destination is currently active (buffer or tool args), eliminating branching +- **`closing_quote_pending`**: Tracks whether a closing `"` needs to be appended when a string argument value is finalized (for schema-declared string types in tagged format) +- **Quote normalization**: Python-style quotes (`'key': 'value'`) are converted to JSON (`"key": "value"`) +- **Brace auto-closing**: At tool close, unclosed `{` braces are closed automatically + +## Files + +| File | Purpose | +|-------------------------------------------|----------------------------------------------------------------------| +| `common/chat-auto-parser.h` | All analysis structs, enums, `autoparser`, `peg_generator`, `templates_params` | +| `common/chat-auto-parser-generator.cpp` | Parser generator: `generate_parser()` and `build_parser()` methods | +| `common/chat-diff-analyzer.cpp` | Differential analysis implementation and workarounds | +| `common/chat-auto-parser-helpers.h/cpp` | `calculate_diff_split()`, `segmentize_markers()`, | +| | `compare_variants()`, string helpers | +| `common/chat-peg-parser.h/cpp` | `common_chat_peg_builder`, `common_chat_peg_mapper`, and helpers | +| `common/chat.cpp` | Entry point: `common_chat_templates_apply_jinja()` | +| `tools/parser/debug-template-parser.cpp` | Debug tool for template analysis | +| `tools/parser/template-analysis.cpp` | Template analysis tool | + +## Testing & Debugging + +### Debug Tools + +**Template Debugger**: `tools/parser/debug-template-parser.cpp` + +- Usage: `./bin/llama-debug-template-parser path/to/template.jinja` +- Shows detected format, markers, generated parser, and GBNF grammar + +**Template Analysis**: `tools/parser/template-analysis.cpp` + +- Usage: `./bin/llama-template-analysis path/to/template.jinja` + +**Debug Logging**: Enable with `LLAMA_LOG_VERBOSITY=2` + +- Shows detailed analysis steps, pattern extraction results, and generated parser structure + +**PEG Test Builder**: Fluent API for creating test cases — see [tests/test-chat.cpp:947-1043](tests/test-chat.cpp#L947-L1043). Example usage: + +```cpp +auto tst = peg_tester("models/templates/Template.jinja"); +tst.test("input text") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({tool_json}) + .parallel_tool_calls(true) + .enable_thinking(true) + .expect(expected_message) + .run(); +``` + +### Tested Templates + +The following templates have active tests in `tests/test-chat.cpp`: + +| Template | Format | Notes | +| -------- | ------ | ----- | +| Ministral-3-14B-Reasoning | Reasoning | `[THINK]...[/THINK]` tags (specialized handler) | +| NVIDIA-Nemotron-3-Nano-30B | TAG_WITH_TAGGED | Reasoning + tools | +| CohereForAI Command-R7B | JSON_NATIVE | `<\|START_THINKING\|>`/`<\|START_RESPONSE\|>` markers | +| Google Gemma 2 2B | Content only | No tool support | +| Qwen-QwQ-32B | Reasoning | Forced-open thinking | +| NousResearch Hermes 2 Pro | JSON_NATIVE | `` wrapper | +| IBM Granite 3.3 | JSON_NATIVE | `` + `` | +| ByteDance Seed-OSS | TAG_WITH_TAGGED | Custom `` and `` tags | +| Qwen3-Coder | TAG_WITH_TAGGED | XML-style tool format | +| DeepSeek V3.1 | JSON_NATIVE | Forced thinking mode | +| GLM-4.6 | TAG_WITH_TAGGED | `name\n......` format | +| GLM-4.7-Flash | TAG_WITH_TAGGED | Updated GLM format | +| Kimi-K2-Thinking | JSON_NATIVE | Reasoning + JSON tools | +| Apertus-8B-Instruct | JSON_NATIVE | Function name as JSON key | +| MiniMax-M2 | TAG_WITH_JSON | XML invoke with JSON args | +| NVIDIA-Nemotron-Nano-v2 | JSON_NATIVE | `` wrapper (nested) | +| CohereForAI Command-R Plus | JSON_NATIVE | Markdown code block format | +| Mistral-Nemo-Instruct-2407 | JSON_NATIVE | `[TOOL_CALLS]` wrapper with ID field | +| Functionary v3.1 | TAG_WITH_JSON | `` format | +| Functionary v3.2 | Specialized | `>>>` recipient delimiter (dedicated handler) | +| Fireworks Firefunction v2 | TAG_WITH_JSON | Fireworks tool format | +| DeepSeek R1 Distill (Llama/Qwen) | Reasoning | Forced-open thinking | +| llama-cpp-deepseek-r1 | Reasoning | Forced-open thinking | +| Kimi-K2 / Kimi-K2-Instruct | JSON_NATIVE | JSON tools with special markers | +| Llama 3.1/3.2/3.3 | JSON_NATIVE | Standard Llama tool format | +| OpenAI GPT-OSS | Specialized | Channel-based (dedicated handler) | +| Apriel 1.5 | JSON_NATIVE | `` wrapper with JSON array | +| Apriel 1.6 Thinker | Reasoning | Implicit reasoning start | +| Mistral Small 3.2 | JSON_NATIVE | `[TOOL_CALLS]func[ARGS]{...}` with call ID | +| Devstral | JSON_NATIVE | `[TOOL_CALLS]func[ARGS]{...}` without call ID | +| StepFun 3.5 Flash | TAG_WITH_TAGGED | `` format | + +## Adding Support for New Templates + +To support a new template format: + +1. **If it follows standard patterns** — The auto-parser should detect it automatically. Run `llama-debug-template-parser` to verify markers are correctly extracted. +2. **If differential analysis extracts incorrect markers** — Add a workaround lambda to the `workarounds` vector in `common/chat-diff-analyzer.cpp`. Inspect the template source for a unique identifying substring. +3. **If it needs fundamentally different handling** — Add a dedicated handler function in `chat.cpp` before the auto-parser block (as done for GPT-OSS, Functionary v3.2, and Ministral). + +## Edge Cases and Quirks + +1. **Forced Thinking**: When `enable_thinking=true` and the model prompt ends with an open reasoning tag (e.g., ``), the parser enters forced thinking mode and immediately expects reasoning content without waiting for a start marker. +2. **Per-Call vs Per-Section Markers**: Some templates wrap each tool call individually (`per_call_start/end`); others wrap the entire section (`section_start/end`). T2 (`check_per_call_markers()`) disambiguates by checking if the second call in a two-call output starts with the section marker. +3. **Python Dict Format**: The Seed template family uses single-quoted JSON (`'key': 'value'`). The `uses_python_dicts` flag causes the PEG builder to register a flexible `json-string` rule accepting both quote styles before any JSON rules are built. +4. **Tag Boundary Fixing**: `calculate_diff_split()` iteratively adjusts prefix/suffix boundaries to avoid splitting `` or `[marker]` tokens, ensuring clean extraction. +5. **Call ID Side Effects**: When a call ID is detected, `per_call_end` may have been incorrectly set to include the call ID suffix. T7 clears `per_call_end` in this case. +6. **Tool Analysis Gating**: `analyze_tools` is only constructed (and all tool analysis phases run) when `jinja_caps.supports_tool_calls` is true. Within tool analysis, `check_per_call_markers()` (T2) only runs if `jinja_caps.supports_parallel_tool_calls`. +7. **`analyze_arguments()` Gating**: Within tool analysis, A1 and A2 (argument name/value marker extraction) only run for `TAG_WITH_TAGGED` format. `extract_argument_separator()` and `extract_args_markers()` run for all non-`JSON_NATIVE` formats. diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 23b6a627634..51adaaf95f5 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -20,7 +20,7 @@ **Llama.cpp + CANN** -The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are intergrated to CANN Toolkit and kernels to using Ascend NPU directly. +The llama.cpp CANN backend is designed to support Ascend NPU. It utilize the ability of AscendC and ACLNN which are integrated to CANN Toolkit and kernels to using Ascend NPU directly. ## News @@ -210,7 +210,7 @@ docker run --name llamacpp --device /dev/davinci0 --device /dev/davinci_manager # and install driver. sudo sh Ascend-hdk-910b-npu-firmware_x.x.x.x.X.run --full ``` - If the following messaage appers, firmware is installed successfully. + If the following message appears, firmware is installed successfully. ```sh Firmware package installed successfully! ``` diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index 07c68be5cbd..dd4c66dbe95 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -708,7 +708,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512 - Remove **build** folder or try a clean-build. -- I can **not** see `[ext_oneapi_level_zero:gpu]` afer installing the GPU driver on Linux. +- I can **not** see `[ext_oneapi_level_zero:gpu]` after installing the GPU driver on Linux. Please double-check with `sudo sycl-ls`. diff --git a/docs/backend/snapdragon/README.md b/docs/backend/snapdragon/README.md index 2c3f88e91a2..0783555ce8a 100644 --- a/docs/backend/snapdragon/README.md +++ b/docs/backend/snapdragon/README.md @@ -116,7 +116,7 @@ Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 ### Windows All artifacts are already installed in the `pkg-snapdragon` folder. -To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`. +To run, adapt below instructions to use Powershell scripts in `scripts/snapdragon/windows`. ## How to Run diff --git a/docs/backend/snapdragon/windows.md b/docs/backend/snapdragon/windows.md index e9346ccadf1..6307e1b69f1 100644 --- a/docs/backend/snapdragon/windows.md +++ b/docs/backend/snapdragon/windows.md @@ -144,7 +144,7 @@ Once the build is complete HTP ops libraries will be installed like this -a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat ``` -The .cat file, the signature and proper certicate installation can be verified with +The .cat file, the signature and proper certificate installation can be verified with ``` > signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat diff --git a/docs/build.md b/docs/build.md index fd447424c78..772731f6418 100644 --- a/docs/build.md +++ b/docs/build.md @@ -108,7 +108,7 @@ Building through oneAPI compilers will make avx_vnni instruction set available f - Using oneAPI docker image: If you do not want to source the environment vars and install oneAPI manually, you can also build the code using intel docker container: [oneAPI-basekit](https://hub.docker.com/r/intel/oneapi-basekit). Then, you can use the commands given above. -Check [Optimizing and Running LLaMA2 on Intel® CPU](https://www.intel.com/content/www/us/en/content-details/791610/optimizing-and-running-llama2-on-intel-cpu.html) for more information. +Check [Optimizing and Running LLaMA2 on Intel® CPU](https://builders.intel.com/solutionslibrary/optimizing-and-running-llama2-on-intel-cpu) for more information. ### Other BLAS libraries @@ -595,7 +595,7 @@ You can verify that KleidiAI is being used by running ```bash ./build/bin/llama-cli -m PATH_TO_MODEL -p "What is a car?" ``` -If KleidiAI is enabled, the ouput will contain a line similar to: +If KleidiAI is enabled, the output will contain a line similar to: ``` load_tensors: CPU_KLEIDIAI model buffer size = 3474.00 MiB ``` @@ -699,7 +699,7 @@ To read documentation for how to build on Android, [click here](./android.md) ## WebGPU [In Progress] -The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The currrent implementation is up-to-date with Dawn commit `bed1a61`. +The WebGPU backend relies on [Dawn](https://dawn.googlesource.com/dawn). Follow the instructions [here](https://dawn.googlesource.com/dawn/+/refs/heads/main/docs/quickstart-cmake.md) to install Dawn locally so that llama.cpp can find it using CMake. The current implementation is up-to-date with Dawn commit `bed1a61`. In the llama.cpp directory, build with CMake: diff --git a/docs/development/parsing.md b/docs/development/parsing.md index dbb989bf08e..a41057db2b8 100644 --- a/docs/development/parsing.md +++ b/docs/development/parsing.md @@ -22,7 +22,7 @@ Below is a contrived example demonstrating how to use the PEG parser to parse output from a model that emits arguments as JSON. ```cpp -auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { +auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { // Build a choice of all available tools auto tool_choice = p.choice(); for (const auto & tool : tools) { @@ -212,7 +212,7 @@ mapper.from_ast(ctx.ast, result); ### Native -The `common_chat_peg_native_builder` builds a `native` parser suitable for +The `common_chat_peg_builder` builds a `native` parser suitable for models that emit tool arguments as a direct JSON object. - **`reasoning(p)`** - Tag node for `reasoning_content` @@ -225,7 +225,7 @@ models that emit tool arguments as a direct JSON object. - **`tool_args(p)`** - Tag the tool arguments ```cpp -build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) { +build_chat_peg_parser([&](common_chat_peg_builder & p) { auto get_weather_tool = p.tool(p.sequence({ p.tool_open(p.literal("{")), p.json_member("name", "\"" + p.tool_name(p.literal("get_weather")) + "\""), @@ -246,7 +246,7 @@ build_chat_peg_native_parser([&](common_chat_peg_native_parser & p) { ### Constructed -The `common_chat_peg_constructed_builder` builds a `constructed` parser +The `common_chat_peg_builder` builds a `constructed` parser suitable for models that emit tool arguments as separate entities, such as XML tags. @@ -264,7 +264,7 @@ tags. - **`tool_arg_json_value(p)`** - Tag JSON value for the argument ```cpp -build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { +build_chat_peg_parser([&](common_chat_peg_builder & p) { auto location_arg = p.tool_arg( p.tool_arg_open(""), p.tool_arg_string_value(p.until("")), diff --git a/docs/multimodal/MobileVLM.md b/docs/multimodal/MobileVLM.md index 3bfab9f3d22..6c17dbf902e 100644 --- a/docs/multimodal/MobileVLM.md +++ b/docs/multimodal/MobileVLM.md @@ -281,7 +281,7 @@ llama_print_timings: total time = 5990.25 ms / 202 tokens Just the same as above. -**ouput** +**output** ```sh encode_image_with_clip: image embedding created: 144 tokens @@ -305,7 +305,7 @@ llama_print_timings: total time = 15513.95 ms / 412 tokens ## Run on Intel(R) Core(TM) Ultra7 115H ### operation system Windows11 -### comiple +### compile ```sh make -j32 ``` diff --git a/docs/ops.md b/docs/ops.md index 296c0ba1d45..8213bc6abfb 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -24,7 +24,7 @@ Legend: | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ✅ | 🟡 | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | -| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | +| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ✅ | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | | CONV_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_2D_DW | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index e2ed3e2cfad..9e081e7605f 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -9535,38 +9535,38 @@ "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=0,v=0,inplace=1","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=24,n_ctx=512,fs=1.424500,ef=0.746500,af=1.424500,ff=1,v=0,inplace=1","support","1","yes","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU" -"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","0","no","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=2","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=0,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=1,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=2,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=f32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU" +"WebGPU: WebGPU","CONCAT","type=i32,ne_a=[11,12,13,14],ne_b_d=7,dim=3,v=3","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[3,1,1,1],order=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[4,1,1,1],order=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[7,1,1,1],order=0","support","1","yes","WebGPU" diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index a8c19a6aba6..d2b2e336e75 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -5,6 +5,7 @@ #include "sampling.h" #include +#include #include #include #include @@ -16,6 +17,8 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.prompt = "Hello my name is"; diff --git a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp index 767198aafa2..702bc74bee2 100644 --- a/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp +++ b/examples/convert-llama2c-to-ggml/convert-llama2c-to-ggml.cpp @@ -5,14 +5,16 @@ #include "common.h" #include "log.h" -#include -#include +#include #include +#include #include -#include +#include #include -#include +#include #include +#include +#include #include #include #include @@ -874,6 +876,8 @@ static std::string basename(const std::string &path) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_init(); struct train_params params = get_default_train_params(); diff --git a/examples/debug/README.md b/examples/debug/README.md index 28e00c93427..2ea716eb543 100644 --- a/examples/debug/README.md +++ b/examples/debug/README.md @@ -2,7 +2,7 @@ This is a utility intended to help debug a model by registering a callback that logs GGML operations and tensor data. It can also store the generated logits or -embeddings as well as the prompt and token ids for comparision with the original +embeddings as well as the prompt and token ids for comparison with the original model. ### Usage diff --git a/examples/deprecation-warning/deprecation-warning.cpp b/examples/deprecation-warning/deprecation-warning.cpp index 11f5147328a..0cde17f6e99 100644 --- a/examples/deprecation-warning/deprecation-warning.cpp +++ b/examples/deprecation-warning/deprecation-warning.cpp @@ -1,11 +1,14 @@ // Warns users that this filename was deprecated, and provides a link for more information. +#include #include #include #include // Main int main(int argc, char** argv) { + std::setlocale(LC_NUMERIC, "C"); + std::string filename = "main"; if (argc >= 1) { filename = argv[0]; diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md index f71d2413193..b3942002147 100644 --- a/examples/diffusion/README.md +++ b/examples/diffusion/README.md @@ -43,12 +43,12 @@ Choose one of the following scheduling methods: - `-b`: Batch size ### Examples -#### Dream architechture: +#### Dream architecture: ``` llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual ``` -#### LLaDA architechture: +#### LLaDA architecture: ``` llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual ``` diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp index d50f754092d..d38bfe7f82d 100644 --- a/examples/diffusion/diffusion-cli.cpp +++ b/examples/diffusion/diffusion-cli.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -538,6 +539,8 @@ static std::string format_input_text(const std::string & prompt, const std::stri } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + ggml_time_init(); common_params params; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index d8eaaa2691f..33ef2a7521f 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include @@ -94,6 +95,8 @@ static void print_raw_embeddings(const float * emb, } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_EMBEDDING)) { diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index bd587349798..17d162d95d3 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -4,6 +4,8 @@ #include "log.h" #include "llama.h" #include "llama-cpp.h" + +#include #include #include @@ -29,6 +31,8 @@ static bool run(llama_context * ctx, const common_params & params) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + base_callback_data cb_data; common_params params; diff --git a/examples/gen-docs/gen-docs.cpp b/examples/gen-docs/gen-docs.cpp index 0aa33e8245b..7ba7d79f721 100644 --- a/examples/gen-docs/gen-docs.cpp +++ b/examples/gen-docs/gen-docs.cpp @@ -1,6 +1,7 @@ #include "arg.h" #include "common.h" +#include #include #include #include @@ -100,6 +101,8 @@ static void write_help(std::ostringstream & ss, const md_file & md) { } int main(int, char **) { + std::setlocale(LC_NUMERIC, "C"); + for (const auto & md : md_files) { std::ifstream infile(md.fname); if (!infile.is_open()) { diff --git a/examples/gguf-hash/gguf-hash.cpp b/examples/gguf-hash/gguf-hash.cpp index 9523ec122f5..331de301ffc 100644 --- a/examples/gguf-hash/gguf-hash.cpp +++ b/examples/gguf-hash/gguf-hash.cpp @@ -1,13 +1,14 @@ #include "ggml.h" #include "gguf.h" -#include /* abort() */ +#include +#include #include #include -#include -#include -#include +#include /* abort() */ #include +#include +#include #include #include @@ -626,6 +627,8 @@ static hash_exit_code_t gguf_hash(const hash_params & hash_params) { } int main(int argc, const char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + hash_params params; manifest_check_params manifest_check; hash_params_parse(argc, argv, params); diff --git a/examples/gguf/gguf.cpp b/examples/gguf/gguf.cpp index 499cfacc92a..79ad38711e3 100644 --- a/examples/gguf/gguf.cpp +++ b/examples/gguf/gguf.cpp @@ -1,6 +1,7 @@ #include "ggml.h" #include "gguf.h" +#include #include #include #include @@ -240,6 +241,8 @@ static bool gguf_ex_read_1(const std::string & fname, bool check_data) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + if (argc < 3) { printf("usage: %s data.gguf r|w [n]\n", argv[0]); printf("r: read data.gguf file\n"); diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 9fc90a3c987..35f7d47f3c8 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -689,6 +689,11 @@ def add_component(comp_schema, is_required): elif (schema_type == 'object') or (len(schema) == 0): return self._add_rule(rule_name, self._add_primitive('object', PRIMITIVE_RULES['object'])) + elif schema_type is None and isinstance(schema, dict): + # No type constraint and no recognized structural keywords (e.g. {"description": "..."}). + # Per JSON Schema semantics this is equivalent to {} and accepts any value. + return self._add_rule(rule_name, self._add_primitive('value', PRIMITIVE_RULES['value'])) + else: assert schema_type in PRIMITIVE_RULES, f'Unrecognized schema: {schema}' # TODO: support minimum, maximum, exclusiveMinimum, exclusiveMaximum at least for zero diff --git a/examples/llama.vim b/examples/llama.vim index 736802d3655..23a281fc333 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -52,8 +52,8 @@ highlight llama_hl_info guifg=#77ff2f ctermfg=119 " n_prefix: number of lines before the cursor location to include in the local prefix " n_suffix: number of lines after the cursor location to include in the local suffix " n_predict: max number of tokens to predict -" t_max_prompt_ms: max alloted time for the prompt processing (TODO: not yet supported) -" t_max_predict_ms: max alloted time for the prediction +" t_max_prompt_ms: max allotted time for the prompt processing (TODO: not yet supported) +" t_max_predict_ms: max allotted time for the prediction " show_info: show extra info about the inference (0 - disabled, 1 - statusline, 2 - inline) " auto_fim: trigger FIM completion automatically on cursor movement " max_line_suffix: do not auto-trigger FIM completion if there are more than this number of characters to the right of the cursor diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index aa6efa62b3b..d5fde081c59 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -4,10 +4,11 @@ #include "log.h" #include "llama.h" +#include +#include #include #include #include -#include struct ngram_data { bool active = false; @@ -38,6 +39,8 @@ struct ngram_container { }; int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index f7b6ea1b190..439e3f726ee 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -3,10 +3,13 @@ #include "ngram-cache.h" #include "llama.h" +#include #include #include int main(int argc, char ** argv){ + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { diff --git a/examples/lookup/lookup-merge.cpp b/examples/lookup/lookup-merge.cpp index 6871c0f5fdb..ee3c7249cf1 100644 --- a/examples/lookup/lookup-merge.cpp +++ b/examples/lookup/lookup-merge.cpp @@ -3,6 +3,7 @@ #include "common.h" #include "ngram-cache.h" +#include #include #include #include @@ -17,6 +18,8 @@ static void print_usage(char* argv0) { } int main(int argc, char ** argv){ + std::setlocale(LC_NUMERIC, "C"); + if (argc < 3) { print_usage(argv[0]); exit(1); diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index ae28b2e6e86..c3158281c75 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -5,14 +5,17 @@ #include "llama.h" #include "ggml.h" +#include +#include #include #include -#include #include #include #include int main(int argc, char ** argv){ + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index c7552ddde14..bd216035c0b 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -6,6 +6,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include @@ -13,6 +14,8 @@ #include int main(int argc, char ** argv){ + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_LOOKUP)) { diff --git a/examples/model-conversion/README.md b/examples/model-conversion/README.md index 637870a5c15..c43e642fee7 100644 --- a/examples/model-conversion/README.md +++ b/examples/model-conversion/README.md @@ -69,7 +69,7 @@ Command line arguments take precedence over environment variables when both are In cases where the transformer implementation for the model has not been released yet it is possible to set the environment variable `UNRELEASED_MODEL_NAME` which -will then cause the transformer implementation to be loaded explicitely and not +will then cause the transformer implementation to be loaded explicitly and not use AutoModelForCausalLM: ``` export UNRELEASED_MODEL_NAME=SomeNewModel @@ -120,7 +120,7 @@ The converted model can be inspected using the following command: (venv) $ make causal-run-converted-model ``` -### Model logits verfication +### Model logits verification The following target will run the original model and the converted model and compare the logits: ```console @@ -235,7 +235,7 @@ new model the model can be converted to GGUF format using the following command: (venv) $ make embedding-run-converted-model ``` -### Model logits verfication +### Model logits verification The following target will run the original model and the converted model (which was done manually in the previous steps) and compare the logits: ```console @@ -335,7 +335,7 @@ $ make perplexity-run-full QUANTIZED_MODEL=~/path/to/quantized/model-Qxx.gguf LO ## HuggingFace utilities The following targets are useful for creating collections and model repositories -on Hugging Face in the the ggml-org. These can be used when preparing a relase +on Hugging Face in the the ggml-org. These can be used when preparing a release to script the process for new model releases. For the following targets a `HF_TOKEN` environment variable is required. @@ -347,7 +347,7 @@ For the following targets a `HF_TOKEN` environment variable is required. > $ unset HF_TOKEN ### Create a new Hugging Face Model (model repository) -This will create a new model repsository on Hugging Face with the specified +This will create a new model repository on Hugging Face with the specified model name. ```console (venv) $ make hf-create-model MODEL_NAME='TestModel' NAMESPACE="danbev" ORIGINAL_BASE_MODEL="some-base-model" diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index c92173ae291..1700ceefbf7 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -7,12 +7,13 @@ #include "log.h" #include "llama.h" +#include +#include #include #include #include #include #include -#include // trim whitespace from the beginning and end of a string static std::string trim(const std::string & str) { @@ -153,6 +154,8 @@ static std::vector split_string(const std::string& input, char deli } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + srand(1234); common_params params; diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index 8a4faa383bf..665191047a4 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include @@ -16,6 +17,8 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.n_junk = 250; diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 3f2afd4346e..9e05fc22337 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -4,6 +4,7 @@ #include "llama.h" #include +#include #include #include // TODO: remove me @@ -112,6 +113,8 @@ static void batch_process(llama_context * ctx, llama_batch & batch, float * outp } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_RETRIEVAL, print_usage)) { diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 5e35dcd6030..174c8c75854 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -2,11 +2,14 @@ #include "common.h" #include "llama.h" +#include #include #include int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.prompt = "The quick brown fox"; diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 57195df3316..97e9dc9842f 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include #include #include #include @@ -12,6 +13,8 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + std::string model_path; int ngl = 99; int n_ctx = 2048; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index d09771d1045..9f0a25d713f 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include #include #include #include @@ -11,6 +12,8 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + // path to the model gguf file std::string model_path; // prompt to generate text from diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index d8b1f5a480c..8a1cbd96c25 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -5,12 +5,15 @@ #include "log.h" #include "llama.h" +#include #include #include #include #include int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SPECULATIVE)) { diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 3e5cf5f46b5..250c5b7c62d 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -5,6 +5,7 @@ #include "llama.h" #include +#include #include #include #include @@ -30,6 +31,8 @@ struct seq_draft { }; int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; // needed to get candidate probs even for temp <= 0.0 diff --git a/examples/sycl/README.md b/examples/sycl/README.md index 8819d87f56e..29143dd6176 100644 --- a/examples/sycl/README.md +++ b/examples/sycl/README.md @@ -6,11 +6,11 @@ This example program provides the tools for llama.cpp for SYCL on Intel GPU. |Tool Name| Function|Status| |-|-|-| -|llama-ls-sycl-device| List all SYCL devices with ID, compute capability, max work group size, ect.|Support| +|llama-ls-sycl-device| List all SYCL devices with ID, compute capability, max work group size, etc.|Support| ### llama-ls-sycl-device -List all SYCL devices with ID, compute capability, max work group size, ect. +List all SYCL devices with ID, compute capability, max work group size, etc. 1. Build the llama.cpp for SYCL for the specified target *(using GGML_SYCL_TARGET)*. diff --git a/examples/sycl/ls-sycl-device.cpp b/examples/sycl/ls-sycl-device.cpp index 74a8b7fd814..3bdc4059825 100644 --- a/examples/sycl/ls-sycl-device.cpp +++ b/examples/sycl/ls-sycl-device.cpp @@ -6,8 +6,10 @@ #include "ggml-sycl.h" +#include int main() { + std::setlocale(LC_NUMERIC, "C"); ggml_backend_sycl_print_sycl_devices(); return 0; } diff --git a/examples/training/finetune.cpp b/examples/training/finetune.cpp index c82de8d35d0..e20f89488f2 100644 --- a/examples/training/finetune.cpp +++ b/examples/training/finetune.cpp @@ -3,6 +3,7 @@ #include "log.h" #include "llama.h" +#include #include #include #include @@ -14,6 +15,8 @@ #endif int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.escape = false; diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index a9d1778641e..9fd3f7f32a0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -259,7 +259,7 @@ extern "C" { Example usage: // operations that use tensors allocated in a buffer with USAGE_WEIGHTS will be assigned - // preferrably to run on the same backend as the buffer + // preferably to run on the same backend as the buffer ggml_backend_buffer_set_usage(buf_weights, GGML_BACKEND_BUFFER_USAGE_WEIGHTS); sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, NULL, num_backends, GGML_DEFAULT_GRAPH_SIZE, false, true); diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h index 4703a05afe1..1c2ed79b774 100644 --- a/ggml/include/ggml-opt.h +++ b/ggml/include/ggml-opt.h @@ -138,7 +138,7 @@ extern "C" { GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); - // set gradients to zero, initilize loss, and optionally reset the optimizer + // set gradients to zero, initialize loss, and optionally reset the optimizer GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); GGML_API bool ggml_opt_static_graphs(ggml_opt_context_t opt_ctx); // whether the graphs are allocated_statically diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index fcc51f1f71a..784d69206b4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2575,7 +2575,7 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay - // build forward mutiple tensors and select one of them for computing + // build forward multiple tensors and select one of them for computing // this is useful for creating graphs that have constant topology but compute different things based on the input // ref: https://github.com/ggml-org/llama.cpp/pull/18550 // diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 22c656996cc..bc57df20ba2 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -1455,6 +1455,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + if (sched->events[split_backend_id][sched->cur_copy] == NULL) { + ggml_backend_synchronize(split_backend); + } + // copy the input tensors to the split backend for (int input_id = 0; input_id < split->n_inputs; input_id++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[input_id]); @@ -1465,16 +1469,12 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s // inputs from the user must be copied immediately to prevent the user overwriting the data before the copy is done if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_synchronize(sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize(split_backend); } - ggml_backend_tensor_copy(input, input_cpy); + ggml_backend_tensor_copy_async(input_backend, split_backend, input, input_cpy); } else { // wait for the split backend to finish using the input before overwriting it if (sched->events[split_backend_id][sched->cur_copy] != NULL) { ggml_backend_event_wait(split_backend, sched->events[split_backend_id][sched->cur_copy]); - } else { - ggml_backend_synchronize(split_backend); } // when offloading MoE weights, we can reduce the amount of data copied by copying only the experts that are used @@ -1578,6 +1578,10 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } + if (sched->events[split_backend_id][sched->cur_copy] == NULL) { + ggml_backend_synchronize(split_backend); + } + if (!sched->callback_eval) { enum ggml_status ec = ggml_backend_graph_compute_async(split_backend, &split->graph); if (ec != GGML_STATUS_SUCCESS) { diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 2e9ddf2240d..5de64b816fc 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -339,8 +339,8 @@ static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t } static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - // TODO - *free = 0; + // no memory to report + *free = 0; *total = 0; GGML_UNUSED(dev); diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 3dc948e4d8e..6ca3176a2f2 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -566,9 +566,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.16.0") + set(KLEIDIAI_COMMIT_TAG "v1.22.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "0a9e9008adb6031f9e8cf70dff4a3321") + set(KLEIDIAI_ARCHIVE_MD5 "54049037570ab0ee0a0d126b2ba5ece1") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -608,6 +608,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/ ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) set(ARCH_FLAGS_TEMP "${ARCH_FLAGS}") @@ -648,7 +649,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (NOT SME_ENABLED MATCHES -1) list(APPEND GGML_KLEIDIAI_SOURCES - ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa_asm.S @@ -656,10 +656,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4vlx4_1x4vl_sme2_dot_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_f16p_qsi4c32p/kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_f16pmrx2_f32_neon.c ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) - set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") + set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2+sme2+fp16") endif() if (NOT SVE_ENABLED MATCHES -1) diff --git a/ggml/src/ggml-cpu/amx/common.h b/ggml/src/ggml-cpu/amx/common.h index f392e898518..26a6ec1a2d0 100644 --- a/ggml/src/ggml-cpu/amx/common.h +++ b/ggml/src/ggml-cpu/amx/common.h @@ -9,6 +9,8 @@ #if defined(GGML_USE_OPENMP) #include +#else +#include #endif #define TILE_M 16 @@ -56,18 +58,40 @@ inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { } template -inline void parallel_for(int n, const func_t& f) { +inline void parallel_for(int n, const func_t & f) { + if (n <= 0) { + return; + } #if defined(GGML_USE_OPENMP) -#pragma omp parallel -{ - int nth = omp_get_num_threads(); - int ith = omp_get_thread_num(); - int tbegin, tend; - balance211(n, nth, ith, tbegin, tend); - f(tbegin, tend); -} + #pragma omp parallel + { + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + } #else - f(0, n); + int nth = std::thread::hardware_concurrency(); + if (nth <= 1) { + f(0, n); + return; + } + if (nth > n) { + nth = n; + } + std::vector threads; + threads.reserve(nth); + for (int ith = 0; ith < nth; ++ith) { + threads.emplace_back([&f, n, ith, nth] { + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + }); + } + for (auto & t : threads) { + t.join(); + } #endif } diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index b5aca76633c..93a6d397f79 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -195,7 +195,7 @@ struct tile_config_t{ // will be needed. // // Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16; -// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`. +// and the single batch gemm (m=1) has a special fast path with `avx512-vnni`. // // ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/ // advanced-matrix-extensions-intrinsics-functions.html @@ -1379,8 +1379,8 @@ struct tinygemm_kernel_vnni 4 #if _WIN32_WINNT >= 0x0602 diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index d114f2d49bf..40f7c0df650 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -1,4 +1,4 @@ -// SPDX-FileCopyrightText: Copyright 2025 Arm Limited and/or its affiliates +// SPDX-FileCopyrightText: Copyright 2025-2026 Arm Limited and/or its affiliates // SPDX-License-Identifier: MIT // @@ -9,7 +9,6 @@ #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm.h" -#include "kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.h" #include "kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.h" #include "kai_matmul_clamp_f32_qai8dxp1vlx4_qsi8cxp4vlx4_1vlx4vl_sme2_mopa.h" @@ -20,6 +19,7 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p8x8_16x8_sve_i8mm.h" #include "kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p8x8_1x8_sve_dotprod.h" +#include "kai_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h" @@ -31,6 +31,7 @@ #include "kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.h" #include "kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.h" #include "kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon.h" +#include "kai_lhs_pack_f16pmrx2_f32_neon.h" #include "kai_common.h" @@ -309,24 +310,24 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { { /* SME GEMM */ /* .kern_info = */ { - /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_mr = */ kai_get_mr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_nr = */ kai_get_nr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_kr = */ kai_get_kr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_sr = */ kai_get_sr_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, - /* .get_lhs_offset_ex = */ &kernel_offs_fn3, - /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, - /* .run_kernel_ex = */ &kernel_run_fn11, + /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_n_step = */ kai_get_n_step_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_mr = */ kai_get_mr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_nr = */ kai_get_nr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_kr = */ kai_get_kr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_sr = */ kai_get_sr_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_dst_offset = */ kai_get_dst_offset_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_f16p1vlx2_qsi4c32p4vlx2_1vlx4vl_sme2_mopa, + /* .get_lhs_offset_ex = */ &kernel_offs_fn3, + /* .get_rhs_packed_offset_ex = */ &kernel_offs_fn3, + /* .run_kernel_ex = */ &kernel_run_fn11, }, /* .gemm_lhs_info = */ { - /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, - /* .get_packed_offset_ex = */ &lhs_offs_fn6, - /* .packed_size_ex = */ &lhs_ps_fn6, - /* .pack_func_ex = */ &lhs_pack_float_fn10, + /* .get_offset = */ kai_get_lhs_offset_lhs_pack_f16pmrx2_f32_neon, + /* .get_packed_offset_ex = */ &lhs_offs_fn6, + /* .packed_size_ex = */ &lhs_ps_fn6, + /* .pack_func_ex = */ &lhs_pack_void_fn10, }, /* SME GEMV */ /* .kern_info = */ { diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index da412fd009b..c89e5076f26 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -533,7 +533,7 @@ class tinyBLAS { if constexpr (RN > 1) { return mnpack(m, n, SIZE_N, BN); } else { - GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N); GGML_ASSERT(false); // we have miss something. } } @@ -711,7 +711,7 @@ class tinyBLAS_RVV { if constexpr (RN > 1) { return mnpack(m, n, SIZE_N, BN); } else { - GGML_LOG_ERROR("mnpack<%d, %d> bloc size not supported\n", RM, (int)SIZE_N); + GGML_LOG_ERROR("mnpack<%d, %d> block size not supported\n", RM, (int)SIZE_N); GGML_ASSERT(false); // we have miss something. } } @@ -2497,7 +2497,7 @@ class tinyBLAS_Q0_PPC { for (int r = 0; r < 8; r++) { const block_q4_0 * current_blk = rows_base[r] + blk; vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); - vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs)); + vector signed char v_qs = vec_xl(0, (const vector signed char *)current_blk->qs); vector signed char c1, c2; unpack_q4_to_q8(v_qs, c1, c2); convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); @@ -2611,14 +2611,14 @@ class tinyBLAS_Q0_PPC { i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - c5[1] = reinterpret_cast(vec_xl(0, aoffset5->qs)); - c6[1] = reinterpret_cast(vec_xl(0, aoffset6->qs)); - c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); - c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); + c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); + c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs); + c5[1] = vec_xl(0, (const vector signed char *)aoffset5->qs); + c6[1] = vec_xl(0, (const vector signed char *)aoffset6->qs); + c7[1] = vec_xl(0, (const vector signed char *)aoffset7->qs); + c8[1] = vec_xl(0, (const vector signed char *)aoffset8->qs); process_q4_elements(c1, & comparray[0]); process_q4_elements(c2, & comparray[1]); @@ -2657,10 +2657,10 @@ class tinyBLAS_Q0_PPC { i = (cols >> 2); if (i > 0) { do { - c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); - c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); + c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); + c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + c4[1] = vec_xl(0, (const vector signed char *)aoffset4->qs); process_q4_elements(c1, & comparray[0]); process_q4_elements(c2, & comparray[1]); @@ -2686,9 +2686,9 @@ class tinyBLAS_Q0_PPC { if (i > 0) { do { switch(rows) { - case 3: c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); - case 2: c2[1] = reinterpret_cast(vec_xl(0, aoffset2->qs)); - case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); + case 3: c3[1] = vec_xl(0, (const vector signed char *)aoffset3->qs); + case 2: c2[1] = vec_xl(0, (const vector signed char *)aoffset2->qs); + case 1: c1[1] = vec_xl(0, (const vector signed char *)aoffset1->qs); break; } process_q4_elements(c1, & comparray[0]); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index b7a70e06f1d..2c372f9635b 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -375,7 +375,7 @@ static void ggml_compute_forward_dup_bytes( const size_t rs = ne00 * type_size; if (nb00 == type_size) { - // src0 is contigous on first dimension, copy by rows + // src0 is contiguous on first dimension, copy by rows for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { id += rs * ir0; @@ -1795,7 +1795,7 @@ void ggml_compute_forward_repeat( { ggml_compute_forward_repeat_f32(params, dst); } break; - // TODO: templateify the implemenation and support for I64 + // TODO: templateify the implementation and support for I64 // ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225 //case GGML_TYPE_I64: // { @@ -2129,12 +2129,12 @@ static void ggml_compute_forward_gelu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2176,13 +2176,13 @@ static void ggml_compute_forward_gelu_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2325,12 +2325,12 @@ static void ggml_compute_forward_gelu_erf_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2372,13 +2372,13 @@ static void ggml_compute_forward_gelu_erf_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2444,12 +2444,12 @@ static void ggml_compute_forward_gelu_quick_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2491,13 +2491,13 @@ static void ggml_compute_forward_gelu_quick_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2563,12 +2563,12 @@ static void ggml_compute_forward_silu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*(dst->nb[1])))[k]; GGML_UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2610,13 +2610,13 @@ static void ggml_compute_forward_silu_f16( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])))[k]; + const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); GGML_UNUSED(v); assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -2766,7 +2766,7 @@ static void ggml_compute_forward_silu_back_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2802,7 +2802,7 @@ static void ggml_compute_forward_silu_back_f16( (ggml_fp16_t *) ((char *) src1->data + i1*(src1->nb[1])), (ggml_fp16_t *) ((char *) grad->data + i1*(grad->nb[1]))); - #ifndef NDEBUG +#ifndef NDEBUG for (int k = 0; k < nc; k++) { const float x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k]; const float v = GGML_CPU_FP16_TO_FP32(x); @@ -2810,7 +2810,7 @@ static void ggml_compute_forward_silu_back_f16( assert(!isnan(v)); assert(!isinf(v)); } - #endif +#endif // NDEBUG } } @@ -2893,7 +2893,7 @@ static void ggml_compute_forward_reglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -2953,7 +2953,7 @@ static void ggml_compute_forward_reglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3036,7 +3036,7 @@ static void ggml_compute_forward_geglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3096,7 +3096,7 @@ static void ggml_compute_forward_geglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3179,7 +3179,7 @@ static void ggml_compute_forward_swiglu_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3239,7 +3239,7 @@ static void ggml_compute_forward_swiglu_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3330,7 +3330,7 @@ static void ggml_compute_forward_swiglu_oai_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3409,7 +3409,7 @@ static void ggml_compute_forward_geglu_erf_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3469,7 +3469,7 @@ static void ggml_compute_forward_geglu_erf_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -3552,7 +3552,7 @@ static void ggml_compute_forward_geglu_quick_f32( assert(!isnan(x)); assert(!isinf(x)); } -#endif +#endif // NDEBUG } } @@ -3612,7 +3612,7 @@ static void ggml_compute_forward_geglu_quick_f16( assert(!isnan(v)); assert(!isinf(v)); } -#endif +#endif // NDEBUG } } @@ -5303,7 +5303,7 @@ static void ggml_compute_forward_soft_max_f32( //printf("p[%d] = %f\n", i, p[i]); assert(!isnan(wp[i])); } -#endif +#endif // NDEBUG float max = -INFINITY; ggml_vec_max_f32(ne00, &max, wp); @@ -5328,7 +5328,7 @@ static void ggml_compute_forward_soft_max_f32( assert(!isnan(dp[i])); assert(!isinf(dp[i])); } -#endif +#endif // NDEBUG } } } @@ -5402,7 +5402,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32( assert(!isnan(dy[i])); assert(!isnan(y[i])); } -#endif +#endif // NDEBUG // Jii = yi - yi*yi // Jij = -yi*yj // J = diag(y)-y.T*y @@ -5435,7 +5435,7 @@ static void ggml_compute_forward_soft_max_ext_back_f32( assert(!isnan(dx[i])); assert(!isinf(dx[i])); } -#endif +#endif // NDEBUG } } @@ -5803,28 +5803,33 @@ static void ggml_compute_forward_rope_flt( const int32_t * pos = (const int32_t *) src1->data; + int64_t last_i2 = -1; + for (int64_t i3 = 0; i3 < ne3; i3++) { // batch for (int64_t i2 = 0; i2 < ne2; i2++) { // seq-len - - float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; - if (!mrope_used) { - const int64_t p = pos[i2]; - ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - else { - const int64_t p_t = pos[i2]; - const int64_t p_h = pos[i2 + ne2]; - const int64_t p_w = pos[i2 + ne2 * 2]; - const int64_t p_e = pos[i2 + ne2 * 3]; - ggml_mrope_cache_init( - p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, - freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); - } - for (int64_t i1 = 0; i1 < ne1; i1++) { // attn-heads - if (ir++ < ir0) continue; + if (ir++ < ir0) continue; // skip rows mapped to other threads if (ir > ir1) break; + float * cache = (float *) params->wdata + (ne0 + CACHE_LINE_SIZE_F32)*ith; + if (last_i2 != i2) { + if (!mrope_used) { + const int64_t p = pos[i2]; + ggml_rope_cache_init(p, freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + else { + const int64_t p_t = pos[i2]; + const int64_t p_h = pos[i2 + ne2]; + const int64_t p_w = pos[i2 + ne2 * 2]; + const int64_t p_e = pos[i2 + ne2 * 3]; + ggml_mrope_cache_init( + p_t, p_h, p_w, p_e, sections, is_imrope, is_vision, + freq_scale, freq_factors, corr_dims, ne0, ext_factor, attn_factor, cache, sin_sign, theta_scale); + } + + last_i2 = i2; + } + T * src = (T *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); T * dst_data = (T *)((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1); @@ -10700,7 +10705,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( assert(!isnan(s0[i])); assert(!isnan(s1[i])); } -#endif +#endif // NDEBUG float max = -INFINITY; ggml_vec_max_f32(nc, &max, s0); @@ -10719,7 +10724,7 @@ static void ggml_compute_forward_cross_entropy_loss_f32( assert(!isnan(st[i])); assert(!isinf(st[i])); } -#endif +#endif // NDEBUG } sums[ith] = sum_thread; ggml_barrier(params->threadpool); @@ -10792,7 +10797,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(s0[i])); assert(!isnan(s1[i])); } -#endif +#endif // NDEBUG // soft_max float max = -INFINITY; @@ -10810,7 +10815,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( assert(!isnan(ds0[i])); assert(!isinf(ds0[i])); } -#endif +#endif // NDEBUG } } diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 5edba4212f6..02c3cc3119b 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -3032,7 +3032,7 @@ template src[1])); - size = GGML_PAD(size, sizeof(int64_t)); // + padding for next bloc. + size = GGML_PAD(size, sizeof(int64_t)); // + padding for next block. const int64_t ne02 = op->src[0]->ne[2]; // n_as, n_expert const int64_t ne12 = op->src[1]->ne[2]; // n_tokens @@ -3297,7 +3297,7 @@ template wdata; auto * wdata_src1_end = (char *)wdata + GGML_PAD(nbw3, sizeof(int64_t)); - // total of [n_as][ne12 + 1] elemets of type mmid_row_mapping (2*int32_t = int64_t) + // total of [n_as][ne12 + 1] elements of type mmid_row_mapping (2*int32_t = int64_t) auto * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as] struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *) (matrix_row_counts + n_as); // [n_as][ne12] diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 09b6d5db6a0..b70492c7d6c 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ return; } - const int64_t i01 = blockIdx.y; - - for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { - const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); - const int64_t i02 = dm.y; - const int64_t i03 = dm.x; - - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; - - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); - - const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } } @@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { const int64_t ne0203 = ne02*ne03; const uint3 ne02_fdv = init_fastdiv_values(ne02); - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } @@ -628,18 +628,18 @@ static __global__ void convert_unary( return; } - const int64_t i01 = blockIdx.y; - const src_t * x = (const src_t *) vx; - for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { - const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); - const int64_t i02 = dm.y; - const int64_t i03 = dm.x; + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } } @@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { const int64_t ne0203 = ne02*ne03; const uint3 ne02_fdv = init_fastdiv_values(ne02); - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index beb7e32e4fc..fff70c8eb89 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -1215,7 +1215,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } // If attention sinks are used, potentially re-scale if KQ_max is small. - // Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum + // Also add the sink as a value to KQ_rowsum, this is done after synchronization of KQ_rowsum // so it's being done unconditionally for every thread. if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) { float KQ_max_scale[cols_per_thread]; diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index 3f4a78cc6e5..7cbe32633e5 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -10,7 +10,7 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { return 128; } -// Currenlty llvm with the amdgcn target does not support unrolling loops +// Currently llvm with the amdgcn target does not support unrolling loops // that contain a break that can not be resolved at compile time. #ifdef __clang__ #pragma clang diagnostic push diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index cd3bfd4051a..aaf711a618c 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -18,7 +18,7 @@ #if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 #define GGML_USE_WMMA_FATTN #elif defined(RDNA4) -#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#warning "rocwmma fattn is not supported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" #endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 #endif // defined(GGML_HIP_ROCWMMA_FATTN) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7e6d3303549..54dc43bc088 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2803,11 +2803,14 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_buffer_t buf_src = src->view_src ? src->view_src->buffer : src->buffer; ggml_backend_buffer_t buf_dst = dst->view_src ? dst->view_src->buffer : dst->buffer; - if (!ggml_backend_is_cuda(backend_src) || !ggml_backend_is_cuda(backend_dst)) { + //enables async copies from CPU to CUDA, instead of only CUDA-to-CUDA + bool copy_from_host = ggml_backend_buffer_is_host(buf_src) && ggml_backend_dev_type(backend_src->device) == GGML_BACKEND_DEVICE_TYPE_CPU; + + if (!(copy_from_host || ggml_backend_is_cuda(backend_src)) || !ggml_backend_is_cuda(backend_dst)) { return false; } - if (!ggml_backend_buffer_is_cuda(src->buffer) || !ggml_backend_buffer_is_cuda(dst->buffer)) { + if (!(copy_from_host || ggml_backend_buffer_is_cuda(buf_src)) || !ggml_backend_buffer_is_cuda(dst->buffer)) { return false; } @@ -2818,14 +2821,17 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_ ggml_backend_cuda_buffer_context * buf_ctx_src = (ggml_backend_cuda_buffer_context *)buf_src->context; ggml_backend_cuda_buffer_context * buf_ctx_dst = (ggml_backend_cuda_buffer_context *)buf_dst->context; - if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) { + if ((copy_from_host && cuda_ctx_dst->device != buf_ctx_dst->device) || + !copy_from_host && (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device)) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__); #endif return false; } - if (backend_src != backend_dst) { + if (copy_from_host) { + CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyHostToDevice, cuda_ctx_dst->stream())); + } else if (backend_src != backend_dst) { // copy on src stream if (cuda_ctx_src->device == cuda_ctx_dst->device) { CUDA_CHECK(cudaMemcpyAsync(dst->data, src->data, ggml_nbytes(dst), cudaMemcpyDeviceToDevice, cuda_ctx_src->stream())); @@ -3330,7 +3336,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } - //rms_norm kernel assumes contigous rows + //rms_norm kernel assumes contiguous rows if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { return false; } @@ -3342,6 +3348,46 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY + && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + + if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + return true; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL + && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { + const ggml_tensor * unary = cgraph->nodes[node_idx]; + const ggml_tensor * mul = cgraph->nodes[node_idx+1]; + + if (ggml_get_unary_op(unary) != unary_ops.begin()[0]) { + return false; + } + + if (unary->type != GGML_TYPE_F32 && unary->type != GGML_TYPE_F16) { + return false; + } + + if (unary->type != mul->type) { + return false; + } + + const ggml_tensor * other = (mul->src[0] == unary) ? mul->src[1] : mul->src[0]; + if (other->type != unary->type) { + return false; + } + if (!ggml_is_contiguous_1(other) || !ggml_is_contiguous_1(unary->src[0]) || !ggml_are_same_shape(other, unary)) { + return false; + } + + return true; + } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SCALE && ops.begin()[1] == GGML_OP_UNARY && ops.begin()[2] == GGML_OP_SCALE && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_TANH) { const ggml_tensor *scale = cgraph->nodes[node_idx]; @@ -3366,6 +3412,69 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return false; } +// returns whether the write (out) nodes overwrite the read nodes in operation +static bool ggml_cuda_check_fusion_memory_ranges(ggml_cgraph * cgraph, + int node_idx, + int node_count, + int * out_nodes, + int out_count) { + auto nodes_overlap = [&](const ggml_tensor * a, const ggml_tensor * b) { + const int64_t a_start = (int64_t) a->data; + const int64_t a_end = a_start + ggml_nbytes(a); + + const int64_t b_start = (int64_t) b->data; + const int64_t b_end = b_start + ggml_nbytes(b); + + if ((b_start <= a_start && a_start < b_end) || (a_start <= b_start && b_start < a_end)) { + return true; + } + + return false; + }; + + bool is_ok = true; + // for nrows=1, all fusion operations correctly read the src before writing dst or do it elementwise, so we should be ok + if (ggml_nrows(cgraph->nodes[node_idx]) == 1) { + return true; + } + + for (int i = 0; i < out_count; ++i) { + const ggml_tensor * dst = cgraph->nodes[out_nodes[i]]; + + for (int j = node_idx; j < node_idx + node_count; ++j) { + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * src = cgraph->nodes[j]->src[src_idx]; + + if (!src || src->op == GGML_OP_NONE) { + continue; + } + + if (nodes_overlap(dst, src)) { + bool found = false; + + for (int k = node_idx; k < j; ++k) { + if (cgraph->nodes[k] == src) { + found = true; + break; + } + } + + if (!found) { + is_ok = false; + break; + } + } + } + } + } + + return is_ok; +} + static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required, const void * graph_key) { bool graph_evaluated_or_captured = false; @@ -3562,7 +3671,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud out_nodes[1] = i + ops.size() - 1; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) { + ggml_cuda_should_use_topk_moe(node, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; @@ -3577,7 +3687,8 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud int out_nodes[2] = { i + 1, i + 5 }; if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) && - ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) { + ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids) && + ggml_cuda_check_fusion_memory_ranges(cgraph, i, ops.size(), out_nodes, 2)) { ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args); i += ops.size() - 1; continue; @@ -3830,6 +3941,20 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SILU }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SIGMOID }) || + ggml_cuda_can_fuse(cgraph, i, { GGML_OP_UNARY, GGML_OP_MUL }, { GGML_UNARY_OP_SOFTPLUS })) { + ggml_cuda_op_unary_mul(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { i += 2; ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node); diff --git a/ggml/src/ggml-cuda/quantize.cu b/ggml/src/ggml-cuda/quantize.cu index a8c68e44b16..4300ffc148c 100644 --- a/ggml/src/ggml-cuda/quantize.cu +++ b/ggml/src/ggml-cuda/quantize.cu @@ -235,7 +235,7 @@ static __global__ void quantize_mmq_q8_1( q.z = roundf(xi.z*d_inv); q.w = roundf(xi.w*d_inv); - // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: + // Write back 4 int8 values as a single 32 bit value for better memory bandwidth: char4 * yqs4 = (char4 *) y[ib].qs; yqs4[iqs/4] = q; diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index dc06d06930e..285c0e9543a 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -46,7 +46,7 @@ struct soft_max_params { }; // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. -// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here. #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-cuda/solve_tri.cu b/ggml/src/ggml-cuda/solve_tri.cu index 177ffc268f1..07ca33f513b 100644 --- a/ggml/src/ggml-cuda/solve_tri.cu +++ b/ggml/src/ggml-cuda/solve_tri.cu @@ -83,7 +83,7 @@ static void solve_tri_f32_cublas(ggml_backend_cuda_context & ctx, // ====================== // When ncols_template == 0 the bounds for the loops in this function are not // known and can't be unrolled. As we want to keep pragma unroll for all other -// cases we supress the clang transformation warning here. +// cases we suppress the clang transformation warning here. #ifdef __clang__ # pragma clang diagnostic push # pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 6d5ea704c65..85e82b5a422 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,6 +1,7 @@ #include "ssm-conv.cuh" +#include "unary.cuh" -template +template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, @@ -41,11 +42,11 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } - y_block[i * stride_y + tid] = sumf; + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template +template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, @@ -65,36 +66,46 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const int stride_w = src1_nb1 / sizeof(float); const int stride_y = dst_nb1 / sizeof(float); - float x[d_conv] = { 0.0f }; - float w[d_conv] = { 0.0f }; + const int64_t local_n_t = min(split_n_t, n_t - bidz * split_n_t); + const int n_cols = d_conv - 1 + split_n_t; + + extern __shared__ float smem[]; + constexpr int load_cols = d_conv - 1 + split_n_t; + constexpr int total_elems = split_d_inner * load_cols; + int row = tid / load_cols; + int col = tid % load_cols; #pragma unroll - for (size_t j = 0; j < d_conv; j++) { - w[j] = w_block[tid * stride_w + j]; + for (int idx = tid; idx < total_elems; idx += split_d_inner) { + if (row < (int)split_d_inner) { + smem[row * n_cols + col] = x_block[row * stride_x + col]; + } + + col += split_d_inner; + row += col / load_cols; + col = col % load_cols; } + __syncthreads(); + // Load weights into registers (done once, small) + float w[d_conv] = { 0.0f }; #pragma unroll - for (int64_t i = 0; i < split_n_t; i++) { - if (bidz * split_n_t + i < n_t) { - float sumf = 0.0f; - - if (i == 0) { - for (size_t j = 0; j < d_conv; j++) { - x[j] = x_block[tid * stride_x + j]; - } - } else { - x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1]; - } + for (size_t j = 0; j < d_conv; j++) { + w[j] = w_block[tid * stride_w + j]; + } + // Compute from shared memory + for (int64_t i = 0; i < local_n_t; i++) { + float sumf = 0.0f; #pragma unroll - for (size_t j = 0; j < d_conv; j++) { - sumf += x[(i + j) % d_conv] * w[j]; - } - y_block[i * stride_y + tid] = sumf; + for (size_t j = 0; j < d_conv; j++) { + sumf += smem[tid * n_cols + i + j] * w[j]; } + y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } +template static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, @@ -106,12 +117,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); - ssm_conv_long_token_f32<<>>( + const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); + ssm_conv_long_token_f32<<>>( src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -124,27 +136,36 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_silu = silu_dst != nullptr; + + // When fusing, write to silu_dst (the node downstream references). + const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; const int64_t nc = src1->ne[0]; // d_conv const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = dst->ne[1]; // tokens per sequence - const int64_t n_s = dst->ne[2]; // number of sequences in the batch + const int64_t n_t = out->ne[1]; // tokens per sequence + const int64_t n_s = out->ne[2]; // number of sequences in the batch - GGML_ASSERT(dst->ne[0] == nr); + GGML_ASSERT(out->ne[0] == nr); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float)); const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; - float * dst_d = (float *) dst->data; + float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(dst->type == GGML_TYPE_F32); - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1], - dst->nb[2], nc, nr, n_t, n_s, stream); + GGML_ASSERT(out->type == GGML_TYPE_F32); + if (fuse_silu) { + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } else { + ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index 8e6c1f00bfa..f96a1cd2484 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); diff --git a/ggml/src/ggml-cuda/topk-moe.cu b/ggml/src/ggml-cuda/topk-moe.cu index 08a88990dde..3020e5c7433 100644 --- a/ggml/src/ggml-cuda/topk-moe.cu +++ b/ggml/src/ggml-cuda/topk-moe.cu @@ -119,6 +119,18 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * } } + // Sanitize NaN to -FLT_MAX so the iterative argmax produces unique expert IDs. + // NaN comparisons always return false, which would cause the same expert to be + // selected repeatedly. -FLT_MAX compares normally and is still excluded by the + // -INFINITY sentinel used after each selection round. + // More relevant for the cuBLAS path. See https://github.com/ggml-org/llama.cpp/issues/19659 +#pragma unroll + for (int i = 0; i < experts_per_thread; i++) { + if (__isnanf(wt[i])) { + wt[i] = -FLT_MAX; + } + } + // selection_wt is only needed when bias is present (selection uses wt + bias) // when no bias, we use wt directly for both selection and weight values float selection_wt[has_bias ? experts_per_thread : 1]; diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index d4866067a4f..4ad30fa1f35 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -560,3 +560,58 @@ void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) leaky_relu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), negative_slope, stream); } } + +/* fused unary + mul */ + +template +static void ggml_cuda_op_unary_mul_impl(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + // unary_node: UNARY op applied to unary_node->src[0] + // mul_node: MUL(a, b) where one of a/b is unary_node + // Output goes to mul_node->data + + const ggml_tensor * unary_src = unary_node->src[0]; // input to the unary op + const ggml_tensor * other_src = (mul_node->src[0] == unary_node) ? mul_node->src[1] : mul_node->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(unary_src)); + GGML_ASSERT(unary_src->nb[0] == ggml_element_size(unary_src)); + GGML_ASSERT(ggml_is_contiguous_1(other_src)); + GGML_ASSERT(other_src->nb[0] == ggml_element_size(other_src)); + GGML_ASSERT(ggml_are_same_shape(unary_src, other_src)); + + GGML_ASSERT(unary_src->type == GGML_TYPE_F32 || unary_src->type == GGML_TYPE_F16); + GGML_ASSERT(unary_src->type == other_src->type); + GGML_ASSERT(unary_src->type == mul_node->type); + + cudaStream_t stream = ctx.stream(); + + const int64_t k = ggml_nelements(mul_node); + const int64_t nc = unary_src->ne[0]; + const int64_t unary_stride = unary_src->nb[1]; + const int64_t other_stride = other_src->nb[1]; + + if (unary_src->type == GGML_TYPE_F16) { + unary_gated_cuda((const half *) unary_src->data, (const half *) other_src->data, + (half *) mul_node->data, k, nc, + unary_stride / sizeof(half), other_stride / sizeof(half), stream); + } else { + unary_gated_cuda((const float *) unary_src->data, (const float *) other_src->data, + (float *) mul_node->data, k, nc, + unary_stride / sizeof(float), other_stride / sizeof(float), stream); + } +} + +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node) { + switch (ggml_get_unary_op(unary_node)) { + case GGML_UNARY_OP_SILU: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SIGMOID: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + case GGML_UNARY_OP_SOFTPLUS: + ggml_cuda_op_unary_mul_impl(ctx, unary_node, mul_node); + break; + default: + GGML_ABORT("Unsupported unary op for fused unary+mul"); + } +} diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index 609046e5694..f1dd2183a6c 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -89,6 +89,8 @@ void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_unary_mul(ggml_backend_cuda_context & ctx, ggml_tensor * unary_node, ggml_tensor * mul_node); + __device__ __forceinline__ float ggml_cuda_op_silu_single(float x) { return x / (1.0f + expf(-x)); } diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 7a44443a8a3..d6e9776b878 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -139,7 +139,7 @@ struct ggml_hexagon_session { }; void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_buffer *bufs, uint32_t n_bufs, bool sync) { - // Bump pending flag (cleared in the session::flush once we get the responce) + // Bump pending flag (cleared in the session::flush once we get the response) this->op_pending++; // atomic inc int err = dspqueue_write(this->queue, @@ -443,7 +443,7 @@ static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales ggml_half * d = (ggml_half *) (y_d + i * dblk_size); @@ -503,7 +503,7 @@ static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); @@ -552,7 +552,7 @@ static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].d = 0; @@ -770,7 +770,7 @@ static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales ggml_half * d = (ggml_half *) (y_d + i * dblk_size); @@ -829,7 +829,7 @@ static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); @@ -878,7 +878,7 @@ static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].d = 0; @@ -1120,7 +1120,7 @@ static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Repack the scales uint8_t * e = (uint8_t *) (y_e + i * eblk_size); @@ -1180,7 +1180,7 @@ static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) // Repack the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); @@ -1229,7 +1229,7 @@ static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { // Init the scales // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overriden by the scales. + // the last block is truncated and overridden by the scales. for (int i = 0; i < nb; i++) { // Unpack the scales x[i * 8 + 0].e = 0; @@ -1865,15 +1865,26 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (src0->type != GGML_TYPE_F32) { - return false; + if (src0->type == GGML_TYPE_F32) { + if (src1->type != GGML_TYPE_F32) { + return false; + } + if (dst->type != GGML_TYPE_F32) { + return false; + } } - if (src1->type != GGML_TYPE_F32) { - return false; + else if (src0->type == GGML_TYPE_F16) { + if (src1->type != GGML_TYPE_F16) { + return false; + } + if (dst->type != GGML_TYPE_F16) { + return false; + } } - if (dst->type != GGML_TYPE_F32) { + else { return false; } + if (!ggml_are_same_shape(src0, dst)) { return false; } @@ -2141,6 +2152,44 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess return true; } +static bool ggml_hexagon_supported_ssm_conv(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * dst = op; + + // Only support FP32 for now + if (src0->type != GGML_TYPE_F32 || src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + return false; + } + + // Check IO tensor shapes and dims + if (src0->ne[3] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || dst->ne[3] != 1) { + return false; // src0 should be effectively 3D + } + + const int d_conv = src1->ne[0]; + const int d_inner = src0->ne[1]; + const int n_t = dst->ne[1]; + const int n_s = dst->ne[2]; + + if (src0->ne[0] != d_conv - 1 + n_t || src0->ne[1] != d_inner || src0->ne[2] != n_s) { + return false; + } + if (src1->ne[0] != d_conv || src1->ne[1] != d_inner) { + return false; + } + if (dst->ne[0] != d_inner || dst->ne[1] != n_t || dst->ne[2] != n_s) { + return false; + } + + // TODO: add support for non-contiguous tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + enum dspqbuf_type { DSPQBUF_TYPE_DSP_WRITE_CPU_READ = 0, DSPQBUF_TYPE_CPU_WRITE_DSP_READ, @@ -2457,6 +2506,17 @@ static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buf return n_bufs; } +static inline size_t init_ssm_conv_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SSM_CONV; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CONSTANT); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2595,6 +2655,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SSM_CONV: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2670,7 +2734,7 @@ static std::vector ggml_hexagon_graph_optimize_reorder(const std::vectorn_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); size_t src0_row_size = src0->nb[1]; size_t src1_row_size = src1->nb[1]; // zero bytes if src1 is not used @@ -748,13 +748,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - uint32_t n_jobs = MIN(n_threads, src0_nrows); - // Prepare context struct htp_act_context actx; actx.octx = octx; - actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + actx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; actx.src0_row_size = src0_row_size; actx.src1_row_size = src1_row_size; @@ -794,7 +792,7 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { actx.data_src1 = data_src1; actx.data_dst = (uint8_t *) dst->data; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c index a4cee980be8..170220e8f80 100644 --- a/ggml/src/ggml-hexagon/htp/argsort-ops.c +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -241,6 +241,9 @@ int op_argsort(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } + const uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + const uint32_t n_threads = MIN(total_rows, octx->n_threads); + // Allocate scratchpad // We need 1 row of float + 1 row of int32 per thread. uint32_t ne00 = octx->src0.ne[0]; @@ -251,7 +254,7 @@ int op_argsort(struct htp_ops_context * octx) { // Make sure we round up to 256 for alignment requirements spad_per_thread = hex_round_up(spad_per_thread, 256); - size_t total_spad_size = spad_per_thread * octx->n_threads; + size_t total_spad_size = spad_per_thread * n_threads; if (octx->ctx->vtcm_size < total_spad_size) { FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); @@ -267,15 +270,12 @@ int op_argsort(struct htp_ops_context * octx) { octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], octx->src0.data, octx->dst.data); - uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; - uint32_t n_jobs = MIN(total_rows, octx->n_threads); - struct htp_argsort_context actx; actx.octx = octx; - actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + actx.nrows_per_thread = (total_rows + n_threads - 1) / n_threads; // Run jobs - worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index 00dbcf87986..ec90f22de52 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -95,43 +95,87 @@ static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_ } // Macro for scalar op switch -#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ - case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ - default: break; \ +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, *(float *)VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (*(float *)VAL), N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + case HTP_OP_DIV: hvx_div_scalar_f16_aa(DST, SRC, *(_Float16 *)VAL, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (All Aligned) -#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) -#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // Macro for vector op switch (All Unaligned - generic loop used in element repeat) -#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ - switch (octx->op) { \ - case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ - case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ - default: break; \ +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, TYPE, N) \ + if(TYPE == HTP_TYPE_F32) { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ + } \ + else { \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f16_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f16_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } \ } // 1. Scalar src1 (ne10 == 1) @@ -140,6 +184,8 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -170,7 +216,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -199,13 +245,12 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { for (uint32_t r = 0; r < current_block_size; r++) { uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - float val = *(float *)src1_ptr; + COMPUTE_SCALAR_OP(r_dst, r_src0, src1_ptr, src0_type, ne00); src1_ptr += s1_stride; - COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -216,7 +261,7 @@ static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -230,6 +275,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -268,8 +315,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -284,7 +331,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } uint32_t i03, i02, i01, rem; @@ -293,7 +340,7 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi i02 = fastdiv(rem, &bctx->dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -310,8 +357,8 @@ static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, voi uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); - dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } @@ -326,6 +373,8 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -359,7 +408,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -373,7 +422,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; - COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, src0_type, ne00); } uint32_t i03, i02, i01, rem; @@ -382,7 +431,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, i02 = fastdiv(rem, &bctx->dim1_div); i01 = rem - i02 * ne01; uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -392,7 +441,7 @@ static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -406,6 +455,8 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t row_size_bytes = (src0_type == HTP_TYPE_F32) ? ne00 * sizeof(float) : ne00 * sizeof(_Float16); const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -435,7 +486,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -462,11 +513,11 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; // Read src1 from DDR (unaligned) - COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, src0_type, ne00); } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -476,7 +527,7 @@ static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -490,6 +541,9 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; + const uint32_t src0_type = octx->src0.type; + const uint32_t elem_size_bytes = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const uint32_t row_size_bytes = ne00 * elem_size_bytes;; const uint32_t total_rows = ne01 * ne02 * ne03; const uint32_t start_row = bctx->nrows_per_thread * ith; const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); @@ -519,7 +573,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); - dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, row_size_bytes, current_block_size); ir_prefetch += current_block_size; spad_idx ^= 1; } @@ -549,12 +603,12 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * for (uint32_t c = 0; c < ne00; c += ne10) { uint32_t len = MIN(ne10, ne00 - c); // Use UUU for speed and simplicity - COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + COMPUTE_VECTOR_OP_UUU(r_dst + c * elem_size_bytes, r_src0 + c * elem_size_bytes, r_src1_row, src0_type, len); } } uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; - dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, row_size_bytes, current_block_size); if (ir_prefetch < end_row) { uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); @@ -564,7 +618,7 @@ static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * p02 = fastdiv(prem, &bctx->dim1_div); p01 = prem - p02 * ne01; uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; - dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, row_size_bytes, next_block_size); ir_prefetch += next_block_size; } ir += current_block_size; @@ -672,18 +726,20 @@ static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { dma_queue_flush(q); } -static int execute_op_binary_f32(struct htp_ops_context * octx) { +static int execute_op_binary(struct htp_ops_context * octx) { const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); // Use packed row sizes for VTCM allocation - const size_t src0_row_size = src0->ne[0] * sizeof(float); - const size_t src1_row_size = src1->ne[0] * sizeof(float); - const size_t dst_row_size = dst->ne[0] * sizeof(float); + const uint32_t src0_type = octx->src0.type; + const size_t elem_size = (src0_type == HTP_TYPE_F32) ? sizeof(float) : sizeof(_Float16); + const size_t src0_row_size = src0->ne[0] * elem_size; + const size_t src1_row_size = src1->ne[0] * elem_size; + const size_t dst_row_size = dst->ne[0] * elem_size; // Align to VLEN const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); @@ -694,7 +750,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { bool is_scalar = !is_add_id && (src1->ne[0] == 1); // Determine which kernel we will use to alloc memory and dispatch - bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + bool use_vector_same = !is_add_id && !is_scalar && ((src0->nb[1] % VLEN) == 0) && (src1->ne[0] == src0->ne[0]) && (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); @@ -726,7 +782,7 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { } if (rows_per_buffer < 1) { - FARF(ERROR, "binary-f32: VTCM too small\n"); + FARF(ERROR, "binary: VTCM too small\n"); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -761,16 +817,14 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - uint32_t n_jobs = MIN(n_threads, src0_nrows); - dma_queue * q = octx->ctx->dma[0]; if (is_row_bcast) { - dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * elem_size, 1); } struct htp_binary_context bctx; bctx.octx = octx; - bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bctx.block_max = rows_per_buffer; bctx.src0_row_size_aligned = src0_row_size_aligned; bctx.src1_row_size_aligned = src1_row_size_aligned; @@ -814,14 +868,24 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { dma_queue_pop(q); } - worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_threads); return HTP_STATUS_OK; } int op_binary(struct htp_ops_context * octx) { - if (octx->src0.type == HTP_TYPE_F32) { - return execute_op_binary_f32(octx); + + // Does not support permutations of src1 + const struct htp_tensor * src1 = &octx->src1; + if (src1->nb[1] < src1->nb[0]) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t src0_type = octx->src0.type; + if ((src0_type == HTP_TYPE_F32) || (src0_type == HTP_TYPE_F16)) { + return execute_op_binary(octx); } + return HTP_STATUS_NO_SUPPORT; } + diff --git a/ggml/src/ggml-hexagon/htp/cpy-ops.c b/ggml/src/ggml-hexagon/htp/cpy-ops.c index 559ca183789..a40d866b9c3 100644 --- a/ggml/src/ggml-hexagon/htp/cpy-ops.c +++ b/ggml/src/ggml-hexagon/htp/cpy-ops.c @@ -202,6 +202,8 @@ static void cpy_work_func(unsigned int n, unsigned int i, void *data) { int op_cpy(struct htp_ops_context * octx) { cpy_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + struct htp_copy_context ct; ct.octx = octx; @@ -227,8 +229,7 @@ int op_cpy(struct htp_ops_context * octx) { const bool transposed = (nb00 > nb01) || (nb0 > nb1); const bool sameshape = !transposed && (ne00 == ne0 && ne01 == ne1 && ne02 == ne2 && ne03 == ne3); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - ct.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + ct.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; if (sametype && sameshape) { ct.copy = cpy_thread_sametype_sameshape; @@ -245,7 +246,7 @@ int op_cpy(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, cpy_work_func, &ct, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index 74c777d4c3e..6dc978dd68a 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hvx-dump.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -17,6 +18,16 @@ #include "htp-msg.h" #include "htp-ops.h" +// Must be multiple of 32 +#define FLASH_ATTN_BLOCK_SIZE (32 * 2) + +// This is a bit of a hack because the compiler is strugling to properly inline +// the default hvx_vec_f32_to_f16 with output into the local array. +static void __attribute__((noinline)) hvx_vec_f32_to_f16_a(void *ptr, HVX_Vector v0, HVX_Vector v1) +{ + *(HVX_Vector *) ptr = hvx_vec_f32_to_f16(v0, v1); +} + // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 @@ -25,175 +36,184 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_Vector y_hf = vy[i]; - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, vx[i], vy[i]); } if (nloe) { - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum))); + hvx_vec_store_u(r, 4, rsum); } -static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); +static inline HVX_Vector hvx_dot_f16_f16_aa_rx4(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t nvec, + const size_t nloe) { + const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) (x + stride_x); // fp16 + const HVX_Vector * restrict vx2 = (const HVX_Vector * restrict) (x + stride_x * 2); // fp16 + const HVX_Vector * restrict vx3 = (const HVX_Vector * restrict) (x + stride_x * 3); // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum2_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum3_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; - #pragma unroll(4) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = vy[i]; HVX_Vector x0_hf = vx0[i]; HVX_Vector x1_hf = vx1[i]; + HVX_Vector x2_hf = vx2[i]; + HVX_Vector x3_hf = vx3[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); } if (nloe) { // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x2_hf = Q6_V_vand_QV(bmask, vx2[i]); + HVX_Vector x3_hf = Q6_V_vand_QV(bmask, vx3[i]); + + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); + rsum2_p = hvx_vec_mpyacc_f32_f16(rsum2_p, x2_hf, y_hf); + rsum3_p = hvx_vec_mpyacc_f32_f16(rsum3_p, x3_hf, y_hf); + } + + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum2 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum2_p), Q6_V_hi_W(rsum2_p))); + HVX_Vector rsum3 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum3_p), Q6_V_hi_W(rsum3_p))); - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + HVX_Vector_x4 rsum0123 = { .v = { rsum0, rsum1, rsum2, rsum3 } }; + return hvx_vec_reduce_sum_f32x4(rsum0123); +} - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); +static inline HVX_Vector hvx_dot_f16_f16_aa_rx32(const void * restrict y, + const uint8_t * restrict x, + const size_t stride_x, + const size_t n, + float s) { + + const size_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + const size_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector sums; // initialize at j = 0 + const size_t stride_x_4 = stride_x * 4; + for (uint32_t j = 0; j < VLEN_FP32; j += 4) { + HVX_Vector sums_x4 = hvx_dot_f16_f16_aa_rx4(y, x, stride_x, nvec, nloe); + HVX_VectorPred pred = Q6_Q_vsetq_R(j * SIZEOF_FP32); + sums = Q6_V_vmux_QVV(pred, sums, sums_x4); + x += stride_x_4; } - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); + sums = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), sums); + return Q6_Vsf_equals_Vqf32(sums); } -// MAD: y (F32) += x (F16) * s (F32) -static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { - const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +// MAD: y (F32) += x (F16) * s (F16) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, const __fp16 * restrict s, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S = hvx_vec_splat_f16(s); + HVX_Vector S0 = hvx_vec_splat_f16(*s); uint32_t i = 0; - #pragma unroll(4) + + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); - ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); - ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); } if (nloe) { - HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); - HVX_Vector xs = Q6_V_lo_W(xs_p); - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) -static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, - const void * restrict x0, - const void * restrict x1, - float s0, - float s1, - int n) { - const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; - const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; - HVX_Vector * restrict ptr_y = (HVX_Vector *) y; +// MAD: y (F32) += x0 (F16) * s0 (F16) + x1 (F16) * s1 (F16) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, const void * restrict x0, const void * restrict x1, + const __fp16 * restrict s0, const __fp16 * restrict s1, int n) { + const HVX_Vector * restrict vx0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict vx1 = (const HVX_Vector *) x1; + + HVX_VectorPair * restrict vy_p = (HVX_VectorPair *) y; + HVX_Vector * restrict vy = (HVX_Vector *) y; uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector S0 = hvx_vec_splat_f16(s0); - HVX_Vector S1 = hvx_vec_splat_f16(s1); + HVX_Vector S0 = hvx_vec_splat_f16(*s0); + HVX_Vector S1 = hvx_vec_splat_f16(*s1); uint32_t i = 0; + #pragma unroll(2) for (i = 0; i < nvec; ++i) { - // Multiply x * s -> pair of F32 vectors - HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); - HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); - - HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); - HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); - - ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); - ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx0[i]), S0); + vy_p[i] = hvx_vec_mpyacc_f32_f16(vy_p[i], Q6_Vh_vshuff_Vh(vx1[i]), S1); } if (nloe) { - HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); - HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + HVX_VectorPair xy_p = vy_p[i]; + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx0[i]), S0); + xy_p = hvx_vec_mpyacc_f32_f16(xy_p, Q6_Vh_vshuff_Vh(vx1[i]), S1); - HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); - HVX_Vector xs = xs_p_lo; - i = 2 * i; // index for ptr_y + HVX_Vector xy = Q6_V_lo_W(xy_p); + i = 2 * i; // index for vy - if (nloe >= 32) { - ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - nloe -= 32; ++i; - xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + if (nloe >= VLEN_FP32) { + vy[i] = xy; + nloe -= VLEN_FP32; ++i; xy = Q6_V_hi_W(xy_p); } if (nloe) { - HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); - hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + hvx_vec_store_a(&vy[i], nloe * 4, xy); } } } -#define FLASH_ATTN_BLOCK_SIZE 128 - struct htp_fa_context { const struct htp_ops_context * octx; @@ -226,7 +246,12 @@ struct htp_fa_context { size_t size_v_block; size_t size_m_block; + uint32_t qrows; + uint32_t qrows_per_thread; + bool is_q_fp32; + + uint64_t t_start; }; static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { @@ -296,9 +321,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint32_t nb3 = dst->nb[3]; // total rows in q - const uint32_t nr = neq1*neq2*neq3; - - const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t nr = factx->qrows; + const uint32_t dr = factx->qrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = MIN(ir0 + dr, nr); @@ -337,15 +361,8 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); - const uint32_t h = iq2; // head index - const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; - - HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); - HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); - - // Clear accumulator - hvx_splat_f32_a(spad_a, 0, DV); - float * VKQ32 = (float *) spad_a; + // FARF(HIGH, "fa %u: prefetch Q: ir %u iq1 %u iq2 %u iq3 %u q_row_ptr %p size %u : usec %u", ith, ir, iq1, iq2, iq3, q_row_ptr, size_q_row, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); const __fp16 * mp_base = NULL; if (mask) { @@ -376,8 +393,23 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * // Mask is 1D contiguous for this row dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, ib, iq1, iq2, iq3, + // size_k_row, size_v_row, current_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } + const uint32_t h = iq2; // head index + const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; + + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); + + // Clear accumulator + hvx_splat_f32_a(spad_a, 0, DV); + float * VKQ32 = (float *) (spad_a + 0); + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; if (factx->is_q_fp32) { hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 @@ -393,23 +425,19 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * uint8_t * v_base = dma_queue_pop(dma).dst; // V __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + // FARF(HIGH, "fa %u: process: ir %u ib %u : iq1 %u iq2 %u iq3 %u q_ptr_vtcm %p : usec %u", + // ith, ir, ib, iq1, iq2, iq3, q_ptr_vtcm, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); + // Inner loop processing the block from VTCM uint32_t ic = 0; - // Process in blocks of 32 (VLEN_FP32) - static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); - HVX_Vector_x4 scores_x4; + // Process in sub-blocks of 32 (VLEN_FP32) + HVX_Vector sb_scores[FLASH_ATTN_BLOCK_SIZE / VLEN_FP32]; HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores - float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (uint32_t j = 0; j < VLEN_FP32; j += 2) { - const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale); - } - - HVX_Vector scores = *(HVX_Vector *) scores_arr; + HVX_Vector scores = hvx_dot_f16_f16_aa_rx32(q_ptr_vtcm, k_base + ic * factx->size_k_row_padded, factx->size_k_row_padded, DK, factx->scale); // 2. Softcap if (factx->logit_softcap != 0.0f) { @@ -428,35 +456,35 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * scores = Q6_Vsf_equals_Vqf32(scores); } - scores_x4.v[iv] = scores; + sb_scores[iv] = scores; v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max } { // 4. Online Softmax Update HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); - HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec); - HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec)); + HVX_Vector diff_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec)); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); M_vec = M_new_vec; hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { - HVX_Vector scores = scores_x4.v[iv]; + HVX_Vector scores = sb_scores[iv]; HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); // 5. Accumulate V - float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector *) p_arr = P; + __fp16 __attribute__((aligned(VLEN))) p_arr[VLEN_FP16]; + hvx_vec_f32_to_f16_a(p_arr, P, hvx_vec_splat_f32(0)); for (uint32_t j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic2 + j; const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; - hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV); + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, (p_arr + j), (p_arr + j + 1), DV); } } @@ -464,47 +492,50 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); } - // Sync scalars for leftover/next block if needed - float M = hvx_vec_get_f32(M_vec); - float S = hvx_vec_get_f32(S_vec); + if (ic < current_block_size) { + // Sync scalars for leftover/next block if needed + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); + if (factx->logit_softcap != 0.0f) { + s_val = factx->logit_softcap * tanhf(s_val); + } - // Leftover - for (; ic < current_block_size; ++ic) { - float s_val; - const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); - if (factx->logit_softcap != 0.0f) { - s_val = factx->logit_softcap * tanhf(s_val); - } + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } - if (mask) { - const float m_val = m_base[ic]; - s_val += slope * m_val; - } + const float Mold = M; + __fp16 vs = 1.0f; + + if (s_val > M) { + M = s_val; + HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; + } else { + HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; + } - const float Mold = M; - float vs = 1.0f; - - if (s_val > M) { - M = s_val; - HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); - HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); - hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - - float ms = hvx_vec_get_f32(ms_vec); - S = S * ms + vs; - } else { - HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); - vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); - S += vs; - } + const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; - const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, &vs, DV); + } - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + M_vec = hvx_vec_splat_f32(M); + S_vec = hvx_vec_splat_f32(S); } - M_vec = hvx_vec_splat_f32(M); - S_vec = hvx_vec_splat_f32(S); // Issue DMA for next+1 block (if exists) if (ib + 2 < factx->n_blocks) { @@ -525,6 +556,11 @@ static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); } + + // FARF(HIGH, "fa %u: prefetch KVM: ir %u ib %u : iq1 %u iq2 %u iq3 %u : size_k_row %u size_v_row %u bs %u: usec %u", + // ith, ir, next_ib, iq1, iq2, iq3, + // size_k_row, size_v_row, next_block_size, + // (unsigned)HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - factx->t_start)); } } @@ -586,6 +622,8 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { struct htp_fa_context factx; factx.octx = octx; + factx.t_start = HAP_perf_get_qtimer_count(); + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); factx.src0_div1 = init_fastdiv_values(q->ne[1]); @@ -632,6 +670,15 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); + // total rows in q + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + factx.qrows = neq1*neq2*neq3; + factx.qrows_per_thread = (factx.qrows + octx->n_threads - 1) / octx->n_threads; + size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index bf24bbda70a..047d2850aaa 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -82,6 +82,8 @@ static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *da int op_get_rows(struct htp_ops_context * octx) { get_rows_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + if (octx->src0.type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -103,9 +105,8 @@ int op_get_rows(struct htp_ops_context * octx) { grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + grctx.src1_nrows_per_thread = (nr + n_threads - 1) / n_threads; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index 25403bb1126..52dcc36d8f7 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -68,6 +68,7 @@ enum htp_op { HTP_OP_SQR, HTP_OP_SQRT, HTP_OP_SUM_ROWS, + HTP_OP_SSM_CONV, INVALID }; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 127ab1d6659..2ef20936f1b 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -41,9 +41,6 @@ struct htp_ops_context { worker_pool_context_t * wpool; // worker pool uint32_t n_threads; // num threads - uint32_t src0_nrows_per_thread; - uint32_t src1_nrows_per_thread; - uint32_t flags; }; @@ -61,5 +58,6 @@ int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); int op_argsort(struct htp_ops_context * octx); +int op_ssm_conv(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h index 2577cdd0418..82e3416970b 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-arith.h +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -13,14 +13,15 @@ // Binary operations (add, mul, sub) // -#define hvx_arith_loop_body(dst_type, src0_type, src1_type, vec_store, vec_op) \ +#define UNUSED(x) (void)(x) + +#define hvx_arith_loop_body(dst_type, src0_type, src1_type, elem_size, vec_store, vec_op) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src0_type * restrict vsrc0 = (src0_type *) src0; \ src1_type * restrict vsrc1 = (src1_type *) src1; \ \ - const uint32_t elem_size = sizeof(float); \ - const uint32_t epv = 128 / elem_size; \ + const uint32_t epv = 128 / (elem_size); \ const uint32_t nvec = n / epv; \ const uint32_t nloe = n % epv; \ \ @@ -32,62 +33,74 @@ } \ if (nloe) { \ HVX_Vector v = vec_op(vsrc0[i], vsrc1[i]); \ - vec_store((void *) &vdst[i], nloe * elem_size, v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ } \ } while(0) #if __HVX_ARCH__ < 79 -#define HVX_OP_ADD(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) -#define HVX_OP_SUB(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) -#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) + #else -#define HVX_OP_ADD(a, b) Q6_Vsf_vadd_VsfVsf(a, b) -#define HVX_OP_SUB(a, b) Q6_Vsf_vsub_VsfVsf(a, b) -#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + +#define HVX_OP_ADD_F32(a, b) Q6_Vsf_vadd_VsfVsf(a, b) +#define HVX_OP_SUB_F32(a, b) Q6_Vsf_vsub_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) + #endif +#define HVX_OP_ADD_F16(a, b) hvx_vec_add_f16_f16(a, b) +#define HVX_OP_SUB_F16(a, b) hvx_vec_sub_f16_f16(a, b) +#define HVX_OP_MUL_F16(a, b) hvx_vec_mul_f16_f16(a, b) + // Generic macro to define alignment permutations for an op -#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO, ELEM_TYPE) \ static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src0 % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src0 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) dst % 128 == 0); \ - hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ } \ static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src0 % 128 == 0); \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src0 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ assert((uintptr_t) src1 % 128 == 0); \ - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ } \ -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) -DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB_F32, float) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL_F32, float) + +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f16, HVX_OP_ADD_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f16, HVX_OP_SUB_F16, _Float16) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f16, HVX_OP_MUL_F16, _Float16) // Dispatcher logic #define HVX_BINARY_DISPATCHER(OP_NAME) \ @@ -115,6 +128,10 @@ HVX_BINARY_DISPATCHER(hvx_add_f32) HVX_BINARY_DISPATCHER(hvx_sub_f32) HVX_BINARY_DISPATCHER(hvx_mul_f32) +HVX_BINARY_DISPATCHER(hvx_add_f16) +HVX_BINARY_DISPATCHER(hvx_sub_f16) +HVX_BINARY_DISPATCHER(hvx_mul_f16) + // Mul-Mul Optimized static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { assert((unsigned long) dst % 128 == 0); @@ -136,26 +153,25 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re _Pragma("unroll(4)") for (; i < nvec; i++) { - HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); vdst[i] = HVX_OP_MUL(v1, vsrc2[i]); } if (nloe) { - HVX_Vector v1 = HVX_OP_MUL(vsrc0[i], vsrc1[i]); - HVX_Vector v2 = HVX_OP_MUL(v1, vsrc2[i]); + HVX_Vector v1 = HVX_OP_MUL_F32(vsrc0[i], vsrc1[i]); + HVX_Vector v2 = HVX_OP_MUL_F32(v1, vsrc2[i]); hvx_vec_store_a((void *) &vdst[i], nloe * elem_size, v2); } } // Scalar Operations -#define hvx_scalar_loop_body(dst_type, src_type, vec_store, scalar_op_macro) \ +#define hvx_scalar_loop_body(dst_type, src_type, elem_size, vec_store, scalar_op_macro) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const uint32_t elem_size = sizeof(float); \ - const uint32_t epv = 128 / elem_size; \ + const uint32_t epv = 128 / (elem_size); \ const uint32_t nvec = n / epv; \ const uint32_t nloe = n % epv; \ \ @@ -169,138 +185,88 @@ static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * re if (nloe) { \ HVX_Vector v = vsrc[i]; \ v = scalar_op_macro(v); \ - vec_store((void *) &vdst[i], nloe * elem_size, v); \ + vec_store((void *) &vdst[i], nloe * (elem_size), v); \ } \ } while(0) -#define HVX_OP_ADD_SCALAR(v) \ +#define HVX_OP_ADD_SCALAR_F32(v) \ ({ \ const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VwVw(inf, v); \ - HVX_Vector out = HVX_OP_ADD(v, val_vec); \ + HVX_Vector out = HVX_OP_ADD_F32(v, val_vec); \ Q6_V_vmux_QVV(pred_inf, inf, out); \ }) -#define HVX_OP_MUL_SCALAR(v) HVX_OP_MUL(v, val_vec) -#define HVX_OP_SUB_SCALAR(v) HVX_OP_SUB(v, val_vec) - -// Add Scalar Variants - -static inline void hvx_add_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - const HVX_Vector inf = hvx_vec_splat_f32(INFINITY); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); -} - -static inline void hvx_add_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - static const float kInf = INFINITY; - const HVX_Vector inf = hvx_vec_splat_f32(kInf); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD_SCALAR); -} - -// Sub Scalar Variants +#define HVX_OP_MUL_SCALAR_F32(v) HVX_OP_MUL_F32(v, val_vec) +#define HVX_OP_SUB_SCALAR_F32(v) HVX_OP_SUB_F32(v, val_vec) -static inline void hvx_sub_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); -} - -static inline void hvx_sub_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB_SCALAR); -} +#define HVX_OP_ADD_SCALAR_F16(v) \ + ({ \ + const HVX_VectorPred pred_inf = Q6_Q_vcmp_eq_VhVh(inf, v); \ + HVX_Vector out = HVX_OP_ADD_F16(v, val_vec); \ + Q6_V_vmux_QVV(pred_inf, inf, out); \ + }) -// Mul Scalar Variants +#define HVX_OP_MUL_SCALAR_F16(v) HVX_OP_MUL_F16(v, val_vec) +#define HVX_OP_SUB_SCALAR_F16(v) HVX_OP_SUB_F16(v, val_vec) -static inline void hvx_mul_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); -} +// Scalar Variants -static inline void hvx_mul_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL_SCALAR); -} +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(OP_NAME, OP_MACRO, SPLAT_MACRO, ELEM_TYPE) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + assert((uintptr_t) src % 128 == 0); \ + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, uint32_t n) { \ + const HVX_Vector val_vec = SPLAT_MACRO(val); \ + const HVX_Vector inf = SPLAT_MACRO((ELEM_TYPE)INFINITY); UNUSED(inf); \ + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(ELEM_TYPE), hvx_vec_store_u, OP_MACRO); \ +} \ -static inline void hvx_mul_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); -} +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f32, HVX_OP_ADD_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f32, HVX_OP_SUB_SCALAR_F32, hvx_vec_splat_f32, float) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f32, HVX_OP_MUL_SCALAR_F32, hvx_vec_splat_f32, float) -static inline void hvx_mul_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { - const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL_SCALAR); -} +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_add_scalar_f16, HVX_OP_ADD_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_sub_scalar_f16, HVX_OP_SUB_SCALAR_F16, hvx_vec_splat_f16, _Float16) +DEFINE_HVX_BINARY_SCALAR_OP_VARIANTS(hvx_mul_scalar_f16, HVX_OP_MUL_SCALAR_F16, hvx_vec_splat_f16, _Float16) -static inline void hvx_add_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_add_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_add_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_add_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_add_scalar_f32_uu(dst, src, val, num_elems); - } +// Dispatcher logic +#define HVX_BINARY_SCALAR_DISPATCHER(OP_NAME, ELEM_TYPE) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const ELEM_TYPE val, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, val, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, val, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, val, num_elems); \ + } \ } -static inline void hvx_mul_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_mul_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_mul_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_mul_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_mul_scalar_f32_uu(dst, src, val, num_elems); - } -} +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f32, float) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f32, float) -static inline void hvx_sub_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { - hvx_sub_scalar_f32_aa(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) dst, 128)) { - hvx_sub_scalar_f32_au(dst, src, val, num_elems); - } else if (hex_is_aligned((void *) src, 128)) { - hvx_sub_scalar_f32_ua(dst, src, val, num_elems); - } else { - hvx_sub_scalar_f32_uu(dst, src, val, num_elems); - } -} +HVX_BINARY_SCALAR_DISPATCHER(hvx_add_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_sub_scalar_f16, _Float16) +HVX_BINARY_SCALAR_DISPATCHER(hvx_mul_scalar_f16, _Float16) // MIN Scalar variants @@ -310,24 +276,24 @@ static inline void hvx_min_scalar_f32_aa(uint8_t * restrict dst, const uint8_t * const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float val, uint32_t n) { const HVX_Vector val_vec = hvx_vec_splat_f32(val); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MIN_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_MIN_SCALAR); } static inline void hvx_min_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float val, const int num_elems) { @@ -357,27 +323,27 @@ static inline void hvx_clamp_scalar_f32_aa(uint8_t * restrict dst, const uint8_t const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_Vector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) dst % 128 == 0); - hvx_scalar_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_Vector, HVX_UVector, sizeof(float), hvx_vec_store_a, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); assert((unsigned long) src % 128 == 0); - hvx_scalar_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_Vector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, uint32_t n) { const HVX_Vector min_vec = hvx_vec_splat_f32(min); const HVX_Vector max_vec = hvx_vec_splat_f32(max); - hvx_scalar_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); + hvx_scalar_loop_body(HVX_UVector, HVX_UVector, sizeof(float), hvx_vec_store_u, HVX_OP_CLAMP_SCALAR); } static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * restrict src, const float min, const float max, const int num_elems) { @@ -396,7 +362,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * // Square // -#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ +#define hvx_sqr_f32_loop_body(dst_type, src_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ @@ -410,10 +376,10 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * \ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ - vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vdst[i] = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ } \ if (nloe) { \ - HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + HVX_Vector v = HVX_OP_MUL_F32(vsrc[i], vsrc[i]); \ vec_store((void *) &vdst[i], nloe * elem_size, v); \ } \ } while(0) @@ -421,21 +387,21 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); - hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); } static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); - hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); + hvx_sqr_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); } static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) src % 128 == 0); - hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); } static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); + hvx_sqr_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); } static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { @@ -454,17 +420,24 @@ static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict } } -#undef HVX_OP_ADD -#undef HVX_OP_SUB -#undef HVX_OP_MUL +#undef HVX_OP_ADD_F32 +#undef HVX_OP_SUB_F32 +#undef HVX_OP_MUL_F32 +#undef HVX_OP_ADD_F16 +#undef HVX_OP_SUB_F16 +#undef HVX_OP_MUL_F16 #undef hvx_arith_loop_body -#undef HVX_OP_ADD_SCALAR -#undef HVX_OP_SUB_SCALAR -#undef HVX_OP_MUL_SCALAR +#undef HVX_OP_ADD_SCALAR_F32 +#undef HVX_OP_SUB_SCALAR_F32 +#undef HVX_OP_MUL_SCALAR_F32 +#undef HVX_OP_ADD_SCALAR_F16 +#undef HVX_OP_SUB_SCALAR_F16 +#undef HVX_OP_MUL_SCALAR_F16 #undef hvx_scalar_loop_body #undef HVX_OP_MIN_SCALAR #undef HVX_OP_CLAMP_SCALAR #undef DEFINE_HVX_BINARY_OP_VARIANTS #undef HVX_BINARY_DISPATCHER +#undef UNUSED #endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index 12a1b7f1288..578ca288fb6 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -38,7 +38,7 @@ static inline HVX_Vector hvx_vec_splat_f32(float v) { return Q6_V_vsplat_R(u.i); } -static inline HVX_Vector hvx_vec_splat_f16(float v) { +static inline HVX_Vector hvx_vec_splat_f16(_Float16 v) { union { __fp16 f; uint16_t i; } u = { .f = v }; return Q6_Vh_vsplat_R(u.i); } @@ -170,4 +170,71 @@ static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { return Q6_Vh_vround_VwVw_sat(vsf_1, vsf_0); } +#if __HVX_ARCH__ < 79 + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + HVX_VectorPair m = Q6_Wqf32_vmpy_VhfVhf(x, y); + HVX_Vector a0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(m), Q6_V_lo_W(acc))); + HVX_Vector a1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(m), Q6_V_hi_W(acc))); + return Q6_W_vcombine_VV(a1, a0); +} + +#else + +static inline HVX_VectorPair hvx_vec_mpyacc_f32_f16(HVX_VectorPair acc, HVX_Vector x, HVX_Vector y) +{ + return Q6_Wsf_vmpyacc_WsfVhfVhf(acc, x, y); +} + +#endif + +#if __HVX_ARCH__ < 79 + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vsub_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + const HVX_Vector negone = Q6_Vh_vsplat_R(0xBC00); // -1.0 in IEEE FP16 + const HVX_Vector one = Q6_Vh_vsplat_R(0x3C00); // 1.0 in IEEE FP16 + HVX_VectorPair a_p = Q6_Wqf32_vmpy_VhfVhf(a, one); + HVX_VectorPair b_p = Q6_Wqf32_vmpy_VhfVhf(b, negone); + HVX_Vector a0 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(a_p), Q6_V_lo_W(b_p)); + HVX_Vector a1 = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(a_p), Q6_V_hi_W(b_p)); + return Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(a1, a0)); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_equals_Wqf32(Q6_Wqf32_vmpy_VhfVhf(a, b)); +} + +#else + +static inline HVX_Vector hvx_vec_add_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vadd_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_sub_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vsub_VhfVhf(a, b); +} + +static inline HVX_Vector hvx_vec_mul_f16_f16(HVX_Vector a, HVX_Vector b) +{ + return Q6_Vhf_vmpy_VhfVhf(a, b); +} + +#endif // __HVX_ARCH__ < 79 + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index ae0dbed0306..851482e01b2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -42,11 +42,11 @@ static inline void hvx_splat_f32_u(uint8_t * restrict dst, float v, uint32_t n) hvx_splat_u(dst, hvx_vec_splat_f32(v), n, sizeof(float)); } -static inline void hvx_splat_f16_a(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f16_a(uint8_t * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } -static inline void hvx_splat_f16_u(uint8_t * restrict dst, float v, uint32_t n) { +static inline void hvx_splat_f16_u(uint8_t * restrict dst, _Float16 v, uint32_t n) { hvx_splat_u(dst, hvx_vec_splat_f16(v), n, sizeof(__fp16)); } diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h index 7dae012e0ed..05cefea039f 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-div.h +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -15,11 +15,144 @@ #include "hvx-arith.h" #if __HVX_ARCH__ < 79 -#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) #else -#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif +// Compute div by scaler in f32. Requires first by expanding fp32 to fp16 and converting the result back to fp32. +static inline HVX_Vector hvx_div_mul_f16_const_using_f32(HVX_Vector vec1_hf, HVX_Vector vec2_sf_const, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + HVX_VectorPair src_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(src_to_f32)); + HVX_Vector src_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(src_to_f32)); +#else + HVX_VectorPair src_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1_hf, vec_hf_one_1_0); + HVX_Vector src_to_f32_0 = Q6_V_lo_W(src_to_f32); + HVX_Vector src_to_f32_1 = Q6_V_hi_W(src_to_f32); +#endif + + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(src_to_f32_0, vec2_sf_const); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(src_to_f32_1, vec2_sf_const); + +#if __HVX_ARCH__ < 79 + HVX_Vector res = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector res = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + return res; +} + +#define hvx_div_scaler_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_div_mul_f16_const_using_f32(vsrc[i], val_vec_f32, hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + +static inline void hvx_div_scalar_f16_aa(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_au(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) dst % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} +static inline void hvx_div_scalar_f16_ua(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + assert((uintptr_t) src % 128 == 0); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} +static inline void hvx_div_scalar_f16_uu(uint8_t * restrict dst, const uint8_t * restrict src, const _Float16 val, uint32_t n) { + const HVX_Vector val_vec_f32 = hvx_vec_splat_f32(1.0f/((float)val)); + hvx_div_scaler_f16_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +// Compute div by using hvx_vec_inverse_f32_guard. Requires first by exapnding fp32 to fp16 and convert the result back to fp32. +static inline HVX_Vector hvx_vec_div_f16_using_f32(HVX_Vector vec1, HVX_Vector vec2, HVX_Vector f32_nan_inf_mask, HVX_Vector vec_hf_one_1_0) { +#if __HVX_ARCH__ < 79 + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec1_to_f32)); + HVX_Vector vec1_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec1_to_f32)); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wqf32_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vec2_to_f32)); + HVX_Vector vec2_to_f32_1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vec2_to_f32)); +#else + // Convert first input to fp32 + HVX_VectorPair vec1_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec1, vec_hf_one_1_0); // *1.0 + HVX_Vector vec1_to_f32_0 = Q6_V_lo_W(vec1_to_f32); + HVX_Vector vec1_to_f32_1 = Q6_V_hi_W(vec1_to_f32); + + // Convert second input to fp32 + HVX_VectorPair vec2_to_f32 = Q6_Wsf_vmpy_VhfVhf(vec2, vec_hf_one_1_0); // *1.0 + HVX_Vector vec2_to_f32_0 = Q6_V_lo_W(vec2_to_f32); + HVX_Vector vec2_to_f32_1 = Q6_V_hi_W(vec2_to_f32); +#endif + + // Inverse second input in fp32 + HVX_Vector vec2_inv_f32_0 = hvx_vec_inverse_f32_guard(vec2_to_f32_0, f32_nan_inf_mask); + HVX_Vector vec2_inv_f32_1 = hvx_vec_inverse_f32_guard(vec2_to_f32_1, f32_nan_inf_mask); + + // Multiply first input by inverse of second, in fp32 + HVX_Vector div_f32_0 = HVX_OP_MUL_F32(vec1_to_f32_0, vec2_inv_f32_0); + HVX_Vector div_f32_1 = HVX_OP_MUL_F32(vec1_to_f32_1, vec2_inv_f32_1); + + // Convert back to fp16 +#if __HVX_ARCH__ < 79 + HVX_Vector recip = hvx_vec_f32_to_f16(div_f32_0, div_f32_1); +#else + HVX_Vector recip = Q6_Vhf_vcvt_VsfVsf(div_f32_0, div_f32_1); +#endif + + return recip; +} + +#define hvx_div_f16_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + const HVX_Vector hf_one = Q6_Vh_vsplat_R(0x3C00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector res = hvx_vec_div_f16_using_f32(vsrc0[i], vsrc1[i], nan_inf_mask, hf_one); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, res); \ + } \ + } while(0) + #define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ do { \ dst_type * restrict vdst = (dst_type *) dst; \ @@ -36,81 +169,83 @@ _Pragma("unroll(4)") \ for (; i < nvec; i++) { \ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ - HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ vdst[i] = res; \ } \ if (nloe) { \ HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ - HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + HVX_Vector res = HVX_OP_MUL_F32(vsrc0[i], inv_src1); \ vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ } \ } while(0) -// 3-letter suffix variants -static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src0 % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); -} - -static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src0 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_DIV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_DIV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ } -static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); -} +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f32, hvx_div_f32_loop_body) +DEFINE_HVX_DIV_OP_VARIANTS(hvx_div_f16, hvx_div_f16_loop_body) -static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) dst % 128 == 0); - hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); -} - -static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src0 % 128 == 0); - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src0 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((uintptr_t) src1 % 128 == 0); - hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); -} - -static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); -} - -static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128)) { - if (hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); - else hvx_div_f32_aau(dst, src0, src1, num_elems); - } else { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); - else hvx_div_f32_auu(dst, src0, src1, num_elems); - } - } else { - if (hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); - else hvx_div_f32_uau(dst, src0, src1, num_elems); - } else { - if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); - else hvx_div_f32_uuu(dst, src0, src1, num_elems); - } - } -} +HVX_DIV_DISPATCHER(hvx_div_f32) +HVX_DIV_DISPATCHER(hvx_div_f16) -#undef HVX_OP_MUL +#undef HVX_OP_MUL_F32 #endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-inverse.h b/ggml/src/ggml-hexagon/htp/hvx-inverse.h index 49f3efabbcc..f2054f45bac 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-inverse.h +++ b/ggml/src/ggml-hexagon/htp/hvx-inverse.h @@ -67,7 +67,7 @@ static inline HVX_Vector hvx_vec_inverse_f16(HVX_Vector vals) { HVX_Vector vcl0 = Q6_Vuh_vcl0_Vuh(rm); //count leading zeros - // Get mantissa for 16-bit represenation + // Get mantissa for 16-bit representation HVX_Vector mant_recip = Q6_V_vand_VV(Q6_Vh_vasr_VhR(Q6_Vh_vasl_VhVh(rm, vcl0), 5), Q6_Vh_vsplat_R(0x03FF)); //Compute Reciprocal Exponent @@ -137,40 +137,74 @@ static inline HVX_Vector hvx_vec_inverse_f32_guard(HVX_Vector v_sf, HVX_Vector n } \ } while(0) -static inline void hvx_inverse_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src % 128 == 0); - hvx_inverse_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); -} +static inline HVX_Vector hvx_vec_inverse_f16_guard(HVX_Vector v_sf, HVX_Vector nan_inf_mask) { + HVX_Vector out = hvx_vec_inverse_f16(v_sf); -static inline void hvx_inverse_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - hvx_inverse_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); -} + HVX_Vector masked_out = Q6_V_vand_VV(out, nan_inf_mask); + const HVX_VectorPred pred = Q6_Q_vcmp_eq_VhVh(nan_inf_mask, masked_out); -static inline void hvx_inverse_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - assert((unsigned long) src % 128 == 0); - hvx_inverse_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); + return Q6_V_vmux_QVV(pred, Q6_V_vzero(), out); } -static inline void hvx_inverse_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { - hvx_inverse_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); -} +#define hvx_inverse_f16_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const HVX_Vector nan_inf_mask = Q6_Vh_vsplat_R(0x7c00); \ + \ + const uint32_t nvec = n / VLEN_FP16; \ + const uint32_t nloe = n % VLEN_FP16; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + } \ + if (nloe) { \ + HVX_Vector v = hvx_vec_inverse_f16_guard(vsrc[i], nan_inf_mask); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP16, v); \ + } \ + } while(0) -static inline void hvx_inverse_f32(uint8_t * restrict dst, uint8_t * restrict src, const int num_elems) { - if ((unsigned long) dst % 128 == 0) { - if ((unsigned long) src % 128 == 0) { - hvx_inverse_f32_aa(dst, src, num_elems); - } else { - hvx_inverse_f32_au(dst, src, num_elems); - } - } else { - if ((unsigned long) src % 128 == 0) { - hvx_inverse_f32_ua(dst, src, num_elems); - } else { - hvx_inverse_f32_uu(dst, src, num_elems); - } - } +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_INV_OP_VARIANTS(OP_NAME, OP_LOOP_BODY) \ +static inline void OP_NAME##_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_Vector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + OP_LOOP_BODY(HVX_Vector, HVX_UVector, hvx_vec_store_a); \ +} \ +static inline void OP_NAME##_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + assert((uintptr_t) src % 128 == 0); \ + OP_LOOP_BODY(HVX_UVector, HVX_Vector, hvx_vec_store_u); \ +} \ +static inline void OP_NAME##_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { \ + OP_LOOP_BODY(HVX_UVector, HVX_UVector, hvx_vec_store_u); \ +} \ + +// Dispatcher logic +#define HVX_INV_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_aa(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) dst, 128)) { \ + OP_NAME##_au(dst, src, num_elems); \ + } else if (hex_is_aligned((void *) src, 128)) { \ + OP_NAME##_ua(dst, src, num_elems); \ + } else { \ + OP_NAME##_uu(dst, src, num_elems); \ + } \ } +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f32, hvx_inverse_f32_loop_body) +DEFINE_HVX_INV_OP_VARIANTS(hvx_inverse_f16, hvx_inverse_f16_loop_body) + +HVX_INV_DISPATCHER(hvx_inverse_f32) +HVX_INV_DISPATCHER(hvx_inverse_f16) + #endif // HVX_INVERSE_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-reduce.h b/ggml/src/ggml-hexagon/htp/hvx-reduce.h index 1ca7c05d983..3c0073ef6d8 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-reduce.h +++ b/ggml/src/ggml-hexagon/htp/hvx-reduce.h @@ -46,6 +46,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) { #if __HVX_ARCH__ > 75 +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_sf01 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_sf23 = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(sum_sf23, sum_sf01, 8); + HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4)); + sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8)); + return sum_sf; +} + static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); @@ -72,6 +87,21 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) #else +static inline HVX_Vector hvx_vec_reduce_sum_f32x4(HVX_Vector_x4 in) { + HVX_VectorPair sum_p01 = Q6_W_vshuff_VVR(in.v[1], in.v[0], 4); + HVX_VectorPair sum_p23 = Q6_W_vshuff_VVR(in.v[3], in.v[2], 4); + HVX_Vector sum_qf01 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p01), Q6_V_hi_W(sum_p01)); + HVX_Vector sum_qf23 = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p23), Q6_V_hi_W(sum_p23)); + + HVX_VectorPair sum_p0123 = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(sum_qf23), Q6_Vsf_equals_Vqf32(sum_qf01), 8); + HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sum_p0123), Q6_V_hi_W(sum_p0123)); + + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4)); + sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8)); + return Q6_Vsf_equals_Vqf32(sum_qf); +} + static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) { HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4); HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump)); diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index a518ad37331..08343798794 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -15,4 +15,12 @@ #include "hvx-div.h" #include "hvx-base.h" +#ifndef GATHER_TYPE +# if defined(__hexagon__) +# define GATHER_TYPE(_a) (intptr_t) _a +# else +# define GATHER_TYPE(_a) (HVX_Vector *) _a +# endif +#endif + #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 92a1422896c..3f99dbb32c4 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -757,6 +757,47 @@ static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_ssm_conv_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We've written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup OP context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_ssm_conv(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -1142,6 +1183,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_argsort_req(ctx, &req, bufs); break; + case HTP_OP_SSM_CONV: + if (n_bufs != 3) { + FARF(ERROR, "Bad ssm-conv-req buffer list"); + continue; + } + proc_ssm_conv_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 6f6f51f01f5..9ca74aedfef 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -1234,27 +1234,24 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_VectorPair rsum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(4) for (i = 0; i < nvec; i++) { - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x[i], y[i]); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + rsum_p = hvx_vec_mpyacc_f32_f16(rsum_p, x_hf, y_hf); } - rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum)); - hvx_vec_store_u(&s[0], 4, rsum); + HVX_Vector rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum_p), Q6_V_hi_W(rsum_p))); + hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, @@ -1267,35 +1264,30 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_VectorPair rsum0_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair rsum1_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; #pragma unroll(2) for (i = 0; i < nvec; i++) { HVX_Vector y_hf = y[i]; - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0[i], y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1[i], y_hf); } if (nloe) { HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); - HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + rsum0_p = hvx_vec_mpyacc_f32_f16(rsum0_p, x0_hf, y_hf); + rsum1_p = hvx_vec_mpyacc_f32_f16(rsum1_p, x1_hf, y_hf); } - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); + HVX_Vector rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum0_p), Q6_V_hi_W(rsum0_p))); + HVX_Vector rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(rsum1_p), Q6_V_hi_W(rsum1_p))); + HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(rsum0, rsum1); hvx_vec_store_u(s0, 8, rsum); } @@ -1311,10 +1303,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res uint32_t nloe = n % VLEN_FP16; // Row sums (sf) - 4 accumulators for 2×2 tile - HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); - HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + HVX_VectorPair r0_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r0_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r1_c0_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); + HVX_VectorPair r1_c1_sum_p = Q6_W_vcombine_VV(Q6_V_vsplat_R(0), Q6_V_vsplat_R(0)); uint32_t i = 0; @@ -1326,20 +1318,10 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res HVX_Vector c1_hf = y1[i]; // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 - HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); - HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); - HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); - HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); - - HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); - HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); - HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); - HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } if (nloe) { @@ -1350,23 +1332,17 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); - HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); - HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); - HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); - HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); - - HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); - HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); - HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); - HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); - + r0_c0_sum_p = hvx_vec_mpyacc_f32_f16(r0_c0_sum_p, r0_hf, c0_hf); + r0_c1_sum_p = hvx_vec_mpyacc_f32_f16(r0_c1_sum_p, r0_hf, c1_hf); + r1_c0_sum_p = hvx_vec_mpyacc_f32_f16(r1_c0_sum_p, r1_hf, c0_hf); + r1_c1_sum_p = hvx_vec_mpyacc_f32_f16(r1_c1_sum_p, r1_hf, c1_hf); } + HVX_Vector r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c0_sum_p), Q6_V_hi_W(r0_c0_sum_p))); + HVX_Vector r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r0_c1_sum_p), Q6_V_hi_W(r0_c1_sum_p))); + HVX_Vector r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c0_sum_p), Q6_V_hi_W(r1_c0_sum_p))); + HVX_Vector r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(r1_c1_sum_p), Q6_V_hi_W(r1_c1_sum_p))); + // Reduce and store results HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index aa6a6c9008d..be9469538f6 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -18,7 +18,7 @@ #include "htp-msg.h" #include "htp-ops.h" -// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we cant include ggml.h +// Redefined the types GGML_ROPE_TYPE_NORMAL & GGML_ROPE_TYPE_NEOX as we can't include ggml.h #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 @@ -400,7 +400,9 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t ne0 = dst->ne[0]; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; @@ -465,17 +467,14 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { rctx.dst_row_size_aligned = dst_row_size_aligned; rctx.theta_cache_offset = theta_cache_size_aligned; - uint32_t ne0 = dst->ne[0]; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; rctx.src0_nrows = src0_nrows; + rctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 2fd6c907724..4b6967749f8 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -128,6 +128,8 @@ static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *da int op_set_rows(struct htp_ops_context * octx) { set_rows_preamble; + const uint32_t n_threads = MIN(nr, octx->n_threads); + if (octx->src0.type != HTP_TYPE_F32) { return HTP_STATUS_NO_SUPPORT; } @@ -149,15 +151,14 @@ int op_set_rows(struct htp_ops_context * octx) { srctx.div_ne12 = init_fastdiv_values(ne12); srctx.div_ne11 = init_fastdiv_values(ne11); - const uint32_t n_jobs = MIN(nr, octx->n_threads); - srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_threads - 1) / n_threads; switch(octx->dst.type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_threads); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_threads); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 6e22eb6a639..8dae7f1ed55 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -353,7 +353,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const uint32_t n_threads = octx->n_threads; + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t src1_row_size = src0_row_size; @@ -393,12 +394,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs); + smctx.src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads; + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c new file mode 100644 index 00000000000..b3c1ef9572e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -0,0 +1,339 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "hex-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" + +#define htp_ssm_conv_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; + +struct htp_ssm_conv_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; + uint64_t t_start; +}; + +#define htp_ssm_conv_preamble \ + struct htp_ssm_conv_context * scctx = (struct htp_ssm_conv_context *) data; \ + struct htp_ops_context * octx = scctx->octx; \ + htp_ssm_conv_tensors_preamble; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; + +// Scalar FP32 SSM_CONV implementation +static void ssm_conv_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const uint32_t src0_stride_inner = src0->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t src0_stride_seq = src0->nb[2] / sizeof(float); // stride for sequence dimension + const uint32_t src1_stride_inner = src1->nb[1] / sizeof(float); // stride for inner dimension + const uint32_t dst_stride_token = dst->nb[1] / sizeof(float); // stride for token dimension + const uint32_t dst_stride_seq = dst->nb[2] / sizeof(float); // stride for sequence dimension + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const uint32_t d_inner_per_thread = scctx->nrows_per_thread; + const uint32_t d_inner_start = d_inner_per_thread * ith; + const uint32_t d_inner_end = MIN(d_inner_start + d_inner_per_thread, d_inner); + + // No work for this thread + if (d_inner_start >= d_inner_end) { + return; + } + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + for (uint32_t i1 = d_inner_start; i1 < d_inner_end; ++i1) { + float sumf = 0.0f; + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + const uint32_t src0_idx = (i2 + i0) + i1 * src0_stride_inner + i3 * src0_stride_seq; + const uint32_t src1_idx = i0 + i1 * src1_stride_inner; + + sumf += src0_data[src0_idx] * src1_data[src1_idx]; + } + + const uint32_t dst_idx = i1 + i2 * dst_stride_token + i3 * dst_stride_seq; + dst_data[dst_idx] = sumf; + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32 %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], d_inner_start, d_inner_end, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// HVX FP32 SSM_CONV implementation - vectorizes across d_inner dimension +static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void *data) { + htp_ssm_conv_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const int nc = src1->ne[0]; // d_conv + const int ncs = src0->ne[0]; // d_conv - 1 + n_t + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; + const uint32_t n_s = dst->ne[2]; + + const float * src0_data = (const float *) src0->data; + const float * src1_data = (const float *) src1->data; + float * dst_data = (float *) dst->data; + + // Calculate row range for this thread + const int dr = scctx->nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, d_inner); + const int ir = ir1 - ir0; + + if (ir0 >= ir1) { + return; // No work for this thread + } + + // src0 and src1 gather offsets + uint32_t __attribute__((aligned(VLEN))) src0_offsets[VLEN_FP32] = { 0 }; + uint32_t __attribute__((aligned(VLEN))) src1_offsets[VLEN_FP32] = { 0 }; + + for (uint32_t i = 0; i < VLEN_FP32; ++i) { + src0_offsets[i] = i * (ncs) * sizeof(float); + src1_offsets[i] = i * (d_conv) * sizeof(float); + } + + const uint32_t src0_gather_len = VLEN * ncs; + const uint32_t src1_gather_len = VLEN * d_conv; + + // gather scratchpads + HVX_Vector * src0_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + 0); + HVX_Vector * src1_vec = (HVX_Vector *) (octx->ctx->vtcm_base + ith * VLEN*2 + VLEN); + + float * data_src0 = (float *) ((char *) src0->data + ir0 * src0->nb[1]); + float * data_src1 = (float *) ((char *) src1->data + ir0 * src1->nb[1]); + + uint8_t * spad_src0 = octx->src0_spad.data + ith * octx->src0_spad.size_per_thread; + uint8_t * spad_src1 = octx->src1_spad.data + ith * octx->src1_spad.size_per_thread; + + // copy src1 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src1, data_src1), nb11, nb11, ir); + + // FARF(HIGH, "ssm-conv-src1-fetch %d: ir0 %u size %u\n", ith, ir0, nb11 * ir); + + for (uint32_t i3 = 0; i3 < n_s; ++i3) { + float * src0_data_ptr = (float *) ((char *) data_src0 + i3 * (src0->nb[2])); + + // copy src0 workload to VTCM + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0, src0_data_ptr), nb01, nb01, ir); + + // FARF(HIGH, "ssm-conv-src0-fetch %d: ir0 %u i3 %u size %u\n", ith, ir0, i3, nb01 * ir); + + dma_queue_flush(dma_queue); + + for (uint32_t i2 = 0; i2 < n_t; ++i2) { + float * dst_ptr = (float *) ((char *) dst->data + ir0 * (dst->nb[0]) + i2 * (dst->nb[1]) + i3 * (dst->nb[2])); + + const uint32_t nvec = ir / VLEN_FP32; + const uint32_t nloe = ir % VLEN_FP32; + uint32_t i1 = 0; + + for (uint32_t vi1 = 0; vi1 < nvec; vi1++) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + *(HVX_UVector *) (dst_ptr + i1) = Q6_Vsf_equals_Vqf32(acc_vec); + i1 += VLEN_FP32; + } + + if (nloe) { + HVX_Vector acc_vec = Q6_V_vsplat_R(0); + + for (uint32_t i0 = 0; i0 < d_conv; ++i0) { + Q6_vgather_ARMVw(src0_vec, GATHER_TYPE(spad_src0 + (i0 + i1 * ncs) * sizeof(float) + i2 * (src0->nb[0])), + src0_gather_len, (*(const HVX_Vector *) src0_offsets)); + Q6_vgather_ARMVw(src1_vec, GATHER_TYPE(spad_src1 + (i0 + i1 * nc) * sizeof(float)), + src1_gather_len, (*(const HVX_Vector *) src1_offsets)); + + HVX_Vector prod = Q6_Vqf32_vmpy_VsfVsf(*(const HVX_Vector *) src0_vec, *(const HVX_Vector *) src1_vec); + acc_vec = Q6_Vqf32_vadd_Vqf32Vqf32(acc_vec, prod); + } + + hvx_vec_store_u(dst_ptr + i1, (ir - i1) * 4, Q6_Vsf_equals_Vqf32(acc_vec)); + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "ssm-conv-f32-hvx %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", + ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0, ir1, + src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], + dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +int op_ssm_conv_f32(struct htp_ops_context * octx) { + htp_ssm_conv_tensors_preamble; + + if (src0->type != HTP_TYPE_F32 || src1->type != HTP_TYPE_F32 || dst->type != HTP_TYPE_F32) { + FARF(ERROR, "ssm_conv: only (F32 x F32 -> F32) OPs supported"); + return HTP_STATUS_NO_SUPPORT; + } + + struct htp_ssm_conv_context scctx = { 0 }; + scctx.octx = octx; + + const uint32_t d_conv = src1->ne[0]; + const uint32_t d_inner = src0->ne[1]; + const uint32_t n_t = dst->ne[1]; // tokens per sequence + const uint32_t n_s = dst->ne[2]; // number of sequences in the batch + + const uint32_t n_threads = MIN(octx->n_threads, d_inner); + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + uint32_t use_hvx = 0; + if (d_inner >= VLEN_FP32 && d_inner % VLEN_FP32 == 0) { + int is_aligned = hex_is_aligned((void *) src0->data, VLEN) && + hex_is_aligned((void *) src1->data, VLEN) && + hex_is_aligned((void *) dst->data, VLEN); + + if (is_aligned) { + use_hvx = 1; + } + } + + if (use_hvx) { + scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; // d_inner chunks per thread + scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); // round up to even + + octx->src0_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb01, 256); + octx->src1_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * nb11, 256); + octx->dst_spad.size_per_thread = hex_round_up(scctx.nrows_per_thread * sizeof(float), 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * n_threads; + + // Compute gather scratchpad size for src0 and src1 + const size_t gather_spad_size = n_threads * VLEN * 2; + + octx->src0_spad.data = octx->ctx->vtcm_base + gather_spad_size; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + + FARF(HIGH, "ssm_conv-f32: gather-spad:%zu spad-per-thread:(%u:%u:%u) spad-sizes:(%u:%u:%u) spad-data:(%p:%p:%p)\n", + gather_spad_size, octx->src0_spad.size_per_thread, octx->src1_spad.size_per_thread, + octx->dst_spad.size_per_thread, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, + octx->src0_spad.data, octx->src1_spad.data, octx->dst_spad.data); + + const size_t total_spad_size = + gather_spad_size + octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + + if (total_spad_size > octx->ctx->vtcm_size) { + FARF(HIGH, "ssm_conv-f32: HVX scratchpad size %zu exceeds VTCM size %zu", total_spad_size, + octx->ctx->vtcm_size); + use_hvx = 0; + } + } + + FARF(HIGH, "ssm-conv-f32: (%ux%ux%ux%u) x (%ux%ux%ux%u) -> (%ux%ux%ux%u) : use_hvx %d\n", src0->ne[0], + src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], + dst->ne[1], dst->ne[2], dst->ne[3], use_hvx); + + if (use_hvx) { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32_hvx, &scctx, n_threads); + } else { + worker_pool_run_func(octx->ctx->worker_pool, ssm_conv_thread_f32_f32, &scctx, n_threads); + } + } + + return HTP_STATUS_OK; +} + +int op_ssm_conv(struct htp_ops_context * octx) { + int err = HTP_STATUS_OK; + struct htp_tensor * dst = &octx->dst; + + switch (dst->type) { + case HTP_TYPE_F32: + err = op_ssm_conv_f32(octx); + break; + default: + err = HTP_STATUS_NO_SUPPORT; + break; + } + + return err; +} diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 04fa72182a3..352650b689b 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -102,11 +102,9 @@ int op_sum_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = ne01 * ne02 * ne03; - - uint32_t n_jobs = MIN(n_threads, src0_nrows); - uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); + const uint32_t rows_per_thread = (src0_nrows + n_threads - 1) / n_threads; bool opt_path = false; if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { @@ -124,7 +122,7 @@ int op_sum_rows(struct htp_ops_context * octx) { .opt_path = opt_path, }; - worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_threads); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 98135c50ab8..5bbd5040d3d 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -301,8 +301,8 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } - const int n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t n_threads = MIN(octx->n_threads, src0_nrows); const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; @@ -338,11 +338,9 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - struct htp_unary_context uctx = { .octx = octx, - .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs, + .src0_nrows_per_thread = (src0_nrows + n_threads - 1) / n_threads, .src0_nrows = src0_nrows, .data_src0 = (const uint8_t *)src0->data, @@ -361,7 +359,7 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { .nc = src0->ne[0], }; - worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_threads); } return err; diff --git a/ggml/src/ggml-hexagon/htp/worker-pool.c b/ggml/src/ggml-hexagon/htp/worker-pool.c index 894815f46a5..172e28908eb 100644 --- a/ggml/src/ggml-hexagon/htp/worker-pool.c +++ b/ggml/src/ggml-hexagon/htp/worker-pool.c @@ -56,7 +56,7 @@ static void worker_pool_main(void * context) { unsigned int n = atomic_load(&pool->n_jobs); unsigned int i = atomic_fetch_add(&pool->next_job, 1); if (i >= n) { - // Spurios wakeup + // Spurious wakeup continue; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 3db7f126291..4cce414abfe 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1281,7 +1281,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te bool use_residency_sets; // optional MTLResidencySet - // note: cannot use explicity "id" here because it is not available on certain OSes + // note: cannot use explicitly "id" here because it is not available on certain OSes id rset; // pointers to global device diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 3d5db0b79f5..b3390352ffc 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -631,7 +631,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; if (!inplace) { - // run a separete kernel to cpy src->dst + // run a separate kernel to cpy src->dst // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel @@ -1644,7 +1644,7 @@ int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; if (!inplace) { - // run a separete kernel to cpy src->dst + // run a separate kernel to cpy src->dst // not sure how to avoid this // TODO: make a simpler cpy_bytes kernel @@ -2005,7 +2005,7 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { const int16_t r0ptg = nypsg*nsg; // num src0 rows per threadgroup int16_t r1ptg = 4; // num src1 rows per threadgroup - // note: not sure how optimal are those across all different hardware. there might be someting cleverer + // note: not sure how optimal are those across all different hardware. there might be something cleverer switch (ne11) { case 2: r1ptg = 2; break; diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp index 1c705362fb7..9382ce53b36 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -14,7 +14,7 @@ #define GGML_METAL_MAX_DEVICES 16 // number of Metal devices -// note: can be overriden with GGML_METAL_DEVICES env to simulate virtual devices +// note: can be overridden with GGML_METAL_DEVICES env to simulate virtual devices static int g_devices = 1; //////////////////////////////////////////////////////////////////////////////// diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 6c349aa0c92..a58e641ad86 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -4218,7 +4218,7 @@ kernel void kernel_im2col( template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col; template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col; -// TODO: obolete -- remove +// TODO: obsolete -- remove //typedef void (im2col_ext_t)( // constant ggml_metal_kargs_im2col & args, // device const float * x, diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index f3891936911..fb3ae17eaf4 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -63,6 +63,7 @@ set(GGML_OPENCL_KERNELS cpy cvt diag_mask_inf + diag div gelu gemv_noshuffle_general @@ -108,8 +109,11 @@ set(GGML_OPENCL_KERNELS mul_mm_q8_0_f32_l4_lm mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 + gemv_noshuffle_q4_1_f32 + gemm_noshuffle_q4_1_f32 gemv_noshuffle_general_q8_0_f32 mul + neg norm relu rms_norm @@ -132,6 +136,7 @@ set(GGML_OPENCL_KERNELS tsembd upscale tanh + exp expm1 softplus pad diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3da022ed86c..0a2c86c6e22 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -313,7 +313,7 @@ struct ProfilingInfo { cl_ulong cmd_duration_ns; // The time for the kernel to complete - COMPLETE - END cl_ulong cmd_complete_duration_ns; - // Total time to finish the kernel - COMPELTE - QUEUED + // Total time to finish the kernel - COMPLETE - QUEUED cl_ulong cmd_total_duration_ns; // Global and local work sizes. size_t global_size[3]; @@ -416,7 +416,6 @@ struct ggml_backend_opencl_context { cl_program program_add; cl_program program_add_id; cl_program program_clamp; - cl_program program_cpy; cl_program program_cvt; cl_program program_diag_mask_inf; cl_program program_gelu; @@ -500,6 +499,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_rms_norm, kernel_rms_norm_mul; cl_kernel kernel_group_norm, kernel_group_norm_mul_add; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; + cl_kernel kernel_diag_f32; cl_kernel kernel_soft_max, kernel_soft_max_4; cl_kernel kernel_soft_max_f16, kernel_soft_max_4_f16; std::map, cl_kernel> kernels_flash_attn_f16; @@ -514,7 +514,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32; cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16; cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16; - cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32; + cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32, kernel_cpy_i32_i32; cl_kernel kernel_mul_mat_f32_f32; cl_kernel kernel_mul_mat_f16_f16; cl_kernel kernel_mul_mat_f16_f32_1row; @@ -531,6 +531,8 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; cl_kernel kernel_convert_block_q4_0_noshuffle; cl_kernel kernel_restore_block_q4_0_noshuffle; + cl_kernel kernel_convert_block_q4_1_noshuffle; + cl_kernel kernel_restore_block_q4_1_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; cl_kernel kernel_mul_mv_q4_1_f32; @@ -548,6 +550,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; + cl_kernel kernel_neg_f32, kernel_neg_f32_4, kernel_neg_f32_nc; + cl_kernel kernel_neg_f16, kernel_neg_f16_4, kernel_neg_f16_nc; + cl_kernel kernel_exp_f32, kernel_exp_f32_4, kernel_exp_f32_nc; + cl_kernel kernel_exp_f16, kernel_exp_f16_4, kernel_exp_f16_nc; cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc; cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc; cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc; @@ -683,7 +689,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_transpose_32; cl_kernel kernel_transpose_32_16; cl_kernel kernel_transpose_16; + cl_kernel kernel_transpose_8_buf; cl_kernel kernel_transpose_16_buf; + cl_kernel kernel_transpose_32_buf; cl_kernel kernel_transpose_16_4x1; // Gemm and Gemv related programs, kernels, etc @@ -699,6 +707,8 @@ struct ggml_backend_opencl_context { cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_4096_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_11008_1_4096; cl_kernel CL_mul_mat_vec_q4_0_f32_1d_4x_flat_32000_1_4096; + cl_kernel kernel_gemv_noshuffle_q4_1_f32; + cl_kernel kernel_gemm_noshuffle_q4_1_f32; cl_kernel kernel_mul_mm_q8_0_f32_8x4; cl_kernel CL_mul_mat_vec_q8_0_f32; #endif // GGML_OPENCL_USE_ADRENO_KERNELS @@ -867,13 +877,14 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("cpy.cl"); #endif - backend_ctx->program_cpy = + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f16_f32", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f16", &err), err)); - CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(backend_ctx->program_cpy, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f16 = clCreateKernel(prog, "kernel_cpy_f16_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f16_f32 = clCreateKernel(prog, "kernel_cpy_f16_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f16 = clCreateKernel(prog, "kernel_cpy_f32_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_f32_f32 = clCreateKernel(prog, "kernel_cpy_f32_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_cpy_i32_i32 = clCreateKernel(prog, "kernel_cpy_i32_i32", &err), err)); GGML_LOG_CONT("."); } @@ -893,6 +904,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1_noshuffle", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); @@ -924,6 +937,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // diag + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "diag.cl.h" + }; +#else + const std::string kernel_src = read_file("diag.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_diag_f32 = clCreateKernel(prog, "kernel_diag_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // gelu { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1971,6 +2001,48 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // neg + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "neg.cl.h" + }; +#else + const std::string kernel_src = read_file("neg.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_neg_f32 = clCreateKernel(prog, "kernel_neg_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f32_4 = clCreateKernel(prog, "kernel_neg_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f32_nc = clCreateKernel(prog, "kernel_neg_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16 = clCreateKernel(prog, "kernel_neg_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16_4 = clCreateKernel(prog, "kernel_neg_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_neg_f16_nc = clCreateKernel(prog, "kernel_neg_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // exp + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "exp.cl.h" + }; +#else + const std::string kernel_src = read_file("exp.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_exp_f32 = clCreateKernel(prog, "kernel_exp_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f32_4 = clCreateKernel(prog, "kernel_exp_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f32_nc = clCreateKernel(prog, "kernel_exp_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16 = clCreateKernel(prog, "kernel_exp_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16_4 = clCreateKernel(prog, "kernel_exp_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_exp_f16_nc = clCreateKernel(prog, "kernel_exp_f16_nc", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // expm1 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2258,7 +2330,9 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_transpose_32_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_16", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_32 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_8_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_8_buf", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_buf", &err), err)); + CL_CHECK((backend_ctx->kernel_transpose_32_buf = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_32_buf", &err), err)); CL_CHECK((backend_ctx->kernel_transpose_16_4x1 = clCreateKernel(backend_ctx->program_transpose, "kernel_transpose_16_4x1", &err), err)); GGML_LOG_CONT("."); } @@ -2378,6 +2452,45 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // gemm_noshuffle_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemm_noshuffle_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemm_noshuffle_q4_1_f32.cl"); +#endif + cl_program prog = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_gemm_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemm_noshuffle_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // gemv_noshuffle_q4_1_f32 + { + std::string CL_gemv_compile_opts = std::string("-cl-std=") + opencl_c_std + + " -cl-mad-enable "; + if (backend_ctx->has_vector_subgroup_broadcast) { + CL_gemv_compile_opts += " -DVECTOR_SUB_GROUP_BROADCAT "; + } + +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "gemv_noshuffle_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("gemv_noshuffle_q4_1_f32.cl"); +#endif + + cl_program prog = build_program_from_source( + backend_ctx->context, backend_ctx->device, kernel_src.c_str(), CL_gemv_compile_opts); + + CL_CHECK((backend_ctx->kernel_gemv_noshuffle_q4_1_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_8x4 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2413,7 +2526,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve cl_program prog = build_program_from_source( backend_ctx->context, backend_ctx->device, kernel_src_CL_gemv_general.c_str(), CL_gemv_compile_opts); - CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle", &err), err)); + CL_CHECK((backend_ctx->CL_mul_mat_vec_q8_0_f32 = clCreateKernel(prog, "kernel_gemv_noshuffle_q8_0_f32", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); } @@ -2506,7 +2619,7 @@ static std::vector ggml_opencl_probe_devices(ggml_backend_r cl_platform_id platform_ids[NPLAT]; if (clGetPlatformIDs(NPLAT, platform_ids, &n_platforms) != CL_SUCCESS) { - GGML_LOG_ERROR("ggml_opencl: plaform IDs not available.\n"); + GGML_LOG_ERROR("ggml_opencl: platform IDs not available.\n"); return found_devices; } @@ -2923,6 +3036,82 @@ static void ggml_cl2_free(ggml_backend_t backend) { } } +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS +static void transpose_2d( + ggml_backend_opencl_context * backend_ctx, + cl_kernel kernel, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + static ggml_cl_buffer buf; + + cl_event evt; + cl_int err; + + buf.allocate(backend_ctx->context, size); + + cl_mem trans; + cl_buffer_region region; + + region.origin = 0; + region.size = size; + CL_CHECK((trans = clCreateSubBuffer( + buf.buffer, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &src)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_int), &stride)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_int), &rows)); + + size_t local_size[3] = {64, 1, 1}; + size_t global_size[3] = {(size_t)stride, (size_t)rows, 1};; + CL_CHECK(clEnqueueNDRangeKernel(backend_ctx->queue, kernel, 3, NULL, + global_size, local_size, 0, NULL, NULL)); + + if (blocking) { + CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseEvent(evt)); + } else { + CL_CHECK(clEnqueueCopyBuffer(backend_ctx->queue, trans, dst, 0, 0, size, 0, NULL, NULL)); + } + + CL_CHECK(clReleaseMemObject(trans)); +} + +static void transpose_2d_as_8b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_8_buf, + src, dst, size, stride, rows, blocking); +} + +static void transpose_2d_as_16b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_16_buf, + src, dst, size, stride, rows, blocking); +} + +static void transpose_2d_as_32b( + ggml_backend_opencl_context * backend_ctx, + cl_mem src, cl_mem dst, size_t size, + cl_int stride, cl_int rows, + bool blocking = true +) { + transpose_2d(backend_ctx, backend_ctx->kernel_transpose_32_buf, + src, dst, size, stride, rows, blocking); +} +#endif // GGML_OPENCL_USE_ADRENO_KERNELS + //------------------------------------------------------------------------------ // Tensor extra management //------------------------------------------------------------------------------ @@ -3214,7 +3403,7 @@ static void ggml_backend_opencl_synchronize(ggml_backend_t backend) { CL_CHECK(clReleaseEvent(evt)); } -// Syncronizes the 'backend_ctx's device with others so that commands +// Synchronizes the 'backend_ctx's device with others so that commands // enqueued to it won't start until commands in the other devices have // completed. static void sync_with_other_backends(ggml_backend_opencl_context * backend_ctx) { @@ -3419,9 +3608,21 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te default: return false; } + case GGML_TYPE_I32: + switch (op->type) { + case GGML_TYPE_I32: + return true; + default: + return false; + } default: return false; } + case GGML_OP_SET: { + return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_I32) && + op->type == op->src[0]->type && + op->type == op->src[1]->type; + } case GGML_OP_SCALE: return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); case GGML_OP_ADD: @@ -3455,6 +3656,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_SIGMOID: return ggml_is_contiguous(op->src[0]); case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_EXP: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; @@ -3540,6 +3743,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_PERMUTE: case GGML_OP_TRANSPOSE: return true; + case GGML_OP_DIAG: + return true; case GGML_OP_DIAG_MASK_INF: return op->ne[3] == 1; case GGML_OP_ROPE: { @@ -3860,7 +4065,7 @@ struct ggml_backend_opencl_buffer_context { // The buffer_context is initially created by ggml_backend_buft_alloc_buffer // before any tensor is initialized (at the beginning of alloc_tensor_range). - // Hence, there is alway a buffer object in this vector. When each tensor is + // Hence, there is always a buffer object in this vector. When each tensor is // being initialized, this original buffer object will be released if both // flattening and small allocation are enabled, and additional buffer // objects will be created in init_tensor to represent flattened quantized @@ -3995,7 +4200,7 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, //GGML_ASSERT(offset == 0); // We create subbuffers from the original tensor buffer for scales and - // quants - i.e., scales and quants are aliases into the buffer obejct + // quants - i.e., scales and quants are aliases into the buffer object // that backs the original tensor. This is a cleaner way to adapt to the // new memory management. // In the old code, we allocate new buffers for scales and quants @@ -4271,7 +4476,15 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); CL_CHECK(err); + #ifdef GGML_OPENCL_USE_ADRENO_KERNELS cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + + if (use_adreno_kernels(backend_ctx, tensor)) { + kernel = backend_ctx->kernel_convert_block_q4_1_noshuffle; + } + #else + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + #endif // GGML_OPENCL_USE_ADRENO_KERNELS CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); @@ -4287,6 +4500,22 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, tensor->extra = extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + + int M = tensor->ne[1]; + int K = tensor->ne[0]; + + GGML_ASSERT(K % 32 == 0); + + // Transpose q as ushort + transpose_2d_as_16b(backend_ctx, extra->q, extra->q, size_q, K/4, M); + // Transpose d as ushort + transpose_2d_as_16b(backend_ctx, extra->d, extra->d, size_d, K/32, M); + // Transpose m as ushort + transpose_2d_as_16b(backend_ctx, extra->m, extra->m, size_m, K/32, M); + } +#endif // GGML_OPENCL_USE_ADRENO_KERNELS return; } if (tensor->type == GGML_TYPE_MXFP4) { @@ -4795,6 +5024,53 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, if (tensor->type == GGML_TYPE_Q4_1) { ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + if (use_adreno_kernels(backend_ctx, tensor)) { + static ggml_cl_buffer buf_trans_q; + static ggml_cl_buffer buf_trans_m; + static ggml_cl_buffer buf_trans_d; + static ggml_cl_buffer buf_unpacked; + + cl_int M = tensor->ne[1]; + cl_int K = tensor->ne[0]; + + GGML_ASSERT(K % ggml_blck_size(tensor->type) == 0); + + size_t size_q = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*ggml_blck_size(tensor->type)/2; + size_t size_d = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + size_t size_m = (ggml_nelements(tensor)/ggml_blck_size(tensor->type))*sizeof(ggml_fp16_t); + GGML_ASSERT(size_d + size_q + size_m == ggml_nbytes(tensor) && "Incorrect tensor size"); + + buf_trans_q.allocate(backend_ctx->context, size_q); + buf_trans_m.allocate(backend_ctx->context, size_m); + buf_trans_d.allocate(backend_ctx->context, size_d); + buf_unpacked.allocate(backend_ctx->context, ggml_nbytes(tensor)); + + // transpose q, d, m back + transpose_2d_as_16b(backend_ctx, extra->q, buf_trans_q.buffer, size_q, M, K/4); + transpose_2d_as_16b(backend_ctx, extra->d, buf_trans_d.buffer, size_d, M, K/32); + transpose_2d_as_16b(backend_ctx, extra->m, buf_trans_m.buffer, size_m, M, K/32); + + cl_uchar mask_0F = 0x0F; + cl_uchar mask_F0 = 0xF0; + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1_noshuffle; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &buf_trans_q.buffer)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &buf_trans_d.buffer)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &buf_trans_m.buffer)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &buf_unpacked.buffer)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_uchar), &mask_0F)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_uchar), &mask_F0)); + + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, NULL)); + CL_CHECK(clEnqueueReadBuffer(queue, buf_unpacked.buffer, CL_TRUE, offset, size, data, 0, NULL, NULL)); + return; + } +#endif + cl_int err; cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, ggml_nbytes(tensor), NULL, &err); @@ -4886,8 +5162,8 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, int ne00 = tensor->ne[0]; int ne01 = tensor->ne[1]; - GGML_ASSERT(tensor->ne[2] == 1); // ??? - GGML_ASSERT(tensor->ne[3] == 1); // ??? + GGML_ASSERT(tensor->ne[2] == 1); + GGML_ASSERT(tensor->ne[3] == 1); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); @@ -5069,7 +5345,8 @@ static const char * ggml_backend_opencl_device_get_description(ggml_backend_dev_ } static void ggml_backend_opencl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { - *free = 0; + // no memory to report + *free = 0; *total = 0; GGML_UNUSED(dev); @@ -7373,6 +7650,170 @@ static void ggml_cl_tanh(ggml_backend_t backend, const ggml_tensor * src0, const } } +static void ggml_cl_neg(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel; + + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32_4; + } else { + kernel = backend_ctx->kernel_neg_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32; + } else { + kernel = backend_ctx->kernel_neg_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &n)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_neg_f32_nc; + } else { + kernel = backend_ctx->kernel_neg_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + +static void ggml_cl_exp(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel; + + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32_4; + } else { + kernel = backend_ctx->kernel_exp_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32; + } else { + kernel = backend_ctx->kernel_exp_f16; + } + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_int), &n)); + + size_t global_work_size[] = {(size_t)CEIL_DIV(n, 64)*64, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_exp_f32_nc; + } else { + kernel = backend_ctx->kernel_exp_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } +} + static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -8371,6 +8812,180 @@ static void ggml_cl_mul_mat_kq_kqv_adreno(ggml_backend_t backend, const ggml_ten CL_CHECK(clReleaseMemObject(D_sub_buffer)); } +static void ggml_cl_mul_mat_q4_1_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { +#ifdef GGML_OPENCL_USE_ADRENO_KERNELS + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + + const int ne1 = dst->ne[1]; + + GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0); + + cl_context context = backend_ctx->context; + cl_kernel kernel; + + cl_int err; + cl_image_format img_fmt; + cl_image_desc img_desc; + cl_buffer_region region; + + int M = ne01; + int N = ne1; + int K = ne00; + + if (ne1 == 1) { + cl_mem q_img = nullptr; + cl_mem b_sub_buf = nullptr; + cl_mem b_img = nullptr; + + // image for q + img_fmt = { CL_R, CL_UNSIGNED_INT32}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = M * K / 2 / 4; + img_desc.buffer = extra0_q4_1->q; + CL_CHECK((q_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + kernel = backend_ctx->kernel_gemv_noshuffle_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &q_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne01)); + + size_t local_work_size[3] = {64, 4, 1}; + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne01/2, 64)*64, 4, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(q_img)); + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_img)); + } else { + cl_mem b_sub_buf = nullptr; + cl_mem b_sub_buf_trans = nullptr; + cl_mem b_img = nullptr; + cl_mem b_img_trans = nullptr; + + // subbuffer for activations + region.origin = offset1; + region.size = K * N * sizeof(float); + CL_CHECK((b_sub_buf = clCreateSubBuffer(extra1->data_device, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for activations + img_fmt = {CL_RGBA, CL_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * N / 4; + img_desc.buffer = b_sub_buf; + CL_CHECK((b_img = clCreateImage(context, CL_MEM_READ_ONLY, &img_fmt, &img_desc, NULL, &err), err)); + + // pad N to multiple of 8 + int extra_elements = N % 8; + int padding = 0; + if (extra_elements > 0){ + padding = 8 - extra_elements; + } + + // subbuffer for transposed activations + region.origin = 0; + region.size = K * (N + padding) * sizeof(float)/2; + backend_ctx->prealloc_act_trans.allocate(context, region.size); + CL_CHECK((b_sub_buf_trans = clCreateSubBuffer(backend_ctx->prealloc_act_trans.buffer, 0, CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err), err)); + + // image for transposed activations + img_fmt = {CL_RGBA, CL_HALF_FLOAT}; + memset(&img_desc, 0, sizeof(img_desc)); + img_desc.image_type = CL_MEM_OBJECT_IMAGE1D_BUFFER; + img_desc.image_width = K * (N + padding) / 4; + img_desc.buffer = b_sub_buf_trans; + CL_CHECK((b_img_trans = clCreateImage(context, 0, &img_fmt, &img_desc, NULL, &err), err)); + + // transpose activations + int height_B = N/4; + if (height_B == 0) { + height_B = 1; + } + int width_B = K/4; + int padded_height_B = (N + padding)/4; + + kernel = backend_ctx->kernel_transpose_32_16; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &b_img)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(int), &height_B)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &width_B)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &padded_height_B)); + + size_t local_work_size_t[2] = { 1, 16 }; + size_t global_work_size_t[2] = { (size_t)width_B, (size_t)padded_height_B }; + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size_t, local_work_size_t, dst); + + // gemm + kernel = backend_ctx->kernel_gemm_noshuffle_q4_1_f32; + int padded_N = N + padding; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &b_img_trans)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &padded_N)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_int), &ne1)); + + size_t global_work_size[3] = {(size_t)CEIL_DIV(ne1, 8), (size_t)CEIL_DIV(ne01, 4), 1}; + size_t local_work_size[3] = {1, 128, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + + CL_CHECK(clReleaseMemObject(b_sub_buf)); + CL_CHECK(clReleaseMemObject(b_sub_buf_trans)); + CL_CHECK(clReleaseMemObject(b_img)); + CL_CHECK(clReleaseMemObject(b_img_trans)); + } +#else + GGML_UNUSED(backend); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(dst); +#endif +} + static void ggml_cl_mul_mat_q8_0_f32_adreno(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { #ifdef GGML_OPENCL_USE_ADRENO_KERNELS GGML_ASSERT(src0); @@ -8736,6 +9351,16 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co int padding; // <--------------------------------------------> // + // NOTE: Kernels using image1d_buffer_t (e.g., src0_q) would normally require + // a limit check, but q4_0 / q4_1 tensors are very unlikely to exceed that + // limit, so the check is omitted. + + // q4_1 x fp32 + if (src0t == GGML_TYPE_Q4_1 && src1t == GGML_TYPE_F32) { + ggml_cl_mul_mat_q4_1_f32_adreno(backend, src0, src1, dst); + return; + } + // q8_0 x fp32 if (src0t == GGML_TYPE_Q8_0 && src1t == GGML_TYPE_F32 && enable_adreno_trans_weight(backend_ctx, src0)) { @@ -10402,28 +11027,13 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const // GGML_OP_DUP and GGML_OP_CONT happen between src0 and dst. UNUSED(dst); - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb00 = src0 ? src0->nb[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; - - const int ne10 = src1 ? src1->ne[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const int ne12 = src1 ? src1->ne[2] : 0; - const int ne13 = src1 ? src1->ne[3] : 0; - - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb12 = src1 ? src1->nb[2] : 0; - const cl_ulong nb13 = src1 ? src1->nb[3] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); - const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; - const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; + const enum ggml_type src0t = src0->type; + const enum ggml_type src1t = src1->type; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -10460,6 +11070,15 @@ static void ggml_cl_cpy(ggml_backend_t backend, const ggml_tensor * src0, const GGML_ASSERT(false && "not implemented"); } break; + case GGML_TYPE_I32: + switch (src1t) { + case GGML_TYPE_I32: + kernel = backend_ctx->kernel_cpy_i32_i32; + break; + default: + GGML_ASSERT(false && "not implemented"); + } + break; default: GGML_ASSERT(false && "not implemented"); } @@ -10498,6 +11117,89 @@ static void ggml_cl_dup(ggml_backend_t backend, const ggml_tensor * src0, const UNUSED(src1); } +static void ggml_cl_set(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32) && + src1->type == src0->type && dst->type == src0->type); + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne1, src1, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb1, src1, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_ulong pnb1 = ((const int32_t *)dst->op_params)[0]; + const cl_ulong pnb2 = ((const int32_t *)dst->op_params)[1]; + const cl_ulong pnb3 = ((const int32_t *)dst->op_params)[2]; + const cl_ulong offs = ((const int32_t *)dst->op_params)[3]; + const bool inplace = (bool)((const int32_t *)dst->op_params)[4]; + + cl_kernel kernel = nullptr; + + // for inplace case, dst is a view of src0 and is updated on top of it + // so for non-inplace case, copy src0 to dst first + if (!inplace) { + ggml_cl_cpy(backend, src0, dst, nullptr); + } + + // then copy src1 to dst with specified offset + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_cpy_f32_f32; + } else if (src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + kernel = backend_ctx->kernel_cpy_i32_i32; + } else { + GGML_ASSERT(false && "not implemented"); + } + + offsetd += offs; + cl_ulong nb = ggml_element_size(dst); + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &pnb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &pnb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &pnb3)); + + int max_local_size = backend_ctx->get_kernel_workgroup_size(kernel); + + const int nth = MIN(max_local_size, ne00); + + size_t global_work_size[] = {(size_t)ne11*nth, (size_t)ne12, (size_t)ne13}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -10560,6 +11262,49 @@ static void ggml_cl_diag_mask_inf(ggml_backend_t backend, const ggml_tensor * sr } } +static void ggml_cl_diag(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + UNUSED(src1); + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); + GGML_TENSOR_LOCALS(int, ne, dst, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb, dst, nb); + + cl_kernel kernel = backend_ctx->kernel_diag_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne1*nth, (size_t)ne2, (size_t)ne3}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -11271,6 +12016,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_cpy; break; + case GGML_OP_SET: + if (!any_on_device) { + return false; + } + func = ggml_cl_set; + break; case GGML_OP_DUP: case GGML_OP_CONT: if (!any_on_device) { @@ -11370,6 +12121,18 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_tanh; break; + case GGML_UNARY_OP_NEG: + if (!any_on_device) { + return false; + } + func = ggml_cl_neg; + break; + case GGML_UNARY_OP_EXP: + if (!any_on_device) { + return false; + } + func = ggml_cl_exp; + break; case GGML_UNARY_OP_EXPM1: if (!any_on_device) { return false; @@ -11496,6 +12259,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } func = ggml_cl_nop; break; + case GGML_OP_DIAG: + if (!any_on_device) { + return false; + } + func = ggml_cl_diag; + break; case GGML_OP_DIAG_MASK_INF: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/cpy.cl b/ggml/src/ggml-opencl/kernels/cpy.cl index 9369351a60c..820aa538a34 100644 --- a/ggml/src/ggml-opencl/kernels/cpy.cl +++ b/ggml/src/ggml-opencl/kernels/cpy.cl @@ -182,3 +182,48 @@ kernel void kernel_cpy_f32_f32( dst_data[i00] = src[0]; } } + +kernel void kernel_cpy_i32_i32( + global int * src0, + ulong offset0, + global int * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + int ne1, + int ne2, + int ne3, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = (global int*)((global char*)src0 + offset0); + dst = (global int*)((global char*)dst + offsetd); + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + int n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + int i3 = n / (ne2*ne1*ne0); + int i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + int i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + int i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + global int * dst_data = (global int *) ((global char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { + global const int * src = (global int *)((global char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 2c244ce3215..78ef9c177f6 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -199,6 +199,58 @@ kernel void kernel_restore_block_q4_1( } } +kernel void kernel_convert_block_q4_1_noshuffle( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + for (int i = 0; i < QK4_1/4; ++i) { + uchar x0 = b->qs[2*i + 0]; + uchar x1 = b->qs[2*i + 1]; + + q[i + 0 ] = convert_uchar(x0 & 0x0F) | convert_uchar((x1 & 0x0F) << 4); + q[i + QK4_1/4] = convert_uchar((x0 & 0xF0) >> 4) | convert_uchar(x1 & 0xF0); + +#ifdef ADRENO_GPU + if (get_global_id(0) == 65536*4096) { + printf("%04x - %02x\n", *(global ushort*)d, ((x0 & 0xF0) >> 4) | (x1 & 0xF0)); + } +#endif + } +} + +kernel void kernel_restore_block_q4_1_noshuffle( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst, + uchar mask_0F, + uchar mask_F0 +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/4; ++i) { + uchar x0 = q[i + 0 ] ; + uchar x1 = q[i + QK4_1/4]; + + b->qs[2*i + 0] = convert_uchar((x0 & mask_0F) | ((x1 & mask_0F) << 4)); + b->qs[2*i + 1] = convert_uchar(((x0 & mask_F0) >> 4) | (x1 & mask_F0)); + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/diag.cl b/ggml/src/ggml-opencl/kernels/diag.cl new file mode 100644 index 00000000000..884efa08fdd --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/diag.cl @@ -0,0 +1,27 @@ +kernel void kernel_diag_f32( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + ulong nb01, + ulong nb02, + ulong nb03, + int ne0, + ulong nb0, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + int i3 = get_group_id(2); + int i2 = get_group_id(1); + int i1 = get_group_id(0); + + global const float * src0_ptr = (global const float *)(src0 + i2*nb02 + i3*nb03); + global float * dst_ptr = (global float *)(dst + i1*nb01 + i2*nb2 + i3*nb3); + + for (int i0 = get_local_id(0); i0 < ne0; i0 += get_local_size(0)) { + dst_ptr[i0] = i0 == i1 ? src0_ptr[i0] : 0.0f; + } +} diff --git a/ggml/src/ggml-opencl/kernels/exp.cl b/ggml/src/ggml-opencl/kernels/exp.cl new file mode 100644 index 00000000000..a2458b6579c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/exp.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_exp_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]); +} + +kernel void kernel_exp_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = exp(*x); + } +} + +kernel void kernel_exp_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = exp(*x); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl new file mode 100644 index 00000000000..5c4d5cc8e2c --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemm_noshuffle_q4_1_f32.cl @@ -0,0 +1,132 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_128 +#endif + +kernel void kernel_gemm_noshuffle_q4_1_f32( + global const ushort * src0_q, + global const half * src0_d, + global const half * src0_m, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int m, + int n, + int k, + int n_no_padding +) { + dst = (global float *)((global char *)dst + offsetd); + + int m_4 = m >> 2; + int n_4 = n >> 2; + + int gy = get_global_id(0); + int gx = get_global_id(1); + int gx_2 = gx << 2; + + half8 c0 = 0, c1 = 0, c2 = 0, c3 = 0; + half8 B; + half4 dequantized_weights; + + global const ushort* weight_ptr = src0_q + gx_2; + global const half* scale_ptr = src0_d + gx_2; + global const half* min_ptr = src0_m + gx_2; + + for(int i = 0; i < k; i += 4) { + B.s0123 = read_imageh(src1, gy*2 + (i)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i)*(n_4)+1); + + ushort4 bits4 = vload4(0, weight_ptr + (i/4)*(m)); + + half4 scale = vload4(0, scale_ptr + (i/32)*(m)); + half4 minv = vload4(0, min_ptr + (i/32)*(m)); + + // j=0 + dequantized_weights.s0 = (bits4.s0 & (0x000F)) * scale.s0 + minv.s0; + dequantized_weights.s1 = (bits4.s1 & (0x000F)) * scale.s1 + minv.s1; + dequantized_weights.s2 = (bits4.s2 & (0x000F)) * scale.s2 + minv.s2; + dequantized_weights.s3 = (bits4.s3 & (0x000F)) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=1 + B.s0123 = read_imageh(src1, gy*2 + (i+1)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+1)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0x00F0)) >> 4) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0x00F0)) >> 4) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0x00F0)) >> 4) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0x00F0)) >> 4) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=2 + B.s0123 = read_imageh(src1, gy*2 + (i+2)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+2)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0x0F00)) >> 8) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0x0F00)) >> 8) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0x0F00)) >> 8) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0x0F00)) >> 8) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + + // j=3 + B.s0123 = read_imageh(src1, gy*2 + (i+3)*(n_4)); + B.s4567 = read_imageh(src1, gy*2 + (i+3)*(n_4)+1); + dequantized_weights.s0 = ((bits4.s0 & (0xF000)) >> 12) * scale.s0 + minv.s0; + dequantized_weights.s1 = ((bits4.s1 & (0xF000)) >> 12) * scale.s1 + minv.s1; + dequantized_weights.s2 = ((bits4.s2 & (0xF000)) >> 12) * scale.s2 + minv.s2; + dequantized_weights.s3 = ((bits4.s3 & (0xF000)) >> 12) * scale.s3 + minv.s3; + c0 += B * dequantized_weights.s0; + c1 += B * dequantized_weights.s1; + c2 += B * dequantized_weights.s2; + c3 += B * dequantized_weights.s3; + } + + int idx = (gy<<3)*m + (gx<<2); + + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s0, c1.s0, c2.s0, c3.s0), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s1, c1.s1, c2.s1, c3.s1), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s2, c1.s2, c2.s2, c3.s2), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s3, c1.s3, c2.s3, c3.s3), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s4, c1.s4, c2.s4, c3.s4), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s5, c1.s5, c2.s5, c3.s5), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s6, c1.s6, c2.s6, c3.s6), 0, dst + idx); + idx += m; + } + if(idx+3 < m*n_no_padding){ + vstore4((float4)(c0.s7, c1.s7, c2.s7, c3.s7), 0, dst + idx); + } +} diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl index f944ef3a992..9703b693e56 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_general_q8_0_f32.cl @@ -121,7 +121,7 @@ #ifdef ADRENO_GPU REQD_SUBGROUP_SIZE_64 #endif -__kernel void kernel_gemv_noshuffle( +__kernel void kernel_gemv_noshuffle_q8_0_f32( __read_only image1d_buffer_t src0_q, // quantized A global half * src0_d, // A scales __read_only image1d_buffer_t src1, // B diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl new file mode 100644 index 00000000000..fdc1472454f --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q4_1_f32.cl @@ -0,0 +1,283 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable + +#ifdef cl_qcom_reqd_sub_group_size +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#endif + +#define QK4_0 32 +#define NSUBGROUPS 4 +#define SUBGROUP_SIZE 64 + +#define dequantizeBlockAccum_ns_sgbroadcast_1_hi(total_sums, bits4, scale, minv, y) \ + float shared_y; \ + shared_y = sub_group_broadcast(y.s0, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 0); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 0); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 0); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 0); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 0); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 0); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 0); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 1); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 1); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 1); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 1); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 1); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 1); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 1); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_1_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y.s0, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 2); \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 2); \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 2); \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 2); \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 2); \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 2); \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 2); \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s0, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s1, 3); \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s2, 3); \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s3, 3); \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s4, 3); \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s5, 3); \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s6, 3); \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y; \ + shared_y = sub_group_broadcast(y.s7, 3); \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_hi(total_sums, bits4, scale, minv, y) \ + float8 shared_y; \ + shared_y = sub_group_broadcast(y, 0); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 1); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + + +#define dequantizeBlockAccum_ns_sgbroadcast_8_lo(total_sums, bits4, scale, minv, y) \ + shared_y = sub_group_broadcast(y, 2); \ + total_sums.s0 += ((bits4.s0 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s0 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s0 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s0 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s2 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s2 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s2 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s2 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s1 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s1 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s1 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s1 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s3 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s3 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s3 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s3 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + shared_y = sub_group_broadcast(y, 3); \ + total_sums.s0 += ((bits4.s4 & 0x000F) * scale.s0 + minv.s0) * shared_y.s0; \ + total_sums.s0 += (((bits4.s4 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s1; \ + total_sums.s0 += (((bits4.s4 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s2; \ + total_sums.s0 += (((bits4.s4 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s3; \ + total_sums.s0 += ((bits4.s6 & 0x000F) * scale.s0 + minv.s0) * shared_y.s4; \ + total_sums.s0 += (((bits4.s6 & 0x00F0) >> 4) * scale.s0 + minv.s0) * shared_y.s5; \ + total_sums.s0 += (((bits4.s6 & 0x0F00) >> 8) * scale.s0 + minv.s0) * shared_y.s6; \ + total_sums.s0 += (((bits4.s6 & 0xF000) >> 12) * scale.s0 + minv.s0) * shared_y.s7; \ + total_sums.s1 += ((bits4.s5 & 0x000F) * scale.s1 + minv.s1) * shared_y.s0; \ + total_sums.s1 += (((bits4.s5 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s1; \ + total_sums.s1 += (((bits4.s5 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s2; \ + total_sums.s1 += (((bits4.s5 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s3; \ + total_sums.s1 += ((bits4.s7 & 0x000F) * scale.s1 + minv.s1) * shared_y.s4; \ + total_sums.s1 += (((bits4.s7 & 0x00F0) >> 4) * scale.s1 + minv.s1) * shared_y.s5; \ + total_sums.s1 += (((bits4.s7 & 0x0F00) >> 8) * scale.s1 + minv.s1) * shared_y.s6; \ + total_sums.s1 += (((bits4.s7 & 0xF000) >> 12) * scale.s1 + minv.s1) * shared_y.s7; \ + +#ifdef ADRENO_GPU +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_gemv_noshuffle_q4_1_f32( + read_only image1d_buffer_t src0_q, + global half2 * src0_d, + global half2 * src0_m, + read_only image1d_buffer_t src1, + global float * dst, + ulong offsetd, + int ne00, + int ne01) +{ + uint groupId = get_local_id(1); + uint gid = get_global_id(0); + ushort slid = get_sub_group_local_id(); + + uint K = ne00; + uint M = ne01; + + uint LINE_STRIDE_A = M / 2; + uint BLOCK_STRIDE_A = NSUBGROUPS * M; + + private uint4 regA; + private half2 regS; + private half2 regM; + private float8 regB; + + private float2 totalSum = (float2)(0.0f); + + // loop along K in block granularity, skip 4 blocks every iter + for (uint k = groupId; k < (K / QK4_0); k += NSUBGROUPS) { + regS = src0_d[gid + k * LINE_STRIDE_A]; // each fiber loads scale of two rows + regM = src0_m[gid + k * LINE_STRIDE_A]; // each fiber loads min of two rows + // first 4 fibers in each wave load 8 B values to its private scope + if (slid < 4) { + regB.s0123 = read_imagef(src1, (slid * 2 + k * 8)); + regB.s4567 = read_imagef(src1, (1 + slid * 2 + k * 8)); + } + + // load half weights for two blocks in consecutive rows + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 0)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 1)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 2)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 3)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_hi(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + + regA.s0 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 4)).x; + regA.s1 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 5)).x; + regA.s2 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; + regA.s3 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; +#ifdef VECTOR_SUB_GROUP_BROADCAT + dequantizeBlockAccum_ns_sgbroadcast_8_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#else + dequantizeBlockAccum_ns_sgbroadcast_1_lo(totalSum, as_ushort8(regA), regS, regM, regB); +#endif // VECTOR_SUB_GROUP_BROADCAT + } + + // reduction in local memory, assumes #wave=4 + local float2 reduceLM[SUBGROUP_SIZE * 3]; + if (groupId == 1) { + reduceLM[SUBGROUP_SIZE * 0 + slid] = totalSum; + } + if (groupId == 2) { + reduceLM[SUBGROUP_SIZE * 1 + slid] = totalSum; + } + if (groupId == 3) { + reduceLM[SUBGROUP_SIZE * 2 + slid] = totalSum; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 0 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 1 + slid]; + } + if (groupId == 0) { + totalSum += reduceLM[SUBGROUP_SIZE * 2 + slid]; + } + + // 2 outputs per fiber in wave 0 + if (groupId == 0) { + dst = (global float*)((global char*)dst + offsetd); + vstore2(totalSum, 0, &(dst[gid * 2])); + } + +} diff --git a/ggml/src/ggml-opencl/kernels/neg.cl b/ggml/src/ggml-opencl/kernels/neg.cl new file mode 100644 index 00000000000..a862d8bc585 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/neg.cl @@ -0,0 +1,125 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +kernel void kernel_neg_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd, + int n +) { + if (get_global_id(0) >= n) { + return; + } + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = -src0[get_global_id(0)]; +} + +kernel void kernel_neg_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = -*x; + } +} + +kernel void kernel_neg_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + ulong nb00, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + *y = -*x; + } +} diff --git a/ggml/src/ggml-opencl/kernels/transpose.cl b/ggml/src/ggml-opencl/kernels/transpose.cl index 1279b6531b9..ad89bdcbdec 100644 --- a/ggml/src/ggml-opencl/kernels/transpose.cl +++ b/ggml/src/ggml-opencl/kernels/transpose.cl @@ -44,6 +44,19 @@ kernel void kernel_transpose_16_4x1( write_imageh(output, i * rows + j, (half4)(temp0, temp1, temp2, temp3)); } +// Transpose treating each element as 8-bit using buffer +kernel void kernel_transpose_8_buf( + global const uchar * input, + global uchar * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // Transpose treating each element as 16-bit using buffer kernel void kernel_transpose_16_buf( global const ushort * input, @@ -57,6 +70,19 @@ kernel void kernel_transpose_16_buf( output[x*ldo + y] = input[y*ldi + x]; } +// Transpose treating each element as 32-bit using buffer +kernel void kernel_transpose_32_buf( + global const uint * input, + global uint * output, + const int ldi, + const int ldo +) { + const int x = get_global_id(0); + const int y = get_global_id(1); + + output[x*ldo + y] = input[y*ldi + x]; +} + // 32-bit transpose, loading/storing a 4x4 tile of elements kernel void kernel_transpose_32( __read_only image1d_buffer_t input, diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index de5cbd75e86..e8e25633fb8 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3104,6 +3104,11 @@ static void quantize_row_iq2_xxs_impl(const float * GGML_RESTRICT x, void * GGML } float scale = make_qp_quants(32, kMaxQ+1, xval, (uint8_t*)L, weight); float eff_max = scale*kMaxQ; + if (eff_max <= 0) { + scales[ib] = 0; + memset(L, 0, 32); + continue; + } float best = 0; for (int is = -6; is <= 6; ++is) { float id = (2*kMaxQ-1+is*0.1f)/eff_max; @@ -3273,9 +3278,9 @@ static void quantize_row_iq2_xs_impl(const float * GGML_RESTRICT x, void * GGML_ } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS) { scales[ib] = 0; - memset(L, 0, 16); continue; } float best = 0; @@ -3714,9 +3719,9 @@ static void quantize_row_iq3_xxs_impl(int grid_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < 32; ++i) max = MAX(max, xval[i]); + memset(L, 0, 32); if (max < GROUP_MAX_EPS_IQ3_XXS) { scales[ib] = 0; - memset(L, 0, 32); continue; } float best = 0; @@ -3922,6 +3927,7 @@ static void quantize_row_iq3_s_impl(int block_size, const float * GGML_RESTRICT } float max = xval[0]; for (int i = 1; i < block_size; ++i) max = MAX(max, xval[i]); + memset(L, 0, block_size); if (!max) { scales[ib] = 0; continue; @@ -4245,6 +4251,7 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_S) { scales[ib] = 0; + shifts[ib] = 1; memset(L, 1, block_size); continue; } @@ -4285,7 +4292,12 @@ static void quantize_row_iq1_s_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_shift != 0); + if (besti1 < 0 || besti2 < 0 || best_shift == 0) { + scales[ib] = 0; + shifts[ib] = 1; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4429,6 +4441,7 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R for (int i = 1; i < block_size; ++i) max = MAX(max, fabsf(xb[i])); if (max < GROUP_MAX_EPS_IQ1_M) { scales[ib] = 0; + shifts[ib] = 0; memset(L, 1, block_size); continue; } @@ -4527,7 +4540,12 @@ static void quantize_row_iq1_m_impl(const float * GGML_RESTRICT x, void * GGML_R } } } - GGML_ASSERT(besti1 >= 0 && besti2 >= 0 && best_k >= 0); + if (besti1 < 0 || besti2 < 0 || best_k < 0) { + scales[ib] = 0; + shifts[ib] = 0; + memset(L, 1, block_size); + continue; + } for (int j = 0; j < besti1; ++j) L[idx[2*j]] = 0; for (int j = besti1; j < besti2; ++j) L[idx[2*j]] = 1; for (int j = besti2; j < block_size; ++j) L[idx[2*j]] = 2; @@ -4874,6 +4892,7 @@ static void quantize_row_iq2_s_impl(const float * GGML_RESTRICT x, void * GGML_R } float max = xval[0]; for (int i = 1; i < 16; ++i) max = MAX(max, xval[i]); + memset(L, 0, 16); if (max < GROUP_MAX_EPS_IQ2_S) { scales[ib] = 0; continue; diff --git a/ggml/src/ggml-sycl/common.hpp b/ggml/src/ggml-sycl/common.hpp index 519638fd416..04c9e1d7864 100644 --- a/ggml/src/ggml-sycl/common.hpp +++ b/ggml/src/ggml-sycl/common.hpp @@ -76,10 +76,10 @@ extern int g_ggml_sycl_prioritize_dmmv; #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP -#define VER_4VEC 610 // todo for hardward optimize. -#define VER_GEN9 700 // todo for hardward optimize. -#define VER_GEN12 1000000 // todo for hardward optimize. -#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardward optimize. +#define VER_4VEC 610 // todo for hardware optimize. +#define VER_GEN9 700 // todo for hardware optimize. +#define VER_GEN12 1000000 // todo for hardware optimize. +#define VER_GEN13 (VER_GEN12 + 1030) // todo for hardware optimize. #define GGML_SYCL_MAX_NODES 8192 // TODO: adapt to hardwares diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index d0d5ac9a4e8..14490fea5be 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -29,7 +29,7 @@ namespace ggml_sycl_reordered { // [qs0, qs1, qs2, ..., qsN] [d0, d1, d2, ..., dN] // // Notes: out-of-bounds qs will run into d values -// Aligment relies on the allocated size of qs +// Alignment relies on the allocated size of qs template struct block_q_t; diff --git a/ggml/src/ggml-sycl/softmax.cpp b/ggml/src/ggml-sycl/softmax.cpp index b41124acc13..15d92e5e04c 100644 --- a/ggml/src/ggml-sycl/softmax.cpp +++ b/ggml/src/ggml-sycl/softmax.cpp @@ -37,7 +37,7 @@ struct soft_max_params { }; // When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled. -// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here. +// As we want to keep pragma unroll for all other cases we suppress the clang transformation warning here. #ifdef __clang__ #pragma clang diagnostic push #pragma clang diagnostic ignored "-Wpass-failed" diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index de01336cd3f..715a263a6d0 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -90,7 +90,7 @@ if (Vulkan_FOUND) target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build - # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + # Possibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) endif() diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0fae68628b6..23d6d39e0e8 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -590,6 +590,7 @@ struct vk_device_struct { vk_queue transfer_queue; bool single_queue; bool support_async; + bool async_use_transfer_queue; uint32_t subgroup_size; uint32_t subgroup_size_log2; uint32_t shader_core_count; @@ -1858,6 +1859,10 @@ struct ggml_backend_vk_context { vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + vk_semaphore transfer_semaphore; + uint64_t transfer_semaphore_last_submitted {}; + std::vector tensor_ctxs; std::vector descriptor_pools; @@ -1866,6 +1871,7 @@ struct ggml_backend_vk_context { uint32_t pipeline_descriptor_set_requirements {}; vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; // number of additional consecutive nodes that are being fused with the // node currently being processed @@ -5391,13 +5397,19 @@ static vk_device ggml_vk_get_device(size_t idx) { ggml_vk_load_shaders(device); + const bool prefers_transfer_queue = device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN; + if (!device->single_queue) { const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + + device->async_use_transfer_queue = prefers_transfer_queue || (getenv("GGML_VK_ASYNC_USE_TRANSFER_QUEUE") != nullptr); } else { // TODO: Use pointer or reference to avoid copy device->transfer_queue.copyFrom(device->compute_queue); device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); + + device->async_use_transfer_queue = false; } device->buffer_type = { @@ -5871,6 +5883,15 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { ctx->almost_ready_fence = ctx->device->device.createFence({}); ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + if (ctx->device->async_use_transfer_queue) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + ctx->transfer_semaphore.s = ctx->device->device.createSemaphore(ci); + ctx->transfer_semaphore.value = 0; + + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + } if (vk_perf_logger_enabled) { ctx->perf_logger = std::unique_ptr(new vk_perf_logger()); @@ -6419,6 +6440,47 @@ static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); } +static vk_context ggml_vk_get_compute_ctx(ggml_backend_vk_context * ctx) { + if (!ctx->compute_ctx.expired()) { + return ctx->compute_ctx.lock(); + } + + vk_context result = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + + ctx->compute_ctx = result; + ggml_vk_ctx_begin(ctx->device, result); + + if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { + result->s->wait_semaphores.push_back(ctx->transfer_semaphore); + ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value; + } + + return result; +} + +// Submit any pending transfer queue work and signal the transfer semaphore. +// The next compute context created via ggml_vk_get_compute_ctx will wait on this semaphore. +// Returns true if work was submitted. +static bool ggml_vk_submit_transfer_ctx(ggml_backend_vk_context * ctx) { + if (!ctx->device->async_use_transfer_queue || ctx->transfer_ctx.expired()) { + return false; + } + + vk_context cpy_ctx = ctx->transfer_ctx.lock(); + ggml_vk_ctx_end(cpy_ctx); + + for (auto& cpy : cpy_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_semaphore.value++; + cpy_ctx->seqs.back().back().signal_semaphores.push_back(ctx->transfer_semaphore); + + ggml_vk_submit(cpy_ctx, {}); + ctx->transfer_ctx.reset(); + return true; +} + static size_t ggml_vk_align_size(size_t width, size_t align) { VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); return CEIL_DIV(width, align) * align; @@ -7512,6 +7574,18 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ return false; } + if (device->driver_id == vk::DriverId::eIntelProprietaryWindows) { + // Intel Windows proprietary driver tuning + switch (src0_type) { + case GGML_TYPE_MXFP4: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + return false; + default: + return true; + } + } + switch (src0_type) { // From tests on A770 Linux, may need more tuning case GGML_TYPE_Q4_0: @@ -12529,15 +12603,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr } } - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); { // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers @@ -13055,6 +13121,9 @@ static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + if (ctx->device->async_use_transfer_queue) { + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + } for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); @@ -13116,6 +13185,11 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ctx->descriptor_sets.clear(); ctx->compute_cmd_pool.destroy(ctx->device->device); + if (ctx->device->async_use_transfer_queue) { + ctx->device->device.destroySemaphore(ctx->transfer_semaphore.s); + + ctx->transfer_cmd_pool.destroy(ctx->device->device); + } if (vk_perf_logger_enabled) { ctx->perf_logger->print_timings(true); } @@ -13387,34 +13461,38 @@ static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context compute_ctx; + vk_context cpy_ctx; - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); + if (ctx->device->async_use_transfer_queue) { + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = cpy_ctx; + ggml_vk_ctx_begin(ctx->device, cpy_ctx); + } else { + cpy_ctx = ctx->transfer_ctx.lock(); + } } else { - compute_ctx = ctx->compute_ctx.lock(); + cpy_ctx = ggml_vk_get_compute_ctx(ctx); } vk_buffer buf = buf_ctx->dev_buffer; auto dst_offset = vk_tensor_offset(tensor) + tensor->view_offs + offset; - bool ret = ggml_vk_buffer_write_async(compute_ctx, buf, dst_offset, data, size); + bool ret = ggml_vk_buffer_write_async(cpy_ctx, buf, dst_offset, data, size); if (!ret) { ggml_vk_ensure_sync_staging_buffer(ctx, size); - ggml_vk_sync_buffers(nullptr, compute_ctx); + ggml_vk_sync_buffers(nullptr, cpy_ctx); vk::BufferCopy buffer_cpy; buffer_cpy.srcOffset = 0; buffer_cpy.dstOffset = dst_offset; buffer_cpy.size = size; - compute_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); - deferred_memcpy(ctx->sync_staging->ptr, data, size, &compute_ctx->in_memcpys); + cpy_ctx->s->buffer.copyBuffer(ctx->sync_staging->buffer, buf->buffer, { buffer_cpy }); + deferred_memcpy(ctx->sync_staging->ptr, data, size, &cpy_ctx->in_memcpys); ggml_vk_synchronize(ctx); } } @@ -13426,16 +13504,7 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); vk_buffer buf = buf_ctx->dev_buffer; @@ -13458,31 +13527,60 @@ static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_ } } -static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend_src, ggml_backend_t backend_dst, const ggml_tensor * src, ggml_tensor * dst) { VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); - ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; - if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { - ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; - ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend_dst->context; + + if (dst->buffer->buft != ggml_backend_vk_get_default_buffer_type(backend_dst)) { + return false; + } - vk_context compute_ctx; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + + // Async copy only works within the same device + if (src_buf_ctx->dev_buffer->device != dst_buf->device) { + return false; } - vk_buffer src_buf = src_buf_ctx->dev_buffer; - vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); - ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + ggml_vk_buffer_copy_async(compute_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, + src_buf_ctx->dev_buffer, vk_tensor_offset(src) + src->view_offs, + ggml_nbytes(src)); return true; } + if (ggml_backend_buffer_is_host(src->buffer)) { + vk_buffer pinned_buf = nullptr; + size_t pinned_offset = 0; + ggml_vk_host_get(ctx->device, src->data, pinned_buf, pinned_offset); + if (pinned_buf == nullptr) { + return false; + } + + vk_context cpy_ctx; + if (ctx->device->async_use_transfer_queue) { + if (ctx->transfer_ctx.expired()) { + cpy_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = cpy_ctx; + ggml_vk_ctx_begin(ctx->device, cpy_ctx); + } else { + cpy_ctx = ctx->transfer_ctx.lock(); + } + } else { + cpy_ctx = ggml_vk_get_compute_ctx(ctx); + } + + return ggml_vk_buffer_write_async(cpy_ctx, dst_buf, + vk_tensor_offset(dst) + dst->view_offs, + src->data, ggml_nbytes(src)); + } + + GGML_UNUSED(backend_src); return false; } @@ -13491,6 +13589,10 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { bool do_transfer = !ctx->compute_ctx.expired(); + if (ggml_vk_submit_transfer_ctx(ctx)) { + ctx->submit_pending = true; + } + vk_context compute_ctx; if (do_transfer) { compute_ctx = ctx->compute_ctx.lock(); @@ -13506,7 +13608,22 @@ static void ggml_vk_synchronize(ggml_backend_vk_context * ctx) { } if (ctx->submit_pending) { - { + if (ctx->device->async_use_transfer_queue && ctx->transfer_semaphore_last_submitted < ctx->transfer_semaphore.value) { + vk::TimelineSemaphoreSubmitInfo tl_info{ + 1, &ctx->transfer_semaphore.value, + 0, nullptr, + }; + vk::PipelineStageFlags stage = ctx->device->transfer_queue.stage_flags; + vk::SubmitInfo si{ + 1, &ctx->transfer_semaphore.s, &stage, + 0, nullptr, + 0, nullptr, + }; + si.setPNext(&tl_info); + std::lock_guard guard(queue_mutex); + ctx->device->compute_queue.queue.submit({ si }, ctx->fence); + ctx->transfer_semaphore_last_submitted = ctx->transfer_semaphore.value; + } else { std::lock_guard guard(queue_mutex); ctx->device->compute_queue.queue.submit({}, ctx->fence); } @@ -13972,6 +14089,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool first_node_in_batch = true; // true if next node will be first node in a batch int submit_node_idx = 0; // index to first node in a batch + ggml_vk_submit_transfer_ctx(ctx); + vk_context compute_ctx; if (vk_perf_logger_enabled) { // allocate/resize the query pool @@ -13997,9 +14116,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg std::fill(ctx->query_node_idx.begin(), ctx->query_node_idx.end(), 0); GGML_ASSERT(ctx->compute_ctx.expired()); - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx = ggml_vk_get_compute_ctx(ctx); ctx->query_idx = 0; compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->query_pool, ctx->query_idx++); } @@ -14009,13 +14126,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (ctx->prealloc_size_add_rms_partials) { ggml_vk_preallocate_buffers(ctx, nullptr); - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + compute_ctx = ggml_vk_get_compute_ctx(ctx); // initialize partial sums to zero. ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); ggml_vk_sync_buffers(ctx, compute_ctx); @@ -14238,13 +14349,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, i + ctx->num_additional_fused_ops >= last_node, almost_ready, submit); if (vk_perf_logger_enabled && enqueued) { - if (ctx->compute_ctx.expired()) { - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + compute_ctx = ggml_vk_get_compute_ctx(ctx); if (!vk_perf_logger_concurrent) { // track a single node/fusion for the current query ctx->query_nodes[ctx->query_idx] = cgraph->nodes[i]; @@ -14579,16 +14684,9 @@ static void ggml_backend_vk_event_record(ggml_backend_t backend, ggml_backend_ev ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context compute_ctx; + ggml_vk_submit_transfer_ctx(ctx); - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); // the backend interface doesn't have an explicit reset, so reset it here // before we record the command to set it @@ -14609,16 +14707,7 @@ static void ggml_backend_vk_event_wait(ggml_backend_t backend, ggml_backend_even ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; vk_event *vkev = (vk_event *)event->context; - vk_context compute_ctx; - - if (ctx->compute_ctx.expired()) { - // Initialize new transfer context - compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); - ctx->compute_ctx = compute_ctx; - ggml_vk_ctx_begin(ctx->device, compute_ctx); - } else { - compute_ctx = ctx->compute_ctx.lock(); - } + vk_context compute_ctx = ggml_vk_get_compute_ctx(ctx); ggml_vk_wait_events(compute_ctx, {vkev->event}); ggml_vk_ctx_end(compute_ctx); @@ -14631,7 +14720,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .free = */ ggml_backend_vk_free, /* .set_tensor_async = */ ggml_backend_vk_set_tensor_async, /* .get_tensor_async = */ ggml_backend_vk_get_tensor_async, - /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .cpy_tensor_async = */ ggml_backend_vk_cpy_tensor_async, /* .synchronize = */ ggml_backend_vk_synchronize, /* .graph_plan_create = */ NULL, /* .graph_plan_free = */ NULL, @@ -15367,11 +15456,25 @@ static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_ba return buft_ctx->device->idx == ctx->device; } +static int64_t ggml_vk_get_op_batch_size(const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_GET_ROWS: + return 0; + case GGML_OP_MUL_MAT: + return op->ne[1]; + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + return op->ne[2]; + default: + return ggml_nrows(op); + } +} + static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { ggml_backend_vk_device_context * dev_ctx = (ggml_backend_vk_device_context *)dev->context; - return (op->ne[1] >= dev_ctx->op_offload_min_batch_size && op->op != GGML_OP_GET_ROWS) || - (op->ne[2] >= dev_ctx->op_offload_min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + return ggml_vk_get_op_batch_size(op) >= dev_ctx->op_offload_min_batch_size; } static ggml_backend_event_t ggml_backend_vk_device_event_new(ggml_backend_dev_t dev) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 0d5a818dacb..17c5e0fb51f 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -68,6 +68,7 @@ struct ggml_webgpu_shader_lib_context { size_t wg_mem_limit_bytes = 0; bool inplace = false; bool overlap = false; + bool src_overlap = false; bool supports_subgroup_matrix = false; uint32_t sg_mat_m = 0; uint32_t sg_mat_n = 0; @@ -172,6 +173,22 @@ struct ggml_webgpu_scale_pipeline_key_hash { } }; +/** Concat **/ + +struct ggml_webgpu_concat_pipeline_key { + int type; + + bool operator==(const ggml_webgpu_concat_pipeline_key & other) const { return type == other.type; } +}; + +struct ggml_webgpu_concat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_concat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + return seed; + } +}; + /** Binary **/ struct ggml_webgpu_binary_pipeline_key { @@ -179,9 +196,10 @@ struct ggml_webgpu_binary_pipeline_key { int op; bool inplace; bool overlap; + bool src_overlap; bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap && src_overlap == other.src_overlap; } }; @@ -192,6 +210,7 @@ struct ggml_webgpu_binary_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.op); ggml_webgpu_hash_combine(seed, key.inplace); ggml_webgpu_hash_combine(seed, key.overlap); + ggml_webgpu_hash_combine(seed, key.src_overlap); return seed; } }; @@ -400,6 +419,8 @@ class ggml_webgpu_shader_lib { pad_pipelines; // circular/non-circular std::unordered_map binary_pipelines; // type/op/inplace/overlap + std::unordered_map + concat_pipelines; // type std::unordered_map flash_attn_pipelines; std::unordered_mapop, .inplace = context.inplace, .overlap = context.overlap, + .src_overlap = context.src_overlap, }; auto it = binary_pipelines.find(key); @@ -1076,6 +1098,9 @@ class ggml_webgpu_shader_lib { } else if (key.overlap) { defines.push_back("OVERLAP"); variant += "_overlap"; + } else if (key.src_overlap) { + defines.push_back("SRC_OVERLAP"); + variant += "_src_overlap"; } defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); @@ -1089,6 +1114,43 @@ class ggml_webgpu_shader_lib { return binary_pipelines[key]; } + webgpu_pipeline get_concat_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_concat_pipeline_key key = { + .type = context.dst->type, + }; + + auto it = concat_pipelines.find(key); + if (it != concat_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "concat"; + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_I32: + defines.push_back("TYPE_I32"); + variant += "_i32"; + break; + default: + GGML_ABORT("Unsupported type for concat shader"); + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_concat, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + concat_pipelines[key] = pipeline; + return concat_pipelines[key]; + } + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { const bool has_mask = context.src3 != nullptr; const bool has_sinks = context.src4 != nullptr; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1c00d3cb2b1..b2ef2d59010 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -31,6 +31,13 @@ #define ROUNDUP_POW2(x, pow2) (((x) + ((pow2) - 1)) & ~((pow2) - 1)) #define CEIL_DIV(M, N) (((M) + (N) - 1) / (N)) +// Return a rectangular grid of workgroups with minimal over-provisioned workgroups. +// Assumes that the total number of workgroups does not exceed max_per_dim^2. +static inline void compute_2d_workgroups(uint32_t total_wg, uint32_t max_per_dim, uint32_t & wg_x, uint32_t & wg_y) { + wg_y = std::max(1u, CEIL_DIV(total_wg, max_per_dim)); + wg_x = CEIL_DIV(total_wg, wg_y); +} + #ifdef GGML_WEBGPU_DEBUG # define WEBGPU_LOG_DEBUG(msg) std::cout << msg << std::endl # define WEBGPU_DEBUG_BUF_ELEMS 512 @@ -69,8 +76,8 @@ /* Constants */ -#define WEBGPU_NUM_PARAM_BUFS 16u -#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u +#define WEBGPU_NUM_PARAM_BUFS 48u +#define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 16u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 // Maximum number of in-flight submissions per-thread, to avoid exhausting the // parameter buffer pool @@ -116,11 +123,6 @@ struct webgpu_pool_bufs { wgpu::Buffer dev_buf; }; -// The futures to wait on for a single queue submission -struct webgpu_submission_futures { - std::vector futures; -}; - // Holds a pool of parameter buffers for WebGPU operations struct webgpu_buf_pool { std::vector free; @@ -133,12 +135,28 @@ struct webgpu_buf_pool { // which can run on a different thread than the calling thread. std::mutex mutex; std::condition_variable cv; + size_t cur_pool_size; + size_t max_pool_size; + wgpu::Device device; + wgpu::BufferUsage host_buf_usage; + wgpu::BufferUsage dev_buf_usage; + size_t buf_size; + bool should_grow; void init(wgpu::Device device, int num_bufs, size_t buf_size, wgpu::BufferUsage dev_buf_usage, - wgpu::BufferUsage host_buf_usage) { + wgpu::BufferUsage host_buf_usage, + bool should_grow = false, + size_t max_pool_size = WEBGPU_NUM_PARAM_BUFS * 2) { + this->max_pool_size = max_pool_size; + this->cur_pool_size = num_bufs; + this->device = device; + this->host_buf_usage = host_buf_usage; + this->dev_buf_usage = dev_buf_usage; + this->buf_size = buf_size; + this->should_grow = should_grow; for (int i = 0; i < num_bufs; i++) { wgpu::Buffer host_buf; wgpu::Buffer dev_buf; @@ -150,6 +168,25 @@ struct webgpu_buf_pool { webgpu_pool_bufs alloc_bufs() { std::unique_lock lock(mutex); + if (!free.empty()) { + webgpu_pool_bufs bufs = free.back(); + free.pop_back(); + return bufs; + } + + // Try growing the pool if no free buffers + if (free.empty() && cur_pool_size < max_pool_size && should_grow) { + cur_pool_size++; + wgpu::Buffer host_buf; + wgpu::Buffer dev_buf; + ggml_webgpu_create_buffer(device, host_buf, buf_size, host_buf_usage, "ggml_webgpu_host_pool_buf"); + ggml_webgpu_create_buffer(device, dev_buf, buf_size, dev_buf_usage, "ggml_webgpu_dev_pool_buf"); + + if (!(host_buf && dev_buf)) { + GGML_ABORT("webgpu_buf_pool: failed to allocate buffers"); + } + return webgpu_pool_bufs{ host_buf, dev_buf }; + } cv.wait(lock, [this] { return !free.empty(); }); webgpu_pool_bufs bufs = free.back(); free.pop_back(); @@ -243,6 +280,7 @@ struct webgpu_gpu_profile_buf_pool { #endif struct webgpu_command { + uint32_t num_kernels; wgpu::CommandBuffer commands; std::vector params_bufs; std::optional set_rows_error_bufs; @@ -280,7 +318,6 @@ struct webgpu_global_context_struct { webgpu_buf_pool memset_buf_pool; std::map memset_pipelines; // variant or type index - std::atomic_uint inflight_threads = 0; #ifdef GGML_WEBGPU_CPU_PROFILE // Profiling: labeled CPU time in ms (total) @@ -421,30 +458,60 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, /** End WebGPU object initializations */ /** WebGPU Actions */ +static void erase_completed(std::vector & futures) { + futures.erase(std::remove_if(futures.begin(), futures.end(), + [](const wgpu::FutureWaitInfo & info) { return info.completed; }), + futures.end()); +} // Wait for the queue to finish processing all submitted work -static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, - std::vector & futures, - bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If - // there are many threads, inflight_max may be 0, meaning that we must wait on - // all futures. - uint64_t timeout_ms = block ? UINT64_MAX : 0; - uint32_t inflight_threads = ctx->inflight_threads; - uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); - while (futures.size() >= inflight_max && futures.size() > 0) { - ctx->instance.WaitAny(futures[0].futures.size(), futures[0].futures.data(), UINT64_MAX); - futures.erase(futures.begin()); - } - size_t i = 0; - while (i < futures.size()) { - auto waitStatus = ctx->instance.WaitAny(futures[i].futures.size(), futures[i].futures.data(), timeout_ms); +static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, + std::vector & futures, + bool block = true) { + // If we have too many in-flight submissions, wait on the oldest one first. + if (futures.empty()) { + return; + } + uint64_t timeout_ms = block ? UINT64_MAX : 0; + while (futures.size() >= WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD) { + auto waitStatus = ctx->instance.WaitAny(1, &futures[0], UINT64_MAX); + if (waitStatus == wgpu::WaitStatus::Error) { + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + } + if (futures[0].completed) { + futures.erase(futures.begin()); + } + } + + if (futures.empty()) { + return; + } + + if (block) { + while (!futures.empty()) { + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); + switch (waitStatus) { + case wgpu::WaitStatus::Success: + // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. + erase_completed(futures); + break; + case wgpu::WaitStatus::Error: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); + break; + default: + GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an unknown status\n"); + break; + } + } + } else { + // Poll once and return + auto waitStatus = ctx->instance.WaitAny(futures.size(), futures.data(), timeout_ms); switch (waitStatus) { case wgpu::WaitStatus::Success: - futures.erase(futures.begin() + i); + // WaitAny doesn't tell us which future completed, so we must check all futures to see which finished. + erase_completed(futures); break; case wgpu::WaitStatus::TimedOut: - i++; break; case wgpu::WaitStatus::Error: GGML_LOG_ERROR("ggml_webgpu: WaitAny returned an error\n"); @@ -487,10 +554,11 @@ static void ggml_backend_webgpu_debug(webgpu_global_context & ctx) { } #endif -static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_context ctx, - std::vector commands, - webgpu_buf_pool & param_buf_pool, - webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { +static std::vector ggml_backend_webgpu_submit( + webgpu_global_context ctx, + std::vector commands, + webgpu_buf_pool & param_buf_pool, + webgpu_buf_pool * set_rows_error_buf_pool = nullptr) { std::vector command_buffers; std::vector params_bufs; std::vector set_rows_error_bufs; @@ -562,7 +630,7 @@ static webgpu_submission_futures ggml_backend_webgpu_submit(webgpu_global_contex futures.push_back({ f }); } #endif - return { futures }; + return futures; } static webgpu_command ggml_backend_webgpu_build_multi( @@ -651,6 +719,7 @@ static webgpu_command ggml_backend_webgpu_build_multi( result.commands = commands; result.params_bufs = params_bufs_list; result.set_rows_error_bufs = set_rows_error_bufs; + result.num_kernels = pipelines.size(); #ifdef GGML_WEBGPU_GPU_PROFILE result.timestamp_query_bufs = ts_bufs; // TODO: handle multiple pipeline names @@ -688,8 +757,7 @@ static void ggml_backend_webgpu_buffer_memset(webgpu_global_context & ctx, webgpu_command command = ggml_backend_webgpu_build(ctx, ctx->memset_buf_pool, ctx->memset_pipelines[0], params, entries, wg_x); - std::vector futures = { ggml_backend_webgpu_submit(ctx, { command }, - ctx->memset_buf_pool) }; + auto futures = ggml_backend_webgpu_submit(ctx, { command }, ctx->memset_buf_pool); ggml_backend_webgpu_wait(ctx, futures); } @@ -788,6 +856,7 @@ static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) { struct binary_overlap_flags { bool inplace; // src0 == dst bool overlap; // src1 == dst + bool src_overlap; }; static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0, @@ -796,6 +865,7 @@ static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0 binary_overlap_flags flags = {}; flags.inplace = ggml_webgpu_tensor_equal(src0, dst); flags.overlap = ggml_webgpu_tensor_overlap(src1, dst); + flags.src_overlap = ggml_webgpu_tensor_overlap(src0, src1); return flags; } @@ -1112,8 +1182,9 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, }; // Calculate workgroup dimensions - uint32_t wg_x = 1; - uint32_t wg_y = 1; + uint32_t wg_x = 1; + uint32_t wg_y = 1; + const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; if (use_fast && is_vec) { auto decisions = static_cast(pipeline.context.get()); @@ -1121,9 +1192,7 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t batches = dst->ne[2] * dst->ne[3]; uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); uint32_t total_wg = output_groups * batches; - // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups - wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } else if (use_fast) { auto decisions = static_cast(pipeline.context.get()); @@ -1142,12 +1211,14 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, wg_m = CEIL_DIV(dst->ne[0], tile_m_s); wg_n = CEIL_DIV(dst->ne[1], tile_n_s); } - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + uint32_t total_wg = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); + } else { // legacy auto decisions = static_cast(pipeline.context.get()); uint32_t wg_size = decisions->wg_size; - wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); - wg_y = 1; + uint32_t total_wg = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + compute_2d_workgroups(total_wg, max_wg_per_dim, wg_x, wg_y); } return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); @@ -1353,6 +1424,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .inplace = flags.inplace, .overlap = flags.overlap, + .src_overlap = flags.src_overlap, }; webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); @@ -1361,11 +1433,28 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, uint32_t ne = (uint32_t) ggml_nelements(dst); + size_t src0_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src0); + size_t src1_webgpu_tensor_align_offset = ggml_webgpu_tensor_align_offset(ctx, src1); + + uint32_t offset_merged_src0 = 0; + uint32_t offset_merged_src1 = 0; + if (flags.src_overlap) { + size_t min_off = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + offset_merged_src0 = (uint32_t) ((src0_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + offset_merged_src1 = (uint32_t) ((src1_webgpu_tensor_align_offset - min_off) / ggml_type_size(src0->type)); + } + std::vector params = { ne, (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + offset_merged_src0, + offset_merged_src1, + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), @@ -1381,31 +1470,111 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, std::vector entries; - entries.push_back({ - .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0), - }); - - entries.push_back({ - .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1), - }); - - if (!flags.inplace && !flags.overlap) { - entries.push_back({ .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + if (flags.src_overlap) { + size_t merged_offset = std::min(src0_webgpu_tensor_align_offset, src1_webgpu_tensor_align_offset); + size_t merged_end = std::max(src0_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src0), + src1_webgpu_tensor_align_offset + ggml_webgpu_tensor_binding_size(ctx, src1)); + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = merged_offset, + .size = merged_end - merged_offset, + }); + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst), + }); + } else { + entries.push_back({ + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = src0_webgpu_tensor_align_offset, + .size = ggml_webgpu_tensor_binding_size(ctx, src0), + }); + entries.push_back({ + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = src1_webgpu_tensor_align_offset, + .size = ggml_webgpu_tensor_binding_size(ctx, src1), + }); + if (!flags.inplace && !flags.overlap) { + entries.push_back({ + .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst), + }); + } } uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } +static webgpu_command ggml_webgpu_concat(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * dst) { + uint32_t ne = (uint32_t) ggml_nelements(dst); + uint32_t dim = (uint32_t) dst->op_params[0]; + + std::vector params = { + ne, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[0] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[0] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + dim, + (uint32_t)src0->ne[dim] + }; + + std::vector entries = { + { + .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) + }, + { + .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) + }, + { + .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) + } + }; + + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_concat_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); +} + static webgpu_command ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { int inplace = ggml_webgpu_tensor_equal(src, dst); @@ -1990,6 +2159,8 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, case GGML_OP_MUL: case GGML_OP_DIV: return ggml_webgpu_binary_op(ctx, src0, src1, node); + case GGML_OP_CONCAT: + return ggml_webgpu_concat(ctx, src0, src1, node); case GGML_OP_RMS_NORM: return ggml_webgpu_rms_norm(ctx, src0, node); case GGML_OP_ROPE: @@ -2043,21 +2214,20 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str WEBGPU_CPU_PROFILE_TOTAL_START(graph_compute); - ctx->global_ctx->inflight_threads++; - - std::vector commands; - std::vector futures; + std::vector commands; + std::vector futures; + uint32_t num_batched_kernels = 0; for (int i = 0; i < cgraph->n_nodes; i++) { if (auto cmd = ggml_webgpu_encode_node(ctx, cgraph->nodes[i])) { commands.push_back(*cmd); + num_batched_kernels += cmd.value().num_kernels; } - // compute the batch size based on the number of inflight threads - uint32_t inflight_threads = ctx->global_ctx->inflight_threads; - uint32_t batch_size = std::min(std::max(1u, WEBGPU_NUM_PARAM_BUFS / std::max(inflight_threads, 1u)), - WEBGPU_COMMAND_SUBMIT_BATCH_SIZE); - if (commands.size() >= batch_size) { - futures.push_back(ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, - &ctx->set_rows_error_buf_pool)); + + if (num_batched_kernels >= WEBGPU_COMMAND_SUBMIT_BATCH_SIZE) { + num_batched_kernels = 0; + std::vector compute_futures = ggml_backend_webgpu_submit( + ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); + futures.insert(futures.end(), compute_futures.begin(), compute_futures.end()); // Process events and check for completed submissions ctx->global_ctx->instance.ProcessEvents(); ggml_backend_webgpu_wait(ctx->global_ctx, futures, false); @@ -2065,13 +2235,12 @@ static ggml_status ggml_backend_webgpu_graph_compute(ggml_backend_t backend, str } } if (!commands.empty()) { - webgpu_submission_futures new_futures = + auto new_futures = ggml_backend_webgpu_submit(ctx->global_ctx, commands, ctx->param_buf_pool, &ctx->set_rows_error_buf_pool); - futures.push_back(new_futures); + futures.insert(futures.end(), new_futures.begin(), new_futures.end()); } ggml_backend_webgpu_wait(ctx->global_ctx, futures); - ctx->global_ctx->inflight_threads--; WEBGPU_CPU_PROFILE_TOTAL_END(graph_compute, ctx->global_ctx); return GGML_STATUS_SUCCESS; } @@ -2689,7 +2858,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, - wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); + wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite, true); webgpu_ctx->set_rows_error_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_SET_ROWS_ERROR_BUFS, WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, @@ -2816,10 +2985,11 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_SUB: case GGML_OP_MUL: case GGML_OP_DIV: - // TODO: support non-contiguous tensors, e.g. for MOE_EXPERT_REDUCE - // see https://github.com/ggml-org/llama.cpp/pull/16857 supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && - (src1->type == op->type) && ggml_is_contiguous(src0) && ggml_is_contiguous(src1); + (src1->type == op->type); + break; + case GGML_OP_CONCAT: + supports_op = (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); break; case GGML_OP_CPY: case GGML_OP_CONT: diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl index 55dd66408a3..a748dc1b86c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/binary.wgsl @@ -7,6 +7,13 @@ struct Params { offset_src0: u32, offset_src1: u32, offset_dst: u32, + offset_merged_src0: u32, + offset_merged_src1: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, stride_src1_0: u32, stride_src1_1: u32, @@ -23,6 +30,21 @@ struct Params { b_ne3: u32, }; +fn src0_index(_i: u32) -> u32 { + var i = _i; + let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); + i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0); + let a_i2 = i / (params.a_ne1 * params.a_ne0); + i = i % (params.a_ne1 * params.a_ne0); + let a_i1 = i / params.a_ne0; + let a_i0 = i % params.a_ne0; + + return a_i0 * params.stride_src0_0 + + a_i1 * params.stride_src0_1 + + a_i2 * params.stride_src0_2 + + a_i3 * params.stride_src0_3; +} + fn src1_index(_i: u32) -> u32 { var i = _i; let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0); @@ -53,17 +75,22 @@ fn src1_index(_i: u32) -> u32 { #define DataType f16 #endif +#ifdef SRC_OVERLAP @group(0) @binding(0) -var src0: array; +var merged_src: array; @group(0) @binding(1) -var src1 : array; +var dst: array; -#ifdef INPLACE @group(0) @binding(2) var params: Params; +#else +@group(0) @binding(0) +var src0: array; -#elif defined(OVERLAP) +@group(0) @binding(1) +var src1 : array; +#if defined(INPLACE) || defined(OVERLAP) @group(0) @binding(2) var params: Params; @@ -74,6 +101,7 @@ var dst: array; @group(0) @binding(3) var params: Params; #endif +#endif fn op(a: DataType, b: DataType) -> DataType { #ifdef OP_ADD @@ -87,13 +115,17 @@ fn op(a: DataType, b: DataType) -> DataType { #endif } -fn update(dst_i: u32, src0_i: u32, src1_i: u32){ +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { +#ifdef SRC_OVERLAP + let result = op(merged_src[src0_i], merged_src[src1_i]); +#else let result = op(src0[src0_i], src1[src1_i]); +#endif #ifdef INPLACE - src0[dst_i] = result; + src0[src0_i] = result; #elif defined(OVERLAP) - src1[dst_i] = result; + src1[src1_i] = result; #else dst[dst_i] = result; #endif @@ -102,6 +134,8 @@ fn update(dst_i: u32, src0_i: u32, src1_i: u32){ @compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x < params.ne) { - update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + let src0_i = params.offset_src0 + params.offset_merged_src0 + src0_index(gid.x); + let src1_i = params.offset_src1 + params.offset_merged_src1 + src1_index(gid.x); + update(params.offset_dst + gid.x, src0_i, src1_i); } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl new file mode 100644 index 00000000000..a22d245d2cc --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/concat.wgsl @@ -0,0 +1,75 @@ +struct Params { + ne: u32, + + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + stride_src0_0: u32, + stride_src0_1: u32, + stride_src0_2: u32, + stride_src0_3: u32, + + stride_src1_0: u32, + stride_src1_1: u32, + stride_src1_2: u32, + stride_src1_3: u32, + + ne0: u32, + ne1: u32, + ne2: u32, + ne3: u32, + + dim: u32, + src0_nedim: u32 +}; + +#ifdef TYPE_F32 +#define DataType f32 +#endif +#ifdef TYPE_I32 +#define DataType i32 +#endif + +@group(0) @binding(0) +var src0: array; + +@group(0) @binding(1) +var src1 : array; + +@group(0) @binding(2) +var dst: array; + +@group(0) @binding(3) +var params: Params; + +@compute @workgroup_size(WG_SIZE) +fn main(@builtin(global_invocation_id) gid: vec3) { + + if (gid.x < params.ne) { + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + var ni = array(i0, i1, i2, i3); + + if (ni[params.dim] < params.src0_nedim) { + let src_i = ni[0] * params.stride_src0_0 + + ni[1] * params.stride_src0_1 + + ni[2] * params.stride_src0_2 + + ni[3] * params.stride_src0_3; + dst[params.offset_dst + gid.x] = src0[params.offset_src0 + src_i]; + } else { + ni[params.dim] -= params.src0_nedim; + let src_i = ni[0] * params.stride_src1_0 + + ni[1] * params.stride_src1_1 + + ni[2] * params.stride_src1_2 + + ni[3] * params.stride_src1_3; + dst[params.offset_dst + gid.x] = src1[params.offset_src1 + src_i]; + } + } +} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 6aba47317c6..5b9f5b36224 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -679,19 +679,24 @@ struct MulMatParams { @group(0) @binding(3) var params: MulMatParams; @compute @workgroup_size(256) -fn main(@builtin(global_invocation_id) global_id: vec3) { +fn main(@builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { + let wg_linear = wg_id.y * num_wg.x + wg_id.x; + let global_idx = wg_linear * 256u + local_id.x; + let total = params.m * params.n * params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; - if (global_id.x >= total) { + if (global_idx >= total) { return; } let dst2_stride = params.m * params.n; let dst3_stride = dst2_stride * params.bs02 * params.broadcast2; - let dst3_idx = global_id.x / dst3_stride; + let dst3_idx = global_idx / dst3_stride; let src03_idx = dst3_idx / params.broadcast3; // src0 may be broadcast along the third dimension let src13_idx = dst3_idx; // src1 is not broadcast - let dst3_rem = global_id.x % dst3_stride; + let dst3_rem = global_idx % dst3_stride; let dst2_idx = dst3_rem / dst2_stride; let src02_idx = dst2_idx / params.broadcast2; // src0 may also be broadcast along the second dimension diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 771e5cd1ee3..761e3017c14 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -54,7 +54,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, - @builtin(local_invocation_id) local_id: vec3) { + @builtin(local_invocation_id) local_id: vec3, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let local_m = get_local_m(thread_id); @@ -64,9 +65,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_m_count = (params.m + WORKGROUP_SIZE_M * TILE_M - 1u) / (WORKGROUP_SIZE_M * TILE_M); let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 64529e03cdc..9f9ef279f29 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -69,7 +69,8 @@ var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) fn main(@builtin(workgroup_id) wg_id: vec3, @builtin(local_invocation_id) local_id: vec3, - @builtin(subgroup_id) subgroup_id: u32) { + @builtin(subgroup_id) subgroup_id: u32, + @builtin(num_workgroups) num_wg: vec3) { let thread_id = local_id.x; let subgroup_m = subgroup_id % SUBGROUP_M; @@ -79,9 +80,16 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let wg_n_count = (params.n + WG_N_SG_TILE_SIZE - 1) / WG_N_SG_TILE_SIZE; let wg_per_matrix = wg_m_count * wg_n_count; - let batch_idx = wg_id.x / wg_per_matrix; + let wg_linear = wg_id.y * num_wg.x + wg_id.x; - let wg_in_batch = wg_id.x % wg_per_matrix; + let batch_idx = wg_linear / wg_per_matrix; + + let total_batches = params.bs02 * params.broadcast2 * params.bs03 * params.broadcast3; + if (batch_idx >= total_batches) { + return; + } + + let wg_in_batch = wg_linear % wg_per_matrix; let wg_m = wg_in_batch % wg_m_count; let wg_n = wg_in_batch / wg_m_count; diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index e9529fbb662..d644cca8a6e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1410,16 +1410,14 @@ static bool ggml_is_contiguous_n(const struct ggml_tensor * tensor, int n) { } next_nb *= tensor->ne[0]/ggml_blck_size(tensor->type); for (int i = 1; i < GGML_MAX_DIMS; i++) { - if (tensor->ne[i] != 1) { - if (i > n) { - if (tensor->nb[i] != next_nb) { - return false; - } - next_nb *= tensor->ne[i]; - } else { - // this dimension does not need to be contiguous - next_nb = tensor->ne[i]*tensor->nb[i]; + if (i > n) { + if (tensor->ne[i] != 1 && tensor->nb[i] != next_nb) { + return false; } + next_nb *= tensor->ne[i]; + } else { + // this dimension does not need to be contiguous + next_nb = tensor->ne[i]*tensor->nb[i]; } } return true; diff --git a/gguf-py/gguf/metadata.py b/gguf-py/gguf/metadata.py index e0d478ce95d..e954644e28f 100644 --- a/gguf-py/gguf/metadata.py +++ b/gguf-py/gguf/metadata.py @@ -186,7 +186,7 @@ def load_model_card(model_path: Optional[Path] = None) -> dict[str, Any]: # Quick hack to fix the Norway problem # https://hitchdev.com/strictyaml/why/implicit-typing-removed/ yaml_content = yaml_content.replace("- no\n", "- \"no\"\n") - # yaml should use 2 spaces insted of tab + # yaml should use 2 spaces instead of tab # this issue has came up with the Qwen/Qwen3-235B-A22B-Instruct-2507 model card # (I've also sent a pr tp fix the modelcard too) yaml_content = yaml_content.replace("\t", " ") diff --git a/gguf-py/tests/test_metadata.py b/gguf-py/tests/test_metadata.py index 40d484f4eaa..b77c563ff25 100755 --- a/gguf-py/tests/test_metadata.py +++ b/gguf-py/tests/test_metadata.py @@ -164,7 +164,7 @@ def test_get_model_id_components(self): self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B"), ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration-LoRA', None, '8B')) - # Negative size --> output is a LoRA adaper --> prune "LoRA" out of the name to avoid redundancy with the suffix + # Negative size --> output is a LoRA adapter --> prune "LoRA" out of the name to avoid redundancy with the suffix self.assertEqual(gguf.Metadata.get_model_id_components("Llama-3-Instruct-abliteration-LoRA-8B", -1234), ('Llama-3-Instruct-abliteration-LoRA-8B', None, 'Llama-3', 'Instruct-abliteration', None, '8B')) diff --git a/include/llama.h b/include/llama.h index 077f66dc651..a84d56a8850 100644 --- a/include/llama.h +++ b/include/llama.h @@ -973,7 +973,7 @@ extern "C" { // Logits for the ith token. For positive indices, Equivalent to: // llama_get_logits(ctx) + ctx->output_ids[i]*n_vocab - // Negative indicies can be used to access logits in reverse order, -1 is the last logit. + // Negative indices can be used to access logits in reverse order, -1 is the last logit. // returns NULL for invalid ids. LLAMA_API float * llama_get_logits_ith(struct llama_context * ctx, int32_t i); @@ -988,7 +988,7 @@ extern "C" { // Get the embeddings for the ith token. For positive indices, Equivalent to: // llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd - // Negative indicies can be used to access embeddings in reverse order, -1 is the last embedding. + // Negative indices can be used to access embeddings in reverse order, -1 is the last embedding. // shape: [n_embd] (1-dimensional) // returns NULL for invalid ids. LLAMA_API float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i); @@ -1008,9 +1008,9 @@ extern "C" { // Returns LLAMA_TOKEN_NULL if no token was sampled. LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); - // Get the backend sampled probabilites for the ith token + // Get the backend sampled probabilities for the ith token // The index matches llama_get_sampled_token_ith(). - // Returns NULL if no probabilites were generated. + // Returns NULL if no probabilities were generated. LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); @@ -1337,7 +1337,7 @@ extern "C" { float tau, float eta); - /// @details Intializes a GBNF grammar, see grammars/README.md for details. + /// @details Initializes a GBNF grammar, see grammars/README.md for details. /// @param vocab The vocabulary that this grammar will be used with. /// @param grammar_str The production rules for the grammar, encoded as a string. Returns an empty grammar if empty. Returns NULL if parsing of grammar_str fails. /// @param grammar_root The name of the start symbol for the grammar. diff --git a/models/templates/Apertus-8B-Instruct.jinja b/models/templates/Apertus-8B-Instruct.jinja index 10826ff6901..432ae59a406 100644 --- a/models/templates/Apertus-8B-Instruct.jinja +++ b/models/templates/Apertus-8B-Instruct.jinja @@ -97,20 +97,20 @@ {%- macro render_tools(tools) -%} {%- for tool in tools %} - {{- "// " + tool.description + "\n" }} - {{- "type "+ tool.name + " = " }} - {%- if tool.parameters and tool.parameters.properties %} + {{- "// " + tool.function.description + "\n" }} + {{- "type "+ tool.function.name + " = " }} + {%- if tool.function.parameters and tool.function.parameters.properties %} {{- "(_: {\n" }} - {%- for param_name, param_spec in tool.parameters.properties.items() %} + {%- for param_name, param_spec in tool.function.parameters.properties.items() %} {%- if param_spec.description %} {{- "// " + param_spec.description + "\n" }} {%- endif %} {{- param_name }} - {%- if param_name not in (tool.parameters.required or []) -%} + {%- if param_name not in (tool.function.parameters.required or []) -%} {{- "?" }} {%- endif -%} {{- ": " }} - {{- render_typescript_type(param_spec, tool.parameters.required or []) }} + {{- render_typescript_type(param_spec, tool.function.parameters.required or []) }} {%- if param_spec.default is defined -%} {%- if param_spec.enum %} {{- ", // default: " + param_spec.default }} @@ -294,7 +294,7 @@ {%- for tool_call in message.tool_calls -%} {%- if tool_call.type == 'function' -%} {%- set function = tool_call.function -%} - {{- '{"' + function.name + '": ' + function.arguments + '}' }} + {{- '{"' + function.name + '": ' + function.arguments|tojson + '}' }} {%- if not loop.last -%} {{- ", " }} {%- endif -%} diff --git a/models/templates/Apriel-1.6-15b-Thinker-fixed.jinja b/models/templates/Apriel-1.6-15b-Thinker-fixed.jinja new file mode 100755 index 00000000000..a60a95f44d2 --- /dev/null +++ b/models/templates/Apriel-1.6-15b-Thinker-fixed.jinja @@ -0,0 +1,172 @@ +{# ---------------------------------------------------------------------- #} +{# ƛƬ Default setup and flags #} +{# ---------------------------------------------------------------------- #} +{%- set messages = messages or [] -%} +{%- set tools = tools or [] -%} +{%- set add_generation_prompt = add_generation_prompt or false -%} +{%- set available_tool_string = '' -%} +{%- set add_tool_id = true -%} +{%- set add_thoughts = true -%} {# whether to include reasoning blocks #} +{%- set add_generation_prompt = true -%} {# whether to emit reasoning starter before assistant response #} +{# Optional token placeholders (safe defaults) #} +{%- set bos_token = bos_token or '' -%} +{%- set eos_token = eos_token or '' -%} +{# ---------------------------------------------------------------------- #} +{# Core reasoning prompt and assistant reasoning prefix #} +{# ---------------------------------------------------------------------- #} +{%- set reasoning_prompt -%} + You are a thoughtful, systematic AI assistant from ServiceNow Language Models (SLAM) lab. + Analyze each question carefully, present your reasoning step-by-step, then provide the final + response after the marker [BEGIN FINAL RESPONSE]. +{%- endset -%} +{%- set reasoning_asst_turn_start = 'Here are my reasoning steps:\n' -%} +{# ---------------------------------------------------------------------- #} +{# Tool list and tool call output format #} +{# ---------------------------------------------------------------------- #} +{%- if tools|length > 0 -%} + {%- set available_tool_string -%} + You are provided with function signatures within XML tags. + You may call one or more functions to assist with the user query. + Don't make assumptions about the arguments. You should infer the argument values from previous + user responses and the system message. + Here are the available tools: + + {% for tool in tools %}{{ tool|string }}{% endfor %} + + . + + Return all function calls as a list of JSON objects within XML tags. + Each JSON object should contain a function name and arguments as follows: + [ + {"name": , "arguments": }, + {"name": , "arguments": }, + ... + ] + {%- endset -%} +{%- endif -%} +{# ---------------------------------------------------------------------- #} +{# Start system block if first message is not system #} +{# ---------------------------------------------------------------------- #} +{%- if messages|length > 0 and messages[0]['role'] != 'system' -%} + {%- if tools|length > 0 -%} + {{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' + available_tool_string + '\n' }} + {%- else -%} + {{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' }} + {%- endif -%} +{%- endif -%} +{# ---------------------------------------------------------------------- #} +{# Iterate through messages #} +{# ---------------------------------------------------------------------- #} +{%- for message in messages -%} + + {# ---------------- USER MESSAGE ---------------- #} + {%- if message['role'] == 'user' -%} + {{ '<|begin_user|>\n' }} + {%- if message['content'] is not string -%} + {%- for chunk in message['content'] -%} + {%- if chunk['type'] == 'text' -%} + {{ chunk['text'] }} + {%- elif chunk['type'] in ['image', 'image_url'] -%} + {{ '[IMG]' }} + {%- else -%} + {{ raise_exception('Unrecognized content type!') }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{ message['content'] }} + {%- endif -%} + + {# ---------------- SYSTEM MESSAGE ---------------- #} + {%- elif message['role'] == 'system' -%} + {%- set sys_content = message.get('content', '') -%} + {%- if sys_content and sys_content|length > 0 -%} + {%- if sys_content is string -%} + {%- set system_message = sys_content -%} + {%- else -%} + {%- set system_message = sys_content[0]['text'] -%} + {%- endif -%} + {%- else -%} + {%- set system_message = '' -%} + {%- endif -%} + + {%- if tools|length > 0 -%} + {{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' + system_message + '\n' + available_tool_string + '\n' }} + {%- else -%} + {{ bos_token + '<|begin_system|>\n' + reasoning_prompt + '\n' + system_message + '\n' }} + {%- endif -%} + + {# ---------------- ASSISTANT MESSAGE ---------------- #} + {%- elif message['role'] == 'assistant' -%} + {%- if loop.last -%} + {%- set add_tool_id = false -%} + {%- endif -%} + + {{ '\n<|begin_assistant|>\n' }} + + {%- if add_thoughts and message.get('reasoning_content') and loop.last -%} + {{ message['reasoning_content'] + '\n[BEGIN FINAL RESPONSE]\n' }} + {%- endif -%} + + {%- set asst_content = message.get('content', '') -%} + {%- if asst_content and asst_content|length > 0 -%} + {%- if asst_content is not string -%} + {%- set asst_text = asst_content[0]['text'] -%} + {%- else -%} + {%- set asst_text = asst_content -%} + {%- endif -%} + {# For historical turns (not the last), strip reasoning and keep only final response #} + {%- if not loop.last and '[BEGIN FINAL RESPONSE]' in asst_text -%} + {{- asst_text.split('[BEGIN FINAL RESPONSE]')[-1] | trim -}} + {%- else -%} + {{- asst_text -}} + {%- endif -%} + {%- elif message.get('chosen') and message['chosen']|length > 0 -%} + {{ message['chosen'][0] }} + {%- endif -%} + + {# Tool call output #} + {%- set tool_calls = message.get('tool_calls', []) -%} + {%- if tool_calls and tool_calls|length > 0 -%} + {{ '\n[' }} + {%- for tool_call in tool_calls -%} + {{ '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|tojson }} + {%- if add_tool_id == true and 'id' in tool_call -%} + {{ ', "id": "' + tool_call['id'] + '"' }} + {%- endif -%} + {{ '}' }} + {%- if not loop.last -%}{{ ', ' }}{%- endif -%} + {%- endfor -%} + {{ ']' }} + {%- endif -%} + + {%- set training_prompt = training_prompt if (training_prompt is defined) else false -%} + {%- if not loop.last or training_prompt -%} + {{ '\n<|end|>\n' }} + {%- endif -%} + + {# ---------------- TOOL RESULT MESSAGE ---------------- #} + {%- elif message['role'] == 'tool' -%} + {%- set tool_content = message.get('content', '') -%} + {%- if tool_content is string -%} + {%- set tool_message = tool_content -%} + {%- else -%} + {%- set tool_message = tool_content[0]['text'] if tool_content else '' -%} + {%- endif -%} + {{ '<|begin_tool_result|>\n' + tool_message|string + '\n' }} + + {# ---------------- CONTENT MESSAGE ---------------- #} + {%- elif message['role'] == 'content' -%} + {%- set msg_content = message.get('content', '') -%} + {%- if msg_content is not string -%} + {{ '<|begin_content|>\n' + msg_content[0]['text'] + '\n' }} + {%- else -%} + {{ '<|begin_content|>\n' + msg_content + '\n' }} + {%- endif -%} + {%- endif -%} + + {# ---------------- REASONING PROMPT BEFORE NEXT ASSISTANT ---------------- #} + {%- if loop.last and add_generation_prompt and message['role'] != 'assistant' -%} + {{ '\n<|begin_assistant|>\n' + reasoning_asst_turn_start }} + {%- endif -%} + +{%- endfor -%} diff --git a/models/templates/Bielik-11B-v3.0-Instruct.jinja b/models/templates/Bielik-11B-v3.0-Instruct.jinja new file mode 100644 index 00000000000..40ef50076eb --- /dev/null +++ b/models/templates/Bielik-11B-v3.0-Instruct.jinja @@ -0,0 +1,77 @@ +{{ bos_token }} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content'] %} + {%- set loop_start_index = 1 %} +{%- else %} + {%- set system_message = "" %} + {%- set loop_start_index = 0 %} +{%- endif %} + +{%- if system_message or tools %} + {{- '<|im_start|>system\n' }} + {%- if system_message %} + {{- system_message }} + {%- endif %} + {%- if tools %} + {{- '\n\nMasz dostęp do następujących narzędzi. Definicje narzędzi znajdują się poniżej wewnątrz znaczników <|function_list|>:\n<|function_list|>\n' }} + {{- '[' }} + {%- for tool in tools %} + {{- tool | tojson }} + {%- if not loop.last %} + {{- ',\n' }} + {%- endif %} + {%- endfor %} + {{- ']\n<|function_list|>\n\nAby wywołać narzędzie, użyj formatu {"name": "nazwa_narzędzia", "arguments": {"argument": "wartość"}}. Wyniki działania narzędzi zostaną przekazane z markerem <|function_output|>.\n' }} + {%- endif %} + {%- if enable_thinking %} + {{- '\n\nZanim odpowiesz na pytanie, najpierw przemyśl swoje kroki i umieść swoje myśli wewnątrz tagów .... Musisz najpierw pomyśleć, zanim udzielisz odpowiedzi. ' -}} + {{- 'WAŻNE: Powinieneś myśleć w tym samym języku, co pytanie użytkownika. Jeśli pytanie jest zadane po polsku, powinieneś również myśleć po polsku. Jeśli pytanie jest po angielsku, myślisz również po angielsku itd. ' }} + {{- '** PAMIĘTAJ! ** Pytanie po polsku -> myślenie po polsku -> odpowiedź po polsku!' -}} + {{- '*** BARDZO WAŻNE!!! *** Jesteś Bielikiem, polskim modelem językowym. Twoją główną cechą jest umiejętność pisania po polsku. Jeśli użytkownik zadaje Ci pytania po polsku, ZAWSZE odpowiadaj po polsku. ' -}} + {{- 'Nawet, jeśli korzystasz z narzędzia, którego większość instrukcji jest po angielsku, powinieneś przede wszystkim odpowiadać po polsku, jeśli użytkownik zadaje pytanie w tym języku. ' -}} + {%- endif %} + {{- '<|im_end|>\n' }} +{%- endif %} + +{%- for message in messages[loop_start_index:] %} + {%- if message['role'] == 'user' %} + {{- '<|im_start|>user\n' + message['content'] + '<|im_end|>\n' }} + {%- elif message['role'] == 'assistant' %} + {{- '<|im_start|>assistant\n' }} + {%- set content = message.content | default('') %} + {%- set reasoning_content = message.reasoning_content | default('') %} + {%- if not reasoning_content and '' in content and '' in content %} + {%- set reasoning_parts = content.split('') %} + {%- set reasoning_content = reasoning_parts[0].split('')[-1] %} + {%- set content = reasoning_parts[1:] | join('') %} + {%- endif %} + {%- if reasoning_content %} + {{- '\n' + reasoning_content.strip() + '\n\n' }} + {%- endif %} + {{- content.lstrip() }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n{"name": "' + tool_call.name + '", "arguments": ' + (tool_call.arguments if tool_call.arguments is string else tool_call.arguments | tojson) + '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message['role'] == 'tool' %} + {%- if loop.index0 == 0 or messages[loop.index0 - 1]['role'] != 'tool' %} + {{- '<|im_start|>user\n' }} + {%- endif %} + {{- '<|function_output|>' + message['content'] }} + {%- if loop.last or messages[loop.index0 + 1]['role'] != 'tool' %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking %} + {{- '\n' }} + {%- endif %} +{%- endif %} \ No newline at end of file diff --git a/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja b/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja index 078e9f5458e..fcf1259d33c 100644 --- a/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja +++ b/models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja @@ -132,7 +132,7 @@ The following instructions take precedence over instructions in the default prea {%- elif message.role|lower == 'user' %} <|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{ message.content }}<|END_OF_TURN_TOKEN|>{%- if documents and not sent_documents.value %}{%- set sent_documents.value = true %}{% set tool_idx.value = tool_idx.value + 1 %}{{ document_turn(documents) }}{% endif %} {%- elif message.role|lower == 'assistant' or message.role|lower == 'chatbot' %} -<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.tool_plan}}<|END_THINKING|><|START_ACTION|>[ +<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{% if message.tool_calls %}<|START_THINKING|>{{message.reasoning_content}}<|END_THINKING|><|START_ACTION|>[ {% for tc in message.tool_calls %} {"tool_call_id": "{{ tool_idx.value }}", "tool_name": "{{ tc['function']['name'] }}", "parameters": {{ tc['function']['arguments']|tojson }}}{% if not loop.last %},{% endif %} @@ -153,4 +153,4 @@ The following instructions take precedence over instructions in the default prea ]<|END_TOOL_RESULT|><|END_OF_TURN_TOKEN|> {%- endif %} -{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|> \ No newline at end of file +{%- endfor %}<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>{%- if not enable_thinking -%}<|START_THINKING|><|END_THINKING|>{%- endif %} \ No newline at end of file diff --git a/models/templates/GLM-4.7-Flash.jinja b/models/templates/GLM-4.7-Flash.jinja new file mode 100644 index 00000000000..2ab98ef068d --- /dev/null +++ b/models/templates/GLM-4.7-Flash.jinja @@ -0,0 +1,86 @@ +[gMASK] +{%- if tools -%} +<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{% for tool in tools %} +{{ tool | tojson(ensure_ascii=False) }} +{% endfor %} + + +For each function call, output the function name and arguments within the following XML format: +{function-name}{arg-key-1}{arg-value-1}{arg-key-2}{arg-value-2}...{%- endif -%} +{%- macro visible_text(content) -%} + {%- if content is string -%} + {{- content }} + {%- elif content is iterable and content is not mapping -%} + {%- for item in content -%} + {%- if item is mapping and item.type == 'text' -%} + {{- item.text }} + {%- elif item is string -%} + {{- item }} + {%- endif -%} + {%- endfor -%} + {%- else -%} + {{- content }} + {%- endif -%} +{%- endmacro -%} +{%- set ns = namespace(last_user_index=-1) %} +{%- for m in messages %} + {%- if m.role == 'user' %} + {% set ns.last_user_index = loop.index0 -%} + {%- endif %} +{%- endfor %} +{% for m in messages %} +{%- if m.role == 'user' -%}<|user|>{{ visible_text(m.content) }} +{%- elif m.role == 'assistant' -%} +<|assistant|> +{%- set reasoning_content = '' %} +{%- set content = visible_text(m.content) %} +{%- if m.reasoning_content is string %} + {%- set reasoning_content = m.reasoning_content %} +{%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} +{%- endif %} +{%- if ((clear_thinking is defined and not clear_thinking) or loop.index0 > ns.last_user_index) and reasoning_content -%} +{{ '' + reasoning_content.strip() + ''}} +{%- else -%} +{{ '' }} +{%- endif -%} +{%- if content.strip() -%} +{{ content.strip() }} +{%- endif -%} +{% if m.tool_calls %} +{% for tc in m.tool_calls %} +{%- if tc.function %} + {%- set tc = tc.function %} +{%- endif %} +{{- '' + tc.name -}} +{% set _args = tc.arguments %}{% for k, v in _args.items() %}{{ k }}{{ v | tojson(ensure_ascii=False) if v is not string else v }}{% endfor %}{% endfor %} +{% endif %} +{%- elif m.role == 'tool' -%} +{%- if m.content is string -%} +{%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|observation|>' }} +{%- endif %} +{{- '' }} +{{- m.content }} +{{- '' }} +{%- else -%} +<|observation|>{% for tr in m.content %} +{{ tr.output if tr.output is defined else tr }}{% endfor -%} +{% endif -%} +{%- elif m.role == 'system' -%} +<|system|>{{ visible_text(m.content) }} +{%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + <|assistant|>{{- '' if (enable_thinking is defined and not enable_thinking) else '' -}} +{%- endif -%} \ No newline at end of file diff --git a/models/templates/LFM2-8B-A1B.jinja b/models/templates/LFM2-8B-A1B.jinja new file mode 100644 index 00000000000..3738b3d145b --- /dev/null +++ b/models/templates/LFM2-8B-A1B.jinja @@ -0,0 +1,47 @@ +{{- bos_token -}} +{%- set system_prompt = "" -%} +{%- set ns = namespace(system_prompt="") -%} +{%- if messages[0]["role"] == "system" -%} + {%- set ns.system_prompt = messages[0]["content"] -%} + {%- set messages = messages[1:] -%} +{%- endif -%} +{%- if tools -%} + {%- set ns.system_prompt = ns.system_prompt + ("\n" if ns.system_prompt else "") + "You can use the following tools: <|tool_list_start|>[" -%} + {%- for tool in tools -%} + {%- if tool is not string -%} + {%- set tool = tool | tojson -%} + {%- endif -%} + {%- set ns.system_prompt = ns.system_prompt + tool -%} + {%- if not loop.last -%} + {%- set ns.system_prompt = ns.system_prompt + ", " -%} + {%- endif -%} + {%- endfor -%} + {%- set ns.system_prompt = ns.system_prompt + "]<|tool_list_end|>" -%} + {{- '**IMPORTANT**: The syntax for calling the tools is: <|tool_call_start|>JSON tool call goes here<|tool_call_end|>. Please only call tools in the specified manner.' -}} +{%- endif -%} +{%- if ns.system_prompt -%} + {{- "<|im_start|>system\n" + ns.system_prompt + "<|im_end|>\n" -}} +{%- endif -%} +{%- for message in messages -%} + {{- "<|im_start|>" + message["role"] + "\n" -}} + {%- set content = message["content"] -%} + {%- if content is not string -%} + {%- set content = content | tojson -%} + {%- endif -%} + {%- if message["role"] == "tool" -%} + {%- set content = "<|tool_response_start|>" + content + "<|tool_response_end|>" -%} + {%- elif message["role"] == "assistant" -%} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n<|tool_call_start|>\n{"name": "' + tool_call.name + '", "arguments": ' + (tool_call.arguments if tool_call.arguments is string else tool_call.arguments | tojson) + '}\n<|tool_call_end|>\n' }} + {%- endfor %} + {%- endif %} + {%- endif -%} + {{- content + "<|im_end|>\n" -}} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- "<|im_start|>assistant\n" -}} +{%- endif -%} diff --git a/models/templates/Qwen-QwQ-32B.jinja b/models/templates/Qwen-QwQ-32B.jinja index d475f706873..ce314a039f6 100644 --- a/models/templates/Qwen-QwQ-32B.jinja +++ b/models/templates/Qwen-QwQ-32B.jinja @@ -59,4 +59,5 @@ {%- endfor %} {%- if add_generation_prompt %} {{- '<|im_start|>assistant\n\n' }} + {%- if not enable_thinking -%}{{- '' -}}{%- endif -%} {%- endif %} diff --git a/models/templates/Qwen3-Coder.jinja b/models/templates/Qwen3-Coder.jinja index 49b0e8d0ee7..cde8c0e43db 100644 --- a/models/templates/Qwen3-Coder.jinja +++ b/models/templates/Qwen3-Coder.jinja @@ -29,7 +29,7 @@ {%- endif %} {%- endif %} {%- if tools is iterable and tools | length > 0 %} - {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "\n\n# Tools\n\nYou have access to the following tools:\n\n" }} {{- "" }} {%- for tool in tools %} {%- if tool.function is defined %} @@ -63,7 +63,7 @@ {{- '\n' }} {%- endfor %} {{- "\n" }} - {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} + {{- '\n\nIf you choose to call a tool ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nvalue_2\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: the tool calling block MUST begin with an opening tag and end with a closing tag.\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} {%- endif %} {%- if system_message is defined %} {{- '<|im_end|>\n' }} diff --git a/models/templates/StepFun3.5-Flash.jinja b/models/templates/StepFun3.5-Flash.jinja new file mode 100644 index 00000000000..c09ea497dad --- /dev/null +++ b/models/templates/StepFun3.5-Flash.jinja @@ -0,0 +1,80 @@ +{% macro render_content(content) %}{% if content is none %}{{- '' }}{% elif content is string %}{{- content }}{% elif content is mapping %}{{- content['value'] if 'value' in content else content['text'] }}{% elif content is iterable %}{% for item in content %}{% if item.type == 'text' %}{{- item['value'] if 'value' in item else item['text'] }}{% elif item.type == 'image' %}{% endif %}{% endfor %}{% endif %}{% endmacro %} +{{bos_token}}{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- render_content(messages[0].content) + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou have access to the following functions in JSONSchema format:\n\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson(ensure_ascii=False) }} + {%- endfor %} + {{- "\n\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner \n...\n block must be nested within \n...\n XML tags\n- Required parameters MUST be specified\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + render_content(messages[0].content) + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and render_content(message.content) is string and not(render_content(message.content).startswith('') and render_content(message.content).endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- set content = render_content(message.content) %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {%- set role_name = 'observation' if (message.role == "system" and not loop.first and message.name == 'observation') else message.role %} + {{- '<|im_start|>' + role_name + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = render_content(message.reasoning_content) %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- else %} + {%- set reasoning_content = '' %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content + '\n\n' + content }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n' }} + {%- if tool_call.arguments is defined %} + {%- set arguments = tool_call.arguments %} + {%- for args_name, args_value in arguments|items %} + {{- '\n' }} + {%- set args_value = args_value | tojson(ensure_ascii=False) | safe if args_value is mapping or (args_value is sequence and args_value is not string) else args_value | string %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>tool_response\n' }} + {%- endif %} + {{- '' }} + {{- content }} + {{- '' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n\n' }} +{%- endif %} diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja index c2066bd7391..299f7a7ff12 100644 --- a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja @@ -1 +1,44 @@ -{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\n'}}{% endif %} \ No newline at end of file +{% if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.system_prompt = message['content'] -%} + {%- endif -%} +{%- endfor -%}{{bos_token}}{{ns.system_prompt}} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is none -%} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls']-%} + {%- if not ns.is_first -%}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- set ns.is_first = true -%} + {%- else -%}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is not none -%} + {%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else -%} + {%- set content = message['content'] -%} + {%- if '' in content -%} + {%- set content = content.split('')[-1] -%} + {%- endif -%}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} + {%- endif -%} + {%- endif -%} + {%- if message['role'] == 'tool' -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first -%}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false -%} + {%- else -%}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>'}} +{%- endif -%} +{%- if add_generation_prompt and not ns.is_tool -%}{{'<|Assistant|>\n'}} +{%- endif %} \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja index c2066bd7391..9e6ec845d39 100644 --- a/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja +++ b/models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja @@ -1 +1,47 @@ -{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') %}{%- for message in messages %}{%- if message['role'] == 'system' %}{% set ns.system_prompt = message['content'] %}{%- endif %}{%- endfor %}{{bos_token}}{{ns.system_prompt}}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is none %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls']%}{%- if not ns.is_first %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{%- set ns.is_first = true -%}{%- else %}{{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- endfor %}{%- endif %}{%- if message['role'] == 'assistant' and message['content'] is not none %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>\n'}}{% endif %} \ No newline at end of file +{% if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='') -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.system_prompt = message['content'] -%} + {%- endif -%} +{%- endfor -%}{{bos_token}}{{ns.system_prompt}} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['tool_calls'] -%} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls']-%} + {%- if not ns.is_first -%} + {{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- set ns.is_first = true -%} + {%- else -%} + {{'\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\n' + '```json' + '\n' + tool['function']['arguments'] + '\n' + '```' + '<|tool▁call▁end|>'}} + {%- endif -%} + {%- endfor -%} + {{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['content'] is not none -%} + {%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else -%} + {%- set content = message['content'] -%} + {%- if '' in content -%} + {%- set content = content.split('')[-1] -%} + {%- endif -%}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}} + {%- endif -%} + {%- endif -%} + {%- if message['role'] == 'tool' -%} + {%- set ns.is_tool = true -%} + {%- if ns.is_output_first -%}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- set ns.is_output_first = false -%} + {%- else -%}{{'\n<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif -%} + {%- endif -%} +{%- endfor -%} +{%- if ns.is_tool -%}{{'<|tool▁outputs▁end|>'}} +{%- endif -%} +{%- if add_generation_prompt and not ns.is_tool -%}{{'<|Assistant|>\n'}}{% if not enable_thinking %}{{- '' -}}{% endif %} +{%- endif %} \ No newline at end of file diff --git a/models/templates/deepseek-ai-DeepSeek-V3.1.jinja b/models/templates/deepseek-ai-DeepSeek-V3.1.jinja index e5656196a3f..2fd1c415b88 100644 --- a/models/templates/deepseek-ai-DeepSeek-V3.1.jinja +++ b/models/templates/deepseek-ai-DeepSeek-V3.1.jinja @@ -1,3 +1,71 @@ -{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% if not thinking is defined %}{% set thinking = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + ' +{% if not add_generation_prompt is defined -%} + {%- set add_generation_prompt = false -%} +{%- endif -%} +{%- if not thinking is defined -%} + {%- if enable_thinking is defined -%} + {%- set thinking = enable_thinking -%} + {%- else -%} + {%- set thinking = false -%} + {%- endif -%} +{%- endif -%} +{%- set ns = namespace(is_first=false, is_tool=false, system_prompt='', is_first_sp=true, is_last_user=false) -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- if ns.is_first_sp -%} + {%- set ns.system_prompt = ns.system_prompt + message['content'] -%} + {%- set ns.is_first_sp = false -%} + {%- else -%} + {%- set ns.system_prompt = ns.system_prompt + ' -' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{%- set ns.is_first = false -%}{%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and message['tool_calls'] is defined and message['tool_calls'] is not none %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- endif %}{%- set ns.is_last_user = false -%}{%- set ns.is_first = false %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- else %}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and (message['tool_calls'] is not defined or message['tool_calls'] is none) %}{%- if ns.is_last_user %}{{'<|Assistant|>'}}{%- if message['prefix'] is defined and message['prefix'] and thinking %}{{''}} {%- else %}{{''}}{%- endif %}{%- endif %}{%- set ns.is_last_user = false -%}{%- if ns.is_tool %}{{message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{%- set content = message['content'] -%}{%- if '' in content %}{%- set content = content.split('', 1)[1] -%}{%- endif %}{{content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_last_user = false -%}{%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endfor -%}{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool %}{{'<|Assistant|>'}}{%- if not thinking %}{{''}}{%- else %}{{''}}{%- endif %}{% endif %} \ No newline at end of file +' + message['content'] -%} + {%- endif -%} + {%- endif -%} +{%- endfor -%}{{ bos_token }}{{ ns.system_prompt }} +{%- for message in messages -%} + {%- if message['role'] == 'user' -%} + {%- set ns.is_tool = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_last_user = true -%}{{'<|User|>' + message['content']}} + {%- endif -%} + {%- if message['role'] == 'assistant' and message['tool_calls'] -%} + {%- if ns.is_last_user -%}{{'<|Assistant|>'}} + {%- endif -%} + {%- set ns.is_last_user = false -%} + {%- set ns.is_first = false -%} + {%- set ns.is_tool = false -%} + {%- for tool in message['tool_calls'] -%} + {%- if not ns.is_first -%} + {%- if not message['content'] -%}{{'<|tool▁calls▁begin|><|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}} + {%- else -%}{{message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}} + {%- endif -%} + {%- set ns.is_first = true -%} + {%- else -%}{{'<|tool▁call▁begin|>'+ tool['function']['name'] + '<|tool▁sep|>' + tool['function']['arguments'] + '<|tool▁call▁end|>'}} + {%- endif -%} + {%- endfor -%}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}} + {%- endif -%} + {%- if message['role'] == 'assistant' and not message['tool_calls'] -%} + {%- if ns.is_last_user -%}{{'<|Assistant|>'}} + {%- if message['prefix'] is defined and message['prefix'] and thinking -%}{{''}} + {%- else -%}{{''}} + {%- endif -%} + {%- endif -%} + {%- set ns.is_last_user = false -%} + {%- if ns.is_tool -%}{{message['content'] + '<|end▁of▁sentence|>'}} + {%- set ns.is_tool = false -%} + {%- else -%} + {%- set content = message['content'] -%} + {%- if '' in content -%} + {%- set content = content.split('', 1)[1] -%} + {%- endif -%}{{content + '<|end▁of▁sentence|>'}} + {%- endif -%} + {%- endif -%} + {%- if message['role'] == 'tool' -%} + {%- set ns.is_last_user = false -%} + {%- set ns.is_tool = true -%}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt and ns.is_last_user and not ns.is_tool -%}{{'<|Assistant|>'}} + {%- if not thinking -%}{{''}} + {%- else -%}{{''}} + {%- endif -%} +{%- endif %} \ No newline at end of file diff --git a/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja index 9b8136df73b..b94cfd4d9bd 100644 --- a/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja +++ b/models/templates/fireworks-ai-llama-3-firefunction-v2.jinja @@ -46,7 +46,7 @@ Available functions as JSON spec: {%- if 'tool_calls' in message and message['tool_calls'] -%} {%- set tool = namespace(calls=[]) -%} {%- for call in message['tool_calls'] -%} - {%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments'] + '}'] -%} + {%- set tool.calls = tool.calls + ['{"name": "' + call['function']['name'] + '", "arguments": ' + call['function']['arguments']|tojson + '}'] -%} {%- endfor -%} {%- set ns.content = ns.content + ' functools[' + tool.calls | join(', ') + ']' -%} {%- endif -%} diff --git a/models/templates/moonshotai-Kimi-K2.jinja b/models/templates/moonshotai-Kimi-K2.jinja index ecb49a21085..e286d8a7b5b 100644 --- a/models/templates/moonshotai-Kimi-K2.jinja +++ b/models/templates/moonshotai-Kimi-K2.jinja @@ -1,43 +1,43 @@ -{%- if tools -%} - <|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|> -{%- endif -%} -{%- for message in messages -%} - {%- if loop.first and messages[0]['role'] != 'system' -%} - <|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|> - {%- endif -%} - {%- if message['role'] == 'system' -%} - <|im_system|>system<|im_middle|> - {%- elif message['role'] == 'user' -%} - <|im_user|>user<|im_middle|> - {%- elif message['role'] == 'assistant' -%} - <|im_assistant|>assistant<|im_middle|> - {%- elif message['role'] == 'tool' -%} - <|im_system|>tool<|im_middle|> - {%- endif -%} - {%- if message['role'] == 'assistant' and message.get('tool_calls') -%} - {%- if message['content'] -%}{{ message['content'] }}{%- endif -%} - <|tool_calls_section_begin|> - {%- for tool_call in message['tool_calls'] -%} - {%- set func_name = tool_call['function']['name'] -%} - {%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%} - <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|> - {%- endfor -%} - <|tool_calls_section_end|> - {%- elif message['role'] == 'tool' -%} - ## Return of {{ message.tool_call_id }}\n{{ message['content'] }} - {%- elif message['content'] is string -%} - {{ message['content'] }} - {%- elif message['content'] is not none -%} - {% for content in message['content'] -%} - {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} - <|media_start|>image<|media_content|><|media_pad|><|media_end|> - {% else -%} - {{ content['text'] }} - {%- endif -%} - {%- endfor -%} - {%- endif -%} - <|im_end|> -{%- endfor -%} -{%- if add_generation_prompt -%} - <|im_assistant|>assistant<|im_middle|> -{%- endif -%} +{%- if tools -%} + <|im_system|>tool_declare<|im_middle|>{{ tools | tojson }}<|im_end|> +{%- endif -%} +{%- for message in messages -%} + {%- if loop.first and messages[0]['role'] != 'system' -%} + <|im_system|>system<|im_middle|>You are a helpful assistant<|im_end|> + {%- endif -%} + {%- if message['role'] == 'system' -%} + <|im_system|>system<|im_middle|> + {%- elif message['role'] == 'user' -%} + <|im_user|>user<|im_middle|> + {%- elif message['role'] == 'assistant' -%} + <|im_assistant|>assistant<|im_middle|> + {%- elif message['role'] == 'tool' -%} + <|im_system|>tool<|im_middle|> + {%- endif -%} + {%- if message['role'] == 'assistant' and message.get('tool_calls') -%} + {%- if message['content'] -%}{{ message['content'] }}{%- endif -%} + <|tool_calls_section_begin|> + {%- for tool_call in message['tool_calls'] -%} + {%- set func_name = tool_call['function']['name'] -%} + {%- set formatted_id = 'functions.' + func_name + ':' + loop.index0|string -%} + <|tool_call_begin|>{{ formatted_id }}<|tool_call_argument_begin|>{{ tool_call['function']['arguments'] | tojson}}<|tool_call_end|> + {%- endfor -%} + <|tool_calls_section_end|> + {%- elif message['role'] == 'tool' -%} + ## Return of {{ message.tool_call_id }}\n{{ message['content'] }} + {%- elif message['content'] is string -%} + {{ message['content'] }} + {%- elif message['content'] is not none -%} + {% for content in message['content'] -%} + {% if content['type'] == 'image' or 'image' in content or 'image_url' in content -%} + <|media_start|>image<|media_content|><|media_pad|><|media_end|> + {% else -%} + {{ content['text'] }} + {%- endif -%} + {%- endfor -%} + {%- endif -%} + <|im_end|> +{%- endfor -%} +{%- if add_generation_prompt -%} + <|im_assistant|>assistant<|im_middle|> +{%- endif -%} diff --git a/models/templates/unsloth-Apriel-1.5.jinja b/models/templates/unsloth-Apriel-1.5.jinja index 29e582fbf63..1639b639015 100644 --- a/models/templates/unsloth-Apriel-1.5.jinja +++ b/models/templates/unsloth-Apriel-1.5.jinja @@ -86,22 +86,22 @@ Prior to generating the function calls, you should generate the reasoning for wh {%- set add_tool_id = false -%} {%- endif -%} {{- '<|assistant|>\n' -}} - {%- if message['content'] is not none and message['content']|length > 0 -%} + {%- if message['content'] is defined and message['content'] is not none and message['content']|length > 0 -%} {%- if message['content'] is not string and message['content'][0]['text'] is not none %} {{- message['content'][0]['text'] }} {%- else %} {{- message['content'] -}} {%- endif -%} - {%- elif message['chosen'] is not none and message['chosen']|length > 0 -%} + {%- elif message['chosen'] is defined and message['chosen'] is not none and message['chosen']|length > 0 -%} {{- message['chosen'][0] -}} {%- endif -%} {%- if add_thoughts and 'thought' in message and message['thought'] is not none -%} {{- '' + message['thought'] + '' -}} {%- endif -%} - {%- if message['tool_calls'] is not none and message['tool_calls']|length > 0 -%} + {%- if message['tool_calls'] is defined and message['tool_calls'] is not none and message['tool_calls']|length > 0 -%} {{- '\n[' -}} {%- for tool_call in message["tool_calls"] -%} - {{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|string -}} + {{- '{"name": "' + tool_call['function']['name'] + '", "arguments": ' + tool_call['function']['arguments']|tojson -}} {%- if add_tool_id == true -%} {{- ', "id": "' + tool_call['id'] + '"' -}} {%- endif -%} diff --git a/scripts/get-wikitext-2.sh b/scripts/get-wikitext-2.sh index 67b0b0118b4..bd03ad35263 100755 --- a/scripts/get-wikitext-2.sh +++ b/scripts/get-wikitext-2.sh @@ -1,11 +1,43 @@ -#!/usr/bin/env bash +#!/bin/sh +# vim: set ts=4 sw=4 et: -wget https://huggingface.co/datasets/ggml-org/ci/resolve/main/wikitext-2-raw-v1.zip -unzip wikitext-2-raw-v1.zip +ZIP="wikitext-2-raw-v1.zip" +FILE="wikitext-2-raw/wiki.test.raw" +URL="https://huggingface.co/datasets/ggml-org/ci/resolve/main/$ZIP" -echo "Usage:" -echo "" -echo " ./llama-perplexity -m model.gguf -f wikitext-2-raw/wiki.test.raw [other params]" -echo "" +die() { + printf "%s\n" "$@" >&2 + exit 1 +} -exit 0 +have_cmd() { + for cmd; do + command -v "$cmd" >/dev/null || return + done +} + +dl() { + [ -f "$2" ] && return + if have_cmd wget; then + wget "$1" -O "$2" + elif have_cmd curl; then + curl -L "$1" -o "$2" + else + die "Please install wget or curl" + fi +} + +have_cmd unzip || die "Please install unzip" + +if [ ! -f "$FILE" ]; then + dl "$URL" "$ZIP" || exit + unzip -o "$ZIP" || exit + rm -f -- "$ZIP" +fi + +cat <= 1000. Negative values mean no seed.") + "Correlations between seeds can occur when set >= 1000. Negative values mean no seed.") args = parser.parse_args() benchmark(**vars(args)) diff --git a/scripts/server-test-model.py b/scripts/server-test-model.py new file mode 100644 index 00000000000..9049d80279a --- /dev/null +++ b/scripts/server-test-model.py @@ -0,0 +1,202 @@ +import argparse +import json +import requests +import logging +import sys + +handler = logging.StreamHandler(sys.stdout) +handler.terminator = "" # ← no newline +logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[handler]) +logger = logging.getLogger("server-test-model") + + +def run_query(url, messages, tools=None, stream=False, tool_choice=None): + payload = { + "messages": messages, + "stream": stream, + "max_tokens": 5000, + } + if tools: + payload["tools"] = tools + if tool_choice: + payload["tool_choice"] = tool_choice + + try: + response = requests.post(url, json=payload, stream=stream) + response.raise_for_status() + except requests.exceptions.RequestException as e: + if e.response is not None: + logger.info(f"Response error: {e} for {e.response.content}\n") + else: + logger.info(f"Error connecting to server: {e}\n") + return None + + full_content = "" + reasoning_content = "" + tool_calls = [] + + if stream: + logger.info(f"--- Streaming response (Tools: {bool(tools)}) ---\n") + for line in response.iter_lines(): + if line: + decoded_line = line.decode("utf-8") + if decoded_line.startswith("data: "): + data_str = decoded_line[6:] + if data_str == "[DONE]": + break + try: + data = json.loads(data_str) + if "choices" in data and len(data["choices"]) > 0: + delta = data["choices"][0].get("delta", {}) + + # Content + content_chunk = delta.get("content", "") + if content_chunk: + full_content += content_chunk + logger.info(content_chunk) + + # Reasoning + reasoning_chunk = delta.get("reasoning_content", "") + if reasoning_chunk: + reasoning_content += reasoning_chunk + logger.info(f"\x1B[3m{reasoning_chunk}\x1B[0m") + + # Tool calls + if "tool_calls" in delta: + for tc in delta["tool_calls"]: + index = tc.get("index") + if index is not None: + while len(tool_calls) <= index: + # Using "function" as type default but could be flexible + tool_calls.append( + { + "id": "", + "type": "function", + "function": { + "name": "", + "arguments": "", + }, + } + ) + + if "id" in tc: + tool_calls[index]["id"] += tc["id"] + if "function" in tc: + if "name" in tc["function"]: + tool_calls[index]["function"][ + "name" + ] += tc["function"]["name"] + if "arguments" in tc["function"]: + tool_calls[index]["function"][ + "arguments" + ] += tc["function"]["arguments"] + + except json.JSONDecodeError: + logger.info(f"Failed to decode JSON: {data_str}\n") + logger.info("\n--- End of Stream ---\n") + else: + logger.info(f"--- Non-streaming response (Tools: {bool(tools)}) ---\n") + data = response.json() + if "choices" in data and len(data["choices"]) > 0: + message = data["choices"][0].get("message", {}) + full_content = message.get("content", "") + reasoning_content = message.get("reasoning_content", "") + tool_calls = message.get("tool_calls", []) + logger.info(full_content) + logger.info("--- End of Response ---\n") + + return { + "content": full_content, + "reasoning_content": reasoning_content, + "tool_calls": tool_calls, + } + + +def test_chat(url, stream): + logger.info(f"\n=== Testing Chat (Stream={stream}) ===\n") + messages = [{"role": "user", "content": "What is the capital of France?"}] + result = run_query(url, messages, stream=stream) + + if result: + if result["content"]: + logger.info("PASS: Output received.\n") + else: + logger.info("WARN: No content received (valid if strict tool call, but unexpected here).\n") + + if result.get("reasoning_content"): + logger.info(f"INFO: Reasoning content detected ({len(result['reasoning_content'])} chars).\n") + else: + logger.info("INFO: No reasoning content detected (Standard model behavior).\n") + else: + logger.info("FAIL: No result.\n") + + +def test_tool_call(url, stream): + logger.info(f"\n=== Testing Tool Call (Stream={stream}) ===\n") + messages = [ + { + "role": "user", + "content": "What is the weather in London? Please use the get_weather tool.", + } + ] + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + + result = run_query(url, messages, tools=tools, tool_choice="auto", stream=stream) + + if result: + tcs = result.get("tool_calls") + if tcs and len(tcs) > 0: + logger.info("PASS: Tool calls detected.") + for tc in tcs: + func = tc.get("function", {}) + logger.info(f" Tool: {func.get('name')}, Args: {func.get('arguments')}\n") + else: + logger.info(f"FAIL: No tool calls. Content: {result['content']}\n") + + if result.get("reasoning_content"): + logger.info( + f"INFO: Reasoning content detected during tool call ({len(result['reasoning_content'])} chars).\n" + ) + else: + logger.info("FAIL: Query failed.\n") + + +def main(): + parser = argparse.ArgumentParser(description="Test llama-server functionality.") + parser.add_argument("--host", default="localhost", help="Server host") + parser.add_argument("--port", default=8080, type=int, help="Server port") + args = parser.parse_args() + + base_url = f"http://{args.host}:{args.port}/v1/chat/completions" + logger.info(f"Testing server at {base_url}\n") + + # Non-streaming tests + test_chat(base_url, stream=False) + test_tool_call(base_url, stream=False) + + # Streaming tests + test_chat(base_url, stream=True) + test_tool_call(base_url, stream=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/snapdragon/windows/run-cli.ps1 b/scripts/snapdragon/windows/run-cli.ps1 index 40c7acc430f..5891c894a9f 100644 --- a/scripts/snapdragon/windows/run-cli.ps1 +++ b/scripts/snapdragon/windows/run-cli.ps1 @@ -46,8 +46,8 @@ if ($null -ne $env:NDEV) { $env:ADSP_LIBRARY_PATH="$basedir\lib" -& "$basedir\bin\llama-completion.exe" ` - --no-mmap -no-cnv -m $basedir\..\..\gguf\$model ` +& "$basedir\bin\llama-cli.exe" ` + --no-mmap -m $basedir\..\..\gguf\$model ` --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` --ctx-size 8192 --ubatch-size 128 -fa on ` -ngl 99 --device $device $cli_opts diff --git a/scripts/snapdragon/windows/run-completion.ps1 b/scripts/snapdragon/windows/run-completion.ps1 new file mode 100644 index 00000000000..8a48d2d7486 --- /dev/null +++ b/scripts/snapdragon/windows/run-completion.ps1 @@ -0,0 +1,53 @@ + +#!/usr/bin/env pwsh + +# Basedir on device +$basedir=".\pkg-snapdragon" + +$cli_opts=$args + +$model="Llama-3.2-3B-Instruct-Q4_0.gguf" +if ($null -ne $env:M) { + $model=$env:M +} + +$device="HTP0" +if ($null -ne $env:D) { + $device=$env:D +} + +if ($null -ne $env:V) { + $env:GGML_HEXAGON_VERBOSE=$env:V +} + +if ($null -ne $env:E) { + $env:GGML_HEXAGON_EXPERIMENTAL=$env:E +} + +if ($null -ne $env:SCHED) { + $env:GGML_SCHED_DEBUG=$env:SCHED; $cli_opts="$cli_opts -v" +} + +if ($null -ne $env:PROF) { + $env:GGML_HEXAGON_PROFILE=$env:PROF; $env:GGML_HEXAGON_OPSYNC=1 +} + +if ($null -ne $env:OPMASK) { + $env:GGML_HEXAGON_OPMASK=$env:OPMASK +} + +if ($null -ne $env:NHVX) { + $env:GGML_HEXAGON_NHVX=$env:NHVX +} + +if ($null -ne $env:NDEV) { + $env:GGML_HEXAGON_NDEV=$env:NDEV +} + +$env:ADSP_LIBRARY_PATH="$basedir\lib" + +& "$basedir\bin\llama-completion.exe" ` + --no-mmap -m $basedir\..\..\gguf\$model ` + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` + --ctx-size 8192 --batch-size 128 -fa on ` + -ngl 99 -no-cnv --device $device $cli_opts diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 386fab04ac9..6bf76939cdd 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -394,11 +394,13 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t clear(); split_reset(); + const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd; + auto udata = std::make_shared(); udata->token .resize(n_tokens); udata->embd .clear(); - udata->pos .resize(n_tokens); + udata->pos .resize(n_pos_all); udata->n_seq_id .resize(n_tokens); udata->seq_id .resize(n_tokens); udata->seq_id_unq.resize(0); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 98d055d34ef..d050604b2a0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -158,7 +158,7 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; - // intialized later + // initialized later cparams.pipeline_parallel = false; { @@ -1039,11 +1039,15 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (n_adapters != loras->size()) { - return false; - } + // Adapters with a zero scale are never added to `loras`, so also ignore them for the comparison. + size_t n_non_zero = 0; for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] == 0.0f) { + continue; + } + n_non_zero++; + auto it = loras->find(adapters[i]); if (it == loras->end() || it->second != scales[i]) { @@ -1051,6 +1055,10 @@ bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_ } } + if (n_non_zero != loras->size()) { + return false; + } + return true; } @@ -1981,7 +1989,7 @@ ggml_cgraph * llama_context::graph_reserve( ggml_backend_sched_reset(sched.get()); - // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that + // when the scheduler is reset, we cannot reuse the old graph, so we reset the previous graph result to prevent that gf_res_prev->reset(); // store the n_outputs as it is, and restore it afterwards diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 23a86ea2905..b8126ce5081 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1616,7 +1616,7 @@ ggml_tensor * llm_graph_context::build_inp_attn_scale() const { ggml_tensor * llm_graph_context::build_inp_out_ids() const { // note: when all tokens are output, we could skip this optimization to spare the ggml_get_rows() calls, // but this would make the graph topology depend on the number of output tokens, which can interere with - // features that require constant topology such as pipline parallelism + // features that require constant topology such as pipeline parallelism // ref: https://github.com/ggml-org/llama.cpp/pull/14275#issuecomment-2987424471 //if (n_outputs < n_tokens) { // return nullptr; @@ -1779,7 +1779,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( if (v_mla) { #if 0 // v_mla can be applied as a matrix-vector multiplication with broadcasting across dimension 3 == n_tokens. - // However, the code is optimized for dimensions 0 and 1 being large, so this is ineffient. + // However, the code is optimized for dimensions 0 and 1 being large, so this is inefficient. cur = ggml_reshape_4d(ctx0, cur, v_mla->ne[0], 1, n_head, n_tokens); cur = ggml_mul_mat(ctx0, v_mla, cur); #else diff --git a/src/llama-impl.cpp b/src/llama-impl.cpp index 710a5a1e08d..4c0188ee722 100644 --- a/src/llama-impl.cpp +++ b/src/llama-impl.cpp @@ -100,9 +100,9 @@ std::string format(const char * fmt, ...) { std::string llama_format_tensor_shape(const std::vector & ne) { char buf[256]; - snprintf(buf, sizeof(buf), "%5" PRId64, ne.at(0)); + snprintf(buf, sizeof(buf), "%6" PRId64, ne.at(0)); for (size_t i = 1; i < ne.size(); i++) { - snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %5" PRId64, ne.at(i)); + snprintf(buf + strlen(buf), sizeof(buf) - strlen(buf), ", %6" PRId64, ne.at(i)); } return buf; } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 6b668ee9abd..d80e8a70bc2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -583,7 +583,7 @@ llama_kv_cache::slot_info_vec_t llama_kv_cache::prepare(const std::vector seq_srct; std::unordered_map> seq_idxs; @@ -1760,8 +1760,10 @@ void llama_kv_cache::state_write_meta(llama_io_write_i & io, const cell_ranges_t io.write(&pos, sizeof(pos)); io.write(&n_seq_id, sizeof(n_seq_id)); - // TODO: we also need to save llama_kv_cell_ext when apply_ubatch() support loading it - // see: https://github.com/ggml-org/llama.cpp/pull/16825#issuecomment-3460868350 + if (hparams.n_pos_per_embd() > 1) { + const llama_kv_cell_ext ext = cells.ext_get(i); + io.write(&ext, sizeof(ext)); + } for (const auto & seq_id : seq_ids) { io.write(&seq_id, sizeof(seq_id)); @@ -1895,6 +1897,14 @@ bool llama_kv_cache::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32 return false; } + if (hparams.n_pos_per_embd() > 1) { + llama_kv_cell_ext ext; + io.read_to(&ext, sizeof(ext)); + + ubatch.pos[i + ubatch.n_tokens] = ext.y; + ubatch.pos[i + ubatch.n_tokens*2] = ext.x; + } + // read the sequence id, but directly discard it - we will use dest_seq_id instead { llama_seq_id seq_id; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index dabf3b3086e..924e5708cde 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -61,6 +61,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_0_3B: return "0.3B"; case LLM_TYPE_0_5B: return "0.5B"; case LLM_TYPE_0_6B: return "0.6B"; + case LLM_TYPE_0_8B: return "0.8B"; case LLM_TYPE_1B: return "1B"; case LLM_TYPE_1_2B: return "1.2B"; case LLM_TYPE_1_3B: return "1.3B"; @@ -132,12 +133,14 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_100B_A6B: return "100B.A6B"; case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; + case LLM_TYPE_122B_A10B: return "122B.A10B"; case LLM_TYPE_196B_A11B: return "196B.A11B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_397B_A17B: return "397B.A17B"; case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; @@ -1524,7 +1527,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - // TODO: Jamba layers are a bit heterogenous, so naming this is hard. + // TODO: Jamba layers are a bit heterogeneous, so naming this is hard. case 12: // 900M 8x???M case 32: // 51B 16x?B default: type = LLM_TYPE_UNKNOWN; @@ -2528,7 +2531,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - case 24: type = LLM_TYPE_2B; break; + case 24: type = hparams.n_embd == 1024 ? LLM_TYPE_0_8B : LLM_TYPE_2B; break; + case 32: type = hparams.n_embd == 2560 ? LLM_TYPE_4B : LLM_TYPE_9B; break; + case 64: type = LLM_TYPE_27B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -2557,8 +2562,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { } switch (hparams.n_layer) { - case 28: type = LLM_TYPE_35B_A3B; break; - case 48: type = LLM_TYPE_80B_A3B; break; + case 40: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_122B_A10B; break; + case 60: type = LLM_TYPE_397B_A17B; break; default: type = LLM_TYPE_UNKNOWN; } } break; diff --git a/src/llama-model.h b/src/llama-model.h index d7c3e7d1c1a..5ecb8344a25 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -54,6 +54,7 @@ enum llm_type { LLM_TYPE_0_3B, LLM_TYPE_0_5B, LLM_TYPE_0_6B, + LLM_TYPE_0_8B, LLM_TYPE_1B, LLM_TYPE_1_2B, LLM_TYPE_1_3B, @@ -125,12 +126,14 @@ enum llm_type { LLM_TYPE_100B_A6B, LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air + LLM_TYPE_122B_A10B, // Qwen3.5 LLM_TYPE_196B_A11B, // Step3.5-Flash LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_397B_A17B, // Qwen3.5 LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 194eed238ec..ce83361dc79 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1833,7 +1833,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx); precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap); #if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ - // correct endiannes of data in precompiled_charsmap binary blob + // correct endianness of data in precompiled_charsmap binary blob uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0]; *xcda_blob_size = __builtin_bswap32(*xcda_blob_size); assert(*xcda_blob_size + sizeof(uint32_t) < n_precompiled_charsmap); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index b608396e50e..be81709c50b 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -146,7 +146,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr cb(Qcur, "Qcur_attn_temp_scaled", il); } - // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) + // note: MLA with the absorption optimization converts into MQA (ie: GQA with 1 group) cur = build_attn(inp_attn_k, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il); diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index 99f1fdd9538..c57abbb5b74 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -1,7 +1,5 @@ #include "models.h" -#define CHUNK_SIZE 64 - // utility to get one slice from the third dimension // input dim: [x, y, c, b] // output dim: [x, y, 1, b] @@ -57,7 +55,7 @@ std::pair llm_build_delta_net_base::build_delta_ne g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] - const int CS = CHUNK_SIZE; + const int CS = kda ? 16 : 64; // chunk size const int pad = (CS - n_tokens % CS) % CS; const int n_chunks = (n_tokens + pad) / CS; diff --git a/src/models/models.h b/src/models/models.h index 0712d03d8d9..cf9ba04e7f7 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -3,7 +3,7 @@ #include "llama-model.h" #include "llama-graph.h" -// note: almost all graphs require atleast sqrtf, so include cmath globally +// note: almost all graphs require at least sqrtf, so include cmath globally #include // diff --git a/src/unicode.cpp b/src/unicode.cpp index 1475b53b659..122c8ca04a5 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -773,7 +773,7 @@ static std::vector unicode_regex_split_custom(const std::string & text, // tiny_aya digit grouping pattern from tokenizer.json: // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567) - // TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex. + // TODO: Revisit this regex, in case there are any subtle tokenization differences with the original regex. bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 7e0b17a7c1f..46ab7a0cef0 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -187,11 +187,11 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) # llama_build_and_test(test-double-float.cpp) # SLOW endif() -llama_build_and_test(test-chat-parser.cpp) llama_build_and_test(test-chat-peg-parser.cpp peg-parser/simple-tokenize.cpp) -llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-jinja.cpp) llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python) +llama_build_and_test(test-chat-auto-parser.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) +llama_build_and_test(test-chat-template.cpp) llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( @@ -201,6 +201,7 @@ llama_build_and_test( peg-parser/test-gbnf-generation.cpp peg-parser/test-json-parser.cpp peg-parser/test-json-serialization.cpp + peg-parser/test-python-dict-parser.cpp peg-parser/test-unicode.cpp peg-parser/tests.h ) @@ -279,3 +280,5 @@ target_link_libraries(${TEST_TARGET} PRIVATE llama) llama_build_and_test(test-alloc.cpp) target_include_directories(test-alloc PRIVATE ${PROJECT_SOURCE_DIR}/ggml/src) + + diff --git a/tests/peg-parser/test-basic.cpp b/tests/peg-parser/test-basic.cpp index 1bda6f2e690..872f16a78df 100644 --- a/tests/peg-parser/test-basic.cpp +++ b/tests/peg-parser/test-basic.cpp @@ -1,3 +1,4 @@ +#include "peg-parser.h" #include "tests.h" void test_basic(testing & t) { @@ -450,5 +451,21 @@ void test_basic(testing & t) { t.assert_equal("result_is_fail", true, result.fail()); }); + + // Test markers + t.test("marker", [](testing &t) { + auto bracket_parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.marker(); + }); + + common_peg_parse_context ctx_square("[marker]", false); + common_peg_parse_context ctx_sharp("", false); + + auto result_square = bracket_parser.parse(ctx_square); + auto result_sharp = bracket_parser.parse(ctx_sharp); + + t.assert_true("result_square_is_success", result_square.success()); + t.assert_true("result_sharp_is_success", result_sharp.success()); + }); }); } diff --git a/tests/peg-parser/test-python-dict-parser.cpp b/tests/peg-parser/test-python-dict-parser.cpp new file mode 100644 index 00000000000..d9946a4916f --- /dev/null +++ b/tests/peg-parser/test-python-dict-parser.cpp @@ -0,0 +1,318 @@ +#include "tests.h" + +void test_python_dict_parser(testing &t) { + // Test parsing a simple Python dict object with single quotes + t.test("simple Python dict object parsing", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'name': 'test', 'value': 42, 'flag': True}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing a Python array with mixed types + t.test("Python array with mixed types", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "[1, 'hello', True, None, 3.14]"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing nested Python dict with objects and arrays + t.test("nested Python dict with objects and arrays", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = + "{'users': [{'id': 1, 'name': 'Alice'}, {'id': 2, 'name': 'Bob'}], 'count': 2, 'metadata': {'version': '1.0', 'tags': ['admin', 'user']}}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing Python dict with escaped single quotes + t.test("Python dict with escaped single quotes", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'message': 'It\\'s working!'}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test parsing Python dict with double quotes inside single quotes + t.test("Python dict with double quotes inside single quotes", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'quote': 'He said \"Hello\"'}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test the example from the requirements + t.test("complex Python dict example from requirements", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{ 'obj' : { 'something': 1, 'other \"something\"' : 'foo\\'s bar' } }"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test need_more_input() parsing - incomplete object + t.test("need_more_input() parsing - incomplete object", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'name': 'test', 'value': "; + common_peg_parse_context ctx(input, true); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test need_more_input() parsing - incomplete single-quoted string + t.test("need_more_input() parsing - incomplete single-quoted string", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'name': 'test"; + common_peg_parse_context ctx(input, true); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_need_more_input", true, result.need_more_input()); + }); + + // Test unicode in Python dict strings + t.test("unicode in Python dict strings", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'message': 'Hello, 世界!'}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test Python dict with unicode escapes + t.test("Python dict with unicode escapes", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{'unicode': 'Hello\\u0041'}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test that Python parser accepts double-quoted strings too + t.test("Python parser accepts double-quoted strings", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{\"name\": \"test\"}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test Python parser with mixed quote styles + t.test("Python parser with mixed quote styles", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + std::string input = "{\"name\": 'test', 'value': \"hello\"}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + + t.assert_equal("result_is_success", true, result.success()); + t.assert_equal("result_end", input.size(), result.end); + }); + + // Test Python True/False/None + t.test("Python True/False/None", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { return p.python_value(); }); + + t.test("True", [&](testing &t) { + std::string input = "True"; + common_peg_parse_context ctx(input); + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("False", [&](testing &t) { + std::string input = "False"; + common_peg_parse_context ctx(input); + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("None", [&](testing &t) { + std::string input = "None"; + common_peg_parse_context ctx(input); + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("rejects JSON-style true/false/null", [&](testing &t) { + for (const auto & kw : {"true", "false", "null"}) { + std::string input = kw; + common_peg_parse_context ctx(input); + auto result = parser.parse(ctx); + t.assert_true(std::string("rejects ") + kw, result.fail()); + } + }); + }); + + // Test single-quoted string content parser directly + t.test("single-quoted string content parser", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.sequence({ p.literal("'"), p.single_quoted_string_content(), p.literal("'"), p.space() }); + }); + + t.test("simple string", [&](testing &t) { + std::string input = "'hello'"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("string with escaped single quote", [&](testing &t) { + std::string input = "'it\\'s'"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("string with double quotes", [&](testing &t) { + std::string input = "'say \"hello\"'"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("incomplete string", [&](testing &t) { + std::string input = "'hello"; + common_peg_parse_context ctx(input, true); + + auto result = parser.parse(ctx); + t.assert_true("need_more_input", result.need_more_input()); + }); + }); + + // Test json() with pre-registered flexible json-string rule (python dict support) + t.test("json() parser with flexible json-string rule", [](testing &t) { + t.test("json() rejects single quotes by default", [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.json(); + }); + + std::string input = "{'name': 'test'}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("fail", result.fail()); + }); + + t.test("json() accepts single quotes with pre-registered flexible json-string rule", [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + // Pre-register json-string rule with both quote styles + p.rule("json-string", [&]() { + return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); + }); + return p.json(); + }); + + std::string input = "{'name': 'test'}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("json() still accepts double quotes with flexible json-string rule", [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("json-string", [&]() { + return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); + }); + return p.json(); + }); + + std::string input = "{\"name\": \"test\"}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("json() accepts mixed quote styles with flexible json-string rule", [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("json-string", [&]() { + return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); + }); + return p.json(); + }); + + std::string input = "{\"name\": 'test', 'value': \"hello\"}"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + + t.test("complex nested structure with flexible json-string rule", [&](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + p.rule("json-string", [&]() { + return p.choice({ p.double_quoted_string(), p.single_quoted_string() }); + }); + return p.json(); + }); + + std::string input = "{ 'obj' : { 'something': 1, 'other \"something\"' : 'foo\\'s bar' } }"; + common_peg_parse_context ctx(input); + + auto result = parser.parse(ctx); + t.assert_true("success", result.success()); + t.assert_equal("end", input.size(), result.end); + }); + }); +} diff --git a/tests/peg-parser/tests.h b/tests/peg-parser/tests.h index 4d3f4e9eaf5..debd4286c50 100644 --- a/tests/peg-parser/tests.h +++ b/tests/peg-parser/tests.h @@ -22,3 +22,4 @@ void test_json_parser(testing &t); void test_gbnf_generation(testing &t); void test_unicode(testing &t); void test_json_serialization(testing &t); +void test_python_dict_parser(testing &t); diff --git a/tests/test-alloc.cpp b/tests/test-alloc.cpp index 95e09c97b02..7ae739ad2ef 100644 --- a/tests/test-alloc.cpp +++ b/tests/test-alloc.cpp @@ -285,7 +285,7 @@ static void test_max_size_too_many_tensors() { GGML_ASSERT(backend.context->allocated_total() <= 16 + 16); } -// Scenario where there is some space left in the first buffer, but not enough to accomodate +// Scenario where there is some space left in the first buffer, but not enough to accommodate // a larger tensor, so a second buffer is required static void test_max_size_tensor_too_large() { dummy_backend backend = dummy_backend_init(32); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index e8e237c6ec8..faa771e0869 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1868,9 +1868,9 @@ struct test_case { }; -// ################################### -// ## Section 2: GGML Op Defintions ## -// ################################### +// #################################### +// ## Section 2: GGML Op Definitions ## +// #################################### // The following is an example showing the bare minimum for creating a test for a GGML op. @@ -2977,6 +2977,7 @@ struct test_bin_bcast : public test_case { const std::array nr; int nf; // number of fused ops, nf == 1 -> single op (no fusion) bool perm1; // permute src1? + bool src_overlap; // src0 and src1 are overlapping views of the same buffer bool run_whole_graph() override { return nf > 1; } @@ -2992,8 +2993,8 @@ struct test_bin_bcast : public test_case { std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}, int nf = 1, - bool perm1 = false) - : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {} + bool perm1 = false, bool src_overlap = false) + : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1), src_overlap(src_overlap) {} ggml_tensor * build_graph(ggml_context * ctx) override { GGML_ASSERT(nf <= 16); @@ -3008,6 +3009,8 @@ struct test_bin_bcast : public test_case { b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]); b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]); + } else if (src_overlap) { + b[i] = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], (ne[3] / 3) * a->nb[3]); } else { b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); } @@ -3021,7 +3024,13 @@ struct test_bin_bcast : public test_case { ggml_set_param(b[0]); } - ggml_tensor * out = a; + ggml_tensor *out; + + if (src_overlap) { + out = ggml_view_4d(ctx, a, ne[0], ne[1], ne[2], 2 * (ne[3] / 3), a->nb[1], a->nb[2], a->nb[3], 0); + } else { + out = a; + } for (int i = 0; i < nf; ++i) { out = op(ctx, out, b[i]); @@ -6213,7 +6222,7 @@ struct test_flash_attn_ext : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (strcmp(t->name, "s") == 0) { - // make the sink values more noticable in order to trigger a test failure when the implementation is wrong + // make the sink values more noticeable in order to trigger a test failure when the implementation is wrong init_tensor_uniform(t, -10.0f, 10.0f); } else if (strcmp(t->name, "m") == 0) { init_tensor_kq_mask(t); @@ -7527,9 +7536,9 @@ static std::vector> make_test_cases_eval() { } } - auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false) { + auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false, bool src_overlap = false) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { - test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1)); + test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1, src_overlap)); } }; for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { @@ -7549,6 +7558,12 @@ static std::vector> make_test_cases_eval() { add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1); } + // src_overlap + add_test_bin_bcast(type, {10, 5, 4, 6}, {1, 1, 1, 1}, false, true); + add_test_bin_bcast(type, {10, 5, 4, 5}, {1, 1, 1, 1}, false, true); + add_test_bin_bcast(type, {1, 1, 120, 120}, {1, 1, 1, 1}, false, true); + add_test_bin_bcast(type, {1, 1, 4, 320}, {1, 1, 1, 1}, false, true); + // test case for k_bin_bcast_unravel in CUDA backend add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1}); @@ -7648,6 +7663,9 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); + // long token (n_t > 32, exercises the long_token kernel path) + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1})); + test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1})); } } @@ -7802,6 +7820,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 1, 64, 256, {1, 1}, {1, 1})); } + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 6, 4096, 5120, {1, 1}, {1, 1})); + #if 0 // test the mat-mat path for Metal for (int k = 1; k < 512; ++k) { diff --git a/tests/test-chat-auto-parser.cpp b/tests/test-chat-auto-parser.cpp new file mode 100644 index 00000000000..f2364862c59 --- /dev/null +++ b/tests/test-chat-auto-parser.cpp @@ -0,0 +1,1889 @@ +#include "chat-auto-parser-helpers.h" +#include "chat-auto-parser.h" +#include "chat-peg-parser.h" +#include "chat.h" +#include "peg-parser.h" +#include "testing.h" + +#include +#include +#include +#include + +using namespace autoparser; + +static void test_calculate_diff_split_basic(testing & t); +static void test_calculate_diff_split_identical(testing & t); +static void test_calculate_diff_split_common_prefix(testing & t); +static void test_calculate_diff_split_common_suffix(testing & t); +static void test_calculate_diff_split_common_both(testing & t); +static void test_calculate_diff_split_empty_cases(testing & t); +static void test_calculate_diff_split_no_common(testing & t); +static void test_calculate_diff_split_single_char(testing & t); +static void test_calculate_diff_split_overlaps(testing & t); +static void test_calculate_diff_split_tag_boundaries(testing & t); +static void test_calculate_diff_split(testing & t); + +static void test_until_common_prefix_basic(testing & t); +static void test_until_common_prefix(testing & t); + +static void test_after_common_suffix_basic(testing & t); +static void test_after_common_suffix(testing & t); + +static void test_analyze_tool_call_pure_json(testing & t); +static void test_analyze_tool_call_function_name_markers(testing & t); +static void test_analyze_tool_call_full_markers(testing & t); +static void test_analyze_tool_call_edge_cases(testing & t); + +static void test_compare_variants_basic(testing & t); +static void test_compare_variants_messages_modifier(testing & t); +static void test_compare_variants_tools_modifier(testing & t); +static void test_compare_variants_both_modifiers(testing & t); +static void test_compare_variants_template_failure(testing & t); +static void test_compare_variants_identity(testing & t); +static void test_compare_variants(testing & t); + +// Seed-OSS template tool calling analysis tests +static void test_seed_oss_tool_analysis(testing & t); +static void test_seed_oss_tool_presence(testing & t); +static void test_seed_oss_call_count(testing & t); +static void test_seed_oss_function_names(testing & t); +static void test_seed_oss_argument_count(testing & t); +static void test_seed_oss_args_presence(testing & t); +static void test_seed_oss_tool_with_reasoning(testing & t); + +// Nemotron template analysis tests +static void test_nemotron_analysis(testing & t); +static void test_nemotron_reasoning_detection(testing & t); +static void test_nemotron_tool_format(testing & t); + +// CohereForAI template analysis tests +static void test_cohere_reasoning_detection(testing & t); +static void test_cohere_analysis(testing & t); + +// Marker separation +static void test_marker_separation(testing & t); + +// standard_json_tools format tests +static void test_standard_json_tools_formats(testing & t); +static void test_standard_json_tools_openai(testing & t); +static void test_standard_json_tools_cohere(testing & t); +static void test_standard_json_tools_function_key(testing & t); + +// normalize_quotes_to_json tests +static void test_normalize_quotes_to_json(testing & t); +static void test_normalize_quotes_with_embedded_quotes(testing & t); + +// TAG_WITH_TAGGED argument parsing tests +static void test_tagged_args_with_embedded_quotes(testing & t); + +int main(int argc, char * argv[]) { + testing t(std::cout); + t.verbose = true; + + // usage: test-chat-auto-parser-helpers [filter_regex] + + if (argc > 1) { + t.set_filter(argv[1]); + } + + t.test("diff_split", test_calculate_diff_split); + t.test("common_prefix", test_until_common_prefix); + t.test("common_suffix", test_after_common_suffix); + t.test("compare_variants", test_compare_variants); + t.test("segments", test_marker_separation); + t.test("seed_oss_diffs", test_seed_oss_tool_analysis); + t.test("cohere", test_cohere_analysis); + t.test("nemotron", test_nemotron_analysis); + t.test("standard_json_tools", test_standard_json_tools_formats); + t.test("normalize_quotes_to_json", test_normalize_quotes_to_json); + t.test("tagged_args_embedded_quotes", test_tagged_args_with_embedded_quotes); + + return t.summary(); +} + +static void test_marker_separation(testing & t) { + auto single_square_marker = segmentize_markers("pre_marker[marker]post_marker"); + auto single_diag_marker = segmentize_markers("pre_markerpost_marker"); + auto paired_markers = segmentize_markers("world"); + auto double_different_markers = segmentize_markers("[hello][world]"); + auto in_between = segmentize_markers("imdabada[hey]"); + + t.test("single_square_marker", [&] (testing & t) { + t.assert_equal("first is text", segment_type::TEXT, single_square_marker[0].type); + t.assert_equal("second is marker", segment_type::MARKER, single_square_marker[1].type); + t.assert_equal("last is text", segment_type::TEXT, single_square_marker[2].type); + + t.assert_equal("first is 'pre_marker'", "pre_marker", single_square_marker[0].value); + t.assert_equal("second is '[marker]'", "[marker]", single_square_marker[1].value); + t.assert_equal("last is 'post_marker'", "post_marker", single_square_marker[2].value); + }); + + t.test("single_diagonal_marker", [&] (testing & t) { + t.assert_equal("first is text", segment_type::TEXT, single_diag_marker[0].type); + t.assert_equal("second is marker", segment_type::MARKER, single_diag_marker[1].type); + t.assert_equal("last is text", segment_type::TEXT, single_diag_marker[2].type); + + t.assert_equal("first is 'pre_marker'", "pre_marker", single_diag_marker[0].value); + t.assert_equal("second is ''", "", single_diag_marker[1].value); + t.assert_equal("last is 'post_marker'", "post_marker", single_diag_marker[2].value); + }); + + t.test("paired_markers", [&] (testing & t) { + t.assert_equal("first is marker", segment_type::MARKER, paired_markers[0].type); + t.assert_equal("second is text", segment_type::TEXT, paired_markers[1].type); + t.assert_equal("third is marker", segment_type::MARKER, paired_markers[2].type); + + t.assert_equal("first is ''", "", paired_markers[0].value); + t.assert_equal("second is 'world'", "world", paired_markers[1].value); + t.assert_equal("third is ''", "", paired_markers[2].value); + }); + + t.test("double_different_markers", [&] (testing & t) { + t.assert_equal("first is marker", segment_type::MARKER, double_different_markers[0].type); + t.assert_equal("second is marker", segment_type::MARKER, double_different_markers[1].type); + t.assert_equal("third is marker", segment_type::MARKER, double_different_markers[2].type); + t.assert_equal("fourth is marker", segment_type::MARKER, double_different_markers[3].type); + + t.assert_equal("first is ''", "", double_different_markers[0].value); + t.assert_equal("second is '[hello]'", "[hello]", double_different_markers[1].value); + t.assert_equal("third is ''", "", double_different_markers[2].value); + t.assert_equal("fourth is '[world]'", "[world]", double_different_markers[3].value); + }); + + t.test("in_between", [&] (testing & t) { + t.assert_equal("first is text", segment_type::TEXT, in_between[0].type); + t.assert_equal("second is marker", segment_type::MARKER, in_between[1].type); + t.assert_equal("third is text", segment_type::TEXT, in_between[2].type); + t.assert_equal("fourth is marker", segment_type::MARKER, in_between[3].type); + t.assert_equal("fifth is text", segment_type::TEXT, in_between[4].type); + t.assert_equal("sixth is marker", segment_type::MARKER, in_between[5].type); + + t.assert_equal("first is 'im'", "im", in_between[0].value); + t.assert_equal("second is ''", "", in_between[1].value); + t.assert_equal("third is 'daba'", "daba", in_between[2].value); + t.assert_equal("fourth is ''", "", in_between[3].value); + t.assert_equal("fifth is 'da'", "da", in_between[4].value); + t.assert_equal("sixth is '[hey]'", "[hey]", in_between[5].value); + }); +} + +static void test_calculate_diff_split(testing & t) { + t.test("calculate_diff_split basic", test_calculate_diff_split_basic); + t.test("calculate_diff_split identical", test_calculate_diff_split_identical); + t.test("calculate_diff_split common prefix", test_calculate_diff_split_common_prefix); + t.test("calculate_diff_split common suffix", test_calculate_diff_split_common_suffix); + t.test("calculate_diff_split common both", test_calculate_diff_split_common_both); + t.test("calculate_diff_split empty cases", test_calculate_diff_split_empty_cases); + t.test("calculate_diff_split no common", test_calculate_diff_split_no_common); + t.test("calculate_diff_split single char", test_calculate_diff_split_single_char); + t.test("calculate_diff_split overlaps", test_calculate_diff_split_overlaps); + t.test("calculate_diff_split tag boundaries", test_calculate_diff_split_tag_boundaries); +} + +static void test_calculate_diff_split_basic(testing & t) { + diff_split result = calculate_diff_split("hello world", "hello test"); + t.assert_equal("prefix should be 'hello '", "hello ", result.prefix); + t.assert_equal("left should be 'world'", "world", result.left); + t.assert_equal("right should be 'test'", "test", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("abc", "xyz"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'abc'", "abc", result.left); + t.assert_equal("right should be 'xyz'", "xyz", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("prefixA suffix", "prefixB suffix"); + t.assert_equal("prefix should be 'prefix'", "prefix", result.prefix); + t.assert_equal("left should be 'A'", "A", result.left); + t.assert_equal("right should be 'B'", "B", result.right); + t.assert_equal("suffix should be ' suffix'", " suffix", result.suffix); +} + +static void test_calculate_diff_split_identical(testing & t) { + diff_split result = calculate_diff_split("hello", "hello"); + t.assert_equal("prefix should be 'hello'", "hello", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("", ""); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("a", "a"); + t.assert_equal("prefix should be 'a'", "a", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("", ""); + t.assert_equal("prefix should be ''", "", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_common_prefix(testing & t) { + diff_split result = calculate_diff_split("abcdef", "abcxyz"); + t.assert_equal("prefix should be 'abc'", "abc", result.prefix); + t.assert_equal("left should be 'def'", "def", result.left); + t.assert_equal("right should be 'xyz'", "xyz", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("same", "sameagain"); + t.assert_equal("prefix should be 'same'", "same", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be 'again'", "again", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("test", "testing"); + t.assert_equal("prefix should be 'test'", "test", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be 'ing'", "ing", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_common_suffix(testing & t) { + diff_split result = calculate_diff_split("123end", "456end"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be '123'", "123", result.left); + t.assert_equal("right should be '456'", "456", result.right); + t.assert_equal("suffix should be 'end'", "end", result.suffix); + + result = calculate_diff_split("start", "end"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'start'", "start", result.left); + t.assert_equal("right should be 'end'", "end", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("abcsuffix", "xyzsuffix"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'abc'", "abc", result.left); + t.assert_equal("right should be 'xyz'", "xyz", result.right); + t.assert_equal("suffix should be 'suffix'", "suffix", result.suffix); +} + +static void test_calculate_diff_split_common_both(testing & t) { + diff_split result = calculate_diff_split("helloXworld", "helloYworld"); + t.assert_equal("prefix should be 'hello'", "hello", result.prefix); + t.assert_equal("left should be 'X'", "X", result.left); + t.assert_equal("right should be 'Y'", "Y", result.right); + t.assert_equal("suffix should be 'world'", "world", result.suffix); + + result = calculate_diff_split("ABCmiddleXYZ", "ABCdifferentXYZ"); + t.assert_equal("prefix should be 'ABC'", "ABC", result.prefix); + t.assert_equal("left should be 'middle'", "middle", result.left); + t.assert_equal("right should be 'different'", "different", result.right); + t.assert_equal("suffix should be 'XYZ'", "XYZ", result.suffix); + + result = calculate_diff_split("startAend", "startBend"); + t.assert_equal("prefix should be 'start'", "start", result.prefix); + t.assert_equal("left should be 'A'", "A", result.left); + t.assert_equal("right should be 'B'", "B", result.right); + t.assert_equal("suffix should be 'end'", "end", result.suffix); + + // Edge case: common prefix and suffix overlap + result = calculate_diff_split("aa", "ab"); + t.assert_equal("prefix should be 'a'", "a", result.prefix); + t.assert_equal("left should be 'a'", "a", result.left); + t.assert_equal("right should be 'b'", "b", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_empty_cases(testing & t) { + // Empty left, non-empty right + diff_split result = calculate_diff_split("", "hello"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be 'hello'", "hello", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Non-empty left, empty right + result = calculate_diff_split("hello", ""); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'hello'", "hello", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Both empty + result = calculate_diff_split("", ""); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Left single char, empty right + result = calculate_diff_split("a", ""); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'a'", "a", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Empty left, right single char + result = calculate_diff_split("", "a"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be 'a'", "a", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_no_common(testing & t) { + diff_split result = calculate_diff_split("abc", "xyz"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'abc'", "abc", result.left); + t.assert_equal("right should be 'xyz'", "xyz", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("left", "right"); + // The algorithm finds "t" as a common suffix since both strings end with 't' + // This is the algorithm's actual behavior - it finds maximal common suffix + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'lef'", "lef", result.left); + t.assert_equal("right should be 'righ'", "righ", result.right); + t.assert_equal("suffix should be 't'", "t", result.suffix); + + result = calculate_diff_split("123", "456"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be '123'", "123", result.left); + t.assert_equal("right should be '456'", "456", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_single_char(testing & t) { + diff_split result = calculate_diff_split("a", "b"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'a'", "a", result.left); + t.assert_equal("right should be 'b'", "b", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("a", "a"); + t.assert_equal("prefix should be 'a'", "a", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("a", "ab"); + t.assert_equal("prefix should be 'a'", "a", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be 'b'", "b", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("ab", "a"); + t.assert_equal("prefix should be 'a'", "a", result.prefix); + t.assert_equal("left should be 'b'", "b", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_overlaps(testing & t) { + // One string is substring of another + diff_split result = calculate_diff_split("test", "testing"); + t.assert_equal("prefix should be 'test'", "test", result.prefix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be 'ing'", "ing", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + result = calculate_diff_split("testing", "test"); + t.assert_equal("prefix should be 'test'", "test", result.prefix); + t.assert_equal("left should be 'ing'", "ing", result.left); + t.assert_equal("right should be empty", "", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Similar strings with one extra char at start + result = calculate_diff_split("Xtest", "Ytest"); + // The algorithm finds "test" as a common suffix since both strings end with "test" + // This is the algorithm's actual behavior - it finds maximal common suffix + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'X'", "X", result.left); + t.assert_equal("right should be 'Y'", "Y", result.right); + t.assert_equal("suffix should be 'test'", "test", result.suffix); + + // Similar strings with one extra char at end + result = calculate_diff_split("testX", "testY"); + t.assert_equal("prefix should be 'test'", "test", result.prefix); + t.assert_equal("left should be 'X'", "X", result.left); + t.assert_equal("right should be 'Y'", "Y", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Strings that are reverses + result = calculate_diff_split("abc", "cba"); + t.assert_equal("prefix should be empty", "", result.prefix); + t.assert_equal("left should be 'abc'", "abc", result.left); + t.assert_equal("right should be 'cba'", "cba", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); +} + +static void test_calculate_diff_split_tag_boundaries(testing & t) { + // Test with unclosed XML tags + diff_split result = calculate_diff_split("testcontent"); + // The fix_tag_boundaries should move incomplete tags appropriately + t.assert_true("prefix should start with 'test'", result.prefix.find("test") == 0); + t.assert_true("should handle tag boundaries", result.left != "" || result.right != "" || result.suffix != ""); + + // Test with unclosed brackets + result = calculate_diff_split("test[", "test]value"); + t.assert_true("should handle bracket boundaries", result.left != "" || result.right != "" || result.suffix != ""); + + // Test with partial tags on both sides + result = calculate_diff_split("prefix", "prefixsuffix"); + // fix_tag_boundaries moves the incomplete '<' from prefix to left/right + t.assert_equal("prefix should be 'prefix'", "prefix", result.prefix); + t.assert_equal("left should be ''", "", result.left); + t.assert_equal("right should be 'suffix'", "suffix", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Test with complex nested tags + result = calculate_diff_split("prefix
content
", "prefix
different
"); + // Algorithm finds "ent" as a common suffix because both strings end with it + // This is the actual algorithm behavior, though not semantically ideal + t.assert_equal("prefix should be 'prefix
'", "prefix
", result.prefix); + t.assert_equal("left should be 'cont'", "cont", result.left); + t.assert_equal("right should be 'differ'", "differ", result.right); + t.assert_equal("suffix should be 'ent
'", "ent
", result.suffix); + + // Test with unclosed angle bracket + result = calculate_diff_split("Hello ", "Hello test"); + t.assert_equal("prefix should be 'Hello '", "Hello ", result.prefix); + t.assert_true("left should contain ''", result.left.find("") != std::string::npos); + t.assert_equal("right should be 'test'", "test", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Test with unclosed square bracket + result = calculate_diff_split("test [array]", "test other"); + t.assert_equal("prefix should be 'test '", "test ", result.prefix); + t.assert_true("left should contain '[array]'", result.left.find("[array]") != std::string::npos); + t.assert_equal("right should be 'other'", "other", result.right); + t.assert_equal("suffix should be empty", "", result.suffix); + + // Test empty prefix and suffix with tags + result = calculate_diff_split("left", "righ"); + t.assert_equal("prefix should be ''", "", result.prefix); + t.assert_equal("left should be 'left'", "left", result.left); + t.assert_equal("right should be 'righ'", "righ", result.right); + t.assert_equal("suffix should be ''", "", result.suffix); + + { + // real case from template tests, simplified + std::string left = "PREFIX
Sure"; + std::string right = "PREFIXLemme thinkSure"; + result = calculate_diff_split(left, right); + t.assert_equal("prefix should be PREFIX", "PREFIX", result.prefix); + t.assert_equal("suffix should be Sure", "Sure", result.suffix); + t.assert_equal("left should be empty", "", result.left); + t.assert_equal("right should be Lemme think", "Lemme think", result.right); + } + + { + // Real case: special tokens with |> boundary issue + // The suffix starts with |> which should be moved to complete <|END_RESPONSE and <|END_ACTION + std::string prefix = "SOME_PREFIX"; + std::string suffix = "|><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"; + std::string left_diff = "<|START_RESPONSE|>Let me help you.<|END_RESPONSE"; + std::string right_diff = + "<|START_THINKING|><|END_THINKING|><|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"test_function_name\", " + "\"parameters\": {\"param1\": \"value1\", \"param2\": \"value2\"}}\n" + "]<|END_ACTION"; + + std::string left = prefix + left_diff + suffix; + std::string right = prefix + right_diff + suffix; + result = calculate_diff_split(left, right); + + t.assert_equal("special token prefix", prefix, result.prefix); + // The |> should be moved from suffix to complete the tokens + t.assert_equal("special token left", "<|START_RESPONSE|>Let me help you.<|END_RESPONSE|>", result.left); + t.assert_true("special token right ends with |>", result.right.find("<|END_ACTION|>") != std::string::npos); + t.assert_equal("special token suffix", "<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>", + result.suffix); + } +} + +static void test_until_common_prefix(testing & t) { + t.test("until_common_prefix basic", test_until_common_prefix_basic); +} + +static void test_until_common_prefix_basic(testing & t) { + // Test case from the user request + std::string result = until_common_prefix("", "", ""); + t.assert_equal("untilCommonPrefix should return ''", "", result); + + // Additional test cases to ensure robustness + // Test with different common prefix lengths + result = until_common_prefix("prefixsuffix", "different", "other"); + t.assert_equal("should return 'prefix'", "prefix", result); + + // Test when common prefix is at the start + result = until_common_prefix("rest", "left", "right"); + t.assert_equal("should return empty string when common prefix at start", "", result); + + // Test when there's no common prefix + result = until_common_prefix("something", "left", "right"); + t.assert_equal("should return empty string when no common prefix", "", result); + + // Test with empty strings + result = until_common_prefix("test", "", "right"); + t.assert_equal("should return empty string when left is empty", "", result); + + // Test with longer common prefix + result = until_common_prefix("abcXYZrest", "left", "right"); + t.assert_equal("should return 'abcXYZ'", "abcXYZ", result); +} + +static void test_after_common_suffix(testing & t) { + t.test("after_common_suffix basic", test_after_common_suffix_basic); +} + +static void test_after_common_suffix_basic(testing & t) { + // Test case from the user request + std::string result = after_common_suffix("100", + "100", + "535"); + t.assert_equal("afterCommonSuffix should return ''", "", result); + + // Test when common suffix is at the end + result = after_common_suffix("rest", "left", "right"); + t.assert_equal("should return empty string when common suffix at end", "", result); + + // Test with empty strings + result = after_common_suffix("test", "left", ""); + t.assert_equal("should return empty string when right is empty", "", result); + + // Test case with XML-like structure similar to the main example + result = after_common_suffix("value", + "value", + "different"); + t.assert_equal("should return ''", "", result); + + // Test with longer common suffix appearing at the end of full + result = after_common_suffix("prefixrest", "prefixleft", "prefixright"); + t.assert_equal("should return '' when common suffix is at end of full", "", result); + + // Test with common suffix appearing in middle but not at end + result = after_common_suffix("content", "value", "other"); + t.assert_equal("should return '' when common suffix appears before end", "", result); + + // Test with multi-character common suffix at the very end of full + result = after_common_suffix("startend", "prefixleft", "prefixright"); + t.assert_equal("should return '' when common suffix is at end of full", "", result); +} + +static void test_compare_variants(testing & t) { + t.test("compare_variants basic", test_compare_variants_basic); + t.test("compare_variants messages modifier", test_compare_variants_messages_modifier); + t.test("compare_variants tools modifier", test_compare_variants_tools_modifier); + t.test("compare_variants both modifiers", test_compare_variants_both_modifiers); + t.test("compare_variants template failure", test_compare_variants_template_failure); + t.test("compare_variants identity", test_compare_variants_identity); +} + +static void test_compare_variants_basic(testing & t) { + // Create a simple template that just echoes messages + common_chat_template tmpl("{{ messages[0]['content'] }}", "", ""); + + template_params params; + params.messages = json::array({ + json {{"role", "user"}, {"content", "Hello"}} + }); + + auto modifier = [](template_params & p) { + p.messages[0]["content"] = "World"; + }; + + auto result = ::compare_variants(tmpl, params, modifier); + + if (!t.assert_true("result should have value", result.has_value())) { + return; + } + // The template might not output anything if messages is empty or format is different + // Check that we get a valid result + t.assert_true("prefix or left should have content", !result->diff.prefix.empty() || !result->diff.left.empty()); +} + +static void test_compare_variants_messages_modifier(testing & t) { + // Test with messages modifier only + common_chat_template tmpl("{% for message in messages %}{{ message['role'] }}:{{ message['content'] }}{% endfor %}", "", ""); + + template_params params; + params.messages = json::array({ + json {{"role", "user"}, {"content", "A"}} + }); + + auto modifier = [](template_params & p) { + p.messages[0]["content"] = "B"; + }; + + std::optional result = ::compare_variants(tmpl, params, modifier); + + if (!t.assert_true("result should have value", result.has_value())) { + return; + } + t.assert_equal("left should be 'A'", "A", result->diff.left); + t.assert_equal("right should be 'B'", "B", result->diff.right); +} + +static void test_compare_variants_tools_modifier(testing & t) { + // Test with tools modifier only + common_chat_template tmpl( + "{% for tool in tools %}{{ tool['name'] }}{% endfor %}", "", ""); + + template_params params; + params.tools = json::array({ + json {{"name", "foo"}} + }); + + auto modifier = [](template_params & p) { + p.tools[0]["name"] = "bar"; + }; + + auto result = ::compare_variants(tmpl, params, modifier); + + if (!t.assert_true("result should have value", result.has_value())) { + return; + } + t.assert_equal("left should be 'foo'", "foo", result->diff.left); + t.assert_equal("right should be 'bar'", "bar", result->diff.right); +} + +static void test_compare_variants_both_modifiers(testing & t) { + // Test with both messages and tools modifiers using the for loop approach + common_chat_template tmpl( + "{% for message in messages %}{{ message['role'] }}:{{ message['content'] }}{% endfor %}", "", ""); + + template_params params; + params.messages = json::array({ + json {{"role", "user"}, {"content", "A"}} + }); + + auto modifier = [](template_params & p) { + p.messages[0]["content"] = "B"; + p.messages[0]["role"] = "newuser"; + }; + + auto result = ::compare_variants(tmpl, params, modifier); + + if (!t.assert_true("result should have value", result.has_value())) { + return; + } + t.assert_equal("left should be 'user:A'", "user:A", result->diff.left); + t.assert_equal("right should be 'newuser:B'", "newuser:B", result->diff.right); +} + +static void test_compare_variants_template_failure(testing & t) { + // Test with template that causes failure during application (not construction) + // We use a valid template syntax but one that will fail during application + common_chat_template tmpl("{{ messages[0]['nonexistent_field'] }}", "", ""); + + template_params params; + params.messages = json::array({ + json {{"role", "user"}, {"content", "Hello"}} + }); + + auto modifier = [](template_params & p) { + p.messages[0]["content"] = "World"; + }; + + auto result = ::compare_variants(tmpl, params, modifier); + + t.assert_true("result should be nullopt on template failure", !result.has_value()); +} + +static void test_compare_variants_identity(testing & t) { + // Test with identity modifier (no change) + common_chat_template tmpl("{{ messages[0]['content'] }}", "", ""); + + template_params params; + params.messages = json::array({ + json {{"role", "user"}, {"content", "Hello"}} + }); + + // No modifier - should use identity + auto result = ::compare_variants(tmpl, params, nullptr); + + if (!t.assert_true("result should have value", result.has_value())) { + return; + } + t.assert_equal("prefix should be 'Hello'", "Hello", result->diff.prefix); + t.assert_equal("left should be empty", "", result->diff.left); + t.assert_equal("right should be empty", "", result->diff.right); + t.assert_equal("suffix should be empty", "", result->diff.suffix); +} + +// ============================================================================ +// Seed-OSS Template Tool Calling Analysis Tests +// ============================================================================ + +static void test_seed_oss_tool_analysis(testing & t) { + t.test("Seed-OSS tool presence", test_seed_oss_tool_presence); + t.test("Seed-OSS call count", test_seed_oss_call_count); + t.test("Seed-OSS function names", test_seed_oss_function_names); + t.test("Seed-OSS argument count", test_seed_oss_argument_count); + t.test("Seed-OSS args presence", test_seed_oss_args_presence); + t.test("Seed-OSS tool with reasoning", test_seed_oss_tool_with_reasoning); +} + +// Helper to load Seed-OSS template +static common_chat_template load_seed_oss_template(testing & t) { + std::string template_path = "models/templates/ByteDance-Seed-OSS.jinja"; + std::ifstream fin(template_path, std::ios::binary); + std::ostringstream buf; + if (fin.is_open()) { + buf << fin.rdbuf(); + } + std::string template_source = buf.str(); + common_chat_template tmpl(template_source, "", ""); + t.assert_true("Seed-OSS template loaded successfully", template_source.length() > 0); + return tmpl; +} + +// Helper to build tool call JSON +static json build_tool_call(const std::string & name, const json & args, const std::string & id = "call_001") { + return json{ + {"id", id}, + {"type", "function"}, + {"function", json{ + {"name", name}, + {"arguments", args} + }} + }; +} + +// Helper to build tools definition +static json build_tools_definition() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["param1"] = json::object({ + {"type", "string"}, + {"description", "First parameter"} + }); + parameters_schema["properties"]["param2"] = json::object({ + {"type", "string"}, + {"description", "Second parameter"} + }); + parameters_schema["required"] = json::array({"param1", "param2"}); + + return json::array({ + json{ + {"type", "function"}, + {"function", json{ + {"name", "test_function_name"}, + {"description", "A test function for debugging"}, + {"parameters", parameters_schema} + }} + } + }); +} + +// T1: Compare with/without tool call (user, assistant) +static void test_seed_oss_tool_presence(testing & t) { + common_chat_template tmpl = load_seed_oss_template(t); + + json assistant_no_tools = json{ + {"role", "assistant"}, + {"content", "Let me help you."} + }; + + json assistant_with_tools = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })} + }; + + json user_msg = json{ + {"role", "user"}, + {"content", "Hello, please help me."} + }; + + template_params params_no_tools; + params_no_tools.messages = json::array({user_msg, assistant_no_tools}); + params_no_tools.tools = build_tools_definition(); + params_no_tools.add_generation_prompt = false; + params_no_tools.enable_thinking = true; + + template_params params_with_tools; + params_with_tools.messages = json::array({user_msg, assistant_with_tools}); + params_with_tools.tools = build_tools_definition(); + params_with_tools.add_generation_prompt = false; + params_with_tools.enable_thinking = true; + + auto result = ::compare_variants(tmpl, params_no_tools, + [&](template_params & p) { + p.messages = params_with_tools.messages; + }); + + if (!t.assert_true("T1 result should have value", result.has_value())) { + return; + } + + const auto & diff = result->diff; + t.assert_true("T1 prefix should contain system", diff.prefix.find("system") != std::string::npos); + t.assert_true("T1 prefix should contain user", diff.prefix.find("user") != std::string::npos); + t.assert_true("T1 prefix should contain assistant", diff.prefix.find("assistant") != std::string::npos); + + // Left should be the assistant content without tool + t.assert_equal("T1 left should contain 'Let me help you.'", "Let me help you.", diff.left); + + // Right should contain the tool call markers + t.assert_true("T1 right should contain tool_call begin", diff.right.find("") != std::string::npos); + t.assert_true("T1 right should contain function tag", diff.right.find("") != std::string::npos); + t.assert_true("T1 right should contain parameter=param1", diff.right.find("") != std::string::npos); + t.assert_true("T1 right should contain parameter=param2", diff.right.find("") != std::string::npos); + t.assert_true("T1 right should contain value1", diff.right.find("value1") != std::string::npos); + t.assert_true("T1 right should contain value2", diff.right.find("value2") != std::string::npos); + t.assert_true("T1 right should contain tool_call end", diff.right.find("") != std::string::npos); + + // Suffix should be the eos token + t.assert_equal("T1 suffix should be ''", "", diff.suffix); +} + +// T2: Compare one vs two tool calls +static void test_seed_oss_call_count(testing & t) { + common_chat_template tmpl = load_seed_oss_template(t); + + json assistant_one_call = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })} + }; + + json assistant_two_calls = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})), + build_tool_call("test_function_name", json::object({{"param1", "value3"}, {"param2", "value4"}}), "call_002") + })} + }; + + json user_msg = json{ + {"role", "user"}, + {"content", "Hello, please help me."} + }; + + template_params params_one; + params_one.messages = json::array({user_msg, assistant_one_call}); + params_one.tools = build_tools_definition(); + params_one.add_generation_prompt = false; + params_one.enable_thinking = true; + + auto result = ::compare_variants(tmpl, params_one, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_two_calls}); + }); + + if (!t.assert_true("T2 result should have value", result.has_value())) { + return; + } + + const auto & diff = result->diff; + + // Prefix should include the first tool call + t.assert_true("T2 prefix should contain first tool_call begin", diff.prefix.find("") != std::string::npos); + t.assert_true("T2 prefix should contain first function", diff.prefix.find("") != std::string::npos); + t.assert_true("T2 prefix should contain value1", diff.prefix.find("value1") != std::string::npos); + t.assert_true("T2 prefix should contain value2", diff.prefix.find("value2") != std::string::npos); + t.assert_true("T2 prefix should contain first tool_call end", diff.prefix.find("") != std::string::npos); + + // Left should be empty (no second tool call in variant A) + t.assert_equal("T2 left should be empty", "", diff.left); + + // Right should contain the second tool call + t.assert_true("T2 right should contain second tool_call begin", diff.right.find("") != std::string::npos); + t.assert_true("T2 right should contain second function", diff.right.find("") != std::string::npos); + t.assert_true("T2 right should contain value3", diff.right.find("value3") != std::string::npos); + t.assert_true("T2 right should contain value4", diff.right.find("value4") != std::string::npos); + t.assert_true("T2 right should contain second tool_call end", diff.right.find("") != std::string::npos); + + // Suffix should end with the eos token + t.assert_equal("T2 suffix should end with ''", "", diff.suffix.substr(diff.suffix.length() - 10, 10)); +} + +// T3: Compare different function names +static void test_seed_oss_function_names(testing & t) { + common_chat_template tmpl = load_seed_oss_template(t); + + // Build tools with two different function names + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["arg1"] = json::object({ + {"type", "string"}, + {"description", "Argument 1"} + }); + parameters_schema["required"] = json::array({"arg1"}); + + json tools = json::array({ + json{ + {"type", "function"}, + {"function", json{ + {"name", "func_alpha"}, + {"description", "First function"}, + {"parameters", parameters_schema} + }} + }, + json{ + {"type", "function"}, + {"function", json{ + {"name", "func_beta"}, + {"description", "Second function"}, + {"parameters", parameters_schema} + }} + } + }); + + json assistant_func_alpha = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("func_alpha", json::object({{"arg1", "test_value"}})) + })} + }; + + json assistant_func_beta = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("func_beta", json::object({{"arg1", "test_value"}})) + })} + }; + + json user_msg = json{ + {"role", "user"}, + {"content", "Hello"} + }; + + template_params params_alpha; + params_alpha.messages = json::array({user_msg, assistant_func_alpha}); + params_alpha.tools = tools; + params_alpha.add_generation_prompt = false; + params_alpha.enable_thinking = true; + + auto result = ::compare_variants(tmpl, params_alpha, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_func_beta}); + }); + + if (!t.assert_true("T3 result should have value", result.has_value())) { + return; + } + + const auto & diff = result->diff; + + bool func_alpha_in_left = diff.left.find("func_alpha") != std::string::npos; + bool func_alpha_in_prefix = diff.prefix.find("func_alpha") != std::string::npos; + bool func_beta_in_right = diff.right.find("func_beta") != std::string::npos; + bool func_beta_in_prefix = diff.prefix.find("func_beta") != std::string::npos; + bool func_beta_in_suffix = diff.suffix.find("func_beta") != std::string::npos; + + // Left should contain func_alpha (or be in prefix) + t.assert_true("T3 left should contain func_alpha (or prefix)", func_alpha_in_left || func_alpha_in_prefix); + + // Right should contain func_beta + t.assert_true("T3 right should contain func_beta", func_beta_in_right || func_beta_in_prefix || func_beta_in_suffix); + + // Both should have the same parameter value (in common parts, not in diffs) + // Since both have same args, test_value will be in prefix/suffix + t.assert_true("T3 diff should contain test_value (in prefix or suffix)", + diff.prefix.find("test_value") != std::string::npos || diff.suffix.find("test_value") != std::string::npos); +} + +// T4: Compare different argument counts (zero, one, two parameters) +static void test_seed_oss_argument_count(testing & t) { + common_chat_template tmpl = load_seed_oss_template(t); + + // Build tools with 0, 1, or 2 required parameters + json params_2_required = json::object(); + params_2_required["type"] = "object"; + params_2_required["properties"] = json::object(); + params_2_required["properties"]["arg1"] = json::object({ + {"type", "string"}, + {"description", "Argument 1"} + }); + params_2_required["properties"]["arg2"] = json::object({ + {"type", "string"}, + {"description", "Argument 2"} + }); + params_2_required["required"] = json::array({"arg1", "arg2"}); + + json params_1_required = json::object(); + params_1_required["type"] = "object"; + params_1_required["properties"] = json::object(); + params_1_required["properties"]["arg1"] = json::object({ + {"type", "string"}, + {"description", "Argument 1"} + }); + params_1_required["required"] = json::array({"arg1"}); + + json tools = json::array({ + json{ + {"type", "function"}, + {"function", json{ + {"name", "test_func"}, + {"description", "Test function"}, + {"parameters", params_2_required} + }} + } + }); + + // Test: zero args vs one arg + json assistant_zero_args = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_func", json::object()) + })} + }; + + json assistant_one_arg = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_func", json::object({{"arg1", "value1"}})) + })} + }; + + json assistant_two_args = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_func", json::object({{"arg1", "value1"}, {"arg2", "value2"}})) + })} + }; + + json user_msg = json{ + {"role", "user"}, + {"content", "Hello"} + }; + + // Test zero vs one + template_params params_zero; + params_zero.messages = json::array({user_msg, assistant_zero_args}); + params_zero.tools = tools; + params_zero.add_generation_prompt = false; + params_zero.enable_thinking = true; + + auto result_zero_one = ::compare_variants(tmpl, params_zero, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_one_arg}); + }); + + if (!t.assert_true("T4 zero vs one result should have value", result_zero_one.has_value())) { + return; + } + t.assert_true("T4 zero vs one left should be empty or minimal", result_zero_one->diff.left.empty() || result_zero_one->diff.left == ""); + t.assert_true("T4 zero vs one right should contain arg1", result_zero_one->diff.right.find("arg1") != std::string::npos); + + // Test one vs two + template_params params_one; + params_one.messages = json::array({user_msg, assistant_one_arg}); + params_one.tools = tools; + params_one.add_generation_prompt = false; + params_one.enable_thinking = true; + + auto result_one_two = ::compare_variants(tmpl, params_one, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_two_args}); + }); + + if (!t.assert_true("T4 one vs two result should have value", result_one_two.has_value())) { + return; + } + + const auto & diff4 = result_one_two->diff; + t.assert_true("T4 one vs two left should contain arg1 (or prefix)", + diff4.left.find("arg1") != std::string::npos || diff4.prefix.find("arg1") != std::string::npos); + t.assert_true("T4 one vs two right should contain arg1 (or prefix)", + diff4.right.find("arg1") != std::string::npos || diff4.prefix.find("arg1") != std::string::npos); + t.assert_true("T4 one vs two right should contain arg2 (or prefix/suffix)", + diff4.right.find("arg2") != std::string::npos || diff4.prefix.find("arg2") != std::string::npos || diff4.suffix.find("arg2") != std::string::npos); +} + +// T5: Compare different argument values +static void test_seed_oss_args_presence(testing & t) { + common_chat_template tmpl = load_seed_oss_template(t); + + json assistant_same_arg = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}})) + })} + }; + + json assistant_other_arg = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param2", "value2"}})) + })} + }; + + json assistant_both_args = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })} + }; + + json user_msg = json{ + {"role", "user"}, + {"content", "Hello"} + }; + + template_params params_same; + params_same.messages = json::array({user_msg, assistant_same_arg}); + params_same.tools = build_tools_definition(); + params_same.add_generation_prompt = false; + params_same.enable_thinking = true; + + // Test same arg vs other arg + auto result_same_other = ::compare_variants(tmpl, params_same, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_other_arg}); + }); + + if (!t.assert_true("T5 same vs other result should have value", result_same_other.has_value())) { + return; + } + const auto & diff5a = result_same_other->diff; + t.assert_true("T5 same vs other left should contain param1 (or prefix/suffix)", + diff5a.left.find("param1") != std::string::npos || diff5a.prefix.find("param1") != std::string::npos || diff5a.suffix.find("param1") != std::string::npos); + t.assert_true("T5 same vs other left should contain value1 (or prefix/suffix)", + diff5a.left.find("value1") != std::string::npos || diff5a.prefix.find("value1") != std::string::npos); + t.assert_true("T5 same vs other right should contain param2 (or prefix/suffix)", + diff5a.right.find("param2") != std::string::npos || diff5a.prefix.find("param2") != std::string::npos || diff5a.suffix.find("param2") != std::string::npos); + t.assert_true("T5 same vs other right should contain value2 (or prefix/suffix)", + diff5a.right.find("value2") != std::string::npos || diff5a.prefix.find("value2") != std::string::npos || diff5a.suffix.find("value2") != std::string::npos); + + // Test same arg vs both args + auto result_same_both = ::compare_variants(tmpl, params_same, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_both_args}); + }); + + if (!t.assert_true("T5 same vs both result should have value", result_same_both.has_value())) { + return; + } + const auto & diff5b = result_same_both->diff; + t.assert_true("T5 same vs both left should contain param1 (or prefix/suffix)", + diff5b.left.find("param1") != std::string::npos || diff5b.prefix.find("param1") != std::string::npos || diff5b.suffix.find("param1") != std::string::npos); + t.assert_true("T5 same vs both right should contain param1 (or prefix/suffix)", + diff5b.right.find("param1") != std::string::npos || diff5b.prefix.find("param1") != std::string::npos || diff5b.suffix.find("param1") != std::string::npos); + t.assert_true("T5 same vs both right should contain param2 (or prefix/suffix)", + diff5b.right.find("param2") != std::string::npos || diff5b.prefix.find("param2") != std::string::npos || diff5b.suffix.find("param2") != std::string::npos); +} + +// T6: Tool call with vs without reasoning_content +static void test_seed_oss_tool_with_reasoning(testing & t) { + common_chat_template tmpl = load_seed_oss_template(t); + + json assistant_tool_only = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })} + }; + + json assistant_tool_with_reasoning = json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })}, + {"reasoning_content", "I need to call the tool first."} + }; + + json user_msg = json{ + {"role", "user"}, + {"content", "Hello, please help me."} + }; + + template_params params_tool_only; + params_tool_only.messages = json::array({user_msg, assistant_tool_only}); + params_tool_only.tools = build_tools_definition(); + params_tool_only.add_generation_prompt = false; + params_tool_only.enable_thinking = true; + + auto result = ::compare_variants(tmpl, params_tool_only, + [&](template_params & p) { + p.messages = json::array({user_msg, assistant_tool_with_reasoning}); + }); + + if (!t.assert_true("T6 result should have value", result.has_value())) { + return; + } + + const auto & diff = result->diff; + + // Left should be empty (no reasoning in variant A) + t.assert_equal("T6 left should be empty", "", diff.left); + + // Right should contain the thinking token with reasoning content + t.assert_true("T6 right should contain think begin", diff.right.find("") != std::string::npos); + t.assert_true("T6 right should contain reasoning content", diff.right.find("I need to call the tool first.") != std::string::npos); + t.assert_true("T6 right should contain think end", diff.right.find("") != std::string::npos); + + // Prefix should contain the assistant role + t.assert_true("T6 prefix should contain assistant", diff.prefix.find("assistant") != std::string::npos); + + // Suffix should contain the tool call + t.assert_true("T6 suffix should contain tool_call begin", diff.suffix.find("") != std::string::npos); + t.assert_true("T6 suffix should contain function name", diff.suffix.find("test_function_name") != std::string::npos); + t.assert_true("T6 suffix should contain eos", diff.suffix.find("") != std::string::npos); +} + +static common_chat_template load_template(testing & t, const std::string & template_path) { + std::ifstream fin(template_path, std::ios::binary); + std::ostringstream buf; + if (fin.is_open()) { + buf << fin.rdbuf(); + } + std::string template_source = buf.str(); + common_chat_template tmpl(template_source, "", ""); + t.assert_true("Nemotron template loaded successfully", template_source.length() > 0); + return tmpl; +} + +// ============================================================================ +// Nemotron Template Analysis Tests +// ============================================================================ +static common_chat_template load_nemotron_template(testing & t) { + return load_template(t, "models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja"); +} + +static void test_nemotron_analysis(testing & t) { + t.test("Nemotron reasoning detection", test_nemotron_reasoning_detection); + t.test("Nemotron tool format", test_nemotron_tool_format); +} + +static void test_nemotron_reasoning_detection(testing & t) { + common_chat_template tmpl = load_nemotron_template(t); + + // Test the comparison manually to see what's happening + json user_msg = json{ { "role", "user" }, { "content", "Hello" } }; + json assistant_no_reasoning = json{ + { "role", "assistant" }, + { "content", "I can help." } + }; + json assistant_with_reasoning = json{ + { "role", "assistant" }, + { "content", "I can help." }, + { "reasoning_content", "Let me think about this." } + }; + + template_params params; + params.messages = json::array({ user_msg, assistant_no_reasoning }); + params.add_generation_prompt = false; + params.enable_thinking = true; + + // Run differential analysis + struct autoparser analysis; + analysis.analyze_template(tmpl); + + // Check reasoning markers + t.assert_equal("reasoning_start should be ''", "", analysis.reasoning.start); + t.assert_equal("reasoning_end should be '\\n'", "\n", analysis.reasoning.end); + + // Check reasoning mode detection + // Nemotron uses forced closed reasoning with add_generation_prompt + t.assert_equal("reasoning should be FORCED_CLOSED", reasoning_mode::FORCED_CLOSED, analysis.reasoning.mode); + + // Make sure reasoning markers don't spill over to content markers + t.assert_equal("content start should be empty", "", analysis.content.start); + t.assert_equal("content end should be empty", "", analysis.content.end); + + t.assert_equal("content should be PLAIN", content_mode::PLAIN, analysis.content.mode); +} + +static void test_nemotron_tool_format(testing & t) { + common_chat_template tmpl = load_nemotron_template(t); + + // Run differential analysis + struct autoparser analysis; + analysis.analyze_template(tmpl); + + // Check tool markers - Nemotron uses per-call wrapping (each call individually wrapped) + t.assert_equal("tool_section_start should be empty (per-call format)", "", analysis.tools.format.section_start); + t.assert_equal("tool_section_end should be empty (per-call format)", "", analysis.tools.format.section_end); + t.assert_equal("per_call_start should be '\\n'", "\n", analysis.tools.format.per_call_start); + t.assert_equal("per_call_end should be ''", "", analysis.tools.format.per_call_end); + t.assert_true("should support parallel calls", analysis.jinja_caps.supports_parallel_tool_calls); + + // Check function markers + t.assert_equal("func_name_prefix should be '\\n'", ">\n", analysis.tools.function.name_suffix); + t.assert_equal("func_close should be '\\n'", "\n", analysis.tools.function.close); + + // Check argument markers (note: markers retain trailing newlines for proper parsing) + t.assert_equal("arg_name_prefix should be '\\n'", ">\n", analysis.tools.arguments.name_suffix); + t.assert_equal("arg_value_suffix should be '\\n'", "\n", analysis.tools.arguments.value_suffix); + + // Check format classification + t.assert_true("tool format should be TAG_WITH_TAGGED", analysis.tools.format.mode == tool_format::TAG_WITH_TAGGED); + + // Verify tool support + t.assert_true("should support tools", analysis.jinja_caps.supports_tools); +} + +static common_chat_template load_cohere_template(testing & t) { + return load_template(t, "models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"); +} + +static void test_cohere_analysis(testing & t) { + t.test("Cohere reasoning detection", test_cohere_reasoning_detection); +} + +static void test_cohere_reasoning_detection(testing & t) { + common_chat_template tmpl = load_cohere_template(t); + + // Run differential analysis + struct autoparser analysis; + analysis.analyze_template(tmpl); + + // Check reasoning markers - Cohere uses special token format + t.assert_equal("reasoning_start should be '<|START_THINKING|>'", "<|START_THINKING|>", analysis.reasoning.start); + t.assert_equal("reasoning_end should be '<|END_THINKING|>'", "<|END_THINKING|>", analysis.reasoning.end); + + // Check reasoning mode - Cohere only shows reasoning with tool calls (TOOLS_ONLY) + t.assert_equal("reasoning should be TOOLS_ONLY", reasoning_mode::TOOLS_ONLY, analysis.reasoning.mode); + + // Check content markers - Cohere wraps all content with START/END_RESPONSE + t.assert_equal("content_start should be '<|START_RESPONSE|>'", "<|START_RESPONSE|>", analysis.content.start); + t.assert_equal("content_end should be '<|END_RESPONSE|>'", "<|END_RESPONSE|>", analysis.content.end); + + // Content is always wrapped (both with and without tools) + t.assert_equal("content should be ALWAYS_WRAPPED", content_mode::ALWAYS_WRAPPED, analysis.content.mode); +} + +static void test_tool_format_cohere(testing & t) { + common_chat_template tmpl = load_cohere_template(t); + + // Run differential analysis + struct autoparser analysis; + analysis.analyze_template(tmpl); + + // Check tool section markers - Cohere uses ACTION markers + t.assert_equal("tool_section_start should be '<|START_ACTION|>'", "<|START_ACTION|>", analysis.tools.format.section_start); + t.assert_equal("tool_section_end should be '<|END_ACTION|>'", "<|END_ACTION|>", analysis.tools.format.section_end); + + // JSON_NATIVE format has no per-call markers + t.assert_equal("per_call_start should be empty", "", analysis.tools.format.per_call_start); + t.assert_equal("per_call_end should be empty", "", analysis.tools.format.per_call_end); + + // JSON_NATIVE format has empty function markers (no XML-style markers) + t.assert_equal("func_name_prefix should be empty", "", analysis.tools.function.name_prefix); + t.assert_equal("func_name_suffix should be empty", "", analysis.tools.function.name_suffix); + t.assert_equal("func_close should be empty", "", analysis.tools.function.close); + + // JSON_NATIVE format has empty args markers + t.assert_equal("args_start should be empty", "", analysis.tools.arguments.start); + t.assert_equal("args_end should be empty", "", analysis.tools.arguments.end); + + // JSON_NATIVE format has empty argument markers + t.assert_equal("arg_name_prefix should be empty", "", analysis.tools.arguments.name_prefix); + t.assert_equal("arg_name_suffix should be empty", "", analysis.tools.arguments.name_suffix); + t.assert_equal("arg_value_prefix should be empty", "", analysis.tools.arguments.value_prefix); + t.assert_equal("arg_value_suffix should be empty", "", analysis.tools.arguments.value_suffix); + t.assert_equal("arg_separator should be empty", "", analysis.tools.arguments.separator); + + // Check JSON field names - Cohere uses non-standard names + t.assert_equal("name_field should be 'tool_name'", "tool_name", analysis.tools.format.name_field); + t.assert_equal("args_field should be 'parameters'", "parameters", analysis.tools.format.args_field); + // This isn't a real tool call id field, i.e. with the OpenAI tool call ID format + t.assert_equal("id_field should be 'tool_call_id'", "", analysis.tools.format.id_field); + + // Check format classification + t.assert_equal("tool format should be JSON_NATIVE", tool_format::JSON_NATIVE, analysis.tools.format.mode); + + // Check flags + t.assert_true("should support tools", analysis.jinja_caps.supports_tools); + t.assert_true("should support parallel calls", analysis.jinja_caps.supports_parallel_tool_calls); + t.assert_true("should not require nonnull content", !analysis.content.requires_nonnull_content); + t.assert_true("tools_array_wrapped should be true", analysis.tools.format.tools_array_wrapped); +} + +// ============================================================================ +// standard_json_tools Format Tests +// ============================================================================ + +// Helper to build tools definition for tests +static json build_test_tools() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["location"] = json::object({ + {"type", "string"}, + {"description", "The city and state"} + }); + parameters_schema["properties"]["unit"] = json::object({ + {"type", "string"}, + {"description", "Temperature unit"}, + {"enum", json::array({"celsius", "fahrenheit"})} + }); + parameters_schema["required"] = json::array({"location"}); + + return json::array({ + json{ + {"type", "function"}, + {"function", json{ + {"name", "get_current_weather"}, + {"description", "Get the current weather in a given location"}, + {"parameters", parameters_schema} + }} + } + }); +} + +static void test_standard_json_tools_formats(testing & t) { + t.test("OpenAI format", test_standard_json_tools_openai); + t.test("Cohere format", test_standard_json_tools_cohere); + t.test("function-as-key format", test_standard_json_tools_function_key); +} + +// Test 1: OpenAI Standard Format +// {"id": "call_abc", "function": {"name": "get_weather", "arguments": {"location": "NYC"}}} +static void test_standard_json_tools_openai(testing & t) { + json tools = build_test_tools(); + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto tool_call = p.standard_json_tools( + "", "", tools, + /* parallel */ true, + /* force */ false, + /* name_key */ "function.name", + /* args_key */ "function.arguments", + /* array_wrapped */ false, + /* function_is_key */ false, + /* call_id_key */ "id", + /* gen_call_id_key */ "", + /* parameters_order */ {} + ); + return p.content(p.until("")) + p.optional(tool_call) + p.end(); + }); + + std::string input = + "Let me check the weather." + "" + R"({"id": "call_abc123", "function": {"name": "get_current_weather", "arguments": {"location": "NYC"}}})" + ""; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + if (!t.assert_true("parse success", result.success())) { + return; + } + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "get_current_weather", msg.tool_calls[0].name); + t.assert_equal("tool id", "call_abc123", msg.tool_calls[0].id); + } + t.assert_true("content present", msg.content.find("Let me check the weather") != std::string::npos); +} + +// Test 2: Cohere Format +// {"tool_call_id": 0, "tool_name": "get_weather", "parameters": {"location": "NYC"}} +static void test_standard_json_tools_cohere(testing & t) { + json tools = build_test_tools(); + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto tool_call = p.standard_json_tools( + "<|START_ACTION|>[", "]<|END_ACTION|>", tools, + /* parallel */ true, + /* force */ false, + /* name_key */ "tool_name", + /* args_key */ "parameters", + /* array_wrapped */ false, // Brackets are part of section markers + /* function_is_key */ false, + /* call_id_key */ "", + /* gen_call_id_key */ "tool_call_id", + /* parameters_order */ {"tool_call_id", "tool_name", "parameters"} + ); + return p.content(p.until("<|START_ACTION|>")) + p.optional(tool_call) + p.end(); + }); + + std::string input = + "Let me search for that." + "<|START_ACTION|>[" + R"({"tool_call_id": 0, "tool_name": "get_current_weather", "parameters": {"location": "NYC", "unit": "celsius"}})" + "]<|END_ACTION|>"; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + if (!t.assert_true("parse success", result.success())) { + return; + } + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "get_current_weather", msg.tool_calls[0].name); + t.assert_equal("tool id", "0", msg.tool_calls[0].id); + } + t.assert_true("content present", msg.content.find("Let me search") != std::string::npos); +} + +// Test 3: Function-as-Key Format +// {"get_current_weather": {"id": "call-0001", "args": {"location": "NYC"}}} +static void test_standard_json_tools_function_key(testing & t) { + json tools = build_test_tools(); + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto tool_call = p.standard_json_tools( + "[", "]", tools, + /* parallel */ true, + /* force */ false, + /* name_key */ "", // Name is the key itself + /* args_key */ "args", + /* array_wrapped */ false, + /* function_is_key */ true, + /* call_id_key */ "id", + /* gen_call_id_key */ "", + /* parameters_order */ {} + ); + return p.content(p.until("")) + p.optional(tool_call) + p.end(); + }); + + std::string input = + "I'll call the weather function." + "[" + R"({"get_current_weather": {"id": "call-0001", "args": {"location": "NYC", "unit": "celsius"}}})" + "]"; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + if (!t.assert_true("parse success", result.success())) { + return; + } + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "get_current_weather", msg.tool_calls[0].name); + t.assert_equal("tool id", "call-0001", msg.tool_calls[0].id); + } + t.assert_true("content present", msg.content.find("I'll call the weather") != std::string::npos); +} + +// ============================================================================ +// normalize_quotes_to_json Tests +// ============================================================================ + +// Copy of the function for isolated testing (original is static in chat-peg-parser.cpp) +static std::string normalize_quotes_to_json(const std::string & input) { + std::string result; + result.reserve(input.size() + 16); + + bool in_single_quoted = false; + bool in_double_quoted = false; + + for (size_t i = 0; i < input.size(); ++i) { + char c = input[i]; + + if (c == '\\' && i + 1 < input.size()) { + char next = input[i + 1]; + + if (in_single_quoted) { + if (next == '\'') { + result += '\''; + ++i; + continue; + } + if (next == '"') { + result += "\\\""; + ++i; + continue; + } + result += c; + result += next; + ++i; + continue; + } + + if (in_double_quoted) { + result += c; + result += next; + ++i; + continue; + } + + result += c; + continue; + } + + if (c == '"') { + if (in_single_quoted) { + result += "\\\""; + } else { + in_double_quoted = !in_double_quoted; + result += c; + } + } else if (c == '\'') { + if (in_double_quoted) { + result += c; + } else if (in_single_quoted) { + in_single_quoted = false; + result += '"'; + } else { + in_single_quoted = true; + result += '"'; + } + } else { + result += c; + } + } + + return result; +} + +static void test_normalize_quotes_to_json(testing & t) { + t.test("basic single to double quotes", [](testing & t) { + std::string input = "{'key': 'value'}"; + std::string expected = "{\"key\": \"value\"}"; + std::string result = normalize_quotes_to_json(input); + t.assert_equal("basic conversion", expected, result); + }); + + t.test("escaped single quote inside single-quoted string", [](testing & t) { + std::string input = "{'code': 'print(\\'hello\\')'}"; + std::string expected = "{\"code\": \"print('hello')\"}"; + std::string result = normalize_quotes_to_json(input); + t.assert_equal("escaped single quote", expected, result); + }); + + t.test("double quote inside single-quoted string", [](testing & t) { + std::string input = "{'msg': 'He said \"hi\"'}"; + std::string expected = "{\"msg\": \"He said \\\"hi\\\"\"}"; + std::string result = normalize_quotes_to_json(input); + t.assert_equal("double quote escaping", expected, result); + }); + + t.test("nested backslash escapes", [](testing & t) { + std::string input = "{'path': 'C:\\\\Users\\\\test'}"; + std::string expected = "{\"path\": \"C:\\\\Users\\\\test\"}"; + std::string result = normalize_quotes_to_json(input); + t.assert_equal("backslash escaping", expected, result); + }); + + t.test("newline escapes", [](testing & t) { + std::string input = "{'text': 'line1\\nline2'}"; + std::string expected = "{\"text\": \"line1\\nline2\"}"; + std::string result = normalize_quotes_to_json(input); + t.assert_equal("newline escaping", expected, result); + }); + + t.test("mixed quotes", [](testing & t) { + std::string input = "{\"already_double\": 'single_value'}"; + std::string expected = "{\"already_double\": \"single_value\"}"; + std::string result = normalize_quotes_to_json(input); + t.assert_equal("mixed quotes", expected, result); + }); + + t.test("embedded quotes - the test case", test_normalize_quotes_with_embedded_quotes); +} + +// Test case that mirrors the Seed-OSS failing test scenario +static void test_normalize_quotes_with_embedded_quotes(testing & t) { + // This is similar to the Seed-OSS template test case + // The input has embedded double quotes like "14" and "bar" inside string values + std::string input = "{'filename': 'foo.cpp', 'oldString': 'def foo(arg = \"14\"):\\n return arg + \"bar\"\\n', 'newString': 'def foo(arg = \"15\"):\\n pass\\n'}"; + + // Expected: Python single quotes -> JSON double quotes, internal double quotes escaped + std::string expected = "{\"filename\": \"foo.cpp\", \"oldString\": \"def foo(arg = \\\"14\\\"):\\n return arg + \\\"bar\\\"\\n\", \"newString\": \"def foo(arg = \\\"15\\\"):\\n pass\\n\"}"; + + std::string result = normalize_quotes_to_json(input); + + t.assert_equal("normalize quotes with embedded double quotes", expected, result); + + // Also verify the result is valid JSON + try { + json parsed = json::parse(result); + t.assert_true("result is valid JSON", true); + t.assert_equal("filename field", "foo.cpp", parsed["filename"].get()); + t.assert_true("oldString contains embedded quotes", + parsed["oldString"].get().find("\"14\"") != std::string::npos); + t.assert_true("newString contains embedded quotes", + parsed["newString"].get().find("\"15\"") != std::string::npos); + } catch (const std::exception & e) { + t.assert_true(std::string("JSON parse failed: ") + e.what(), false); + } +} + +// ============================================================================ +// TAG_WITH_TAGGED Argument Parsing Tests +// ============================================================================ + +// Build tools definition for edit function +static json build_edit_tool() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["filename"] = json::object({ + {"type", "string"}, + {"description", "Path of file to edit"} + }); + parameters_schema["properties"]["oldString"] = json::object({ + {"type", "string"}, + {"description", "String to replace"} + }); + parameters_schema["properties"]["newString"] = json::object({ + {"type", "string"}, + {"description", "New (replacement) value"} + }); + parameters_schema["required"] = json::array({"filename", "oldString", "newString"}); + + return json::array({ + json{ + {"type", "function"}, + {"function", json{ + {"name", "edit"}, + {"description", "Edit a file"}, + {"parameters", parameters_schema} + }} + } + }); +} + +// Test that reproduces the Seed-OSS template issue with embedded quotes +static void test_tagged_args_with_embedded_quotes(testing & t) { + json tools = build_edit_tool(); + + // Build a parser for TAG_WITH_TAGGED format like Seed-OSS/Nemotron + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + // Build tool choice for the edit function + auto tool_choice = p.choice(); + + for (const auto & tool_def : tools) { + if (!tool_def.contains("function")) { continue; } + const auto & function = tool_def.at("function"); + std::string name = function.at("name"); + const auto & params = function.at("parameters"); + + if (!params.contains("properties") || !params.at("properties").is_object()) { continue; } + + const auto & properties = params.at("properties"); + + // Build argument parsers + std::vector arg_parsers; + for (const auto & [param_name, param_schema] : properties.items()) { + auto arg = p.tool_arg( + p.tool_arg_open(p.literal("")) + + p.space() + + p.tool_arg_string_value(p.until("")) + + p.space() + + p.tool_arg_close(p.literal("")) + ); + arg_parsers.push_back(p.optional(p.rule("arg-" + param_name, arg))); + } + + // Build arg sequence with space() between + common_peg_parser args_seq = p.eps(); + for (size_t i = 0; i < arg_parsers.size(); i++) { + if (i > 0) { + args_seq = args_seq + p.space(); + } + args_seq = args_seq + arg_parsers[i]; + } + + auto func_parser = + p.tool_open(p.literal("")) + + p.space() + args_seq + p.space() + + p.tool_close(p.literal("")); + + tool_choice |= p.rule("tool-" + name, p.tool(func_parser)); + } + + auto tool_section = + p.literal("") + p.space() + + tool_choice + + p.space() + p.literal(""); + + return p.content(p.until("")) + p.optional(tool_section) + p.end(); + }); + + // The exact input from the failing test + std::string input = + "\n" + "\n" + "\n" + "foo.cpp\n" + "\n" + "" + "def foo(arg = \"14\"):\n" + " return arg + \"bar\"\n" + "\n" + "\n" + "" + "def foo(arg = \"15\"):\n" + " pass\n" + "\n" + "\n" + "\n" + ""; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + if (!t.assert_true("parse success", result.success())) { + return; + } + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "edit", msg.tool_calls[0].name); + + // Parse the arguments as JSON to verify they're valid + std::string args = msg.tool_calls[0].arguments; + + try { + json parsed = json::parse(args); + t.assert_true("arguments is valid JSON", true); + + // Verify each field has proper value + t.assert_equal("filename", "foo.cpp", parsed.value("filename", "")); + + std::string oldString = parsed.value("oldString", ""); + t.assert_true("oldString contains embedded quotes", + oldString.find("\"14\"") != std::string::npos); + t.assert_true("oldString contains bar with quotes", + oldString.find("\"bar\"") != std::string::npos); + + std::string newString = parsed.value("newString", ""); + t.assert_true("newString contains embedded quotes", + newString.find("\"15\"") != std::string::npos); + + } catch (const std::exception & e) { + t.assert_true(std::string("arguments should be valid JSON: ") + e.what(), false); + } + } +} + diff --git a/tests/test-chat-parser.cpp b/tests/test-chat-parser.cpp deleted file mode 100644 index 6f44a2b4211..00000000000 --- a/tests/test-chat-parser.cpp +++ /dev/null @@ -1,617 +0,0 @@ -// Tests chat handling, including grammar generation and parsing for tool calling, for various templates. -// -// Also acts as a CLI to generate a Markdown summary of the formats of Jinja templates, -// e.g. given Minja (http://github.com/google/minja) checked out in parent dir: -// -// cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null -// -#include -#include -#include - -#include "chat-parser.h" -#include "common.h" -#include "log.h" -#include "regex-partial.h" - -template -static void assert_equals(const std::string_view label, const T & expected, const T & actual) { - if (expected != actual) { - std::cerr << label << std::endl; - std::cerr << "Expected: " << expected << std::endl; - std::cerr << "Actual: " << actual << std::endl; - std::cerr << std::flush; - throw std::runtime_error("Test failed"); - } -} - -template -static void assert_equals(const T & expected, const T & actual) { - assert_equals("", expected, actual); -} -static void assert_equals(const char * expected, const std::string & actual) { - return assert_equals(expected, actual); -} - -static void assert_throws(const std::function & fn, const std::string & expected_exception_pattern = "") { - try { - fn(); - } catch (const std::exception & e) { - if (expected_exception_pattern.empty()) { - return; - } - std::regex expected_exception_regex(expected_exception_pattern); - std::string actual_message = e.what(); - if (std::regex_search(actual_message, expected_exception_regex)) { - return; - } - throw std::runtime_error("Exception doesn't match expected pattern: " + actual_message + " (pattern: " + expected_exception_pattern + ")"); - throw std::runtime_error("Exception of unexpected type: " + std::string(e.what())); - } - throw std::runtime_error("Exception was expected but not thrown"); -} - -static void test_reasoning() { - //common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG); - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - params.reasoning_format = COMMON_REASONING_FORMAT_NONE; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); - assert_equals(false, builder.try_parse_reasoning("", "")); - assert_equals("CogitoErgo sum", builder.consume_rest()); - } - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); - assert_equals(true, builder.try_parse_reasoning("", "")); - assert_equals(std::string("Cogito"), builder.result().reasoning_content); - assert_equals("Ergo sum", builder.consume_rest()); - } - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - params.reasoning_format = COMMON_REASONING_FORMAT_NONE; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); - assert_equals(false, builder.try_parse_reasoning("", "")); - assert_equals("CogitoErgo sum", builder.consume_rest()); - } - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); - assert_equals(true, builder.try_parse_reasoning("", "")); - assert_equals(std::string("Cogito"), builder.result().reasoning_content); - assert_equals("Ergo sum", builder.consume_rest()); - } - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = true; - params.thinking_forced_open = true; - common_chat_msg_parser builder("CogitoErgo sum", /* is_partial= */ false, params); - assert_equals(true, builder.try_parse_reasoning("", "")); - assert_equals("Cogito", builder.result().content); - assert_equals("Ergo sum", builder.consume_rest()); - } - { - const std::string variant("content_only_inline_think"); - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - params.parse_tool_calls = false; - const std::string input = "PenseBonjour"; - auto msg = common_chat_parse(input, false, params); - assert_equals(variant, std::string("Pense"), msg.reasoning_content); - assert_equals(variant, std::string("Bonjour"), msg.content); - } - { - const std::string variant("llama_3_inline_think"); - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_LLAMA_3_X; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - params.parse_tool_calls = false; - const std::string input = "PlanRéponse"; - auto msg = common_chat_parse(input, false, params); - assert_equals(variant, std::string("Plan"), msg.reasoning_content); - assert_equals(variant, std::string("Réponse"), msg.content); - } - // Test DeepSeek V3.1 parsing - reasoning content followed by "" and then regular content - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("deepseek_v3_1_reasoning_format_deepseek"); - common_chat_msg_parser builder("REASONINGok", /* is_partial= */ false, params); - assert_equals(variant, true, builder.try_parse_reasoning("", "")); - assert_equals(variant, std::string("REASONING"), builder.result().reasoning_content); - assert_equals(variant, std::string("ok"), builder.consume_rest()); - } - // Test DeepSeek V3.1 parsing - reasoning_format none - reasoning content followed by "" and then regular content - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_NONE; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("deepseek_v3_1_reasoning_format_none"); - const std::string input = "REASONINGok"; - auto msg = common_chat_parse(input, false, params); - assert_equals(variant, std::string("REASONINGok"), msg.content); - assert_equals(variant, std::string(""), msg.reasoning_content); - } -} - -static void test_regex() { - auto test_throws = [](const std::string & input, const std::string & regex, const std::string & expected_exception_pattern = "") { - common_chat_msg_parser builder(input, /* is_partial= */ false, {}); - assert_throws([&]() { builder.consume_regex(common_regex(regex)); }, expected_exception_pattern); - }; - - test_throws("Hello, world!", "abc", "^abc$"); - test_throws("Hello, world!", "e", "^e$"); - - { - common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {}); - builder.consume_regex(common_regex("Hello")); - assert_equals(", world!", builder.consume_rest()); - } - - { - // When in non partial mode, we can say whether the regex was consumed or not. - common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {}); - assert_equals(false, builder.try_consume_regex(common_regex("Hello, world!")).has_value()); - } - { - common_chat_msg_parser builder("Hello,", /* is_partial= */ false, {}); - auto res = builder.try_consume_regex(common_regex("H(el)l(?:o, world!)?")); - assert_equals(true, res.has_value()); - // Verify captures - assert_equals(2, res->groups.size()); - assert_equals("Hell", builder.str(res->groups[0])); - assert_equals("el", builder.str(res->groups[1])); - // Verify position is after the match - assert_equals(4, builder.pos()); - assert_equals("o,", builder.consume_rest()); - } - { - // But in partial mode, we have a partial final match / can't decide, so we throw a partial exception. - common_chat_msg_parser builder("Hello,", /* is_partial= */ true, {}); - assert_throws([&]() { - builder.try_consume_regex(common_regex("Hello, world!")); - }, "^Hello, world!$"); - } - - // Now regardless of the mode, we can tell these aren't a match. - for (const auto is_partial : {false, true}) { - common_chat_msg_parser builder("Hello,", is_partial, {}); - assert_equals(false, builder.try_consume_regex(common_regex("a(b|c)(d|e)f")).has_value()); - } - for (const auto is_partial : {false, true}) { - common_chat_msg_parser builder("Hello,", is_partial, {}); - assert_equals(false, builder.try_consume_literal("Oh")); - } -} - -const std::vector barely_healable_jsons = { - "{", - "{\"", - "{\"\\", - "{\"n", - "{\"name\"", - "{\"name\":", - "{\"name\":\"", - "{\"name\":\"\\", - "{\"name\":\"python", - "{\"name\":\"python\\", - "{\",", - "{\":", - "{\"[", - "{\"]", - "{\"{", - "{\"}", - "{\"1", - "{\"name\":\",", - "{\"name\":\":", - "{\"name\":\"[", - "{\"name\":\"]", - "{\"name\":\"{", - "{\"name\":\"}", - "{\"name\":\"1", -}; - -static void test(const std::string & input, bool is_partial, const std::vector> & args_paths, const std::vector> & content_paths, const std::string & expected) { - common_chat_msg_parser builder(input, is_partial, {}); - auto js = builder.try_consume_json_with_dumped_args(args_paths, content_paths); - assert_equals(true, js.has_value()); - assert_equals(is_partial, js->is_partial); - assert_equals(expected, args_paths.size() == 1 && args_paths[0].empty() ? js->value.get() : js->value.dump()); -} - -static void test_deepseek_v3_1_tool_calls() { - //common_log_set_verbosity_thold(LOG_DEFAULT_DEBUG); - // variant: happy path for when it works as the model card says it should - const std::string variant("simple"); - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - params.parse_tool_calls = true; - const std::string input = "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto msg = common_chat_parse(input, false, params); - assert_equals(variant, 1, msg.tool_calls.size()); - assert_equals(variant, std::string("get_time"), msg.tool_calls[0].name); - // JSON arguments are dumped without spaces - assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), msg.tool_calls[0].arguments); - assert_equals(variant, std::string(""), msg.content); - assert_equals(variant, std::string(""), msg.reasoning_content); - - // variant: simple + thinking open - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("simple_thinking"); - const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, params); - assert_equals(variant, 1, m.tool_calls.size()); - assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); - assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); - assert_equals(variant, std::string(""), m.content); - assert_equals(variant, std::string("REASONING"), m.reasoning_content); - } - // variant: simple + multiple tool calls - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - params.parse_tool_calls = true; - const std::string variant("simple_multiple_tool_calls"); - const std::string in = "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, params); - assert_equals(variant, 2, m.tool_calls.size()); - assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); - assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[0].arguments); - assert_equals(variant, std::string("get_weather"), m.tool_calls[1].name); - assert_equals(variant, std::string("{\"city\":\"Paris\"}"), m.tool_calls[1].arguments); - assert_equals(variant, std::string("CONTENT"), m.content); - assert_equals(variant, std::string(""), m.reasoning_content); - } - - - // variant: thinking forced open + tool call in reasoning content - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("thinking_forced_open_tool_call_in_reasoning"); - const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, params); - assert_equals(variant, 1, m.tool_calls.size()); - assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); - assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); - assert_equals(variant, std::string(""), m.content); - assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING"), m.reasoning_content); - } - - // variant: thinking forced open + tool call in reasoning content + no closing think + not partial - // This is a bit of a fine tuning issue on the model's part IMO. It really should not be attempting - // to make tool calls in reasoning content according to the model card, but it does sometimes, so - // add the reasoning content as regular content and parse the tool calls. - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_not_partial"); - const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, false, params); - assert_equals(variant, std::string("REASONING"), m.content); - assert_equals(variant, std::string(""), m.reasoning_content); - assert_equals(variant, 1, m.tool_calls.size()); - assert_equals(variant, std::string("get_time"), m.tool_calls[0].name); - assert_equals(variant, std::string("{\"city\":\"Tokyo\"}"), m.tool_calls[0].arguments); - } - - // variant: thinking forced open + tool call in reasoning content + no closing think + partial - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("thinking_forced_open_tool_call_in_reasoning_no_closing_think_partial"); - const std::string in = "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"; - auto m = common_chat_parse(in, /* is_partial= */ true, params); - assert_equals(variant, std::string("REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>"), m.reasoning_content); - assert_equals(variant, std::string(""), m.content); - assert_equals(variant, 0, m.tool_calls.size()); - } - - // variant: thinking not forced open + reasoning + regular content + no tool calls - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = true; - params.parse_tool_calls = true; - const std::string variant("thinking_forced_open_reasoning_regular_content_no_tool_calls"); - const std::string in = "REASONINGCONTENT"; - auto m = common_chat_parse(in, false, params); - assert_equals(variant, 0, m.tool_calls.size()); - assert_equals(variant, std::string("CONTENT"), m.content); - assert_equals(variant, std::string("REASONING"), m.reasoning_content); - } - // variant: thinking not forced open + missing reasoning + no tool calls - { - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1; - params.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - params.parse_tool_calls = true; - const std::string variant("thinking_not_forced_open_missing_reasoning_no_tool_calls"); - const std::string in = "CONTENT"; - auto m = common_chat_parse(in, false, params); - assert_equals(variant, 0, m.tool_calls.size()); - assert_equals(variant, std::string("CONTENT"), m.content); - assert_equals(variant, std::string(""), m.reasoning_content); - } -} - -static void test_with_args(const std::string & input, const std::string & expected, bool parse_as_partial = true, bool is_partial = true) { - common_chat_msg_parser builder(input, parse_as_partial, {}); - auto js = builder.try_consume_json_with_dumped_args({{"args"}}, {}); - assert_equals(true, js.has_value()); - assert_equals(is_partial, js->is_partial); - assert_equals(expected, js->value.dump()); -} - -static void test_json_with_dumped_args_no_args() { - // Normal JSON, nothing to heal, nothing to dump - test("{\"name\": \"python\"}", false, {}, {}, "{\"name\":\"python\"}"); - // Full json is args - test("{\"name\": \"python\"}", false, {{}}, {}, "{\"name\":\"python\"}"); - - // If the arguments are further down, don't heal partial content. - for (const auto & src : barely_healable_jsons) { - test(src, true, {{"arguments"}}, {}, "{}"); - } - // But heal content that isn't partial. - test("{\"name\": \"python\"", true, {{"arguments"}}, {}, "{\"name\":\"python\"}"); -} - -static void test_json_with_dumped_args() { - - // Partial content. - test("{\"content\": \"t", true, {}, {{"content"}}, "{\"content\":\"t\"}"); - test("{\"content\": \"", true, {}, {{"content"}}, "{\"content\":\"\"}"); - test("{\"content\": ", true, {}, {{"content"}}, "{}"); - - // If the entire JSON is the arguments, healing it them dumping it produces the same output as the input (just reformatted). - test("{\"name\": \"python", true, {{}}, {}, "{\"name\":\"python"); - for (const auto & src : barely_healable_jsons) { - test(src, true, {{}}, {}, src); - } - - // Full JSON w/ args - for (auto parse_as_partial : {true, false}) { - test_with_args( - R"({"name": "python", "args": {"arg1": 1}})", - R"({"name":"python","args":"{\"arg1\":1}"})", - parse_as_partial, - /* is_partial= */ false - ); - } - - // Partial JSON w/ partial args - test_with_args( - R"({"foo": "bar", "args": {")", - R"({"foo":"bar","args":"{\""})" - ); - // Partial args broken in object key - test_with_args( - R"({"foo": "bar", "args": {"ar)", - R"({"foo":"bar","args":"{\"ar"})" - ); - // Partial args broken after object key - test_with_args( - R"({"foo": "bar", "args": {"arg1")", - R"({"foo":"bar","args":"{\"arg1\""})" - ); - // Partial args broken before object value - test_with_args( - R"({"foo": "bar", "args": {"arg1":)", - R"({"foo":"bar","args":"{\"arg1\":"})" - ); - // Partial args broken before object value (space) - test_with_args( - R"({"foo": "bar", "args": {"arg1": )", - R"({"foo":"bar","args":"{\"arg1\":"})" - ); - // Partial args broken in object value that may not be complete (int) - test_with_args( - R"({"foo": "bar", "args": {"arg1": 1)", - R"({"foo":"bar","args":"{\"arg1\":"})" - ); - // Partial args broken in object value that is complete (int) - test_with_args( - R"({"foo": "bar", "args": {"arg1": 1 )", - R"({"foo":"bar","args":"{\"arg1\":1"})" - ); - // Partial args broken in object value that is incomplete (string) - test_with_args( - R"({"foo": "bar", "args": {"arg1": ")", - R"({"foo":"bar","args":"{\"arg1\":\""})" - ); - // Partial args broken in object value that is complete (string) - test_with_args( - R"({"foo": "bar", "args": {"arg1": "1")", - R"({"foo":"bar","args":"{\"arg1\":\"1\""})" - ); - // Partial args broken on array opening - test_with_args( - R"({"foo": "bar", "args": [)", - R"({"foo":"bar","args":"["})" - ); - // Partial args broken on array value that is incomplete (int) - test_with_args( - R"({"foo": "bar", "args": [1)", - R"({"foo":"bar","args":"["})" - ); - // Partial args broken on array value that is complete (int) - test_with_args( - R"({"foo": "bar", "args": [1 )", - R"({"foo":"bar","args":"[1"})" - ); - // Partial args broken on array value that is complete (string) - test_with_args( - R"({"foo": "bar", "args": ["1")", - R"({"foo":"bar","args":"[\"1\""})" - ); - // Partial args broken after array value - test_with_args( - R"({"foo": "bar", "args": [1,)", - R"({"foo":"bar","args":"[1,"})" - ); - // Partial args broken on nested array - test_with_args( - R"({"foo": "bar", "args": {"arg1": [)", - R"({"foo":"bar","args":"{\"arg1\":["})" - ); - - // Unicode tests - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\u)", - R"({"foo":"bar","args":"{\"arg1\":\"\\u"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\u0)", - R"({"foo":"bar","args":"{\"arg1\":\"\\u0"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\u00)", - R"({"foo":"bar","args":"{\"arg1\":\"\\u00"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\u000)", - R"({"foo":"bar","args":"{\"arg1\":\"\\u000"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\u0000)", - R"({"foo":"bar","args":"{\"arg1\":\"\\u0000"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud8)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud8"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud80)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud80"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800\)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800\u)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\u"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800\ud)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\ud"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800\udc)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800\udc0)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc0"})" - ); - test_with_args( - R"({"foo": "bar", "args": {"arg1": "\ud800\udc00)", - R"({"foo":"bar","args":"{\"arg1\":\"\\ud800\\udc00"})" - ); -} - -static void test_positions() { - { - common_chat_msg_parser builder("Hello, world!", /* is_partial= */ false, {}); - assert_equals(0, builder.pos()); - assert_throws([&]() { builder.move_to(100); }); - assert_equals(0, builder.pos()); - assert_throws([&]() { builder.move_back(1); }); - assert_equals(0, builder.pos()); - - builder.move_to(8); - assert_equals(8, builder.pos()); - builder.move_back(1); - assert_equals(7, builder.pos()); - assert_equals("world!", builder.consume_rest()); - - builder.move_to(0); - assert_equals(0, builder.pos()); - - assert_throws([&]() { builder.finish(); }); - assert_equals(0, builder.pos()); - - builder.move_to(builder.input().size()); - builder.finish(); - } - { - common_chat_msg_parser builder("Hello, world!", /* is_partial= */ true, {}); - - builder.move_to(builder.input().size()); - assert_equals(builder.input().size(), builder.pos()); - builder.finish(); - } -} - -int main() { - test_positions(); - test_json_with_dumped_args_no_args(); - test_json_with_dumped_args(); - test_reasoning(); - test_regex(); - test_deepseek_v3_1_tool_calls(); - std::cout << "All tests passed!\n"; - return 0; -} diff --git a/tests/test-chat-peg-parser.cpp b/tests/test-chat-peg-parser.cpp index f767c73c27a..7626ca12dbd 100644 --- a/tests/test-chat-peg-parser.cpp +++ b/tests/test-chat-peg-parser.cpp @@ -1,8 +1,3 @@ -#include -#include -#include - -#include "chat-parser.h" #include "chat-peg-parser.h" #include "chat.h" #include "common.h" @@ -10,6 +5,11 @@ #include "peg-parser.h" #include "testing.h" #include "peg-parser/simple-tokenize.h" + +#include +#include +#include + #include "nlohmann/json.hpp" using json = nlohmann::ordered_json; @@ -17,9 +17,12 @@ using json = nlohmann::ordered_json; static json create_tools(); static void test_example_native(testing & t); static void test_example_qwen3_coder(testing & t); +static void test_example_qwen3_non_coder(testing & t); static void test_command7_parser_compare(testing & t); +static void test_prefix_tool_names(testing & t); +static void test_tagged_peg_parser(testing & t); -int main(int argc, char *argv[]) { +int main(int argc, char * argv[]) { testing t(std::cout); if (argc >= 2) { t.set_filter(argv[1]); @@ -32,7 +35,10 @@ int main(int argc, char *argv[]) { t.test("native", test_example_native); t.test("qwen3 coder", test_example_qwen3_coder); + t.test("qwen3 non-coder", test_example_qwen3_non_coder); t.test("comparison", test_command7_parser_compare); + t.test("prefix tool names", test_prefix_tool_names); + t.test("tagged peg parser", test_tagged_peg_parser); return t.summary(); } @@ -41,87 +47,75 @@ static json create_tools() { json tools = json::array(); json tool_weather = { - {"type", "function"}, - {"function", { - {"name", "get_current_weather"}, - {"description", "Get the current weather in a given location"}, - {"parameters", { - {"type", "object"}, - {"properties", { - {"location", { - {"type", "string"}, - {"description", "The city and state, e.g. San Francisco, CA"} - }}, - {"unit", { - {"type", "string"}, - {"enum", {"celsius", "fahrenheit"}}, - {"description", "The temperature unit to use. Infer this from the users location."} - }} - }}, - {"required", {"location", "unit"}}, - }}, - }} + { "type", "function" }, + { "function", + { + { "name", "get_current_weather" }, + { "description", "Get the current weather in a given location" }, + { "parameters", + { + { "type", "object" }, + { "properties", + { { "location", + { { "type", "string" }, { "description", "The city and state, e.g. San Francisco, CA" } } }, + { "unit", + { { "type", "string" }, + { "enum", { "celsius", "fahrenheit" } }, + { "description", + "The temperature unit to use. Infer this from the users location." } } } } }, + { "required", { "location", "unit" } }, + } }, + } } }; tools.push_back(tool_weather); json tool_forecast = { - {"type", "function"}, - {"function", { - {"name", "get_forecast"}, - {"description", "Get the weather forecast for a given location"}, - {"parameters", { - {"type", "object"}, - {"properties", { - {"location", { - {"type", "string"}, - {"description", "The city and state, e.g. San Francisco, CA"} - }}, - {"unit", { - {"type", "string"}, - {"enum", {"celsius", "fahrenheit"}}, - {"description", "The temperature unit to use. Infer this from the users location."} - }}, - {"days", { - {"type", "integer"}, - {"description", "Number of days to forecast (1-10)"}, - {"minimum", 1}, - {"maximum", 10} - }} - }}, - {"required", {"location", "unit"}}, - }}, - }} + { "type", "function" }, + { "function", + { + { "name", "get_forecast" }, + { "description", "Get the weather forecast for a given location" }, + { "parameters", + { + { "type", "object" }, + { "properties", + { { "location", + { { "type", "string" }, { "description", "The city and state, e.g. San Francisco, CA" } } }, + { "unit", + { { "type", "string" }, + { "enum", { "celsius", "fahrenheit" } }, + { "description", "The temperature unit to use. Infer this from the users location." } } }, + { "days", + { { "type", "integer" }, + { "description", "Number of days to forecast (1-10)" }, + { "minimum", 1 }, + { "maximum", 10 } } } } }, + { "required", { "location", "unit" } }, + } }, + } } }; tools.push_back(tool_forecast); json tool_search = { - {"type", "function"}, - {"function", { - {"name", "search_knowledge_base"}, - {"description", "Search the internal technical documentation knowledge base."}, - {"parameters", { - {"type", "object"}, - {"properties", { - {"query", { - {"type", "string"}, - {"description", "The search query string."} - }}, - {"max_results", { - {"type", "integer"}, - {"description", "The maximum number of results to return."}, - {"default", 5} - }}, - {"category", { - {"type", "string"}, - {"enum", {"api", "troubleshooting", "billing", "general"}}, - {"description", "Filter search by specific category."} - }} - }}, - {"required", {"query", "category"}}, - {"additionalProperties", false} - }}, - {"strict", true} - }} + { "type", "function" }, + { "function", + { { "name", "search_knowledge_base" }, + { "description", "Search the internal technical documentation knowledge base." }, + { "parameters", + { { "type", "object" }, + { "properties", + { { "query", { { "type", "string" }, { "description", "The search query string." } } }, + { "max_results", + { { "type", "integer" }, + { "description", "The maximum number of results to return." }, + { "default", 5 } } }, + { "category", + { { "type", "string" }, + { "enum", { "api", "troubleshooting", "billing", "general" } }, + { "description", "Filter search by specific category." } } } } }, + { "required", { "query", "category" } }, + { "additionalProperties", false } } }, + { "strict", true } } } }; tools.push_back(tool_search); @@ -131,39 +125,39 @@ static json create_tools() { struct tool_argument { std::string name; std::string type; - bool is_required; - json schema; + bool is_required; + json schema; }; struct tool_definition { - std::string name; + std::string name; std::vector arguments; - json schema; + json schema; }; // Test fictitious model output that emits arguments as JSON. static void test_example_native(testing & t) { struct test_case { // Parameters - std::string name; - json tools; + std::string name; + json tools; common_chat_tool_choice tool_choice; common_reasoning_format reasoning_format; - json json_schema; - bool parallel_tool_calls; - bool thinking_forced_open; - std::string input; + json json_schema; + bool parallel_tool_calls; + bool thinking_forced_open; + std::string input; // Expect - std::string expect_reasoning; - std::string expect_content; + std::string expect_reasoning; + std::string expect_content; std::vector expect_tool_calls; }; auto build_parser = [](const test_case & tc) { - return build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) { + return build_chat_peg_parser([&](common_chat_peg_builder & p) { auto reasoning_in_content = (tc.reasoning_format == COMMON_REASONING_FORMAT_NONE); - auto reasoning = p.eps(); + auto reasoning = p.eps(); if (tc.thinking_forced_open) { // If thinking is forced open, expect a closing tag reasoning = p.reasoning(p.until("")) + "" + p.space(); @@ -174,231 +168,188 @@ static void test_example_native(testing & t) { // tool calling parser if (tc.tools.is_array() && !tc.tools.empty()) { - auto tools = p.choice(); - for (const auto & tool : tc.tools) { - const auto & function = tool.at("function"); - std::string name = function.at("name"); - const auto & schema = function.at("parameters"); - - auto tool_name = p.json_member("name", "\"" + p.tool_name(p.literal(name)) + "\""); - auto tool_args = p.json_member("arguments", p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema))); + auto tool_call = + p.standard_json_tools("[", "]", tc.tools, tc.parallel_tool_calls, + tc.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED); - tools |= p.rule("tool-" + name, p.tool_open(p.literal("{")) << tool_name << "," << tool_args << "}"); - }; - - auto parallel_calls = p.eps(); - if (tc.parallel_tool_calls) { - parallel_calls = p.zero_or_more("," << tools); - } - - auto tool_call = p.trigger_rule("tool-call", - p.sequence({ - p.literal("["), - tools, - parallel_calls, - p.literal("]") - }) - ); - - return p.sequence({ - (reasoning_in_content ? p.eps() : reasoning), - p.content(p.until("")), - p.optional(p.space() + tool_call), - p.space(), - p.end() - }); + return p.sequence({ (reasoning_in_content ? p.eps() : reasoning), p.content(p.until("")), + p.optional(p.space() + tool_call), p.space(), p.end() }); } // response_format parser if (tc.json_schema.is_object() && !tc.json_schema.empty()) { - return p.sequence({ - (reasoning_in_content ? p.eps() : reasoning), - p.content(p.schema(p.json(), "response-output", tc.json_schema)), - p.space(), - p.end() - }); + return p.sequence({ (reasoning_in_content ? p.eps() : reasoning), + p.content(p.schema(p.json(), "response-output", tc.json_schema)), p.space(), + p.end() }); } // Content-only parser - return p.sequence({ - (reasoning_in_content ? p.eps() : reasoning), - p.content(p.rest()), - p.end() - }); + return p.sequence({ (reasoning_in_content ? p.eps() : reasoning), p.content(p.rest()), p.end() }); }); }; std::vector test_cases = std::vector{ { - /* .name = */ "content with thinking_forced_open = false", - /* .tools = */ {}, - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ false, - /* .input = */ ( - "The user said hello, I must say hello back\nHello" - ), - /* .expect_reasoning = */ "The user said hello, I must say hello back", - /* .expect_content = */ "Hello", - /* .expect_tool_calls = */ {}, - }, + /* .name = */ "content with thinking_forced_open = false", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ("The user said hello, I must say hello back\nHello"), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, { - /* .name = */ "content with thinking_forced_open = false and no reasoning", - /* .tools = */ {}, - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ false, - /* .input = */ ( - "Hello" - ), - /* .expect_reasoning = */ "", - /* .expect_content = */ "Hello", - /* .expect_tool_calls = */ {}, - }, + /* .name = */ "content with thinking_forced_open = false and no reasoning", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ false, + /* .input = */ ("Hello"), + /* .expect_reasoning = */ "", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, { - /* .name = */ "content with thinking_forced_open = false and reasoning_format = none", - /* .tools = */ {}, - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ true, - /* .input = */ ( - "The user said hello, I must say hello back\nHello" - ), - /* .expect_reasoning = */ "", - /* .expect_content = */ "The user said hello, I must say hello back\nHello", - /* .expect_tool_calls = */ {}, - }, + /* .name = */ "content with thinking_forced_open = false and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ("The user said hello, I must say hello back\nHello"), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, { - /* .name = */ "content with thinking_forced_open = true", - /* .tools = */ {}, - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ true, - /* .input = */ ( - "The user said hello, I must say hello back\nHello" - ), - /* .expect_reasoning = */ "The user said hello, I must say hello back", - /* .expect_content = */ "Hello", - /* .expect_tool_calls = */ {}, - }, + /* .name = */ "content with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ("The user said hello, I must say hello back\nHello"), + /* .expect_reasoning = */ "The user said hello, I must say hello back", + /* .expect_content = */ "Hello", + /* .expect_tool_calls = */ {}, + }, { - /* .name = */ "content with thinking_forced_open = true and reasoning_format = none", - /* .tools = */ {}, - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ true, - /* .input = */ ( - "The user said hello, I must say hello back\nHello" - ), - /* .expect_reasoning = */ "", - /* .expect_content = */ "The user said hello, I must say hello back\nHello", - /* .expect_tool_calls = */ {}, - }, + /* .name = */ "content with thinking_forced_open = true and reasoning_format = none", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ ("The user said hello, I must say hello back\nHello"), + /* .expect_reasoning = */ "", + /* .expect_content = */ "The user said hello, I must say hello back\nHello", + /* .expect_tool_calls = */ {}, + }, { - /* .name = */ "tools with tool_choice = auto and no parallel_tool_calls", - /* .tools = */ create_tools(), - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ true, - /* .input = */ ( - "I must get the weather in New York\n" - "[" - R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" - "]" - ), - /* .expect_reasoning = */ "I must get the weather in New York", - /* .expect_content = */ "", - /* .expect_tool_calls = */ {{ + /* .name = */ "tools with tool_choice = auto and no parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ + ("I must get the weather in New York\n" + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + "]"), + /* .expect_reasoning = */ "I must get the weather in New York", + /* .expect_content = */ "", + /* .expect_tool_calls = */ + { { /* .name = */ "get_current_weather", /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", /* .id = */ "", - }}, - }, + } }, + }, { - /* .name = */ "tools with tool_choice = auto and parallel_tool_calls", - /* .tools = */ create_tools(), - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .json_schema = */ {}, - /* .parallel_tool_calls = */ true, - /* .thinking_forced_open = */ true, - /* .input = */ ( - "I must get the weather in New York and San Francisco and a 3 day forecast of each.\nLet me search that for you." - "[" - R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" - ", " - R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})" - ", " - R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})" - ", " - R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})" - "]" - ), - /* .expect_reasoning = */ "I must get the weather in New York and San Francisco and a 3 day forecast of each.", - /* .expect_content = */ "Let me search that for you.", - /* .expect_tool_calls = */ {{ - /* .name = */ "get_current_weather", - /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", - /* .id = */ "", - }, { - /* .name = */ "get_current_weather", - /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})", - /* .id = */ "", - }, { - /* .name = */ "get_forecast", - /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})", - /* .id = */ "", - }, { - /* .name = */ "get_forecast", - /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})", - /* .id = */ "", - }}, - }, + /* .name = */ "tools with tool_choice = auto and parallel_tool_calls", + /* .tools = */ create_tools(), + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_AUTO, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ {}, + /* .parallel_tool_calls = */ true, + /* .thinking_forced_open = */ true, + /* .input = */ + ("I must get the weather in New York and San Francisco and a 3 day forecast of each.\nLet me " + "search that for you." + "[" + R"({"name": "get_current_weather", "arguments": {"location": "New York City, NY", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_current_weather", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit"}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "New York City, NY", "unit": "fahrenheit", "days": 3}})" + ", " + R"({"name": "get_forecast", "arguments": {"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3}})" + "]"), + /* .expect_reasoning = */ + "I must get the weather in New York and San Francisco and a 3 day forecast of each.", /* .expect_content = */ "Let me search that for you.", + /* .expect_tool_calls = */ + { { + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit"})", + /* .id = */ "", + }, + { + /* .name = */ "get_current_weather", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit"})", + /* .id = */ "", + }, + { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "New York City, NY", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + }, + { + /* .name = */ "get_forecast", + /* .arguments = */ R"({"location": "San Francisco, CA", "unit": "fahrenheit", "days": 3})", + /* .id = */ "", + } }, + }, { - /* .name = */ "response_format with thinking_forced_open = true", - /* .tools = */ {}, - /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .json_schema = */ { - {"type", "object"}, - {"properties", { - {"invoice_number", {{"type", "string"}}}, - {"amount", {{"type", "number"}}}, - {"due_date", {{"type", "string"}}} - }}, - {"required", {"invoice_number", "amount", "due_date"}} - }, - /* .parallel_tool_calls = */ false, - /* .thinking_forced_open = */ true, - /* .input = */ ( - "I must produce the invoice in the requested format\n" - R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})" - ), - /* .expect_reasoning = */ "I must produce the invoice in the requested format", - /* .expect_content = */ R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})", - /* .expect_tool_calls = */ {}, - }, + /* .name = */ "response_format with thinking_forced_open = true", + /* .tools = */ {}, + /* .tool_choice = */ COMMON_CHAT_TOOL_CHOICE_NONE, + /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, + /* .json_schema = */ + { { "type", "object" }, + { "properties", + { { "invoice_number", { { "type", "string" } } }, + { "amount", { { "type", "number" } } }, + { "due_date", { { "type", "string" } } } } }, + { "required", { "invoice_number", "amount", "due_date" } } }, + /* .parallel_tool_calls = */ false, + /* .thinking_forced_open = */ true, + /* .input = */ + ("I must produce the invoice in the requested format\n" + R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})"), + /* .expect_reasoning = */ "I must produce the invoice in the requested format", + /* .expect_content = */ + R"({"invoice_number": "INV-2025-001", "amount": 1250.50, "due_date": "2025-12-31"})", /* .expect_tool_calls = */ {}, + }, }; for (const auto & tc : test_cases) { t.test(tc.name, [&](testing & t) { - auto parser = build_parser(tc); - auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; + auto parser = build_parser(tc); + auto lazy = !tc.tools.empty() && tc.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; auto grammar = build_grammar([&](const common_grammar_builder & builder) { - for (auto const & def : tc.tools) { - auto function = def.at("function"); + for (const auto & def : tc.tools) { + auto function = def.at("function"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); }; @@ -406,17 +357,17 @@ static void test_example_native(testing & t) { }); t.log("Grammar:"); - for (auto const & line : string_split(grammar, "\n")) { + for (const auto & line : string_split(grammar, "\n")) { t.log(line); } common_peg_parse_context ctx(tc.input, false); - auto result = parser.parse(ctx); + auto result = parser.parse(ctx); t.assert_true("success", result.success()); common_chat_msg msg; - auto mapper = common_chat_peg_native_mapper(msg); + auto mapper = common_chat_peg_mapper(msg); mapper.from_ast(ctx.ast, result); t.assert_equal("content equal", tc.expect_content, msg.content); @@ -431,16 +382,16 @@ static void test_example_native(testing & t) { } static void test_example_qwen3_coder(testing & t) { - auto tools = create_tools(); - auto parser = build_chat_peg_constructed_parser([&](common_chat_peg_constructed_builder & p) { + auto tools = create_tools(); + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { auto content = p.rule("content", p.content(p.until(""))); std::vector tool_parsers; - for (auto const & def : tools) { - auto function = def.at("function"); - std::string name = function.at("name"); - auto parameters = function.at("parameters"); - auto properties = parameters.at("properties"); + for (const auto & def : tools) { + auto function = def.at("function"); + std::string name = function.at("name"); + auto parameters = function.at("parameters"); + auto properties = parameters.at("properties"); std::set required_properties; if (function.contains("required")) { @@ -450,59 +401,36 @@ static void test_example_qwen3_coder(testing & t) { std::vector arg_parsers; for (const auto & [param_name, param_schema] : properties.items()) { bool is_required = required_properties.find(param_name) != required_properties.end(); - auto type = param_schema.value("type", "object"); - - auto arg = p.tool_arg(p.sequence({ - p.tool_arg_open(""), - (type == "string" ? - p.tool_arg_string_value( - p.schema( - p.until_one_of({ - "\n\n" - }), - "tool-" + name + "-arg-" + param_name + "-schema", - param_schema, - true - ) - ) : p.tool_arg_json_value( - p.schema( - p.json(), - "tool-" + name + "-arg-" + param_name + "-schema", - param_schema - ) - ) - ), - p.tool_arg_close( - "\n" + - p.peek(p.literal("")) - ) - })); - - arg_parsers.push_back(is_required ? - p.rule("tool-" + name + "-arg-" + param_name, arg) : - p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg))); + auto type = param_schema.value("type", "object"); + + auto arg = p.tool_arg( + p.sequence({ p.tool_arg_open(""), + (type == "string" ? + p.tool_arg_string_value(p.schema( + p.until_one_of({ "\n\n" }), + "tool-" + name + "-arg-" + param_name + "-schema", param_schema, true)) : + p.tool_arg_json_value(p.schema( + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema))), + p.tool_arg_close("\n" + + p.peek(p.literal(""))) })); + + arg_parsers.push_back(is_required ? p.rule("tool-" + name + "-arg-" + param_name, arg) : + p.optional(p.rule("tool-" + name + "-arg-" + param_name, arg))); } - tool_parsers.push_back(p.rule("tool-" + name, - p.tool_open("") - << p.sequence(arg_parsers) - << p.tool_close(p.literal("")) - )); + tool_parsers.push_back(p.rule("tool-" + name, p.tool_open("") + << p.sequence(arg_parsers) + << p.tool_close(p.literal("")))); }; - auto tool_call = p.trigger_rule("tool-call", - "" - << p.choice(tool_parsers) - << "" - ); + auto tool_call = p.trigger_rule("tool-call", "" << p.choice(tool_parsers) << ""); return content + p.zero_or_more(p.space() + tool_call) + p.end(); }); auto grammar = build_grammar([&](const common_grammar_builder & builder) { - for (auto const & def : tools) { - auto function = def.at("function"); + for (const auto & def : tools) { + auto function = def.at("function"); auto parameters = function.at("parameters"); builder.resolve_refs(parameters); }; @@ -510,11 +438,11 @@ static void test_example_qwen3_coder(testing & t) { }); t.log("Grammar:"); - for (auto const & line : string_split(grammar, "\n")) { + for (const auto & line : string_split(grammar, "\n")) { t.log(line); } - t.test("incremental parsing", [&](testing &t) { + t.test("incremental parsing", [&](testing & t) { std::string input = "Let me search the knowledge base for cat pictures." "\n" @@ -538,7 +466,105 @@ static void test_example_qwen3_coder(testing & t) { } common_chat_msg msg; - auto mapper = common_chat_peg_constructed_mapper(msg); + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + //t.log("Input: " + input); + t.log("==========================================="); + t.log("Iteration " + std::to_string(in.size())); + t.log("Reasoning: " + msg.reasoning_content); + t.log("Content : " + msg.content); + for (const auto & tc : msg.tool_calls) { + t.log("Tool name: " + tc.name); + t.log("Tool args: " + tc.arguments); + } + + try { + // This shouldn't emit any runtime errors + auto diffs = common_chat_msg_diff::compute_diffs(prev, msg); + } catch (const std::exception & e) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + t.assert_true(std::string("failed with ") + e.what(), false); + } + + prev = msg; + } + }); +} + +static void test_example_qwen3_non_coder(testing & t) { + auto tools = create_tools(); + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + // tool calling parser using standard JSON format + auto tool_call = p.standard_json_tools("", "", tools, true, false); + + return p.sequence({ p.content(p.until("")), p.optional(p.space() + tool_call), p.end() }); + }); + + auto grammar = build_grammar([&](const common_grammar_builder & builder) { + for (const auto & def : tools) { + auto function = def.at("function"); + auto parameters = function.at("parameters"); + builder.resolve_refs(parameters); + }; + parser.build_grammar(builder); + }); + + t.log("Grammar:"); + for (const auto & line : string_split(grammar, "\n")) { + t.log(line); + } + + t.test("tool call parsing", [&](testing & t) { + std::string input = + "I need to get the weather.\n" + "" + "{\"name\": \"get_current_weather\", \"arguments\": {\"location\": \"New York City, NY\", \"unit\": " + "\"fahrenheit\"}}" + ""; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + t.assert_true("success", result.success()); + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("content", "I need to get the weather.\n", msg.content); + t.assert_equal("reasoning", "", msg.reasoning_content); + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "get_current_weather", msg.tool_calls[0].name); + t.assert_equal("tool args", "{\"location\": \"New York City, NY\", \"unit\": \"fahrenheit\"}", + msg.tool_calls[0].arguments); + } + }); + + t.test("incremental parsing", [&](testing & t) { + std::string input = + "I need to get the weather.\n" + "" + "{\"name\": \"get_current_weather\", \"arguments\": {\"location\": \"New York City, NY\", \"unit\": " + "\"fahrenheit\"}}" + ""; + + std::vector tokens = simple_tokenize(input); + + common_chat_msg prev; + for (auto it = tokens.begin(); it != tokens.end(); it++) { + std::string in = std::accumulate(tokens.begin(), it + 1, std::string()); + + common_peg_parse_context ctx(in, it + 1 < tokens.end()); + + auto result = parser.parse(ctx); + if (!t.assert_equal("not fail", false, result.fail())) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + } + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); mapper.from_ast(ctx.ast, result); //t.log("Input: " + input); @@ -554,7 +580,7 @@ static void test_example_qwen3_coder(testing & t) { try { // This shouldn't emit any runtime errors auto diffs = common_chat_msg_diff::compute_diffs(prev, msg); - } catch(const std::exception & e) { + } catch (const std::exception & e) { t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); t.assert_true(std::string("failed with ") + e.what(), false); } @@ -565,38 +591,37 @@ static void test_example_qwen3_coder(testing & t) { } void test_command7_parser_compare(testing & t) { - auto parser = build_chat_peg_native_parser([](common_chat_peg_native_builder & p) { - auto thinking = p.reasoning_block( - "<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); + auto parser = build_chat_peg_parser([](common_chat_peg_builder & p) { + auto thinking = + p.reasoning_block("<|START_THINKING|>" << p.reasoning(p.until("<|END_THINKING|>")) << "<|END_THINKING|>"); auto response = "<|START_RESPONSE|>" << p.content(p.until("<|END_RESPONSE|>")) << "<|END_RESPONSE|>"; auto tool_call_id = p.atomic("\"tool_call_id\"" << (":" << ("\"" + p.tool_id(p.json_string_content()) + "\""))); - auto tool_call_name = p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\""))); + auto tool_call_name = + p.atomic("\"tool_name\"" << (":" << ("\"" + p.tool_name(p.json_string_content()) + "\""))); auto tool_call_args = "\"parameters\"" << (":" << p.tool_args(p.json())); auto tool_call_fields = p.rule("tool-call-fields", tool_call_id | tool_call_name | tool_call_args); - auto tool_call = p.rule("tool-call", p.tool( - p.tool_open(p.literal("{")) - << tool_call_fields - << p.zero_or_more( p.literal(",") << tool_call_fields) - << p.tool_close(p.literal("}")) - )); - - auto tool_calls = p.rule("tool-calls", - "<|START_ACTION|>" - << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") - << "<|END_ACTION|>"); + auto tool_call = + p.rule("tool-call", p.tool(p.tool_open(p.literal("{")) + << tool_call_fields << p.zero_or_more(p.literal(",") << tool_call_fields) + << p.tool_close(p.literal("}")))); + + auto tool_calls = p.rule( + "tool-calls", "<|START_ACTION|>" << ("[" << tool_call << p.zero_or_more(p.literal(",") << tool_call) << "]") + << "<|END_ACTION|>"); return p.optional(thinking) << (tool_calls | response) + p.end(); }); - auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, bool print_results) { + auto test_current = [&](const common_peg_arena & p, const std::string & input, bool is_partial, + bool print_results) { common_peg_parse_context ctx(input, is_partial); - auto result = p.parse(ctx); + auto result = p.parse(ctx); common_chat_msg msg; - auto mapper = common_chat_peg_native_mapper(msg); + auto mapper = common_chat_peg_mapper(msg); mapper.from_ast(ctx.ast, result); if (print_results) { @@ -614,79 +639,19 @@ void test_command7_parser_compare(testing & t) { } }; - auto test_legacy = [&](const std::string & input, bool need_more_input, bool print_results) { - // Original common_chat_combinator_parser taken from chat.cpp - common_chat_parser_params params; - params.format = COMMON_CHAT_FORMAT_GENERIC; - params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - params.reasoning_in_content = false; - params.thinking_forced_open = false; - common_chat_msg_parser builder( - input, - /* .is_partial = */ need_more_input, - params - ); - - builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>"); - - static const common_regex start_action_regex("<\\|START_ACTION\\|>"); - static const common_regex end_action_regex("<\\|END_ACTION\\|>"); - static const common_regex start_response_regex("<\\|START_RESPONSE\\|>"); - static const common_regex end_response_regex("<\\|END_RESPONSE\\|>"); - - if (auto res = builder.try_find_regex(start_action_regex)) { - // If we didn't extract thoughts, prelude includes them. - auto tool_calls = builder.consume_json_with_dumped_args({ { "parameters" } }); - for (const auto & tool_call : tool_calls.value) { - std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : ""; - std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : ""; - std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : ""; - if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - } - if (tool_calls.is_partial) { - throw common_chat_msg_partial_exception("incomplete tool call"); - } - builder.consume_regex(end_action_regex); - } else if (auto res = builder.try_find_regex(start_response_regex)) { - if (!builder.try_find_regex(end_response_regex)) { - builder.add_content(builder.consume_rest()); - throw common_chat_msg_partial_exception(end_response_regex.str()); - } - } else { - builder.add_content(builder.consume_rest()); - } - - if (print_results) { - std::cout << "== Parsed (legacy) ==\n"; - std::cout << "=== Reasoning ===\n"; - std::cout << builder.result().reasoning_content << "\n"; - std::cout << "\n\n=== Content ===\n"; - std::cout << builder.result().content << "\n"; - std::cout << "\n\n=== Tool Calls ===\n"; - for (const auto & tc : builder.result().tool_calls) { - std::cout << "id: " << tc.id << "\n"; - std::cout << "name: " << tc.name << "\n"; - std::cout << "args: " << tc.arguments << "\n"; - } - } - }; - - std::string reasoning = "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " - "budget of $4000 for a two-week stay, we need to:\n\n" - "1. Identify key historical sites and modern attractions in Japan.\n" - "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" - "3. Determine the best modes of transportation for getting around Japan.\n" - "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " - "overspending.\n" - "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " - "to attractions."; - - std::vector> tool_calls = {{ - "call_0", - "plan_trip", - nlohmann::json::parse(R"({ + std::string reasoning = + "To plan an effective trip to Japan that includes both historical sites and modern attractions within a " + "budget of $4000 for a two-week stay, we need to:\n\n" + "1. Identify key historical sites and modern attractions in Japan.\n" + "2. Find affordable accommodation options that provide a balance between comfort and cost.\n" + "3. Determine the best modes of transportation for getting around Japan.\n" + "4. Create a day-by-day itinerary that ensures the user gets to see a variety of attractions without " + "overspending.\n" + "5. Provide a detailed cost breakdown that includes accommodation, transportation, meals, and entry fees " + "to attractions."; + + std::vector> tool_calls = { + { "call_0", "plan_trip", nlohmann::json::parse(R"({ "destination": "Japan", "duration": 14, "budget": 4000, @@ -694,8 +659,8 @@ void test_command7_parser_compare(testing & t) { "accommodation_preferences": "affordable", "transportation_preferences": "efficient", "meal_preferences": "local cuisine" - })") - }}; + })") } + }; std::vector tokens; @@ -712,10 +677,10 @@ void test_command7_parser_compare(testing & t) { auto json = nlohmann::json::array(); for (const auto & tc : tool_calls) { - auto tc_json = nlohmann::json::object(); + auto tc_json = nlohmann::json::object(); tc_json["tool_call_id"] = std::get<0>(tc); - tc_json["tool_name"] = std::get<1>(tc); - tc_json["parameters"] = std::get<2>(tc); + tc_json["tool_name"] = std::get<1>(tc); + tc_json["parameters"] = std::get<2>(tc); json.push_back(tc_json); } @@ -727,42 +692,284 @@ void test_command7_parser_compare(testing & t) { std::string input = std::accumulate(tokens.begin(), tokens.end(), std::string()); - // Run tests - t.test("legacy_parse", [&](testing & /* t */) { - test_legacy(input, false, false); - }); + t.test("current_parse", [&](testing & /* t */) { test_current(parser, input, false, false); }); + t.bench("current_parse_benchmark complete", [&]() { test_current(parser, input, false, false); }, 100); + t.bench( + "current_parse_benchmark incremental", + [&]() { + std::string in; + for (auto i = 0u; i < tokens.size(); i++) { + in += tokens[i]; + test_current(parser, in, i + 1 < tokens.size(), false); + } + }, + 20); +} + +// Test that tool names that are proper prefixes of other tool names don't cause +// premature matching during incremental parsing. +// For example, "special_function" should not match when parsing "special_function_with_opt". +static void test_prefix_tool_names(testing & t) { + // Create tools where one name is a proper prefix of another + json tools = json::array(); - t.test("current_parse", [&](testing & /* t */) { - test_current(parser, input, false, false); + json tool_short = { + { "type", "function" }, + { "function", + { + { "name", "special_function" }, + { "description", "A special function" }, + { "parameters", + { + { "type", "object" }, + { "properties", + { + { "arg1", { { "type", "integer" } } }, + } }, + { "required", { "arg1" } }, + } }, + } } + }; + tools.push_back(tool_short); + + json tool_long = { + { "type", "function" }, + { "function", + { + { "name", "special_function_with_opt" }, + { "description", "A special function with optional params" }, + { "parameters", + { + { "type", "object" }, + { "properties", + { + { "arg1", { { "type", "integer" } } }, + { "arg2", { { "type", "integer" } } }, + } }, + { "required", { "arg1" } }, + } }, + } } + }; + tools.push_back(tool_long); + + // Use standard_constructed_tools which had the prefix matching bug + std::map markers = { + { "tool_call_start_marker", "" }, + { "tool_call_end_marker", "" }, + { "function_opener", "" }, + { "function_name_suffix", ">" }, + { "parameter_key_prefix", "" }, + { "parameter_closer", "" }, + }; + + auto parser = build_chat_peg_parser([&](common_chat_peg_builder & p) { + auto content = p.rule("content", p.content(p.until(""))); + auto tool_call = p.standard_constructed_tools(markers, tools, false, false); + return content + p.zero_or_more(p.space() + tool_call) + p.end(); }); - // Run benchmarks - t.bench("legacy_parse_benchmark complete", [&]() { - test_legacy(input, false, false); + // Test parsing the long tool name - this should NOT trigger the short tool name + t.test("parse long tool name", [&](testing & t) { + std::string input = + "Let me call the function." + "" + "" + "42" + "" + ""; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + t.assert_true("success", result.success()); + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("content", "Let me call the function.", msg.content); + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "special_function_with_opt", msg.tool_calls[0].name); + } }); - t.bench("legacy_parse_benchmark incremental", [&]() { - std::string in; - for (auto i = 0u; i < tokens.size(); i++) { - in += tokens[i]; + // Test incremental parsing - the key test case + // This ensures that when incrementally parsing "special_function_with_opt", + // we don't prematurely emit "special_function" as a tool call + t.test("incremental parse long tool name", [&](testing & t) { + std::string input = + "Let me call the function." + "" + "" + "42" + "" + ""; + + std::vector tokens = simple_tokenize(input); + + common_chat_msg prev; + for (auto it = tokens.begin(); it != tokens.end(); it++) { + std::string in = std::accumulate(tokens.begin(), it + 1, std::string()); + + common_peg_parse_context ctx(in, it + 1 < tokens.end()); + auto result = parser.parse(ctx); + + if (!t.assert_equal("not fail", false, result.fail())) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + return; + } + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + // The critical check: during incremental parsing, we should never + // see "special_function" as the tool name when parsing "special_function_with_opt" + for (const auto & tc : msg.tool_calls) { + if (!t.assert_equal("tool name should not be short prefix", false, + tc.name == "special_function")) { + t.log("Premature tool name match at input: " + in); + return; + } + } try { - test_legacy(in, i + 1 < tokens.size(), false); - } catch (common_chat_msg_partial_exception & /* e */) { - // Do nothing, this is expected + auto diffs = common_chat_msg_diff::compute_diffs(prev, msg); + } catch (const std::exception & e) { + t.log(in.substr(0, result.end) + "[failed->]" + in.substr(result.end)); + t.assert_true(std::string("diff failed with ") + e.what(), false); + return; } + + prev = msg; } - }, 20); - t.bench("current_parse_benchmark complete", [&]() { - test_current(parser, input, false, false); - }, 100); + // Final check: the complete parse should have the correct tool name + t.assert_equal("final tool calls count", 1u, prev.tool_calls.size()); + if (!prev.tool_calls.empty()) { + t.assert_equal("final tool name", "special_function_with_opt", prev.tool_calls[0].name); + } + }); - t.bench("current_parse_benchmark incremental", [&]() { - std::string in; - for (auto i = 0u; i < tokens.size(); i++) { - in += tokens[i]; - test_current(parser, in, i + 1 < tokens.size(), false); + // Test parsing the short tool name still works + t.test("parse short tool name", [&](testing & t) { + std::string input = + "Let me call the function." + "" + "" + "42" + "" + ""; + + common_peg_parse_context ctx(input, false); + auto result = parser.parse(ctx); + + t.assert_true("success", result.success()); + + common_chat_msg msg; + auto mapper = common_chat_peg_mapper(msg); + mapper.from_ast(ctx.ast, result); + + t.assert_equal("content", "Let me call the function.", msg.content); + t.assert_equal("tool calls count", 1u, msg.tool_calls.size()); + if (!msg.tool_calls.empty()) { + t.assert_equal("tool name", "special_function", msg.tool_calls[0].name); } - }, 20); + }); +} + +static void test_tagged_peg_parser(testing & t) { + t.test("basic tag extraction", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + return p.tag("greeting", p.until(" ")) + " " + p.tag("name", p.rest()) + p.end(); + }); + + auto result = parser.parse_and_extract("Hello World"); + t.assert_true("success", result.result.success()); + t.assert_equal("greeting tag", "Hello", result.tags.at("greeting")); + t.assert_equal("name tag", "World", result.tags.at("name")); + }); + + t.test("duplicate tags overwrite", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + return p.tag("item", p.until(",")) + "," + p.tag("item", p.rest()) + p.end(); + }); + + auto result = parser.parse_and_extract("first,second"); + t.assert_true("success", result.result.success()); + t.assert_equal("item tag", "second", result.tags.at("item")); + }); + + t.test("no tags extracted", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + return p.rest() + p.end(); + }); + + auto result = parser.parse_and_extract("Hello"); + t.assert_true("success", result.result.success()); + t.assert_equal("empty tags", 0u, result.tags.size()); + }); + + t.test("structured extraction", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + auto header = p.tag("header", p.until("\n")); + auto body = p.tag("body", p.rest()); + return header + "\n" + body + p.end(); + }); + + auto result = parser.parse_and_extract("Title\nBody content here"); + t.assert_true("success", result.result.success()); + t.assert_equal("header", "Title", result.tags.at("header")); + t.assert_equal("body", "Body content here", result.tags.at("body")); + }); + + t.test("partial parse", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + return p.tag("prefix", p.until(":")) + ":" + p.tag("value", p.rest()) + p.end(); + }); + + auto result = parser.parse_and_extract("key:val", true); + t.assert_true("not fail", !result.result.fail()); + t.assert_equal("prefix tag", "key", result.tags.at("prefix")); + t.assert_equal("value tag", "val", result.tags.at("value")); + }); + + t.test("find in the middle", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + return p.choice({ p.literal("{"), p.literal(":") }) + p.space() + p.literal("\"") + p.atomic(p.literal("fun_name")); + }); + + std::string tpl = "This is a very long jinja template string. We have tools. We will try to call them now: { \"fun_name\" : { \"arg\" : 1 }"; + auto result = parser.parse_anywhere_and_extract(tpl); + t.assert_true("success", result.result.success()); + }); + + t.test("fail find in the middle", [&](testing & t) { + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + return p.choice({ p.literal("{"), p.literal(":") }) + p.space() + p.literal("\"") + p.atomic(p.literal("fun_name")); + }); + + std::string tpl = "This is a very long jinja template string. We have tools. We will try to call them now: 1"; + auto result = parser.parse_anywhere_and_extract(tpl); + t.assert_true("failure", result.result.fail()); + }); + + t.test("find function tag with name", [&](testing &t) { + std::string haystack = "\n\n\n\nXXXX\n\n\nYYYY\n\n\n\n"; + auto parser = build_tagged_peg_parser([](common_peg_parser_builder & p) { + std::string needle = "foofoo"; + return p.tag("fun_marker", p.choice({ + p.tag("fun_pre", p.literal("<") + p.until_one_of({ ">", needle })) + p.literal(needle) + + p.tag("fun_post", p.negate(p.space() + p.literal("<")) + p.until(">") + p.literal(">")) + p.space(), + p.tag("fun_pre", p.literal("[") + p.until_one_of({ "]", needle })) + p.literal(needle) + + p.tag("fun_post", p.negate(p.space() + p.literal("[") + p.until("]") + p.literal("]")) + p.space()) })); + }); + auto result = parser.parse_anywhere_and_extract(haystack); + t.assert_true("success", result.result.success()); + t.assert_equal("fun_pre should be ''", ">", result.tags["fun_post"]); + }); } diff --git a/tests/test-chat-template.cpp b/tests/test-chat-template.cpp index 27b537a0369..6cc132131c1 100644 --- a/tests/test-chat-template.cpp +++ b/tests/test-chat-template.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -21,17 +22,16 @@ using json = nlohmann::ordered_json; -int main_automated_tests(void); +static int main_automated_tests(void); -void run_multiple(std::string dir_path, bool stop_on_first_failure, json input, bool use_common = false); -void run_single(std::string contents, json input, bool use_common = false, const std::string & output_path = ""); +static void run_multiple(const std::string& dir_path, bool stop_on_first_failure, const json& input, bool use_common = false); +static void run_single(const std::string& contents, json input, bool use_common = false, const std::string & output_path = ""); - - -std::string HELP = R"( +static std::string HELP = R"( Usage: test-chat-template [OPTIONS] PATH_TO_TEMPLATE Options: -h, --help Show this help message and exit. + --with-tools Add a tool and a tool call to the default JSON input --json Path to the JSON input file. --stop-on-first-fail Stop testing on the first failure (default: false). --no-common Use direct Jinja engine instead of common chat templates (default: use common). @@ -41,7 +41,23 @@ If PATH_TO_TEMPLATE is a directory, runs all .jinja files in that directory. If PATH_TO_TEMPLATE is omitted, runs automated tests (default CI mode). )"; -std::string DEFAULT_JSON = R"({ +static std::string DEFAULT_JSON = R"({ + "messages": [ + { + "role": "user", + "content": "Hello, how are you?" + }, + { + "role": "assistant", + "content": "I am fine, thank you!" + } + ], + "bos_token": "", + "eos_token": "", + "add_generation_prompt": true +})"; + +static std::string DEFAULT_JSON_WITH_TOOLS = R"({ "messages": [ { "role": "user", @@ -50,6 +66,41 @@ std::string DEFAULT_JSON = R"({ { "role": "assistant", "content": "I am fine, thank you!" + }, + { + "role": "user", + "content": "Call a tool!" + }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call00001", + "type": "function", + "function": { + "name": "test", + "arguments": { "arg": "hello" } + } + } + ] + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "test", + "description": "Test", + "parameters": { + "type": "object", + "properties": { + "arg": { + "type": "string" + } + } + }, + "required": ["arg"] + } } ], "bos_token": "", @@ -57,12 +108,14 @@ std::string DEFAULT_JSON = R"({ "add_generation_prompt": true })"; + int main(int argc, char ** argv) { std::vector args(argv, argv + argc); std::string tmpl_path; std::string json_path; std::string output_path; + std::string & json_to_use = DEFAULT_JSON; bool stop_on_first_fail = false; bool use_common = true; @@ -70,9 +123,12 @@ int main(int argc, char ** argv) { if (args[i] == "--help" || args[i] == "-h") { std::cout << HELP << "\n"; return 0; - } else if (args[i] == "--json" && i + 1 < args.size()) { + } + if (args[i] == "--json" && i + 1 < args.size()) { json_path = args[i + 1]; i++; + } else if (args[i] == "--with-tools") { + json_to_use = DEFAULT_JSON_WITH_TOOLS; } else if (args[i] == "--stop-on-first-fail") { stop_on_first_fail = true; } else if (args[i] == "--output" && i + 1 < args.size()) { @@ -105,7 +161,7 @@ int main(int argc, char ** argv) { std::istreambuf_iterator()); input_json = json::parse(content); } else { - input_json = json::parse(DEFAULT_JSON); + input_json = json::parse(json_to_use); } std::filesystem::path p(tmpl_path); @@ -125,7 +181,7 @@ int main(int argc, char ** argv) { return 0; } -void run_multiple(std::string dir_path, bool stop_on_first_fail, json input, bool use_common) { +void run_multiple(const std::string& dir_path, bool stop_on_first_fail, const json& input, bool use_common) { std::vector failed_tests; // list all files in models/templates/ and run each @@ -180,7 +236,7 @@ static std::string format_using_common( common_chat_templates_inputs inputs; inputs.use_jinja = true; inputs.messages = messages; - inputs.tools = tools; + inputs.tools = std::move(tools); inputs.add_generation_prompt = true; auto output = common_chat_templates_apply(tmpls.get(), inputs).prompt; output = normalize_newlines(output); @@ -209,7 +265,7 @@ static jinja::value_string format_using_direct_engine( jinja::runtime runtime(ctx); const jinja::value results = runtime.execute(ast); - auto parts = runtime.gather_string_parts(results); + auto parts = jinja::runtime::gather_string_parts(results); std::cout << "\n=== RESULTS ===\n"; for (const auto & part : parts->as_string().parts) { @@ -220,7 +276,7 @@ static jinja::value_string format_using_direct_engine( } -void run_single(std::string contents, json input, bool use_common, const std::string & output_path) { +void run_single(const std::string& contents, json input, bool use_common, const std::string & output_path) { jinja::enable_debug(true); jinja::value_string output_parts; @@ -560,7 +616,7 @@ int main_automated_tests(void) { supported_tmpl.resize(res); res = llama_chat_builtin_templates(supported_tmpl.data(), supported_tmpl.size()); std::cout << "Built-in chat templates:\n"; - for (auto tmpl : supported_tmpl) { + for (const auto *tmpl : supported_tmpl) { std::cout << " " << tmpl << "\n"; } @@ -592,6 +648,7 @@ int main_automated_tests(void) { } std::vector messages; + messages.reserve(conversation.size()); for (const auto & msg : conversation) { messages.push_back(simple_msg(msg.role, msg.content)); } @@ -622,58 +679,6 @@ int main_automated_tests(void) { } } - // TODO: llama_chat_format_single will be deprecated, remove these tests later - - // test llama_chat_format_single for system message - std::cout << "\n\n=== llama_chat_format_single (system message) ===\n\n"; - std::vector chat2; - auto sys_msg = simple_msg("system", "You are a helpful assistant"); - - auto fmt_sys = [&](std::string tmpl_str) { - auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str); - auto output = common_chat_format_single(tmpls.get(), chat2, sys_msg, false, /* use_jinja= */ false); - std::cout << "fmt_sys(" << tmpl_str << ") : " << output << "\n"; - std::cout << "-------------------------\n"; - return output; - }; - assert(fmt_sys("chatml") == "<|im_start|>system\nYou are a helpful assistant<|im_end|>\n"); - assert(fmt_sys("mistral-v1") == " [INST] You are a helpful assistant\n\n"); - assert(fmt_sys("mistral-v3") == "[INST] You are a helpful assistant\n\n"); - assert(fmt_sys("mistral-v3-tekken") == "[INST]You are a helpful assistant\n\n"); - assert(fmt_sys("mistral-v7") == "[SYSTEM_PROMPT] You are a helpful assistant[/SYSTEM_PROMPT]"); - assert(fmt_sys("llama2") == "[INST] You are a helpful assistant\n"); - assert(fmt_sys("llama2-sys") == "[INST] <>\nYou are a helpful assistant\n<>\n\n"); - assert(fmt_sys("mistral") == "[INST] You are a helpful assistant\n"); // for old pre-v1 templates - assert(fmt_sys("gemma") == ""); // for gemma, system message is merged with user message - assert(fmt_sys("llama3") == "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful assistant<|eot_id|>"); - assert(fmt_sys("gigachat") == "You are a helpful assistant<|message_sep|>"); - - - // test llama_chat_format_single for user message - std::cout << "\n\n=== llama_chat_format_single (user message) ===\n\n"; - chat2.push_back(simple_msg("system", "You are a helpful assistant")); - chat2.push_back(simple_msg("user", "Hello")); - chat2.push_back(simple_msg("assistant", "I am assistant")); - auto new_msg = simple_msg("user", "How are you"); - - auto fmt_single = [&](const std::string & tmpl_str) { - auto tmpls = common_chat_templates_init(/* model= */ nullptr, tmpl_str.c_str()); - auto output = common_chat_format_single(tmpls.get(), chat2, new_msg, true, /* use_jinja= */ false); - std::cout << "fmt_single(" << tmpl_str << ") : " << output << "\n"; - std::cout << "-------------------------\n"; - return output; - }; - assert(fmt_single("chatml") == "\n<|im_start|>user\nHow are you<|im_end|>\n<|im_start|>assistant\n"); - assert(fmt_single("mistral-v1") == " [INST] How are you [/INST]"); - assert(fmt_single("mistral-v3") == "[INST] How are you[/INST]"); - assert(fmt_single("mistral-v3-tekken") == "[INST]How are you[/INST]"); - assert(fmt_single("mistral-v7") == "[INST] How are you[/INST]"); - assert(fmt_single("llama2") == "[INST] How are you [/INST]"); - assert(fmt_single("mistral") == "[INST] How are you [/INST]"); // for old pre-v1 templates - assert(fmt_single("gemma") == "\nuser\nHow are you\nmodel\n"); - assert(fmt_single("llama3") == "<|start_header_id|>user<|end_header_id|>\n\nHow are you<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"); - // assert(fmt_single("gigachat") == "user<|role_sep|>How are you<|message_sep|>available functions<|role_sep|>[]<|message_sep|>assistant<|role_sep|>"); - std::cout << "\nOK: All tests passed successfully.\n"; return 0; diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index f3d19118b58..0bd95af5d43 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -5,18 +5,22 @@ // // cmake -B build && cmake --build build --parallel && ./build/bin/test-chat ../minja/build/tests/*.jinja 2>/dev/null // +#include "../src/llama-grammar.h" +#include "../src/unicode.h" +#include "chat-auto-parser.h" #include "chat.h" - +#include "common.h" +#include "ggml.h" #include "log.h" -#include "../src/unicode.h" -#include "../src/llama-grammar.h" - -#include - +#include +#include #include -#include #include +#include +#include +#include +#include #include using json = nlohmann::ordered_json; @@ -33,6 +37,7 @@ static std::ostream & operator<<(std::ostream & os, const common_chat_msg_diff & os << "}"; return os; } + // operator<< for vector: static std::ostream & operator<<(std::ostream & os, const std::vector & diffs) { os << "[\n"; @@ -42,6 +47,7 @@ static std::ostream & operator<<(std::ostream & os, const std::vector -bool equals(const common_chat_msg & expected, const common_chat_msg & actual) { +template <> bool equals(const common_chat_msg & expected, const common_chat_msg & actual) { return normalize(expected) == normalize(actual); } template static void assert_equals(const T & expected, const T & actual) { if (!equals(expected, actual)) { - std::cerr << "Expected:```\n" << expected << "\n```" << std::endl; - std::cerr << "Actual:```\n" << actual << "\n```" << std::endl; - std::cerr << std::flush; + std::ostringstream oss_expected; + oss_expected << expected; + std::ostringstream oss_actual; + oss_actual << actual; + LOG_ERR("Expected: %s\n", oss_expected.str().c_str()); + LOG_ERR("Actual: %s\n", oss_actual.str().c_str()); + common_log_flush(common_log_main()); throw std::runtime_error("Test failed"); } } static std::string read_file(const std::string & path) { - std::cerr << "# Reading: " << path << '\n' << std::flush; std::ifstream fs(path, std::ios_base::binary); if (!fs.is_open()) { fs = std::ifstream("../" + path, std::ios_base::binary); @@ -118,6 +125,207 @@ static std::unique_ptr build_grammar(const std::string & grammar_ llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root", false, nullptr, 0, nullptr, 0)); } +// Helper to format a code point as a readable string +static std::string format_codepoint(uint32_t cp) { + if (cp >= 32 && cp < 127) { + return std::string("'") + static_cast(cp) + "'"; + } else if (cp == '\n') { + return "'\\n'"; + } else if (cp == '\r') { + return "'\\r'"; + } else if (cp == '\t') { + return "'\\t'"; + } else { + return "U+" + std::to_string(cp); + } +} + +// Helper to format expected element from grammar stack +static std::string format_expected_element(const llama_grammar_rules & /* rules*/, const llama_grammar_element * elem) { + if (!elem) { + return ""; + } + + switch (elem->type) { + case LLAMA_GRETYPE_END: + return ""; + case LLAMA_GRETYPE_ALT: + return ""; + case LLAMA_GRETYPE_RULE_REF: + { + // Find rule name - just show rule ID for now + return "value) + ">"; + } + case LLAMA_GRETYPE_CHAR: + { + std::string result; + const llama_grammar_element * pos = elem; + bool first = true; + + do { + if (!first) { + result += " | "; + } + first = false; + + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + // Range like [a-z] + result += "[" + format_codepoint(pos->value) + "-" + format_codepoint(pos[1].value) + "]"; + pos += 2; + } else { + result += format_codepoint(pos->value); + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return result; + } + case LLAMA_GRETYPE_CHAR_NOT: + { + std::string result = "[^"; + const llama_grammar_element * pos = elem; + bool first = true; + + do { + if (!first) { + result += " "; + } + first = false; + + if (pos[1].type == LLAMA_GRETYPE_CHAR_RNG_UPPER) { + result += format_codepoint(pos->value) + "-" + format_codepoint(pos[1].value); + pos += 2; + } else { + result += format_codepoint(pos->value); + pos += 1; + } + } while (pos->type == LLAMA_GRETYPE_CHAR_ALT); + + return result + "]"; + } + case LLAMA_GRETYPE_CHAR_ANY: + return ""; + case LLAMA_GRETYPE_TOKEN: + return "value) + ">"; + case LLAMA_GRETYPE_TOKEN_NOT: + return "value) + ">"; + default: + return ""; + } +} + +// Get description of what the grammar expects at current position +static std::string get_expected_description(const llama_grammar_rules & rules, const llama_grammar_stacks & stacks) { + if (stacks.empty()) { + return ""; + } + + std::string result; + std::set seen; + + for (const auto & stack : stacks) { + if (stack.empty()) { + if (seen.insert("").second) { + if (!result.empty()) { + result += " OR "; + } + result += ""; + } + continue; + } + + const llama_grammar_element * elem = stack.back(); + std::string desc = format_expected_element(rules, elem); + if (seen.insert(desc).second) { + if (!result.empty()) { + result += " OR "; + } + result += desc; + } + } + + return result; +} + +// Result of a detailed grammar match attempt +struct grammar_match_result { + bool success = false; // Did the string fully match the grammar? + size_t matched_bytes = 0; // Bytes successfully matched before failure + size_t matched_codepoints = 0; // Codepoints successfully matched before failure + size_t total_bytes = 0; // Total bytes in input + size_t total_codepoints = 0; // Total codepoints in input + std::string matched_prefix; // The portion that was successfully matched + std::string failing_char; // The character that caused failure (if any) + std::string expected_description; // What the grammar expected at failure point + bool incomplete = false; // True if matched all input but grammar expects more +}; + +// Detailed version of match_string that returns failure information +static grammar_match_result match_string_detailed(const std::string & input, llama_grammar * grammar) { + grammar_match_result result; + result.total_bytes = input.size(); + + const auto cpts = unicode_cpts_from_utf8(input); + result.total_codepoints = cpts.size(); + + auto & stacks_cur = llama_grammar_get_stacks(grammar); + const auto & rules = llama_grammar_get_rules(grammar); + + size_t byte_pos = 0; + + for (size_t i = 0; i < cpts.size(); i++) { + const auto & cpt = cpts[i]; + + // Get expected before accepting (for error reporting) + std::string expected_before = get_expected_description(rules, stacks_cur); + + llama_grammar_accept(grammar, cpt); + + // Calculate byte position for this codepoint + size_t cpt_bytes = 0; + if (cpt < 0x80) { + cpt_bytes = 1; + } else if (cpt < 0x800) { + cpt_bytes = 2; + } else if (cpt < 0x10000) { + cpt_bytes = 3; + } else { + cpt_bytes = 4; + } + + if (stacks_cur.empty()) { + // Grammar failed to match at this point + result.matched_bytes = byte_pos; + result.matched_codepoints = i; + result.matched_prefix = input.substr(0, byte_pos); + result.failing_char = format_codepoint(cpt); + result.expected_description = expected_before; + result.incomplete = false; + return result; + } + + byte_pos += cpt_bytes; + } + + // All input matched - check if grammar is complete + result.matched_bytes = input.size(); + result.matched_codepoints = cpts.size(); + result.matched_prefix = input; + + if (std::any_of(stacks_cur.begin(), stacks_cur.end(), [](const auto & stack) { return stack.empty(); })) { + // An empty stack means that the grammar has been completed + result.success = true; + result.incomplete = false; + } else { + // Grammar expects more input + result.success = false; + result.incomplete = true; + result.expected_description = get_expected_description(rules, stacks_cur); + } + + return result; +} + // TODO: extract to common helper (copied from test-grammar-integration.cpp) static bool match_string(const std::string & input, llama_grammar * grammar) { const auto cpts = unicode_cpts_from_utf8(input); @@ -146,11 +354,13 @@ static std::string renormalize_json(const std::string & json_str) { auto json_obj = json::parse(json_str); return json_obj.dump(); } catch (const std::exception & e) { - std::cerr << "Failed to parse JSON: " << e.what() << '\n'; - return json_str; + return ""; // ignore parial JSON contents for comparison purposes } } -static void assert_msg_equals(const common_chat_msg & expected, const common_chat_msg & actual, bool ignore_whitespace_differences = false) { + +static void assert_msg_equals(const common_chat_msg & expected, + const common_chat_msg & actual, + bool ignore_whitespace_differences = false) { assert_equals(expected.role, actual.role); if (ignore_whitespace_differences) { assert_equals(string_strip(expected.content), string_strip(actual.content)); @@ -183,7 +393,7 @@ static void assert_msg_equals(const common_chat_msg & expected, const common_cha } } -common_chat_tool special_function_tool { +static common_chat_tool special_function_tool{ /* .name = */ "special_function", /* .description = */ "I'm special", /* .parameters = */ R"({ @@ -197,7 +407,7 @@ common_chat_tool special_function_tool { "required": ["arg1"] })", }; -common_chat_tool special_function_tool_with_optional_param { +static common_chat_tool special_function_tool_with_optional_param{ /* .name = */ "special_function_with_opt", /* .description = */ "I'm special but have optional stuff", /* .parameters = */ R"({ @@ -215,7 +425,15 @@ common_chat_tool special_function_tool_with_optional_param { "required": ["arg1"] })", }; -common_chat_tool python_tool { +static common_chat_tool empty_args_tool{ + /* .name = */ "empty_args", + /* .description = */ "A tool that takes no arguments", + /* .parameters = */ R"({ + "type": "object", + "properties": {} + })", +}; +static common_chat_tool python_tool{ /* .name = */ "python", /* .description = */ "an ipython interpreter", /* .parameters = */ R"({ @@ -229,7 +447,53 @@ common_chat_tool python_tool { "required": ["code"] })", }; -common_chat_tool todo_list_tool { + +static common_chat_tool html_tool{ + /* .name = */ "html", + /* .description = */ "an html validator", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "markup": { + "type": "string", + "description": "HTML markup to validate." + } + }, + "required": ["markup"] + })", +}; + +static common_chat_tool get_time_tool{ + /* .name = */ "get_time", + /* .description = */ "Get the current time in a city", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name" + } + }, + "required": ["city"] + })", +}; + +static common_chat_tool get_weather_tool{ + /* .name = */ "get_weather", + /* .description = */ "Get the current weather in a city", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "City name" + } + }, + "required": ["city"] + })", +}; + +static common_chat_tool todo_list{ /* .name = */ "todo_list", /* .description = */ "Create or update the todo list", /* .parameters = */ R"({ @@ -243,44 +507,310 @@ common_chat_tool todo_list_tool { "required": ["todos"] })", }; -common_chat_tool code_interpreter_tool { - /* .name = */ "code_interpreter", - /* .description = */ "an ipython interpreter", + +static common_chat_tool edit_tool{ + /* .name = */ "edit", + /* .description = */ "Edit file", /* .parameters = */ R"({ "type": "object", "properties": { - "code": { + "filename": { "type": "string", - "description": "Python code to execute." + "description": "Path of file to edit" + }, + "oldString": { + "type": "string", + "description": "String to replace" + }, + "newString": { + "type": "string", + "description": "New (replacement) value" } }, - "required": ["code"] + "required": ["filename", "oldString", "newString"] })", }; -std::vector tools { special_function_tool, special_function_tool_with_optional_param, python_tool }; -std::vector llama_3_1_tools { special_function_tool, code_interpreter_tool }; -struct delta_data { - std::string delta; - common_chat_params params; +static common_chat_tool magic_tool{ + /* .name = */ "magic", + /* .description = */ "Magic tool that takes a hash", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "ref": { + "type": "string" + } + }, + "required": ["name", "ref"] + })", +}; + +static common_chat_tool magic_int_tool{ + /* .name = */ "magic_int", + /* .description = */ "Magic tool that takes a hash", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "ref": { + "type": "integer" + }, + "name": { + "type": "string" + } + }, + "required": ["ref"] + })", +}; + +static common_chat_tool amount_tool{ + /* .name = */ "amount", + /* .description = */ "Amount converter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "orig": { + "type": "number" + } + }, + "required": ["orig"] + })", +}; + +static common_chat_tool imaginary_number_tool{ + /* .name = */ "imaginary_number", + /* .description = */ "Imaginary number converter", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "number": { + "type": "object", + "properties": { + "real": { + "type": "number" + }, + "imaginary": { + "type": "number" + } + }, + "required": ["real", "imaginary"] + } + }, + "required": ["number"] + })", +}; + +static common_chat_tool string_param_tool{ + /* .name = */ "string_param", + /* .description = */ "Tool with string parameter for testing", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "text": { + "type": "string", + "description": "A text parameter" + } + }, + "required": [] + })", +}; + +static common_chat_tool quoted_unquoted_tool{ + /* .name = */ "quoted_unquoted", + /* .description = */ "Tool with two string parameters, one for quoted string, one for unquoted", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "quoted": { + "type": "string", + "description": "Quoted value" + }, + "unquoted": { + "type": "string", + "description": "Unquoted value" + } + }, + "required": ["quoted", "unquoted"] + })", +}; + + +static common_chat_tool tool_2req_4opt{ + /* .name = */ "tool_2req_4opt", + /* .description = */ "Tool with 2 required and 4 optional params", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "req1": { "type": "string", "description": "Required string" }, + "req2": { "type": "integer", "description": "Required int" }, + "opt1": { "type": "string", "description": "Optional string 1" }, + "opt2": { "type": "integer", "description": "Optional int 1" }, + "opt3": { "type": "string", "description": "Optional string 2" }, + "opt4": { "type": "integer", "description": "Optional int 2" } + }, + "required": ["req1", "req2"] + })", +}; + +static common_chat_tool tool_2req_5opt{ + /* .name = */ "tool_2req_5opt", + /* .description = */ "Tool with 2 required and 5 optional params", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "req1": { "type": "string", "description": "Required string" }, + "req2": { "type": "integer", "description": "Required int" }, + "opt1": { "type": "string", "description": "Optional string 1" }, + "opt2": { "type": "integer", "description": "Optional int 1" }, + "opt3": { "type": "string", "description": "Optional string 2" }, + "opt4": { "type": "integer", "description": "Optional int 2" }, + "opt5": { "type": "string", "description": "Optional string 3" } + }, + "required": ["req1", "req2"] + })", +}; + +static std::vector tools{ special_function_tool, special_function_tool_with_optional_param, + python_tool, html_tool, todo_list }; + +const common_chat_msg message_user{ + "user", + "Hey there!", + /* .content_parts = */ {}, + /* .tool_calls = */ {}, + /* .reasoning_content = */ "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", +}; + +const common_chat_msg message_user_parts{ + "user", + /* .content = */ "", + /* .content_parts = */ + { + { "text", "Hey" }, + { "text", "there" }, + }, + /* .tool_calls = */ + { }, + /* .reasoning_content = */ + "", + /* .tool_name = */ "", + /* .tool_call_id = */ "", }; -static common_chat_msg simple_assist_msg(const std::string & content, const std::string & reasoning_content = "", const std::string & tool_name = "", const std::string & arguments = "", const std::string & id = "") { +static common_chat_msg simple_assist_msg(const std::string & content, + const std::string & reasoning_content = "", + const std::string & tool_name = "", + const std::string & arguments = "", + const std::string & id = "") { common_chat_msg msg; - msg.role = "assistant"; - msg.content = content; + msg.role = "assistant"; + msg.content = content; msg.reasoning_content = reasoning_content; - if (!tool_name.empty()) { + if (!tool_name.empty() || !id.empty()) { msg.tool_calls.push_back({ tool_name, arguments, id }); } return msg; } -static delta_data init_delta(const struct common_chat_templates * tmpls, const std::vector & end_tokens, - const common_chat_msg & user_message, - const common_chat_msg & delta_message, +static common_chat_msg message_with_tool_calls(const std::string & tool_name, const std::string & arguments) { + return simple_assist_msg("", "", tool_name, arguments); +} + +static common_chat_msg message_with_tool_calls_and_reasoning(const std::string & tool_name, + const std::string & arguments, + const std::string & reasoning) { + return simple_assist_msg("", reasoning, tool_name, arguments); +} + +static common_chat_msg message_with_reasoning_content_and_multiple_tool_calls( + const std::string & reasoning, + const std::string & content, + const std::vector> & tool_calls) { + common_chat_msg msg; + msg.role = "assistant"; + msg.content = content; + msg.reasoning_content = reasoning; + for (const auto & [name, args] : tool_calls) { + msg.tool_calls.push_back({ name, args, "" }); + } + return msg; +} + +static common_chat_msg message_with_content_and_tool_call(const std::string & content, + const std::string & tool_name, + const std::string & arguments) { + return simple_assist_msg(content, "", tool_name, arguments); +} + +static common_chat_msg message_with_reasoning_and_tool_call(const std::string & reasoning, + const std::string & tool_name, + const std::string & arguments) { + return simple_assist_msg("", reasoning, tool_name, arguments); +} + +const common_chat_msg message_assist = simple_assist_msg("Hello, world!\nWhat's up?"); +const common_chat_msg message_assist_empty = simple_assist_msg(""); +const common_chat_msg message_assist_thoughts_unparsed_deepseek = + simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_unparsed_md = + simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```"); +const common_chat_msg message_assist_thoughts_unparsed_md_partial = + simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}"); + +const common_chat_msg message_assist_thoughts_unparsed_r7b = + simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_unparsed_magistral = + simple_assist_msg("[THINK]raisonnement[/THINK]Réponse"); +const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"); +const common_chat_msg message_assist_thoughts_unopened_unparsed = + simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); +const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking"); +const common_chat_msg message_assist_call = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_noopt = + simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_withopt = + simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1, \"arg2\": 2}"); +const common_chat_msg message_assist_call_content = + simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}"); +const common_chat_msg message_assist_call_empty_args = simple_assist_msg("", "", "special_function"); +const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg"); +const common_chat_msg message_assist_call_thoughts = + simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}"); +const common_chat_msg message_assist_call_thoughts_unparsed = + simple_assist_msg("I'm\nthinking\n\n", "", "special_function", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_thoughts_content = + simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}"); +const common_chat_msg message_assist_call_id = + simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789"); +const common_chat_msg message_assist_call_idx = + simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0"); +const common_chat_msg message_assist_thoughts_call_idx = + simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0"); +const common_chat_msg message_assist_thoughts_partial_call = + simple_assist_msg("", "I'm\nthinking", "special_function", "", /* id = */ "0"); +const common_chat_msg message_assist_call_python = simple_assist_msg("", "", "python", "{\"code\":\"print('hey')\"}"); +const common_chat_msg message_assist_call_python_lines = + simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')\"}"); +const common_chat_msg message_assist_call_python_lines_unclosed = + simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')"); +const common_chat_msg message_assist_json_content = + simple_assist_msg("{\n \"response\": \"Hello, world!\\nWhat's up?\"\n}"); + +struct delta_data { + std::string delta; + common_chat_params params; +}; + +static delta_data init_delta(const struct common_chat_templates * tmpls, + const std::vector & end_tokens, + const common_chat_msg & user_message, + const common_chat_msg & delta_message, const std::vector & tools, - const common_chat_tool_choice & tool_choice) { + const common_chat_tool_choice & tool_choice) { common_chat_templates_inputs inputs; inputs.parallel_tool_calls = true; inputs.messages.push_back(user_message); @@ -331,20 +861,27 @@ static delta_data init_delta(const struct common_chat_templates * tmpls, const s gets the diff, removes any end tokens and parses the result w/ the grammar, checking that the parsed message is the same as the test_message */ -static void test_templates(const struct common_chat_templates * tmpls, const std::vector & end_tokens, - const common_chat_msg & test_message, - const std::vector & tools = {}, - const std::string & expected_delta = "", - bool expect_grammar_triggered = true, - bool test_grammar_if_triggered = true, - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE, - bool ignore_whitespace_differences = false - ) { +static void test_templates(const struct common_chat_templates * tmpls, + const std::vector & end_tokens, + const common_chat_msg & test_message, + const std::vector & tools = {}, + const std::string & expected_delta = "", + bool expect_grammar_triggered = true, + bool test_grammar_if_triggered = true, + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE, + bool ignore_whitespace_differences = false) { common_chat_msg user_message; - user_message.role = "user"; + user_message.role = "user"; user_message.content = "Hello, world!"; - for (const auto & tool_choice : std::vector {COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED}) { + common_chat_templates_inputs inputs_tools; + inputs_tools.messages = { message_user }; + inputs_tools.tools = { special_function_tool }; + + common_chat_params params = common_chat_templates_apply(tmpls, inputs_tools); + + for (const auto & tool_choice : + std::vector{ COMMON_CHAT_TOOL_CHOICE_AUTO, COMMON_CHAT_TOOL_CHOICE_REQUIRED }) { auto data = init_delta(tmpls, end_tokens, user_message, test_message, tools, tool_choice); if (!expected_delta.empty()) { if (ignore_whitespace_differences) { @@ -356,10 +893,14 @@ static void test_templates(const struct common_chat_templates * tmpls, const std if (expect_grammar_triggered) { // TODO @ngxson : refactor common_chat_parse to avoid passing format/reasoning_format every time - common_chat_parser_params params; - params.format = data.params.format; - params.reasoning_format = reasoning_format; - const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, params); + common_chat_parser_params parser_params; + parser_params.format = data.params.format; + parser_params.reasoning_format = reasoning_format; + if (!parser_params.parser.empty()) { + parser_params.parser = common_peg_arena(); + parser_params.parser.load(params.parser); + } + const auto msg = common_chat_parse(data.delta, /* is_partial= */ false, parser_params); assert_msg_equals(test_message, msg, ignore_whitespace_differences); } @@ -372,43 +913,43 @@ static void test_templates(const struct common_chat_templates * tmpls, const std throw std::runtime_error("Failed to build grammar"); } auto earliest_trigger_pos = std::string::npos; - auto constrained = data.delta; + auto constrained = data.delta; for (const auto & trigger : data.params.grammar_triggers) { - size_t pos = std::string::npos; + size_t pos = std::string::npos; std::smatch match; switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: - { - const auto & word = trigger.value; - pos = constrained.find(word); - break; - } + { + const auto & word = trigger.value; + pos = constrained.find(word); + break; + } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - { - const auto & pattern = trigger.value; - if (std::regex_search(constrained, match, std::regex(pattern))) { - pos = match.position(1); + { + const auto & pattern = std::regex(trigger.value); + if (std::regex_search(constrained, match, pattern)) { + pos = match.position(pattern.mark_count()); + } + break; } - break; - } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: - { - const auto & pattern = trigger.value; - if (std::regex_match(constrained, match, std::regex(pattern))) { - auto mpos = std::string::npos; - for (size_t i = 1; i < match.size(); ++i) { - if (match[i].length() > 0) { - mpos = match.position(i); - break; + { + const auto & pattern = trigger.value; + if (std::regex_match(constrained, match, std::regex(pattern))) { + auto mpos = std::string::npos; + for (size_t i = 1; i < match.size(); ++i) { + if (match[i].length() > 0) { + mpos = match.position(i); + break; + } } + if (mpos == std::string::npos) { + mpos = match.position(0); + } + pos = mpos; } - if (mpos == std::string::npos) { - mpos = match.position(0); - } - pos = mpos; + break; } - break; - } default: throw std::runtime_error("Unknown trigger type"); } @@ -421,7 +962,7 @@ static void test_templates(const struct common_chat_templates * tmpls, const std } auto grammar_triggered = false; if (earliest_trigger_pos != std::string::npos) { - constrained = constrained.substr(earliest_trigger_pos); + constrained = constrained.substr(earliest_trigger_pos); grammar_triggered = true; } if (data.params.grammar_lazy) { @@ -430,39 +971,45 @@ static void test_templates(const struct common_chat_templates * tmpls, const std if (grammar_triggered && test_grammar_if_triggered && !match_string(constrained, grammar.get())) { throw std::runtime_error("Failed to match delta against grammar:\n\n" + data.delta + - "\n\nConstrained: " + constrained + - "\n\nGrammar: " + data.params.grammar); + "\n\nConstrained: " + constrained + "\n\nGrammar: " + data.params.grammar); } } } } /** - * Test if streaming=true is consistant with streaming=false for given partial parser + * Test if streaming=true is consistent with streaming=false for given partial parser * Also test if there is any problem with partial message */ template static void test_parser_with_streaming(const common_chat_msg & expected, const std::string & raw_message, T parse_msg) { constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t { auto len = s.size(); - if (len == 0) return 0; + if (len == 0) { + return 0; + } auto i = len; for (size_t back = 0; back < 4 && i > 0; ++back) { --i; unsigned char c = s[i]; if ((c & 0x80) == 0) { return len; - } else if ((c & 0xC0) == 0xC0) { + } + if ((c & 0xC0) == 0xC0) { size_t expected_len = 0; - if ((c & 0xE0) == 0xC0) expected_len = 2; - else if ((c & 0xF0) == 0xE0) expected_len = 3; - else if ((c & 0xF8) == 0xF0) expected_len = 4; - else return i; - if (len - i >= expected_len) { - return len; + if ((c & 0xE0) == 0xC0) { + expected_len = 2; + } else if ((c & 0xF0) == 0xE0) { + expected_len = 3; + } else if ((c & 0xF8) == 0xF0) { + expected_len = 4; } else { return i; } + if (len - i >= expected_len) { + return len; + } + return i; } } return len - std::min(len, size_t(3)); @@ -471,13 +1018,15 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s return s.substr(0, utf8_truncate_safe_len(s)); }; - auto merged = simple_assist_msg(""); + auto merged = simple_assist_msg(""); auto last_msg = parse_msg(""); for (size_t i = 1; i <= raw_message.size(); ++i) { auto curr_msg = parse_msg(std::string(utf8_truncate_safe_view(std::string_view(raw_message).substr(0, i)))); - if (curr_msg == simple_assist_msg("")) continue; - LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({curr_msg}).dump().c_str()); - for (auto diff: common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) { + if (curr_msg == simple_assist_msg("")) { + continue; + } + LOG_INF("Streaming msg: %s\n", common_chat_msgs_to_json_oaicompat({ curr_msg }).dump().c_str()); + for (auto diff : common_chat_msg_diff::compute_diffs(last_msg, curr_msg)) { LOG_INF("Streaming diff: %s\n", common_chat_msg_diff_to_json_oaicompat(diff).dump().c_str()); if (!diff.reasoning_content_delta.empty()) { merged.reasoning_content += diff.reasoning_content_delta; @@ -487,14 +1036,14 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s } if (diff.tool_call_index != std::string::npos) { if (!diff.tool_call_delta.name.empty()) { - merged.tool_calls.push_back({diff.tool_call_delta.name, "", ""}); + merged.tool_calls.push_back({ diff.tool_call_delta.name, "", "" }); } if (!diff.tool_call_delta.arguments.empty()) { GGML_ASSERT(!merged.tool_calls.empty()); merged.tool_calls.back().arguments += diff.tool_call_delta.arguments; } } - LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({merged}).dump().c_str()); + LOG_INF("Streaming merged: %s\n", common_chat_msgs_to_json_oaicompat({ merged }).dump().c_str()); } assert_msg_equals(curr_msg, merged, true); last_msg = curr_msg; @@ -503,99 +1052,95 @@ static void test_parser_with_streaming(const common_chat_msg & expected, const s assert_msg_equals(expected, merged, true); } -const common_chat_msg message_user { - "user", - "Hey there!", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; - -const common_chat_msg message_user_parts { - "user", - /* .content = */ "", - /* .content_parts = */ { - { "text", "Hey" }, - { "text", "there" }, - }, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "", - /* .tool_call_id = */ "", -}; - -const common_chat_msg message_assist = simple_assist_msg("Hello, world!\nWhat's up?"); -const common_chat_msg message_assist_empty = simple_assist_msg(""); -const common_chat_msg message_assist_thoughts_unparsed_deepseek = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); -const common_chat_msg message_assist_thoughts_unparsed_md = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```"); -const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}"); - -const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?"); -const common_chat_msg message_assist_thoughts_unparsed_magistral = simple_assist_msg("[THINK]raisonnement[/THINK]Réponse"); -const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"); -const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"); -const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking"); -const common_chat_msg message_assist_call = simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"); -const common_chat_msg message_assist_call_noopt = simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1}"); -const common_chat_msg message_assist_call_withopt = simple_assist_msg("", "", "special_function_with_opt", "{\"arg1\": 1, \"arg2\": 2}"); -const common_chat_msg message_assist_call_content = simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\":1}"); -const common_chat_msg message_assist_call_empty_args = simple_assist_msg("", "", "special_function"); -const common_chat_msg message_assist_call_cutoff_args = simple_assist_msg("", "", "special_function", "{\"arg"); -const common_chat_msg message_assist_call_thoughts = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\":1}"); -const common_chat_msg message_assist_call_thoughts_unparsed = simple_assist_msg("I'm\nthinking\n\n", "", "special_function", "{\"arg1\": 1}"); -const common_chat_msg message_assist_call_thoughts_content = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}"); -const common_chat_msg message_assist_call_id = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "123456789"); -const common_chat_msg message_assist_call_idx = simple_assist_msg("", "", "special_function", "{\"arg1\":1}", /* .id = */ "0"); -const common_chat_msg message_assist_thoughts_call_idx = simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}", /* id = */ "0"); -const common_chat_msg message_assist_call_python = simple_assist_msg("", "", "python", "{\"code\":\"print('hey')\"}"); -const common_chat_msg message_assist_call_python_lines = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')\"}"); -const common_chat_msg message_assist_call_python_lines_unclosed = simple_assist_msg("", "", "python", "{\"code\":\"# This is a program:\\nprint('hey')"); -const common_chat_msg message_assist_call_code_interpreter = simple_assist_msg("", "", "code_interpreter", "{\"code\":\"print('hey')\"}"); - // Use for PEG parser implementations struct peg_test_case { common_chat_templates_inputs params; - std::string input; - common_chat_msg expect; + std::string input; + common_chat_msg expect; + bool is_partial = false; }; struct make_peg_parser { common_chat_params params_; - common_peg_arena arena_; - - make_peg_parser(common_chat_templates * tmpls, const common_chat_templates_inputs & inputs) { - params_ = common_chat_templates_apply(tmpls, inputs); + common_peg_arena arena_; + bool detailed_debug_; + + make_peg_parser(common_chat_templates * tmpls, + const common_chat_templates_inputs & inputs, + bool detailed_debug = false) { + detailed_debug_ = detailed_debug; + params_ = common_chat_templates_apply(tmpls, inputs); arena_.load(params_.parser); } - common_chat_msg parse(const std::string & msg, bool is_partial) { + common_chat_msg parse(const std::string & msg, bool is_partial) const { common_chat_parser_params parser_params; parser_params.format = params_.format; + parser_params.debug = detailed_debug_; return common_chat_peg_parse(arena_, msg, is_partial, parser_params); } }; -static void test_peg_parser(common_chat_templates * tmpls, const std::function & init) { +static void test_peg_parser(common_chat_templates * tmpls, + const std::function & init, + bool detailed_debug) { + // UTF-8-safe truncation helper (same as in test_parser_with_streaming) + constexpr auto utf8_truncate_safe_len = [](const std::string_view s) -> size_t { + auto len = s.size(); + if (len == 0) { + return 0; + } + auto i = len; + for (size_t back = 0; back < 4 && i > 0; ++back) { + --i; + unsigned char c = s[i]; + if ((c & 0x80) == 0) { + return len; + } + if ((c & 0xC0) == 0xC0) { + size_t expected_len = 0; + if ((c & 0xE0) == 0xC0) { + expected_len = 2; + } else if ((c & 0xF0) == 0xE0) { + expected_len = 3; + } else if ((c & 0xF8) == 0xF0) { + expected_len = 4; + } else { + return i; + } + if (len - i >= expected_len) { + return len; + } + return i; + } + } + return len - std::min(len, size_t(3)); + }; + peg_test_case tc; init(tc); if (tc.params.messages.empty()) { - tc.params.messages = {message_user}; + tc.params.messages = { message_user }; } if (tc.expect.role.empty()) { tc.expect.role = "assistant"; } - auto parser = make_peg_parser(tmpls, tc.params); + auto parser = make_peg_parser(tmpls, tc.params, detailed_debug); + if (detailed_debug) { + LOG_DBG("Using parser: \n%s\n", parser.arena_.dump(parser.arena_.root()).c_str()); + } common_chat_msg msg_accum; common_chat_msg msg_prev; msg_accum.role = msg_prev.role = "assistant"; for (size_t i = 1; i <= tc.input.size(); ++i) { - auto is_partial = i < tc.input.size(); - common_chat_msg msg_current = parser.parse(tc.input.substr(0, i), is_partial); + auto is_partial = i < tc.input.size() || tc.is_partial; + // Use UTF-8 safe truncation to avoid corrupting multi-byte characters + size_t safe_len = utf8_truncate_safe_len(std::string_view(tc.input).substr(0, i)); + std::string prefix = tc.input.substr(0, safe_len); + common_chat_msg msg_current = parser.parse(prefix, is_partial); for (const auto & diff : common_chat_msg_diff::compute_diffs(msg_prev, msg_current)) { if (!diff.reasoning_content_delta.empty()) { @@ -605,24 +1150,245 @@ static void test_peg_parser(common_chat_templates * tmpls, const std::function 0) { + mpos = match.position(i); + break; + } + } + if (mpos == std::string::npos) { + mpos = match.position(0); + } + pos = mpos; + } + break; + } + default: + throw std::runtime_error("Unknown trigger type"); + } + if (pos != std::string::npos) { + if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) { + earliest_trigger_pos = pos; + } + } + } + + // Determine the constrained portion of input to test against grammar + std::string constrained = tc.input; + bool grammar_triggered = false; + if (earliest_trigger_pos != std::string::npos) { + constrained = tc.input.substr(earliest_trigger_pos); + grammar_triggered = true; + } else if (!parser.params_.grammar_lazy) { + // For non-lazy grammars, the entire input should match + grammar_triggered = true; + } + + // Test the constrained portion against the grammar + if (grammar_triggered && !tc.is_partial) { + auto result = match_string_detailed(constrained, grammar.get()); + if (!result.success) { + std::string error_msg; + if (result.incomplete) { + error_msg = + "Grammar matched all input but expects more:\n\n" + ">>> Input: " + tc.input + + "\n\n>>> Constrained: " + constrained + + "\n\n>>> Matched prefix (" + std::to_string(result.matched_bytes) + " bytes, " + + std::to_string(result.matched_codepoints) + " codepoints): " + + (result.matched_prefix.size() > 100 ? result.matched_prefix.substr(0, 100) + "..." : result.matched_prefix) + + "\n\n>>> Expected next: " + result.expected_description + + "\n\n>>> Grammar: " + parser.params_.grammar; + } else { + error_msg = + "Grammar match failed:\n\n" + ">>> Input: " + tc.input + + "\n\n>>> Constrained: " + constrained + + "\n\n>>> Matched prefix (" + std::to_string(result.matched_bytes) + " bytes, " + + std::to_string(result.matched_codepoints) + " codepoints): " + + (result.matched_prefix.size() > 100 ? result.matched_prefix.substr(0, 100) + "..." : result.matched_prefix) + + "\n\n>>> Failing character: " + result.failing_char + + "\n\n>>> Expected: " + result.expected_description + + "\n\n>>> Grammar: " + parser.params_.grammar; + } + throw std::runtime_error(error_msg); + } + } + } +} + +// Global template filter for --template flag +static std::string g_template_filter; + +// Fluent builder for PEG parser tests +class peg_test_builder; + +class peg_tester { + common_chat_templates_ptr tmpls_; + std::string template_path_; + bool detailed_debug_; + friend class peg_test_builder; + + public: + explicit peg_tester(const std::string & template_path, const bool detailed_debug = false) : + tmpls_(read_templates(template_path)), + template_path_(template_path), + detailed_debug_(detailed_debug) {} + + const std::string & template_path() const { return template_path_; } + + peg_test_builder test(const std::string & input); +}; + +class peg_test_builder { + peg_tester & tester_; + peg_test_case tc_; + + public: + peg_test_builder(peg_tester & tester, const std::string & input) : tester_(tester) { tc_.input = input; } + + // Parameter setters + peg_test_builder & reasoning_format(common_reasoning_format fmt) { + tc_.params.reasoning_format = fmt; + return *this; + } + + peg_test_builder & tools(std::vector tools) { + tc_.params.tools = std::move(tools); + return *this; + } + + peg_test_builder & enable_thinking(bool val) { + tc_.params.enable_thinking = val; + return *this; + } + + peg_test_builder & parallel_tool_calls(bool val) { + tc_.params.parallel_tool_calls = val; + return *this; + } + + peg_test_builder & json_schema(const std::string & schema) { + tc_.params.json_schema = schema; + return *this; + } + + peg_test_builder & is_partial(bool val) { + tc_.is_partial = val; + return *this; + } + + // Expect setters + peg_test_builder & expect(const common_chat_msg & msg) { + tc_.expect = msg; + return *this; + } + + peg_test_builder & expect_content(const std::string & content) { + tc_.expect.content = content; + return *this; + } + + peg_test_builder & expect_reasoning(const std::string & reasoning) { + tc_.expect.reasoning_content = reasoning; + return *this; + } + + peg_test_builder & expect_tool_calls(std::vector calls) { + tc_.expect.tool_calls = std::move(calls); + return *this; + } + + // Execute the test + void run() { + // Check template filter + if (!g_template_filter.empty()) { + // Case-insensitive substring match + std::string template_path_lower = tester_.template_path(); + std::string filter_lower = g_template_filter; + std::transform(template_path_lower.begin(), template_path_lower.end(), template_path_lower.begin(), + ::tolower); + std::transform(filter_lower.begin(), filter_lower.end(), filter_lower.begin(), ::tolower); + if (template_path_lower.find(filter_lower) == std::string::npos) { + // Skip this test + return; + } + } + LOG_INF("\n\x1b[38;5;126m[%s]\x1b[0m\n%s\n\n", tester_.template_path().c_str(), tc_.input.c_str()); + test_peg_parser(tester_.tmpls_.get(), [this](peg_test_case & t) { t = tc_; }, tester_.detailed_debug_); + } +}; + +peg_test_builder peg_tester::test(const std::string & input) { + return peg_test_builder(*this, input); } static void test_msgs_oaicompat_json_conversion() { - printf("[%s]\n", __func__); + LOG_DBG("%s\n", __func__); std::vector msgs{ message_user, message_user_parts, @@ -633,54 +1399,50 @@ static void test_msgs_oaicompat_json_conversion() { message_assist_call_id, message_assist_call_idx, message_assist_call_python, - message_assist_call_code_interpreter, }; for (const auto & msg : msgs) { - auto oai_json = common_chat_msgs_to_json_oaicompat({msg}); - auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json); + auto oai_json = common_chat_msgs_to_json_oaicompat({ msg }); + auto msgs2 = common_chat_msgs_parse_oaicompat(oai_json); assert_equals((size_t) 1, msgs2.size()); - auto msg2 = msgs2[0]; + const auto & msg2 = msgs2[0]; assert_msg_equals(msg, msg2); } - assert_equals( - std::string( - "[\n" - " {\n" - " \"role\": \"user\",\n" - " \"content\": [\n" - " {\n" - " \"type\": \"text\",\n" - " \"text\": \"Hey\"\n" - " },\n" - " {\n" - " \"type\": \"text\",\n" - " \"text\": \"there\"\n" - " }\n" - " ]\n" - " }\n" - "]" - ), - common_chat_msgs_to_json_oaicompat({message_user_parts}).dump(2)); - - assert_equals( - std::string( - "[\n" - " {\n" - " \"role\": \"assistant\",\n" - " \"content\": \"\",\n" - " \"tool_calls\": [\n" - " {\n" - " \"type\": \"function\",\n" - " \"function\": {\n" - " \"name\": \"python\",\n" - " \"arguments\": \"{\\\"code\\\":\\\"print('hey')\\\"}\"\n" - " }\n" - " }\n" - " ]\n" - " }\n" - "]" - ), - common_chat_msgs_to_json_oaicompat({message_assist_call_python}).dump(2)); + assert_equals(std::string("[\n" + " {\n" + " \"role\": \"user\",\n" + " \"content\": [\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"Hey\"\n" + " },\n" + " {\n" + " \"type\": \"text\",\n" + " \"text\": \"there\"\n" + " }\n" + " ]\n" + " }\n" + "]"), + common_chat_msgs_to_json_oaicompat({ message_user_parts }).dump(2)); + + // Note: content is "" instead of null due to workaround for templates that render null as "None" + assert_equals(std::string("[\n" + " {\n" + " \"role\": \"assistant\",\n" + " \"content\": \"\",\n" + " \"tool_calls\": [\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"python\",\n" + " \"arguments\": {\n" + " \"code\": \"print('hey')\"\n" + " }\n" + " }\n" + " }\n" + " ]\n" + " }\n" + "]"), + common_chat_msgs_to_json_oaicompat({ message_assist_call_python }).dump(2)); auto res = common_chat_msgs_parse_oaicompat(json::parse("[{\"role\": \"assistant\", \"tool_calls\": []}]")); assert_equals(1, res.size()); @@ -699,16 +1461,15 @@ static void test_msgs_oaicompat_json_conversion() { } static void test_tools_oaicompat_json_conversion() { - printf("[%s]\n", __func__); + LOG_DBG("%s\n", __func__); std::vector tools{ special_function_tool, python_tool, - code_interpreter_tool, }; for (const auto & tool : tools) { - auto oai_json = common_chat_tools_to_json_oaicompat({tool}); - auto tools2 = common_chat_tools_parse_oaicompat(oai_json); + auto oai_json = common_chat_tools_to_json_oaicompat({ tool }); + auto tools2 = common_chat_tools_parse_oaicompat(oai_json); assert_equals((size_t) 1, tools2.size()); auto tool2 = tools2[0]; assert_equals(tool.name, tool2.name); @@ -716,3040 +1477,1480 @@ static void test_tools_oaicompat_json_conversion() { assert_equals(json::parse(tool.parameters).dump(2), json::parse(tool2.parameters).dump(2)); } - assert_equals( - std::string( - "[\n" - " {\n" - " \"type\": \"function\",\n" - " \"function\": {\n" - " \"name\": \"special_function\",\n" - " \"description\": \"I'm special\",\n" - " \"parameters\": {\n" - " \"type\": \"object\",\n" - " \"properties\": {\n" - " \"arg1\": {\n" - " \"type\": \"integer\",\n" - " \"description\": \"The arg.\"\n" - " }\n" - " },\n" - " \"required\": [\n" - " \"arg1\"\n" - " ]\n" - " }\n" - " }\n" - " }\n" - "]" - ), - common_chat_tools_to_json_oaicompat({special_function_tool}).dump(2)); + assert_equals(std::string("[\n" + " {\n" + " \"type\": \"function\",\n" + " \"function\": {\n" + " \"name\": \"special_function\",\n" + " \"description\": \"I'm special\",\n" + " \"parameters\": {\n" + " \"type\": \"object\",\n" + " \"properties\": {\n" + " \"arg1\": {\n" + " \"type\": \"integer\",\n" + " \"description\": \"The arg.\"\n" + " }\n" + " },\n" + " \"required\": [\n" + " \"arg1\"\n" + " ]\n" + " }\n" + " }\n" + " }\n" + "]"), + common_chat_tools_to_json_oaicompat({ special_function_tool }).dump(2)); +} + +static void test_template_output_peg_parsers(bool detailed_debug) { + LOG_DBG("%s\n", __func__); + + // JSON schemas + const char * invoice_schema = R"({ + "type": "object", + "properties": { + "amount": {"type": "number"}, + "date": {"type": "string"} + } + })"; { - auto tools_no_params = common_chat_tools_parse_oaicompat(json::parse( - R"([{"type": "function", "function": {"name": "test_func", "description": "A test"}}])")); - assert_equals((size_t) 1, tools_no_params.size()); - assert_equals(std::string("test_func"), tools_no_params[0].name); - assert_equals(std::string("A test"), tools_no_params[0].description); - assert_equals(std::string("{}"), tools_no_params[0].parameters); + // Ministral-3-14B-Reasoning-2512 + auto tst = peg_tester("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", detailed_debug); + + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + + tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") + .expect_content("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") + .run(); + + tst.test("[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); + + tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + tst.test( + "[THINK]I'm\nthinking[/THINK]" + R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + tst.test(R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})" + R"([TOOL_CALLS]special_function_with_opt[ARGS]{"arg1": 1, "arg2": 2})") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + + tst.test( + "[THINK]I need to output the invoice details in JSON[/THINK]" + "```json\n" + R"({"amount": 123.45, "date": "2025-12-03"})" + "\n```") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .json_schema(invoice_schema) + .expect_reasoning("I need to output the invoice details in JSON") + .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})") + .run(); } + { - auto tools_no_desc = common_chat_tools_parse_oaicompat(json::parse( - R"([{"type": "function", "function": {"name": "test_func", "parameters": {"type": "object"}}}])")); - assert_equals((size_t) 1, tools_no_desc.size()); - assert_equals(std::string("test_func"), tools_no_desc[0].name); - assert_equals(std::string(""), tools_no_desc[0].description); + // NVIDIA Nemotron-3 Nano + auto tst = peg_tester("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", detailed_debug); + + tst.test("Hello, world!\nWhat's up?").enable_thinking(false).expect(message_assist).run(); + + tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") + .enable_thinking(false) + .reasoning_format(COMMON_REASONING_FORMAT_NONE) + .expect_content("I'm\nthinking\n\nHello, world!\nWhat's up?") + .run(); + + tst.test("I'm\nthinking\n\nHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); + + tst.test( + "\n" + "\n" + "\n1\n\n" + "\n" + "") + .enable_thinking(false) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + tst.test( + "I'm\nthinking\n\n" + "\n" + "\n" + "\n1\n\n" + "\n" + "") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + tst.test( + "\n" + "\n" + "\n1\n\n" + "\n" + "\n" + "\n" + "\n" + "\n1\n\n" + "\n2\n\n" + "\n" + "") + .enable_thinking(false) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + + tst.test( + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "") + .enable_thinking(false) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + python_tool + }) + .expect_tool_calls({ + { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, + }) + .run(); + + tst.test( + "I need to output the invoice details in JSON\n" + "\n" + R"({"amount": 123.45, "date": "2025-12-03"})") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .json_schema(invoice_schema) + .expect_reasoning("I need to output the invoice details in JSON") + .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})") + .run(); } + { - auto tools_minimal = common_chat_tools_parse_oaicompat(json::parse( - R"([{"type": "function", "function": {"name": "test_func"}}])")); - assert_equals((size_t) 1, tools_minimal.size()); - assert_equals(std::string("test_func"), tools_minimal[0].name); - assert_equals(std::string(""), tools_minimal[0].description); - assert_equals(std::string("{}"), tools_minimal[0].parameters); + // CohereForAI Command-R 7B (2024-tool_use) + auto tst = peg_tester("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja", detailed_debug); + + tst.test("<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>").expect(message_assist).run(); + + tst.test( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + + tst.test( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>") + .expect(message_assist_thoughts_unparsed_r7b) + .run(); + + tst.test( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_thoughts_call_idx) + .run(); + + tst.test( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", ") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .is_partial(true) + .expect(message_assist_thoughts_partial_call) + .run(); + + tst.test( + "<|START_THINKING|><|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_call_idx) + .run(); } -} - -// for compat; ref: https://github.com/ggml-org/llama.cpp/pull/18961 -struct test_parser_params { - common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; - bool reasoning_in_content = false; - bool thinking_forced_open = false; - bool parse_tool_calls = true; -}; -static common_chat_msg test_chat_parse(const std::string & input, bool is_partial, const test_parser_params & syntax) { - common_chat_parser_params params; - params.format = syntax.format; - params.reasoning_format = syntax.reasoning_format; - params.reasoning_in_content = syntax.reasoning_in_content; - params.thinking_forced_open = syntax.thinking_forced_open; - params.parse_tool_calls = syntax.parse_tool_calls; - return common_chat_parse(input, is_partial, params); -} + { + // Google Gemma 2 2B - does not support tool calling + auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja"); -static void test_template_output_parsers() { - printf("[%s]\n", __func__); + tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run(); - common_chat_templates_inputs inputs_no_tools; - inputs_no_tools.messages = {message_user}; + tst.test("Line 1\nLine 2\nLine 3").expect(simple_assist_msg("Line 1\nLine 2\nLine 3")).run(); + } - common_chat_templates_inputs inputs_tools; - inputs_tools.messages = {message_user}; - inputs_tools.tools = {special_function_tool}; + { + // Qwen-QwQ-32B (reasoning model) + auto tst = peg_tester("models/templates/Qwen-QwQ-32B.jinja"); - common_chat_templates_inputs inputs_tools_builtin; - inputs_tools_builtin.messages = {message_user}; - inputs_tools_builtin.tools = {python_tool}; + // QwQ always has thinking forced open - input starts after the \n in the prompt + tst.test("Let me think about this...\n\nThe answer is 42.") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(simple_assist_msg("The answer is 42.", "Let me think about this...")) + .run(); + tst.test("Hello, world!").expect(simple_assist_msg("Hello, world!")).run(); + } { - // Not supported yet - auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja"); - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + // NousResearch-Hermes-2-Pro and Hermes-3 (tool calling models) + auto tst = peg_tester("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", detailed_debug); + + tst.test( + "\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + tst.test( + "Hello, world!\nWhat's up?\n" + "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" + "") + .tools({ special_function_tool }) + .expect(message_assist_call_content) + .run(); + + // Note: Hermes template doesn't support thinking/reasoning natively + // Note: We only support one tool calling format per template, no alternate formats } { - auto tmpls = read_templates("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja"); - std::vector end_tokens{ "<|END_OF_TURN_TOKEN|>" }; + // Test simple content-only template + auto tst = peg_tester("models/templates/google-gemma-2-2b-it.jinja", detailed_debug); - for (const auto & inputs : { inputs_no_tools, inputs_tools }) { - auto params = common_chat_templates_apply(tmpls.get(), inputs); - assert_equals(COMMON_CHAT_FORMAT_COMMAND_R7B, params.format); - assert_equals(false, params.thinking_forced_open); - } - - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_COMMAND_R7B})); - assert_msg_equals(message_assist, - test_chat_parse( - "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_COMMAND_R7B})); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" - "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - test_chat_parse( - "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" - "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ true, - /* .thinking_forced_open = */ false, - })); - assert_msg_equals(message_assist_thoughts_unparsed_r7b, - test_chat_parse( - "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" - "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_COMMAND_R7B})); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" - "<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts_call_idx, - test_chat_parse( - "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" - "<|START_ACTION|>[\n" - " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" - "]<|END_ACTION|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts_no_content, - test_chat_parse( - "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" - "<|START_ACTION|>[\n" - " {\"tool_call_id\": \"0\", \"tool_name\": \"special", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_COMMAND_R7B, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - test_templates(tmpls.get(), end_tokens, message_assist_call_idx, tools, - "<|START_THINKING|><|END_THINKING|>" - "<|START_ACTION|>[\n" - " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" - "]<|END_ACTION|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - COMMON_REASONING_FORMAT_DEEPSEEK); - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "<|START_RESPONSE|>Hello, world!\n" - "What's up?<|END_RESPONSE|>", - /* expect_grammar_triggered= */ false); - } - // TODO @ngxson : generic tool calls is too costly to maintain, consider removing it in the future - { - auto tmpls = read_templates("models/templates/google-gemma-2-2b-it.jinja"); - std::vector end_tokens{ "" }; - - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GENERIC, - common_chat_templates_apply( - read_templates("models/templates/microsoft-Phi-3.5-mini-instruct.jinja").get(), - inputs_tools) - .format); - - // Generic tool calls doesn't generate / parse content-only messages symmetrically. - - assert_equals( - simple_assist_msg("{ \"tool_call\" : { \"name\" : \"t"), - test_chat_parse( - "{ \"tool_call\" : { \"name\" : \"t", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GENERIC, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ false, - })); - assert_equals( - message_assist_empty, - test_chat_parse( - "{ \"tool_call\" : { \"name\" : \"t", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GENERIC})); - - assert_equals( - simple_assist_msg("", "", "puppeteer_screenshot", "{\"name\":\"servethehome_homepage\","), - test_chat_parse( - R"({"tool_call": {"name": "puppeteer_screenshot", "arguments": {"name": "servethehome_homepage",)", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GENERIC})); - - assert_equals( - message_assist_call_empty_args, - test_chat_parse( - "{ \"tool_call\" : { \"name\" : \"special_function\"", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GENERIC})); - assert_equals( - message_assist_call_cutoff_args, - test_chat_parse( - "{ \"tool_call\" : { \"name\" : \"special_function\", \"arguments\" : { \"arg", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GENERIC})); - - assert_msg_equals(message_assist, - test_chat_parse( - "{\n" - " \"response\": \"Hello, world!\\nWhat's up?\"\n" - "}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GENERIC})); -#if 0 - test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, - "{\n" - " \"tool_calls\": [\n" - " {\n" - " \"name\": \"special_function\",\n" - " \"arguments\": {\n" - " \"arg1\": 1\n" - " },\n" - " \"id\": \"123456789\"\n" - " }\n" - " ],\n" - " \"content\": \"\"\n" - "}"); -#endif + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); } { - auto tmpls = read_templates("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja"); - std::vector end_tokens{ "" }; + // IBM Granite (reasoning and tool calling model) + auto tst = peg_tester("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja", detailed_debug); + + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); - assert_equals(COMMON_CHAT_FORMAT_MISTRAL_NEMO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); + tst.test("I'm\nthinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates( - tmpls.get(), end_tokens, message_assist_call_id, tools, - "[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]"); + // TODO: pending support for WRAPPED_WITH_REASONING + // tst.test("I'm\nthinkingHello, world!\nWhat's up?") + // .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + // .expect(message_assist_thoughts) + // .run(); } + { - assert_msg_equals( - simple_assist_msg("Réponse", "raisonnement"), - test_chat_parse( - message_assist_thoughts_unparsed_magistral.content, - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_MAGISTRAL, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); + // ByteDance-Seed-OSS (reasoning and tool calling model) + auto tst = peg_tester("models/templates/ByteDance-Seed-OSS.jinja", detailed_debug); + + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + + tst.test("I'm thinking about the answer\nHello, world!") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(simple_assist_msg("Hello, world!", "I'm thinking about the answer")) + .run(); + + tst.test( + "\n" + "\n" + "1\n" + "\n" + "") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + tst.test( + "\n" + "\n" + "1\n" + "\n" + "\n" + "\n" + "\n" + "1\n" + "2\n" + "\n" + "") + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + + tst.test( + "\n" + "\n" + "[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n" + "\n" + "") + .tools({ + todo_list + }) + .expect_tool_calls({ + { "todo_list", "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", {} }, + }) + .run(); + + // tool call with inside quotes + tst.test( + "\n" + "\n" + "\n" + "foo.cpp\n" + "\n" + "" + "def foo(arg = \"14\"):\n" + " return arg + \"bar\"\n" + "\n" + "\n" + "" + "def foo(arg = \"15\"):\n" + " pass\n" + "\n" + "\n" + "\n" + "") + .tools({ + edit_tool + }) + .expect_tool_calls({ + { "edit", "{\"filename\": \"foo.cpp\", " + "\"oldString\": \"def foo(arg = \\\"14\\\"):\\n return arg + \\\"bar\\\"\\n\", " + "\"newString\": \"def foo(arg = \\\"15\\\"):\\n pass\\n\"}", {} + } + }) + .run(); } - { - auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja"); - std::vector end_tokens{ "<|im_end|>" }; - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - } { - auto tmpls = read_templates("models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja"); - std::vector end_tokens{ "<|im_end|>" }; - - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_HERMES_2_PRO, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals( - COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_templates_apply( - read_templates("models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja").get(), - inputs_tools) - .format); - assert_equals( - COMMON_CHAT_FORMAT_HERMES_2_PRO, - common_chat_templates_apply( - read_templates("models/templates/Qwen-Qwen2.5-7B-Instruct.jinja").get(), - inputs_tools) - .format); - - // Test parsing - assert_msg_equals( - simple_assist_msg("", "", "python", ""), - test_chat_parse( - "```json\n" - " { \"name\" : \"python\"", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - simple_assist_msg("Let's call something\n"), - test_chat_parse( - "Let's call something\n" - "{\"name\"", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals( - simple_assist_msg("Let's call something\n"), - test_chat_parse( - "Let's call something\n" - "{\"name", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - // QwQ-32B's template adds a trailing if add_generation_prompt - "I'm\nthinking\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals(message_assist_call_content, - test_chat_parse( - "Hello, world!\nWhat's up?\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "{\"arg1\": 1}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - "{\"arg1\": 1}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "```xml\n" - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "\n" - "```", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "```xml\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "```\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "```\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "```json\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "```", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "```json\n" - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}} \n" - " \n" - "``` ", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " {\n" - " \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}\n" - " }\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " {\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals( - message_assist_call, - test_chat_parse( - "{\n \"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - - // Test multiple tool calls - common_chat_msg message_assist_multiple_calls; - message_assist_multiple_calls.role = "assistant"; - message_assist_multiple_calls.content = ""; - message_assist_multiple_calls.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""}); - message_assist_multiple_calls.tool_calls.push_back({"python", "{\"code\":\"print('hello')\"}", ""}); - - assert_msg_equals( - message_assist_multiple_calls, - test_chat_parse( - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "\n" - "\n" - "{\"name\": \"python\", \"arguments\": {\"code\":\"print('hello')\"}}\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - - assert_msg_equals( - message_assist_multiple_calls, - test_chat_parse( - "{\"arg1\": 1}\n" - "{\"code\":\"print('hello')\"}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - - assert_msg_equals( - simple_assist_msg( - "This is not a tool call:", - "", - "special_function", - "{\"arg1\": 1}"), - test_chat_parse( - "This is not a tool call:\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_HERMES_2_PRO})); - // assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - // test_chat_parse( - // "I'm\nthinkingHello, world!\nWhat's up?", - // COMMON_CHAT_FORMAT_HERMES_2_PRO)); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts_unparsed_md, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ true, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ false, - })); - assert_msg_equals(message_assist_thoughts_unparsed_md_partial, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?\n```json\n{}```", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ true, - /* .thinking_forced_open = */ false, - })); - assert_msg_equals(message_assist_thoughts_unopened_unparsed, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - ""); - - // Test multiple tool calls with template - common_chat_msg message_assist_multiple_calls_template; - message_assist_multiple_calls_template.role = "assistant"; - message_assist_multiple_calls_template.content = ""; - message_assist_multiple_calls_template.tool_calls.push_back({"special_function", "{\"arg1\": 1}", ""}); - message_assist_multiple_calls_template.tool_calls.push_back({"python", "{\"code\":\"print('test')\"}", ""}); - - test_templates(tmpls.get(), end_tokens, message_assist_multiple_calls_template, tools, - "\n" - "{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n" - "\n" - "\n" - "{\"name\": \"python\", \"arguments\": {\"code\":\"print('test')\"}}\n" - ""); - - test_templates(tmpls.get(), end_tokens, message_assist_call_python_lines, tools, - "\n" - "{\"name\": \"python\", \"arguments\": {\"code\":\"# This is a program:\\nprint('hey')\"}}\n" - ""); - assert_msg_equals( - simple_assist_msg("", /* reasoning_content= */ "nah uhg"), - test_chat_parse( - "nah uhg", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_HERMES_2_PRO, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); + // Qwen3-Coder (tool calling with XML-style format) + auto tst = peg_tester("models/templates/Qwen3-Coder.jinja", detailed_debug); + + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + + tst.test( + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + tst.test( + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "2\n" + "\n" + "\n" + "") + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + + // Test with code content (multiline) + tst.test( + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "") + .tools({ + python_tool + }) + .expect_tool_calls({ + { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, + }) + .run(); + + // Test with code content (asian unicode chars) + tst.test( + "\n" + "\n" + "\n" + "格\n" + "\n" + "\n" + "") + .tools({ + python_tool + }) + .expect_tool_calls({ + { "python", "{\"code\": \"格\"}", {} }, + }) + .run(); + + // Test with HTML tag content + tst.test( + "\n" + "\n" + "\n" + "\n" + " \n" + " Hello!\n" + " \n" + "\n" + "\n" + "\n" + "") + .tools({ + html_tool + }) + .expect_tool_calls({ + { "html", "{\"markup\": \"\\n \\n Hello!\\n \\n\"}", {} }, + }) + .run(); + + // Test with TODO list (array of objects) + tst.test( + "\n" + "\n" + "\n" + "[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n" + "\n" + "\n" + "") + .tools({ + todo_list + }) + .expect_tool_calls({ + { "todo_list", "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", {} }, + }) + .run(); + + // Test flexible optional argument ordering (2 required + 4 optional, reversed optional order) + tst.test( + "\n" + "\n" + "\nhello\n\n" + "\n42\n\n" + "\n100\n\n" + "\n200\n\n" + "\n" + "") + .tools({ tool_2req_4opt }) + .expect_tool_calls({ + { "tool_2req_4opt", R"({"req1": "hello", "req2": 42, "opt4": 100, "opt2": 200})", {} }, + }) + .run(); + + // Test flexible optional argument ordering (2 required + 5 optional, reversed optional order) + tst.test( + "\n" + "\n" + "\nworld\n\n" + "\n7\n\n" + "\nlast\n\n" + "\nmiddle\n\n" + "\nfirst\n\n" + "\n" + "") + .tools({ tool_2req_5opt }) + .expect_tool_calls({ + { "tool_2req_5opt", R"({"req1": "world", "req2": 7, "opt5": "last", "opt3": "middle", "opt1": "first"})", {} }, + }) + .run(); + + // Test flexible optional argument ordering (2 required + 5 optional, all 5 in shuffled order) + tst.test( + "\n" + "\n" + "\ntest\n\n" + "\n99\n\n" + "\nc\n\n" + "\na\n\n" + "\ne\n\n" + "\n4\n\n" + "\n2\n\n" + "\n" + "") + .tools({ tool_2req_5opt }) + .expect_tool_calls({ + { "tool_2req_5opt", R"({"req1": "test", "req2": 99, "opt3": "c", "opt1": "a", "opt5": "e", "opt4": 4, "opt2": 2})", {} }, + }) + .run(); } { - auto tmpls = read_templates("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja"); - std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_templates_apply(tmpls.get(), inputs_tools_builtin).format); - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, - common_chat_templates_apply( - read_templates("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja").get(), - inputs_tools_builtin) - .format); - - assert_equals( - message_assist_call, - test_chat_parse( - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LLAMA_3_X})); - - // test_templates(tmpls.get(), end_tokens, message_assist, tools, R"(?)", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_call_code_interpreter, llama_3_1_tools, - "<|python_tag|>code_interpreter.call(code=\"print('hey')\")"); - test_templates(tmpls.get(), end_tokens, message_assist_call_python, tools, - "<|python_tag|>python.call(code=\"print('hey')\")"); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); + tst.test( + "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": " + "\"XYZCITY\"}<|tool▁call▁end|><|tool▁calls▁end|>") + .tools({ get_time_tool }) + .expect(message_with_tool_calls("get_time", "{\"city\":\"XYZCITY\"}")) + .run(); } - { - auto tmpls = read_templates("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja"); - std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_LLAMA_3_X, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "{\"name\": \"special_function\", \"parameters\": {\"arg1\": 1}}"); - } { - auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.1.jinja"); - std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, - common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, - common_chat_templates_apply(tmpls.get(), inputs_tools).format); - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, - common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - - for (auto is_partial : { false, true }) { - assert_equals( - message_assist_call, - test_chat_parse( - "{\"arg1\": 1}", - is_partial, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); - } - - assert_equals( - message_assist_call, - test_chat_parse( - "{\"arg1\": 1}<", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1})); - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "{\"arg1\": 1}"); + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); + tst.test( + "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": " + "\"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ get_time_tool }) + .expect(message_with_tool_calls_and_reasoning("get_time", "{\"city\":\"Tokyo\"}", "REASONING")) + .run(); } + { - auto tmpls = read_templates("models/templates/meetkai-functionary-medium-v3.2.jinja"); - std::vector end_tokens{ "<|eom_id|>", "<|eot_id|>" }; - - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - assert_msg_equals( - simple_assist_msg( - "Hello, world!\nnono\nWhat's up?", - "", - "special_function", - "{\"arg1\": 1}"), - test_chat_parse( - "all\n" - "Hello, world!\n" - "nono\n" - "What's up?>>>special_function\n" - "{\"arg1\": 1}\n", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); - assert_msg_equals(message_assist_call_python_lines, - test_chat_parse( - "python\n" - "# This is a program:\n" - "print('hey')", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); - assert_msg_equals(message_assist_call_python_lines_unclosed, - test_chat_parse( - "python\n" - "# This is a program:\n" - "print('hey')", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); - assert_msg_equals(message_assist_call, - test_chat_parse( - "special_function\n" - "{\"arg1\": 1} \n ", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); - assert_msg_equals(message_assist, - test_chat_parse( - "all\n" - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2})); - - test_templates(tmpls.get(), end_tokens, message_assist, {}, - "all\n" - "Hello, world!\n" - "What's up?", - /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "special_function\n" - "{\"arg1\": 1}"); + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); + tst.test( + "REASONINGCONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": " + "\"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": " + "\"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>") + .tools({ + get_time_tool, get_weather_tool + }) + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .parallel_tool_calls(true) + .expect(message_with_reasoning_content_and_multiple_tool_calls( + "REASONING", "CONTENT", + { { "get_time", "{\"city\":\"Paris\"}" }, { "get_weather", "{\"city\":\"Paris\"}" } })) + .run(); } - { - auto tmpls = read_templates("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja"); - std::vector end_tokens{ "<|eot_id|>" }; - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_FIREFUNCTION_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - " functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]"); - } { - // Original DeepSeek R1 template. Leaves <|tool▁calls▁begin|> and others unclosed. Our logic fixes the prompt. - auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja"); - std::vector end_tokens{ "<|end▁of▁sentence|>" }; - - for (const auto & inputs : { inputs_no_tools, inputs_tools }) { - auto params = common_chat_templates_apply(tmpls.get(), inputs); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, params.format); - assert_equals(true, params.thinking_forced_open); - } - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals( - simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - assert_msg_equals( - simple_assist_msg("", "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with"), - test_chat_parse( - "I need to remember the correct syntax. It starts with <|tool▁calls▁begin|> and ends with", - /* is_partial= */ true, - { - COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts_unopened_unparsed, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - assert_msg_equals(message_assist_thoughts, - // Latest template update (ast of 20250209) adds a trailing \n if add_generation_prompt is true. - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - // test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - // "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - // "```json\n" - // "{\"arg1\": 1}\n" - // // Look what's not here: <|tool▁calls▁end|> (also missing the <|end▁of▁sentence|>, but that is removed lazily by the test's delta logic) - // "```<|tool▁call▁end|>", - // /* expect_grammar_triggered= */ true, - // /* test_grammar_if_triggered= */ false); + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); + tst.test("REASONING\nCONTENT") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(simple_assist_msg("CONTENT", "REASONING\n")) + .run(); } + { - // Replacement DeepSeek R1 template. Makes the Distill Qwen 7B/32B models happy to call tools and all. - auto tmpls = read_templates("models/templates/llama-cpp-deepseek-r1.jinja"); - std::vector end_tokens{ "<|end▁of▁sentence|>" }; - - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_R1, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - - assert_msg_equals(message_assist_call_thoughts_unparsed, - test_chat_parse( - "I'm\nthinking\n\n" - "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); - assert_msg_equals(message_assist_call, - test_chat_parse( - "<|tool▁calls|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_DEEPSEEK_R1})); - - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - "I'm\nthinking\n\n" - "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_DEEPSEEK_R1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" - "```json\n" - "{\"arg1\": 1}\n" - "```<|tool▁call▁end|><|tool▁calls▁end|>"); + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-V3.1.jinja", detailed_debug); + tst.test("CONTENT").expect(simple_assist_msg("CONTENT", "")).run(); } + + // GLM-4.6 tests - format: function_name\n...\n...\n { - auto tmpls = read_templates("models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja"); - std::vector end_tokens{ "<|end_of_text|>" }; - - assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - - assert_equals(COMMON_CHAT_FORMAT_GRANITE, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GRANITE})); - assert_msg_equals( - message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GRANITE})); - - // Test parsing content with thinking - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GRANITE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts_unparsed_deepseek, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GRANITE})); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GRANITE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GRANITE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals(simple_assist_msg("I'm\nthinkingHello, world!\nWhat's up?"), - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GRANITE})); - assert_msg_equals(message_assist_empty, - test_chat_parse( - "I'm\nthinking", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GRANITE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - assert_msg_equals( - message_assist_empty, - test_chat_parse( - "I'm\nthinking[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GRANITE})); - assert_msg_equals( - message_assist_call_empty_args, - test_chat_parse( - "<|tool_call|>[{\"name\": \"special_function\"", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GRANITE})); - assert_msg_equals( - message_assist_call_cutoff_args, - test_chat_parse( - "<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_GRANITE})); - assert_msg_equals( - message_assist_call_cutoff_args, - test_chat_parse( - "<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GRANITE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test parsing tool calls with thinking - assert_msg_equals( - message_assist_call_thoughts, - test_chat_parse( - "I'm\nthinking<|tool_call|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, {", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GRANITE, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test template generation for regular content - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "Hello, world!\nWhat's up?", - /* expect_grammar_triggered= */ false); - // TODO @ngxson : generic tool call should be removed in the future -#if 0 - // Test template generation for tool calls - test_templates(tmpls.get(), end_tokens, message_assist_call_id, tools, - "{\n" - " \"tool_calls\": [\n" - " {\n" - " \"name\": \"special_function\",\n" - " \"arguments\": {\n" - " \"arg1\": 1\n" - " },\n" - " \"id\": \"123456789\"\n" - " }\n" - " ],\n" - " \"content\": \"\"\n" - "}", - /* expect_grammar_triggered= */ false - ); -#endif - } + auto tst = peg_tester("models/templates/GLM-4.6.jinja", detailed_debug); + tst.test( + "special_function\n" + "arg1\n1\n" + "") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + } + + // GLM-4.7-Flash tests - format: function_name...... + // Note: Template uses forced-open thinking mode (prompt ends with ) { - auto tmpls = read_templates("models/templates/openai-gpt-oss-120b.jinja"); - std::vector end_tokens{ "<|return|>", "<|call|>" }; - - assert_equals(COMMON_CHAT_FORMAT_GPT_OSS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GPT_OSS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - assert_msg_equals(simple_assist_msg("", "I'm\nthink"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthink", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "I'm\nthinking"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>analysis to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary<|message|>Hello, world!\nWhat's up?", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary<|message|>Hello, world!\nWhat's up?<|end|>" - "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - - // Test parse_tool_calls == false - assert_msg_equals( - simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ false, - })); - assert_msg_equals( - simple_assist_msg("", "I'm\nthinking"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary to=functions.special_function<|message|>{\"arg1", - /* is_partial= */ true, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ false, - })); - assert_msg_equals( - simple_assist_msg("", "I'm\nthinking"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>commentary to=functions.special_function <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ false, - })); - - // Test reasoning formats - assert_msg_equals( - simple_assist_msg( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - })); - - assert_msg_equals( - simple_assist_msg( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant<|channel|>final<|message|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - /* .reasoning_in_content = */ true, - })); - - // Test tool calling in role header - assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"), - test_chat_parse( - " to=functions.special_function<|channel|>commentary <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}"), - test_chat_parse( - " to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - assert_msg_equals(simple_assist_msg("", "I'm\nthinking", "special_function", "{\"arg1\": 1}"), - test_chat_parse( - "<|channel|>analysis<|message|>I'm\nthinking<|end|>" - "<|start|>assistant to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GPT_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO, - })); - } + auto tst = peg_tester("models/templates/GLM-4.7-Flash.jinja", detailed_debug); + + // Pure content (no reasoning) + tst.test("Hello, world!\nWhat's up?") + .enable_thinking(false) + .expect(message_assist) + .run(); + + // Reasoning with content (forced-open mode - input starts after ) + tst.test("I'm\nthinkingHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + + // Tool call without reasoning + tst.test( + "special_function" + "arg11" + "") + .enable_thinking(false) + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call with reasoning (forced-open mode) + tst.test( + "I'm\nthinking" + "special_function" + "arg11" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + tst.test( + "special_function" + "arg11" + "" + "special_function_with_opt" + "arg11" + "arg22" + "") + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); + } + + // Kimi-K2-Thinking tests - custom parser + // Unique feature: tool call ID embeds function name as functions.: { - // Seed-OSS format tests - auto tmpls = read_templates("models/templates/ByteDance-Seed-OSS.jinja"); - std::vector end_tokens{ "" }; - - assert_equals(COMMON_CHAT_FORMAT_SEED_OSS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_SEED_OSS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - - // Test simple reasoning content - assert_msg_equals( - simple_assist_msg("Hello, world!", "I'm thinking about the answer"), - test_chat_parse( - "I'm thinking about the answerHello, world!", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test budget reflection tags - common_chat_msg msg_budget_reflect; - msg_budget_reflect.role = "assistant"; - msg_budget_reflect.content = "Token usage: 45/1000\nI should continue thinking to find the best solution.I need to calculate this step by step."; - msg_budget_reflect.reasoning_content = "Token usage: 45/1000\nI should continue thinking to find the best solution."; - assert_msg_equals( - msg_budget_reflect, - test_chat_parse( - "Token usage: 45/1000\nI should continue thinking to find the best solution." - "Token usage: 45/1000\nI should continue thinking to find the best solution." - "I need to calculate this step by step.", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test tool calls with Seed-OSS format - common_chat_msg msg_tool_call; - msg_tool_call.role = "assistant"; - msg_tool_call.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""}); - assert_msg_equals( - msg_tool_call, - test_chat_parse( - "\n" - "\n" - "[1, 2, 3]\n" - "\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_SEED_OSS})); - - // Test reasoning + tool call combination - common_chat_msg msg_reasoning_tool; - msg_reasoning_tool.role = "assistant"; - msg_reasoning_tool.content = ""; - msg_reasoning_tool.reasoning_content = "I need to calculate the sum of these numbers"; - msg_reasoning_tool.tool_calls.push_back({"calculate_sum", "{\"numbers\": [1, 2, 3]}", ""}); - assert_msg_equals( - msg_reasoning_tool, - test_chat_parse( - "I need to calculate the sum of these numbers" - "\n" - "\n" - "[1, 2, 3]\n" - "\n" - "", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test deltas: the number of tool calls in partial parses should never decrease - std::string tool_msg = "\n" - "\n" - "[1, 2, 3]\n" - ""; - std::size_t previousToolCalls = 0; - for (std::size_t i = std::string("").length(); i < tool_msg.length() - 1; i++) { - auto partial = tool_msg.substr(0, i); - auto partial_res = test_chat_parse(partial, true, { COMMON_CHAT_FORMAT_SEED_OSS, COMMON_REASONING_FORMAT_DEEPSEEK }); - if (partial_res.tool_calls.size() < previousToolCalls) { - throw std::runtime_error("Tool call size decreased on partial: " + partial + " from " + std::to_string(previousToolCalls) + " to " + std::to_string(partial_res.tool_calls.size())); - } - previousToolCalls = partial_res.tool_calls.size(); - } - - // Test multiple parameters in tool call - common_chat_msg msg_multi_param; - msg_multi_param.role = "assistant"; - msg_multi_param.tool_calls.push_back({"process_data", "{\"input\": \"test\", \"format\": \"json\"}", ""}); - assert_msg_equals( - msg_multi_param, - test_chat_parse( - "\n" - "\n" - "test\n" - "json\n" - "\n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_SEED_OSS})); - - // Test partial parsing for incomplete tool call - don't actually add the call until parsing parameters is done - assert_msg_equals( - simple_assist_msg("", "", "calculate_sum", "{\"numbers\":"), - test_chat_parse( - "\n" - "\n" - "[1,\n", - /* is_partial= */ true, - {COMMON_CHAT_FORMAT_SEED_OSS})); - - // Test incomplete reasoning tag - assert_msg_equals( - simple_assist_msg("", "I was thinking"), - test_chat_parse( - "I was thinking", - /* is_partial= */ true, + auto tst = peg_tester("models/templates/Kimi-K2-Thinking.jinja", detailed_debug); + + // Basic content only + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + + // Single tool call + tst.test( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" + "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") + .tools({ special_function_tool }) + .expect(simple_assist_msg("", "", "special_function", "{\"arg1\": 1}", "functions.special_function:0")) + .run(); + + // Single tool call with reasoning + tst.test( + "I'm thinking about this" + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" + "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(simple_assist_msg("", "I'm thinking about this", "special_function", "{\"arg1\": 1}", "functions.special_function:0")) + .run(); + + // Tool call with content + tst.test( + "Hello, world!\nWhat's up?" + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" + "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") + .tools({ special_function_tool }) + .expect(simple_assist_msg("Hello, world!\nWhat's up?", "", "special_function", "{\"arg1\": 1}", "functions.special_function:0")) + .run(); + + // Multiple tool calls (parallel) - tests the indexing behavior + tst.test( + "<|tool_calls_section_begin|>" + "<|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|>" + "<|tool_call_begin|>functions.special_function_with_opt:1<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|>" + "<|tool_calls_section_end|>") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", "functions.special_function:0" }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", "functions.special_function_with_opt:1" }, + }) + .run(); + + // Multiple tool calls with reasoning + tst.test( + "I need to call two functions" + "<|tool_calls_section_begin|>" + "<|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|>" + "<|tool_call_begin|>functions.python:1<|tool_call_argument_begin|>{\"code\": \"print('hey')\"}<|tool_call_end|>" + "<|tool_calls_section_end|>") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, python_tool + }) + .expect_reasoning("I need to call two functions") + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", "functions.special_function:0" }, + { "python", "{\"code\": \"print('hey')\"}", "functions.python:1" }, + }) + .run(); + + // Python tool with multiline code + tst.test( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.python:0<|tool_call_argument_begin|>" + "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}<|tool_call_end|><|tool_calls_section_end|>") + .tools({ python_tool }) + .expect_tool_calls({ + { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", "functions.python:0" }, + }) + .run(); + + // Tool call with empty arguments + tst.test( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.empty_args:0<|tool_call_argument_begin|>" + "{}<|tool_call_end|><|tool_calls_section_end|>") + .tools({ empty_args_tool }) + .expect(simple_assist_msg("", "", "empty_args", "{}", "functions.empty_args:0")) + .run(); + + // Partial tool call (streaming) + tst.test( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" + "{\"arg1\": ") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .is_partial(true) + .expect(simple_assist_msg("", "", "special_function", "{\"arg1\": ", "functions.special_function:0")) + .run(); + + // Three tool calls to verify counter continues incrementing + tst.test( + "<|tool_calls_section_begin|>" + "<|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|>" + "<|tool_call_begin|>functions.python:1<|tool_call_argument_begin|>{\"code\": \"print(1)\"}<|tool_call_end|>" + "<|tool_call_begin|>functions.html:2<|tool_call_argument_begin|>{\"markup\": \"

test

\"}<|tool_call_end|>" + "<|tool_calls_section_end|>") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, python_tool, html_tool + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", "functions.special_function:0" }, + { "python", "{\"code\": \"print(1)\"}", "functions.python:1" }, + { "html", "{\"markup\": \"

test

\"}", "functions.html:2" }, + }) + .run(); + + // Multiple tool calls with reasoning, call *inside thinking block* + tst.test( + "I need to call two functions" + "<|tool_calls_section_begin|>" + "<|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|>" + "<|tool_call_begin|>functions.python:1<|tool_call_argument_begin|>{\"code\": \"print('hey')\"}<|tool_call_end|>" + "<|tool_calls_section_end|>") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, python_tool + }) + .expect_reasoning("I need to call two functions") + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", "functions.special_function:0" }, + { "python", "{\"code\": \"print('hey')\"}", "functions.python:1" }, + }) + .run(); + + // Multiple tool calls with reasoning, call *inside thinking block* and *without section markers or end markers + tst.test( + "I need to call two functions" + "<|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}" + "<|tool_call_begin|>functions.python:1<|tool_call_argument_begin|>{\"code\": \"print('hey')\"}") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + special_function_tool, python_tool + }) + .expect_reasoning("I need to call two functions") + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", "functions.special_function:0" }, + { "python", "{\"code\": \"print('hey')\"}", "functions.python:1" }, + }) + .run(); + + // Real life test - execute_command + tst.test("<|tool_call_begin|>functions.execute_command:0<|tool_call_argument_begin|>{\"command\": \"ls -lah\"" + ", \"cwd\": \"/home/jarvis/development/exllamav3\", \"timeout\": 10}") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .parallel_tool_calls(true) + .tools({ + { + /* .name = */ "execute_command", + /* .description = */ "Execute shell command", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "Shell command to execute" + }, + "cwd": { + "type": "string", + "description": "Working directory" + }, + "timeout": { + "type": "integer", + "description": "The timeout in seconds" + } + }, + "required": ["command"] + })" + } + }). + expect_tool_calls({ { - /* .format = */ COMMON_CHAT_FORMAT_SEED_OSS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test content without reasoning - assert_msg_equals( - simple_assist_msg("This is a simple response without reasoning."), - test_chat_parse( - "This is a simple response without reasoning.", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_SEED_OSS})); + "execute_command", + R"({"command": "ls -lah", "cwd": "/home/jarvis/development/exllamav3", "timeout": 10})", + "functions.execute_command:0" + } + }) + .run(); } - { - auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-Nano-v2.jinja"); - std::vector end_tokens{ "" }; - - assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_NEMOTRON_V2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_NEMOTRON_V2})); - - // Test parsing content with thinking - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test parsing tool calls - assert_msg_equals(message_assist_call, - test_chat_parse( - "[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_NEMOTRON_V2})); - - // Test parsing tool calls with thinking - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - "I'm\nthinking[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test tool calls with extra content - assert_msg_equals(message_assist_call_content, - test_chat_parse( - "[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_NEMOTRON_V2} - )); - - // Test tool calls with extra content AND thinking - assert_msg_equals(message_assist_call_thoughts_content, - test_chat_parse( - "I'm\nthinking[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_NEMOTRON_V2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - // Test template generation for regular content - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "Hello, world!\nWhat's up?\n", - /* expect_grammar_triggered= */ false); - - // Test template generation for tool calls - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]", - /* expect_grammar_triggered= */ true - ); - } { - auto tmpls = read_templates("models/templates/deepseek-ai-DeepSeek-V3.1.jinja"); - std::vector end_tokens{ "<|end▁of▁sentence|>" }; - - for (const auto & inputs : { inputs_no_tools, inputs_tools }) { - auto params = common_chat_templates_apply(tmpls.get(), inputs); - assert_equals(COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, params.format); - assert_equals(true, params.thinking_forced_open); - } - - test_templates(tmpls.get(), end_tokens, message_assist, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - test_templates(tmpls.get(), end_tokens, message_assist_thoughts, tools, "Hello, world!\nWhat's up?", /* expect_grammar_triggered= */ false); - assert_msg_equals( - simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking"), - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - })); - // variant: thinking forced open, reasoning_format none - assert_msg_equals( - simple_assist_msg("REASONINGok", ""), - test_chat_parse( - "REASONINGok", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - })); - // variant: happy path for when it works as the model card says it should - assert_msg_equals( - simple_assist_msg("", "", "get_time", "{\"city\":\"Tokyo\"}"), - test_chat_parse( - "<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ true, - })); - // variant: simple + thinking open - assert_msg_equals( - simple_assist_msg("", "REASONING", "get_time", "{\"city\":\"Tokyo\"}"), - test_chat_parse( - "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - })); - // variant: simple + multiple tool calls - common_chat_msg message_assist_multiple_calls; - message_assist_multiple_calls.role = "assistant"; - message_assist_multiple_calls.content = "CONTENT"; - message_assist_multiple_calls.tool_calls.push_back({"get_time", "{\"city\":\"Paris\"}", ""}); - message_assist_multiple_calls.tool_calls.push_back({"get_weather", "{\"city\":\"Paris\"}", ""}); - assert_msg_equals( - message_assist_multiple_calls, - test_chat_parse( - "CONTENT<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁call▁begin|>get_weather<|tool▁sep|>{\"city\": \"Paris\"}<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ true, - })); - // variant: thinking forced open + tool call in reasoning content - assert_msg_equals( - simple_assist_msg("", "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING", "get_time", "{\"city\":\"Tokyo\"}"), - test_chat_parse( - "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time2<|tool▁sep|>{\"city\": \"Tokyo2\"}<|tool▁call▁end|><|tool▁calls▁end|>REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - })); - // variant: thinking forced open + tool call in reasoning content + no closing think + not partial - // This is a bit of a fine tuning issue on the model's part IMO. It really should not be attempting - // to make tool calls in reasoning content according to the model card, but it does sometimes, so - // add the reasoning content as regular content and parse the tool calls. - assert_msg_equals( - simple_assist_msg("REASONING", "", "get_time", "{\"city\":\"Tokyo\"}"), - test_chat_parse( - "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - })); - // variant: thinking forced open + tool call in reasoning content + no closing think + partial - assert_msg_equals( - simple_assist_msg("", "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", "", ""), - test_chat_parse( - "REASONING<|tool▁calls▁begin|><|tool▁call▁begin|>get_time<|tool▁sep|>{\"city\": \"Tokyo\"}<|tool▁call▁end|><|tool▁calls▁end|>", - /* is_partial= */ true, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ true, - /* .parse_tool_calls = */ true, - })); - // variant: thinking not forced open + missing reasoning + no tool calls - assert_msg_equals( - simple_assist_msg("CONTENT", ""), - test_chat_parse( - "CONTENT", - /* is_partial= */ false, - { - COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* .reasoning_in_content = */ false, - /* .thinking_forced_open = */ false, - /* .parse_tool_calls = */ true, - })); - } + auto kimi_id_special_func_tool_call = + simple_assist_msg("", "", "special_function", "{\"arg1\": 1}", "functions.special_function:0"); + + // Kimi-K2 old template + auto tst = peg_tester("models/templates/moonshotai-Kimi-K2.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" + "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") + .tools({ special_function_tool }) + .expect(kimi_id_special_func_tool_call) + .run(); + + // Kimi-K2-Instruct + auto tst2 = peg_tester("models/templates/Kimi-K2-Instruct.jinja", detailed_debug); + tst2.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst2.test( + "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>" + "{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>") + .tools({ special_function_tool }) + .expect(kimi_id_special_func_tool_call) + .run(); + } + + // Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format + // Format: <|tools_prefix|>[{"function_name": {...arguments...}}]<|tools_suffix|> { - auto tmpls = read_templates("models/templates/Apertus-8B-Instruct.jinja"); - std::vector end_tokens{ "<|assistant_end|>" }; - - assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_APERTUS, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_APERTUS})); - - // Test parsing content with thinking - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "<|inner_prefix|>I'm\nthinking<|inner_suffix|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_APERTUS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test parsing tool calls - assert_msg_equals(message_assist_call, - test_chat_parse( - "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_APERTUS})); - - // Test parsing tool calls with thinking - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_APERTUS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test tool calls with extra content - assert_msg_equals(message_assist_call_content, - test_chat_parse( - "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_APERTUS} - )); - - // Test tool calls with extra content AND thinking - assert_msg_equals(message_assist_call_thoughts_content, - test_chat_parse( - "<|inner_prefix|>I'm\nthinking<|inner_suffix|><|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_APERTUS, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test template generation for regular content - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "Hello, world!\nWhat's up?", - /* expect_grammar_triggered= */ false); - - // Test template generation for tool calls - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>", - /* expect_grammar_triggered= */ true - ); - - // TODO @ngxson : not sure why this fails, but not very important for now - // assert_equals(true, common_chat_templates_support_enable_thinking(tmpls.get())); + auto tst = peg_tester("models/templates/Apertus-8B-Instruct.jinja", detailed_debug); + tst.test("<|tools_prefix|>[{\"special_function\": {\"arg1\": 1}}]<|tools_suffix|>") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); } - { - // LFM2 format tests - auto tmpls = read_templates("models/templates/llama-cpp-lfm2.jinja"); - std::vector end_tokens{ "<|im_end|>" }; - - auto inputs_tools_forced_json_schema = std::invoke([&]() -> common_chat_templates_inputs { - common_chat_templates_inputs inputs; - inputs.messages = { - std::invoke([&]() -> common_chat_msg { - common_chat_msg msg; - msg.role = "system"; - msg.content = "force json schema.\n"; - return msg; - }), - message_user, - }; - inputs.tools = {special_function_tool}; - return inputs; - }); - - { - auto params = common_chat_templates_apply(tmpls.get(), inputs_no_tools); - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); - assert_equals(false, params.grammar_lazy); - assert_equals(std::string(R"(<|im_start|>user -Hey there!<|im_end|> -<|im_start|>assistant -)"), params.prompt); - } - - { - auto params = common_chat_templates_apply(tmpls.get(), inputs_tools); - assert_equals(COMMON_CHAT_FORMAT_CONTENT_ONLY, params.format); - assert_equals(false, params.grammar_lazy); - assert_equals(std::string(R"(<|im_start|>system -List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> -<|im_start|>user -Hey there!<|im_end|> -<|im_start|>assistant -)"), params.prompt); - assert_equals(true, params.grammar.empty()); - } - - { - auto params = common_chat_templates_apply(tmpls.get(), inputs_tools_forced_json_schema); - assert_equals(COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, params.format); - assert_equals(true, params.grammar_lazy); - assert_equals(std::string(R"(<|im_start|>system -List of tools: <|tool_list_start|>[{"type": "function", "function": {"name": "special_function", "description": "I'm special", "parameters": {"type": "object", "properties": {"arg1": {"type": "integer", "description": "The arg."}}, "required": ["arg1"]}}}]<|tool_list_end|><|im_end|> -<|im_start|>user -Hey there!<|im_end|> -<|im_start|>assistant -)"), params.prompt); - assert_equals(false, params.grammar.empty()); - } - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test single tool call with JSON format - common_chat_msg msg_single_tool_call; - msg_single_tool_call.role = "assistant"; - msg_single_tool_call.tool_calls.push_back({"special_function", "{\"arg1\":1}", ""}); - assert_msg_equals( - msg_single_tool_call, - test_chat_parse( - "<|tool_call_start|>[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]<|tool_call_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test tool call with string argument - common_chat_msg msg_tool_call_string; - msg_tool_call_string.role = "assistant"; - msg_tool_call_string.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); - assert_msg_equals( - msg_tool_call_string, - test_chat_parse( - "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test tool call with multiple arguments - common_chat_msg msg_multi_args; - msg_multi_args.role = "assistant"; - msg_multi_args.tool_calls.push_back({"calculate", "{\"x\":10,\"y\":20,\"operation\":\"add\"}", ""}); - assert_msg_equals( - msg_multi_args, - test_chat_parse( - "<|tool_call_start|>[{\"name\": \"calculate\", \"arguments\": {\"x\": 10, \"y\": 20, \"operation\": \"add\"}}]<|tool_call_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test multiple tool calls in single array - common_chat_msg msg_multiple_tools; - msg_multiple_tools.role = "assistant"; - msg_multiple_tools.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); - msg_multiple_tools.tool_calls.push_back({"get_time", "{\"timezone\":\"UTC\"}", ""}); - assert_msg_equals( - msg_multiple_tools, - test_chat_parse( - "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}, {\"name\": \"get_time\", \"arguments\": {\"timezone\": \"UTC\"}}]<|tool_call_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test tool call with content before - common_chat_msg msg_content_before_tool; - msg_content_before_tool.role = "assistant"; - msg_content_before_tool.content = "Let me check the weather for you."; - msg_content_before_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); - assert_msg_equals( - msg_content_before_tool, - test_chat_parse( - "Let me check the weather for you.<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test tool call with content after - common_chat_msg msg_content_after_tool; - msg_content_after_tool.role = "assistant"; - msg_content_after_tool.content = "Here's the result."; - msg_content_after_tool.tool_calls.push_back({"get_weather", "{\"location\":\"Paris\"}", ""}); - assert_msg_equals( - msg_content_after_tool, - test_chat_parse( - "<|tool_call_start|>[{\"name\": \"get_weather\", \"arguments\": {\"location\": \"Paris\"}}]<|tool_call_end|>Here's the result.", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Test tool call with newlines (common in LLM output) - common_chat_msg msg_tool_call_newlines; - msg_tool_call_newlines.role = "assistant"; - msg_tool_call_newlines.tool_calls.push_back({"get_current_time", "{\"location\":\"Paris\"}", ""}); - assert_msg_equals( - msg_tool_call_newlines, - test_chat_parse( - "<|tool_call_start|>[{\n \"name\": \"get_current_time\",\n \"arguments\": {\n \"location\": \"Paris\"\n }\n}]<|tool_call_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS})); - - // Note: LFM2 uses JSON format for tool calls: [{"name": "...", "arguments": {...}}] - // Unlike other formats, LFM2 template does not render tool calls in conversation history, - // so we don't use test_templates() for tool call generation. Instead, the parsing tests - // above verify edge cases and format variations for the tool call output format. + // MiniMax-M2 tests - XML invoke format with parameter tags + // Format: value + { + auto tst = peg_tester("models/templates/MiniMax-M2.jinja", detailed_debug); + tst.test( + "\n\n1\n\n") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); } + // NVIDIA-Nemotron-Nano-v2 tests - ... format + // Format: [{"name": "func", "arguments": {...}}] { - auto tmpls = read_templates("models/templates/MiniMax-M2.jinja"); - std::vector end_tokens{ "[e~[" }; - - assert_equals(COMMON_CHAT_FORMAT_MINIMAX_M2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_MINIMAX_M2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_MINIMAX_M2})); - - // Test parsing content with thinking - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test parsing tool calls - assert_msg_equals(message_assist_call, - test_chat_parse( - "1", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_MINIMAX_M2})); - - // Test parsing tool calls with thinking - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - "I'm\nthinking1", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test tool calls with extra content - assert_msg_equals(message_assist_call_content, - test_chat_parse( - "1Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_MINIMAX_M2} - )); - - // Test tool calls with extra content AND thinking - assert_msg_equals(message_assist_call_thoughts_content, - test_chat_parse( - "I'm\nthinking1Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test streaming - test_parser_with_streaming(message_assist_call_thoughts_content, - "I'm\nthinking\nHello, world!\nWhat's up?\n1", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(message_assist_call_thoughts_unparsed, - "I'm\nthinking\n\n1", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE - }); }); - test_parser_with_streaming(message_assist_call_thoughts_content, - "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n\n\n1\n\n\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(message_assist_call_withopt, - "\n\n1\n2\n\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_MINIMAX_M2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE - }); }); - - // Test template generation for regular content - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "Hello, world!\nWhat's up?", - /* expect_grammar_triggered= */ false); - - // Test template generation for tool calls - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "\n\n1\n\n", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE, - /* ignore_whitespace_differences= */ true - ); - - // Test template generation for tools with optional parameters - test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, - "\n\n1\n\n", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE, - /* ignore_whitespace_differences= */ true - ); - test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, - "\n\n1\n2\n\n", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_NONE, - /* ignore_whitespace_differences= */ true - ); + auto tst = peg_tester("models/templates/NVIDIA-Nemotron-Nano-v2.jinja", detailed_debug); + tst.test("[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); } + // CohereForAI-c4ai-command-r7b (uses START_RESPONSE/END_RESPONSE, START_THINKING/END_THINKING, START_ACTION/END_ACTION) { - auto tmpls = read_templates("models/templates/GLM-4.6.jinja"); - std::vector end_tokens{ "<|assistant|>", "<|observation|>" }; - - assert_equals(COMMON_CHAT_FORMAT_GLM_4_5, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_GLM_4_5, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GLM_4_5})); - - // Test parsing content with thinking - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "\nI'm\nthinking\nHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - }), true); - - // Test parsing tool calls - assert_msg_equals(message_assist_call, - test_chat_parse( - "\nspecial_function\narg1\n1\n", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GLM_4_5}), true); - - // Test parsing tool calls with thinking - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - "\nI'm\nthinking\nspecial_function\narg1\n1\n", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }), true); - - // Test tool calls with extra content - assert_msg_equals(message_assist_call_content, - test_chat_parse( - "\nspecial_function\narg1\n1\nHello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_GLM_4_5} - ), true); - - // Test tool calls with extra content AND thinking - assert_msg_equals(message_assist_call_thoughts_content, - test_chat_parse( - "\nI'm\nthinkingHello, world!\nWhat's up?\nspecial_function\narg1\n1\n", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }), true); - - // Test streaming - test_parser_with_streaming(message_assist_call_thoughts_content, - "\nI'm\nthinkingHello, world!\nWhat's up?\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(message_assist_call_thoughts_unparsed, - "\nI'm\nthinking\n\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE - }); }); - test_parser_with_streaming(message_assist_call_withopt, - "\n\nspecial_function_with_opt\narg1\n1\narg2\n2\n\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming( - simple_assist_msg("", "", "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), - "complex_function\n" - "name\n" - "John Doe\n" - "age\n" - "30\n" - "active\n" - "true\n" - "score\n" - "95.5\n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); - test_parser_with_streaming( - simple_assist_msg("", "", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}"), - "web_search\n" - "query\n" - "\"From Zero\" Linkin Park album tracklist complete songs\n" - "limit\n" - "3\n" - "type\n" - "text\n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_GLM_4_5}); }); - - // Test interleaved thinking - test_parser_with_streaming(simple_assist_msg("Hello, world!\n\nWhat's up?", "I'm\nthinkingThinking2", "special_function", "{\"arg1\": 1}"), - "\nI'm\nthinkingHello, world!\nThinking2What's up?\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(simple_assist_msg("\nI'm\nthinkingHello, world!\nThinking2What's up?", "", "special_function", "{\"arg1\": 1}"), - "\nI'm\nthinkingHello, world!\nThinking2What's up?\nspecial_function\narg1\n1\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_GLM_4_5, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE - }); }); - - // Test template generation for regular content - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "\n\nHello, world!\nWhat's up?", - /* expect_grammar_triggered= */ false); - - // Test template generation for tool calls - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "\n\nspecial_function\narg1\n1\n\n", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ false, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* ignore_whitespace_differences= */ true - ); - - // Test template generation for tools with optional parameters - test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, - "\n\nspecial_function_with_opt\narg1\n1\n\n", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ false, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* ignore_whitespace_differences= */ true - ); - test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, - "\n\nspecial_function_with_opt\narg1\n1\narg2\n2\n\n", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ false, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* ignore_whitespace_differences= */ true - ); + auto tst = peg_tester("models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja", detailed_debug); + tst.test("<|START_RESPONSE|>Hello, world!\nWhat's up?<|END_RESPONSE|>").expect(message_assist).run(); + tst.test( + "<|START_THINKING|>I'm\nthinking<|END_THINKING|>" + "<|START_ACTION|>[\n" + " {\"tool_call_id\": \"0\", \"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}\n" + "]<|END_ACTION|>") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect(message_assist_thoughts_call_idx) + .run(); + } + // CohereForAI-c4ai-command-r-plus (uses markdown code block format) + { + auto tst = peg_tester("models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", detailed_debug); + tst.test("<|CHATBOT_TOKEN|>Hello, world!\nWhat's up?<|END_OF_TURN_TOKEN|>").expect(message_assist).run(); + // Tool calls: Action: followed by JSON code block + tst.test( + "Action:\n" + "```json\n" + "[{\"tool_name\": \"special_function\", \"parameters\": {\"arg1\": 1}}]\n" + "```") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + } + + // mistralai-Mistral-Nemo-Instruct-2407.jinja + { + auto tst = peg_tester("models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]") + .tools({ special_function_tool }) + .expect(message_assist_call_id) + .run(); } - { - auto tmpls = read_templates("models/templates/Kimi-K2-Thinking.jinja"); - std::vector end_tokens{ "<|im_end|>" }; - - assert_equals(COMMON_CHAT_FORMAT_KIMI_K2, common_chat_templates_apply(tmpls.get(), inputs_no_tools).format); - assert_equals(COMMON_CHAT_FORMAT_KIMI_K2, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Test parsing regular content - assert_msg_equals(message_assist, - test_chat_parse( - "Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_KIMI_K2})); - - // Test parsing content with thinking - assert_msg_equals(message_assist_thoughts, - test_chat_parse( - "I'm\nthinkingHello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK, - })); - - // Test parsing tool calls - assert_msg_equals(message_assist_call, - test_chat_parse( - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_KIMI_K2})); - - // Test parsing tool calls with thinking - assert_msg_equals(message_assist_call_thoughts, - test_chat_parse( - "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test tool calls with extra content - assert_msg_equals(message_assist_call_content, - test_chat_parse( - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_KIMI_K2} - )); - - // Test tool calls with extra content AND thinking - assert_msg_equals(message_assist_call_thoughts_content, - test_chat_parse( - "I'm\nthinking<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>Hello, world!\nWhat's up?", - /* is_partial= */ false, - { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - })); - - // Test streaming - test_parser_with_streaming(message_assist_call_thoughts_content, - "I'm\nthinking\nHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(message_assist_call_thoughts_unparsed, - "I'm\nthinking\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE - }); }); - test_parser_with_streaming(message_assist_call_thoughts_content, - "I'm\nthinking\n\n\nHello, world!\nWhat's up?\n\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>\n", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(message_assist_call_withopt, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_NONE - }); }); - test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": \"123456\"}"), - "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": \"123456\"}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": [1, 2, \"345\", 6]}"), - "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": [1, 2, \"345\", 6]}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming(simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking", "special_function", "{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}"), - "I'm\nthinkingHello, world!\nWhat's up?\n<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": {\"12\": 34, \"5\": [67, 8], \"9\": \"10\"}}<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - /* .format = */ COMMON_CHAT_FORMAT_KIMI_K2, - /* .reasoning_format = */ COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming( - simple_assist_msg("", "", "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), - "<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function:0<|tool_call_argument_begin|>" - "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" - "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); - test_parser_with_streaming( - simple_assist_msg("", "", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}"), - "<|tool_calls_section_begin|><|tool_call_begin|>functions.web_search:0<|tool_call_argument_begin|>" - "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}" - "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); - test_parser_with_streaming( - simple_assist_msg("", "", "read_file", "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}"), - "<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" - "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}" - "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); - test_parser_with_streaming( - simple_assist_msg( - "Let me start by examining the relevant files to understand the current implementation.", "", - "read_file", - "{\"files\": [{\"path\": \"src/app/Partners.tsx\", \"line_ranges\": [\"1-100\"]}]}"), - "Let me start by examining the relevant files to understand the current implementation." - "<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" - "{\"files\":[{\"path\":\"src/app/Partners.tsx\",\"line_ranges\":[\"1-100\"]}]}" - "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_KIMI_K2}); }); - auto multi_tool_msg = simple_assist_msg("Let me call multiple tools.", "I'm thinking."); - multi_tool_msg.tool_calls.push_back({ "read_file", "{\"files\": [{\"path\": \"src/app/Partners.tsx\", \"line_ranges\": [\"1-100\"]}]}", "" }); - multi_tool_msg.tool_calls.push_back({ "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}", "" }); - multi_tool_msg.tool_calls.push_back({ "complex_function", "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}", "" }); - multi_tool_msg.tool_calls.push_back({ "emoji_function", "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}", "" }); - test_parser_with_streaming(multi_tool_msg, - "I'm thinking.Let me call multiple tools." - "<|tool_calls_section_begin|>" - "<|tool_call_begin|>functions.read_file:0<|tool_call_argument_begin|>" - "{\"files\":[{\"path\":\"src/app/Partners.tsx\",\"line_ranges\":[\"1-100\"]}]}" - "<|tool_call_end|>" - "<|tool_call_begin|>functions.web_search:1<|tool_call_argument_begin|>" - "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}" - "<|tool_call_end|>" - "<|tool_call_begin|>functions.complex_function:2<|tool_call_argument_begin|>" - "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" - "<|tool_call_end|>" - "<|tool_call_begin|>functions.emoji_function:3<|tool_call_argument_begin|>" - "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}" - "<|tool_call_end|>" - "<|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming( - simple_assist_msg("", "I'm thinking", "complex_function_in_think", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), - "I'm thinking<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function_in_think:0<|tool_call_argument_begin|>" - "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" - "<|tool_call_end|><|tool_calls_section_end|>", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - test_parser_with_streaming( - simple_assist_msg("Hello", "I'm thinkingI'm still thinking", "complex_function_in_think", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}"), - "I'm thinking<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function_in_think:0<|tool_call_argument_begin|>" - "{\"name\": \"John Doe\", \"age\": 30, \"active\": true, \"score\": 95.5}" - "<|tool_call_end|><|tool_calls_section_end|>I'm still thinkingHello", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, { - COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_REASONING_FORMAT_DEEPSEEK - }); }); - - // Test template rendering - common_chat_templates_inputs conversation_with_tools = inputs_tools; - conversation_with_tools.messages.push_back(simple_assist_msg("Let's do it", "Think first", "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}")); - conversation_with_tools.messages.push_back({ - "tool", - "Tool response 1", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "complex_function", - /* .tool_call_id = */ "", - }); - conversation_with_tools.messages.push_back(simple_assist_msg("Continue", "Think next", "web_search", "{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}")); - conversation_with_tools.messages.push_back({ - "tool", - "Tool response 2", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "web_search", - /* .tool_call_id = */ "", - }); - conversation_with_tools.messages.push_back(simple_assist_msg("CC", "Think last", "read_file", "{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}")); - conversation_with_tools.messages.push_back({ - "tool", - "Tool response 3", - /* .content_parts = */ {}, - /* .tool_calls = */ {}, - /* .reasoning_content = */ "", - /* .tool_name = */ "read_file", - /* .tool_call_id = */ "", - }); - assert_equals(common_chat_templates_apply(tmpls.get(), conversation_with_tools).prompt, std::string("<|im_system|>tool_declare<|im_middle|>[{\"type\": \"function\", \"function\": {\"name\": \"special_function\", \"description\": \"I'm special\", \"parameters\": {\"type\": \"object\", \"properties\": {\"arg1\": {\"type\": \"integer\", \"description\": \"The arg.\"}}, \"required\": [\"arg1\"]}}}]<|im_end|><|im_system|>system<|im_middle|>You are Kimi, an AI assistant created by Moonshot AI.<|im_end|><|im_user|>user<|im_middle|>Hey there!<|im_end|><|im_assistant|>assistant<|im_middle|>Think firstLet's do it<|tool_calls_section_begin|><|tool_call_begin|>functions.complex_function:0<|tool_call_argument_begin|>{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}<|tool_call_end|><|tool_calls_section_end|><|im_end|><|im_system|>complex_function<|im_middle|>## Return of functions.complex_function:0\nTool response 1<|im_end|><|im_assistant|>assistant<|im_middle|>Think nextContinue<|tool_calls_section_begin|><|tool_call_begin|>functions.web_search:1<|tool_call_argument_begin|>{\"query\":\"\\\"From Zero\\\" Linkin Park album tracklist complete songs\",\"limit\":3,\"type\":\"text\"}<|tool_call_end|><|tool_calls_section_end|><|im_end|><|im_system|>web_search<|im_middle|>## Return of functions.web_search:1\nTool response 2<|im_end|><|im_assistant|>assistant<|im_middle|>Think lastCC<|tool_calls_section_begin|><|tool_call_begin|>functions.read_file:2<|tool_call_argument_begin|>{\"args\": [{\"path\": \"src/providers/ThemeProvider.tsx\"}, {\"path\": \"src/components/Header.tsx\"}, {\"path\": \"src/components/ThemeToggle.tsx\"}, {\"path\": \"src/app/globals.css\"}, {\"path\": \"src/app/layout.tsx\"}]}<|tool_call_end|><|tool_calls_section_end|><|im_end|><|im_system|>read_file<|im_middle|>## Return of functions.read_file:2\nTool response 3<|im_end|><|im_assistant|>assistant<|im_middle|>")); - - // Test template generation for regular content - test_templates(tmpls.get(), end_tokens, message_assist, tools, - "Hello, world!\nWhat's up?", - /* expect_grammar_triggered= */ false); - - // Test template generation for tool calls - test_templates(tmpls.get(), end_tokens, message_assist_call, tools, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* ignore_whitespace_differences= */ true - ); - - // Test template generation for tools with optional parameters - test_templates(tmpls.get(), end_tokens, message_assist_call_noopt, tools, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1}<|tool_call_end|><|tool_calls_section_end|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* ignore_whitespace_differences= */ true - ); - test_templates(tmpls.get(), end_tokens, message_assist_call_withopt, tools, - "<|tool_calls_section_begin|><|tool_call_begin|>functions.special_function_with_opt:0<|tool_call_argument_begin|>{\"arg1\": 1, \"arg2\": 2}<|tool_call_end|><|tool_calls_section_end|>", - /* expect_grammar_triggered= */ true, - /* test_grammar_if_triggered= */ true, - /* common_reasoning_format= */ COMMON_REASONING_FORMAT_DEEPSEEK, - /* ignore_whitespace_differences= */ true - ); + auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.1.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + } + // Functionary v3.2 - recipient-based format: >>>recipient\n{content} + { + auto tst = peg_tester("models/templates/meetkai-functionary-medium-v3.2.jinja", detailed_debug); + tst.test(">>>all\nHello, world!\nWhat's up?").expect(message_assist).run(); + tst.test(">>>special_function\n{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); } + // FireFunction { - // Step-3.5-Flash template: uses same XML output format as Qwen3-Coder and Nemotron v3, - // but with support. Routes to the Nemotron v3 PEG parser for streaming and - // schema-aware parameter parsing. - auto tmpls = read_templates("models/templates/stepfun-ai-Step-3.5-Flash.jinja"); - assert_equals(COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, common_chat_templates_apply(tmpls.get(), inputs_tools).format); - - // Grammar and PEG parser should be generated with thinking_forced_open - { - common_chat_templates_inputs inputs; - inputs.messages = { message_user }; - inputs.tools = { special_function_tool }; - auto params = common_chat_templates_apply(tmpls.get(), inputs); - assert_equals(COMMON_CHAT_FORMAT_PEG_CONSTRUCTED, params.format); - assert_equals(true, params.thinking_forced_open); - assert_equals(false, params.grammar.empty()); - assert_equals(false, params.parser.empty()); - auto grammar = build_grammar(params.grammar); - GGML_ASSERT(grammar && "Failed to build Step-3.5-Flash grammar"); - } + auto tst = peg_tester("models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test(" functools[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); } -} -static void test_template_output_peg_parsers() { - printf("[%s]\n", __func__); + // DeepSeek R1 Distill Llama 8B - reasoning tests only (forced open thinking) + // Note: Template uses forced-open mode (prompt ends with ), so input shouldn't include opening tag + { + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?") + .enable_thinking(true) // Forced open + .expect(message_assist) + .run(); + tst.test("I'm\nthinkingHello, world!\nWhat's up?") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + } + // llama-cpp DeepSeek R1 template (always forced-open thinking) + { + auto tst = peg_tester("models/templates/llama-cpp-deepseek-r1.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("I'm\nthinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + tst.test( + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n{\"arg1\": 1}```<|tool▁call▁end|><|tool▁calls▁end|>") + .tools({ special_function_tool }) + .parallel_tool_calls(true) + .expect(message_assist_call) + .run(); + } + // DeepSeek R1 Distill Qwen 32B - reasoning tests only (forced open thinking) + // Note: Template uses forced-open mode (prompt ends with ), so input shouldn't include opening tag + { + auto tst = peg_tester("models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").enable_thinking(true).expect(message_assist).run(); + tst.test("I'm\nthinkingHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .expect(message_assist_thoughts) + .run(); + tst.test( + "<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>special_function\n" + "```json\n{\"arg1\": 1}```<|tool▁call▁end|><|tool▁calls▁end|>") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + } + + // MiMo-VL / Hermes 3 / Qwen 2.5 (Common JSON format) + for (const auto & path : + { "models/templates/MiMo-VL.jinja", "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", + "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja" }) { + auto tst = peg_tester(path, detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("\n{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}\n") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + } + + // Apriel 1.5 + { + auto tst = peg_tester("models/templates/unsloth-Apriel-1.5.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("[{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}}]") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + } - // JSON schemas - const char * invoice_schema = R"({ - "type": "object", - "properties": { - "amount": {"type": "number"}, - "date": {"type": "string"} - } - })"; + // Apriel 1.6 Thinker (reasoning-only support) + { + auto tst = peg_tester("models/templates/Apriel-1.6-15b-Thinker-fixed.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + + // Implicit reasoning start (forced open) + tst.test("I'm\nthinking\n[BEGIN FINAL RESPONSE]\nHello, world!\nWhat's up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); + + // Reasoning + Tool calls + tst.test( + "I'm\nthinking\n[BEGIN FINAL RESPONSE]\n[{\"name\": \"special_function\", \"arguments\": " + "{\"arg1\": 1}}]") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + } + + // Mistral Small 3.2 - FUNC_BRACKET_TAG format: [TOOL_CALLS]func_name[CALL_ID]id[ARGS]{...} + { + auto tst = peg_tester("models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("[TOOL_CALLS]special_function[CALL_ID]123456789[ARGS]{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call_id) + .run(); + } + // Devstral + { + auto tst = peg_tester("models/templates/unsloth-mistral-Devstral-Small-2507.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").expect(message_assist).run(); + tst.test("[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + tst.test("Hello, world!\nWhat's up?[TOOL_CALLS]special_function[ARGS]{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call_content) + .run(); + } { - // Ministral-3-14B-Reasoning-2512 - auto tmpls = read_templates("models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja"); - - // Test basic message - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "Hello, world!\nWhat's up?"; - t.expect = message_assist; - }); - - // Test basic message and reasoning with reasoning_format = none - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?"; - t.expect.content = "[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?"; - }); - - // Test basic message and reasoning with reasoning_format = auto - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "[THINK]I'm\nthinking[/THINK]Hello, world!\nWhat's up?"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - - t.expect = message_assist_thoughts; - }); - - // Test tool call - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {special_function_tool}; - - t.expect = message_assist_call; - }); - - // Test tool call with reasoning - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "[THINK]I'm\nthinking[/THINK]" - R"([TOOL_CALLS]special_function[ARGS]{"arg1":1})"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {special_function_tool}; - - t.expect = message_assist_call_thoughts; - }); - - // Test parallel tool calls - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})" - R"([TOOL_CALLS]special_function_with_opt[ARGS]{"arg1": 1, "arg2": 2})"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.parallel_tool_calls = true; - t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; - - t.expect.tool_calls = {{ - /* .name = */ "special_function", - /* .arguments = */ R"({"arg1": 1})", - /* .id = */ {}, - }, { - /* .name = */ "special_function_with_opt", - /* .arguments = */ R"({"arg1": 1, "arg2": 2})", - /* .id = */ {}, - }}; - }); - - // Test response format - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "[THINK]I need to output the invoice details in JSON[/THINK]" - "```json\n" - R"({"amount": 123.45, "date": "2025-12-03"})" - "\n```"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.json_schema = invoice_schema; - - t.expect.reasoning_content = "I need to output the invoice details in JSON"; - t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; - }); + // Llama 3.1 + auto tst = peg_tester("models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run(); } { - // Qwen3-Coder - auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja"); - - // Test basic message - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "Hello, world!\nWhat's up?"; - t.expect = message_assist; - }); - - // Test tool call - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - ""; - t.params.tools = {special_function_tool}; - t.expect = message_assist_call; - }); - - // Test parallel tool calls - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "2\n" - "\n" - "\n" - ""; - t.params.parallel_tool_calls = true; - t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; - - t.expect.tool_calls = {{ - /* .name = */ "special_function", - /* .arguments = */ R"({"arg1": 1})", - /* .id = */ {}, - }, { - /* .name = */ "special_function_with_opt", - /* .arguments = */ R"({"arg1": 1, "arg2": 2})", - /* .id = */ {}, - }}; - }); - - // Test tool call with string parameter - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "def hello():\n" - " print(\"Hello, world!\")\n" - "\n" - "hello()\n" - "\n" - "\n" - ""; - t.params.tools = {python_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "python", - /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", - /* .id = */ {}, - }}; - }); - - // Test tool call with JSON parameter - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n" - "\n" - "\n" - ""; - t.params.tools = {todo_list_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "todo_list", - /* .arguments = */ "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", - /* .id = */ {}, - }}; - }); - - // Test tool call with string parameter and no closing tag - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "def hello():\n" - " print(\"Hello, world!\")\n" - "\n" - "hello()\n" - "\n" - ""; - t.params.tools = {python_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "python", - /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", - /* .id = */ {}, - }}; - }); - - // Test response format - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = R"({"amount": 123.45, "date": "2025-12-03"})"; - t.params.json_schema = invoice_schema; - - t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})"; - }); + // Llama 3.2 + auto tst = peg_tester("models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).run(); } { - // NVIDIA Nemotron-3 Nano - auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja"); - - // Test basic message - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "Hello, world!\nWhat's up?"; - t.expect = message_assist; - }); - - // Test basic message and reasoning with reasoning_format = none - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "I'm\nthinking\n\nHello, world!\nWhat's up?"; - t.expect.content = "I'm\nthinking\n\nHello, world!\nWhat's up?"; - }); - - // Test basic message and reasoning with reasoning_format = auto - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "I'm\nthinking\n\nHello, world!\nWhat's up?"; - t.params.enable_thinking = true; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - - t.expect = message_assist_thoughts; - }); - - // Test tool call - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - ""; - t.params.enable_thinking = false; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {special_function_tool}; - - t.expect = message_assist_call; - }); - - // Test tool call with reasoning - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "I'm\nthinking\n\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {special_function_tool}; - - t.expect = message_assist_call_thoughts; - }); - - // Test parallel tool calls - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "2\n" - "\n" - "\n" - ""; - t.params.enable_thinking = false; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.parallel_tool_calls = true; - t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; - - t.expect.tool_calls = {{ - /* .name = */ "special_function", - /* .arguments = */ R"({"arg1": 1})", - /* .id = */ {}, - }, { - /* .name = */ "special_function_with_opt", - /* .arguments = */ R"({"arg1": 1, "arg2": 2})", - /* .id = */ {}, - }}; - }); - - // Test tool call with string parameter - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "def hello():\n" - " print(\"Hello, world!\")\n" - "\n" - "hello()\n" - "\n" - "\n" - ""; - t.params.enable_thinking = false; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {python_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "python", - /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", - /* .id = */ {}, - }}; - }); - - // Test tool call with string parameter and no closing tag - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "def hello():\n" - " print(\"Hello, world!\")\n" - "\n" - "hello()\n" - "\n" - ""; - t.params.enable_thinking = false; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {python_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "python", - /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", - /* .id = */ {}, - }}; - }); - - // Test response format - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "I need to output the invoice details in JSON\n" - "\n" - R"({"amount": 123.45, "date": "2025-12-03"})"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.json_schema = invoice_schema; - - t.expect.reasoning_content = "I need to output the invoice details in JSON"; - t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})"; - }); + // Llama 3.3 + auto tst = peg_tester("models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", detailed_debug); + tst.test("Hello, world!\nWhat's up?").tools({ python_tool }).expect(message_assist).run(); } + // GPT-OSS format tests { - // Step-3.5-Flash (uses Nemotron v3 PEG parser with thinking_forced_open) - // Unlike Nemotron, Step-3.5-Flash always emits regardless of enable_thinking, - // so all inputs must include a delimiter. - auto tmpls = read_templates("models/templates/stepfun-ai-Step-3.5-Flash.jinja"); - - // Test basic message with reasoning - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "I'm\nthinking\n\nHello, world!\nWhat's up?"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - - t.expect = message_assist_thoughts; - }); - - // Test basic message without thinking content - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "\nHello, world!\nWhat's up?"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - - t.expect = message_assist; - }); - - // Test tool call without thinking content - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {special_function_tool}; - - t.expect = message_assist_call; - }); - - // Test tool call with thinking - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "I'm\nthinking\n\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {special_function_tool}; - - t.expect = message_assist_call_thoughts; - }); - - // Test parallel tool calls with thinking - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "I'm\nthinking\n\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "2\n" - "\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.parallel_tool_calls = true; - t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; - - t.expect.reasoning_content = "I'm\nthinking"; - t.expect.tool_calls = {{ - /* .name = */ "special_function", - /* .arguments = */ R"({"arg1": 1})", - /* .id = */ {}, - }, { - /* .name = */ "special_function_with_opt", - /* .arguments = */ R"({"arg1": 1, "arg2": 2})", - /* .id = */ {}, - }}; - }); - - // Test parallel tool calls without thinking content - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "\n" - "1\n" - "\n" - "\n" - "2\n" - "\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.parallel_tool_calls = true; - t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; - - t.expect.tool_calls = {{ - /* .name = */ "special_function", - /* .arguments = */ R"({"arg1": 1})", - /* .id = */ {}, - }, { - /* .name = */ "special_function_with_opt", - /* .arguments = */ R"({"arg1": 1, "arg2": 2})", - /* .id = */ {}, - }}; - }); - - // Test tool call with code string parameter - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "\n" - "def hello():\n" - " print(\"Hello, world!\")\n" - "\n" - "hello()\n" - "\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {python_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "python", - /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", - /* .id = */ {}, - }}; - }); - - // Test tool call with string parameter and no closing tag - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "\n" - "\n" - "\n" - "\n" - "def hello():\n" - " print(\"Hello, world!\")\n" - "\n" - "hello()\n" - "\n" - ""; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.tools = {python_tool}; - - t.expect.tool_calls = {{ - /* .name = */ "python", - /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", - /* .id = */ {}, - }}; - }); - - // Test response format (JSON schema with thinking) - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = - "I need to output the invoice details in JSON\n" - "\n" - R"({"amount": 123.45, "date": "2025-12-03"})"; - t.params.reasoning_format = COMMON_REASONING_FORMAT_AUTO; - t.params.json_schema = invoice_schema; - - t.expect.reasoning_content = "I need to output the invoice details in JSON"; - t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})"; - }); + auto tst = peg_tester("models/templates/openai-gpt-oss-120b.jinja", detailed_debug); + + // Basic content only - final channel + tst.test("<|channel|>final<|message|>Hello, world!\nWhat's up?").expect(message_assist).run(); + + // Basic content only - commentary channel + tst.test("<|channel|>commentary<|message|>Hello, world!\nWhat's up?").expect(message_assist).run(); + + // Analysis channel (reasoning) with final channel (content) + tst.test( + "<|channel|>analysis<|message|>I'm\nthinking<|end|>\n<|channel|>final<|message|>Hello, world!\nWhat's " + "up?") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect(message_assist_thoughts) + .run(); + + // Analysis channel only (partial) - still works when reasoning format is set + tst.test("<|channel|>analysis<|message|>I'm\nthinking") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .is_partial(true) + .expect_reasoning("I'm\nthinking") + .run(); + + // Reasoning format none - reasoning stays in content + tst.test( + "<|channel|>analysis<|message|>I'm\nthinking<|end|>\n<|channel|>final<|message|>Hello, world!\nWhat's " + "up?") + .reasoning_format(COMMON_REASONING_FORMAT_NONE) + .expect_content( + "<|channel|>analysis<|message|>I'm\nthinking<|end|>Hello, world!\nWhat's up?") + .run(); + + // Tool call with recipient in role header: " to=functions.NAME<|channel|>analysis<|message|>JSON" + tst.test(" to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call with recipient in channel header: "<|channel|>analysis to=functions.NAME<|message|>JSON" + tst.test("<|channel|>analysis to=functions.special_function<|message|>{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call with constraint: " to=functions.NAME<|channel|>analysis <|constrain|>json<|message|>JSON" + tst.test(" to=functions.special_function<|channel|>analysis <|constrain|>json<|message|>{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call in commentary channel (channel header variant) + tst.test("<|channel|>commentary to=functions.special_function<|message|>{\"arg1\": 1}") + .tools({ special_function_tool }) + .expect(message_assist_call) + .run(); + + // Tool call with reasoning + content (analysis first, then tool call) + tst.test( + "<|channel|>analysis<|message|>I'm\nthinking<|end|>\n" + "<|start|>assistant to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + // Tool calling with extra channel before + tst.test( + "<|channel|>analysis<|message|>I'm\nthinking<|end|><|start|>assistant<|channel|>commentary" + " to=functions.special_function <|message|>{\"arg1\": 1}") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect(message_assist_call_thoughts) + .run(); + + // Reasoning after final channel + // Tool calling after final channel + tst.test( + "<|channel|>final<|message|><|end|>" + "<|start|>assistant<|channel|>analysis<|message|>Thinking about edit..." + ) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .expect_reasoning("Thinking about edit...") + .expect_content("") + .run(); + + // Tool calling after final channel + tst.test( + "<|channel|>final<|message|><|end|>" + "<|start|>assistant<|channel|>analysis<|message|>Thinking about edit...<|end|>" + "<|start|>assistant<|channel|>commentary to=functions.edit <|constrain|>json" + "<|message|>{\"oldString\": \"if (part < railCount - 1) {\", \"newString\": \"if (part < 4) {\", \"replaceAll\": false}" + ) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + { + /* .name = */ "edit", + /* .description = */ "Edit a file", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "oldString": { + "type": "string", + "description": "Old string to replace." + }, + "newString": { + "type": "string", + "description": "New replacement string." + }, + "replaceAll": { + "type": "boolean", + "description": "Whether to replace all occurences." + } + }, + "required": ["oldString", "newString"] + })", + } + }) + .expect_reasoning("Thinking about edit...") + .expect_tool_calls({ + { "edit", R"({"oldString": "if (part < railCount - 1) {", "newString": "if (part < 4) {", "replaceAll": false})", {} } + }) + .run(); + + // Parallel tool calls + tst.test( + " to=functions.special_function<|channel|>analysis<|message|>{\"arg1\": 1}\n" + "<|start|>assistant to=functions.special_function_with_opt<|channel|>analysis<|message|>{\"arg1\": 1, " + "\"arg2\": 2}") + .parallel_tool_calls(true) + .tools({ + special_function_tool, special_function_tool_with_optional_param + }) + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + { "special_function_with_opt", R"({"arg1": 1, "arg2": 2})", {} }, + }) + .run(); } { - // Solar-Open-100B - auto tmpls = read_templates("models/templates/upstage-Solar-Open-100B.jinja"); - - // Test basic message - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|content|>Hello, world!\nWhat's up?"; - t.expect = message_assist; - }); - - // Test basic message and reasoning - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|think|>I'm\nthinking<|end|><|begin|>assistant<|content|>Hello, world!\nWhat's up?"; - t.expect = message_assist_thoughts; - }); - - // Test basic message and reasoning_effort = low - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|content|>Hello, world!\nWhat's up?"; - t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; - t.expect = message_assist; - }); - - // Test tool call - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|tool_calls|>" - "<|tool_call:begin|>123456789" - "<|tool_call:name|>special_function" - "<|tool_call:args|>{\"arg1\":1}" - "<|tool_call:end|>"; - - t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; - t.params.tools = {special_function_tool}; - t.expect = message_assist_call_id; - }); - - // Test tool call with reasoning - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|think|>I'm\nthinking<|end|>" - "<|begin|>assistant<|tool_calls|>" - "<|tool_call:begin|>0" - "<|tool_call:name|>special_function" - "<|tool_call:args|>{\"arg1\":1}" - "<|tool_call:end|>"; - - t.params.tools = {special_function_tool}; - t.expect = message_assist_thoughts_call_idx; - }); - - // Test tool call with reasoning and tool_choice = required - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|think|>I'm\nthinking<|end|>" - "<|begin|>assistant<|tool_calls|>" - "<|tool_call:begin|>0" - "<|tool_call:name|>special_function" - "<|tool_call:args|>{\"arg1\":1}" - "<|tool_call:end|>"; - - t.params.tools = {special_function_tool}; - t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; - t.expect = message_assist_thoughts_call_idx; - }); - - // Test tool call without reasoning and tool_choice = required - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|tool_calls|>" - "<|tool_call:begin|>0" - "<|tool_call:name|>special_function" - "<|tool_call:args|>{\"arg1\":1}" - "<|tool_call:end|>"; - - t.params.tools = {special_function_tool}; - t.params.tool_choice = COMMON_CHAT_TOOL_CHOICE_REQUIRED; - t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; - t.expect = message_assist_call_idx; - }); - - // Test parallel tool calls - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|think|>I'm\nthinking<|end|>" - "<|begin|>assistant<|tool_calls|>" - "<|tool_call:begin|>0" - "<|tool_call:name|>special_function" - "<|tool_call:args|>{\"arg1\":1}" - "<|tool_call:end|>" - "<|tool_call:begin|>1" - "<|tool_call:name|>special_function_with_opt" - "<|tool_call:args|>{\"arg1\": 1, \"arg2\": 2}" - "<|tool_call:end|>"; - - t.params.parallel_tool_calls = true; - t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; - - t.expect.reasoning_content = "I'm\nthinking"; - t.expect.tool_calls = {{ - /* .name = */ "special_function", - /* .arguments = */ R"({"arg1": 1})", - /* .id = */ "0", - }, { - /* .name = */ "special_function_with_opt", - /* .arguments = */ R"({"arg1": 1, "arg2": 2})", - /* .id = */ "1", - }}; - }); - - // Test response format - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|think|>I need to output the invoice details in JSON<|end|>" - "<|begin|>assistant<|content|>" - R"({"amount": 123.45, "date": "2025-12-03"})"; - - t.params.json_schema = invoice_schema; - - t.expect.reasoning_content = "I need to output the invoice details in JSON"; - t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; - }); - - // Test response format no reasoning - test_peg_parser(tmpls.get(), [&](auto & t) { - t.input = "<|content|>" - R"({"amount": 123.45, "date": "2025-12-03"})"; - - t.params.chat_template_kwargs["reasoning_effort"] = "\"low\""; - t.params.json_schema = invoice_schema; - - t.expect.content =R"({"amount": 123.45, "date": "2025-12-03"})"; - }); + auto tst = peg_tester("models/templates/StepFun3.5-Flash.jinja", detailed_debug); + tst.test("I was thinking\nNow I'm not."). + enable_thinking(true). + reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK). + expect_reasoning("I was thinking"). + expect_content("Now I'm not.") + .run(); + + // Test that numeric-looking string values are coerced to strings per the schema + tst.test( + "Let me call the magic tool\n" + "\n" + "\n" + "\n" + "\nfooBar\n\n" + "\n5123123\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ magic_tool }) + .expect_reasoning("Let me call the magic tool") + .expect_tool_calls({ + { "magic", R"({"name": "fooBar", "ref": "5123123"})", {} }, + }) + .run(); + + // Test that numeric values are correctly interpreted as numbers when schema calls for number + tst.test( + "Let me call the special function\n" + "\n" + "\n" + "\n" + "\n42555916\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool }) + .expect_reasoning("Let me call the special function") + .expect_tool_calls({ + { "special_function", R"({"arg1": 42555916})", {} }, + }) + .run(); + + tst.test( + "Let me call the special function with opt\n" + "\n" + "\n" + "\n" + "\n42555916\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ special_function_tool_with_optional_param }) + .expect_reasoning("Let me call the special function with opt") + .expect_tool_calls({ + { "special_function_with_opt", R"({"arg1": 42555916})", {} }, + }) + .run(); + + tst.test( + "Let me call the magic_int function\n" + "\n" + "\n" + "\n" + "\n42555916\n\n" + "\nbaz\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ magic_int_tool }) + .expect_reasoning("Let me call the magic_int function") + .expect_tool_calls({ + { "magic_int", R"({"ref": 42555916, "name": "baz"})", {} }, + }) + .run(); + + tst.test( + "Call string_param with empty text\n" + "\n" + "\n" + "\n" + "\n\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ string_param_tool }) + .expect_reasoning("Call string_param with empty text") + .expect_tool_calls({ + { "string_param", R"({"text": ""})", {} }, + }) + .run(); + + tst.test( + "Test simple quoted unquoted\n" + "\n" + "\n" + "\n" + "\n\"foo\"\n\n" + "\nfoo\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ quoted_unquoted_tool }) + .expect_reasoning("Test simple quoted unquoted") + .expect_tool_calls({ + { "quoted_unquoted", R"({"quoted": "\"foo\"", "unquoted": "foo"})", {} }, + }) + .run(); + + tst.test( + "Test complex quoted unquoted\n" + "\n" + "\n" + "\n" + "\n\"printf(\\\"foo\\\");\"\n\n" + "\nprintf(\"foo\");\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ quoted_unquoted_tool }) + .expect_reasoning("Test complex quoted unquoted") + .expect_tool_calls({ + { "quoted_unquoted", R"({ "quoted" : "\"printf(\\\"foo\\\");\"", "unquoted": "printf(\"foo\");" })", {} } + }) + .run(); + + tst.test( + "Test negative number\n" + "\n" + "\n" + "\n" + "\n-14\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ magic_int_tool }) + .expect_reasoning("Test negative number") + .expect_tool_calls({ + { "magic_int", R"({ "ref" : -14 })", {} } + }) + .run(); + + tst.test( + "Test decimal number\n" + "\n" + "\n" + "\n" + "\n3.14\n\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ amount_tool }) + .expect_reasoning("Test decimal number") + .expect_tool_calls({ + { "amount", R"({ "orig" : 3.14 })", {} } + }) + .run(); + + tst.test( + "Test imaginary number\n" + "\n" + "\n" + "\n" + "\n" + "{ \"real\": 3.14, \"imaginary\": 2.71 }\n" + "\n" + "\n" + "") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_DEEPSEEK) + .tools({ imaginary_number_tool }) + .expect_reasoning("Test imaginary number") + .expect_tool_calls({ + { "imaginary_number", R"({ "number" : {"real":3.14,"imaginary":2.71 } })", {} } + }) + .run(); + } } static void test_msg_diffs_compute() { - printf("[%s]\n", __func__); + LOG_DBG("%s\n", __func__); { common_chat_msg msg1; @@ -3759,9 +2960,7 @@ static void test_msg_diffs_compute() { common_chat_msg_diff diff; diff.content_delta = "Hello, world!"; - assert_equals( - {diff}, - common_chat_msg_diff::compute_diffs(msg1, msg2)); + assert_equals({ diff }, common_chat_msg_diff::compute_diffs(msg1, msg2)); } { common_chat_msg msg1; @@ -3773,37 +2972,35 @@ static void test_msg_diffs_compute() { common_chat_msg_diff diff; diff.content_delta = " world!"; - assert_equals( - {diff}, - common_chat_msg_diff::compute_diffs(msg1, msg2)); + assert_equals({ diff }, common_chat_msg_diff::compute_diffs(msg1, msg2)); } { common_chat_msg msg0; common_chat_msg msg1; - msg1.tool_calls = { { "special_function", "{\"ar", /* .id = */ "123" } }; + msg1.tool_calls = { + { "special_function", "{\"ar", /* .id = */ "123" } + }; common_chat_msg msg2; - msg2.tool_calls = { { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } }; + msg2.tool_calls = { + { "special_function", "{\"arg1\": 1}", /* .id = */ "123" } + }; common_chat_msg_diff diff01; - diff01.tool_call_index = 0; - diff01.tool_call_delta.name = "special_function"; - diff01.tool_call_delta.id = "123"; + diff01.tool_call_index = 0; + diff01.tool_call_delta.name = "special_function"; + diff01.tool_call_delta.id = "123"; diff01.tool_call_delta.arguments = "{\"ar"; - assert_equals( - {diff01}, - common_chat_msg_diff::compute_diffs(msg0, msg1)); + assert_equals({ diff01 }, common_chat_msg_diff::compute_diffs(msg0, msg1)); common_chat_msg_diff diff12; - diff12.tool_call_index = 0; + diff12.tool_call_index = 0; // Note: neither id nor name change here. diff12.tool_call_delta.arguments = "g1\": 1}"; - assert_equals( - {diff12}, - common_chat_msg_diff::compute_diffs(msg1, msg2)); + assert_equals({ diff12 }, common_chat_msg_diff::compute_diffs(msg1, msg2)); } { common_chat_msg msg0; @@ -3815,68 +3012,81 @@ static void test_msg_diffs_compute() { }; common_chat_msg_diff diff1; - diff1.tool_call_index = 0; - diff1.tool_call_delta.name = "f1"; - diff1.tool_call_delta.id = "123"; + diff1.tool_call_index = 0; + diff1.tool_call_delta.name = "f1"; + diff1.tool_call_delta.id = "123"; diff1.tool_call_delta.arguments = "{\"arg1\": 1}"; common_chat_msg_diff diff2; - diff2.tool_call_index = 1; - diff2.tool_call_delta.name = "f2"; - diff2.tool_call_delta.id = "222"; + diff2.tool_call_index = 1; + diff2.tool_call_delta.name = "f2"; + diff2.tool_call_delta.id = "222"; diff2.tool_call_delta.arguments = "{\"arg2\": 2}"; - assert_equals( - {diff1, diff2}, - common_chat_msg_diff::compute_diffs(msg0, msg2)); + assert_equals({ diff1, diff2 }, common_chat_msg_diff::compute_diffs(msg0, msg2)); } } int main(int argc, char ** argv) { - common_log_set_verbosity_thold(999); + bool detailed_debug = false; + bool only_run_filtered = false; + + // Check for --template flag + for (int i = 1; i < argc; i++) { + std::string arg = argv[i]; + if (arg == "--template" && i + 1 < argc) { + g_template_filter = argv[++i]; + // Only run PEG parser tests with the filter + only_run_filtered = true; + } + if (arg == "--detailed") { + detailed_debug = true; + common_log_set_verbosity_thold(999); + } + } + + if (only_run_filtered) { + test_template_output_peg_parsers(detailed_debug); + std::cout << "\n[chat] All template tests passed!" << '\n'; + return 0; + } - // try { #ifndef _WIN32 - if (argc > 1) { - common_chat_templates_inputs inputs; - common_chat_msg msg; - msg.role = "user"; - msg.content = "Hey"; - inputs.messages = {msg}; - inputs.tools = { special_function_tool }; - - std::cout << "| Template | Format |\n"; - std::cout << "|----------|--------|\n"; - - for (int i = 1; i < argc; i++) { - try { - std::string path = argv[i]; - if (path.rfind(".jinja") != path.size() - 6) { - std::cerr << "Skipping non-jinja file: " << path << '\n'; - continue; - } - auto tmpls = read_templates(path); - auto parts = string_split(path, "/"); - auto name = parts[parts.size() - 1]; - auto format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); - std::cout << "| " << name << " | " << format << " |\n"; - } catch (const std::exception & e) { - std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n'; + if (argc > 1) { + common_chat_templates_inputs inputs; + common_chat_msg msg; + msg.role = "user"; + msg.content = "Hey"; + inputs.messages = { msg }; + inputs.tools = { special_function_tool }; + + std::cout << "| Template | Format |\n"; + std::cout << "|----------|--------|\n"; + + for (int i = 1; i < argc; i++) { + try { + std::string path = argv[i]; + if (path.rfind(".jinja") != path.size() - 6) { + std::cerr << "Skipping non-jinja file: " << path << '\n'; + continue; } + auto tmpls = read_templates(path); + auto parts = string_split(path, "/"); + const auto & name = parts[parts.size() - 1]; + const auto * format = common_chat_format_name(common_chat_templates_apply(tmpls.get(), inputs).format); + std::cout << "| " << name << " | " << format << " |\n"; + } catch (const std::exception & e) { + std::cerr << "Failed to process " << argv[i] << ": " << e.what() << '\n'; } - } else -#endif - { - test_msg_diffs_compute(); - test_msgs_oaicompat_json_conversion(); - test_tools_oaicompat_json_conversion(); - test_template_output_parsers(); - test_template_output_peg_parsers(); - std::cout << "\n[chat] All tests passed!" << '\n'; } - return 0; - // } catch (const std::exception & e) { - // std::cerr << "Error: " << e.what() << '\n'; - // return 1; - // } + } else +#endif + { + test_msg_diffs_compute(); + test_msgs_oaicompat_json_conversion(); + test_tools_oaicompat_json_conversion(); + test_template_output_peg_parsers(detailed_debug); + std::cout << "\n[chat] All tests passed!" << '\n'; + } + return 0; } diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index a8e9ff33a43..eb33804c9a7 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -1340,6 +1340,26 @@ static void test_all(const std::string & lang, std::function +#include #include #include #include @@ -15,6 +16,8 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_BENCH, print_usage)) { diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index e57bf52e36c..d43d1054907 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -1,3 +1,4 @@ +#include "chat.h" #include "common.h" #include "arg.h" #include "console.h" @@ -6,7 +7,10 @@ #include "server-context.h" #include "server-task.h" +#include #include +#include +#include #include #include #include @@ -188,13 +192,130 @@ struct cli_context { inputs.use_jinja = chat_params.use_jinja; inputs.parallel_tool_calls = false; inputs.add_generation_prompt = true; - inputs.enable_thinking = chat_params.enable_thinking; + inputs.reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; + inputs.enable_thinking = common_chat_templates_support_enable_thinking(chat_params.tmpls.get()); // Apply chat template to the list of messages return common_chat_templates_apply(chat_params.tmpls.get(), inputs); } }; +// TODO?: Make this reusable, enums, docs +static const std::array cmds = { + "/audio ", + "/clear", + "/exit", + "/image ", + "/read ", + "/regen", +}; + +static std::vector> auto_completion_callback(std::string_view line, size_t cursor_byte_pos) { + std::vector> matches; + std::string cmd; + + if (line.length() > 1 && line[0] == '/' && !std::any_of(cmds.begin(), cmds.end(), [line](const std::string & prefix) { + return string_starts_with(line, prefix); + })) { + auto it = cmds.begin(); + + while ((it = std::find_if(it, cmds.end(), [line](const std::string & cmd_line) { + return string_starts_with(cmd_line, line); + })) != cmds.end()) { + matches.emplace_back(*it, (*it).length()); + ++it; + } + } else { + auto it = std::find_if(cmds.begin(), cmds.end(), [line](const std::string & prefix) { + return prefix.back() == ' ' && string_starts_with(line, prefix); + }); + + if (it != cmds.end()) { + cmd = *it; + } + } + + if (!cmd.empty() && line.length() >= cmd.length() && cursor_byte_pos >= cmd.length()) { + const std::string path_prefix = std::string(line.substr(cmd.length(), cursor_byte_pos - cmd.length())); + const std::string path_postfix = std::string(line.substr(cursor_byte_pos)); + auto cur_dir = std::filesystem::current_path(); + std::string cur_dir_str = cur_dir.string(); + std::string expanded_prefix = path_prefix; + +#if !defined(_WIN32) + if (string_starts_with(path_prefix, "~")) { + const char * home = std::getenv("HOME"); + if (home && home[0]) { + expanded_prefix = std::string(home) + path_prefix.substr(1); + } + } + if (string_starts_with(expanded_prefix, "/")) { +#else + if (std::isalpha(expanded_prefix[0]) && expanded_prefix.find(':') == 1) { +#endif + cur_dir = std::filesystem::path(expanded_prefix).parent_path(); + cur_dir_str = ""; + } else if (!path_prefix.empty()) { + cur_dir /= std::filesystem::path(path_prefix).parent_path(); + } + + std::error_code ec; + for (const auto & entry : std::filesystem::directory_iterator(cur_dir, ec)) { + if (ec) { + break; + } + if (!entry.exists(ec)) { + ec.clear(); + continue; + } + + const std::string path_full = entry.path().string(); + std::string path_entry = !cur_dir_str.empty() && string_starts_with(path_full, cur_dir_str) ? path_full.substr(cur_dir_str.length() + 1) : path_full; + + if (entry.is_directory(ec)) { + path_entry.push_back(std::filesystem::path::preferred_separator); + } + + if (expanded_prefix.empty() || string_starts_with(path_entry, expanded_prefix)) { + std::string updated_line = cmd + path_entry; + matches.emplace_back(updated_line + path_postfix, updated_line.length()); + } + + if (ec) { + ec.clear(); + } + } + + if (matches.empty()) { + std::string updated_line = cmd + path_prefix; + matches.emplace_back(updated_line + path_postfix, updated_line.length()); + } + + // Add the longest common prefix + if (!expanded_prefix.empty() && matches.size() > 1) { + const std::string_view match0(matches[0].first); + const std::string_view match1(matches[1].first); + auto it = std::mismatch(match0.begin(), match0.end(), match1.begin(), match1.end()); + size_t len = it.first - match0.begin(); + + for (size_t i = 2; i < matches.size(); ++i) { + const std::string_view matchi(matches[i].first); + auto cmp = std::mismatch(match0.begin(), match0.end(), matchi.begin(), matchi.end()); + len = std::min(len, static_cast(cmp.first - match0.begin())); + } + + std::string updated_line = std::string(match0.substr(0, len)); + matches.emplace_back(updated_line + path_postfix, updated_line.length()); + } + + std::sort(matches.begin(), matches.end(), [](const auto & a, const auto & b) { + return a.first.compare(0, a.second, b.first, 0, b.second) < 0; + }); + } + + return matches; +} + int main(int argc, char ** argv) { common_params params; @@ -223,6 +344,7 @@ int main(int argc, char ** argv) { atexit([]() { console::cleanup(); }); console::set_display(DISPLAY_TYPE_RESET); + console::set_completion_callback(auto_completion_callback); #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; @@ -262,12 +384,15 @@ int main(int argc, char ** argv) { modalities += ", audio"; } - if (!params.system_prompt.empty()) { - ctx_cli.messages.push_back({ - {"role", "system"}, - {"content", params.system_prompt} - }); - } + auto add_system_prompt = [&]() { + if (!params.system_prompt.empty()) { + ctx_cli.messages.push_back({ + {"role", "system"}, + {"content", params.system_prompt} + }); + } + }; + add_system_prompt(); console::log("\n"); console::log("%s\n", LLAMA_ASCII_LOGO); @@ -357,6 +482,8 @@ int main(int argc, char ** argv) { } } else if (string_starts_with(buffer, "/clear")) { ctx_cli.messages.clear(); + add_system_prompt(); + ctx_cli.input_files.clear(); console::log("Chat history cleared.\n"); continue; diff --git a/tools/completion/README.md b/tools/completion/README.md index bcc08876592..f868c2c7d7d 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -480,7 +480,7 @@ Example usage: `--mirostat 2 --mirostat-lr 0.05 --mirostat-ent 3.0` Exclude Top Choices (XTC) is a unique sampler that is designed to remove top tokens from consideration and avoid more obvious and repetitive outputs. With a chance of `xtc-probability` it searches for tokens with probabilities of `xtc-threshold` and above, then removes all such tokens except the least probable one. -By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models. +By removing top tokens XTC can improve the variety of answers, break writing clichés and inhibit repetition, since clichés and repeated phrases are usually more likely to appear. By keeping the last token above the threshold, XTC ensures that the answer is still coherent. XTC is meant to be used for creative tasks, but feel free to experiment with different settings for different models. Being experimental and unique, XTC is disabled by default. The recommended combination of samplers is Min-P followed by XTC on its default settings: `--sampling-seq mx --min-p 0.02 --xtc-probability 0.5`. @@ -531,7 +531,7 @@ These options help improve the performance and memory usage of the LLaMA models. ### NUMA support -- `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilitizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes. +- `--numa distribute`: Pin an equal proportion of the threads to the cores on each NUMA node. This will spread the load amongst all cores on the system, utilizing all memory channels at the expense of potentially requiring memory to travel over the slow links between nodes. - `--numa isolate`: Pin all threads to the NUMA node that the program starts on. This limits the number of cores and amount of memory that can be used, but guarantees all memory access remains local to the NUMA node. - `--numa numactl`: Pin threads to the CPUMAP that is passed to the program by starting it with the numactl utility. This is the most flexible mode, and allow arbitrary core usage patterns, for example a map that uses all the cores on one NUMA nodes, and just enough cores on a second node to saturate the inter-node memory bus. diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index aed2c0e38fb..2e0f0871847 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -6,6 +6,7 @@ #include "llama.h" #include "chat.h" +#include #include #include #include @@ -84,6 +85,8 @@ static void sigint_handler(int signo) { #endif int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; g_params = ¶ms; @@ -376,7 +379,7 @@ int main(int argc, char ** argv) { // remove any "future" tokens that we might have inherited from the previous session if (session_tokens.size() > n_match) { if (!llama_memory_seq_rm(mem, -1, n_match, -1)) { - LOG_WRN("%s: unable to resuse common prefix (for example, when the memory is recurrent)\n", __func__); + LOG_WRN("%s: unable to reuse common prefix (for example, when the memory is recurrent)\n", __func__); llama_memory_clear(mem, true); session_tokens.clear(); n_match = 0; diff --git a/tools/cvector-generator/cvector-generator.cpp b/tools/cvector-generator/cvector-generator.cpp index 3ba7c529506..dcce0e98418 100644 --- a/tools/cvector-generator/cvector-generator.cpp +++ b/tools/cvector-generator/cvector-generator.cpp @@ -7,6 +7,8 @@ #include "pca.hpp" #include "mean.hpp" +#include + #ifdef GGML_USE_CUDA #include "ggml-cuda.h" #endif @@ -108,7 +110,7 @@ struct callback_data { auto diff_filtered = filter_nonzero_rows(v_pos[il]); v_diff_filtered.push_back(diff_filtered); } - return v_diff_filtered; // for convinient, we return the result std::vector + return v_diff_filtered; // for convenient, we return the result std::vector } // delete zero rows from a given 2D tensor @@ -392,6 +394,8 @@ static int prepare_entries(common_params & params, train_context & ctx_train) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.out_file = "control_vector.gguf"; diff --git a/tools/export-lora/export-lora.cpp b/tools/export-lora/export-lora.cpp index 41f426208f8..50774c59bae 100644 --- a/tools/export-lora/export-lora.cpp +++ b/tools/export-lora/export-lora.cpp @@ -5,6 +5,7 @@ #include "arg.h" #include "common.h" +#include #include #include #include @@ -411,6 +412,8 @@ static void print_usage(int, char ** argv) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.out_file = "ggml-lora-merged-f16.gguf"; diff --git a/tools/gguf-split/gguf-split.cpp b/tools/gguf-split/gguf-split.cpp index 30e771564e8..f99f0299b9c 100644 --- a/tools/gguf-split/gguf-split.cpp +++ b/tools/gguf-split/gguf-split.cpp @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -567,6 +568,8 @@ static void gguf_merge(const split_params & split_params) { } int main(int argc, const char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + split_params params; split_params_parse(argc, argv, params); diff --git a/tools/imatrix/README.md b/tools/imatrix/README.md index 4505cb4ce8c..4cbe4fd0cf7 100644 --- a/tools/imatrix/README.md +++ b/tools/imatrix/README.md @@ -95,4 +95,4 @@ Weighted averages of Σ(Act²), ZD Score and CosSim are also calculated. #### Important note on the computed Statistics When using these statistics, please note that they are computed on the squared activations, **not on the actual (raw) activations**. -Whilst the results are still useful, they're less realiable than using the raw values, and in the case of the cosine similarity, could be misleading if the tensor contains opposite vectors. +Whilst the results are still useful, they're less reliable than using the raw values, and in the case of the cosine similarity, could be misleading if the tensor contains opposite vectors. diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index e025c114b48..bbedb159cd4 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -1191,6 +1192,8 @@ static bool show_statistics(const common_params & params) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.out_file = "imatrix.gguf"; diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 7da6c3957c7..7a750265505 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -2034,8 +2034,9 @@ static std::unique_ptr create_printer(output_formats format) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); // try to set locale for unicode characters in markdown - setlocale(LC_CTYPE, ".UTF-8"); + std::setlocale(LC_CTYPE, ".UTF-8"); #if !defined(NDEBUG) fprintf(stderr, "warning: asserts enabled, performance may be affected\n"); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index a30c32ed42b..0c3cf8670a4 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -68,7 +68,7 @@ #define TN_POS_EMBD "%s.position_embd.weight" #define TN_CLASS_EMBD "v.class_embd" -#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backwrad compat +#define TN_PATCH_EMBD "v.patch_embd.weight" // not rename tensor with ".0" postfix for backward compat #define TN_PATCH_EMBD_1 "v.patch_embd.weight.1" #define TN_PATCH_BIAS "v.patch_embd.bias" #define TN_NORM_EMBD "v.norm_embd.%s" diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index e0eb9b32c8f..eeb8da58e08 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -46,7 +46,7 @@ struct clip_hparams { float image_std[3]; // for models using dynamic image size, we need to have a smaller image size to warmup - // otherwise, user will get OOM everytime they load the model + // otherwise, user will get OOM every time they load the model int32_t warmup_image_size = 0; int32_t warmup_audio_size = 3000; @@ -221,7 +221,7 @@ struct clip_model { // embeddings ggml_tensor * class_embedding = nullptr; ggml_tensor * patch_embeddings_0 = nullptr; - ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temproal dimension (Qwen2VL) + ggml_tensor * patch_embeddings_1 = nullptr; // second Conv2D kernel when we decouple Conv3D along temporal dimension (Qwen2VL) ggml_tensor * patch_bias = nullptr; ggml_tensor * position_embeddings = nullptr; ggml_tensor * norm_embd_w = nullptr; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 607d4b83731..b70bad33b68 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2287,7 +2287,7 @@ static void normalize_image_u8_to_f32(const clip_image_u8 & src, clip_image_f32 } } -// set of tools to manupulate images +// set of tools to manipulate images // in the future, we can have HW acceleration by allowing this struct to access 3rd party lib like imagick or opencv struct img_tool { enum resize_algo { diff --git a/tools/mtmd/deprecation-warning.cpp b/tools/mtmd/deprecation-warning.cpp index dded0a56af9..2b31a9d8b0b 100644 --- a/tools/mtmd/deprecation-warning.cpp +++ b/tools/mtmd/deprecation-warning.cpp @@ -1,7 +1,10 @@ +#include #include #include int main(int argc, char** argv) { + std::setlocale(LC_NUMERIC, "C"); + std::string filename = "main"; if (argc >= 1) { filename = argv[0]; diff --git a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py index 944037e703e..1f563fbfc59 100644 --- a/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py +++ b/tools/mtmd/legacy-models/minicpmv-convert-image-encoder-to-gguf.py @@ -186,7 +186,7 @@ def trunc_normal_tf_( best when :math:`a \\leq \text{mean} \\leq b`. NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 - and the result is subsquently scaled and shifted by the mean and std args. + and the result is subsequently scaled and shifted by the mean and std args. Args: tensor: an n-dimensional `torch.Tensor` mean: the mean of the normal distribution diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index e8eef035ff5..447f61aaa40 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -560,7 +560,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess(const float * s for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); if ((size_t) n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + break; // last incomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index 054c7faa6af..ba00e08534f 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) #include @@ -274,6 +275,8 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + ggml_time_init(); common_params params; diff --git a/tools/parser/CMakeLists.txt b/tools/parser/CMakeLists.txt new file mode 100644 index 00000000000..55e0c634375 --- /dev/null +++ b/tools/parser/CMakeLists.txt @@ -0,0 +1,20 @@ +if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) + # this tool is disabled on Windows when building with shared libraries because it uses internal functions not exported with LLAMA_API + set(TARGET llama-debug-template-parser) + add_executable(${TARGET} debug-template-parser.cpp) + target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) + target_compile_features(${TARGET} PRIVATE cxx_std_17) + + if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET} RUNTIME) + endif() +endif() + +set(TARGET llama-template-analysis) +add_executable(${TARGET} template-analysis.cpp) +target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PRIVATE cxx_std_17) + +if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET} RUNTIME) +endif() diff --git a/tools/parser/debug-template-parser.cpp b/tools/parser/debug-template-parser.cpp new file mode 100644 index 00000000000..ffa3a5af7ab --- /dev/null +++ b/tools/parser/debug-template-parser.cpp @@ -0,0 +1,452 @@ +#include "../src/llama-grammar.h" +#include "chat-auto-parser.h" +#include "chat.h" +#include "common.h" +#include "gguf.h" +#include "jinja/runtime.h" +#include "log.h" + +#include +#include +#include +#include + +#include "nlohmann/json.hpp" +#include "peg-parser.h" + +using json = nlohmann::ordered_json; + +enum class output_mode { + ANALYSIS, // Only output analysis results (default) + TEMPLATE, // Only output rendered template + BOTH // Output both +}; + +enum class input_message_type { + NONE, // Don't render any message scenarios (only analysis) + CONTENT_ONLY, // Simple assistant message with content + REASONING_CONTENT, // Message with reasoning_content + content + TOOL_CALL_ONLY, // Message with tool_calls only + CONTENT_TOOL_CALL, // Message with content + tool_calls + REASONING_TOOL_CALL, // Message with reasoning_content + tool_calls + CONTENT_FAKE_TOOL_CALL, // Message with content but no actual tool_calls (for testing) + ALL // Render all scenarios +}; + +struct debug_options { + std::string template_path; + bool with_tools = true; + bool generation_prompt = true; + bool enable_reasoning = true; + bool debug_jinja = false; + bool force_tool_call = false; + output_mode mode = output_mode::BOTH; + input_message_type input_message = input_message_type::NONE; +}; + +static std::string read_file(const std::string & path) { + std::ifstream fin(path, std::ios::binary); + if (!fin.is_open()) { + throw std::runtime_error("Could not open file: " + path); + } + std::ostringstream buf; + buf << fin.rdbuf(); + return buf.str(); +} + +static std::string read_gguf_chat_template(const std::string & path) { + struct gguf_init_params params = { /*no_alloc =*/true, // We only need metadata, not tensor data + /*ctx=*/nullptr }; + + struct gguf_context * ctx = gguf_init_from_file(path.c_str(), params); + if (ctx == nullptr) { + throw std::runtime_error("Could not open GGUF file: " + path); + } + + const char * key = "tokenizer.chat_template"; + int64_t key_id = gguf_find_key(ctx, key); + + if (key_id == -1) { + gguf_free(ctx); + throw std::runtime_error("GGUF file does not contain chat template key: " + std::string(key)); + } + + const char * template_str = gguf_get_val_str(ctx, key_id); + if (template_str == nullptr) { + gguf_free(ctx); + throw std::runtime_error("GGUF file contains chat template key but value is null"); + } + + std::string result = template_str; + gguf_free(ctx); + return result; +} + +static void print_usage(const char * program_name) { + LOG_ERR("Usage: %s [options]\n", program_name); + LOG_ERR("\nOptions:\n"); + LOG_ERR(" --no-tools Disable tool definitions\n"); + LOG_ERR(" --force-tool-call Set tool calls to forced\n"); + LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n"); + LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n"); + LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n"); + LOG_ERR(" --debug-jinja Enable Jinja fine-grained debug\n"); + LOG_ERR(" --input-message=TYPE Message type to render:\n"); + LOG_ERR(" content_only, reasoning_content, tool_call_only,\n"); + LOG_ERR(" content_tool_call, reasoning_tool_call,\n"); + LOG_ERR(" content_fake_tool_call, all\n"); + LOG_ERR("\nExamples:\n"); + LOG_ERR(" %s template.jinja --input-message=all --generation-prompt=1\n", program_name); + LOG_ERR(" %s template.jinja --output=template --input-message=tool_call_only\n", program_name); +} + +static bool parse_bool_option(const std::string & value) { + return value == "1" || value == "true" || value == "yes"; +} + +static bool parse_options(int argc, char ** argv, debug_options & opts) { + if (argc < 2) { + print_usage(argv[0]); + return false; + } + + opts.template_path = argv[1]; + + for (int i = 2; i < argc; ++i) { + std::string arg = argv[i]; + + if (arg == "--force-tool-call") { + opts.force_tool_call = true; + } else if (arg == "--debug-jinja") { + opts.debug_jinja = true; + } else if (arg == "--no-tools") { + opts.with_tools = false; + } else if (arg.rfind("--generation-prompt=", 0) == 0) { + opts.generation_prompt = parse_bool_option(arg.substr(20)); + } else if (arg.rfind("--enable-reasoning=", 0) == 0) { + opts.enable_reasoning = parse_bool_option(arg.substr(19)); + } else if (arg.rfind("--output=", 0) == 0) { + std::string mode = arg.substr(9); + if (mode == "analysis") { + opts.mode = output_mode::ANALYSIS; + } else if (mode == "template") { + opts.mode = output_mode::TEMPLATE; + } else if (mode == "both") { + opts.mode = output_mode::BOTH; + } else { + LOG_ERR("Unknown output mode: %s\n", mode.c_str()); + return false; + } + } else if (arg.rfind("--input-message=", 0) == 0) { + std::string type = arg.substr(16); + if (type == "content_only") { + opts.input_message = input_message_type::CONTENT_ONLY; + } else if (type == "reasoning_content") { + opts.input_message = input_message_type::REASONING_CONTENT; + } else if (type == "tool_call_only") { + opts.input_message = input_message_type::TOOL_CALL_ONLY; + } else if (type == "content_tool_call") { + opts.input_message = input_message_type::CONTENT_TOOL_CALL; + } else if (type == "reasoning_tool_call") { + opts.input_message = input_message_type::REASONING_TOOL_CALL; + } else if (type == "content_fake_tool_call") { + opts.input_message = input_message_type::CONTENT_FAKE_TOOL_CALL; + } else if (type == "all") { + opts.input_message = input_message_type::ALL; + } else { + LOG_ERR("Unknown input message type: %s\n", type.c_str()); + return false; + } + } else { + LOG_ERR("Unknown option: %s\n", arg.c_str()); + print_usage(argv[0]); + return false; + } + } + + return true; +} + +static json build_user_message() { + return json{ + { "role", "user" }, + { "content", "Hello, please help me with a task." } + }; +} + +static json build_content_only_message() { + return json{ + { "role", "assistant" }, + { "content", "Hello! I'm here to help you with your task." } + }; +} + +static json build_reasoning_content_message() { + return json{ + { "role", "assistant" }, + { "content", "Hello! I'm here to help you with your task." }, + { "reasoning_content", "The user is greeting me and asking for help. I should respond politely." } + }; +} + +static json build_tool_call_only_message() { + return json{ + { "role", "assistant" }, + { "content", nullptr }, + { "tool_calls", + json::array({ json{ + { "type", "function" }, + { "function", json{ { "name", "test_function_name" }, + { "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } }, + { "id", "123456789" } } }) } + }; +} + +static json build_content_tool_call_message() { + return json{ + { "role", "assistant" }, + { "content", "I'll help you by calling a function." }, + { "tool_calls", + json::array({ json{ + { "type", "function" }, + { "function", + json{ { "name", "test_function_name" }, + { "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } } } }) } + }; +} + +static json build_reasoning_tool_call_message() { + return json{ + { "role", "assistant" }, + { "content", nullptr }, + { "reasoning_content", "I need to call a function to help with this task." }, + { "tool_calls", + json::array({ json{ + { "type", "function" }, + { "function", + json{ { "name", "test_function_name" }, + { "arguments", json::object({ { "param1", "value1" }, { "param2", "value2" } }) } } } } }) } + }; +} + +static json build_content_fake_tool_call_message() { + // This message has content but NO tool_calls field + // It's used to test if a template renders tool definitions but not tool calls + return json{ + { "role", "assistant" }, + { "content", "I'll help you by calling a function." } + }; +} + +static json build_tools_definition() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["param1"] = json::object({ + { "type", "string" }, + { "description", "First parameter" } + }); + parameters_schema["properties"]["param2"] = json::object({ + { "type", "string" }, + { "description", "Second parameter" } + }); + parameters_schema["required"] = json::array({ "param1" }); + + return json::array({ + json{ { "type", "function" }, + { "function", json{ { "name", "test_function_name" }, + { "description", "A test function for debugging" }, + { "parameters", parameters_schema } } } } + }); +} + +static void render_scenario(const common_chat_template & tmpl, + const std::string & scenario_name, + const json & messages, + const json & tools, + bool add_generation_prompt, + bool enable_thinking) { + LOG_ERR("\n=== Scenario: %s ===\n", scenario_name.c_str()); + LOG_ERR("add_generation_prompt: %s, enable_thinking: %s\n", add_generation_prompt ? "true" : "false", + enable_thinking ? "true" : "false"); + + // When add_generation_prompt is true, add a trailing user message to trigger the prompt + json final_messages = messages; + if (add_generation_prompt && !messages.empty() && messages.back().value("role", "") == "assistant") { + final_messages.push_back(json{ + { "role", "user" }, + { "content", "Now please continue with another response." } + }); + } + + LOG_ERR("Messages:\n%s\n", final_messages.dump(2).c_str()); + + try { + autoparser::templates_params inputs; + inputs.messages = final_messages; + inputs.add_generation_prompt = add_generation_prompt; + inputs.extra_context["enable_thinking"] = enable_thinking; + + if (!tools.is_null() && tools.is_array() && !tools.empty()) { + inputs.tools = tools; + } + + std::string output = common_chat_template_direct_apply(tmpl, inputs); + + LOG_ERR("\n--- Rendered Output ---\n"); + LOG_ERR("%s\n", output.c_str()); + LOG_ERR("--- End Output (length: %zu) ---\n", output.length()); + } catch (const std::exception & e) { + LOG_ERR("Rendering failed: %s\n", e.what()); + } +} + +static void render_all_scenarios(const common_chat_template & tmpl, + const json & tools, + bool add_generation_prompt, + bool enable_thinking, + input_message_type message_type) { + json user_msg = build_user_message(); + + auto render_if = [&](input_message_type type, const std::string & name, const json & assistant_msg) { + if (message_type == input_message_type::ALL || message_type == type) { + json messages = json::array({ user_msg, assistant_msg }); + render_scenario(tmpl, name, messages, tools, add_generation_prompt, enable_thinking); + } + }; + + render_if(input_message_type::CONTENT_ONLY, "content_only", build_content_only_message()); + render_if(input_message_type::REASONING_CONTENT, "reasoning_content", build_reasoning_content_message()); + render_if(input_message_type::TOOL_CALL_ONLY, "tool_call_only", build_tool_call_only_message()); + render_if(input_message_type::CONTENT_TOOL_CALL, "content_tool_call", build_content_tool_call_message()); + render_if(input_message_type::REASONING_TOOL_CALL, "reasoning_tool_call", build_reasoning_tool_call_message()); + render_if(input_message_type::CONTENT_FAKE_TOOL_CALL, "content_fake_tool_call", + build_content_fake_tool_call_message()); + + // Also render with add_generation_prompt=true to show the prompt ending + if (message_type == input_message_type::ALL) { + LOG_ERR("\n\n=== Generation Prompt Scenarios (add_generation_prompt=true) ===\n"); + + json prompt_messages = json::array({ user_msg }); + render_scenario(tmpl, "generation_prompt_only", prompt_messages, tools, true, enable_thinking); + + // With enable_thinking toggled + render_scenario(tmpl, "generation_prompt_thinking_disabled", prompt_messages, tools, true, false); + } +} + +int main(int argc, char ** argv) { + // Set log level to most verbose to capture all debug output + common_log_set_verbosity_thold(99); + + debug_options opts; + if (!parse_options(argc, argv, opts)) { + return 1; + } + + if (opts.debug_jinja || std::getenv("LLAMA_DEBUG_JINJA") != nullptr) { + jinja::enable_debug(true); + } + + std::string template_source; + try { + // Check if the file is a GGUF file + if (opts.template_path.size() >= 5 && + opts.template_path.compare(opts.template_path.size() - 5, 5, ".gguf") == 0) { + template_source = read_gguf_chat_template(opts.template_path); + } else { + template_source = read_file(opts.template_path); + } + } catch (const std::exception & e) { + LOG_ERR("Error reading template: %s\n", e.what()); + return 1; + } + + LOG_ERR("Analyzing template: %s\n", opts.template_path.c_str()); + LOG_ERR("Options: with_tools=%s, generation_prompt=%s, enable_reasoning=%s\n", opts.with_tools ? "true" : "false", + opts.generation_prompt ? "true" : "false", opts.enable_reasoning ? "true" : "false"); + + try { + common_chat_template chat_template(template_source, "", ""); + + // Build tools definition + json tools = opts.with_tools ? build_tools_definition() : json(); + + // Render template scenarios if requested + if (opts.input_message != input_message_type::NONE && + (opts.mode == output_mode::TEMPLATE || opts.mode == output_mode::BOTH)) { + LOG_ERR("\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE RENDERING OUTPUT\n"); + LOG_ERR("================================================================================\n"); + + render_all_scenarios(chat_template, tools, opts.generation_prompt, opts.enable_reasoning, + opts.input_message); + } + + // Output analysis if requested + if (opts.mode == output_mode::ANALYSIS || opts.mode == output_mode::BOTH) { + LOG_ERR("\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE ANALYSIS\n"); + LOG_ERR("================================================================================\n"); + + autoparser::autoparser analysis; + analysis.analyze_template(chat_template); + + // Generate Parser + autoparser::templates_params params; + params.messages = json::array({ build_user_message() }); + params.reasoning_format = + opts.enable_reasoning ? COMMON_REASONING_FORMAT_DEEPSEEK : COMMON_REASONING_FORMAT_NONE; + params.enable_thinking = opts.enable_reasoning; + params.add_generation_prompt = opts.generation_prompt; + + if (opts.with_tools) { + params.tools = tools; + params.tool_choice = opts.force_tool_call ? COMMON_CHAT_TOOL_CHOICE_REQUIRED : COMMON_CHAT_TOOL_CHOICE_AUTO; + } else { + params.tools = json(); + params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; + } + params.parallel_tool_calls = false; + + auto parser_data = autoparser::peg_generator::generate_parser(chat_template, params, analysis); + + LOG_ERR("\n=== Generated Parser ===\n"); + common_peg_arena arena; + arena.load(parser_data.parser); + LOG_ERR("%s\n", arena.dump(arena.root()).c_str()); + + LOG_ERR("\n=== Generated Grammar ===\n"); + LOG_ERR("%s\n", parser_data.grammar.c_str()); + + LOG_ERR("\n=== Generated Lazy Grammar ===\n"); + LOG_ERR("%d\n", parser_data.grammar_lazy); + + LOG_ERR("\n=== Generated Grammar Triggers ===\n"); + for (const common_grammar_trigger & cgt : parser_data.grammar_triggers) { + LOG_ERR("Token: %d | Type: %d | Value: %s\n", cgt.token, cgt.type, cgt.value.c_str()); + } + + LOG_ERR("\n=== Preserved Tokens ===\n"); + for (const std::string & token : parser_data.preserved_tokens) { + LOG_ERR(" '%s'\n", token.c_str()); + } + + if (!parser_data.grammar.empty()) { + LOG_ERR("\n=== Verifying created grammar ===\n"); + auto * grammar = llama_grammar_init_impl(nullptr, parser_data.grammar.c_str(), "root", + parser_data.grammar_lazy, nullptr, 0, nullptr, 0); + if (grammar != nullptr) { + LOG_ERR("\n=== Grammar successfully created ===\n"); + } + } + } + } catch (const std::exception & e) { + LOG_ERR("Analysis failed: %s\n", e.what()); + return 1; + } + + return 0; +} diff --git a/tools/parser/template-analysis.cpp b/tools/parser/template-analysis.cpp new file mode 100644 index 00000000000..a92e104ac0f --- /dev/null +++ b/tools/parser/template-analysis.cpp @@ -0,0 +1,611 @@ +#include "chat-auto-parser.h" +#include "chat-auto-parser-helpers.h" +#include "chat.h" +#include "log.h" +#include "jinja/caps.h" +#include "jinja/runtime.h" + +#include +#include +#include +#include +#include + +#include "nlohmann/json.hpp" + +using json = nlohmann::ordered_json; + +// ANSI color codes - using 256-color palette for brighter colors (all bold) +#define ANSI_RESET "\033[0m" +#define ANSI_PURPLE "\033[1m\x1b[38;5;126m" // Bold bright purple for main headers +#define ANSI_CYAN "\033[1m\x1b[38;5;81m" // Bold bright cyan for section headers +#define ANSI_BLUE "\033[1m\x1b[38;5;12m" // Bold bright blue for labels +#define ANSI_ORANGE "\033[1m\x1b[38;5;209m" // Bold orange for right differences +#define ANSI_GREEN "\033[1m\x1b[38;5;83m" // Bold bright green for left differences +#define ANSI_GRAY "\033[1m\x1b[38;5;240m" // Bold gray (used for "no variables" message) +#define ANSI_BOLD "\033[1m" // Standalone bold +#define ANSI_PREFIX "\033[1m\x1b[38;5;176m" // Bold color for common prefix +#define ANSI_SUFFIX "\033[1m\x1b[38;5;61m" // Bold color for common suffix + +// All template paths extracted from tests/test-chat.cpp +static const std::vector ALL_TEMPLATE_PATHS = { + "models/templates/Apertus-8B-Instruct.jinja", + "models/templates/Apriel-1.6-15b-Thinker-fixed.jinja", + "models/templates/ByteDance-Seed-OSS.jinja", + "models/templates/CohereForAI-c4ai-command-r-plus-tool_use.jinja", + "models/templates/CohereForAI-c4ai-command-r7b-12-2024-tool_use.jinja", + "models/templates/GLM-4.6.jinja", + "models/templates/GLM-4.7-Flash.jinja", + "models/templates/Kimi-K2-Instruct.jinja", + "models/templates/Kimi-K2-Thinking.jinja", + "models/templates/MiMo-VL.jinja", + "models/templates/MiniMax-M2.jinja", + "models/templates/Mistral-Small-3.2-24B-Instruct-2506.jinja", + "models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja", + "models/templates/NVIDIA-Nemotron-Nano-v2.jinja", + "models/templates/NousResearch-Hermes-2-Pro-Llama-3-8B-tool_use.jinja", + "models/templates/NousResearch-Hermes-3-Llama-3.1-8B-tool_use.jinja", + "models/templates/Qwen-QwQ-32B.jinja", + "models/templates/Qwen-Qwen2.5-7B-Instruct.jinja", + "models/templates/Qwen3-Coder.jinja", + "models/templates/deepseek-ai-DeepSeek-R1-Distill-Llama-8B.jinja", + "models/templates/deepseek-ai-DeepSeek-R1-Distill-Qwen-32B.jinja", + "models/templates/deepseek-ai-DeepSeek-V3.1.jinja", + "models/templates/fireworks-ai-llama-3-firefunction-v2.jinja", + "models/templates/google-gemma-2-2b-it.jinja", + "models/templates/ibm-granite-granite-3.3-2B-Instruct.jinja", + "models/templates/llama-cpp-deepseek-r1.jinja", + "models/templates/meetkai-functionary-medium-v3.1.jinja", + "models/templates/meetkai-functionary-medium-v3.2.jinja", + "models/templates/meta-llama-Llama-3.1-8B-Instruct.jinja", + "models/templates/meta-llama-Llama-3.2-3B-Instruct.jinja", + "models/templates/meta-llama-Llama-3.3-70B-Instruct.jinja", + "models/templates/mistralai-Ministral-3-14B-Reasoning-2512.jinja", + "models/templates/mistralai-Mistral-Nemo-Instruct-2407.jinja", + "models/templates/moonshotai-Kimi-K2.jinja", + "models/templates/openai-gpt-oss-120b.jinja", + "models/templates/unsloth-Apriel-1.5.jinja", + "models/templates/unsloth-mistral-Devstral-Small-2507.jinja", +}; + +struct analysis_options { + std::vector template_paths; + bool analyze_all = false; +}; + +static std::string read_file(const std::string & path) { + std::ifstream fin(path, std::ios::binary); + if (!fin.is_open()) { + throw std::runtime_error("Could not open file: " + path); + } + std::ostringstream buf; + buf << fin.rdbuf(); + return buf.str(); +} + +static void print_usage(const char * program_name) { + LOG_ERR("Usage: %s [options]\n", program_name); + LOG_ERR("\nOptions:\n"); + LOG_ERR(" --template Analyze specific template from test suite (e.g., 'deepseek' or 'DeepSeek-V3.1')\n"); + LOG_ERR(" --template-file Analyze custom template file\n"); + LOG_ERR(" --all Analyze all templates from test suite\n"); + LOG_ERR("\nExamples:\n"); + LOG_ERR(" %s --all\n", program_name); + LOG_ERR(" %s --template deepseek\n", program_name); + LOG_ERR(" %s --template-file my-template.jinja\n", program_name); +} + +static bool parse_options(int argc, char ** argv, analysis_options & opts) { + if (argc < 2) { + print_usage(argv[0]); + return false; + } + + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + + if (arg == "--all") { + opts.analyze_all = true; + } else if (arg == "--template") { + if (i + 1 >= argc) { + LOG_ERR("--template requires an argument\n"); + return false; + } + std::string pattern = argv[++i]; + std::transform(pattern.begin(), pattern.end(), pattern.begin(), ::tolower); + + // Find matching templates + bool found = false; + for (const auto & path : ALL_TEMPLATE_PATHS) { + std::string path_lower = path; + std::transform(path_lower.begin(), path_lower.end(), path_lower.begin(), ::tolower); + if (path_lower.find(pattern) != std::string::npos) { + opts.template_paths.push_back(path); + found = true; + } + } + + if (!found) { + LOG_ERR("No templates found matching: %s\n", pattern.c_str()); + return false; + } + } else if (arg == "--template-file") { + if (i + 1 >= argc) { + LOG_ERR("--template-file requires an argument\n"); + return false; + } + opts.template_paths.push_back(argv[++i]); + } else { + LOG_ERR("Unknown option: %s\n", arg.c_str()); + print_usage(argv[0]); + return false; + } + } + + if (opts.analyze_all) { + opts.template_paths = ALL_TEMPLATE_PATHS; + } + + if (opts.template_paths.empty()) { + LOG_ERR("No templates specified\n"); + print_usage(argv[0]); + return false; + } + + return true; +} + +static json build_tools_definition() { + json parameters_schema = json::object(); + parameters_schema["type"] = "object"; + parameters_schema["properties"] = json::object(); + parameters_schema["properties"]["param1"] = json::object({ + { "type", "string" }, + { "description", "First parameter" } + }); + parameters_schema["properties"]["param2"] = json::object({ + { "type", "string" }, + { "description", "Second parameter" } + }); + parameters_schema["required"] = json::array({ "param1", "param2" }); + + return json::array({ + json{ { "type", "function" }, + { "function", json{ { "name", "test_function_name" }, + { "description", "A test function for debugging" }, + { "parameters", parameters_schema } } } } + }); +} + +// Helper to create a tool call with arguments as JSON object +static json build_tool_call(const std::string & name, const json & args_object, const std::string & id = "call_001") { + return json{ + {"id", id}, + {"type", "function"}, + {"function", json{ + {"name", name}, + {"arguments", args_object} // Pass as JSON object, not serialized string + }} + }; +} + +// Helper functions to create repeating message definitions +static json make_user_msg() { + return json{ + {"role", "user"}, + {"content", "Hello, please help me."} + }; +} + +static json make_user_msg2() { + return json{ + {"role", "user"}, + {"content", "Thank you."} + }; +} + +static json make_user_msg2_continue() { + return json{ + {"role", "user"}, + {"content", "Continue."} + }; +} + +static json make_assistant_no_tool() { + return json{ + {"role", "assistant"}, + {"content", "Let me help you."} + }; +} + +static json make_assistant_one_tool() { + return json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })} + }; +} + +static json make_assistant_two_tools() { + return json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})), + build_tool_call("test_function_name", json::object({{"param1", "value3"}, {"param2", "value4"}}), "call_002") + })} + }; +} + +static json make_assistant_no_reasoning() { + return json{ + {"role", "assistant"}, + {"content", "I can help you with that."} + }; +} + +static json make_assistant_with_reasoning() { + return json{ + {"role", "assistant"}, + {"content", "I can help you with that."}, + {"reasoning_content", "The user is asking for help. I should respond positively."} + }; +} + +static json make_assistant_one_tool_with_reasoning() { + return json{ + {"role", "assistant"}, + {"content", nullptr}, + {"tool_calls", json::array({ + build_tool_call("test_function_name", json::object({{"param1", "value1"}, {"param2", "value2"}})) + })}, + {"reasoning_content", "I need to call the tool first."} + }; +} + +static void print_diff_split(const std::string & title, const diff_split & diff) { + LOG_ERR("\n%s=== %s ===%s\n", ANSI_CYAN, title.c_str(), ANSI_RESET); + LOG_ERR("%sCommon Prefix:%s '%s'\n", ANSI_PREFIX, ANSI_RESET, diff.prefix.c_str()); + LOG_ERR("%sCommon Suffix:%s '%s'\n", ANSI_SUFFIX, ANSI_RESET, diff.suffix.c_str()); + LOG_ERR("%sLeft (difference):%s '%s'\n", ANSI_GREEN, ANSI_RESET, diff.left.c_str()); + LOG_ERR("%sRight (difference):%s '%s'\n", ANSI_ORANGE, ANSI_RESET, diff.right.c_str()); +} + +static void check_reasoning_variables(const common_chat_template & tmpl) { + LOG_ERR("\n%s=== Checking Reasoning Variables ===%s\n", ANSI_CYAN, ANSI_RESET); + + try { + // Create a list of candidate reasoning/thinking variable names to probe + std::vector candidate_vars = { + "enable_reasoning", + "use_reasoning", + "reasoning_enabled", + "has_reasoning", + "reasoning_mode", + "reasoning_format", + "reasoning_active", + "with_reasoning", + "use_thinking", + "thinking_enabled", + "has_thinking", + "thinking_mode", + "thinking_format", + "thinking_active", + "with_thinking", + "enable_reason", + "reason_enabled", + "enable_think", + "think_enabled", + }; + + jinja::context ctx; + ctx.is_get_stats = true; + + json messages = json::array({ + json{ + {"role", "user"}, + {"content", "Test message"} + }, + json{ + {"role", "assistant"}, + {"content", "Response"}, + {"reasoning_content", "Some reasoning"} + } + }); + + // Set up base context + jinja::global_from_json(ctx, json{ + {"messages", messages}, + {"tools", json::array()}, + {"bos_token", ""}, + {"eos_token", ""}, + {"add_generation_prompt", false}, + {"enable_thinking", true} // Already passed, so we'll exclude this from results + }, true); + + // Add candidate variables as undefined to probe which ones are accessed + for (const auto & var_name : candidate_vars) { + ctx.set_val(var_name, jinja::mk_val(var_name)); + } + + try { + jinja::runtime runtime(ctx); + runtime.execute(tmpl.prog); + } catch (const std::exception & e) { + // Execution may fail, that's okay - we just want to see what variables were accessed + } + + // Check which candidate variables were accessed (stats.used = true) + std::vector accessed_vars; + for (const auto & var_name : candidate_vars) { + auto val = ctx.get_val(var_name); + if (!val->is_undefined()) { + // Variable was overwritten, skip it + continue; + } + if (val->stats.used) { + accessed_vars.push_back(var_name); + } + } + + if (accessed_vars.empty()) { + LOG_ERR("%sNo reasoning/thinking-related variables were queried by the template%s\n", ANSI_GRAY, ANSI_RESET); + } else { + LOG_ERR("Template queries the following reasoning/thinking-related variables:\n"); + for (const auto & var : accessed_vars) { + LOG_ERR(" %s- %s%s\n", ANSI_ORANGE, var.c_str(), ANSI_RESET); + } + } + + } catch (const std::exception & e) { + LOG_ERR("Error checking reasoning variables: %s\n", e.what()); + } +} + +static void analyze_template(const std::string & template_path) { + LOG_ERR("\n"); + LOG_ERR("%s", ANSI_PURPLE); + LOG_ERR("================================================================================\n"); + LOG_ERR(" ANALYZING TEMPLATE: %s\n", template_path.c_str()); + LOG_ERR("================================================================================\n"); + LOG_ERR("%s", ANSI_RESET); + + std::string template_source; + try { + template_source = read_file(template_path); + } catch (const std::exception & e) { + LOG_ERR("Error reading template: %s\n", e.what()); + return; + } + + try { + common_chat_template chat_template(template_source, "", ""); + json tools = build_tools_definition(); + + // ===== CAPABILITIES ANALYSIS ===== + LOG_ERR("\n%s=== Template Capabilities (from jinja::caps) ===%s\n", ANSI_CYAN, ANSI_RESET); + auto caps = chat_template.original_caps(); + LOG_ERR("%ssupports_tools:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_tools ? "true" : "false"); + LOG_ERR("%ssupports_tool_calls:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_tool_calls ? "true" : "false"); + LOG_ERR("%ssupports_system_role:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_system_role ? "true" : "false"); + LOG_ERR("%ssupports_parallel_tool_calls:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_parallel_tool_calls ? "true" : "false"); + LOG_ERR("%ssupports_typed_content:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_typed_content ? "true" : "false"); + LOG_ERR("%ssupports_string_content:%s %s\n", ANSI_BLUE, ANSI_RESET, caps.supports_string_content ? "true" : "false"); + + // ===== DIFFERENTIAL ANALYSIS ===== + + // Test 1: With and without tools (single user message) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_tools; + params_no_tools.messages = json::array({ user_msg }); + params_no_tools.add_generation_prompt = false; + params_no_tools.tools = json::array(); + + autoparser::templates_params params_with_tools = params_no_tools; + params_with_tools.tools = tools; + + std::string output_no_tools = common_chat_template_direct_apply(chat_template, params_no_tools); + std::string output_with_tools = common_chat_template_direct_apply(chat_template, params_with_tools); + + auto diff = calculate_diff_split(output_no_tools, output_with_tools); + print_diff_split("Diff: With vs Without Tools (single user message)", diff); + } + + // Test 2: With and without add_generation_prompt (single user message) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_prompt; + params_no_prompt.messages = json::array({ user_msg }); + params_no_prompt.add_generation_prompt = false; + params_no_prompt.tools = json::array(); + + autoparser::templates_params params_with_prompt = params_no_prompt; + params_with_prompt.add_generation_prompt = true; + + std::string output_no_prompt = common_chat_template_direct_apply(chat_template, params_no_prompt); + std::string output_with_prompt = common_chat_template_direct_apply(chat_template, params_with_prompt); + + auto diff = calculate_diff_split(output_no_prompt, output_with_prompt); + print_diff_split("Diff: With vs Without add_generation_prompt (single user message)", diff); + } + + // Test 3: Assistant with reasoning_content (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_reasoning; + params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning() }); + params_no_reasoning.add_generation_prompt = false; + params_no_reasoning.enable_thinking = true; + + autoparser::templates_params params_with_reasoning = params_no_reasoning; + params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning() }); + + std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); + std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning); + + auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning); + print_diff_split("Diff: With vs Without reasoning_content (user, assistant)", diff); + } + + // Test 4: Assistant with reasoning_content (user, assistant, user) + { + json user_msg = make_user_msg(); + json user_msg2 = make_user_msg2(); + + autoparser::templates_params params_no_reasoning; + params_no_reasoning.messages = json::array({ user_msg, make_assistant_no_reasoning(), user_msg2 }); + params_no_reasoning.add_generation_prompt = false; + params_no_reasoning.enable_thinking = true; + + autoparser::templates_params params_with_reasoning = params_no_reasoning; + params_with_reasoning.messages = json::array({ user_msg, make_assistant_with_reasoning(), user_msg2 }); + + std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); + std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning); + + auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning); + print_diff_split("Diff: With vs Without reasoning_content (user, assistant, user)", diff); + } + + // Test 5: Tool call in last assistant message (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_tool; + params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool() }); + params_no_tool.add_generation_prompt = false; + params_no_tool.tools = tools; + + autoparser::templates_params params_with_tool = params_no_tool; + params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); + + std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); + std::string output_with_tool = common_chat_template_direct_apply(chat_template, params_with_tool); + + auto diff = calculate_diff_split(output_no_tool, output_with_tool); + print_diff_split("Diff: With vs Without tool call (user, assistant)", diff); + } + + // Test 6: Tool call in last assistant message (user, assistant, user) + { + json user_msg = make_user_msg(); + json user_msg2 = make_user_msg2_continue(); + + autoparser::templates_params params_no_tool; + params_no_tool.messages = json::array({ user_msg, make_assistant_no_tool(), user_msg2 }); + params_no_tool.add_generation_prompt = false; + params_no_tool.tools = tools; + + autoparser::templates_params params_with_tool = params_no_tool; + params_with_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); + + std::string output_no_tool = common_chat_template_direct_apply(chat_template, params_no_tool); + std::string output_with_tool = common_chat_template_direct_apply(chat_template, params_with_tool); + + auto diff = calculate_diff_split(output_no_tool, output_with_tool); + print_diff_split("Diff: With vs Without tool call (user, assistant, user)", diff); + } + + // Test 7: One vs two tool calls (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_one_tool; + params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool() }); + params_one_tool.add_generation_prompt = false; + params_one_tool.tools = tools; + + autoparser::templates_params params_two_tools = params_one_tool; + params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools() }); + + std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); + std::string output_two_tools = common_chat_template_direct_apply(chat_template, params_two_tools); + + auto diff = calculate_diff_split(output_one_tool, output_two_tools); + print_diff_split("Diff: One vs Two tool calls (user, assistant)", diff); + } + + // Test 8: One vs two tool calls (user, assistant, user) + { + json user_msg = make_user_msg(); + json user_msg2 = make_user_msg2_continue(); + + autoparser::templates_params params_one_tool; + params_one_tool.messages = json::array({ user_msg, make_assistant_one_tool(), user_msg2 }); + params_one_tool.add_generation_prompt = false; + params_one_tool.tools = tools; + + autoparser::templates_params params_two_tools = params_one_tool; + params_two_tools.messages = json::array({ user_msg, make_assistant_two_tools(), user_msg2 }); + + std::string output_one_tool = common_chat_template_direct_apply(chat_template, params_one_tool); + std::string output_two_tools = common_chat_template_direct_apply(chat_template, params_two_tools); + + auto diff = calculate_diff_split(output_one_tool, output_two_tools); + print_diff_split("Diff: One vs Two tool calls (user, assistant, user)", diff); + } + + // Test 9: Tool call with vs without reasoning_content (user, assistant) + { + json user_msg = make_user_msg(); + + autoparser::templates_params params_no_reasoning; + params_no_reasoning.messages = json::array({ user_msg, make_assistant_one_tool() }); + params_no_reasoning.add_generation_prompt = false; + params_no_reasoning.tools = tools; + params_no_reasoning.enable_thinking = true; + + autoparser::templates_params params_with_reasoning = params_no_reasoning; + params_with_reasoning.messages = json::array({ user_msg, make_assistant_one_tool_with_reasoning() }); + + std::string output_no_reasoning = common_chat_template_direct_apply(chat_template, params_no_reasoning); + std::string output_with_reasoning = common_chat_template_direct_apply(chat_template, params_with_reasoning); + + auto diff = calculate_diff_split(output_no_reasoning, output_with_reasoning); + print_diff_split("Diff: Tool call with vs without reasoning_content (user, assistant)", diff); + } + + // Check reasoning variables + check_reasoning_variables(chat_template); + + } catch (const std::exception & e) { + LOG_ERR("Analysis failed: %s\n", e.what()); + } +} + +int main(int argc, char ** argv) { + // Set log level to capture all output + common_log_set_verbosity_thold(99); + + analysis_options opts; + if (!parse_options(argc, argv, opts)) { + return 1; + } + + LOG_ERR("\n"); + LOG_ERR("%s", ANSI_PURPLE); + LOG_ERR("================================================================================\n"); + LOG_ERR(" TEMPLATE ANALYSIS TOOL\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR("%s", ANSI_RESET); + LOG_ERR("Analyzing %s%zu%s template(s)\n", ANSI_CYAN, opts.template_paths.size(), ANSI_RESET); + + for (const auto & path : opts.template_paths) { + analyze_template(path); + } + + LOG_ERR("\n"); + LOG_ERR("%s", ANSI_GREEN); + LOG_ERR("================================================================================\n"); + LOG_ERR(" ANALYSIS COMPLETE\n"); + LOG_ERR("================================================================================\n"); + LOG_ERR("%s", ANSI_RESET); + + return 0; +} diff --git a/tools/perplexity/README.md b/tools/perplexity/README.md index eb3846072ea..f82d34c8a25 100644 --- a/tools/perplexity/README.md +++ b/tools/perplexity/README.md @@ -27,10 +27,10 @@ In addition to the KL divergence the following statistics are calculated with `- * Ratio of mean FP16 PPL and quantized PPL. Uncertainty is estimated on logits, then propagated. The logarithm of this metric is also calculated and printed, it is 0 if the logit distributions are the same. * Difference of mean FP16 PPL and quantized PPL. Uncertainty is estimated on logits, then propagated. * Mean change in "correct" token probability. Positive values mean the model gets better at prediction, negative values mean it gets worse. -* Pearson correlation coefficient of the "correct" token probabilites between models. +* Pearson correlation coefficient of the "correct" token probabilities between models. * Percentiles of change in "correct" token probability. Positive values mean the model gets better at prediction, negative values mean it gets worse. Can be used to judge noise vs. quality loss from quantization. If the percentiles are symmetric then the quantization is essentially just adding noise. If the negative values are significantly larger than the positive values then this indicates that the model is actually becoming worse from the quantization. * The root mean square of the change in token probabilities. If you were to assume that the quantization simply causes Gaussian noise on the token probabilities then this would be the standard deviation of said noise. The uncertainty on the value is calculated that the change in token probabilities follows a Gaussian distribution. Related discussion: https://github.com/ggml-org/llama.cpp/discussions/2875 . -* Same top p: Percentage of how often the token was assigned the highest probabilites by both models. The uncertainty is calculated from the Gaussian approximation of the binomial distribution. +* Same top p: Percentage of how often the token was assigned the highest probabilities by both models. The uncertainty is calculated from the Gaussian approximation of the binomial distribution. ## LLaMA 3 8b Scoreboard diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 433b747f0d4..cc5ea99c4df 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -3,10 +3,11 @@ #include "log.h" #include "llama.h" -#include #include #include #include +#include +#include #include #include #include @@ -2004,6 +2005,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + common_params params; params.n_ctx = 512; diff --git a/tools/quantize/README.md b/tools/quantize/README.md index 22f07102867..b8c225124b3 100644 --- a/tools/quantize/README.md +++ b/tools/quantize/README.md @@ -100,7 +100,7 @@ Examples: ## Memory/Disk Requirements When running the larger models, make sure you have enough disk space to store all the intermediate files. -As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. For exmaple (Llama 3.1): +As the models are currently fully loaded into memory, you will need adequate disk space to save them and sufficient RAM to load them. At the moment, memory and disk requirements are the same. For example (Llama 3.1): | Model | Original size | Quantized size (Q4_K_M) | | ----: | ------------: | ----------------------: | diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 59bf9bd3fd0..0a483328ee5 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -2,6 +2,10 @@ #include "llama.h" #include "gguf.h" +#include +#include +#include +#include #include #include #include @@ -485,6 +489,8 @@ static bool parse_layer_prune(const char * data, std::vector & prune_layers } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + if (argc < 3) { usage(argv[0]); } diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index 6feb0e91f32..03ab78e5f05 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -10,12 +10,15 @@ # include # include #endif -#include -#include -#include #include -#include +#include +#include +#include #include +#include +#include +#include +#include #if defined(__linux__) #include @@ -285,6 +288,8 @@ static std::vector get_devices(const rpc_server_params & par } int main(int argc, char * argv[]) { + std::setlocale(LC_NUMERIC, "C"); + ggml_backend_load_all(); rpc_server_params params; diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index a5465fcd132..ed3fc127b77 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/public_legacy/index-new.html b/tools/server/public_legacy/index-new.html index e2f39d6687e..2cee7f3c3c4 100644 --- a/tools/server/public_legacy/index-new.html +++ b/tools/server/public_legacy/index-new.html @@ -36,7 +36,7 @@ const params = signal({ n_predict: 358, // 358 is a nice number - temperature: 0.8, // adapt all following parameters to optimized min-p requierements. If for non-english, set to 0.6 or lower + temperature: 0.8, // adapt all following parameters to optimized min-p requirements. If for non-english, set to 0.6 or lower repeat_last_n: 0, // 0 = disable penalty, -1 = context size repeat_penalty: 1.0, // 1.0 = disabled dry_multiplier: 0.0, // 0.0 = disabled, 0.8 works well @@ -108,7 +108,7 @@ let importedTemplates = local_storage_getDataAsObject('user_templates') if (importedTemplates) { - // saved templates were successfuly imported. + // saved templates were successfully imported. console.log('Processing saved templates and updating default template') params.value = { ...params.value, image_data: [] }; @@ -129,7 +129,7 @@ } function userTemplateResetToDefault() { - console.log('Reseting themplate to default') + console.log('Reseting template to default') selectedUserTemplate.value.name = 'default'; selectedUserTemplate.value.data = savedUserTemplates.value['default']; } diff --git a/tools/server/public_legacy/json-schema-to-grammar.mjs b/tools/server/public_legacy/json-schema-to-grammar.mjs index 38576c45fa0..bb25887a144 100644 --- a/tools/server/public_legacy/json-schema-to-grammar.mjs +++ b/tools/server/public_legacy/json-schema-to-grammar.mjs @@ -729,6 +729,10 @@ export class SchemaConverter { return this._addRule(ruleName, out.join('')); } else if ((schemaType === 'object') || (Object.keys(schema).length === 0)) { return this._addRule(ruleName, this._addPrimitive('object', PRIMITIVE_RULES['object'])); + } else if (schemaType === undefined && typeof schema === 'object' && !Array.isArray(schema) && schema !== null) { + // No type constraint and no recognized structural keywords (e.g. {"description": "..."}). + // Per JSON Schema semantics this is equivalent to {} and accepts any value. + return this._addRule(ruleName, this._addPrimitive('value', PRIMITIVE_RULES['value'])); } else { if (!(schemaType in PRIMITIVE_RULES)) { throw new Error(`Unrecognized schema: ${JSON.stringify(schema)}`); diff --git a/tools/server/public_simplechat/datautils.mjs b/tools/server/public_simplechat/datautils.mjs index 75159d6b167..08ccc219bfd 100644 --- a/tools/server/public_simplechat/datautils.mjs +++ b/tools/server/public_simplechat/datautils.mjs @@ -63,7 +63,7 @@ export function trim_repeat_garbage_at_end(sIn, maxSubL=10, maxMatchLenThreshold /** - * Simple minded logic to help remove repeating garbage at end of the string, till it cant. + * Simple minded logic to help remove repeating garbage at end of the string, till it can't. * If its not able to trim, then it will try to skip a char at end and then trim, a few times. * This ensures that even if there are multiple runs of garbage with different patterns, the * logic still tries to munch through them. diff --git a/tools/server/public_simplechat/readme.md b/tools/server/public_simplechat/readme.md index 24e026d455b..cc86d62494c 100644 --- a/tools/server/public_simplechat/readme.md +++ b/tools/server/public_simplechat/readme.md @@ -30,7 +30,7 @@ The UI follows a responsive web design so that the layout can adapt to available enough manner, in general. Allows developer/end-user to control some of the behaviour by updating gMe members from browser's devel-tool -console. Parallely some of the directly useful to end-user settings can also be changed using the provided +console. Parallelly some of the directly useful to end-user settings can also be changed using the provided settings ui. NOTE: Current web service api doesnt expose the model context length directly, so client logic doesnt provide @@ -38,7 +38,7 @@ any adaptive culling of old messages nor of replacing them with summary of their is a optional sliding window based chat logic, which provides a simple minded culling of old messages from the chat history before sending to the ai model. -NOTE: Wrt options sent with the request, it mainly sets temperature, max_tokens and optionaly stream for now. +NOTE: Wrt options sent with the request, it mainly sets temperature, max_tokens and optionally stream for now. However if someone wants they can update the js file or equivalent member in gMe as needed. NOTE: One may be able to use this to chat with openai api web-service /chat/completions endpoint, in a very @@ -88,7 +88,7 @@ Once inside then the end user needs to enter the same. This keeps the logic simple, while still giving flexibility to the end user to manage any templating/tagging requirement wrt their messages to the model. - * the logic doesnt insert newline at the begining and end wrt the prompt message generated. + * the logic doesnt insert newline at the beginning and end wrt the prompt message generated. However if the chat being sent to /completions end point has more than one role's message, then insert newline when moving from one role's message to the next role's message, so that it can be clearly identified/distinguished. @@ -101,8 +101,8 @@ Once inside Normally Completion mode doesnt need system prompt, while Chat mode can generate better/interesting responses with a suitable system prompt. * if chat.add_system_begin is used - * you cant change the system prompt, after it is has been submitted once along with user query. - * you cant set a system prompt, after you have submitted any user query + * you can't change the system prompt, after it is has been submitted once along with user query. + * you can't set a system prompt, after you have submitted any user query * if chat.add_system_anytime is used * one can change the system prompt any time during chat, by changing the contents of system prompt. * inturn the updated/changed system prompt will be inserted into the chat session. @@ -129,7 +129,7 @@ Once inside ### Reason behind this -The idea is to be easy enough to use for basic purposes, while also being simple and easily discernable +The idea is to be easy enough to use for basic purposes, while also being simple and easily discernible by developers who may not be from web frontend background (so inturn may not be familiar with template / end-use-specific-language-extensions driven flows) so that they can use it to explore/experiment things. @@ -167,7 +167,7 @@ It is attached to the document object. Some of these can also be updated using t messages that get inserted into prompt field wrt /Completion endpoint. bTrimGarbage - whether garbage repeatation at the end of the generated ai response, should be - trimmed or left as is. If enabled, it will be trimmed so that it wont be sent back as part of + trimmed or left as is. If enabled, it will be trimmed so that it won't be sent back as part of subsequent chat history. At the same time the actual trimmed text is shown to the user, once when it was generated, so user can check if any useful info/data was there in the response. @@ -244,7 +244,7 @@ full chat history. This way if there is any response with garbage/repeatation, i mess with things beyond the next question/request/query, in some ways. The trim garbage option also tries to help avoid issues with garbage in the context to an extent. -Set max_tokens to 1024, so that a relatively large previous reponse doesnt eat up the space +Set max_tokens to 1024, so that a relatively large previous response doesnt eat up the space available wrt next query-response. However dont forget that the server when started should also be started with a model context size of 1k or more, to be on safe side. diff --git a/tools/server/public_simplechat/simplechat.js b/tools/server/public_simplechat/simplechat.js index 2fcd24a860b..c67577d5ae7 100644 --- a/tools/server/public_simplechat/simplechat.js +++ b/tools/server/public_simplechat/simplechat.js @@ -318,7 +318,7 @@ class SimpleChat { } /** - * Allow setting of system prompt, but only at begining. + * Allow setting of system prompt, but only at beginning. * @param {string} sysPrompt * @param {string} msgTag */ @@ -333,7 +333,7 @@ class SimpleChat { console.error(`ERRR:SimpleChat:SC:${msgTag}:You need to specify system prompt before any user query, ignoring...`); } else { if (this.xchat[0].content !== sysPrompt) { - console.error(`ERRR:SimpleChat:SC:${msgTag}:You cant change system prompt, mid way through, ignoring...`); + console.error(`ERRR:SimpleChat:SC:${msgTag}:You can't change system prompt, mid way through, ignoring...`); } } } diff --git a/tools/server/public_simplechat/ui.mjs b/tools/server/public_simplechat/ui.mjs index b2d5b9aeab7..afa619a0663 100644 --- a/tools/server/public_simplechat/ui.mjs +++ b/tools/server/public_simplechat/ui.mjs @@ -44,7 +44,7 @@ export function el_create_button(id, callback, name=undefined, innerText=undefin } /** - * Create a para and set it up. Optionaly append it to a passed parent. + * Create a para and set it up. Optionally append it to a passed parent. * @param {string} text * @param {HTMLElement | undefined} elParent * @param {string | undefined} id @@ -111,7 +111,7 @@ export function el_creatediv_boolbutton(id, label, texts, defaultValue, cb, clas /** * Create a select ui element, with a set of options to select from. * * options: an object which contains name-value pairs - * * defaultOption: the value whose name should be choosen, by default. + * * defaultOption: the value whose name should be chosen, by default. * * cb : the call back returns the name string of the option selected. * * @param {string} id diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index ff3c6d3c2b0..13ea8c690f3 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1463,6 +1463,7 @@ json convert_anthropic_to_oai(const json & body) { json tool_calls = json::array(); json converted_content = json::array(); json tool_results = json::array(); + std::string reasoning_content; bool has_tool_calls = false; for (const auto & block : content) { @@ -1470,6 +1471,8 @@ json convert_anthropic_to_oai(const json & body) { if (type == "text") { converted_content.push_back(block); + } else if (type == "thinking") { + reasoning_content += json_value(block, "thinking", std::string()); } else if (type == "image") { json source = json_value(block, "source", json::object()); std::string source_type = json_value(source, "type", std::string()); @@ -1528,16 +1531,19 @@ json convert_anthropic_to_oai(const json & body) { } } - if (!converted_content.empty() || has_tool_calls) { + if (!converted_content.empty() || has_tool_calls || !reasoning_content.empty()) { json new_msg = {{"role", role}}; if (!converted_content.empty()) { new_msg["content"] = converted_content; - } else if (has_tool_calls) { + } else if (has_tool_calls || !reasoning_content.empty()) { new_msg["content"] = ""; } if (!tool_calls.empty()) { new_msg["tool_calls"] = tool_calls; } + if (!reasoning_content.empty()) { + new_msg["reasoning_content"] = reasoning_content; + } oai_messages.push_back(new_msg); } diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aafed495020..9dbd6d798a3 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -12,6 +12,7 @@ #include "mtmd.h" #include "mtmd-helper.h" +#include #include #include #include @@ -2348,8 +2349,10 @@ struct server_context_impl { const auto it = std::find_if( slot.prompt.checkpoints.rbegin(), slot.prompt.checkpoints.rend(), - [&](const auto & cur) { + [&, func_name = __func__](const auto & cur) { // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] + LOG_INF("slot %12.*s: id %2d | task %d | Checking checkpoint with [%d, %d] against %d...\n", 12, + func_name, (slot).id, ((slot).task ? (slot).task->id : -1), cur.pos_min, cur.pos_max, pos_min_thold); return cur.pos_min < pos_min_thold; } ); @@ -2533,47 +2536,65 @@ struct server_context_impl { slot.i_batch = batch.n_tokens - 1; slot.init_sampler(); + SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + } else { + // only do non-end checkpoints if the "checkpoint every n tokens" option is set + do_checkpoint = do_checkpoint && params_base.checkpoint_every_nt > 0; + if (do_checkpoint) { + llama_pos last_checkpoint = 0; + if (!slot.prompt.checkpoints.empty()) { + last_checkpoint = slot.prompt.checkpoints.back().n_tokens; + } + do_checkpoint = do_checkpoint && slot.prompt.n_tokens() - batch.n_tokens - last_checkpoint >= params_base.checkpoint_every_nt; + if (do_checkpoint) { + SLT_INF(slot, "%d tokens since last checkpoint at %d, creating new checkpoint during processing at position %d\n", params_base.checkpoint_every_nt, last_checkpoint, slot.prompt.n_tokens()); + } + } + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); + } - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); + const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); + const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); + // no need for empty or small checkpoints + do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + // no need to create checkpoints that are too close together + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); - // note: we create the checkpoint before calling llama_decode(), so the current batch is not - // yet processed and therefore it is not part of the checkpoint. - if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); + // note: we create the checkpoint before calling llama_decode(), so the current batch is not + // yet processed and therefore it is not part of the checkpoint. + if (do_checkpoint) { + while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { + // make room for the new checkpoint, if needed + const auto & cur = slot.prompt.checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, + "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 + ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); + } - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens, - /*.data = */ std::vector(checkpoint_size), - }); + const size_t checkpoint_size = + llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens, + /*.data = */ std::vector(checkpoint_size), + }); - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); - } + llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, + LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - } else { - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); + SLT_WRN(slot, + "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 + ", size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, + cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); } } diff --git a/tools/server/server-cors-proxy.h b/tools/server/server-cors-proxy.h new file mode 100644 index 00000000000..bca50b53df5 --- /dev/null +++ b/tools/server/server-cors-proxy.h @@ -0,0 +1,56 @@ +#pragma once + +#include "common.h" +#include "http.h" + +#include +#include +#include +#include + +#include "server-http.h" + +static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) { + std::string target_url = req.get_param("url"); + common_http_url parsed_url = common_http_parse_url(target_url); + + if (parsed_url.host.empty()) { + throw std::runtime_error("invalid target URL: missing host"); + } + + if (parsed_url.path.empty()) { + parsed_url.path = "/"; + } + + if (!parsed_url.password.empty()) { + throw std::runtime_error("authentication in target URL is not supported"); + } + + if (parsed_url.scheme != "http" && parsed_url.scheme != "https") { + throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme); + } + + SRV_INF("proxying %s request to %s://%s%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.path.c_str()); + + auto proxy = std::make_unique( + method, + parsed_url.host, + parsed_url.scheme == "http" ? 80 : 443, + parsed_url.path, + req.headers, + req.body, + req.should_stop, + 600, // timeout_read (default to 10 minutes) + 600 // timeout_write (default to 10 minutes) + ); + + return proxy; +} + +static server_http_context::handler_t proxy_handler_post = [](const server_http_req & req) -> server_http_res_ptr { + return proxy_request(req, "POST"); +}; + +static server_http_context::handler_t proxy_handler_get = [](const server_http_req & req) -> server_http_res_ptr { + return proxy_request(req, "GET"); +}; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index bc601237b7d..5f87ba9a212 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1089,11 +1089,20 @@ server_http_proxy::server_http_proxy( int32_t timeout_write ) { // shared between reader and writer threads - auto cli = std::make_shared(host, port); + auto cli = std::make_shared(host, port); auto pipe = std::make_shared>(); + if (port == 443) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli.reset(new httplib::SSLClient(host, port)); +#else + throw std::runtime_error("HTTPS requested but CPPHTTPLIB_OPENSSL_SUPPORT is not defined"); +#endif + } + // setup Client - cli->set_connection_timeout(0, 200000); // 200 milliseconds + cli->set_follow_location(true); + cli->set_connection_timeout(5, 0); // 5 seconds cli->set_write_timeout(timeout_read, 0); // reversed for cli (client) vs srv (server) cli->set_read_timeout(timeout_write, 0); this->status = 500; // to be overwritten upon response @@ -1142,7 +1151,15 @@ server_http_proxy::server_http_proxy( req.method = method; req.path = path; for (const auto & [key, value] : headers) { - req.set_header(key, value); + if (key == "Accept-Encoding") { + // disable Accept-Encoding to avoid compressed responses + continue; + } + if (key == "Host" || key == "host") { + req.set_header(key, host); + } else { + req.set_header(key, value); + } } req.body = body; req.response_handler = response_handler; diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index d3aba18489b..32c0d8f481d 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -1,12 +1,12 @@ -#include "server-common.h" #include "server-task.h" +#include "chat.h" #include "common.h" +#include "json-schema-to-grammar.h" #include "llama.h" -#include "chat.h" #include "sampling.h" #include "speculative.h" -#include "json-schema-to-grammar.h" +#include "server-common.h" using json = nlohmann::ordered_json; @@ -157,7 +157,8 @@ json task_params::to_json(bool only_metrics) const { common_chat_msg task_result_state::update_chat_msg( const std::string & text_added, bool is_partial, - std::vector & diffs) { + std::vector & diffs, + bool filter_tool_calls) { generated_text += text_added; auto msg_prv_copy = chat_msg; SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); @@ -168,7 +169,64 @@ common_chat_msg task_result_state::update_chat_msg( if (!new_msg.empty()) { new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg); + auto all_diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, chat_msg); + + if (!filter_tool_calls) { + diffs = std::move(all_diffs); + } else { + for (auto & d : all_diffs) { + // If this is a new type of delta, flush all currently pending tool call names + for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) { + if (sent_tool_call_names.count(i) || chat_msg.tool_calls[i].name.empty()) { + continue; + } + if (d.tool_call_index != i || !d.tool_call_delta.arguments.empty()) { + common_chat_msg_diff header; + header.tool_call_index = i; + header.tool_call_delta.id = chat_msg.tool_calls[i].id; + header.tool_call_delta.name = chat_msg.tool_calls[i].name; + diffs.push_back(std::move(header)); + sent_tool_call_names.insert(i); + } + } + + if (d.tool_call_index == std::string::npos) { + diffs.push_back(std::move(d)); + } else { + size_t i = d.tool_call_index; + if (sent_tool_call_names.count(i)) { + if (!d.tool_call_delta.arguments.empty()) { + d.tool_call_delta.name = ""; + d.tool_call_delta.id = ""; + diffs.push_back(std::move(d)); + } + } else { + // Not sent yet. + if (!d.tool_call_delta.arguments.empty() || !is_partial) { + d.tool_call_delta.name = chat_msg.tool_calls[i].name; + d.tool_call_delta.id = chat_msg.tool_calls[i].id; + diffs.push_back(std::move(d)); + sent_tool_call_names.insert(i); + } else { + // Suppress + } + } + } + } + // Final check at EOF + if (!is_partial) { + for (size_t i = 0; i < chat_msg.tool_calls.size(); ++i) { + if (!sent_tool_call_names.count(i) && !chat_msg.tool_calls[i].name.empty()) { + common_chat_msg_diff header; + header.tool_call_index = i; + header.tool_call_delta.id = chat_msg.tool_calls[i].id; + header.tool_call_delta.name = chat_msg.tool_calls[i].name; + diffs.push_back(std::move(header)); + sent_tool_call_names.insert(i); + } + } + } + } } return chat_msg; } diff --git a/tools/server/server-task.h b/tools/server/server-task.h index e2e3e5a5828..1e342531d8e 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -98,6 +98,7 @@ struct task_result_state { common_chat_msg chat_msg; std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; + std::unordered_set sent_tool_call_names; // for OpenAI Responses and Anthropic streaming API: // track output item / content block state across chunks @@ -120,7 +121,8 @@ struct task_result_state { common_chat_msg update_chat_msg( const std::string & text_added, bool is_partial, - std::vector & diffs); + std::vector & diffs, + bool filter_tool_calls = false); }; struct server_task { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index f353dcdde7b..0bd6fda17d2 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -1,6 +1,7 @@ #include "server-context.h" #include "server-http.h" #include "server-models.h" +#include "server-cors-proxy.h" #include "arg.h" #include "common.h" @@ -8,6 +9,7 @@ #include "log.h" #include +#include #include #include #include // for std::thread::hardware_concurrency @@ -67,6 +69,8 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t } int main(int argc, char ** argv) { + std::setlocale(LC_NUMERIC, "C"); + // own arguments required by this example common_params params; @@ -198,6 +202,15 @@ int main(int argc, char ** argv) { // Save & load slots ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); + // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) + if (params.webui_mcp_proxy) { + SRV_WRN("%s", "-----------------\n"); + SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n"); + SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n"); + SRV_WRN("%s", "-----------------\n"); + ctx_http.get ("/cors-proxy", ex_wrapper(proxy_handler_get)); + ctx_http.post("/cors-proxy", ex_wrapper(proxy_handler_post)); + } // // Start the server diff --git a/tools/server/tests/README.md b/tools/server/tests/README.md index a60d3f8ea1a..f566b43644b 100644 --- a/tools/server/tests/README.md +++ b/tools/server/tests/README.md @@ -57,7 +57,7 @@ To run a single test: ./tests.sh unit/test_chat_completion.py::test_invalid_chat_completion_req ``` -Hint: You can compile and run test in single command, useful for local developement: +Hint: You can compile and run test in single command, useful for local development: ```shell cmake --build build -j --target llama-server && ./tools/server/tests/tests.sh diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index e16e0235c64..93ff03be6b4 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -809,6 +809,139 @@ def test_anthropic_vs_openai_different_response_format(): # Extended thinking tests with reasoning models +# The next two tests cover the input path (conversation history): +# Client sends thinking blocks -> convert_anthropic_to_oai -> reasoning_content -> template + +def test_anthropic_thinking_history_in_count_tokens(): + """Test that interleaved thinking blocks in conversation history are not dropped during conversion.""" + global server + server.jinja = True + server.chat_template_file = '../../../models/templates/Qwen-Qwen3-0.6B.jinja' + server.start() + + tool = { + "name": "list_files", + "description": "List files", + "input_schema": { + "type": "object", + "properties": {"path": {"type": "string"}}, + "required": ["path"] + } + } + + messages_without_thinking = [ + {"role": "user", "content": "Fix the bug"}, + { + "role": "assistant", + "content": [ + {"type": "tool_use", "id": "call_1", "name": "list_files", "input": {"path": "."}} + ] + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "call_1", "content": "main.py"} + ] + }, + ] + + messages_with_thinking = [ + {"role": "user", "content": "Fix the bug"}, + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "I should check the project structure first to understand the codebase layout."}, + {"type": "tool_use", "id": "call_1", "name": "list_files", "input": {"path": "."}} + ] + }, + { + "role": "user", + "content": [ + {"type": "tool_result", "tool_use_id": "call_1", "content": "main.py"} + ] + }, + ] + + res_without = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": messages_without_thinking, + "tools": [tool], + }) + assert res_without.status_code == 200, f"Expected 200: {res_without.body}" + + res_with = server.make_request("POST", "/v1/messages/count_tokens", data={ + "model": "test", + "messages": messages_with_thinking, + "tools": [tool], + }) + assert res_with.status_code == 200, f"Expected 200: {res_with.body}" + + # Thinking blocks should increase the token count + assert res_with.body["input_tokens"] > res_without.body["input_tokens"], \ + f"Expected more tokens with thinking ({res_with.body['input_tokens']}) than without ({res_without.body['input_tokens']})" + + +def test_anthropic_thinking_history_in_template(): + """Test that reasoning_content from converted interleaved thinking blocks renders in the prompt.""" + global server + server.jinja = True + server.chat_template_file = '../../../models/templates/Qwen-Qwen3-0.6B.jinja' + server.start() + + reasoning_1 = "I should check the project structure first." + reasoning_2 = "Now I need to read the main file." + + res = server.make_request("POST", "/apply-template", data={ + "messages": [ + {"role": "user", "content": "Fix the bug in main.py"}, + { + "role": "assistant", + "content": "", + "reasoning_content": reasoning_1, + "tool_calls": [{ + "id": "call_1", + "type": "function", + "function": {"name": "list_files", "arguments": "{\"path\": \".\"}"} + }] + }, + {"role": "tool", "tool_call_id": "call_1", "content": "main.py\nutils.py"}, + { + "role": "assistant", + "content": "", + "reasoning_content": reasoning_2, + "tool_calls": [{ + "id": "call_2", + "type": "function", + "function": {"name": "read_file", "arguments": "{\"path\": \"main.py\"}"} + }] + }, + {"role": "tool", "tool_call_id": "call_2", "content": "print('hello')"}, + ], + "tools": [{ + "type": "function", + "function": { + "name": "list_files", + "description": "List files", + "parameters": {"type": "object", "properties": {"path": {"type": "string"}}, "required": ["path"]} + } + }, { + "type": "function", + "function": { + "name": "read_file", + "description": "Read a file", + "parameters": {"type": "object", "properties": {"path": {"type": "string"}}, "required": ["path"]} + } + }], + }) + assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}" + prompt = res.body["prompt"] + + # Both reasoning_content values should be rendered in tags + assert reasoning_1 in prompt, f"Expected first reasoning text in prompt: {prompt}" + assert reasoning_2 in prompt, f"Expected second reasoning text in prompt: {prompt}" + assert prompt.count("") >= 2, f"Expected at least 2 blocks in prompt: {prompt}" + + @pytest.mark.slow @pytest.mark.parametrize("stream", [False, True]) def test_anthropic_thinking_with_reasoning_model(stream): diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index b8f0f10863f..ba41cd44ea9 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -100,18 +100,19 @@ def do_test_completion_with_required_tool_tiny(server: ServerProcess, tool: dict assert choice["message"].get("content") in (None, ""), f'Expected no content in {choice["message"]}' # assert len(tool_call.get("id", "")) > 0, f'Expected non empty tool call id in {tool_call}' expected_function_name = "python" if tool["type"] == "code_interpreter" else tool["function"]["name"] - assert expected_function_name == tool_call["function"]["name"] + assert expected_function_name == tool_call["function"]["name"], f'Expected tool name to be {tool_call["function"]["name"]} in {choice["message"]}' actual_arguments = tool_call["function"]["arguments"] - assert isinstance(actual_arguments, str) + assert isinstance(actual_arguments, dict) or isinstance(actual_arguments, str), f'Expected arguments to be a dict or str, got: {actual_arguments}' if argument_key is not None: - actual_arguments = json.loads(actual_arguments) - assert argument_key in actual_arguments, f"tool arguments: {json.dumps(actual_arguments)}, expected: {argument_key}" + if (isinstance(actual_arguments, str)): + actual_arguments = json.loads(actual_arguments) + assert argument_key in actual_arguments, f"tool arguments: {actual_arguments}, expected: {argument_key}" @pytest.mark.parametrize("stream", [CompletionMode.NORMAL, CompletionMode.STREAMED]) @pytest.mark.parametrize("template_name,tool,argument_key", [ - ("google-gemma-2-2b-it", TEST_TOOL, "success"), - ("google-gemma-2-2b-it", TEST_TOOL, "success"), + ("Qwen3-Coder", TEST_TOOL, "success"), + ("Qwen3-Coder", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", TEST_TOOL, "success"), ("meta-llama-Llama-3.3-70B-Instruct", PYTHON_TOOL, "code"), diff --git a/tools/server/webui/docs/architecture/high-level-architecture-simplified.md b/tools/server/webui/docs/architecture/high-level-architecture-simplified.md index a6cb1e9c394..500f477c9a4 100644 --- a/tools/server/webui/docs/architecture/high-level-architecture-simplified.md +++ b/tools/server/webui/docs/architecture/high-level-architecture-simplified.md @@ -12,9 +12,13 @@ flowchart TB C_Form["ChatForm"] C_Messages["ChatMessages"] C_Message["ChatMessage"] + C_ChatMessageAgenticContent["ChatMessageAgenticContent"] C_MessageEditForm["ChatMessageEditForm"] C_ModelsSelector["ModelsSelector"] C_Settings["ChatSettings"] + C_McpSettings["McpServersSettings"] + C_McpResourceBrowser["McpResourceBrowser"] + C_McpServersSelector["McpServersSelector"] end subgraph Hooks["🪝 Hooks"] @@ -24,10 +28,13 @@ flowchart TB subgraph Stores["🗄️ Stores"] S1["chatStore
Chat interactions & streaming"] - S2["conversationsStore
Conversation data & messages"] + SA["agenticStore
Multi-turn agentic loop orchestration"] + S2["conversationsStore
Conversation data, messages & MCP overrides"] S3["modelsStore
Model selection & loading"] S4["serverStore
Server props & role detection"] - S5["settingsStore
User configuration"] + S5["settingsStore
User configuration incl. MCP"] + S6["mcpStore
MCP servers, tools, prompts"] + S7["mcpResourceStore
MCP resources & attachments"] end subgraph Services["⚙️ Services"] @@ -36,11 +43,12 @@ flowchart TB SV3["PropsService"] SV4["DatabaseService"] SV5["ParameterSyncService"] + SV6["MCPService
protocol operations"] end subgraph Storage["💾 Storage"] ST1["IndexedDB
conversations, messages"] - ST2["LocalStorage
config, userOverrides"] + ST2["LocalStorage
config, userOverrides, mcpServers"] end subgraph APIs["🌐 llama-server API"] @@ -50,15 +58,27 @@ flowchart TB API4["/v1/models"] end + subgraph ExternalMCP["🔌 External MCP Servers"] + EXT1["MCP Server 1
WebSocket/HTTP/SSE"] + EXT2["MCP Server N"] + end + %% Routes → Components R1 & R2 --> C_Screen RL --> C_Sidebar + %% Layout runs MCP health checks + RL --> S6 + %% Component hierarchy C_Screen --> C_Form & C_Messages & C_Settings C_Messages --> C_Message + C_Message --> C_ChatMessageAgenticContent C_Message --> C_MessageEditForm C_Form & C_MessageEditForm --> C_ModelsSelector + C_Form --> C_McpServersSelector + C_Settings --> C_McpSettings + C_McpSettings --> C_McpResourceBrowser %% Components → Hooks → Stores C_Form & C_Messages --> H1 & H2 @@ -70,6 +90,15 @@ flowchart TB C_Sidebar --> S2 C_ModelsSelector --> S3 & S4 C_Settings --> S5 + C_McpSettings --> S6 + C_McpResourceBrowser --> S6 & S7 + C_McpServersSelector --> S6 + C_Form --> S6 + + %% chatStore → agenticStore → mcpStore (agentic loop) + S1 --> SA + SA --> SV1 + SA --> S6 %% Stores → Services S1 --> SV1 & SV4 @@ -77,6 +106,8 @@ flowchart TB S3 --> SV2 & SV3 S4 --> SV3 S5 --> SV5 + S6 --> SV6 + S7 --> SV6 %% Services → Storage SV4 --> ST1 @@ -87,6 +118,9 @@ flowchart TB SV2 --> API3 & API4 SV3 --> API2 + %% MCP → External Servers + SV6 --> EXT1 & EXT2 + %% Styling classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px @@ -95,12 +129,17 @@ flowchart TB classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px + classDef mcpStyle fill:#e0f2f1,stroke:#00695c,stroke-width:2px + classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px + classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5 class R1,R2,RL routeStyle - class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle + class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_ChatMessageAgenticContent,C_MessageEditForm,C_ModelsSelector,C_Settings componentStyle + class C_McpSettings,C_McpResourceBrowser,C_McpServersSelector componentStyle class H1,H2 hookStyle - class S1,S2,S3,S4,S5 storeStyle - class SV1,SV2,SV3,SV4,SV5 serviceStyle + class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle + class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle class ST1,ST2 storageStyle class API1,API2,API3,API4 apiStyle + class EXT1,EXT2 externalStyle ``` diff --git a/tools/server/webui/docs/architecture/high-level-architecture.md b/tools/server/webui/docs/architecture/high-level-architecture.md index c5ec4d69095..42ddb3f4f5b 100644 --- a/tools/server/webui/docs/architecture/high-level-architecture.md +++ b/tools/server/webui/docs/architecture/high-level-architecture.md @@ -22,6 +22,13 @@ end C_ModelsSelector["ModelsSelector"] C_Settings["ChatSettings"] end + subgraph MCPComponents["MCP UI"] + C_McpSettings["McpServersSettings"] + C_McpServerCard["McpServerCard"] + C_McpResourceBrowser["McpResourceBrowser"] + C_McpResourcePreview["McpResourcePreview"] + C_McpServersSelector["McpServersSelector"] + end end subgraph Hooks["🪝 Hooks"] @@ -43,14 +50,20 @@ end S1Edit["Editing:
editAssistantMessage()
editUserMessagePreserveResponses()
editMessageWithBranching()
clearEditMode()
isEditModeActive()
getAddFilesHandler()
setEditModeActive()"] S1Utils["Utilities:
getApiOptions()
parseTimingData()
getOrCreateAbortController()
getConversationModel()"] end + subgraph SA["agenticStore"] + SAState["State:
sessions (Map)
isAnyRunning"] + SASession["Session Management:
getSession()
updateSession()
clearSession()
getActiveSessions()
isRunning()
currentTurn()
totalToolCalls()
lastError()
streamingToolCall()"] + SAConfig["Configuration:
getConfig()
maxTurns, maxToolPreviewLines"] + SAFlow["Agentic Loop:
runAgenticFlow()
executeAgenticLoop()
normalizeToolCalls()
emitToolCallResult()
extractBase64Attachments()"] + end subgraph S2["conversationsStore"] - S2State["State:
conversations
activeConversation
activeMessages
usedModalities
isInitialized
titleUpdateConfirmationCallback"] - S2Modal["Modalities:
getModalitiesUpToMessage()
calculateModalitiesFromMessages()"] + S2State["State:
conversations
activeConversation
activeMessages
isInitialized
pendingMcpServerOverrides
titleUpdateConfirmationCallback"] S2Lifecycle["Lifecycle:
initialize()
loadConversations()
clearActiveConversation()"] - S2ConvCRUD["Conversation CRUD:
createConversation()
loadConversation()
deleteConversation()
updateConversationName()
updateConversationTitleWithConfirmation()"] + S2ConvCRUD["Conversation CRUD:
createConversation()
loadConversation()
deleteConversation()
deleteAll()
updateConversationName()
updateConversationTitleWithConfirmation()"] S2MsgMgmt["Message Management:
refreshActiveMessages()
addMessageToActive()
updateMessageAtIndex()
findMessageIndex()
sliceActiveMessages()
removeMessageAtIndex()
getConversationMessages()"] S2Nav["Navigation:
navigateToSibling()
updateCurrentNode()
updateConversationTimestamp()"] - S2Export["Import/Export:
downloadConversation()
exportAllConversations()
importConversations()
triggerDownload()"] + S2McpOverrides["MCP Per-Chat Overrides:
getMcpServerOverride()
getAllMcpServerOverrides()
setMcpServerOverride()
toggleMcpServerForChat()
removeMcpServerOverride()
isMcpServerEnabledForChat()
clearPendingMcpServerOverrides()"] + S2Export["Import/Export:
downloadConversation()
exportAllConversations()
importConversations()
importConversationsData()
triggerDownload()"] S2Utils["Utilities:
setTitleUpdateConfirmationCallback()"] end subgraph S3["modelsStore"] @@ -77,6 +90,21 @@ end S5Sync["Server Sync:
syncWithServerDefaults()
forceSyncWithServerDefaults()"] S5Utils["Utilities:
getConfig()
getAllConfig()
getParameterInfo()
getParameterDiff()
getServerDefaults()
clearAllUserOverrides()"] end + subgraph S6["mcpStore"] + S6State["State:
isInitializing, error
toolCount, connectedServers
healthChecks (Map)
connections (Map)
toolsIndex (Map)"] + S6Lifecycle["Lifecycle:
ensureInitialized()
initialize()
shutdown()
acquireConnection()
releaseConnection()"] + S6Health["Health Checks:
runHealthCheck()
runHealthChecksForServers()
updateHealthCheck()
getHealthCheckState()
clearHealthCheck()"] + S6Servers["Server Management:
getServers()
addServer()
updateServer()
removeServer()
getServerById()
getServerDisplayName()"] + S6Tools["Tool Operations:
getToolDefinitionsForLLM()
getToolNames()
hasTool()
getToolServer()
executeTool()
executeToolByName()"] + S6Prompts["Prompt Operations:
getAllPrompts()
getPrompt()
hasPromptsCapability()
getPromptCompletions()"] + end + subgraph S7["mcpResourceStore"] + S7State["State:
serverResources (Map)
cachedResources (Map)
subscriptions (Map)
attachments[]
isLoading"] + S7Resources["Resource Discovery:
setServerResources()
getServerResources()
getAllResourceInfos()
getAllTemplateInfos()
clearServerResources()"] + S7Cache["Caching:
cacheResourceContent()
getCachedContent()
invalidateCache()
clearCache()"] + S7Subs["Subscriptions:
addSubscription()
removeSubscription()
isSubscribed()
handleResourceUpdate()"] + S7Attach["Attachments:
addAttachment()
updateAttachmentContent()
removeAttachment()
clearAttachments()
toMessageExtras()"] + end subgraph ReactiveExports["⚡ Reactive Exports"] direction LR @@ -95,12 +123,19 @@ end RE9c["setEditModeActive()"] RE9d["clearEditMode()"] end + subgraph AgenticExports["agenticStore"] + REA1["agenticIsRunning()"] + REA2["agenticCurrentTurn()"] + REA3["agenticTotalToolCalls()"] + REA4["agenticLastError()"] + REA5["agenticStreamingToolCall()"] + REA6["agenticIsAnyRunning()"] + end subgraph ConvExports["conversationsStore"] RE10["conversations()"] RE11["activeConversation()"] RE12["activeMessages()"] RE13["isConversationsInitialized()"] - RE14["usedModalities()"] end subgraph ModelsExports["modelsStore"] RE15["modelOptions()"] @@ -131,6 +166,13 @@ end RE36["theme()"] RE37["isInitialized()"] end + subgraph MCPExports["mcpStore / mcpResourceStore"] + RE38["mcpResources()"] + RE39["mcpResourceAttachments()"] + RE40["mcpHasResourceAttachments()"] + RE41["mcpTotalResourceCount()"] + RE42["mcpResourcesLoading()"] + end end end @@ -138,9 +180,9 @@ end direction TB subgraph SV1["ChatService"] SV1Msg["Messaging:
sendMessage()"] - SV1Stream["Streaming:
handleStreamResponse()
parseSSEChunk()"] - SV1Convert["Conversion:
convertMessageToChatData()
convertExtraToApiFormat()"] - SV1Utils["Utilities:
extractReasoningContent()
getServerProps()
getModels()"] + SV1Stream["Streaming:
handleStreamResponse()
handleNonStreamResponse()"] + SV1Convert["Conversion:
convertDbMessageToApiChatMessageData()
mergeToolCallDeltas()"] + SV1Utils["Utilities:
stripReasoningContent()
extractModelName()
parseErrorResponse()"] end subgraph SV2["ModelsService"] SV2List["Listing:
list()
listRouter()"] @@ -152,7 +194,7 @@ end end subgraph SV4["DatabaseService"] SV4Conv["Conversations:
createConversation()
getConversation()
getAllConversations()
updateConversation()
deleteConversation()"] - SV4Msg["Messages:
createMessageBranch()
createRootMessage()
getConversationMessages()
updateMessage()
deleteMessage()
deleteMessageCascading()"] + SV4Msg["Messages:
createMessageBranch()
createRootMessage()
createSystemMessage()
getConversationMessages()
updateMessage()
deleteMessage()
deleteMessageCascading()"] SV4Node["Navigation:
updateCurrentNode()"] SV4Import["Import:
importConversations()"] end @@ -162,6 +204,19 @@ end SV5Info["Info:
getParameterInfo()
canSyncParameter()
getSyncableParameterKeys()
validateServerParameter()"] SV5Diff["Diff:
createParameterDiff()"] end + subgraph SV6["MCPService"] + SV6Transport["Transport:
createTransport()
WebSocket / StreamableHTTP / SSE"] + SV6Conn["Connection:
connect()
disconnect()"] + SV6Tools["Tools:
listTools()
callTool()"] + SV6Prompts["Prompts:
listPrompts()
getPrompt()"] + SV6Resources["Resources:
listResources()
listResourceTemplates()
readResource()
subscribeResource()
unsubscribeResource()"] + SV6Complete["Completions:
complete()"] + end + end + + subgraph ExternalMCP["🔌 External MCP Servers"] + EXT1["MCP Server 1
(WebSocket/StreamableHTTP/SSE)"] + EXT2["MCP Server N"] end subgraph Storage["💾 Storage"] @@ -171,6 +226,7 @@ end ST5["LocalStorage"] ST6["config"] ST7["userOverrides"] + ST8["mcpServers"] end subgraph APIs["🌐 llama-server API"] @@ -185,6 +241,9 @@ end R2 --> C_Screen RL --> C_Sidebar + %% Layout runs MCP health checks on startup + RL --> S6 + %% Component hierarchy C_Screen --> C_Form & C_Messages & C_Settings C_Messages --> C_Message @@ -194,8 +253,15 @@ end C_MessageEditForm --> C_Attach C_Form --> C_ModelsSelector C_Form --> C_Attach + C_Form --> C_McpServersSelector C_Message --> C_Attach + %% MCP Components hierarchy + C_Settings --> C_McpSettings + C_McpSettings --> C_McpServerCard + C_McpServerCard --> C_McpResourceBrowser + C_McpResourceBrowser --> C_McpResourcePreview + %% Components use Hooks C_Form --> H1 C_Message --> H1 & H2 @@ -210,17 +276,29 @@ end C_Screen --> S1 & S2 C_Messages --> S2 C_Message --> S1 & S2 & S3 - C_Form --> S1 & S3 + C_Form --> S1 & S3 & S6 C_Sidebar --> S2 C_ModelsSelector --> S3 & S4 C_Settings --> S5 + C_McpSettings --> S6 + C_McpServerCard --> S6 + C_McpResourceBrowser --> S6 & S7 + C_McpServersSelector --> S6 %% Stores export Reactive State S1 -. exports .-> ChatExports + SA -. exports .-> AgenticExports S2 -. exports .-> ConvExports S3 -. exports .-> ModelsExports S4 -. exports .-> ServerExports S5 -. exports .-> SettingsExports + S6 -. exports .-> MCPExports + S7 -. exports .-> MCPExports + + %% chatStore → agenticStore (agentic loop orchestration) + S1 --> SA + SA --> SV1 + SA --> S6 %% Stores use Services S1 --> SV1 & SV4 @@ -228,28 +306,35 @@ end S3 --> SV2 & SV3 S4 --> SV3 S5 --> SV5 + S6 --> SV6 + S7 --> SV6 %% Services to Storage SV4 --> ST1 ST1 --> ST2 & ST3 SV5 --> ST5 - ST5 --> ST6 & ST7 + ST5 --> ST6 & ST7 & ST8 %% Services to APIs SV1 --> API1 SV2 --> API3 & API4 SV3 --> API2 + %% MCP → External Servers + SV6 --> EXT1 & EXT2 + %% Styling classDef routeStyle fill:#e1f5fe,stroke:#01579b,stroke-width:2px classDef componentStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px classDef componentGroupStyle fill:#e1bee7,stroke:#7b1fa2,stroke-width:1px + classDef hookStyle fill:#fff8e1,stroke:#ff8f00,stroke-width:2px classDef storeStyle fill:#fff3e0,stroke:#e65100,stroke-width:2px classDef stateStyle fill:#ffe0b2,stroke:#e65100,stroke-width:1px classDef methodStyle fill:#ffecb3,stroke:#e65100,stroke-width:1px classDef reactiveStyle fill:#fffde7,stroke:#f9a825,stroke-width:1px classDef serviceStyle fill:#e8f5e9,stroke:#2e7d32,stroke-width:2px classDef serviceMStyle fill:#c8e6c9,stroke:#2e7d32,stroke-width:1px + classDef externalStyle fill:#f3e5f5,stroke:#6a1b9a,stroke-width:2px,stroke-dasharray: 5 5 classDef storageStyle fill:#fce4ec,stroke:#c2185b,stroke-width:2px classDef apiStyle fill:#e3f2fd,stroke:#1565c0,stroke-width:2px @@ -257,23 +342,32 @@ end class C_Sidebar,C_Screen,C_Form,C_Messages,C_Message,C_MessageUser,C_MessageEditForm componentStyle class C_ModelsSelector,C_Settings componentStyle class C_Attach componentStyle - class H1,H2,H3 methodStyle - class LayoutComponents,ChatUIComponents componentGroupStyle - class Hooks storeStyle - class S1,S2,S3,S4,S5 storeStyle - class S1State,S2State,S3State,S4State,S5State stateStyle + class C_McpSettings,C_McpServerCard,C_McpResourceBrowser,C_McpResourcePreview,C_McpServersSelector componentStyle + class H1,H2,H3 hookStyle + class LayoutComponents,ChatUIComponents,MCPComponents componentGroupStyle + class Hooks hookStyle + classDef agenticStyle fill:#e8eaf6,stroke:#283593,stroke-width:2px + classDef agenticMethodStyle fill:#c5cae9,stroke:#283593,stroke-width:1px + + class S1,S2,S3,S4,S5,SA,S6,S7 storeStyle + class S1State,S2State,S3State,S4State,S5State,SAState,S6State,S7State stateStyle class S1Msg,S1Regen,S1Edit,S1Stream,S1LoadState,S1ProcState,S1Error,S1Utils methodStyle - class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2Modal,S2Export,S2Utils methodStyle + class SASession,SAConfig,SAFlow methodStyle + class S2Lifecycle,S2ConvCRUD,S2MsgMgmt,S2Nav,S2McpOverrides,S2Export,S2Utils methodStyle class S3Getters,S3Modal,S3Status,S3Fetch,S3Select,S3LoadUnload,S3Utils methodStyle class S4Getters,S4Data,S4Utils methodStyle class S5Lifecycle,S5Update,S5Reset,S5Sync,S5Utils methodStyle - class ChatExports,ConvExports,ModelsExports,ServerExports,SettingsExports reactiveStyle - class SV1,SV2,SV3,SV4,SV5 serviceStyle + class S6Lifecycle,S6Health,S6Servers,S6Tools,S6Prompts methodStyle + class S7Resources,S7Cache,S7Subs,S7Attach methodStyle + class ChatExports,AgenticExports,ConvExports,ModelsExports,ServerExports,SettingsExports,MCPExports reactiveStyle + class SV1,SV2,SV3,SV4,SV5,SV6 serviceStyle + class SV6Transport,SV6Conn,SV6Tools,SV6Prompts,SV6Resources,SV6Complete serviceMStyle + class EXT1,EXT2 externalStyle class SV1Msg,SV1Stream,SV1Convert,SV1Utils serviceMStyle class SV2List,SV2LoadUnload,SV2Status serviceMStyle class SV3Fetch serviceMStyle class SV4Conv,SV4Msg,SV4Node,SV4Import serviceMStyle class SV5Extract,SV5Merge,SV5Info,SV5Diff serviceMStyle - class ST1,ST2,ST3,ST5,ST6,ST7 storageStyle + class ST1,ST2,ST3,ST5,ST6,ST7,ST8 storageStyle class API1,API2,API3,API4 apiStyle ``` diff --git a/tools/server/webui/docs/flows/chat-flow.md b/tools/server/webui/docs/flows/chat-flow.md index 05e1df385a7..296693c6a54 100644 --- a/tools/server/webui/docs/flows/chat-flow.md +++ b/tools/server/webui/docs/flows/chat-flow.md @@ -2,8 +2,10 @@ sequenceDiagram participant UI as 🧩 ChatForm / ChatMessage participant chatStore as 🗄️ chatStore + participant agenticStore as 🗄️ agenticStore participant convStore as 🗄️ conversationsStore participant settingsStore as 🗄️ settingsStore + participant mcpStore as 🗄️ mcpStore participant ChatSvc as ⚙️ ChatService participant DbSvc as ⚙️ DatabaseService participant API as 🌐 /v1/chat/completions @@ -25,6 +27,9 @@ sequenceDiagram Note over convStore: → see conversations-flow.mmd end + chatStore->>mcpStore: consumeResourceAttachmentsAsExtras() + Note right of mcpStore: Converts pending MCP resource
attachments into message extras + chatStore->>chatStore: addMessage("user", content, extras) chatStore->>DbSvc: createMessageBranch(userMsg, parentId) chatStore->>convStore: addMessageToActive(userMsg) @@ -38,7 +43,7 @@ sequenceDiagram deactivate chatStore %% ═══════════════════════════════════════════════════════════════════════════ - Note over UI,API: 🌊 STREAMING + Note over UI,API: 🌊 STREAMING (with agentic flow detection) %% ═══════════════════════════════════════════════════════════════════════════ activate chatStore @@ -52,10 +57,17 @@ sequenceDiagram chatStore->>chatStore: getApiOptions() Note right of chatStore: Merge from settingsStore.config:
temperature, max_tokens, top_p, etc. - chatStore->>ChatSvc: sendMessage(messages, options, signal) + alt agenticConfig.enabled && mcpStore has connected servers + chatStore->>agenticStore: runAgenticFlow(convId, messages, assistantMsg, options, signal) + Note over agenticStore: Multi-turn agentic loop:
1. Call ChatService.sendMessage()
2. If response has tool_calls → execute via mcpStore
3. Append tool results as messages
4. Loop until no more tool_calls or maxTurns
→ see agentic flow details below + agenticStore-->>chatStore: final response with timings + else standard (non-agentic) flow + chatStore->>ChatSvc: sendMessage(messages, options, signal) + end + activate ChatSvc - ChatSvc->>ChatSvc: convertMessageToChatData(messages) + ChatSvc->>ChatSvc: convertDbMessageToApiChatMessageData(messages) Note right of ChatSvc: DatabaseMessage[] → ApiChatMessageData[]
Process attachments (images, PDFs, audio) ChatSvc->>API: POST /v1/chat/completions @@ -63,7 +75,7 @@ sequenceDiagram loop SSE chunks API-->>ChatSvc: data: {"choices":[{"delta":{...}}]} - ChatSvc->>ChatSvc: parseSSEChunk(line) + ChatSvc->>ChatSvc: handleStreamResponse(response) alt content chunk ChatSvc-->>chatStore: onChunk(content) @@ -154,12 +166,15 @@ sequenceDiagram Note over UI,API: ✏️ EDIT USER MESSAGE %% ═══════════════════════════════════════════════════════════════════════════ - UI->>chatStore: editUserMessagePreserveResponses(msgId, newContent) + UI->>chatStore: editMessageWithBranching(msgId, newContent, extras) activate chatStore chatStore->>chatStore: Get parent of target message chatStore->>DbSvc: createMessageBranch(editedMsg, parentId) chatStore->>convStore: refreshActiveMessages() Note right of chatStore: Creates new branch, original preserved + chatStore->>chatStore: createAssistantMessage(editedMsg.id) + chatStore->>chatStore: streamChatCompletion(...) + Note right of chatStore: Automatically regenerates response deactivate chatStore %% ═══════════════════════════════════════════════════════════════════════════ @@ -171,4 +186,43 @@ sequenceDiagram Note right of chatStore: errorDialogState = {type: 'timeout'|'server', message} chatStore->>convStore: removeMessageAtIndex(failedMsgIdx) chatStore->>DbSvc: deleteMessage(failedMsgId) + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,API: 🤖 AGENTIC LOOP (when agenticConfig.enabled) + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over agenticStore: agenticStore.runAgenticFlow(convId, messages, assistantMsg, options, signal) + activate agenticStore + agenticStore->>agenticStore: getSession(convId) or create new + agenticStore->>agenticStore: updateSession(turn: 0, running: true) + + loop executeAgenticLoop (until no tool_calls or maxTurns) + agenticStore->>agenticStore: turn++ + agenticStore->>ChatSvc: sendMessage(messages, options, signal) + ChatSvc->>API: POST /v1/chat/completions + API-->>ChatSvc: response with potential tool_calls + ChatSvc-->>agenticStore: onComplete(content, reasoning, timings, toolCalls) + + alt response has tool_calls + agenticStore->>agenticStore: normalizeToolCalls(toolCalls) + loop for each tool_call + agenticStore->>agenticStore: updateSession(streamingToolCall) + agenticStore->>mcpStore: executeTool(mcpCall, signal) + mcpStore-->>agenticStore: tool result + agenticStore->>agenticStore: extractBase64Attachments(result) + agenticStore->>agenticStore: emitToolCallResult(convId, ...) + agenticStore->>convStore: addMessageToActive(toolResultMsg) + agenticStore->>DbSvc: createMessageBranch(toolResultMsg) + end + agenticStore->>agenticStore: Create new assistantMsg for next turn + Note right of agenticStore: Continue loop with updated messages + else no tool_calls (final response) + agenticStore->>agenticStore: buildFinalTimings(allTurns) + Note right of agenticStore: Break loop, return final response + end + end + + agenticStore->>agenticStore: updateSession(running: false) + agenticStore-->>chatStore: final content, timings, model + deactivate agenticStore ``` diff --git a/tools/server/webui/docs/flows/conversations-flow.md b/tools/server/webui/docs/flows/conversations-flow.md index 185ed16e0cd..bd2309bc03e 100644 --- a/tools/server/webui/docs/flows/conversations-flow.md +++ b/tools/server/webui/docs/flows/conversations-flow.md @@ -6,7 +6,7 @@ sequenceDiagram participant DbSvc as ⚙️ DatabaseService participant IDB as 💾 IndexedDB - Note over convStore: State:
conversations: DatabaseConversation[]
activeConversation: DatabaseConversation | null
activeMessages: DatabaseMessage[]
isInitialized: boolean
usedModalities: $derived({vision, audio}) + Note over convStore: State:
conversations: DatabaseConversation[]
activeConversation: DatabaseConversation | null
activeMessages: DatabaseMessage[]
isInitialized: boolean
pendingMcpServerOverrides: Map<string, McpServerOverride> %% ═══════════════════════════════════════════════════════════════════════════ Note over UI,IDB: 🚀 INITIALIZATION @@ -37,6 +37,13 @@ sequenceDiagram convStore->>convStore: conversations.unshift(conversation) convStore->>convStore: activeConversation = $state(conversation) convStore->>convStore: activeMessages = $state([]) + + alt pendingMcpServerOverrides has entries + loop each pending override + convStore->>DbSvc: Store MCP server override for new conversation + end + convStore->>convStore: clearPendingMcpServerOverrides() + end deactivate convStore %% ═══════════════════════════════════════════════════════════════════════════ @@ -58,8 +65,7 @@ sequenceDiagram Note right of convStore: Filter to show only current branch path convStore->>convStore: activeMessages = $state(filtered) - convStore->>chatStore: syncLoadingStateForChat(convId) - Note right of chatStore: Sync isLoading/currentResponse if streaming + Note right of convStore: Route (+page.svelte) then calls:
chatStore.syncLoadingStateForChat(convId) deactivate convStore %% ═══════════════════════════════════════════════════════════════════════════ @@ -121,16 +127,36 @@ sequenceDiagram end deactivate convStore + UI->>convStore: deleteAll() + activate convStore + convStore->>DbSvc: Delete all conversations and messages + convStore->>convStore: conversations = [] + convStore->>convStore: clearActiveConversation() + deactivate convStore + %% ═══════════════════════════════════════════════════════════════════════════ - Note over UI,IDB: 📊 MODALITY TRACKING + Note over UI,IDB: � MCP SERVER PER-CHAT OVERRIDES %% ═══════════════════════════════════════════════════════════════════════════ - Note over convStore: usedModalities = $derived.by(() => {
calculateModalitiesFromMessages(activeMessages)
}) + Note over convStore: Conversations can override which MCP servers are enabled. + Note over convStore: Uses pendingMcpServerOverrides before conversation
is created, then persists to conversation metadata. + + UI->>convStore: setMcpServerOverride(convId, serverName, override) + Note right of convStore: override = {enabled: boolean} + + UI->>convStore: toggleMcpServerForChat(convId, serverName, enabled) + activate convStore + convStore->>convStore: setMcpServerOverride(convId, serverName, {enabled}) + deactivate convStore + + UI->>convStore: isMcpServerEnabledForChat(convId, serverName) + Note right of convStore: Check override → fall back to global MCP config - Note over convStore: Scans activeMessages for attachments:
- IMAGE → vision: true
- PDF (processedAsImages) → vision: true
- AUDIO → audio: true + UI->>convStore: getAllMcpServerOverrides(convId) + Note right of convStore: Returns all overrides for a conversation - UI->>convStore: getModalitiesUpToMessage(msgId) - Note right of convStore: Used for regeneration validation
Only checks messages BEFORE target + UI->>convStore: removeMcpServerOverride(convId, serverName) + UI->>convStore: getMcpServerOverride(convId, serverName) %% ═══════════════════════════════════════════════════════════════════════════ Note over UI,IDB: 📤 EXPORT / 📥 IMPORT @@ -148,8 +174,10 @@ sequenceDiagram UI->>convStore: importConversations(file) activate convStore convStore->>convStore: Parse JSON file + convStore->>convStore: importConversationsData(parsed) convStore->>DbSvc: importConversations(parsed) - DbSvc->>IDB: Bulk INSERT conversations + messages + Note right of DbSvc: Skips duplicate conversations
(checks existing by ID) + DbSvc->>IDB: INSERT conversations + messages (skip existing) convStore->>convStore: loadConversations() deactivate convStore ``` diff --git a/tools/server/webui/docs/flows/database-flow.md b/tools/server/webui/docs/flows/database-flow.md index 50f8284e3c3..38cd6941cf7 100644 --- a/tools/server/webui/docs/flows/database-flow.md +++ b/tools/server/webui/docs/flows/database-flow.md @@ -66,6 +66,14 @@ sequenceDiagram DbSvc-->>Store: rootMessageId deactivate DbSvc + Store->>DbSvc: createSystemMessage(convId, content, parentId) + activate DbSvc + DbSvc->>DbSvc: Create message {role: "system", parent: parentId} + DbSvc->>Dexie: db.messages.add(systemMsg) + Dexie->>IDB: INSERT + DbSvc-->>Store: DatabaseMessage + deactivate DbSvc + Store->>DbSvc: createMessageBranch(message, parentId) activate DbSvc DbSvc->>DbSvc: Generate UUID for new message @@ -116,6 +124,13 @@ sequenceDiagram end DbSvc->>Dexie: db.messages.delete(msgId) Dexie->>IDB: DELETE target message + + alt target message has a parent + DbSvc->>Dexie: db.messages.get(parentId) + DbSvc->>DbSvc: parent.children.filter(id !== msgId) + DbSvc->>Dexie: db.messages.update(parentId, {children}) + Note right of DbSvc: Remove deleted message from parent's children[] + end deactivate DbSvc %% ═══════════════════════════════════════════════════════════════════════════ @@ -125,12 +140,16 @@ sequenceDiagram Store->>DbSvc: importConversations(data) activate DbSvc loop each conversation in data - DbSvc->>DbSvc: Generate new UUIDs (avoid conflicts) - DbSvc->>Dexie: db.conversations.add(conversation) - Dexie->>IDB: INSERT conversation - loop each message - DbSvc->>Dexie: db.messages.add(message) - Dexie->>IDB: INSERT message + DbSvc->>Dexie: db.conversations.get(conv.id) + alt conversation already exists + Note right of DbSvc: Skip duplicate (keep existing) + else conversation is new + DbSvc->>Dexie: db.conversations.add(conversation) + Dexie->>IDB: INSERT conversation + loop each message + DbSvc->>Dexie: db.messages.add(message) + Dexie->>IDB: INSERT message + end end end deactivate DbSvc diff --git a/tools/server/webui/docs/flows/mcp-flow.md b/tools/server/webui/docs/flows/mcp-flow.md new file mode 100644 index 00000000000..c8aa6665993 --- /dev/null +++ b/tools/server/webui/docs/flows/mcp-flow.md @@ -0,0 +1,226 @@ +```mermaid +sequenceDiagram + participant UI as 🧩 McpServersSettings / ChatForm + participant chatStore as 🗄️ chatStore + participant mcpStore as 🗄️ mcpStore + participant mcpResStore as 🗄️ mcpResourceStore + participant convStore as 🗄️ conversationsStore + participant MCPSvc as ⚙️ MCPService + participant LS as 💾 LocalStorage + participant ExtMCP as 🔌 External MCP Server + + Note over mcpStore: State:
isInitializing, error
toolCount, connectedServers
healthChecks (Map)
connections (Map)
toolsIndex (Map)
serverConfigs (Map) + + Note over mcpResStore: State:
serverResources (Map)
cachedResources (Map)
subscriptions (Map)
attachments[] + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: 🚀 INITIALIZATION (App Startup) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>mcpStore: ensureInitialized() + activate mcpStore + + mcpStore->>LS: get(MCP_SERVERS_LOCALSTORAGE_KEY) + LS-->>mcpStore: MCPServerSettingsEntry[] + + mcpStore->>mcpStore: parseServerSettings(servers) + Note right of mcpStore: Filter enabled servers
Build MCPServerConfig objects
Per-chat overrides checked via convStore + + loop For each enabled server + mcpStore->>mcpStore: runHealthCheck(serverId) + mcpStore->>mcpStore: updateHealthCheck(id, CONNECTING) + + mcpStore->>MCPSvc: connect(serverName, config, clientInfo, capabilities, onPhase) + activate MCPSvc + + MCPSvc->>MCPSvc: createTransport(config) + Note right of MCPSvc: WebSocket / StreamableHTTP / SSE
with optional CORS proxy + + MCPSvc->>ExtMCP: Transport handshake + ExtMCP-->>MCPSvc: Connection established + + MCPSvc->>ExtMCP: Initialize request + Note right of ExtMCP: Exchange capabilities
Server info, protocol version + + ExtMCP-->>MCPSvc: InitializeResult (serverInfo, capabilities) + + MCPSvc->>ExtMCP: listTools() + ExtMCP-->>MCPSvc: Tool[] + + MCPSvc-->>mcpStore: MCPConnection + deactivate MCPSvc + + mcpStore->>mcpStore: connections.set(serverName, connection) + mcpStore->>mcpStore: indexTools(connection.tools, serverName) + Note right of mcpStore: toolsIndex.set(toolName, serverName)
Handle name conflicts with prefixes + + mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS) + mcpStore->>mcpStore: _connectedServers.push(serverName) + + alt Server supports resources + mcpStore->>MCPSvc: listAllResources(connection) + MCPSvc->>ExtMCP: listResources() + ExtMCP-->>MCPSvc: MCPResource[] + MCPSvc-->>mcpStore: resources + + mcpStore->>MCPSvc: listAllResourceTemplates(connection) + MCPSvc->>ExtMCP: listResourceTemplates() + ExtMCP-->>MCPSvc: MCPResourceTemplate[] + MCPSvc-->>mcpStore: templates + + mcpStore->>mcpResStore: setServerResources(serverName, resources, templates) + end + end + + mcpStore->>mcpStore: _isInitializing = false + deactivate mcpStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: 🔧 TOOL EXECUTION (Chat with Tools) + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>mcpStore: executeTool(mcpCall: MCPToolCall, signal?) + activate mcpStore + + mcpStore->>mcpStore: toolsIndex.get(mcpCall.function.name) + Note right of mcpStore: Resolve serverName from toolsIndex
MCPToolCall = {id, type, function: {name, arguments}} + + mcpStore->>mcpStore: acquireConnection() + Note right of mcpStore: activeFlowCount++
Prevent shutdown during execution + + mcpStore->>mcpStore: connection = connections.get(serverName) + + mcpStore->>MCPSvc: callTool(connection, {name, arguments}, signal) + activate MCPSvc + + MCPSvc->>MCPSvc: throwIfAborted(signal) + MCPSvc->>ExtMCP: callTool(name, arguments) + + alt Tool execution success + ExtMCP-->>MCPSvc: ToolCallResult (content, isError) + MCPSvc->>MCPSvc: formatToolResult(result) + Note right of MCPSvc: Handle text, image (base64),
embedded resource content + MCPSvc-->>mcpStore: ToolExecutionResult + else Tool execution error + ExtMCP-->>MCPSvc: Error + MCPSvc-->>mcpStore: throw Error + else Aborted + MCPSvc-->>mcpStore: throw AbortError + end + + deactivate MCPSvc + + mcpStore->>mcpStore: releaseConnection() + Note right of mcpStore: activeFlowCount-- + + mcpStore-->>UI: ToolExecutionResult + deactivate mcpStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: � RESOURCE ATTACHMENT CONSUMPTION + %% ═══════════════════════════════════════════════════════════════════════════ + + chatStore->>mcpStore: consumeResourceAttachmentsAsExtras() + activate mcpStore + mcpStore->>mcpResStore: getAttachments() + mcpResStore-->>mcpStore: MCPResourceAttachment[] + mcpStore->>mcpStore: Convert attachments to message extras + mcpStore->>mcpResStore: clearAttachments() + mcpStore-->>chatStore: MessageExtra[] (for user message) + deactivate mcpStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: �📝 PROMPT OPERATIONS + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>mcpStore: getAllPrompts() + activate mcpStore + + loop For each connected server with prompts capability + mcpStore->>MCPSvc: listPrompts(connection) + MCPSvc->>ExtMCP: listPrompts() + ExtMCP-->>MCPSvc: Prompt[] + MCPSvc-->>mcpStore: prompts + end + + mcpStore-->>UI: MCPPromptInfo[] (with serverName) + deactivate mcpStore + + UI->>mcpStore: getPrompt(serverName, promptName, args?) + activate mcpStore + + mcpStore->>MCPSvc: getPrompt(connection, name, args) + MCPSvc->>ExtMCP: getPrompt({name, arguments}) + ExtMCP-->>MCPSvc: GetPromptResult (messages) + MCPSvc-->>mcpStore: GetPromptResult + + mcpStore-->>UI: GetPromptResult + deactivate mcpStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: 📁 RESOURCE OPERATIONS + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>mcpResStore: addAttachment(resourceInfo) + activate mcpResStore + mcpResStore->>mcpResStore: Create MCPResourceAttachment (loading: true) + mcpResStore-->>UI: attachment + + UI->>mcpStore: readResource(serverName, uri) + activate mcpStore + + mcpStore->>MCPSvc: readResource(connection, uri) + MCPSvc->>ExtMCP: readResource({uri}) + ExtMCP-->>MCPSvc: MCPReadResourceResult (contents) + MCPSvc-->>mcpStore: contents + + mcpStore-->>UI: MCPResourceContent[] + deactivate mcpStore + + UI->>mcpResStore: updateAttachmentContent(attachmentId, content) + mcpResStore->>mcpResStore: cacheResourceContent(resource, content) + deactivate mcpResStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: 🔄 AUTO-RECONNECTION + %% ═══════════════════════════════════════════════════════════════════════════ + + Note over mcpStore: On WebSocket close or connection error: + mcpStore->>mcpStore: autoReconnect(serverName, attempt) + activate mcpStore + + mcpStore->>mcpStore: Calculate backoff delay + Note right of mcpStore: delay = min(30s, 1s * 2^attempt) + + mcpStore->>mcpStore: Wait for delay + mcpStore->>mcpStore: reconnectServer(serverName) + + alt Reconnection success + mcpStore->>mcpStore: updateHealthCheck(id, SUCCESS) + else Max attempts reached + mcpStore->>mcpStore: updateHealthCheck(id, ERROR) + end + deactivate mcpStore + + %% ═══════════════════════════════════════════════════════════════════════════ + Note over UI,ExtMCP: 🛑 SHUTDOWN + %% ═══════════════════════════════════════════════════════════════════════════ + + UI->>mcpStore: shutdown() + activate mcpStore + + mcpStore->>mcpStore: Wait for activeFlowCount == 0 + + loop For each connection + mcpStore->>MCPSvc: disconnect(connection) + MCPSvc->>MCPSvc: transport.onclose = undefined + MCPSvc->>ExtMCP: close() + end + + mcpStore->>mcpStore: connections.clear() + mcpStore->>mcpStore: toolsIndex.clear() + mcpStore->>mcpStore: _connectedServers = [] + + mcpStore->>mcpResStore: clear() + deactivate mcpStore +``` diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 8d13e5a535f..361144915f0 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -8,6 +8,7 @@ "name": "webui", "version": "1.0.0", "dependencies": { + "@modelcontextprotocol/sdk": "^1.25.1", "highlight.js": "^11.11.1", "mode-watcher": "^1.1.0", "pdfjs-dist": "^5.4.54", @@ -19,7 +20,8 @@ "remark-html": "^16.0.1", "remark-rehype": "^11.1.2", "svelte-sonner": "^1.0.5", - "unist-util-visit": "^5.0.0" + "unist-util-visit": "^5.0.0", + "zod": "^4.2.1" }, "devDependencies": { "@chromatic-com/storybook": "^5.0.0", @@ -853,6 +855,18 @@ "dev": true, "license": "MIT" }, + "node_modules/@hono/node-server": { + "version": "1.19.9", + "resolved": "https://registry.npmjs.org/@hono/node-server/-/node-server-1.19.9.tgz", + "integrity": "sha512-vHL6w3ecZsky+8P5MD+eFfaGTyCeOHUIFYMGpQGbrBTSmNNoxv0if69rEZ5giu36weC5saFuznL411gRX7bJDw==", + "license": "MIT", + "engines": { + "node": ">=18.14.1" + }, + "peerDependencies": { + "hono": "^4" + } + }, "node_modules/@humanfs/core": { "version": "0.19.1", "resolved": "https://registry.npmjs.org/@humanfs/core/-/core-0.19.1.tgz", @@ -1044,6 +1058,68 @@ "react": ">=16" } }, + "node_modules/@modelcontextprotocol/sdk": { + "version": "1.26.0", + "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.26.0.tgz", + "integrity": "sha512-Y5RmPncpiDtTXDbLKswIJzTqu2hyBKxTNsgKqKclDbhIgg1wgtf1fRuvxgTnRfcnxtvvgbIEcqUOzZrJ6iSReg==", + "license": "MIT", + "dependencies": { + "@hono/node-server": "^1.19.9", + "ajv": "^8.17.1", + "ajv-formats": "^3.0.1", + "content-type": "^1.0.5", + "cors": "^2.8.5", + "cross-spawn": "^7.0.5", + "eventsource": "^3.0.2", + "eventsource-parser": "^3.0.0", + "express": "^5.2.1", + "express-rate-limit": "^8.2.1", + "hono": "^4.11.4", + "jose": "^6.1.3", + "json-schema-typed": "^8.0.2", + "pkce-challenge": "^5.0.0", + "raw-body": "^3.0.0", + "zod": "^3.25 || ^4.0", + "zod-to-json-schema": "^3.25.1" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@cfworker/json-schema": "^4.1.1", + "zod": "^3.25 || ^4.0" + }, + "peerDependenciesMeta": { + "@cfworker/json-schema": { + "optional": true + }, + "zod": { + "optional": false + } + } + }, + "node_modules/@modelcontextprotocol/sdk/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/@modelcontextprotocol/sdk/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "license": "MIT" + }, "node_modules/@napi-rs/canvas": { "version": "0.1.76", "resolved": "https://registry.npmjs.org/@napi-rs/canvas/-/canvas-0.1.76.tgz", @@ -2164,9 +2240,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.52.0", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.52.0.tgz", - "integrity": "sha512-zG+HmJuSF7eC0e7xt2htlOcEMAdEtlVdb7+gAr+ef08EhtwUsjLxcAwBgUCJY3/5p08OVOxVZti91WfXeuLvsg==", + "version": "2.50.2", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.50.2.tgz", + "integrity": "sha512-875hTUkEbz+MyJIxWbQjfMaekqdmEKUUfR7JyKcpfMRZqcGyrO9Gd+iS1D/Dx8LpE5FEtutWGOtlAh4ReSAiOA==", "dev": true, "license": "MIT", "peer": true, @@ -3282,6 +3358,19 @@ "url": "https://opencollective.com/vitest" } }, + "node_modules/accepts": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/accepts/-/accepts-2.0.0.tgz", + "integrity": "sha512-5cvg6CtKwfgdmVqY1WIiXKc3Q1bkRqGLi+2W/6ao+6Y7gu/RCwRuAhGEzh5B4KlszSuTLgZYuqFqo5bImjNKng==", + "license": "MIT", + "dependencies": { + "mime-types": "^3.0.0", + "negotiator": "^1.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, "node_modules/acorn": { "version": "8.15.0", "resolved": "https://registry.npmjs.org/acorn/-/acorn-8.15.0.tgz", @@ -3322,6 +3411,45 @@ "url": "https://github.com/sponsors/epoberezkin" } }, + "node_modules/ajv-formats": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/ajv-formats/-/ajv-formats-3.0.1.tgz", + "integrity": "sha512-8iUql50EUR+uUcdRQ3HDqa6EVyo3docL8g5WJ3FNcWmu62IbkGUue/pEyLBW8VGKKucTPgqeks4fIU1DA4yowQ==", + "license": "MIT", + "dependencies": { + "ajv": "^8.0.0" + }, + "peerDependencies": { + "ajv": "^8.0.0" + }, + "peerDependenciesMeta": { + "ajv": { + "optional": true + } + } + }, + "node_modules/ajv-formats/node_modules/ajv": { + "version": "8.17.1", + "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz", + "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==", + "license": "MIT", + "dependencies": { + "fast-deep-equal": "^3.1.3", + "fast-uri": "^3.0.1", + "json-schema-traverse": "^1.0.0", + "require-from-string": "^2.0.2" + }, + "funding": { + "type": "github", + "url": "https://github.com/sponsors/epoberezkin" + } + }, + "node_modules/ajv-formats/node_modules/json-schema-traverse": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz", + "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==", + "license": "MIT" + }, "node_modules/ansi-regex": { "version": "5.0.1", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-5.0.1.tgz", @@ -3464,9 +3592,9 @@ } }, "node_modules/bits-ui": { - "version": "2.15.7", - "resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.15.7.tgz", - "integrity": "sha512-M9VrQAJXnT3xfhN/joEtVXhO794yBPmadZfNtDT4t4QwI8wgCBmDuv8FlH6K4v0q0Ugw07tumAPfym9MU2BGpg==", + "version": "2.15.5", + "resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.15.5.tgz", + "integrity": "sha512-WhS+P+E//ClLfKU6KqjKC17nGDRLnz+vkwoP6ClFUPd5m1fFVDxTElPX8QVsduLj5V1KFDxlnv6sW2G5Lqk+vw==", "dev": true, "license": "MIT", "dependencies": { @@ -3534,6 +3662,46 @@ "svelte": "^5.30.2" } }, + "node_modules/body-parser": { + "version": "2.2.2", + "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-2.2.2.tgz", + "integrity": "sha512-oP5VkATKlNwcgvxi0vM0p/D3n2C3EReYVX+DNYs5TjZFn/oQt2j+4sVJtSMr18pdRr8wjTcBl6LoV+FUwzPmNA==", + "license": "MIT", + "dependencies": { + "bytes": "^3.1.2", + "content-type": "^1.0.5", + "debug": "^4.4.3", + "http-errors": "^2.0.0", + "iconv-lite": "^0.7.0", + "on-finished": "^2.4.1", + "qs": "^6.14.1", + "raw-body": "^3.0.1", + "type-is": "^2.0.1" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/body-parser/node_modules/iconv-lite": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.7.2.tgz", + "integrity": "sha512-im9DjEDQ55s9fL4EYzOAv0yMqmMBSZp6G0VvFyTMPKWxiSBHUj9NW/qqLmXUwXrrM7AvqSlTCfvqRb0cM8yYqw==", + "license": "MIT", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/brace-expansion": { "version": "1.1.12", "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-1.1.12.tgz", @@ -3575,6 +3743,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/bytes": { + "version": "3.1.2", + "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz", + "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/cac": { "version": "6.7.14", "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", @@ -3589,7 +3766,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz", "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==", - "dev": true, "license": "MIT", "dependencies": { "es-errors": "^1.3.0", @@ -3603,7 +3779,6 @@ "version": "1.0.4", "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz", "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==", - "dev": true, "license": "MIT", "dependencies": { "call-bind-apply-helpers": "^1.0.2", @@ -3816,6 +3991,28 @@ "dev": true, "license": "MIT" }, + "node_modules/content-disposition": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-1.0.1.tgz", + "integrity": "sha512-oIXISMynqSqm241k6kcQ5UwttDILMK4BiurCfGEREw6+X9jkkpEe5T9FZaApyLGGOnFuyMWZpdolTXMtvEJ08Q==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/content-type": { + "version": "1.0.5", + "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz", + "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/cookie": { "version": "0.6.0", "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.6.0.tgz", @@ -3826,6 +4023,28 @@ "node": ">= 0.6" } }, + "node_modules/cookie-signature": { + "version": "1.2.2", + "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.2.2.tgz", + "integrity": "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==", + "license": "MIT", + "engines": { + "node": ">=6.6.0" + } + }, + "node_modules/cors": { + "version": "2.8.5", + "resolved": "https://registry.npmjs.org/cors/-/cors-2.8.5.tgz", + "integrity": "sha512-KIHbLJqu73RGr/hnbrO9uBeixNGuvSQjul/jdFvS/KFSIH1hWVd1ng7zOHx+YrEfInLG7q4n6GHQ9cDtxv/P6g==", + "license": "MIT", + "dependencies": { + "object-assign": "^4", + "vary": "^1" + }, + "engines": { + "node": ">= 0.10" + } + }, "node_modules/corser": { "version": "2.0.1", "resolved": "https://registry.npmjs.org/corser/-/corser-2.0.1.tgz", @@ -3840,7 +4059,6 @@ "version": "7.0.6", "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz", "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==", - "dev": true, "license": "MIT", "dependencies": { "path-key": "^3.1.0", @@ -4000,6 +4218,15 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/depd": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz", + "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/dequal": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz", @@ -4056,7 +4283,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz", "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==", - "dev": true, "license": "MIT", "dependencies": { "call-bind-apply-helpers": "^1.0.1", @@ -4074,6 +4300,12 @@ "dev": true, "license": "MIT" }, + "node_modules/ee-first": { + "version": "1.1.1", + "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz", + "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==", + "license": "MIT" + }, "node_modules/emoji-regex": { "version": "9.2.2", "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", @@ -4081,6 +4313,15 @@ "dev": true, "license": "MIT" }, + "node_modules/encodeurl": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz", + "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/enhanced-resolve": { "version": "5.18.2", "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.2.tgz", @@ -4112,7 +4353,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz", "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -4122,7 +4362,6 @@ "version": "1.3.0", "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz", "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -4139,7 +4378,6 @@ "version": "1.1.1", "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz", "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==", - "dev": true, "license": "MIT", "dependencies": { "es-errors": "^1.3.0" @@ -4202,6 +4440,12 @@ "@esbuild/win32-x64": "0.25.8" } }, + "node_modules/escape-html": { + "version": "1.0.3", + "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz", + "integrity": "sha512-NiSupZ4OeuGwr68lGIeym/ksIZMJodUGOSCZ/FSnTxcrekbvqrgdUxlJOMpijaKZVjAJrWrGs/6Jy8OMuyj9ow==", + "license": "MIT" + }, "node_modules/escape-string-regexp": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/escape-string-regexp/-/escape-string-regexp-4.0.0.tgz", @@ -4293,9 +4537,9 @@ } }, "node_modules/eslint-plugin-storybook": { - "version": "10.2.9", - "resolved": "https://registry.npmjs.org/eslint-plugin-storybook/-/eslint-plugin-storybook-10.2.9.tgz", - "integrity": "sha512-nmPxjPw2KfmosqAUb/W0jmEfAZzK97kyJ8W5KMuweCblwjIL0hI/GMsWSP8CCBPnhQ9LnuxtT8JtQUOsslcbwA==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/eslint-plugin-storybook/-/eslint-plugin-storybook-10.2.4.tgz", + "integrity": "sha512-D8a6Y+iun2MSOpgps0Vd/t8y9Y5ZZ7O2VeKqw2PCv2+b7yInqogOS2VBMSRZVfP8TTGQgDpbUK67k7KZEUC7Ng==", "dev": true, "license": "MIT", "dependencies": { @@ -4303,7 +4547,174 @@ }, "peerDependencies": { "eslint": ">=8", - "storybook": "^10.2.9" + "storybook": "^10.2.4" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/project-service": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.54.0.tgz", + "integrity": "sha512-YPf+rvJ1s7MyiWM4uTRhE4DvBXrEV+d8oC3P9Y2eT7S+HBS0clybdMIPnhiATi9vZOYDc7OQ1L/i6ga6NFYK/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/tsconfig-utils": "^8.54.0", + "@typescript-eslint/types": "^8.54.0", + "debug": "^4.4.3" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/scope-manager": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.54.0.tgz", + "integrity": "sha512-27rYVQku26j/PbHYcVfRPonmOlVI6gihHtXFbTdB5sb6qA0wdAQAbyXFVarQ5t4HRojIz64IV90YtsjQSSGlQg==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.54.0", + "@typescript-eslint/visitor-keys": "8.54.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/tsconfig-utils": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.54.0.tgz", + "integrity": "sha512-dRgOyT2hPk/JwxNMZDsIXDgyl9axdJI3ogZ2XWhBPsnZUv+hPesa5iuhdYt2gzwA9t8RE5ytOJ6xB0moV0Ujvw==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/types": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.54.0.tgz", + "integrity": "sha512-PDUI9R1BVjqu7AUDsRBbKMtwmjWcn4J3le+5LpcFgWULN3LvHC5rkc9gCVxbrsrGmO1jfPybN5s6h4Jy+OnkAA==", + "dev": true, + "license": "MIT", + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/typescript-estree": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.54.0.tgz", + "integrity": "sha512-BUwcskRaPvTk6fzVWgDPdUndLjB87KYDrN5EYGetnktoeAvPtO4ONHlAZDnj5VFnUANg0Sjm7j4usBlnoVMHwA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/project-service": "8.54.0", + "@typescript-eslint/tsconfig-utils": "8.54.0", + "@typescript-eslint/types": "8.54.0", + "@typescript-eslint/visitor-keys": "8.54.0", + "debug": "^4.4.3", + "minimatch": "^9.0.5", + "semver": "^7.7.3", + "tinyglobby": "^0.2.15", + "ts-api-utils": "^2.4.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/utils": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.54.0.tgz", + "integrity": "sha512-9Cnda8GS57AQakvRyG0PTejJNlA2xhvyNtEVIMlDWOOeEyBkYWhGPnfrIAnqxLMTSTo6q8g12XVjjev5l1NvMA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@eslint-community/eslint-utils": "^4.9.1", + "@typescript-eslint/scope-manager": "8.54.0", + "@typescript-eslint/types": "8.54.0", + "@typescript-eslint/typescript-estree": "8.54.0" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + }, + "peerDependencies": { + "eslint": "^8.57.0 || ^9.0.0", + "typescript": ">=4.8.4 <6.0.0" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/@typescript-eslint/visitor-keys": { + "version": "8.54.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.54.0.tgz", + "integrity": "sha512-VFlhGSl4opC0bprJiItPQ1RfUhGDIBokcPwaFH4yiBCaNPeld/9VeXbiPO1cLyorQi1G1vL+ecBk1x8o1axORA==", + "dev": true, + "license": "MIT", + "dependencies": { + "@typescript-eslint/types": "8.54.0", + "eslint-visitor-keys": "^4.2.1" + }, + "engines": { + "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/typescript-eslint" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/eslint-plugin-storybook/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" } }, "node_modules/eslint-plugin-svelte": { @@ -4474,6 +4885,15 @@ "node": ">=0.10.0" } }, + "node_modules/etag": { + "version": "1.8.1", + "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz", + "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/eventemitter3": { "version": "4.0.7", "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-4.0.7.tgz", @@ -4481,6 +4901,27 @@ "dev": true, "license": "MIT" }, + "node_modules/eventsource": { + "version": "3.0.7", + "resolved": "https://registry.npmjs.org/eventsource/-/eventsource-3.0.7.tgz", + "integrity": "sha512-CRT1WTyuQoD771GW56XEZFQ/ZoSfWid1alKGDYMmkt2yl8UXrVR4pspqWNEcqKvVIzg6PAltWjxcSSPrboA4iA==", + "license": "MIT", + "dependencies": { + "eventsource-parser": "^3.0.1" + }, + "engines": { + "node": ">=18.0.0" + } + }, + "node_modules/eventsource-parser": { + "version": "3.0.6", + "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.6.tgz", + "integrity": "sha512-Vo1ab+QXPzZ4tCa8SwIHJFaSzy4R6SHf7BY79rFBDf0idraZWAkYrDjDj8uWaSm3S2TK+hJ7/t1CEmZ7jXw+pg==", + "license": "MIT", + "engines": { + "node": ">=18.0.0" + } + }, "node_modules/expect-type": { "version": "1.2.2", "resolved": "https://registry.npmjs.org/expect-type/-/expect-type-1.2.2.tgz", @@ -4491,6 +4932,76 @@ "node": ">=12.0.0" } }, + "node_modules/express": { + "version": "5.2.1", + "resolved": "https://registry.npmjs.org/express/-/express-5.2.1.tgz", + "integrity": "sha512-hIS4idWWai69NezIdRt2xFVofaF4j+6INOpJlVOLDO8zXGpUVEVzIYk12UUi2JzjEzWL3IOAxcTubgz9Po0yXw==", + "license": "MIT", + "dependencies": { + "accepts": "^2.0.0", + "body-parser": "^2.2.1", + "content-disposition": "^1.0.0", + "content-type": "^1.0.5", + "cookie": "^0.7.1", + "cookie-signature": "^1.2.1", + "debug": "^4.4.0", + "depd": "^2.0.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "finalhandler": "^2.1.0", + "fresh": "^2.0.0", + "http-errors": "^2.0.0", + "merge-descriptors": "^2.0.0", + "mime-types": "^3.0.0", + "on-finished": "^2.4.1", + "once": "^1.4.0", + "parseurl": "^1.3.3", + "proxy-addr": "^2.0.7", + "qs": "^6.14.0", + "range-parser": "^1.2.1", + "router": "^2.2.0", + "send": "^1.1.0", + "serve-static": "^2.2.0", + "statuses": "^2.0.1", + "type-is": "^2.0.1", + "vary": "^1.1.2" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/express-rate-limit": { + "version": "8.2.1", + "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-8.2.1.tgz", + "integrity": "sha512-PCZEIEIxqwhzw4KF0n7QF4QqruVTcF73O5kFKUnGOyjbCCgizBBiFaYpd/fnBLUMPw/BWw9OsiN7GgrNYr7j6g==", + "license": "MIT", + "dependencies": { + "ip-address": "10.0.1" + }, + "engines": { + "node": ">= 16" + }, + "funding": { + "url": "https://github.com/sponsors/express-rate-limit" + }, + "peerDependencies": { + "express": ">= 4.11" + } + }, + "node_modules/express/node_modules/cookie": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz", + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/extend": { "version": "3.0.2", "resolved": "https://registry.npmjs.org/extend/-/extend-3.0.2.tgz", @@ -4501,7 +5012,6 @@ "version": "3.1.3", "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz", "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==", - "dev": true, "license": "MIT" }, "node_modules/fast-json-stable-stringify": { @@ -4518,6 +5028,22 @@ "dev": true, "license": "MIT" }, + "node_modules/fast-uri": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.1.0.tgz", + "integrity": "sha512-iPeeDKJSWf4IEOasVVrknXpaBV0IApz/gp7S2bb7Z4Lljbl2MGJRqInZiUrQwV16cpzw/D3S5j5Julj/gT52AA==", + "funding": [ + { + "type": "github", + "url": "https://github.com/sponsors/fastify" + }, + { + "type": "opencollective", + "url": "https://opencollective.com/fastify" + } + ], + "license": "BSD-3-Clause" + }, "node_modules/fdir": { "version": "6.5.0", "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", @@ -4580,6 +5106,27 @@ "node": ">=8" } }, + "node_modules/finalhandler": { + "version": "2.1.1", + "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-2.1.1.tgz", + "integrity": "sha512-S8KoZgRZN+a5rNwqTxlZZePjT/4cnm0ROV70LedRHZ0p8u9fRID0hJUZQpkKLzro8LfmC8sx23bY6tVNxv8pQA==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.0", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "on-finished": "^2.4.1", + "parseurl": "^1.3.3", + "statuses": "^2.0.1" + }, + "engines": { + "node": ">= 18.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/find-up": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/find-up/-/find-up-5.0.0.tgz", @@ -4656,6 +5203,24 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/forwarded": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz", + "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/fresh": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/fresh/-/fresh-2.0.0.tgz", + "integrity": "sha512-Rx/WycZ60HOaqLKAi6cHRKKI7zxWbJ31MhntmtwMoaTeF7XFH9hhBp8vITaMidfljRQ6eYWCKkaTK+ykVJHP2A==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/fsevents": { "version": "2.3.2", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", @@ -4675,7 +5240,6 @@ "version": "1.1.2", "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz", "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==", - "dev": true, "license": "MIT", "funding": { "url": "https://github.com/sponsors/ljharb" @@ -4685,7 +5249,6 @@ "version": "1.3.0", "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz", "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==", - "dev": true, "license": "MIT", "dependencies": { "call-bind-apply-helpers": "^1.0.2", @@ -4710,7 +5273,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz", "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==", - "dev": true, "license": "MIT", "dependencies": { "dunder-proto": "^1.0.1", @@ -4797,7 +5359,6 @@ "version": "1.2.0", "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz", "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -4827,7 +5388,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz", "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -4840,7 +5400,6 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz", "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==", - "dev": true, "license": "MIT", "dependencies": { "function-bind": "^1.1.2" @@ -5108,6 +5667,16 @@ "node": ">=12.0.0" } }, + "node_modules/hono": { + "version": "4.11.7", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.11.7.tgz", + "integrity": "sha512-l7qMiNee7t82bH3SeyUCt9UF15EVmaBvsppY2zQtrbIhl/yzBTny+YUxsVjSjQ6gaqaeVtZmGocom8TzBlA4Yw==", + "license": "MIT", + "peer": true, + "engines": { + "node": ">=16.9.0" + } + }, "node_modules/html-encoding-sniffer": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/html-encoding-sniffer/-/html-encoding-sniffer-3.0.0.tgz", @@ -5138,6 +5707,26 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/http-errors": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.1.tgz", + "integrity": "sha512-4FbRdAX+bSdmo4AUFuS0WNiPz8NgFt+r8ThgNWmlrjQjt1Q7ZR9+zTlce2859x4KSXrwIsaeTqDoKQmtP8pLmQ==", + "license": "MIT", + "dependencies": { + "depd": "~2.0.0", + "inherits": "~2.0.4", + "setprototypeof": "~1.2.0", + "statuses": "~2.0.2", + "toidentifier": "~1.0.1" + }, + "engines": { + "node": ">= 0.8" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/http-proxy": { "version": "1.18.1", "resolved": "https://registry.npmjs.org/http-proxy/-/http-proxy-1.18.1.tgz", @@ -5248,12 +5837,36 @@ "node": ">=8" } }, + "node_modules/inherits": { + "version": "2.0.4", + "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz", + "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==", + "license": "ISC" + }, "node_modules/inline-style-parser": { "version": "0.2.4", "resolved": "https://registry.npmjs.org/inline-style-parser/-/inline-style-parser-0.2.4.tgz", "integrity": "sha512-0aO8FkhNZlj/ZIbNi7Lxxr12obT7cL1moPfE4tg1LkX7LlLfC6DeX4l2ZEud1ukP9jNQyNnfzQVqwbwmAATY4Q==", "license": "MIT" }, + "node_modules/ip-address": { + "version": "10.0.1", + "resolved": "https://registry.npmjs.org/ip-address/-/ip-address-10.0.1.tgz", + "integrity": "sha512-NWv9YLW4PoW2B7xtzaS3NCot75m6nK7Icdv0o3lfMceJVRfSoQwqD4wEH5rLwoKJwUiZ/rfpiVBhnaF0FK4HoA==", + "license": "MIT", + "engines": { + "node": ">= 12" + } + }, + "node_modules/ipaddr.js": { + "version": "1.9.1", + "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz", + "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==", + "license": "MIT", + "engines": { + "node": ">= 0.10" + } + }, "node_modules/is-docker": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-3.0.0.tgz", @@ -5345,6 +5958,12 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/is-promise": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-4.0.0.tgz", + "integrity": "sha512-hvpoI6korhJMnej285dSg6nu1+e6uxs7zG3BYAm5byqDsgJNWwxzM6z6iZiAgQR4TJ30JmBTOwqZUw3WlyH3AQ==", + "license": "MIT" + }, "node_modules/is-wsl": { "version": "3.1.0", "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-3.1.0.tgz", @@ -5365,7 +5984,6 @@ "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==", - "dev": true, "license": "ISC" }, "node_modules/istanbul-lib-coverage": { @@ -5448,6 +6066,15 @@ "jiti": "lib/jiti-cli.mjs" } }, + "node_modules/jose": { + "version": "6.1.3", + "resolved": "https://registry.npmjs.org/jose/-/jose-6.1.3.tgz", + "integrity": "sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/panva" + } + }, "node_modules/js-tokens": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-4.0.0.tgz", @@ -5482,6 +6109,12 @@ "dev": true, "license": "MIT" }, + "node_modules/json-schema-typed": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/json-schema-typed/-/json-schema-typed-8.0.2.tgz", + "integrity": "sha512-fQhoXdcvc3V28x7C7BMs4P5+kNlgUURe2jmUT1T//oBRMDrqy1QPelJimwZGo7Hg9VPV3EQV5Bnq4hbFy2vetA==", + "license": "BSD-2-Clause" + }, "node_modules/json-stable-stringify-without-jsonify": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/json-stable-stringify-without-jsonify/-/json-stable-stringify-without-jsonify-1.0.1.tgz", @@ -5959,7 +6592,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz", "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -6313,6 +6945,27 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/media-typer": { + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-1.1.0.tgz", + "integrity": "sha512-aisnrDP4GNe06UcKFnV5bfMNPBUw4jsLGaWwWfnH3v02GnBuXX2MCVn5RbrWo0j3pczUilYblq7fQ7Nw2t5XKw==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/merge-descriptors": { + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-2.0.0.tgz", + "integrity": "sha512-Snk314V5ayFLhp3fkUREub6WtjBfPdCPY1Ln8/8munuLuiYhsABgBVWsozAG+MWMbVEvcdcpbi9R7ww22l9Q3g==", + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/micromark": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/micromark/-/micromark-4.0.2.tgz", @@ -6938,6 +7591,31 @@ "node": ">=4" } }, + "node_modules/mime-db": { + "version": "1.54.0", + "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.54.0.tgz", + "integrity": "sha512-aU5EJuIN2WDemCcAp2vFBfp/m4EAhWJnUNSSw0ixs7/kXbd6Pg64EmwJkNdFhB8aWt1sH2CTXrLxo/iAGV3oPQ==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/mime-types": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-3.0.2.tgz", + "integrity": "sha512-Lbgzdk0h4juoQ9fCKXW4by0UJqj+nOOrI9MJ1sSj4nI8aI2eo1qmvQEie4VD1glsS250n15LsWsYtCugiStS5A==", + "license": "MIT", + "dependencies": { + "mime-db": "^1.54.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/min-indent": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/min-indent/-/min-indent-1.0.1.tgz", @@ -7069,6 +7747,15 @@ "dev": true, "license": "MIT" }, + "node_modules/negotiator": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-1.0.0.tgz", + "integrity": "sha512-8Ofs/AUQh8MaEcrlq5xOX0CQ9ypTF5dl78mjlMNfOK08fzpgTHQRQPBxcPlEtIw0yRpws+Zo/3r+5WRby7u3Gg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, "node_modules/node-addon-api": { "version": "7.1.1", "resolved": "https://registry.npmjs.org/node-addon-api/-/node-addon-api-7.1.1.tgz", @@ -7077,11 +7764,19 @@ "license": "MIT", "optional": true }, + "node_modules/object-assign": { + "version": "4.1.1", + "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz", + "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/object-inspect": { "version": "1.13.4", "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz", "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==", - "dev": true, "license": "MIT", "engines": { "node": ">= 0.4" @@ -7090,6 +7785,27 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/on-finished": { + "version": "2.4.1", + "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", + "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==", + "license": "MIT", + "dependencies": { + "ee-first": "1.1.1" + }, + "engines": { + "node": ">= 0.8" + } + }, + "node_modules/once": { + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", + "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==", + "license": "ISC", + "dependencies": { + "wrappy": "1" + } + }, "node_modules/open": { "version": "10.2.0", "resolved": "https://registry.npmjs.org/open/-/open-10.2.0.tgz", @@ -7202,6 +7918,15 @@ "url": "https://github.com/inikulin/parse5?sponsor=1" } }, + "node_modules/parseurl": { + "version": "1.3.3", + "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz", + "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/path-exists": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/path-exists/-/path-exists-4.0.0.tgz", @@ -7216,7 +7941,6 @@ "version": "3.1.1", "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz", "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -7239,6 +7963,16 @@ "url": "https://github.com/sponsors/isaacs" } }, + "node_modules/path-to-regexp": { + "version": "8.3.0", + "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-8.3.0.tgz", + "integrity": "sha512-7jdwVIRtsP8MYpdXSwOS0YdD0Du+qOoF/AEPIt88PcCFrZCzx41oxku1jD88hZBwbNUIEfpqvuhjFaMAqMTWnA==", + "license": "MIT", + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/pathe": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", @@ -7288,6 +8022,15 @@ "url": "https://github.com/sponsors/jonschlinkert" } }, + "node_modules/pkce-challenge": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-5.0.1.tgz", + "integrity": "sha512-wQ0b/W4Fr01qtpHlqSqspcj3EhBvimsdh0KlHhH8HRZnMsEa0ea2fTULOXOS9ccQr3om+GcGRk4e+isrZWV8qQ==", + "license": "MIT", + "engines": { + "node": ">=16.20.0" + } + }, "node_modules/playwright": { "version": "1.56.1", "resolved": "https://registry.npmjs.org/playwright/-/playwright-1.56.1.tgz", @@ -7653,6 +8396,19 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/proxy-addr": { + "version": "2.0.7", + "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz", + "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==", + "license": "MIT", + "dependencies": { + "forwarded": "0.2.0", + "ipaddr.js": "1.9.1" + }, + "engines": { + "node": ">= 0.10" + } + }, "node_modules/punycode": { "version": "2.3.1", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz", @@ -7664,10 +8420,9 @@ } }, "node_modules/qs": { - "version": "6.15.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.15.0.tgz", - "integrity": "sha512-mAZTtNCeetKMH+pSjrb76NAM8V9a05I9aBZOHztWy/UqcJdQYNsf59vrRKWnojAT9Y+GbIvoTBC++CPHqpDBhQ==", - "dev": true, + "version": "6.14.1", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.1.tgz", + "integrity": "sha512-4EK3+xJl8Ts67nLYNwqw/dsFVnCf+qR7RgXSK9jEEm9unao3njwMDdmsdvoKBKHzxd7tCYz5e5M+SnMjdtXGQQ==", "license": "BSD-3-Clause", "dependencies": { "side-channel": "^1.1.0" @@ -7679,6 +8434,46 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/range-parser": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz", + "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==", + "license": "MIT", + "engines": { + "node": ">= 0.6" + } + }, + "node_modules/raw-body": { + "version": "3.0.2", + "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-3.0.2.tgz", + "integrity": "sha512-K5zQjDllxWkf7Z5xJdV0/B0WTNqx6vxG70zJE4N0kBs4LovmEYWJzQGxC9bS9RAKu3bgM40lrd5zoLJ12MQ5BA==", + "license": "MIT", + "dependencies": { + "bytes": "~3.1.2", + "http-errors": "~2.0.1", + "iconv-lite": "~0.7.0", + "unpipe": "~1.0.0" + }, + "engines": { + "node": ">= 0.10" + } + }, + "node_modules/raw-body/node_modules/iconv-lite": { + "version": "0.7.2", + "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.7.2.tgz", + "integrity": "sha512-im9DjEDQ55s9fL4EYzOAv0yMqmMBSZp6G0VvFyTMPKWxiSBHUj9NW/qqLmXUwXrrM7AvqSlTCfvqRb0cM8yYqw==", + "license": "MIT", + "dependencies": { + "safer-buffer": ">= 2.1.2 < 3.0.0" + }, + "engines": { + "node": ">=0.10.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/react": { "version": "19.1.0", "resolved": "https://registry.npmjs.org/react/-/react-19.1.0.tgz", @@ -7939,6 +8734,15 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/require-from-string": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz", + "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==", + "license": "MIT", + "engines": { + "node": ">=0.10.0" + } + }, "node_modules/requires-port": { "version": "1.0.0", "resolved": "https://registry.npmjs.org/requires-port/-/requires-port-1.0.0.tgz", @@ -7997,6 +8801,22 @@ "fsevents": "~2.3.2" } }, + "node_modules/router": { + "version": "2.2.0", + "resolved": "https://registry.npmjs.org/router/-/router-2.2.0.tgz", + "integrity": "sha512-nLTrUKm2UyiL7rlhapu/Zl45FwNgkZGaCpZbIHajDYgwlJCOzLSk+cIPAnsEqV955GjILJnKbdQC1nVPz+gAYQ==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.0", + "depd": "^2.0.0", + "is-promise": "^4.0.0", + "parseurl": "^1.3.3", + "path-to-regexp": "^8.0.0" + }, + "engines": { + "node": ">= 18" + } + }, "node_modules/run-applescript": { "version": "7.1.0", "resolved": "https://registry.npmjs.org/run-applescript/-/run-applescript-7.1.0.tgz", @@ -8049,7 +8869,6 @@ "version": "2.1.2", "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz", "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==", - "dev": true, "license": "MIT" }, "node_modules/sass": { @@ -8108,6 +8927,51 @@ "node": ">=10" } }, + "node_modules/send": { + "version": "1.2.1", + "resolved": "https://registry.npmjs.org/send/-/send-1.2.1.tgz", + "integrity": "sha512-1gnZf7DFcoIcajTjTwjwuDjzuz4PPcY2StKPlsGAQ1+YH20IRVrBaXSWmdjowTJ6u8Rc01PoYOGHXfP1mYcZNQ==", + "license": "MIT", + "dependencies": { + "debug": "^4.4.3", + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "etag": "^1.8.1", + "fresh": "^2.0.0", + "http-errors": "^2.0.1", + "mime-types": "^3.0.2", + "ms": "^2.1.3", + "on-finished": "^2.4.1", + "range-parser": "^1.2.1", + "statuses": "^2.0.2" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, + "node_modules/serve-static": { + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-2.2.1.tgz", + "integrity": "sha512-xRXBn0pPqQTVQiC8wyQrKs2MOlX24zQ0POGaj0kultvoOCstBQM5yvOhAVSUwOMjQtTvsPWoNCHfPGwaaQJhTw==", + "license": "MIT", + "dependencies": { + "encodeurl": "^2.0.0", + "escape-html": "^1.0.3", + "parseurl": "^1.3.3", + "send": "^1.2.0" + }, + "engines": { + "node": ">= 18" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/express" + } + }, "node_modules/set-cookie-parser": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-3.0.1.tgz", @@ -8115,11 +8979,16 @@ "dev": true, "license": "MIT" }, + "node_modules/setprototypeof": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz", + "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==", + "license": "ISC" + }, "node_modules/shebang-command": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz", "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==", - "dev": true, "license": "MIT", "dependencies": { "shebang-regex": "^3.0.0" @@ -8132,7 +9001,6 @@ "version": "3.0.0", "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz", "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==", - "dev": true, "license": "MIT", "engines": { "node": ">=8" @@ -8142,7 +9010,6 @@ "version": "1.1.0", "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz", "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==", - "dev": true, "license": "MIT", "dependencies": { "es-errors": "^1.3.0", @@ -8162,7 +9029,6 @@ "version": "1.0.0", "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz", "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==", - "dev": true, "license": "MIT", "dependencies": { "es-errors": "^1.3.0", @@ -8179,7 +9045,6 @@ "version": "1.0.1", "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz", "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==", - "dev": true, "license": "MIT", "dependencies": { "call-bound": "^1.0.2", @@ -8198,7 +9063,6 @@ "version": "1.0.2", "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz", "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==", - "dev": true, "license": "MIT", "dependencies": { "call-bound": "^1.0.2", @@ -8286,6 +9150,15 @@ "dev": true, "license": "MIT" }, + "node_modules/statuses": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz", + "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/std-env": { "version": "3.9.0", "resolved": "https://registry.npmjs.org/std-env/-/std-env-3.9.0.tgz", @@ -8294,9 +9167,9 @@ "license": "MIT" }, "node_modules/storybook": { - "version": "10.2.9", - "resolved": "https://registry.npmjs.org/storybook/-/storybook-10.2.9.tgz", - "integrity": "sha512-DGok7XwIwdPWF+a49Yw+4madER5DZWRo9CdyySBLT3zeuxiEPt0Ua7ouJHm/y6ojnb/FVKZcQe8YmrE71s0qPQ==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/storybook/-/storybook-10.2.4.tgz", + "integrity": "sha512-LwF0VZsT4qkgx66Ad/q0QgZZrU2a5WftaADDEcJ3bGq3O2fHvwWPlSZjM1HiXD4vqP9U5JiMqQkV1gkyH0XJkw==", "dev": true, "license": "MIT", "peer": true, @@ -8805,9 +9678,9 @@ } }, "node_modules/tar": { - "version": "7.5.9", - "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.9.tgz", - "integrity": "sha512-BTLcK0xsDh2+PUe9F6c2TlRp4zOOBMTkoQHQIWSIzI0R7KG46uEwq4OPk2W7bZcprBMsuaeFsqwYr7pjh6CuHg==", + "version": "7.5.7", + "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.7.tgz", + "integrity": "sha512-fov56fJiRuThVFXD6o6/Q354S7pnWMJIVlDBYijsTNx6jKSE4pvrDTs6lUnmGvNyfJwFQQwWy3owKz1ucIhveQ==", "dev": true, "license": "BlueOak-1.0.0", "dependencies": { @@ -8944,6 +9817,15 @@ "node": ">=8.0" } }, + "node_modules/toidentifier": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz", + "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==", + "license": "MIT", + "engines": { + "node": ">=0.6" + } + }, "node_modules/totalist": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/totalist/-/totalist-3.0.1.tgz", @@ -9040,6 +9922,20 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/type-is": { + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/type-is/-/type-is-2.0.1.tgz", + "integrity": "sha512-OZs6gsjF4vMp32qrCbiVSkrFmXtG/AZhY3t0iAMrMBiAZyV9oALtXO8hsrHbMXF9x6L3grlFuwW2oAz7cav+Gw==", + "license": "MIT", + "dependencies": { + "content-type": "^1.0.5", + "media-typer": "^1.1.0", + "mime-types": "^3.0.0" + }, + "engines": { + "node": ">= 0.6" + } + }, "node_modules/typescript": { "version": "5.8.3", "resolved": "https://registry.npmjs.org/typescript/-/typescript-5.8.3.tgz", @@ -9268,6 +10164,15 @@ "node": ">= 10.0.0" } }, + "node_modules/unpipe": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz", + "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/unplugin": { "version": "2.3.11", "resolved": "https://registry.npmjs.org/unplugin/-/unplugin-2.3.11.tgz", @@ -9332,6 +10237,15 @@ "uuid": "dist-node/bin/uuid" } }, + "node_modules/vary": { + "version": "1.1.2", + "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz", + "integrity": "sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==", + "license": "MIT", + "engines": { + "node": ">= 0.8" + } + }, "node_modules/vfile": { "version": "6.0.3", "resolved": "https://registry.npmjs.org/vfile/-/vfile-6.0.3.tgz", @@ -9704,7 +10618,6 @@ "version": "2.0.2", "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz", "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==", - "dev": true, "license": "ISC", "dependencies": { "isexe": "^2.0.0" @@ -9828,6 +10741,12 @@ "url": "https://github.com/chalk/ansi-styles?sponsor=1" } }, + "node_modules/wrappy": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz", + "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==", + "license": "ISC" + }, "node_modules/ws": { "version": "8.18.3", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", @@ -9895,6 +10814,25 @@ "integrity": "sha512-rAbqEGa8ovJy4pyBxZM70hg4pE6gDgaQ0Sl9M3enG3I0d6H4XSAM3GeNGLKnsBpuijUow064sf7ww1nutC5/3w==", "license": "MIT" }, + "node_modules/zod": { + "version": "4.2.1", + "resolved": "https://registry.npmjs.org/zod/-/zod-4.2.1.tgz", + "integrity": "sha512-0wZ1IRqGGhMP76gLqz8EyfBXKk0J2qo2+H3fi4mcUP/KtTocoX08nmIAHl1Z2kJIZbZee8KOpBCSNPRgauucjw==", + "license": "MIT", + "peer": true, + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, + "node_modules/zod-to-json-schema": { + "version": "3.25.1", + "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.25.1.tgz", + "integrity": "sha512-pM/SU9d3YAggzi6MtR4h7ruuQlqKtad8e9S0fmxcMi+ueAK5Korys/aWcV9LIIHTVbj01NdzxcnXSN+O74ZIVA==", + "license": "ISC", + "peerDependencies": { + "zod": "^3.25 || ^4" + } + }, "node_modules/zwitch": { "version": "2.0.4", "resolved": "https://registry.npmjs.org/zwitch/-/zwitch-2.0.4.tgz", diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index 0b74e301b1d..f5cdc9e47f0 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -79,6 +79,7 @@ "vitest-browser-svelte": "^0.1.0" }, "dependencies": { + "@modelcontextprotocol/sdk": "^1.25.1", "highlight.js": "^11.11.1", "mode-watcher": "^1.1.0", "pdfjs-dist": "^5.4.54", @@ -90,6 +91,7 @@ "remark-html": "^16.0.1", "remark-rehype": "^11.1.2", "svelte-sonner": "^1.0.5", - "unist-util-visit": "^5.0.0" + "unist-util-visit": "^5.0.0", + "zod": "^4.2.1" } } diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte index 4494ea880b9..c676e224a72 100644 --- a/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte +++ b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte @@ -8,6 +8,7 @@ tooltip: string; variant?: 'default' | 'destructive' | 'outline' | 'secondary' | 'ghost' | 'link'; size?: 'default' | 'sm' | 'lg' | 'icon'; + iconSize?: string; class?: string; disabled?: boolean; onclick: () => void; @@ -21,6 +22,7 @@ size = 'sm', class: className = '', disabled = false, + iconSize = 'h-3 w-3', onclick, 'aria-label': ariaLabel }: Props = $props(); @@ -38,7 +40,7 @@ > {@const IconComponent = icon} - + diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte index 1ae3d217747..11f1c17d988 100644 --- a/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte @@ -6,21 +6,22 @@ id: string; onRemove?: (id: string) => void; class?: string; + iconSize?: number; } - let { id, onRemove, class: className = '' }: Props = $props(); + let { id, onRemove, class: className = '', iconSize = 3 }: Props = $props(); diff --git a/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte index a0d5e863c2a..15936691a6a 100644 --- a/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte +++ b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte @@ -1,6 +1,6 @@ + +
+ + + {#if !readonly && onRemove} +
+ onRemove?.()} /> +
+ {/if} +
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentMcpResource.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentMcpResource.svelte new file mode 100644 index 00000000000..258fcac80e7 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentMcpResource.svelte @@ -0,0 +1,86 @@ + + + + + + + + +
+ {#if favicon} + { + (e.currentTarget as HTMLImageElement).style.display = 'none'; + }} + /> + {/if} + + + {serverName} + +
+
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentMcpResources.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentMcpResources.svelte new file mode 100644 index 00000000000..341bf32c058 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentMcpResources.svelte @@ -0,0 +1,41 @@ + + +{#if hasAttachments} +
+ + {#each attachments as attachment, i (attachment.id)} + handleResourceClick(attachment.resource.uri)} + /> + {/each} + +
+{/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte index 6248d84fb0d..a3d37b42a3b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentsList.svelte @@ -1,12 +1,21 @@ @@ -168,6 +231,103 @@

{systemMessageTooltip}

+ + + + + + + + MCP Servers + + + + +
+ {#each filteredMcpServers as server (server.id)} + {@const healthState = mcpStore.getHealthCheckState(server.id)} + {@const hasError = healthState.status === HealthCheckStatus.ERROR} + {@const isEnabledForChat = isServerEnabledForChat(server.id)} + + + {/each} +
+ + {#snippet footer()} + + + + Manage MCP Servers + + {/snippet} +
+
+
+ + {#if hasMcpPromptsSupport} + + + + MCP Prompt + + {/if} + + {#if hasMcpResourcesSupport} + + + + MCP Resources + + {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsSheet.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsSheet.svelte new file mode 100644 index 00000000000..bf643dd7f25 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsSheet.svelte @@ -0,0 +1,170 @@ + + +
+ + + + + + Add to chat + + + Add files, system prompt or configure MCP servers + + + +
+ + + + + + + + + + + + + + + {#if hasMcpPromptsSupport} + + {/if} + + {#if hasMcpResourcesSupport} + + {/if} +
+
+
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte index 54b11c86249..85017769339 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte @@ -3,17 +3,24 @@ import { Button } from '$lib/components/ui/button'; import { ChatFormActionAttachmentsDropdown, + ChatFormActionAttachmentsSheet, ChatFormActionRecord, ChatFormActionSubmit, - ModelsSelector + McpServersSelector, + ModelsSelector, + ModelsSelectorSheet } from '$lib/components/app'; + import { DialogChatSettings } from '$lib/components/app/dialogs'; + import { SETTINGS_SECTION_TITLES } from '$lib/constants'; + import { mcpStore } from '$lib/stores/mcp.svelte'; import { FileTypeCategory } from '$lib/enums'; import { getFileTypeCategory } from '$lib/utils'; import { config } from '$lib/stores/settings.svelte'; import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte'; - import { isRouterMode } from '$lib/stores/server.svelte'; + import { isRouterMode, serverError } from '$lib/stores/server.svelte'; import { chatStore } from '$lib/stores/chat.svelte'; - import { activeMessages } from '$lib/stores/conversations.svelte'; + import { activeMessages, conversationsStore } from '$lib/stores/conversations.svelte'; + import { IsMobile } from '$lib/hooks/is-mobile.svelte'; interface Props { canSend?: boolean; @@ -27,6 +34,8 @@ onMicClick?: () => void; onStop?: () => void; onSystemPromptClick?: () => void; + onMcpPromptClick?: () => void; + onMcpResourcesClick?: () => void; } let { @@ -40,11 +49,14 @@ onFileUpload, onMicClick, onStop, - onSystemPromptClick + onSystemPromptClick, + onMcpPromptClick, + onMcpResourcesClick }: Props = $props(); let currentConfig = $derived(config()); let isRouter = $derived(isRouterMode()); + let isOffline = $derived(!!serverError()); let conversationModel = $derived( chatStore.getConversationModel(activeMessages() as DatabaseMessage[]) @@ -55,7 +67,10 @@ $effect(() => { if (conversationModel && conversationModel !== previousConversationModel) { previousConversationModel = conversationModel; - modelsStore.selectModelByName(conversationModel); + + if (!isRouter || modelsStore.isModelLoaded(conversationModel)) { + modelsStore.selectModelByName(conversationModel); + } } }); @@ -148,32 +163,83 @@ return ''; }); - let selectorModelRef: ModelsSelector | undefined = $state(undefined); + let selectorModelRef: ModelsSelector | ModelsSelectorSheet | undefined = $state(undefined); + + let isMobile = new IsMobile(); export function openModelSelector() { selectorModelRef?.open(); } + + let showChatSettingsDialogWithMcpSection = $state(false); + + let hasMcpPromptsSupport = $derived.by(() => { + const perChatOverrides = conversationsStore.getAllMcpServerOverrides(); + + return mcpStore.hasPromptsCapability(perChatOverrides); + }); + + let hasMcpResourcesSupport = $derived.by(() => { + const perChatOverrides = conversationsStore.getAllMcpServerOverrides(); + + return mcpStore.hasResourcesCapability(perChatOverrides); + });
- (showChatSettingsDialogWithMcpSection = true)} + /> + {:else} + (showChatSettingsDialogWithMcpSection = true)} + /> + {/if} + + (showChatSettingsDialogWithMcpSection = true)} />
- + {#if isMobile.current} + + {:else} + + {/if}
{#if isLoading} @@ -201,3 +267,9 @@ /> {/if}
+ + (showChatSettingsDialogWithMcpSection = open)} + initialSection={SETTINGS_SECTION_TITLES.MCP} +/> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormHelperText.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormHelperText.svelte index f8246f249c3..a8f1f76c7cf 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormHelperText.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormHelperText.svelte @@ -8,7 +8,7 @@ {#if show} -
+