diff --git a/codex-rs/app-server/README.md b/codex-rs/app-server/README.md index 55adb794a3ed..0b52d2ce6051 100644 --- a/codex-rs/app-server/README.md +++ b/codex-rs/app-server/README.md @@ -29,7 +29,8 @@ Supported transports: When running with `--listen ws://IP:PORT`, the same listener also serves basic HTTP health probes: - `GET /readyz` returns `200 OK` once the listener is accepting new connections. -- `GET /healthz` currently always returns `200 OK`. +- `GET /healthz` returns `200 OK` when no `Origin` header is present. +- Any request carrying an `Origin` header is rejected with `403 Forbidden`. Websocket transport is currently experimental and unsupported. Do not rely on it for production workloads. diff --git a/codex-rs/app-server/src/transport.rs b/codex-rs/app-server/src/transport.rs index d0aa753358e4..3e24d831ae5a 100644 --- a/codex-rs/app-server/src/transport.rs +++ b/codex-rs/app-server/src/transport.rs @@ -5,13 +5,19 @@ use crate::outgoing_message::OutgoingEnvelope; use crate::outgoing_message::OutgoingError; use crate::outgoing_message::OutgoingMessage; use axum::Router; +use axum::body::Body; use axum::extract::ConnectInfo; use axum::extract::State; use axum::extract::ws::Message as WebSocketMessage; use axum::extract::ws::WebSocket; use axum::extract::ws::WebSocketUpgrade; +use axum::http::Request; use axum::http::StatusCode; +use axum::http::header::ORIGIN; +use axum::middleware; +use axum::middleware::Next; use axum::response::IntoResponse; +use axum::response::Response; use axum::routing::any; use axum::routing::get; use codex_app_server_protocol::JSONRPCErrorError; @@ -91,6 +97,22 @@ async fn health_check_handler() -> StatusCode { StatusCode::OK } +async fn reject_requests_with_origin_header( + request: Request, + next: Next, +) -> Result { + if request.headers().contains_key(ORIGIN) { + warn!( + method = %request.method(), + uri = %request.uri(), + "rejecting websocket listener request with Origin header" + ); + Err(StatusCode::FORBIDDEN) + } else { + Ok(next.run(request).await) + } +} + async fn websocket_upgrade_handler( websocket: WebSocketUpgrade, ConnectInfo(peer_addr): ConnectInfo, @@ -322,6 +344,7 @@ pub(crate) async fn start_websocket_acceptor( .route("/readyz", get(health_check_handler)) .route("/healthz", get(health_check_handler)) .fallback(any(websocket_upgrade_handler)) + .layer(middleware::from_fn(reject_requests_with_origin_header)) .with_state(WebSocketListenerState { transport_event_tx, connection_counter: Arc::new(AtomicU64::new(1)), diff --git a/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs b/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs index 3a8ae9243047..f0216f6baee0 100644 --- a/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs +++ b/codex-rs/app-server/tests/suite/v2/connection_handling_websocket.rs @@ -29,7 +29,11 @@ use tokio::time::timeout; use tokio_tungstenite::MaybeTlsStream; use tokio_tungstenite::WebSocketStream; use tokio_tungstenite::connect_async; +use tokio_tungstenite::tungstenite::Error as WebSocketError; use tokio_tungstenite::tungstenite::Message as WebSocketMessage; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::HeaderValue; +use tokio_tungstenite::tungstenite::http::header::ORIGIN; pub(super) const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(5); @@ -107,6 +111,55 @@ async fn websocket_transport_serves_health_endpoints_on_same_listener() -> Resul Ok(()) } +#[tokio::test] +async fn websocket_transport_rejects_requests_with_origin_header() -> Result<()> { + let server = create_mock_responses_server_sequence_unchecked(Vec::new()).await; + let codex_home = TempDir::new()?; + create_config_toml(codex_home.path(), &server.uri(), "never")?; + + let (mut process, bind_addr) = spawn_websocket_server(codex_home.path()).await?; + let client = reqwest::Client::new(); + + let deadline = Instant::now() + Duration::from_secs(10); + let healthz = loop { + match client + .get(format!("http://{bind_addr}/healthz")) + .header(ORIGIN.as_str(), "https://example.com") + .send() + .await + .with_context(|| format!("failed to GET http://{bind_addr}/healthz with Origin header")) + { + Ok(response) => break response, + Err(err) => { + if Instant::now() >= deadline { + bail!("failed to GET http://{bind_addr}/healthz with Origin header: {err}"); + } + sleep(Duration::from_millis(50)).await; + } + } + }; + assert_eq!(healthz.status(), StatusCode::FORBIDDEN); + + let url = format!("ws://{bind_addr}"); + let mut request = url.into_client_request()?; + request + .headers_mut() + .insert(ORIGIN, HeaderValue::from_static("https://example.com")); + match connect_async(request).await { + Err(WebSocketError::Http(response)) => { + assert_eq!(response.status(), StatusCode::FORBIDDEN); + } + Ok(_) => bail!("expected websocket handshake with Origin header to be rejected"), + Err(err) => bail!("expected HTTP rejection for Origin header, got {err}"), + } + + process + .kill() + .await + .context("failed to stop websocket app-server process")?; + Ok(()) +} + pub(super) async fn spawn_websocket_server(codex_home: &Path) -> Result<(Child, SocketAddr)> { let program = codex_utils_cargo_bin::cargo_bin("codex-app-server") .context("should find app-server binary")?;