From 4dccf7de9a9e0ca50d781466d976b361a3ed96a8 Mon Sep 17 00:00:00 2001 From: pjb157 Date: Wed, 6 May 2026 17:47:51 +0100 Subject: [PATCH] fix(responses): pass through client-side tool_calls instead of dispatching them MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The /v1/responses warm path always engages multi-step orchestration and dispatches every tool_call returned by the model through HttpToolExecutor. When a client supplies their own function tools in the request body (client-side tool calling per the OpenAI Responses spec), the executor looks them up in the server-side tool_sources registry, doesn't find them, and fails the step with `Tool not found`. Fix: thread the set of registered tool names from `resolve_tools_for_request` through PendingResponseInput into the transition function. When a model_call returns tool_calls and any name is missing from that set, complete the response with the model's payload — assembly already emits `function_call` output items in OpenAI Responses shape, so the client receives the calls and can execute them locally. The whole tool_call batch passes through (not just the unregistered ones) because partial dispatch would leave the upstream conversation expecting tool results for calls the loop never ran. --- dwctl/src/responses/middleware.rs | 26 ++++-- dwctl/src/responses/store.rs | 16 +++- dwctl/src/responses/transition.rs | 117 ++++++++++++++++++++++++-- dwctl/src/test/multi_step_executor.rs | 1 + dwctl/src/test/responses.rs | 1 + 5 files changed, 143 insertions(+), 18 deletions(-) diff --git a/dwctl/src/responses/middleware.rs b/dwctl/src/responses/middleware.rs index 8156e544d..cb74def80 100644 --- a/dwctl/src/responses/middleware.rs +++ b/dwctl/src/responses/middleware.rs @@ -602,13 +602,6 @@ async fn warm_path_setup( model: &str, ) -> Option<(uuid::Uuid, Arc, onwards::UpstreamTarget)> { let created_by = response_store::lookup_created_by(&state.dwctl_pool, Some(api_key)).await; - let pending = response_store::PendingResponseInput { - body: request_value.to_string(), - api_key: Some(api_key.to_string()), - created_by, - base_url: state.loopback_base_url.clone(), - }; - let head_step_uuid = state.response_store.register_pending(pending); let resolved = match crate::tool_injection::resolve_tools_for_request(&state.dwctl_pool, api_key, Some(model)).await { Ok(Some(set)) => Arc::new(set), @@ -625,6 +618,25 @@ async fn warm_path_setup( } }; + // The transition function uses these names to decide which + // tool_calls returned by the model can be auto-dispatched and + // which must be passed through to the client as `function_call` + // output items. Any tool the user supplies in their request body + // that isn't registered in `tool_sources` ends up outside this set + // and gets the client-side passthrough treatment — without this, + // HttpToolExecutor would try to dispatch the unknown name and the + // step would fail with `Tool not found`. + let resolved_tool_names = resolved.tools.keys().cloned().collect(); + + let pending = response_store::PendingResponseInput { + body: request_value.to_string(), + api_key: Some(api_key.to_string()), + created_by, + base_url: state.loopback_base_url.clone(), + resolved_tool_names, + }; + let head_step_uuid = state.response_store.register_pending(pending); + let upstream = onwards::UpstreamTarget { url: format!("{}/v1/chat/completions", state.loopback_base_url), api_key: Some(api_key.to_string()), diff --git a/dwctl/src/responses/store.rs b/dwctl/src/responses/store.rs index a20b9d45e..5ae6f3814 100644 --- a/dwctl/src/responses/store.rs +++ b/dwctl/src/responses/store.rs @@ -4,7 +4,7 @@ //! All fusillade operations go through the `Storage` trait via `request_manager`. //! The only raw SQL is the `api_keys` lookup which queries a dwctl-owned table. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::{Arc, RwLock}; use async_trait::async_trait; @@ -45,6 +45,18 @@ pub struct PendingResponseInput { /// `base_url` column points at the dwctl loopback so onwards can /// pick a target / honor strict mode at fire time. pub base_url: String, + /// Names of server-side tools registered for this request (resolved + /// from `tool_sources` joined with the user's groups + the deployment). + /// The transition function uses this to decide which tool_calls + /// returned by the model can be auto-dispatched server-side and which + /// must be passed through to the client as `function_call` output items. + /// + /// When a tool_call's name is missing from this set, it's treated as a + /// client-side tool: the loop completes with the model's response, and + /// `assemble_response` surfaces the call as a `function_call` item per + /// the OpenAI Responses contract — the client is expected to execute + /// it and submit the result via a follow-up request. + pub resolved_tool_names: HashSet, } /// Header set by the responses middleware so the outlet handler knows which @@ -741,7 +753,7 @@ impl MultiStepStore for Fusilla let pending = self.pending_input(request_id)?; let parsed = super::transition::parse_parent_request(&pending.body).map_err(StoreError::StorageError)?; let chain = ::list_chain(self, request_id, scope_parent).await?; - Ok(super::transition::decide_next_action(&parsed, &chain)) + Ok(super::transition::decide_next_action(&parsed, &chain, &pending.resolved_tool_names)) } async fn record_step( diff --git a/dwctl/src/responses/transition.rs b/dwctl/src/responses/transition.rs index 56b1f41d6..0411d850d 100644 --- a/dwctl/src/responses/transition.rs +++ b/dwctl/src/responses/transition.rs @@ -56,6 +56,8 @@ //! tests) stay decoupled from the strict-mode adapter — wiring those in //! is a future cleanup once the multi-step path is live. +use std::collections::HashSet; + use onwards::{ChainStep, NextAction, StepDescriptor, StepKind, StepState}; use serde_json::{Value, json}; @@ -239,10 +241,16 @@ pub(crate) fn prepare_followup_model_call(parsed: &ParsedRequest, chain: &[Chain /// Decide the next action given: /// - `parsed`: the parent fusillade request body /// - `chain`: completed/failed steps in the current scope, in sequence order +/// - `resolved_tool_names`: names of server-side tools registered for this +/// request. Tool_calls whose name is in this set get auto-dispatched as +/// server-side `ToolCall` steps; tool_calls outside the set are +/// passed through to the client as `function_call` output items by +/// completing the response with the model's payload (the assembly step +/// surfaces them per the OpenAI Responses contract). /// /// Returns the action the loop should take. Pure function over its inputs; /// no I/O. -pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep]) -> NextAction { +pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep], resolved_tool_names: &HashSet) -> NextAction { if chain.is_empty() { return NextAction::AppendSteps(vec![prepare_initial_model_call(parsed)]); } @@ -275,9 +283,31 @@ pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep]) -> let tool_calls = extract_tool_calls(&response); if tool_calls.is_empty() { // No tool calls — the model returned final output. - NextAction::Complete(response) - } else { + return NextAction::Complete(response); + } + + // Server-side dispatch is only safe when every tool_call + // names a server-registered tool (i.e., one with a row in + // `tool_sources` for this request's user/deployment). If + // any name is unregistered, it's a client-side function + // tool — the model must have seen it because the user put + // it in the request body — and we cannot dispatch it. The + // OpenAI Responses contract for that case is to surface + // every tool_call as a `function_call` output item and let + // the client run them and submit results in a follow-up. + // + // We bail out for the *whole* fan-out (rather than partial + // dispatch) because the model expects results for every + // call it emitted before producing its next message; a + // mixed dispatch would leave the conversation in a state + // the upstream model can't reason about. + let all_registered = tool_calls + .iter() + .all(|step| tool_call_name(&step.request_payload).is_some_and(|name| resolved_tool_names.contains(name))); + if all_registered { NextAction::AppendSteps(tool_calls) + } else { + NextAction::Complete(response) } } StepKind::ToolCall => { @@ -290,6 +320,10 @@ pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep]) -> } } +fn tool_call_name(payload: &Value) -> Option<&str> { + payload.get("name").and_then(|n| n.as_str()) +} + #[cfg(test)] mod tests { use super::*; @@ -322,6 +356,10 @@ mod tests { assert_eq!(p.initial_messages.len(), 1); } + fn names(items: &[&str]) -> HashSet { + items.iter().map(|s| s.to_string()).collect() + } + #[test] fn empty_chain_emits_initial_model_call() { let parsed = ParsedRequest { @@ -330,7 +368,7 @@ mod tests { tools: None, stream: false, }; - match decide_next_action(&parsed, &[]) { + match decide_next_action(&parsed, &[], &HashSet::new()) { NextAction::AppendSteps(steps) => { assert_eq!(steps.len(), 1); assert!(matches!(steps[0].kind, StepKind::ModelCall)); @@ -341,7 +379,7 @@ mod tests { } #[test] - fn model_call_with_tool_calls_emits_fan_out() { + fn model_call_with_registered_tool_calls_emits_fan_out() { let parsed = ParsedRequest { model: "m".into(), initial_messages: vec![], @@ -360,7 +398,7 @@ mod tests { }] }); let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response))]; - match decide_next_action(&parsed, &chain) { + match decide_next_action(&parsed, &chain, &names(&["a", "b"])) { NextAction::AppendSteps(steps) => { assert_eq!(steps.len(), 2); assert_eq!(steps[0].request_payload["name"], "a"); @@ -372,6 +410,67 @@ mod tests { } } + #[test] + fn model_call_with_unregistered_tool_completes_for_client_dispatch() { + // The user supplied a client-side function tool in the request + // body; the model emits a tool_call for it. With no row in + // `tool_sources` for this name, the loop must NOT try to + // dispatch it (HttpToolExecutor would fail with NotFound) — + // instead it completes with the model's response so assembly + // can surface a `function_call` output item to the client. + let parsed = ParsedRequest { + model: "m".into(), + initial_messages: vec![], + tools: None, + stream: false, + }; + let response = json!({ + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "read_pages", "arguments": "{\"id\":1}"}}, + ] + } + }] + }); + let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))]; + match decide_next_action(&parsed, &chain, &HashSet::new()) { + NextAction::Complete(v) => assert_eq!(v, response), + other => panic!("expected Complete for unregistered tool, got {other:?}"), + } + } + + #[test] + fn model_call_with_mixed_registered_and_unregistered_completes() { + // If even one tool_call in a fan-out is unregistered, the whole + // batch passes through to the client. Partial dispatch would + // leave the model expecting results for tool_calls the loop + // never ran. + let parsed = ParsedRequest { + model: "m".into(), + initial_messages: vec![], + tools: None, + stream: false, + }; + let response = json!({ + "choices": [{ + "message": { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "weather", "arguments": "{}"}}, + {"id": "call_2", "type": "function", "function": {"name": "client_only", "arguments": "{}"}}, + ] + } + }] + }); + let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))]; + match decide_next_action(&parsed, &chain, &names(&["weather"])) { + NextAction::Complete(v) => assert_eq!(v, response), + other => panic!("expected Complete for mixed tool_calls, got {other:?}"), + } + } + #[test] fn model_call_without_tool_calls_completes() { let parsed = ParsedRequest { @@ -386,7 +485,7 @@ mod tests { }] }); let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))]; - match decide_next_action(&parsed, &chain) { + match decide_next_action(&parsed, &chain, &HashSet::new()) { NextAction::Complete(v) => assert_eq!(v, response), _ => panic!("expected Complete"), } @@ -412,7 +511,7 @@ mod tests { step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(model_response)), step("s2", 2, StepKind::ToolCall, StepState::Completed, Some(json!({"result": 1}))), ]; - match decide_next_action(&parsed, &chain) { + match decide_next_action(&parsed, &chain, &names(&["a"])) { NextAction::AppendSteps(steps) => { assert_eq!(steps.len(), 1); assert!(matches!(steps[0].kind, StepKind::ModelCall)); @@ -437,7 +536,7 @@ mod tests { }; let mut s = step("s1", 1, StepKind::ModelCall, StepState::Failed, None); s.error = Some(json!({"type": "upstream_500"})); - match decide_next_action(&parsed, &[s]) { + match decide_next_action(&parsed, &[s], &HashSet::new()) { NextAction::Fail(v) => assert_eq!(v, json!({"type": "upstream_500"})), _ => panic!("expected Fail"), } diff --git a/dwctl/src/test/multi_step_executor.rs b/dwctl/src/test/multi_step_executor.rs index e292743db..ecf946589 100644 --- a/dwctl/src/test/multi_step_executor.rs +++ b/dwctl/src/test/multi_step_executor.rs @@ -62,6 +62,7 @@ where api_key: None, created_by: None, base_url: base_url.to_string(), + resolved_tool_names: std::collections::HashSet::new(), }) .to_string() } diff --git a/dwctl/src/test/responses.rs b/dwctl/src/test/responses.rs index dbf8cc862..62d8b13a7 100644 --- a/dwctl/src/test/responses.rs +++ b/dwctl/src/test/responses.rs @@ -315,6 +315,7 @@ async fn test_multi_step_chain_assembles_and_is_retrievable_via_get(pool: PgPool api_key: None, created_by: Some("test-user".to_string()), base_url: "http://upstream-mock".to_string(), + resolved_tool_names: std::collections::HashSet::new(), }); let request_id = head_uuid.to_string();