diff --git a/Cargo.lock b/Cargo.lock index afab7b0401..1499dde807 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "ascii" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -743,6 +749,12 @@ dependencies = [ "windows-link 0.2.1", ] +[[package]] +name = "chunked_transfer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" + [[package]] name = "clap" version = "4.6.0" @@ -2150,6 +2162,7 @@ dependencies = [ "tempfile", "terminal_size", "thiserror 2.0.18", + "tiny_http", "tokio", "tokio-stream", "tracing", @@ -6440,6 +6453,18 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tiny_http" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82" +dependencies = [ + "ascii", + "chunked_transfer", + "httpdate", + "log", +] + [[package]] name = "tinystr" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index ecdc37bcb7..622cd79f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,6 +89,7 @@ syn = { version = "2.0.117", features = ["derive", "parsing"] } sysinfo = "0.38.3" tempfile = "3.27.0" termimad = "0.34.1" +tiny_http = "0.12.0" syntect = { version = "5", default-features = false, features = ["default-syntaxes", "default-themes", "regex-onig"] } thiserror = "2.0.18" toml_edit = { version = "0.22", features = ["serde"] } diff --git a/crates/forge_domain/src/auth/auth_token_response.rs b/crates/forge_domain/src/auth/auth_token_response.rs index 7bea91b8e1..a0b66c713d 100644 --- a/crates/forge_domain/src/auth/auth_token_response.rs +++ b/crates/forge_domain/src/auth/auth_token_response.rs @@ -26,6 +26,10 @@ pub struct OAuthTokenResponse { /// OAuth scopes granted #[serde(skip_serializing_if = "Option::is_none")] pub scope: Option, + + /// ID token containing user identity claims (OpenID Connect) + #[serde(skip_serializing_if = "Option::is_none")] + pub id_token: Option, } fn default_token_type() -> String { diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 1062581faf..70a4711976 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -57,6 +57,68 @@ impl AuthStrategy for ApiKeyStrategy { } } +/// Extract the ChatGPT account ID from a JWT token's claims. +/// +/// Checks `chatgpt_account_id`, `https://api.openai.com/auth.chatgpt_account_id`, +/// and `organizations[0].id` in that order, matching the opencode +/// implementation. +fn extract_chatgpt_account_id(token: &str) -> Option { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return None; + } + use base64::Engine; + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(parts[1]) + .ok()?; + let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?; + + // Try chatgpt_account_id first + if let Some(id) = claims["chatgpt_account_id"].as_str() { + return Some(id.to_string()); + } + // Try nested auth claim + if let Some(id) = claims["https://api.openai.com/auth"]["chatgpt_account_id"].as_str() { + return Some(id.to_string()); + } + // Fall back to organizations[0].id + if let Some(id) = claims["organizations"] + .as_array() + .and_then(|orgs| orgs.first()) + .and_then(|org| org["id"].as_str()) + { + return Some(id.to_string()); + } + None +} + +/// Adds Codex-specific credential metadata derived from OAuth tokens. +/// +/// Tries to extract the account ID from the `id_token` first (which typically +/// contains the user identity claims in OpenID Connect flows), then falls back +/// to the `access_token` if needed. +fn enrich_codex_oauth_credential( + provider_id: &ProviderId, + credential: &mut AuthCredential, + id_token: Option<&str>, + access_token: &str, +) { + if *provider_id != ProviderId::CODEX { + return; + } + + // Try id_token first (preferred for user identity claims) + let account_id = id_token + .and_then(extract_chatgpt_account_id) + .or_else(|| extract_chatgpt_account_id(access_token)); + + if let Some(account_id) = account_id { + credential + .url_params + .insert("chatgpt_account_id".to_string().into(), account_id.into()); + } +} + /// OAuth Code Strategy - Browser redirect flow pub struct OAuthCodeStrategy { provider_id: ProviderId, @@ -96,7 +158,7 @@ impl AuthStrategy for OAuthCodeStrategy { let token_response = self .adapter .exchange_code( - &self.config, + &ctx.request.oauth_config, ctx.response.code.as_str(), ctx.request.pkce_verifier.as_ref().map(|v| v.as_str()), ) @@ -107,12 +169,21 @@ impl AuthStrategy for OAuthCodeStrategy { )) })?; - build_oauth_credential( + let access_token = token_response.access_token.clone(); + let id_token = token_response.id_token.clone(); + let mut credential = build_oauth_credential( self.provider_id.clone(), token_response, - &self.config, + &ctx.request.oauth_config, chrono::Duration::hours(1), // Code flow default - ) + )?; + enrich_codex_oauth_credential( + &self.provider_id, + &mut credential, + id_token.as_deref(), + &access_token, + ); + Ok(credential) } _ => Err(AuthError::InvalidContext("Expected Code context".to_string()).into()), } @@ -479,41 +550,6 @@ struct CodexDeviceTokenResponse { code_verifier: String, } -/// Extract the ChatGPT account ID from a JWT token's claims. -/// -/// Checks `chatgpt_account_id`, `https://api.openai.com/auth.chatgpt_account_id`, -/// and `organizations[0].id` in that order, matching the opencode -/// implementation. -fn extract_chatgpt_account_id(token: &str) -> Option { - let parts: Vec<&str> = token.split('.').collect(); - if parts.len() != 3 { - return None; - } - use base64::Engine; - let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(parts[1]) - .ok()?; - let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?; - - // Try chatgpt_account_id first - if let Some(id) = claims["chatgpt_account_id"].as_str() { - return Some(id.to_string()); - } - // Try nested auth claim - if let Some(id) = claims["https://api.openai.com/auth"]["chatgpt_account_id"].as_str() { - return Some(id.to_string()); - } - // Fall back to organizations[0].id - if let Some(id) = claims["organizations"] - .as_array() - .and_then(|orgs| orgs.first()) - .and_then(|org| org["id"].as_str()) - { - return Some(id.to_string()); - } - None -} - #[async_trait::async_trait] impl AuthStrategy for CodexDeviceStrategy { async fn init(&self) -> anyhow::Result { @@ -570,11 +606,8 @@ impl AuthStrategy for CodexDeviceStrategy { // Poll for authorization code using the custom OpenAI endpoint let token_response = codex_poll_for_tokens(&ctx.request, &self.config).await?; - // Extract ChatGPT account ID from the access token JWT. - // This is used for the optional `ChatGPT-Account-Id` request - // header when available. - let account_id = extract_chatgpt_account_id(&token_response.access_token); - + let access_token = token_response.access_token.clone(); + let id_token = token_response.id_token.clone(); let mut credential = build_oauth_credential( self.provider_id.clone(), token_response, @@ -583,12 +616,13 @@ impl AuthStrategy for CodexDeviceStrategy { )?; // Store account_id in url_params so it's persisted and available - // for chat request headers - if let Some(id) = account_id { - credential - .url_params - .insert("chatgpt_account_id".to_string().into(), id.into()); - } + // for chat request headers. + enrich_codex_oauth_credential( + &self.provider_id, + &mut credential, + id_token.as_deref(), + &access_token, + ); Ok(credential) } diff --git a/crates/forge_infra/src/auth/util.rs b/crates/forge_infra/src/auth/util.rs index de652e0f67..90f8bf1d71 100644 --- a/crates/forge_infra/src/auth/util.rs +++ b/crates/forge_infra/src/auth/util.rs @@ -36,6 +36,7 @@ pub(crate) fn into_domain(token: T) -> OAuthTokenRespo .collect::>() .join(" ") }), + id_token: None, // oauth2 crate doesn't provide id_token directly } } @@ -98,6 +99,7 @@ pub(crate) fn build_token_response( expires_at: None, token_type: "Bearer".to_string(), scope: None, + id_token: None, } } diff --git a/crates/forge_main/Cargo.toml b/crates/forge_main/Cargo.toml index 10d4ba741e..9357f0da16 100644 --- a/crates/forge_main/Cargo.toml +++ b/crates/forge_main/Cargo.toml @@ -66,6 +66,7 @@ strip-ansi-escapes.workspace = true terminal_size = "0.4" rustls.workspace = true tempfile.workspace = true +tiny_http.workspace = true [target.'cfg(windows)'.dependencies] enable-ansi-support.workspace = true diff --git a/crates/forge_main/src/lib.rs b/crates/forge_main/src/lib.rs index c5b342df7c..1fc22a116d 100644 --- a/crates/forge_main/src/lib.rs +++ b/crates/forge_main/src/lib.rs @@ -7,6 +7,7 @@ mod editor; mod info; mod input; mod model; +mod oauth_callback; mod porcelain; mod prompt; mod sandbox; diff --git a/crates/forge_main/src/oauth_callback.rs b/crates/forge_main/src/oauth_callback.rs new file mode 100644 index 0000000000..427e2a7e74 --- /dev/null +++ b/crates/forge_main/src/oauth_callback.rs @@ -0,0 +1,526 @@ +use std::collections::HashMap; +use std::net::{IpAddr, SocketAddr, TcpListener}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::time::{Duration, Instant}; + +use forge_domain::CodeRequest; +use tiny_http::{Header, Method, Request, Response, Server, StatusCode}; +use url::{Host, Url}; + +/// Maximum time to wait for the OAuth browser callback before giving up. +const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(300); +const CALLBACK_POLL_INTERVAL: Duration = Duration::from_secs(1); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct OAuthCallbackPayload { + code: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum OAuthCallbackParseResult { + PathMismatch, + InvalidRequest(String), + OAuthError(String), + Success(OAuthCallbackPayload), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum CallbackRequestDisposition { + Continue { + status_code: StatusCode, + message: String, + }, + Fail { + status_code: StatusCode, + message: String, + }, + Complete(OAuthCallbackPayload), +} + +/// Localhost OAuth callback server that waits for a browser redirect and +/// returns the authorization code. +pub(crate) struct LocalhostOAuthCallbackServer { + redirect_uri: Url, + server: Arc, + shutdown: Arc, + task: Option>>, +} + +impl LocalhostOAuthCallbackServer { + /// Starts a localhost OAuth callback server when the request uses a + /// localhost redirect URI. + /// + /// Returns `Ok(None)` when the request is not configured for a localhost + /// callback. + /// + /// # Errors + /// + /// Returns an error if the localhost redirect URI is invalid or the HTTP + /// listener cannot be bound. + pub(crate) fn start(request: &CodeRequest) -> anyhow::Result> { + let Some(redirect_uri) = localhost_oauth_redirect_uri(request) else { + return Ok(None); + }; + + let listener = TcpListener::bind(localhost_oauth_bind_addr(&redirect_uri)?)?; + let callback_path = redirect_uri.path().to_string(); + let expected_state = request.state.to_string(); + let server = Arc::new(Server::from_listener(listener, None).map_err(|e| { + anyhow::anyhow!("Failed to start localhost OAuth callback server: {e}") + })?); + let shutdown = Arc::new(AtomicBool::new(false)); + let task = tokio::task::spawn_blocking({ + let server = Arc::clone(&server); + let shutdown = Arc::clone(&shutdown); + move || { + wait_for_localhost_oauth_callback(server, callback_path, expected_state, shutdown) + } + }); + + Ok(Some(Self { + redirect_uri, + server, + shutdown, + task: Some(task), + })) + } + + /// Returns the redirect URI the callback server is listening on. + pub(crate) fn redirect_uri(&self) -> &Url { + &self.redirect_uri + } + + /// Waits for the browser callback and returns the authorization code. + /// + /// # Errors + /// + /// Returns an error when the background task fails or the callback request + /// is invalid. + pub(crate) async fn wait_for_code(mut self) -> anyhow::Result { + self.task + .take() + .expect("OAuth callback task should exist") + .await + .map_err(|e| anyhow::anyhow!("OAuth callback task failed: {e}"))? + } +} + +impl Drop for LocalhostOAuthCallbackServer { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Relaxed); + self.server.unblock(); + } +} + +fn escape_html(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +fn localhost_oauth_redirect_uri(request: &CodeRequest) -> Option { + request + .oauth_config + .redirect_uri + .as_ref() + .and_then(|uri| Url::parse(uri).ok()) + .filter(|uri| { + uri.scheme() == "http" + && uri.port().is_some() + && (matches!(uri.host(), Some(Host::Domain("localhost"))) + || uri.host().is_some_and(|host| match host { + Host::Ipv4(ip) => ip.is_loopback(), + Host::Ipv6(ip) => ip.is_loopback(), + Host::Domain(_) => false, + })) + }) +} + +fn localhost_oauth_bind_addr(redirect_uri: &Url) -> anyhow::Result { + let port = redirect_uri + .port() + .ok_or_else(|| anyhow::anyhow!("OAuth redirect URI is missing an explicit port"))?; + + match redirect_uri.host() { + Some(Host::Domain("localhost")) => Ok(SocketAddr::from(([127, 0, 0, 1], port))), + Some(Host::Ipv4(ip)) if ip.is_loopback() => Ok(SocketAddr::new(IpAddr::V4(ip), port)), + Some(Host::Ipv6(ip)) if ip.is_loopback() => Ok(SocketAddr::new(IpAddr::V6(ip), port)), + Some(_) => anyhow::bail!("OAuth redirect URI host must be localhost or loopback"), + None => anyhow::bail!("OAuth redirect URI is missing a host"), + } +} + +fn oauth_callback_success_page() -> String { + "ForgeCode Authorization Successful

Authorization Successful

You can close this window and return to ForgeCode.

".to_string() +} + +fn oauth_callback_error_page(message: &str) -> String { + format!( + "ForgeCode Authorization Failed

Authorization Failed

ForgeCode could not complete sign-in.

{}
", + escape_html(message) + ) +} + +fn html_response(status_code: StatusCode, body: String) -> Response>> { + Response::from_string(body) + .with_status_code(status_code) + .with_header(response_header("Content-Type", "text/html; charset=utf-8")) + .with_header(response_header("Cache-Control", "no-store")) + .with_header(response_header("X-Content-Type-Options", "nosniff")) +} + +fn response_header(name: &str, value: &str) -> Header { + Header::from_bytes(name.as_bytes(), value.as_bytes()) + .expect("static HTTP header should be valid") +} + +fn parse_oauth_callback_target( + request_target: &str, + expected_path: &str, + expected_state: &str, +) -> OAuthCallbackParseResult { + let Some(request_target) = request_target.strip_prefix('/') else { + return OAuthCallbackParseResult::InvalidRequest( + "Malformed OAuth callback request target".to_string(), + ); + }; + + let callback_url = match Url::parse(&format!("http://localhost/{request_target}")) { + Ok(url) => url, + Err(_) => { + return OAuthCallbackParseResult::InvalidRequest( + "Malformed OAuth callback request target".to_string(), + ); + } + }; + + if callback_url.path() != expected_path { + return OAuthCallbackParseResult::PathMismatch; + } + + let params: HashMap = callback_url.query_pairs().into_owned().collect(); + if let Some(error) = params.get("error") { + let detail = params + .get("error_description") + .filter(|value| !value.trim().is_empty()) + .map(|value| format!(": {value}")) + .unwrap_or_default(); + return OAuthCallbackParseResult::OAuthError(format!( + "Authorization failed ({error}{detail})" + )); + } + + let Some(state) = params.get("state").filter(|value| !value.trim().is_empty()) else { + return OAuthCallbackParseResult::InvalidRequest( + "Missing OAuth state in callback".to_string(), + ); + }; + if state != expected_state { + return OAuthCallbackParseResult::InvalidRequest( + "OAuth state mismatch. Please try again.".to_string(), + ); + } + + let Some(code) = params + .get("code") + .filter(|value| !value.trim().is_empty()) + .cloned() + else { + return OAuthCallbackParseResult::InvalidRequest( + "Missing authorization code in callback".to_string(), + ); + }; + + OAuthCallbackParseResult::Success(OAuthCallbackPayload { code }) +} + +fn classify_callback_request( + request: &Request, + expected_path: &str, + expected_state: &str, +) -> CallbackRequestDisposition { + if request + .remote_addr() + .is_some_and(|remote_addr| !remote_addr.ip().is_loopback()) + { + return CallbackRequestDisposition::Continue { + status_code: StatusCode(403), + message: "Only loopback callback requests are accepted.".to_string(), + }; + } + + if request.method() != &Method::Get { + return CallbackRequestDisposition::Continue { + status_code: StatusCode(405), + message: "Only GET requests are supported for OAuth callbacks.".to_string(), + }; + } + + match parse_oauth_callback_target(request.url(), expected_path, expected_state) { + OAuthCallbackParseResult::PathMismatch => CallbackRequestDisposition::Continue { + status_code: StatusCode(404), + message: "Received a request for an unexpected callback path.".to_string(), + }, + OAuthCallbackParseResult::InvalidRequest(message) => { + CallbackRequestDisposition::Continue { status_code: StatusCode(400), message } + } + OAuthCallbackParseResult::OAuthError(message) => { + CallbackRequestDisposition::Fail { status_code: StatusCode(400), message } + } + OAuthCallbackParseResult::Success(payload) => CallbackRequestDisposition::Complete(payload), + } +} + +fn respond_to_callback_request( + request: Request, + disposition: &CallbackRequestDisposition, +) -> anyhow::Result<()> { + let response = match disposition { + CallbackRequestDisposition::Continue { status_code, message } + | CallbackRequestDisposition::Fail { status_code, message } => { + html_response(*status_code, oauth_callback_error_page(message)) + } + CallbackRequestDisposition::Complete(_) => { + html_response(StatusCode(200), oauth_callback_success_page()) + } + }; + + request.respond(response)?; + Ok(()) +} + +fn wait_for_localhost_oauth_callback( + server: Arc, + expected_path: String, + expected_state: String, + shutdown: Arc, +) -> anyhow::Result { + let deadline = Instant::now() + OAUTH_CALLBACK_TIMEOUT; + + loop { + if shutdown.load(Ordering::Relaxed) { + anyhow::bail!("OAuth callback listener was cancelled"); + } + + let remaining = deadline + .checked_duration_since(Instant::now()) + .ok_or_else(|| { + anyhow::anyhow!( + "Timed out waiting for OAuth callback after {} seconds", + OAUTH_CALLBACK_TIMEOUT.as_secs() + ) + })?; + + let timeout = remaining.min(CALLBACK_POLL_INTERVAL); + let Some(request) = server.recv_timeout(timeout)? else { + continue; + }; + + let actual = classify_callback_request(&request, &expected_path, &expected_state); + let _ = respond_to_callback_request(request, &actual); + + match actual { + CallbackRequestDisposition::Continue { .. } => continue, + CallbackRequestDisposition::Fail { message, .. } => anyhow::bail!(message), + CallbackRequestDisposition::Complete(payload) => return Ok(payload.code), + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream}; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; + use std::thread; + + use forge_domain::{OAuthConfig, PkceVerifier, State}; + use pretty_assertions::assert_eq; + + use super::*; + + fn sample_code_request(authorization_url: &str) -> CodeRequest { + CodeRequest { + authorization_url: Url::parse(authorization_url).unwrap(), + state: State::from("expected-state".to_string()), + pkce_verifier: Some(PkceVerifier::from("verifier".to_string())), + oauth_config: OAuthConfig { + auth_url: Url::parse("https://auth.openai.com/oauth/authorize").unwrap(), + token_url: Url::parse("https://auth.openai.com/oauth/token").unwrap(), + client_id: "client-id".to_string().into(), + scopes: vec!["openid".to_string()], + redirect_uri: Some("http://localhost:1455/auth/callback".to_string()), + use_pkce: true, + token_refresh_url: None, + custom_headers: None, + extra_auth_params: None, + }, + } + } + + fn sample_callback_server() -> (Arc, SocketAddr, Arc) { + let fixture = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = fixture.local_addr().unwrap(); + let server = Arc::new(Server::from_listener(fixture, None).unwrap()); + let shutdown = Arc::new(AtomicBool::new(false)); + (server, addr, shutdown) + } + + fn send_http_request(addr: SocketAddr, request: &str) -> String { + let mut fixture = TcpStream::connect(addr).unwrap(); + fixture.write_all(request.as_bytes()).unwrap(); + fixture.shutdown(Shutdown::Write).unwrap(); + + let mut actual = String::new(); + fixture.read_to_string(&mut actual).unwrap(); + actual + } + + #[test] + fn extracts_localhost_redirect_uri_from_oauth_request() { + let setup = sample_code_request( + "https://auth.openai.com/oauth/authorize?client_id=test&redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback&state=expected-state", + ); + + let actual = localhost_oauth_redirect_uri(&setup).unwrap(); + + let expected = "http://localhost:1455/auth/callback"; + assert_eq!(actual.as_str(), expected); + } + + #[test] + fn extracts_ipv6_loopback_redirect_uri_from_oauth_request() { + let mut setup = sample_code_request("https://auth.openai.com/oauth/authorize"); + setup.oauth_config.redirect_uri = Some("http://[::1]:1455/auth/callback".to_string()); + + let actual = localhost_oauth_redirect_uri(&setup).unwrap(); + + let expected = "http://[::1]:1455/auth/callback"; + assert_eq!(actual.as_str(), expected); + } + + #[test] + fn captures_authorization_code_from_localhost_callback() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + }); + + let actual = wait_for_localhost_oauth_callback( + server, + "/auth/callback".to_string(), + "expected-state".to_string(), + shutdown, + ) + .unwrap(); + + let response = fixture.join().unwrap(); + let expected = "auth-code".to_string(); + assert_eq!(actual, expected); + assert!(response.contains("200 OK")); + } + + #[test] + fn keeps_listening_after_invalid_state_until_valid_callback_arrives() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + let first = send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=wrong-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ); + let second = send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ); + (first, second) + }); + + let actual = wait_for_localhost_oauth_callback( + server, + "/auth/callback".to_string(), + "expected-state".to_string(), + shutdown, + ) + .unwrap(); + + let responses = fixture.join().unwrap(); + let expected = "auth-code".to_string(); + assert_eq!(actual, expected); + assert!(responses.0.contains("400 Bad Request")); + assert!(responses.1.contains("200 OK")); + } + + #[test] + fn keeps_listening_after_invalid_method_until_valid_callback_arrives() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + let first = send_http_request( + addr, + "POST /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n", + ); + let second = send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ); + (first, second) + }); + + let actual = wait_for_localhost_oauth_callback( + server, + "/auth/callback".to_string(), + "expected-state".to_string(), + shutdown, + ) + .unwrap(); + + let responses = fixture.join().unwrap(); + let expected = "auth-code".to_string(); + assert_eq!(actual, expected); + assert!(responses.0.contains("405 Method Not Allowed")); + assert!(responses.1.contains("200 OK")); + } + + #[test] + fn stops_when_provider_returns_terminal_oauth_error() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + send_http_request( + addr, + "GET /auth/callback?error=access_denied&error_description=user%20cancelled HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + }); + + let actual = wait_for_localhost_oauth_callback( + server, + "/auth/callback".to_string(), + "expected-state".to_string(), + shutdown, + ) + .unwrap_err(); + + let response = fixture.join().unwrap(); + let expected = "Authorization failed (access_denied: user cancelled)"; + assert_eq!(actual.to_string(), expected); + assert!(response.contains("400 Bad Request")); + } +} diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 7a0bf8276b..3d6b946bac 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -2433,6 +2433,23 @@ impl A + Send + Sync> UI { format!("Authenticate using your {provider_id} account").dimmed() ))?; + let callback_server = + match crate::oauth_callback::LocalhostOAuthCallbackServer::start(request) { + Ok(Some(server)) => { + self.writeln(format!( + "{} Waiting for browser callback on {}", + "→".blue(), + server.redirect_uri().as_str().blue().underline() + ))?; + Some(server) + } + Ok(None) | Err(_) => { + // Not a localhost callback flow, or the listener could not be + // started — fall back to manual code paste. + None + } + }; + // Display authorization URL self.writeln(format!( "{} Please visit: {}", @@ -2447,14 +2464,20 @@ impl A + Send + Sync> UI { )))?; } - // Prompt user to paste authorization code - let code = ForgeWidget::input("Paste the authorization code") - .prompt()? - .ok_or_else(|| anyhow::anyhow!("Authorization code input cancelled"))?; + let code = if let Some(server) = callback_server { + server.wait_for_code().await? + } else { + // Prompt user to paste authorization code + let code = ForgeWidget::input("Paste the authorization code") + .prompt()? + .ok_or_else(|| anyhow::anyhow!("Authorization code input cancelled"))?; - if code.trim().is_empty() { - anyhow::bail!("Authorization code cannot be empty"); - } + if code.trim().is_empty() { + anyhow::bail!("Authorization code cannot be empty"); + } + + code + }; self.spinner .start(Some("Exchanging authorization code..."))?; diff --git a/crates/forge_repo/src/provider/provider.json b/crates/forge_repo/src/provider/provider.json index 73f59a6d65..926516f694 100644 --- a/crates/forge_repo/src/provider/provider.json +++ b/crates/forge_repo/src/provider/provider.json @@ -2149,6 +2149,23 @@ } ], "auth_methods": [ + { + "oauth_code": { + "auth_url": "https://auth.openai.com/oauth/authorize", + "token_url": "https://auth.openai.com/oauth/token", + "client_id": "app_EMoamEEZ73f0CkXaXp7hrann", + "scopes": ["openid", "profile", "email", "offline_access"], + "redirect_uri": "http://localhost:1455/auth/callback", + "use_pkce": true, + "custom_headers": { + "originator": "forge" + }, + "extra_auth_params": { + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true" + } + } + }, { "codex_device": { "auth_url": "https://auth.openai.com/api/accounts/deviceauth/usercode",