Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions codex-rs/core/tests/suite/auth_refresh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,13 +294,14 @@ async fn returns_fresh_tokens_as_is() -> Result<()> {
.await;

let ctx = RefreshTokenTestContext::new(&server)?;
let initial_last_refresh = Utc::now() - Duration::days(1);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let stale_refresh = Utc::now() - Duration::days(9);
let fresh_access_token = access_token_with_expiration(Utc::now() + Duration::hours(1));
let initial_tokens = build_tokens(&fresh_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(initial_last_refresh),
last_refresh: Some(stale_refresh),
};
ctx.write_auth(&initial_auth)?;

Expand All @@ -325,7 +326,7 @@ async fn returns_fresh_tokens_as_is() -> Result<()> {

#[serial_test::serial(auth_refresh)]
#[tokio::test]
async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
async fn refreshes_token_when_access_token_is_expired() -> Result<()> {
skip_if_no_network!(Ok(()));

let server = MockServer::start().await;
Expand All @@ -340,13 +341,14 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
.await;

let ctx = RefreshTokenTestContext::new(&server)?;
let stale_refresh = Utc::now() - Duration::days(9);
let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN);
let fresh_refresh = Utc::now() - Duration::days(1);
let expired_access_token = access_token_with_expiration(Utc::now() - Duration::hours(1));
let initial_tokens = build_tokens(&expired_access_token, INITIAL_REFRESH_TOKEN);
let initial_auth = AuthDotJson {
auth_mode: Some(AuthMode::Chatgpt),
openai_api_key: None,
tokens: Some(initial_tokens.clone()),
last_refresh: Some(stale_refresh),
last_refresh: Some(fresh_refresh),
};
ctx.write_auth(&initial_auth)?;

Expand All @@ -373,7 +375,7 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> {
.as_ref()
.context("last_refresh should be recorded")?;
assert!(
*refreshed_at >= stale_refresh,
*refreshed_at >= fresh_refresh,
"last_refresh should advance"
);

Expand Down Expand Up @@ -867,7 +869,7 @@ impl Drop for EnvGuard {
}
}

fn minimal_jwt() -> String {
fn jwt_with_payload(payload: serde_json::Value) -> String {
#[derive(Serialize)]
struct Header {
alg: &'static str,
Expand All @@ -878,7 +880,6 @@ fn minimal_jwt() -> String {
alg: "none",
typ: "JWT",
};
let payload = json!({ "sub": "user-123" });

fn b64(data: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data)
Expand All @@ -898,6 +899,14 @@ fn minimal_jwt() -> String {
format!("{header_b64}.{payload_b64}.{signature_b64}")
}

fn minimal_jwt() -> String {
jwt_with_payload(json!({ "sub": "user-123" }))
}

fn access_token_with_expiration(expires_at: chrono::DateTime<Utc>) -> String {
jwt_with_payload(json!({ "sub": "user-123", "exp": expires_at.timestamp() }))
}

fn build_tokens(access_token: &str, refresh_token: &str) -> TokenData {
let id_token = IdTokenInfo {
raw_jwt: minimal_jwt(),
Expand Down
7 changes: 6 additions & 1 deletion codex-rs/login/src/auth/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ use crate::token_data::KnownPlan as InternalKnownPlan;
use crate::token_data::PlanType as InternalPlanType;
use crate::token_data::TokenData;
use crate::token_data::parse_chatgpt_jwt_claims;
use crate::token_data::parse_jwt_expiration;
use codex_client::CodexHttpClient;
use codex_protocol::account::PlanType as AccountPlanType;
use serde_json::Value;
Expand Down Expand Up @@ -69,7 +70,6 @@ impl PartialEq for CodexAuth {
}
}

// TODO(pakrym): use token exp field to check for expiration instead
const TOKEN_REFRESH_INTERVAL: i64 = 8;

const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again.";
Expand Down Expand Up @@ -1333,6 +1333,11 @@ impl AuthManager {
Some(auth_dot_json) => auth_dot_json,
None => return false,
};
if let Some(tokens) = auth_dot_json.tokens.as_ref()
&& let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token)
{
return expires_at <= Utc::now();
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we want a buffer? TOKEN_REFRESH_INTERVAL?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need a buffer? expired_at is a timestamp? we still have the fallback code path that uses Token_refresh_interval after this in case exp field doesn't exist in token?

Copy link
Copy Markdown
Collaborator

@pakrym-oai pakrym-oai Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need that codepath, can "exp" not be there?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically, jwt token doesn't necessarily have the 'exp' field. In reality, the auth token returned should always have this field. We can also throw an error here if exp doesn't exist, but I think having this silent fallback is safer

}
let last_refresh = match auth_dot_json.last_refresh {
Some(last_refresh) => last_refresh,
None => return false,
Expand Down
25 changes: 23 additions & 2 deletions codex-rs/login/src/token_data.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use base64::Engine;
use chrono::DateTime;
use chrono::Utc;
use serde::Deserialize;
use serde::Serialize;
use serde::de::DeserializeOwned;
use thiserror::Error;

#[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)]
Expand Down Expand Up @@ -117,6 +120,12 @@ struct AuthClaims {
chatgpt_account_id: Option<String>,
}

#[derive(Deserialize)]
struct StandardJwtClaims {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why a separate struct? can we put this onto AuthClaims?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or IdClaims ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function is supposed to be generic so it can parse any jwt, not just id / access tokens I think

#[serde(default)]
exp: Option<i64>,
}

#[derive(Debug, Error)]
pub enum IdTokenInfoError {
#[error("invalid ID token format")]
Expand All @@ -127,7 +136,7 @@ pub enum IdTokenInfoError {
Json(#[from] serde_json::Error),
}

pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
fn decode_jwt_payload<T: DeserializeOwned>(jwt: &str) -> Result<T, IdTokenInfoError> {
// JWT format: header.payload.signature
let mut parts = jwt.split('.');
let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) {
Expand All @@ -136,7 +145,19 @@ pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoErr
};

let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64)?;
let claims: IdClaims = serde_json::from_slice(&payload_bytes)?;
let claims = serde_json::from_slice(&payload_bytes)?;
Ok(claims)
}

pub fn parse_jwt_expiration(jwt: &str) -> Result<Option<DateTime<Utc>>, IdTokenInfoError> {
let claims: StandardJwtClaims = decode_jwt_payload(jwt)?;
Ok(claims
.exp
.and_then(|exp| DateTime::<Utc>::from_timestamp(exp, 0)))
}

pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result<IdTokenInfo, IdTokenInfoError> {
let claims: IdClaims = decode_jwt_payload(jwt)?;
let email = claims
.email
.or_else(|| claims.profile.and_then(|profile| profile.email));
Expand Down
89 changes: 41 additions & 48 deletions codex-rs/login/src/token_data_tests.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use super::*;
use chrono::TimeZone;
use chrono::Utc;
use pretty_assertions::assert_eq;
use serde::Serialize;

#[test]
fn id_token_info_parses_email_and_plan() {
fn fake_jwt(payload: serde_json::Value) -> String {
#[derive(Serialize)]
struct Header {
alg: &'static str,
Expand All @@ -13,12 +14,6 @@ fn id_token_info_parses_email_and_plan() {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
});

fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
Expand All @@ -27,7 +22,17 @@ fn id_token_info_parses_email_and_plan() {
let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
format!("{header_b64}.{payload_b64}.{signature_b64}")
}

#[test]
fn id_token_info_parses_email_and_plan() {
let fake_jwt = fake_jwt(serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "pro"
}
}));

let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
Expand All @@ -36,30 +41,12 @@ fn id_token_info_parses_email_and_plan() {

#[test]
fn id_token_info_parses_go_plan() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({
let fake_jwt = fake_jwt(serde_json::json!({
"email": "user@example.com",
"https://api.openai.com/auth": {
"chatgpt_plan_type": "go"
}
});

fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}

let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
}));

let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert_eq!(info.email.as_deref(), Some("user@example.com"));
Expand All @@ -68,31 +55,37 @@ fn id_token_info_parses_go_plan() {

#[test]
fn id_token_info_handles_missing_fields() {
#[derive(Serialize)]
struct Header {
alg: &'static str,
typ: &'static str,
}
let header = Header {
alg: "none",
typ: "JWT",
};
let payload = serde_json::json!({ "sub": "123" });

fn b64url_no_pad(bytes: &[u8]) -> String {
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}

let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap());
let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap());
let signature_b64 = b64url_no_pad(b"sig");
let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}");
let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" }));

let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse");
assert!(info.email.is_none());
assert!(info.get_chatgpt_plan_type().is_none());
}

#[test]
fn jwt_expiration_parses_exp_claim() {
let fake_jwt = fake_jwt(serde_json::json!({
"exp": 1_700_000_000_i64,
}));

let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse");
assert_eq!(expires_at, Utc.timestamp_opt(1_700_000_000, 0).single());
}

#[test]
fn jwt_expiration_handles_missing_exp() {
let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" }));

let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse");
assert_eq!(expires_at, None);
}

#[test]
fn jwt_expiration_rejects_malformed_jwt() {
let err = parse_jwt_expiration("not-a-jwt").expect_err("should fail");
assert_eq!(err.to_string(), "invalid ID token format");
}

#[test]
fn workspace_account_detection_matches_workspace_plans() {
let workspace = IdTokenInfo {
Expand Down
Loading