diff --git a/README.md b/README.md index e88417e..df35da5 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ It is a single Go binary with zero dependencies. 15 MB distroless image. Two por ```mermaid flowchart LR A[Agent
bearer token] -->|request| P[cllama-passthrough
identity → route → swap key
extract usage → record cost] - P -->|real key| U[Provider
OpenAI · Anthropic
OpenRouter · Ollama
] + P -->|real key| U[Provider
OpenAI · Anthropic
OpenRouter · Google · Ollama
] U -->|response| P P -->|response| A P --- D[:8081 dashboard
providers · pod · costs · api] @@ -84,6 +84,7 @@ Or with Docker: docker run -p 8080:8080 -p 8081:8081 \ -e ANTHROPIC_API_KEY=sk-ant-... \ -e OPENROUTER_API_KEY=sk-or-... \ + -e GEMINI_API_KEY=sk-gemini-... \ -v ./context:/claw/context:ro \ ghcr.io/mostlydev/cllama:latest ``` @@ -105,6 +106,9 @@ docker run -p 8080:8080 -p 8081:8081 \ | `OPENAI_API_KEY` | | Provider key override | | `ANTHROPIC_API_KEY` | | Provider key override | | `OPENROUTER_API_KEY` | | Provider key override | +| `GEMINI_API_KEY` | | Primary Google Gemini provider key override | +| `GOOGLE_API_KEY` | | Lower-priority alias for the Google Gemini provider key | +| `GOOGLE_BASE_URL` | | Override for Google's OpenAI-compatible base URL | Environment variables override keys saved via the web UI. @@ -153,6 +157,11 @@ When orchestrated by Clawdapus, `claw up` generates all of this — tokens via ` "api_key": "sk-or-...", "auth": "bearer" }, + "google": { + "base_url": "https://generativelanguage.googleapis.com/v1beta/openai", + "api_key": "sk-gemini-...", + "auth": "bearer" + }, "ollama": { "base_url": "http://ollama:11434/v1", "auth": "none" @@ -161,7 +170,7 @@ When orchestrated by Clawdapus, `claw up` generates all of this — tokens via ` } ``` -Auth schemes: `bearer` (OpenAI, OpenRouter), `x-api-key` (Anthropic), `none` (Ollama, local models). +Auth schemes: `bearer` (OpenAI, OpenRouter, Google), `x-api-key` (Anthropic), `none` (Ollama, local models). --- diff --git a/internal/cost/pricing.go b/internal/cost/pricing.go index c4aed25..a1e1430 100644 --- a/internal/cost/pricing.go +++ b/internal/cost/pricing.go @@ -67,10 +67,15 @@ func DefaultPricing() *Pricing { }, "openrouter": { // OpenRouter passes through to upstream providers; rates match origin pricing. - "anthropic/claude-sonnet-4": {InputPerMTok: 3.0, OutputPerMTok: 15.0}, - "anthropic/claude-haiku-3-5": {InputPerMTok: 0.80, OutputPerMTok: 4.0}, - "google/gemini-2.5-pro": {InputPerMTok: 1.25, OutputPerMTok: 10.0}, - "google/gemini-2.5-flash": {InputPerMTok: 0.15, OutputPerMTok: 0.60}, + "anthropic/claude-sonnet-4": {InputPerMTok: 3.0, OutputPerMTok: 15.0}, + "anthropic/claude-haiku-3-5": {InputPerMTok: 0.80, OutputPerMTok: 4.0}, + "google/gemini-2.5-pro": {InputPerMTok: 1.25, OutputPerMTok: 10.0}, + "google/gemini-2.5-flash": {InputPerMTok: 0.15, OutputPerMTok: 0.60}, + }, + // Google pricing is simplified to the standard <=200k-token text tier. + "google": { + "gemini-2.5-pro": {InputPerMTok: 1.25, OutputPerMTok: 10.0}, + "gemini-2.5-flash": {InputPerMTok: 0.30, OutputPerMTok: 2.50}, }, }} } diff --git a/internal/cost/pricing_test.go b/internal/cost/pricing_test.go index 4839a94..23e01bc 100644 --- a/internal/cost/pricing_test.go +++ b/internal/cost/pricing_test.go @@ -32,6 +32,28 @@ func TestLookupOpenAIModel(t *testing.T) { } } +func TestLookupGoogleGeminiFlashModel(t *testing.T) { + p := DefaultPricing() + rate, ok := p.Lookup("google", "gemini-2.5-flash") + if !ok { + t.Fatal("expected to find google/gemini-2.5-flash") + } + if rate.InputPerMTok != 0.30 || rate.OutputPerMTok != 2.50 { + t.Fatalf("unexpected google/gemini-2.5-flash rates: %+v", rate) + } +} + +func TestLookupGoogleGeminiProModel(t *testing.T) { + p := DefaultPricing() + rate, ok := p.Lookup("google", "gemini-2.5-pro") + if !ok { + t.Fatal("expected to find google/gemini-2.5-pro") + } + if rate.InputPerMTok != 1.25 || rate.OutputPerMTok != 10.0 { + t.Fatalf("unexpected google/gemini-2.5-pro rates: %+v", rate) + } +} + func TestComputeCost(t *testing.T) { rate := Rate{InputPerMTok: 3.0, OutputPerMTok: 15.0} cost := rate.Compute(1000, 500) diff --git a/internal/provider/provider.go b/internal/provider/provider.go index cb844fe..ed44f53 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -77,6 +77,7 @@ var knownProviders = map[string]string{ "xai": "https://api.x.ai/v1", "anthropic": "https://api.anthropic.com/v1", "openrouter": "https://openrouter.ai/api/v1", + "google": "https://generativelanguage.googleapis.com/v1beta/openai", "ollama": "http://ollama:11434/v1", } @@ -92,6 +93,9 @@ var envKeyMap = map[string]string{ "ANTHROPIC_API_KEY_1": "anthropic", "OPENROUTER_API_KEY": "openrouter", "OPENROUTER_API_KEY_1": "openrouter", + "GEMINI_API_KEY": "google", + "GEMINI_API_KEY_1": "google", + "GOOGLE_API_KEY": "google", } var envBaseURLMap = map[string]string{ @@ -99,6 +103,7 @@ var envBaseURLMap = map[string]string{ "XAI_BASE_URL": "xai", "ANTHROPIC_BASE_URL": "anthropic", "OPENROUTER_BASE_URL": "openrouter", + "GOOGLE_BASE_URL": "google", "OLLAMA_BASE_URL": "ollama", } @@ -270,6 +275,11 @@ func (r *Registry) LoadFromEnv() { {"OPENROUTER_API_KEY", "seed:OPENROUTER_API_KEY", "primary"}, {"OPENROUTER_API_KEY_1", "seed:OPENROUTER_API_KEY_1", "backup-1"}, }, + "google": { + {"GEMINI_API_KEY", "seed:GEMINI_API_KEY", "primary"}, + {"GEMINI_API_KEY_1", "seed:GEMINI_API_KEY_1", "backup-1"}, + {"GOOGLE_API_KEY", "seed:GOOGLE_API_KEY", "backup-2"}, + }, } for provName, defs := range envKeysByProvider { diff --git a/internal/provider/provider_test.go b/internal/provider/provider_test.go index 81b20e6..7ddf293 100644 --- a/internal/provider/provider_test.go +++ b/internal/provider/provider_test.go @@ -147,6 +147,113 @@ func TestLoadFromEnvAppliesXAIBaseURLOverride(t *testing.T) { } } +func TestLoadFromEnvSeedsGoogleProviderFromGeminiKey(t *testing.T) { + t.Setenv("GEMINI_API_KEY", "sk-gemini-primary") + + r := NewRegistry("") + r.LoadFromEnv() + + p, err := r.Get("google") + if err != nil { + t.Fatalf("google: %v", err) + } + if p.APIKey != "sk-gemini-primary" { + t.Fatalf("expected google key from GEMINI_API_KEY, got %q", p.APIKey) + } + if p.BaseURL != "https://generativelanguage.googleapis.com/v1beta/openai" { + t.Fatalf("unexpected google base URL: %q", p.BaseURL) + } + if p.Auth != "bearer" { + t.Fatalf("expected google auth=bearer, got %q", p.Auth) + } + if p.APIFormat != "openai" { + t.Fatalf("expected google api_format=openai, got %q", p.APIFormat) + } + + state := r.All()["google"] + if state == nil { + t.Fatal("expected google provider state") + } + if state.ActiveKeyID != "seed:GEMINI_API_KEY" { + t.Fatalf("active_key_id = %q, want seed:GEMINI_API_KEY", state.ActiveKeyID) + } + if len(state.Keys) != 1 || state.Keys[0].ID != "seed:GEMINI_API_KEY" { + t.Fatalf("expected single google seed key, got %+v", state.Keys) + } +} + +func TestLoadFromEnvSeedsGoogleProviderFromGoogleAlias(t *testing.T) { + t.Setenv("GOOGLE_API_KEY", "sk-google-alias") + + r := NewRegistry("") + r.LoadFromEnv() + + p, err := r.Get("google") + if err != nil { + t.Fatalf("google: %v", err) + } + if p.APIKey != "sk-google-alias" { + t.Fatalf("expected google key from GOOGLE_API_KEY, got %q", p.APIKey) + } + + state := r.All()["google"] + if state == nil { + t.Fatal("expected google provider state") + } + if state.ActiveKeyID != "seed:GOOGLE_API_KEY" { + t.Fatalf("active_key_id = %q, want seed:GOOGLE_API_KEY", state.ActiveKeyID) + } + if len(state.Keys) != 1 || state.Keys[0].ID != "seed:GOOGLE_API_KEY" { + t.Fatalf("expected GOOGLE_API_KEY alias to seed google provider, got %+v", state.Keys) + } +} + +func TestLoadFromEnvPrefersGeminiKeyOverGoogleAlias(t *testing.T) { + t.Setenv("GEMINI_API_KEY", "sk-gemini-primary") + t.Setenv("GOOGLE_API_KEY", "sk-google-alias") + + r := NewRegistry("") + r.LoadFromEnv() + + p, err := r.Get("google") + if err != nil { + t.Fatalf("google: %v", err) + } + if p.APIKey != "sk-gemini-primary" { + t.Fatalf("expected GEMINI_API_KEY to win, got %q", p.APIKey) + } + + state := r.All()["google"] + if state == nil { + t.Fatal("expected google provider state") + } + if state.ActiveKeyID != "seed:GEMINI_API_KEY" { + t.Fatalf("active_key_id = %q, want seed:GEMINI_API_KEY", state.ActiveKeyID) + } + if len(state.Keys) != 2 { + t.Fatalf("expected 2 google keys, got %+v", state.Keys) + } + if state.Keys[0].ID != "seed:GEMINI_API_KEY" || state.Keys[1].ID != "seed:GOOGLE_API_KEY" { + t.Fatalf("expected GEMINI primary then GOOGLE alias ordering, got %+v", state.Keys) + } +} + +func TestLoadFromEnvAppliesGoogleBaseURLOverride(t *testing.T) { + t.Setenv("GEMINI_API_KEY", "sk-gemini-primary") + t.Setenv("GOOGLE_BASE_URL", "https://proxy.example.test/google") + + r := NewRegistry("") + r.LoadFromEnv() + + p, err := r.Get("google") + if err != nil { + t.Fatalf("google: %v", err) + } + if p.BaseURL != "https://proxy.example.test/google" { + t.Fatalf("expected google base URL override, got %q", p.BaseURL) + } +} + func TestGetUnknownProviderErrors(t *testing.T) { r := NewRegistry("") _, err := r.Get("nonexistent") diff --git a/internal/proxy/handler.go b/internal/proxy/handler.go index 26d3a36..a887c0a 100644 --- a/internal/proxy/handler.go +++ b/internal/proxy/handler.go @@ -854,11 +854,21 @@ func injectManagedOpenAITools(payload map[string]any, agentCtx *agentctx.AgentCo payload["stream"] = false delete(payload, "stream_options") } - payload["tools"] = buildOpenAIToolSchemas(agentCtx.Tools.Tools) + tools := existingOpenAIToolSchemas(payload) + for _, tool := range buildOpenAIToolSchemas(agentCtx.Tools.Tools) { + tools = append(tools, tool) + } + payload["tools"] = tools + if _, ok := payload["tool_choice"]; !ok { + if toolChoice, ok := legacyFunctionCallToToolChoice(payload["function_call"]); ok { + payload["tool_choice"] = toolChoice + } + } + if toolChoice, ok := payload["tool_choice"]; ok { + payload["tool_choice"] = rewriteManagedOpenAIToolChoice(toolChoice, agentCtx) + } delete(payload, "functions") delete(payload, "function_call") - delete(payload, "tool_choice") - delete(payload, "parallel_tool_calls") return nil } @@ -869,8 +879,15 @@ func injectManagedAnthropicTools(payload map[string]any, agentCtx *agentctx.Agen if requestedStream(payload) { payload["stream"] = false } - payload["tools"] = buildAnthropicToolSchemas(agentCtx.Tools.Tools) - delete(payload, "tool_choice") + tools, _ := payload["tools"].([]any) + merged := append([]any{}, tools...) + for _, tool := range buildAnthropicToolSchemas(agentCtx.Tools.Tools) { + merged = append(merged, tool) + } + payload["tools"] = merged + if toolChoice, ok := payload["tool_choice"]; ok { + payload["tool_choice"] = rewriteManagedAnthropicToolChoice(toolChoice, agentCtx) + } return nil } @@ -878,6 +895,88 @@ func hasManagedTools(agentCtx *agentctx.AgentContext) bool { return agentCtx != nil && agentCtx.Tools != nil && len(agentCtx.Tools.Tools) > 0 } +func existingOpenAIToolSchemas(payload map[string]any) []any { + var tools []any + if existing, ok := payload["tools"].([]any); ok { + tools = append(tools, existing...) + } + if functions, ok := payload["functions"].([]any); ok { + for _, raw := range functions { + function, _ := raw.(map[string]any) + if function == nil { + continue + } + tools = append(tools, map[string]any{ + "type": "function", + "function": function, + }) + } + } + return tools +} + +func legacyFunctionCallToToolChoice(raw any) (any, bool) { + switch typed := raw.(type) { + case string: + value := strings.TrimSpace(typed) + if value == "" { + return nil, false + } + return value, true + case map[string]any: + name, _ := typed["name"].(string) + name = strings.TrimSpace(name) + if name == "" { + return nil, false + } + return map[string]any{ + "type": "function", + "function": map[string]any{ + "name": name, + }, + }, true + default: + return nil, false + } +} + +func rewriteManagedOpenAIToolChoice(toolChoice any, agentCtx *agentctx.AgentContext) any { + choice, _ := toolChoice.(map[string]any) + if choice == nil { + return toolChoice + } + function, _ := choice["function"].(map[string]any) + if function == nil { + return toolChoice + } + name, _ := function["name"].(string) + resolved, ok := resolveManagedTool(agentCtx, name) + if !ok { + return toolChoice + } + function["name"] = resolved.PresentedName + choice["function"] = function + return choice +} + +func rewriteManagedAnthropicToolChoice(toolChoice any, agentCtx *agentctx.AgentContext) any { + choice, _ := toolChoice.(map[string]any) + if choice == nil { + return toolChoice + } + kind, _ := choice["type"].(string) + if !strings.EqualFold(strings.TrimSpace(kind), "tool") { + return toolChoice + } + name, _ := choice["name"].(string) + resolved, ok := resolveManagedTool(agentCtx, name) + if !ok { + return toolChoice + } + choice["name"] = resolved.PresentedName + return choice +} + func buildOpenAIToolSchemas(tools []agentctx.ToolManifestEntry) []map[string]any { schemas := make([]map[string]any, 0, len(tools)) for _, tool := range tools { diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index 8db1231..1db7125 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -112,22 +112,28 @@ func TestHandlerInjectsManagedToolsIntoOpenAIRequests(t *testing.T) { t.Fatalf("unmarshal backend body: %v", err) } tools, _ := payload["tools"].([]any) - if len(tools) != 1 { - t.Fatalf("expected 1 managed tool, got %+v", payload["tools"]) + if len(tools) != 3 { + t.Fatalf("expected runner-local tools plus managed tools, got %+v", payload["tools"]) } - first, _ := tools[0].(map[string]any) - function, _ := first["function"].(map[string]any) - if first["type"] != "function" || function["name"] != expectedName { - t.Fatalf("unexpected managed tool payload: %+v", first) + var names []string + for _, raw := range tools { + tool, _ := raw.(map[string]any) + function, _ := tool["function"].(map[string]any) + name, _ := function["name"].(string) + names = append(names, name) + } + joined := strings.Join(names, ",") + if !strings.Contains(joined, "runner_local") || !strings.Contains(joined, "legacy") || !strings.Contains(joined, expectedName) { + t.Fatalf("expected merged tool list to include runner_local, legacy, and %q, got %+v", expectedName, payload["tools"]) } if _, ok := payload["functions"]; ok { t.Fatalf("expected legacy functions field removed, got %+v", payload) } - if _, ok := payload["tool_choice"]; ok { - t.Fatalf("expected tool_choice removed, got %+v", payload) + if payload["tool_choice"] != "auto" { + t.Fatalf("expected tool_choice preserved, got %+v", payload) } - if _, ok := payload["parallel_tool_calls"]; ok { - t.Fatalf("expected parallel_tool_calls removed, got %+v", payload) + if enabled, _ := payload["parallel_tool_calls"].(bool); !enabled { + t.Fatalf("expected parallel_tool_calls preserved, got %+v", payload) } } @@ -234,6 +240,221 @@ func TestHandlerRestreamsManagedOpenAITools(t *testing.T) { } } +func TestHandlerPassesThroughNativeOpenAIToolCallsWhenManagedToolsArePresent(t *testing.T) { + var xaiBodies [][]byte + toolCalls := 0 + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + xaiBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read xai body: %v", err) + } + xaiBodies = append(xaiBodies, body) + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal xai request: %v", err) + } + if stream, _ := payload["stream"].(bool); stream { + t.Fatalf("expected managed/native upstream request to force stream=false, got %+v", payload) + } + tools, _ := payload["tools"].([]any) + if len(tools) != 2 { + t.Fatalf("expected merged tool list upstream, got %+v", payload["tools"]) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-native", + "model":"grok-4.1-fast", + "choices":[{ + "finish_reason":"tool_calls", + "message":{ + "role":"assistant", + "tool_calls":[{"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{\"ticker\":\"NVDA\"}"}}] + } + }], + "usage":{"prompt_tokens":9,"completion_tokens":4,"total_tokens":13} + }`)) + })) + defer xaiBackend.Close() + + reg := provider.NewRegistry("") + reg.Set("xai", &provider.Provider{ + Name: "xai", BaseURL: xaiBackend.URL + "/v1", APIKey: "xai-real", Auth: "bearer", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("tiverton", "tiverton:dummy123", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(io.Discard)) + body := `{ + "model":"xai/grok-4.1-fast", + "stream":true, + "stream_options":{"include_usage":true}, + "messages":[{"role":"user","content":"hi"}], + "tools":[{"type":"function","function":{"name":"runner_local","description":"native runner tool","parameters":{"type":"object"}}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer tiverton:dummy123") + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if toolCalls != 0 { + t.Fatalf("expected managed tool executor to stay idle for native tool call, got %d", toolCalls) + } + if len(xaiBodies) != 1 { + t.Fatalf("expected one upstream model round, got %d", len(xaiBodies)) + } + if got := w.Header().Get("Content-Type"); !strings.Contains(got, "text/event-stream") { + t.Fatalf("expected SSE content-type, got %q", got) + } + events := parseSSEEvents(t, w.Body.String()) + if !sseHasToolCall(events, "runner_local") { + t.Fatalf("expected re-streamed native tool call, got %+v", events) + } + if !sseHasUsage(events, 9, 4, 13) { + t.Fatalf("expected usage chunk in re-streamed native tool response, got %+v", events) + } +} + +func TestHandlerRejectsMixedManagedAndNativeOpenAIToolCalls(t *testing.T) { + presentedName := managedToolPresentedNameForCanonical("trading-api.get_market_context") + toolCalls := 0 + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + xaiBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-mixed", + "choices":[{ + "finish_reason":"tool_calls", + "message":{ + "role":"assistant", + "tool_calls":[ + {"id":"call_managed_1","type":"function","function":{"name":"` + presentedName + `","arguments":"{}"}}, + {"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}} + ] + } + }] + }`)) + })) + defer xaiBackend.Close() + + reg := provider.NewRegistry("") + reg.Set("xai", &provider.Provider{ + Name: "xai", BaseURL: xaiBackend.URL + "/v1", APIKey: "xai-real", Auth: "bearer", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("tiverton", "tiverton:dummy123", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(io.Discard)) + body := `{ + "model":"xai/grok-4.1-fast", + "messages":[{"role":"user","content":"hi"}], + "tools":[{"type":"function","function":{"name":"runner_local","description":"native runner tool","parameters":{"type":"object"}}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer tiverton:dummy123") + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadGateway { + t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "mixed managed and runner-native tool calls") { + t.Fatalf("expected mixed tool-call error, got %s", w.Body.String()) + } + if toolCalls != 0 { + t.Fatalf("expected no managed tool execution on mixed tool-call failure, got %d", toolCalls) + } +} + +func TestHandlerRejectsNativeOpenAIToolCallsAfterManagedRounds(t *testing.T) { + xaiCalls := 0 + toolCalls := 0 + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + xaiBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + xaiCalls++ + w.Header().Set("Content-Type", "application/json") + switch xaiCalls { + case 1: + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-managed-first", + "choices":[{ + "finish_reason":"tool_calls", + "message":{ + "role":"assistant", + "tool_calls":[{"id":"call_managed_1","type":"function","function":{"name":"trading-api.get_market_context","arguments":"{}"}}] + } + }] + }`)) + case 2: + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-native-second", + "choices":[{ + "finish_reason":"tool_calls", + "message":{ + "role":"assistant", + "tool_calls":[{"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}}] + } + }] + }`)) + default: + t.Fatalf("unexpected xai round %d", xaiCalls) + } + })) + defer xaiBackend.Close() + + reg := provider.NewRegistry("") + reg.Set("xai", &provider.Provider{ + Name: "xai", BaseURL: xaiBackend.URL + "/v1", APIKey: "xai-real", Auth: "bearer", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("tiverton", "tiverton:dummy123", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(io.Discard)) + body := `{ + "model":"xai/grok-4.1-fast", + "messages":[{"role":"user","content":"hi"}], + "tools":[{"type":"function","function":{"name":"runner_local","description":"native runner tool","parameters":{"type":"object"}}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer tiverton:dummy123") + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadGateway { + t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "runner-native tool calls after managed tool rounds") { + t.Fatalf("expected interleaving error, got %s", w.Body.String()) + } + if xaiCalls != 2 { + t.Fatalf("expected two model rounds before failure, got %d", xaiCalls) + } + if toolCalls != 1 { + t.Fatalf("expected one managed tool execution before failure, got %d", toolCalls) + } +} + func TestHandlerStreamsManagedOpenAIKeepaliveComments(t *testing.T) { presentedName := managedToolPresentedNameForCanonical("trading-api.get_market_context") toolRelease := make(chan struct{}) @@ -1599,15 +1820,15 @@ func TestHandlerExecutesManagedAnthropicTools(t *testing.T) { t.Fatalf("unmarshal anthropic request: %v", err) } tools, _ := payload["tools"].([]any) - if len(tools) != 1 { - t.Fatalf("expected 1 managed anthropic tool, got %+v", payload["tools"]) + if len(tools) != 2 { + t.Fatalf("expected runner-native plus managed anthropic tools, got %+v", payload["tools"]) } - first, _ := tools[0].(map[string]any) - if first["name"] != presentedName { - t.Fatalf("expected provider-safe managed anthropic tool name %q, got %+v", presentedName, first) + if !anthropicToolsIncludeName(tools, "runner_local") || !anthropicToolsIncludeName(tools, presentedName) { + t.Fatalf("expected merged anthropic tool list to include runner_local and %q, got %+v", presentedName, payload["tools"]) } - if _, ok := payload["tool_choice"]; ok { - t.Fatalf("expected tool_choice removed from managed anthropic request, got %+v", payload) + toolChoice, _ := payload["tool_choice"].(map[string]any) + if toolChoice == nil || toolChoice["type"] != "tool" || toolChoice["name"] != "runner_local" { + t.Fatalf("expected native anthropic tool_choice preserved, got %+v", payload) } _, _ = w.Write([]byte(`{ "id":"msg_01", @@ -1642,6 +1863,7 @@ func TestHandlerExecutesManagedAnthropicTools(t *testing.T) { body := `{ "model":"claude-sonnet-4-20250514", "messages":[{"role":"user","content":"hi"}], + "tools":[{"name":"runner_local","description":"native runner tool","input_schema":{"type":"object"}}], "tool_choice":{"type":"tool","name":"runner_local"} }` req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewBufferString(body)) @@ -1851,6 +2073,137 @@ func TestHandlerRestreamsManagedAnthropicTools(t *testing.T) { } } +func TestHandlerPassesThroughNativeAnthropicToolUseWhenManagedToolsArePresent(t *testing.T) { + var anthropicBodies [][]byte + toolCalls := 0 + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read anthropic body: %v", err) + } + anthropicBodies = append(anthropicBodies, body) + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal anthropic request: %v", err) + } + if stream, _ := payload["stream"].(bool); stream { + t.Fatalf("expected managed/native anthropic upstream request to force stream=false, got %+v", payload) + } + tools, _ := payload["tools"].([]any) + if len(tools) != 2 { + t.Fatalf("expected merged anthropic tool list upstream, got %+v", payload["tools"]) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"msg_native", + "type":"message", + "model":"claude-sonnet-4-20250514", + "content":[{"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{"ticker":"NVDA"}}], + "stop_reason":"tool_use", + "usage":{"input_tokens":8,"output_tokens":3} + }`)) + })) + defer backend.Close() + + reg := provider.NewRegistry("") + reg.Set("anthropic", &provider.Provider{ + Name: "anthropic", BaseURL: backend.URL + "/v1", APIKey: "sk-ant-real", Auth: "x-api-key", APIFormat: "anthropic", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("nano-bot", "nano-bot:dummy456", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(io.Discard)) + body := `{ + "model":"claude-sonnet-4-20250514", + "stream":true, + "messages":[{"role":"user","content":"hi"}], + "tools":[{"name":"runner_local","description":"native runner tool","input_schema":{"type":"object"}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer nano-bot:dummy456") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Anthropic-Version", "2023-06-01") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if toolCalls != 0 { + t.Fatalf("expected managed tool executor to stay idle for native anthropic tool_use, got %d", toolCalls) + } + if len(anthropicBodies) != 1 { + t.Fatalf("expected one upstream anthropic round, got %d", len(anthropicBodies)) + } + if got := w.Header().Get("Content-Type"); !strings.Contains(got, "text/event-stream") { + t.Fatalf("expected SSE content-type, got %q", got) + } + if !anthropicSSEHasToolUse(parseSSEEvents(t, w.Body.String()), "runner_local") { + t.Fatalf("expected re-streamed native anthropic tool_use, got %s", w.Body.String()) + } +} + +func TestHandlerRejectsMixedManagedAndNativeAnthropicToolUse(t *testing.T) { + presentedName := managedToolPresentedNameForCanonical("trading-api.get_market_context") + toolCalls := 0 + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{ + "id":"msg_mixed", + "type":"message", + "content":[ + {"type":"tool_use","id":"toolu_managed_1","name":"` + presentedName + `","input":{}}, + {"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}} + ], + "stop_reason":"tool_use", + "usage":{"input_tokens":8,"output_tokens":3} + }`)) + })) + defer backend.Close() + + reg := provider.NewRegistry("") + reg.Set("anthropic", &provider.Provider{ + Name: "anthropic", BaseURL: backend.URL + "/v1", APIKey: "sk-ant-real", Auth: "x-api-key", APIFormat: "anthropic", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("nano-bot", "nano-bot:dummy456", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(io.Discard)) + body := `{ + "model":"claude-sonnet-4-20250514", + "messages":[{"role":"user","content":"hi"}], + "tools":[{"name":"runner_local","description":"native runner tool","input_schema":{"type":"object"}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer nano-bot:dummy456") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Anthropic-Version", "2023-06-01") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusBadGateway { + t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "mixed managed and runner-native tool calls") { + t.Fatalf("expected mixed anthropic tool-use error, got %s", w.Body.String()) + } + if toolCalls != 0 { + t.Fatalf("expected no managed tool execution on mixed anthropic failure, got %d", toolCalls) + } +} + func TestHandlerStreamsManagedAnthropicKeepaliveComments(t *testing.T) { toolRelease := make(chan struct{}) toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -3508,6 +3861,47 @@ func sseHasContent(events []map[string]any, want string) bool { return false } +func sseHasToolCall(events []map[string]any, want string) bool { + for _, event := range events { + choices, _ := event["choices"].([]any) + for _, rawChoice := range choices { + choice, _ := rawChoice.(map[string]any) + if choice == nil { + continue + } + delta, _ := choice["delta"].(map[string]any) + if delta == nil { + continue + } + toolCalls, _ := delta["tool_calls"].([]any) + for _, rawTool := range toolCalls { + tool, _ := rawTool.(map[string]any) + function, _ := tool["function"].(map[string]any) + if function["name"] == want { + return true + } + } + } + } + return false +} + +func anthropicSSEHasToolUse(events []map[string]any, want string) bool { + for _, event := range events { + if event["type"] != "content_block_start" { + continue + } + block, _ := event["content_block"].(map[string]any) + if block == nil { + continue + } + if block["type"] == "tool_use" && block["name"] == want { + return true + } + } + return false +} + func sseHasUsage(events []map[string]any, prompt, completion, total int) bool { for _, event := range events { usage, _ := event["usage"].(map[string]any) @@ -3523,6 +3917,19 @@ func sseHasUsage(events []map[string]any, prompt, completion, total int) bool { return false } +func anthropicToolsIncludeName(tools []any, want string) bool { + for _, raw := range tools { + tool, _ := raw.(map[string]any) + if tool == nil { + continue + } + if tool["name"] == want { + return true + } + } + return false +} + func assertInterventionLogged(t *testing.T, raw []byte, reason string) { t.Helper() entries := parseLogEntries(t, raw) diff --git a/internal/proxy/toolmediation.go b/internal/proxy/toolmediation.go index ab1407d..665d480 100644 --- a/internal/proxy/toolmediation.go +++ b/internal/proxy/toolmediation.go @@ -219,6 +219,60 @@ func (h *Handler) handleManagedOpenAI(w http.ResponseWriter, r *http.Request, ag return } + managedCalls, nativeCalls := partitionManagedOpenAIToolCalls(agentCtx, toolCalls) + if len(managedCalls) == 0 { + if len(toolTrace) > 0 || len(hiddenMessages) > 0 { + msg := "runner-native tool calls after managed tool rounds are not supported in one request" + h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) + if streamKeepalive != nil && downstreamStream { + streamKeepalive.writeOpenAIError(jsonErrorPayload(msg)) + h.logger.LogError(agentID, requestedModel, http.StatusBadGateway, time.Since(start).Milliseconds(), fmt.Errorf(msg)) + return + } + h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) + return + } + + responseBytes := resp.Body + responseHeader := resp.Header.Clone() + if downstreamStream { + sse, synthErr := synthesizeOpenAIToolCallStream(resp.Body, resp.UpstreamModel, usageAgg, downstreamIncludeUsage) + if synthErr != nil { + h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload("failed to synthesize managed stream"), usageAgg, toolTrace) + if streamKeepalive != nil { + streamKeepalive.writeOpenAIError(jsonErrorPayload("failed to synthesize managed stream")) + h.logger.LogError(agentID, requestedModel, http.StatusBadGateway, time.Since(start).Milliseconds(), synthErr) + return + } + h.fail(w, http.StatusBadGateway, "failed to synthesize managed stream", agentID, requestedModel, start, synthErr) + return + } + streamKeepalive.writeFinal(sse) + responseBytes = sse + responseHeader = syntheticSSEHeader() + } else { + copyResponseHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + if len(resp.Body) > 0 { + _, _ = w.Write(resp.Body) + } + } + h.recordResponse(agentID, agentCtx, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, resp.StatusCode, responseHeader, responseBytes, start) + return + } + + if len(nativeCalls) > 0 { + msg := "mixed managed and runner-native tool calls are not supported in one model response" + h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) + if streamKeepalive != nil && downstreamStream { + streamKeepalive.writeOpenAIError(jsonErrorPayload(msg)) + h.logger.LogError(agentID, requestedModel, http.StatusBadGateway, time.Since(start).Milliseconds(), fmt.Errorf(msg)) + return + } + h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) + return + } + if len(toolTrace) >= policy.MaxRounds { msg := fmt.Sprintf("managed tool max rounds exceeded (%d)", policy.MaxRounds) h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) @@ -235,7 +289,7 @@ func (h *Handler) handleManagedOpenAI(w http.ResponseWriter, r *http.Request, ag ReportedCostUSD: usage.ReportedCostUSD, }, } - for _, call := range toolCalls { + for _, call := range managedCalls { execResult := waitWithManagedKeepalive(streamKeepalive, managedToolWaitComment(len(toolTrace)+1, managedToolDisplayName(agentCtx, call.Name)), func() managedOpenAIToolExecResult { outcome, execErr := h.executeManagedOpenAITool(loopCtx, agentID, agentCtx, call, policy) return managedOpenAIToolExecResult{Outcome: outcome, Err: execErr} @@ -361,6 +415,60 @@ func (h *Handler) handleManagedAnthropic(w http.ResponseWriter, r *http.Request, return } + managedToolUses, nativeToolUses := partitionManagedAnthropicToolUses(agentCtx, toolUses) + if len(managedToolUses) == 0 { + if len(toolTrace) > 0 || len(hiddenMessages) > 0 { + msg := "runner-native tool calls after managed tool rounds are not supported in one request" + h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) + if streamKeepalive != nil && downstreamStream { + streamKeepalive.writeAnthropicError(msg) + h.logger.LogError(agentID, requestedModel, http.StatusBadGateway, time.Since(start).Milliseconds(), fmt.Errorf(msg)) + return + } + h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) + return + } + + responseBytes := resp.Body + responseHeader := resp.Header.Clone() + if downstreamStream { + sse, synthErr := synthesizeAnthropicToolUseStream(resp.Body, resp.UpstreamModel, usageAgg) + if synthErr != nil { + h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload("failed to synthesize managed stream"), usageAgg, toolTrace) + if streamKeepalive != nil { + streamKeepalive.writeAnthropicError("failed to synthesize managed stream") + h.logger.LogError(agentID, requestedModel, http.StatusBadGateway, time.Since(start).Milliseconds(), synthErr) + return + } + h.fail(w, http.StatusBadGateway, "failed to synthesize managed stream", agentID, requestedModel, start, synthErr) + return + } + streamKeepalive.writeFinal(sse) + responseBytes = sse + responseHeader = syntheticSSEHeader() + } else { + copyResponseHeaders(w.Header(), resp.Header) + w.WriteHeader(resp.StatusCode) + if len(resp.Body) > 0 { + _, _ = w.Write(resp.Body) + } + } + h.recordResponse(agentID, agentCtx, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, resp.StatusCode, responseHeader, responseBytes, start) + return + } + + if len(nativeToolUses) > 0 { + msg := "mixed managed and runner-native tool calls are not supported in one model response" + h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) + if streamKeepalive != nil && downstreamStream { + streamKeepalive.writeAnthropicError(msg) + h.logger.LogError(agentID, requestedModel, http.StatusBadGateway, time.Since(start).Milliseconds(), fmt.Errorf(msg)) + return + } + h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) + return + } + if len(toolTrace) >= policy.MaxRounds { msg := fmt.Sprintf("managed tool max rounds exceeded (%d)", policy.MaxRounds) h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) @@ -377,7 +485,7 @@ func (h *Handler) handleManagedAnthropic(w http.ResponseWriter, r *http.Request, ReportedCostUSD: usage.ReportedCostUSD, }, } - for _, call := range toolUses { + for _, call := range managedToolUses { execResult := waitWithManagedKeepalive(streamKeepalive, managedToolWaitComment(len(toolTrace)+1, managedToolDisplayName(agentCtx, call.Name)), func() managedAnthropicToolExecResult { outcome, execErr := h.executeManagedAnthropicTool(loopCtx, agentID, agentCtx, call, policy) return managedAnthropicToolExecResult{Outcome: outcome, Err: execErr} @@ -927,6 +1035,38 @@ func parseAnthropicToolResponse(body []byte) (map[string]any, []anthropicToolUse } } +func partitionManagedOpenAIToolCalls(agentCtx *agentctx.AgentContext, calls []openAIToolCall) ([]openAIToolCall, []openAIToolCall) { + if len(calls) == 0 { + return nil, nil + } + managed := make([]openAIToolCall, 0, len(calls)) + native := make([]openAIToolCall, 0, len(calls)) + for _, call := range calls { + if _, ok := resolveManagedTool(agentCtx, call.Name); ok { + managed = append(managed, call) + continue + } + native = append(native, call) + } + return managed, native +} + +func partitionManagedAnthropicToolUses(agentCtx *agentctx.AgentContext, calls []anthropicToolUse) ([]anthropicToolUse, []anthropicToolUse) { + if len(calls) == 0 { + return nil, nil + } + managed := make([]anthropicToolUse, 0, len(calls)) + native := make([]anthropicToolUse, 0, len(calls)) + for _, call := range calls { + if _, ok := resolveManagedTool(agentCtx, call.Name); ok { + managed = append(managed, call) + continue + } + native = append(native, call) + } + return managed, native +} + func appendOpenAIAssistantAndToolMessages(payload map[string]any, assistantMessage map[string]any, toolMessages []any) { messages, _ := payload["messages"].([]any) messages = append(messages, assistantMessage) @@ -1027,6 +1167,113 @@ func synthesizeOpenAIStream(finalBody []byte, upstreamModel string, usage manage return stream.Bytes(), nil } +func synthesizeOpenAIToolCallStream(finalBody []byte, upstreamModel string, usage managedUsageAggregate, includeUsage bool) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(finalBody, &payload); err != nil { + return nil, err + } + id, _ := payload["id"].(string) + if strings.TrimSpace(id) == "" { + id = "chatcmpl-managed" + } + model, _ := payload["model"].(string) + if strings.TrimSpace(model) == "" { + model = upstreamModel + } + created := time.Now().Unix() + if rawCreated, ok := payload["created"].(float64); ok { + created = int64(rawCreated) + } + + assistantMessage, toolCalls, err := parseOpenAIToolResponse(finalBody) + if err != nil { + return nil, err + } + if len(toolCalls) == 0 { + return nil, fmt.Errorf("cannot synthesize streamed tool_call payload without tool calls") + } + + var stream bytes.Buffer + writeSSEChunk(&stream, map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{"role": "assistant"}, + "finish_reason": nil, + }}, + }) + if content := openAIMessageContent(assistantMessage); content != "" { + writeSSEChunk(&stream, map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{"content": content}, + "finish_reason": nil, + }}, + }) + } + for i, call := range toolCalls { + args := strings.TrimSpace(string(call.ArgumentsRaw)) + if args == "" { + args = "{}" + } + writeSSEChunk(&stream, map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{ + "tool_calls": []map[string]any{{ + "index": i, + "id": call.ID, + "type": "function", + "function": map[string]any{ + "name": call.Name, + "arguments": args, + }, + }}, + }, + "finish_reason": nil, + }}, + }) + } + writeSSEChunk(&stream, map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []map[string]any{{ + "index": 0, + "delta": map[string]any{}, + "finish_reason": "tool_calls", + }}, + }) + if includeUsage { + writeSSEChunk(&stream, map[string]any{ + "id": id, + "object": "chat.completion.chunk", + "created": created, + "model": model, + "choices": []any{}, + "usage": map[string]any{ + "prompt_tokens": usage.PromptTokens, + "completion_tokens": usage.CompletionTokens, + "total_tokens": usage.TotalTokens, + }, + }) + } + stream.WriteString("data: [DONE]\n\n") + return stream.Bytes(), nil +} + func synthesizeAnthropicStream(finalBody []byte, upstreamModel string, usage managedUsageAggregate) ([]byte, error) { var payload map[string]any if err := json.Unmarshal(finalBody, &payload); err != nil { @@ -1110,6 +1357,132 @@ func synthesizeAnthropicStream(finalBody []byte, upstreamModel string, usage man return stream.Bytes(), nil } +func synthesizeAnthropicToolUseStream(finalBody []byte, upstreamModel string, usage managedUsageAggregate) ([]byte, error) { + var payload map[string]any + if err := json.Unmarshal(finalBody, &payload); err != nil { + return nil, err + } + assistantMessage, toolUses, err := parseAnthropicToolResponse(finalBody) + if err != nil { + return nil, err + } + if len(toolUses) == 0 { + return nil, fmt.Errorf("cannot synthesize streamed tool_use payload without tool calls") + } + + id, _ := payload["id"].(string) + if strings.TrimSpace(id) == "" { + id = "msg_managed" + } + model, _ := payload["model"].(string) + if strings.TrimSpace(model) == "" { + model = upstreamModel + } + stopReason, _ := payload["stop_reason"].(string) + if strings.TrimSpace(stopReason) == "" { + stopReason = "tool_use" + } + + var stream bytes.Buffer + writeAnthropicSSEEvent(&stream, "message_start", map[string]any{ + "type": "message_start", + "message": map[string]any{ + "id": id, + "type": "message", + "role": "assistant", + "model": model, + "content": []any{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]any{ + "input_tokens": usage.PromptTokens, + }, + }, + }) + + content, _ := assistantMessage["content"].([]any) + for idx, raw := range content { + block, _ := raw.(map[string]any) + if block == nil { + continue + } + blockType, _ := block["type"].(string) + switch blockType { + case "tool_use": + blockID, _ := block["id"].(string) + name, _ := block["name"].(string) + input := "{}" + if block["input"] != nil { + encoded, err := json.Marshal(block["input"]) + if err != nil { + return nil, err + } + input = string(encoded) + } + writeAnthropicSSEEvent(&stream, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "tool_use", + "id": blockID, + "name": name, + }, + }) + writeAnthropicSSEEvent(&stream, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "input_json_delta", + "partial_json": input, + }, + }) + writeAnthropicSSEEvent(&stream, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + case "text", "": + text, _ := block["text"].(string) + writeAnthropicSSEEvent(&stream, "content_block_start", map[string]any{ + "type": "content_block_start", + "index": idx, + "content_block": map[string]any{ + "type": "text", + "text": "", + }, + }) + if text != "" { + writeAnthropicSSEEvent(&stream, "content_block_delta", map[string]any{ + "type": "content_block_delta", + "index": idx, + "delta": map[string]any{ + "type": "text_delta", + "text": text, + }, + }) + } + writeAnthropicSSEEvent(&stream, "content_block_stop", map[string]any{ + "type": "content_block_stop", + "index": idx, + }) + } + } + + writeAnthropicSSEEvent(&stream, "message_delta", map[string]any{ + "type": "message_delta", + "delta": map[string]any{ + "stop_reason": stopReason, + "stop_sequence": payload["stop_sequence"], + }, + "usage": map[string]any{ + "output_tokens": usage.CompletionTokens, + }, + }) + writeAnthropicSSEEvent(&stream, "message_stop", map[string]any{ + "type": "message_stop", + }) + return stream.Bytes(), nil +} + func openAIMessageContent(message map[string]any) string { if message == nil { return "" @@ -1180,6 +1553,14 @@ func writeSyntheticSSE(w http.ResponseWriter, stream []byte) { _, _ = w.Write(stream) } +func syntheticSSEHeader() http.Header { + header := http.Header{} + header.Set("Content-Type", "text/event-stream") + header.Set("Cache-Control", "no-cache") + header.Set("Connection", "keep-alive") + return header +} + func newManagedStreamKeepalive(w http.ResponseWriter, enabled bool) *managedStreamKeepalive { if !enabled { return nil