diff --git a/internal/proxy/handler_test.go b/internal/proxy/handler_test.go index 04d03c0..c85301e 100644 --- a/internal/proxy/handler_test.go +++ b/internal/proxy/handler_test.go @@ -343,8 +343,8 @@ func TestHandlerRejectsMixedManagedAndNativeOpenAIToolCalls(t *testing.T) { "message":{ "role":"assistant", "tool_calls":[ - {"id":"call_managed_1","type":"function","function":{"name":"` + presentedName + `","arguments":"{}"}}, - {"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}} + {"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}}, + {"id":"call_managed_1","type":"function","function":{"name":"` + presentedName + `","arguments":"{}"}} ] } }] @@ -373,7 +373,7 @@ func TestHandlerRejectsMixedManagedAndNativeOpenAIToolCalls(t *testing.T) { if w.Code != http.StatusBadGateway { t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String()) } - if !strings.Contains(w.Body.String(), "mixed managed and runner-native tool calls") { + if !strings.Contains(w.Body.String(), "managed service tools come first") { t.Fatalf("expected mixed tool-call error, got %s", w.Body.String()) } if toolCalls != 0 { @@ -381,6 +381,196 @@ func TestHandlerRejectsMixedManagedAndNativeOpenAIToolCalls(t *testing.T) { } } +func TestHandlerSerializesManagedPrefixBeforeNativeOpenAIToolCalls(t *testing.T) { + presentedName := managedToolPresentedNameForCanonical("trading-api.get_market_context") + xaiCalls := 0 + toolCalls := 0 + var logs bytes.Buffer + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + xaiBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + xaiCalls++ + w.Header().Set("Content-Type", "application/json") + switch xaiCalls { + case 1: + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-mixed-first", + "choices":[{ + "finish_reason":"tool_calls", + "message":{ + "role":"assistant", + "content":"Checking managed context first.", + "tool_calls":[ + {"id":"call_managed_1","type":"function","function":{"name":"` + presentedName + `","arguments":"{}"}}, + {"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}} + ] + } + }] + }`)) + case 2: + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read xai body: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal xai request: %v", err) + } + rawMessages, _ := payload["messages"].([]any) + conversation := make([]map[string]any, 0, len(rawMessages)) + for _, raw := range rawMessages { + msg, _ := raw.(map[string]any) + if msg == nil { + continue + } + if role, _ := msg["role"].(string); role == "system" { + continue + } + conversation = append(conversation, msg) + } + if len(conversation) != 3 { + t.Fatalf("expected only managed prefix to be serialized before rerun, got %+v", payload) + } + if conversation[0]["role"] != "user" || conversation[0]["content"] != "hi" { + t.Fatalf("unexpected first conversation message: %+v", conversation[0]) + } + managedAssistant := conversation[1] + managedToolCalls, _ := managedAssistant["tool_calls"].([]any) + if managedAssistant["role"] != "assistant" || managedAssistant["content"] != "Checking managed context first." || len(managedToolCalls) != 1 { + t.Fatalf("expected managed-only assistant message, got %+v", managedAssistant) + } + if managedAssistantCall := managedToolCalls[0].(map[string]any); managedAssistantCall["id"] != "call_managed_1" { + t.Fatalf("expected managed tool call id preserved, got %+v", managedAssistantCall) + } + managedTool := conversation[2] + if managedTool["role"] != "tool" || managedTool["tool_call_id"] != "call_managed_1" { + t.Fatalf("expected managed tool result after serialization, got %+v", managedTool) + } + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-native-second", + "choices":[{ + "finish_reason":"tool_calls", + "message":{ + "role":"assistant", + "tool_calls":[{"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}}] + } + }] + }`)) + case 3: + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read xai body: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal xai request: %v", err) + } + rawMessages, _ := payload["messages"].([]any) + conversation := make([]map[string]any, 0, len(rawMessages)) + for _, raw := range rawMessages { + msg, _ := raw.(map[string]any) + if msg == nil { + continue + } + if role, _ := msg["role"].(string); role == "system" { + continue + } + conversation = append(conversation, msg) + } + if len(conversation) < 5 { + t.Fatalf("expected managed handoff continuity injection, got %+v", payload) + } + hiddenAssistant := conversation[1] + hiddenToolCalls, _ := hiddenAssistant["tool_calls"].([]any) + if hiddenAssistant["role"] != "assistant" || len(hiddenToolCalls) != 1 { + t.Fatalf("expected hidden managed assistant before native tool call, got %+v", hiddenAssistant) + } + hiddenTool := conversation[2] + if hiddenTool["role"] != "tool" || hiddenTool["tool_call_id"] != "call_managed_1" { + t.Fatalf("expected hidden managed tool result, got %+v", hiddenTool) + } + nativeAssistant := conversation[3] + nativeToolCalls, _ := nativeAssistant["tool_calls"].([]any) + if nativeAssistant["role"] != "assistant" || len(nativeToolCalls) != 1 { + t.Fatalf("expected native tool call after hidden managed rounds, got %+v", nativeAssistant) + } + nativeToolResult := conversation[4] + if nativeToolResult["role"] != "tool" || nativeToolResult["tool_call_id"] != "call_native_1" { + t.Fatalf("expected native tool result after injected hidden rounds, got %+v", nativeToolResult) + } + _, _ = w.Write([]byte(`{ + "id":"chatcmpl-final", + "choices":[{"message":{"role":"assistant","content":"native handoff complete"}}] + }`)) + default: + t.Fatalf("unexpected xai round %d", xaiCalls) + } + })) + defer xaiBackend.Close() + + reg := provider.NewRegistry("") + reg.Set("xai", &provider.Provider{ + Name: "xai", BaseURL: xaiBackend.URL + "/v1", APIKey: "xai-real", Auth: "bearer", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("tiverton", "tiverton:dummy123", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(&logs)) + body := `{ + "model":"xai/grok-4.1-fast", + "messages":[{"role":"user","content":"hi"}], + "tools":[{"type":"function","function":{"name":"runner_local","description":"native runner tool","parameters":{"type":"object"}}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer tiverton:dummy123") + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "\"runner_local\"") { + t.Fatalf("expected native tool call to be returned to runner, got %s", w.Body.String()) + } + if xaiCalls != 2 { + t.Fatalf("expected two model rounds before handoff, got %d", xaiCalls) + } + if toolCalls != 1 { + t.Fatalf("expected one managed tool execution before handoff, got %d", toolCalls) + } + + followupReq := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", bytes.NewBufferString(`{ + "model":"xai/grok-4.1-fast", + "messages":[ + {"role":"user","content":"hi"}, + {"role":"assistant","tool_calls":[{"id":"call_native_1","type":"function","function":{"name":"runner_local","arguments":"{}"}}]}, + {"role":"tool","tool_call_id":"call_native_1","content":"{\"native\":true}"} + ], + "tools":[{"type":"function","function":{"name":"runner_local","description":"native runner tool","parameters":{"type":"object"}}}] + }`)) + followupReq.Header.Set("Authorization", "Bearer tiverton:dummy123") + followupReq.Header.Set("Content-Type", "application/json") + followupW := httptest.NewRecorder() + + h.ServeHTTP(followupW, followupReq) + + if followupW.Code != http.StatusOK { + t.Fatalf("expected follow-up request 200, got %d: %s", followupW.Code, followupW.Body.String()) + } + if !strings.Contains(followupW.Body.String(), "native handoff complete") { + t.Fatalf("expected final text after native handoff, got %s", followupW.Body.String()) + } + if xaiCalls != 3 { + t.Fatalf("expected third model round after runner tool result, got %d", xaiCalls) + } + assertInterventionLogged(t, logs.Bytes(), managedMixedPrefixSerializedIntervention) +} + func TestHandlerHandsOffNativeOpenAIToolCallsAfterManagedRounds(t *testing.T) { xaiCalls := 0 toolCalls := 0 @@ -2248,8 +2438,8 @@ func TestHandlerRejectsMixedManagedAndNativeAnthropicToolUse(t *testing.T) { "id":"msg_mixed", "type":"message", "content":[ - {"type":"tool_use","id":"toolu_managed_1","name":"` + presentedName + `","input":{}}, - {"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}} + {"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}}, + {"type":"tool_use","id":"toolu_managed_1","name":"` + presentedName + `","input":{}} ], "stop_reason":"tool_use", "usage":{"input_tokens":8,"output_tokens":3} @@ -2279,7 +2469,7 @@ func TestHandlerRejectsMixedManagedAndNativeAnthropicToolUse(t *testing.T) { if w.Code != http.StatusBadGateway { t.Fatalf("expected 502, got %d: %s", w.Code, w.Body.String()) } - if !strings.Contains(w.Body.String(), "mixed managed and runner-native tool calls") { + if !strings.Contains(w.Body.String(), "managed service tools come first") { t.Fatalf("expected mixed anthropic tool-use error, got %s", w.Body.String()) } if toolCalls != 0 { @@ -2287,6 +2477,214 @@ func TestHandlerRejectsMixedManagedAndNativeAnthropicToolUse(t *testing.T) { } } +func TestHandlerSerializesManagedPrefixBeforeNativeAnthropicToolUse(t *testing.T) { + presentedName := managedToolPresentedNameForCanonical("trading-api.get_market_context") + anthropicCalls := 0 + toolCalls := 0 + var logs bytes.Buffer + toolSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + toolCalls++ + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"balance":5000}`)) + })) + defer toolSrv.Close() + + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + anthropicCalls++ + w.Header().Set("Content-Type", "application/json") + switch anthropicCalls { + case 1: + _, _ = w.Write([]byte(`{ + "id":"msg_mixed_first", + "type":"message", + "content":[ + {"type":"text","text":"Checking managed context first."}, + {"type":"tool_use","id":"toolu_managed_1","name":"` + presentedName + `","input":{}}, + {"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}} + ], + "stop_reason":"tool_use", + "usage":{"input_tokens":10,"output_tokens":3} + }`)) + case 2: + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read anthropic body: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal anthropic request: %v", err) + } + rawMessages, _ := payload["messages"].([]any) + if len(rawMessages) != 3 { + t.Fatalf("expected only managed prefix to be serialized before rerun, got %+v", payload) + } + first := rawMessages[0].(map[string]any) + if first["role"] != "user" || first["content"] != "hi" { + t.Fatalf("unexpected first conversation message: %+v", first) + } + managedAssistant := rawMessages[1].(map[string]any) + managedBlocks, _ := managedAssistant["content"].([]any) + if managedAssistant["role"] != "assistant" || len(managedBlocks) != 2 { + t.Fatalf("expected managed-only anthropic assistant, got %+v", managedAssistant) + } + if managedBlocks[0].(map[string]any)["type"] != "text" || managedBlocks[1].(map[string]any)["id"] != "toolu_managed_1" { + t.Fatalf("expected managed prefix preserved, got %+v", managedBlocks) + } + managedResult := rawMessages[2].(map[string]any) + managedResultBlocks, _ := managedResult["content"].([]any) + if managedResult["role"] != "user" || len(managedResultBlocks) != 1 || managedResultBlocks[0].(map[string]any)["tool_use_id"] != "toolu_managed_1" { + t.Fatalf("expected managed tool_result after serialization, got %+v", managedResult) + } + _, _ = w.Write([]byte(`{ + "id":"msg_native_second", + "type":"message", + "content":[{"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}}], + "stop_reason":"tool_use", + "usage":{"input_tokens":7,"output_tokens":4} + }`)) + case 3: + body, err := io.ReadAll(r.Body) + if err != nil { + t.Fatalf("read anthropic body: %v", err) + } + var payload map[string]any + if err := json.Unmarshal(body, &payload); err != nil { + t.Fatalf("unmarshal anthropic request: %v", err) + } + rawMessages, _ := payload["messages"].([]any) + if len(rawMessages) < 5 { + t.Fatalf("expected managed handoff continuity injection, got %+v", payload) + } + hiddenAssistant := rawMessages[1].(map[string]any) + hiddenAssistantBlocks, _ := hiddenAssistant["content"].([]any) + if hiddenAssistant["role"] != "assistant" || len(hiddenAssistantBlocks) != 2 { + t.Fatalf("expected hidden managed assistant before native tool use, got %+v", hiddenAssistant) + } + hiddenUser := rawMessages[2].(map[string]any) + hiddenUserBlocks, _ := hiddenUser["content"].([]any) + if hiddenUser["role"] != "user" || len(hiddenUserBlocks) != 1 || hiddenUserBlocks[0].(map[string]any)["tool_use_id"] != "toolu_managed_1" { + t.Fatalf("expected hidden managed tool_result after tool_use, got %+v", hiddenUser) + } + nativeAssistant := rawMessages[3].(map[string]any) + nativeAssistantBlocks, _ := nativeAssistant["content"].([]any) + if nativeAssistant["role"] != "assistant" || len(nativeAssistantBlocks) != 1 { + t.Fatalf("expected native tool_use after injected hidden rounds, got %+v", nativeAssistant) + } + nativeUser := rawMessages[4].(map[string]any) + nativeUserBlocks, _ := nativeUser["content"].([]any) + if nativeUser["role"] != "user" || len(nativeUserBlocks) != 1 || nativeUserBlocks[0].(map[string]any)["tool_use_id"] != "toolu_native_1" { + t.Fatalf("expected native tool_result after injected hidden rounds, got %+v", nativeUser) + } + _, _ = w.Write([]byte(`{ + "id":"msg_final", + "type":"message", + "content":[{"type":"text","text":"native handoff complete"}], + "stop_reason":"end_turn", + "usage":{"input_tokens":6,"output_tokens":5} + }`)) + default: + t.Fatalf("unexpected anthropic round %d", anthropicCalls) + } + })) + defer backend.Close() + + reg := provider.NewRegistry("") + reg.Set("anthropic", &provider.Provider{ + Name: "anthropic", BaseURL: backend.URL + "/v1", APIKey: "sk-ant-real", Auth: "x-api-key", APIFormat: "anthropic", + }) + + h := NewHandler(reg, stubContextLoaderWithTools("nano-bot", "nano-bot:dummy456", managedToolManifestForURL(toolSrv.URL, http.MethodGet, "/api/v1/market_context/{claw_id}", "")), logging.New(&logs)) + body := `{ + "model":"claude-sonnet-4-20250514", + "messages":[{"role":"user","content":"hi"}], + "tools":[{"name":"runner_local","description":"native runner tool","input_schema":{"type":"object"}}] + }` + req := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewBufferString(body)) + req.Header.Set("Authorization", "Bearer nano-bot:dummy456") + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Anthropic-Version", "2023-06-01") + w := httptest.NewRecorder() + + h.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + if !strings.Contains(w.Body.String(), "\"runner_local\"") { + t.Fatalf("expected native anthropic tool_use to be returned to runner, got %s", w.Body.String()) + } + if anthropicCalls != 2 { + t.Fatalf("expected two anthropic rounds before handoff, got %d", anthropicCalls) + } + if toolCalls != 1 { + t.Fatalf("expected one managed tool execution before handoff, got %d", toolCalls) + } + + followupReq := httptest.NewRequest(http.MethodPost, "/v1/messages", bytes.NewBufferString(`{ + "model":"claude-sonnet-4-20250514", + "messages":[ + {"role":"user","content":"hi"}, + {"role":"assistant","content":[{"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}}]}, + {"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_native_1","content":"{\"native\":true}"}]} + ], + "tools":[{"name":"runner_local","description":"native runner tool","input_schema":{"type":"object"}}] + }`)) + followupReq.Header.Set("Authorization", "Bearer nano-bot:dummy456") + followupReq.Header.Set("Content-Type", "application/json") + followupReq.Header.Set("Anthropic-Version", "2023-06-01") + followupW := httptest.NewRecorder() + + h.ServeHTTP(followupW, followupReq) + + if followupW.Code != http.StatusOK { + t.Fatalf("expected follow-up request 200, got %d: %s", followupW.Code, followupW.Body.String()) + } + if !strings.Contains(followupW.Body.String(), "native handoff complete") { + t.Fatalf("expected final text after native handoff, got %s", followupW.Body.String()) + } + if anthropicCalls != 3 { + t.Fatalf("expected third anthropic round after runner tool result, got %d", anthropicCalls) + } + assertInterventionLogged(t, logs.Bytes(), managedMixedPrefixSerializedIntervention) +} + +func TestBuildAnthropicAssistantMessageRetainsManagedBlockWhenUpstreamIDMissing(t *testing.T) { + assistantMessage, toolUses, err := parseAnthropicToolResponse([]byte(`{ + "content":[ + {"type":"text","text":"Checking managed context first."}, + {"type":"tool_use","name":"trading-api.get_market_context","input":{}}, + {"type":"tool_use","id":"toolu_native_1","name":"runner_local","input":{}} + ] + }`)) + if err != nil { + t.Fatalf("parse anthropic response: %v", err) + } + if len(toolUses) != 2 { + t.Fatalf("expected two parsed tool uses, got %d", len(toolUses)) + } + if toolUses[0].ID != "toolu_2" { + t.Fatalf("expected synthetic id for id-less managed tool_use, got %q", toolUses[0].ID) + } + + managedAssistant := buildAnthropicAssistantMessage(assistantMessage, toolUses[:1], true) + blocks, _ := managedAssistant["content"].([]any) + if len(blocks) != 2 { + t.Fatalf("expected text plus managed tool_use after filtering, got %+v", managedAssistant["content"]) + } + + first, _ := blocks[0].(map[string]any) + if first == nil || first["type"] != "text" { + t.Fatalf("expected leading text block preserved, got %+v", blocks[0]) + } + second, _ := blocks[1].(map[string]any) + if second == nil || second["type"] != "tool_use" || second["name"] != "trading-api.get_market_context" { + t.Fatalf("expected managed tool_use preserved after synthetic-id filtering, got %+v", blocks[1]) + } + if _, ok := second["id"]; ok { + t.Fatalf("expected filtered managed tool_use to preserve missing upstream id, got %+v", second) + } +} + func TestHandlerHandsOffNativeAnthropicToolUseAfterManagedRounds(t *testing.T) { anthropicCalls := 0 toolCalls := 0 diff --git a/internal/proxy/toolmediation.go b/internal/proxy/toolmediation.go index 366e34d..4a5b714 100644 --- a/internal/proxy/toolmediation.go +++ b/internal/proxy/toolmediation.go @@ -38,6 +38,10 @@ var ( const managedToolModeMessage = "This request is in mediated mode. Action required: re-emit only managed service tools for this turn, or respond in text." +const mixedToolOrderMessage = "mixed managed and runner-native tool calls are not supported in one model response unless managed service tools come first. Re-emit managed service tools first, then emit runner-native tool calls in a later response." + +const managedMixedPrefixSerializedIntervention = "managed_prefix_native_suffix_serialized" + type capturedResponse struct { StatusCode int Header http.Header @@ -117,6 +121,24 @@ type managedStreamKeepalive struct { started bool } +type openAIToolOwnership int + +const ( + openAIToolsAllManaged openAIToolOwnership = iota + openAIToolsAllNative + openAIToolsManagedThenNative + openAIToolsUnsafeMixed +) + +type anthropicToolOwnership int + +const ( + anthropicToolsAllManaged anthropicToolOwnership = iota + anthropicToolsAllNative + anthropicToolsManagedThenNative + anthropicToolsUnsafeMixed +) + type limitedReadResult struct { Body []byte Truncated bool @@ -219,8 +241,8 @@ func (h *Handler) handleManagedOpenAI(w http.ResponseWriter, r *http.Request, ag return } - managedCalls, nativeCalls := partitionManagedOpenAIToolCalls(agentCtx, toolCalls) - if len(managedCalls) == 0 { + managedCalls, _, ownership := classifyOpenAIToolCalls(agentCtx, toolCalls) + if ownership == openAIToolsAllNative { responseBytes := resp.Body if downstreamStream { sse, synthErr := synthesizeOpenAIToolCallStream(resp.Body, resp.UpstreamModel, usageAgg, downstreamIncludeUsage) @@ -256,8 +278,8 @@ func (h *Handler) handleManagedOpenAI(w http.ResponseWriter, r *http.Request, ag return } - if len(nativeCalls) > 0 { - msg := "mixed managed and runner-native tool calls are not supported in one model response" + if ownership == openAIToolsUnsafeMixed { + msg := mixedToolOrderMessage h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) if streamKeepalive != nil && downstreamStream { streamKeepalive.writeOpenAIError(jsonErrorPayload(msg)) @@ -267,13 +289,15 @@ func (h *Handler) handleManagedOpenAI(w http.ResponseWriter, r *http.Request, ag h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) return } - if len(toolTrace) >= policy.MaxRounds { msg := fmt.Sprintf("managed tool max rounds exceeded (%d)", policy.MaxRounds) h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) return } + if ownership == openAIToolsManagedThenNative { + h.logger.LogIntervention(agentID, requestedModel, managedMixedPrefixSerializedIntervention) + } toolMessages := make([]any, 0, len(toolCalls)) roundTrace := sessionhistory.ToolRoundTrace{ @@ -309,8 +333,15 @@ func (h *Handler) handleManagedOpenAI(w http.ResponseWriter, r *http.Request, ag }) } toolTrace = append(toolTrace, roundTrace) - appendOpenAIAssistantAndToolMessages(payload, assistantMessage, toolMessages) - hiddenMessages = appendManagedOpenAIContinuityMessages(hiddenMessages, assistantMessage, toolMessages) + managedAssistant := assistantMessage + if ownership == openAIToolsManagedThenNative { + managedAssistant = buildOpenAIAssistantMessage(assistantMessage, managedCalls, true) + } + appendOpenAIAssistantAndToolMessages(payload, managedAssistant, toolMessages) + // Persist the filtered managed-only assistant so the hidden continuity + // transcript matches the serialized round the model actually saw before + // the runner-native handoff. + hiddenMessages = appendManagedOpenAIContinuityMessages(hiddenMessages, managedAssistant, toolMessages) } } @@ -410,8 +441,8 @@ func (h *Handler) handleManagedAnthropic(w http.ResponseWriter, r *http.Request, return } - managedToolUses, nativeToolUses := partitionManagedAnthropicToolUses(agentCtx, toolUses) - if len(managedToolUses) == 0 { + managedToolUses, _, ownership := classifyAnthropicToolUses(agentCtx, toolUses) + if ownership == anthropicToolsAllNative { responseBytes := resp.Body if downstreamStream { sse, synthErr := synthesizeAnthropicToolUseStream(resp.Body, resp.UpstreamModel, usageAgg) @@ -447,8 +478,8 @@ func (h *Handler) handleManagedAnthropic(w http.ResponseWriter, r *http.Request, return } - if len(nativeToolUses) > 0 { - msg := "mixed managed and runner-native tool calls are not supported in one model response" + if ownership == anthropicToolsUnsafeMixed { + msg := mixedToolOrderMessage h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) if streamKeepalive != nil && downstreamStream { streamKeepalive.writeAnthropicError(msg) @@ -458,13 +489,15 @@ func (h *Handler) handleManagedAnthropic(w http.ResponseWriter, r *http.Request, h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) return } - if len(toolTrace) >= policy.MaxRounds { msg := fmt.Sprintf("managed tool max rounds exceeded (%d)", policy.MaxRounds) h.recordManagedFailure(agentID, resp.ProviderName, requestedModel, resp.UpstreamModel, r.URL.Path, requestOriginal, requestEffective, http.StatusBadGateway, jsonErrorPayload(msg), usageAgg, toolTrace) h.fail(w, http.StatusBadGateway, msg, agentID, requestedModel, start, fmt.Errorf(msg)) return } + if ownership == anthropicToolsManagedThenNative { + h.logger.LogIntervention(agentID, requestedModel, managedMixedPrefixSerializedIntervention) + } toolResults := make([]map[string]any, 0, len(toolUses)) roundTrace := sessionhistory.ToolRoundTrace{ @@ -496,8 +529,15 @@ func (h *Handler) handleManagedAnthropic(w http.ResponseWriter, r *http.Request, toolResults = append(toolResults, anthropicToolResultBlock(call.ID, outcome.RawJSON)) } toolTrace = append(toolTrace, roundTrace) - toolResultMessage := appendAnthropicAssistantAndToolResultMessages(payload, assistantMessage, toolResults) - hiddenMessages = appendManagedAnthropicContinuityMessages(hiddenMessages, assistantMessage, toolResultMessage) + managedAssistant := assistantMessage + if ownership == anthropicToolsManagedThenNative { + managedAssistant = buildAnthropicAssistantMessage(assistantMessage, managedToolUses, true) + } + toolResultMessage := appendAnthropicAssistantAndToolResultMessages(payload, managedAssistant, toolResults) + // Persist the filtered managed-only assistant so the hidden continuity + // transcript matches the serialized round the model actually saw before + // the runner-native handoff. + hiddenMessages = appendManagedAnthropicContinuityMessages(hiddenMessages, managedAssistant, toolResultMessage) } } @@ -1025,36 +1065,163 @@ func parseAnthropicToolResponse(body []byte) (map[string]any, []anthropicToolUse } } -func partitionManagedOpenAIToolCalls(agentCtx *agentctx.AgentContext, calls []openAIToolCall) ([]openAIToolCall, []openAIToolCall) { +func classifyOpenAIToolCalls(agentCtx *agentctx.AgentContext, calls []openAIToolCall) ([]openAIToolCall, []openAIToolCall, openAIToolOwnership) { if len(calls) == 0 { - return nil, nil + return nil, nil, openAIToolsAllManaged } managed := make([]openAIToolCall, 0, len(calls)) native := make([]openAIToolCall, 0, len(calls)) + sawManaged := false + sawNative := false for _, call := range calls { if _, ok := resolveManagedTool(agentCtx, call.Name); ok { + if sawNative { + managed = append(managed, call) + return managed, native, openAIToolsUnsafeMixed + } + sawManaged = true managed = append(managed, call) continue } + sawNative = true native = append(native, call) } - return managed, native + switch { + case sawManaged && sawNative: + return managed, native, openAIToolsManagedThenNative + case sawManaged: + return managed, nil, openAIToolsAllManaged + default: + return nil, native, openAIToolsAllNative + } } -func partitionManagedAnthropicToolUses(agentCtx *agentctx.AgentContext, calls []anthropicToolUse) ([]anthropicToolUse, []anthropicToolUse) { +func classifyAnthropicToolUses(agentCtx *agentctx.AgentContext, calls []anthropicToolUse) ([]anthropicToolUse, []anthropicToolUse, anthropicToolOwnership) { if len(calls) == 0 { - return nil, nil + return nil, nil, anthropicToolsAllManaged } managed := make([]anthropicToolUse, 0, len(calls)) native := make([]anthropicToolUse, 0, len(calls)) + sawManaged := false + sawNative := false for _, call := range calls { if _, ok := resolveManagedTool(agentCtx, call.Name); ok { + if sawNative { + managed = append(managed, call) + return managed, native, anthropicToolsUnsafeMixed + } + sawManaged = true managed = append(managed, call) continue } + sawNative = true native = append(native, call) } - return managed, native + switch { + case sawManaged && sawNative: + return managed, native, anthropicToolsManagedThenNative + case sawManaged: + return managed, nil, anthropicToolsAllManaged + default: + return nil, native, anthropicToolsAllNative + } +} + +func buildOpenAIAssistantMessage(base map[string]any, calls []openAIToolCall, includeContent bool) map[string]any { + msg := cloneAnyMap(base) + msg["role"] = "assistant" + if !includeContent { + delete(msg, "content") + } + delete(msg, "tool_calls") + delete(msg, "function_call") + msg["tool_calls"] = serializeOpenAIToolCalls(calls) + return msg +} + +func buildAnthropicAssistantMessage(base map[string]any, calls []anthropicToolUse, includeText bool) map[string]any { + msg := cloneAnyMap(base) + msg["role"] = "assistant" + msg["content"] = filterAnthropicToolUseContent(base["content"], calls, includeText) + return msg +} + +func serializeOpenAIToolCalls(calls []openAIToolCall) []any { + out := make([]any, 0, len(calls)) + for _, call := range calls { + args := strings.TrimSpace(string(call.ArgumentsRaw)) + if args == "" { + if len(call.Arguments) > 0 { + if raw, err := json.Marshal(call.Arguments); err == nil { + args = string(raw) + } + } + if args == "" { + args = "{}" + } + } + out = append(out, map[string]any{ + "id": call.ID, + "type": "function", + "function": map[string]any{ + "name": call.Name, + "arguments": args, + }, + }) + } + return out +} + +func filterAnthropicToolUseContent(content any, calls []anthropicToolUse, includeText bool) []any { + selected := make(map[string]struct{}, len(calls)) + for _, call := range calls { + if id := strings.TrimSpace(call.ID); id != "" { + selected[id] = struct{}{} + } + } + blocks, _ := content.([]any) + out := make([]any, 0, len(blocks)) + for i, raw := range blocks { + block, _ := raw.(map[string]any) + if block == nil { + if includeText { + out = append(out, raw) + } + continue + } + blockType, _ := block["type"].(string) + if blockType == "tool_use" { + blockID, _ := block["id"].(string) + selectedID := strings.TrimSpace(blockID) + if selectedID == "" { + // parseAnthropicToolResponse synthesizes toolu_ ids + // for id-less tool_use blocks; mirror that mapping here so the + // managed prefix can be filtered back out consistently. + selectedID = fmt.Sprintf("toolu_%d", i+1) + } + if _, ok := selected[selectedID]; ok { + out = append(out, cloneAnyMap(block)) + } + continue + } + if includeText { + out = append(out, cloneAnyMap(block)) + } + } + return out +} + +func cloneAnyMap(in map[string]any) map[string]any { + if in == nil { + return map[string]any{} + } + // Intentionally shallow: callers replace top-level keys but otherwise treat + // nested values as immutable snapshots of the captured assistant message. + out := make(map[string]any, len(in)) + for key, value := range in { + out[key] = value + } + return out } func appendOpenAIAssistantAndToolMessages(payload map[string]any, assistantMessage map[string]any, toolMessages []any) {