Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions internal/toolcall/toolcalls_markup.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package toolcall

import (
"encoding/json"
"html"
"regexp"
"strings"
)
Expand Down Expand Up @@ -92,7 +93,7 @@ func parseMarkupSingleToolCall(attrs string, inner string) ParsedToolCall {
}

func parseMarkupInput(raw string) map[string]any {
raw = strings.TrimSpace(raw)
raw = strings.TrimSpace(html.UnescapeString(raw))
if raw == "" {
return map[string]any{}
}
Expand All @@ -102,7 +103,7 @@ func parseMarkupInput(raw string) map[string]any {
if kv := parseMarkupKVObject(raw); len(kv) > 0 {
return kv
}
return map[string]any{"_raw": stripTagText(raw)}
return map[string]any{"_raw": html.UnescapeString(stripTagText(raw))}
}

func parseMarkupKVObject(text string) map[string]any {
Expand All @@ -123,7 +124,7 @@ func parseMarkupKVObject(text string) map[string]any {
if !strings.EqualFold(key, endKey) {
continue
}
value := strings.TrimSpace(stripTagText(m[2]))
value := strings.TrimSpace(html.UnescapeString(stripTagText(m[2])))
if value == "" {
continue
}
Expand Down
48 changes: 25 additions & 23 deletions internal/toolcall/toolcalls_parse_markup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package toolcall
import (
"encoding/json"
"encoding/xml"
"html"
"regexp"
"strings"
)
Expand Down Expand Up @@ -114,10 +115,11 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
if err := dec.DecodeElement(&node, &t); err == nil {
inner := strings.TrimSpace(node.Inner)
if inner != "" {
if parsed := parseToolCallInput(inner); len(parsed) > 0 {
unescapedInner := html.UnescapeString(inner)
if parsed := parseToolCallInput(unescapedInner); len(parsed) > 0 {
if len(parsed) == 1 {
if _, onlyRaw := parsed["_raw"]; onlyRaw {
if kv := parseMarkupKVObject(inner); len(kv) > 0 {
if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 {
for k, vv := range kv {
params[k] = vv
}
Expand All @@ -128,7 +130,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
for k, vv := range parsed {
params[k] = vv
}
} else if kv := parseMarkupKVObject(inner); len(kv) > 0 {
} else if kv := parseMarkupKVObject(unescapedInner); len(kv) > 0 {
for k, vv := range kv {
params[k] = vv
}
Expand All @@ -143,12 +145,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
params[t.Name.Local] = strings.TrimSpace(v)
break
}
name = strings.TrimSpace(v)
name = strings.TrimSpace(html.UnescapeString(v))
}
case "input", "arguments", "argument", "args", "params":
var v string
if err := dec.DecodeElement(&v, &t); err == nil && strings.TrimSpace(v) != "" {
if parsed := parseToolCallInput(strings.TrimSpace(v)); len(parsed) > 0 {
if parsed := parseToolCallInput(strings.TrimSpace(html.UnescapeString(v))); len(parsed) > 0 {
for k, vv := range parsed {
params[k] = vv
}
Expand All @@ -158,7 +160,7 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
if inParams || inTool {
var v string
if err := dec.DecodeElement(&v, &t); err == nil {
params[t.Name.Local] = strings.TrimSpace(v)
params[t.Name.Local] = strings.TrimSpace(html.UnescapeString(v))
}
}
}
Expand All @@ -173,12 +175,12 @@ func parseSingleXMLToolCall(block string) (ParsedToolCall, bool) {
}
}
if strings.TrimSpace(name) == "" {
name = strings.TrimSpace(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner)))
name = strings.TrimSpace(html.UnescapeString(extractXMLToolNameByRegex(stripTopLevelXMLParameters(inner))))
}
if strings.TrimSpace(name) == "" {
return ParsedToolCall{}, false
}
return ParsedToolCall{Name: strings.TrimSpace(name), Input: params}, true
return ParsedToolCall{Name: strings.TrimSpace(html.UnescapeString(name)), Input: params}, true
}

func stripTopLevelXMLParameters(inner string) string {
Expand Down Expand Up @@ -231,7 +233,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
if len(m) < 2 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
Expand All @@ -241,7 +243,7 @@ func parseFunctionCallTagStyle(text string) (ParsedToolCall, bool) {
continue
}
key := strings.TrimSpace(pm[1])
val := strings.TrimSpace(pm[2])
val := strings.TrimSpace(html.UnescapeString(pm[2]))
if key != "" {
input[key] = val
}
Expand Down Expand Up @@ -270,11 +272,11 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
body := strings.TrimSpace(m[2])
body := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if strings.HasPrefix(body, "{") {
if err := json.Unmarshal([]byte(body), &input); err == nil {
Expand All @@ -291,7 +293,7 @@ func parseSingleAntmlFunctionCallMatch(m []string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(am[1])
v := strings.TrimSpace(am[2])
v := strings.TrimSpace(html.UnescapeString(am[2]))
if k != "" {
input[k] = v
}
Expand All @@ -304,7 +306,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
Expand All @@ -314,7 +316,7 @@ func parseInvokeFunctionCallStyle(text string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(pm[2])
v := strings.TrimSpace(html.UnescapeString(pm[2]))
if k != "" {
input[k] = v
}
Expand All @@ -334,7 +336,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
Expand All @@ -345,7 +347,7 @@ func parseToolUseFunctionStyle(text string) (ParsedToolCall, bool) {
continue
}
k := strings.TrimSpace(pm[1])
v := strings.TrimSpace(pm[2])
v := strings.TrimSpace(html.UnescapeString(pm[2]))
if k != "" {
input[k] = v
}
Expand All @@ -358,11 +360,11 @@ func parseToolUseNameParametersStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
raw := strings.TrimSpace(m[2])
raw := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if raw != "" {
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
Expand All @@ -379,11 +381,11 @@ func parseToolUseFunctionNameParametersStyle(text string) (ParsedToolCall, bool)
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
raw := strings.TrimSpace(m[2])
raw := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if raw != "" {
if parsed := parseToolCallInput(raw); len(parsed) > 0 {
Expand All @@ -400,11 +402,11 @@ func parseToolUseToolNameBodyStyle(text string) (ParsedToolCall, bool) {
if len(m) < 3 {
return ParsedToolCall{}, false
}
name := strings.TrimSpace(m[1])
name := strings.TrimSpace(html.UnescapeString(m[1]))
if name == "" {
return ParsedToolCall{}, false
}
body := strings.TrimSpace(m[2])
body := strings.TrimSpace(html.UnescapeString(m[2]))
input := map[string]any{}
if body != "" {
if kv := parseXMLChildKV(body); len(kv) > 0 {
Expand Down
24 changes: 24 additions & 0 deletions internal/toolcall/toolcalls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -691,3 +691,27 @@ func TestRepairLooseJSONWithNestedObjects(t *testing.T) {
}
}
}

func TestParseToolCallsUnescapesHTMLEntityArguments(t *testing.T) {
text := `<tool_call><tool_name>Bash</tool_name><parameters>{"command":"echo a &gt; out.txt"}</parameters></tool_call>`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected one call, got %#v", calls)
}
cmd, _ := calls[0].Input["command"].(string)
if cmd != "echo a > out.txt" {
t.Fatalf("expected html entities to be unescaped in command, got %q", cmd)
}
}

func TestParseToolCallsJSONPayloadKeepsLiteralEntities(t *testing.T) {
text := `{"tool_calls":[{"name":"bash","input":{"command":"echo &gt; literally"}}]}`
calls := ParseToolCalls(text, []string{"bash"})
if len(calls) != 1 {
t.Fatalf("expected one call, got %#v", calls)
}
cmd, _ := calls[0].Input["command"].(string)
if cmd != "echo &gt; literally" {
t.Fatalf("expected json payload to keep literal entities, got %q", cmd)
}
}
62 changes: 61 additions & 1 deletion internal/translatorcliproxy/bridge.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package translatorcliproxy
import (
"bytes"
"context"
"encoding/json"
"strings"

sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
Expand All @@ -15,7 +16,12 @@ func ToOpenAI(from sdktranslator.Format, model string, raw []byte, stream bool)

func FromOpenAINonStream(to sdktranslator.Format, model string, originalReq, translatedReq, raw []byte) []byte {
var param any
return sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, &param)
converted := sdktranslator.TranslateNonStream(context.Background(), sdktranslator.FormatOpenAI, to, model, originalReq, translatedReq, raw, &param)
usage, ok := extractOpenAIUsageFromJSON(raw)
if !ok {
return converted
}
return injectNonStreamUsageMetadata(converted, to, usage)
}

func FromOpenAIStream(to sdktranslator.Format, model string, originalReq, translatedReq, streamBody []byte) []byte {
Expand Down Expand Up @@ -65,3 +71,57 @@ func ParseFormat(name string) sdktranslator.Format {
func ToOpenAIByName(formatName, model string, raw []byte, stream bool) []byte {
return ToOpenAI(ParseFormat(formatName), model, raw, stream)
}

func extractOpenAIUsageFromJSON(raw []byte) (openAIUsage, bool) {
payload := map[string]any{}
if err := json.Unmarshal(raw, &payload); err != nil {
return openAIUsage{}, false
}
usageObj, _ := payload["usage"].(map[string]any)
if usageObj == nil {
return openAIUsage{}, false
}
p := toInt(usageObj["prompt_tokens"])
c := toInt(usageObj["completion_tokens"])
t := toInt(usageObj["total_tokens"])
if p <= 0 {
p = toInt(usageObj["input_tokens"])
}
if c <= 0 {
c = toInt(usageObj["output_tokens"])
}
if t <= 0 {
t = p + c
}
if p <= 0 && c <= 0 && t <= 0 {
return openAIUsage{}, false
}
return openAIUsage{PromptTokens: p, CompletionTokens: c, TotalTokens: t}, true
}

func injectNonStreamUsageMetadata(converted []byte, target sdktranslator.Format, usage openAIUsage) []byte {
obj := map[string]any{}
if err := json.Unmarshal(converted, &obj); err != nil {
return converted
}
switch target {
case sdktranslator.FormatClaude:
obj["usage"] = map[string]any{
"input_tokens": usage.PromptTokens,
"output_tokens": usage.CompletionTokens,
}
case sdktranslator.FormatGemini:
obj["usageMetadata"] = map[string]any{
"promptTokenCount": usage.PromptTokens,
"candidatesTokenCount": usage.CompletionTokens,
"totalTokenCount": usage.TotalTokens,
}
default:
return converted
}
out, err := json.Marshal(obj)
if err != nil {
return converted
}
return out
}
16 changes: 16 additions & 0 deletions internal/translatorcliproxy/bridge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,22 @@ func TestFromOpenAINonStreamGeminiPreservesUsageFromOpenAI(t *testing.T) {
}
}

func TestFromOpenAINonStreamPreservesResponsesUsageShape(t *testing.T) {
original := []byte(`{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}`)
translatedReq := []byte(`{"model":"gemini-2.5-pro","messages":[{"role":"user","content":"hi"}],"stream":false}`)
openaibody := []byte(`{"id":"resp_1","object":"response","model":"gemini-2.5-pro","usage":{"input_tokens":"11","output_tokens":"29","total_tokens":"40"}}`)
gotGemini := string(FromOpenAINonStream(sdktranslator.FormatGemini, "gemini-2.5-pro", original, translatedReq, openaibody))
if !strings.Contains(gotGemini, `"promptTokenCount":11`) || !strings.Contains(gotGemini, `"candidatesTokenCount":29`) || !strings.Contains(gotGemini, `"totalTokenCount":40`) {
t.Fatalf("expected gemini usageMetadata from input/output usage fields, got: %s", gotGemini)
}

origClaude := []byte(`{"model":"claude-sonnet-4-5","messages":[{"role":"user","content":"hi"}],"stream":false}`)
gotClaude := string(FromOpenAINonStream(sdktranslator.FormatClaude, "claude-sonnet-4-5", origClaude, origClaude, openaibody))
if !strings.Contains(gotClaude, `"input_tokens":11`) || !strings.Contains(gotClaude, `"output_tokens":29`) {
t.Fatalf("expected claude usage from input/output usage fields, got: %s", gotClaude)
}
}

func TestParseFormatAliases(t *testing.T) {
cases := map[string]sdktranslator.Format{
"responses": sdktranslator.FormatOpenAIResponse,
Expand Down
13 changes: 13 additions & 0 deletions internal/translatorcliproxy/stream_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/json"
"net/http"
"strconv"
"strings"

sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
Expand Down Expand Up @@ -149,6 +150,12 @@ func extractOpenAIUsage(line []byte) (openAIUsage, bool) {
p := toInt(usageObj["prompt_tokens"])
c := toInt(usageObj["completion_tokens"])
t := toInt(usageObj["total_tokens"])
if p <= 0 {
p = toInt(usageObj["input_tokens"])
}
if c <= 0 {
c = toInt(usageObj["output_tokens"])
}
if p <= 0 && c <= 0 && t <= 0 {
return openAIUsage{}, false
}
Expand Down Expand Up @@ -221,6 +228,12 @@ func toInt(v any) int {
return int(x)
case float32:
return int(x)
case string:
n, err := strconv.Atoi(strings.TrimSpace(x))
if err != nil {
return 0
}
return n
default:
return 0
}
Expand Down
Loading
Loading