From 23500597899268530b837759e1a7a748c100835d Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Feb 2026 18:24:21 -0800 Subject: [PATCH 01/17] app-server: Add transport for remote control --- codex-rs/Cargo.lock | 3 + codex-rs/app-server/Cargo.toml | 3 + codex-rs/app-server/README.md | 1 + codex-rs/app-server/src/app_server_tracing.rs | 1 + codex-rs/app-server/src/in_process.rs | 11 +- codex-rs/app-server/src/lib.rs | 36 +- codex-rs/app-server/src/main.rs | 2 +- codex-rs/app-server/src/message_processor.rs | 10 +- .../src/message_processor/tracing_tests.rs | 10 +- codex-rs/app-server/src/transport/mod.rs | 111 +- .../remote_control/client_tracker.rs | 422 +++++++ .../src/transport/remote_control/enroll.rs | 399 ++++++ .../src/transport/remote_control/mod.rs | 59 + .../src/transport/remote_control/protocol.rs | 139 +++ .../src/transport/remote_control/tests.rs | 1090 ++++++++++++++++ .../src/transport/remote_control/websocket.rs | 1094 +++++++++++++++++ codex-rs/app-server/src/transport/stdio.rs | 4 +- .../app-server/src/transport/websocket.rs | 7 +- codex-rs/cli/src/main.rs | 8 +- codex-rs/core/config.schema.json | 6 + codex-rs/features/src/lib.rs | 8 + codex-rs/features/src/tests.rs | 6 + .../0023_remote_control_enrollments.sql | 8 + codex-rs/state/src/runtime.rs | 1 + codex-rs/state/src/runtime/remote_control.rs | 197 +++ 25 files changed, 3548 insertions(+), 88 deletions(-) create mode 100644 codex-rs/app-server/src/transport/remote_control/client_tracker.rs create mode 100644 codex-rs/app-server/src/transport/remote_control/enroll.rs create mode 100644 codex-rs/app-server/src/transport/remote_control/mod.rs create mode 100644 codex-rs/app-server/src/transport/remote_control/protocol.rs create mode 100644 codex-rs/app-server/src/transport/remote_control/tests.rs create mode 100644 codex-rs/app-server/src/transport/remote_control/websocket.rs create mode 100644 codex-rs/state/migrations/0023_remote_control_enrollments.sql create mode 100644 codex-rs/state/src/runtime/remote_control.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index c5846fbe1988..681187b27f3c 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1423,9 +1423,11 @@ dependencies = [ "codex-utils-cli", "codex-utils-json-to-toml", "codex-utils-pty", + "codex-utils-rustls-provider", "constant_time_eq", "core_test_support", "futures", + "gethostname", "hmac", "jsonwebtoken", "opentelemetry", @@ -1448,6 +1450,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "url", "uuid", "wiremock", ] diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index ee12e87acb4e..74ac27955073 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -52,10 +52,12 @@ codex-state = { workspace = true } codex-tools = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-json-to-toml = { workspace = true } +codex-utils-rustls-provider = { workspace = true } chrono = { workspace = true } clap = { workspace = true, features = ["derive"] } constant_time_eq = { workspace = true } futures = { workspace = true } +gethostname = { workspace = true } hmac = { workspace = true } jsonwebtoken = { workspace = true } owo-colors = { workspace = true, features = ["supports-colors"] } @@ -76,6 +78,7 @@ tokio-util = { workspace = true } tokio-tungstenite = { workspace = true } tracing = { workspace = true, features = ["log"] } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] } +url = { workspace = true } uuid = { workspace = true, features = ["serde", "v7"] } [dev-dependencies] diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index cabea022a1e3..4fed3052c815 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -25,6 +25,7 @@ Supported transports: - stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL) - websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**) +- off (`--listen off`): do not expose a local transport When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes: diff --git a/codex-rs/app-server/src/app_server_tracing.rs b/codex-rs/app-server/src/app_server_tracing.rs index 26fe8ca99971..b06a8e52c48d 100644 --- a/codex-rs/app-server/src/app_server_tracing.rs +++ b/codex-rs/app-server/src/app_server_tracing.rs @@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str { match transport { AppServerTransport::Stdio => "stdio", AppServerTransport::WebSocket { .. } => "websocket", + AppServerTransport::Off => "off", } } diff --git a/codex-rs/app-server/src/in_process.rs b/codex-rs/app-server/src/in_process.rs index 0405c7225909..248ff85badf4 100644 --- a/codex-rs/app-server/src/in_process.rs +++ b/codex-rs/app-server/src/in_process.rs @@ -74,6 +74,7 @@ use codex_app_server_protocol::Result; use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequest; use codex_arg0::Arg0DispatchPaths; +use codex_core::AuthManager; use codex_core::config::Config; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::LoaderOverrides; @@ -377,6 +378,14 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle { } }); + let auth_manager = AuthManager::shared( + args.config.codex_home.clone(), + args.enable_codex_api_key_env, + args.config.cli_auth_credentials_store_mode, + ); + auth_manager + .set_forced_chatgpt_workspace_id(args.config.forced_chatgpt_workspace_id.clone()); + let processor_outgoing = Arc::clone(&outgoing_message_sender); let (processor_tx, mut processor_rx) = mpsc::channel::(channel_capacity); let mut processor_handle = tokio::spawn(async move { @@ -392,7 +401,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle { log_db: None, config_warnings: args.config_warnings, session_source: args.session_source, - enable_codex_api_key_env: args.enable_codex_api_key_env, + auth_manager, }); let mut thread_created_rx = processor.thread_created_receiver(); let mut session = ConnectionSessionState::default(); diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index e85a8f1290b1..31c0a9fe739c 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -8,6 +8,7 @@ use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; +use codex_features::Feature; use codex_utils_cli::CliConfigOverrides; use std::collections::HashMap; use std::collections::HashSet; @@ -29,6 +30,7 @@ use crate::transport::OutboundConnectionState; use crate::transport::TransportEvent; use crate::transport::auth::policy_from_settings; use crate::transport::route_outgoing_envelope; +use crate::transport::start_remote_control; use crate::transport::start_stdio_connection; use crate::transport::start_websocket_acceptor; use codex_app_server_protocol::ConfigLayerSource; @@ -499,13 +501,13 @@ pub async fn run_main_with_transport( let feedback_layer = feedback.logger_layer(); let feedback_metadata_layer = feedback.metadata_layer(); - let log_db = codex_state::StateRuntime::init( + let state_db = codex_state::StateRuntime::init( config.sqlite_home.clone(), config.model_provider_id.clone(), ) .await - .ok() - .map(log_db::start); + .ok(); + let log_db = state_db.clone().map(log_db::start); let log_db_layer = log_db .clone() .map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE))); @@ -548,6 +550,32 @@ pub async fn run_main_with_transport( .await?; transport_accept_handles.push(accept_handle); } + AppServerTransport::Off => {} + } + + let auth_manager = AuthManager::shared( + config.codex_home.clone(), + /*enable_codex_api_key_env*/ false, + config.cli_auth_credentials_store_mode, + ); + auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); + + if config.features.enabled(Feature::RemoteControl) { + let accept_handle = start_remote_control( + config.chatgpt_base_url.clone(), + state_db.clone(), + auth_manager.clone(), + transport_event_tx.clone(), + transport_shutdown_token.clone(), + ) + .await?; + transport_accept_handles.push(accept_handle); + } + if transport_accept_handles.is_empty() { + return Err(std::io::Error::new( + ErrorKind::InvalidInput, + "no transport configured; use --listen or enable remote control", + )); } let outbound_handle = tokio::spawn(async move { @@ -622,7 +650,7 @@ pub async fn run_main_with_transport( log_db, config_warnings, session_source, - enable_codex_api_key_env: false, + auth_manager, }); let mut thread_created_rx = processor.thread_created_receiver(); let mut running_turn_count_rx = processor.subscribe_running_assistant_turn_count(); diff --git a/codex-rs/app-server/src/main.rs b/codex-rs/app-server/src/main.rs index fa95f973ea59..9a23680fb9ce 100644 --- a/codex-rs/app-server/src/main.rs +++ b/codex-rs/app-server/src/main.rs @@ -16,7 +16,7 @@ const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH" #[derive(Debug, Parser)] struct AppServerArgs { /// Transport endpoint URL. Supported values: `stdio://` (default), - /// `ws://IP:PORT`. + /// `ws://IP:PORT`, `off`. #[arg( long = "listen", value_name = "URL", diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 301ef9992c72..97d519fb79d4 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -187,7 +187,7 @@ pub(crate) struct MessageProcessorArgs { pub(crate) log_db: Option, pub(crate) config_warnings: Vec, pub(crate) session_source: SessionSource, - pub(crate) enable_codex_api_key_env: bool, + pub(crate) auth_manager: Arc, } impl MessageProcessor { @@ -206,13 +206,8 @@ impl MessageProcessor { log_db, config_warnings, session_source, - enable_codex_api_key_env, + auth_manager, } = args; - let auth_manager = AuthManager::shared( - config.codex_home.clone(), - enable_codex_api_key_env, - config.cli_auth_credentials_store_mode, - ); let thread_manager = Arc::new(ThreadManager::new( config.as_ref(), auth_manager.clone(), @@ -224,7 +219,6 @@ impl MessageProcessor { }, environment_manager, )); - auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); auth_manager.set_external_auth_refresher(Arc::new(ExternalAuthRefreshBridge { outgoing: outgoing.clone(), })); diff --git a/codex-rs/app-server/src/message_processor/tracing_tests.rs b/codex-rs/app-server/src/message_processor/tracing_tests.rs index d2fe9c23d7db..4665653ef149 100644 --- a/codex-rs/app-server/src/message_processor/tracing_tests.rs +++ b/codex-rs/app-server/src/message_processor/tracing_tests.rs @@ -20,6 +20,7 @@ use codex_app_server_protocol::TurnStartParams; use codex_app_server_protocol::TurnStartResponse; use codex_app_server_protocol::UserInput; use codex_arg0::Arg0DispatchPaths; +use codex_core::AuthManager; use codex_core::config::Config; use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; @@ -231,6 +232,13 @@ fn build_test_processor( MessageProcessor, mpsc::Receiver, ) { + let auth_manager = AuthManager::shared( + config.codex_home.clone(), + /*enable_codex_api_key_env*/ false, + config.cli_auth_credentials_store_mode, + ); + auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); + let (outgoing_tx, outgoing_rx) = mpsc::channel(16); let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx)); let processor = MessageProcessor::new(MessageProcessorArgs { @@ -245,7 +253,7 @@ fn build_test_processor( log_db: None, config_warnings: Vec::new(), session_source: SessionSource::VSCode, - enable_codex_api_key_env: false, + auth_manager, }); (processor, outgoing_rx) } diff --git a/codex-rs/app-server/src/transport/mod.rs b/codex-rs/app-server/src/transport/mod.rs index c0653b903bac..504a02f23c70 100644 --- a/codex-rs/app-server/src/transport/mod.rs +++ b/codex-rs/app-server/src/transport/mod.rs @@ -17,6 +17,7 @@ use std::str::FromStr; use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; @@ -28,9 +29,11 @@ use tracing::warn; /// plenty for an interactive CLI. pub(crate) const CHANNEL_CAPACITY: usize = 128; +mod remote_control; mod stdio; mod websocket; +pub(crate) use remote_control::start_remote_control; pub(crate) use stdio::start_stdio_connection; pub(crate) use websocket::start_websocket_acceptor; @@ -38,6 +41,7 @@ pub(crate) use websocket::start_websocket_acceptor; pub enum AppServerTransport { Stdio, WebSocket { bind_address: SocketAddr }, + Off, } #[derive(Debug, Clone, Eq, PartialEq)] @@ -51,7 +55,7 @@ impl std::fmt::Display for AppServerTransportParseError { match self { AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( f, - "unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`" + "unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`" ), AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( f, @@ -71,6 +75,10 @@ impl AppServerTransport { return Ok(Self::Stdio); } + if listen_url == "off" { + return Ok(Self::Off); + } + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { let bind_address = socket_addr.parse::().map_err(|_| { AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) @@ -166,6 +174,12 @@ impl OutboundConnectionState { } } +static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0); + +fn next_connection_id() -> ConnectionId { + ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed)) +} + async fn forward_incoming_message( transport_event_tx: &mpsc::Sender, writer: &mpsc::Sender, @@ -378,8 +392,11 @@ pub(crate) async fn route_outgoing_envelope( #[cfg(test)] mod tests { use super::*; - use crate::error_code::OVERLOADED_ERROR_CODE; use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::JSONRPCNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::JSONRPCResponse; + use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ServerNotification; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; @@ -393,41 +410,10 @@ mod tests { } #[test] - fn app_server_transport_parses_stdio_listen_url() { - let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL) - .expect("stdio listen URL should parse"); - assert_eq!(transport, AppServerTransport::Stdio); - } - - #[test] - fn app_server_transport_parses_websocket_listen_url() { - let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234") - .expect("websocket listen URL should parse"); - assert_eq!( - transport, - AppServerTransport::WebSocket { - bind_address: "127.0.0.1:1234".parse().expect("valid socket address"), - } - ); - } - - #[test] - fn app_server_transport_rejects_invalid_websocket_listen_url() { - let err = AppServerTransport::from_listen_url("ws://localhost:1234") - .expect_err("hostname bind address should be rejected"); + fn listen_off_parses_as_off_transport() { assert_eq!( - err.to_string(), - "invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`" - ); - } - - #[test] - fn app_server_transport_rejects_unsupported_listen_url() { - let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234") - .expect_err("unsupported scheme should fail"); - assert_eq!( - err.to_string(), - "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" + AppServerTransport::from_listen_url("off"), + Ok(AppServerTransport::Off) ); } @@ -437,11 +423,10 @@ mod tests { let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); let (writer_tx, mut writer_rx) = mpsc::channel(1); - let first_message = - JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }); + let first_message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); transport_event_tx .send(TransportEvent::IncomingMessage { connection_id, @@ -450,8 +435,8 @@ mod tests { .await .expect("queue should accept first message"); - let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { - id: codex_app_server_protocol::RequestId::Integer(7), + let request = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(7), method: "config/read".to_string(), params: Some(json!({ "includeLayers": false })), trace: None, @@ -499,11 +484,10 @@ mod tests { let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); let (writer_tx, _writer_rx) = mpsc::channel(1); - let first_message = - JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }); + let first_message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); transport_event_tx .send(TransportEvent::IncomingMessage { connection_id, @@ -512,8 +496,8 @@ mod tests { .await .expect("queue should accept first message"); - let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { - id: codex_app_server_protocol::RequestId::Integer(7), + let response = JSONRPCMessage::Response(JSONRPCResponse { + id: RequestId::Integer(7), result: json!({"ok": true}), }); let transport_event_tx_for_enqueue = transport_event_tx.clone(); @@ -553,11 +537,10 @@ mod tests { match forwarded_event { TransportEvent::IncomingMessage { connection_id: queued_connection_id, - message: - JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }), + message: JSONRPCMessage::Response(JSONRPCResponse { id, result }), } => { assert_eq!(queued_connection_id, connection_id); - assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7)); + assert_eq!(id, RequestId::Integer(7)); assert_eq!(result, json!({"ok": true})); } _ => panic!("expected forwarded response message"), @@ -573,12 +556,10 @@ mod tests { transport_event_tx .send(TransportEvent::IncomingMessage { connection_id, - message: JSONRPCMessage::Notification( - codex_app_server_protocol::JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }, - ), + message: JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }), }) .await .expect("transport queue should accept first message"); @@ -597,15 +578,15 @@ mod tests { .await .expect("writer queue should accept first message"); - let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { - id: codex_app_server_protocol::RequestId::Integer(7), + let request = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(7), method: "config/read".to_string(), params: Some(json!({ "includeLayers": false })), trace: None, }); - let enqueue_result = tokio::time::timeout( - std::time::Duration::from_millis(100), + let enqueue_result = timeout( + Duration::from_millis(100), enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), ) .await @@ -781,7 +762,7 @@ mod tests { OutgoingEnvelope::ToConnection { connection_id, message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id: codex_app_server_protocol::RequestId::Integer(1), + request_id: RequestId::Integer(1), params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { thread_id: "thr_123".to_string(), turn_id: "turn_123".to_string(), @@ -843,7 +824,7 @@ mod tests { OutgoingEnvelope::ToConnection { connection_id, message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id: codex_app_server_protocol::RequestId::Integer(1), + request_id: RequestId::Integer(1), params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { thread_id: "thr_123".to_string(), turn_id: "turn_123".to_string(), diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs new file mode 100644 index 000000000000..f23a1a441b24 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -0,0 +1,422 @@ +use super::CHANNEL_CAPACITY; +use super::TransportEvent; +use super::next_connection_id; +use super::protocol::ClientEnvelope; +pub use super::protocol::ClientEvent; +pub use super::protocol::ClientId; +use super::protocol::PongStatus; +use super::protocol::ServerEvent; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::QueuedOutgoingMessage; +use crate::transport::remote_control::QueuedServerEnvelope; +use codex_app_server_protocol::JSONRPCMessage; +use std::collections::HashMap; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio::task::JoinSet; +use tokio::time::Duration; +use tokio::time::Instant; +use tokio_util::sync::CancellationToken; + +const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60); +pub(crate) const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30); + +#[derive(Debug)] +pub(crate) struct Stopped; + +struct ClientState { + connection_id: ConnectionId, + disconnect_token: CancellationToken, + last_activity_at: Instant, + last_inbound_seq_id: Option, + status_tx: watch::Sender, +} + +pub(crate) struct ClientTracker { + clients: HashMap, + join_set: JoinSet<()>, + server_event_tx: mpsc::Sender, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, +} + +impl ClientTracker { + pub(crate) fn new( + server_event_tx: mpsc::Sender, + transport_event_tx: mpsc::Sender, + shutdown_token: &CancellationToken, + ) -> Self { + Self { + clients: HashMap::new(), + join_set: JoinSet::new(), + server_event_tx, + transport_event_tx, + shutdown_token: shutdown_token.child_token(), + } + } + + pub(crate) async fn bookkeep_join_set(&mut self) { + while self.join_set.join_next().await.is_some() {} + futures::future::pending().await + } + + pub(crate) async fn shutdown(&mut self) { + self.shutdown_token.cancel(); + + while let Some(client_id) = self.clients.keys().next().cloned() { + let _ = self.close_client(&client_id).await; + } + + self.drain_join_set().await; + } + + async fn drain_join_set(&mut self) { + while self.join_set.join_next().await.is_some() {} + } + + pub(crate) async fn handle_message( + &mut self, + client_envelope: ClientEnvelope, + ) -> Result<(), Stopped> { + let ClientEnvelope { + client_id, + event, + seq_id, + cursor: _, + } = client_envelope; + match event { + ClientEvent::ClientMessage { message } => { + let is_initialize = remote_control_message_starts_connection(&message); + if let Some(seq_id) = seq_id + && let Some(client) = self.clients.get(&client_id) + && client + .last_inbound_seq_id + .is_some_and(|last_seq_id| last_seq_id >= seq_id) + && !is_initialize + { + return Ok(()); + } + + if is_initialize && self.clients.contains_key(&client_id) { + self.close_client(&client_id).await?; + } + + if let Some(connection_id) = self.clients.get_mut(&client_id).map(|client| { + client.last_activity_at = Instant::now(); + if let Some(seq_id) = seq_id { + client.last_inbound_seq_id = Some(seq_id); + } + client.connection_id + }) { + self.transport_event_tx + .send(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await + .map_err(|_| Stopped)?; + return Ok(()); + } + + if !is_initialize { + return Ok(()); + } + + let connection_id = next_connection_id(); + let (writer_tx, writer_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let disconnect_token = self.shutdown_token.child_token(); + self.transport_event_tx + .send(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + disconnect_sender: Some(disconnect_token.clone()), + }) + .await + .map_err(|_| Stopped)?; + + let (status_tx, status_rx) = watch::channel(PongStatus::Active); + self.join_set.spawn(Self::run_client_outbound( + client_id.clone(), + self.server_event_tx.clone(), + writer_rx, + status_rx, + disconnect_token.clone(), + )); + self.clients.insert( + client_id, + ClientState { + connection_id, + disconnect_token, + last_activity_at: Instant::now(), + last_inbound_seq_id: seq_id, + status_tx, + }, + ); + self.send_transport_event(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await + } + ClientEvent::Ack => Ok(()), + ClientEvent::Ping => { + if let Some(client) = self.clients.get_mut(&client_id) { + client.last_activity_at = Instant::now(); + let _ = client.status_tx.send(PongStatus::Active); + return Ok(()); + } + + let server_event_tx = self.server_event_tx.clone(); + self.join_set.spawn(async move { + let server_envelope = QueuedServerEnvelope { + event: ServerEvent::Pong { + status: PongStatus::Unknown, + }, + client_id, + write_complete_tx: None, + }; + let _ = server_event_tx.send(server_envelope).await; + }); + Ok(()) + } + ClientEvent::ClientClosed => self.close_client(&client_id).await, + } + } + + async fn run_client_outbound( + client_id: ClientId, + server_event_tx: mpsc::Sender, + mut writer_rx: mpsc::Receiver, + mut status_rx: watch::Receiver, + disconnect_token: CancellationToken, + ) { + loop { + let (event, write_complete_tx) = tokio::select! { + _ = disconnect_token.cancelled() => { + break; + } + queued_message = writer_rx.recv() => { + let Some(queued_message) = queued_message else { + break; + }; + let event = ServerEvent::ServerMessage { + message: Box::new(queued_message.message), + }; + (event, queued_message.write_complete_tx) + } + changed = status_rx.changed() => { + if changed.is_err() { + break; + } + let event = ServerEvent::Pong { status: status_rx.borrow().clone() }; + (event, None) + } + }; + let send_result = tokio::select! { + _ = disconnect_token.cancelled() => { + break; + } + send_result = server_event_tx.send(QueuedServerEnvelope { + event, + client_id: client_id.clone(), + write_complete_tx, + }) => send_result, + }; + if send_result.is_err() { + break; + } + } + } + + pub(crate) async fn close_expired_clients(&mut self) -> Result, Stopped> { + let now = Instant::now(); + let expired_client_ids: Vec = self + .clients + .iter() + .filter_map(|(client_id, client)| { + (!remote_control_client_is_alive(client, now)).then_some(client_id.clone()) + }) + .collect(); + for client_id in &expired_client_ids { + self.close_client(client_id).await?; + } + Ok(expired_client_ids) + } + + async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> { + let Some(client) = self.clients.remove(client_id) else { + return Ok(()); + }; + client.disconnect_token.cancel(); + self.send_transport_event(TransportEvent::ConnectionClosed { + connection_id: client.connection_id, + }) + .await + } + + async fn send_transport_event(&self, event: TransportEvent) -> Result<(), Stopped> { + self.transport_event_tx + .send(event) + .await + .map_err(|_| Stopped) + } +} + +fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool { + matches!( + message, + JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. }) + if method == "initialize" + ) +} + +fn remote_control_client_is_alive(client: &ClientState, now: Instant) -> bool { + now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::outgoing_message::OutgoingMessage; + use crate::transport::remote_control::protocol::ClientEnvelope; + use crate::transport::remote_control::protocol::ClientEvent; + use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::RequestId; + use codex_app_server_protocol::ServerNotification; + use pretty_assertions::assert_eq; + use serde_json::json; + use tokio::time::timeout; + + fn initialize_envelope(client_id: &str) -> ClientEnvelope { + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }), + }, + client_id: ClientId(client_id.to_string()), + seq_id: Some(0), + cursor: None, + } + } + + #[tokio::test] + async fn cancelled_outbound_task_emits_connection_closed() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + + let (connection_id, disconnect_sender) = match transport_event_rx + .recv() + .await + .expect("connection opened should be sent") + { + TransportEvent::ConnectionOpened { + connection_id, + disconnect_sender: Some(disconnect_sender), + .. + } => (connection_id, disconnect_sender), + other => panic!("expected connection opened, got {other:?}"), + }; + match transport_event_rx + .recv() + .await + .expect("initialize should be forwarded") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + .. + } => assert_eq!(incoming_connection_id, connection_id), + other => panic!("expected incoming initialize, got {other:?}"), + } + + disconnect_sender.cancel(); + timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set()) + .await + .expect_err("bookkeeping should process the closed task and stay pending"); + + match transport_event_rx + .recv() + .await + .expect("connection closed should be sent") + { + TransportEvent::ConnectionClosed { + connection_id: closed_connection_id, + } => assert_eq!(closed_connection_id, connection_id), + other => panic!("expected connection closed, got {other:?}"), + } + } + + #[tokio::test] + async fn shutdown_cancels_blocked_outbound_forwarding() { + let (server_event_tx, _server_event_rx) = mpsc::channel(1); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx.clone(), transport_event_tx, &shutdown_token); + + server_event_tx + .send(QueuedServerEnvelope { + event: ServerEvent::Pong { + status: PongStatus::Unknown, + }, + client_id: ClientId("queued-client".to_string()), + write_complete_tx: None, + }) + .await + .expect("server event queue should accept prefill"); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + + let writer = match transport_event_rx + .recv() + .await + .expect("connection opened should be sent") + { + TransportEvent::ConnectionOpened { writer, .. } => writer, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx + .recv() + .await + .expect("initialize should be forwarded"); + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "test".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("writer should accept queued message"); + + timeout(Duration::from_secs(1), client_tracker.shutdown()) + .await + .expect("shutdown should not hang on blocked server forwarding"); + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs new file mode 100644 index 000000000000..99b71c9fd759 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -0,0 +1,399 @@ +use super::protocol::EnrollRemoteServerRequest; +use super::protocol::EnrollRemoteServerResponse; +use super::protocol::RemoteControlTarget; +use axum::http::HeaderMap; +use codex_core::default_client::build_reqwest_client; +use codex_state::StateRuntime; +use gethostname::gethostname; +use std::io; +use std::io::ErrorKind; +use tracing::warn; + +const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); +const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096; + +const REQUEST_ID_HEADER: &str = "x-request-id"; +const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id"; +const CF_RAY_HEADER: &str = "cf-ray"; +pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RemoteControlEnrollment { + pub(super) account_id: Option, + pub(super) server_id: String, + pub(super) server_name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RemoteControlConnectionAuth { + pub(super) bearer_token: String, + pub(super) account_id: Option, +} + +pub(super) async fn load_persisted_remote_control_enrollment( + state_db: Option<&StateRuntime>, + remote_control_target: &RemoteControlTarget, + account_id: Option<&str>, +) -> Option { + let state_db = state_db?; + let enrollment = match state_db + .get_remote_control_enrollment(&remote_control_target.websocket_url, account_id) + .await + { + Ok(enrollment) => enrollment, + Err(err) => { + warn!("{err}"); + return None; + } + }; + + enrollment.map(|(server_id, server_name)| RemoteControlEnrollment { + account_id: account_id.map(&str::to_string), + server_id, + server_name, + }) +} + +pub(super) async fn update_persisted_remote_control_enrollment( + state_db: Option<&StateRuntime>, + remote_control_target: &RemoteControlTarget, + account_id: Option<&str>, + enrollment: Option<&RemoteControlEnrollment>, +) -> io::Result<()> { + let Some(state_db) = state_db else { + return Ok(()); + }; + if let &Some(enrollment) = &enrollment + && enrollment.account_id.as_deref() != account_id + { + return Err(io::Error::other(format!( + "enrollment account_id does not match expected account_id `{account_id:?}`" + ))); + } + + if let Some(enrollment) = enrollment { + state_db + .upsert_remote_control_enrollment( + &remote_control_target.websocket_url, + account_id, + &enrollment.server_id, + &enrollment.server_name, + ) + .await + .map_err(io::Error::other) + } else { + state_db + .delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id) + .await + .map(|_| ()) + .map_err(io::Error::other) + } +} + +pub(crate) fn preview_remote_control_response_body(body: &[u8]) -> String { + let body = String::from_utf8_lossy(body); + let trimmed = body.trim(); + if trimmed.is_empty() { + return "".to_string(); + } + if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES { + return trimmed.to_string(); + } + + let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES; + while !trimmed.is_char_boundary(cut) { + cut = cut.saturating_sub(1); + } + let mut truncated = trimmed[..cut].to_string(); + truncated.push_str("..."); + truncated +} + +pub(crate) fn format_headers(headers: &HeaderMap) -> String { + let request_id_str = headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(OAI_REQUEST_ID_HEADER)) + .map(|value| value.to_str().unwrap_or("").to_owned()) + .unwrap_or_else(|| "".to_owned()); + let cf_ray_str = headers + .get(CF_RAY_HEADER) + .map(|value| value.to_str().unwrap_or("").to_owned()) + .unwrap_or_else(|| "".to_owned()); + format!("request-id: {request_id_str}, cf-ray: {cf_ray_str}") +} + +pub(super) async fn enroll_remote_control_server( + remote_control_target: &RemoteControlTarget, + auth: &RemoteControlConnectionAuth, +) -> io::Result { + let enroll_url = &remote_control_target.enroll_url; + let server_name = gethostname().to_string_lossy().trim().to_string(); + let request = EnrollRemoteServerRequest { + name: server_name.clone(), + os: std::env::consts::OS, + arch: std::env::consts::ARCH, + app_server_version: env!("CARGO_PKG_VERSION"), + }; + let client = build_reqwest_client(); + let mut http_request = client + .post(enroll_url) + .timeout(REMOTE_CONTROL_ENROLL_TIMEOUT) + .bearer_auth(&auth.bearer_token) + .json(&request); + let account_id = auth.account_id.as_deref(); + if let Some(account_id) = account_id { + http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id); + } + + let response = http_request.send().await.map_err(|err| { + io::Error::other(format!( + "failed to enroll remote control server at `{enroll_url}`: {err}" + )) + })?; + let headers = response.headers().clone(); + let status = response.status(); + let body = response.bytes().await.map_err(|err| { + io::Error::other(format!( + "failed to read remote control enrollment response from `{enroll_url}`: {err}" + )) + })?; + let body_preview = preview_remote_control_response_body(&body); + if !status.is_success() { + let headers_str = format_headers(&headers); + let error_kind = if matches!(status.as_u16(), 401 | 403) { + ErrorKind::PermissionDenied + } else { + ErrorKind::Other + }; + return Err(io::Error::new( + error_kind, + format!( + "remote control server enrollment failed at `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}" + ), + )); + } + + let enrollment = serde_json::from_slice::(&body).map_err(|err| { + let headers_str = format_headers(&headers); + io::Error::other(format!( + "failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}" + )) + })?; + + Ok(RemoteControlEnrollment { + account_id: account_id.map(&str::to_string), + server_id: enrollment.server_id, + server_name, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transport::remote_control::protocol::normalize_remote_control_url; + use codex_state::StateRuntime; + use pretty_assertions::assert_eq; + use serde_json::json; + use std::sync::Arc; + use tempfile::TempDir; + use tokio::io::AsyncWriteExt; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::time::Duration; + use tokio::time::timeout; + + async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) + .await + .expect("state runtime should initialize") + } + + #[tokio::test] + async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() { + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let first_target = normalize_remote_control_url("http://example.com/remote/control") + .expect("first target should parse"); + let second_target = normalize_remote_control_url("http://example.com/other/control") + .expect("second target should parse"); + let first_enrollment = RemoteControlEnrollment { + account_id: Some("account-a".to_string()), + server_id: "srv_e_first".to_string(), + server_name: "first-server".to_string(), + }; + let second_enrollment = RemoteControlEnrollment { + account_id: Some("account-a".to_string()), + server_id: "srv_e_second".to_string(), + server_name: "second-server".to_string(), + }; + + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + Some("account-a"), + Some(&first_enrollment), + ) + .await + .expect("first enrollment should persist"); + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + Some("account-a"), + Some(&second_enrollment), + ) + .await + .expect("second enrollment should persist"); + + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + Some("account-a"), + ) + .await, + Some(first_enrollment.clone()) + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + Some("account-b"), + ) + .await, + None + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + Some("account-a"), + ) + .await, + Some(second_enrollment) + ); + } + + #[tokio::test] + async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() { + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let first_target = normalize_remote_control_url("http://example.com/remote/control") + .expect("first target should parse"); + let second_target = normalize_remote_control_url("http://example.com/other/control") + .expect("second target should parse"); + let first_enrollment = RemoteControlEnrollment { + account_id: Some("account-a".to_string()), + server_id: "srv_e_first".to_string(), + server_name: "first-server".to_string(), + }; + let second_enrollment = RemoteControlEnrollment { + account_id: Some("account-a".to_string()), + server_id: "srv_e_second".to_string(), + server_name: "second-server".to_string(), + }; + + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + Some("account-a"), + Some(&first_enrollment), + ) + .await + .expect("first enrollment should persist"); + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + Some("account-a"), + Some(&second_enrollment), + ) + .await + .expect("second enrollment should persist"); + + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + Some("account-a"), + None, + ) + .await + .expect("matching enrollment should clear"); + + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + Some("account-a"), + ) + .await, + None + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + Some("account-a"), + ) + .await, + Some(second_enrollment) + ); + } + + #[tokio::test] + async fn enroll_remote_control_server_parse_failure_includes_response_body() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let enroll_url = remote_control_target.enroll_url.clone(); + let response_body = json!({ + "error": "not enrolled", + }); + let expected_body = response_body.to_string(); + let server_task = tokio::spawn(async move { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + respond_with_json(stream, response_body).await; + }); + + let err = enroll_remote_control_server( + &remote_control_target, + &RemoteControlConnectionAuth { + bearer_token: "Access Token".to_string(), + account_id: Some("account_id".to_string()), + }, + ) + .await + .expect_err("invalid response should fail to parse"); + + server_task.await.expect("server task should succeed"); + assert_eq!( + err.to_string(), + format!( + "failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, request-id: , cf-ray: , body: {expected_body}, decode error: missing field `server_id` at line 1 column {}", + expected_body.len() + ) + ); + } + + async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) { + let body = body.to_string(); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs new file mode 100644 index 000000000000..11d9302bb46e --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -0,0 +1,59 @@ +mod client_tracker; +mod enroll; +mod protocol; +mod websocket; + +use crate::transport::remote_control::websocket::load_remote_control_auth; + +pub use self::protocol::ClientId; +use self::protocol::ServerEvent; +use self::protocol::normalize_remote_control_url; +use self::websocket::run_remote_control_websocket_loop; +use super::CHANNEL_CAPACITY; +use super::TransportEvent; +use super::next_connection_id; +use codex_core::AuthManager; +use codex_state::StateRuntime; +use std::io; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; + +pub(super) struct QueuedServerEnvelope { + pub(super) event: ServerEvent, + pub(super) client_id: ClientId, + pub(super) write_complete_tx: Option>, +} + +pub(crate) async fn start_remote_control( + remote_control_url: String, + state_db: Option>, + auth_manager: Arc, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, +) -> io::Result> { + let remote_control_target = normalize_remote_control_url(&remote_control_url)?; + validate_remote_control_auth(&auth_manager).await?; + + Ok(tokio::spawn(async move { + run_remote_control_websocket_loop( + remote_control_target, + state_db, + auth_manager, + transport_event_tx, + shutdown_token.child_token(), + ) + .await; + })) +} + +pub(crate) async fn validate_remote_control_auth( + auth_manager: &Arc, +) -> io::Result<()> { + load_remote_control_auth(auth_manager).await.map(|_| ()) +} + +#[cfg(test)] +mod tests; diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs new file mode 100644 index 000000000000..5981bbac0fd2 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -0,0 +1,139 @@ +use crate::outgoing_message::OutgoingMessage; +use codex_app_server_protocol::JSONRPCMessage; +use serde::Deserialize; +use serde::Serialize; +use std::io; +use std::io::ErrorKind; +use url::Url; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RemoteControlTarget { + pub(super) websocket_url: String, + pub(super) enroll_url: String, +} + +#[derive(Debug, Serialize)] +pub(super) struct EnrollRemoteServerRequest { + pub(super) name: String, + pub(super) os: &'static str, + pub(super) arch: &'static str, + pub(super) app_server_version: &'static str, +} + +#[derive(Debug, Deserialize)] +pub(super) struct EnrollRemoteServerResponse { + pub(super) server_id: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ClientId(pub String); + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientEvent { + ClientMessage { message: JSONRPCMessage }, + Ack, + Ping, + ClientClosed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct ClientEnvelope { + #[serde(flatten)] + pub(crate) event: ClientEvent, + #[serde(rename = "client_id", alias = "clientId")] + pub(crate) client_id: ClientId, + #[serde( + rename = "seq_id", + alias = "seqId", + skip_serializing_if = "Option::is_none" + )] + pub(crate) seq_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) cursor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PongStatus { + Active, + Unknown, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerEvent { + ServerMessage { + message: Box, + }, + #[allow(dead_code)] + Ack, + Pong { + status: PongStatus, + }, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct ServerEnvelope { + #[serde(flatten)] + pub(crate) event: ServerEvent, + #[serde(rename = "client_id", alias = "clientId")] + pub(crate) client_id: ClientId, + #[serde(rename = "seq_id", alias = "seqId")] + pub(crate) seq_id: u64, +} + +pub(super) fn normalize_remote_control_url( + remote_control_url: &str, +) -> io::Result { + let map_url_parse_error = |err: url::ParseError| -> io::Error { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control URL `{remote_control_url}`: {err}"), + ) + }; + let map_scheme_error = |_: ()| -> io::Error { + io::Error::new( + ErrorKind::InvalidInput, + format!( + "invalid remote control URL `{remote_control_url}`; expected absolute URL with http:// or https:// scheme" + ), + ) + }; + + let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?; + match remote_control_url.scheme() { + "https" | "http" => {} + _ => return Err(map_scheme_error(())), + } + if !remote_control_url.path().ends_with('/') { + let normalized_path = format!("{}/", remote_control_url.path()); + remote_control_url.set_path(&normalized_path); + } + + let mut enroll_url = remote_control_url + .join("wham/remote/control/server/enroll") + .map_err(map_url_parse_error)?; + let mut websocket_url = remote_control_url + .join("wham/remote/control/server") + .map_err(map_url_parse_error)?; + match remote_control_url.scheme() { + "https" => { + enroll_url.set_scheme("https").map_err(map_scheme_error)?; + websocket_url.set_scheme("wss").map_err(map_scheme_error)?; + } + "http" => { + enroll_url.set_scheme("http").map_err(map_scheme_error)?; + websocket_url.set_scheme("ws").map_err(map_scheme_error)?; + } + _ => return Err(map_scheme_error(())), + } + + Ok(RemoteControlTarget { + websocket_url: websocket_url.to_string(), + enroll_url: enroll_url.to_string(), + }) +} diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs new file mode 100644 index 000000000000..7f9a436343e6 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -0,0 +1,1090 @@ +use super::enroll::REMOTE_CONTROL_ACCOUNT_ID_HEADER; +use super::enroll::RemoteControlEnrollment; +use super::enroll::load_persisted_remote_control_enrollment; +use super::enroll::update_persisted_remote_control_enrollment; +use super::protocol::ClientEnvelope; +use super::protocol::ClientEvent; +use super::protocol::ClientId; +use super::protocol::normalize_remote_control_url; +use super::websocket::REMOTE_CONTROL_PROTOCOL_VERSION; +use super::*; +use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::QueuedOutgoingMessage; +use crate::transport::CHANNEL_CAPACITY; +use crate::transport::TransportEvent; +use base64::Engine; +use codex_app_server_protocol::ConfigWarningNotification; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::ServerNotification; +use codex_core::AuthManager; +use codex_core::CodexAuth; +use codex_core::test_support::auth_manager_from_auth; +use codex_core::test_support::auth_manager_from_auth_with_home; +use codex_state::StateRuntime; +use futures::SinkExt; +use futures::StreamExt; +use gethostname::gethostname; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::collections::BTreeMap; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::time::Duration; +use tokio::time::timeout; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::accept_hdr_async; +use tokio_tungstenite::tungstenite; +use tokio_util::sync::CancellationToken; + +fn remote_control_auth_manager() -> Arc { + auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) +} + +fn remote_control_auth_manager_with_home(codex_home: &TempDir) -> Arc { + auth_manager_from_auth_with_home( + CodexAuth::create_dummy_chatgpt_auth_for_testing(), + codex_home.path().to_path_buf(), + ) +} + +async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) + .await + .expect("state runtime should initialize") +} + +#[tokio::test] +async fn remote_control_transport_manages_virtual_clients_and_routes_messages() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("remote control should start"); + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + let mut websocket = accept_remote_control_connection(&listener).await; + + let client_id = ClientId("client-1".to_string()); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id: client_id.clone(), + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 0, + "status": "unknown", + }) + ); + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }, + client_id: client_id.clone(), + seq_id: Some(0), + cursor: None, + }, + ) + .await; + assert!( + timeout(Duration::from_millis(100), transport_event_rx.recv()) + .await + .is_err(), + "non-initialize client messages should be ignored before connection creation" + ); + + let initialize_message = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: initialize_message.clone(), + }, + client_id: client_id.clone(), + seq_id: Some(1), + cursor: None, + }, + ) + .await; + + let (connection_id, writer) = match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection open should arrive in time") + .expect("connection open should exist") + { + TransportEvent::ConnectionOpened { + connection_id, + writer, + .. + } => (connection_id, writer), + other => panic!("expected connection open event, got {other:?}"), + }; + + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("initialize message should arrive in time") + .expect("initialize message should exist") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => { + assert_eq!(incoming_connection_id, connection_id); + assert_eq!(message, initialize_message); + } + other => panic!("expected initialize incoming message, got {other:?}"), + } + + let followup_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: followup_message.clone(), + }, + client_id: client_id.clone(), + seq_id: Some(2), + cursor: None, + }, + ) + .await; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("followup message should arrive in time") + .expect("followup message should exist") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => { + assert_eq!(incoming_connection_id, connection_id); + assert_eq!(message, followup_message); + } + other => panic!("expected followup incoming message, got {other:?}"), + } + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id: client_id.clone(), + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 1, + "status": "active", + }) + ); + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "test".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("remote writer should accept outgoing message"); + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "server_message", + "client_id": "client-1", + "seq_id": 2, + "message": { + "method": "configWarning", + "params": { + "summary": "test", + "details": null, + } + } + }) + ); + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientClosed, + client_id: client_id.clone(), + seq_id: None, + cursor: None, + }, + ) + .await; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection close should arrive in time") + .expect("connection close should exist") + { + TransportEvent::ConnectionClosed { + connection_id: closed_connection_id, + } => { + assert_eq!(closed_connection_id, connection_id); + } + other => panic!("expected connection close event, got {other:?}"), + } + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id, + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 3, + "status": "unknown", + }) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_transport_reconnects_after_disconnect() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + let mut first_websocket = accept_remote_control_connection(&listener).await; + first_websocket + .close(None) + .await + .expect("first websocket should close"); + drop(first_websocket); + + let mut second_websocket = accept_remote_control_connection(&listener).await; + send_client_event( + &mut second_websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(2), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }), + }, + client_id: ClientId("client-2".to_string()), + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("reconnected initialize should arrive in time") + .expect("reconnected initialize should exist") + { + TransportEvent::ConnectionOpened { .. } => {} + other => panic!("expected connection open after reconnect, got {other:?}"), + } + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + let mut first_websocket = accept_remote_control_connection(&listener).await; + + let client_id = ClientId("client-1".to_string()); + let initialize_message = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }); + send_client_event( + &mut first_websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: initialize_message, + }, + client_id: client_id.clone(), + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + let writer = match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection open should arrive in time") + .expect("connection open should exist") + { + TransportEvent::ConnectionOpened { writer, .. } => writer, + other => panic!("expected connection open event, got {other:?}"), + }; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("initialize message should arrive in time") + .expect("initialize message should exist") + { + TransportEvent::IncomingMessage { .. } => {} + other => panic!("expected initialize incoming message, got {other:?}"), + } + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "stale".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("remote writer should accept outgoing message"); + assert_eq!( + read_server_event(&mut first_websocket).await, + json!({ + "type": "server_message", + "client_id": "client-1", + "seq_id": 0, + "message": { + "method": "configWarning", + "params": { + "summary": "stale", + "details": null, + } + } + }) + ); + + send_client_event( + &mut first_websocket, + ClientEnvelope { + event: ClientEvent::ClientClosed, + client_id: client_id.clone(), + seq_id: None, + cursor: None, + }, + ) + .await; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection close should arrive in time") + .expect("connection close should exist") + { + TransportEvent::ConnectionClosed { .. } => {} + other => panic!("expected connection close event, got {other:?}"), + } + + first_websocket + .close(None) + .await + .expect("first websocket should close"); + drop(first_websocket); + + let mut second_websocket = accept_remote_control_connection(&listener).await; + send_client_event( + &mut second_websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id, + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut second_websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 1, + "status": "unknown", + }) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_http_mode_enrolls_before_connecting() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let expected_server_name = gethostname().to_string_lossy().trim().to_string(); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + assert_eq!( + enroll_request.headers.get("authorization"), + Some(&"Bearer Access Token".to_string()) + ); + assert_eq!( + enroll_request.headers.get(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + Some(&"account_id".to_string()) + ); + assert_eq!( + serde_json::from_str::(&enroll_request.body) + .expect("enroll body should deserialize"), + json!({ + "name": expected_server_name, + "os": std::env::consts::OS, + "arch": std::env::consts::ARCH, + "app_server_version": env!("CARGO_PKG_VERSION"), + }) + ); + respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + + let (handshake_request, mut websocket) = + accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.path, + "/backend-api/wham/remote/control/server" + ); + assert_eq!( + handshake_request.headers.get("authorization"), + Some(&"Bearer Access Token".to_string()) + ); + assert_eq!( + handshake_request + .headers + .get(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + Some(&"account_id".to_string()) + ); + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&"srv_e_test".to_string()) + ); + assert_eq!( + handshake_request.headers.get("x-codex-name"), + Some(&base64::engine::general_purpose::STANDARD.encode(&expected_server_name)) + ); + assert_eq!( + handshake_request.headers.get("x-codex-protocol-version"), + Some(&REMOTE_CONTROL_PROTOCOL_VERSION.to_string()) + ); + + let backend_client_id = ClientId("backend-test-client".to_string()); + let writer = { + let initialize_message = + JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(11), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-backend-client", + "version": "0.1.0" + } + })), + trace: None, + }); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: initialize_message.clone(), + }, + client_id: backend_client_id.clone(), + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + let (connection_id, writer) = + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection open should arrive in time") + .expect("connection open should exist") + { + TransportEvent::ConnectionOpened { + connection_id, + writer, + .. + } => (connection_id, writer), + other => panic!("expected connection open event, got {other:?}"), + }; + + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("initialize message should arrive in time") + .expect("initialize message should exist") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => { + assert_eq!(incoming_connection_id, connection_id); + assert_eq!(message, initialize_message); + } + other => panic!("expected initialize incoming message, got {other:?}"), + } + writer + }; + + writer + .send(QueuedOutgoingMessage::new(OutgoingMessage::Response( + crate::outgoing_message::OutgoingResponse { + id: codex_app_server_protocol::RequestId::Integer(11), + result: json!({ + "userAgent": "codex-test-agent" + }), + }, + ))) + .await + .expect("remote writer should accept initialize response"); + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "server_message", + "client_id": backend_client_id.0.clone(), + "seq_id": 0, + "message": { + "id": 11, + "result": { + "userAgent": "codex-test-agent", + } + } + }) + ); + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "backend".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("remote writer should accept outgoing message"); + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "server_message", + "client_id": backend_client_id.0.clone(), + "seq_id": 1, + "message": { + "method": "configWarning", + "params": { + "summary": "backend", + "details": null, + } + } + }) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let persisted_enrollment = RemoteControlEnrollment { + account_id: Some("account_id".to_string()), + server_id: "srv_e_persisted".to_string(), + server_name: "persisted-server".to_string(), + }; + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + Some("account_id"), + Some(&persisted_enrollment), + ) + .await + .expect("persisted enrollment should save"); + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + remote_control_auth_manager_with_home(&codex_home), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("remote control should start"); + + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.path, + "/backend-api/wham/remote/control/server" + ); + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&persisted_enrollment.server_id) + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + Some("account_id"), + ) + .await, + Some(persisted_enrollment) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let expected_server_name = gethostname().to_string_lossy().trim().to_string(); + let stale_enrollment = RemoteControlEnrollment { + account_id: Some("account_id".to_string()), + server_id: "srv_e_stale".to_string(), + server_name: "stale-server".to_string(), + }; + let refreshed_enrollment = RemoteControlEnrollment { + account_id: Some("account_id".to_string()), + server_id: "srv_e_refreshed".to_string(), + server_name: expected_server_name, + }; + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + Some("account_id"), + Some(&stale_enrollment), + ) + .await + .expect("stale enrollment should save"); + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + remote_control_auth_manager_with_home(&codex_home), + transport_event_tx, + shutdown_token.clone(), + ) + .await + .expect("remote control should start"); + + let websocket_request = accept_http_request(&listener).await; + assert_eq!( + websocket_request.request_line, + "GET /backend-api/wham/remote/control/server HTTP/1.1" + ); + assert_eq!( + websocket_request.headers.get("x-codex-server-id"), + Some(&stale_enrollment.server_id) + ); + respond_with_status(websocket_request.stream, "404 Not Found", "").await; + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json( + enroll_request.stream, + json!({ "server_id": refreshed_enrollment.server_id }), + ) + .await; + + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&refreshed_enrollment.server_id) + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + Some("account_id"), + ) + .await, + Some(refreshed_enrollment) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[derive(Debug)] +struct CapturedHttpRequest { + stream: TcpStream, + request_line: String, + headers: BTreeMap, + body: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct CapturedWebSocketRequest { + path: String, + headers: BTreeMap, +} + +async fn accept_remote_control_connection(listener: &TcpListener) -> WebSocketStream { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("remote control should connect in time") + .expect("listener accept should succeed"); + accept_async(stream) + .await + .expect("websocket handshake should succeed") +} + +async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + let mut reader = BufReader::new(stream); + + let mut request_line = String::new(); + reader + .read_line(&mut request_line) + .await + .expect("request line should read"); + let request_line = request_line.trim_end_matches("\r\n").to_string(); + + let mut headers = BTreeMap::new(); + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .await + .expect("header line should read"); + if line == "\r\n" { + break; + } + let line = line.trim_end_matches("\r\n"); + let (name, value) = line.split_once(':').expect("header should contain colon"); + headers.insert(name.to_ascii_lowercase(), value.trim().to_string()); + } + + let content_length = headers + .get("content-length") + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + let mut body = vec![0; content_length]; + reader + .read_exact(&mut body) + .await + .expect("request body should read"); + + CapturedHttpRequest { + stream: reader.into_inner(), + request_line, + headers, + body: String::from_utf8(body).expect("body should be utf-8"), + } +} + +async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) { + let body = body.to_string(); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); +} + +async fn respond_with_status(stream: TcpStream, status: &str, body: &str) { + respond_with_status_and_headers(stream, status, &[], body).await; +} + +async fn respond_with_status_and_headers( + mut stream: TcpStream, + status: &str, + headers: &[(&str, &str)], + body: &str, +) { + let extra_headers = headers + .iter() + .map(|(name, value)| format!("{name}: {value}\r\n")) + .collect::(); + let response = format!( + "HTTP/1.1 {status}\r\ncontent-type: text/plain\r\ncontent-length: {}\r\nconnection: close\r\n{extra_headers}\r\n{body}", + body.len(), + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); +} + +async fn accept_remote_control_backend_connection( + listener: &TcpListener, +) -> (CapturedWebSocketRequest, WebSocketStream) { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("websocket request should arrive in time") + .expect("listener accept should succeed"); + let captured_request = Arc::new(std::sync::Mutex::new(None::)); + let captured_request_for_callback = captured_request.clone(); + let websocket = accept_hdr_async( + stream, + move |request: &tungstenite::handshake::server::Request, + response: tungstenite::handshake::server::Response| { + let headers = request + .headers() + .iter() + .map(|(name, value)| { + ( + name.as_str().to_ascii_lowercase(), + value + .to_str() + .expect("header should be valid utf-8") + .to_string(), + ) + }) + .collect::>(); + *captured_request_for_callback + .lock() + .expect("capture lock should acquire") = Some(CapturedWebSocketRequest { + path: request.uri().path().to_string(), + headers, + }); + Ok(response) + }, + ) + .await + .expect("websocket handshake should succeed"); + let captured_request = captured_request + .lock() + .expect("capture lock should acquire") + .clone() + .expect("websocket request should be captured"); + (captured_request, websocket) +} + +async fn send_client_event( + websocket: &mut WebSocketStream, + client_envelope: ClientEnvelope, +) { + let payload = serde_json::to_string(&client_envelope).expect("client event should serialize"); + websocket + .send(tungstenite::Message::Text(payload.into())) + .await + .expect("client event should send"); +} + +async fn read_server_event(websocket: &mut WebSocketStream) -> serde_json::Value { + loop { + let frame = timeout(Duration::from_secs(5), websocket.next()) + .await + .expect("server event should arrive in time") + .expect("websocket should stay open") + .expect("websocket frame should be readable"); + match frame { + tungstenite::Message::Text(text) => { + return serde_json::from_str(text.as_ref()) + .expect("server event should deserialize"); + } + tungstenite::Message::Ping(payload) => { + websocket + .send(tungstenite::Message::Pong(payload)) + .await + .expect("websocket pong should send"); + } + tungstenite::Message::Pong(_) => {} + tungstenite::Message::Close(frame) => { + panic!("unexpected websocket close frame: {frame:?}"); + } + tungstenite::Message::Binary(_) => { + panic!("unexpected binary websocket frame"); + } + tungstenite::Message::Frame(_) => {} + } + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs new file mode 100644 index 000000000000..7b743efa4480 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -0,0 +1,1094 @@ +use crate::transport::TransportEvent; +use crate::transport::remote_control::client_tracker::ClientTracker; +use crate::transport::remote_control::client_tracker::REMOTE_CONTROL_IDLE_SWEEP_INTERVAL; +use crate::transport::remote_control::enroll::RemoteControlConnectionAuth; +use crate::transport::remote_control::enroll::RemoteControlEnrollment; +use crate::transport::remote_control::enroll::enroll_remote_control_server; +use crate::transport::remote_control::enroll::format_headers; +use crate::transport::remote_control::enroll::load_persisted_remote_control_enrollment; +use crate::transport::remote_control::enroll::preview_remote_control_response_body; +use crate::transport::remote_control::enroll::update_persisted_remote_control_enrollment; + +use super::protocol::ClientEnvelope; +use super::protocol::ClientEvent; +use super::protocol::ClientId; +use super::protocol::RemoteControlTarget; +use super::protocol::ServerEnvelope; +use axum::http::HeaderValue; +use base64::Engine; +use codex_core::AuthManager; +use codex_core::auth::UnauthorizedRecovery; +use codex_core::util::backoff; +use codex_state::StateRuntime; +use codex_utils_rustls_provider::ensure_rustls_crypto_provider; +use futures::SinkExt; +use futures::StreamExt; +use futures::stream::SplitSink; +use futures::stream::SplitStream; +use std::collections::BTreeMap; +use std::collections::HashMap; +use std::io; +use std::io::ErrorKind; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio::time::MissedTickBehavior; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_util::sync::CancellationToken; +use tracing::error; +use tracing::info; +use tracing::warn; + +pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2"; +pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; +const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor"; + +struct BoundedOutboundBuffer { + buffer_by_client: HashMap>, + used_tx: watch::Sender, +} + +impl BoundedOutboundBuffer { + fn new() -> (Self, watch::Receiver) { + let (used_tx, used_rx) = watch::channel(0); + let buffer = Self { + buffer_by_client: HashMap::new(), + used_tx, + }; + (buffer, used_rx) + } + + fn insert(&mut self, server_envelope: &ServerEnvelope) { + self.buffer_by_client + .entry(server_envelope.client_id.clone()) + .or_default() + .insert(server_envelope.seq_id, server_envelope.clone()); + self.used_tx.send_modify(|used| *used += 1); + } + + fn remove(&mut self, client_id: &ClientId) { + if let Some(buffer) = self.buffer_by_client.remove(client_id) { + self.used_tx.send_modify(|used| *used -= buffer.len()); + } + } + + fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) { + let Some(buffer) = self.buffer_by_client.get_mut(client_id) else { + return; + }; + while let Some(seq_id) = buffer.first_key_value().map(|(seq_id, _)| seq_id) + && *seq_id <= acked_seq_id + { + buffer.pop_first(); + self.used_tx.send_modify(|used| *used -= 1); + } + if buffer.is_empty() { + self.buffer_by_client.remove(client_id); + } + } + + fn server_envelopes(&self) -> impl Iterator { + self.buffer_by_client + .values() + .flat_map(|buffer| buffer.values()) + } +} + +struct WebsocketState { + outbound_buffer: BoundedOutboundBuffer, + subscribe_cursor: Option, + next_seq_id: u64, +} + +struct RemoteControlWebsocket { + remote_control_target: RemoteControlTarget, + state_db: Option>, + auth_manager: Arc, + shutdown_token: CancellationToken, + reconnect_attempt: u64, + enrollment: Option, + auth_recovery: UnauthorizedRecovery, + client_tracker: Arc>, + state: Arc>, + server_event_rx: Arc>>, + used_rx: watch::Receiver, +} + +impl RemoteControlWebsocket { + fn new( + remote_control_target: RemoteControlTarget, + state_db: Option>, + auth_manager: Arc, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, + ) -> Self { + let (server_event_tx, server_event_rx) = mpsc::channel(super::CHANNEL_CAPACITY); + let client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + let (outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let auth_recovery = auth_manager.unauthorized_recovery(); + + Self { + remote_control_target, + state_db, + auth_manager, + shutdown_token, + reconnect_attempt: 0, + enrollment: None, + auth_recovery, + client_tracker: Arc::new(Mutex::new(client_tracker)), + state: Arc::new(Mutex::new(WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id: 0, + })), + server_event_rx: Arc::new(Mutex::new(server_event_rx)), + used_rx, + } + } + + async fn run(mut self) { + loop { + let shutdown_token = self.shutdown_token.child_token(); + let websocket_connection = match self.connect(&shutdown_token).await { + Some(websocket_connection) => websocket_connection, + None => break, + }; + + self.run_connection(websocket_connection, shutdown_token) + .await; + } + + self.client_tracker.lock().await.shutdown().await; + } + + async fn connect( + &mut self, + shutdown_token: &CancellationToken, + ) -> Option>> { + loop { + let subscribe_cursor = self.state.lock().await.subscribe_cursor.clone(); + tokio::select! { + _ = shutdown_token.cancelled() => return None, + connect_result = connect_remote_control_websocket( + &self.remote_control_target, + self.state_db.as_deref(), + &self.auth_manager, + &mut self.auth_recovery, + &mut self.enrollment, + subscribe_cursor.as_deref(), + ) => { + match connect_result { + Ok((websocket_connection, response)) => { + self.reconnect_attempt = 0; + self.auth_recovery = self.auth_manager.unauthorized_recovery(); + info!( + "connected to app-server remote control websocket: {}, {}", + self.remote_control_target.websocket_url, + format_headers(response.headers()) + ); + return Some(websocket_connection); + } + Err(err) => { + warn!("{err}"); + let reconnect_delay = backoff(self.reconnect_attempt); + self.reconnect_attempt += 1; + tokio::select! { + _ = shutdown_token.cancelled() => return None, + _ = tokio::time::sleep(reconnect_delay) => {} + } + } + } + } + } + } + } + + async fn run_connection( + &self, + websocket_connection: WebSocketStream>, + shutdown_token: CancellationToken, + ) { + let (websocket_writer, websocket_reader) = websocket_connection.split(); + let mut join_set = tokio::task::JoinSet::new(); + + join_set.spawn(Self::run_server_writer( + self.state.clone(), + self.server_event_rx.clone(), + self.used_rx.clone(), + websocket_writer, + shutdown_token.clone(), + )); + join_set.spawn(Self::run_websocket_reader( + self.client_tracker.clone(), + self.state.clone(), + websocket_reader, + shutdown_token.clone(), + )); + + tokio::select! { + _ = shutdown_token.cancelled() => {} + _ = join_set.join_next() => shutdown_token.cancel(), + } + + join_set.join_all().await; + } + + async fn run_server_writer( + state: Arc>, + server_event_rx: Arc>>, + used_rx: watch::Receiver, + websocket_writer: SplitSink< + WebSocketStream>, + tungstenite::Message, + >, + shutdown_token: CancellationToken, + ) { + let result = Self::run_server_writer_inner( + state, + server_event_rx, + used_rx, + websocket_writer, + shutdown_token, + ) + .await; + if let Err(err) = result { + warn!("remote control websocket writer disconnected, err: {err}"); + } else { + warn!("remote control websocket writer was stopped"); + } + } + + async fn run_server_writer_inner( + state: Arc>, + server_event_rx: Arc>>, + mut used_rx: watch::Receiver, + mut websocket_writer: SplitSink< + WebSocketStream>, + tungstenite::Message, + >, + shutdown_token: CancellationToken, + ) -> io::Result<()> { + for server_envelope in state.lock().await.outbound_buffer.server_envelopes() { + let payload = match serde_json::to_string(&server_envelope) { + Ok(payload) => payload, + Err(err) => { + error!("failed to serialize remote-control server event: {err}"); + continue; + } + }; + tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => { + if let Err(err) = send_result { + return Err(io::Error::other(err)); + } + } + }; + } + + let mut server_event_rx = server_event_rx.lock().await; + loop { + tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + _ = used_rx.wait_for(|used| *used < super::CHANNEL_CAPACITY) => {} + }; + let queued_server_envelope = tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + recv_result = server_event_rx.recv() => { + match recv_result { + Some(queued_server_envelope) => queued_server_envelope, + None => { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "server event channel closed")); + } + } + } + }; + let (server_envelope, write_complete_tx) = { + let mut state = state.lock().await; + let seq_id = state.next_seq_id; + state.next_seq_id = state.next_seq_id.saturating_add(1); + + let server_envelope = ServerEnvelope { + event: queued_server_envelope.event, + client_id: queued_server_envelope.client_id, + seq_id, + }; + state.outbound_buffer.insert(&server_envelope); + + (server_envelope, queued_server_envelope.write_complete_tx) + }; + + let payload = match serde_json::to_string(&server_envelope) { + Ok(payload) => payload, + Err(err) => { + error!("failed to serialize remote-control server event: {err}"); + continue; + } + }; + + tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => { + if let Err(err) = send_result { + return Err(io::Error::other(err)); + } + } + }; + if let Some(write_complete_tx) = write_complete_tx { + let _ = write_complete_tx.send(()); + } + } + } + + async fn run_websocket_reader( + client_tracker: Arc>, + state: Arc>, + websocket_reader: SplitStream>>, + shutdown_token: CancellationToken, + ) { + let result = Self::run_websocket_reader_inner( + client_tracker, + state, + websocket_reader, + shutdown_token, + ) + .await; + if let Err(err) = result { + warn!("remote control websocket reader disconnected, err: {err}"); + } else { + warn!("remote control websocket reader was stopped"); + } + } + + async fn run_websocket_reader_inner( + client_tracker: Arc>, + state: Arc>, + mut websocket_reader: SplitStream>>, + shutdown_token: CancellationToken, + ) -> io::Result<()> { + let mut client_tracker = client_tracker.lock().await; + let mut idle_sweep_interval = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL); + idle_sweep_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + + loop { + let incoming_message = tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + _ = client_tracker.bookkeep_join_set() => continue, + _ = idle_sweep_interval.tick() => { + let expired_client_ids = match client_tracker.close_expired_clients().await { + Ok(expired_client_ids) => expired_client_ids, + Err(_) => return Ok(()), + }; + if !expired_client_ids.is_empty() { + let mut state = state.lock().await; + for client_id in expired_client_ids { + state.outbound_buffer.remove(&client_id); + } + } + continue; + } + incoming_message = websocket_reader.next() => { + match incoming_message { + Some(incoming_message) => incoming_message, + None => return Err(io::Error::new(ErrorKind::UnexpectedEof, "websocket stream ended")), + } + } + }; + let client_envelope = match incoming_message { + Ok(tungstenite::Message::Text(text)) => { + match serde_json::from_str::(&text) { + Ok(client_envelope) => client_envelope, + Err(err) => { + warn!("failed to deserialize remote-control client event: {err}"); + continue; + } + } + } + Ok(tungstenite::Message::Ping(_)) + | Ok(tungstenite::Message::Pong(_)) + | Ok(tungstenite::Message::Frame(_)) => continue, + Ok(tungstenite::Message::Binary(_)) => { + warn!("dropping unsupported binary remote-control websocket message"); + continue; + } + Ok(tungstenite::Message::Close(_)) => { + return Err(io::Error::new( + ErrorKind::ConnectionAborted, + "websocket disconnected", + )); + } + Err(err) => { + return Err(io::Error::new( + ErrorKind::InvalidData, + format!("failed to read from websocket: {err}"), + )); + } + }; + + let mut state = state.lock().await; + if let Some(cursor) = client_envelope.cursor.as_deref() { + state.subscribe_cursor = Some(cursor.to_string()); + } + if let ClientEvent::Ack = &client_envelope.event + && let Some(acked_seq_id) = client_envelope.seq_id + { + state + .outbound_buffer + .ack(&client_envelope.client_id, acked_seq_id); + } + if matches!(&client_envelope.event, ClientEvent::ClientClosed) + || remote_control_message_starts_connection(&client_envelope.event) + { + state.outbound_buffer.remove(&client_envelope.client_id); + } + drop(state); + + if client_tracker + .handle_message(client_envelope) + .await + .is_err() + { + return Ok(()); + } + } + } +} + +pub(super) async fn run_remote_control_websocket_loop( + remote_control_target: RemoteControlTarget, + state_db: Option>, + auth_manager: Arc, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, +) { + RemoteControlWebsocket::new( + remote_control_target, + state_db, + auth_manager, + transport_event_tx, + shutdown_token, + ) + .run() + .await; +} + +fn remote_control_message_starts_connection(event: &ClientEvent) -> bool { + matches!( + event, + ClientEvent::ClientMessage { + message: codex_app_server_protocol::JSONRPCMessage::Request( + codex_app_server_protocol::JSONRPCRequest { method, .. } + ), + } if method == "initialize" + ) +} + +fn set_remote_control_header( + headers: &mut tungstenite::http::HeaderMap, + name: &'static str, + value: &str, +) -> io::Result<()> { + let header_value = HeaderValue::from_str(value).map_err(|err| { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control header `{name}`: {err}"), + ) + })?; + headers.insert(name, header_value); + Ok(()) +} + +fn build_remote_control_websocket_request( + websocket_url: &str, + enrollment: &RemoteControlEnrollment, + auth: &RemoteControlConnectionAuth, + subscribe_cursor: Option<&str>, +) -> io::Result> { + let mut request = websocket_url.into_client_request().map_err(|err| { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control websocket URL `{websocket_url}`: {err}"), + ) + })?; + let headers = request.headers_mut(); + set_remote_control_header(headers, "x-codex-server-id", &enrollment.server_id)?; + set_remote_control_header( + headers, + "x-codex-name", + &base64::engine::general_purpose::STANDARD.encode(&enrollment.server_name), + )?; + set_remote_control_header( + headers, + "x-codex-protocol-version", + REMOTE_CONTROL_PROTOCOL_VERSION, + )?; + set_remote_control_header( + headers, + "authorization", + &format!("Bearer {}", auth.bearer_token), + )?; + if let Some(account_id) = auth.account_id.as_deref() { + set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id)?; + } + if let Some(subscribe_cursor) = subscribe_cursor { + set_remote_control_header( + headers, + REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER, + subscribe_cursor, + )?; + } + Ok(request) +} + +pub(crate) async fn load_remote_control_auth( + auth_manager: &Arc, +) -> io::Result { + let auth = match auth_manager.auth().await { + Some(auth) => auth, + None => { + auth_manager.reload(); + auth_manager.auth().await.ok_or_else(|| { + io::Error::new( + ErrorKind::PermissionDenied, + "remote control requires ChatGPT authentication", + ) + })? + } + }; + + if !auth.is_chatgpt_auth() { + return Err(io::Error::new( + ErrorKind::PermissionDenied, + "remote control requires ChatGPT authentication; API key auth is not supported", + )); + } + + Ok(RemoteControlConnectionAuth { + bearer_token: auth.get_token().map_err(io::Error::other)?, + account_id: auth.get_account_id(), + }) +} + +pub(super) async fn connect_remote_control_websocket( + remote_control_target: &RemoteControlTarget, + state_db: Option<&StateRuntime>, + auth_manager: &Arc, + auth_recovery: &mut UnauthorizedRecovery, + enrollment: &mut Option, + subscribe_cursor: Option<&str>, +) -> io::Result<( + WebSocketStream>, + tungstenite::http::Response<()>, +)> { + ensure_rustls_crypto_provider(); + + let auth = load_remote_control_auth(auth_manager).await?; + if auth.account_id.as_ref() + != enrollment + .as_ref() + .and_then(|enrollment| enrollment.account_id.as_ref()) + { + *enrollment = None; + } + + if enrollment.is_none() { + *enrollment = load_persisted_remote_control_enrollment( + state_db, + remote_control_target, + auth.account_id.as_deref(), + ) + .await; + } + + if enrollment.is_none() { + let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await + { + Ok(new_enrollment) => new_enrollment, + Err(err) + if err.kind() == ErrorKind::PermissionDenied + && recover_remote_control_auth(auth_recovery).await => + { + return Err(io::Error::other(format!( + "{err}; retrying after auth recovery" + ))); + } + Err(err) => return Err(err), + }; + if let Err(err) = update_persisted_remote_control_enrollment( + state_db, + remote_control_target, + auth.account_id.as_deref(), + Some(&new_enrollment), + ) + .await + { + warn!("failed to persist remote control enrollment in sqlite state db: {err}"); + } + *enrollment = Some(new_enrollment); + } + + let enrollment_ref = enrollment.as_ref().ok_or_else(|| { + io::Error::other("missing remote control enrollment after enrollment step") + })?; + let request = build_remote_control_websocket_request( + &remote_control_target.websocket_url, + enrollment_ref, + &auth, + subscribe_cursor, + )?; + + match connect_async(request).await { + Ok((websocket_stream, response)) => Ok((websocket_stream, response.map(|_| ()))), + Err(err) => { + match &err { + tungstenite::Error::Http(response) if response.status().as_u16() == 404 => { + if let Err(clear_err) = update_persisted_remote_control_enrollment( + state_db, + remote_control_target, + auth.account_id.as_deref(), + /*enrollment*/ None, + ) + .await + { + warn!( + "failed to clear stale remote control enrollment in sqlite state db: {clear_err}" + ); + } + *enrollment = None; + } + tungstenite::Error::Http(response) + if matches!(response.status().as_u16(), 401 | 403) => + { + if recover_remote_control_auth(auth_recovery).await { + return Err(io::Error::other(format!( + "remote control websocket auth failed with HTTP {}; retrying after auth recovery", + response.status() + ))); + } + } + _ => {} + } + Err(io::Error::other( + format_remote_control_websocket_connect_error( + &remote_control_target.websocket_url, + &err, + ), + )) + } + } +} + +async fn recover_remote_control_auth(auth_recovery: &mut UnauthorizedRecovery) -> bool { + if !auth_recovery.has_next() { + return false; + } + + let mode = auth_recovery.mode_name(); + let step = auth_recovery.step_name(); + match auth_recovery.next().await { + Ok(step_result) => { + info!( + "remote control websocket auth recovery succeeded: mode={mode}, step={step}, auth_state_changed={:?}", + step_result.auth_state_changed() + ); + true + } + Err(err) => { + warn!("remote control websocket auth recovery failed: mode={mode}, step={step}: {err}"); + false + } + } +} + +fn format_remote_control_websocket_connect_error( + websocket_url: &str, + err: &tungstenite::Error, +) -> String { + let mut message = + format!("failed to connect app-server remote control websocket `{websocket_url}`: {err}"); + let tungstenite::Error::Http(response) = err else { + return message; + }; + + message.push_str(&format!(", {}", format_headers(response.headers()))); + if let Some(body) = response.body().as_ref() + && !body.is_empty() + { + let body_preview = preview_remote_control_response_body(body); + message.push_str(&format!(", body: {body_preview}")); + } + + message +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transport::remote_control::protocol::normalize_remote_control_url; + use chrono::Utc; + use codex_app_server_protocol::AuthMode; + use codex_core::CodexAuth; + use codex_core::auth::AuthCredentialsStoreMode; + use codex_core::auth::AuthDotJson; + use codex_core::auth::save_auth; + use codex_core::test_support::auth_manager_from_auth; + use codex_login::token_data::TokenData; + use codex_login::token_data::parse_chatgpt_jwt_claims; + use codex_state::StateRuntime; + use pretty_assertions::assert_eq; + use std::sync::Arc; + use tempfile::TempDir; + use tokio::io::AsyncBufReadExt; + use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::sync::mpsc; + use tokio::time::Duration; + use tokio::time::timeout; + + async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) + .await + .expect("state runtime should initialize") + } + + fn remote_control_auth_manager() -> Arc { + auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + } + + fn remote_control_auth_dot_json(access_token: &str) -> AuthDotJson { + #[derive(serde::Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_user_id": "user-12345", + "user_id": "user-12345", + "chatgpt_account_id": "account_id" + } + }); + let b64 = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + let header_b64 = b64(&serde_json::to_vec(&header).expect("header should serialize")); + let payload_b64 = b64(&serde_json::to_vec(&payload).expect("payload should serialize")); + let fake_jwt = format!("{header_b64}.{payload_b64}.sig"); + + AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: parse_chatgpt_jwt_claims(&fake_jwt).expect("fake jwt should parse"), + access_token: access_token.to_string(), + refresh_token: "refresh-token".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + } + } + + #[tokio::test] + async fn connect_remote_control_websocket_includes_http_error_details() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let expected_error = format!( + "failed to connect app-server remote control websocket `{}`: HTTP error: 503 Service Unavailable, request-id: , cf-ray: , body: upstream unavailable", + remote_control_target.websocket_url + ); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "GET /backend-api/wham/remote/control/server HTTP/1.1" + ); + respond_with_status_and_headers( + stream, + "503 Service Unavailable", + &[("x-trace-id", "trace-503"), ("x-region", "us-east-1")], + "upstream unavailable", + ) + .await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut enrollment = Some(RemoteControlEnrollment { + account_id: Some("account_id".to_string()), + server_id: "srv_e_test".to_string(), + server_name: "test-server".to_string(), + }); + + let err = match connect_remote_control_websocket( + &remote_control_target, + Some(state_db.as_ref()), + &auth_manager, + &mut auth_recovery, + &mut enrollment, + None, + ) + .await + { + Ok(_) => panic!("http error response should fail the websocket connect"), + Err(err) => err, + }; + + server_task.await.expect("server task should succeed"); + assert_eq!(err.to_string(), expected_error); + } + + #[tokio::test] + async fn connect_remote_control_websocket_recovers_after_unauthorized_reload() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "GET /backend-api/wham/remote/control/server HTTP/1.1" + ); + respond_with_status_and_headers(stream, "401 Unauthorized", &[], "unauthorized").await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("stale-token"), + AuthCredentialsStoreMode::File, + ) + .expect("stale auth should save"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + ); + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut enrollment = Some(RemoteControlEnrollment { + account_id: Some("account_id".to_string()), + server_id: "srv_e_test".to_string(), + server_name: "test-server".to_string(), + }); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("fresh-token"), + AuthCredentialsStoreMode::File, + ) + .expect("fresh auth should save"); + + let err = connect_remote_control_websocket( + &remote_control_target, + Some(state_db.as_ref()), + &auth_manager, + &mut auth_recovery, + &mut enrollment, + None, + ) + .await + .expect_err("unauthorized response should fail the websocket connect"); + + server_task.await.expect("server task should succeed"); + assert_eq!( + err.to_string(), + "remote control websocket auth failed with HTTP 401 Unauthorized; retrying after auth recovery" + ); + assert_eq!( + auth_manager + .auth() + .await + .expect("auth should remain available") + .get_token() + .expect("token should be readable"), + "fresh-token" + ); + } + + #[tokio::test] + async fn connect_remote_control_websocket_recovers_after_unauthorized_enrollment() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let enroll_url = remote_control_target.enroll_url.clone(); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_status_and_headers(stream, "401 Unauthorized", &[], "unauthorized").await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("stale-token"), + AuthCredentialsStoreMode::File, + ) + .expect("stale auth should save"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + ); + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut enrollment = None; + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("fresh-token"), + AuthCredentialsStoreMode::File, + ) + .expect("fresh auth should save"); + + let err = connect_remote_control_websocket( + &remote_control_target, + Some(state_db.as_ref()), + &auth_manager, + &mut auth_recovery, + &mut enrollment, + None, + ) + .await + .expect_err("unauthorized enrollment should fail the websocket connect"); + + server_task.await.expect("server task should succeed"); + assert_eq!( + err.to_string(), + format!( + "remote control server enrollment failed at `{enroll_url}`: HTTP 401 Unauthorized, request-id: , cf-ray: , body: unauthorized; retrying after auth recovery" + ) + ); + assert_eq!( + auth_manager + .auth() + .await + .expect("auth should remain available") + .get_token() + .expect("token should be readable"), + "fresh-token" + ); + } + + #[tokio::test] + async fn run_remote_control_websocket_loop_shutdown_cancels_reconnect_backoff() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + ); + drop(listener); + + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let (transport_event_tx, transport_event_rx) = mpsc::channel(1); + drop(transport_event_rx); + let shutdown_token = CancellationToken::new(); + let websocket_task = tokio::spawn(run_remote_control_websocket_loop( + remote_control_target, + None, + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + )); + + tokio::time::sleep(Duration::from_millis(50)).await; + shutdown_token.cancel(); + + timeout(Duration::from_millis(100), websocket_task) + .await + .expect("shutdown should cancel reconnect backoff") + .expect("websocket task should join"); + } + + async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + let mut reader = BufReader::new(stream); + + let mut request_line = String::new(); + reader + .read_line(&mut request_line) + .await + .expect("request line should read"); + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .await + .expect("header line should read"); + if line == "\r\n" { + break; + } + } + + ( + reader.into_inner(), + request_line.trim_end_matches("\r\n").to_string(), + ) + } + + async fn respond_with_status_and_headers( + mut stream: TcpStream, + status: &str, + headers: &[(&str, &str)], + body: &str, + ) { + let extra_headers = headers + .iter() + .map(|(name, value)| format!("{name}: {value}\r\n")) + .collect::(); + let response = format!( + "HTTP/1.1 {status}\r\ncontent-type: text/plain\r\ncontent-length: {}\r\nconnection: close\r\n{extra_headers}\r\n{body}", + body.len(), + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); + } +} diff --git a/codex-rs/app-server/src/transport/stdio.rs b/codex-rs/app-server/src/transport/stdio.rs index 4f2bf267455b..6d40593a6190 100644 --- a/codex-rs/app-server/src/transport/stdio.rs +++ b/codex-rs/app-server/src/transport/stdio.rs @@ -1,8 +1,8 @@ use super::CHANNEL_CAPACITY; use super::TransportEvent; use super::forward_incoming_message; +use super::next_connection_id; use super::serialize_outgoing_message; -use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; use std::io::ErrorKind; use std::io::Result as IoResult; @@ -20,7 +20,7 @@ pub(crate) async fn start_stdio_connection( transport_event_tx: mpsc::Sender, stdio_handles: &mut Vec>, ) -> IoResult<()> { - let connection_id = ConnectionId(0); + let connection_id = next_connection_id(); let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); let writer_tx_for_reader = writer_tx.clone(); transport_event_tx diff --git a/codex-rs/app-server/src/transport/websocket.rs b/codex-rs/app-server/src/transport/websocket.rs index 05dfe24b05a6..41b138e216dd 100644 --- a/codex-rs/app-server/src/transport/websocket.rs +++ b/codex-rs/app-server/src/transport/websocket.rs @@ -4,6 +4,7 @@ use super::auth::WebsocketAuthPolicy; use super::auth::authorize_upgrade; use super::auth::should_warn_about_unauthenticated_non_loopback_listener; use super::forward_incoming_message; +use super::next_connection_id; use super::serialize_outgoing_message; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; @@ -32,8 +33,6 @@ use owo_colors::Style; use std::io::Result as IoResult; use std::net::SocketAddr; use std::sync::Arc; -use std::sync::atomic::AtomicU64; -use std::sync::atomic::Ordering; use tokio::net::TcpListener; use tokio::sync::mpsc; use tokio::task::JoinHandle; @@ -75,7 +74,6 @@ fn print_websocket_startup_banner(addr: SocketAddr) { #[derive(Clone)] struct WebSocketListenerState { transport_event_tx: mpsc::Sender, - connection_counter: Arc, auth_policy: Arc, } @@ -113,7 +111,7 @@ async fn websocket_upgrade_handler( ); return (err.status_code(), err.message()).into_response(); } - let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed)); + let connection_id = next_connection_id(); info!(%peer_addr, "websocket client connected"); websocket .on_upgrade(move |stream| async move { @@ -146,7 +144,6 @@ pub(crate) async fn start_websocket_acceptor( .layer(middleware::from_fn(reject_requests_with_origin_header)) .with_state(WebSocketListenerState { transport_event_tx, - connection_counter: Arc::new(AtomicU64::new(1)), auth_policy: Arc::new(auth_policy), }); let server = axum::serve( diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 12a531d35dc0..1f962bf2f489 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -328,7 +328,7 @@ struct AppServerCommand { subcommand: Option, /// Transport endpoint URL. Supported values: `stdio://` (default), - /// `ws://IP:PORT`. + /// `ws://IP:PORT`, `off`. #[arg( long = "listen", value_name = "URL", @@ -1930,6 +1930,12 @@ mod tests { ); } + #[test] + fn app_server_listen_off_parses() { + let app_server = app_server_from_args(["codex", "app-server", "--listen", "off"].as_ref()); + assert_eq!(app_server.listen, codex_app_server::AppServerTransport::Off); + } + #[test] fn app_server_listen_invalid_url_fails_to_parse() { let parse_result = diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 3d091006c975..86c8788228b4 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -434,6 +434,9 @@ "realtime_conversation": { "type": "boolean" }, + "remote_control": { + "type": "boolean" + }, "remote_models": { "type": "boolean" }, @@ -2071,6 +2074,9 @@ "realtime_conversation": { "type": "boolean" }, + "remote_control": { + "type": "boolean" + }, "remote_models": { "type": "boolean" }, diff --git a/codex-rs/features/src/lib.rs b/codex-rs/features/src/lib.rs index 81cc95a47bee..a20a36eddda8 100644 --- a/codex-rs/features/src/lib.rs +++ b/codex-rs/features/src/lib.rs @@ -176,6 +176,8 @@ pub enum Feature { VoiceTranscription, /// Enable experimental realtime voice conversation mode in the TUI. RealtimeConversation, + /// Connect app-server to the ChatGPT remote control service. + RemoteControl, /// Route interactive startup to the app-server-backed TUI implementation. TuiAppServer, /// Prevent idle system sleep while a turn is actively running. @@ -819,6 +821,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::UnderDevelopment, default_enabled: false, }, + FeatureSpec { + id: Feature::RemoteControl, + key: "remote_control", + stage: Stage::UnderDevelopment, + default_enabled: false, + }, FeatureSpec { id: Feature::TuiAppServer, key: "tui_app_server", diff --git a/codex-rs/features/src/tests.rs b/codex-rs/features/src/tests.rs index ecfa87b41bd1..b36286d2a320 100644 --- a/codex-rs/features/src/tests.rs +++ b/codex-rs/features/src/tests.rs @@ -159,6 +159,12 @@ fn image_detail_original_feature_is_under_development() { assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false); } +#[test] +fn remote_control_is_under_development() { + assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment); + assert_eq!(Feature::RemoteControl.default_enabled(), false); +} + #[test] fn collab_is_legacy_alias_for_multi_agent() { assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab)); diff --git a/codex-rs/state/migrations/0023_remote_control_enrollments.sql b/codex-rs/state/migrations/0023_remote_control_enrollments.sql new file mode 100644 index 000000000000..9a2081dd8f38 --- /dev/null +++ b/codex-rs/state/migrations/0023_remote_control_enrollments.sql @@ -0,0 +1,8 @@ +CREATE TABLE remote_control_enrollments ( + websocket_url TEXT NOT NULL, + account_id TEXT NOT NULL, + server_id TEXT NOT NULL, + server_name TEXT NOT NULL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (websocket_url, account_id) +); diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index 645aa4269561..04eed8e0e2ea 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -53,6 +53,7 @@ mod agent_jobs; mod backfill; mod logs; mod memories; +mod remote_control; #[cfg(test)] mod test_support; mod threads; diff --git a/codex-rs/state/src/runtime/remote_control.rs b/codex-rs/state/src/runtime/remote_control.rs new file mode 100644 index 000000000000..307dac8184ba --- /dev/null +++ b/codex-rs/state/src/runtime/remote_control.rs @@ -0,0 +1,197 @@ +use super::*; + +const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = ""; + +fn remote_control_account_id_key(account_id: Option<&str>) -> &str { + account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE) +} + +impl StateRuntime { + pub async fn get_remote_control_enrollment( + &self, + websocket_url: &str, + account_id: Option<&str>, + ) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT server_id, server_name +FROM remote_control_enrollments +WHERE websocket_url = ? AND account_id = ? + "#, + ) + .bind(websocket_url) + .bind(remote_control_account_id_key(account_id)) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?))) + .transpose() + } + + pub async fn upsert_remote_control_enrollment( + &self, + websocket_url: &str, + account_id: Option<&str>, + server_id: &str, + server_name: &str, + ) -> anyhow::Result<()> { + sqlx::query( + r#" +INSERT INTO remote_control_enrollments ( + websocket_url, + account_id, + server_id, + server_name, + updated_at +) VALUES (?, ?, ?, ?, ?) +ON CONFLICT(websocket_url, account_id) DO UPDATE SET + server_id = excluded.server_id, + server_name = excluded.server_name, + updated_at = excluded.updated_at + "#, + ) + .bind(websocket_url) + .bind(remote_control_account_id_key(account_id)) + .bind(server_id) + .bind(server_name) + .bind(Utc::now().timestamp()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn delete_remote_control_enrollment( + &self, + websocket_url: &str, + account_id: Option<&str>, + ) -> anyhow::Result { + let result = sqlx::query( + r#" +DELETE FROM remote_control_enrollments +WHERE websocket_url = ? AND account_id = ? + "#, + ) + .bind(websocket_url) + .bind(remote_control_account_id_key(account_id)) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected()) + } +} + +#[cfg(test)] +mod tests { + use super::StateRuntime; + use super::test_support::unique_temp_dir; + use pretty_assertions::assert_eq; + + #[tokio::test] + async fn remote_control_enrollment_round_trips_by_target_and_account() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + runtime + .upsert_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + Some("account-a"), + "srv_e_first", + "first-server", + ) + .await + .expect("insert first enrollment"); + runtime + .upsert_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + Some("account-b"), + "srv_e_second", + "second-server", + ) + .await + .expect("insert second enrollment"); + + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + Some("account-a"), + ) + .await + .expect("load first enrollment"), + Some(("srv_e_first".to_string(), "first-server".to_string())) + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + None, + ) + .await + .expect("load missing enrollment"), + None + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn delete_remote_control_enrollment_removes_only_matching_entry() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + runtime + .upsert_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + None, + "srv_e_first", + "first-server", + ) + .await + .expect("insert first enrollment"); + runtime + .upsert_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + Some("account-a"), + "srv_e_second", + "second-server", + ) + .await + .expect("insert second enrollment"); + + assert_eq!( + runtime + .delete_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + None, + ) + .await + .expect("delete first enrollment"), + 1 + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + None, + ) + .await + .expect("load deleted enrollment"), + None + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + Some("account-a"), + ) + .await + .expect("load retained enrollment"), + Some(("srv_e_second".to_string(), "second-server".to_string())) + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } +} From a8d299f06532d94483f1b5b102c723a8fc4cf4c6 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 10:44:29 -0700 Subject: [PATCH 02/17] always https except localhost --- .../src/transport/remote_control/enroll.rs | 11 ++-- .../src/transport/remote_control/protocol.rs | 54 ++++++++++++++++++- .../src/transport/remote_control/tests.rs | 52 ++++++------------ .../src/transport/remote_control/websocket.rs | 38 +++++-------- 4 files changed, 88 insertions(+), 67 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index 99b71c9fd759..5ac851f4a8a4 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -212,9 +212,9 @@ mod tests { async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() { let codex_home = TempDir::new().expect("temp dir should create"); let state_db = remote_control_state_runtime(&codex_home).await; - let first_target = normalize_remote_control_url("http://example.com/remote/control") + let first_target = normalize_remote_control_url("https://example.com/remote/control") .expect("first target should parse"); - let second_target = normalize_remote_control_url("http://example.com/other/control") + let second_target = normalize_remote_control_url("https://example.com/other/control") .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), @@ -277,9 +277,9 @@ mod tests { async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() { let codex_home = TempDir::new().expect("temp dir should create"); let state_db = remote_control_state_runtime(&codex_home).await; - let first_target = normalize_remote_control_url("http://example.com/remote/control") + let first_target = normalize_remote_control_url("https://example.com/remote/control") .expect("first target should parse"); - let second_target = normalize_remote_control_url("http://example.com/other/control") + let second_target = normalize_remote_control_url("https://example.com/other/control") .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), @@ -344,10 +344,11 @@ mod tests { .await .expect("listener should bind"); let remote_control_url = format!( - "http://{}/backend-api/", + "http://localhost:{}/backend-api/", listener .local_addr() .expect("listener should have a local addr") + .port() ); let remote_control_target = normalize_remote_control_url(&remote_control_url).expect("target should parse"); diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 5981bbac0fd2..76d7a9fb1f39 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -99,14 +99,18 @@ pub(super) fn normalize_remote_control_url( io::Error::new( ErrorKind::InvalidInput, format!( - "invalid remote control URL `{remote_control_url}`; expected absolute URL with http:// or https:// scheme" + "invalid remote control URL `{remote_control_url}`; expected absolute HTTPS URL" ), ) }; let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?; match remote_control_url.scheme() { - "https" | "http" => {} + "https" => {} + "http" + if remote_control_url + .host_str() + .is_some_and(|host| host.eq_ignore_ascii_case("localhost")) => {} _ => return Err(map_scheme_error(())), } if !remote_control_url.path().ends_with('/') { @@ -137,3 +141,49 @@ pub(super) fn normalize_remote_control_url( enroll_url: enroll_url.to_string(), }) } + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn normalize_remote_control_url_accepts_https_urls() { + assert_eq!( + normalize_remote_control_url("https://example.com/backend-api") + .expect("https URL should normalize"), + RemoteControlTarget { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: "https://example.com/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); + } + + #[test] + fn normalize_remote_control_url_accepts_localhost_http_urls() { + assert_eq!( + normalize_remote_control_url("http://localhost:8080/backend-api") + .expect("localhost http URL should normalize"), + RemoteControlTarget { + websocket_url: "ws://localhost:8080/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: "http://localhost:8080/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); + } + + #[test] + fn normalize_remote_control_url_rejects_non_localhost_http_urls() { + let err = normalize_remote_control_url("http://example.com/backend-api") + .expect_err("non-localhost http URL should be rejected"); + + assert_eq!(err.kind(), ErrorKind::InvalidInput); + assert_eq!( + err.to_string(), + "invalid remote control URL `http://example.com/backend-api`; expected absolute HTTPS URL" + ); + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs index 7f9a436343e6..e0556ebfbeee 100644 --- a/codex-rs/app-server/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -61,17 +61,22 @@ async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc .expect("state runtime should initialize") } +fn remote_control_url_for_listener(listener: &TcpListener) -> String { + format!( + "http://localhost:{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + .port() + ) +} + #[tokio::test] async fn remote_control_transport_manages_virtual_clients_and_routes_messages() { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let codex_home = TempDir::new().expect("temp dir should create"); let (transport_event_tx, mut transport_event_rx) = mpsc::channel::(CHANNEL_CAPACITY); @@ -323,12 +328,7 @@ async fn remote_control_transport_reconnects_after_disconnect() { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let codex_home = TempDir::new().expect("temp dir should create"); let (transport_event_tx, mut transport_event_rx) = mpsc::channel::(CHANNEL_CAPACITY); @@ -398,12 +398,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let codex_home = TempDir::new().expect("temp dir should create"); let (transport_event_tx, mut transport_event_rx) = mpsc::channel::(CHANNEL_CAPACITY); @@ -548,12 +543,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let codex_home = TempDir::new().expect("temp dir should create"); let (transport_event_tx, mut transport_event_rx) = mpsc::channel::(CHANNEL_CAPACITY); @@ -745,12 +735,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let codex_home = TempDir::new().expect("temp dir should create"); let state_db = remote_control_state_runtime(&codex_home).await; let remote_control_target = @@ -810,12 +795,7 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let codex_home = TempDir::new().expect("temp dir should create"); let state_db = remote_control_state_runtime(&codex_home).await; let remote_control_target = diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 7b743efa4480..a71a89797323 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -764,6 +764,16 @@ mod tests { auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) } + fn remote_control_url_for_listener(listener: &TcpListener) -> String { + format!( + "http://localhost:{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + .port() + ) + } + fn remote_control_auth_dot_json(access_token: &str) -> AuthDotJson { #[derive(serde::Serialize)] struct Header { @@ -806,12 +816,7 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let remote_control_target = normalize_remote_control_url(&remote_control_url).expect("target should parse"); let expected_error = format!( @@ -865,12 +870,7 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let remote_control_target = normalize_remote_control_url(&remote_control_url).expect("target should parse"); let server_task = tokio::spawn(async move { @@ -939,12 +939,7 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); let remote_control_target = normalize_remote_control_url(&remote_control_url).expect("target should parse"); let enroll_url = remote_control_target.enroll_url.clone(); @@ -1012,12 +1007,7 @@ mod tests { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); - let remote_control_url = format!( - "http://{}/backend-api/", - listener - .local_addr() - .expect("listener should have a local addr") - ); + let remote_control_url = remote_control_url_for_listener(&listener); drop(listener); let remote_control_target = From 5537c4c01435bf4bcbe50f2fdf39a5e0821cdbcf Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 12:24:29 -0700 Subject: [PATCH 03/17] lock to chatgpt.com --- .../src/transport/remote_control/enroll.rs | 14 ++-- .../src/transport/remote_control/protocol.rs | 79 ++++++++++++++----- 2 files changed, 66 insertions(+), 27 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index 5ac851f4a8a4..ab689496b35d 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -212,10 +212,11 @@ mod tests { async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() { let codex_home = TempDir::new().expect("temp dir should create"); let state_db = remote_control_state_runtime(&codex_home).await; - let first_target = normalize_remote_control_url("https://example.com/remote/control") + let first_target = normalize_remote_control_url("https://chatgpt.com/remote/control") .expect("first target should parse"); - let second_target = normalize_remote_control_url("https://example.com/other/control") - .expect("second target should parse"); + let second_target = + normalize_remote_control_url("https://api.chatgpt-staging.com/other/control") + .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), server_id: "srv_e_first".to_string(), @@ -277,10 +278,11 @@ mod tests { async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() { let codex_home = TempDir::new().expect("temp dir should create"); let state_db = remote_control_state_runtime(&codex_home).await; - let first_target = normalize_remote_control_url("https://example.com/remote/control") + let first_target = normalize_remote_control_url("https://chatgpt.com/remote/control") .expect("first target should parse"); - let second_target = normalize_remote_control_url("https://example.com/other/control") - .expect("second target should parse"); + let second_target = + normalize_remote_control_url("https://api.chatgpt-staging.com/other/control") + .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), server_id: "srv_e_first".to_string(), diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 76d7a9fb1f39..4785abb3d35f 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -99,18 +99,22 @@ pub(super) fn normalize_remote_control_url( io::Error::new( ErrorKind::InvalidInput, format!( - "invalid remote control URL `{remote_control_url}`; expected absolute HTTPS URL" + "invalid remote control URL `{remote_control_url}`; expected HTTPS URL for chatgpt.com or chatgpt-staging.com, or HTTP/HTTPS URL for localhost" ), ) }; let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?; + let host = remote_control_url.host_str(); + let is_localhost = host == Some("localhost"); + let is_allowed_chatgpt_host = host.is_some_and(|host| { + matches!(host, "chatgpt.com" | "chatgpt-staging.com") + || host.ends_with(".chatgpt.com") + || host.ends_with(".chatgpt-staging.com") + }); match remote_control_url.scheme() { - "https" => {} - "http" - if remote_control_url - .host_str() - .is_some_and(|host| host.eq_ignore_ascii_case("localhost")) => {} + "https" if is_localhost || is_allowed_chatgpt_host => {} + "http" if is_localhost => {} _ => return Err(map_scheme_error(())), } if !remote_control_url.path().ends_with('/') { @@ -148,21 +152,33 @@ mod tests { use pretty_assertions::assert_eq; #[test] - fn normalize_remote_control_url_accepts_https_urls() { + fn normalize_remote_control_url_accepts_chatgpt_https_urls() { assert_eq!( - normalize_remote_control_url("https://example.com/backend-api") - .expect("https URL should normalize"), + normalize_remote_control_url("https://chatgpt.com/backend-api") + .expect("chatgpt.com URL should normalize"), RemoteControlTarget { - websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server" .to_string(), - enroll_url: "https://example.com/backend-api/wham/remote/control/server/enroll" + enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll" .to_string(), } ); + assert_eq!( + normalize_remote_control_url("https://api.chatgpt-staging.com/backend-api") + .expect("chatgpt-staging.com subdomain URL should normalize"), + RemoteControlTarget { + websocket_url: + "wss://api.chatgpt-staging.com/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: + "https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); } #[test] - fn normalize_remote_control_url_accepts_localhost_http_urls() { + fn normalize_remote_control_url_accepts_localhost_urls() { assert_eq!( normalize_remote_control_url("http://localhost:8080/backend-api") .expect("localhost http URL should normalize"), @@ -173,17 +189,38 @@ mod tests { .to_string(), } ); + assert_eq!( + normalize_remote_control_url("https://localhost:8443/backend-api") + .expect("localhost https URL should normalize"), + RemoteControlTarget { + websocket_url: "wss://localhost:8443/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: "https://localhost:8443/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); } #[test] - fn normalize_remote_control_url_rejects_non_localhost_http_urls() { - let err = normalize_remote_control_url("http://example.com/backend-api") - .expect_err("non-localhost http URL should be rejected"); - - assert_eq!(err.kind(), ErrorKind::InvalidInput); - assert_eq!( - err.to_string(), - "invalid remote control URL `http://example.com/backend-api`; expected absolute HTTPS URL" - ); + fn normalize_remote_control_url_rejects_unsupported_urls() { + for remote_control_url in [ + "http://chatgpt.com/backend-api", + "http://example.com/backend-api", + "https://example.com/backend-api", + "https://chatgpt.com.evil.com/backend-api", + "https://evilchatgpt.com/backend-api", + "https://foo.localhost/backend-api", + ] { + let err = normalize_remote_control_url(remote_control_url) + .expect_err("unsupported URL should be rejected"); + + assert_eq!(err.kind(), ErrorKind::InvalidInput); + assert_eq!( + err.to_string(), + format!( + "invalid remote control URL `{remote_control_url}`; expected HTTPS URL for chatgpt.com or chatgpt-staging.com, or HTTP/HTTPS URL for localhost" + ) + ); + } } } From db806e3aaa0b032773c945ca4fe017d28725b922 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 12:35:23 -0700 Subject: [PATCH 04/17] fix test --- .../remote_control/client_tracker.rs | 28 +++++++++++++------ .../src/transport/remote_control/websocket.rs | 10 ++++++- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs index f23a1a441b24..ef3af37942be 100644 --- a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -34,7 +34,7 @@ struct ClientState { pub(crate) struct ClientTracker { clients: HashMap, - join_set: JoinSet<()>, + join_set: JoinSet, server_event_tx: mpsc::Sender, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, @@ -55,8 +55,13 @@ impl ClientTracker { } } - pub(crate) async fn bookkeep_join_set(&mut self) { - while self.join_set.join_next().await.is_some() {} + pub(crate) async fn bookkeep_join_set(&mut self) -> Option { + while let Some(join_result) = self.join_set.join_next().await { + let Ok(client_id) = join_result else { + continue; + }; + return Some(client_id); + } futures::future::pending().await } @@ -168,7 +173,7 @@ impl ClientTracker { } let server_event_tx = self.server_event_tx.clone(); - self.join_set.spawn(async move { + tokio::spawn(async move { let server_envelope = QueuedServerEnvelope { event: ServerEvent::Pong { status: PongStatus::Unknown, @@ -190,7 +195,7 @@ impl ClientTracker { mut writer_rx: mpsc::Receiver, mut status_rx: watch::Receiver, disconnect_token: CancellationToken, - ) { + ) -> ClientId { loop { let (event, write_complete_tx) = tokio::select! { _ = disconnect_token.cancelled() => { @@ -227,6 +232,7 @@ impl ClientTracker { break; } } + client_id } pub(crate) async fn close_expired_clients(&mut self) -> Result, Stopped> { @@ -244,7 +250,7 @@ impl ClientTracker { Ok(expired_client_ids) } - async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> { + pub(super) async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> { let Some(client) = self.clients.remove(client_id) else { return Ok(()); }; @@ -348,9 +354,15 @@ mod tests { } disconnect_sender.cancel(); - timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set()) + let closed_client_id = timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set()) + .await + .expect("bookkeeping should process the closed task") + .expect("closed task should return client id"); + assert_eq!(closed_client_id, ClientId("client-1".to_string())); + client_tracker + .close_client(&closed_client_id) .await - .expect_err("bookkeeping should process the closed task and stay pending"); + .expect("closed client should emit connection closed"); match transport_event_rx .recv() diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index a71a89797323..d5c54d2e2a24 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -380,7 +380,15 @@ impl RemoteControlWebsocket { loop { let incoming_message = tokio::select! { _ = shutdown_token.cancelled() => return Ok(()), - _ = client_tracker.bookkeep_join_set() => continue, + client_id = client_tracker.bookkeep_join_set() => { + let Some(client_id) = client_id else { + continue; + }; + if client_tracker.close_client(&client_id).await.is_err() { + return Ok(()); + } + continue; + } _ = idle_sweep_interval.tick() => { let expired_client_ids = match client_tracker.close_expired_clients().await { Ok(expired_client_ids) => expired_client_ids, From 558b357a6c3a4c4f890b322c19828be20407ebc8 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 12:40:38 -0700 Subject: [PATCH 05/17] allow ips --- codex-rs/app-server/src/transport/remote_control/protocol.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 4785abb3d35f..03b854939f21 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -106,7 +106,7 @@ pub(super) fn normalize_remote_control_url( let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?; let host = remote_control_url.host_str(); - let is_localhost = host == Some("localhost"); + let is_localhost = matches!(host, Some("localhost") | Some("127.0.0.1") | Some("::1")); let is_allowed_chatgpt_host = host.is_some_and(|host| { matches!(host, "chatgpt.com" | "chatgpt-staging.com") || host.ends_with(".chatgpt.com") From bca7d04529c2dc5a0537e31aaf4286c1f8f54bab Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 13:00:55 -0700 Subject: [PATCH 06/17] simplify, cover all loopback ips --- .../src/transport/remote_control/protocol.rs | 43 +++++++++++-------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 03b854939f21..4e057e565d75 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -4,6 +4,7 @@ use serde::Deserialize; use serde::Serialize; use std::io; use std::io::ErrorKind; +use url::Host; use url::Url; #[derive(Debug, Clone, PartialEq, Eq)] @@ -86,6 +87,25 @@ pub(crate) struct ServerEnvelope { pub(crate) seq_id: u64, } +fn is_allowed_chatgpt_host(host: &Option>) -> bool { + let Some(Host::Domain(host)) = *host else { + return false; + }; + host == "chatgpt.com" + || host == "chatgpt-staging.com" + || host.ends_with(".chatgpt.com") + || host.ends_with(".chatgpt-staging.com") +} + +fn is_localhost(host: &Option>) -> bool { + match host { + Some(Host::Domain("localhost")) => true, + Some(Host::Ipv4(ip)) => ip.is_loopback(), + Some(Host::Ipv6(ip)) => ip.is_loopback(), + _ => false, + } +} + pub(super) fn normalize_remote_control_url( remote_control_url: &str, ) -> io::Result { @@ -105,36 +125,23 @@ pub(super) fn normalize_remote_control_url( }; let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?; - let host = remote_control_url.host_str(); - let is_localhost = matches!(host, Some("localhost") | Some("127.0.0.1") | Some("::1")); - let is_allowed_chatgpt_host = host.is_some_and(|host| { - matches!(host, "chatgpt.com" | "chatgpt-staging.com") - || host.ends_with(".chatgpt.com") - || host.ends_with(".chatgpt-staging.com") - }); - match remote_control_url.scheme() { - "https" if is_localhost || is_allowed_chatgpt_host => {} - "http" if is_localhost => {} - _ => return Err(map_scheme_error(())), - } if !remote_control_url.path().ends_with('/') { let normalized_path = format!("{}/", remote_control_url.path()); remote_control_url.set_path(&normalized_path); } - let mut enroll_url = remote_control_url + let enroll_url = remote_control_url .join("wham/remote/control/server/enroll") .map_err(map_url_parse_error)?; let mut websocket_url = remote_control_url .join("wham/remote/control/server") .map_err(map_url_parse_error)?; - match remote_control_url.scheme() { - "https" => { - enroll_url.set_scheme("https").map_err(map_scheme_error)?; + let host = enroll_url.host(); + match enroll_url.scheme() { + "https" if is_localhost(&host) || is_allowed_chatgpt_host(&host) => { websocket_url.set_scheme("wss").map_err(map_scheme_error)?; } - "http" => { - enroll_url.set_scheme("http").map_err(map_scheme_error)?; + "http" if is_localhost(&host) => { websocket_url.set_scheme("ws").map_err(map_scheme_error)?; } _ => return Err(map_scheme_error(())), From f2623f20f3ff7ac7a622abf280103cfa0e96db4f Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 13:33:48 -0700 Subject: [PATCH 07/17] maybe fix windows? --- codex-rs/app-server/src/transport/remote_control/enroll.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index ab689496b35d..276524ea45b1 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -346,7 +346,7 @@ mod tests { .await .expect("listener should bind"); let remote_control_url = format!( - "http://localhost:{}/backend-api/", + "http://127.0.0.1:{}/backend-api/", listener .local_addr() .expect("listener should have a local addr") From 457fabc4099f9a53c2f4f3560c4aef6b1e69676d Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 13:43:50 -0700 Subject: [PATCH 08/17] rm run_remote_control_websocket_loop --- .../src/transport/remote_control/mod.rs | 7 +-- .../src/transport/remote_control/websocket.rs | 46 ++++++++----------- 2 files changed, 22 insertions(+), 31 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs index 11d9302bb46e..e9d91b17e1fa 100644 --- a/codex-rs/app-server/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -3,12 +3,12 @@ mod enroll; mod protocol; mod websocket; +use crate::transport::remote_control::websocket::RemoteControlWebsocket; use crate::transport::remote_control::websocket::load_remote_control_auth; pub use self::protocol::ClientId; use self::protocol::ServerEvent; use self::protocol::normalize_remote_control_url; -use self::websocket::run_remote_control_websocket_loop; use super::CHANNEL_CAPACITY; use super::TransportEvent; use super::next_connection_id; @@ -38,13 +38,14 @@ pub(crate) async fn start_remote_control( validate_remote_control_auth(&auth_manager).await?; Ok(tokio::spawn(async move { - run_remote_control_websocket_loop( + RemoteControlWebsocket::new( remote_control_target, state_db, auth_manager, transport_event_tx, - shutdown_token.child_token(), + shutdown_token, ) + .run() .await; })) } diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index d5c54d2e2a24..74416642af0e 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -106,7 +106,7 @@ struct WebsocketState { next_seq_id: u64, } -struct RemoteControlWebsocket { +pub(crate) struct RemoteControlWebsocket { remote_control_target: RemoteControlTarget, state_db: Option>, auth_manager: Arc, @@ -121,13 +121,14 @@ struct RemoteControlWebsocket { } impl RemoteControlWebsocket { - fn new( + pub(crate) fn new( remote_control_target: RemoteControlTarget, state_db: Option>, auth_manager: Arc, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, ) -> Self { + let shutdown_token = shutdown_token.child_token(); let (server_event_tx, server_event_rx) = mpsc::channel(super::CHANNEL_CAPACITY); let client_tracker = ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); @@ -153,7 +154,7 @@ impl RemoteControlWebsocket { } } - async fn run(mut self) { + pub(crate) async fn run(mut self) { loop { let shutdown_token = self.shutdown_token.child_token(); let websocket_connection = match self.connect(&shutdown_token).await { @@ -469,24 +470,6 @@ impl RemoteControlWebsocket { } } -pub(super) async fn run_remote_control_websocket_loop( - remote_control_target: RemoteControlTarget, - state_db: Option>, - auth_manager: Arc, - transport_event_tx: mpsc::Sender, - shutdown_token: CancellationToken, -) { - RemoteControlWebsocket::new( - remote_control_target, - state_db, - auth_manager, - transport_event_tx, - shutdown_token, - ) - .run() - .await; -} - fn remote_control_message_starts_connection(event: &ClientEvent) -> bool { matches!( event, @@ -1023,13 +1006,20 @@ mod tests { let (transport_event_tx, transport_event_rx) = mpsc::channel(1); drop(transport_event_rx); let shutdown_token = CancellationToken::new(); - let websocket_task = tokio::spawn(run_remote_control_websocket_loop( - remote_control_target, - None, - remote_control_auth_manager(), - transport_event_tx, - shutdown_token.clone(), - )); + let websocket_task = tokio::spawn({ + let shutdown_token = shutdown_token.clone(); + async move { + RemoteControlWebsocket::new( + remote_control_target, + None, + remote_control_auth_manager(), + transport_event_tx, + shutdown_token, + ) + .run() + .await + } + }); tokio::time::sleep(Duration::from_millis(50)).await; shutdown_token.cancel(); From 0f1f511e3c4a6bd19babf5fd3e88854d3f3e35d1 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 14:17:41 -0700 Subject: [PATCH 09/17] maybe fix windows? --- .../src/transport/remote_control/enroll.rs | 33 ++++++++++++++++--- 1 file changed, 29 insertions(+), 4 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index 276524ea45b1..23b190f96205 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -196,7 +196,9 @@ mod tests { use serde_json::json; use std::sync::Arc; use tempfile::TempDir; + use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::time::Duration; @@ -360,10 +362,7 @@ mod tests { }); let expected_body = response_body.to_string(); let server_task = tokio::spawn(async move { - let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) - .await - .expect("HTTP request should arrive in time") - .expect("listener accept should succeed"); + let stream = accept_http_request(&listener).await; respond_with_json(stream, response_body).await; }); @@ -387,6 +386,32 @@ mod tests { ); } + async fn accept_http_request(listener: &TcpListener) -> TcpStream { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + let mut reader = BufReader::new(stream); + + let mut request_line = String::new(); + reader + .read_line(&mut request_line) + .await + .expect("request line should read"); + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .await + .expect("header line should read"); + if line == "\r\n" { + break; + } + } + + reader.into_inner() + } + async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) { let body = body.to_string(); let response = format!( From e716db62b1f3d3ae6b570c8a2843c459fd7e2976 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 27 Mar 2026 23:18:04 -0700 Subject: [PATCH 10/17] cr --- .../remote_control/client_tracker.rs | 222 ++++++++++++++---- .../src/transport/remote_control/enroll.rs | 20 +- .../src/transport/remote_control/mod.rs | 2 + .../src/transport/remote_control/protocol.rs | 27 ++- .../src/transport/remote_control/tests.rs | 55 ++++- .../src/transport/remote_control/websocket.rs | 57 +++-- .../0023_remote_control_enrollments.sql | 1 + codex-rs/state/src/runtime/remote_control.rs | 36 ++- 8 files changed, 333 insertions(+), 87 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs index ef3af37942be..be80ad523a49 100644 --- a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -6,6 +6,7 @@ pub use super::protocol::ClientEvent; pub use super::protocol::ClientId; use super::protocol::PongStatus; use super::protocol::ServerEvent; +use super::protocol::StreamId; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; use crate::transport::remote_control::QueuedServerEnvelope; @@ -33,8 +34,9 @@ struct ClientState { } pub(crate) struct ClientTracker { - clients: HashMap, - join_set: JoinSet, + clients: HashMap<(ClientId, StreamId), ClientState>, + legacy_stream_ids: HashMap, + join_set: JoinSet<(ClientId, StreamId)>, server_event_tx: mpsc::Sender, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, @@ -48,6 +50,7 @@ impl ClientTracker { ) -> Self { Self { clients: HashMap::new(), + legacy_stream_ids: HashMap::new(), join_set: JoinSet::new(), server_event_tx, transport_event_tx, @@ -55,12 +58,12 @@ impl ClientTracker { } } - pub(crate) async fn bookkeep_join_set(&mut self) -> Option { + pub(crate) async fn bookkeep_join_set(&mut self) -> Option<(ClientId, StreamId)> { while let Some(join_result) = self.join_set.join_next().await { - let Ok(client_id) = join_result else { + let Ok(client_key) = join_result else { continue; }; - return Some(client_id); + return Some(client_key); } futures::future::pending().await } @@ -68,8 +71,8 @@ impl ClientTracker { pub(crate) async fn shutdown(&mut self) { self.shutdown_token.cancel(); - while let Some(client_id) = self.clients.keys().next().cloned() { - let _ = self.close_client(&client_id).await; + while let Some(client_key) = self.clients.keys().next().cloned() { + let _ = self.close_client(&client_key).await; } self.drain_join_set().await; @@ -79,6 +82,10 @@ impl ClientTracker { while self.join_set.join_next().await.is_some() {} } + pub(crate) fn legacy_stream_id(&self, client_id: &ClientId) -> Option { + self.legacy_stream_ids.get(client_id).cloned() + } + pub(crate) async fn handle_message( &mut self, client_envelope: ClientEnvelope, @@ -86,14 +93,40 @@ impl ClientTracker { let ClientEnvelope { client_id, event, + stream_id, seq_id, cursor: _, } = client_envelope; + let is_legacy_stream_id = stream_id.is_none(); + let is_initialize = matches!(&event, ClientEvent::ClientMessage { message } if remote_control_message_starts_connection(message)); + let stream_id = match stream_id { + Some(stream_id) => stream_id, + None if is_initialize => { + // TODO(ruslan): delete this fallback once all clients are updated to send stream_id. + self.legacy_stream_ids + .remove(&client_id) + .unwrap_or_else(StreamId::new_random) + } + None => self + .legacy_stream_ids + .get(&client_id) + .cloned() + .unwrap_or_else(|| { + if matches!(&event, ClientEvent::Ping) { + StreamId::new_random() + } else { + StreamId(String::new()) + } + }), + }; + if stream_id.0.is_empty() { + return Ok(()); + } + let client_key = (client_id.clone(), stream_id.clone()); match event { ClientEvent::ClientMessage { message } => { - let is_initialize = remote_control_message_starts_connection(&message); if let Some(seq_id) = seq_id - && let Some(client) = self.clients.get(&client_id) + && let Some(client) = self.clients.get(&client_key) && client .last_inbound_seq_id .is_some_and(|last_seq_id| last_seq_id >= seq_id) @@ -102,24 +135,22 @@ impl ClientTracker { return Ok(()); } - if is_initialize && self.clients.contains_key(&client_id) { - self.close_client(&client_id).await?; + if is_initialize && self.clients.contains_key(&client_key) { + self.close_client(&client_key).await?; } - if let Some(connection_id) = self.clients.get_mut(&client_id).map(|client| { + if let Some(connection_id) = self.clients.get_mut(&client_key).map(|client| { client.last_activity_at = Instant::now(); if let Some(seq_id) = seq_id { client.last_inbound_seq_id = Some(seq_id); } client.connection_id }) { - self.transport_event_tx - .send(TransportEvent::IncomingMessage { - connection_id, - message, - }) - .await - .map_err(|_| Stopped)?; + self.send_transport_event(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await?; return Ok(()); } @@ -131,33 +162,35 @@ impl ClientTracker { let (writer_tx, writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); let disconnect_token = self.shutdown_token.child_token(); - self.transport_event_tx - .send(TransportEvent::ConnectionOpened { - connection_id, - writer: writer_tx, - disconnect_sender: Some(disconnect_token.clone()), - }) - .await - .map_err(|_| Stopped)?; + self.send_transport_event(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + disconnect_sender: Some(disconnect_token.clone()), + }) + .await?; let (status_tx, status_rx) = watch::channel(PongStatus::Active); self.join_set.spawn(Self::run_client_outbound( client_id.clone(), + stream_id.clone(), self.server_event_tx.clone(), writer_rx, status_rx, disconnect_token.clone(), )); self.clients.insert( - client_id, + client_key, ClientState { connection_id, disconnect_token, last_activity_at: Instant::now(), - last_inbound_seq_id: seq_id, + last_inbound_seq_id: if is_legacy_stream_id { None } else { seq_id }, status_tx, }, ); + if is_legacy_stream_id { + self.legacy_stream_ids.insert(client_id.clone(), stream_id); + } self.send_transport_event(TransportEvent::IncomingMessage { connection_id, message, @@ -166,7 +199,7 @@ impl ClientTracker { } ClientEvent::Ack => Ok(()), ClientEvent::Ping => { - if let Some(client) = self.clients.get_mut(&client_id) { + if let Some(client) = self.clients.get_mut(&client_key) { client.last_activity_at = Instant::now(); let _ = client.status_tx.send(PongStatus::Active); return Ok(()); @@ -179,23 +212,25 @@ impl ClientTracker { status: PongStatus::Unknown, }, client_id, + stream_id, write_complete_tx: None, }; let _ = server_event_tx.send(server_envelope).await; }); Ok(()) } - ClientEvent::ClientClosed => self.close_client(&client_id).await, + ClientEvent::ClientClosed => self.close_client(&client_key).await, } } async fn run_client_outbound( client_id: ClientId, + stream_id: StreamId, server_event_tx: mpsc::Sender, mut writer_rx: mpsc::Receiver, mut status_rx: watch::Receiver, disconnect_token: CancellationToken, - ) -> ClientId { + ) -> (ClientId, StreamId) { loop { let (event, write_complete_tx) = tokio::select! { _ = disconnect_token.cancelled() => { @@ -225,6 +260,7 @@ impl ClientTracker { send_result = server_event_tx.send(QueuedServerEnvelope { event, client_id: client_id.clone(), + stream_id: stream_id.clone(), write_complete_tx, }) => send_result, }; @@ -232,28 +268,40 @@ impl ClientTracker { break; } } - client_id + (client_id, stream_id) } - pub(crate) async fn close_expired_clients(&mut self) -> Result, Stopped> { + pub(crate) async fn close_expired_clients( + &mut self, + ) -> Result, Stopped> { let now = Instant::now(); - let expired_client_ids: Vec = self + let expired_client_ids: Vec<(ClientId, StreamId)> = self .clients .iter() - .filter_map(|(client_id, client)| { - (!remote_control_client_is_alive(client, now)).then_some(client_id.clone()) + .filter_map(|(client_key, client)| { + (!remote_control_client_is_alive(client, now)).then_some(client_key.clone()) }) .collect(); - for client_id in &expired_client_ids { - self.close_client(client_id).await?; + for client_key in &expired_client_ids { + self.close_client(client_key).await?; } Ok(expired_client_ids) } - pub(super) async fn close_client(&mut self, client_id: &ClientId) -> Result<(), Stopped> { - let Some(client) = self.clients.remove(client_id) else { + pub(super) async fn close_client( + &mut self, + client_key: &(ClientId, StreamId), + ) -> Result<(), Stopped> { + let Some(client) = self.clients.remove(client_key) else { return Ok(()); }; + if self + .legacy_stream_ids + .get(&client_key.0) + .is_some_and(|stream_id| stream_id == &client_key.1) + { + self.legacy_stream_ids.remove(&client_key.0); + } client.disconnect_token.cancel(); self.send_transport_event(TransportEvent::ConnectionClosed { connection_id: client.connection_id, @@ -296,6 +344,13 @@ mod tests { use tokio::time::timeout; fn initialize_envelope(client_id: &str) -> ClientEnvelope { + initialize_envelope_with_stream_id(client_id, None) + } + + fn initialize_envelope_with_stream_id( + client_id: &str, + stream_id: Option<&str>, + ) -> ClientEnvelope { ClientEnvelope { event: ClientEvent::ClientMessage { message: JSONRPCMessage::Request(JSONRPCRequest { @@ -311,6 +366,7 @@ mod tests { }), }, client_id: ClientId(client_id.to_string()), + stream_id: stream_id.map(|stream_id| StreamId(stream_id.to_string())), seq_id: Some(0), cursor: None, } @@ -358,7 +414,7 @@ mod tests { .await .expect("bookkeeping should process the closed task") .expect("closed task should return client id"); - assert_eq!(closed_client_id, ClientId("client-1".to_string())); + assert_eq!(closed_client_id.0, ClientId("client-1".to_string())); client_tracker .close_client(&closed_client_id) .await @@ -390,6 +446,7 @@ mod tests { status: PongStatus::Unknown, }, client_id: ClientId("queued-client".to_string()), + stream_id: StreamId("queued-stream".to_string()), write_complete_tx: None, }) .await @@ -431,4 +488,85 @@ mod tests { .await .expect("shutdown should not hang on blocked server forwarding"); } + + #[tokio::test] + async fn initialize_with_new_stream_id_opens_new_connection_for_same_client() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope_with_stream_id( + "client-1", + Some("stream-1"), + )) + .await + .expect("first initialize should open client"); + let first_connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx.recv().await.expect("initialize event"); + + client_tracker + .handle_message(initialize_envelope_with_stream_id( + "client-1", + Some("stream-2"), + )) + .await + .expect("second initialize should open client"); + let second_connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + + assert_ne!(first_connection_id, second_connection_id); + } + + #[tokio::test] + async fn legacy_initialize_without_stream_id_resets_inbound_seq_id() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + let connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx.recv().await.expect("initialize event"); + + client_tracker + .handle_message(ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }, + client_id: ClientId("client-1".to_string()), + stream_id: None, + seq_id: Some(0), + cursor: None, + }) + .await + .expect("legacy followup should be forwarded"); + + match transport_event_rx.recv().await.expect("followup event") { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + .. + } => assert_eq!(incoming_connection_id, connection_id), + other => panic!("expected incoming message, got {other:?}"), + } + } } diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index 23b190f96205..f74721beacab 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -20,6 +20,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct RemoteControlEnrollment { pub(super) account_id: Option, + pub(super) environment_id: String, pub(super) server_id: String, pub(super) server_name: String, } @@ -47,11 +48,14 @@ pub(super) async fn load_persisted_remote_control_enrollment( } }; - enrollment.map(|(server_id, server_name)| RemoteControlEnrollment { - account_id: account_id.map(&str::to_string), - server_id, - server_name, - }) + enrollment.map( + |(server_id, environment_id, server_name)| RemoteControlEnrollment { + account_id: account_id.map(&str::to_string), + environment_id, + server_id, + server_name, + }, + ) } pub(super) async fn update_persisted_remote_control_enrollment( @@ -77,6 +81,7 @@ pub(super) async fn update_persisted_remote_control_enrollment( &remote_control_target.websocket_url, account_id, &enrollment.server_id, + &enrollment.environment_id, &enrollment.server_name, ) .await @@ -182,6 +187,7 @@ pub(super) async fn enroll_remote_control_server( Ok(RemoteControlEnrollment { account_id: account_id.map(&str::to_string), + environment_id: enrollment.environment_id, server_id: enrollment.server_id, server_name, }) @@ -221,11 +227,13 @@ mod tests { .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_first".to_string(), server_id: "srv_e_first".to_string(), server_name: "first-server".to_string(), }; let second_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_second".to_string(), server_id: "srv_e_second".to_string(), server_name: "second-server".to_string(), }; @@ -287,11 +295,13 @@ mod tests { .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_first".to_string(), server_id: "srv_e_first".to_string(), server_name: "first-server".to_string(), }; let second_enrollment = RemoteControlEnrollment { account_id: Some("account-a".to_string()), + environment_id: "env_second".to_string(), server_id: "srv_e_second".to_string(), server_name: "second-server".to_string(), }; diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs index e9d91b17e1fa..4dd5a68769eb 100644 --- a/codex-rs/app-server/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -8,6 +8,7 @@ use crate::transport::remote_control::websocket::load_remote_control_auth; pub use self::protocol::ClientId; use self::protocol::ServerEvent; +use self::protocol::StreamId; use self::protocol::normalize_remote_control_url; use super::CHANNEL_CAPACITY; use super::TransportEvent; @@ -24,6 +25,7 @@ use tokio_util::sync::CancellationToken; pub(super) struct QueuedServerEnvelope { pub(super) event: ServerEvent, pub(super) client_id: ClientId, + pub(super) stream_id: StreamId, pub(super) write_complete_tx: Option>, } diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 4e057e565d75..0bedf32dc5c9 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -24,12 +24,23 @@ pub(super) struct EnrollRemoteServerRequest { #[derive(Debug, Deserialize)] pub(super) struct EnrollRemoteServerResponse { pub(super) server_id: String, + pub(super) environment_id: String, } #[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] #[serde(transparent)] pub struct ClientId(pub String); +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct StreamId(pub String); + +impl StreamId { + pub fn new_random() -> Self { + Self(uuid::Uuid::now_v7().to_string()) + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ClientEvent { @@ -44,13 +55,11 @@ pub enum ClientEvent { pub(crate) struct ClientEnvelope { #[serde(flatten)] pub(crate) event: ClientEvent, - #[serde(rename = "client_id", alias = "clientId")] + #[serde(rename = "client_id")] pub(crate) client_id: ClientId, - #[serde( - rename = "seq_id", - alias = "seqId", - skip_serializing_if = "Option::is_none" - )] + #[serde(rename = "stream_id", skip_serializing_if = "Option::is_none")] + pub(crate) stream_id: Option, + #[serde(rename = "seq_id", skip_serializing_if = "Option::is_none")] pub(crate) seq_id: Option, #[serde(skip_serializing_if = "Option::is_none")] pub(crate) cursor: Option, @@ -81,9 +90,11 @@ pub enum ServerEvent { pub(crate) struct ServerEnvelope { #[serde(flatten)] pub(crate) event: ServerEvent, - #[serde(rename = "client_id", alias = "clientId")] + #[serde(rename = "client_id")] pub(crate) client_id: ClientId, - #[serde(rename = "seq_id", alias = "seqId")] + #[serde(rename = "stream_id")] + pub(crate) stream_id: StreamId, + #[serde(rename = "seq_id")] pub(crate) seq_id: u64, } diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs index e0556ebfbeee..888fb4732030 100644 --- a/codex-rs/app-server/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -95,7 +95,11 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() enroll_request.request_line, "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" ); - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let mut websocket = accept_remote_control_connection(&listener).await; let client_id = ClientId("client-1".to_string()); @@ -104,6 +108,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::Ping, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -131,6 +136,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ), }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -161,6 +167,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() message: initialize_message.clone(), }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(1), cursor: None, }, @@ -207,6 +214,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() message: followup_message.clone(), }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(2), cursor: None, }, @@ -232,6 +240,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::Ping, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -281,6 +290,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::ClientClosed, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -304,6 +314,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() ClientEnvelope { event: ClientEvent::Ping, client_id, + stream_id: None, seq_id: None, cursor: None, }, @@ -348,7 +359,11 @@ async fn remote_control_transport_reconnects_after_disconnect() { enroll_request.request_line, "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" ); - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let mut first_websocket = accept_remote_control_connection(&listener).await; first_websocket .close(None) @@ -374,6 +389,7 @@ async fn remote_control_transport_reconnects_after_disconnect() { }), }, client_id: ClientId("client-2".to_string()), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -414,7 +430,11 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { .expect("remote control should start"); let enroll_request = accept_http_request(&listener).await; - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let mut first_websocket = accept_remote_control_connection(&listener).await; let client_id = ClientId("client-1".to_string()); @@ -436,6 +456,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { message: initialize_message, }, client_id: client_id.clone(), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -493,6 +514,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { ClientEnvelope { event: ClientEvent::ClientClosed, client_id: client_id.clone(), + stream_id: None, seq_id: None, cursor: None, }, @@ -519,6 +541,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { ClientEnvelope { event: ClientEvent::Ping, client_id, + stream_id: None, seq_id: None, cursor: None, }, @@ -582,7 +605,11 @@ async fn remote_control_http_mode_enrolls_before_connecting() { "app_server_version": env!("CARGO_PKG_VERSION"), }) ); - respond_with_json(enroll_request.stream, json!({ "server_id": "srv_e_test" })).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; let (handshake_request, mut websocket) = accept_remote_control_backend_connection(&listener).await; @@ -634,6 +661,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() { message: initialize_message.clone(), }, client_id: backend_client_id.clone(), + stream_id: None, seq_id: Some(0), cursor: None, }, @@ -742,6 +770,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling normalize_remote_control_url(&remote_control_url).expect("target should parse"); let persisted_enrollment = RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_persisted".to_string(), server_id: "srv_e_persisted".to_string(), server_name: "persisted-server".to_string(), }; @@ -803,11 +832,13 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() let expected_server_name = gethostname().to_string_lossy().trim().to_string(); let stale_enrollment = RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_stale".to_string(), server_id: "srv_e_stale".to_string(), server_name: "stale-server".to_string(), }; let refreshed_enrollment = RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_refreshed".to_string(), server_id: "srv_e_refreshed".to_string(), server_name: expected_server_name, }; @@ -851,7 +882,10 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() ); respond_with_json( enroll_request.stream, - json!({ "server_id": refreshed_enrollment.server_id }), + json!({ + "server_id": refreshed_enrollment.server_id, + "environment_id": refreshed_enrollment.environment_id, + }), ) .await; @@ -1048,8 +1082,15 @@ async fn read_server_event(websocket: &mut WebSocketStream) -> serde_ .expect("websocket frame should be readable"); match frame { tungstenite::Message::Text(text) => { - return serde_json::from_str(text.as_ref()) - .expect("server event should deserialize"); + let mut event: serde_json::Value = + serde_json::from_str(text.as_ref()).expect("server event should deserialize"); + if let Some(stream_id) = event + .as_object_mut() + .and_then(|event| event.remove("stream_id")) + { + assert!(stream_id.is_string(), "stream_id should be a string"); + } + return event; } tungstenite::Message::Ping(payload) => { websocket diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 74416642af0e..52cc9c14debc 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -14,6 +14,7 @@ use super::protocol::ClientEvent; use super::protocol::ClientId; use super::protocol::RemoteControlTarget; use super::protocol::ServerEnvelope; +use super::protocol::StreamId; use axum::http::HeaderValue; use base64::Engine; use codex_core::AuthManager; @@ -50,7 +51,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor"; struct BoundedOutboundBuffer { - buffer_by_client: HashMap>, + buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap>, used_tx: watch::Sender, } @@ -66,20 +67,29 @@ impl BoundedOutboundBuffer { fn insert(&mut self, server_envelope: &ServerEnvelope) { self.buffer_by_client - .entry(server_envelope.client_id.clone()) + .entry(( + server_envelope.client_id.clone(), + server_envelope.stream_id.clone(), + )) .or_default() .insert(server_envelope.seq_id, server_envelope.clone()); self.used_tx.send_modify(|used| *used += 1); } - fn remove(&mut self, client_id: &ClientId) { - if let Some(buffer) = self.buffer_by_client.remove(client_id) { + fn remove(&mut self, client_id: &ClientId, stream_id: &StreamId) { + if let Some(buffer) = self + .buffer_by_client + .remove(&(client_id.clone(), stream_id.clone())) + { self.used_tx.send_modify(|used| *used -= buffer.len()); } } - fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) { - let Some(buffer) = self.buffer_by_client.get_mut(client_id) else { + fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) { + let Some(buffer) = self + .buffer_by_client + .get_mut(&(client_id.clone(), stream_id.clone())) + else { return; }; while let Some(seq_id) = buffer.first_key_value().map(|(seq_id, _)| seq_id) @@ -89,7 +99,8 @@ impl BoundedOutboundBuffer { self.used_tx.send_modify(|used| *used -= 1); } if buffer.is_empty() { - self.buffer_by_client.remove(client_id); + self.buffer_by_client + .remove(&(client_id.clone(), stream_id.clone())); } } @@ -320,6 +331,7 @@ impl RemoteControlWebsocket { event: queued_server_envelope.event, client_id: queued_server_envelope.client_id, seq_id, + stream_id: queued_server_envelope.stream_id, }; state.outbound_buffer.insert(&server_envelope); @@ -391,14 +403,14 @@ impl RemoteControlWebsocket { continue; } _ = idle_sweep_interval.tick() => { - let expired_client_ids = match client_tracker.close_expired_clients().await { - Ok(expired_client_ids) => expired_client_ids, + let expired_client_keys = match client_tracker.close_expired_clients().await { + Ok(expired_client_keys) => expired_client_keys, Err(_) => return Ok(()), }; - if !expired_client_ids.is_empty() { + if !expired_client_keys.is_empty() { let mut state = state.lock().await; - for client_id in expired_client_ids { - state.outbound_buffer.remove(&client_id); + for (client_id, stream_id) in expired_client_keys { + state.outbound_buffer.remove(&client_id, &stream_id); } } continue; @@ -441,22 +453,29 @@ impl RemoteControlWebsocket { } }; + let resolved_stream_id = client_envelope + .stream_id + .clone() + .or_else(|| client_tracker.legacy_stream_id(&client_envelope.client_id)); let mut state = state.lock().await; if let Some(cursor) = client_envelope.cursor.as_deref() { state.subscribe_cursor = Some(cursor.to_string()); } if let ClientEvent::Ack = &client_envelope.event && let Some(acked_seq_id) = client_envelope.seq_id + && let Some(stream_id) = resolved_stream_id.as_ref() { state .outbound_buffer - .ack(&client_envelope.client_id, acked_seq_id); - } - if matches!(&client_envelope.event, ClientEvent::ClientClosed) - || remote_control_message_starts_connection(&client_envelope.event) - { - state.outbound_buffer.remove(&client_envelope.client_id); + .ack(&client_envelope.client_id, stream_id, acked_seq_id); } + if (matches!(&client_envelope.event, ClientEvent::ClientClosed) + || remote_control_message_starts_connection(&client_envelope.event)) + && let Some(stream_id) = resolved_stream_id.as_ref() { + state + .outbound_buffer + .remove(&client_envelope.client_id, stream_id); + } drop(state); if client_tracker @@ -834,6 +853,7 @@ mod tests { let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut enrollment = Some(RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_test".to_string(), server_id: "srv_e_test".to_string(), server_name: "test-server".to_string(), }); @@ -888,6 +908,7 @@ mod tests { let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut enrollment = Some(RemoteControlEnrollment { account_id: Some("account_id".to_string()), + environment_id: "env_test".to_string(), server_id: "srv_e_test".to_string(), server_name: "test-server".to_string(), }); diff --git a/codex-rs/state/migrations/0023_remote_control_enrollments.sql b/codex-rs/state/migrations/0023_remote_control_enrollments.sql index 9a2081dd8f38..247b8d419253 100644 --- a/codex-rs/state/migrations/0023_remote_control_enrollments.sql +++ b/codex-rs/state/migrations/0023_remote_control_enrollments.sql @@ -2,6 +2,7 @@ CREATE TABLE remote_control_enrollments ( websocket_url TEXT NOT NULL, account_id TEXT NOT NULL, server_id TEXT NOT NULL, + environment_id TEXT NOT NULL, server_name TEXT NOT NULL, updated_at INTEGER NOT NULL, PRIMARY KEY (websocket_url, account_id) diff --git a/codex-rs/state/src/runtime/remote_control.rs b/codex-rs/state/src/runtime/remote_control.rs index 307dac8184ba..4ac1f81872cd 100644 --- a/codex-rs/state/src/runtime/remote_control.rs +++ b/codex-rs/state/src/runtime/remote_control.rs @@ -11,10 +11,10 @@ impl StateRuntime { &self, websocket_url: &str, account_id: Option<&str>, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let row = sqlx::query( r#" -SELECT server_id, server_name +SELECT server_id, environment_id, server_name FROM remote_control_enrollments WHERE websocket_url = ? AND account_id = ? "#, @@ -24,8 +24,14 @@ WHERE websocket_url = ? AND account_id = ? .fetch_optional(self.pool.as_ref()) .await?; - row.map(|row| Ok((row.try_get("server_id")?, row.try_get("server_name")?))) - .transpose() + row.map(|row| { + Ok(( + row.try_get("server_id")?, + row.try_get("environment_id")?, + row.try_get("server_name")?, + )) + }) + .transpose() } pub async fn upsert_remote_control_enrollment( @@ -33,6 +39,7 @@ WHERE websocket_url = ? AND account_id = ? websocket_url: &str, account_id: Option<&str>, server_id: &str, + environment_id: &str, server_name: &str, ) -> anyhow::Result<()> { sqlx::query( @@ -41,11 +48,13 @@ INSERT INTO remote_control_enrollments ( websocket_url, account_id, server_id, + environment_id, server_name, updated_at -) VALUES (?, ?, ?, ?, ?) +) VALUES (?, ?, ?, ?, ?, ?) ON CONFLICT(websocket_url, account_id) DO UPDATE SET server_id = excluded.server_id, + environment_id = excluded.environment_id, server_name = excluded.server_name, updated_at = excluded.updated_at "#, @@ -53,6 +62,7 @@ ON CONFLICT(websocket_url, account_id) DO UPDATE SET .bind(websocket_url) .bind(remote_control_account_id_key(account_id)) .bind(server_id) + .bind(environment_id) .bind(server_name) .bind(Utc::now().timestamp()) .execute(self.pool.as_ref()) @@ -97,6 +107,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", Some("account-a"), "srv_e_first", + "env_first", "first-server", ) .await @@ -106,6 +117,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", Some("account-b"), "srv_e_second", + "env_second", "second-server", ) .await @@ -119,7 +131,11 @@ mod tests { ) .await .expect("load first enrollment"), - Some(("srv_e_first".to_string(), "first-server".to_string())) + Some(( + "srv_e_first".to_string(), + "env_first".to_string(), + "first-server".to_string() + )) ); assert_eq!( runtime @@ -147,6 +163,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", None, "srv_e_first", + "env_first", "first-server", ) .await @@ -156,6 +173,7 @@ mod tests { "wss://example.com/backend-api/wham/remote/control/server", Some("account-a"), "srv_e_second", + "env_second", "second-server", ) .await @@ -189,7 +207,11 @@ mod tests { ) .await .expect("load retained enrollment"), - Some(("srv_e_second".to_string(), "second-server".to_string())) + Some(( + "srv_e_second".to_string(), + "env_second".to_string(), + "second-server".to_string() + )) ); let _ = tokio::fs::remove_dir_all(codex_home).await; From 93672f05dfbb0402a2129abed253eecd133a91bc Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Tue, 31 Mar 2026 21:48:43 -0700 Subject: [PATCH 11/17] share auth manager init code --- codex-rs/app-server/src/auth_manager.rs | 17 +++++++++++++++++ codex-rs/app-server/src/in_process.rs | 10 ++-------- codex-rs/app-server/src/lib.rs | 17 +++++------------ codex-rs/app-server/src/message_processor.rs | 1 - .../src/message_processor/tracing_tests.rs | 9 ++------- 5 files changed, 26 insertions(+), 28 deletions(-) create mode 100644 codex-rs/app-server/src/auth_manager.rs diff --git a/codex-rs/app-server/src/auth_manager.rs b/codex-rs/app-server/src/auth_manager.rs new file mode 100644 index 000000000000..321c1deed5af --- /dev/null +++ b/codex-rs/app-server/src/auth_manager.rs @@ -0,0 +1,17 @@ +use std::sync::Arc; + +use codex_core::AuthManager; +use codex_core::config::Config; + +pub(crate) fn auth_manager_from_config( + config: &Config, + enable_codex_api_key_env: bool, +) -> Arc { + let auth_manager = AuthManager::shared( + config.codex_home.clone(), + enable_codex_api_key_env, + config.cli_auth_credentials_store_mode, + ); + auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); + auth_manager +} diff --git a/codex-rs/app-server/src/in_process.rs b/codex-rs/app-server/src/in_process.rs index 2aef7897e1fb..849984060e58 100644 --- a/codex-rs/app-server/src/in_process.rs +++ b/codex-rs/app-server/src/in_process.rs @@ -50,6 +50,7 @@ use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; use std::time::Duration; +use crate::auth_manager::auth_manager_from_config; use crate::error_code::INTERNAL_ERROR_CODE; use crate::error_code::INVALID_REQUEST_ERROR_CODE; use crate::error_code::OVERLOADED_ERROR_CODE; @@ -75,7 +76,6 @@ use codex_app_server_protocol::ServerNotification; use codex_app_server_protocol::ServerRequest; use codex_arg0::Arg0DispatchPaths; use codex_core::AppServerRpcTransport; -use codex_core::AuthManager; use codex_core::config::Config; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::LoaderOverrides; @@ -379,13 +379,7 @@ fn start_uninitialized(args: InProcessStartArgs) -> InProcessClientHandle { } }); - let auth_manager = AuthManager::shared( - args.config.codex_home.clone(), - args.enable_codex_api_key_env, - args.config.cli_auth_credentials_store_mode, - ); - auth_manager - .set_forced_chatgpt_workspace_id(args.config.forced_chatgpt_workspace_id.clone()); + let auth_manager = auth_manager_from_config(&args.config, args.enable_codex_api_key_env); let processor_outgoing = Arc::clone(&outgoing_message_sender); let (processor_tx, mut processor_rx) = mpsc::channel::(channel_capacity); diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 406da2e4bfc4..ece52ac13717 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -2,7 +2,6 @@ use codex_arg0::Arg0DispatchPaths; use codex_cloud_requirements::cloud_requirements_loader; -use codex_core::AuthManager; use codex_core::config::Config; use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; @@ -18,6 +17,7 @@ use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::AtomicBool; +use crate::auth_manager::auth_manager_from_config; use crate::message_processor::MessageProcessor; use crate::message_processor::MessageProcessorArgs; use crate::outgoing_message::ConnectionId; @@ -63,6 +63,7 @@ use tracing_subscriber::registry::Registry; use tracing_subscriber::util::SubscriberInitExt; mod app_server_tracing; +mod auth_manager; mod bespoke_event_handling; mod codex_message_processor; mod command_exec; @@ -398,11 +399,8 @@ pub async fn run_main_with_transport( } } - let auth_manager = AuthManager::shared( - config.codex_home.clone(), - /*enable_codex_api_key_env*/ false, - config.cli_auth_credentials_store_mode, - ); + let auth_manager = + auth_manager_from_config(&config, /*enable_codex_api_key_env*/ false); cloud_requirements_loader( auth_manager, config.chatgpt_base_url, @@ -554,12 +552,7 @@ pub async fn run_main_with_transport( AppServerTransport::Off => {} } - let auth_manager = AuthManager::shared( - config.codex_home.clone(), - /*enable_codex_api_key_env*/ false, - config.cli_auth_credentials_store_mode, - ); - auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); + let auth_manager = auth_manager_from_config(&config, /*enable_codex_api_key_env*/ false); if config.features.enabled(Feature::RemoteControl) { let accept_handle = start_remote_control( diff --git a/codex-rs/app-server/src/message_processor.rs b/codex-rs/app-server/src/message_processor.rs index 6f966e737bee..03c8ad4d9f1a 100644 --- a/codex-rs/app-server/src/message_processor.rs +++ b/codex-rs/app-server/src/message_processor.rs @@ -214,7 +214,6 @@ impl MessageProcessor { auth_manager.set_external_chatgpt_auth_refresher(Arc::new(ExternalAuthRefreshBridge { outgoing: outgoing.clone(), })); - auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); let thread_manager = Arc::new(ThreadManager::new( config.as_ref(), auth_manager.clone(), diff --git a/codex-rs/app-server/src/message_processor/tracing_tests.rs b/codex-rs/app-server/src/message_processor/tracing_tests.rs index 899df49463d2..42173fc54223 100644 --- a/codex-rs/app-server/src/message_processor/tracing_tests.rs +++ b/codex-rs/app-server/src/message_processor/tracing_tests.rs @@ -1,6 +1,7 @@ use super::ConnectionSessionState; use super::MessageProcessor; use super::MessageProcessorArgs; +use crate::auth_manager::auth_manager_from_config; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::OutgoingMessageSender; use crate::transport::AppServerTransport; @@ -21,7 +22,6 @@ use codex_app_server_protocol::TurnStartResponse; use codex_app_server_protocol::UserInput; use codex_arg0::Arg0DispatchPaths; use codex_core::AppServerRpcTransport; -use codex_core::AuthManager; use codex_core::config::Config; use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; @@ -233,12 +233,7 @@ fn build_test_processor( MessageProcessor, mpsc::Receiver, ) { - let auth_manager = AuthManager::shared( - config.codex_home.clone(), - /*enable_codex_api_key_env*/ false, - config.cli_auth_credentials_store_mode, - ); - auth_manager.set_forced_chatgpt_workspace_id(config.forced_chatgpt_workspace_id.clone()); + let auth_manager = auth_manager_from_config(&config, /*enable_codex_api_key_env*/ false); let (outgoing_tx, outgoing_rx) = mpsc::channel(16); let outgoing = Arc::new(OutgoingMessageSender::new(outgoing_tx)); From cb3931b0addc3f474afef99edc7af8f72167158d Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Fri, 3 Apr 2026 22:54:06 -0700 Subject: [PATCH 12/17] add logging & add regular pings --- .../src/transport/remote_control/enroll.rs | 69 ++++-- .../src/transport/remote_control/websocket.rs | 199 ++++++++++++++++-- 2 files changed, 241 insertions(+), 27 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index b3648e9bc51a..787ece764686 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -7,6 +7,7 @@ use codex_state::StateRuntime; use gethostname::gethostname; use std::io; use std::io::ErrorKind; +use tracing::info; use tracing::warn; const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); @@ -36,26 +37,48 @@ pub(super) async fn load_persisted_remote_control_enrollment( remote_control_target: &RemoteControlTarget, account_id: Option<&str>, ) -> Option { - let state_db = state_db?; + let Some(state_db) = state_db else { + info!( + "remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={:?}", + remote_control_target.websocket_url, account_id + ); + return None; + }; let enrollment = match state_db .get_remote_control_enrollment(&remote_control_target.websocket_url, account_id) .await { Ok(enrollment) => enrollment, Err(err) => { - warn!("{err}"); + warn!( + "failed to load persisted remote control enrollment: websocket_url={}, account_id={:?}, err={err}", + remote_control_target.websocket_url, account_id + ); return None; } }; - enrollment.map( - |(server_id, environment_id, server_name)| RemoteControlEnrollment { - account_id: account_id.map(&str::to_string), - environment_id, - server_id, - server_name, - }, - ) + match enrollment { + Some((server_id, environment_id, server_name)) => { + info!( + "reusing persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, account_id, server_id, environment_id + ); + Some(RemoteControlEnrollment { + account_id: account_id.map(&str::to_string), + environment_id, + server_id, + server_name, + }) + } + None => { + info!( + "no persisted remote control enrollment found: websocket_url={}, account_id={:?}", + remote_control_target.websocket_url, account_id + ); + None + } + } } pub(super) async fn update_persisted_remote_control_enrollment( @@ -65,6 +88,12 @@ pub(super) async fn update_persisted_remote_control_enrollment( enrollment: Option<&RemoteControlEnrollment>, ) -> io::Result<()> { let Some(state_db) = state_db else { + info!( + "skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={:?}, has_enrollment={}", + remote_control_target.websocket_url, + account_id, + enrollment.is_some() + ); return Ok(()); }; if let &Some(enrollment) = &enrollment @@ -85,13 +114,25 @@ pub(super) async fn update_persisted_remote_control_enrollment( &enrollment.server_name, ) .await - .map_err(io::Error::other) + .map_err(io::Error::other)?; + info!( + "persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + account_id, + enrollment.server_id, + enrollment.environment_id + ); + Ok(()) } else { - state_db + let rows_affected = state_db .delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id) .await - .map(|_| ()) - .map_err(io::Error::other) + .map_err(io::Error::other)?; + info!( + "cleared persisted remote control enrollment: websocket_url={}, account_id={:?}, rows_affected={rows_affected}", + remote_control_target.websocket_url, account_id + ); + Ok(()) } } diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index d30fbbb51394..42004799cc47 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -49,6 +49,10 @@ use tracing::warn; pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2"; pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor"; +const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration = + std::time::Duration::from_secs(10); +const REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT: std::time::Duration = + std::time::Duration::from_secs(60); struct BoundedOutboundBuffer { buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap>, @@ -235,12 +239,14 @@ impl RemoteControlWebsocket { self.server_event_rx.clone(), self.used_rx.clone(), websocket_writer, + REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL, shutdown_token.clone(), )); join_set.spawn(Self::run_websocket_reader( self.client_tracker.clone(), self.state.clone(), websocket_reader, + REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT, shutdown_token.clone(), )); @@ -260,6 +266,7 @@ impl RemoteControlWebsocket { WebSocketStream>, tungstenite::Message, >, + ping_interval: std::time::Duration, shutdown_token: CancellationToken, ) { let result = Self::run_server_writer_inner( @@ -267,6 +274,7 @@ impl RemoteControlWebsocket { server_event_rx, used_rx, websocket_writer, + ping_interval, shutdown_token, ) .await; @@ -285,6 +293,7 @@ impl RemoteControlWebsocket { WebSocketStream>, tungstenite::Message, >, + ping_interval: std::time::Duration, shutdown_token: CancellationToken, ) -> io::Result<()> { for server_envelope in state.lock().await.outbound_buffer.server_envelopes() { @@ -305,15 +314,35 @@ impl RemoteControlWebsocket { }; } + let mut ping_interval = + tokio::time::interval_at(tokio::time::Instant::now() + ping_interval, ping_interval); + ping_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + let mut server_event_rx = server_event_rx.lock().await; loop { - tokio::select! { - _ = shutdown_token.cancelled() => return Ok(()), - _ = used_rx.wait_for(|used| *used < super::CHANNEL_CAPACITY) => {} - }; + let outbound_has_capacity = *used_rx.borrow() < super::CHANNEL_CAPACITY; let queued_server_envelope = tokio::select! { _ = shutdown_token.cancelled() => return Ok(()), - recv_result = server_event_rx.recv() => { + _ = ping_interval.tick() => { + if let Err(err) = websocket_writer + .send(tungstenite::Message::Ping(Vec::new().into())) + .await + { + return Err(io::Error::other(err)); + } + continue; + } + wait_result = used_rx.changed(), if !outbound_has_capacity => + { + if wait_result.is_err() { + return Err(io::Error::new( + ErrorKind::UnexpectedEof, + "outbound buffer usage channel closed", + )); + } + continue; + } + recv_result = server_event_rx.recv(), if outbound_has_capacity => { match recv_result { Some(queued_server_envelope) => queued_server_envelope, None => { @@ -364,12 +393,14 @@ impl RemoteControlWebsocket { client_tracker: Arc>, state: Arc>, websocket_reader: SplitStream>>, + pong_timeout: std::time::Duration, shutdown_token: CancellationToken, ) { let result = Self::run_websocket_reader_inner( client_tracker, state, websocket_reader, + pong_timeout, shutdown_token, ) .await; @@ -384,15 +415,24 @@ impl RemoteControlWebsocket { client_tracker: Arc>, state: Arc>, mut websocket_reader: SplitStream>>, + pong_timeout: std::time::Duration, shutdown_token: CancellationToken, ) -> io::Result<()> { let mut client_tracker = client_tracker.lock().await; let mut idle_sweep_interval = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL); idle_sweep_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + let pong_deadline = tokio::time::sleep(pong_timeout); + tokio::pin!(pong_deadline); loop { let incoming_message = tokio::select! { _ = shutdown_token.cancelled() => return Ok(()), + _ = &mut pong_deadline => { + return Err(io::Error::new( + ErrorKind::TimedOut, + "remote control websocket pong timeout", + )); + } client_id = client_tracker.bookkeep_join_set() => { let Some(client_id) = client_id else { continue; @@ -432,9 +472,13 @@ impl RemoteControlWebsocket { } } } - Ok(tungstenite::Message::Ping(_)) - | Ok(tungstenite::Message::Pong(_)) - | Ok(tungstenite::Message::Frame(_)) => continue, + Ok(tungstenite::Message::Pong(_)) => { + pong_deadline + .as_mut() + .reset(tokio::time::Instant::now() + pong_timeout); + continue; + } + Ok(tungstenite::Message::Ping(_)) | Ok(tungstenite::Message::Frame(_)) => continue, Ok(tungstenite::Message::Binary(_)) => { warn!("dropping unsupported binary remote-control websocket message"); continue; @@ -601,11 +645,16 @@ pub(super) async fn connect_remote_control_websocket( ensure_rustls_crypto_provider(); let auth = load_remote_control_auth(auth_manager).await?; - if auth.account_id.as_ref() - != enrollment - .as_ref() - .and_then(|enrollment| enrollment.account_id.as_ref()) - { + let enrollment_account_id = enrollment + .as_ref() + .and_then(|enrollment| enrollment.account_id.clone()); + if auth.account_id.as_deref() != enrollment_account_id.as_deref() { + info!( + "clearing in-memory remote control enrollment because account id changed: websocket_url={}, previous_account_id={:?}, current_account_id={:?}", + remote_control_target.websocket_url, + enrollment_account_id.as_deref(), + auth.account_id.as_deref() + ); *enrollment = None; } @@ -619,6 +668,12 @@ pub(super) async fn connect_remote_control_websocket( } if enrollment.is_none() { + info!( + "creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={:?}", + remote_control_target.websocket_url, + remote_control_target.enroll_url, + auth.account_id.as_deref() + ); let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await { Ok(new_enrollment) => new_enrollment, @@ -642,6 +697,13 @@ pub(super) async fn connect_remote_control_websocket( { warn!("failed to persist remote control enrollment in sqlite state db: {err}"); } + info!( + "created new remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + new_enrollment.account_id.as_deref(), + new_enrollment.server_id, + new_enrollment.environment_id + ); *enrollment = Some(new_enrollment); } @@ -660,6 +722,13 @@ pub(super) async fn connect_remote_control_websocket( Err(err) => { match &err { tungstenite::Error::Http(response) if response.status().as_u16() == 404 => { + info!( + "remote control websocket returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + auth.account_id.as_deref(), + enrollment_ref.server_id, + enrollment_ref.environment_id + ); if let Err(clear_err) = update_persisted_remote_control_enrollment( state_db, remote_control_target, @@ -753,6 +822,7 @@ mod tests { use codex_login::token_data::TokenData; use codex_login::token_data::parse_chatgpt_jwt_claims; use codex_state::StateRuntime; + use futures::StreamExt; use pretty_assertions::assert_eq; use std::sync::Arc; use tempfile::TempDir; @@ -764,6 +834,7 @@ mod tests { use tokio::sync::mpsc; use tokio::time::Duration; use tokio::time::timeout; + use tokio_tungstenite::accept_async; async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) @@ -1052,6 +1123,80 @@ mod tests { .expect("websocket task should join"); } + #[tokio::test] + async fn run_server_writer_inner_sends_periodic_ping_frames() { + let (client_stream, mut server_stream) = connected_websocket_pair().await; + let (websocket_writer, _websocket_reader) = client_stream.split(); + let (outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let state = Arc::new(Mutex::new(WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id: 0, + })); + let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); + let server_event_rx = Arc::new(Mutex::new(server_event_rx)); + let shutdown_token = CancellationToken::new(); + let writer_task = tokio::spawn(RemoteControlWebsocket::run_server_writer_inner( + state, + server_event_rx, + used_rx, + websocket_writer, + Duration::from_millis(20), + shutdown_token.clone(), + )); + + let message = timeout(Duration::from_secs(5), server_stream.next()) + .await + .expect("ping frame should arrive in time") + .expect("server websocket should stay open") + .expect("ping frame should read"); + assert!(matches!(message, tungstenite::Message::Ping(_))); + + shutdown_token.cancel(); + writer_task + .await + .expect("writer task should join") + .expect("writer should stop cleanly"); + } + + #[tokio::test] + async fn run_websocket_reader_inner_times_out_without_pong_frames() { + let (client_stream, _server_stream) = connected_websocket_pair().await; + let (_websocket_writer, websocket_reader) = client_stream.split(); + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let state = Arc::new(Mutex::new(WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id: 0, + })); + let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); + let (transport_event_tx, _transport_event_rx) = + mpsc::channel(super::super::CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let client_tracker = Arc::new(Mutex::new(ClientTracker::new( + server_event_tx, + transport_event_tx, + &shutdown_token, + ))); + + let err = timeout( + Duration::from_secs(5), + RemoteControlWebsocket::run_websocket_reader_inner( + client_tracker, + state, + websocket_reader, + Duration::from_millis(100), + shutdown_token, + ), + ) + .await + .expect("reader should time out waiting for pong") + .expect_err("missing pong should fail the websocket reader"); + + assert_eq!(err.kind(), ErrorKind::TimedOut); + assert_eq!(err.to_string(), "remote control websocket pong timeout"); + } + async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) .await @@ -1081,6 +1226,34 @@ mod tests { ) } + async fn connected_websocket_pair() -> ( + WebSocketStream>, + WebSocketStream, + ) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let connect_task = tokio::spawn(connect_async(format!( + "ws://{}", + listener + .local_addr() + .expect("listener should have a local addr") + ))); + let (server_stream, _) = listener + .accept() + .await + .expect("server should accept client"); + let server_stream = accept_async(server_stream) + .await + .expect("server websocket handshake should succeed"); + let (client_stream, _) = connect_task + .await + .expect("client connect task should join") + .expect("client websocket handshake should succeed"); + + (client_stream, server_stream) + } + async fn respond_with_status_and_headers( mut stream: TcpStream, status: &str, From 5b53648571d8e43c434b2e57afb14ab3e5ca8f78 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Sat, 4 Apr 2026 01:18:44 -0700 Subject: [PATCH 13/17] fix cleanup on acking --- .../remote_control/client_tracker.rs | 6 +- .../src/transport/remote_control/protocol.rs | 10 +- .../src/transport/remote_control/websocket.rs | 183 ++++++++++++++---- 3 files changed, 157 insertions(+), 42 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs index f255eba6fcfd..b98a4caf065f 100644 --- a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -82,8 +82,10 @@ impl ClientTracker { while self.join_set.join_next().await.is_some() {} } - pub(crate) fn legacy_stream_id(&self, client_id: &ClientId) -> Option { - self.legacy_stream_ids.get(client_id).cloned() + pub(crate) fn client_has_open_stream(&self, client_id: &ClientId) -> bool { + self.clients + .keys() + .any(|(open_client_id, _)| open_client_id == client_id) } pub(crate) async fn handle_message( diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs index 0bedf32dc5c9..857855f2a08d 100644 --- a/codex-rs/app-server/src/transport/remote_control/protocol.rs +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -44,7 +44,13 @@ impl StreamId { #[derive(Debug, Clone, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "snake_case")] pub enum ClientEvent { - ClientMessage { message: JSONRPCMessage }, + ClientMessage { + message: JSONRPCMessage, + }, + /// Backend-generated acknowledgement for all server envelopes addressed to + /// `client_id` whose envelope `seq_id` is less than or equal to this ack's + /// `seq_id`. This cursor is client-scoped, not stream-scoped, so receivers + /// must not use `stream_id` to partition acks. Ack, Ping, ClientClosed, @@ -59,6 +65,8 @@ pub(crate) struct ClientEnvelope { pub(crate) client_id: ClientId, #[serde(rename = "stream_id", skip_serializing_if = "Option::is_none")] pub(crate) stream_id: Option, + /// For `Ack`, this is the backend-generated per-client cursor over + /// `ServerEnvelope.seq_id`. #[serde(rename = "seq_id", skip_serializing_if = "Option::is_none")] pub(crate) seq_id: Option, #[serde(skip_serializing_if = "Option::is_none")] diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 42004799cc47..9fd095232b33 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -14,7 +14,6 @@ use super::protocol::ClientEvent; use super::protocol::ClientId; use super::protocol::RemoteControlTarget; use super::protocol::ServerEnvelope; -use super::protocol::StreamId; use axum::http::HeaderValue; use base64::Engine; use codex_core::util::backoff; @@ -55,7 +54,10 @@ const REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); struct BoundedOutboundBuffer { - buffer_by_client: HashMap<(ClientId, StreamId), BTreeMap>, + // Remote-control acks are generated by the backend at client scope, so + // retransmit retention is keyed by client_id only. stream_id stays on each + // envelope for routing, but it is not part of the ack cursor. + buffer_by_client: HashMap>, used_tx: watch::Sender, } @@ -71,29 +73,20 @@ impl BoundedOutboundBuffer { fn insert(&mut self, server_envelope: &ServerEnvelope) { self.buffer_by_client - .entry(( - server_envelope.client_id.clone(), - server_envelope.stream_id.clone(), - )) + .entry(server_envelope.client_id.clone()) .or_default() .insert(server_envelope.seq_id, server_envelope.clone()); self.used_tx.send_modify(|used| *used += 1); } - fn remove(&mut self, client_id: &ClientId, stream_id: &StreamId) { - if let Some(buffer) = self - .buffer_by_client - .remove(&(client_id.clone(), stream_id.clone())) - { + fn remove(&mut self, client_id: &ClientId) { + if let Some(buffer) = self.buffer_by_client.remove(client_id) { self.used_tx.send_modify(|used| *used -= buffer.len()); } } - fn ack(&mut self, client_id: &ClientId, stream_id: &StreamId, acked_seq_id: u64) { - let Some(buffer) = self - .buffer_by_client - .get_mut(&(client_id.clone(), stream_id.clone())) - else { + fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) { + let Some(buffer) = self.buffer_by_client.get_mut(client_id) else { return; }; while let Some(seq_id) = buffer.first_key_value().map(|(seq_id, _)| seq_id) @@ -103,8 +96,7 @@ impl BoundedOutboundBuffer { self.used_tx.send_modify(|used| *used -= 1); } if buffer.is_empty() { - self.buffer_by_client - .remove(&(client_id.clone(), stream_id.clone())); + self.buffer_by_client.remove(client_id); } } @@ -433,13 +425,20 @@ impl RemoteControlWebsocket { "remote control websocket pong timeout", )); } - client_id = client_tracker.bookkeep_join_set() => { - let Some(client_id) = client_id else { + client_key = client_tracker.bookkeep_join_set() => { + let Some(client_key) = client_key else { continue; }; - if client_tracker.close_client(&client_id).await.is_err() { + if client_tracker.close_client(&client_key).await.is_err() { return Ok(()); } + if !client_tracker.client_has_open_stream(&client_key.0) { + state + .lock() + .await + .outbound_buffer + .remove(&client_key.0); + } continue; } _ = idle_sweep_interval.tick() => { @@ -449,8 +448,10 @@ impl RemoteControlWebsocket { }; if !expired_client_keys.is_empty() { let mut state = state.lock().await; - for (client_id, stream_id) in expired_client_keys { - state.outbound_buffer.remove(&client_id, &stream_id); + for (client_id, _) in expired_client_keys { + if !client_tracker.client_has_open_stream(&client_id) { + state.outbound_buffer.remove(&client_id); + } } } continue; @@ -497,32 +498,26 @@ impl RemoteControlWebsocket { } }; - let resolved_stream_id = client_envelope - .stream_id - .clone() - .or_else(|| client_tracker.legacy_stream_id(&client_envelope.client_id)); - let mut state = state.lock().await; + let mut websocket_state = state.lock().await; if let Some(cursor) = client_envelope.cursor.as_deref() { - state.subscribe_cursor = Some(cursor.to_string()); + websocket_state.subscribe_cursor = Some(cursor.to_string()); } if let ClientEvent::Ack = &client_envelope.event && let Some(acked_seq_id) = client_envelope.seq_id - && let Some(stream_id) = resolved_stream_id.as_ref() { - state + websocket_state .outbound_buffer - .ack(&client_envelope.client_id, stream_id, acked_seq_id); + .ack(&client_envelope.client_id, acked_seq_id); } - if (matches!(&client_envelope.event, ClientEvent::ClientClosed) - || remote_control_message_starts_connection(&client_envelope.event)) - && let Some(stream_id) = resolved_stream_id.as_ref() - { - state + if remote_control_message_starts_connection(&client_envelope.event) { + websocket_state .outbound_buffer - .remove(&client_envelope.client_id, stream_id); + .remove(&client_envelope.client_id); } - drop(state); + drop(websocket_state); + let is_client_closed = matches!(&client_envelope.event, ClientEvent::ClientClosed); + let client_id = client_envelope.client_id.clone(); if client_tracker .handle_message(client_envelope) .await @@ -530,6 +525,9 @@ impl RemoteControlWebsocket { { return Ok(()); } + if is_client_closed && !client_tracker.client_has_open_stream(&client_id) { + state.lock().await.outbound_buffer.remove(&client_id); + } } } } @@ -811,9 +809,14 @@ fn format_remote_control_websocket_connect_error( #[cfg(test)] mod tests { use super::*; + use crate::outgoing_message::OutgoingMessage; + use crate::transport::remote_control::ServerEvent; + use crate::transport::remote_control::protocol::StreamId; use crate::transport::remote_control::protocol::normalize_remote_control_url; use chrono::Utc; use codex_app_server_protocol::AuthMode; + use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::ServerNotification; use codex_core::test_support::auth_manager_from_auth; use codex_login::AuthCredentialsStoreMode; use codex_login::AuthDotJson; @@ -1197,6 +1200,108 @@ mod tests { assert_eq!(err.to_string(), "remote control websocket pong timeout"); } + #[test] + fn outbound_buffer_acks_by_client_id_across_stream_ids() { + let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let client_1 = ClientId("client-1".to_string()); + let client_2 = ClientId("client-2".to_string()); + + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-1", + /*seq_id*/ 0, + "first-client-old-stream", + )); + outbound_buffer.insert(&server_envelope( + &client_2, + "stream-1", + /*seq_id*/ 1, + "second-client", + )); + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-2", + /*seq_id*/ 2, + "first-client-new-stream", + )); + + outbound_buffer.ack(&client_1, /*acked_seq_id*/ 2); + + let retained = outbound_buffer + .server_envelopes() + .map(|server_envelope| { + ( + server_envelope.client_id.0.as_str(), + server_envelope.stream_id.0.as_str(), + server_envelope.seq_id, + ) + }) + .collect::>(); + assert_eq!(retained, vec![("client-2", "stream-1", 1)]); + assert_eq!(*used_rx.borrow(), 1); + } + + #[test] + fn outbound_buffer_remove_drops_all_streams_for_client_id() { + let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let client_1 = ClientId("client-1".to_string()); + let client_2 = ClientId("client-2".to_string()); + + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-1", + /*seq_id*/ 0, + "first-old", + )); + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-2", + /*seq_id*/ 1, + "first-new", + )); + outbound_buffer.insert(&server_envelope( + &client_2, "stream-1", /*seq_id*/ 2, "second", + )); + + outbound_buffer.remove(&client_1); + + let retained = outbound_buffer + .server_envelopes() + .map(|server_envelope| { + ( + server_envelope.client_id.0.as_str(), + server_envelope.stream_id.0.as_str(), + server_envelope.seq_id, + ) + }) + .collect::>(); + assert_eq!(retained, vec![("client-2", "stream-1", 2)]); + assert_eq!(*used_rx.borrow(), 1); + } + + fn server_envelope( + client_id: &ClientId, + stream_id: &str, + seq_id: u64, + summary: &str, + ) -> ServerEnvelope { + ServerEnvelope { + event: ServerEvent::ServerMessage { + message: Box::new(OutgoingMessage::AppServerNotification( + ServerNotification::ConfigWarning(ConfigWarningNotification { + summary: summary.to_string(), + details: None, + path: None, + range: None, + }), + )), + }, + client_id: client_id.clone(), + stream_id: StreamId(stream_id.to_string()), + seq_id, + } + } + async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) .await From 9bc01bc57bb3959d681e1777c4852779f64ef07f Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Sat, 4 Apr 2026 01:33:15 -0700 Subject: [PATCH 14/17] simplify cleanup of outbound buffer --- .../remote_control/client_tracker.rs | 6 -- .../src/transport/remote_control/tests.rs | 14 ++++- .../src/transport/remote_control/websocket.rs | 58 +++---------------- 3 files changed, 22 insertions(+), 56 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs index b98a4caf065f..fa9a208ade53 100644 --- a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -82,12 +82,6 @@ impl ClientTracker { while self.join_set.join_next().await.is_some() {} } - pub(crate) fn client_has_open_stream(&self, client_id: &ClientId) -> bool { - self.clients - .keys() - .any(|(open_client_id, _)| open_client_id == client_id) - } - pub(crate) async fn handle_message( &mut self, client_envelope: ClientEnvelope, diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs index 6e7f5e219a82..77323e847b89 100644 --- a/codex-rs/app-server/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -410,7 +410,7 @@ async fn remote_control_transport_reconnects_after_disconnect() { } #[tokio::test] -async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { +async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() { let listener = TcpListener::bind("127.0.0.1:0") .await .expect("listener should bind"); @@ -509,6 +509,18 @@ async fn remote_control_transport_clears_outgoing_buffer_when_client_closes() { }) ); + send_client_event( + &mut first_websocket, + ClientEnvelope { + event: ClientEvent::Ack, + client_id: client_id.clone(), + stream_id: None, + seq_id: Some(0), + cursor: None, + }, + ) + .await; + send_client_event( &mut first_websocket, ClientEnvelope { diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 9fd095232b33..70ceb6eff2d4 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -79,12 +79,6 @@ impl BoundedOutboundBuffer { self.used_tx.send_modify(|used| *used += 1); } - fn remove(&mut self, client_id: &ClientId) { - if let Some(buffer) = self.buffer_by_client.remove(client_id) { - self.used_tx.send_modify(|used| *used -= buffer.len()); - } - } - fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) { let Some(buffer) = self.buffer_by_client.get_mut(client_id) else { return; @@ -432,27 +426,11 @@ impl RemoteControlWebsocket { if client_tracker.close_client(&client_key).await.is_err() { return Ok(()); } - if !client_tracker.client_has_open_stream(&client_key.0) { - state - .lock() - .await - .outbound_buffer - .remove(&client_key.0); - } continue; } _ = idle_sweep_interval.tick() => { - let expired_client_keys = match client_tracker.close_expired_clients().await { - Ok(expired_client_keys) => expired_client_keys, - Err(_) => return Ok(()), - }; - if !expired_client_keys.is_empty() { - let mut state = state.lock().await; - for (client_id, _) in expired_client_keys { - if !client_tracker.client_has_open_stream(&client_id) { - state.outbound_buffer.remove(&client_id); - } - } + if client_tracker.close_expired_clients().await.is_err() { + return Ok(()); } continue; } @@ -509,15 +487,8 @@ impl RemoteControlWebsocket { .outbound_buffer .ack(&client_envelope.client_id, acked_seq_id); } - if remote_control_message_starts_connection(&client_envelope.event) { - websocket_state - .outbound_buffer - .remove(&client_envelope.client_id); - } drop(websocket_state); - let is_client_closed = matches!(&client_envelope.event, ClientEvent::ClientClosed); - let client_id = client_envelope.client_id.clone(); if client_tracker .handle_message(client_envelope) .await @@ -525,24 +496,10 @@ impl RemoteControlWebsocket { { return Ok(()); } - if is_client_closed && !client_tracker.client_has_open_stream(&client_id) { - state.lock().await.outbound_buffer.remove(&client_id); - } } } } -fn remote_control_message_starts_connection(event: &ClientEvent) -> bool { - matches!( - event, - ClientEvent::ClientMessage { - message: codex_app_server_protocol::JSONRPCMessage::Request( - codex_app_server_protocol::JSONRPCRequest { method, .. } - ), - } if method == "initialize" - ) -} - fn set_remote_control_header( headers: &mut tungstenite::http::HeaderMap, name: &'static str, @@ -1242,7 +1199,7 @@ mod tests { } #[test] - fn outbound_buffer_remove_drops_all_streams_for_client_id() { + fn outbound_buffer_retains_unacked_messages_until_ack_advances() { let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); let client_1 = ClientId("client-1".to_string()); let client_2 = ClientId("client-2".to_string()); @@ -1263,7 +1220,7 @@ mod tests { &client_2, "stream-1", /*seq_id*/ 2, "second", )); - outbound_buffer.remove(&client_1); + outbound_buffer.ack(&client_1, /*acked_seq_id*/ 0); let retained = outbound_buffer .server_envelopes() @@ -1275,8 +1232,11 @@ mod tests { ) }) .collect::>(); - assert_eq!(retained, vec![("client-2", "stream-1", 2)]); - assert_eq!(*used_rx.borrow(), 1); + assert_eq!( + retained, + vec![("client-1", "stream-2", 1), ("client-2", "stream-1", 2)] + ); + assert_eq!(*used_rx.borrow(), 2); } fn server_envelope( From 32706e4119691d5f24781e8627efe1939102c751 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Sat, 4 Apr 2026 01:44:52 -0700 Subject: [PATCH 15/17] fix test --- .../app-server/src/transport/remote_control/websocket.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 70ceb6eff2d4..962a046fe953 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -1184,7 +1184,7 @@ mod tests { outbound_buffer.ack(&client_1, /*acked_seq_id*/ 2); - let retained = outbound_buffer + let mut retained = outbound_buffer .server_envelopes() .map(|server_envelope| { ( @@ -1194,6 +1194,7 @@ mod tests { ) }) .collect::>(); + retained.sort_unstable(); assert_eq!(retained, vec![("client-2", "stream-1", 1)]); assert_eq!(*used_rx.borrow(), 1); } @@ -1222,7 +1223,7 @@ mod tests { outbound_buffer.ack(&client_1, /*acked_seq_id*/ 0); - let retained = outbound_buffer + let mut retained = outbound_buffer .server_envelopes() .map(|server_envelope| { ( @@ -1232,6 +1233,7 @@ mod tests { ) }) .collect::>(); + retained.sort_unstable(); assert_eq!( retained, vec![("client-1", "stream-2", 1), ("client-2", "stream-1", 2)] From fe2732cbc94136fafaaf1a85dc2251b6957c6fa6 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Sat, 4 Apr 2026 02:17:01 -0700 Subject: [PATCH 16/17] fix migrations due to bad merge --- .../state/migrations/{0024_drop_logs.sql => 0023_drop_logs.sql} | 0 ...ontrol_enrollments.sql => 0024_remote_control_enrollments.sql} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename codex-rs/state/migrations/{0024_drop_logs.sql => 0023_drop_logs.sql} (100%) rename codex-rs/state/migrations/{0023_remote_control_enrollments.sql => 0024_remote_control_enrollments.sql} (100%) diff --git a/codex-rs/state/migrations/0024_drop_logs.sql b/codex-rs/state/migrations/0023_drop_logs.sql similarity index 100% rename from codex-rs/state/migrations/0024_drop_logs.sql rename to codex-rs/state/migrations/0023_drop_logs.sql diff --git a/codex-rs/state/migrations/0023_remote_control_enrollments.sql b/codex-rs/state/migrations/0024_remote_control_enrollments.sql similarity index 100% rename from codex-rs/state/migrations/0023_remote_control_enrollments.sql rename to codex-rs/state/migrations/0024_remote_control_enrollments.sql From 194e34febecd1b9830f49ec197c7f04f176fd7e5 Mon Sep 17 00:00:00 2001 From: Ruslan Nigmatullin Date: Mon, 6 Apr 2026 12:59:21 -0700 Subject: [PATCH 17/17] cr for db --- codex-rs/app-server/src/lib.rs | 13 +- .../src/transport/remote_control/enroll.rs | 131 ++++++----- .../src/transport/remote_control/mod.rs | 9 +- .../src/transport/remote_control/tests.rs | 201 ++++++++++++++++- .../src/transport/remote_control/websocket.rs | 139 ++++++++---- codex-rs/app-server/src/transport/stdio.rs | 23 ++ .../0024_remote_control_enrollments.sql | 3 +- codex-rs/state/src/lib.rs | 1 + codex-rs/state/src/runtime.rs | 2 + codex-rs/state/src/runtime/remote_control.rs | 206 ++++++++++++------ 10 files changed, 555 insertions(+), 173 deletions(-) diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 2c4387e4672e..c9cf74eecd48 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -48,6 +48,7 @@ use codex_feedback::CodexFeedback; use codex_protocol::protocol::SessionSource; use codex_state::log_db; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use toml::Value as TomlValue; @@ -535,11 +536,18 @@ pub async fn run_main_with_transport( let single_client_mode = matches!(&transport, AppServerTransport::Stdio); let shutdown_when_no_connections = single_client_mode; let graceful_signal_restart_enabled = !single_client_mode; + let mut app_server_client_name_rx = None; match transport { AppServerTransport::Stdio => { - start_stdio_connection(transport_event_tx.clone(), &mut transport_accept_handles) - .await?; + let (stdio_client_name_tx, stdio_client_name_rx) = oneshot::channel::(); + app_server_client_name_rx = Some(stdio_client_name_rx); + start_stdio_connection( + transport_event_tx.clone(), + &mut transport_accept_handles, + stdio_client_name_tx, + ) + .await?; } AppServerTransport::WebSocket { bind_address } => { let accept_handle = start_websocket_acceptor( @@ -563,6 +571,7 @@ pub async fn run_main_with_transport( auth_manager.clone(), transport_event_tx.clone(), transport_shutdown_token.clone(), + app_server_client_name_rx, ) .await?; transport_accept_handles.push(accept_handle); diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs index 787ece764686..dbe18c8355db 100644 --- a/codex-rs/app-server/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -3,6 +3,7 @@ use super::protocol::EnrollRemoteServerResponse; use super::protocol::RemoteControlTarget; use axum::http::HeaderMap; use codex_login::default_client::build_reqwest_client; +use codex_state::RemoteControlEnrollmentRecord; use codex_state::StateRuntime; use gethostname::gethostname; use std::io; @@ -20,7 +21,7 @@ pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct RemoteControlEnrollment { - pub(super) account_id: Option, + pub(super) account_id: String, pub(super) environment_id: String, pub(super) server_id: String, pub(super) server_name: String, @@ -29,52 +30,61 @@ pub(super) struct RemoteControlEnrollment { #[derive(Debug, Clone, PartialEq, Eq)] pub(super) struct RemoteControlConnectionAuth { pub(super) bearer_token: String, - pub(super) account_id: Option, + pub(super) account_id: String, } pub(super) async fn load_persisted_remote_control_enrollment( state_db: Option<&StateRuntime>, remote_control_target: &RemoteControlTarget, - account_id: Option<&str>, + account_id: &str, + app_server_client_name: Option<&str>, ) -> Option { let Some(state_db) = state_db else { info!( - "remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={:?}", - remote_control_target.websocket_url, account_id + "remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={}, app_server_client_name={:?}", + remote_control_target.websocket_url, account_id, app_server_client_name ); return None; }; let enrollment = match state_db - .get_remote_control_enrollment(&remote_control_target.websocket_url, account_id) + .get_remote_control_enrollment( + &remote_control_target.websocket_url, + account_id, + app_server_client_name, + ) .await { Ok(enrollment) => enrollment, Err(err) => { warn!( - "failed to load persisted remote control enrollment: websocket_url={}, account_id={:?}, err={err}", - remote_control_target.websocket_url, account_id + "failed to load persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, err={err}", + remote_control_target.websocket_url, account_id, app_server_client_name ); return None; } }; match enrollment { - Some((server_id, environment_id, server_name)) => { + Some(enrollment) => { info!( - "reusing persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", - remote_control_target.websocket_url, account_id, server_id, environment_id + "reusing persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + account_id, + app_server_client_name, + enrollment.server_id, + enrollment.environment_id ); Some(RemoteControlEnrollment { - account_id: account_id.map(&str::to_string), - environment_id, - server_id, - server_name, + account_id: enrollment.account_id, + environment_id: enrollment.environment_id, + server_id: enrollment.server_id, + server_name: enrollment.server_name, }) } None => { info!( - "no persisted remote control enrollment found: websocket_url={}, account_id={:?}", - remote_control_target.websocket_url, account_id + "no persisted remote control enrollment found: websocket_url={}, account_id={}, app_server_client_name={:?}", + remote_control_target.websocket_url, account_id, app_server_client_name ); None } @@ -84,53 +94,61 @@ pub(super) async fn load_persisted_remote_control_enrollment( pub(super) async fn update_persisted_remote_control_enrollment( state_db: Option<&StateRuntime>, remote_control_target: &RemoteControlTarget, - account_id: Option<&str>, + account_id: &str, + app_server_client_name: Option<&str>, enrollment: Option<&RemoteControlEnrollment>, ) -> io::Result<()> { let Some(state_db) = state_db else { info!( - "skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={:?}, has_enrollment={}", + "skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={}, app_server_client_name={:?}, has_enrollment={}", remote_control_target.websocket_url, account_id, + app_server_client_name, enrollment.is_some() ); return Ok(()); }; if let &Some(enrollment) = &enrollment - && enrollment.account_id.as_deref() != account_id + && enrollment.account_id != account_id { return Err(io::Error::other(format!( - "enrollment account_id does not match expected account_id `{account_id:?}`" + "enrollment account_id does not match expected account_id `{account_id}`" ))); } if let Some(enrollment) = enrollment { state_db - .upsert_remote_control_enrollment( - &remote_control_target.websocket_url, - account_id, - &enrollment.server_id, - &enrollment.environment_id, - &enrollment.server_name, - ) + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: remote_control_target.websocket_url.clone(), + account_id: account_id.to_string(), + app_server_client_name: app_server_client_name.map(str::to_string), + server_id: enrollment.server_id.clone(), + environment_id: enrollment.environment_id.clone(), + server_name: enrollment.server_name.clone(), + }) .await .map_err(io::Error::other)?; info!( - "persisted remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + "persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, server_id={}, environment_id={}", remote_control_target.websocket_url, account_id, + app_server_client_name, enrollment.server_id, enrollment.environment_id ); Ok(()) } else { let rows_affected = state_db - .delete_remote_control_enrollment(&remote_control_target.websocket_url, account_id) + .delete_remote_control_enrollment( + &remote_control_target.websocket_url, + account_id, + app_server_client_name, + ) .await .map_err(io::Error::other)?; info!( - "cleared persisted remote control enrollment: websocket_url={}, account_id={:?}, rows_affected={rows_affected}", - remote_control_target.websocket_url, account_id + "cleared persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, rows_affected={rows_affected}", + remote_control_target.websocket_url, account_id, app_server_client_name ); Ok(()) } @@ -181,15 +199,12 @@ pub(super) async fn enroll_remote_control_server( app_server_version: env!("CARGO_PKG_VERSION"), }; let client = build_reqwest_client(); - let mut http_request = client + let http_request = client .post(enroll_url) .timeout(REMOTE_CONTROL_ENROLL_TIMEOUT) .bearer_auth(&auth.bearer_token) + .header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id) .json(&request); - let account_id = auth.account_id.as_deref(); - if let Some(account_id) = account_id { - http_request = http_request.header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id); - } let response = http_request.send().await.map_err(|err| { io::Error::other(format!( @@ -227,7 +242,7 @@ pub(super) async fn enroll_remote_control_server( })?; Ok(RemoteControlEnrollment { - account_id: account_id.map(&str::to_string), + account_id: auth.account_id.clone(), environment_id: enrollment.environment_id, server_id: enrollment.server_id, server_name, @@ -267,13 +282,13 @@ mod tests { normalize_remote_control_url("https://api.chatgpt-staging.com/other/control") .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { - account_id: Some("account-a".to_string()), + account_id: "account-a".to_string(), environment_id: "env_first".to_string(), server_id: "srv_e_first".to_string(), server_name: "first-server".to_string(), }; let second_enrollment = RemoteControlEnrollment { - account_id: Some("account-a".to_string()), + account_id: "account-a".to_string(), environment_id: "env_second".to_string(), server_id: "srv_e_second".to_string(), server_name: "second-server".to_string(), @@ -282,7 +297,8 @@ mod tests { update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &first_target, - Some("account-a"), + "account-a", + Some("desktop-client"), Some(&first_enrollment), ) .await @@ -290,7 +306,8 @@ mod tests { update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &second_target, - Some("account-a"), + "account-a", + Some("desktop-client"), Some(&second_enrollment), ) .await @@ -300,7 +317,8 @@ mod tests { load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &first_target, - Some("account-a"), + "account-a", + Some("desktop-client"), ) .await, Some(first_enrollment.clone()) @@ -309,7 +327,8 @@ mod tests { load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &first_target, - Some("account-b"), + "account-b", + Some("desktop-client"), ) .await, None @@ -318,7 +337,8 @@ mod tests { load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &second_target, - Some("account-a"), + "account-a", + Some("desktop-client"), ) .await, Some(second_enrollment) @@ -335,13 +355,13 @@ mod tests { normalize_remote_control_url("https://api.chatgpt-staging.com/other/control") .expect("second target should parse"); let first_enrollment = RemoteControlEnrollment { - account_id: Some("account-a".to_string()), + account_id: "account-a".to_string(), environment_id: "env_first".to_string(), server_id: "srv_e_first".to_string(), server_name: "first-server".to_string(), }; let second_enrollment = RemoteControlEnrollment { - account_id: Some("account-a".to_string()), + account_id: "account-a".to_string(), environment_id: "env_second".to_string(), server_id: "srv_e_second".to_string(), server_name: "second-server".to_string(), @@ -350,7 +370,8 @@ mod tests { update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &first_target, - Some("account-a"), + "account-a", + /*app_server_client_name*/ None, Some(&first_enrollment), ) .await @@ -358,7 +379,8 @@ mod tests { update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &second_target, - Some("account-a"), + "account-a", + /*app_server_client_name*/ None, Some(&second_enrollment), ) .await @@ -367,7 +389,8 @@ mod tests { update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &first_target, - Some("account-a"), + "account-a", + /*app_server_client_name*/ None, /*enrollment*/ None, ) .await @@ -377,7 +400,8 @@ mod tests { load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &first_target, - Some("account-a"), + "account-a", + /*app_server_client_name*/ None, ) .await, None @@ -386,7 +410,8 @@ mod tests { load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &second_target, - Some("account-a"), + "account-a", + /*app_server_client_name*/ None, ) .await, Some(second_enrollment) @@ -421,7 +446,7 @@ mod tests { &remote_control_target, &RemoteControlConnectionAuth { bearer_token: "Access Token".to_string(), - account_id: Some("account_id".to_string()), + account_id: "account_id".to_string(), }, ) .await diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs index 361debce5c18..6d9d65e8a313 100644 --- a/codex-rs/app-server/src/transport/remote_control/mod.rs +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -35,6 +35,7 @@ pub(crate) async fn start_remote_control( auth_manager: Arc, transport_event_tx: mpsc::Sender, shutdown_token: CancellationToken, + app_server_client_name_rx: Option>, ) -> io::Result> { let remote_control_target = normalize_remote_control_url(&remote_control_url)?; validate_remote_control_auth(&auth_manager).await?; @@ -47,7 +48,7 @@ pub(crate) async fn start_remote_control( transport_event_tx, shutdown_token, ) - .run() + .run(app_server_client_name_rx) .await; })) } @@ -55,7 +56,11 @@ pub(crate) async fn start_remote_control( pub(crate) async fn validate_remote_control_auth( auth_manager: &Arc, ) -> io::Result<()> { - load_remote_control_auth(auth_manager).await.map(|_| ()) + match load_remote_control_auth(auth_manager).await { + Ok(_) => Ok(()), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Ok(()), + Err(err) => Err(err), + } } #[cfg(test)] diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs index 77323e847b89..280949adfc00 100644 --- a/codex-rs/app-server/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -13,13 +13,19 @@ use crate::outgoing_message::QueuedOutgoingMessage; use crate::transport::CHANNEL_CAPACITY; use crate::transport::TransportEvent; use base64::Engine; +use codex_app_server_protocol::AuthMode; use codex_app_server_protocol::ConfigWarningNotification; use codex_app_server_protocol::JSONRPCMessage; use codex_app_server_protocol::ServerNotification; use codex_core::test_support::auth_manager_from_auth; use codex_core::test_support::auth_manager_from_auth_with_home; +use codex_login::AuthCredentialsStoreMode; +use codex_login::AuthDotJson; use codex_login::AuthManager; use codex_login::CodexAuth; +use codex_login::save_auth; +use codex_login::token_data::TokenData; +use codex_login::token_data::parse_chatgpt_jwt_claims; use codex_state::StateRuntime; use futures::SinkExt; use futures::StreamExt; @@ -36,6 +42,7 @@ use tokio::io::BufReader; use tokio::net::TcpListener; use tokio::net::TcpStream; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::time::Duration; use tokio::time::timeout; use tokio_tungstenite::WebSocketStream; @@ -55,6 +62,43 @@ fn remote_control_auth_manager_with_home(codex_home: &TempDir) -> Arc) -> AuthDotJson { + #[derive(serde::Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_user_id": "user-12345", + "user_id": "user-12345", + "chatgpt_account_id": "account_id" + } + }); + let b64 = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + let header_b64 = b64(&serde_json::to_vec(&header).expect("header should serialize")); + let payload_b64 = b64(&serde_json::to_vec(&payload).expect("payload should serialize")); + let fake_jwt = format!("{header_b64}.{payload_b64}.sig"); + + AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: parse_chatgpt_jwt_claims(&fake_jwt).expect("fake jwt should parse"), + access_token: "Access Token".to_string(), + refresh_token: "refresh-token".to_string(), + account_id: account_id.map(str::to_string), + }), + last_refresh: Some(chrono::Utc::now()), + } +} + async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) .await @@ -87,6 +131,7 @@ async fn remote_control_transport_manages_virtual_clients_and_routes_messages() remote_control_auth_manager(), transport_event_tx, shutdown_token.clone(), + /*app_server_client_name_rx*/ None, ) .await .expect("remote control should start"); @@ -350,6 +395,7 @@ async fn remote_control_transport_reconnects_after_disconnect() { remote_control_auth_manager(), transport_event_tx, shutdown_token.clone(), + /*app_server_client_name_rx*/ None, ) .await .expect("remote control should start"); @@ -425,6 +471,7 @@ async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() { remote_control_auth_manager(), transport_event_tx, shutdown_token.clone(), + /*app_server_client_name_rx*/ None, ) .await .expect("remote control should start"); @@ -590,6 +637,7 @@ async fn remote_control_http_mode_enrolls_before_connecting() { remote_control_auth_manager(), transport_event_tx, shutdown_token.clone(), + /*app_server_client_name_rx*/ None, ) .await .expect("remote control should start"); @@ -781,7 +829,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling let remote_control_target = normalize_remote_control_url(&remote_control_url).expect("target should parse"); let persisted_enrollment = RemoteControlEnrollment { - account_id: Some("account_id".to_string()), + account_id: "account_id".to_string(), environment_id: "env_persisted".to_string(), server_id: "srv_e_persisted".to_string(), server_name: "persisted-server".to_string(), @@ -789,7 +837,8 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &remote_control_target, - Some("account_id"), + "account_id", + /*app_server_client_name*/ None, Some(&persisted_enrollment), ) .await @@ -804,6 +853,7 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling remote_control_auth_manager_with_home(&codex_home), transport_event_tx, shutdown_token.clone(), + /*app_server_client_name_rx*/ None, ) .await .expect("remote control should start"); @@ -821,7 +871,8 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &remote_control_target, - Some("account_id"), + "account_id", + /*app_server_client_name*/ None, ) .await, Some(persisted_enrollment) @@ -831,6 +882,139 @@ async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling let _ = remote_handle.await; } +#[tokio::test] +async fn remote_control_stdio_mode_waits_for_client_name_before_connecting() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let app_server_client_name = "stdio-client"; + let persisted_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_persisted".to_string(), + server_id: "srv_e_persisted".to_string(), + server_name: "persisted-server".to_string(), + }; + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + "account_id", + Some(app_server_client_name), + Some(&persisted_enrollment), + ) + .await + .expect("persisted enrollment should save"); + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let (app_server_client_name_tx, app_server_client_name_rx) = oneshot::channel::(); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + remote_control_auth_manager_with_home(&codex_home), + transport_event_tx, + shutdown_token.clone(), + Some(app_server_client_name_rx), + ) + .await + .expect("remote control should start"); + + timeout(Duration::from_millis(100), listener.accept()) + .await + .expect_err("remote control should wait for the stdio client name"); + + let _ = app_server_client_name_tx.send(app_server_client_name.to_string()); + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&persisted_enrollment.server_id) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_waits_for_account_id_before_enrolling() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json(/*account_id*/ None), + AuthCredentialsStoreMode::File, + ) + .expect("auth without account id should save"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + ); + let expected_server_name = gethostname().to_string_lossy().trim().to_string(); + let expected_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_ready".to_string(), + server_id: "srv_e_ready".to_string(), + server_name: expected_server_name, + }; + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + auth_manager, + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start before account id is available"); + + timeout(Duration::from_millis(100), listener.accept()) + .await + .expect_err("remote control should wait for account id before enrolling"); + + save_auth( + codex_home.path(), + &remote_control_auth_dot_json(Some("account_id")), + AuthCredentialsStoreMode::File, + ) + .expect("auth with account id should save"); + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json( + enroll_request.stream, + json!({ + "server_id": expected_enrollment.server_id, + "environment_id": expected_enrollment.environment_id, + }), + ) + .await; + + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&expected_enrollment.server_id) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + #[tokio::test] async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() { let listener = TcpListener::bind("127.0.0.1:0") @@ -843,13 +1027,13 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() normalize_remote_control_url(&remote_control_url).expect("target should parse"); let expected_server_name = gethostname().to_string_lossy().trim().to_string(); let stale_enrollment = RemoteControlEnrollment { - account_id: Some("account_id".to_string()), + account_id: "account_id".to_string(), environment_id: "env_stale".to_string(), server_id: "srv_e_stale".to_string(), server_name: "stale-server".to_string(), }; let refreshed_enrollment = RemoteControlEnrollment { - account_id: Some("account_id".to_string()), + account_id: "account_id".to_string(), environment_id: "env_refreshed".to_string(), server_id: "srv_e_refreshed".to_string(), server_name: expected_server_name, @@ -857,7 +1041,8 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() update_persisted_remote_control_enrollment( Some(state_db.as_ref()), &remote_control_target, - Some("account_id"), + "account_id", + /*app_server_client_name*/ None, Some(&stale_enrollment), ) .await @@ -872,6 +1057,7 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() remote_control_auth_manager_with_home(&codex_home), transport_event_tx, shutdown_token.clone(), + /*app_server_client_name_rx*/ None, ) .await .expect("remote control should start"); @@ -910,7 +1096,8 @@ async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() load_persisted_remote_control_enrollment( Some(state_db.as_ref()), &remote_control_target, - Some("account_id"), + "account_id", + /*app_server_client_name*/ None, ) .await, Some(refreshed_enrollment) diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs index 962a046fe953..56bc88cc6f48 100644 --- a/codex-rs/app-server/src/transport/remote_control/websocket.rs +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -33,6 +33,7 @@ use std::sync::Arc; use tokio::net::TcpStream; use tokio::sync::Mutex; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::sync::watch; use tokio::time::MissedTickBehavior; use tokio_tungstenite::MaybeTlsStream; @@ -52,6 +53,8 @@ const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration = std::time::Duration::from_secs(10); const REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(60); +const REMOTE_CONTROL_ACCOUNT_ID_RETRY_INTERVAL: std::time::Duration = + std::time::Duration::from_secs(1); struct BoundedOutboundBuffer { // Remote-control acks are generated by the backend at client scope, so @@ -155,10 +158,27 @@ impl RemoteControlWebsocket { } } - pub(crate) async fn run(mut self) { + pub(crate) async fn run( + mut self, + app_server_client_name_rx: Option>, + ) { + let app_server_client_name = match self + .wait_for_app_server_client_name(app_server_client_name_rx) + .await + { + Ok(app_server_client_name) => app_server_client_name, + Err(_) => { + self.client_tracker.lock().await.shutdown().await; + return; + } + }; + loop { let shutdown_token = self.shutdown_token.child_token(); - let websocket_connection = match self.connect(&shutdown_token).await { + let websocket_connection = match self + .connect(&shutdown_token, app_server_client_name.as_deref()) + .await + { Some(websocket_connection) => websocket_connection, None => break, }; @@ -170,9 +190,28 @@ impl RemoteControlWebsocket { self.client_tracker.lock().await.shutdown().await; } + async fn wait_for_app_server_client_name( + &self, + app_server_client_name_rx: Option>, + ) -> Result, ()> { + match app_server_client_name_rx { + Some(app_server_client_name_rx) => { + tokio::select! { + _ = self.shutdown_token.cancelled() => Err(()), + app_server_client_name = app_server_client_name_rx => match app_server_client_name { + Ok(app_server_client_name) => Ok(Some(app_server_client_name)), + Err(_) => Err(()), + }, + } + } + None => Ok(None), + } + } + async fn connect( &mut self, shutdown_token: &CancellationToken, + app_server_client_name: Option<&str>, ) -> Option>> { loop { let subscribe_cursor = self.state.lock().await.subscribe_cursor.clone(); @@ -185,6 +224,7 @@ impl RemoteControlWebsocket { &mut self.auth_recovery, &mut self.enrollment, subscribe_cursor.as_deref(), + app_server_client_name, ) => { match connect_result { Ok((websocket_connection, response)) => { @@ -198,9 +238,15 @@ impl RemoteControlWebsocket { return Some(websocket_connection); } Err(err) => { - warn!("{err}"); - let reconnect_delay = backoff(self.reconnect_attempt); - self.reconnect_attempt += 1; + let reconnect_delay = if err.kind() == ErrorKind::WouldBlock { + info!("{err}"); + REMOTE_CONTROL_ACCOUNT_ID_RETRY_INTERVAL + } else { + warn!("{err}"); + let reconnect_delay = backoff(self.reconnect_attempt); + self.reconnect_attempt += 1; + reconnect_delay + }; tokio::select! { _ = shutdown_token.cancelled() => return None, _ = tokio::time::sleep(reconnect_delay) => {} @@ -544,9 +590,7 @@ fn build_remote_control_websocket_request( "authorization", &format!("Bearer {}", auth.bearer_token), )?; - if let Some(account_id) = auth.account_id.as_deref() { - set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, account_id)?; - } + set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id)?; if let Some(subscribe_cursor) = subscribe_cursor { set_remote_control_header( headers, @@ -560,17 +604,28 @@ fn build_remote_control_websocket_request( pub(crate) async fn load_remote_control_auth( auth_manager: &Arc, ) -> io::Result { - let auth = match auth_manager.auth().await { - Some(auth) => auth, - None => { - auth_manager.reload(); - auth_manager.auth().await.ok_or_else(|| { - io::Error::new( + let mut reloaded = false; + let auth = loop { + let Some(auth) = auth_manager.auth().await else { + if reloaded { + return Err(io::Error::new( ErrorKind::PermissionDenied, "remote control requires ChatGPT authentication", - ) - })? + )); + } + auth_manager.reload(); + reloaded = true; + continue; + }; + if !auth.is_chatgpt_auth() { + break auth; + } + if auth.get_account_id().is_none() && !reloaded { + auth_manager.reload(); + reloaded = true; + continue; } + break auth; }; if !auth.is_chatgpt_auth() { @@ -582,7 +637,12 @@ pub(crate) async fn load_remote_control_auth( Ok(RemoteControlConnectionAuth { bearer_token: auth.get_token().map_err(io::Error::other)?, - account_id: auth.get_account_id(), + account_id: auth.get_account_id().ok_or_else(|| { + io::Error::new( + ErrorKind::WouldBlock, + "remote control enrollment is waiting for a ChatGPT account id", + ) + })?, }) } @@ -593,6 +653,7 @@ pub(super) async fn connect_remote_control_websocket( auth_recovery: &mut UnauthorizedRecovery, enrollment: &mut Option, subscribe_cursor: Option<&str>, + app_server_client_name: Option<&str>, ) -> io::Result<( WebSocketStream>, tungstenite::http::Response<()>, @@ -600,15 +661,15 @@ pub(super) async fn connect_remote_control_websocket( ensure_rustls_crypto_provider(); let auth = load_remote_control_auth(auth_manager).await?; - let enrollment_account_id = enrollment - .as_ref() - .and_then(|enrollment| enrollment.account_id.clone()); - if auth.account_id.as_deref() != enrollment_account_id.as_deref() { + let enrollment_account_id = enrollment.as_ref().map(|enrollment| &enrollment.account_id); + if enrollment_account_id.is_some_and(|account_id| account_id != &auth.account_id) { info!( "clearing in-memory remote control enrollment because account id changed: websocket_url={}, previous_account_id={:?}, current_account_id={:?}", remote_control_target.websocket_url, - enrollment_account_id.as_deref(), - auth.account_id.as_deref() + enrollment + .as_ref() + .map(|enrollment| enrollment.account_id.as_str()), + auth.account_id ); *enrollment = None; } @@ -617,17 +678,16 @@ pub(super) async fn connect_remote_control_websocket( *enrollment = load_persisted_remote_control_enrollment( state_db, remote_control_target, - auth.account_id.as_deref(), + &auth.account_id, + app_server_client_name, ) .await; } if enrollment.is_none() { info!( - "creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={:?}", - remote_control_target.websocket_url, - remote_control_target.enroll_url, - auth.account_id.as_deref() + "creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={}", + remote_control_target.websocket_url, remote_control_target.enroll_url, auth.account_id ); let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await { @@ -645,7 +705,8 @@ pub(super) async fn connect_remote_control_websocket( if let Err(err) = update_persisted_remote_control_enrollment( state_db, remote_control_target, - auth.account_id.as_deref(), + &auth.account_id, + app_server_client_name, Some(&new_enrollment), ) .await @@ -653,9 +714,9 @@ pub(super) async fn connect_remote_control_websocket( warn!("failed to persist remote control enrollment in sqlite state db: {err}"); } info!( - "created new remote control enrollment: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + "created new remote control enrollment: websocket_url={}, account_id={}, server_id={}, environment_id={}", remote_control_target.websocket_url, - new_enrollment.account_id.as_deref(), + new_enrollment.account_id, new_enrollment.server_id, new_enrollment.environment_id ); @@ -678,16 +739,17 @@ pub(super) async fn connect_remote_control_websocket( match &err { tungstenite::Error::Http(response) if response.status().as_u16() == 404 => { info!( - "remote control websocket returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={:?}, server_id={}, environment_id={}", + "remote control websocket returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={}, server_id={}, environment_id={}", remote_control_target.websocket_url, - auth.account_id.as_deref(), + auth.account_id, enrollment_ref.server_id, enrollment_ref.environment_id ); if let Err(clear_err) = update_persisted_remote_control_enrollment( state_db, remote_control_target, - auth.account_id.as_deref(), + &auth.account_id, + app_server_client_name, /*enrollment*/ None, ) .await @@ -884,7 +946,7 @@ mod tests { let auth_manager = remote_control_auth_manager(); let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut enrollment = Some(RemoteControlEnrollment { - account_id: Some("account_id".to_string()), + account_id: "account_id".to_string(), environment_id: "env_test".to_string(), server_id: "srv_e_test".to_string(), server_name: "test-server".to_string(), @@ -897,6 +959,7 @@ mod tests { &mut auth_recovery, &mut enrollment, /*subscribe_cursor*/ None, + /*app_server_client_name*/ None, ) .await { @@ -939,7 +1002,7 @@ mod tests { ); let mut auth_recovery = auth_manager.unauthorized_recovery(); let mut enrollment = Some(RemoteControlEnrollment { - account_id: Some("account_id".to_string()), + account_id: "account_id".to_string(), environment_id: "env_test".to_string(), server_id: "srv_e_test".to_string(), server_name: "test-server".to_string(), @@ -958,6 +1021,7 @@ mod tests { &mut auth_recovery, &mut enrollment, /*subscribe_cursor*/ None, + /*app_server_client_name*/ None, ) .await .expect_err("unauthorized response should fail the websocket connect"); @@ -1024,6 +1088,7 @@ mod tests { &mut auth_recovery, &mut enrollment, /*subscribe_cursor*/ None, + /*app_server_client_name*/ None, ) .await .expect_err("unauthorized enrollment should fail the websocket connect"); @@ -1069,7 +1134,7 @@ mod tests { transport_event_tx, shutdown_token, ) - .run() + .run(/*app_server_client_name_rx*/ None) .await } }); diff --git a/codex-rs/app-server/src/transport/stdio.rs b/codex-rs/app-server/src/transport/stdio.rs index 6d40593a6190..20eab025fb72 100644 --- a/codex-rs/app-server/src/transport/stdio.rs +++ b/codex-rs/app-server/src/transport/stdio.rs @@ -4,6 +4,9 @@ use super::forward_incoming_message; use super::next_connection_id; use super::serialize_outgoing_message; use crate::outgoing_message::QueuedOutgoingMessage; +use codex_app_server_protocol::InitializeParams; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCRequest; use std::io::ErrorKind; use std::io::Result as IoResult; use tokio::io; @@ -11,6 +14,7 @@ use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use tracing::debug; use tracing::error; @@ -19,6 +23,7 @@ use tracing::info; pub(crate) async fn start_stdio_connection( transport_event_tx: mpsc::Sender, stdio_handles: &mut Vec>, + initialize_client_name_tx: oneshot::Sender, ) -> IoResult<()> { let connection_id = next_connection_id(); let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); @@ -37,10 +42,16 @@ pub(crate) async fn start_stdio_connection( let stdin = io::stdin(); let reader = BufReader::new(stdin); let mut lines = reader.lines(); + let mut initialize_client_name_tx = Some(initialize_client_name_tx); loop { match lines.next_line().await { Ok(Some(line)) => { + if let Some(client_name) = stdio_initialize_client_name(&line) + && let Some(initialize_client_name_tx) = initialize_client_name_tx.take() + { + let _ = initialize_client_name_tx.send(client_name); + } if !forward_incoming_message( &transport_event_tx_for_reader, &writer_tx_for_reader, @@ -86,3 +97,15 @@ pub(crate) async fn start_stdio_connection( Ok(()) } + +fn stdio_initialize_client_name(line: &str) -> Option { + let message = serde_json::from_str::(line).ok()?; + let JSONRPCMessage::Request(JSONRPCRequest { method, params, .. }) = message else { + return None; + }; + if method != "initialize" { + return None; + } + let params = serde_json::from_value::(params?).ok()?; + Some(params.client_info.name) +} diff --git a/codex-rs/state/migrations/0024_remote_control_enrollments.sql b/codex-rs/state/migrations/0024_remote_control_enrollments.sql index 247b8d419253..970db9ef2010 100644 --- a/codex-rs/state/migrations/0024_remote_control_enrollments.sql +++ b/codex-rs/state/migrations/0024_remote_control_enrollments.sql @@ -1,9 +1,10 @@ CREATE TABLE remote_control_enrollments ( websocket_url TEXT NOT NULL, account_id TEXT NOT NULL, + app_server_client_name TEXT NOT NULL, server_id TEXT NOT NULL, environment_id TEXT NOT NULL, server_name TEXT NOT NULL, updated_at INTEGER NOT NULL, - PRIMARY KEY (websocket_url, account_id) + PRIMARY KEY (websocket_url, account_id, app_server_client_name) ); diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index ffaa1637e26a..efad7651f5ef 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -46,6 +46,7 @@ pub use model::Stage1StartupClaimParams; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; pub use model::ThreadsPage; +pub use runtime::RemoteControlEnrollmentRecord; pub use runtime::logs_db_filename; pub use runtime::logs_db_path; pub use runtime::state_db_filename; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index 8f28b5ba675d..f0a27b9fb7df 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -59,6 +59,8 @@ mod remote_control; mod test_support; mod threads; +pub use remote_control::RemoteControlEnrollmentRecord; + // "Partition" is the retained-log-content bucket we cap at 10 MiB: // - one bucket per non-null thread_id // - one bucket per threadless (thread_id IS NULL) non-null process_uuid diff --git a/codex-rs/state/src/runtime/remote_control.rs b/codex-rs/state/src/runtime/remote_control.rs index 12e0e7af9113..fa0b1823f8d1 100644 --- a/codex-rs/state/src/runtime/remote_control.rs +++ b/codex-rs/state/src/runtime/remote_control.rs @@ -1,69 +1,96 @@ use super::*; -const REMOTE_CONTROL_ACCOUNT_ID_NONE: &str = ""; +const REMOTE_CONTROL_APP_SERVER_CLIENT_NAME_NONE: &str = ""; -fn remote_control_account_id_key(account_id: Option<&str>) -> &str { - account_id.unwrap_or(REMOTE_CONTROL_ACCOUNT_ID_NONE) +/// Persisted remote-control server enrollment, including the lookup key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RemoteControlEnrollmentRecord { + pub websocket_url: String, + pub account_id: String, + pub app_server_client_name: Option, + pub server_id: String, + pub environment_id: String, + pub server_name: String, +} + +fn remote_control_app_server_client_name_key(app_server_client_name: Option<&str>) -> &str { + app_server_client_name.unwrap_or(REMOTE_CONTROL_APP_SERVER_CLIENT_NAME_NONE) +} + +fn app_server_client_name_from_key(app_server_client_name: String) -> Option { + if app_server_client_name.is_empty() { + None + } else { + Some(app_server_client_name) + } } impl StateRuntime { pub async fn get_remote_control_enrollment( &self, websocket_url: &str, - account_id: Option<&str>, - ) -> anyhow::Result> { + account_id: &str, + app_server_client_name: Option<&str>, + ) -> anyhow::Result> { let row = sqlx::query( r#" -SELECT server_id, environment_id, server_name +SELECT websocket_url, account_id, app_server_client_name, server_id, environment_id, server_name FROM remote_control_enrollments -WHERE websocket_url = ? AND account_id = ? +WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ? "#, ) .bind(websocket_url) - .bind(remote_control_account_id_key(account_id)) + .bind(account_id) + .bind(remote_control_app_server_client_name_key( + app_server_client_name, + )) .fetch_optional(self.pool.as_ref()) .await?; row.map(|row| { - Ok(( - row.try_get("server_id")?, - row.try_get("environment_id")?, - row.try_get("server_name")?, - )) + let app_server_client_name: String = row.try_get("app_server_client_name")?; + Ok(RemoteControlEnrollmentRecord { + websocket_url: row.try_get("websocket_url")?, + account_id: row.try_get("account_id")?, + app_server_client_name: app_server_client_name_from_key(app_server_client_name), + server_id: row.try_get("server_id")?, + environment_id: row.try_get("environment_id")?, + server_name: row.try_get("server_name")?, + }) }) .transpose() } pub async fn upsert_remote_control_enrollment( &self, - websocket_url: &str, - account_id: Option<&str>, - server_id: &str, - environment_id: &str, - server_name: &str, + enrollment: &RemoteControlEnrollmentRecord, ) -> anyhow::Result<()> { sqlx::query( r#" INSERT INTO remote_control_enrollments ( websocket_url, account_id, + app_server_client_name, server_id, environment_id, server_name, updated_at -) VALUES (?, ?, ?, ?, ?, ?) -ON CONFLICT(websocket_url, account_id) DO UPDATE SET +) VALUES (?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(websocket_url, account_id, app_server_client_name) DO UPDATE SET server_id = excluded.server_id, environment_id = excluded.environment_id, server_name = excluded.server_name, updated_at = excluded.updated_at "#, ) - .bind(websocket_url) - .bind(remote_control_account_id_key(account_id)) - .bind(server_id) - .bind(environment_id) - .bind(server_name) + .bind(&enrollment.websocket_url) + .bind(&enrollment.account_id) + .bind(remote_control_app_server_client_name_key( + enrollment.app_server_client_name.as_deref(), + )) + .bind(&enrollment.server_id) + .bind(&enrollment.environment_id) + .bind(&enrollment.server_name) .bind(Utc::now().timestamp()) .execute(self.pool.as_ref()) .await?; @@ -73,16 +100,20 @@ ON CONFLICT(websocket_url, account_id) DO UPDATE SET pub async fn delete_remote_control_enrollment( &self, websocket_url: &str, - account_id: Option<&str>, + account_id: &str, + app_server_client_name: Option<&str>, ) -> anyhow::Result { let result = sqlx::query( r#" DELETE FROM remote_control_enrollments -WHERE websocket_url = ? AND account_id = ? +WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ? "#, ) .bind(websocket_url) - .bind(remote_control_account_id_key(account_id)) + .bind(account_id) + .bind(remote_control_app_server_client_name_key( + app_server_client_name, + )) .execute(self.pool.as_ref()) .await?; Ok(result.rows_affected()) @@ -91,6 +122,7 @@ WHERE websocket_url = ? AND account_id = ? #[cfg(test)] mod tests { + use super::RemoteControlEnrollmentRecord; use super::StateRuntime; use super::test_support::unique_temp_dir; use pretty_assertions::assert_eq; @@ -103,23 +135,27 @@ mod tests { .expect("initialize runtime"); runtime - .upsert_remote_control_enrollment( - "wss://example.com/backend-api/wham/remote/control/server", - Some("account-a"), - "srv_e_first", - "env_first", - "first-server", - ) + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-a".to_string(), + app_server_client_name: Some("desktop-client".to_string()), + server_id: "srv_e_first".to_string(), + environment_id: "env_first".to_string(), + server_name: "first-server".to_string(), + }) .await .expect("insert first enrollment"); runtime - .upsert_remote_control_enrollment( - "wss://example.com/backend-api/wham/remote/control/server", - Some("account-b"), - "srv_e_second", - "env_second", - "second-server", - ) + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-b".to_string(), + app_server_client_name: Some("desktop-client".to_string()), + server_id: "srv_e_second".to_string(), + environment_id: "env_second".to_string(), + server_name: "second-server".to_string(), + }) .await .expect("insert second enrollment"); @@ -127,26 +163,43 @@ mod tests { runtime .get_remote_control_enrollment( "wss://example.com/backend-api/wham/remote/control/server", - Some("account-a"), + "account-a", + Some("desktop-client"), ) .await .expect("load first enrollment"), - Some(( - "srv_e_first".to_string(), - "env_first".to_string(), - "first-server".to_string() - )) + Some(RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-a".to_string(), + app_server_client_name: Some("desktop-client".to_string()), + server_id: "srv_e_first".to_string(), + environment_id: "env_first".to_string(), + server_name: "first-server".to_string(), + }) ); assert_eq!( runtime .get_remote_control_enrollment( "wss://example.com/backend-api/wham/remote/control/server", - /*account_id*/ None, + "account-missing", + Some("desktop-client"), ) .await .expect("load missing enrollment"), None ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-a", + Some("other-client"), + ) + .await + .expect("load wrong client enrollment"), + None + ); let _ = tokio::fs::remove_dir_all(codex_home).await; } @@ -159,23 +212,27 @@ mod tests { .expect("initialize runtime"); runtime - .upsert_remote_control_enrollment( - "wss://example.com/backend-api/wham/remote/control/server", - /*account_id*/ None, - "srv_e_first", - "env_first", - "first-server", - ) + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-a".to_string(), + app_server_client_name: None, + server_id: "srv_e_first".to_string(), + environment_id: "env_first".to_string(), + server_name: "first-server".to_string(), + }) .await .expect("insert first enrollment"); runtime - .upsert_remote_control_enrollment( - "wss://example.com/backend-api/wham/remote/control/server", - Some("account-a"), - "srv_e_second", - "env_second", - "second-server", - ) + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-b".to_string(), + app_server_client_name: None, + server_id: "srv_e_second".to_string(), + environment_id: "env_second".to_string(), + server_name: "second-server".to_string(), + }) .await .expect("insert second enrollment"); @@ -183,7 +240,8 @@ mod tests { runtime .delete_remote_control_enrollment( "wss://example.com/backend-api/wham/remote/control/server", - /*account_id*/ None, + "account-a", + /*app_server_client_name*/ None, ) .await .expect("delete first enrollment"), @@ -193,7 +251,8 @@ mod tests { runtime .get_remote_control_enrollment( "wss://example.com/backend-api/wham/remote/control/server", - /*account_id*/ None, + "account-a", + /*app_server_client_name*/ None, ) .await .expect("load deleted enrollment"), @@ -203,15 +262,20 @@ mod tests { runtime .get_remote_control_enrollment( "wss://example.com/backend-api/wham/remote/control/server", - Some("account-a"), + "account-b", + /*app_server_client_name*/ None, ) .await .expect("load retained enrollment"), - Some(( - "srv_e_second".to_string(), - "env_second".to_string(), - "second-server".to_string() - )) + Some(RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-b".to_string(), + app_server_client_name: None, + server_id: "srv_e_second".to_string(), + environment_id: "env_second".to_string(), + server_name: "second-server".to_string(), + }) ); let _ = tokio::fs::remove_dir_all(codex_home).await;