diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 575d2407917..74f078f5edd 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -936,75 +936,158 @@ static void test_peg_parser(common_chat_templates * tmpls, throw std::runtime_error("Failed to build grammar: " + parser.params_.grammar); } + // In production, grammar triggers match against the full generated text + // including the generation prompt. All positions are in full_input coordinates. + const auto & gen_prompt = parser.params_.generation_prompt; + std::string full_input = gen_prompt + tc.input; + + // Determine whether the reasoning-budget sampler path applies: tool-call grammar + // with all WORD triggers and thinking tags present. In production, the reasoning + // budget sampler inhibits grammar application while inside thinking blocks — + // triggers inside ... are suppressed. + bool use_reasoning_budget_path = false; + if (parser.params_.grammar_lazy && !parser.params_.thinking_end_tag.empty()) { + use_reasoning_budget_path = true; + for (const auto & trigger : parser.params_.grammar_triggers) { + if (trigger.type != COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { + use_reasoning_budget_path = false; + break; + } + } + } + // Find the earliest trigger position to determine the constrained portion auto earliest_trigger_pos = std::string::npos; - for (const auto & trigger : parser.params_.grammar_triggers) { - size_t pos = std::string::npos; - std::smatch match; - switch (trigger.type) { - case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: - { - const auto & word = trigger.value; - pos = tc.input.find(word); - break; - } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: - { - const auto & pattern = std::regex(trigger.value); - if (std::regex_search(tc.input, match, pattern)) { - pos = match.position(pattern.mark_count()); + + if (use_reasoning_budget_path) { + // Reasoning-budget path: simulate thinking-aware trigger detection. + // Walk through full_input tracking thinking state; only match triggers + // when outside thinking blocks. + const auto & think_start = parser.params_.thinking_start_tag; + const auto & think_end = parser.params_.thinking_end_tag; + + bool in_thinking = false; + for (size_t i = 0; i < full_input.size(); ++i) { + if (!in_thinking && !think_start.empty() + && full_input.compare(i, think_start.size(), think_start) == 0) { + in_thinking = true; + i += think_start.size() - 1; + continue; + } + if (in_thinking && full_input.compare(i, think_end.size(), think_end) == 0) { + in_thinking = false; + i += think_end.size() - 1; + continue; + } + if (in_thinking) { + continue; + } + // Outside thinking — check if any trigger word starts here + for (const auto & trigger : parser.params_.grammar_triggers) { + if (full_input.compare(i, trigger.value.size(), trigger.value) == 0) { + if (earliest_trigger_pos == std::string::npos || i < earliest_trigger_pos) { + earliest_trigger_pos = i; } - break; } - case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: - { - const auto & pattern = trigger.value; - if (std::regex_match(tc.input, match, std::regex(pattern))) { - auto mpos = std::string::npos; - for (size_t i = 1; i < match.size(); ++i) { - if (match[i].length() > 0) { - mpos = match.position(i); + } + if (earliest_trigger_pos != std::string::npos) { + break; // found the earliest + } + } + + // If the reasoning-budget path found no trigger outside thinking but the test + // expects tool calls, this template nests tool calls inside thinking + // blocks (e.g. Kimi). Fall back to the legacy path for this case. + if (earliest_trigger_pos == std::string::npos && !tc.expect.tool_calls.empty()) { + use_reasoning_budget_path = false; + } + } + + if (!use_reasoning_budget_path) { + // Legacy path: find triggers without thinking-awareness + for (const auto & trigger : parser.params_.grammar_triggers) { + size_t pos = std::string::npos; + std::smatch match; + switch (trigger.type) { + case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: + { + const auto & word = trigger.value; + pos = full_input.find(word); + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: + { + const auto & compiled = std::regex(trigger.value); + if (std::regex_search(full_input, match, compiled)) { + pos = match.position(compiled.mark_count()); + } + break; + } + case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: + { + // In production, PATTERN_FULL triggers are checked against + // the text generated so far, growing token by token. Simulate + // by trying every prefix of full_input. + const auto & compiled = std::regex(trigger.value); + for (size_t end = gen_prompt.size(); end <= full_input.size(); ++end) { + std::string prefix = full_input.substr(0, end); + if (std::regex_match(prefix, match, compiled)) { + pos = std::string::npos; + for (size_t gi = 1; gi < match.size(); ++gi) { + if (match[gi].length() > 0) { + pos = match.position(gi); + break; + } + } + if (pos == std::string::npos) { + pos = match.position(0); + } break; } } - if (mpos == std::string::npos) { - mpos = match.position(0); - } - pos = mpos; + break; } - break; + default: + throw std::runtime_error("Unknown trigger type"); + } + if (pos != std::string::npos) { + if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) { + earliest_trigger_pos = pos; } - default: - throw std::runtime_error("Unknown trigger type"); - } - if (pos != std::string::npos) { - if (earliest_trigger_pos == std::string::npos || pos < earliest_trigger_pos) { - earliest_trigger_pos = pos; } } } - // Determine the constrained portion of input to test against grammar - std::string constrained = tc.input; + // If the test expects tool calls and the grammar is lazy, the trigger must fire. + // Otherwise the grammar would never activate in production and tool calls wouldn't + // be constrained. A silent skip here would hide broken triggers. + if (parser.params_.grammar_lazy && !tc.expect.tool_calls.empty() && !tc.is_partial + && earliest_trigger_pos == std::string::npos) { + std::string trigger_desc; + for (const auto & trigger : parser.params_.grammar_triggers) { + trigger_desc += "\n [type=" + std::to_string(trigger.type) + "] " + trigger.value; + } + throw std::runtime_error( + "Grammar trigger did not fire, but test expects tool calls (lazy grammar).\n" + ">>> Input: " + full_input + "\n" + ">>> Triggers (" + std::to_string(parser.params_.grammar_triggers.size()) + "):" + trigger_desc); + } + + // Determine the constrained portion of input to test against grammar. + // If the trigger position falls inside the generation prompt, the grammar + // sampler was already active before model output began — constrain from the + // start of the model output (i.e. tc.input). + std::string constrained = full_input; bool grammar_triggered = false; if (earliest_trigger_pos != std::string::npos) { - constrained = tc.input.substr(earliest_trigger_pos); + auto constrain_from = std::max(earliest_trigger_pos, gen_prompt.size()); + constrained = full_input.substr(constrain_from); grammar_triggered = true; } else if (!parser.params_.grammar_lazy) { // For non-lazy grammars, the entire input should match grammar_triggered = true; } - // For non-lazy grammars, prepend reasoning prefill to grammar input, just like - // PEG parsing does. The grammar includes the full reasoning pattern (e.g. optional - // ...), but the model output may start mid-reasoning if the template - // already placed the opening tag in the prompt. - // For lazy grammars, the grammar only activates from the trigger position, so the - // reasoning prefill is irrelevant — reasoning is handled by the PEG parser. - if (!parser.params_.generation_prompt.empty() && earliest_trigger_pos == std::string::npos) { - constrained = parser.params_.generation_prompt + constrained; - } - // Test the constrained portion against the grammar if (grammar_triggered && !tc.is_partial) { auto result = match_string_detailed(constrained, grammar.get()); @@ -1323,6 +1406,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reasoning("I need to output the invoice details in JSON") .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})") .run(); + + // fake tool call marker in reasoning + tst.test( + "[THINK]Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more[/THINK]" + R"([TOOL_CALLS]special_function[ARGS]{"arg1": 1})") + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .enable_thinking(true) + .tools({ special_function_tool }) + .expect_reasoning("Let me think about [TOOL_CALLS]special_function[ARGS]{\"arg1\":1} and more") + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + }) + .run(); } { @@ -1425,6 +1521,50 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .expect_reasoning("I need to output the invoice details in JSON") .expect_content(R"({"amount": 123.45, "date": "2025-12-03"})") .run(); + + // tool call segment in reasoning + tst.test( + "Let's call a tool: \n" + "\n" + "\n" + "def hello():\n" + " print(\"Not the real call!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "" + ) + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ + python_tool + }) + .expect_reasoning("Let's call a tool: \n" + "\n" + "\n" + "def hello():\n" + " print(\"Not the real call!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + "") + .expect_tool_calls({ + { "python", "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", {} }, + }) + .run(); + } { @@ -2297,6 +2437,19 @@ static void test_template_output_peg_parsers(bool detailed_debug) { .tools({ empty_args_tool }) .expect(simple_assist_msg("", "", "empty_args", "{}")) .run(); + + // fake tool call marker in reasoning + tst.test( + "Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm" + "<|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|>") + .enable_thinking(true) + .reasoning_format(COMMON_REASONING_FORMAT_AUTO) + .tools({ special_function_tool }) + .expect_reasoning("Let me think about <|tool_call_start|>[special_function(arg1=1)]<|tool_call_end|> hmm") + .expect_tool_calls({ + { "special_function", R"({"arg1": 1})", {} }, + }) + .run(); } // Apertus-8B-Instruct tests - FUNC_NAME_AS_KEY format