Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion codex-rs/cloud-requirements/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ impl CloudRequirementsService {
"Cloud requirements request was unauthorized; attempting auth recovery"
);
match auth_recovery.next().await {
Ok(()) => {
Ok(_) => {
let Some(refreshed_auth) = self.auth_manager.auth().await else {
tracing::error!(
"Auth recovery succeeded but no auth is available for cloud requirements"
Expand Down
9 changes: 8 additions & 1 deletion codex-rs/codex-api/src/endpoint/responses_websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ impl ResponsesWebsocketConnection {
pub async fn stream_request(
&self,
request: ResponsesWsRequest,
connection_reused: bool,
) -> Result<ResponseStream, ApiError> {
let (tx_event, rx_event) =
mpsc::channel::<std::result::Result<ResponseEvent, ApiError>>(1600);
Expand Down Expand Up @@ -258,6 +259,7 @@ impl ResponsesWebsocketConnection {
request_body,
idle_timeout,
telemetry,
connection_reused,
)
.await
};
Expand Down Expand Up @@ -534,6 +536,7 @@ async fn run_websocket_response_stream(
request_body: Value,
idle_timeout: Duration,
telemetry: Option<Arc<dyn WebsocketTelemetry>>,
connection_reused: bool,
) -> Result<(), ApiError> {
let mut last_server_model: Option<String> = None;
let request_text = match serde_json::to_string(&request_body) {
Expand All @@ -553,7 +556,11 @@ async fn run_websocket_response_stream(
.map_err(|err| ApiError::Stream(format!("failed to send websocket request: {err}")));

if let Some(t) = telemetry.as_ref() {
t.on_ws_request(request_start.elapsed(), result.as_ref().err());
t.on_ws_request(
request_start.elapsed(),
result.as_ref().err(),
connection_reused,
);
}

result?;
Expand Down
2 changes: 1 addition & 1 deletion codex-rs/codex-api/src/telemetry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub trait SseTelemetry: Send + Sync {

/// Telemetry for Responses WebSocket transport.
pub trait WebsocketTelemetry: Send + Sync {
fn on_ws_request(&self, duration: Duration, error: Option<&ApiError>);
fn on_ws_request(&self, duration: Duration, error: Option<&ApiError>, connection_reused: bool);

fn on_ws_event(
&self,
Expand Down
44 changes: 44 additions & 0 deletions codex-rs/core/src/api_bridge.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use base64::Engine;
use chrono::DateTime;
use chrono::Utc;
use codex_api::AuthProvider as ApiAuthProvider;
Expand All @@ -7,6 +8,7 @@ use codex_api::rate_limits::parse_promo_message;
use codex_api::rate_limits::parse_rate_limit_for_limit;
use http::HeaderMap;
use serde::Deserialize;
use serde_json::Value;

use crate::auth::CodexAuth;
use crate::error::CodexErr;
Expand All @@ -30,6 +32,8 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
url: None,
cf_ray: None,
request_id: None,
identity_authorization_error: None,
identity_error_code: None,
}),
ApiError::InvalidRequest { message } => CodexErr::InvalidRequest(message),
ApiError::Transport(transport) => match transport {
Expand Down Expand Up @@ -98,6 +102,11 @@ pub(crate) fn map_api_error(err: ApiError) -> CodexErr {
url,
cf_ray: extract_header(headers.as_ref(), CF_RAY_HEADER),
request_id: extract_request_id(headers.as_ref()),
identity_authorization_error: extract_header(
headers.as_ref(),
X_OPENAI_AUTHORIZATION_ERROR_HEADER,
),
identity_error_code: extract_x_error_json_code(headers.as_ref()),
})
}
}
Expand All @@ -118,6 +127,8 @@ const ACTIVE_LIMIT_HEADER: &str = "x-codex-active-limit";
const REQUEST_ID_HEADER: &str = "x-request-id";
const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id";
const CF_RAY_HEADER: &str = "cf-ray";
const X_OPENAI_AUTHORIZATION_ERROR_HEADER: &str = "x-openai-authorization-error";
const X_ERROR_JSON_HEADER: &str = "x-error-json";

#[cfg(test)]
#[path = "api_bridge_tests.rs"]
Expand All @@ -140,6 +151,19 @@ fn extract_header(headers: Option<&HeaderMap>, name: &str) -> Option<String> {
})
}

fn extract_x_error_json_code(headers: Option<&HeaderMap>) -> Option<String> {
let encoded = extract_header(headers, X_ERROR_JSON_HEADER)?;
let decoded = base64::engine::general_purpose::STANDARD
.decode(encoded)
.ok()?;
let parsed = serde_json::from_slice::<Value>(&decoded).ok()?;
parsed
.get("error")
.and_then(|error| error.get("code"))
.and_then(Value::as_str)
.map(str::to_string)
}

pub(crate) fn auth_provider_from_auth(
auth: Option<CodexAuth>,
provider: &ModelProviderInfo,
Expand Down Expand Up @@ -191,6 +215,26 @@ pub(crate) struct CoreAuthProvider {
account_id: Option<String>,
}

impl CoreAuthProvider {
pub(crate) fn auth_header_attached(&self) -> bool {
self.token
.as_ref()
.is_some_and(|token| http::HeaderValue::from_str(&format!("Bearer {token}")).is_ok())
}

pub(crate) fn auth_header_name(&self) -> Option<&'static str> {
self.auth_header_attached().then_some("authorization")
}

#[cfg(test)]
pub(crate) fn for_test(token: Option<&str>, account_id: Option<&str>) -> Self {
Self {
token: token.map(str::to_string),
account_id: account_id.map(str::to_string),
}
}
}

impl ApiAuthProvider for CoreAuthProvider {
fn bearer_token(&self) -> Option<String> {
self.token.clone()
Expand Down
47 changes: 47 additions & 0 deletions codex-rs/core/src/api_bridge_tests.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use super::*;
use base64::Engine;
use pretty_assertions::assert_eq;

#[test]
Expand Down Expand Up @@ -94,3 +95,49 @@ fn map_api_error_does_not_fallback_limit_name_to_limit_id() {
None
);
}

#[test]
fn map_api_error_extracts_identity_auth_details_from_headers() {
let mut headers = HeaderMap::new();
headers.insert(REQUEST_ID_HEADER, http::HeaderValue::from_static("req-401"));
headers.insert(CF_RAY_HEADER, http::HeaderValue::from_static("ray-401"));
headers.insert(
X_OPENAI_AUTHORIZATION_ERROR_HEADER,
http::HeaderValue::from_static("missing_authorization_header"),
);
let x_error_json =
base64::engine::general_purpose::STANDARD.encode(r#"{"error":{"code":"token_expired"}}"#);
headers.insert(
X_ERROR_JSON_HEADER,
http::HeaderValue::from_str(&x_error_json).expect("valid x-error-json header"),
);

let err = map_api_error(ApiError::Transport(TransportError::Http {
status: http::StatusCode::UNAUTHORIZED,
url: Some("https://chatgpt.com/backend-api/codex/models".to_string()),
headers: Some(headers),
body: Some(r#"{"detail":"Unauthorized"}"#.to_string()),
}));

let CodexErr::UnexpectedStatus(err) = err else {
panic!("expected CodexErr::UnexpectedStatus, got {err:?}");
};
assert_eq!(err.request_id.as_deref(), Some("req-401"));
assert_eq!(err.cf_ray.as_deref(), Some("ray-401"));
assert_eq!(
err.identity_authorization_error.as_deref(),
Some("missing_authorization_header")
);
assert_eq!(err.identity_error_code.as_deref(), Some("token_expired"));
}

#[test]
fn core_auth_provider_reports_when_auth_header_will_attach() {
let auth = CoreAuthProvider {
token: Some("access-token".to_string()),
account_id: None,
};

assert!(auth.auth_header_attached());
assert_eq!(auth.auth_header_name(), Some("authorization"));
}
73 changes: 70 additions & 3 deletions codex-rs/core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,17 @@ pub struct UnauthorizedRecovery {
mode: UnauthorizedRecoveryMode,
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct UnauthorizedRecoveryStepResult {
auth_state_changed: Option<bool>,
}

impl UnauthorizedRecoveryStepResult {
pub fn auth_state_changed(&self) -> Option<bool> {
self.auth_state_changed
}
}

impl UnauthorizedRecovery {
fn new(manager: Arc<AuthManager>) -> Self {
let cached_auth = manager.auth_cached();
Expand Down Expand Up @@ -917,7 +928,46 @@ impl UnauthorizedRecovery {
!matches!(self.step, UnauthorizedRecoveryStep::Done)
}

pub async fn next(&mut self) -> Result<(), RefreshTokenError> {
pub fn unavailable_reason(&self) -> &'static str {
if !self
.manager
.auth_cached()
.as_ref()
.is_some_and(CodexAuth::is_chatgpt_auth)
{
return "not_chatgpt_auth";
}

if self.mode == UnauthorizedRecoveryMode::External
&& !self.manager.has_external_auth_refresher()
{
return "no_external_refresher";
}

if matches!(self.step, UnauthorizedRecoveryStep::Done) {
return "recovery_exhausted";
}

"ready"
}

pub fn mode_name(&self) -> &'static str {
match self.mode {
UnauthorizedRecoveryMode::Managed => "managed",
UnauthorizedRecoveryMode::External => "external",
}
}

pub fn step_name(&self) -> &'static str {
match self.step {
UnauthorizedRecoveryStep::Reload => "reload",
UnauthorizedRecoveryStep::RefreshToken => "refresh_token",
UnauthorizedRecoveryStep::ExternalRefresh => "external_refresh",
UnauthorizedRecoveryStep::Done => "done",
}
}

pub async fn next(&mut self) -> Result<UnauthorizedRecoveryStepResult, RefreshTokenError> {
if !self.has_next() {
return Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new(
RefreshTokenFailedReason::Other,
Expand All @@ -931,8 +981,17 @@ impl UnauthorizedRecovery {
.manager
.reload_if_account_id_matches(self.expected_account_id.as_deref())
{
ReloadOutcome::ReloadedChanged | ReloadOutcome::ReloadedNoChange => {
ReloadOutcome::ReloadedChanged => {
self.step = UnauthorizedRecoveryStep::RefreshToken;
return Ok(UnauthorizedRecoveryStepResult {
auth_state_changed: Some(true),
});
}
ReloadOutcome::ReloadedNoChange => {
self.step = UnauthorizedRecoveryStep::RefreshToken;
return Ok(UnauthorizedRecoveryStepResult {
auth_state_changed: Some(false),
});
}
ReloadOutcome::Skipped => {
self.step = UnauthorizedRecoveryStep::Done;
Expand All @@ -946,16 +1005,24 @@ impl UnauthorizedRecovery {
UnauthorizedRecoveryStep::RefreshToken => {
self.manager.refresh_token_from_authority().await?;
self.step = UnauthorizedRecoveryStep::Done;
return Ok(UnauthorizedRecoveryStepResult {
auth_state_changed: Some(true),
});
}
UnauthorizedRecoveryStep::ExternalRefresh => {
self.manager
.refresh_external_auth(ExternalAuthRefreshReason::Unauthorized)
.await?;
self.step = UnauthorizedRecoveryStep::Done;
return Ok(UnauthorizedRecoveryStepResult {
auth_state_changed: Some(true),
});
}
UnauthorizedRecoveryStep::Done => {}
}
Ok(())
Ok(UnauthorizedRecoveryStepResult {
auth_state_changed: None,
})
}
}

Expand Down
28 changes: 28 additions & 0 deletions codex-rs/core/src/auth_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use codex_protocol::config_types::ForcedLoginMethod;
use pretty_assertions::assert_eq;
use serde::Serialize;
use serde_json::json;
use std::sync::Arc;
use tempfile::tempdir;

#[tokio::test]
Expand Down Expand Up @@ -171,6 +172,33 @@ fn logout_removes_auth_file() -> Result<(), std::io::Error> {
Ok(())
}

#[test]
fn unauthorized_recovery_reports_mode_and_step_names() {
let dir = tempdir().unwrap();
let manager = AuthManager::shared(
dir.path().to_path_buf(),
false,
AuthCredentialsStoreMode::File,
);
let managed = UnauthorizedRecovery {
manager: Arc::clone(&manager),
step: UnauthorizedRecoveryStep::Reload,
expected_account_id: None,
mode: UnauthorizedRecoveryMode::Managed,
};
assert_eq!(managed.mode_name(), "managed");
assert_eq!(managed.step_name(), "reload");

let external = UnauthorizedRecovery {
manager,
step: UnauthorizedRecoveryStep::ExternalRefresh,
expected_account_id: None,
mode: UnauthorizedRecoveryMode::External,
};
assert_eq!(external.mode_name(), "external");
assert_eq!(external.step_name(), "external_refresh");
}

struct AuthFileParams {
openai_api_key: Option<String>,
chatgpt_plan_type: Option<String>,
Expand Down
Loading
Loading