Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion codex-rs/rmcp-client/src/elicitation_client_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ use serde_json::Value;

use crate::logging_client_handler::LoggingClientHandler;
use crate::rmcp_client::Elicitation;
use crate::rmcp_client::ElicitationPauseState;
use crate::rmcp_client::ElicitationResponse;
use crate::rmcp_client::SendElicitation;

Expand All @@ -26,17 +27,23 @@ const MCP_PROGRESS_TOKEN_META_KEY: &str = "progressToken";
pub(crate) struct ElicitationClientService {
handler: LoggingClientHandler,
send_elicitation: Arc<SendElicitation>,
pause_state: ElicitationPauseState,
}

impl ElicitationClientService {
pub(crate) fn new(client_info: ClientInfo, send_elicitation: SendElicitation) -> Self {
pub(crate) fn new(
client_info: ClientInfo,
send_elicitation: SendElicitation,
pause_state: ElicitationPauseState,
) -> Self {
let send_elicitation = Arc::new(send_elicitation);
Self {
handler: LoggingClientHandler::new(
client_info,
clone_send_elicitation(Arc::clone(&send_elicitation)),
),
send_elicitation,
pause_state,
}
}

Expand All @@ -47,6 +54,7 @@ impl ElicitationClientService {
) -> Result<ElicitationResponse, rmcp::ErrorData> {
let RequestContext { id, meta, .. } = context;
let request = restore_context_meta(request, meta);
let _pause = self.pause_state.enter();
(self.send_elicitation)(id, request)
.await
.map_err(|err| rmcp::ErrorData::internal_error(err.to_string(), None))
Expand Down
169 changes: 156 additions & 13 deletions codex-rs/rmcp-client/src/rmcp_client.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::ffi::OsString;
use std::future::Future;
use std::io;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::time::Instant;

use anyhow::Result;
use anyhow::anyhow;
Expand Down Expand Up @@ -63,6 +67,7 @@ use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::process::Command;
use tokio::sync::Mutex;
use tokio::sync::watch;
use tokio::time;
use tracing::info;
use tracing::warn;
Expand Down Expand Up @@ -410,6 +415,93 @@ struct InitializeContext {
client_service: ElicitationClientService,
}

#[derive(Clone)]
pub(crate) struct ElicitationPauseState {
active_count: Arc<AtomicUsize>,
paused: watch::Sender<bool>,
}

impl ElicitationPauseState {
fn new() -> Self {
let (paused, _rx) = watch::channel(false);
Self {
active_count: Arc::new(AtomicUsize::new(0)),
paused,
}
}

pub(crate) fn enter(&self) -> ElicitationPauseGuard {
if self.active_count.fetch_add(1, Ordering::AcqRel) == 0 {
self.paused.send_replace(true);
}
ElicitationPauseGuard {
pause_state: self.clone(),
}
}

fn subscribe(&self) -> watch::Receiver<bool> {
self.paused.subscribe()
}
}

pub(crate) struct ElicitationPauseGuard {
pause_state: ElicitationPauseState,
}

impl Drop for ElicitationPauseGuard {
fn drop(&mut self) {
if self.pause_state.active_count.fetch_sub(1, Ordering::AcqRel) == 1 {
self.pause_state.paused.send_replace(false);
}
}
}

async fn active_time_timeout<T, Fut>(
duration: Duration,
mut pause_state: watch::Receiver<bool>,
operation: Fut,
) -> std::result::Result<T, ()>
where
Fut: Future<Output = T>,
{
let mut remaining = duration;
tokio::pin!(operation);

loop {
if *pause_state.borrow_and_update() {
tokio::select! {
result = &mut operation => return Ok(result),
changed = pause_state.changed() => {
if changed.is_err() {
return time::timeout(remaining, operation).await.map_err(|_| ());
}
let _paused = *pause_state.borrow_and_update();
}
}
continue;
}

let active_start = Instant::now();
tokio::select! {
result = &mut operation => return Ok(result),
_ = time::sleep(remaining) => {
return Err(());
}
changed = pause_state.changed() => {
if changed.is_err() {
return time::timeout(remaining, operation).await.map_err(|_| ());
}
if *pause_state.borrow_and_update() {
remaining = remaining.saturating_sub(active_start.elapsed());
if remaining.is_zero() {
return Err(());
}
}
}
}
}
}

#[derive(Debug, thiserror::Error)]
enum ClientOperationError {
#[error(transparent)]
Expand Down Expand Up @@ -472,6 +564,7 @@ pub struct RmcpClient {
transport_recipe: TransportRecipe,
initialize_context: Mutex<Option<InitializeContext>>,
session_recovery_lock: Mutex<()>,
elicitation_pause_state: ElicitationPauseState,
}

impl RmcpClient {
Expand Down Expand Up @@ -500,6 +593,7 @@ impl RmcpClient {
transport_recipe,
initialize_context: Mutex::new(None),
session_recovery_lock: Mutex::new(()),
elicitation_pause_state: ElicitationPauseState::new(),
})
}

Expand Down Expand Up @@ -528,6 +622,7 @@ impl RmcpClient {
transport_recipe,
initialize_context: Mutex::new(None),
session_recovery_lock: Mutex::new(()),
elicitation_pause_state: ElicitationPauseState::new(),
})
}

Expand All @@ -539,7 +634,11 @@ impl RmcpClient {
timeout: Option<Duration>,
send_elicitation: SendElicitation,
) -> Result<InitializeResult> {
let client_service = ElicitationClientService::new(params.clone(), send_elicitation);
let client_service = ElicitationClientService::new(
params.clone(),
send_elicitation,
self.elicitation_pause_state.clone(),
);
let pending_transport = {
let mut guard = self.state.lock().await;
match &mut *guard {
Expand Down Expand Up @@ -1052,16 +1151,28 @@ impl RmcpClient {
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
{
let service = self.service().await?;
match Self::run_service_operation_once(Arc::clone(&service), label, timeout, &operation)
.await
match Self::run_service_operation_once(
Arc::clone(&service),
label,
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
{
Ok(result) => Ok(result),
Err(error) if Self::is_session_expired_404(&error) => {
self.reinitialize_after_session_expiry(&service).await?;
let recovered_service = self.service().await?;
Self::run_service_operation_once(recovered_service, label, timeout, &operation)
.await
.map_err(Into::into)
Self::run_service_operation_once(
recovered_service,
label,
timeout,
self.elicitation_pause_state.clone(),
&operation,
)
.await
.map_err(Into::into)
}
Err(error) => Err(error.into()),
}
Expand All @@ -1071,20 +1182,23 @@ impl RmcpClient {
service: Arc<RunningService<RoleClient, ElicitationClientService>>,
label: &str,
timeout: Option<Duration>,
pause_state: ElicitationPauseState,
operation: &F,
) -> std::result::Result<T, ClientOperationError>
where
F: Fn(Arc<RunningService<RoleClient, ElicitationClientService>>) -> Fut,
Fut: std::future::Future<Output = std::result::Result<T, rmcp::service::ServiceError>>,
{
match timeout {
Some(duration) => time::timeout(duration, operation(service))
.await
.map_err(|_| ClientOperationError::Timeout {
label: label.to_string(),
duration,
})?
.map_err(ClientOperationError::from),
Some(duration) => {
active_time_timeout(duration, pause_state.subscribe(), operation(service))
.await
.map_err(|_| ClientOperationError::Timeout {
label: label.to_string(),
duration,
})?
.map_err(ClientOperationError::from)
}
None => operation(service).await.map_err(ClientOperationError::from),
}
}
Expand Down Expand Up @@ -1207,3 +1321,32 @@ async fn create_oauth_transport_and_runtime(

Ok((transport, runtime))
}

#[cfg(test)]
mod tests {
use std::time::Duration;

use pretty_assertions::assert_eq;
use tokio::time;

use super::*;

#[tokio::test]
async fn active_time_timeout_pauses_while_elicitation_is_pending() {
let pause_state = ElicitationPauseState::new();
let pause = pause_state.enter();
tokio::spawn(async move {
time::sleep(Duration::from_millis(75)).await;
drop(pause);
});

let result =
active_time_timeout(Duration::from_millis(50), pause_state.subscribe(), async {
time::sleep(Duration::from_millis(90)).await;
"done"
})
.await;

assert_eq!(Ok("done"), result);
}
}
Loading