Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions internal/adapter/openai/handler_toolcall_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,35 @@ func TestHandleNonStreamUnknownToolIntercepted(t *testing.T) {
}
}

func TestHandleNonStreamDoesNotPromoteToolCallsWhenNoToolsProvided(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
`data: {"p":"response/content","v":"{\"tool_calls\":[{\"name\":\"search\",\"input\":{\"q\":\"go\"}}]}"}`,
`data: [DONE]`,
)
rec := httptest.NewRecorder()

h.handleNonStream(rec, context.Background(), resp, "cid2b2", "deepseek-chat", "prompt", false, nil)
if rec.Code != http.StatusOK {
t.Fatalf("unexpected status: %d", rec.Code)
}

out := decodeJSONBody(t, rec.Body.String())
choices, _ := out["choices"].([]any)
choice, _ := choices[0].(map[string]any)
if choice["finish_reason"] != "stop" {
t.Fatalf("expected finish_reason=stop, got %#v", choice["finish_reason"])
}
msg, _ := choice["message"].(map[string]any)
if msg["tool_calls"] != nil {
t.Fatalf("expected no tool_calls, got %#v", msg["tool_calls"])
}
content, _ := msg["content"].(string)
if !strings.Contains(content, `"tool_calls"`) {
t.Fatalf("expected tool_calls json preserved as content, got %#v", content)
}
}

func TestHandleNonStreamEmbeddedToolCallExamplePromotesToolCall(t *testing.T) {
h := &Handler{}
resp := makeSSEHTTPResponse(
Expand Down
12 changes: 7 additions & 5 deletions internal/format/openai/render_chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@ import (
)

func BuildChatCompletion(completionID, model, finalPrompt, finalThinking, finalText string, toolNames []string) map[string]any {
detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, toolNames)
finishReason := "stop"
messageObj := map[string]any{"role": "assistant", "content": finalText}
if strings.TrimSpace(finalThinking) != "" {
messageObj["reasoning_content"] = finalThinking
}
if len(detected.Calls) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected.Calls)
messageObj["content"] = nil
if len(toolNames) > 0 {
detected := toolcall.ParseStandaloneToolCallsDetailed(finalText, toolNames)
if len(detected.Calls) > 0 {
finishReason = "tool_calls"
messageObj["tool_calls"] = toolcall.FormatOpenAIToolCalls(detected.Calls)
messageObj["content"] = nil
}
}

return map[string]any{
Expand Down
40 changes: 40 additions & 0 deletions internal/js/helpers/stream-tool-sieve/parse_payload.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const TOOL_CALL_MARKUP_ARGS_PATTERNS = [
/<(?:[a-z0-9_:-]+:)?args\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?args>/i,
/<(?:[a-z0-9_:-]+:)?params\b[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?params>/i,
];
const NAMED_PARAMETER_PATTERN = /<(?:[a-z0-9_:-]+:)?parameter\b[^>]*\bname="([^"]+)"[^>]*>([\s\S]*?)<\/(?:[a-z0-9_:-]+:)?parameter>/gi;
const TEXT_KV_NAME_PATTERN = /function\.name:\s*([a-zA-Z0-9_.-]+)/gi;

const {
Expand Down Expand Up @@ -266,6 +267,10 @@ function parseMarkupInput(raw) {
if (!s) {
return {};
}
const named = parseNamedParameterElements(s);
if (Object.keys(named).length > 0) {
return named;
}
const parsed = parseToolCallInput(s);
if (parsed && typeof parsed === 'object' && !Array.isArray(parsed) && Object.keys(parsed).length > 0) {
return parsed;
Expand All @@ -277,6 +282,41 @@ function parseMarkupInput(raw) {
return { _raw: stripTagText(s) };
}

function parseNamedParameterElements(text) {
const raw = toStringSafe(text);
if (!raw) {
return {};
}
const out = {};
for (const m of raw.matchAll(NAMED_PARAMETER_PATTERN)) {
const key = toStringSafe(m[1]).trim();
if (!key) {
continue;
}
let value = toStringSafe(m[2]);
value = unwrapCDATA(value);
value = normalizeParamValue(value);
out[key] = value;
}
return out;
}

function unwrapCDATA(value) {
const t = toStringSafe(value).trim();
if (t.startsWith('<![CDATA[') && t.endsWith(']]>')) {
return t.slice('<![CDATA['.length, t.length - ']]>'.length);
}
return value;
}

function normalizeParamValue(value) {
const s = toStringSafe(value);
if (s.includes('\n')) {
return s.replace(/^\n+/, '').replace(/\n+$/, '');
}
return s.trim();
}

function parseMarkupKVObject(text) {
const raw = toStringSafe(text).trim();
if (!raw) {
Expand Down
67 changes: 65 additions & 2 deletions internal/toolcall/toolcalls_parse_markup.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
if err := dec.DecodeElement(&node, &t); err == nil {
inner := strings.TrimSpace(node.Inner)
if inner != "" {
if named := parseXMLNamedParameters(inner); len(named) > 0 {
for k, vv := range named {
params[k] = vv
}
break
}
unescapedInner := html.UnescapeString(inner)
if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 {
if len(parsed) == 1 {
Expand Down Expand Up @@ -183,6 +189,55 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true
}

func parseXMLNamedParameters(innerXML string) map[string]any {
raw := strings.TrimSpace(innerXML)
if raw == "" {
return nil
}
dec := xml.NewDecoder(strings.NewReader("<root>" + raw + "</root>"))
out := map[string]any{}
for {
tok, err := dec.Token()
if err != nil {
break
}
start, ok := tok.(xml.StartElement)
if !ok {
continue
}
if !strings.EqualFold(start.Name.Local, "parameter") {
continue
}
key := ""
for _, attr := range start.Attr {
if strings.EqualFold(strings.TrimSpace(attr.Name.Local), "name") {
key = strings.TrimSpace(attr.Value)
break
}
}
var v string
if err := dec.DecodeElement(&v, &start); err != nil {
continue
}
val := normalizeToolParamValue(v)
if key == "" {
continue
}
out[key] = val
}
if len(out) == 0 {
return nil
}
return out
}

func normalizeToolParamValue(v string) string {
if strings.Contains(v, "\n") {
return strings.Trim(v, "\n")
}
return strings.TrimSpace(v)
}

func stripTopLevelXMLParameters(inner string) string {
out := strings.TrimSpace(inner)
for {
Expand Down Expand Up @@ -316,7 +371,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(html.UnescapeString(pm[2]))
v := normalizeToolParamValue(unwrapXMLCDATA(html.UnescapeString(pm[2])))
if k != "" {
input[k] = v
}
Expand Down Expand Up @@ -347,7 +402,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(html.UnescapeString(pm[2]))
v := normalizeToolParamValue(unwrapXMLCDATA(html.UnescapeString(pm[2])))
if k != "" {
input[k] = v
}
Expand Down Expand Up @@ -457,3 +512,11 @@ func asString(v any) string {
s, _ := v.(string)
return s
}

func unwrapXMLCDATA(s string) string {
t := strings.TrimSpace(s)
if strings.HasPrefix(t, "<![CDATA[") && strings.HasSuffix(t, "]]>") && len(t) >= len("<![CDATA[")+len("]]>") {
return t[len("<![CDATA[") : len(t)-len("]]>")]
}
return s
}
21 changes: 21 additions & 0 deletions internal/toolcall/toolcalls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,27 @@ func TestParseToolCallsSupportsInvokeFunctionCallStyle(t *testing.T) {
}
}

func TestParseToolCallsSupportsXMLNamedParameterStyle(t *testing.T) {
text := `<tool_calls><tool_call><tool_name>write_to_file</tool_name><parameters><parameter name="path">novel_chapter1.md</parameter><parameter name="content"><![CDATA[# 第一章:启航

这里是正文
]]></parameter></parameters></tool_call></tool_calls>`
calls := ParseToolCalls(text, []string{"write_to_file"})
if len(calls) != 1 {
t.Fatalf("expected 1 call, got %#v", calls)
}
if calls[0].Name != "write_to_file" {
t.Fatalf("expected tool name write_to_file, got %q", calls[0].Name)
}
if calls[0].Input["path"] != "novel_chapter1.md" {
t.Fatalf("expected path argument, got %#v", calls[0].Input)
}
content, _ := calls[0].Input["content"].(string)
if !strings.Contains(content, "第一章") || !strings.Contains(content, "这里是正文") {
t.Fatalf("expected content to preserve text, got %#v", calls[0].Input)
}
}

func TestParseToolCallsSupportsGeminiFunctionCallJSON(t *testing.T) {
text := `{"functionCall":{"name":"search_web","args":{"query":"latest"}}}`
calls := ParseToolCalls(text, []string{"search_web"})
Expand Down