From e478ad9ca609c39d1092b3f326a91e61178abdbc Mon Sep 17 00:00:00 2001 From: David Tivris <163727302+tivris@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:35:39 +0300 Subject: [PATCH 01/11] feat(auth): add OAuth authorization code + PKCE for Codex provider The Codex provider previously only supported the device code flow, which requires workspace admin enablement on OpenAI Business/Enterprise plans. Add authorization code + PKCE as the primary login method so non-admin users can authenticate through a standard browser redirect. Provider config (provider.json): - Add oauth_code entry before codex_device with PKCE, localhost redirect on port 1455, and the same client ID / scopes Auth strategy (strategy.rs): - Extract chatgpt_account_id helper to module level so both OAuthCodeStrategy and CodexDeviceStrategy can use it - Enrich OAuth code credentials with the ChatGPT account ID needed for API request headers UI callback server (ui.rs): - Localhost TCP listener with state validation, HTML response pages, 5-minute timeout, and graceful fallback to manual paste if the port is unavailable Closes #2767 --- crates/forge_infra/src/auth/strategy.rs | 108 +++--- crates/forge_main/src/ui.rs | 354 ++++++++++++++++++- crates/forge_repo/src/provider/provider.json | 17 + 3 files changed, 426 insertions(+), 53 deletions(-) diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 1062581faf..f6606080d9 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -57,6 +57,58 @@ impl AuthStrategy for ApiKeyStrategy { } } +/// Extract the ChatGPT account ID from a JWT token's claims. +/// +/// Checks `chatgpt_account_id`, `https://api.openai.com/auth.chatgpt_account_id`, +/// and `organizations[0].id` in that order, matching the opencode +/// implementation. +fn extract_chatgpt_account_id(token: &str) -> Option { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() != 3 { + return None; + } + use base64::Engine; + let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(parts[1]) + .ok()?; + let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?; + + // Try chatgpt_account_id first + if let Some(id) = claims["chatgpt_account_id"].as_str() { + return Some(id.to_string()); + } + // Try nested auth claim + if let Some(id) = claims["https://api.openai.com/auth"]["chatgpt_account_id"].as_str() { + return Some(id.to_string()); + } + // Fall back to organizations[0].id + if let Some(id) = claims["organizations"] + .as_array() + .and_then(|orgs| orgs.first()) + .and_then(|org| org["id"].as_str()) + { + return Some(id.to_string()); + } + None +} + +/// Adds Codex-specific credential metadata derived from OAuth tokens. +fn enrich_codex_oauth_credential( + provider_id: &ProviderId, + credential: &mut AuthCredential, + access_token: &str, +) { + if *provider_id != ProviderId::CODEX { + return; + } + + if let Some(account_id) = extract_chatgpt_account_id(access_token) { + credential + .url_params + .insert("chatgpt_account_id".to_string().into(), account_id.into()); + } +} + /// OAuth Code Strategy - Browser redirect flow pub struct OAuthCodeStrategy { provider_id: ProviderId, @@ -96,7 +148,7 @@ impl AuthStrategy for OAuthCodeStrategy { let token_response = self .adapter .exchange_code( - &self.config, + &ctx.request.oauth_config, ctx.response.code.as_str(), ctx.request.pkce_verifier.as_ref().map(|v| v.as_str()), ) @@ -107,12 +159,19 @@ impl AuthStrategy for OAuthCodeStrategy { )) })?; - build_oauth_credential( + let access_token = token_response.access_token.clone(); + let mut credential = build_oauth_credential( self.provider_id.clone(), token_response, - &self.config, + &ctx.request.oauth_config, chrono::Duration::hours(1), // Code flow default - ) + )?; + enrich_codex_oauth_credential( + &self.provider_id, + &mut credential, + &access_token, + ); + Ok(credential) } _ => Err(AuthError::InvalidContext("Expected Code context".to_string()).into()), } @@ -479,41 +538,6 @@ struct CodexDeviceTokenResponse { code_verifier: String, } -/// Extract the ChatGPT account ID from a JWT token's claims. -/// -/// Checks `chatgpt_account_id`, `https://api.openai.com/auth.chatgpt_account_id`, -/// and `organizations[0].id` in that order, matching the opencode -/// implementation. -fn extract_chatgpt_account_id(token: &str) -> Option { - let parts: Vec<&str> = token.split('.').collect(); - if parts.len() != 3 { - return None; - } - use base64::Engine; - let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(parts[1]) - .ok()?; - let claims: serde_json::Value = serde_json::from_slice(&payload).ok()?; - - // Try chatgpt_account_id first - if let Some(id) = claims["chatgpt_account_id"].as_str() { - return Some(id.to_string()); - } - // Try nested auth claim - if let Some(id) = claims["https://api.openai.com/auth"]["chatgpt_account_id"].as_str() { - return Some(id.to_string()); - } - // Fall back to organizations[0].id - if let Some(id) = claims["organizations"] - .as_array() - .and_then(|orgs| orgs.first()) - .and_then(|org| org["id"].as_str()) - { - return Some(id.to_string()); - } - None -} - #[async_trait::async_trait] impl AuthStrategy for CodexDeviceStrategy { async fn init(&self) -> anyhow::Result { @@ -583,11 +607,11 @@ impl AuthStrategy for CodexDeviceStrategy { )?; // Store account_id in url_params so it's persisted and available - // for chat request headers - if let Some(id) = account_id { + // for chat request headers. + if let Some(account_id) = account_id { credential .url_params - .insert("chatgpt_account_id".to_string().into(), id.into()); + .insert("chatgpt_account_id".to_string().into(), account_id.into()); } Ok(credential) diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 7f0f251e22..3dc24371aa 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; use std::fmt::Display; +use std::io::{Read, Write}; +use std::net::{TcpListener, TcpStream}; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; @@ -59,6 +61,209 @@ struct ConversationDump { related_conversations: Vec, } +#[derive(Debug, Clone, PartialEq, Eq)] +struct OAuthCallbackPayload { + code: String, +} + +fn escape_html(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +fn localhost_oauth_redirect_uri(request: &CodeRequest) -> Option { + request + .oauth_config + .redirect_uri + .as_ref() + .and_then(|uri| Url::parse(uri).ok()) + .filter(|uri| { + uri.scheme() == "http" + && matches!(uri.host_str(), Some("localhost") | Some("127.0.0.1")) + && uri.port().is_some() + }) +} + +fn localhost_oauth_bind_addr(redirect_uri: &Url) -> anyhow::Result { + let host = match redirect_uri.host_str() { + Some("localhost") => "127.0.0.1", + Some(host) => host, + None => anyhow::bail!("OAuth redirect URI is missing a host"), + }; + let port = redirect_uri + .port() + .ok_or_else(|| anyhow::anyhow!("OAuth redirect URI is missing an explicit port"))?; + Ok(format!("{host}:{port}")) +} + +fn oauth_callback_success_page() -> String { + "ForgeCode Authorization Successful

Authorization Successful

You can close this window and return to ForgeCode.

".to_string() +} + +fn oauth_callback_error_page(message: &str) -> String { + format!( + "ForgeCode Authorization Failed

Authorization Failed

ForgeCode could not complete sign-in.

{}
", + escape_html(message) + ) +} + +fn write_http_response( + stream: &mut TcpStream, + status_line: &str, + body: &str, +) -> anyhow::Result<()> { + let response = format!( + "HTTP/1.1 {status_line}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + stream.flush()?; + Ok(()) +} + +fn parse_oauth_callback_target( + request_target: &str, + expected_path: &str, + expected_state: &str, +) -> anyhow::Result> { + let callback_url = Url::parse(&format!("http://localhost{request_target}"))?; + if callback_url.path() != expected_path { + return Ok(None); + } + + let params: HashMap = callback_url.query_pairs().into_owned().collect(); + if let Some(error) = params.get("error") { + let detail = params + .get("error_description") + .filter(|value| !value.trim().is_empty()) + .map(|value| format!(": {value}")) + .unwrap_or_default(); + anyhow::bail!("Authorization failed ({error}{detail})"); + } + + let state = params + .get("state") + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?; + if state != expected_state { + anyhow::bail!("OAuth state mismatch. Please try again."); + } + + let code = params + .get("code") + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing authorization code in callback"))? + .to_string(); + + Ok(Some(OAuthCallbackPayload { code })) +} + +/// Maximum time to wait for the OAuth browser callback before giving up. +const OAUTH_CALLBACK_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); + +fn wait_for_localhost_oauth_callback( + listener: TcpListener, + expected_path: String, + expected_state: String, +) -> anyhow::Result { + let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT; + listener.set_nonblocking(false)?; + + loop { + let remaining = deadline + .checked_duration_since(std::time::Instant::now()) + .ok_or_else(|| { + anyhow::anyhow!( + "Timed out waiting for OAuth callback after {} seconds", + OAUTH_CALLBACK_TIMEOUT.as_secs() + ) + })?; + + // Cap each accept at the remaining time so we re-check the deadline + let accept_timeout = remaining.min(std::time::Duration::from_secs(5)); + listener.set_nonblocking(true)?; + let accept_result = loop { + match listener.accept() { + Ok(conn) => break Ok(conn), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + std::thread::sleep(std::time::Duration::from_millis(100)); + if std::time::Instant::now() + accept_timeout > deadline { + // Will be caught by the remaining check at top of outer loop + break Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "accept timeout", + )); + } + } + Err(e) => break Err(e), + } + }; + + let (mut stream, _) = match accept_result { + Ok(conn) => conn, + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => continue, + Err(e) => return Err(e.into()), + }; + + let mut buffer = [0u8; 8192]; + let bytes_read = match stream.read(&mut buffer) { + Ok(n) => n, + Err(_) => continue, + }; + if bytes_read == 0 { + continue; + } + + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + let Some(request_line) = request.lines().next() else { + continue; + }; + + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or_default(); + let request_target = parts.next().unwrap_or_default(); + + if method != "GET" { + let _ = write_http_response( + &mut stream, + "405 Method Not Allowed", + &oauth_callback_error_page("Only GET requests are supported for OAuth callbacks."), + ); + continue; + } + + match parse_oauth_callback_target(request_target, &expected_path, &expected_state) { + Ok(Some(payload)) => { + let _ = write_http_response(&mut stream, "200 OK", &oauth_callback_success_page()); + return Ok(payload.code); + } + Ok(None) => { + let _ = write_http_response( + &mut stream, + "404 Not Found", + &oauth_callback_error_page( + "Received a request for an unexpected callback path.", + ), + ); + } + Err(err) => { + let message = err.to_string(); + let _ = write_http_response( + &mut stream, + "400 Bad Request", + &oauth_callback_error_page(&message), + ); + return Err(err); + } + } + } +} + /// Formats an MCP server config for display, redacting sensitive information. /// Returns the command/URL string only. fn format_mcp_server(server: &forge_domain::McpServerConfig) -> String { @@ -2420,6 +2625,32 @@ impl A + Send + Sync> UI { format!("Authenticate using your {provider_id} account").dimmed() ))?; + let callback_task = if let Some(redirect_uri) = localhost_oauth_redirect_uri(request) { + match localhost_oauth_bind_addr(&redirect_uri) + .and_then(|addr| TcpListener::bind(&addr).map_err(Into::into)) + { + Ok(listener) => { + let callback_path = redirect_uri.path().to_string(); + let expected_state = request.state.to_string(); + self.writeln(format!( + "{} Waiting for browser callback on {}", + "→".blue(), + redirect_uri.as_str().blue().underline() + ))?; + + Some(tokio::task::spawn_blocking(move || { + wait_for_localhost_oauth_callback(listener, callback_path, expected_state) + })) + } + Err(_) => { + // Port in use or bind failed — fall back to manual paste + None + } + } + } else { + None + }; + // Display authorization URL self.writeln(format!( "{} Please visit: {}", @@ -2434,14 +2665,21 @@ impl A + Send + Sync> UI { )))?; } - // Prompt user to paste authorization code - let code = ForgeWidget::input("Paste the authorization code") - .prompt()? - .ok_or_else(|| anyhow::anyhow!("Authorization code input cancelled"))?; + let code = if let Some(task) = callback_task { + task.await + .map_err(|e| anyhow::anyhow!("OAuth callback task failed: {e}"))?? + } else { + // Prompt user to paste authorization code + let code = ForgeWidget::input("Paste the authorization code") + .prompt()? + .ok_or_else(|| anyhow::anyhow!("Authorization code input cancelled"))?; - if code.trim().is_empty() { - anyhow::bail!("Authorization code cannot be empty"); - } + if code.trim().is_empty() { + anyhow::bail!("Authorization code cannot be empty"); + } + + code + }; self.spinner .start(Some("Exchanging authorization code..."))?; @@ -4123,8 +4361,102 @@ impl A + Send + Sync> UI { #[cfg(test)] mod tests { - // Note: Tests for confirm_delete_conversation are disabled because - // ForgeSelect::confirm is not easily mockable in the current - // architecture. The functionality is tested through integration tests - // instead. + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream}; + use std::thread; + + use forge_domain::{CodeRequest, OAuthConfig, PkceVerifier, State}; + use url::Url; + + use super::{localhost_oauth_redirect_uri, wait_for_localhost_oauth_callback}; + + fn sample_code_request(authorization_url: &str) -> CodeRequest { + CodeRequest { + authorization_url: Url::parse(authorization_url).unwrap(), + state: State::from("expected-state".to_string()), + pkce_verifier: Some(PkceVerifier::from("verifier".to_string())), + oauth_config: OAuthConfig { + auth_url: Url::parse("https://auth.openai.com/oauth/authorize").unwrap(), + token_url: Url::parse("https://auth.openai.com/oauth/token").unwrap(), + client_id: "client-id".to_string().into(), + scopes: vec!["openid".to_string()], + redirect_uri: Some("http://localhost:1455/auth/callback".to_string()), + use_pkce: true, + token_refresh_url: None, + custom_headers: None, + extra_auth_params: None, + }, + } + } + + #[test] + fn extracts_localhost_redirect_uri_from_oauth_request() { + let request = sample_code_request( + "https://auth.openai.com/oauth/authorize?client_id=test&redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback&state=expected-state", + ); + + let redirect_uri = localhost_oauth_redirect_uri(&request).unwrap(); + + assert_eq!(redirect_uri.as_str(), "http://localhost:1455/auth/callback"); + } + + #[test] + fn captures_authorization_code_from_localhost_callback() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let client = thread::spawn(move || { + let mut stream = TcpStream::connect(addr).unwrap(); + stream + .write_all( + b"GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .unwrap(); + let mut response = String::new(); + stream.read_to_string(&mut response).unwrap(); + response + }); + + let code = wait_for_localhost_oauth_callback( + listener, + "/auth/callback".to_string(), + "expected-state".to_string(), + ) + .unwrap(); + + let response = client.join().unwrap(); + + assert_eq!(code, "auth-code"); + assert!(response.contains("200 OK")); + } + + #[test] + fn rejects_localhost_callback_with_mismatched_state() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let client = thread::spawn(move || { + let mut stream = TcpStream::connect(addr).unwrap(); + stream + .write_all( + b"GET /auth/callback?code=auth-code&state=wrong-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .unwrap(); + let mut response = String::new(); + stream.read_to_string(&mut response).unwrap(); + response + }); + + let error = wait_for_localhost_oauth_callback( + listener, + "/auth/callback".to_string(), + "expected-state".to_string(), + ) + .unwrap_err(); + + let response = client.join().unwrap(); + + assert!(error.to_string().contains("OAuth state mismatch")); + assert!(response.contains("400 Bad Request")); + } } diff --git a/crates/forge_repo/src/provider/provider.json b/crates/forge_repo/src/provider/provider.json index 2805c1f8a5..af64ccfa48 100644 --- a/crates/forge_repo/src/provider/provider.json +++ b/crates/forge_repo/src/provider/provider.json @@ -2069,6 +2069,23 @@ } ], "auth_methods": [ + { + "oauth_code": { + "auth_url": "https://auth.openai.com/oauth/authorize", + "token_url": "https://auth.openai.com/oauth/token", + "client_id": "app_EMoamEEZ73f0CkXaXp7hrann", + "scopes": ["openid", "profile", "email", "offline_access"], + "redirect_uri": "http://localhost:1455/auth/callback", + "use_pkce": true, + "custom_headers": { + "originator": "forge" + }, + "extra_auth_params": { + "id_token_add_organizations": "true", + "codex_cli_simplified_flow": "true" + } + } + }, { "codex_device": { "auth_url": "https://auth.openai.com/api/accounts/deviceauth/usercode", From afae0b60342135f0787bed3b72b3e3c3ea142983 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:37:44 +0000 Subject: [PATCH 02/11] [autofix.ci] apply automated fixes --- crates/forge_infra/src/auth/strategy.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index f6606080d9..9e975d7bd7 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -166,11 +166,7 @@ impl AuthStrategy for OAuthCodeStrategy { &ctx.request.oauth_config, chrono::Duration::hours(1), // Code flow default )?; - enrich_codex_oauth_credential( - &self.provider_id, - &mut credential, - &access_token, - ); + enrich_codex_oauth_credential(&self.provider_id, &mut credential, &access_token); Ok(credential) } _ => Err(AuthError::InvalidContext("Expected Code context".to_string()).into()), From 0873133a7222a15e633364c69e34a4c7dca0399c Mon Sep 17 00:00:00 2001 From: David Tivris <163727302+tivris@users.noreply.github.com> Date: Thu, 2 Apr 2026 12:45:32 +0300 Subject: [PATCH 03/11] fix(auth): correct accept timeout tracking in OAuth callback listener Track elapsed time from the start of each accept attempt instead of comparing a fixed duration against the deadline on every poll iteration. --- crates/forge_main/src/ui.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 3dc24371aa..894bbcfbd7 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -187,12 +187,15 @@ fn wait_for_localhost_oauth_callback( // Cap each accept at the remaining time so we re-check the deadline let accept_timeout = remaining.min(std::time::Duration::from_secs(5)); listener.set_nonblocking(true)?; + let accept_start = std::time::Instant::now(); let accept_result = loop { match listener.accept() { Ok(conn) => break Ok(conn), Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { std::thread::sleep(std::time::Duration::from_millis(100)); - if std::time::Instant::now() + accept_timeout > deadline { + if accept_start.elapsed() >= accept_timeout + || std::time::Instant::now() >= deadline + { // Will be caught by the remaining check at top of outer loop break Err(std::io::Error::new( std::io::ErrorKind::TimedOut, From aa18e7e62d2bb4cc860278f44c72d0f9ccf13e37 Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Fri, 3 Apr 2026 23:09:48 +0530 Subject: [PATCH 04/11] Support Codex account extraction from id_token Prefer id_token claims for ChatGPT account ID extraction and fall back to access_token when needed. Co-Authored-By: ForgeCode --- .../src/auth/auth_token_response.rs | 4 +++ crates/forge_infra/src/auth/strategy.rs | 28 +++++++++++-------- crates/forge_infra/src/auth/util.rs | 2 ++ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/crates/forge_domain/src/auth/auth_token_response.rs b/crates/forge_domain/src/auth/auth_token_response.rs index 7bea91b8e1..a0b66c713d 100644 --- a/crates/forge_domain/src/auth/auth_token_response.rs +++ b/crates/forge_domain/src/auth/auth_token_response.rs @@ -26,6 +26,10 @@ pub struct OAuthTokenResponse { /// OAuth scopes granted #[serde(skip_serializing_if = "Option::is_none")] pub scope: Option, + + /// ID token containing user identity claims (OpenID Connect) + #[serde(skip_serializing_if = "Option::is_none")] + pub id_token: Option, } fn default_token_type() -> String { diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 9e975d7bd7..25f2e50600 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -93,16 +93,26 @@ fn extract_chatgpt_account_id(token: &str) -> Option { } /// Adds Codex-specific credential metadata derived from OAuth tokens. +/// +/// Tries to extract the account ID from the `id_token` first (which typically +/// contains the user identity claims in OpenID Connect flows), then falls back +/// to the `access_token` if needed. fn enrich_codex_oauth_credential( provider_id: &ProviderId, credential: &mut AuthCredential, + id_token: Option<&str>, access_token: &str, ) { if *provider_id != ProviderId::CODEX { return; } - if let Some(account_id) = extract_chatgpt_account_id(access_token) { + // Try id_token first (preferred for user identity claims) + let account_id = id_token + .and_then(extract_chatgpt_account_id) + .or_else(|| extract_chatgpt_account_id(access_token)); + + if let Some(account_id) = account_id { credential .url_params .insert("chatgpt_account_id".to_string().into(), account_id.into()); @@ -160,13 +170,14 @@ impl AuthStrategy for OAuthCodeStrategy { })?; let access_token = token_response.access_token.clone(); + let id_token = token_response.id_token.clone(); let mut credential = build_oauth_credential( self.provider_id.clone(), token_response, &ctx.request.oauth_config, chrono::Duration::hours(1), // Code flow default )?; - enrich_codex_oauth_credential(&self.provider_id, &mut credential, &access_token); + enrich_codex_oauth_credential(&self.provider_id, &mut credential, id_token.as_deref(), &access_token); Ok(credential) } _ => Err(AuthError::InvalidContext("Expected Code context".to_string()).into()), @@ -590,11 +601,8 @@ impl AuthStrategy for CodexDeviceStrategy { // Poll for authorization code using the custom OpenAI endpoint let token_response = codex_poll_for_tokens(&ctx.request, &self.config).await?; - // Extract ChatGPT account ID from the access token JWT. - // This is used for the optional `ChatGPT-Account-Id` request - // header when available. - let account_id = extract_chatgpt_account_id(&token_response.access_token); - + let access_token = token_response.access_token.clone(); + let id_token = token_response.id_token.clone(); let mut credential = build_oauth_credential( self.provider_id.clone(), token_response, @@ -604,11 +612,7 @@ impl AuthStrategy for CodexDeviceStrategy { // Store account_id in url_params so it's persisted and available // for chat request headers. - if let Some(account_id) = account_id { - credential - .url_params - .insert("chatgpt_account_id".to_string().into(), account_id.into()); - } + enrich_codex_oauth_credential(&self.provider_id, &mut credential, id_token.as_deref(), &access_token); Ok(credential) } diff --git a/crates/forge_infra/src/auth/util.rs b/crates/forge_infra/src/auth/util.rs index de652e0f67..90f8bf1d71 100644 --- a/crates/forge_infra/src/auth/util.rs +++ b/crates/forge_infra/src/auth/util.rs @@ -36,6 +36,7 @@ pub(crate) fn into_domain(token: T) -> OAuthTokenRespo .collect::>() .join(" ") }), + id_token: None, // oauth2 crate doesn't provide id_token directly } } @@ -98,6 +99,7 @@ pub(crate) fn build_token_response( expires_at: None, token_type: "Bearer".to_string(), scope: None, + id_token: None, } } From 07c890174f787bcf042ab9bf04e07c7dd140ec69 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:41:41 +0000 Subject: [PATCH 05/11] [autofix.ci] apply automated fixes --- crates/forge_infra/src/auth/strategy.rs | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/crates/forge_infra/src/auth/strategy.rs b/crates/forge_infra/src/auth/strategy.rs index 25f2e50600..70a4711976 100644 --- a/crates/forge_infra/src/auth/strategy.rs +++ b/crates/forge_infra/src/auth/strategy.rs @@ -177,7 +177,12 @@ impl AuthStrategy for OAuthCodeStrategy { &ctx.request.oauth_config, chrono::Duration::hours(1), // Code flow default )?; - enrich_codex_oauth_credential(&self.provider_id, &mut credential, id_token.as_deref(), &access_token); + enrich_codex_oauth_credential( + &self.provider_id, + &mut credential, + id_token.as_deref(), + &access_token, + ); Ok(credential) } _ => Err(AuthError::InvalidContext("Expected Code context".to_string()).into()), @@ -612,7 +617,12 @@ impl AuthStrategy for CodexDeviceStrategy { // Store account_id in url_params so it's persisted and available // for chat request headers. - enrich_codex_oauth_credential(&self.provider_id, &mut credential, id_token.as_deref(), &access_token); + enrich_codex_oauth_credential( + &self.provider_id, + &mut credential, + id_token.as_deref(), + &access_token, + ); Ok(credential) } From f9102858025345d48396d4b03b30a4664ced2066 Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Fri, 3 Apr 2026 23:19:08 +0530 Subject: [PATCH 06/11] Extract OAuth callback server from UI Move the localhost OAuth callback server into a dedicated module so ui.rs only orchestrates the browser auth flow. Co-Authored-By: ForgeCode --- crates/forge_main/src/lib.rs | 1 + crates/forge_main/src/oauth_callback.rs | 362 ++++++++++++++++++++++++ crates/forge_main/src/ui.rs | 352 ++--------------------- 3 files changed, 379 insertions(+), 336 deletions(-) create mode 100644 crates/forge_main/src/oauth_callback.rs diff --git a/crates/forge_main/src/lib.rs b/crates/forge_main/src/lib.rs index c5b342df7c..1fc22a116d 100644 --- a/crates/forge_main/src/lib.rs +++ b/crates/forge_main/src/lib.rs @@ -7,6 +7,7 @@ mod editor; mod info; mod input; mod model; +mod oauth_callback; mod porcelain; mod prompt; mod sandbox; diff --git a/crates/forge_main/src/oauth_callback.rs b/crates/forge_main/src/oauth_callback.rs new file mode 100644 index 0000000000..033330990b --- /dev/null +++ b/crates/forge_main/src/oauth_callback.rs @@ -0,0 +1,362 @@ +use std::collections::HashMap; +use std::io::{Read, Write}; +use std::net::{TcpListener, TcpStream}; + +use forge_domain::CodeRequest; +use url::Url; + +/// Maximum time to wait for the OAuth browser callback before giving up. +const OAUTH_CALLBACK_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); + +#[derive(Debug, Clone, PartialEq, Eq)] +struct OAuthCallbackPayload { + code: String, +} + +/// Localhost OAuth callback server that waits for a browser redirect and +/// returns the authorization code. +pub(crate) struct LocalhostOAuthCallbackServer { + redirect_uri: Url, + task: tokio::task::JoinHandle>, +} + +impl LocalhostOAuthCallbackServer { + /// Starts a localhost OAuth callback server when the request uses a + /// localhost redirect URI. + /// + /// Returns `Ok(None)` when the request is not configured for a localhost + /// callback. + /// + /// # Errors + /// + /// Returns an error if the localhost redirect URI is invalid or the TCP + /// listener cannot be bound. + pub(crate) fn start(request: &CodeRequest) -> anyhow::Result> { + let Some(redirect_uri) = localhost_oauth_redirect_uri(request) else { + return Ok(None); + }; + + let listener = TcpListener::bind(localhost_oauth_bind_addr(&redirect_uri)?)?; + let callback_path = redirect_uri.path().to_string(); + let expected_state = request.state.to_string(); + let task = tokio::task::spawn_blocking(move || { + wait_for_localhost_oauth_callback(listener, callback_path, expected_state) + }); + + Ok(Some(Self { redirect_uri, task })) + } + + /// Returns the redirect URI the callback server is listening on. + pub(crate) fn redirect_uri(&self) -> &Url { + &self.redirect_uri + } + + /// Waits for the browser callback and returns the authorization code. + /// + /// # Errors + /// + /// Returns an error when the background task fails or the callback request + /// is invalid. + pub(crate) async fn wait_for_code(self) -> anyhow::Result { + self.task + .await + .map_err(|e| anyhow::anyhow!("OAuth callback task failed: {e}"))? + } +} + +fn escape_html(input: &str) -> String { + input + .replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +fn localhost_oauth_redirect_uri(request: &CodeRequest) -> Option { + request + .oauth_config + .redirect_uri + .as_ref() + .and_then(|uri| Url::parse(uri).ok()) + .filter(|uri| { + uri.scheme() == "http" + && matches!(uri.host_str(), Some("localhost") | Some("127.0.0.1")) + && uri.port().is_some() + }) +} + +fn localhost_oauth_bind_addr(redirect_uri: &Url) -> anyhow::Result { + let host = match redirect_uri.host_str() { + Some("localhost") => "127.0.0.1", + Some(host) => host, + None => anyhow::bail!("OAuth redirect URI is missing a host"), + }; + let port = redirect_uri + .port() + .ok_or_else(|| anyhow::anyhow!("OAuth redirect URI is missing an explicit port"))?; + Ok(format!("{host}:{port}")) +} + +fn oauth_callback_success_page() -> String { + "ForgeCode Authorization Successful

Authorization Successful

You can close this window and return to ForgeCode.

".to_string() +} + +fn oauth_callback_error_page(message: &str) -> String { + format!( + "ForgeCode Authorization Failed

Authorization Failed

ForgeCode could not complete sign-in.

{}
", + escape_html(message) + ) +} + +fn write_http_response( + stream: &mut TcpStream, + status_line: &str, + body: &str, +) -> anyhow::Result<()> { + let response = format!( + "HTTP/1.1 {status_line}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes())?; + stream.flush()?; + Ok(()) +} + +fn parse_oauth_callback_target( + request_target: &str, + expected_path: &str, + expected_state: &str, +) -> anyhow::Result> { + let callback_url = Url::parse(&format!("http://localhost{request_target}"))?; + if callback_url.path() != expected_path { + return Ok(None); + } + + let params: HashMap = callback_url.query_pairs().into_owned().collect(); + if let Some(error) = params.get("error") { + let detail = params + .get("error_description") + .filter(|value| !value.trim().is_empty()) + .map(|value| format!(": {value}")) + .unwrap_or_default(); + anyhow::bail!("Authorization failed ({error}{detail})"); + } + + let state = params + .get("state") + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?; + if state != expected_state { + anyhow::bail!("OAuth state mismatch. Please try again."); + } + + let code = params + .get("code") + .filter(|value| !value.trim().is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing authorization code in callback"))? + .to_string(); + + Ok(Some(OAuthCallbackPayload { code })) +} + +fn wait_for_localhost_oauth_callback( + listener: TcpListener, + expected_path: String, + expected_state: String, +) -> anyhow::Result { + let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT; + listener.set_nonblocking(false)?; + + loop { + let remaining = deadline + .checked_duration_since(std::time::Instant::now()) + .ok_or_else(|| { + anyhow::anyhow!( + "Timed out waiting for OAuth callback after {} seconds", + OAUTH_CALLBACK_TIMEOUT.as_secs() + ) + })?; + + let accept_timeout = remaining.min(std::time::Duration::from_secs(5)); + listener.set_nonblocking(true)?; + let accept_start = std::time::Instant::now(); + let accept_result = loop { + match listener.accept() { + Ok(conn) => break Ok(conn), + Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { + std::thread::sleep(std::time::Duration::from_millis(100)); + if accept_start.elapsed() >= accept_timeout + || std::time::Instant::now() >= deadline + { + break Err(std::io::Error::new( + std::io::ErrorKind::TimedOut, + "accept timeout", + )); + } + } + Err(e) => break Err(e), + } + }; + + let (mut stream, _) = match accept_result { + Ok(conn) => conn, + Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => continue, + Err(e) => return Err(e.into()), + }; + + let mut buffer = [0u8; 8192]; + let bytes_read = match stream.read(&mut buffer) { + Ok(n) => n, + Err(_) => continue, + }; + if bytes_read == 0 { + continue; + } + + let request = String::from_utf8_lossy(&buffer[..bytes_read]); + let Some(request_line) = request.lines().next() else { + continue; + }; + + let mut parts = request_line.split_whitespace(); + let method = parts.next().unwrap_or_default(); + let request_target = parts.next().unwrap_or_default(); + + if method != "GET" { + let _ = write_http_response( + &mut stream, + "405 Method Not Allowed", + &oauth_callback_error_page("Only GET requests are supported for OAuth callbacks."), + ); + continue; + } + + match parse_oauth_callback_target(request_target, &expected_path, &expected_state) { + Ok(Some(payload)) => { + let _ = write_http_response(&mut stream, "200 OK", &oauth_callback_success_page()); + return Ok(payload.code); + } + Ok(None) => { + let _ = write_http_response( + &mut stream, + "404 Not Found", + &oauth_callback_error_page( + "Received a request for an unexpected callback path.", + ), + ); + } + Err(err) => { + let message = err.to_string(); + let _ = write_http_response( + &mut stream, + "400 Bad Request", + &oauth_callback_error_page(&message), + ); + return Err(err); + } + } + } +} + +#[cfg(test)] +mod tests { + use std::io::{Read, Write}; + use std::net::{TcpListener, TcpStream}; + use std::thread; + + use forge_domain::{OAuthConfig, PkceVerifier, State}; + + use super::*; + + fn sample_code_request(authorization_url: &str) -> CodeRequest { + CodeRequest { + authorization_url: Url::parse(authorization_url).unwrap(), + state: State::from("expected-state".to_string()), + pkce_verifier: Some(PkceVerifier::from("verifier".to_string())), + oauth_config: OAuthConfig { + auth_url: Url::parse("https://auth.openai.com/oauth/authorize").unwrap(), + token_url: Url::parse("https://auth.openai.com/oauth/token").unwrap(), + client_id: "client-id".to_string().into(), + scopes: vec!["openid".to_string()], + redirect_uri: Some("http://localhost:1455/auth/callback".to_string()), + use_pkce: true, + token_refresh_url: None, + custom_headers: None, + extra_auth_params: None, + }, + } + } + + #[test] + fn extracts_localhost_redirect_uri_from_oauth_request() { + let request = sample_code_request( + "https://auth.openai.com/oauth/authorize?client_id=test&redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback&state=expected-state", + ); + + let redirect_uri = localhost_oauth_redirect_uri(&request).unwrap(); + + assert_eq!(redirect_uri.as_str(), "http://localhost:1455/auth/callback"); + } + + #[test] + fn captures_authorization_code_from_localhost_callback() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let client = thread::spawn(move || { + let mut stream = TcpStream::connect(addr).unwrap(); + stream + .write_all( + b"GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .unwrap(); + let mut response = String::new(); + stream.read_to_string(&mut response).unwrap(); + response + }); + + let code = wait_for_localhost_oauth_callback( + listener, + "/auth/callback".to_string(), + "expected-state".to_string(), + ) + .unwrap(); + + let response = client.join().unwrap(); + + assert_eq!(code, "auth-code"); + assert!(response.contains("200 OK")); + } + + #[test] + fn rejects_localhost_callback_with_mismatched_state() { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = listener.local_addr().unwrap(); + + let client = thread::spawn(move || { + let mut stream = TcpStream::connect(addr).unwrap(); + stream + .write_all( + b"GET /auth/callback?code=auth-code&state=wrong-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + .unwrap(); + let mut response = String::new(); + stream.read_to_string(&mut response).unwrap(); + response + }); + + let error = wait_for_localhost_oauth_callback( + listener, + "/auth/callback".to_string(), + "expected-state".to_string(), + ) + .unwrap_err(); + + let response = client.join().unwrap(); + + assert!(error.to_string().contains("OAuth state mismatch")); + assert!(response.contains("400 Bad Request")); + } +} diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 316026997b..6033c9b816 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -1,7 +1,5 @@ use std::collections::HashMap; use std::fmt::Display; -use std::io::{Read, Write}; -use std::net::{TcpListener, TcpStream}; use std::path::PathBuf; use std::str::FromStr; use std::sync::Arc; @@ -39,6 +37,7 @@ use crate::editor::ReadLineError; use crate::info::Info; use crate::input::Console; use crate::model::{ForgeCommandManager, SlashCommand}; +use crate::oauth_callback::LocalhostOAuthCallbackServer; use crate::porcelain::Porcelain; use crate::prompt::ForgePrompt; use crate::state::UIState; @@ -61,212 +60,6 @@ struct ConversationDump { related_conversations: Vec, } -#[derive(Debug, Clone, PartialEq, Eq)] -struct OAuthCallbackPayload { - code: String, -} - -fn escape_html(input: &str) -> String { - input - .replace('&', "&") - .replace('<', "<") - .replace('>', ">") - .replace('"', """) - .replace('\'', "'") -} - -fn localhost_oauth_redirect_uri(request: &CodeRequest) -> Option { - request - .oauth_config - .redirect_uri - .as_ref() - .and_then(|uri| Url::parse(uri).ok()) - .filter(|uri| { - uri.scheme() == "http" - && matches!(uri.host_str(), Some("localhost") | Some("127.0.0.1")) - && uri.port().is_some() - }) -} - -fn localhost_oauth_bind_addr(redirect_uri: &Url) -> anyhow::Result { - let host = match redirect_uri.host_str() { - Some("localhost") => "127.0.0.1", - Some(host) => host, - None => anyhow::bail!("OAuth redirect URI is missing a host"), - }; - let port = redirect_uri - .port() - .ok_or_else(|| anyhow::anyhow!("OAuth redirect URI is missing an explicit port"))?; - Ok(format!("{host}:{port}")) -} - -fn oauth_callback_success_page() -> String { - "ForgeCode Authorization Successful

Authorization Successful

You can close this window and return to ForgeCode.

".to_string() -} - -fn oauth_callback_error_page(message: &str) -> String { - format!( - "ForgeCode Authorization Failed

Authorization Failed

ForgeCode could not complete sign-in.

{}
", - escape_html(message) - ) -} - -fn write_http_response( - stream: &mut TcpStream, - status_line: &str, - body: &str, -) -> anyhow::Result<()> { - let response = format!( - "HTTP/1.1 {status_line}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", - body.len(), - body - ); - stream.write_all(response.as_bytes())?; - stream.flush()?; - Ok(()) -} - -fn parse_oauth_callback_target( - request_target: &str, - expected_path: &str, - expected_state: &str, -) -> anyhow::Result> { - let callback_url = Url::parse(&format!("http://localhost{request_target}"))?; - if callback_url.path() != expected_path { - return Ok(None); - } - - let params: HashMap = callback_url.query_pairs().into_owned().collect(); - if let Some(error) = params.get("error") { - let detail = params - .get("error_description") - .filter(|value| !value.trim().is_empty()) - .map(|value| format!(": {value}")) - .unwrap_or_default(); - anyhow::bail!("Authorization failed ({error}{detail})"); - } - - let state = params - .get("state") - .filter(|value| !value.trim().is_empty()) - .ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?; - if state != expected_state { - anyhow::bail!("OAuth state mismatch. Please try again."); - } - - let code = params - .get("code") - .filter(|value| !value.trim().is_empty()) - .ok_or_else(|| anyhow::anyhow!("Missing authorization code in callback"))? - .to_string(); - - Ok(Some(OAuthCallbackPayload { code })) -} - -/// Maximum time to wait for the OAuth browser callback before giving up. -const OAUTH_CALLBACK_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); - -fn wait_for_localhost_oauth_callback( - listener: TcpListener, - expected_path: String, - expected_state: String, -) -> anyhow::Result { - let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT; - listener.set_nonblocking(false)?; - - loop { - let remaining = deadline - .checked_duration_since(std::time::Instant::now()) - .ok_or_else(|| { - anyhow::anyhow!( - "Timed out waiting for OAuth callback after {} seconds", - OAUTH_CALLBACK_TIMEOUT.as_secs() - ) - })?; - - // Cap each accept at the remaining time so we re-check the deadline - let accept_timeout = remaining.min(std::time::Duration::from_secs(5)); - listener.set_nonblocking(true)?; - let accept_start = std::time::Instant::now(); - let accept_result = loop { - match listener.accept() { - Ok(conn) => break Ok(conn), - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - std::thread::sleep(std::time::Duration::from_millis(100)); - if accept_start.elapsed() >= accept_timeout - || std::time::Instant::now() >= deadline - { - // Will be caught by the remaining check at top of outer loop - break Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "accept timeout", - )); - } - } - Err(e) => break Err(e), - } - }; - - let (mut stream, _) = match accept_result { - Ok(conn) => conn, - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => continue, - Err(e) => return Err(e.into()), - }; - - let mut buffer = [0u8; 8192]; - let bytes_read = match stream.read(&mut buffer) { - Ok(n) => n, - Err(_) => continue, - }; - if bytes_read == 0 { - continue; - } - - let request = String::from_utf8_lossy(&buffer[..bytes_read]); - let Some(request_line) = request.lines().next() else { - continue; - }; - - let mut parts = request_line.split_whitespace(); - let method = parts.next().unwrap_or_default(); - let request_target = parts.next().unwrap_or_default(); - - if method != "GET" { - let _ = write_http_response( - &mut stream, - "405 Method Not Allowed", - &oauth_callback_error_page("Only GET requests are supported for OAuth callbacks."), - ); - continue; - } - - match parse_oauth_callback_target(request_target, &expected_path, &expected_state) { - Ok(Some(payload)) => { - let _ = write_http_response(&mut stream, "200 OK", &oauth_callback_success_page()); - return Ok(payload.code); - } - Ok(None) => { - let _ = write_http_response( - &mut stream, - "404 Not Found", - &oauth_callback_error_page( - "Received a request for an unexpected callback path.", - ), - ); - } - Err(err) => { - let message = err.to_string(); - let _ = write_http_response( - &mut stream, - "400 Bad Request", - &oauth_callback_error_page(&message), - ); - return Err(err); - } - } - } -} - /// Formats an MCP server config for display, redacting sensitive information. /// Returns the command/URL string only. fn format_mcp_server(server: &forge_domain::McpServerConfig) -> String { @@ -2641,30 +2434,20 @@ impl A + Send + Sync> UI { format!("Authenticate using your {provider_id} account").dimmed() ))?; - let callback_task = if let Some(redirect_uri) = localhost_oauth_redirect_uri(request) { - match localhost_oauth_bind_addr(&redirect_uri) - .and_then(|addr| TcpListener::bind(&addr).map_err(Into::into)) - { - Ok(listener) => { - let callback_path = redirect_uri.path().to_string(); - let expected_state = request.state.to_string(); - self.writeln(format!( - "{} Waiting for browser callback on {}", - "→".blue(), - redirect_uri.as_str().blue().underline() - ))?; - - Some(tokio::task::spawn_blocking(move || { - wait_for_localhost_oauth_callback(listener, callback_path, expected_state) - })) - } - Err(_) => { - // Port in use or bind failed — fall back to manual paste - None - } + let callback_server = match LocalhostOAuthCallbackServer::start(request) { + Ok(Some(server)) => { + self.writeln(format!( + "{} Waiting for browser callback on {}", + "→".blue(), + server.redirect_uri().as_str().blue().underline() + ))?; + Some(server) + } + Ok(None) | Err(_) => { + // Not a localhost callback flow, or the listener could not be + // started — fall back to manual code paste. + None } - } else { - None }; // Display authorization URL @@ -2681,9 +2464,8 @@ impl A + Send + Sync> UI { )))?; } - let code = if let Some(task) = callback_task { - task.await - .map_err(|e| anyhow::anyhow!("OAuth callback task failed: {e}"))?? + let code = if let Some(server) = callback_server { + server.wait_for_code().await? } else { // Prompt user to paste authorization code let code = ForgeWidget::input("Paste the authorization code") @@ -4409,105 +4191,3 @@ impl A + Send + Sync> UI { }); } } - -#[cfg(test)] -mod tests { - use std::io::{Read, Write}; - use std::net::{TcpListener, TcpStream}; - use std::thread; - - use forge_domain::{CodeRequest, OAuthConfig, PkceVerifier, State}; - use url::Url; - - use super::{localhost_oauth_redirect_uri, wait_for_localhost_oauth_callback}; - - fn sample_code_request(authorization_url: &str) -> CodeRequest { - CodeRequest { - authorization_url: Url::parse(authorization_url).unwrap(), - state: State::from("expected-state".to_string()), - pkce_verifier: Some(PkceVerifier::from("verifier".to_string())), - oauth_config: OAuthConfig { - auth_url: Url::parse("https://auth.openai.com/oauth/authorize").unwrap(), - token_url: Url::parse("https://auth.openai.com/oauth/token").unwrap(), - client_id: "client-id".to_string().into(), - scopes: vec!["openid".to_string()], - redirect_uri: Some("http://localhost:1455/auth/callback".to_string()), - use_pkce: true, - token_refresh_url: None, - custom_headers: None, - extra_auth_params: None, - }, - } - } - - #[test] - fn extracts_localhost_redirect_uri_from_oauth_request() { - let request = sample_code_request( - "https://auth.openai.com/oauth/authorize?client_id=test&redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback&state=expected-state", - ); - - let redirect_uri = localhost_oauth_redirect_uri(&request).unwrap(); - - assert_eq!(redirect_uri.as_str(), "http://localhost:1455/auth/callback"); - } - - #[test] - fn captures_authorization_code_from_localhost_callback() { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let client = thread::spawn(move || { - let mut stream = TcpStream::connect(addr).unwrap(); - stream - .write_all( - b"GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", - ) - .unwrap(); - let mut response = String::new(); - stream.read_to_string(&mut response).unwrap(); - response - }); - - let code = wait_for_localhost_oauth_callback( - listener, - "/auth/callback".to_string(), - "expected-state".to_string(), - ) - .unwrap(); - - let response = client.join().unwrap(); - - assert_eq!(code, "auth-code"); - assert!(response.contains("200 OK")); - } - - #[test] - fn rejects_localhost_callback_with_mismatched_state() { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let client = thread::spawn(move || { - let mut stream = TcpStream::connect(addr).unwrap(); - stream - .write_all( - b"GET /auth/callback?code=auth-code&state=wrong-state HTTP/1.1\r\nHost: localhost\r\n\r\n", - ) - .unwrap(); - let mut response = String::new(); - stream.read_to_string(&mut response).unwrap(); - response - }); - - let error = wait_for_localhost_oauth_callback( - listener, - "/auth/callback".to_string(), - "expected-state".to_string(), - ) - .unwrap_err(); - - let response = client.join().unwrap(); - - assert!(error.to_string().contains("OAuth state mismatch")); - assert!(response.contains("400 Bad Request")); - } -} From 59030e4682665976b259db1e9f033adc9dd878ce Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Fri, 3 Apr 2026 23:24:29 +0530 Subject: [PATCH 07/11] Minimize remaining ui.rs refactor changes Use the OAuth callback server with a fully qualified path so the UI file keeps fewer incidental changes. Co-Authored-By: ForgeCode --- crates/forge_main/src/ui.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 6033c9b816..3c49e14bb9 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -37,7 +37,6 @@ use crate::editor::ReadLineError; use crate::info::Info; use crate::input::Console; use crate::model::{ForgeCommandManager, SlashCommand}; -use crate::oauth_callback::LocalhostOAuthCallbackServer; use crate::porcelain::Porcelain; use crate::prompt::ForgePrompt; use crate::state::UIState; @@ -2434,7 +2433,7 @@ impl A + Send + Sync> UI { format!("Authenticate using your {provider_id} account").dimmed() ))?; - let callback_server = match LocalhostOAuthCallbackServer::start(request) { + let callback_server = match crate::oauth_callback::LocalhostOAuthCallbackServer::start(request) { Ok(Some(server)) => { self.writeln(format!( "{} Waiting for browser callback on {}", From 3f197c3503d4a3000bbe647a673e554af36cd3b1 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Fri, 3 Apr 2026 17:56:22 +0000 Subject: [PATCH 08/11] [autofix.ci] apply automated fixes --- crates/forge_main/src/ui.rs | 31 ++++++++++++++++--------------- 1 file changed, 16 insertions(+), 15 deletions(-) diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 3c49e14bb9..9fb7b35958 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -2433,21 +2433,22 @@ impl A + Send + Sync> UI { format!("Authenticate using your {provider_id} account").dimmed() ))?; - let callback_server = match crate::oauth_callback::LocalhostOAuthCallbackServer::start(request) { - Ok(Some(server)) => { - self.writeln(format!( - "{} Waiting for browser callback on {}", - "→".blue(), - server.redirect_uri().as_str().blue().underline() - ))?; - Some(server) - } - Ok(None) | Err(_) => { - // Not a localhost callback flow, or the listener could not be - // started — fall back to manual code paste. - None - } - }; + let callback_server = + match crate::oauth_callback::LocalhostOAuthCallbackServer::start(request) { + Ok(Some(server)) => { + self.writeln(format!( + "{} Waiting for browser callback on {}", + "→".blue(), + server.redirect_uri().as_str().blue().underline() + ))?; + Some(server) + } + Ok(None) | Err(_) => { + // Not a localhost callback flow, or the listener could not be + // started — fall back to manual code paste. + None + } + }; // Display authorization URL self.writeln(format!( From c8ff443314d173b1367cbfc72f233169224c8088 Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Fri, 3 Apr 2026 23:28:11 +0530 Subject: [PATCH 09/11] Restore ui.rs test module after refactor Keep the existing placeholder test module in ui.rs so the refactor does not remove unrelated test scaffolding. Co-Authored-By: ForgeCode --- crates/forge_main/src/ui.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/crates/forge_main/src/ui.rs b/crates/forge_main/src/ui.rs index 9fb7b35958..3d6b946bac 100644 --- a/crates/forge_main/src/ui.rs +++ b/crates/forge_main/src/ui.rs @@ -4191,3 +4191,11 @@ impl A + Send + Sync> UI { }); } } + +#[cfg(test)] +mod tests { + // Note: Tests for confirm_delete_conversation are disabled because + // ForgeSelect::confirm is not easily mockable in the current + // architecture. The functionality is tested through integration tests + // instead. +} From ca76e48c4acc531b854986957d2a3e45a92a15f1 Mon Sep 17 00:00:00 2001 From: Amit Singh Date: Sat, 4 Apr 2026 07:27:15 +0530 Subject: [PATCH 10/11] fix(oauth): replace manual TCP handling with tiny_http for callback server --- Cargo.lock | 25 ++ Cargo.toml | 1 + crates/forge_main/Cargo.toml | 1 + crates/forge_main/src/info.rs | 2 +- crates/forge_main/src/oauth_callback.rs | 493 ++++++++++++++++-------- 5 files changed, 359 insertions(+), 163 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index afab7b0401..1499dde807 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -114,6 +114,12 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "ascii" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d92bec98840b8f03a5ff5413de5293bfcd8bf96467cf5452609f939ec6f5de16" + [[package]] name = "assert-json-diff" version = "2.0.2" @@ -743,6 +749,12 @@ dependencies = [ "windows-link 0.2.1", ] +[[package]] +name = "chunked_transfer" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e4de3bc4ea267985becf712dc6d9eed8b04c953b3fcfb339ebc87acd9804901" + [[package]] name = "clap" version = "4.6.0" @@ -2150,6 +2162,7 @@ dependencies = [ "tempfile", "terminal_size", "thiserror 2.0.18", + "tiny_http", "tokio", "tokio-stream", "tracing", @@ -6440,6 +6453,18 @@ dependencies = [ "crunchy", ] +[[package]] +name = "tiny_http" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "389915df6413a2e74fb181895f933386023c71110878cd0825588928e64cdc82" +dependencies = [ + "ascii", + "chunked_transfer", + "httpdate", + "log", +] + [[package]] name = "tinystr" version = "0.8.2" diff --git a/Cargo.toml b/Cargo.toml index ecdc37bcb7..622cd79f1a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -89,6 +89,7 @@ syn = { version = "2.0.117", features = ["derive", "parsing"] } sysinfo = "0.38.3" tempfile = "3.27.0" termimad = "0.34.1" +tiny_http = "0.12.0" syntect = { version = "5", default-features = false, features = ["default-syntaxes", "default-themes", "regex-onig"] } thiserror = "2.0.18" toml_edit = { version = "0.22", features = ["serde"] } diff --git a/crates/forge_main/Cargo.toml b/crates/forge_main/Cargo.toml index 10d4ba741e..9357f0da16 100644 --- a/crates/forge_main/Cargo.toml +++ b/crates/forge_main/Cargo.toml @@ -66,6 +66,7 @@ strip-ansi-escapes.workspace = true terminal_size = "0.4" rustls.workspace = true tempfile.workspace = true +tiny_http.workspace = true [target.'cfg(windows)'.dependencies] enable-ansi-support.workspace = true diff --git a/crates/forge_main/src/info.rs b/crates/forge_main/src/info.rs index b0815a8799..074e8e9711 100644 --- a/crates/forge_main/src/info.rs +++ b/crates/forge_main/src/info.rs @@ -75,7 +75,7 @@ impl Section { /// # Output Format /// /// ```text -/// +/// /// CONFIGURATION /// model gpt-4 /// provider openai diff --git a/crates/forge_main/src/oauth_callback.rs b/crates/forge_main/src/oauth_callback.rs index 033330990b..50bf6aaba0 100644 --- a/crates/forge_main/src/oauth_callback.rs +++ b/crates/forge_main/src/oauth_callback.rs @@ -1,23 +1,52 @@ use std::collections::HashMap; -use std::io::{Read, Write}; -use std::net::{TcpListener, TcpStream}; +use std::net::{IpAddr, SocketAddr, TcpListener}; +use std::sync::{ + Arc, + atomic::{AtomicBool, Ordering}, +}; +use std::time::{Duration, Instant}; use forge_domain::CodeRequest; -use url::Url; +use tiny_http::{Header, Method, Request, Response, Server, StatusCode}; +use url::{Host, Url}; /// Maximum time to wait for the OAuth browser callback before giving up. -const OAUTH_CALLBACK_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(300); +const OAUTH_CALLBACK_TIMEOUT: Duration = Duration::from_secs(300); +const CALLBACK_POLL_INTERVAL: Duration = Duration::from_secs(1); #[derive(Debug, Clone, PartialEq, Eq)] struct OAuthCallbackPayload { code: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +enum OAuthCallbackParseResult { + PathMismatch, + InvalidRequest(String), + OAuthError(String), + Success(OAuthCallbackPayload), +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum CallbackRequestDisposition { + Continue { + status_code: StatusCode, + message: String, + }, + Fail { + status_code: StatusCode, + message: String, + }, + Complete(OAuthCallbackPayload), +} + /// Localhost OAuth callback server that waits for a browser redirect and /// returns the authorization code. pub(crate) struct LocalhostOAuthCallbackServer { redirect_uri: Url, - task: tokio::task::JoinHandle>, + server: Arc, + shutdown: Arc, + task: Option>>, } impl LocalhostOAuthCallbackServer { @@ -29,7 +58,7 @@ impl LocalhostOAuthCallbackServer { /// /// # Errors /// - /// Returns an error if the localhost redirect URI is invalid or the TCP + /// Returns an error if the localhost redirect URI is invalid or the HTTP /// listener cannot be bound. pub(crate) fn start(request: &CodeRequest) -> anyhow::Result> { let Some(redirect_uri) = localhost_oauth_redirect_uri(request) else { @@ -39,11 +68,24 @@ impl LocalhostOAuthCallbackServer { let listener = TcpListener::bind(localhost_oauth_bind_addr(&redirect_uri)?)?; let callback_path = redirect_uri.path().to_string(); let expected_state = request.state.to_string(); - let task = tokio::task::spawn_blocking(move || { - wait_for_localhost_oauth_callback(listener, callback_path, expected_state) + let server = Arc::new(Server::from_listener(listener, None).map_err(|e| { + anyhow::anyhow!("Failed to start localhost OAuth callback server: {e}") + })?); + let shutdown = Arc::new(AtomicBool::new(false)); + let task = tokio::task::spawn_blocking({ + let server = Arc::clone(&server); + let shutdown = Arc::clone(&shutdown); + move || { + wait_for_localhost_oauth_callback(server, callback_path, expected_state, shutdown) + } }); - Ok(Some(Self { redirect_uri, task })) + Ok(Some(Self { + redirect_uri, + server, + shutdown, + task: Some(task), + })) } /// Returns the redirect URI the callback server is listening on. @@ -57,13 +99,22 @@ impl LocalhostOAuthCallbackServer { /// /// Returns an error when the background task fails or the callback request /// is invalid. - pub(crate) async fn wait_for_code(self) -> anyhow::Result { + pub(crate) async fn wait_for_code(mut self) -> anyhow::Result { self.task + .take() + .expect("OAuth callback task should exist") .await .map_err(|e| anyhow::anyhow!("OAuth callback task failed: {e}"))? } } +impl Drop for LocalhostOAuthCallbackServer { + fn drop(&mut self) { + self.shutdown.store(true, Ordering::Relaxed); + self.server.unblock(); + } +} + fn escape_html(input: &str) -> String { input .replace('&', "&") @@ -81,21 +132,28 @@ fn localhost_oauth_redirect_uri(request: &CodeRequest) -> Option { .and_then(|uri| Url::parse(uri).ok()) .filter(|uri| { uri.scheme() == "http" - && matches!(uri.host_str(), Some("localhost") | Some("127.0.0.1")) && uri.port().is_some() + && (matches!(uri.host(), Some(Host::Domain("localhost"))) + || uri.host().is_some_and(|host| match host { + Host::Ipv4(ip) => ip.is_loopback(), + Host::Ipv6(ip) => ip.is_loopback(), + Host::Domain(_) => false, + })) }) } -fn localhost_oauth_bind_addr(redirect_uri: &Url) -> anyhow::Result { - let host = match redirect_uri.host_str() { - Some("localhost") => "127.0.0.1", - Some(host) => host, - None => anyhow::bail!("OAuth redirect URI is missing a host"), - }; +fn localhost_oauth_bind_addr(redirect_uri: &Url) -> anyhow::Result { let port = redirect_uri .port() .ok_or_else(|| anyhow::anyhow!("OAuth redirect URI is missing an explicit port"))?; - Ok(format!("{host}:{port}")) + + match redirect_uri.host() { + Some(Host::Domain("localhost")) => Ok(SocketAddr::from(([127, 0, 0, 1], port))), + Some(Host::Ipv4(ip)) if ip.is_loopback() => Ok(SocketAddr::new(IpAddr::V4(ip), port)), + Some(Host::Ipv6(ip)) if ip.is_loopback() => Ok(SocketAddr::new(IpAddr::V6(ip), port)), + Some(_) => anyhow::bail!("OAuth redirect URI host must be localhost or loopback"), + None => anyhow::bail!("OAuth redirect URI is missing a host"), + } } fn oauth_callback_success_page() -> String { @@ -109,29 +167,41 @@ fn oauth_callback_error_page(message: &str) -> String { ) } -fn write_http_response( - stream: &mut TcpStream, - status_line: &str, - body: &str, -) -> anyhow::Result<()> { - let response = format!( - "HTTP/1.1 {status_line}\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", - body.len(), - body - ); - stream.write_all(response.as_bytes())?; - stream.flush()?; - Ok(()) +fn html_response(status_code: StatusCode, body: String) -> Response>> { + Response::from_string(body) + .with_status_code(status_code) + .with_header(response_header("Content-Type", "text/html; charset=utf-8")) + .with_header(response_header("Cache-Control", "no-store")) + .with_header(response_header("X-Content-Type-Options", "nosniff")) +} + +fn response_header(name: &str, value: &str) -> Header { + Header::from_bytes(name.as_bytes(), value.as_bytes()) + .expect("static HTTP header should be valid") } fn parse_oauth_callback_target( request_target: &str, expected_path: &str, expected_state: &str, -) -> anyhow::Result> { - let callback_url = Url::parse(&format!("http://localhost{request_target}"))?; +) -> OAuthCallbackParseResult { + let Some(request_target) = request_target.strip_prefix('/') else { + return OAuthCallbackParseResult::InvalidRequest( + "Malformed OAuth callback request target".to_string(), + ); + }; + + let callback_url = match Url::parse(&format!("http://localhost/{request_target}")) { + Ok(url) => url, + Err(_) => { + return OAuthCallbackParseResult::InvalidRequest( + "Malformed OAuth callback request target".to_string(), + ); + } + }; + if callback_url.path() != expected_path { - return Ok(None); + return OAuthCallbackParseResult::PathMismatch; } let params: HashMap = callback_url.query_pairs().into_owned().collect(); @@ -141,37 +211,105 @@ fn parse_oauth_callback_target( .filter(|value| !value.trim().is_empty()) .map(|value| format!(": {value}")) .unwrap_or_default(); - anyhow::bail!("Authorization failed ({error}{detail})"); + return OAuthCallbackParseResult::OAuthError(format!( + "Authorization failed ({error}{detail})" + )); } - let state = params - .get("state") - .filter(|value| !value.trim().is_empty()) - .ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?; + let Some(state) = params.get("state").filter(|value| !value.trim().is_empty()) else { + return OAuthCallbackParseResult::InvalidRequest( + "Missing OAuth state in callback".to_string(), + ); + }; if state != expected_state { - anyhow::bail!("OAuth state mismatch. Please try again."); + return OAuthCallbackParseResult::InvalidRequest( + "OAuth state mismatch. Please try again.".to_string(), + ); } - let code = params + let Some(code) = params .get("code") .filter(|value| !value.trim().is_empty()) - .ok_or_else(|| anyhow::anyhow!("Missing authorization code in callback"))? - .to_string(); + .cloned() + else { + return OAuthCallbackParseResult::InvalidRequest( + "Missing authorization code in callback".to_string(), + ); + }; + + OAuthCallbackParseResult::Success(OAuthCallbackPayload { code }) +} + +fn classify_callback_request( + request: &Request, + expected_path: &str, + expected_state: &str, +) -> CallbackRequestDisposition { + if request + .remote_addr() + .is_some_and(|remote_addr| !remote_addr.ip().is_loopback()) + { + return CallbackRequestDisposition::Continue { + status_code: StatusCode(403), + message: "Only loopback callback requests are accepted.".to_string(), + }; + } + + if request.method() != &Method::Get { + return CallbackRequestDisposition::Continue { + status_code: StatusCode(405), + message: "Only GET requests are supported for OAuth callbacks.".to_string(), + }; + } - Ok(Some(OAuthCallbackPayload { code })) + match parse_oauth_callback_target(request.url(), expected_path, expected_state) { + OAuthCallbackParseResult::PathMismatch => CallbackRequestDisposition::Continue { + status_code: StatusCode(404), + message: "Received a request for an unexpected callback path.".to_string(), + }, + OAuthCallbackParseResult::InvalidRequest(message) => { + CallbackRequestDisposition::Continue { status_code: StatusCode(400), message } + } + OAuthCallbackParseResult::OAuthError(message) => { + CallbackRequestDisposition::Fail { status_code: StatusCode(400), message } + } + OAuthCallbackParseResult::Success(payload) => CallbackRequestDisposition::Complete(payload), + } +} + +fn respond_to_callback_request( + request: Request, + disposition: &CallbackRequestDisposition, +) -> anyhow::Result<()> { + let response = match disposition { + CallbackRequestDisposition::Continue { status_code, message } + | CallbackRequestDisposition::Fail { status_code, message } => { + html_response(*status_code, oauth_callback_error_page(message)) + } + CallbackRequestDisposition::Complete(_) => { + html_response(StatusCode(200), oauth_callback_success_page()) + } + }; + + request.respond(response)?; + Ok(()) } fn wait_for_localhost_oauth_callback( - listener: TcpListener, + server: Arc, expected_path: String, expected_state: String, + shutdown: Arc, ) -> anyhow::Result { - let deadline = std::time::Instant::now() + OAUTH_CALLBACK_TIMEOUT; - listener.set_nonblocking(false)?; + let deadline = Instant::now() + OAUTH_CALLBACK_TIMEOUT; loop { + if shutdown.load(Ordering::Relaxed) { + anyhow::bail!("OAuth callback listener was cancelled"); + } + let remaining = deadline - .checked_duration_since(std::time::Instant::now()) + .checked_duration_since(Instant::now()) .ok_or_else(|| { anyhow::anyhow!( "Timed out waiting for OAuth callback after {} seconds", @@ -179,83 +317,18 @@ fn wait_for_localhost_oauth_callback( ) })?; - let accept_timeout = remaining.min(std::time::Duration::from_secs(5)); - listener.set_nonblocking(true)?; - let accept_start = std::time::Instant::now(); - let accept_result = loop { - match listener.accept() { - Ok(conn) => break Ok(conn), - Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => { - std::thread::sleep(std::time::Duration::from_millis(100)); - if accept_start.elapsed() >= accept_timeout - || std::time::Instant::now() >= deadline - { - break Err(std::io::Error::new( - std::io::ErrorKind::TimedOut, - "accept timeout", - )); - } - } - Err(e) => break Err(e), - } - }; - - let (mut stream, _) = match accept_result { - Ok(conn) => conn, - Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => continue, - Err(e) => return Err(e.into()), - }; - - let mut buffer = [0u8; 8192]; - let bytes_read = match stream.read(&mut buffer) { - Ok(n) => n, - Err(_) => continue, - }; - if bytes_read == 0 { - continue; - } - - let request = String::from_utf8_lossy(&buffer[..bytes_read]); - let Some(request_line) = request.lines().next() else { + let timeout = remaining.min(CALLBACK_POLL_INTERVAL); + let Some(request) = server.recv_timeout(timeout)? else { continue; }; - let mut parts = request_line.split_whitespace(); - let method = parts.next().unwrap_or_default(); - let request_target = parts.next().unwrap_or_default(); + let actual = classify_callback_request(&request, &expected_path, &expected_state); + let _ = respond_to_callback_request(request, &actual); - if method != "GET" { - let _ = write_http_response( - &mut stream, - "405 Method Not Allowed", - &oauth_callback_error_page("Only GET requests are supported for OAuth callbacks."), - ); - continue; - } - - match parse_oauth_callback_target(request_target, &expected_path, &expected_state) { - Ok(Some(payload)) => { - let _ = write_http_response(&mut stream, "200 OK", &oauth_callback_success_page()); - return Ok(payload.code); - } - Ok(None) => { - let _ = write_http_response( - &mut stream, - "404 Not Found", - &oauth_callback_error_page( - "Received a request for an unexpected callback path.", - ), - ); - } - Err(err) => { - let message = err.to_string(); - let _ = write_http_response( - &mut stream, - "400 Bad Request", - &oauth_callback_error_page(&message), - ); - return Err(err); - } + match actual { + CallbackRequestDisposition::Continue { .. } => continue, + CallbackRequestDisposition::Fail { message, .. } => anyhow::bail!(message), + CallbackRequestDisposition::Complete(payload) => return Ok(payload.code), } } } @@ -263,10 +336,15 @@ fn wait_for_localhost_oauth_callback( #[cfg(test)] mod tests { use std::io::{Read, Write}; - use std::net::{TcpListener, TcpStream}; + use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream}; + use std::sync::{ + Arc, + atomic::AtomicBool, + }; use std::thread; use forge_domain::{OAuthConfig, PkceVerifier, State}; + use pretty_assertions::assert_eq; use super::*; @@ -289,74 +367,165 @@ mod tests { } } + fn sample_callback_server() -> (Arc, SocketAddr, Arc) { + let fixture = TcpListener::bind("127.0.0.1:0").unwrap(); + let addr = fixture.local_addr().unwrap(); + let server = Arc::new(Server::from_listener(fixture, None).unwrap()); + let shutdown = Arc::new(AtomicBool::new(false)); + (server, addr, shutdown) + } + + fn send_http_request(addr: SocketAddr, request: &str) -> String { + let mut fixture = TcpStream::connect(addr).unwrap(); + fixture.write_all(request.as_bytes()).unwrap(); + fixture.shutdown(Shutdown::Write).unwrap(); + + let mut actual = String::new(); + fixture.read_to_string(&mut actual).unwrap(); + actual + } + #[test] fn extracts_localhost_redirect_uri_from_oauth_request() { - let request = sample_code_request( + let setup = sample_code_request( "https://auth.openai.com/oauth/authorize?client_id=test&redirect_uri=http%3A%2F%2Flocalhost%3A1455%2Fauth%2Fcallback&state=expected-state", ); - let redirect_uri = localhost_oauth_redirect_uri(&request).unwrap(); + let actual = localhost_oauth_redirect_uri(&setup).unwrap(); - assert_eq!(redirect_uri.as_str(), "http://localhost:1455/auth/callback"); + let expected = "http://localhost:1455/auth/callback"; + assert_eq!(actual.as_str(), expected); + } + + #[test] + fn extracts_ipv6_loopback_redirect_uri_from_oauth_request() { + let mut setup = sample_code_request("https://auth.openai.com/oauth/authorize"); + setup.oauth_config.redirect_uri = Some("http://[::1]:1455/auth/callback".to_string()); + + let actual = localhost_oauth_redirect_uri(&setup).unwrap(); + + let expected = "http://[::1]:1455/auth/callback"; + assert_eq!(actual.as_str(), expected); } #[test] fn captures_authorization_code_from_localhost_callback() { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let client = thread::spawn(move || { - let mut stream = TcpStream::connect(addr).unwrap(); - stream - .write_all( - b"GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", - ) - .unwrap(); - let mut response = String::new(); - stream.read_to_string(&mut response).unwrap(); - response + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) }); - let code = wait_for_localhost_oauth_callback( - listener, + let actual = wait_for_localhost_oauth_callback( + server, "/auth/callback".to_string(), "expected-state".to_string(), + shutdown, ) .unwrap(); - let response = client.join().unwrap(); - - assert_eq!(code, "auth-code"); + let response = fixture.join().unwrap(); + let expected = "auth-code".to_string(); + assert_eq!(actual, expected); assert!(response.contains("200 OK")); } #[test] - fn rejects_localhost_callback_with_mismatched_state() { - let listener = TcpListener::bind("127.0.0.1:0").unwrap(); - let addr = listener.local_addr().unwrap(); - - let client = thread::spawn(move || { - let mut stream = TcpStream::connect(addr).unwrap(); - stream - .write_all( - b"GET /auth/callback?code=auth-code&state=wrong-state HTTP/1.1\r\nHost: localhost\r\n\r\n", - ) - .unwrap(); - let mut response = String::new(); - stream.read_to_string(&mut response).unwrap(); - response + fn keeps_listening_after_invalid_state_until_valid_callback_arrives() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + let first = send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=wrong-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ); + let second = send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ); + (first, second) }); - let error = wait_for_localhost_oauth_callback( - listener, + let actual = wait_for_localhost_oauth_callback( + server, "/auth/callback".to_string(), "expected-state".to_string(), + shutdown, ) - .unwrap_err(); + .unwrap(); + + let responses = fixture.join().unwrap(); + let expected = "auth-code".to_string(); + assert_eq!(actual, expected); + assert!(responses.0.contains("400 Bad Request")); + assert!(responses.1.contains("200 OK")); + } - let response = client.join().unwrap(); + #[test] + fn keeps_listening_after_invalid_method_until_valid_callback_arrives() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + let first = send_http_request( + addr, + "POST /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\nContent-Length: 0\r\n\r\n", + ); + let second = send_http_request( + addr, + "GET /auth/callback?code=auth-code&state=expected-state HTTP/1.1\r\nHost: localhost\r\n\r\n", + ); + (first, second) + }); - assert!(error.to_string().contains("OAuth state mismatch")); + let actual = wait_for_localhost_oauth_callback( + server, + "/auth/callback".to_string(), + "expected-state".to_string(), + shutdown, + ) + .unwrap(); + + let responses = fixture.join().unwrap(); + let expected = "auth-code".to_string(); + assert_eq!(actual, expected); + assert!(responses.0.contains("405 Method Not Allowed")); + assert!(responses.1.contains("200 OK")); + } + + #[test] + fn stops_when_provider_returns_terminal_oauth_error() { + let setup = sample_callback_server(); + let server = Arc::clone(&setup.0); + let addr = setup.1; + let shutdown = Arc::clone(&setup.2); + let fixture = thread::spawn(move || { + send_http_request( + addr, + "GET /auth/callback?error=access_denied&error_description=user%20cancelled HTTP/1.1\r\nHost: localhost\r\n\r\n", + ) + }); + + let actual = wait_for_localhost_oauth_callback( + server, + "/auth/callback".to_string(), + "expected-state".to_string(), + shutdown, + ) + .unwrap_err(); + + let response = fixture.join().unwrap(); + let expected = "Authorization failed (access_denied: user cancelled)"; + assert_eq!(actual.to_string(), expected); assert!(response.contains("400 Bad Request")); } } + From bd5e3c9f07fabbb49196e7380e6e897369b77995 Mon Sep 17 00:00:00 2001 From: "autofix-ci[bot]" <114827586+autofix-ci[bot]@users.noreply.github.com> Date: Sat, 4 Apr 2026 02:00:41 +0000 Subject: [PATCH 11/11] [autofix.ci] apply automated fixes --- crates/forge_main/src/info.rs | 2 +- crates/forge_main/src/oauth_callback.rs | 13 ++++--------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/crates/forge_main/src/info.rs b/crates/forge_main/src/info.rs index 074e8e9711..b0815a8799 100644 --- a/crates/forge_main/src/info.rs +++ b/crates/forge_main/src/info.rs @@ -75,7 +75,7 @@ impl Section { /// # Output Format /// /// ```text -/// +/// /// CONFIGURATION /// model gpt-4 /// provider openai diff --git a/crates/forge_main/src/oauth_callback.rs b/crates/forge_main/src/oauth_callback.rs index 50bf6aaba0..427e2a7e74 100644 --- a/crates/forge_main/src/oauth_callback.rs +++ b/crates/forge_main/src/oauth_callback.rs @@ -1,9 +1,7 @@ use std::collections::HashMap; use std::net::{IpAddr, SocketAddr, TcpListener}; -use std::sync::{ - Arc, - atomic::{AtomicBool, Ordering}, -}; +use std::sync::Arc; +use std::sync::atomic::{AtomicBool, Ordering}; use std::time::{Duration, Instant}; use forge_domain::CodeRequest; @@ -337,10 +335,8 @@ fn wait_for_localhost_oauth_callback( mod tests { use std::io::{Read, Write}; use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream}; - use std::sync::{ - Arc, - atomic::AtomicBool, - }; + use std::sync::Arc; + use std::sync::atomic::AtomicBool; use std::thread; use forge_domain::{OAuthConfig, PkceVerifier, State}; @@ -528,4 +524,3 @@ mod tests { assert!(response.contains("400 Bad Request")); } } -