diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 18e9aa434cab..976485c5c5f5 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -1433,9 +1433,11 @@ dependencies = [ "codex-utils-cli", "codex-utils-json-to-toml", "codex-utils-pty", + "codex-utils-rustls-provider", "constant_time_eq", "core_test_support", "futures", + "gethostname", "hmac", "jsonwebtoken", "opentelemetry", @@ -1458,6 +1460,7 @@ dependencies = [ "tracing", "tracing-opentelemetry", "tracing-subscriber", + "url", "uuid", "wiremock", ] diff --git a/codex-rs/app-server/Cargo.toml b/codex-rs/app-server/Cargo.toml index 4db6b81a21bc..5dc3a31485bb 100644 --- a/codex-rs/app-server/Cargo.toml +++ b/codex-rs/app-server/Cargo.toml @@ -57,10 +57,12 @@ codex-state = { workspace = true } codex-tools = { workspace = true } codex-utils-absolute-path = { workspace = true } codex-utils-json-to-toml = { workspace = true } +codex-utils-rustls-provider = { workspace = true } chrono = { workspace = true } clap = { workspace = true, features = ["derive"] } constant_time_eq = { workspace = true } futures = { workspace = true } +gethostname = { workspace = true } hmac = { workspace = true } jsonwebtoken = { workspace = true } owo-colors = { workspace = true, features = ["supports-colors"] } @@ -81,6 +83,7 @@ tokio-util = { workspace = true } tokio-tungstenite = { workspace = true } tracing = { workspace = true, features = ["log"] } tracing-subscriber = { workspace = true, features = ["env-filter", "fmt", "json"] } +url = { workspace = true } uuid = { workspace = true, features = ["serde", "v7"] } [dev-dependencies] diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index c24c3fd7e860..3660ea02011b 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -25,6 +25,7 @@ Supported transports: - stdio (`--listen stdio://`, default): newline-delimited JSON (JSONL) - websocket (`--listen ws://IP:PORT`): one JSON-RPC message per websocket text frame (**experimental / unsupported**) +- off (`--listen off`): do not expose a local transport When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes: diff --git a/codex-rs/app-server/src/app_server_tracing.rs b/codex-rs/app-server/src/app_server_tracing.rs index 26fe8ca99971..b06a8e52c48d 100644 --- a/codex-rs/app-server/src/app_server_tracing.rs +++ b/codex-rs/app-server/src/app_server_tracing.rs @@ -86,6 +86,7 @@ fn transport_name(transport: AppServerTransport) -> &'static str { match transport { AppServerTransport::Stdio => "stdio", AppServerTransport::WebSocket { .. } => "websocket", + AppServerTransport::Off => "off", } } diff --git a/codex-rs/app-server/src/lib.rs b/codex-rs/app-server/src/lib.rs index 14250c205ebd..5e02cdb94881 100644 --- a/codex-rs/app-server/src/lib.rs +++ b/codex-rs/app-server/src/lib.rs @@ -7,6 +7,8 @@ use codex_core::config::ConfigBuilder; use codex_core::config_loader::CloudRequirementsLoader; use codex_core::config_loader::ConfigLayerStackOrdering; use codex_core::config_loader::LoaderOverrides; +use codex_features::Feature; +use codex_login::AuthManager; use codex_utils_cli::CliConfigOverrides; use std::collections::HashMap; use std::collections::HashSet; @@ -28,6 +30,7 @@ use crate::transport::OutboundConnectionState; use crate::transport::TransportEvent; use crate::transport::auth::policy_from_settings; use crate::transport::route_outgoing_envelope; +use crate::transport::start_remote_control; use crate::transport::start_stdio_connection; use crate::transport::start_websocket_acceptor; use codex_analytics::AppServerRpcTransport; @@ -42,10 +45,10 @@ use codex_core::config_loader::ConfigLoadError; use codex_core::config_loader::TextRange as CoreTextRange; use codex_exec_server::EnvironmentManager; use codex_feedback::CodexFeedback; -use codex_login::AuthManager; use codex_protocol::protocol::SessionSource; use codex_state::log_db; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use toml::Value as TomlValue; @@ -499,13 +502,13 @@ pub async fn run_main_with_transport( let feedback_layer = feedback.logger_layer(); let feedback_metadata_layer = feedback.metadata_layer(); - let log_db = codex_state::StateRuntime::init( + let state_db = codex_state::StateRuntime::init( config.sqlite_home.clone(), config.model_provider_id.clone(), ) .await - .ok() - .map(log_db::start); + .ok(); + let log_db = state_db.clone().map(log_db::start); let log_db_layer = log_db .clone() .map(|layer| layer.with_filter(Targets::new().with_default(Level::TRACE))); @@ -532,11 +535,18 @@ pub async fn run_main_with_transport( let single_client_mode = matches!(&transport, AppServerTransport::Stdio); let shutdown_when_no_connections = single_client_mode; let graceful_signal_restart_enabled = !single_client_mode; + let mut app_server_client_name_rx = None; match transport { AppServerTransport::Stdio => { - start_stdio_connection(transport_event_tx.clone(), &mut transport_accept_handles) - .await?; + let (stdio_client_name_tx, stdio_client_name_rx) = oneshot::channel::(); + app_server_client_name_rx = Some(stdio_client_name_rx); + start_stdio_connection( + transport_event_tx.clone(), + &mut transport_accept_handles, + stdio_client_name_tx, + ) + .await?; } AppServerTransport::WebSocket { bind_address } => { let accept_handle = start_websocket_acceptor( @@ -548,6 +558,29 @@ pub async fn run_main_with_transport( .await?; transport_accept_handles.push(accept_handle); } + AppServerTransport::Off => {} + } + + let auth_manager = + AuthManager::shared_from_config(&config, /*enable_codex_api_key_env*/ false); + + if config.features.enabled(Feature::RemoteControl) { + let accept_handle = start_remote_control( + config.chatgpt_base_url.clone(), + state_db.clone(), + auth_manager.clone(), + transport_event_tx.clone(), + transport_shutdown_token.clone(), + app_server_client_name_rx, + ) + .await?; + transport_accept_handles.push(accept_handle); + } + if transport_accept_handles.is_empty() { + return Err(std::io::Error::new( + ErrorKind::InvalidInput, + "no transport configured; use --listen or enable remote control", + )); } let outbound_handle = tokio::spawn(async move { @@ -852,7 +885,9 @@ pub async fn run_main_with_transport( fn analytics_rpc_transport(transport: AppServerTransport) -> AppServerRpcTransport { match transport { AppServerTransport::Stdio => AppServerRpcTransport::Stdio, - AppServerTransport::WebSocket { .. } => AppServerRpcTransport::Websocket, + AppServerTransport::WebSocket { .. } | AppServerTransport::Off => { + AppServerRpcTransport::Websocket + } } } diff --git a/codex-rs/app-server/src/main.rs b/codex-rs/app-server/src/main.rs index fa95f973ea59..9a23680fb9ce 100644 --- a/codex-rs/app-server/src/main.rs +++ b/codex-rs/app-server/src/main.rs @@ -16,7 +16,7 @@ const MANAGED_CONFIG_PATH_ENV_VAR: &str = "CODEX_APP_SERVER_MANAGED_CONFIG_PATH" #[derive(Debug, Parser)] struct AppServerArgs { /// Transport endpoint URL. Supported values: `stdio://` (default), - /// `ws://IP:PORT`. + /// `ws://IP:PORT`, `off`. #[arg( long = "listen", value_name = "URL", diff --git a/codex-rs/app-server/src/transport/mod.rs b/codex-rs/app-server/src/transport/mod.rs index fa744a1af552..7e1512a79419 100644 --- a/codex-rs/app-server/src/transport/mod.rs +++ b/codex-rs/app-server/src/transport/mod.rs @@ -17,6 +17,7 @@ use std::str::FromStr; use std::sync::Arc; use std::sync::RwLock; use std::sync::atomic::AtomicBool; +use std::sync::atomic::AtomicU64; use std::sync::atomic::Ordering; use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; @@ -28,9 +29,11 @@ use tracing::warn; /// plenty for an interactive CLI. pub(crate) const CHANNEL_CAPACITY: usize = 128; +mod remote_control; mod stdio; mod websocket; +pub(crate) use remote_control::start_remote_control; pub(crate) use stdio::start_stdio_connection; pub(crate) use websocket::start_websocket_acceptor; @@ -38,6 +41,7 @@ pub(crate) use websocket::start_websocket_acceptor; pub enum AppServerTransport { Stdio, WebSocket { bind_address: SocketAddr }, + Off, } #[derive(Debug, Clone, Eq, PartialEq)] @@ -51,7 +55,7 @@ impl std::fmt::Display for AppServerTransportParseError { match self { AppServerTransportParseError::UnsupportedListenUrl(listen_url) => write!( f, - "unsupported --listen URL `{listen_url}`; expected `stdio://` or `ws://IP:PORT`" + "unsupported --listen URL `{listen_url}`; expected `stdio://`, `ws://IP:PORT`, or `off`" ), AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url) => write!( f, @@ -71,6 +75,10 @@ impl AppServerTransport { return Ok(Self::Stdio); } + if listen_url == "off" { + return Ok(Self::Off); + } + if let Some(socket_addr) = listen_url.strip_prefix("ws://") { let bind_address = socket_addr.parse::().map_err(|_| { AppServerTransportParseError::InvalidWebSocketListenUrl(listen_url.to_string()) @@ -166,6 +174,12 @@ impl OutboundConnectionState { } } +static CONNECTION_ID_COUNTER: AtomicU64 = AtomicU64::new(0); + +fn next_connection_id() -> ConnectionId { + ConnectionId(CONNECTION_ID_COUNTER.fetch_add(1, Ordering::Relaxed)) +} + async fn forward_incoming_message( transport_event_tx: &mpsc::Sender, writer: &mpsc::Sender, @@ -378,8 +392,11 @@ pub(crate) async fn route_outgoing_envelope( #[cfg(test)] mod tests { use super::*; - use crate::error_code::OVERLOADED_ERROR_CODE; use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::JSONRPCNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::JSONRPCResponse; + use codex_app_server_protocol::RequestId; use codex_app_server_protocol::ServerNotification; use codex_utils_absolute_path::AbsolutePathBuf; use pretty_assertions::assert_eq; @@ -393,41 +410,10 @@ mod tests { } #[test] - fn app_server_transport_parses_stdio_listen_url() { - let transport = AppServerTransport::from_listen_url(AppServerTransport::DEFAULT_LISTEN_URL) - .expect("stdio listen URL should parse"); - assert_eq!(transport, AppServerTransport::Stdio); - } - - #[test] - fn app_server_transport_parses_websocket_listen_url() { - let transport = AppServerTransport::from_listen_url("ws://127.0.0.1:1234") - .expect("websocket listen URL should parse"); - assert_eq!( - transport, - AppServerTransport::WebSocket { - bind_address: "127.0.0.1:1234".parse().expect("valid socket address"), - } - ); - } - - #[test] - fn app_server_transport_rejects_invalid_websocket_listen_url() { - let err = AppServerTransport::from_listen_url("ws://localhost:1234") - .expect_err("hostname bind address should be rejected"); + fn listen_off_parses_as_off_transport() { assert_eq!( - err.to_string(), - "invalid websocket --listen URL `ws://localhost:1234`; expected `ws://IP:PORT`" - ); - } - - #[test] - fn app_server_transport_rejects_unsupported_listen_url() { - let err = AppServerTransport::from_listen_url("http://127.0.0.1:1234") - .expect_err("unsupported scheme should fail"); - assert_eq!( - err.to_string(), - "unsupported --listen URL `http://127.0.0.1:1234`; expected `stdio://` or `ws://IP:PORT`" + AppServerTransport::from_listen_url("off"), + Ok(AppServerTransport::Off) ); } @@ -437,11 +423,10 @@ mod tests { let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); let (writer_tx, mut writer_rx) = mpsc::channel(1); - let first_message = - JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }); + let first_message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); transport_event_tx .send(TransportEvent::IncomingMessage { connection_id, @@ -450,8 +435,8 @@ mod tests { .await .expect("queue should accept first message"); - let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { - id: codex_app_server_protocol::RequestId::Integer(7), + let request = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(7), method: "config/read".to_string(), params: Some(json!({ "includeLayers": false })), trace: None, @@ -499,11 +484,10 @@ mod tests { let (transport_event_tx, mut transport_event_rx) = mpsc::channel(1); let (writer_tx, _writer_rx) = mpsc::channel(1); - let first_message = - JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }); + let first_message = JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); transport_event_tx .send(TransportEvent::IncomingMessage { connection_id, @@ -512,8 +496,8 @@ mod tests { .await .expect("queue should accept first message"); - let response = JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { - id: codex_app_server_protocol::RequestId::Integer(7), + let response = JSONRPCMessage::Response(JSONRPCResponse { + id: RequestId::Integer(7), result: json!({"ok": true}), }); let transport_event_tx_for_enqueue = transport_event_tx.clone(); @@ -553,11 +537,10 @@ mod tests { match forwarded_event { TransportEvent::IncomingMessage { connection_id: queued_connection_id, - message: - JSONRPCMessage::Response(codex_app_server_protocol::JSONRPCResponse { id, result }), + message: JSONRPCMessage::Response(JSONRPCResponse { id, result }), } => { assert_eq!(queued_connection_id, connection_id); - assert_eq!(id, codex_app_server_protocol::RequestId::Integer(7)); + assert_eq!(id, RequestId::Integer(7)); assert_eq!(result, json!({"ok": true})); } _ => panic!("expected forwarded response message"), @@ -573,12 +556,10 @@ mod tests { transport_event_tx .send(TransportEvent::IncomingMessage { connection_id, - message: JSONRPCMessage::Notification( - codex_app_server_protocol::JSONRPCNotification { - method: "initialized".to_string(), - params: None, - }, - ), + message: JSONRPCMessage::Notification(JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }), }) .await .expect("transport queue should accept first message"); @@ -597,15 +578,15 @@ mod tests { .await .expect("writer queue should accept first message"); - let request = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { - id: codex_app_server_protocol::RequestId::Integer(7), + let request = JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(7), method: "config/read".to_string(), params: Some(json!({ "includeLayers": false })), trace: None, }); - let enqueue_result = tokio::time::timeout( - std::time::Duration::from_millis(100), + let enqueue_result = timeout( + Duration::from_millis(100), enqueue_incoming_message(&transport_event_tx, &writer_tx, connection_id, request), ) .await @@ -781,7 +762,7 @@ mod tests { OutgoingEnvelope::ToConnection { connection_id, message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id: codex_app_server_protocol::RequestId::Integer(1), + request_id: RequestId::Integer(1), params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { thread_id: "thr_123".to_string(), turn_id: "turn_123".to_string(), @@ -843,7 +824,7 @@ mod tests { OutgoingEnvelope::ToConnection { connection_id, message: OutgoingMessage::Request(ServerRequest::CommandExecutionRequestApproval { - request_id: codex_app_server_protocol::RequestId::Integer(1), + request_id: RequestId::Integer(1), params: codex_app_server_protocol::CommandExecutionRequestApprovalParams { thread_id: "thr_123".to_string(), turn_id: "turn_123".to_string(), diff --git a/codex-rs/app-server/src/transport/remote_control/client_tracker.rs b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs new file mode 100644 index 000000000000..fa9a208ade53 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/client_tracker.rs @@ -0,0 +1,568 @@ +use super::CHANNEL_CAPACITY; +use super::TransportEvent; +use super::next_connection_id; +use super::protocol::ClientEnvelope; +pub use super::protocol::ClientEvent; +pub use super::protocol::ClientId; +use super::protocol::PongStatus; +use super::protocol::ServerEvent; +use super::protocol::StreamId; +use crate::outgoing_message::ConnectionId; +use crate::outgoing_message::QueuedOutgoingMessage; +use crate::transport::remote_control::QueuedServerEnvelope; +use codex_app_server_protocol::JSONRPCMessage; +use std::collections::HashMap; +use tokio::sync::mpsc; +use tokio::sync::watch; +use tokio::task::JoinSet; +use tokio::time::Duration; +use tokio::time::Instant; +use tokio_util::sync::CancellationToken; + +const REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT: Duration = Duration::from_secs(10 * 60); +pub(crate) const REMOTE_CONTROL_IDLE_SWEEP_INTERVAL: Duration = Duration::from_secs(30); + +#[derive(Debug)] +pub(crate) struct Stopped; + +struct ClientState { + connection_id: ConnectionId, + disconnect_token: CancellationToken, + last_activity_at: Instant, + last_inbound_seq_id: Option, + status_tx: watch::Sender, +} + +pub(crate) struct ClientTracker { + clients: HashMap<(ClientId, StreamId), ClientState>, + legacy_stream_ids: HashMap, + join_set: JoinSet<(ClientId, StreamId)>, + server_event_tx: mpsc::Sender, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, +} + +impl ClientTracker { + pub(crate) fn new( + server_event_tx: mpsc::Sender, + transport_event_tx: mpsc::Sender, + shutdown_token: &CancellationToken, + ) -> Self { + Self { + clients: HashMap::new(), + legacy_stream_ids: HashMap::new(), + join_set: JoinSet::new(), + server_event_tx, + transport_event_tx, + shutdown_token: shutdown_token.child_token(), + } + } + + pub(crate) async fn bookkeep_join_set(&mut self) -> Option<(ClientId, StreamId)> { + while let Some(join_result) = self.join_set.join_next().await { + let Ok(client_key) = join_result else { + continue; + }; + return Some(client_key); + } + futures::future::pending().await + } + + pub(crate) async fn shutdown(&mut self) { + self.shutdown_token.cancel(); + + while let Some(client_key) = self.clients.keys().next().cloned() { + let _ = self.close_client(&client_key).await; + } + + self.drain_join_set().await; + } + + async fn drain_join_set(&mut self) { + while self.join_set.join_next().await.is_some() {} + } + + pub(crate) async fn handle_message( + &mut self, + client_envelope: ClientEnvelope, + ) -> Result<(), Stopped> { + let ClientEnvelope { + client_id, + event, + stream_id, + seq_id, + cursor: _, + } = client_envelope; + let is_legacy_stream_id = stream_id.is_none(); + let is_initialize = matches!(&event, ClientEvent::ClientMessage { message } if remote_control_message_starts_connection(message)); + let stream_id = match stream_id { + Some(stream_id) => stream_id, + None if is_initialize => { + // TODO(ruslan): delete this fallback once all clients are updated to send stream_id. + self.legacy_stream_ids + .remove(&client_id) + .unwrap_or_else(StreamId::new_random) + } + None => self + .legacy_stream_ids + .get(&client_id) + .cloned() + .unwrap_or_else(|| { + if matches!(&event, ClientEvent::Ping) { + StreamId::new_random() + } else { + StreamId(String::new()) + } + }), + }; + if stream_id.0.is_empty() { + return Ok(()); + } + let client_key = (client_id.clone(), stream_id.clone()); + match event { + ClientEvent::ClientMessage { message } => { + if let Some(seq_id) = seq_id + && let Some(client) = self.clients.get(&client_key) + && client + .last_inbound_seq_id + .is_some_and(|last_seq_id| last_seq_id >= seq_id) + && !is_initialize + { + return Ok(()); + } + + if is_initialize && self.clients.contains_key(&client_key) { + self.close_client(&client_key).await?; + } + + if let Some(connection_id) = self.clients.get_mut(&client_key).map(|client| { + client.last_activity_at = Instant::now(); + if let Some(seq_id) = seq_id { + client.last_inbound_seq_id = Some(seq_id); + } + client.connection_id + }) { + self.send_transport_event(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await?; + return Ok(()); + } + + if !is_initialize { + return Ok(()); + } + + let connection_id = next_connection_id(); + let (writer_tx, writer_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let disconnect_token = self.shutdown_token.child_token(); + self.send_transport_event(TransportEvent::ConnectionOpened { + connection_id, + writer: writer_tx, + disconnect_sender: Some(disconnect_token.clone()), + }) + .await?; + + let (status_tx, status_rx) = watch::channel(PongStatus::Active); + self.join_set.spawn(Self::run_client_outbound( + client_id.clone(), + stream_id.clone(), + self.server_event_tx.clone(), + writer_rx, + status_rx, + disconnect_token.clone(), + )); + self.clients.insert( + client_key, + ClientState { + connection_id, + disconnect_token, + last_activity_at: Instant::now(), + last_inbound_seq_id: if is_legacy_stream_id { None } else { seq_id }, + status_tx, + }, + ); + if is_legacy_stream_id { + self.legacy_stream_ids.insert(client_id.clone(), stream_id); + } + self.send_transport_event(TransportEvent::IncomingMessage { + connection_id, + message, + }) + .await + } + ClientEvent::Ack => Ok(()), + ClientEvent::Ping => { + if let Some(client) = self.clients.get_mut(&client_key) { + client.last_activity_at = Instant::now(); + let _ = client.status_tx.send(PongStatus::Active); + return Ok(()); + } + + let server_event_tx = self.server_event_tx.clone(); + tokio::spawn(async move { + let server_envelope = QueuedServerEnvelope { + event: ServerEvent::Pong { + status: PongStatus::Unknown, + }, + client_id, + stream_id, + write_complete_tx: None, + }; + let _ = server_event_tx.send(server_envelope).await; + }); + Ok(()) + } + ClientEvent::ClientClosed => self.close_client(&client_key).await, + } + } + + async fn run_client_outbound( + client_id: ClientId, + stream_id: StreamId, + server_event_tx: mpsc::Sender, + mut writer_rx: mpsc::Receiver, + mut status_rx: watch::Receiver, + disconnect_token: CancellationToken, + ) -> (ClientId, StreamId) { + loop { + let (event, write_complete_tx) = tokio::select! { + _ = disconnect_token.cancelled() => { + break; + } + queued_message = writer_rx.recv() => { + let Some(queued_message) = queued_message else { + break; + }; + let event = ServerEvent::ServerMessage { + message: Box::new(queued_message.message), + }; + (event, queued_message.write_complete_tx) + } + changed = status_rx.changed() => { + if changed.is_err() { + break; + } + let event = ServerEvent::Pong { status: status_rx.borrow().clone() }; + (event, None) + } + }; + let send_result = tokio::select! { + _ = disconnect_token.cancelled() => { + break; + } + send_result = server_event_tx.send(QueuedServerEnvelope { + event, + client_id: client_id.clone(), + stream_id: stream_id.clone(), + write_complete_tx, + }) => send_result, + }; + if send_result.is_err() { + break; + } + } + (client_id, stream_id) + } + + pub(crate) async fn close_expired_clients( + &mut self, + ) -> Result, Stopped> { + let now = Instant::now(); + let expired_client_ids: Vec<(ClientId, StreamId)> = self + .clients + .iter() + .filter_map(|(client_key, client)| { + (!remote_control_client_is_alive(client, now)).then_some(client_key.clone()) + }) + .collect(); + for client_key in &expired_client_ids { + self.close_client(client_key).await?; + } + Ok(expired_client_ids) + } + + pub(super) async fn close_client( + &mut self, + client_key: &(ClientId, StreamId), + ) -> Result<(), Stopped> { + let Some(client) = self.clients.remove(client_key) else { + return Ok(()); + }; + if self + .legacy_stream_ids + .get(&client_key.0) + .is_some_and(|stream_id| stream_id == &client_key.1) + { + self.legacy_stream_ids.remove(&client_key.0); + } + client.disconnect_token.cancel(); + self.send_transport_event(TransportEvent::ConnectionClosed { + connection_id: client.connection_id, + }) + .await + } + + async fn send_transport_event(&self, event: TransportEvent) -> Result<(), Stopped> { + self.transport_event_tx + .send(event) + .await + .map_err(|_| Stopped) + } +} + +fn remote_control_message_starts_connection(message: &JSONRPCMessage) -> bool { + matches!( + message, + JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { method, .. }) + if method == "initialize" + ) +} + +fn remote_control_client_is_alive(client: &ClientState, now: Instant) -> bool { + now.duration_since(client.last_activity_at) < REMOTE_CONTROL_CLIENT_IDLE_TIMEOUT +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::outgoing_message::OutgoingMessage; + use crate::transport::remote_control::protocol::ClientEnvelope; + use crate::transport::remote_control::protocol::ClientEvent; + use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::JSONRPCRequest; + use codex_app_server_protocol::RequestId; + use codex_app_server_protocol::ServerNotification; + use pretty_assertions::assert_eq; + use serde_json::json; + use tokio::time::timeout; + + fn initialize_envelope(client_id: &str) -> ClientEnvelope { + initialize_envelope_with_stream_id(client_id, /*stream_id*/ None) + } + + fn initialize_envelope_with_stream_id( + client_id: &str, + stream_id: Option<&str>, + ) -> ClientEnvelope { + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Request(JSONRPCRequest { + id: RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }), + }, + client_id: ClientId(client_id.to_string()), + stream_id: stream_id.map(|stream_id| StreamId(stream_id.to_string())), + seq_id: Some(0), + cursor: None, + } + } + + #[tokio::test] + async fn cancelled_outbound_task_emits_connection_closed() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + + let (connection_id, disconnect_sender) = match transport_event_rx + .recv() + .await + .expect("connection opened should be sent") + { + TransportEvent::ConnectionOpened { + connection_id, + disconnect_sender: Some(disconnect_sender), + .. + } => (connection_id, disconnect_sender), + other => panic!("expected connection opened, got {other:?}"), + }; + match transport_event_rx + .recv() + .await + .expect("initialize should be forwarded") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + .. + } => assert_eq!(incoming_connection_id, connection_id), + other => panic!("expected incoming initialize, got {other:?}"), + } + + disconnect_sender.cancel(); + let closed_client_id = timeout(Duration::from_secs(1), client_tracker.bookkeep_join_set()) + .await + .expect("bookkeeping should process the closed task") + .expect("closed task should return client id"); + assert_eq!(closed_client_id.0, ClientId("client-1".to_string())); + client_tracker + .close_client(&closed_client_id) + .await + .expect("closed client should emit connection closed"); + + match transport_event_rx + .recv() + .await + .expect("connection closed should be sent") + { + TransportEvent::ConnectionClosed { + connection_id: closed_connection_id, + } => assert_eq!(closed_connection_id, connection_id), + other => panic!("expected connection closed, got {other:?}"), + } + } + + #[tokio::test] + async fn shutdown_cancels_blocked_outbound_forwarding() { + let (server_event_tx, _server_event_rx) = mpsc::channel(1); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx.clone(), transport_event_tx, &shutdown_token); + + server_event_tx + .send(QueuedServerEnvelope { + event: ServerEvent::Pong { + status: PongStatus::Unknown, + }, + client_id: ClientId("queued-client".to_string()), + stream_id: StreamId("queued-stream".to_string()), + write_complete_tx: None, + }) + .await + .expect("server event queue should accept prefill"); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + + let writer = match transport_event_rx + .recv() + .await + .expect("connection opened should be sent") + { + TransportEvent::ConnectionOpened { writer, .. } => writer, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx + .recv() + .await + .expect("initialize should be forwarded"); + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "test".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("writer should accept queued message"); + + timeout(Duration::from_secs(1), client_tracker.shutdown()) + .await + .expect("shutdown should not hang on blocked server forwarding"); + } + + #[tokio::test] + async fn initialize_with_new_stream_id_opens_new_connection_for_same_client() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope_with_stream_id( + "client-1", + Some("stream-1"), + )) + .await + .expect("first initialize should open client"); + let first_connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx.recv().await.expect("initialize event"); + + client_tracker + .handle_message(initialize_envelope_with_stream_id( + "client-1", + Some("stream-2"), + )) + .await + .expect("second initialize should open client"); + let second_connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + + assert_ne!(first_connection_id, second_connection_id); + } + + #[tokio::test] + async fn legacy_initialize_without_stream_id_resets_inbound_seq_id() { + let (server_event_tx, _server_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let (transport_event_tx, mut transport_event_rx) = mpsc::channel(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let mut client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + + client_tracker + .handle_message(initialize_envelope("client-1")) + .await + .expect("initialize should open client"); + let connection_id = match transport_event_rx.recv().await.expect("open event") { + TransportEvent::ConnectionOpened { connection_id, .. } => connection_id, + other => panic!("expected connection opened, got {other:?}"), + }; + let _ = transport_event_rx.recv().await.expect("initialize event"); + + client_tracker + .handle_message(ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }, + client_id: ClientId("client-1".to_string()), + stream_id: None, + seq_id: Some(0), + cursor: None, + }) + .await + .expect("legacy followup should be forwarded"); + + match transport_event_rx.recv().await.expect("followup event") { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + .. + } => assert_eq!(incoming_connection_id, connection_id), + other => panic!("expected incoming message, got {other:?}"), + } + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/enroll.rs b/codex-rs/app-server/src/transport/remote_control/enroll.rs new file mode 100644 index 000000000000..dbe18c8355db --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/enroll.rs @@ -0,0 +1,503 @@ +use super::protocol::EnrollRemoteServerRequest; +use super::protocol::EnrollRemoteServerResponse; +use super::protocol::RemoteControlTarget; +use axum::http::HeaderMap; +use codex_login::default_client::build_reqwest_client; +use codex_state::RemoteControlEnrollmentRecord; +use codex_state::StateRuntime; +use gethostname::gethostname; +use std::io; +use std::io::ErrorKind; +use tracing::info; +use tracing::warn; + +const REMOTE_CONTROL_ENROLL_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(30); +const REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES: usize = 4096; + +const REQUEST_ID_HEADER: &str = "x-request-id"; +const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id"; +const CF_RAY_HEADER: &str = "cf-ray"; +pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RemoteControlEnrollment { + pub(super) account_id: String, + pub(super) environment_id: String, + pub(super) server_id: String, + pub(super) server_name: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RemoteControlConnectionAuth { + pub(super) bearer_token: String, + pub(super) account_id: String, +} + +pub(super) async fn load_persisted_remote_control_enrollment( + state_db: Option<&StateRuntime>, + remote_control_target: &RemoteControlTarget, + account_id: &str, + app_server_client_name: Option<&str>, +) -> Option { + let Some(state_db) = state_db else { + info!( + "remote control enrollment cache unavailable because sqlite state db is disabled: websocket_url={}, account_id={}, app_server_client_name={:?}", + remote_control_target.websocket_url, account_id, app_server_client_name + ); + return None; + }; + let enrollment = match state_db + .get_remote_control_enrollment( + &remote_control_target.websocket_url, + account_id, + app_server_client_name, + ) + .await + { + Ok(enrollment) => enrollment, + Err(err) => { + warn!( + "failed to load persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, err={err}", + remote_control_target.websocket_url, account_id, app_server_client_name + ); + return None; + } + }; + + match enrollment { + Some(enrollment) => { + info!( + "reusing persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + account_id, + app_server_client_name, + enrollment.server_id, + enrollment.environment_id + ); + Some(RemoteControlEnrollment { + account_id: enrollment.account_id, + environment_id: enrollment.environment_id, + server_id: enrollment.server_id, + server_name: enrollment.server_name, + }) + } + None => { + info!( + "no persisted remote control enrollment found: websocket_url={}, account_id={}, app_server_client_name={:?}", + remote_control_target.websocket_url, account_id, app_server_client_name + ); + None + } + } +} + +pub(super) async fn update_persisted_remote_control_enrollment( + state_db: Option<&StateRuntime>, + remote_control_target: &RemoteControlTarget, + account_id: &str, + app_server_client_name: Option<&str>, + enrollment: Option<&RemoteControlEnrollment>, +) -> io::Result<()> { + let Some(state_db) = state_db else { + info!( + "skipping remote control enrollment persistence because sqlite state db is disabled: websocket_url={}, account_id={}, app_server_client_name={:?}, has_enrollment={}", + remote_control_target.websocket_url, + account_id, + app_server_client_name, + enrollment.is_some() + ); + return Ok(()); + }; + if let &Some(enrollment) = &enrollment + && enrollment.account_id != account_id + { + return Err(io::Error::other(format!( + "enrollment account_id does not match expected account_id `{account_id}`" + ))); + } + + if let Some(enrollment) = enrollment { + state_db + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: remote_control_target.websocket_url.clone(), + account_id: account_id.to_string(), + app_server_client_name: app_server_client_name.map(str::to_string), + server_id: enrollment.server_id.clone(), + environment_id: enrollment.environment_id.clone(), + server_name: enrollment.server_name.clone(), + }) + .await + .map_err(io::Error::other)?; + info!( + "persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + account_id, + app_server_client_name, + enrollment.server_id, + enrollment.environment_id + ); + Ok(()) + } else { + let rows_affected = state_db + .delete_remote_control_enrollment( + &remote_control_target.websocket_url, + account_id, + app_server_client_name, + ) + .await + .map_err(io::Error::other)?; + info!( + "cleared persisted remote control enrollment: websocket_url={}, account_id={}, app_server_client_name={:?}, rows_affected={rows_affected}", + remote_control_target.websocket_url, account_id, app_server_client_name + ); + Ok(()) + } +} + +pub(crate) fn preview_remote_control_response_body(body: &[u8]) -> String { + let body = String::from_utf8_lossy(body); + let trimmed = body.trim(); + if trimmed.is_empty() { + return "".to_string(); + } + if trimmed.len() <= REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES { + return trimmed.to_string(); + } + + let mut cut = REMOTE_CONTROL_RESPONSE_BODY_MAX_BYTES; + while !trimmed.is_char_boundary(cut) { + cut = cut.saturating_sub(1); + } + let mut truncated = trimmed[..cut].to_string(); + truncated.push_str("..."); + truncated +} + +pub(crate) fn format_headers(headers: &HeaderMap) -> String { + let request_id_str = headers + .get(REQUEST_ID_HEADER) + .or_else(|| headers.get(OAI_REQUEST_ID_HEADER)) + .map(|value| value.to_str().unwrap_or("").to_owned()) + .unwrap_or_else(|| "".to_owned()); + let cf_ray_str = headers + .get(CF_RAY_HEADER) + .map(|value| value.to_str().unwrap_or("").to_owned()) + .unwrap_or_else(|| "".to_owned()); + format!("request-id: {request_id_str}, cf-ray: {cf_ray_str}") +} + +pub(super) async fn enroll_remote_control_server( + remote_control_target: &RemoteControlTarget, + auth: &RemoteControlConnectionAuth, +) -> io::Result { + let enroll_url = &remote_control_target.enroll_url; + let server_name = gethostname().to_string_lossy().trim().to_string(); + let request = EnrollRemoteServerRequest { + name: server_name.clone(), + os: std::env::consts::OS, + arch: std::env::consts::ARCH, + app_server_version: env!("CARGO_PKG_VERSION"), + }; + let client = build_reqwest_client(); + let http_request = client + .post(enroll_url) + .timeout(REMOTE_CONTROL_ENROLL_TIMEOUT) + .bearer_auth(&auth.bearer_token) + .header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id) + .json(&request); + + let response = http_request.send().await.map_err(|err| { + io::Error::other(format!( + "failed to enroll remote control server at `{enroll_url}`: {err}" + )) + })?; + let headers = response.headers().clone(); + let status = response.status(); + let body = response.bytes().await.map_err(|err| { + io::Error::other(format!( + "failed to read remote control enrollment response from `{enroll_url}`: {err}" + )) + })?; + let body_preview = preview_remote_control_response_body(&body); + if !status.is_success() { + let headers_str = format_headers(&headers); + let error_kind = if matches!(status.as_u16(), 401 | 403) { + ErrorKind::PermissionDenied + } else { + ErrorKind::Other + }; + return Err(io::Error::new( + error_kind, + format!( + "remote control server enrollment failed at `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}" + ), + )); + } + + let enrollment = serde_json::from_slice::(&body).map_err(|err| { + let headers_str = format_headers(&headers); + io::Error::other(format!( + "failed to parse remote control enrollment response from `{enroll_url}`: HTTP {status}, {headers_str}, body: {body_preview}, decode error: {err}" + )) + })?; + + Ok(RemoteControlEnrollment { + account_id: auth.account_id.clone(), + environment_id: enrollment.environment_id, + server_id: enrollment.server_id, + server_name, + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::transport::remote_control::protocol::normalize_remote_control_url; + use codex_state::StateRuntime; + use pretty_assertions::assert_eq; + use serde_json::json; + use std::sync::Arc; + use tempfile::TempDir; + use tokio::io::AsyncBufReadExt; + use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::time::Duration; + use tokio::time::timeout; + + async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) + .await + .expect("state runtime should initialize") + } + + #[tokio::test] + async fn persisted_remote_control_enrollment_round_trips_by_target_and_account() { + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let first_target = normalize_remote_control_url("https://chatgpt.com/remote/control") + .expect("first target should parse"); + let second_target = + normalize_remote_control_url("https://api.chatgpt-staging.com/other/control") + .expect("second target should parse"); + let first_enrollment = RemoteControlEnrollment { + account_id: "account-a".to_string(), + environment_id: "env_first".to_string(), + server_id: "srv_e_first".to_string(), + server_name: "first-server".to_string(), + }; + let second_enrollment = RemoteControlEnrollment { + account_id: "account-a".to_string(), + environment_id: "env_second".to_string(), + server_id: "srv_e_second".to_string(), + server_name: "second-server".to_string(), + }; + + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + "account-a", + Some("desktop-client"), + Some(&first_enrollment), + ) + .await + .expect("first enrollment should persist"); + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + "account-a", + Some("desktop-client"), + Some(&second_enrollment), + ) + .await + .expect("second enrollment should persist"); + + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + "account-a", + Some("desktop-client"), + ) + .await, + Some(first_enrollment.clone()) + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + "account-b", + Some("desktop-client"), + ) + .await, + None + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + "account-a", + Some("desktop-client"), + ) + .await, + Some(second_enrollment) + ); + } + + #[tokio::test] + async fn clearing_persisted_remote_control_enrollment_removes_only_matching_entry() { + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let first_target = normalize_remote_control_url("https://chatgpt.com/remote/control") + .expect("first target should parse"); + let second_target = + normalize_remote_control_url("https://api.chatgpt-staging.com/other/control") + .expect("second target should parse"); + let first_enrollment = RemoteControlEnrollment { + account_id: "account-a".to_string(), + environment_id: "env_first".to_string(), + server_id: "srv_e_first".to_string(), + server_name: "first-server".to_string(), + }; + let second_enrollment = RemoteControlEnrollment { + account_id: "account-a".to_string(), + environment_id: "env_second".to_string(), + server_id: "srv_e_second".to_string(), + server_name: "second-server".to_string(), + }; + + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + "account-a", + /*app_server_client_name*/ None, + Some(&first_enrollment), + ) + .await + .expect("first enrollment should persist"); + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + "account-a", + /*app_server_client_name*/ None, + Some(&second_enrollment), + ) + .await + .expect("second enrollment should persist"); + + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + "account-a", + /*app_server_client_name*/ None, + /*enrollment*/ None, + ) + .await + .expect("matching enrollment should clear"); + + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &first_target, + "account-a", + /*app_server_client_name*/ None, + ) + .await, + None + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &second_target, + "account-a", + /*app_server_client_name*/ None, + ) + .await, + Some(second_enrollment) + ); + } + + #[tokio::test] + async fn enroll_remote_control_server_parse_failure_includes_response_body() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = format!( + "http://127.0.0.1:{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + .port() + ); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let enroll_url = remote_control_target.enroll_url.clone(); + let response_body = json!({ + "error": "not enrolled", + }); + let expected_body = response_body.to_string(); + let server_task = tokio::spawn(async move { + let stream = accept_http_request(&listener).await; + respond_with_json(stream, response_body).await; + }); + + let err = enroll_remote_control_server( + &remote_control_target, + &RemoteControlConnectionAuth { + bearer_token: "Access Token".to_string(), + account_id: "account_id".to_string(), + }, + ) + .await + .expect_err("invalid response should fail to parse"); + + server_task.await.expect("server task should succeed"); + assert_eq!( + err.to_string(), + format!( + "failed to parse remote control enrollment response from `{enroll_url}`: HTTP 200 OK, request-id: , cf-ray: , body: {expected_body}, decode error: missing field `server_id` at line 1 column {}", + expected_body.len() + ) + ); + } + + async fn accept_http_request(listener: &TcpListener) -> TcpStream { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + let mut reader = BufReader::new(stream); + + let mut request_line = String::new(); + reader + .read_line(&mut request_line) + .await + .expect("request line should read"); + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .await + .expect("header line should read"); + if line == "\r\n" { + break; + } + } + + reader.into_inner() + } + + async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) { + let body = body.to_string(); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/mod.rs b/codex-rs/app-server/src/transport/remote_control/mod.rs new file mode 100644 index 000000000000..6d9d65e8a313 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/mod.rs @@ -0,0 +1,67 @@ +mod client_tracker; +mod enroll; +mod protocol; +mod websocket; + +use crate::transport::remote_control::websocket::RemoteControlWebsocket; +use crate::transport::remote_control::websocket::load_remote_control_auth; + +pub use self::protocol::ClientId; +use self::protocol::ServerEvent; +use self::protocol::StreamId; +use self::protocol::normalize_remote_control_url; +use super::CHANNEL_CAPACITY; +use super::TransportEvent; +use super::next_connection_id; +use codex_login::AuthManager; +use codex_state::StateRuntime; +use std::io; +use std::sync::Arc; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::task::JoinHandle; +use tokio_util::sync::CancellationToken; + +pub(super) struct QueuedServerEnvelope { + pub(super) event: ServerEvent, + pub(super) client_id: ClientId, + pub(super) stream_id: StreamId, + pub(super) write_complete_tx: Option>, +} + +pub(crate) async fn start_remote_control( + remote_control_url: String, + state_db: Option>, + auth_manager: Arc, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, + app_server_client_name_rx: Option>, +) -> io::Result> { + let remote_control_target = normalize_remote_control_url(&remote_control_url)?; + validate_remote_control_auth(&auth_manager).await?; + + Ok(tokio::spawn(async move { + RemoteControlWebsocket::new( + remote_control_target, + state_db, + auth_manager, + transport_event_tx, + shutdown_token, + ) + .run(app_server_client_name_rx) + .await; + })) +} + +pub(crate) async fn validate_remote_control_auth( + auth_manager: &Arc, +) -> io::Result<()> { + match load_remote_control_auth(auth_manager).await { + Ok(_) => Ok(()), + Err(err) if err.kind() == io::ErrorKind::WouldBlock => Ok(()), + Err(err) => Err(err), + } +} + +#[cfg(test)] +mod tests; diff --git a/codex-rs/app-server/src/transport/remote_control/protocol.rs b/codex-rs/app-server/src/transport/remote_control/protocol.rs new file mode 100644 index 000000000000..857855f2a08d --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/protocol.rs @@ -0,0 +1,252 @@ +use crate::outgoing_message::OutgoingMessage; +use codex_app_server_protocol::JSONRPCMessage; +use serde::Deserialize; +use serde::Serialize; +use std::io; +use std::io::ErrorKind; +use url::Host; +use url::Url; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub(super) struct RemoteControlTarget { + pub(super) websocket_url: String, + pub(super) enroll_url: String, +} + +#[derive(Debug, Serialize)] +pub(super) struct EnrollRemoteServerRequest { + pub(super) name: String, + pub(super) os: &'static str, + pub(super) arch: &'static str, + pub(super) app_server_version: &'static str, +} + +#[derive(Debug, Deserialize)] +pub(super) struct EnrollRemoteServerResponse { + pub(super) server_id: String, + pub(super) environment_id: String, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct ClientId(pub String); + +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(transparent)] +pub struct StreamId(pub String); + +impl StreamId { + pub fn new_random() -> Self { + Self(uuid::Uuid::now_v7().to_string()) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ClientEvent { + ClientMessage { + message: JSONRPCMessage, + }, + /// Backend-generated acknowledgement for all server envelopes addressed to + /// `client_id` whose envelope `seq_id` is less than or equal to this ack's + /// `seq_id`. This cursor is client-scoped, not stream-scoped, so receivers + /// must not use `stream_id` to partition acks. + Ack, + Ping, + ClientClosed, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct ClientEnvelope { + #[serde(flatten)] + pub(crate) event: ClientEvent, + #[serde(rename = "client_id")] + pub(crate) client_id: ClientId, + #[serde(rename = "stream_id", skip_serializing_if = "Option::is_none")] + pub(crate) stream_id: Option, + /// For `Ack`, this is the backend-generated per-client cursor over + /// `ServerEnvelope.seq_id`. + #[serde(rename = "seq_id", skip_serializing_if = "Option::is_none")] + pub(crate) seq_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) cursor: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum PongStatus { + Active, + Unknown, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ServerEvent { + ServerMessage { + message: Box, + }, + #[allow(dead_code)] + Ack, + Pong { + status: PongStatus, + }, +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "snake_case")] +pub(crate) struct ServerEnvelope { + #[serde(flatten)] + pub(crate) event: ServerEvent, + #[serde(rename = "client_id")] + pub(crate) client_id: ClientId, + #[serde(rename = "stream_id")] + pub(crate) stream_id: StreamId, + #[serde(rename = "seq_id")] + pub(crate) seq_id: u64, +} + +fn is_allowed_chatgpt_host(host: &Option>) -> bool { + let Some(Host::Domain(host)) = *host else { + return false; + }; + host == "chatgpt.com" + || host == "chatgpt-staging.com" + || host.ends_with(".chatgpt.com") + || host.ends_with(".chatgpt-staging.com") +} + +fn is_localhost(host: &Option>) -> bool { + match host { + Some(Host::Domain("localhost")) => true, + Some(Host::Ipv4(ip)) => ip.is_loopback(), + Some(Host::Ipv6(ip)) => ip.is_loopback(), + _ => false, + } +} + +pub(super) fn normalize_remote_control_url( + remote_control_url: &str, +) -> io::Result { + let map_url_parse_error = |err: url::ParseError| -> io::Error { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control URL `{remote_control_url}`: {err}"), + ) + }; + let map_scheme_error = |_: ()| -> io::Error { + io::Error::new( + ErrorKind::InvalidInput, + format!( + "invalid remote control URL `{remote_control_url}`; expected HTTPS URL for chatgpt.com or chatgpt-staging.com, or HTTP/HTTPS URL for localhost" + ), + ) + }; + + let mut remote_control_url = Url::parse(remote_control_url).map_err(map_url_parse_error)?; + if !remote_control_url.path().ends_with('/') { + let normalized_path = format!("{}/", remote_control_url.path()); + remote_control_url.set_path(&normalized_path); + } + + let enroll_url = remote_control_url + .join("wham/remote/control/server/enroll") + .map_err(map_url_parse_error)?; + let mut websocket_url = remote_control_url + .join("wham/remote/control/server") + .map_err(map_url_parse_error)?; + let host = enroll_url.host(); + match enroll_url.scheme() { + "https" if is_localhost(&host) || is_allowed_chatgpt_host(&host) => { + websocket_url.set_scheme("wss").map_err(map_scheme_error)?; + } + "http" if is_localhost(&host) => { + websocket_url.set_scheme("ws").map_err(map_scheme_error)?; + } + _ => return Err(map_scheme_error(())), + } + + Ok(RemoteControlTarget { + websocket_url: websocket_url.to_string(), + enroll_url: enroll_url.to_string(), + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use pretty_assertions::assert_eq; + + #[test] + fn normalize_remote_control_url_accepts_chatgpt_https_urls() { + assert_eq!( + normalize_remote_control_url("https://chatgpt.com/backend-api") + .expect("chatgpt.com URL should normalize"), + RemoteControlTarget { + websocket_url: "wss://chatgpt.com/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: "https://chatgpt.com/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); + assert_eq!( + normalize_remote_control_url("https://api.chatgpt-staging.com/backend-api") + .expect("chatgpt-staging.com subdomain URL should normalize"), + RemoteControlTarget { + websocket_url: + "wss://api.chatgpt-staging.com/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: + "https://api.chatgpt-staging.com/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); + } + + #[test] + fn normalize_remote_control_url_accepts_localhost_urls() { + assert_eq!( + normalize_remote_control_url("http://localhost:8080/backend-api") + .expect("localhost http URL should normalize"), + RemoteControlTarget { + websocket_url: "ws://localhost:8080/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: "http://localhost:8080/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); + assert_eq!( + normalize_remote_control_url("https://localhost:8443/backend-api") + .expect("localhost https URL should normalize"), + RemoteControlTarget { + websocket_url: "wss://localhost:8443/backend-api/wham/remote/control/server" + .to_string(), + enroll_url: "https://localhost:8443/backend-api/wham/remote/control/server/enroll" + .to_string(), + } + ); + } + + #[test] + fn normalize_remote_control_url_rejects_unsupported_urls() { + for remote_control_url in [ + "http://chatgpt.com/backend-api", + "http://example.com/backend-api", + "https://example.com/backend-api", + "https://chatgpt.com.evil.com/backend-api", + "https://evilchatgpt.com/backend-api", + "https://foo.localhost/backend-api", + ] { + let err = normalize_remote_control_url(remote_control_url) + .expect_err("unsupported URL should be rejected"); + + assert_eq!(err.kind(), ErrorKind::InvalidInput); + assert_eq!( + err.to_string(), + format!( + "invalid remote control URL `{remote_control_url}`; expected HTTPS URL for chatgpt.com or chatgpt-staging.com, or HTTP/HTTPS URL for localhost" + ) + ); + } + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/tests.rs b/codex-rs/app-server/src/transport/remote_control/tests.rs new file mode 100644 index 000000000000..280949adfc00 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/tests.rs @@ -0,0 +1,1310 @@ +use super::enroll::REMOTE_CONTROL_ACCOUNT_ID_HEADER; +use super::enroll::RemoteControlEnrollment; +use super::enroll::load_persisted_remote_control_enrollment; +use super::enroll::update_persisted_remote_control_enrollment; +use super::protocol::ClientEnvelope; +use super::protocol::ClientEvent; +use super::protocol::ClientId; +use super::protocol::normalize_remote_control_url; +use super::websocket::REMOTE_CONTROL_PROTOCOL_VERSION; +use super::*; +use crate::outgoing_message::OutgoingMessage; +use crate::outgoing_message::QueuedOutgoingMessage; +use crate::transport::CHANNEL_CAPACITY; +use crate::transport::TransportEvent; +use base64::Engine; +use codex_app_server_protocol::AuthMode; +use codex_app_server_protocol::ConfigWarningNotification; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::ServerNotification; +use codex_core::test_support::auth_manager_from_auth; +use codex_core::test_support::auth_manager_from_auth_with_home; +use codex_login::AuthCredentialsStoreMode; +use codex_login::AuthDotJson; +use codex_login::AuthManager; +use codex_login::CodexAuth; +use codex_login::save_auth; +use codex_login::token_data::TokenData; +use codex_login::token_data::parse_chatgpt_jwt_claims; +use codex_state::StateRuntime; +use futures::SinkExt; +use futures::StreamExt; +use gethostname::gethostname; +use pretty_assertions::assert_eq; +use serde_json::json; +use std::collections::BTreeMap; +use std::sync::Arc; +use tempfile::TempDir; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::net::TcpListener; +use tokio::net::TcpStream; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::time::Duration; +use tokio::time::timeout; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::accept_async; +use tokio_tungstenite::accept_hdr_async; +use tokio_tungstenite::tungstenite; +use tokio_util::sync::CancellationToken; + +fn remote_control_auth_manager() -> Arc { + auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) +} + +fn remote_control_auth_manager_with_home(codex_home: &TempDir) -> Arc { + auth_manager_from_auth_with_home( + CodexAuth::create_dummy_chatgpt_auth_for_testing(), + codex_home.path().to_path_buf(), + ) +} + +fn remote_control_auth_dot_json(account_id: Option<&str>) -> AuthDotJson { + #[derive(serde::Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_user_id": "user-12345", + "user_id": "user-12345", + "chatgpt_account_id": "account_id" + } + }); + let b64 = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + let header_b64 = b64(&serde_json::to_vec(&header).expect("header should serialize")); + let payload_b64 = b64(&serde_json::to_vec(&payload).expect("payload should serialize")); + let fake_jwt = format!("{header_b64}.{payload_b64}.sig"); + + AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: parse_chatgpt_jwt_claims(&fake_jwt).expect("fake jwt should parse"), + access_token: "Access Token".to_string(), + refresh_token: "refresh-token".to_string(), + account_id: account_id.map(str::to_string), + }), + last_refresh: Some(chrono::Utc::now()), + } +} + +async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) + .await + .expect("state runtime should initialize") +} + +fn remote_control_url_for_listener(listener: &TcpListener) -> String { + format!( + "http://localhost:{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + .port() + ) +} + +#[tokio::test] +async fn remote_control_transport_manages_virtual_clients_and_routes_messages() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start"); + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; + let mut websocket = accept_remote_control_connection(&listener).await; + + let client_id = ClientId("client-1".to_string()); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id: client_id.clone(), + stream_id: None, + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 0, + "status": "unknown", + }) + ); + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Notification( + codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }, + ), + }, + client_id: client_id.clone(), + stream_id: None, + seq_id: Some(0), + cursor: None, + }, + ) + .await; + assert!( + timeout(Duration::from_millis(100), transport_event_rx.recv()) + .await + .is_err(), + "non-initialize client messages should be ignored before connection creation" + ); + + let initialize_message = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: initialize_message.clone(), + }, + client_id: client_id.clone(), + stream_id: None, + seq_id: Some(1), + cursor: None, + }, + ) + .await; + + let (connection_id, writer) = match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection open should arrive in time") + .expect("connection open should exist") + { + TransportEvent::ConnectionOpened { + connection_id, + writer, + .. + } => (connection_id, writer), + other => panic!("expected connection open event, got {other:?}"), + }; + + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("initialize message should arrive in time") + .expect("initialize message should exist") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => { + assert_eq!(incoming_connection_id, connection_id); + assert_eq!(message, initialize_message); + } + other => panic!("expected initialize incoming message, got {other:?}"), + } + + let followup_message = + JSONRPCMessage::Notification(codex_app_server_protocol::JSONRPCNotification { + method: "initialized".to_string(), + params: None, + }); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: followup_message.clone(), + }, + client_id: client_id.clone(), + stream_id: None, + seq_id: Some(2), + cursor: None, + }, + ) + .await; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("followup message should arrive in time") + .expect("followup message should exist") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => { + assert_eq!(incoming_connection_id, connection_id); + assert_eq!(message, followup_message); + } + other => panic!("expected followup incoming message, got {other:?}"), + } + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id: client_id.clone(), + stream_id: None, + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 1, + "status": "active", + }) + ); + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "test".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("remote writer should accept outgoing message"); + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "server_message", + "client_id": "client-1", + "seq_id": 2, + "message": { + "method": "configWarning", + "params": { + "summary": "test", + "details": null, + } + } + }) + ); + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientClosed, + client_id: client_id.clone(), + stream_id: None, + seq_id: None, + cursor: None, + }, + ) + .await; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection close should arrive in time") + .expect("connection close should exist") + { + TransportEvent::ConnectionClosed { + connection_id: closed_connection_id, + } => { + assert_eq!(closed_connection_id, connection_id); + } + other => panic!("expected connection close event, got {other:?}"), + } + + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id, + stream_id: None, + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 3, + "status": "unknown", + }) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_transport_reconnects_after_disconnect() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; + let mut first_websocket = accept_remote_control_connection(&listener).await; + first_websocket + .close(None) + .await + .expect("first websocket should close"); + drop(first_websocket); + + let mut second_websocket = accept_remote_control_connection(&listener).await; + send_client_event( + &mut second_websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(2), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }), + }, + client_id: ClientId("client-2".to_string()), + stream_id: None, + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("reconnected initialize should arrive in time") + .expect("reconnected initialize should exist") + { + TransportEvent::ConnectionOpened { .. } => {} + other => panic!("expected connection open after reconnect, got {other:?}"), + } + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_transport_clears_outgoing_buffer_when_backend_acks() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; + let mut first_websocket = accept_remote_control_connection(&listener).await; + + let client_id = ClientId("client-1".to_string()); + let initialize_message = JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(1), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-test-client", + "version": "0.1.0" + } + })), + trace: None, + }); + send_client_event( + &mut first_websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: initialize_message, + }, + client_id: client_id.clone(), + stream_id: None, + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + let writer = match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection open should arrive in time") + .expect("connection open should exist") + { + TransportEvent::ConnectionOpened { writer, .. } => writer, + other => panic!("expected connection open event, got {other:?}"), + }; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("initialize message should arrive in time") + .expect("initialize message should exist") + { + TransportEvent::IncomingMessage { .. } => {} + other => panic!("expected initialize incoming message, got {other:?}"), + } + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "stale".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("remote writer should accept outgoing message"); + assert_eq!( + read_server_event(&mut first_websocket).await, + json!({ + "type": "server_message", + "client_id": "client-1", + "seq_id": 0, + "message": { + "method": "configWarning", + "params": { + "summary": "stale", + "details": null, + } + } + }) + ); + + send_client_event( + &mut first_websocket, + ClientEnvelope { + event: ClientEvent::Ack, + client_id: client_id.clone(), + stream_id: None, + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + send_client_event( + &mut first_websocket, + ClientEnvelope { + event: ClientEvent::ClientClosed, + client_id: client_id.clone(), + stream_id: None, + seq_id: None, + cursor: None, + }, + ) + .await; + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection close should arrive in time") + .expect("connection close should exist") + { + TransportEvent::ConnectionClosed { .. } => {} + other => panic!("expected connection close event, got {other:?}"), + } + + first_websocket + .close(None) + .await + .expect("first websocket should close"); + drop(first_websocket); + + let mut second_websocket = accept_remote_control_connection(&listener).await; + send_client_event( + &mut second_websocket, + ClientEnvelope { + event: ClientEvent::Ping, + client_id, + stream_id: None, + seq_id: None, + cursor: None, + }, + ) + .await; + assert_eq!( + read_server_event(&mut second_websocket).await, + json!({ + "type": "pong", + "client_id": "client-1", + "seq_id": 1, + "status": "unknown", + }) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_http_mode_enrolls_before_connecting() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let (transport_event_tx, mut transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let expected_server_name = gethostname().to_string_lossy().trim().to_string(); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(remote_control_state_runtime(&codex_home).await), + remote_control_auth_manager(), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start"); + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + assert_eq!( + enroll_request.headers.get("authorization"), + Some(&"Bearer Access Token".to_string()) + ); + assert_eq!( + enroll_request.headers.get(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + Some(&"account_id".to_string()) + ); + assert_eq!( + serde_json::from_str::(&enroll_request.body) + .expect("enroll body should deserialize"), + json!({ + "name": expected_server_name, + "os": std::env::consts::OS, + "arch": std::env::consts::ARCH, + "app_server_version": env!("CARGO_PKG_VERSION"), + }) + ); + respond_with_json( + enroll_request.stream, + json!({ "server_id": "srv_e_test", "environment_id": "env_test" }), + ) + .await; + + let (handshake_request, mut websocket) = + accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.path, + "/backend-api/wham/remote/control/server" + ); + assert_eq!( + handshake_request.headers.get("authorization"), + Some(&"Bearer Access Token".to_string()) + ); + assert_eq!( + handshake_request + .headers + .get(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + Some(&"account_id".to_string()) + ); + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&"srv_e_test".to_string()) + ); + assert_eq!( + handshake_request.headers.get("x-codex-name"), + Some(&base64::engine::general_purpose::STANDARD.encode(&expected_server_name)) + ); + assert_eq!( + handshake_request.headers.get("x-codex-protocol-version"), + Some(&REMOTE_CONTROL_PROTOCOL_VERSION.to_string()) + ); + + let backend_client_id = ClientId("backend-test-client".to_string()); + let writer = { + let initialize_message = + JSONRPCMessage::Request(codex_app_server_protocol::JSONRPCRequest { + id: codex_app_server_protocol::RequestId::Integer(11), + method: "initialize".to_string(), + params: Some(json!({ + "clientInfo": { + "name": "remote-backend-client", + "version": "0.1.0" + } + })), + trace: None, + }); + send_client_event( + &mut websocket, + ClientEnvelope { + event: ClientEvent::ClientMessage { + message: initialize_message.clone(), + }, + client_id: backend_client_id.clone(), + stream_id: None, + seq_id: Some(0), + cursor: None, + }, + ) + .await; + + let (connection_id, writer) = + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("connection open should arrive in time") + .expect("connection open should exist") + { + TransportEvent::ConnectionOpened { + connection_id, + writer, + .. + } => (connection_id, writer), + other => panic!("expected connection open event, got {other:?}"), + }; + + match timeout(Duration::from_secs(5), transport_event_rx.recv()) + .await + .expect("initialize message should arrive in time") + .expect("initialize message should exist") + { + TransportEvent::IncomingMessage { + connection_id: incoming_connection_id, + message, + } => { + assert_eq!(incoming_connection_id, connection_id); + assert_eq!(message, initialize_message); + } + other => panic!("expected initialize incoming message, got {other:?}"), + } + writer + }; + + writer + .send(QueuedOutgoingMessage::new(OutgoingMessage::Response( + crate::outgoing_message::OutgoingResponse { + id: codex_app_server_protocol::RequestId::Integer(11), + result: json!({ + "userAgent": "codex-test-agent" + }), + }, + ))) + .await + .expect("remote writer should accept initialize response"); + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "server_message", + "client_id": backend_client_id.0.clone(), + "seq_id": 0, + "message": { + "id": 11, + "result": { + "userAgent": "codex-test-agent", + } + } + }) + ); + + writer + .send(QueuedOutgoingMessage::new( + OutgoingMessage::AppServerNotification(ServerNotification::ConfigWarning( + ConfigWarningNotification { + summary: "backend".to_string(), + details: None, + path: None, + range: None, + }, + )), + )) + .await + .expect("remote writer should accept outgoing message"); + assert_eq!( + read_server_event(&mut websocket).await, + json!({ + "type": "server_message", + "client_id": backend_client_id.0.clone(), + "seq_id": 1, + "message": { + "method": "configWarning", + "params": { + "summary": "backend", + "details": null, + } + } + }) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_http_mode_reuses_persisted_enrollment_before_reenrolling() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let persisted_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_persisted".to_string(), + server_id: "srv_e_persisted".to_string(), + server_name: "persisted-server".to_string(), + }; + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + "account_id", + /*app_server_client_name*/ None, + Some(&persisted_enrollment), + ) + .await + .expect("persisted enrollment should save"); + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + remote_control_auth_manager_with_home(&codex_home), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start"); + + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.path, + "/backend-api/wham/remote/control/server" + ); + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&persisted_enrollment.server_id) + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + "account_id", + /*app_server_client_name*/ None, + ) + .await, + Some(persisted_enrollment) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_stdio_mode_waits_for_client_name_before_connecting() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let app_server_client_name = "stdio-client"; + let persisted_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_persisted".to_string(), + server_id: "srv_e_persisted".to_string(), + server_name: "persisted-server".to_string(), + }; + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + "account_id", + Some(app_server_client_name), + Some(&persisted_enrollment), + ) + .await + .expect("persisted enrollment should save"); + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let (app_server_client_name_tx, app_server_client_name_rx) = oneshot::channel::(); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + remote_control_auth_manager_with_home(&codex_home), + transport_event_tx, + shutdown_token.clone(), + Some(app_server_client_name_rx), + ) + .await + .expect("remote control should start"); + + timeout(Duration::from_millis(100), listener.accept()) + .await + .expect_err("remote control should wait for the stdio client name"); + + let _ = app_server_client_name_tx.send(app_server_client_name.to_string()); + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&persisted_enrollment.server_id) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_waits_for_account_id_before_enrolling() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json(/*account_id*/ None), + AuthCredentialsStoreMode::File, + ) + .expect("auth without account id should save"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + ); + let expected_server_name = gethostname().to_string_lossy().trim().to_string(); + let expected_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_ready".to_string(), + server_id: "srv_e_ready".to_string(), + server_name: expected_server_name, + }; + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + auth_manager, + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start before account id is available"); + + timeout(Duration::from_millis(100), listener.accept()) + .await + .expect_err("remote control should wait for account id before enrolling"); + + save_auth( + codex_home.path(), + &remote_control_auth_dot_json(Some("account_id")), + AuthCredentialsStoreMode::File, + ) + .expect("auth with account id should save"); + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json( + enroll_request.stream, + json!({ + "server_id": expected_enrollment.server_id, + "environment_id": expected_enrollment.environment_id, + }), + ) + .await; + + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&expected_enrollment.server_id) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[tokio::test] +async fn remote_control_http_mode_clears_stale_persisted_enrollment_after_404() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let expected_server_name = gethostname().to_string_lossy().trim().to_string(); + let stale_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_stale".to_string(), + server_id: "srv_e_stale".to_string(), + server_name: "stale-server".to_string(), + }; + let refreshed_enrollment = RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_refreshed".to_string(), + server_id: "srv_e_refreshed".to_string(), + server_name: expected_server_name, + }; + update_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + "account_id", + /*app_server_client_name*/ None, + Some(&stale_enrollment), + ) + .await + .expect("stale enrollment should save"); + + let (transport_event_tx, _transport_event_rx) = + mpsc::channel::(CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let remote_handle = start_remote_control( + remote_control_url, + Some(state_db.clone()), + remote_control_auth_manager_with_home(&codex_home), + transport_event_tx, + shutdown_token.clone(), + /*app_server_client_name_rx*/ None, + ) + .await + .expect("remote control should start"); + + let websocket_request = accept_http_request(&listener).await; + assert_eq!( + websocket_request.request_line, + "GET /backend-api/wham/remote/control/server HTTP/1.1" + ); + assert_eq!( + websocket_request.headers.get("x-codex-server-id"), + Some(&stale_enrollment.server_id) + ); + respond_with_status(websocket_request.stream, "404 Not Found", "").await; + + let enroll_request = accept_http_request(&listener).await; + assert_eq!( + enroll_request.request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_json( + enroll_request.stream, + json!({ + "server_id": refreshed_enrollment.server_id, + "environment_id": refreshed_enrollment.environment_id, + }), + ) + .await; + + let (handshake_request, _websocket) = accept_remote_control_backend_connection(&listener).await; + assert_eq!( + handshake_request.headers.get("x-codex-server-id"), + Some(&refreshed_enrollment.server_id) + ); + assert_eq!( + load_persisted_remote_control_enrollment( + Some(state_db.as_ref()), + &remote_control_target, + "account_id", + /*app_server_client_name*/ None, + ) + .await, + Some(refreshed_enrollment) + ); + + shutdown_token.cancel(); + let _ = remote_handle.await; +} + +#[derive(Debug)] +struct CapturedHttpRequest { + stream: TcpStream, + request_line: String, + headers: BTreeMap, + body: String, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +struct CapturedWebSocketRequest { + path: String, + headers: BTreeMap, +} + +async fn accept_remote_control_connection(listener: &TcpListener) -> WebSocketStream { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("remote control should connect in time") + .expect("listener accept should succeed"); + accept_async(stream) + .await + .expect("websocket handshake should succeed") +} + +async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + let mut reader = BufReader::new(stream); + + let mut request_line = String::new(); + reader + .read_line(&mut request_line) + .await + .expect("request line should read"); + let request_line = request_line.trim_end_matches("\r\n").to_string(); + + let mut headers = BTreeMap::new(); + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .await + .expect("header line should read"); + if line == "\r\n" { + break; + } + let line = line.trim_end_matches("\r\n"); + let (name, value) = line.split_once(':').expect("header should contain colon"); + headers.insert(name.to_ascii_lowercase(), value.trim().to_string()); + } + + let content_length = headers + .get("content-length") + .and_then(|value| value.parse::().ok()) + .unwrap_or(0); + let mut body = vec![0; content_length]; + reader + .read_exact(&mut body) + .await + .expect("request body should read"); + + CapturedHttpRequest { + stream: reader.into_inner(), + request_line, + headers, + body: String::from_utf8(body).expect("body should be utf-8"), + } +} + +async fn respond_with_json(mut stream: TcpStream, body: serde_json::Value) { + let body = body.to_string(); + let response = format!( + "HTTP/1.1 200 OK\r\ncontent-type: application/json\r\ncontent-length: {}\r\nconnection: close\r\n\r\n{body}", + body.len() + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); +} + +async fn respond_with_status(stream: TcpStream, status: &str, body: &str) { + respond_with_status_and_headers(stream, status, &[], body).await; +} + +async fn respond_with_status_and_headers( + mut stream: TcpStream, + status: &str, + headers: &[(&str, &str)], + body: &str, +) { + let extra_headers = headers + .iter() + .map(|(name, value)| format!("{name}: {value}\r\n")) + .collect::(); + let response = format!( + "HTTP/1.1 {status}\r\ncontent-type: text/plain\r\ncontent-length: {}\r\nconnection: close\r\n{extra_headers}\r\n{body}", + body.len(), + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); +} + +async fn accept_remote_control_backend_connection( + listener: &TcpListener, +) -> (CapturedWebSocketRequest, WebSocketStream) { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("websocket request should arrive in time") + .expect("listener accept should succeed"); + let captured_request = Arc::new(std::sync::Mutex::new(None::)); + let captured_request_for_callback = captured_request.clone(); + let websocket = accept_hdr_async( + stream, + move |request: &tungstenite::handshake::server::Request, + response: tungstenite::handshake::server::Response| { + let headers = request + .headers() + .iter() + .map(|(name, value)| { + ( + name.as_str().to_ascii_lowercase(), + value + .to_str() + .expect("header should be valid utf-8") + .to_string(), + ) + }) + .collect::>(); + *captured_request_for_callback + .lock() + .expect("capture lock should acquire") = Some(CapturedWebSocketRequest { + path: request.uri().path().to_string(), + headers, + }); + Ok(response) + }, + ) + .await + .expect("websocket handshake should succeed"); + let captured_request = captured_request + .lock() + .expect("capture lock should acquire") + .clone() + .expect("websocket request should be captured"); + (captured_request, websocket) +} + +async fn send_client_event( + websocket: &mut WebSocketStream, + client_envelope: ClientEnvelope, +) { + let payload = serde_json::to_string(&client_envelope).expect("client event should serialize"); + websocket + .send(tungstenite::Message::Text(payload.into())) + .await + .expect("client event should send"); +} + +async fn read_server_event(websocket: &mut WebSocketStream) -> serde_json::Value { + loop { + let frame = timeout(Duration::from_secs(5), websocket.next()) + .await + .expect("server event should arrive in time") + .expect("websocket should stay open") + .expect("websocket frame should be readable"); + match frame { + tungstenite::Message::Text(text) => { + let mut event: serde_json::Value = + serde_json::from_str(text.as_ref()).expect("server event should deserialize"); + if let Some(stream_id) = event + .as_object_mut() + .and_then(|event| event.remove("stream_id")) + { + assert!(stream_id.is_string(), "stream_id should be a string"); + } + return event; + } + tungstenite::Message::Ping(payload) => { + websocket + .send(tungstenite::Message::Pong(payload)) + .await + .expect("websocket pong should send"); + } + tungstenite::Message::Pong(_) => {} + tungstenite::Message::Close(frame) => { + panic!("unexpected websocket close frame: {frame:?}"); + } + tungstenite::Message::Binary(_) => { + panic!("unexpected binary websocket frame"); + } + tungstenite::Message::Frame(_) => {} + } + } +} diff --git a/codex-rs/app-server/src/transport/remote_control/websocket.rs b/codex-rs/app-server/src/transport/remote_control/websocket.rs new file mode 100644 index 000000000000..56bc88cc6f48 --- /dev/null +++ b/codex-rs/app-server/src/transport/remote_control/websocket.rs @@ -0,0 +1,1409 @@ +use crate::transport::TransportEvent; +use crate::transport::remote_control::client_tracker::ClientTracker; +use crate::transport::remote_control::client_tracker::REMOTE_CONTROL_IDLE_SWEEP_INTERVAL; +use crate::transport::remote_control::enroll::RemoteControlConnectionAuth; +use crate::transport::remote_control::enroll::RemoteControlEnrollment; +use crate::transport::remote_control::enroll::enroll_remote_control_server; +use crate::transport::remote_control::enroll::format_headers; +use crate::transport::remote_control::enroll::load_persisted_remote_control_enrollment; +use crate::transport::remote_control::enroll::preview_remote_control_response_body; +use crate::transport::remote_control::enroll::update_persisted_remote_control_enrollment; + +use super::protocol::ClientEnvelope; +use super::protocol::ClientEvent; +use super::protocol::ClientId; +use super::protocol::RemoteControlTarget; +use super::protocol::ServerEnvelope; +use axum::http::HeaderValue; +use base64::Engine; +use codex_core::util::backoff; +use codex_login::AuthManager; +use codex_login::UnauthorizedRecovery; +use codex_state::StateRuntime; +use codex_utils_rustls_provider::ensure_rustls_crypto_provider; +use futures::SinkExt; +use futures::StreamExt; +use futures::stream::SplitSink; +use futures::stream::SplitStream; +use std::collections::BTreeMap; +use std::collections::HashMap; +use std::io; +use std::io::ErrorKind; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio::sync::mpsc; +use tokio::sync::oneshot; +use tokio::sync::watch; +use tokio::time::MissedTickBehavior; +use tokio_tungstenite::MaybeTlsStream; +use tokio_tungstenite::WebSocketStream; +use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_util::sync::CancellationToken; +use tracing::error; +use tracing::info; +use tracing::warn; + +pub(super) const REMOTE_CONTROL_PROTOCOL_VERSION: &str = "2"; +pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; +const REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER: &str = "x-codex-subscribe-cursor"; +const REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL: std::time::Duration = + std::time::Duration::from_secs(10); +const REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT: std::time::Duration = + std::time::Duration::from_secs(60); +const REMOTE_CONTROL_ACCOUNT_ID_RETRY_INTERVAL: std::time::Duration = + std::time::Duration::from_secs(1); + +struct BoundedOutboundBuffer { + // Remote-control acks are generated by the backend at client scope, so + // retransmit retention is keyed by client_id only. stream_id stays on each + // envelope for routing, but it is not part of the ack cursor. + buffer_by_client: HashMap>, + used_tx: watch::Sender, +} + +impl BoundedOutboundBuffer { + fn new() -> (Self, watch::Receiver) { + let (used_tx, used_rx) = watch::channel(0); + let buffer = Self { + buffer_by_client: HashMap::new(), + used_tx, + }; + (buffer, used_rx) + } + + fn insert(&mut self, server_envelope: &ServerEnvelope) { + self.buffer_by_client + .entry(server_envelope.client_id.clone()) + .or_default() + .insert(server_envelope.seq_id, server_envelope.clone()); + self.used_tx.send_modify(|used| *used += 1); + } + + fn ack(&mut self, client_id: &ClientId, acked_seq_id: u64) { + let Some(buffer) = self.buffer_by_client.get_mut(client_id) else { + return; + }; + while let Some(seq_id) = buffer.first_key_value().map(|(seq_id, _)| seq_id) + && *seq_id <= acked_seq_id + { + buffer.pop_first(); + self.used_tx.send_modify(|used| *used -= 1); + } + if buffer.is_empty() { + self.buffer_by_client.remove(client_id); + } + } + + fn server_envelopes(&self) -> impl Iterator { + self.buffer_by_client + .values() + .flat_map(|buffer| buffer.values()) + } +} + +struct WebsocketState { + outbound_buffer: BoundedOutboundBuffer, + subscribe_cursor: Option, + next_seq_id: u64, +} + +pub(crate) struct RemoteControlWebsocket { + remote_control_target: RemoteControlTarget, + state_db: Option>, + auth_manager: Arc, + shutdown_token: CancellationToken, + reconnect_attempt: u64, + enrollment: Option, + auth_recovery: UnauthorizedRecovery, + client_tracker: Arc>, + state: Arc>, + server_event_rx: Arc>>, + used_rx: watch::Receiver, +} + +impl RemoteControlWebsocket { + pub(crate) fn new( + remote_control_target: RemoteControlTarget, + state_db: Option>, + auth_manager: Arc, + transport_event_tx: mpsc::Sender, + shutdown_token: CancellationToken, + ) -> Self { + let shutdown_token = shutdown_token.child_token(); + let (server_event_tx, server_event_rx) = mpsc::channel(super::CHANNEL_CAPACITY); + let client_tracker = + ClientTracker::new(server_event_tx, transport_event_tx, &shutdown_token); + let (outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let auth_recovery = auth_manager.unauthorized_recovery(); + + Self { + remote_control_target, + state_db, + auth_manager, + shutdown_token, + reconnect_attempt: 0, + enrollment: None, + auth_recovery, + client_tracker: Arc::new(Mutex::new(client_tracker)), + state: Arc::new(Mutex::new(WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id: 0, + })), + server_event_rx: Arc::new(Mutex::new(server_event_rx)), + used_rx, + } + } + + pub(crate) async fn run( + mut self, + app_server_client_name_rx: Option>, + ) { + let app_server_client_name = match self + .wait_for_app_server_client_name(app_server_client_name_rx) + .await + { + Ok(app_server_client_name) => app_server_client_name, + Err(_) => { + self.client_tracker.lock().await.shutdown().await; + return; + } + }; + + loop { + let shutdown_token = self.shutdown_token.child_token(); + let websocket_connection = match self + .connect(&shutdown_token, app_server_client_name.as_deref()) + .await + { + Some(websocket_connection) => websocket_connection, + None => break, + }; + + self.run_connection(websocket_connection, shutdown_token) + .await; + } + + self.client_tracker.lock().await.shutdown().await; + } + + async fn wait_for_app_server_client_name( + &self, + app_server_client_name_rx: Option>, + ) -> Result, ()> { + match app_server_client_name_rx { + Some(app_server_client_name_rx) => { + tokio::select! { + _ = self.shutdown_token.cancelled() => Err(()), + app_server_client_name = app_server_client_name_rx => match app_server_client_name { + Ok(app_server_client_name) => Ok(Some(app_server_client_name)), + Err(_) => Err(()), + }, + } + } + None => Ok(None), + } + } + + async fn connect( + &mut self, + shutdown_token: &CancellationToken, + app_server_client_name: Option<&str>, + ) -> Option>> { + loop { + let subscribe_cursor = self.state.lock().await.subscribe_cursor.clone(); + tokio::select! { + _ = shutdown_token.cancelled() => return None, + connect_result = connect_remote_control_websocket( + &self.remote_control_target, + self.state_db.as_deref(), + &self.auth_manager, + &mut self.auth_recovery, + &mut self.enrollment, + subscribe_cursor.as_deref(), + app_server_client_name, + ) => { + match connect_result { + Ok((websocket_connection, response)) => { + self.reconnect_attempt = 0; + self.auth_recovery = self.auth_manager.unauthorized_recovery(); + info!( + "connected to app-server remote control websocket: {}, {}", + self.remote_control_target.websocket_url, + format_headers(response.headers()) + ); + return Some(websocket_connection); + } + Err(err) => { + let reconnect_delay = if err.kind() == ErrorKind::WouldBlock { + info!("{err}"); + REMOTE_CONTROL_ACCOUNT_ID_RETRY_INTERVAL + } else { + warn!("{err}"); + let reconnect_delay = backoff(self.reconnect_attempt); + self.reconnect_attempt += 1; + reconnect_delay + }; + tokio::select! { + _ = shutdown_token.cancelled() => return None, + _ = tokio::time::sleep(reconnect_delay) => {} + } + } + } + } + } + } + } + + async fn run_connection( + &self, + websocket_connection: WebSocketStream>, + shutdown_token: CancellationToken, + ) { + let (websocket_writer, websocket_reader) = websocket_connection.split(); + let mut join_set = tokio::task::JoinSet::new(); + + join_set.spawn(Self::run_server_writer( + self.state.clone(), + self.server_event_rx.clone(), + self.used_rx.clone(), + websocket_writer, + REMOTE_CONTROL_WEBSOCKET_PING_INTERVAL, + shutdown_token.clone(), + )); + join_set.spawn(Self::run_websocket_reader( + self.client_tracker.clone(), + self.state.clone(), + websocket_reader, + REMOTE_CONTROL_WEBSOCKET_PONG_TIMEOUT, + shutdown_token.clone(), + )); + + tokio::select! { + _ = shutdown_token.cancelled() => {} + _ = join_set.join_next() => shutdown_token.cancel(), + } + + join_set.join_all().await; + } + + async fn run_server_writer( + state: Arc>, + server_event_rx: Arc>>, + used_rx: watch::Receiver, + websocket_writer: SplitSink< + WebSocketStream>, + tungstenite::Message, + >, + ping_interval: std::time::Duration, + shutdown_token: CancellationToken, + ) { + let result = Self::run_server_writer_inner( + state, + server_event_rx, + used_rx, + websocket_writer, + ping_interval, + shutdown_token, + ) + .await; + if let Err(err) = result { + warn!("remote control websocket writer disconnected, err: {err}"); + } else { + warn!("remote control websocket writer was stopped"); + } + } + + async fn run_server_writer_inner( + state: Arc>, + server_event_rx: Arc>>, + mut used_rx: watch::Receiver, + mut websocket_writer: SplitSink< + WebSocketStream>, + tungstenite::Message, + >, + ping_interval: std::time::Duration, + shutdown_token: CancellationToken, + ) -> io::Result<()> { + for server_envelope in state.lock().await.outbound_buffer.server_envelopes() { + let payload = match serde_json::to_string(&server_envelope) { + Ok(payload) => payload, + Err(err) => { + error!("failed to serialize remote-control server event: {err}"); + continue; + } + }; + tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => { + if let Err(err) = send_result { + return Err(io::Error::other(err)); + } + } + }; + } + + let mut ping_interval = + tokio::time::interval_at(tokio::time::Instant::now() + ping_interval, ping_interval); + ping_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + + let mut server_event_rx = server_event_rx.lock().await; + loop { + let outbound_has_capacity = *used_rx.borrow() < super::CHANNEL_CAPACITY; + let queued_server_envelope = tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + _ = ping_interval.tick() => { + if let Err(err) = websocket_writer + .send(tungstenite::Message::Ping(Vec::new().into())) + .await + { + return Err(io::Error::other(err)); + } + continue; + } + wait_result = used_rx.changed(), if !outbound_has_capacity => + { + if wait_result.is_err() { + return Err(io::Error::new( + ErrorKind::UnexpectedEof, + "outbound buffer usage channel closed", + )); + } + continue; + } + recv_result = server_event_rx.recv(), if outbound_has_capacity => { + match recv_result { + Some(queued_server_envelope) => queued_server_envelope, + None => { + return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "server event channel closed")); + } + } + } + }; + let (server_envelope, write_complete_tx) = { + let mut state = state.lock().await; + let seq_id = state.next_seq_id; + state.next_seq_id = state.next_seq_id.saturating_add(1); + + let server_envelope = ServerEnvelope { + event: queued_server_envelope.event, + client_id: queued_server_envelope.client_id, + seq_id, + stream_id: queued_server_envelope.stream_id, + }; + state.outbound_buffer.insert(&server_envelope); + + (server_envelope, queued_server_envelope.write_complete_tx) + }; + + let payload = match serde_json::to_string(&server_envelope) { + Ok(payload) => payload, + Err(err) => { + error!("failed to serialize remote-control server event: {err}"); + continue; + } + }; + + tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + send_result = websocket_writer.send(tungstenite::Message::Text(payload.into())) => { + if let Err(err) = send_result { + return Err(io::Error::other(err)); + } + } + }; + if let Some(write_complete_tx) = write_complete_tx { + let _ = write_complete_tx.send(()); + } + } + } + + async fn run_websocket_reader( + client_tracker: Arc>, + state: Arc>, + websocket_reader: SplitStream>>, + pong_timeout: std::time::Duration, + shutdown_token: CancellationToken, + ) { + let result = Self::run_websocket_reader_inner( + client_tracker, + state, + websocket_reader, + pong_timeout, + shutdown_token, + ) + .await; + if let Err(err) = result { + warn!("remote control websocket reader disconnected, err: {err}"); + } else { + warn!("remote control websocket reader was stopped"); + } + } + + async fn run_websocket_reader_inner( + client_tracker: Arc>, + state: Arc>, + mut websocket_reader: SplitStream>>, + pong_timeout: std::time::Duration, + shutdown_token: CancellationToken, + ) -> io::Result<()> { + let mut client_tracker = client_tracker.lock().await; + let mut idle_sweep_interval = tokio::time::interval(REMOTE_CONTROL_IDLE_SWEEP_INTERVAL); + idle_sweep_interval.set_missed_tick_behavior(MissedTickBehavior::Skip); + let pong_deadline = tokio::time::sleep(pong_timeout); + tokio::pin!(pong_deadline); + + loop { + let incoming_message = tokio::select! { + _ = shutdown_token.cancelled() => return Ok(()), + _ = &mut pong_deadline => { + return Err(io::Error::new( + ErrorKind::TimedOut, + "remote control websocket pong timeout", + )); + } + client_key = client_tracker.bookkeep_join_set() => { + let Some(client_key) = client_key else { + continue; + }; + if client_tracker.close_client(&client_key).await.is_err() { + return Ok(()); + } + continue; + } + _ = idle_sweep_interval.tick() => { + if client_tracker.close_expired_clients().await.is_err() { + return Ok(()); + } + continue; + } + incoming_message = websocket_reader.next() => { + match incoming_message { + Some(incoming_message) => incoming_message, + None => return Err(io::Error::new(ErrorKind::UnexpectedEof, "websocket stream ended")), + } + } + }; + let client_envelope = match incoming_message { + Ok(tungstenite::Message::Text(text)) => { + match serde_json::from_str::(&text) { + Ok(client_envelope) => client_envelope, + Err(err) => { + warn!("failed to deserialize remote-control client event: {err}"); + continue; + } + } + } + Ok(tungstenite::Message::Pong(_)) => { + pong_deadline + .as_mut() + .reset(tokio::time::Instant::now() + pong_timeout); + continue; + } + Ok(tungstenite::Message::Ping(_)) | Ok(tungstenite::Message::Frame(_)) => continue, + Ok(tungstenite::Message::Binary(_)) => { + warn!("dropping unsupported binary remote-control websocket message"); + continue; + } + Ok(tungstenite::Message::Close(_)) => { + return Err(io::Error::new( + ErrorKind::ConnectionAborted, + "websocket disconnected", + )); + } + Err(err) => { + return Err(io::Error::new( + ErrorKind::InvalidData, + format!("failed to read from websocket: {err}"), + )); + } + }; + + let mut websocket_state = state.lock().await; + if let Some(cursor) = client_envelope.cursor.as_deref() { + websocket_state.subscribe_cursor = Some(cursor.to_string()); + } + if let ClientEvent::Ack = &client_envelope.event + && let Some(acked_seq_id) = client_envelope.seq_id + { + websocket_state + .outbound_buffer + .ack(&client_envelope.client_id, acked_seq_id); + } + drop(websocket_state); + + if client_tracker + .handle_message(client_envelope) + .await + .is_err() + { + return Ok(()); + } + } + } +} + +fn set_remote_control_header( + headers: &mut tungstenite::http::HeaderMap, + name: &'static str, + value: &str, +) -> io::Result<()> { + let header_value = HeaderValue::from_str(value).map_err(|err| { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control header `{name}`: {err}"), + ) + })?; + headers.insert(name, header_value); + Ok(()) +} + +fn build_remote_control_websocket_request( + websocket_url: &str, + enrollment: &RemoteControlEnrollment, + auth: &RemoteControlConnectionAuth, + subscribe_cursor: Option<&str>, +) -> io::Result> { + let mut request = websocket_url.into_client_request().map_err(|err| { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control websocket URL `{websocket_url}`: {err}"), + ) + })?; + let headers = request.headers_mut(); + set_remote_control_header(headers, "x-codex-server-id", &enrollment.server_id)?; + set_remote_control_header( + headers, + "x-codex-name", + &base64::engine::general_purpose::STANDARD.encode(&enrollment.server_name), + )?; + set_remote_control_header( + headers, + "x-codex-protocol-version", + REMOTE_CONTROL_PROTOCOL_VERSION, + )?; + set_remote_control_header( + headers, + "authorization", + &format!("Bearer {}", auth.bearer_token), + )?; + set_remote_control_header(headers, REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id)?; + if let Some(subscribe_cursor) = subscribe_cursor { + set_remote_control_header( + headers, + REMOTE_CONTROL_SUBSCRIBE_CURSOR_HEADER, + subscribe_cursor, + )?; + } + Ok(request) +} + +pub(crate) async fn load_remote_control_auth( + auth_manager: &Arc, +) -> io::Result { + let mut reloaded = false; + let auth = loop { + let Some(auth) = auth_manager.auth().await else { + if reloaded { + return Err(io::Error::new( + ErrorKind::PermissionDenied, + "remote control requires ChatGPT authentication", + )); + } + auth_manager.reload(); + reloaded = true; + continue; + }; + if !auth.is_chatgpt_auth() { + break auth; + } + if auth.get_account_id().is_none() && !reloaded { + auth_manager.reload(); + reloaded = true; + continue; + } + break auth; + }; + + if !auth.is_chatgpt_auth() { + return Err(io::Error::new( + ErrorKind::PermissionDenied, + "remote control requires ChatGPT authentication; API key auth is not supported", + )); + } + + Ok(RemoteControlConnectionAuth { + bearer_token: auth.get_token().map_err(io::Error::other)?, + account_id: auth.get_account_id().ok_or_else(|| { + io::Error::new( + ErrorKind::WouldBlock, + "remote control enrollment is waiting for a ChatGPT account id", + ) + })?, + }) +} + +pub(super) async fn connect_remote_control_websocket( + remote_control_target: &RemoteControlTarget, + state_db: Option<&StateRuntime>, + auth_manager: &Arc, + auth_recovery: &mut UnauthorizedRecovery, + enrollment: &mut Option, + subscribe_cursor: Option<&str>, + app_server_client_name: Option<&str>, +) -> io::Result<( + WebSocketStream>, + tungstenite::http::Response<()>, +)> { + ensure_rustls_crypto_provider(); + + let auth = load_remote_control_auth(auth_manager).await?; + let enrollment_account_id = enrollment.as_ref().map(|enrollment| &enrollment.account_id); + if enrollment_account_id.is_some_and(|account_id| account_id != &auth.account_id) { + info!( + "clearing in-memory remote control enrollment because account id changed: websocket_url={}, previous_account_id={:?}, current_account_id={:?}", + remote_control_target.websocket_url, + enrollment + .as_ref() + .map(|enrollment| enrollment.account_id.as_str()), + auth.account_id + ); + *enrollment = None; + } + + if enrollment.is_none() { + *enrollment = load_persisted_remote_control_enrollment( + state_db, + remote_control_target, + &auth.account_id, + app_server_client_name, + ) + .await; + } + + if enrollment.is_none() { + info!( + "creating new remote control enrollment: websocket_url={}, enroll_url={}, account_id={}", + remote_control_target.websocket_url, remote_control_target.enroll_url, auth.account_id + ); + let new_enrollment = match enroll_remote_control_server(remote_control_target, &auth).await + { + Ok(new_enrollment) => new_enrollment, + Err(err) + if err.kind() == ErrorKind::PermissionDenied + && recover_remote_control_auth(auth_recovery).await => + { + return Err(io::Error::other(format!( + "{err}; retrying after auth recovery" + ))); + } + Err(err) => return Err(err), + }; + if let Err(err) = update_persisted_remote_control_enrollment( + state_db, + remote_control_target, + &auth.account_id, + app_server_client_name, + Some(&new_enrollment), + ) + .await + { + warn!("failed to persist remote control enrollment in sqlite state db: {err}"); + } + info!( + "created new remote control enrollment: websocket_url={}, account_id={}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + new_enrollment.account_id, + new_enrollment.server_id, + new_enrollment.environment_id + ); + *enrollment = Some(new_enrollment); + } + + let enrollment_ref = enrollment.as_ref().ok_or_else(|| { + io::Error::other("missing remote control enrollment after enrollment step") + })?; + let request = build_remote_control_websocket_request( + &remote_control_target.websocket_url, + enrollment_ref, + &auth, + subscribe_cursor, + )?; + + match connect_async(request).await { + Ok((websocket_stream, response)) => Ok((websocket_stream, response.map(|_| ()))), + Err(err) => { + match &err { + tungstenite::Error::Http(response) if response.status().as_u16() == 404 => { + info!( + "remote control websocket returned HTTP 404; clearing stale enrollment before re-enrolling: websocket_url={}, account_id={}, server_id={}, environment_id={}", + remote_control_target.websocket_url, + auth.account_id, + enrollment_ref.server_id, + enrollment_ref.environment_id + ); + if let Err(clear_err) = update_persisted_remote_control_enrollment( + state_db, + remote_control_target, + &auth.account_id, + app_server_client_name, + /*enrollment*/ None, + ) + .await + { + warn!( + "failed to clear stale remote control enrollment in sqlite state db: {clear_err}" + ); + } + *enrollment = None; + } + tungstenite::Error::Http(response) + if matches!(response.status().as_u16(), 401 | 403) => + { + if recover_remote_control_auth(auth_recovery).await { + return Err(io::Error::other(format!( + "remote control websocket auth failed with HTTP {}; retrying after auth recovery", + response.status() + ))); + } + } + _ => {} + } + Err(io::Error::other( + format_remote_control_websocket_connect_error( + &remote_control_target.websocket_url, + &err, + ), + )) + } + } +} + +async fn recover_remote_control_auth(auth_recovery: &mut UnauthorizedRecovery) -> bool { + if !auth_recovery.has_next() { + return false; + } + + let mode = auth_recovery.mode_name(); + let step = auth_recovery.step_name(); + match auth_recovery.next().await { + Ok(step_result) => { + info!( + "remote control websocket auth recovery succeeded: mode={mode}, step={step}, auth_state_changed={:?}", + step_result.auth_state_changed() + ); + true + } + Err(err) => { + warn!("remote control websocket auth recovery failed: mode={mode}, step={step}: {err}"); + false + } + } +} + +fn format_remote_control_websocket_connect_error( + websocket_url: &str, + err: &tungstenite::Error, +) -> String { + let mut message = + format!("failed to connect app-server remote control websocket `{websocket_url}`: {err}"); + let tungstenite::Error::Http(response) = err else { + return message; + }; + + message.push_str(&format!(", {}", format_headers(response.headers()))); + if let Some(body) = response.body().as_ref() + && !body.is_empty() + { + let body_preview = preview_remote_control_response_body(body); + message.push_str(&format!(", body: {body_preview}")); + } + + message +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::outgoing_message::OutgoingMessage; + use crate::transport::remote_control::ServerEvent; + use crate::transport::remote_control::protocol::StreamId; + use crate::transport::remote_control::protocol::normalize_remote_control_url; + use chrono::Utc; + use codex_app_server_protocol::AuthMode; + use codex_app_server_protocol::ConfigWarningNotification; + use codex_app_server_protocol::ServerNotification; + use codex_core::test_support::auth_manager_from_auth; + use codex_login::AuthCredentialsStoreMode; + use codex_login::AuthDotJson; + use codex_login::CodexAuth; + use codex_login::save_auth; + use codex_login::token_data::TokenData; + use codex_login::token_data::parse_chatgpt_jwt_claims; + use codex_state::StateRuntime; + use futures::StreamExt; + use pretty_assertions::assert_eq; + use std::sync::Arc; + use tempfile::TempDir; + use tokio::io::AsyncBufReadExt; + use tokio::io::AsyncWriteExt; + use tokio::io::BufReader; + use tokio::net::TcpListener; + use tokio::net::TcpStream; + use tokio::sync::mpsc; + use tokio::time::Duration; + use tokio::time::timeout; + use tokio_tungstenite::accept_async; + + async fn remote_control_state_runtime(codex_home: &TempDir) -> Arc { + StateRuntime::init(codex_home.path().to_path_buf(), "test-provider".to_string()) + .await + .expect("state runtime should initialize") + } + + fn remote_control_auth_manager() -> Arc { + auth_manager_from_auth(CodexAuth::create_dummy_chatgpt_auth_for_testing()) + } + + fn remote_control_url_for_listener(listener: &TcpListener) -> String { + format!( + "http://localhost:{}/backend-api/", + listener + .local_addr() + .expect("listener should have a local addr") + .port() + ) + } + + fn remote_control_auth_dot_json(access_token: &str) -> AuthDotJson { + #[derive(serde::Serialize)] + struct Header { + alg: &'static str, + typ: &'static str, + } + + let header = Header { + alg: "none", + typ: "JWT", + }; + let payload = serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_user_id": "user-12345", + "user_id": "user-12345", + "chatgpt_account_id": "account_id" + } + }); + let b64 = |bytes: &[u8]| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes); + let header_b64 = b64(&serde_json::to_vec(&header).expect("header should serialize")); + let payload_b64 = b64(&serde_json::to_vec(&payload).expect("payload should serialize")); + let fake_jwt = format!("{header_b64}.{payload_b64}.sig"); + + AuthDotJson { + auth_mode: Some(AuthMode::Chatgpt), + openai_api_key: None, + tokens: Some(TokenData { + id_token: parse_chatgpt_jwt_claims(&fake_jwt).expect("fake jwt should parse"), + access_token: access_token.to_string(), + refresh_token: "refresh-token".to_string(), + account_id: Some("account_id".to_string()), + }), + last_refresh: Some(Utc::now()), + } + } + + #[tokio::test] + async fn connect_remote_control_websocket_includes_http_error_details() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let expected_error = format!( + "failed to connect app-server remote control websocket `{}`: HTTP error: 503 Service Unavailable, request-id: , cf-ray: , body: upstream unavailable", + remote_control_target.websocket_url + ); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "GET /backend-api/wham/remote/control/server HTTP/1.1" + ); + respond_with_status_and_headers( + stream, + "503 Service Unavailable", + &[("x-trace-id", "trace-503"), ("x-region", "us-east-1")], + "upstream unavailable", + ) + .await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = remote_control_auth_manager(); + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut enrollment = Some(RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_test".to_string(), + server_id: "srv_e_test".to_string(), + server_name: "test-server".to_string(), + }); + + let err = match connect_remote_control_websocket( + &remote_control_target, + Some(state_db.as_ref()), + &auth_manager, + &mut auth_recovery, + &mut enrollment, + /*subscribe_cursor*/ None, + /*app_server_client_name*/ None, + ) + .await + { + Ok(_) => panic!("http error response should fail the websocket connect"), + Err(err) => err, + }; + + server_task.await.expect("server task should succeed"); + assert_eq!(err.to_string(), expected_error); + } + + #[tokio::test] + async fn connect_remote_control_websocket_recovers_after_unauthorized_reload() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "GET /backend-api/wham/remote/control/server HTTP/1.1" + ); + respond_with_status_and_headers(stream, "401 Unauthorized", &[], "unauthorized").await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("stale-token"), + AuthCredentialsStoreMode::File, + ) + .expect("stale auth should save"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + ); + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut enrollment = Some(RemoteControlEnrollment { + account_id: "account_id".to_string(), + environment_id: "env_test".to_string(), + server_id: "srv_e_test".to_string(), + server_name: "test-server".to_string(), + }); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("fresh-token"), + AuthCredentialsStoreMode::File, + ) + .expect("fresh auth should save"); + + let err = connect_remote_control_websocket( + &remote_control_target, + Some(state_db.as_ref()), + &auth_manager, + &mut auth_recovery, + &mut enrollment, + /*subscribe_cursor*/ None, + /*app_server_client_name*/ None, + ) + .await + .expect_err("unauthorized response should fail the websocket connect"); + + server_task.await.expect("server task should succeed"); + assert_eq!( + err.to_string(), + "remote control websocket auth failed with HTTP 401 Unauthorized; retrying after auth recovery" + ); + assert_eq!( + auth_manager + .auth() + .await + .expect("auth should remain available") + .get_token() + .expect("token should be readable"), + "fresh-token" + ); + } + + #[tokio::test] + async fn connect_remote_control_websocket_recovers_after_unauthorized_enrollment() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let enroll_url = remote_control_target.enroll_url.clone(); + let server_task = tokio::spawn(async move { + let (stream, request_line) = accept_http_request(&listener).await; + assert_eq!( + request_line, + "POST /backend-api/wham/remote/control/server/enroll HTTP/1.1" + ); + respond_with_status_and_headers(stream, "401 Unauthorized", &[], "unauthorized").await; + }); + let codex_home = TempDir::new().expect("temp dir should create"); + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("stale-token"), + AuthCredentialsStoreMode::File, + ) + .expect("stale auth should save"); + let state_db = remote_control_state_runtime(&codex_home).await; + let auth_manager = AuthManager::shared( + codex_home.path().to_path_buf(), + /*enable_codex_api_key_env*/ false, + AuthCredentialsStoreMode::File, + ); + let mut auth_recovery = auth_manager.unauthorized_recovery(); + let mut enrollment = None; + save_auth( + codex_home.path(), + &remote_control_auth_dot_json("fresh-token"), + AuthCredentialsStoreMode::File, + ) + .expect("fresh auth should save"); + + let err = connect_remote_control_websocket( + &remote_control_target, + Some(state_db.as_ref()), + &auth_manager, + &mut auth_recovery, + &mut enrollment, + /*subscribe_cursor*/ None, + /*app_server_client_name*/ None, + ) + .await + .expect_err("unauthorized enrollment should fail the websocket connect"); + + server_task.await.expect("server task should succeed"); + assert_eq!( + err.to_string(), + format!( + "remote control server enrollment failed at `{enroll_url}`: HTTP 401 Unauthorized, request-id: , cf-ray: , body: unauthorized; retrying after auth recovery" + ) + ); + assert_eq!( + auth_manager + .auth() + .await + .expect("auth should remain available") + .get_token() + .expect("token should be readable"), + "fresh-token" + ); + } + + #[tokio::test] + async fn run_remote_control_websocket_loop_shutdown_cancels_reconnect_backoff() { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let remote_control_url = remote_control_url_for_listener(&listener); + drop(listener); + + let remote_control_target = + normalize_remote_control_url(&remote_control_url).expect("target should parse"); + let (transport_event_tx, transport_event_rx) = mpsc::channel(1); + drop(transport_event_rx); + let shutdown_token = CancellationToken::new(); + let websocket_task = tokio::spawn({ + let shutdown_token = shutdown_token.clone(); + async move { + RemoteControlWebsocket::new( + remote_control_target, + /*state_db*/ None, + remote_control_auth_manager(), + transport_event_tx, + shutdown_token, + ) + .run(/*app_server_client_name_rx*/ None) + .await + } + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + shutdown_token.cancel(); + + timeout(Duration::from_millis(100), websocket_task) + .await + .expect("shutdown should cancel reconnect backoff") + .expect("websocket task should join"); + } + + #[tokio::test] + async fn run_server_writer_inner_sends_periodic_ping_frames() { + let (client_stream, mut server_stream) = connected_websocket_pair().await; + let (websocket_writer, _websocket_reader) = client_stream.split(); + let (outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let state = Arc::new(Mutex::new(WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id: 0, + })); + let (_server_event_tx, server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); + let server_event_rx = Arc::new(Mutex::new(server_event_rx)); + let shutdown_token = CancellationToken::new(); + let writer_task = tokio::spawn(RemoteControlWebsocket::run_server_writer_inner( + state, + server_event_rx, + used_rx, + websocket_writer, + Duration::from_millis(20), + shutdown_token.clone(), + )); + + let message = timeout(Duration::from_secs(5), server_stream.next()) + .await + .expect("ping frame should arrive in time") + .expect("server websocket should stay open") + .expect("ping frame should read"); + assert!(matches!(message, tungstenite::Message::Ping(_))); + + shutdown_token.cancel(); + writer_task + .await + .expect("writer task should join") + .expect("writer should stop cleanly"); + } + + #[tokio::test] + async fn run_websocket_reader_inner_times_out_without_pong_frames() { + let (client_stream, _server_stream) = connected_websocket_pair().await; + let (_websocket_writer, websocket_reader) = client_stream.split(); + let (outbound_buffer, _used_rx) = BoundedOutboundBuffer::new(); + let state = Arc::new(Mutex::new(WebsocketState { + outbound_buffer, + subscribe_cursor: None, + next_seq_id: 0, + })); + let (server_event_tx, _server_event_rx) = mpsc::channel(super::super::CHANNEL_CAPACITY); + let (transport_event_tx, _transport_event_rx) = + mpsc::channel(super::super::CHANNEL_CAPACITY); + let shutdown_token = CancellationToken::new(); + let client_tracker = Arc::new(Mutex::new(ClientTracker::new( + server_event_tx, + transport_event_tx, + &shutdown_token, + ))); + + let err = timeout( + Duration::from_secs(5), + RemoteControlWebsocket::run_websocket_reader_inner( + client_tracker, + state, + websocket_reader, + Duration::from_millis(100), + shutdown_token, + ), + ) + .await + .expect("reader should time out waiting for pong") + .expect_err("missing pong should fail the websocket reader"); + + assert_eq!(err.kind(), ErrorKind::TimedOut); + assert_eq!(err.to_string(), "remote control websocket pong timeout"); + } + + #[test] + fn outbound_buffer_acks_by_client_id_across_stream_ids() { + let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let client_1 = ClientId("client-1".to_string()); + let client_2 = ClientId("client-2".to_string()); + + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-1", + /*seq_id*/ 0, + "first-client-old-stream", + )); + outbound_buffer.insert(&server_envelope( + &client_2, + "stream-1", + /*seq_id*/ 1, + "second-client", + )); + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-2", + /*seq_id*/ 2, + "first-client-new-stream", + )); + + outbound_buffer.ack(&client_1, /*acked_seq_id*/ 2); + + let mut retained = outbound_buffer + .server_envelopes() + .map(|server_envelope| { + ( + server_envelope.client_id.0.as_str(), + server_envelope.stream_id.0.as_str(), + server_envelope.seq_id, + ) + }) + .collect::>(); + retained.sort_unstable(); + assert_eq!(retained, vec![("client-2", "stream-1", 1)]); + assert_eq!(*used_rx.borrow(), 1); + } + + #[test] + fn outbound_buffer_retains_unacked_messages_until_ack_advances() { + let (mut outbound_buffer, used_rx) = BoundedOutboundBuffer::new(); + let client_1 = ClientId("client-1".to_string()); + let client_2 = ClientId("client-2".to_string()); + + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-1", + /*seq_id*/ 0, + "first-old", + )); + outbound_buffer.insert(&server_envelope( + &client_1, + "stream-2", + /*seq_id*/ 1, + "first-new", + )); + outbound_buffer.insert(&server_envelope( + &client_2, "stream-1", /*seq_id*/ 2, "second", + )); + + outbound_buffer.ack(&client_1, /*acked_seq_id*/ 0); + + let mut retained = outbound_buffer + .server_envelopes() + .map(|server_envelope| { + ( + server_envelope.client_id.0.as_str(), + server_envelope.stream_id.0.as_str(), + server_envelope.seq_id, + ) + }) + .collect::>(); + retained.sort_unstable(); + assert_eq!( + retained, + vec![("client-1", "stream-2", 1), ("client-2", "stream-1", 2)] + ); + assert_eq!(*used_rx.borrow(), 2); + } + + fn server_envelope( + client_id: &ClientId, + stream_id: &str, + seq_id: u64, + summary: &str, + ) -> ServerEnvelope { + ServerEnvelope { + event: ServerEvent::ServerMessage { + message: Box::new(OutgoingMessage::AppServerNotification( + ServerNotification::ConfigWarning(ConfigWarningNotification { + summary: summary.to_string(), + details: None, + path: None, + range: None, + }), + )), + }, + client_id: client_id.clone(), + stream_id: StreamId(stream_id.to_string()), + seq_id, + } + } + + async fn accept_http_request(listener: &TcpListener) -> (TcpStream, String) { + let (stream, _) = timeout(Duration::from_secs(5), listener.accept()) + .await + .expect("HTTP request should arrive in time") + .expect("listener accept should succeed"); + let mut reader = BufReader::new(stream); + + let mut request_line = String::new(); + reader + .read_line(&mut request_line) + .await + .expect("request line should read"); + loop { + let mut line = String::new(); + reader + .read_line(&mut line) + .await + .expect("header line should read"); + if line == "\r\n" { + break; + } + } + + ( + reader.into_inner(), + request_line.trim_end_matches("\r\n").to_string(), + ) + } + + async fn connected_websocket_pair() -> ( + WebSocketStream>, + WebSocketStream, + ) { + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("listener should bind"); + let connect_task = tokio::spawn(connect_async(format!( + "ws://{}", + listener + .local_addr() + .expect("listener should have a local addr") + ))); + let (server_stream, _) = listener + .accept() + .await + .expect("server should accept client"); + let server_stream = accept_async(server_stream) + .await + .expect("server websocket handshake should succeed"); + let (client_stream, _) = connect_task + .await + .expect("client connect task should join") + .expect("client websocket handshake should succeed"); + + (client_stream, server_stream) + } + + async fn respond_with_status_and_headers( + mut stream: TcpStream, + status: &str, + headers: &[(&str, &str)], + body: &str, + ) { + let extra_headers = headers + .iter() + .map(|(name, value)| format!("{name}: {value}\r\n")) + .collect::(); + let response = format!( + "HTTP/1.1 {status}\r\ncontent-type: text/plain\r\ncontent-length: {}\r\nconnection: close\r\n{extra_headers}\r\n{body}", + body.len(), + ); + stream + .write_all(response.as_bytes()) + .await + .expect("response should write"); + stream.flush().await.expect("response should flush"); + } +} diff --git a/codex-rs/app-server/src/transport/stdio.rs b/codex-rs/app-server/src/transport/stdio.rs index 4f2bf267455b..20eab025fb72 100644 --- a/codex-rs/app-server/src/transport/stdio.rs +++ b/codex-rs/app-server/src/transport/stdio.rs @@ -1,9 +1,12 @@ use super::CHANNEL_CAPACITY; use super::TransportEvent; use super::forward_incoming_message; +use super::next_connection_id; use super::serialize_outgoing_message; -use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; +use codex_app_server_protocol::InitializeParams; +use codex_app_server_protocol::JSONRPCMessage; +use codex_app_server_protocol::JSONRPCRequest; use std::io::ErrorKind; use std::io::Result as IoResult; use tokio::io; @@ -11,6 +14,7 @@ use tokio::io::AsyncBufReadExt; use tokio::io::AsyncWriteExt; use tokio::io::BufReader; use tokio::sync::mpsc; +use tokio::sync::oneshot; use tokio::task::JoinHandle; use tracing::debug; use tracing::error; @@ -19,8 +23,9 @@ use tracing::info; pub(crate) async fn start_stdio_connection( transport_event_tx: mpsc::Sender, stdio_handles: &mut Vec>, + initialize_client_name_tx: oneshot::Sender, ) -> IoResult<()> { - let connection_id = ConnectionId(0); + let connection_id = next_connection_id(); let (writer_tx, mut writer_rx) = mpsc::channel::(CHANNEL_CAPACITY); let writer_tx_for_reader = writer_tx.clone(); transport_event_tx @@ -37,10 +42,16 @@ pub(crate) async fn start_stdio_connection( let stdin = io::stdin(); let reader = BufReader::new(stdin); let mut lines = reader.lines(); + let mut initialize_client_name_tx = Some(initialize_client_name_tx); loop { match lines.next_line().await { Ok(Some(line)) => { + if let Some(client_name) = stdio_initialize_client_name(&line) + && let Some(initialize_client_name_tx) = initialize_client_name_tx.take() + { + let _ = initialize_client_name_tx.send(client_name); + } if !forward_incoming_message( &transport_event_tx_for_reader, &writer_tx_for_reader, @@ -86,3 +97,15 @@ pub(crate) async fn start_stdio_connection( Ok(()) } + +fn stdio_initialize_client_name(line: &str) -> Option { + let message = serde_json::from_str::(line).ok()?; + let JSONRPCMessage::Request(JSONRPCRequest { method, params, .. }) = message else { + return None; + }; + if method != "initialize" { + return None; + } + let params = serde_json::from_value::(params?).ok()?; + Some(params.client_info.name) +} diff --git a/codex-rs/app-server/src/transport/websocket.rs b/codex-rs/app-server/src/transport/websocket.rs index 05dfe24b05a6..41b138e216dd 100644 --- a/codex-rs/app-server/src/transport/websocket.rs +++ b/codex-rs/app-server/src/transport/websocket.rs @@ -4,6 +4,7 @@ use super::auth::WebsocketAuthPolicy; use super::auth::authorize_upgrade; use super::auth::should_warn_about_unauthenticated_non_loopback_listener; use super::forward_incoming_message; +use super::next_connection_id; use super::serialize_outgoing_message; use crate::outgoing_message::ConnectionId; use crate::outgoing_message::QueuedOutgoingMessage; @@ -32,8 +33,6 @@ use owo_colors::Style; use std::io::Result as IoResult; use std::net::SocketAddr; use std::sync::Arc; -use std::sync::atomic::AtomicU64; -use std::sync::atomic::Ordering; use tokio::net::TcpListener; use tokio::sync::mpsc; use tokio::task::JoinHandle; @@ -75,7 +74,6 @@ fn print_websocket_startup_banner(addr: SocketAddr) { #[derive(Clone)] struct WebSocketListenerState { transport_event_tx: mpsc::Sender, - connection_counter: Arc, auth_policy: Arc, } @@ -113,7 +111,7 @@ async fn websocket_upgrade_handler( ); return (err.status_code(), err.message()).into_response(); } - let connection_id = ConnectionId(state.connection_counter.fetch_add(1, Ordering::Relaxed)); + let connection_id = next_connection_id(); info!(%peer_addr, "websocket client connected"); websocket .on_upgrade(move |stream| async move { @@ -146,7 +144,6 @@ pub(crate) async fn start_websocket_acceptor( .layer(middleware::from_fn(reject_requests_with_origin_header)) .with_state(WebSocketListenerState { transport_event_tx, - connection_counter: Arc::new(AtomicU64::new(1)), auth_policy: Arc::new(auth_policy), }); let server = axum::serve( diff --git a/codex-rs/cli/src/main.rs b/codex-rs/cli/src/main.rs index 1c4f5068fdef..3007048a2f35 100644 --- a/codex-rs/cli/src/main.rs +++ b/codex-rs/cli/src/main.rs @@ -346,7 +346,7 @@ struct AppServerCommand { subcommand: Option, /// Transport endpoint URL. Supported values: `stdio://` (default), - /// `ws://IP:PORT`. + /// `ws://IP:PORT`, `off`. #[arg( long = "listen", value_name = "URL", @@ -1993,6 +1993,12 @@ mod tests { ); } + #[test] + fn app_server_listen_off_parses() { + let app_server = app_server_from_args(["codex", "app-server", "--listen", "off"].as_ref()); + assert_eq!(app_server.listen, codex_app_server::AppServerTransport::Off); + } + #[test] fn app_server_listen_invalid_url_fails_to_parse() { let parse_result = diff --git a/codex-rs/core/config.schema.json b/codex-rs/core/config.schema.json index 7f5be535eaba..5b8e5dfc1658 100644 --- a/codex-rs/core/config.schema.json +++ b/codex-rs/core/config.schema.json @@ -437,6 +437,9 @@ "realtime_conversation": { "type": "boolean" }, + "remote_control": { + "type": "boolean" + }, "remote_models": { "type": "boolean" }, @@ -2144,6 +2147,9 @@ "realtime_conversation": { "type": "boolean" }, + "remote_control": { + "type": "boolean" + }, "remote_models": { "type": "boolean" }, diff --git a/codex-rs/features/src/lib.rs b/codex-rs/features/src/lib.rs index c87b6dc9ff3e..959268b61cfa 100644 --- a/codex-rs/features/src/lib.rs +++ b/codex-rs/features/src/lib.rs @@ -176,6 +176,8 @@ pub enum Feature { FastMode, /// Enable experimental realtime voice conversation mode in the TUI. RealtimeConversation, + /// Connect app-server to the ChatGPT remote control service. + RemoteControl, /// Removed compatibility flag. The TUI now always uses the app-server implementation. TuiAppServer, /// Prevent idle system sleep while a turn is actively running. @@ -825,6 +827,12 @@ pub const FEATURES: &[FeatureSpec] = &[ stage: Stage::UnderDevelopment, default_enabled: false, }, + FeatureSpec { + id: Feature::RemoteControl, + key: "remote_control", + stage: Stage::UnderDevelopment, + default_enabled: false, + }, FeatureSpec { id: Feature::TuiAppServer, key: "tui_app_server", diff --git a/codex-rs/features/src/tests.rs b/codex-rs/features/src/tests.rs index 0777262065d0..3f4a370cb284 100644 --- a/codex-rs/features/src/tests.rs +++ b/codex-rs/features/src/tests.rs @@ -165,6 +165,12 @@ fn image_detail_original_feature_is_under_development() { assert_eq!(Feature::ImageDetailOriginal.default_enabled(), false); } +#[test] +fn remote_control_is_under_development() { + assert_eq!(Feature::RemoteControl.stage(), Stage::UnderDevelopment); + assert_eq!(Feature::RemoteControl.default_enabled(), false); +} + #[test] fn collab_is_legacy_alias_for_multi_agent() { assert_eq!(feature_for_key("multi_agent"), Some(Feature::Collab)); diff --git a/codex-rs/state/migrations/0024_remote_control_enrollments.sql b/codex-rs/state/migrations/0024_remote_control_enrollments.sql new file mode 100644 index 000000000000..970db9ef2010 --- /dev/null +++ b/codex-rs/state/migrations/0024_remote_control_enrollments.sql @@ -0,0 +1,10 @@ +CREATE TABLE remote_control_enrollments ( + websocket_url TEXT NOT NULL, + account_id TEXT NOT NULL, + app_server_client_name TEXT NOT NULL, + server_id TEXT NOT NULL, + environment_id TEXT NOT NULL, + server_name TEXT NOT NULL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (websocket_url, account_id, app_server_client_name) +); diff --git a/codex-rs/state/src/lib.rs b/codex-rs/state/src/lib.rs index ffaa1637e26a..efad7651f5ef 100644 --- a/codex-rs/state/src/lib.rs +++ b/codex-rs/state/src/lib.rs @@ -46,6 +46,7 @@ pub use model::Stage1StartupClaimParams; pub use model::ThreadMetadata; pub use model::ThreadMetadataBuilder; pub use model::ThreadsPage; +pub use runtime::RemoteControlEnrollmentRecord; pub use runtime::logs_db_filename; pub use runtime::logs_db_path; pub use runtime::state_db_filename; diff --git a/codex-rs/state/src/runtime.rs b/codex-rs/state/src/runtime.rs index d3d81d87d361..f71b6adf0506 100644 --- a/codex-rs/state/src/runtime.rs +++ b/codex-rs/state/src/runtime.rs @@ -54,10 +54,13 @@ mod agent_jobs; mod backfill; mod logs; mod memories; +mod remote_control; #[cfg(test)] mod test_support; mod threads; +pub use remote_control::RemoteControlEnrollmentRecord; + // "Partition" is the retained-log-content bucket we cap at 10 MiB: // - one bucket per non-null thread_id // - one bucket per threadless (thread_id IS NULL) non-null process_uuid diff --git a/codex-rs/state/src/runtime/remote_control.rs b/codex-rs/state/src/runtime/remote_control.rs new file mode 100644 index 000000000000..fa0b1823f8d1 --- /dev/null +++ b/codex-rs/state/src/runtime/remote_control.rs @@ -0,0 +1,283 @@ +use super::*; + +const REMOTE_CONTROL_APP_SERVER_CLIENT_NAME_NONE: &str = ""; + +/// Persisted remote-control server enrollment, including the lookup key. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct RemoteControlEnrollmentRecord { + pub websocket_url: String, + pub account_id: String, + pub app_server_client_name: Option, + pub server_id: String, + pub environment_id: String, + pub server_name: String, +} + +fn remote_control_app_server_client_name_key(app_server_client_name: Option<&str>) -> &str { + app_server_client_name.unwrap_or(REMOTE_CONTROL_APP_SERVER_CLIENT_NAME_NONE) +} + +fn app_server_client_name_from_key(app_server_client_name: String) -> Option { + if app_server_client_name.is_empty() { + None + } else { + Some(app_server_client_name) + } +} + +impl StateRuntime { + pub async fn get_remote_control_enrollment( + &self, + websocket_url: &str, + account_id: &str, + app_server_client_name: Option<&str>, + ) -> anyhow::Result> { + let row = sqlx::query( + r#" +SELECT websocket_url, account_id, app_server_client_name, server_id, environment_id, server_name +FROM remote_control_enrollments +WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ? + "#, + ) + .bind(websocket_url) + .bind(account_id) + .bind(remote_control_app_server_client_name_key( + app_server_client_name, + )) + .fetch_optional(self.pool.as_ref()) + .await?; + + row.map(|row| { + let app_server_client_name: String = row.try_get("app_server_client_name")?; + Ok(RemoteControlEnrollmentRecord { + websocket_url: row.try_get("websocket_url")?, + account_id: row.try_get("account_id")?, + app_server_client_name: app_server_client_name_from_key(app_server_client_name), + server_id: row.try_get("server_id")?, + environment_id: row.try_get("environment_id")?, + server_name: row.try_get("server_name")?, + }) + }) + .transpose() + } + + pub async fn upsert_remote_control_enrollment( + &self, + enrollment: &RemoteControlEnrollmentRecord, + ) -> anyhow::Result<()> { + sqlx::query( + r#" +INSERT INTO remote_control_enrollments ( + websocket_url, + account_id, + app_server_client_name, + server_id, + environment_id, + server_name, + updated_at +) VALUES (?, ?, ?, ?, ?, ?, ?) +ON CONFLICT(websocket_url, account_id, app_server_client_name) DO UPDATE SET + server_id = excluded.server_id, + environment_id = excluded.environment_id, + server_name = excluded.server_name, + updated_at = excluded.updated_at + "#, + ) + .bind(&enrollment.websocket_url) + .bind(&enrollment.account_id) + .bind(remote_control_app_server_client_name_key( + enrollment.app_server_client_name.as_deref(), + )) + .bind(&enrollment.server_id) + .bind(&enrollment.environment_id) + .bind(&enrollment.server_name) + .bind(Utc::now().timestamp()) + .execute(self.pool.as_ref()) + .await?; + Ok(()) + } + + pub async fn delete_remote_control_enrollment( + &self, + websocket_url: &str, + account_id: &str, + app_server_client_name: Option<&str>, + ) -> anyhow::Result { + let result = sqlx::query( + r#" +DELETE FROM remote_control_enrollments +WHERE websocket_url = ? AND account_id = ? AND app_server_client_name = ? + "#, + ) + .bind(websocket_url) + .bind(account_id) + .bind(remote_control_app_server_client_name_key( + app_server_client_name, + )) + .execute(self.pool.as_ref()) + .await?; + Ok(result.rows_affected()) + } +} + +#[cfg(test)] +mod tests { + use super::RemoteControlEnrollmentRecord; + use super::StateRuntime; + use super::test_support::unique_temp_dir; + use pretty_assertions::assert_eq; + + #[tokio::test] + async fn remote_control_enrollment_round_trips_by_target_and_account() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + runtime + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-a".to_string(), + app_server_client_name: Some("desktop-client".to_string()), + server_id: "srv_e_first".to_string(), + environment_id: "env_first".to_string(), + server_name: "first-server".to_string(), + }) + .await + .expect("insert first enrollment"); + runtime + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-b".to_string(), + app_server_client_name: Some("desktop-client".to_string()), + server_id: "srv_e_second".to_string(), + environment_id: "env_second".to_string(), + server_name: "second-server".to_string(), + }) + .await + .expect("insert second enrollment"); + + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-a", + Some("desktop-client"), + ) + .await + .expect("load first enrollment"), + Some(RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-a".to_string(), + app_server_client_name: Some("desktop-client".to_string()), + server_id: "srv_e_first".to_string(), + environment_id: "env_first".to_string(), + server_name: "first-server".to_string(), + }) + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-missing", + Some("desktop-client"), + ) + .await + .expect("load missing enrollment"), + None + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-a", + Some("other-client"), + ) + .await + .expect("load wrong client enrollment"), + None + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } + + #[tokio::test] + async fn delete_remote_control_enrollment_removes_only_matching_entry() { + let codex_home = unique_temp_dir(); + let runtime = StateRuntime::init(codex_home.clone(), "test-provider".to_string()) + .await + .expect("initialize runtime"); + + runtime + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-a".to_string(), + app_server_client_name: None, + server_id: "srv_e_first".to_string(), + environment_id: "env_first".to_string(), + server_name: "first-server".to_string(), + }) + .await + .expect("insert first enrollment"); + runtime + .upsert_remote_control_enrollment(&RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-b".to_string(), + app_server_client_name: None, + server_id: "srv_e_second".to_string(), + environment_id: "env_second".to_string(), + server_name: "second-server".to_string(), + }) + .await + .expect("insert second enrollment"); + + assert_eq!( + runtime + .delete_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-a", + /*app_server_client_name*/ None, + ) + .await + .expect("delete first enrollment"), + 1 + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-a", + /*app_server_client_name*/ None, + ) + .await + .expect("load deleted enrollment"), + None + ); + assert_eq!( + runtime + .get_remote_control_enrollment( + "wss://example.com/backend-api/wham/remote/control/server", + "account-b", + /*app_server_client_name*/ None, + ) + .await + .expect("load retained enrollment"), + Some(RemoteControlEnrollmentRecord { + websocket_url: "wss://example.com/backend-api/wham/remote/control/server" + .to_string(), + account_id: "account-b".to_string(), + app_server_client_name: None, + server_id: "srv_e_second".to_string(), + environment_id: "env_second".to_string(), + server_name: "second-server".to_string(), + }) + ); + + let _ = tokio::fs::remove_dir_all(codex_home).await; + } +}