From 8e5aeddda4c73fbb009a77a7b659be21da341185 Mon Sep 17 00:00:00 2001 From: celia-oai Date: Mon, 23 Mar 2026 12:35:02 -0700 Subject: [PATCH] changes --- codex-rs/core/tests/suite/auth_refresh.rs | 29 +++++--- codex-rs/login/src/auth/manager.rs | 7 +- codex-rs/login/src/token_data.rs | 25 ++++++- codex-rs/login/src/token_data_tests.rs | 89 +++++++++++------------ 4 files changed, 89 insertions(+), 61 deletions(-) diff --git a/codex-rs/core/tests/suite/auth_refresh.rs b/codex-rs/core/tests/suite/auth_refresh.rs index 278124c63a4..680295b026c 100644 --- a/codex-rs/core/tests/suite/auth_refresh.rs +++ b/codex-rs/core/tests/suite/auth_refresh.rs @@ -294,13 +294,14 @@ async fn returns_fresh_tokens_as_is() -> Result<()> { .await; let ctx = RefreshTokenTestContext::new(&server)?; - let initial_last_refresh = Utc::now() - Duration::days(1); - let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN); + let stale_refresh = Utc::now() - Duration::days(9); + let fresh_access_token = access_token_with_expiration(Utc::now() + Duration::hours(1)); + let initial_tokens = build_tokens(&fresh_access_token, INITIAL_REFRESH_TOKEN); let initial_auth = AuthDotJson { auth_mode: Some(AuthMode::Chatgpt), openai_api_key: None, tokens: Some(initial_tokens.clone()), - last_refresh: Some(initial_last_refresh), + last_refresh: Some(stale_refresh), }; ctx.write_auth(&initial_auth)?; @@ -325,7 +326,7 @@ async fn returns_fresh_tokens_as_is() -> Result<()> { #[serial_test::serial(auth_refresh)] #[tokio::test] -async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> { +async fn refreshes_token_when_access_token_is_expired() -> Result<()> { skip_if_no_network!(Ok(())); let server = MockServer::start().await; @@ -340,13 +341,14 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> { .await; let ctx = RefreshTokenTestContext::new(&server)?; - let stale_refresh = Utc::now() - Duration::days(9); - let initial_tokens = build_tokens(INITIAL_ACCESS_TOKEN, INITIAL_REFRESH_TOKEN); + let fresh_refresh = Utc::now() - Duration::days(1); + let expired_access_token = access_token_with_expiration(Utc::now() - Duration::hours(1)); + let initial_tokens = build_tokens(&expired_access_token, INITIAL_REFRESH_TOKEN); let initial_auth = AuthDotJson { auth_mode: Some(AuthMode::Chatgpt), openai_api_key: None, tokens: Some(initial_tokens.clone()), - last_refresh: Some(stale_refresh), + last_refresh: Some(fresh_refresh), }; ctx.write_auth(&initial_auth)?; @@ -373,7 +375,7 @@ async fn refreshes_token_when_last_refresh_is_stale() -> Result<()> { .as_ref() .context("last_refresh should be recorded")?; assert!( - *refreshed_at >= stale_refresh, + *refreshed_at >= fresh_refresh, "last_refresh should advance" ); @@ -867,7 +869,7 @@ impl Drop for EnvGuard { } } -fn minimal_jwt() -> String { +fn jwt_with_payload(payload: serde_json::Value) -> String { #[derive(Serialize)] struct Header { alg: &'static str, @@ -878,7 +880,6 @@ fn minimal_jwt() -> String { alg: "none", typ: "JWT", }; - let payload = json!({ "sub": "user-123" }); fn b64(data: &[u8]) -> String { base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(data) @@ -898,6 +899,14 @@ fn minimal_jwt() -> String { format!("{header_b64}.{payload_b64}.{signature_b64}") } +fn minimal_jwt() -> String { + jwt_with_payload(json!({ "sub": "user-123" })) +} + +fn access_token_with_expiration(expires_at: chrono::DateTime) -> String { + jwt_with_payload(json!({ "sub": "user-123", "exp": expires_at.timestamp() })) +} + fn build_tokens(access_token: &str, refresh_token: &str) -> TokenData { let id_token = IdTokenInfo { raw_jwt: minimal_jwt(), diff --git a/codex-rs/login/src/auth/manager.rs b/codex-rs/login/src/auth/manager.rs index 31860cd5857..d9d520b2ffa 100644 --- a/codex-rs/login/src/auth/manager.rs +++ b/codex-rs/login/src/auth/manager.rs @@ -28,6 +28,7 @@ use crate::token_data::KnownPlan as InternalKnownPlan; use crate::token_data::PlanType as InternalPlanType; use crate::token_data::TokenData; use crate::token_data::parse_chatgpt_jwt_claims; +use crate::token_data::parse_jwt_expiration; use codex_client::CodexHttpClient; use codex_protocol::account::PlanType as AccountPlanType; use serde_json::Value; @@ -69,7 +70,6 @@ impl PartialEq for CodexAuth { } } -// TODO(pakrym): use token exp field to check for expiration instead const TOKEN_REFRESH_INTERVAL: i64 = 8; const REFRESH_TOKEN_EXPIRED_MESSAGE: &str = "Your access token could not be refreshed because your refresh token has expired. Please log out and sign in again."; @@ -1333,6 +1333,11 @@ impl AuthManager { Some(auth_dot_json) => auth_dot_json, None => return false, }; + if let Some(tokens) = auth_dot_json.tokens.as_ref() + && let Ok(Some(expires_at)) = parse_jwt_expiration(&tokens.access_token) + { + return expires_at <= Utc::now(); + } let last_refresh = match auth_dot_json.last_refresh { Some(last_refresh) => last_refresh, None => return false, diff --git a/codex-rs/login/src/token_data.rs b/codex-rs/login/src/token_data.rs index 304bf765f41..1056115a37f 100644 --- a/codex-rs/login/src/token_data.rs +++ b/codex-rs/login/src/token_data.rs @@ -1,6 +1,9 @@ use base64::Engine; +use chrono::DateTime; +use chrono::Utc; use serde::Deserialize; use serde::Serialize; +use serde::de::DeserializeOwned; use thiserror::Error; #[derive(Deserialize, Serialize, Clone, Debug, PartialEq, Default)] @@ -117,6 +120,12 @@ struct AuthClaims { chatgpt_account_id: Option, } +#[derive(Deserialize)] +struct StandardJwtClaims { + #[serde(default)] + exp: Option, +} + #[derive(Debug, Error)] pub enum IdTokenInfoError { #[error("invalid ID token format")] @@ -127,7 +136,7 @@ pub enum IdTokenInfoError { Json(#[from] serde_json::Error), } -pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result { +fn decode_jwt_payload(jwt: &str) -> Result { // JWT format: header.payload.signature let mut parts = jwt.split('.'); let (_header_b64, payload_b64, _sig_b64) = match (parts.next(), parts.next(), parts.next()) { @@ -136,7 +145,19 @@ pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result Result>, IdTokenInfoError> { + let claims: StandardJwtClaims = decode_jwt_payload(jwt)?; + Ok(claims + .exp + .and_then(|exp| DateTime::::from_timestamp(exp, 0))) +} + +pub fn parse_chatgpt_jwt_claims(jwt: &str) -> Result { + let claims: IdClaims = decode_jwt_payload(jwt)?; let email = claims .email .or_else(|| claims.profile.and_then(|profile| profile.email)); diff --git a/codex-rs/login/src/token_data_tests.rs b/codex-rs/login/src/token_data_tests.rs index e599379c18f..7f77ee5eaf1 100644 --- a/codex-rs/login/src/token_data_tests.rs +++ b/codex-rs/login/src/token_data_tests.rs @@ -1,9 +1,10 @@ use super::*; +use chrono::TimeZone; +use chrono::Utc; use pretty_assertions::assert_eq; use serde::Serialize; -#[test] -fn id_token_info_parses_email_and_plan() { +fn fake_jwt(payload: serde_json::Value) -> String { #[derive(Serialize)] struct Header { alg: &'static str, @@ -13,12 +14,6 @@ fn id_token_info_parses_email_and_plan() { alg: "none", typ: "JWT", }; - let payload = serde_json::json!({ - "email": "user@example.com", - "https://api.openai.com/auth": { - "chatgpt_plan_type": "pro" - } - }); fn b64url_no_pad(bytes: &[u8]) -> String { base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) @@ -27,7 +22,17 @@ fn id_token_info_parses_email_and_plan() { let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); let signature_b64 = b64url_no_pad(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + format!("{header_b64}.{payload_b64}.{signature_b64}") +} + +#[test] +fn id_token_info_parses_email_and_plan() { + let fake_jwt = fake_jwt(serde_json::json!({ + "email": "user@example.com", + "https://api.openai.com/auth": { + "chatgpt_plan_type": "pro" + } + })); let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); assert_eq!(info.email.as_deref(), Some("user@example.com")); @@ -36,30 +41,12 @@ fn id_token_info_parses_email_and_plan() { #[test] fn id_token_info_parses_go_plan() { - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - let header = Header { - alg: "none", - typ: "JWT", - }; - let payload = serde_json::json!({ + let fake_jwt = fake_jwt(serde_json::json!({ "email": "user@example.com", "https://api.openai.com/auth": { "chatgpt_plan_type": "go" } - }); - - fn b64url_no_pad(bytes: &[u8]) -> String { - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) - } - - let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); - let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); - let signature_b64 = b64url_no_pad(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + })); let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); assert_eq!(info.email.as_deref(), Some("user@example.com")); @@ -68,31 +55,37 @@ fn id_token_info_parses_go_plan() { #[test] fn id_token_info_handles_missing_fields() { - #[derive(Serialize)] - struct Header { - alg: &'static str, - typ: &'static str, - } - let header = Header { - alg: "none", - typ: "JWT", - }; - let payload = serde_json::json!({ "sub": "123" }); - - fn b64url_no_pad(bytes: &[u8]) -> String { - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) - } - - let header_b64 = b64url_no_pad(&serde_json::to_vec(&header).unwrap()); - let payload_b64 = b64url_no_pad(&serde_json::to_vec(&payload).unwrap()); - let signature_b64 = b64url_no_pad(b"sig"); - let fake_jwt = format!("{header_b64}.{payload_b64}.{signature_b64}"); + let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" })); let info = parse_chatgpt_jwt_claims(&fake_jwt).expect("should parse"); assert!(info.email.is_none()); assert!(info.get_chatgpt_plan_type().is_none()); } +#[test] +fn jwt_expiration_parses_exp_claim() { + let fake_jwt = fake_jwt(serde_json::json!({ + "exp": 1_700_000_000_i64, + })); + + let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse"); + assert_eq!(expires_at, Utc.timestamp_opt(1_700_000_000, 0).single()); +} + +#[test] +fn jwt_expiration_handles_missing_exp() { + let fake_jwt = fake_jwt(serde_json::json!({ "sub": "123" })); + + let expires_at = parse_jwt_expiration(&fake_jwt).expect("should parse"); + assert_eq!(expires_at, None); +} + +#[test] +fn jwt_expiration_rejects_malformed_jwt() { + let err = parse_jwt_expiration("not-a-jwt").expect_err("should fail"); + assert_eq!(err.to_string(), "invalid ID token format"); +} + #[test] fn workspace_account_detection_matches_workspace_plans() { let workspace = IdTokenInfo {