Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 19 additions & 7 deletions dwctl/src/responses/middleware.rs
Original file line number Diff line number Diff line change
Expand Up @@ -602,13 +602,6 @@ async fn warm_path_setup<P: PoolProvider + Clone + Send + Sync + 'static>(
model: &str,
) -> Option<(uuid::Uuid, Arc<crate::tool_executor::ResolvedToolSet>, onwards::UpstreamTarget)> {
let created_by = response_store::lookup_created_by(&state.dwctl_pool, Some(api_key)).await;
let pending = response_store::PendingResponseInput {
body: request_value.to_string(),
api_key: Some(api_key.to_string()),
created_by,
base_url: state.loopback_base_url.clone(),
};
let head_step_uuid = state.response_store.register_pending(pending);

let resolved = match crate::tool_injection::resolve_tools_for_request(&state.dwctl_pool, api_key, Some(model)).await {
Ok(Some(set)) => Arc::new(set),
Expand All @@ -625,6 +618,25 @@ async fn warm_path_setup<P: PoolProvider + Clone + Send + Sync + 'static>(
}
};

// The transition function uses these names to decide which
// tool_calls returned by the model can be auto-dispatched and
// which must be passed through to the client as `function_call`
// output items. Any tool the user supplies in their request body
// that isn't registered in `tool_sources` ends up outside this set
// and gets the client-side passthrough treatment — without this,
// HttpToolExecutor would try to dispatch the unknown name and the
// step would fail with `Tool not found`.
let resolved_tool_names = resolved.tools.keys().cloned().collect();

let pending = response_store::PendingResponseInput {
body: request_value.to_string(),
api_key: Some(api_key.to_string()),
created_by,
base_url: state.loopback_base_url.clone(),
resolved_tool_names,
};
let head_step_uuid = state.response_store.register_pending(pending);

let upstream = onwards::UpstreamTarget {
url: format!("{}/v1/chat/completions", state.loopback_base_url),
api_key: Some(api_key.to_string()),
Expand Down
16 changes: 14 additions & 2 deletions dwctl/src/responses/store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
//! All fusillade operations go through the `Storage` trait via `request_manager`.
//! The only raw SQL is the `api_keys` lookup which queries a dwctl-owned table.

use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, RwLock};

use async_trait::async_trait;
Expand Down Expand Up @@ -45,6 +45,18 @@ pub struct PendingResponseInput {
/// `base_url` column points at the dwctl loopback so onwards can
/// pick a target / honor strict mode at fire time.
pub base_url: String,
/// Names of server-side tools registered for this request (resolved
/// from `tool_sources` joined with the user's groups + the deployment).
/// The transition function uses this to decide which tool_calls
/// returned by the model can be auto-dispatched server-side and which
/// must be passed through to the client as `function_call` output items.
///
/// When a tool_call's name is missing from this set, it's treated as a
/// client-side tool: the loop completes with the model's response, and
/// `assemble_response` surfaces the call as a `function_call` item per
/// the OpenAI Responses contract — the client is expected to execute
/// it and submit the result via a follow-up request.
pub resolved_tool_names: HashSet<String>,
}

/// Header set by the responses middleware so the outlet handler knows which
Expand Down Expand Up @@ -741,7 +753,7 @@ impl<P: PoolProvider + Clone + Send + Sync + 'static> MultiStepStore for Fusilla
let pending = self.pending_input(request_id)?;
let parsed = super::transition::parse_parent_request(&pending.body).map_err(StoreError::StorageError)?;
let chain = <Self as MultiStepStore>::list_chain(self, request_id, scope_parent).await?;
Ok(super::transition::decide_next_action(&parsed, &chain))
Ok(super::transition::decide_next_action(&parsed, &chain, &pending.resolved_tool_names))
}

async fn record_step(
Expand Down
117 changes: 108 additions & 9 deletions dwctl/src/responses/transition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
//! tests) stay decoupled from the strict-mode adapter — wiring those in
//! is a future cleanup once the multi-step path is live.

use std::collections::HashSet;

use onwards::{ChainStep, NextAction, StepDescriptor, StepKind, StepState};
use serde_json::{Value, json};

Expand Down Expand Up @@ -239,10 +241,16 @@ pub(crate) fn prepare_followup_model_call(parsed: &ParsedRequest, chain: &[Chain
/// Decide the next action given:
/// - `parsed`: the parent fusillade request body
/// - `chain`: completed/failed steps in the current scope, in sequence order
/// - `resolved_tool_names`: names of server-side tools registered for this
/// request. Tool_calls whose name is in this set get auto-dispatched as
/// server-side `ToolCall` steps; tool_calls outside the set are
/// passed through to the client as `function_call` output items by
/// completing the response with the model's payload (the assembly step
/// surfaces them per the OpenAI Responses contract).
///
/// Returns the action the loop should take. Pure function over its inputs;
/// no I/O.
pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep]) -> NextAction {
pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep], resolved_tool_names: &HashSet<String>) -> NextAction {
if chain.is_empty() {
return NextAction::AppendSteps(vec![prepare_initial_model_call(parsed)]);
}
Expand Down Expand Up @@ -275,9 +283,31 @@ pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep]) ->
let tool_calls = extract_tool_calls(&response);
if tool_calls.is_empty() {
// No tool calls — the model returned final output.
NextAction::Complete(response)
} else {
return NextAction::Complete(response);
}

// Server-side dispatch is only safe when every tool_call
// names a server-registered tool (i.e., one with a row in
// `tool_sources` for this request's user/deployment). If
// any name is unregistered, it's a client-side function
// tool — the model must have seen it because the user put
// it in the request body — and we cannot dispatch it. The
// OpenAI Responses contract for that case is to surface
// every tool_call as a `function_call` output item and let
// the client run them and submit results in a follow-up.
//
// We bail out for the *whole* fan-out (rather than partial
// dispatch) because the model expects results for every
// call it emitted before producing its next message; a
// mixed dispatch would leave the conversation in a state
// the upstream model can't reason about.
let all_registered = tool_calls
.iter()
.all(|step| tool_call_name(&step.request_payload).is_some_and(|name| resolved_tool_names.contains(name)));
if all_registered {
NextAction::AppendSteps(tool_calls)
} else {
NextAction::Complete(response)
}
}
StepKind::ToolCall => {
Expand All @@ -290,6 +320,10 @@ pub(crate) fn decide_next_action(parsed: &ParsedRequest, chain: &[ChainStep]) ->
}
}

fn tool_call_name(payload: &Value) -> Option<&str> {
payload.get("name").and_then(|n| n.as_str())
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -322,6 +356,10 @@ mod tests {
assert_eq!(p.initial_messages.len(), 1);
}

fn names(items: &[&str]) -> HashSet<String> {
items.iter().map(|s| s.to_string()).collect()
}

#[test]
fn empty_chain_emits_initial_model_call() {
let parsed = ParsedRequest {
Expand All @@ -330,7 +368,7 @@ mod tests {
tools: None,
stream: false,
};
match decide_next_action(&parsed, &[]) {
match decide_next_action(&parsed, &[], &HashSet::new()) {
NextAction::AppendSteps(steps) => {
assert_eq!(steps.len(), 1);
assert!(matches!(steps[0].kind, StepKind::ModelCall));
Expand All @@ -341,7 +379,7 @@ mod tests {
}

#[test]
fn model_call_with_tool_calls_emits_fan_out() {
fn model_call_with_registered_tool_calls_emits_fan_out() {
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
Expand All @@ -360,7 +398,7 @@ mod tests {
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response))];
match decide_next_action(&parsed, &chain) {
match decide_next_action(&parsed, &chain, &names(&["a", "b"])) {
NextAction::AppendSteps(steps) => {
assert_eq!(steps.len(), 2);
assert_eq!(steps[0].request_payload["name"], "a");
Expand All @@ -372,6 +410,67 @@ mod tests {
}
}

#[test]
fn model_call_with_unregistered_tool_completes_for_client_dispatch() {
// The user supplied a client-side function tool in the request
// body; the model emits a tool_call for it. With no row in
// `tool_sources` for this name, the loop must NOT try to
// dispatch it (HttpToolExecutor would fail with NotFound) —
// instead it completes with the model's response so assembly
// can surface a `function_call` output item to the client.
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "read_pages", "arguments": "{\"id\":1}"}},
]
}
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))];
match decide_next_action(&parsed, &chain, &HashSet::new()) {
NextAction::Complete(v) => assert_eq!(v, response),
other => panic!("expected Complete for unregistered tool, got {other:?}"),
}
}

#[test]
fn model_call_with_mixed_registered_and_unregistered_completes() {
// If even one tool_call in a fan-out is unregistered, the whole
// batch passes through to the client. Partial dispatch would
// leave the model expecting results for tool_calls the loop
// never ran.
let parsed = ParsedRequest {
model: "m".into(),
initial_messages: vec![],
tools: None,
stream: false,
};
let response = json!({
"choices": [{
"message": {
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "weather", "arguments": "{}"}},
{"id": "call_2", "type": "function", "function": {"name": "client_only", "arguments": "{}"}},
]
}
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))];
match decide_next_action(&parsed, &chain, &names(&["weather"])) {
NextAction::Complete(v) => assert_eq!(v, response),
other => panic!("expected Complete for mixed tool_calls, got {other:?}"),
}
}

#[test]
fn model_call_without_tool_calls_completes() {
let parsed = ParsedRequest {
Expand All @@ -386,7 +485,7 @@ mod tests {
}]
});
let chain = vec![step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(response.clone()))];
match decide_next_action(&parsed, &chain) {
match decide_next_action(&parsed, &chain, &HashSet::new()) {
NextAction::Complete(v) => assert_eq!(v, response),
_ => panic!("expected Complete"),
}
Expand All @@ -412,7 +511,7 @@ mod tests {
step("s1", 1, StepKind::ModelCall, StepState::Completed, Some(model_response)),
step("s2", 2, StepKind::ToolCall, StepState::Completed, Some(json!({"result": 1}))),
];
match decide_next_action(&parsed, &chain) {
match decide_next_action(&parsed, &chain, &names(&["a"])) {
NextAction::AppendSteps(steps) => {
assert_eq!(steps.len(), 1);
assert!(matches!(steps[0].kind, StepKind::ModelCall));
Expand All @@ -437,7 +536,7 @@ mod tests {
};
let mut s = step("s1", 1, StepKind::ModelCall, StepState::Failed, None);
s.error = Some(json!({"type": "upstream_500"}));
match decide_next_action(&parsed, &[s]) {
match decide_next_action(&parsed, &[s], &HashSet::new()) {
NextAction::Fail(v) => assert_eq!(v, json!({"type": "upstream_500"})),
_ => panic!("expected Fail"),
}
Expand Down
1 change: 1 addition & 0 deletions dwctl/src/test/multi_step_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ where
api_key: None,
created_by: None,
base_url: base_url.to_string(),
resolved_tool_names: std::collections::HashSet::new(),
})
.to_string()
}
Expand Down
1 change: 1 addition & 0 deletions dwctl/src/test/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ async fn test_multi_step_chain_assembles_and_is_retrievable_via_get(pool: PgPool
api_key: None,
created_by: Some("test-user".to_string()),
base_url: "http://upstream-mock".to_string(),
resolved_tool_names: std::collections::HashSet::new(),
});
let request_id = head_uuid.to_string();

Expand Down
Loading