diff --git a/codex-rs/app-server/src/fs_api.rs b/codex-rs/app-server/src/fs_api.rs index 9baa2b1dcec7..3ce485aaa0b8 100644 --- a/codex-rs/app-server/src/fs_api.rs +++ b/codex-rs/app-server/src/fs_api.rs @@ -20,8 +20,8 @@ use codex_app_server_protocol::FsWriteFileResponse; use codex_app_server_protocol::JSONRPCErrorError; use codex_exec_server::CopyOptions; use codex_exec_server::CreateDirectoryOptions; -use codex_exec_server::Environment; use codex_exec_server::ExecutorFileSystem; +use codex_exec_server::LocalFileSystem; use codex_exec_server::RemoveOptions; use std::io; use std::sync::Arc; @@ -34,7 +34,7 @@ pub(crate) struct FsApi { impl Default for FsApi { fn default() -> Self { Self { - file_system: Arc::new(Environment::default().get_filesystem()), + file_system: Arc::new(LocalFileSystem), } } } diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index a7680e73e8db..7b6a2810ebda 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -25,9 +25,9 @@ use tracing::debug; use tracing::warn; use crate::client_api::ExecServerClientConnectOptions; -use crate::client_api::ExecServerEvent; use crate::client_api::RemoteExecServerConnectArgs; use crate::connection::JsonRpcConnection; +use crate::process::ExecServerEvent; use crate::protocol::EXEC_EXITED_METHOD; use crate::protocol::EXEC_METHOD; use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index 962d3ba36483..6e89763416f3 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -1,8 +1,5 @@ use std::time::Duration; -use crate::protocol::ExecExitedNotification; -use crate::protocol::ExecOutputDeltaNotification; - /// Connection options for any exec-server client transport. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExecServerClientConnectOptions { @@ -18,10 +15,3 @@ pub struct RemoteExecServerConnectArgs { pub connect_timeout: Duration, pub initialize_timeout: Duration, } - -/// Connection-level server events. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum ExecServerEvent { - OutputDelta(ExecOutputDeltaNotification), - Exited(ExecExitedNotification), -} diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index c8635ec03a0b..641d4ed2622a 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -1,13 +1,18 @@ use crate::ExecServerClient; +use crate::ExecServerClientConnectOptions; use crate::ExecServerError; use crate::RemoteExecServerConnectArgs; use crate::fs; use crate::fs::ExecutorFileSystem; +use crate::local_process::LocalExecProcess; +use crate::process::ExecProcess; +use crate::remote_process::RemoteExecProcess; +use std::sync::Arc; -#[derive(Clone, Default)] +#[derive(Clone)] pub struct Environment { experimental_exec_server_url: Option, - remote_exec_server_client: Option, + executor: Arc, } impl std::fmt::Debug for Environment { @@ -19,7 +24,7 @@ impl std::fmt::Debug for Environment { ) .field( "has_remote_exec_server_client", - &self.remote_exec_server_client.is_some(), + &self.experimental_exec_server_url.is_some(), ) .finish() } @@ -29,22 +34,21 @@ impl Environment { pub async fn create( experimental_exec_server_url: Option, ) -> Result { - let remote_exec_server_client = + let executor: Arc = if let Some(websocket_url) = experimental_exec_server_url.as_deref() { - Some( - ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new( - websocket_url.to_string(), - "codex-core".to_string(), - )) - .await?, - ) + let client = ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new( + websocket_url.to_string(), + "codex-core".to_string(), + )) + .await?; + Arc::new(RemoteExecProcess::new(client)) } else { - None + Arc::new(LocalExecProcess::new()) }; Ok(Self { experimental_exec_server_url, - remote_exec_server_client, + executor, }) } @@ -52,8 +56,8 @@ impl Environment { self.experimental_exec_server_url.as_deref() } - pub fn remote_exec_server_client(&self) -> Option<&ExecServerClient> { - self.remote_exec_server_client.as_ref() + pub fn get_executor(&self) -> Arc { + Arc::clone(&self.executor) } pub fn get_filesystem(&self) -> impl ExecutorFileSystem + use<> { @@ -61,6 +65,12 @@ impl Environment { } } +impl crate::ExecutorEnvironment for Environment { + fn get_executor(&self) -> Arc { + self.get_executor() + } +} + #[cfg(test)] mod tests { use super::Environment; @@ -71,6 +81,5 @@ mod tests { let environment = Environment::create(None).await.expect("create environment"); assert_eq!(environment.experimental_exec_server_url(), None); - assert!(environment.remote_exec_server_client().is_none()); } } diff --git a/codex-rs/exec-server/src/fs.rs b/codex-rs/exec-server/src/fs.rs index 82e0b8e6e6bc..c5999c1489f8 100644 --- a/codex-rs/exec-server/src/fs.rs +++ b/codex-rs/exec-server/src/fs.rs @@ -72,7 +72,7 @@ pub trait ExecutorFileSystem: Send + Sync { } #[derive(Clone, Default)] -pub(crate) struct LocalFileSystem; +pub struct LocalFileSystem; #[async_trait] impl ExecutorFileSystem for LocalFileSystem { diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index 3c50d0ec5911..fa20ffd18e28 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -3,14 +3,16 @@ mod client_api; mod connection; mod environment; mod fs; +mod local_process; +mod process; mod protocol; +mod remote_process; mod rpc; mod server; pub use client::ExecServerClient; pub use client::ExecServerError; pub use client_api::ExecServerClientConnectOptions; -pub use client_api::ExecServerEvent; pub use client_api::RemoteExecServerConnectArgs; pub use codex_app_server_protocol::FsCopyParams; pub use codex_app_server_protocol::FsCopyResponse; @@ -33,8 +35,11 @@ pub use fs::CreateDirectoryOptions; pub use fs::ExecutorFileSystem; pub use fs::FileMetadata; pub use fs::FileSystemResult; +pub use fs::LocalFileSystem; pub use fs::ReadDirectoryEntry; pub use fs::RemoveOptions; +pub use process::ExecProcess; +pub use process::ExecServerEvent; pub use protocol::ExecExitedNotification; pub use protocol::ExecOutputDeltaNotification; pub use protocol::ExecOutputStream; @@ -50,5 +55,8 @@ pub use protocol::WriteParams; pub use protocol::WriteResponse; pub use server::DEFAULT_LISTEN_URL; pub use server::ExecServerListenUrlParseError; +pub trait ExecutorEnvironment: Send + Sync { + fn get_executor(&self) -> std::sync::Arc; +} pub use server::run_main; pub use server::run_main_with_listen_url; diff --git a/codex-rs/exec-server/src/local_process.rs b/codex-rs/exec-server/src/local_process.rs new file mode 100644 index 000000000000..d78725af5e98 --- /dev/null +++ b/codex-rs/exec-server/src/local_process.rs @@ -0,0 +1,146 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use codex_app_server_protocol::JSONRPCErrorError; +use serde_json::Value; +use tokio::sync::broadcast; +use tokio::sync::mpsc; + +use crate::ExecProcess; +use crate::ExecServerError; +use crate::ExecServerEvent; +use crate::process::ExecServerEvent::Exited; +use crate::process::ExecServerEvent::OutputDelta; +use crate::protocol::EXEC_EXITED_METHOD; +use crate::protocol::EXEC_OUTPUT_DELTA_METHOD; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::rpc::RpcNotificationSender; +use crate::rpc::RpcServerOutboundMessage; +use crate::server::ProcessHandler; + +#[derive(Clone)] +pub(crate) struct LocalExecProcess { + inner: Arc, +} + +struct Inner { + process_handler: ProcessHandler, + events_tx: broadcast::Sender, + reader_task: tokio::task::JoinHandle<()>, +} + +impl Drop for Inner { + fn drop(&mut self) { + if let Ok(handle) = tokio::runtime::Handle::try_current() { + let process_handler = self.process_handler.clone(); + handle.spawn(async move { + process_handler.shutdown().await; + }); + } + self.reader_task.abort(); + } +} + +impl LocalExecProcess { + pub(crate) fn new() -> Self { + let (outgoing_tx, mut outgoing_rx) = mpsc::channel::(256); + let process_handler = ProcessHandler::new(RpcNotificationSender::new(outgoing_tx)); + let events_tx = broadcast::channel(256).0; + let events_tx_for_task = events_tx.clone(); + let reader_task = tokio::spawn(async move { + while let Some(message) = outgoing_rx.recv().await { + if let RpcServerOutboundMessage::Notification(notification) = message { + match notification.method.as_str() { + EXEC_OUTPUT_DELTA_METHOD => { + if let Ok(params) = serde_json::from_value::( + notification.params.unwrap_or(Value::Null), + ) { + let _ = events_tx_for_task.send(OutputDelta(params)); + } + } + EXEC_EXITED_METHOD => { + if let Ok(params) = serde_json::from_value::( + notification.params.unwrap_or(Value::Null), + ) { + let _ = events_tx_for_task.send(Exited(params)); + } + } + _ => {} + } + } + } + }); + + Self { + inner: Arc::new(Inner { + process_handler, + events_tx, + reader_task, + }), + } + } +} + +#[async_trait] +impl ExecProcess for LocalExecProcess { + async fn start(&self, params: ExecParams) -> Result { + self.inner + .process_handler + .exec(params) + .await + .map_err(map_local_error) + } + + async fn read(&self, params: ReadParams) -> Result { + self.inner + .process_handler + .exec_read(params) + .await + .map_err(map_local_error) + } + + async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result { + self.inner + .process_handler + .exec_write(WriteParams { + process_id: process_id.to_string(), + chunk: chunk.into(), + }) + .await + .map_err(map_local_error) + } + + async fn terminate(&self, process_id: &str) -> Result { + self.inner + .process_handler + .terminate(TerminateParams { + process_id: process_id.to_string(), + }) + .await + .map_err(map_local_error) + } + + fn subscribe_events(&self) -> broadcast::Receiver { + self.inner.events_tx.subscribe() + } +} + +fn map_local_error(error: JSONRPCErrorError) -> ExecServerError { + ExecServerError::Server { + code: error.code, + message: error.message, + } +} diff --git a/codex-rs/exec-server/src/process.rs b/codex-rs/exec-server/src/process.rs new file mode 100644 index 000000000000..b2d743c329c7 --- /dev/null +++ b/codex-rs/exec-server/src/process.rs @@ -0,0 +1,35 @@ +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::ExecServerError; +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteResponse; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ExecServerEvent { + OutputDelta(ExecOutputDeltaNotification), + Exited(ExecExitedNotification), +} + +#[async_trait] +pub trait ExecProcess: Send + Sync { + async fn start(&self, params: ExecParams) -> Result; + + async fn read(&self, params: ReadParams) -> Result; + + async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result; + + async fn terminate(&self, process_id: &str) -> Result; + + fn subscribe_events(&self) -> broadcast::Receiver; +} diff --git a/codex-rs/exec-server/src/remote_process.rs b/codex-rs/exec-server/src/remote_process.rs new file mode 100644 index 000000000000..2e6829923452 --- /dev/null +++ b/codex-rs/exec-server/src/remote_process.rs @@ -0,0 +1,51 @@ +use async_trait::async_trait; +use tokio::sync::broadcast; + +use crate::ExecProcess; +use crate::ExecServerClient; +use crate::ExecServerError; +use crate::ExecServerEvent; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteResponse; + +#[derive(Clone)] +pub(crate) struct RemoteExecProcess { + client: ExecServerClient, +} + +impl RemoteExecProcess { + pub(crate) fn new(client: ExecServerClient) -> Self { + Self { client } + } +} + +#[async_trait] +impl ExecProcess for RemoteExecProcess { + async fn start(&self, params: ExecParams) -> Result { + self.client.exec(params).await + } + + async fn read(&self, params: ReadParams) -> Result { + self.client.read(params).await + } + + async fn write( + &self, + process_id: &str, + chunk: Vec, + ) -> Result { + self.client.write(process_id, chunk).await + } + + async fn terminate(&self, process_id: &str) -> Result { + self.client.terminate(process_id).await + } + + fn subscribe_events(&self) -> broadcast::Receiver { + self.client.event_receiver() + } +} diff --git a/codex-rs/exec-server/src/server.rs b/codex-rs/exec-server/src/server.rs index c403b029d702..368407ec8c75 100644 --- a/codex-rs/exec-server/src/server.rs +++ b/codex-rs/exec-server/src/server.rs @@ -1,10 +1,12 @@ mod filesystem; mod handler; +mod process_handler; mod processor; mod registry; mod transport; pub(crate) use handler::ExecServerHandler; +pub(crate) use process_handler::ProcessHandler; pub use transport::DEFAULT_LISTEN_URL; pub use transport::ExecServerListenUrlParseError; diff --git a/codex-rs/exec-server/src/server/filesystem.rs b/codex-rs/exec-server/src/server/filesystem.rs index bc3d22a4da3b..cfd28be7135b 100644 --- a/codex-rs/exec-server/src/server/filesystem.rs +++ b/codex-rs/exec-server/src/server/filesystem.rs @@ -22,8 +22,8 @@ use codex_app_server_protocol::JSONRPCErrorError; use crate::CopyOptions; use crate::CreateDirectoryOptions; -use crate::Environment; use crate::ExecutorFileSystem; +use crate::LocalFileSystem; use crate::RemoveOptions; use crate::rpc::internal_error; use crate::rpc::invalid_request; @@ -36,7 +36,7 @@ pub(crate) struct ExecServerFileSystem { impl Default for ExecServerFileSystem { fn default() -> Self { Self { - file_system: Arc::new(Environment.get_filesystem()), + file_system: Arc::new(LocalFileSystem), } } } diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index c21aeecb5c2e..ace15994829e 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -1,10 +1,19 @@ -use std::collections::HashMap; -use std::collections::VecDeque; -use std::sync::Arc; use std::sync::atomic::AtomicBool; use std::sync::atomic::Ordering; -use std::time::Duration; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::InitializeResponse; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::rpc::RpcNotificationSender; +use crate::rpc::invalid_request; +use crate::server::filesystem::ExecServerFileSystem; +use crate::server::process_handler::ProcessHandler; use codex_app_server_protocol::FsCopyParams; use codex_app_server_protocol::FsCopyResponse; use codex_app_server_protocol::FsCreateDirectoryParams; @@ -20,63 +29,10 @@ use codex_app_server_protocol::FsRemoveResponse; use codex_app_server_protocol::FsWriteFileParams; use codex_app_server_protocol::FsWriteFileResponse; use codex_app_server_protocol::JSONRPCErrorError; -use codex_utils_pty::ExecCommandSession; -use codex_utils_pty::TerminalSize; -use tokio::sync::Mutex; -use tokio::sync::Notify; -use tracing::warn; - -use crate::protocol::ExecExitedNotification; -use crate::protocol::ExecOutputDeltaNotification; -use crate::protocol::ExecOutputStream; -use crate::protocol::ExecParams; -use crate::protocol::ExecResponse; -use crate::protocol::InitializeResponse; -use crate::protocol::ProcessOutputChunk; -use crate::protocol::ReadParams; -use crate::protocol::ReadResponse; -use crate::protocol::TerminateParams; -use crate::protocol::TerminateResponse; -use crate::protocol::WriteParams; -use crate::protocol::WriteResponse; -use crate::rpc::RpcNotificationSender; -use crate::rpc::internal_error; -use crate::rpc::invalid_params; -use crate::rpc::invalid_request; -use crate::server::filesystem::ExecServerFileSystem; - -const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024; -#[cfg(test)] -const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25); -#[cfg(not(test))] -const EXITED_PROCESS_RETENTION: Duration = Duration::from_secs(30); - -#[derive(Clone)] -struct RetainedOutputChunk { - seq: u64, - stream: ExecOutputStream, - chunk: Vec, -} - -struct RunningProcess { - session: ExecCommandSession, - tty: bool, - output: VecDeque, - retained_bytes: usize, - next_seq: u64, - exit_code: Option, - output_notify: Arc, -} - -enum ProcessEntry { - Starting, - Running(Box), -} pub(crate) struct ExecServerHandler { - notifications: RpcNotificationSender, file_system: ExecServerFileSystem, - processes: Arc>>, + process_handler: ProcessHandler, initialize_requested: AtomicBool, initialized: AtomicBool, } @@ -84,28 +40,15 @@ pub(crate) struct ExecServerHandler { impl ExecServerHandler { pub(crate) fn new(notifications: RpcNotificationSender) -> Self { Self { - notifications, file_system: ExecServerFileSystem::default(), - processes: Arc::new(Mutex::new(HashMap::new())), + process_handler: ProcessHandler::new(notifications), initialize_requested: AtomicBool::new(false), initialized: AtomicBool::new(false), } } pub(crate) async fn shutdown(&self) { - let remaining = { - let mut processes = self.processes.lock().await; - processes - .drain() - .filter_map(|(_, process)| match process { - ProcessEntry::Starting => None, - ProcessEntry::Running(process) => Some(process), - }) - .collect::>() - }; - for process in remaining { - process.session.terminate(); - } + self.process_handler.shutdown().await; } pub(crate) fn initialize(&self) -> Result { @@ -141,104 +84,7 @@ impl ExecServerHandler { pub(crate) async fn exec(&self, params: ExecParams) -> Result { self.require_initialized_for("exec")?; - let process_id = params.process_id.clone(); - - let (program, args) = params - .argv - .split_first() - .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; - - { - let mut process_map = self.processes.lock().await; - if process_map.contains_key(&process_id) { - return Err(invalid_request(format!( - "process {process_id} already exists" - ))); - } - process_map.insert(process_id.clone(), ProcessEntry::Starting); - } - - let spawned_result = if params.tty { - codex_utils_pty::spawn_pty_process( - program, - args, - params.cwd.as_path(), - ¶ms.env, - ¶ms.arg0, - TerminalSize::default(), - ) - .await - } else { - codex_utils_pty::spawn_pipe_process_no_stdin( - program, - args, - params.cwd.as_path(), - ¶ms.env, - ¶ms.arg0, - ) - .await - }; - let spawned = match spawned_result { - Ok(spawned) => spawned, - Err(err) => { - let mut process_map = self.processes.lock().await; - if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) { - process_map.remove(&process_id); - } - return Err(internal_error(err.to_string())); - } - }; - - let output_notify = Arc::new(Notify::new()); - { - let mut process_map = self.processes.lock().await; - process_map.insert( - process_id.clone(), - ProcessEntry::Running(Box::new(RunningProcess { - session: spawned.session, - tty: params.tty, - output: VecDeque::new(), - retained_bytes: 0, - next_seq: 1, - exit_code: None, - output_notify: Arc::clone(&output_notify), - })), - ); - } - - tokio::spawn(stream_output( - process_id.clone(), - if params.tty { - ExecOutputStream::Pty - } else { - ExecOutputStream::Stdout - }, - spawned.stdout_rx, - self.notifications.clone(), - Arc::clone(&self.processes), - Arc::clone(&output_notify), - )); - tokio::spawn(stream_output( - process_id.clone(), - if params.tty { - ExecOutputStream::Pty - } else { - ExecOutputStream::Stderr - }, - spawned.stderr_rx, - self.notifications.clone(), - Arc::clone(&self.processes), - Arc::clone(&output_notify), - )); - tokio::spawn(watch_exit( - process_id.clone(), - spawned.exit_rx, - self.notifications.clone(), - Arc::clone(&self.processes), - output_notify, - )); - - Ok(ExecResponse { process_id }) + self.process_handler.exec(params).await } pub(crate) async fn exec_read( @@ -246,68 +92,7 @@ impl ExecServerHandler { params: ReadParams, ) -> Result { self.require_initialized_for("exec")?; - let after_seq = params.after_seq.unwrap_or(0); - let max_bytes = params.max_bytes.unwrap_or(usize::MAX); - let wait = Duration::from_millis(params.wait_ms.unwrap_or(0)); - let deadline = tokio::time::Instant::now() + wait; - - loop { - let (response, output_notify) = { - let process_map = self.processes.lock().await; - let process = process_map.get(¶ms.process_id).ok_or_else(|| { - invalid_request(format!("unknown process id {}", params.process_id)) - })?; - let ProcessEntry::Running(process) = process else { - return Err(invalid_request(format!( - "process id {} is starting", - params.process_id - ))); - }; - - let mut chunks = Vec::new(); - let mut total_bytes = 0; - let mut next_seq = process.next_seq; - for retained in process.output.iter().filter(|chunk| chunk.seq > after_seq) { - let chunk_len = retained.chunk.len(); - if !chunks.is_empty() && total_bytes + chunk_len > max_bytes { - break; - } - total_bytes += chunk_len; - chunks.push(ProcessOutputChunk { - seq: retained.seq, - stream: retained.stream, - chunk: retained.chunk.clone().into(), - }); - next_seq = retained.seq + 1; - if total_bytes >= max_bytes { - break; - } - } - - ( - ReadResponse { - chunks, - next_seq, - exited: process.exit_code.is_some(), - exit_code: process.exit_code, - }, - Arc::clone(&process.output_notify), - ) - }; - - if !response.chunks.is_empty() - || response.exited - || tokio::time::Instant::now() >= deadline - { - return Ok(response); - } - - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - return Ok(response); - } - let _ = tokio::time::timeout(remaining, output_notify.notified()).await; - } + self.process_handler.exec_read(params).await } pub(crate) async fn exec_write( @@ -315,32 +100,7 @@ impl ExecServerHandler { params: WriteParams, ) -> Result { self.require_initialized_for("exec")?; - let writer_tx = { - let process_map = self.processes.lock().await; - let process = process_map.get(¶ms.process_id).ok_or_else(|| { - invalid_request(format!("unknown process id {}", params.process_id)) - })?; - let ProcessEntry::Running(process) = process else { - return Err(invalid_request(format!( - "process id {} is starting", - params.process_id - ))); - }; - if !process.tty { - return Err(invalid_request(format!( - "stdin is closed for process {}", - params.process_id - ))); - } - process.session.writer_sender() - }; - - writer_tx - .send(params.chunk.into_inner()) - .await - .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; - - Ok(WriteResponse { accepted: true }) + self.process_handler.exec_write(params).await } pub(crate) async fn terminate( @@ -348,21 +108,7 @@ impl ExecServerHandler { params: TerminateParams, ) -> Result { self.require_initialized_for("exec")?; - let running = { - let process_map = self.processes.lock().await; - match process_map.get(¶ms.process_id) { - Some(ProcessEntry::Running(process)) => { - if process.exit_code.is_some() { - return Ok(TerminateResponse { running: false }); - } - process.session.terminate(); - true - } - Some(ProcessEntry::Starting) | None => false, - } - }; - - Ok(TerminateResponse { running }) + self.process_handler.terminate(params).await } pub(crate) async fn fs_read_file( @@ -422,96 +168,5 @@ impl ExecServerHandler { } } -async fn stream_output( - process_id: String, - stream: ExecOutputStream, - mut receiver: tokio::sync::mpsc::Receiver>, - notifications: RpcNotificationSender, - processes: Arc>>, - output_notify: Arc, -) { - while let Some(chunk) = receiver.recv().await { - let notification = { - let mut processes = processes.lock().await; - let Some(entry) = processes.get_mut(&process_id) else { - break; - }; - let ProcessEntry::Running(process) = entry else { - break; - }; - let seq = process.next_seq; - process.next_seq += 1; - process.retained_bytes += chunk.len(); - process.output.push_back(RetainedOutputChunk { - seq, - stream, - chunk: chunk.clone(), - }); - while process.retained_bytes > RETAINED_OUTPUT_BYTES_PER_PROCESS { - let Some(evicted) = process.output.pop_front() else { - break; - }; - process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len()); - warn!( - "retained output cap exceeded for process {process_id}; dropping oldest output" - ); - } - ExecOutputDeltaNotification { - process_id: process_id.clone(), - stream, - chunk: chunk.into(), - } - }; - output_notify.notify_waiters(); - - if notifications - .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) - .await - .is_err() - { - break; - } - } -} - -async fn watch_exit( - process_id: String, - exit_rx: tokio::sync::oneshot::Receiver, - notifications: RpcNotificationSender, - processes: Arc>>, - output_notify: Arc, -) { - let exit_code = exit_rx.await.unwrap_or(-1); - { - let mut processes = processes.lock().await; - if let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) { - process.exit_code = Some(exit_code); - } - } - output_notify.notify_waiters(); - if notifications - .notify( - crate::protocol::EXEC_EXITED_METHOD, - &ExecExitedNotification { - process_id: process_id.clone(), - exit_code, - }, - ) - .await - .is_err() - { - return; - } - - tokio::time::sleep(EXITED_PROCESS_RETENTION).await; - let mut processes = processes.lock().await; - if matches!( - processes.get(&process_id), - Some(ProcessEntry::Running(process)) if process.exit_code == Some(exit_code) - ) { - processes.remove(&process_id); - } -} - #[cfg(test)] mod tests; diff --git a/codex-rs/exec-server/src/server/process_handler.rs b/codex-rs/exec-server/src/server/process_handler.rs new file mode 100644 index 000000000000..eb43093600d5 --- /dev/null +++ b/codex-rs/exec-server/src/server/process_handler.rs @@ -0,0 +1,400 @@ +use std::collections::HashMap; +use std::collections::VecDeque; +use std::sync::Arc; +use std::time::Duration; + +use codex_app_server_protocol::JSONRPCErrorError; +use codex_utils_pty::ExecCommandSession; +use codex_utils_pty::TerminalSize; +use tokio::sync::Mutex; +use tokio::sync::Notify; +use tracing::warn; + +use crate::protocol::ExecExitedNotification; +use crate::protocol::ExecOutputDeltaNotification; +use crate::protocol::ExecOutputStream; +use crate::protocol::ExecParams; +use crate::protocol::ExecResponse; +use crate::protocol::ProcessOutputChunk; +use crate::protocol::ReadParams; +use crate::protocol::ReadResponse; +use crate::protocol::TerminateParams; +use crate::protocol::TerminateResponse; +use crate::protocol::WriteParams; +use crate::protocol::WriteResponse; +use crate::rpc::RpcNotificationSender; +use crate::rpc::internal_error; +use crate::rpc::invalid_params; +use crate::rpc::invalid_request; + +const RETAINED_OUTPUT_BYTES_PER_PROCESS: usize = 1024 * 1024; +#[cfg(test)] +const EXITED_PROCESS_RETENTION: Duration = Duration::from_millis(25); +#[cfg(not(test))] +const EXITED_PROCESS_RETENTION: Duration = Duration::from_secs(30); + +#[derive(Clone)] +struct RetainedOutputChunk { + seq: u64, + stream: ExecOutputStream, + chunk: Vec, +} + +struct RunningProcess { + session: ExecCommandSession, + tty: bool, + output: VecDeque, + retained_bytes: usize, + next_seq: u64, + exit_code: Option, + output_notify: Arc, +} + +enum ProcessEntry { + Starting, + Running(Box), +} + +#[derive(Clone)] +pub(crate) struct ProcessHandler { + notifications: RpcNotificationSender, + processes: Arc>>, +} + +impl ProcessHandler { + pub(crate) fn new(notifications: RpcNotificationSender) -> Self { + Self { + notifications, + processes: Arc::new(Mutex::new(HashMap::new())), + } + } + + pub(crate) async fn shutdown(&self) { + let remaining = { + let mut processes = self.processes.lock().await; + processes + .drain() + .filter_map(|(_, process)| match process { + ProcessEntry::Starting => None, + ProcessEntry::Running(process) => Some(process), + }) + .collect::>() + }; + for process in remaining { + process.session.terminate(); + } + } + + pub(crate) async fn exec(&self, params: ExecParams) -> Result { + let process_id = params.process_id.clone(); + + let (program, args) = params + .argv + .split_first() + .ok_or_else(|| invalid_params("argv must not be empty".to_string()))?; + + { + let mut process_map = self.processes.lock().await; + if process_map.contains_key(&process_id) { + return Err(invalid_request(format!( + "process {process_id} already exists" + ))); + } + process_map.insert(process_id.clone(), ProcessEntry::Starting); + } + + let spawned_result = if params.tty { + codex_utils_pty::spawn_pty_process( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + TerminalSize::default(), + ) + .await + } else { + codex_utils_pty::spawn_pipe_process_no_stdin( + program, + args, + params.cwd.as_path(), + ¶ms.env, + ¶ms.arg0, + ) + .await + }; + let spawned = match spawned_result { + Ok(spawned) => spawned, + Err(err) => { + let mut process_map = self.processes.lock().await; + if matches!(process_map.get(&process_id), Some(ProcessEntry::Starting)) { + process_map.remove(&process_id); + } + return Err(internal_error(err.to_string())); + } + }; + + let output_notify = Arc::new(Notify::new()); + { + let mut process_map = self.processes.lock().await; + process_map.insert( + process_id.clone(), + ProcessEntry::Running(Box::new(RunningProcess { + session: spawned.session, + tty: params.tty, + output: VecDeque::new(), + retained_bytes: 0, + next_seq: 1, + exit_code: None, + output_notify: Arc::clone(&output_notify), + })), + ); + } + + tokio::spawn(stream_output( + process_id.clone(), + if params.tty { + ExecOutputStream::Pty + } else { + ExecOutputStream::Stdout + }, + spawned.stdout_rx, + self.notifications.clone(), + Arc::clone(&self.processes), + Arc::clone(&output_notify), + )); + tokio::spawn(stream_output( + process_id.clone(), + if params.tty { + ExecOutputStream::Pty + } else { + ExecOutputStream::Stderr + }, + spawned.stderr_rx, + self.notifications.clone(), + Arc::clone(&self.processes), + Arc::clone(&output_notify), + )); + tokio::spawn(watch_exit( + process_id.clone(), + spawned.exit_rx, + self.notifications.clone(), + Arc::clone(&self.processes), + output_notify, + )); + + Ok(ExecResponse { process_id }) + } + + pub(crate) async fn exec_read( + &self, + params: ReadParams, + ) -> Result { + let after_seq = params.after_seq.unwrap_or(0); + let max_bytes = params.max_bytes.unwrap_or(usize::MAX); + let wait = Duration::from_millis(params.wait_ms.unwrap_or(0)); + let deadline = tokio::time::Instant::now() + wait; + + loop { + let (response, output_notify) = { + let process_map = self.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + let ProcessEntry::Running(process) = process else { + return Err(invalid_request(format!( + "process id {} is starting", + params.process_id + ))); + }; + + let mut chunks = Vec::new(); + let mut total_bytes = 0; + let mut next_seq = process.next_seq; + for retained in process.output.iter().filter(|chunk| chunk.seq > after_seq) { + let chunk_len = retained.chunk.len(); + if !chunks.is_empty() && total_bytes + chunk_len > max_bytes { + break; + } + total_bytes += chunk_len; + chunks.push(ProcessOutputChunk { + seq: retained.seq, + stream: retained.stream, + chunk: retained.chunk.clone().into(), + }); + next_seq = retained.seq + 1; + if total_bytes >= max_bytes { + break; + } + } + + ( + ReadResponse { + chunks, + next_seq, + exited: process.exit_code.is_some(), + exit_code: process.exit_code, + }, + Arc::clone(&process.output_notify), + ) + }; + + if !response.chunks.is_empty() + || response.exited + || tokio::time::Instant::now() >= deadline + { + return Ok(response); + } + + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + return Ok(response); + } + let _ = tokio::time::timeout(remaining, output_notify.notified()).await; + } + } + + pub(crate) async fn exec_write( + &self, + params: WriteParams, + ) -> Result { + let writer_tx = { + let process_map = self.processes.lock().await; + let process = process_map.get(¶ms.process_id).ok_or_else(|| { + invalid_request(format!("unknown process id {}", params.process_id)) + })?; + let ProcessEntry::Running(process) = process else { + return Err(invalid_request(format!( + "process id {} is starting", + params.process_id + ))); + }; + if !process.tty { + return Err(invalid_request(format!( + "stdin is closed for process {}", + params.process_id + ))); + } + process.session.writer_sender() + }; + + writer_tx + .send(params.chunk.into_inner()) + .await + .map_err(|_| internal_error("failed to write to process stdin".to_string()))?; + + Ok(WriteResponse { accepted: true }) + } + + pub(crate) async fn terminate( + &self, + params: TerminateParams, + ) -> Result { + let running = { + let process_map = self.processes.lock().await; + match process_map.get(¶ms.process_id) { + Some(ProcessEntry::Running(process)) => { + if process.exit_code.is_some() { + return Ok(TerminateResponse { running: false }); + } + process.session.terminate(); + true + } + Some(ProcessEntry::Starting) | None => false, + } + }; + + Ok(TerminateResponse { running }) + } +} + +async fn stream_output( + process_id: String, + stream: ExecOutputStream, + mut receiver: tokio::sync::mpsc::Receiver>, + notifications: RpcNotificationSender, + processes: Arc>>, + output_notify: Arc, +) { + while let Some(chunk) = receiver.recv().await { + let notification = { + let mut processes = processes.lock().await; + let Some(entry) = processes.get_mut(&process_id) else { + break; + }; + let ProcessEntry::Running(process) = entry else { + break; + }; + let seq = process.next_seq; + process.next_seq += 1; + process.retained_bytes += chunk.len(); + process.output.push_back(RetainedOutputChunk { + seq, + stream, + chunk: chunk.clone(), + }); + while process.retained_bytes > RETAINED_OUTPUT_BYTES_PER_PROCESS { + let Some(evicted) = process.output.pop_front() else { + break; + }; + process.retained_bytes = process.retained_bytes.saturating_sub(evicted.chunk.len()); + warn!( + "retained output cap exceeded for process {process_id}; dropping oldest output" + ); + } + ExecOutputDeltaNotification { + process_id: process_id.clone(), + stream, + chunk: chunk.into(), + } + }; + output_notify.notify_waiters(); + + if notifications + .notify(crate::protocol::EXEC_OUTPUT_DELTA_METHOD, ¬ification) + .await + .is_err() + { + break; + } + } +} + +async fn watch_exit( + process_id: String, + exit_rx: tokio::sync::oneshot::Receiver, + notifications: RpcNotificationSender, + processes: Arc>>, + output_notify: Arc, +) { + let exit_code = exit_rx.await.unwrap_or(-1); + { + let mut processes = processes.lock().await; + if let Some(ProcessEntry::Running(process)) = processes.get_mut(&process_id) { + process.exit_code = Some(exit_code); + } + } + output_notify.notify_waiters(); + if notifications + .notify( + crate::protocol::EXEC_EXITED_METHOD, + &ExecExitedNotification { + process_id: process_id.clone(), + exit_code, + }, + ) + .await + .is_err() + { + return; + } + + tokio::time::sleep(EXITED_PROCESS_RETENTION).await; + let mut processes = processes.lock().await; + if matches!( + processes.get(&process_id), + Some(ProcessEntry::Running(process)) if process.exit_code == Some(exit_code) + ) { + processes.remove(&process_id); + } +}