Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 79 additions & 19 deletions codex-rs/core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ const REFRESH_TOKEN_REUSED_MESSAGE: &str = "Your access token could not be refre
const REFRESH_TOKEN_INVALIDATED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token was revoked. Please log out and sign in again.";
const REFRESH_TOKEN_UNKNOWN_MESSAGE: &str =
"Your access token could not be refreshed. Please log out and sign in again.";
const REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE: &str = "Your access token could not be refreshed because you have since logged out or signed in to another account. Please sign in again.";
const REFRESH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
pub const REFRESH_TOKEN_URL_OVERRIDE_ENV_VAR: &str = "CODEX_REFRESH_TOKEN_URL_OVERRIDE";

Expand Down Expand Up @@ -584,7 +585,8 @@ fn load_auth(
Ok(Some(auth))
}

fn update_tokens(
// Persist refreshed tokens into auth storage and update last_refresh.
fn persist_tokens(
storage: &Arc<dyn AuthStorageBackend>,
id_token: Option<String>,
access_token: Option<String>,
Expand All @@ -609,7 +611,9 @@ fn update_tokens(
Ok(auth_dot_json)
}

async fn try_refresh_token(
// Requests refreshed ChatGPT OAuth tokens from the auth service using a refresh token.
// The caller is responsible for persisting any returned tokens.
async fn request_chatgpt_token_refresh(
refresh_token: String,
client: &CodexHttpClient,
) -> Result<RefreshResponse, RefreshTokenError> {
Expand Down Expand Up @@ -823,7 +827,11 @@ enum UnauthorizedRecoveryStep {
}

enum ReloadOutcome {
Reloaded,
/// Reload was performed and the cached auth changed
ReloadedChanged,
/// Reload was performed and the cached auth remained the same
ReloadedNoChange,
/// Reload was skipped (missing or mismatched account id)
Skipped,
}

Expand Down Expand Up @@ -910,17 +918,20 @@ impl UnauthorizedRecovery {
.manager
.reload_if_account_id_matches(self.expected_account_id.as_deref())
{
ReloadOutcome::Reloaded => {
ReloadOutcome::ReloadedChanged | ReloadOutcome::ReloadedNoChange => {
self.step = UnauthorizedRecoveryStep::RefreshToken;
}
ReloadOutcome::Skipped => {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit: if we are being precise with ReloadOutcome maybe ReloadOutcome::SkippedAccountMismatch?

self.manager.refresh_token().await?;
self.step = UnauthorizedRecoveryStep::Done;
return Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new(
RefreshTokenFailedReason::Other,
REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE.to_string(),
)));
}
}
}
UnauthorizedRecoveryStep::RefreshToken => {
self.manager.refresh_token().await?;
self.manager.refresh_token_from_authority().await?;
self.step = UnauthorizedRecoveryStep::Done;
}
UnauthorizedRecoveryStep::ExternalRefresh => {
Expand Down Expand Up @@ -1060,8 +1071,30 @@ impl AuthManager {
}

tracing::info!("Reloading auth for account {expected_account_id}");
let cached_before_reload = self.auth_cached();
let auth_changed =
!Self::auths_equal_for_refresh(cached_before_reload.as_ref(), new_auth.as_ref());
self.set_cached_auth(new_auth);
ReloadOutcome::Reloaded
if auth_changed {
ReloadOutcome::ReloadedChanged
} else {
ReloadOutcome::ReloadedNoChange
}
}

fn auths_equal_for_refresh(a: Option<&CodexAuth>, b: Option<&CodexAuth>) -> bool {
match (a, b) {
(None, None) => true,
(Some(a), Some(b)) => match (a.api_auth_mode(), b.api_auth_mode()) {
(ApiAuthMode::ApiKey, ApiAuthMode::ApiKey) => a.api_key() == b.api_key(),
(ApiAuthMode::Chatgpt, ApiAuthMode::Chatgpt)
| (ApiAuthMode::ChatgptAuthTokens, ApiAuthMode::ChatgptAuthTokens) => {
a.get_current_auth_json() == b.get_current_auth_json()
}
_ => false,
},
_ => false,
}
}

fn auths_equal(a: Option<&CodexAuth>, b: Option<&CodexAuth>) -> bool {
Expand Down Expand Up @@ -1144,10 +1177,37 @@ impl AuthManager {
UnauthorizedRecovery::new(Arc::clone(self))
}

/// Attempt to refresh the current auth token (if any). On success, reload
/// the auth state from disk so other components observe refreshed token.
/// If the token refresh fails, returns the error to the caller.
/// Attempt to refresh the token by first performing a guarded reload. Auth
/// is reloaded from storage only when the account id matches the currently
/// cached account id. If the persisted token differs from the cached token, we
/// can assume that some other instance already refreshed it. If the persisted
/// token is the same as the cached, then ask the token authority to refresh.
pub async fn refresh_token(&self) -> Result<(), RefreshTokenError> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a problem that we modified a method that UnauthorizedRecovery calls?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will now attempt to reload multiple times.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see you've split it.

let auth_before_reload = self.auth_cached();
let expected_account_id = auth_before_reload
.as_ref()
.and_then(CodexAuth::get_account_id);

match self.reload_if_account_id_matches(expected_account_id.as_deref()) {
ReloadOutcome::ReloadedChanged => {
tracing::info!("Skipping token refresh because auth changed after guarded reload.");
Ok(())
}
ReloadOutcome::ReloadedNoChange => self.refresh_token_from_authority().await,
ReloadOutcome::Skipped => {
Err(RefreshTokenError::Permanent(RefreshTokenFailedError::new(
RefreshTokenFailedReason::Other,
REFRESH_TOKEN_ACCOUNT_MISMATCH_MESSAGE.to_string(),
)))
Comment thread
etraut-openai marked this conversation as resolved.
}
}
}

/// Attempt to refresh the current auth token from the authority that issued
/// the token. On success, reloads the auth state from disk so other components
/// observe refreshed token. If the token refresh fails, returns the error to
/// the caller.
pub async fn refresh_token_from_authority(&self) -> Result<(), RefreshTokenError> {
tracing::info!("Refreshing token");

let auth = match self.auth_cached() {
Expand All @@ -1165,10 +1225,8 @@ impl AuthManager {
"Token data is not available.",
))
})?;
self.refresh_tokens(&chatgpt_auth, token_data.refresh_token)
self.refresh_and_persist_chatgpt_token(&chatgpt_auth, token_data.refresh_token)
.await?;
// Reload to pick up persisted changes.
self.reload();
Ok(())
}
CodexAuth::ApiKey(_) => Ok(()),
Expand Down Expand Up @@ -1215,9 +1273,8 @@ impl AuthManager {
if last_refresh >= Utc::now() - chrono::Duration::days(TOKEN_REFRESH_INTERVAL) {
return Ok(false);
}
self.refresh_tokens(chatgpt_auth, tokens.refresh_token)
self.refresh_and_persist_chatgpt_token(chatgpt_auth, tokens.refresh_token)
.await?;
self.reload();
Ok(true)
}

Expand Down Expand Up @@ -1273,20 +1330,23 @@ impl AuthManager {
Ok(())
}

async fn refresh_tokens(
// Refreshes ChatGPT OAuth tokens, persists the updated auth state, and
// reloads the in-memory cache so callers immediately observe new tokens.
async fn refresh_and_persist_chatgpt_token(
&self,
auth: &ChatgptAuth,
refresh_token: String,
) -> Result<(), RefreshTokenError> {
let refresh_response = try_refresh_token(refresh_token, auth.client()).await?;
let refresh_response = request_chatgpt_token_refresh(refresh_token, auth.client()).await?;

update_tokens(
persist_tokens(
auth.storage(),
refresh_response.id_token,
refresh_response.access_token,
refresh_response.refresh_token,
)
.map_err(RefreshTokenError::from)?;
self.reload();

Ok(())
}
Expand Down Expand Up @@ -1328,7 +1388,7 @@ mod tests {
codex_home.path().to_path_buf(),
AuthCredentialsStoreMode::File,
);
let updated = super::update_tokens(
let updated = super::persist_tokens(
&storage,
None,
Some("new-access-token".to_string()),
Expand Down
Loading
Loading