diff --git a/staged/src-tauri/examples/acp_stream_probe.rs b/staged/src-tauri/examples/acp_stream_probe.rs new file mode 100644 index 00000000..60dec1b5 --- /dev/null +++ b/staged/src-tauri/examples/acp_stream_probe.rs @@ -0,0 +1,219 @@ +use std::collections::HashMap; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; + +use acp_client::{AcpDriver, AgentDriver, MessageWriter, Store}; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use tokio_util::sync::CancellationToken; + +struct NoopStore; + +#[async_trait] +impl Store for NoopStore { + fn set_agent_session_id(&self, session_id: &str, agent_session_id: &str) -> Result<(), String> { + println!( + "[probe] set_agent_session_id: session_id={} agent_session_id={}", + session_id, agent_session_id + ); + Ok(()) + } +} + +#[derive(Default)] +struct ProbeState { + tool_call_counts: HashMap, + total_tool_calls: usize, + total_tool_title_updates: usize, + total_tool_results: usize, + assistant_chunks: usize, +} + +struct ProbeWriter { + state: Arc>, +} + +impl ProbeWriter { + fn new(state: Arc>) -> Self { + Self { state } + } + + fn truncate(s: &str) -> String { + const MAX: usize = 180; + let mut out = String::new(); + for (idx, ch) in s.chars().enumerate() { + if idx >= MAX { + out.push_str("..."); + break; + } + out.push(ch); + } + out.replace('\n', "\\n") + } +} + +#[async_trait] +impl MessageWriter for ProbeWriter { + async fn append_text(&self, text: &str) { + let mut state = self.state.lock().expect("probe state lock poisoned"); + state.assistant_chunks += 1; + println!( + "[probe] assistant_chunk #{}: {}", + state.assistant_chunks, + Self::truncate(text) + ); + } + + async fn finalize(&self) { + println!("[probe] finalize"); + } + + async fn record_tool_call(&self, tool_call_id: &str, title: &str) { + let mut state = self.state.lock().expect("probe state lock poisoned"); + state.total_tool_calls += 1; + let count_for_id = *state + .tool_call_counts + .entry(tool_call_id.to_string()) + .and_modify(|count| *count += 1) + .or_insert(1); + println!( + "[probe] tool_call #{} id={} seen_for_id={} title={}", + state.total_tool_calls, + tool_call_id, + count_for_id, + Self::truncate(title) + ); + } + + async fn update_tool_call_title(&self, tool_call_id: &str, title: &str) { + let mut state = self.state.lock().expect("probe state lock poisoned"); + state.total_tool_title_updates += 1; + println!( + "[probe] tool_call_update #{} id={} title={}", + state.total_tool_title_updates, + tool_call_id, + Self::truncate(title) + ); + } + + async fn record_tool_result(&self, content: &str) { + let mut state = self.state.lock().expect("probe state lock poisoned"); + state.total_tool_results += 1; + println!( + "[probe] tool_result #{}: {}", + state.total_tool_results, + Self::truncate(content) + ); + } +} + +fn print_usage() { + eprintln!( + "Usage: + cargo run --manifest-path src-tauri/Cargo.toml --example acp_stream_probe -- \\ + --provider --workdir --prompt + +Defaults: + --provider codex + --workdir . + --prompt \"Run `echo hello` and summarize the output in one sentence.\"" + ); +} + +fn parse_args() -> Result<(String, PathBuf, String)> { + let mut provider = "codex".to_string(); + let mut workdir = PathBuf::from("."); + let mut prompt = "Run `echo hello` and summarize the output in one sentence.".to_string(); + + let mut args = std::env::args().skip(1).peekable(); + while let Some(arg) = args.next() { + match arg.as_str() { + "--provider" => { + provider = args + .next() + .ok_or_else(|| anyhow!("missing value for --provider"))?; + } + "--workdir" => { + workdir = PathBuf::from( + args.next() + .ok_or_else(|| anyhow!("missing value for --workdir"))?, + ); + } + "--prompt" => { + prompt = args + .next() + .ok_or_else(|| anyhow!("missing value for --prompt"))?; + } + "--help" | "-h" => { + print_usage(); + std::process::exit(0); + } + other => { + return Err(anyhow!("unknown argument: {other}")); + } + } + } + + Ok((provider, workdir, prompt)) +} + +fn main() -> Result<()> { + let (provider, workdir, prompt) = parse_args()?; + println!("[probe] provider={provider} workdir={}", workdir.display()); + println!("[probe] prompt={}", ProbeWriter::truncate(&prompt)); + + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build()?; + let local = tokio::task::LocalSet::new(); + + local.block_on(&rt, async move { + let driver = AcpDriver::new(&provider).map_err(|e| anyhow!(e))?; + let store = Arc::new(NoopStore) as Arc; + let state = Arc::new(Mutex::new(ProbeState::default())); + let writer = Arc::new(ProbeWriter::new(Arc::clone(&state))) as Arc; + let cancel_token = CancellationToken::new(); + + let result = driver + .run( + "probe-session", + &prompt, + &workdir, + &store, + &writer, + &cancel_token, + None, + ) + .await; + + let state = state.lock().expect("probe state lock poisoned"); + println!("[probe] result={result:?}"); + println!( + "[probe] summary: assistant_chunks={} tool_calls={} tool_call_updates={} tool_results={}", + state.assistant_chunks, + state.total_tool_calls, + state.total_tool_title_updates, + state.total_tool_results + ); + + let mut duplicate_ids: Vec<(&String, &usize)> = state + .tool_call_counts + .iter() + .filter(|(_, count)| **count > 1) + .collect(); + duplicate_ids.sort_by(|(a, _), (b, _)| a.cmp(b)); + + if duplicate_ids.is_empty() { + println!("[probe] duplicate_tool_call_ids=none"); + } else { + println!("[probe] duplicate_tool_call_ids:"); + for (id, count) in duplicate_ids { + println!(" - id={id} count={count}"); + } + } + + Ok::<(), anyhow::Error>(()) + })?; + + Ok(()) +} diff --git a/staged/src-tauri/src/agent/writer.rs b/staged/src-tauri/src/agent/writer.rs index b708b218..76e8c5e9 100644 --- a/staged/src-tauri/src/agent/writer.rs +++ b/staged/src-tauri/src/agent/writer.rs @@ -39,6 +39,11 @@ pub struct MessageWriter { last_flush_at: Mutex, /// Maps external tool-call IDs → DB row IDs. tool_call_rows: Mutex>, + /// DB row id of the currently streaming tool result. + /// + /// ACP can send multiple content updates for one tool call; we update + /// the same row instead of inserting duplicates. + current_tool_result_msg_id: Mutex>, } impl MessageWriter { @@ -50,6 +55,7 @@ impl MessageWriter { current_text: Mutex::new(String::new()), last_flush_at: Mutex::new(Instant::now()), tool_call_rows: Mutex::new(HashMap::new()), + current_tool_result_msg_id: Mutex::new(None), } } @@ -87,6 +93,14 @@ impl MessageWriter { /// to maintain correct message ordering. pub async fn record_tool_call(&self, tool_call_id: &str, title: &str) { self.finalize().await; + *self.current_tool_result_msg_id.lock().await = None; + + // Some providers may resend ToolCall for the same ID while streaming. + // Treat those as updates to the existing row. + if let Some(&row_id) = self.tool_call_rows.lock().await.get(tool_call_id) { + let _ = self.store.update_message_content(row_id, title); + return; + } match self .store @@ -112,9 +126,19 @@ impl MessageWriter { /// Record the result/output of a tool call. pub async fn record_tool_result(&self, content: &str) { - let _ = self + let mut current_result_id = self.current_tool_result_msg_id.lock().await; + if let Some(id) = *current_result_id { + let _ = self.store.update_message_content(id, content); + return; + } + + match self .store - .add_session_message(&self.session_id, MessageRole::ToolResult, content); + .add_session_message(&self.session_id, MessageRole::ToolResult, content) + { + Ok(id) => *current_result_id = Some(id), + Err(e) => log::error!("Failed to insert tool_result message: {e}"), + } } // ===================================================================== @@ -181,3 +205,52 @@ impl acp_client::MessageWriter for MessageWriter { self.record_tool_result(content).await } } + +#[cfg(test)] +mod tests { + use std::path::Path; + use std::sync::Arc; + + use super::MessageWriter; + use crate::store::{MessageRole, Session, Store}; + + fn setup_writer() -> (Arc, String, MessageWriter) { + let store = Arc::new(Store::in_memory().expect("in-memory store")); + let session = Session::new_running("test prompt", Path::new(".")); + store.create_session(&session).expect("create session"); + let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store)); + (store, session.id, writer) + } + + #[tokio::test] + async fn record_tool_result_updates_existing_row_for_streaming_updates() { + let (store, session_id, writer) = setup_writer(); + + writer.record_tool_call("tc-1", "Run echo hello").await; + writer.record_tool_result("first chunk").await; + writer.record_tool_result("second chunk").await; + + let messages = store + .get_session_messages(&session_id) + .expect("query messages"); + assert_eq!(messages.len(), 2); + assert_eq!(messages[0].role, MessageRole::ToolCall); + assert_eq!(messages[1].role, MessageRole::ToolResult); + assert_eq!(messages[1].content, "second chunk"); + } + + #[tokio::test] + async fn record_tool_call_same_id_updates_instead_of_inserting() { + let (store, session_id, writer) = setup_writer(); + + writer.record_tool_call("tc-dup", "Run first title").await; + writer.record_tool_call("tc-dup", "Run updated title").await; + + let messages = store + .get_session_messages(&session_id) + .expect("query messages"); + assert_eq!(messages.len(), 1); + assert_eq!(messages[0].role, MessageRole::ToolCall); + assert_eq!(messages[0].content, "Run updated title"); + } +}