From 4a4a0c3c1e4b82327cc6f83eea7921797c181dc5 Mon Sep 17 00:00:00 2001 From: Dylan Hurd Date: Sun, 12 Apr 2026 13:30:17 -0700 Subject: [PATCH] fix(mcp) pause timer for elicitations --- .../src/elicitation_client_service.rs | 10 +- codex-rs/rmcp-client/src/rmcp_client.rs | 169 ++++++++++++++++-- 2 files changed, 165 insertions(+), 14 deletions(-) diff --git a/codex-rs/rmcp-client/src/elicitation_client_service.rs b/codex-rs/rmcp-client/src/elicitation_client_service.rs index 1d8fb954ccf7..49f11f0a7637 100644 --- a/codex-rs/rmcp-client/src/elicitation_client_service.rs +++ b/codex-rs/rmcp-client/src/elicitation_client_service.rs @@ -17,6 +17,7 @@ use serde_json::Value; use crate::logging_client_handler::LoggingClientHandler; use crate::rmcp_client::Elicitation; +use crate::rmcp_client::ElicitationPauseState; use crate::rmcp_client::ElicitationResponse; use crate::rmcp_client::SendElicitation; @@ -26,10 +27,15 @@ const MCP_PROGRESS_TOKEN_META_KEY: &str = "progressToken"; pub(crate) struct ElicitationClientService { handler: LoggingClientHandler, send_elicitation: Arc, + pause_state: ElicitationPauseState, } impl ElicitationClientService { - pub(crate) fn new(client_info: ClientInfo, send_elicitation: SendElicitation) -> Self { + pub(crate) fn new( + client_info: ClientInfo, + send_elicitation: SendElicitation, + pause_state: ElicitationPauseState, + ) -> Self { let send_elicitation = Arc::new(send_elicitation); Self { handler: LoggingClientHandler::new( @@ -37,6 +43,7 @@ impl ElicitationClientService { clone_send_elicitation(Arc::clone(&send_elicitation)), ), send_elicitation, + pause_state, } } @@ -47,6 +54,7 @@ impl ElicitationClientService { ) -> Result { let RequestContext { id, meta, .. } = context; let request = restore_context_meta(request, meta); + let _pause = self.pause_state.enter(); (self.send_elicitation)(id, request) .await .map_err(|err| rmcp::ErrorData::internal_error(err.to_string(), None)) diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 4f23526d1130..415354fee4e0 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1,11 +1,15 @@ use std::borrow::Cow; use std::collections::HashMap; use std::ffi::OsString; +use std::future::Future; use std::io; use std::path::PathBuf; use std::process::Stdio; use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::atomic::Ordering; use std::time::Duration; +use std::time::Instant; use anyhow::Result; use anyhow::anyhow; @@ -63,6 +67,7 @@ use tokio::io::AsyncBufReadExt; use tokio::io::BufReader; use tokio::process::Command; use tokio::sync::Mutex; +use tokio::sync::watch; use tokio::time; use tracing::info; use tracing::warn; @@ -410,6 +415,93 @@ struct InitializeContext { client_service: ElicitationClientService, } +#[derive(Clone)] +pub(crate) struct ElicitationPauseState { + active_count: Arc, + paused: watch::Sender, +} + +impl ElicitationPauseState { + fn new() -> Self { + let (paused, _rx) = watch::channel(false); + Self { + active_count: Arc::new(AtomicUsize::new(0)), + paused, + } + } + + pub(crate) fn enter(&self) -> ElicitationPauseGuard { + if self.active_count.fetch_add(1, Ordering::AcqRel) == 0 { + self.paused.send_replace(true); + } + ElicitationPauseGuard { + pause_state: self.clone(), + } + } + + fn subscribe(&self) -> watch::Receiver { + self.paused.subscribe() + } +} + +pub(crate) struct ElicitationPauseGuard { + pause_state: ElicitationPauseState, +} + +impl Drop for ElicitationPauseGuard { + fn drop(&mut self) { + if self.pause_state.active_count.fetch_sub(1, Ordering::AcqRel) == 1 { + self.pause_state.paused.send_replace(false); + } + } +} + +async fn active_time_timeout( + duration: Duration, + mut pause_state: watch::Receiver, + operation: Fut, +) -> std::result::Result +where + Fut: Future, +{ + let mut remaining = duration; + tokio::pin!(operation); + + loop { + if *pause_state.borrow_and_update() { + tokio::select! { + result = &mut operation => return Ok(result), + changed = pause_state.changed() => { + if changed.is_err() { + return time::timeout(remaining, operation).await.map_err(|_| ()); + } + let _paused = *pause_state.borrow_and_update(); + } + } + continue; + } + + let active_start = Instant::now(); + tokio::select! { + result = &mut operation => return Ok(result), + _ = time::sleep(remaining) => { + return Err(()); + } + changed = pause_state.changed() => { + if changed.is_err() { + return time::timeout(remaining, operation).await.map_err(|_| ()); + } + if *pause_state.borrow_and_update() { + remaining = remaining.saturating_sub(active_start.elapsed()); + if remaining.is_zero() { + return Err(()); + } + } + } + } + } +} + #[derive(Debug, thiserror::Error)] enum ClientOperationError { #[error(transparent)] @@ -472,6 +564,7 @@ pub struct RmcpClient { transport_recipe: TransportRecipe, initialize_context: Mutex>, session_recovery_lock: Mutex<()>, + elicitation_pause_state: ElicitationPauseState, } impl RmcpClient { @@ -500,6 +593,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + elicitation_pause_state: ElicitationPauseState::new(), }) } @@ -528,6 +622,7 @@ impl RmcpClient { transport_recipe, initialize_context: Mutex::new(None), session_recovery_lock: Mutex::new(()), + elicitation_pause_state: ElicitationPauseState::new(), }) } @@ -539,7 +634,11 @@ impl RmcpClient { timeout: Option, send_elicitation: SendElicitation, ) -> Result { - let client_service = ElicitationClientService::new(params.clone(), send_elicitation); + let client_service = ElicitationClientService::new( + params.clone(), + send_elicitation, + self.elicitation_pause_state.clone(), + ); let pending_transport = { let mut guard = self.state.lock().await; match &mut *guard { @@ -1052,16 +1151,28 @@ impl RmcpClient { Fut: std::future::Future>, { let service = self.service().await?; - match Self::run_service_operation_once(Arc::clone(&service), label, timeout, &operation) - .await + match Self::run_service_operation_once( + Arc::clone(&service), + label, + timeout, + self.elicitation_pause_state.clone(), + &operation, + ) + .await { Ok(result) => Ok(result), Err(error) if Self::is_session_expired_404(&error) => { self.reinitialize_after_session_expiry(&service).await?; let recovered_service = self.service().await?; - Self::run_service_operation_once(recovered_service, label, timeout, &operation) - .await - .map_err(Into::into) + Self::run_service_operation_once( + recovered_service, + label, + timeout, + self.elicitation_pause_state.clone(), + &operation, + ) + .await + .map_err(Into::into) } Err(error) => Err(error.into()), } @@ -1071,6 +1182,7 @@ impl RmcpClient { service: Arc>, label: &str, timeout: Option, + pause_state: ElicitationPauseState, operation: &F, ) -> std::result::Result where @@ -1078,13 +1190,15 @@ impl RmcpClient { Fut: std::future::Future>, { match timeout { - Some(duration) => time::timeout(duration, operation(service)) - .await - .map_err(|_| ClientOperationError::Timeout { - label: label.to_string(), - duration, - })? - .map_err(ClientOperationError::from), + Some(duration) => { + active_time_timeout(duration, pause_state.subscribe(), operation(service)) + .await + .map_err(|_| ClientOperationError::Timeout { + label: label.to_string(), + duration, + })? + .map_err(ClientOperationError::from) + } None => operation(service).await.map_err(ClientOperationError::from), } } @@ -1207,3 +1321,32 @@ async fn create_oauth_transport_and_runtime( Ok((transport, runtime)) } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use pretty_assertions::assert_eq; + use tokio::time; + + use super::*; + + #[tokio::test] + async fn active_time_timeout_pauses_while_elicitation_is_pending() { + let pause_state = ElicitationPauseState::new(); + let pause = pause_state.enter(); + tokio::spawn(async move { + time::sleep(Duration::from_millis(75)).await; + drop(pause); + }); + + let result = + active_time_timeout(Duration::from_millis(50), pause_state.subscribe(), async { + time::sleep(Duration::from_millis(90)).await; + "done" + }) + .await; + + assert_eq!(Ok("done"), result); + } +}