diff --git a/.changeset/refresh-long-running-helper-tokens.md b/.changeset/refresh-long-running-helper-tokens.md new file mode 100644 index 00000000..aab7eba7 --- /dev/null +++ b/.changeset/refresh-long-running-helper-tokens.md @@ -0,0 +1,5 @@ +--- +"@googleworkspace/cli": patch +--- + +Refresh OAuth access tokens for long-running Gmail watch and Workspace Events subscribe helpers before each Pub/Sub and Gmail request. diff --git a/src/auth.rs b/src/auth.rs index 2b6716f1..3825b7ad 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -77,6 +77,73 @@ enum Credential { ServiceAccount(yup_oauth2::ServiceAccountKey), } +/// Fetches access tokens for a fixed set of scopes. +/// +/// Long-running helpers use this trait so they can request a fresh token before +/// each API call instead of holding a single token string until it expires. +#[async_trait::async_trait] +pub trait AccessTokenProvider: Send + Sync { + async fn access_token(&self) -> anyhow::Result; +} + +/// A token provider backed by [`get_token`]. +/// +/// This keeps the scope list in one place so call sites can ask for a fresh +/// token whenever they need to make another request. +#[derive(Debug, Clone)] +pub struct ScopedTokenProvider { + scopes: Vec, +} + +impl ScopedTokenProvider { + pub fn new(scopes: &[&str]) -> Self { + Self { + scopes: scopes.iter().map(|scope| (*scope).to_string()).collect(), + } + } +} + +#[async_trait::async_trait] +impl AccessTokenProvider for ScopedTokenProvider { + async fn access_token(&self) -> anyhow::Result { + let scopes: Vec<&str> = self.scopes.iter().map(String::as_str).collect(); + get_token(&scopes).await + } +} + +pub fn token_provider(scopes: &[&str]) -> ScopedTokenProvider { + ScopedTokenProvider::new(scopes) +} + +/// A fake [`AccessTokenProvider`] for tests that returns tokens from a queue. +#[cfg(test)] +pub struct FakeTokenProvider { + tokens: std::sync::Arc>>, +} + +#[cfg(test)] +impl FakeTokenProvider { + pub fn new(tokens: impl IntoIterator) -> Self { + Self { + tokens: std::sync::Arc::new(tokio::sync::Mutex::new( + tokens.into_iter().map(|t| t.to_string()).collect(), + )), + } + } +} + +#[cfg(test)] +#[async_trait::async_trait] +impl AccessTokenProvider for FakeTokenProvider { + async fn access_token(&self) -> anyhow::Result { + self.tokens + .lock() + .await + .pop_front() + .ok_or_else(|| anyhow::anyhow!("no test token remaining")) + } +} + /// Builds an OAuth2 authenticator and returns an access token. /// /// Tries credentials in order: @@ -544,6 +611,19 @@ mod tests { assert_eq!(result.unwrap(), "my-test-token"); } + #[tokio::test] + #[serial_test::serial] + async fn test_scoped_token_provider_uses_get_token() { + let _token_guard = EnvVarGuard::set("GOOGLE_WORKSPACE_CLI_TOKEN", "provider-token"); + let provider = token_provider(&["https://www.googleapis.com/auth/drive"]); + + let first = provider.access_token().await.unwrap(); + let second = provider.access_token().await.unwrap(); + + assert_eq!(first, "provider-token"); + assert_eq!(second, "provider-token"); + } + #[tokio::test] async fn test_load_credentials_encrypted_file() { // Simulate an encrypted credentials file diff --git a/src/helpers/events/subscribe.rs b/src/helpers/events/subscribe.rs index dbe6047c..c1df3667 100644 --- a/src/helpers/events/subscribe.rs +++ b/src/helpers/events/subscribe.rs @@ -1,6 +1,9 @@ use super::*; +use crate::auth::AccessTokenProvider; use std::path::PathBuf; +const PUBSUB_API_BASE: &str = "https://pubsub.googleapis.com/v1"; + #[derive(Debug, Clone, Default, Builder)] #[builder(setter(into))] pub struct SubscribeConfig { @@ -110,6 +113,7 @@ pub(super) async fn handle_subscribe( } let client = crate::client::build_client()?; + let pubsub_token_provider = auth::token_provider(&[PUBSUB_SCOPE]); // Get Pub/Sub token let pubsub_token = auth::get_token(&[PUBSUB_SCOPE]) @@ -248,29 +252,38 @@ pub(super) async fn handle_subscribe( }; // Pull loop - let result = pull_loop(&client, &pubsub_token, &pubsub_subscription, config.clone()).await; + let result = pull_loop( + &client, + &pubsub_token_provider, + &pubsub_subscription, + config.clone(), + PUBSUB_API_BASE, + ) + .await; // On exit, print reconnection info or cleanup if created_resources { if config.cleanup { eprintln!("\nCleaning up Pub/Sub resources..."); // Delete Pub/Sub subscription - let _ = client - .delete(format!( - "https://pubsub.googleapis.com/v1/{pubsub_subscription}" - )) - .bearer_auth(&pubsub_token) - .send() - .await; - // Delete Pub/Sub topic - if let Some(ref topic) = topic_name { + if let Ok(pubsub_token) = pubsub_token_provider.access_token().await { let _ = client - .delete(format!("https://pubsub.googleapis.com/v1/{topic}")) + .delete(format!("{PUBSUB_API_BASE}/{pubsub_subscription}")) .bearer_auth(&pubsub_token) .send() .await; + // Delete Pub/Sub topic + if let Some(ref topic) = topic_name { + let _ = client + .delete(format!("{PUBSUB_API_BASE}/{topic}")) + .bearer_auth(&pubsub_token) + .send() + .await; + } + eprintln!("Cleanup complete."); + } else { + eprintln!("Warning: failed to refresh token for cleanup. Resources may need manual deletion."); } - eprintln!("Cleanup complete."); } else { eprintln!("\n--- Reconnection Info ---"); eprintln!( @@ -301,21 +314,24 @@ pub(super) async fn handle_subscribe( /// Pulls messages from a Pub/Sub subscription in a loop. async fn pull_loop( client: &reqwest::Client, - token: &str, + token_provider: &dyn auth::AccessTokenProvider, subscription: &str, config: SubscribeConfig, + pubsub_api_base: &str, ) -> Result<(), GwsError> { let mut file_counter: u64 = 0; loop { + let token = token_provider + .access_token() + .await + .map_err(|e| GwsError::Auth(format!("Failed to get Pub/Sub token: {e}")))?; let pull_body = json!({ "maxMessages": config.max_messages, }); let pull_future = client - .post(format!( - "https://pubsub.googleapis.com/v1/{subscription}:pull" - )) - .bearer_auth(token) + .post(format!("{pubsub_api_base}/{subscription}:pull")) + .bearer_auth(&token) .header("Content-Type", "application/json") .json(&pull_body) .timeout(std::time::Duration::from_secs(config.poll_interval.max(10))) @@ -379,10 +395,8 @@ async fn pull_loop( }); let _ = client - .post(format!( - "https://pubsub.googleapis.com/v1/{subscription}:acknowledge" - )) - .bearer_auth(token) + .post(format!("{pubsub_api_base}/{subscription}:acknowledge")) + .bearer_auth(&token) .header("Content-Type", "application/json") .json(&ack_body) .send() @@ -526,6 +540,76 @@ fn derive_slug_from_event_types(event_types: &[&str]) -> String { #[cfg(test)] mod tests { use super::*; + use crate::auth::FakeTokenProvider; + use base64::Engine as _; + use std::sync::Arc; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + use tokio::sync::Mutex; + + async fn spawn_subscribe_server() -> ( + String, + Arc>>, + tokio::task::JoinHandle<()>, + ) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let requests = Arc::new(Mutex::new(Vec::new())); + let recorded_requests = Arc::clone(&requests); + + let handle = tokio::spawn(async move { + for _ in 0..2 { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = [0_u8; 8192]; + let bytes_read = stream.read(&mut buf).await.unwrap(); + let request = String::from_utf8_lossy(&buf[..bytes_read]); + let path = request + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .unwrap_or("") + .to_string(); + let auth_header = request + .lines() + .find(|line| line.to_ascii_lowercase().starts_with("authorization:")) + .unwrap_or("") + .trim() + .to_string(); + recorded_requests + .lock() + .await + .push((path.clone(), auth_header)); + + let body = match path.as_str() { + "/v1/projects/test/subscriptions/demo:pull" => json!({ + "receivedMessages": [{ + "ackId": "ack-1", + "message": { + "attributes": { + "type": "google.workspace.chat.message.v1.created", + "source": "//chat/spaces/A" + }, + "data": base64::engine::general_purpose::STANDARD + .encode(json!({ "id": "evt-1" }).to_string()) + } + }] + }) + .to_string(), + "/v1/projects/test/subscriptions/demo:acknowledge" => json!({}).to_string(), + other => panic!("unexpected request path: {other}"), + }; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nConnection: close\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await.unwrap(); + } + }); + + (format!("http://{addr}/v1"), requests, handle) + } fn make_matches_subscribe(args: &[&str]) -> ArgMatches { let cmd = Command::new("test") @@ -753,4 +837,42 @@ mod tests { let err_msg = result.unwrap_err().to_string(); assert!(err_msg.contains("--project is required")); } + + #[tokio::test] + async fn test_pull_loop_refreshes_pubsub_token_between_requests() { + let client = reqwest::Client::new(); + let token_provider = FakeTokenProvider::new(["pubsub-token"]); + let (pubsub_base, requests, server) = spawn_subscribe_server().await; + let config = SubscribeConfigBuilder::default() + .subscription(Some(SubscriptionName( + "projects/test/subscriptions/demo".to_string(), + ))) + .max_messages(1_u32) + .poll_interval(1_u64) + .once(true) + .build() + .unwrap(); + + pull_loop( + &client, + &token_provider, + "projects/test/subscriptions/demo", + config, + &pubsub_base, + ) + .await + .unwrap(); + + server.await.unwrap(); + + let requests = requests.lock().await; + assert_eq!(requests.len(), 2); + assert_eq!(requests[0].0, "/v1/projects/test/subscriptions/demo:pull"); + assert_eq!(requests[0].1, "authorization: Bearer pubsub-token"); + assert_eq!( + requests[1].0, + "/v1/projects/test/subscriptions/demo:acknowledge" + ); + assert_eq!(requests[1].1, "authorization: Bearer pubsub-token"); + } } diff --git a/src/helpers/gmail/watch.rs b/src/helpers/gmail/watch.rs index 86401799..4e811952 100644 --- a/src/helpers/gmail/watch.rs +++ b/src/helpers/gmail/watch.rs @@ -1,4 +1,8 @@ use super::*; +use crate::auth::AccessTokenProvider; + +const PUBSUB_API_BASE: &str = "https://pubsub.googleapis.com/v1"; +const GMAIL_API_BASE: &str = "https://gmail.googleapis.com/gmail/v1"; /// Handles the `+watch` command — Gmail push notifications via Pub/Sub. pub(super) async fn handle_watch( @@ -12,6 +16,8 @@ pub(super) async fn handle_watch( } let client = crate::client::build_client()?; + let gmail_token_provider = auth::token_provider(&[GMAIL_SCOPE]); + let pubsub_token_provider = auth::token_provider(&[PUBSUB_SCOPE]); // Get tokens let gmail_token = auth::get_token(&[GMAIL_SCOPE]) @@ -195,14 +201,19 @@ pub(super) async fn handle_watch( .unwrap_or(0); // Pull loop + let runtime = WatchRuntime { + client: &client, + pubsub_token_provider: &pubsub_token_provider, + gmail_token_provider: &gmail_token_provider, + sanitize_config, + pubsub_api_base: PUBSUB_API_BASE, + gmail_api_base: GMAIL_API_BASE, + }; let result = watch_pull_loop( - &client, - &pubsub_token, - &gmail_token, + &runtime, &pubsub_subscription, &mut last_history_id, config.clone(), - sanitize_config, ) .await; @@ -210,22 +221,23 @@ pub(super) async fn handle_watch( if created_resources { if config.cleanup { eprintln!("\nCleaning up Pub/Sub resources..."); - let _ = client - .delete(format!( - "https://pubsub.googleapis.com/v1/{}", - pubsub_subscription - )) - .bearer_auth(&pubsub_token) - .send() - .await; - if let Some(ref topic) = topic_name { + if let Ok(pubsub_token) = pubsub_token_provider.access_token().await { let _ = client - .delete(format!("https://pubsub.googleapis.com/v1/{}", topic)) + .delete(format!("{PUBSUB_API_BASE}/{}", pubsub_subscription)) .bearer_auth(&pubsub_token) .send() .await; + if let Some(ref topic) = topic_name { + let _ = client + .delete(format!("{PUBSUB_API_BASE}/{}", topic)) + .bearer_auth(&pubsub_token) + .send() + .await; + } + eprintln!("Cleanup complete."); + } else { + eprintln!("Warning: failed to refresh token for cleanup. Resources may need manual deletion."); } - eprintln!("Cleanup complete."); } else { eprintln!("\n--- Reconnection Info ---"); eprintln!( @@ -245,21 +257,22 @@ pub(super) async fn handle_watch( /// Pull loop for Gmail watch — polls Pub/Sub, fetches messages via history API. async fn watch_pull_loop( - client: &reqwest::Client, - pubsub_token: &str, - gmail_token: &str, + runtime: &WatchRuntime<'_>, subscription: &str, last_history_id: &mut u64, config: WatchConfig, - sanitize_config: &crate::helpers::modelarmor::SanitizeConfig, ) -> Result<(), GwsError> { loop { + let pubsub_token = runtime + .pubsub_token_provider + .access_token() + .await + .context("Failed to get Pub/Sub token")?; let pull_body = json!({ "maxMessages": config.max_messages }); - let pull_future = client - .post(format!( - "https://pubsub.googleapis.com/v1/{subscription}:pull" - )) - .bearer_auth(pubsub_token) + let pull_future = runtime + .client + .post(format!("{}/{subscription}:pull", runtime.pubsub_api_base)) + .bearer_auth(&pubsub_token) .header("Content-Type", "application/json") .json(&pull_body) .timeout(std::time::Duration::from_secs(config.poll_interval.max(10))) @@ -296,12 +309,13 @@ async fn watch_pull_loop( if max_history_id > *last_history_id && *last_history_id > 0 { // Fetch new messages via history API fetch_and_output_messages( - client, - gmail_token, + runtime.client, + runtime.gmail_token_provider, *last_history_id, &config.format, config.output_dir.as_ref(), - sanitize_config, + runtime.sanitize_config, + runtime.gmail_api_base, ) .await?; } @@ -313,11 +327,13 @@ async fn watch_pull_loop( // Acknowledge messages if !ack_ids.is_empty() { let ack_body = json!({ "ackIds": ack_ids }); - let _ = client + let _ = runtime + .client .post(format!( - "https://pubsub.googleapis.com/v1/{subscription}:acknowledge" + "{}/{subscription}:acknowledge", + runtime.pubsub_api_base )) - .bearer_auth(pubsub_token) + .bearer_auth(&pubsub_token) .header("Content-Type", "application/json") .json(&ack_body) .send() @@ -379,19 +395,24 @@ fn process_pull_response(response: &Value) -> (Vec, u64) { /// Fetches new messages since `start_history_id` and outputs them as NDJSON. async fn fetch_and_output_messages( client: &reqwest::Client, - gmail_token: &str, + gmail_token_provider: &dyn auth::AccessTokenProvider, start_history_id: u64, msg_format: &str, output_dir: Option<&std::path::PathBuf>, sanitize_config: &crate::helpers::modelarmor::SanitizeConfig, + gmail_api_base: &str, ) -> Result<(), GwsError> { + let gmail_token = gmail_token_provider + .access_token() + .await + .context("Failed to get Gmail token")?; let resp = client - .get("https://gmail.googleapis.com/gmail/v1/users/me/history") + .get(format!("{gmail_api_base}/users/me/history")) .query(&[ ("startHistoryId", &start_history_id.to_string()), ("historyTypes", &"messageAdded".to_string()), ]) - .bearer_auth(gmail_token) + .bearer_auth(&gmail_token) .send() .await .context("Failed to get history")?; @@ -401,15 +422,14 @@ async fn fetch_and_output_messages( let msg_ids = extract_message_ids_from_history(&body); for msg_id in msg_ids { - // Fetch full message let msg_url = format!( - "https://gmail.googleapis.com/gmail/v1/users/me/messages/{}", + "{gmail_api_base}/users/me/messages/{}", crate::validate::encode_path_segment(&msg_id), ); let msg_resp = client .get(&msg_url) .query(&[("format", msg_format)]) - .bearer_auth(gmail_token) + .bearer_auth(&gmail_token) .send() .await; @@ -527,6 +547,15 @@ struct WatchConfig { output_dir: Option, } +struct WatchRuntime<'a> { + client: &'a reqwest::Client, + pubsub_token_provider: &'a dyn auth::AccessTokenProvider, + gmail_token_provider: &'a dyn auth::AccessTokenProvider, + sanitize_config: &'a crate::helpers::modelarmor::SanitizeConfig, + pubsub_api_base: &'a str, + gmail_api_base: &'a str, +} + fn parse_watch_args(matches: &ArgMatches) -> Result { let format_str = matches .get_one::("msg-format") @@ -562,6 +591,94 @@ fn parse_watch_args(matches: &ArgMatches) -> Result { #[cfg(test)] mod tests { use super::*; + use crate::auth::FakeTokenProvider; + use std::sync::Arc; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio::net::TcpListener; + use tokio::sync::Mutex; + + async fn spawn_watch_server() -> ( + String, + String, + Arc>>, + tokio::task::JoinHandle<()>, + ) { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let requests = Arc::new(Mutex::new(Vec::new())); + let recorded_requests = Arc::clone(&requests); + + let handle = tokio::spawn(async move { + for _ in 0..4 { + let (mut stream, _) = listener.accept().await.unwrap(); + let mut buf = [0_u8; 8192]; + let bytes_read = stream.read(&mut buf).await.unwrap(); + let request = String::from_utf8_lossy(&buf[..bytes_read]); + let path = request + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .unwrap_or("") + .to_string(); + let auth_header = request + .lines() + .find(|line| line.to_ascii_lowercase().starts_with("authorization:")) + .unwrap_or("") + .trim() + .to_string(); + recorded_requests + .lock() + .await + .push((path.clone(), auth_header)); + + let body = match path.as_str() { + "/v1/projects/test/subscriptions/demo:pull" => { + let payload = base64::engine::general_purpose::STANDARD + .encode(json!({ "historyId": 2 }).to_string()); + json!({ + "receivedMessages": [{ + "ackId": "ack-1", + "message": { + "data": payload, + "messageId": "msg-1" + } + }] + }) + .to_string() + } + "/gmail/v1/users/me/history?startHistoryId=1&historyTypes=messageAdded" => { + json!({ + "history": [{ + "messagesAdded": [{ + "message": { "id": "msg-1" } + }] + }] + }) + .to_string() + } + "/gmail/v1/users/me/messages/msg%2D1?format=full" => { + json!({ "id": "msg-1" }).to_string() + } + "/v1/projects/test/subscriptions/demo:acknowledge" => json!({}).to_string(), + other => panic!("unexpected request path: {other}"), + }; + + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: application/json\r\nConnection: close\r\nContent-Length: {}\r\n\r\n{}", + body.len(), + body + ); + stream.write_all(response.as_bytes()).await.unwrap(); + } + }); + + ( + format!("http://{addr}/v1"), + format!("http://{addr}/gmail/v1"), + requests, + handle, + ) + } #[test] fn test_extract_message_ids_from_history() { @@ -778,4 +895,70 @@ mod tests { assert_eq!(output, msg); assert!(output.get("_sanitization").is_none()); } + + #[tokio::test] + async fn test_watch_pull_loop_refreshes_tokens_for_each_request() { + let client = reqwest::Client::new(); + let pubsub_provider = FakeTokenProvider::new(["pubsub-token"]); + let gmail_provider = FakeTokenProvider::new(["gmail-token"]); + let (pubsub_base, gmail_base, requests, server) = spawn_watch_server().await; + let mut last_history_id = 1; + let config = WatchConfig { + project: None, + subscription: None, + topic: None, + label_ids: None, + max_messages: 10, + poll_interval: 1, + format: "full".to_string(), + once: true, + cleanup: false, + output_dir: None, + }; + let sanitize_config = crate::helpers::modelarmor::SanitizeConfig { + template: None, + mode: crate::helpers::modelarmor::SanitizeMode::Warn, + }; + + let runtime = WatchRuntime { + client: &client, + pubsub_token_provider: &pubsub_provider, + gmail_token_provider: &gmail_provider, + sanitize_config: &sanitize_config, + pubsub_api_base: &pubsub_base, + gmail_api_base: &gmail_base, + }; + + watch_pull_loop( + &runtime, + "projects/test/subscriptions/demo", + &mut last_history_id, + config, + ) + .await + .unwrap(); + + server.await.unwrap(); + + let requests = requests.lock().await; + assert_eq!(requests.len(), 4); + assert_eq!(requests[0].0, "/v1/projects/test/subscriptions/demo:pull"); + assert_eq!(requests[0].1, "authorization: Bearer pubsub-token"); + assert_eq!( + requests[1].0, + "/gmail/v1/users/me/history?startHistoryId=1&historyTypes=messageAdded" + ); + assert_eq!(requests[1].1, "authorization: Bearer gmail-token"); + assert_eq!( + requests[2].0, + "/gmail/v1/users/me/messages/msg%2D1?format=full" + ); + assert_eq!(requests[2].1, "authorization: Bearer gmail-token"); + assert_eq!( + requests[3].0, + "/v1/projects/test/subscriptions/demo:acknowledge" + ); + assert_eq!(requests[3].1, "authorization: Bearer pubsub-token"); + assert_eq!(last_history_id, 2); + } }