diff --git a/internal/toolcall/toolcalls_markup.go b/internal/toolcall/toolcalls_markup.go index 9a6ad2c..56546eb 100644 --- a/internal/toolcall/toolcalls_markup.go +++ b/internal/toolcall/toolcalls_markup.go @@ -2,6 +2,7 @@ package toolcall import ( "encoding/json" + "html" "regexp" "strings" ) @@ -92,7 +93,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall { } func parseMarkupInput(raw string) map[string]any { - raw = strings.TrimSpace(raw) + raw = strings.TrimSpace(html.UnescapeString(raw)) if raw == "" { return map[string]any{} } @@ -102,7 +103,7 @@ func parseMarkupInput(raw string) map[string]any { if kv := parseMarkupKVObject(raw); len(kv) > 0 { return kv } - return map[string]any{"_raw": stripTagText(raw)} + return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))} } func parseMarkupKVObject(text string) map[string]any { @@ -123,7 +124,7 @@ func parseMarkupKVObject(text string) map[string]any { if !strings.EqualFold(key, endKey) { continue } - value := strings.TrimSpace(stripTagText(m[2])) + value := strings.TrimSpace(html.UnescapeString(stripTagText(m[2]))) if value == "" { continue } diff --git a/internal/toolcall/toolcalls_parse_markup.go b/internal/toolcall/toolcalls_parse_markup.go index fa41036..d269e40 100644 --- a/internal/toolcall/toolcalls_parse_markup.go +++ b/internal/toolcall/toolcalls_parse_markup.go @@ -3,6 +3,7 @@ package toolcall import ( "encoding/json" "encoding/xml" + "html" "regexp" "strings" ) @@ -114,10 +115,11 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { if err := dec.DecodeElement(&node, &t); err == nil { inner := strings.TrimSpace(node.Inner) if inner != "" { - if parsed := parseToolCallInput(inner); len(parsed) > 0 { + unescapedInner := html.UnescapeString(inner) + if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 { if len(parsed) == 1 { if _, onlyRaw := parsed["_raw"]; onlyRaw { - if kv := parseMarkupKVObject(inner); len(kv) > 0 { + if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 { for k, vv := range kv { params[k] = vv } @@ -128,7 +130,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { for k, vv := range parsed { params[k] = vv } - } else if kv := parseMarkupKVObject(inner); len(kv) > 0 { + } else if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 { for k, vv := range kv { params[k] = vv } @@ -143,12 +145,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { params[t.Name.Local] = strings.TrimSpace(v) break } - name = strings.TrimSpace(v) + name = strings.TrimSpace(html.UnescapeString(v)) } case "input", "arguments", "argument", "args", "params": var v string if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" { - if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 { + if parsed := parseToolCallInput(strings.TrimSpace(html.UnescapeString(v))); len(parsed) > 0 { for k, vv := range parsed { params[k] = vv } @@ -158,7 +160,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { if inParams || inTool { var v string if err := dec.DecodeElement(&v, &t); err == nil { - params[t.Name.Local] = strings.TrimSpace(v) + params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v)) } } } @@ -173,12 +175,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) { } } if strings.TrimSpace(name) == "" { - name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner))) + name = strings.TrimSpace(html.UnescapeString(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))) } if strings.TrimSpace(name) == "" { return ParsedToolCall{}, false } - return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true + return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true } func stripTopLevelXMLParameters(inner string) string { @@ -231,7 +233,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { if len(m) < 2 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } @@ -241,7 +243,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) { continue } key := strings.TrimSpace(pm[1]) - val := strings.TrimSpace(pm[2]) + val := strings.TrimSpace(html.UnescapeString(pm[2])) if key != "" { input[key] = val } @@ -270,11 +272,11 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - body := strings.TrimSpace(m[2]) + body := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if strings.HasPrefix(body, "{") { if err := json.Unmarshal([]byte(body), &input); err == nil { @@ -291,7 +293,7 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(am[1]) - v := strings.TrimSpace(am[2]) + v := strings.TrimSpace(html.UnescapeString(am[2])) if k != "" { input[k] = v } @@ -304,7 +306,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } @@ -314,7 +316,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(pm[2]) + v := strings.TrimSpace(html.UnescapeString(pm[2])) if k != "" { input[k] = v } @@ -334,7 +336,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } @@ -345,7 +347,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) { continue } k := strings.TrimSpace(pm[1]) - v := strings.TrimSpace(pm[2]) + v := strings.TrimSpace(html.UnescapeString(pm[2])) if k != "" { input[k] = v } @@ -358,11 +360,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - raw := strings.TrimSpace(m[2]) + raw := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if raw != "" { if parsed := parseToolCallInput(raw); len(parsed) > 0 { @@ -379,11 +381,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool) if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - raw := strings.TrimSpace(m[2]) + raw := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if raw != "" { if parsed := parseToolCallInput(raw); len(parsed) > 0 { @@ -400,11 +402,11 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) { if len(m) < 3 { return ParsedToolCall{}, false } - name := strings.TrimSpace(m[1]) + name := strings.TrimSpace(html.UnescapeString(m[1])) if name == "" { return ParsedToolCall{}, false } - body := strings.TrimSpace(m[2]) + body := strings.TrimSpace(html.UnescapeString(m[2])) input := map[string]any{} if body != "" { if kv := parseXMLChildKV(body); len(kv) > 0 { diff --git a/internal/toolcall/toolcalls_test.go b/internal/toolcall/toolcalls_test.go index 663b895..faa7322 100644 --- a/internal/toolcall/toolcalls_test.go +++ b/internal/toolcall/toolcalls_test.go @@ -691,3 +691,27 @@ func TestRepairLooseJSONWithNestedObjects(t *testing.T) { } } } + +func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) { + text := `Bash{"command":"echo a > out.txt"}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected one call, got %#v", calls) + } + cmd, _ := calls[0].Input["command"].(string) + if cmd != "echo a > out.txt" { + t.Fatalf("expected html entities to be unescaped in command, got %q", cmd) + } +} + +func TestParseToolCallsJSONPayloadKeepsLiteralEntities(t *testing.T) { + text := `{"tool_calls":[{"name":"bash","input":{"command":"echo > literally"}}]}` + calls := ParseToolCalls(text, []string{"bash"}) + if len(calls) != 1 { + t.Fatalf("expected one call, got %#v", calls) + } + cmd, _ := calls[0].Input["command"].(string) + if cmd != "echo > literally" { + t.Fatalf("expected json payload to keep literal entities, got %q", cmd) + } +} diff --git a/internal/translatorcliproxy/bridge.go b/internal/translatorcliproxy/bridge.go index e5dc5ac..c5d6741 100644 --- a/internal/translatorcliproxy/bridge.go +++ b/internal/translatorcliproxy/bridge.go @@ -3,6 +3,7 @@ package translatorcliproxy import ( "bytes" "context" + "encoding/json" "strings" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -15,7 +16,12 @@ func ToOpenAI(from sdktranslator.Format, model string, raw []byte, stream bool) func FromOpenAINonStream(to sdktranslator.Format, model string, originalReq, translatedReq, raw []byte) []byte { var param any - return sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m) + converted := sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, ¶m) + usage, ok := extractOpenAIUsageFromJSON(raw) + if !ok { + return converted + } + return injectNonStreamUsageMetadata(converted, to, usage) } func FromOpenAIStream(to sdktranslator.Format, model string, originalReq, translatedReq, streamBody []byte) []byte { @@ -65,3 +71,57 @@ func ParseFormat(name string) sdktranslator.Format { func ToOpenAIByName(formatName, model string, raw []byte, stream bool) []byte { return ToOpenAI(ParseFormat(formatName), model, raw, stream) } + +func extractOpenAIUsageFromJSON(raw []byte) (openAIUsage, bool) { + payload := map[string]any{} + if err := json.Unmarshal(raw, &payload); err != nil { + return openAIUsage{}, false + } + usageObj, _ := payload["usage"].(map[string]any) + if usageObj == nil { + return openAIUsage{}, false + } + p := toInt(usageObj["prompt_tokens"]) + c := toInt(usageObj["completion_tokens"]) + t := toInt(usageObj["total_tokens"]) + if p <= 0 { + p = toInt(usageObj["input_tokens"]) + } + if c <= 0 { + c = toInt(usageObj["output_tokens"]) + } + if t <= 0 { + t = p + c + } + if p <= 0 && c <= 0 && t <= 0 { + return openAIUsage{}, false + } + return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true +} + +func injectNonStreamUsageMetadata(converted []byte, target sdktranslator.Format, usage openAIUsage) []byte { + obj := map[string]any{} + if err := json.Unmarshal(converted, &obj); err != nil { + return converted + } + switch target { + case sdktranslator.FormatClaude: + obj["usage"] = map[string]any{ + "input_tokens": usage.PromptTokens, + "output_tokens": usage.CompletionTokens, + } + case sdktranslator.FormatGemini: + obj["usageMetadata"] = map[string]any{ + "promptTokenCount": usage.PromptTokens, + "candidatesTokenCount": usage.CompletionTokens, + "totalTokenCount": usage.TotalTokens, + } + default: + return converted + } + out, err := json.Marshal(obj) + if err != nil { + return converted + } + return out +} diff --git a/internal/translatorcliproxy/bridge_test.go b/internal/translatorcliproxy/bridge_test.go index cdd9cf7..9dbfe30 100644 --- a/internal/translatorcliproxy/bridge_test.go +++ b/internal/translatorcliproxy/bridge_test.go @@ -46,6 +46,22 @@ func TestFromOpenAINonStreamGeminiPreservesUsageFromOpenAI(t *testing.T) { } } +func TestFromOpenAINonStreamPreservesResponsesUsageShape(t *testing.T) { + original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`) + translatedReq := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":false}`) + openaibody := []byte(`{"id":"resp_1","object":"response","model":"gemini-2.5-pro","usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`) + gotGemini := string(FromOpenAINonStream(sdktranslator.FormatGemini, "gemini-2.5-pro", original, translatedReq, openaibody)) + if !strings.Contains(gotGemini, `"promptTokenCount":11`) || !strings.Contains(gotGemini, `"candidatesTokenCount":29`) || !strings.Contains(gotGemini, `"totalTokenCount":40`) { + t.Fatalf("expected gemini usageMetadata from input/output usage fields, got: %s", gotGemini) + } + + origClaude := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`) + gotClaude := string(FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", origClaude, origClaude, openaibody)) + if !strings.Contains(gotClaude, `"input_tokens":11`) || !strings.Contains(gotClaude, `"output_tokens":29`) { + t.Fatalf("expected claude usage from input/output usage fields, got: %s", gotClaude) + } +} + func TestParseFormatAliases(t *testing.T) { cases := map[string]sdktranslator.Format{ "responses": sdktranslator.FormatOpenAIResponse, diff --git a/internal/translatorcliproxy/stream_writer.go b/internal/translatorcliproxy/stream_writer.go index e80ce69..ac7fc41 100644 --- a/internal/translatorcliproxy/stream_writer.go +++ b/internal/translatorcliproxy/stream_writer.go @@ -5,6 +5,7 @@ import ( "context" "encoding/json" "net/http" + "strconv" "strings" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" @@ -149,6 +150,12 @@ func extractOpenAIUsage(line []byte) (openAIUsage, bool) { p := toInt(usageObj["prompt_tokens"]) c := toInt(usageObj["completion_tokens"]) t := toInt(usageObj["total_tokens"]) + if p <= 0 { + p = toInt(usageObj["input_tokens"]) + } + if c <= 0 { + c = toInt(usageObj["output_tokens"]) + } if p <= 0 && c <= 0 && t <= 0 { return openAIUsage{}, false } @@ -221,6 +228,12 @@ func toInt(v any) int { return int(x) case float32: return int(x) + case string: + n, err := strconv.Atoi(strings.TrimSpace(x)) + if err != nil { + return 0 + } + return n default: return 0 } diff --git a/internal/translatorcliproxy/stream_writer_test.go b/internal/translatorcliproxy/stream_writer_test.go index 94d70b8..f4758d4 100644 --- a/internal/translatorcliproxy/stream_writer_test.go +++ b/internal/translatorcliproxy/stream_writer_test.go @@ -75,3 +75,14 @@ func TestInjectStreamUsageMetadataPreservesSSEFrameTerminator(t *testing.T) { t.Fatalf("expected usageMetadata injected, got %q", string(got)) } } + +func TestExtractOpenAIUsageSupportsResponsesUsageFields(t *testing.T) { + line := []byte(`data: {"usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`) + got, ok := extractOpenAIUsage(line) + if !ok { + t.Fatal("expected usage extracted from input/output usage fields") + } + if got.PromptTokens != 11 || got.CompletionTokens != 29 || got.TotalTokens != 40 { + t.Fatalf("unexpected usage extracted: %#v", got) + } +}