Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 219 additions & 0 deletions staged/src-tauri/examples/acp_stream_probe.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};

use acp_client::{AcpDriver, AgentDriver, MessageWriter, Store};
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use tokio_util::sync::CancellationToken;

struct NoopStore;

#[async_trait]
impl Store for NoopStore {
fn set_agent_session_id(&self, session_id: &str, agent_session_id: &str) -> Result<(), String> {
println!(
"[probe] set_agent_session_id: session_id={} agent_session_id={}",
session_id, agent_session_id
);
Ok(())
}
}

#[derive(Default)]
struct ProbeState {
tool_call_counts: HashMap<String, usize>,
total_tool_calls: usize,
total_tool_title_updates: usize,
total_tool_results: usize,
assistant_chunks: usize,
}

struct ProbeWriter {
state: Arc<Mutex<ProbeState>>,
}

impl ProbeWriter {
fn new(state: Arc<Mutex<ProbeState>>) -> Self {
Self { state }
}

fn truncate(s: &str) -> String {
const MAX: usize = 180;
let mut out = String::new();
for (idx, ch) in s.chars().enumerate() {
if idx >= MAX {
out.push_str("...");
break;
}
out.push(ch);
}
out.replace('\n', "\\n")
}
}

#[async_trait]
impl MessageWriter for ProbeWriter {
async fn append_text(&self, text: &str) {
let mut state = self.state.lock().expect("probe state lock poisoned");
state.assistant_chunks += 1;
println!(
"[probe] assistant_chunk #{}: {}",
state.assistant_chunks,
Self::truncate(text)
);
}

async fn finalize(&self) {
println!("[probe] finalize");
}

async fn record_tool_call(&self, tool_call_id: &str, title: &str) {
let mut state = self.state.lock().expect("probe state lock poisoned");
state.total_tool_calls += 1;
let count_for_id = *state
.tool_call_counts
.entry(tool_call_id.to_string())
.and_modify(|count| *count += 1)
.or_insert(1);
println!(
"[probe] tool_call #{} id={} seen_for_id={} title={}",
state.total_tool_calls,
tool_call_id,
count_for_id,
Self::truncate(title)
);
}

async fn update_tool_call_title(&self, tool_call_id: &str, title: &str) {
let mut state = self.state.lock().expect("probe state lock poisoned");
state.total_tool_title_updates += 1;
println!(
"[probe] tool_call_update #{} id={} title={}",
state.total_tool_title_updates,
tool_call_id,
Self::truncate(title)
);
}

async fn record_tool_result(&self, content: &str) {
let mut state = self.state.lock().expect("probe state lock poisoned");
state.total_tool_results += 1;
println!(
"[probe] tool_result #{}: {}",
state.total_tool_results,
Self::truncate(content)
);
}
}

fn print_usage() {
eprintln!(
"Usage:
cargo run --manifest-path src-tauri/Cargo.toml --example acp_stream_probe -- \\
--provider <provider-id> --workdir <path> --prompt <text>

Defaults:
--provider codex
--workdir .
--prompt \"Run `echo hello` and summarize the output in one sentence.\""
);
}

fn parse_args() -> Result<(String, PathBuf, String)> {
let mut provider = "codex".to_string();
let mut workdir = PathBuf::from(".");
let mut prompt = "Run `echo hello` and summarize the output in one sentence.".to_string();

let mut args = std::env::args().skip(1).peekable();
while let Some(arg) = args.next() {
match arg.as_str() {
"--provider" => {
provider = args
.next()
.ok_or_else(|| anyhow!("missing value for --provider"))?;
}
"--workdir" => {
workdir = PathBuf::from(
args.next()
.ok_or_else(|| anyhow!("missing value for --workdir"))?,
);
}
"--prompt" => {
prompt = args
.next()
.ok_or_else(|| anyhow!("missing value for --prompt"))?;
}
"--help" | "-h" => {
print_usage();
std::process::exit(0);
}
other => {
return Err(anyhow!("unknown argument: {other}"));
}
}
}

Ok((provider, workdir, prompt))
}

fn main() -> Result<()> {
let (provider, workdir, prompt) = parse_args()?;
println!("[probe] provider={provider} workdir={}", workdir.display());
println!("[probe] prompt={}", ProbeWriter::truncate(&prompt));

let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
let local = tokio::task::LocalSet::new();

local.block_on(&rt, async move {
let driver = AcpDriver::new(&provider).map_err(|e| anyhow!(e))?;
let store = Arc::new(NoopStore) as Arc<dyn Store>;
let state = Arc::new(Mutex::new(ProbeState::default()));
let writer = Arc::new(ProbeWriter::new(Arc::clone(&state))) as Arc<dyn MessageWriter>;
let cancel_token = CancellationToken::new();

let result = driver
.run(
"probe-session",
&prompt,
&workdir,
&store,
&writer,
&cancel_token,
None,
)
.await;

let state = state.lock().expect("probe state lock poisoned");
println!("[probe] result={result:?}");
println!(
"[probe] summary: assistant_chunks={} tool_calls={} tool_call_updates={} tool_results={}",
state.assistant_chunks,
state.total_tool_calls,
state.total_tool_title_updates,
state.total_tool_results
);

let mut duplicate_ids: Vec<(&String, &usize)> = state
.tool_call_counts
.iter()
.filter(|(_, count)| **count > 1)
.collect();
duplicate_ids.sort_by(|(a, _), (b, _)| a.cmp(b));

if duplicate_ids.is_empty() {
println!("[probe] duplicate_tool_call_ids=none");
} else {
println!("[probe] duplicate_tool_call_ids:");
for (id, count) in duplicate_ids {
println!(" - id={id} count={count}");
}
}

Ok::<(), anyhow::Error>(())
})?;

Ok(())
}
77 changes: 75 additions & 2 deletions staged/src-tauri/src/agent/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ pub struct MessageWriter {
last_flush_at: Mutex<Instant>,
/// Maps external tool-call IDs → DB row IDs.
tool_call_rows: Mutex<HashMap<String, i64>>,
/// DB row id of the currently streaming tool result.
///
/// ACP can send multiple content updates for one tool call; we update
/// the same row instead of inserting duplicates.
current_tool_result_msg_id: Mutex<Option<i64>>,
}

impl MessageWriter {
Expand All @@ -50,6 +55,7 @@ impl MessageWriter {
current_text: Mutex::new(String::new()),
last_flush_at: Mutex::new(Instant::now()),
tool_call_rows: Mutex::new(HashMap::new()),
current_tool_result_msg_id: Mutex::new(None),
}
}

Expand Down Expand Up @@ -87,6 +93,14 @@ impl MessageWriter {
/// to maintain correct message ordering.
pub async fn record_tool_call(&self, tool_call_id: &str, title: &str) {
self.finalize().await;
*self.current_tool_result_msg_id.lock().await = None;

// Some providers may resend ToolCall for the same ID while streaming.
// Treat those as updates to the existing row.
if let Some(&row_id) = self.tool_call_rows.lock().await.get(tool_call_id) {
let _ = self.store.update_message_content(row_id, title);
return;
}

match self
.store
Expand All @@ -112,9 +126,19 @@ impl MessageWriter {

/// Record the result/output of a tool call.
pub async fn record_tool_result(&self, content: &str) {
let _ = self
let mut current_result_id = self.current_tool_result_msg_id.lock().await;
if let Some(id) = *current_result_id {
let _ = self.store.update_message_content(id, content);
return;
}

match self
.store
.add_session_message(&self.session_id, MessageRole::ToolResult, content);
.add_session_message(&self.session_id, MessageRole::ToolResult, content)
{
Ok(id) => *current_result_id = Some(id),
Err(e) => log::error!("Failed to insert tool_result message: {e}"),
}
}

// =====================================================================
Expand Down Expand Up @@ -181,3 +205,52 @@ impl acp_client::MessageWriter for MessageWriter {
self.record_tool_result(content).await
}
}

#[cfg(test)]
mod tests {
use std::path::Path;
use std::sync::Arc;

use super::MessageWriter;
use crate::store::{MessageRole, Session, Store};

fn setup_writer() -> (Arc<Store>, String, MessageWriter) {
let store = Arc::new(Store::in_memory().expect("in-memory store"));
let session = Session::new_running("test prompt", Path::new("."));
store.create_session(&session).expect("create session");
let writer = MessageWriter::new(session.id.clone(), Arc::clone(&store));
(store, session.id, writer)
}

#[tokio::test]
async fn record_tool_result_updates_existing_row_for_streaming_updates() {
let (store, session_id, writer) = setup_writer();

writer.record_tool_call("tc-1", "Run echo hello").await;
writer.record_tool_result("first chunk").await;
writer.record_tool_result("second chunk").await;

let messages = store
.get_session_messages(&session_id)
.expect("query messages");
assert_eq!(messages.len(), 2);
assert_eq!(messages[0].role, MessageRole::ToolCall);
assert_eq!(messages[1].role, MessageRole::ToolResult);
assert_eq!(messages[1].content, "second chunk");
}

#[tokio::test]
async fn record_tool_call_same_id_updates_instead_of_inserting() {
let (store, session_id, writer) = setup_writer();

writer.record_tool_call("tc-dup", "Run first title").await;
writer.record_tool_call("tc-dup", "Run updated title").await;

let messages = store
.get_session_messages(&session_id)
.expect("query messages");
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].role, MessageRole::ToolCall);
assert_eq!(messages[0].content, "Run updated title");
}
}