From c5a3db1799278b1f551ec108d7a7960d7f8b7bf4 Mon Sep 17 00:00:00 2001 From: POf-L Date: Thu, 16 Apr 2026 07:58:48 +0000 Subject: [PATCH] =?UTF-8?q?feat:=20=E4=BC=98=E5=8C=96=20Tool=20Call=20?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E8=A7=A3=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: traeagent --- .../adapter/openai/handler_toolcall_test.go | 29 ++++++++ internal/format/openai/render_chat.go | 12 ++-- .../stream-tool-sieve/parse_payload.js | 40 +++++++++++ internal/toolcall/toolcalls_parse_markup.go | 67 ++++++++++++++++++- internal/toolcall/toolcalls_test.go | 21 ++++++ 5 files changed, 162 insertions(+), 7 deletions(-) diff --git a/internal/adapter/openai/handler_toolcall_test.go b/internal/adapter/openai/handler_toolcall_test.go index 6092461..be30b5b 100644 --- a/internal/adapter/openai/handler_toolcall_test.go +++ b/internal/adapter/openai/handler_toolcall_test.go @@ -208,6 +208,35 @@ func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) { } } +func TestHandleNonStreamDoesNotPromoteToolCallsWhenNoToolsProvided(t *testing.T) { + h := &Handler{} + resp := makeSSEHTTPResponse( + `data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`, + `data: [DONE]`, + ) + rec := httptest.NewRecorder() + + h.handleNonStream(rec, context.Background(), resp, "cid2b2", "deepseek-chat", "prompt", false, nil) + if rec.Code != http.StatusOK { + t.Fatalf("unexpected status: %d", rec.Code) + } + + out := decodeJSONBody(t, rec.Body.String()) + choices, _ := out["choices"].([]any) + choice, _ := choices[0].(map[string]any) + if choice["finish_reason"] != "stop" { + t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"]) + } + msg, _ := choice["message"].(map[string]any) + if msg["tool_calls"] != nil { + t.Fatalf("expected no tool_calls, got %#v", msg["tool_calls"]) + } + content, _ := msg["content"].(string) + if !strings.Contains(content, `"tool_calls"`) { + t.Fatalf("expected tool_calls json preserved as content, got %#v", content) + } +} + func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) { h := &Handler{} resp := makeSSEHTTPResponse( diff --git a/internal/format/openai/render_chat.go b/internal/format/openai/render_chat.go index c09e870..e85db5b 100644 --- a/internal/format/openai/render_chat.go +++ b/internal/format/openai/render_chat.go @@ -7,16 +7,18 @@ import ( ) func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any { - detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, toolNames) finishReason := "stop" messageObj := map[string]any{"role": "assistant", "content": finalText} if strings.TrimSpace(finalThinking) != "" { messageObj["reasoning_content"] = finalThinking } - if len(detected.Calls) > 0 { - finishReason = "tool_calls" - messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected.Calls) - messageObj["content"] = nil + if len(toolNames) > 0 { + detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, toolNames) + if len(detected.Calls) > 0 { + finishReason = "tool_calls" + messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected.Calls) + messageObj["content"] = nil + } } return map[string]any{ diff --git a/internal/js/helpers/stream-tool-sieve/parse_payload.js b/internal/js/helpers/stream-tool-sieve/parse_payload.js index 2970613..027e0b6 100644 --- a/internal/js/helpers/stream-tool-sieve/parse_payload.js +++ b/internal/js/helpers/stream-tool-sieve/parse_payload.js @@ -20,6 +20,7 @@ const TOOL_CALL_MARKUP_ARGS_PATTERNS = [ /<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i, /<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i, ]; +const NAMED_PARAMETER_PATTERN = /<(?:[a-z0-9_:-]+:)?parameter\b[^>]*\bname="([^"]+)"[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/gi; const TEXT_KV_NAME_PATTERN = /function\.name:\s*([a-zA-Z0-9_.-]+)/gi; const { @@ -266,6 +267,10 @@ function parseMarkupInput(raw) { if (!s) { return {}; } + const named = parseNamedParameterElements(s); + if (Object.keys(named).length > 0) { + return named; + } const parsed = parseToolCallInput(s); if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) { return parsed; @@ -277,6 +282,41 @@ function parseMarkupInput(raw) { return { _raw: stripTagText(s) }; } +function parseNamedParameterElements(text) { + const raw = toStringSafe(text); + if (!raw) { + return {}; + } + const out = {}; + for (const m of raw.matchAll(NAMED_PARAMETER_PATTERN)) { + const key = toStringSafe(m[1]).trim(); + if (!key) { + continue; + } + let value = toStringSafe(m[2]); + value = unwrapCDATA(value); + value = normalizeParamValue(value); + out[key] = value; + } + return out; +} + +function unwrapCDATA(value) { + const t = toStringSafe(value).trim(); + if (t.startsWith('')) { + return t.slice(''.length); + } + return value; +} + +function normalizeParamValue(value) { + const s = toStringSafe(value); + if (s.includes('\n')) { + return s.replace(/^\n+/, '').replace(/\n+$/, ''); + } + return s.trim(); +} + function parseMarkupKVObject(text) { const raw = toStringSafe(text).trim(); if (!raw) { diff --git a/internal/toolcall/toolcalls_parse_markup.go b/internal/toolcall/toolcalls_parse_markup.go index 899f8f3..d7cbd64 100644 --- a/internal/toolcall/toolcalls_parse_markup.go +++ b/internal/toolcall/toolcalls_parse_markup.go @@ -115,6 +115,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { if err := dec.DecodeElement(&node, &t); err == nil { inner := strings.TrimSpace(node.Inner) if inner != "" { + if named := parseXMLNamedParameters(inner); len(named) > 0 { + for k, vv := range named { + params[k] = vv + } + break + } unescapedInner := html.UnescapeString(inner) if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 { if len(parsed) == 1 { @@ -183,6 +189,55 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true } +func parseXMLNamedParameters(innerXML string) map[string]any { + raw := strings.TrimSpace(innerXML) + if raw == "" { + return nil + } + dec := xml.NewDecoder(strings.NewReader("" + raw + "")) + out := map[string]any{} + for { + tok, err := dec.Token() + if err != nil { + break + } + start, ok := tok.(xml.StartElement) + if !ok { + continue + } + if !strings.EqualFold(start.Name.Local, "parameter") { + continue + } + key := "" + for _, attr := range start.Attr { + if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") { + key = strings.TrimSpace(attr.Value) + break + } + } + var v string + if err := dec.DecodeElement(&v, &start); err != nil { + continue + } + val := normalizeToolParamValue(v) + if key == "" { + continue + } + out[key] = val + } + if len(out) == 0 { + return nil + } + return out +} + +func normalizeToolParamValue(v string) string { + if strings.Contains(v, "\n") { + return strings.Trim(v, "\n") + } + return strings.TrimSpace(v) +} + func stripTopLevelXMLParameters(inner string) string { out := strings.TrimSpace(inner) for { @@ -316,7 +371,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(html.UnescapeString(pm[2])) + v := normalizeToolParamValue(unwrapXMLCDATA(html.UnescapeString(pm[2]))) if k != "" { input[k] = v } @@ -347,7 +402,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(html.UnescapeString(pm[2])) + v := normalizeToolParamValue(unwrapXMLCDATA(html.UnescapeString(pm[2]))) if k != "" { input[k] = v } @@ -457,3 +512,11 @@ func asString(v any) string { s, _ := v.(string) return s } + +func unwrapXMLCDATA(s string) string { + t := strings.TrimSpace(s) + if strings.HasPrefix(t, "") && len(t) >= len("") { + return t[len("")] + } + return s +} diff --git a/internal/toolcall/toolcalls_test.go b/internal/toolcall/toolcalls_test.go index faa7322..89f7af2 100644 --- a/internal/toolcall/toolcalls_test.go +++ b/internal/toolcall/toolcalls_test.go @@ -318,6 +318,27 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) { } } +func TestParseToolCallsSupportsXMLNamedParameterStyle(t *testing.T) { + text := `write_to_filenovel_chapter1.md` + calls := ParseToolCalls(text, []string{"write_to_file"}) + if len(calls) != 1 { + t.Fatalf("expected 1 call, got %#v", calls) + } + if calls[0].Name != "write_to_file" { + t.Fatalf("expected tool name write_to_file, got %q", calls[0].Name) + } + if calls[0].Input["path"] != "novel_chapter1.md" { + t.Fatalf("expected path argument, got %#v", calls[0].Input) + } + content, _ := calls[0].Input["content"].(string) + if !strings.Contains(content, "第一章") || !strings.Contains(content, "这里是正文") { + t.Fatalf("expected content to preserve text, got %#v", calls[0].Input) + } +} + func TestParseToolCallsSupportsGeminiFunctionCallJSON(t *testing.T) { text := `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}` calls := ParseToolCalls(text, []string{"search_web"})